Project

General

Profile

1
# SQL code generation
2

    
3
import copy
4
import operator
5
from ordereddict import OrderedDict
6
import re
7
import UserDict
8
import warnings
9

    
10
import dicts
11
import exc
12
import iters
13
import lists
14
import objects
15
import strings
16
import util
17

    
18
##### Names
19

    
20
identifier_max_len = 63 # works for both PostgreSQL and MySQL
21

    
22
def concat(str_, suffix):
23
    '''Preserves version so that it won't be truncated off the string, leading
24
    to collisions.'''
25
    # Preserve version
26
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
27
    if match:
28
        str_, old_suffix = match.groups()
29
        suffix = old_suffix+suffix
30
    
31
    return strings.concat(str_, suffix, identifier_max_len)
32

    
33
def truncate(str_): return concat(str_, '')
34

    
35
def is_safe_name(name):
36
    '''A name is safe *and unambiguous* if it:
37
    * contains only *lowercase* word (\w) characters
38
    * doesn't start with a digit
39
    * contains "_", so that it's not a keyword
40
    '''
41
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
42

    
43
def esc_name(name, quote='"'):
44
    return quote + name.replace(quote, quote+quote) + quote
45
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
46

    
47
def clean_name(name): return name.replace('"', '').replace('`', '')
48

    
49
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
50

    
51
def lstrip(str_):
52
    '''Also removes comments.'''
53
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
54
    return str_.lstrip()
55

    
56
##### General SQL code objects
57

    
58
class MockDb:
59
    def esc_value(self, value): return strings.repr_no_u(value)
60
    
61
    def esc_name(self, name): return esc_name(name)
62
    
63
    def col_info(self, col):
64
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
65

    
66
mockDb = MockDb()
67

    
68
class BasicObject(objects.BasicObject):
69
    def __init__(self, value): self.value = value
70
    
71
    def __str__(self): return clean_name(strings.repr_no_u(self))
72

    
73
##### Unparameterized code objects
74

    
75
class Code(BasicObject):
76
    def to_str(self, db): raise NotImplementedError()
77
    
78
    def __repr__(self): return self.to_str(mockDb)
79

    
80
class CustomCode(Code):
81
    def __init__(self, str_): self.str_ = str_
82
    
83
    def to_str(self, db): return self.str_
84

    
85
def as_Code(value, db=None):
86
    '''
87
    @param db If set, runs db.std_code() on the value.
88
    '''
89
    if util.is_str(value):
90
        if db != None: value = db.std_code(value)
91
        return CustomCode(value)
92
    else: return Literal(value)
93

    
94
class Expr(Code):
95
    def __init__(self, expr): self.expr = expr
96
    
97
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
98

    
99
##### Names
100

    
101
class Name(Code):
102
    def __init__(self, name):
103
        name = truncate(name)
104
        
105
        self.name = name
106
    
107
    def to_str(self, db): return db.esc_name(self.name)
108

    
109
def as_Name(value):
110
    if isinstance(value, Code): return value
111
    else: return Name(value)
112

    
113
##### Literal values
114

    
115
class Literal(Code):
116
    def __init__(self, value): self.value = value
117
    
118
    def to_str(self, db): return db.esc_value(self.value)
119

    
120
def as_Value(value):
121
    if isinstance(value, Code): return value
122
    else: return Literal(value)
123

    
124
def is_null(value): return isinstance(value, Literal) and value.value == None
125

    
126
##### Derived elements
127

    
128
src_self = object() # tells Col that it is its own source column
129

    
130
class Derived(Code):
131
    def __init__(self, srcs):
132
        '''An element which was derived from some other element(s).
133
        @param srcs See self.set_srcs()
134
        '''
135
        self.set_srcs(srcs)
136
    
137
    def set_srcs(self, srcs, overwrite=True):
138
        '''
139
        @param srcs (self_type...)|src_self The element(s) this is derived from
140
        '''
141
        if not overwrite and self.srcs != (): return # already set
142
        
143
        if srcs == src_self: srcs = (self,)
144
        srcs = tuple(srcs) # make Col hashable
145
        self.srcs = srcs
146
    
147
    def _compare_on(self):
148
        compare_on = self.__dict__.copy()
149
        del compare_on['srcs'] # ignore
150
        return compare_on
151

    
152
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
153

    
154
##### Tables
155

    
156
class Table(Derived):
157
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
158
        '''
159
        @param schema str|None (for no schema)
160
        @param srcs (Table...)|src_self See Derived.set_srcs()
161
        '''
162
        Derived.__init__(self, srcs)
163
        
164
        if util.is_str(name): name = truncate(name)
165
        
166
        self.name = name
167
        self.schema = schema
168
        self.is_temp = is_temp
169
        self.index_cols = {}
170
    
171
    def to_str(self, db):
172
        str_ = ''
173
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
174
        str_ += as_Name(self.name).to_str(db)
175
        return str_
176
    
177
    def to_Table(self): return self
178
    
179
    def _compare_on(self):
180
        compare_on = Derived._compare_on(self)
181
        del compare_on['index_cols'] # ignore
182
        return compare_on
183

    
184
def is_underlying_table(table):
185
    return isinstance(table, Table) and table.to_Table() is table
186

    
187
class NoUnderlyingTableException(Exception): pass
188

    
189
def underlying_table(table):
190
    table = remove_table_rename(table)
191
    if not is_underlying_table(table): raise NoUnderlyingTableException
192
    return table
193

    
194
def as_Table(table, schema=None):
195
    if table == None or isinstance(table, Code): return table
196
    else: return Table(table, schema)
197

    
198
def suffixed_table(table, suffix):
199
    table = copy.copy(table) # don't modify input!
200
    table.name = concat(table.name, suffix)
201
    return table
202

    
203
class NamedTable(Table):
204
    def __init__(self, name, code, cols=None):
205
        Table.__init__(self, name)
206
        
207
        code = as_Table(code)
208
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
209
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
210
        
211
        self.code = code
212
        self.cols = cols
213
    
214
    def to_str(self, db):
215
        str_ = self.code.to_str(db)
216
        if str_.find('\n') >= 0: whitespace = '\n'
217
        else: whitespace = ' '
218
        str_ += whitespace+'AS '+Table.to_str(self, db)
219
        if self.cols != None:
220
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
221
        return str_
222
    
223
    def to_Table(self): return Table(self.name)
224

    
225
def remove_table_rename(table):
226
    if isinstance(table, NamedTable): table = table.code
227
    return table
228

    
229
##### Columns
230

    
231
class Col(Derived):
232
    def __init__(self, name, table=None, srcs=()):
233
        '''
234
        @param table Table|None (for no table)
235
        @param srcs (Col...)|src_self See Derived.set_srcs()
236
        '''
237
        Derived.__init__(self, srcs)
238
        
239
        if util.is_str(name): name = truncate(name)
240
        if util.is_str(table): table = Table(table)
241
        assert table == None or isinstance(table, Table)
242
        
243
        self.name = name
244
        self.table = table
245
    
246
    def to_str(self, db, for_str=False):
247
        str_ = as_Name(self.name).to_str(db)
248
        if for_str: str_ = clean_name(str_)
249
        if self.table != None:
250
            table = self.table.to_Table()
251
            if for_str: str_ = concat(str(table), '.'+str_)
252
            else: str_ = table.to_str(db)+'.'+str_
253
        return str_
254
    
255
    def __str__(self): return self.to_str(mockDb, for_str=True)
256
    
257
    def to_Col(self): return self
258

    
259
def is_table_col(col): return isinstance(col, Col) and col.table != None
260

    
261
def index_col(col):
262
    if not is_table_col(col): return None
263
    
264
    table = col.table
265
    try: name = table.index_cols[col.name]
266
    except KeyError: return None
267
    else: return Col(name, table, col.srcs)
268

    
269
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
270

    
271
def as_Col(col, table=None, name=None):
272
    '''
273
    @param name If not None, any non-Col input will be renamed using NamedCol.
274
    '''
275
    if name != None:
276
        col = as_Value(col)
277
        if not isinstance(col, Col): col = NamedCol(name, col)
278
    
279
    if isinstance(col, Code): return col
280
    else: return Col(col, table)
281

    
282
def with_table(col, table):
283
    if isinstance(col, NamedCol): pass # doesn't take a table
284
    elif isinstance(col, FunctionCall):
285
        col = copy.deepcopy(col) # don't modify input!
286
        col.args[0].table = table
287
    else:
288
        col = copy.copy(col) # don't modify input!
289
        col.table = table
290
    return col
291

    
292
def with_default_table(col, table):
293
    col = as_Col(col)
294
    if col.table == None: col = with_table(col, table)
295
    return col
296

    
297
def set_cols_table(table, cols):
298
    table = as_Table(table)
299
    
300
    for i, col in enumerate(cols):
301
        col = cols[i] = as_Col(col)
302
        col.table = table
303

    
304
def to_name_only_col(col, check_table=None):
305
    col = as_Col(col)
306
    if not is_table_col(col): return col
307
    
308
    if check_table != None:
309
        table = col.table
310
        assert table == None or table == check_table
311
    return Col(col.name)
312

    
313
def suffixed_col(col, suffix):
314
    return Col(concat(col.name, suffix), col.table, col.srcs)
315

    
316
class NamedCol(Col):
317
    def __init__(self, name, code):
318
        Col.__init__(self, name)
319
        
320
        code = as_Value(code)
321
        
322
        self.code = code
323
    
324
    def to_str(self, db):
325
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
326
    
327
    def to_Col(self): return Col(self.name)
328

    
329
def remove_col_rename(col):
330
    if isinstance(col, NamedCol): col = col.code
331
    return col
332

    
333
def underlying_col(col):
334
    col = remove_col_rename(col)
335
    if not isinstance(col, Col): raise NoUnderlyingTableException
336
    
337
    return Col(col.name, underlying_table(col.table), col.srcs)
338

    
339
def wrap(wrap_func, value):
340
    '''Wraps a value, propagating any column renaming to the returned value.'''
341
    if isinstance(value, NamedCol):
342
        return NamedCol(value.name, wrap_func(value.code))
343
    else: return wrap_func(value)
344

    
345
class ColDict(dicts.DictProxy):
346
    '''A dict that automatically makes inserted entries Col objects'''
347
    
348
    def __init__(self, db, keys_table, dict_={}):
349
        dicts.DictProxy.__init__(self, OrderedDict())
350
        
351
        keys_table = as_Table(keys_table)
352
        
353
        self.db = db
354
        self.table = keys_table
355
        self.update(dict_) # after setting vars because __setitem__() needs them
356
    
357
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
358
    
359
    def __getitem__(self, key):
360
        return dicts.DictProxy.__getitem__(self, self._key(key))
361
    
362
    def __setitem__(self, key, value):
363
        key = self._key(key)
364
        if value == None: value = self.db.col_info(key).default
365
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
366
    
367
    def _key(self, key): return as_Col(key, self.table)
368

    
369
##### Functions
370

    
371
Function = Table
372
as_Function = as_Table
373

    
374
class InternalFunction(CustomCode): pass
375

    
376
class NamedArg(NamedCol):
377
    def __init__(self, name, value):
378
        NamedCol.__init__(self, name, value)
379
    
380
    def to_str(self, db):
381
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
382

    
383
class FunctionCall(Code):
384
    def __init__(self, function, *args, **kw_args):
385
        '''
386
        @param args [Code|literal-value...] The function's arguments
387
        '''
388
        function = as_Function(function)
389
        def filter_(arg): return remove_col_rename(as_Value(arg))
390
        args = map(filter_, args)
391
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
392
        
393
        self.function = function
394
        self.args = args
395
    
396
    def to_str(self, db):
397
        args_str = ', '.join((v.to_str(db) for v in self.args))
398
        return self.function.to_str(db)+'('+args_str+')'
399

    
400
def wrap_in_func(function, value):
401
    '''Wraps a value inside a function call.
402
    Propagates any column renaming to the returned value.
403
    '''
404
    return wrap(lambda v: FunctionCall(function, v), value)
405

    
406
def unwrap_func_call(func_call, check_name=None):
407
    '''Unwraps any function call to its first argument.
408
    Also removes any column renaming.
409
    '''
410
    func_call = remove_col_rename(func_call)
411
    if not isinstance(func_call, FunctionCall): return func_call
412
    
413
    if check_name != None:
414
        name = func_call.function.name
415
        assert name == None or name == check_name
416
    return func_call.args[0]
417

    
418
##### Casts
419

    
420
class Cast(FunctionCall):
421
    def __init__(self, type_, value):
422
        value = as_Value(value)
423
        
424
        self.type_ = type_
425
        self.value = value
426
    
427
    def to_str(self, db):
428
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_+')'
429

    
430
##### Conditions
431

    
432
class ColValueCond(Code):
433
    def __init__(self, col, value):
434
        value = as_ValueCond(value)
435
        
436
        self.col = col
437
        self.value = value
438
    
439
    def to_str(self, db): return self.value.to_str(db, self.col)
440

    
441
def combine_conds(conds, keyword=None):
442
    '''
443
    @param keyword The keyword to add before the conditions, if any
444
    '''
445
    str_ = ''
446
    if keyword != None:
447
        if conds == []: whitespace = ''
448
        elif len(conds) == 1: whitespace = ' '
449
        else: whitespace = '\n'
450
        str_ += keyword+whitespace
451
    
452
    str_ += '\nAND '.join(conds)
453
    return str_
454

    
455
##### Condition column comparisons
456

    
457
class ValueCond(BasicObject):
458
    def __init__(self, value):
459
        value = remove_col_rename(as_Value(value))
460
        
461
        self.value = value
462
    
463
    def to_str(self, db, left_value):
464
        '''
465
        @param left_value The Code object that the condition is being applied on
466
        '''
467
        raise NotImplemented()
468
    
469
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
470

    
471
class CompareCond(ValueCond):
472
    def __init__(self, value, operator='='):
473
        '''
474
        @param operator By default, compares NULL values literally. Use '~=' or
475
            '~!=' to pass NULLs through.
476
        '''
477
        ValueCond.__init__(self, value)
478
        self.operator = operator
479
    
480
    def to_str(self, db, left_value):
481
        left_value = remove_col_rename(as_Col(left_value))
482
        
483
        right_value = self.value
484
        
485
        # Parse operator
486
        operator = self.operator
487
        passthru_null_ref = [False]
488
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
489
        neg_ref = [False]
490
        operator = strings.remove_prefix('!', operator, neg_ref)
491
        equals = operator.endswith('=') # also includes <=, >=
492
        
493
        # Handle nullable columns
494
        check_null = False
495
        if not passthru_null_ref[0]: # NULLs compare equal
496
            try: left_value = ensure_not_null(db, left_value)
497
            except ensure_not_null_excs: # fall back to alternate method
498
                check_null = equals and isinstance(right_value, Col)
499
            else:
500
                if isinstance(left_value, EnsureNotNull):
501
                    right_value = ensure_not_null(db, right_value,
502
                        left_value.type) # apply same function to both sides
503
        
504
        if equals and is_null(right_value): operator = 'IS'
505
        
506
        left = left_value.to_str(db)
507
        right = right_value.to_str(db)
508
        
509
        # Create str
510
        str_ = left+' '+operator+' '+right
511
        if check_null:
512
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
513
        if neg_ref[0]: str_ = 'NOT '+str_
514
        return str_
515

    
516
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
517
assume_literal = object()
518

    
519
def as_ValueCond(value, default_table=assume_literal):
520
    if not isinstance(value, ValueCond):
521
        if default_table is not assume_literal:
522
            value = with_default_table(value, default_table)
523
        return CompareCond(value)
524
    else: return value
525

    
526
##### Joins
527

    
528
join_same = object() # tells Join the left and right columns have the same name
529

    
530
# Tells Join the left and right columns have the same name and are never NULL
531
join_same_not_null = object()
532

    
533
filter_out = object() # tells Join to filter out rows that match the join
534

    
535
class Join(BasicObject):
536
    def __init__(self, table, mapping={}, type_=None):
537
        '''
538
        @param mapping dict(right_table_col=left_table_col, ...)
539
            * if left_table_col is join_same: left_table_col = right_table_col
540
              * Note that right_table_col must be a string
541
            * if left_table_col is join_same_not_null:
542
              left_table_col = right_table_col and both have NOT NULL constraint
543
              * Note that right_table_col must be a string
544
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
545
            * filter_out: equivalent to 'LEFT' with the query filtered by
546
              `table_pkey IS NULL` (indicating no match)
547
        '''
548
        if util.is_str(table): table = Table(table)
549
        assert type_ == None or util.is_str(type_) or type_ is filter_out
550
        
551
        self.table = table
552
        self.mapping = mapping
553
        self.type_ = type_
554
    
555
    def to_str(self, db, left_table_):
556
        def join(entry):
557
            '''Parses non-USING joins'''
558
            right_table_col, left_table_col = entry
559
            
560
            # Switch order (right_table_col is on the left in the comparison)
561
            left = right_table_col
562
            right = left_table_col
563
            left_table = self.table
564
            right_table = left_table_
565
            
566
            # Parse left side
567
            left = with_default_table(left, left_table)
568
            
569
            # Parse special values
570
            left_on_right = Col(left.name, right_table)
571
            if right is join_same: right = left_on_right
572
            elif right is join_same_not_null:
573
                right = CompareCond(left_on_right, '~=')
574
            
575
            # Parse right side
576
            right = as_ValueCond(right, right_table)
577
            
578
            return right.to_str(db, left)
579
        
580
        # Create join condition
581
        type_ = self.type_
582
        joins = self.mapping
583
        if joins == {}: join_cond = None
584
        elif type_ is not filter_out and reduce(operator.and_,
585
            (v is join_same_not_null for v in joins.itervalues())):
586
            # all cols w/ USING, so can use simpler USING syntax
587
            cols = map(to_name_only_col, joins.iterkeys())
588
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
589
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
590
        
591
        if isinstance(self.table, NamedTable): whitespace = '\n'
592
        else: whitespace = ' '
593
        
594
        # Create join
595
        if type_ is filter_out: type_ = 'LEFT'
596
        str_ = ''
597
        if type_ != None: str_ += type_+' '
598
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
599
        if join_cond != None: str_ += whitespace+join_cond
600
        return str_
601
    
602
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
603

    
604
##### Value exprs
605

    
606
all_cols = CustomCode('*')
607

    
608
default = CustomCode('DEFAULT')
609

    
610
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
611

    
612
class Coalesce(FunctionCall):
613
    def __init__(self, *args):
614
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
615

    
616
class Nullif(FunctionCall):
617
    def __init__(self, *args):
618
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
619

    
620
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
621
null_sentinels = {
622
    'character varying': r'\N',
623
    'double precision': 'NaN',
624
    'integer': 2147483647,
625
    'text': r'\N',
626
    'timestamp with time zone': 'infinity'
627
}
628

    
629
class EnsureNotNull(Coalesce):
630
    def __init__(self, value, type_):
631
        Coalesce.__init__(self, as_Col(value),
632
            Cast(type_, null_sentinels[type_]))
633
        
634
        self.type = type_
635
    
636
    def to_str(self, db):
637
        col = self.args[0]
638
        index_col_ = index_col(col)
639
        if index_col_ != None: return index_col_.to_str(db)
640
        return Coalesce.to_str(self, db)
641

    
642
##### Table exprs
643

    
644
class Values(Code):
645
    def __init__(self, values):
646
        '''
647
        @param values [...]|[[...], ...] Can be one or multiple rows.
648
        '''
649
        rows = values
650
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
651
            rows = [values]
652
        for i, row in enumerate(rows):
653
            rows[i] = map(remove_col_rename, map(as_Value, row))
654
        
655
        self.rows = rows
656
    
657
    def to_str(self, db):
658
        def row_str(row):
659
            return '('+(', '.join((v.to_str(db) for v in row)))+')'
660
        return 'VALUES '+(', '.join(map(row_str, self.rows)))
661

    
662
def NamedValues(name, cols, values):
663
    '''
664
    @param cols None|[...]
665
    @post `cols` will be changed to Col objects with the table set to `name`.
666
    '''
667
    table = NamedTable(name, Values(values), cols)
668
    if cols != None: set_cols_table(table, cols)
669
    return table
670

    
671
##### Database structure
672

    
673
class TypedCol(Col):
674
    def __init__(self, name, type_, default=None, nullable=True,
675
        constraints=None):
676
        assert default == None or isinstance(default, Code)
677
        
678
        Col.__init__(self, name)
679
        
680
        self.type = type_
681
        self.default = default
682
        self.nullable = nullable
683
        self.constraints = constraints
684
    
685
    def to_str(self, db):
686
        str_ = Col.to_str(self, db)+' '+self.type
687
        if not self.nullable: str_ += ' NOT NULL'
688
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
689
        if self.constraints != None: str_ += ' '+self.constraints
690
        return str_
691
    
692
    def to_Col(self): return Col(self.name)
693

    
694
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
695

    
696
def ensure_not_null(db, col, type_=None):
697
    '''
698
    @param col If type_ is not set, must have an underlying column.
699
    @param type_ If set, overrides the underlying column's type.
700
    @return EnsureNotNull|Col
701
    @throws ensure_not_null_excs
702
    '''
703
    nullable = True
704
    try: typed_col = db.col_info(underlying_col(col))
705
    except NoUnderlyingTableException:
706
        if type_ == None: raise
707
    else:
708
        if type_ == None: type_ = typed_col.type
709
        nullable = typed_col.nullable
710
    
711
    if nullable:
712
        try: col = EnsureNotNull(col, type_)
713
        except KeyError, e:
714
            # Warn of no null sentinel for type, even if caller catches error
715
            warnings.warn(UserWarning(exc.str_(e)))
716
            raise
717
    
718
    return col
(25-25/37)