Project

General

Profile

1 2211 aaronmk
# SQL code generation
2
3 2748 aaronmk
import copy
4 3518 aaronmk
import itertools
5 2276 aaronmk
import operator
6 3158 aaronmk
from ordereddict import OrderedDict
7 2568 aaronmk
import re
8 2653 aaronmk
import UserDict
9 2953 aaronmk
import warnings
10 2276 aaronmk
11 2667 aaronmk
import dicts
12 2953 aaronmk
import exc
13 2701 aaronmk
import iters
14
import lists
15 2360 aaronmk
import objects
16 2222 aaronmk
import strings
17 2227 aaronmk
import util
18 2211 aaronmk
19 2587 aaronmk
##### Names
20 2499 aaronmk
21 2608 aaronmk
identifier_max_len = 63 # works for both PostgreSQL and MySQL
22 2587 aaronmk
23 2932 aaronmk
def concat(str_, suffix):
24 2609 aaronmk
    '''Preserves version so that it won't be truncated off the string, leading
25
    to collisions.'''
26 2613 aaronmk
    # Preserve version
27 2995 aaronmk
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
28 2985 aaronmk
    if match:
29
        str_, old_suffix = match.groups()
30
        suffix = old_suffix+suffix
31 2613 aaronmk
32 2932 aaronmk
    return strings.concat(str_, suffix, identifier_max_len)
33 2587 aaronmk
34 2932 aaronmk
def truncate(str_): return concat(str_, '')
35 2842 aaronmk
36 2575 aaronmk
def is_safe_name(name):
37 2583 aaronmk
    '''A name is safe *and unambiguous* if it:
38
    * contains only *lowercase* word (\w) characters
39
    * doesn't start with a digit
40
    * contains "_", so that it's not a keyword
41 2984 aaronmk
    '''
42
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
43 2568 aaronmk
44 2499 aaronmk
def esc_name(name, quote='"'):
45
    return quote + name.replace(quote, quote+quote) + quote
46
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
47
48 3320 aaronmk
def unesc_name(name, quote='"'):
49
    removed_ref = [False]
50
    name = strings.remove_prefix(quote, name, removed_ref)
51
    if removed_ref[0]:
52
        name = strings.remove_suffix(quote, name, removed_ref)
53
        assert removed_ref[0]
54
        name = name.replace(quote+quote, quote)
55
    return name
56
57 2513 aaronmk
def clean_name(name): return name.replace('"', '').replace('`', '')
58
59 3041 aaronmk
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
60
61 3182 aaronmk
def lstrip(str_):
62
    '''Also removes comments.'''
63
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
64
    return str_.lstrip()
65
66 2659 aaronmk
##### General SQL code objects
67 2219 aaronmk
68 2349 aaronmk
class MockDb:
69 2503 aaronmk
    def esc_value(self, value): return strings.repr_no_u(value)
70 2349 aaronmk
71 2499 aaronmk
    def esc_name(self, name): return esc_name(name)
72 2859 aaronmk
73
    def col_info(self, col):
74
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
75
76 2349 aaronmk
mockDb = MockDb()
77
78 2514 aaronmk
class BasicObject(objects.BasicObject):
79
    def __str__(self): return clean_name(strings.repr_no_u(self))
80
81 2659 aaronmk
##### Unparameterized code objects
82
83 2514 aaronmk
class Code(BasicObject):
84 3446 aaronmk
    def __init__(self, lang='sql'):
85
        self.lang = lang
86 3445 aaronmk
87 2658 aaronmk
    def to_str(self, db): raise NotImplementedError()
88 2349 aaronmk
89 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb)
90 2211 aaronmk
91 2269 aaronmk
class CustomCode(Code):
92 3445 aaronmk
    def __init__(self, str_):
93
        Code.__init__(self)
94
95
        self.str_ = str_
96 2256 aaronmk
97
    def to_str(self, db): return self.str_
98
99 2815 aaronmk
def as_Code(value, db=None):
100
    '''
101
    @param db If set, runs db.std_code() on the value.
102
    '''
103 3447 aaronmk
    if isinstance(value, Code): return value
104
105 2815 aaronmk
    if util.is_str(value):
106
        if db != None: value = db.std_code(value)
107
        return CustomCode(value)
108 2659 aaronmk
    else: return Literal(value)
109
110 2540 aaronmk
class Expr(Code):
111 3445 aaronmk
    def __init__(self, expr):
112
        Code.__init__(self)
113
114
        self.expr = expr
115 2540 aaronmk
116
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
117
118 3086 aaronmk
##### Names
119
120
class Name(Code):
121 3087 aaronmk
    def __init__(self, name):
122 3445 aaronmk
        Code.__init__(self)
123
124 3087 aaronmk
        name = truncate(name)
125
126
        self.name = name
127 3086 aaronmk
128
    def to_str(self, db): return db.esc_name(self.name)
129
130
def as_Name(value):
131
    if isinstance(value, Code): return value
132
    else: return Name(value)
133
134 2335 aaronmk
##### Literal values
135
136 3519 aaronmk
#### Primitives
137
138 2216 aaronmk
class Literal(Code):
139 3445 aaronmk
    def __init__(self, value):
140
        Code.__init__(self)
141
142
        self.value = value
143 2213 aaronmk
144
    def to_str(self, db): return db.esc_value(self.value)
145 2211 aaronmk
146 2400 aaronmk
def as_Value(value):
147
    if isinstance(value, Code): return value
148
    else: return Literal(value)
149
150 3429 aaronmk
def is_literal(value): return isinstance(value, Literal)
151 2216 aaronmk
152 3429 aaronmk
def is_null(value): return is_literal(value) and value.value == None
153
154 3519 aaronmk
#### Composites
155
156 3521 aaronmk
class List(Code):
157
    def __init__(self, values):
158 3519 aaronmk
        Code.__init__(self)
159
160
        self.values = values
161
162 3521 aaronmk
    def to_str(self, db): return ', '.join((v.to_str(db) for v in self.values))
163 3519 aaronmk
164 3521 aaronmk
class Tuple(List):
165
    def __init__(self, *values):
166
        List.__init__(self, values)
167
168
    def to_str(self, db): return '('+List.to_str(self, db)+')'
169
170 3520 aaronmk
class Row(Tuple):
171
    def to_str(self, db): return 'ROW'+Tuple.to_str(self, db)
172 3519 aaronmk
173 3522 aaronmk
### Arrays
174
175
class Array(List):
176
    def __init__(self, values):
177
        values = map(remove_col_rename, values)
178
179
        List.__init__(self, values)
180
181
    def to_str(self, db): return 'ARRAY['+List.to_str(self, db)+']'
182
183
def to_Array(value):
184
    if isinstance(value, Array): return value
185
    return Array(lists.mk_seq(value))
186
187 2711 aaronmk
##### Derived elements
188
189
src_self = object() # tells Col that it is its own source column
190
191
class Derived(Code):
192
    def __init__(self, srcs):
193 2712 aaronmk
        '''An element which was derived from some other element(s).
194 2711 aaronmk
        @param srcs See self.set_srcs()
195
        '''
196 3445 aaronmk
        Code.__init__(self)
197
198 2711 aaronmk
        self.set_srcs(srcs)
199
200 2713 aaronmk
    def set_srcs(self, srcs, overwrite=True):
201 2711 aaronmk
        '''
202
        @param srcs (self_type...)|src_self The element(s) this is derived from
203
        '''
204 2713 aaronmk
        if not overwrite and self.srcs != (): return # already set
205
206 2711 aaronmk
        if srcs == src_self: srcs = (self,)
207
        srcs = tuple(srcs) # make Col hashable
208
        self.srcs = srcs
209
210
    def _compare_on(self):
211
        compare_on = self.__dict__.copy()
212
        del compare_on['srcs'] # ignore
213
        return compare_on
214
215
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
216
217 2335 aaronmk
##### Tables
218
219 2712 aaronmk
class Table(Derived):
220 2991 aaronmk
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
221 2211 aaronmk
        '''
222
        @param schema str|None (for no schema)
223 2712 aaronmk
        @param srcs (Table...)|src_self See Derived.set_srcs()
224 2211 aaronmk
        '''
225 2712 aaronmk
        Derived.__init__(self, srcs)
226
227 3091 aaronmk
        if util.is_str(name): name = truncate(name)
228 2843 aaronmk
229 2211 aaronmk
        self.name = name
230
        self.schema = schema
231 2991 aaronmk
        self.is_temp = is_temp
232 3000 aaronmk
        self.index_cols = {}
233 2211 aaronmk
234 2348 aaronmk
    def to_str(self, db):
235
        str_ = ''
236 3088 aaronmk
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
237
        str_ += as_Name(self.name).to_str(db)
238 2348 aaronmk
        return str_
239 2336 aaronmk
240
    def to_Table(self): return self
241 3000 aaronmk
242
    def _compare_on(self):
243
        compare_on = Derived._compare_on(self)
244
        del compare_on['index_cols'] # ignore
245
        return compare_on
246 2211 aaronmk
247 2835 aaronmk
def is_underlying_table(table):
248
    return isinstance(table, Table) and table.to_Table() is table
249 2832 aaronmk
250 4114 aaronmk
class NoUnderlyingTableException(Exception):
251
    def __init__(self, ref):
252
        Exception.__init__(self, 'for: '+strings.as_tt(strings.urepr(ref)))
253
        self.ref = ref
254 2902 aaronmk
255
def underlying_table(table):
256
    table = remove_table_rename(table)
257 3532 aaronmk
    if table != None and table.srcs:
258
        table, = table.srcs # for derived tables or row vars
259 4114 aaronmk
    if not is_underlying_table(table): raise NoUnderlyingTableException(table)
260 2902 aaronmk
    return table
261
262 2776 aaronmk
def as_Table(table, schema=None):
263 2270 aaronmk
    if table == None or isinstance(table, Code): return table
264 2776 aaronmk
    else: return Table(table, schema)
265 2219 aaronmk
266 3101 aaronmk
def suffixed_table(table, suffix):
267 3128 aaronmk
    table = copy.copy(table) # don't modify input!
268
    table.name = concat(table.name, suffix)
269
    return table
270 2707 aaronmk
271 2336 aaronmk
class NamedTable(Table):
272
    def __init__(self, name, code, cols=None):
273
        Table.__init__(self, name)
274
275 3016 aaronmk
        code = as_Table(code)
276 2741 aaronmk
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
277 3020 aaronmk
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
278 2336 aaronmk
279
        self.code = code
280
        self.cols = cols
281
282
    def to_str(self, db):
283 3026 aaronmk
        str_ = self.code.to_str(db)
284
        if str_.find('\n') >= 0: whitespace = '\n'
285
        else: whitespace = ' '
286
        str_ += whitespace+'AS '+Table.to_str(self, db)
287 2742 aaronmk
        if self.cols != None:
288
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
289 2336 aaronmk
        return str_
290
291
    def to_Table(self): return Table(self.name)
292
293 2753 aaronmk
def remove_table_rename(table):
294
    if isinstance(table, NamedTable): table = table.code
295
    return table
296
297 2335 aaronmk
##### Columns
298
299 2711 aaronmk
class Col(Derived):
300 2701 aaronmk
    def __init__(self, name, table=None, srcs=()):
301 2211 aaronmk
        '''
302
        @param table Table|None (for no table)
303 2711 aaronmk
        @param srcs (Col...)|src_self See Derived.set_srcs()
304 2211 aaronmk
        '''
305 2711 aaronmk
        Derived.__init__(self, srcs)
306
307 3091 aaronmk
        if util.is_str(name): name = truncate(name)
308 2241 aaronmk
        if util.is_str(table): table = Table(table)
309 2211 aaronmk
        assert table == None or isinstance(table, Table)
310
311
        self.name = name
312
        self.table = table
313
314 2989 aaronmk
    def to_str(self, db, for_str=False):
315 3088 aaronmk
        str_ = as_Name(self.name).to_str(db)
316 2989 aaronmk
        if for_str: str_ = clean_name(str_)
317 2933 aaronmk
        if self.table != None:
318 2989 aaronmk
            table = self.table.to_Table()
319 3750 aaronmk
            if for_str: str_ = concat(strings.ustr(table), '.'+str_)
320 2989 aaronmk
            else: str_ = table.to_str(db)+'.'+str_
321 2211 aaronmk
        return str_
322 2314 aaronmk
323 2989 aaronmk
    def __str__(self): return self.to_str(mockDb, for_str=True)
324 2933 aaronmk
325 2314 aaronmk
    def to_Col(self): return self
326 2211 aaronmk
327 3512 aaronmk
def is_col(col): return isinstance(col, Col)
328 2393 aaronmk
329 3512 aaronmk
def is_table_col(col): return is_col(col) and col.table != None
330
331 3000 aaronmk
def index_col(col):
332
    if not is_table_col(col): return None
333 3104 aaronmk
334
    table = col.table
335
    try: name = table.index_cols[col.name]
336
    except KeyError: return None
337
    else: return Col(name, table, col.srcs)
338 2999 aaronmk
339 3024 aaronmk
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
340 2996 aaronmk
341 2563 aaronmk
def as_Col(col, table=None, name=None):
342
    '''
343
    @param name If not None, any non-Col input will be renamed using NamedCol.
344
    '''
345
    if name != None:
346
        col = as_Value(col)
347
        if not isinstance(col, Col): col = NamedCol(name, col)
348 2333 aaronmk
349
    if isinstance(col, Code): return col
350 3513 aaronmk
    elif util.is_str(col): return Col(col, table)
351
    else: return Literal(col)
352 2260 aaronmk
353 3093 aaronmk
def with_table(col, table):
354 3105 aaronmk
    if isinstance(col, NamedCol): pass # doesn't take a table
355
    elif isinstance(col, FunctionCall):
356
        col = copy.deepcopy(col) # don't modify input!
357
        col.args[0].table = table
358 3493 aaronmk
    elif isinstance(col, Col):
359 3098 aaronmk
        col = copy.copy(col) # don't modify input!
360
        col.table = table
361 3093 aaronmk
    return col
362
363 3100 aaronmk
def with_default_table(col, table):
364 2747 aaronmk
    col = as_Col(col)
365 3100 aaronmk
    if col.table == None: col = with_table(col, table)
366 2747 aaronmk
    return col
367
368 2744 aaronmk
def set_cols_table(table, cols):
369
    table = as_Table(table)
370
371
    for i, col in enumerate(cols):
372
        col = cols[i] = as_Col(col)
373
        col.table = table
374
375 2401 aaronmk
def to_name_only_col(col, check_table=None):
376
    col = as_Col(col)
377 3020 aaronmk
    if not is_table_col(col): return col
378 2401 aaronmk
379
    if check_table != None:
380
        table = col.table
381
        assert table == None or table == check_table
382
    return Col(col.name)
383
384 2993 aaronmk
def suffixed_col(col, suffix):
385
    return Col(concat(col.name, suffix), col.table, col.srcs)
386
387 3516 aaronmk
def has_srcs(col): return is_col(col) and col.srcs
388
389 3518 aaronmk
def cross_join_srcs(cols):
390
    cols = filter(has_srcs, cols) # empty srcs will mess up the cross join
391
    srcs = [[s.name for s in c.srcs] for c in cols]
392
    return [Col(','.join(s)) for s in itertools.product(*srcs)]
393 3514 aaronmk
394 2323 aaronmk
class NamedCol(Col):
395 2229 aaronmk
    def __init__(self, name, code):
396 2310 aaronmk
        Col.__init__(self, name)
397
398 3016 aaronmk
        code = as_Value(code)
399 2229 aaronmk
400
        self.code = code
401
402
    def to_str(self, db):
403 2310 aaronmk
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
404 2314 aaronmk
405
    def to_Col(self): return Col(self.name)
406 2229 aaronmk
407 2462 aaronmk
def remove_col_rename(col):
408
    if isinstance(col, NamedCol): col = col.code
409
    return col
410
411 2830 aaronmk
def underlying_col(col):
412
    col = remove_col_rename(col)
413 4114 aaronmk
    if not isinstance(col, Col): raise NoUnderlyingTableException(col)
414 2849 aaronmk
415 2902 aaronmk
    return Col(col.name, underlying_table(col.table), col.srcs)
416 2830 aaronmk
417 2703 aaronmk
def wrap(wrap_func, value):
418
    '''Wraps a value, propagating any column renaming to the returned value.'''
419
    if isinstance(value, NamedCol):
420
        return NamedCol(value.name, wrap_func(value.code))
421
    else: return wrap_func(value)
422
423 2667 aaronmk
class ColDict(dicts.DictProxy):
424 3691 aaronmk
    '''A dict that automatically makes inserted entries Col objects.
425
    Anything that isn't a column is wrapped in a NamedCol with the key's column
426
    name by `as_Col(value, name=key.name)`.
427
    '''
428 2564 aaronmk
429 2645 aaronmk
    def __init__(self, db, keys_table, dict_={}):
430 3158 aaronmk
        dicts.DictProxy.__init__(self, OrderedDict())
431 2667 aaronmk
432 2645 aaronmk
        keys_table = as_Table(keys_table)
433
434 2642 aaronmk
        self.db = db
435 2641 aaronmk
        self.table = keys_table
436 2653 aaronmk
        self.update(dict_) # after setting vars because __setitem__() needs them
437 2641 aaronmk
438 2667 aaronmk
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
439 2655 aaronmk
440 2667 aaronmk
    def __getitem__(self, key):
441
        return dicts.DictProxy.__getitem__(self, self._key(key))
442 2653 aaronmk
443 2564 aaronmk
    def __setitem__(self, key, value):
444 2642 aaronmk
        key = self._key(key)
445 3639 aaronmk
        if value == None:
446
            try: value = self.db.col_info(key).default
447
            except NoUnderlyingTableException: pass # not a table column
448 2667 aaronmk
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
449 2564 aaronmk
450 2641 aaronmk
    def _key(self, key): return as_Col(key, self.table)
451 2564 aaronmk
452 3519 aaronmk
##### Definitions
453 3491 aaronmk
454 3469 aaronmk
class TypedCol(Col):
455
    def __init__(self, name, type_, default=None, nullable=True,
456
        constraints=None):
457
        assert default == None or isinstance(default, Code)
458
459
        Col.__init__(self, name)
460
461
        self.type = type_
462
        self.default = default
463
        self.nullable = nullable
464
        self.constraints = constraints
465
466
    def to_str(self, db):
467 3481 aaronmk
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
468 3469 aaronmk
        if not self.nullable: str_ += ' NOT NULL'
469
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
470
        if self.constraints != None: str_ += ' '+self.constraints
471
        return str_
472
473
    def to_Col(self): return Col(self.name)
474
475 3488 aaronmk
class SetOf(Code):
476
    def __init__(self, type_):
477
        Code.__init__(self)
478
479
        self.type = type_
480
481
    def to_str(self, db):
482
        return 'SETOF '+self.type.to_str(db)
483
484 3483 aaronmk
class RowType(Code):
485
    def __init__(self, table):
486
        Code.__init__(self)
487
488
        self.table = table
489
490
    def to_str(self, db):
491
        return self.table.to_str(db)+'%ROWTYPE'
492
493 3485 aaronmk
class ColType(Code):
494
    def __init__(self, col):
495
        Code.__init__(self)
496
497
        self.col = col
498
499
    def to_str(self, db):
500
        return self.col.to_str(db)+'%TYPE'
501
502 4935 aaronmk
class ArrayType(Code):
503
    def __init__(self, elem_type):
504
        Code.__init__(self)
505
        elem_type = as_Code(elem_type)
506
507
        self.elem_type = elem_type
508
509
    def to_str(self, db):
510
        return self.elem_type.to_str(db)+'[]'
511
512 2524 aaronmk
##### Functions
513
514 2912 aaronmk
Function = Table
515 2911 aaronmk
as_Function = as_Table
516
517 2691 aaronmk
class InternalFunction(CustomCode): pass
518
519 3442 aaronmk
#### Calls
520
521 2941 aaronmk
class NamedArg(NamedCol):
522
    def __init__(self, name, value):
523
        NamedCol.__init__(self, name, value)
524
525
    def to_str(self, db):
526
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
527
528 2524 aaronmk
class FunctionCall(Code):
529 2941 aaronmk
    def __init__(self, function, *args, **kw_args):
530 2524 aaronmk
        '''
531 2690 aaronmk
        @param args [Code|literal-value...] The function's arguments
532 2524 aaronmk
        '''
533 3445 aaronmk
        Code.__init__(self)
534
535 3016 aaronmk
        function = as_Function(function)
536 2941 aaronmk
        def filter_(arg): return remove_col_rename(as_Value(arg))
537
        args = map(filter_, args)
538
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
539 2524 aaronmk
540
        self.function = function
541
        self.args = args
542
543
    def to_str(self, db):
544
        args_str = ', '.join((v.to_str(db) for v in self.args))
545
        return self.function.to_str(db)+'('+args_str+')'
546
547 2533 aaronmk
def wrap_in_func(function, value):
548
    '''Wraps a value inside a function call.
549
    Propagates any column renaming to the returned value.
550
    '''
551 2703 aaronmk
    return wrap(lambda v: FunctionCall(function, v), value)
552 2533 aaronmk
553 2561 aaronmk
def unwrap_func_call(func_call, check_name=None):
554
    '''Unwraps any function call to its first argument.
555
    Also removes any column renaming.
556
    '''
557
    func_call = remove_col_rename(func_call)
558
    if not isinstance(func_call, FunctionCall): return func_call
559
560
    if check_name != None:
561
        name = func_call.function.name
562
        assert name == None or name == check_name
563
    return func_call.args[0]
564
565 3442 aaronmk
#### Definitions
566
567
class FunctionDef(Code):
568 3471 aaronmk
    def __init__(self, function, return_type, body, params=[], modifiers=None):
569 3445 aaronmk
        Code.__init__(self)
570
571 3487 aaronmk
        return_type = as_Code(return_type)
572 3444 aaronmk
        body = as_Code(body)
573
574 3442 aaronmk
        self.function = function
575
        self.return_type = return_type
576
        self.body = body
577 3471 aaronmk
        self.params = params
578 3456 aaronmk
        self.modifiers = modifiers
579 3442 aaronmk
580
    def to_str(self, db):
581 3487 aaronmk
        params_str = (', '.join((p.to_str(db) for p in self.params)))
582 3442 aaronmk
        str_ = '''\
583 3471 aaronmk
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
584 3487 aaronmk
RETURNS '''+self.return_type.to_str(db)+'''
585 3448 aaronmk
LANGUAGE '''+self.body.lang+'''
586 3456 aaronmk
'''
587
        if self.modifiers != None: str_ += self.modifiers+'\n'
588
        str_ += '''\
589 3442 aaronmk
AS $$
590 3444 aaronmk
'''+self.body.to_str(db)+'''
591 3442 aaronmk
$$;
592
'''
593
        return str_
594
595 3469 aaronmk
class FunctionParam(TypedCol):
596
    def __init__(self, name, type_, default=None, out=False):
597
        TypedCol.__init__(self, name, type_, default)
598
599
        self.out = out
600
601
    def to_str(self, db):
602
        str_ = TypedCol.to_str(self, db)
603
        if self.out: str_ = 'OUT '+str_
604
        return str_
605
606
    def to_Col(self): return Col(self.name)
607
608 3454 aaronmk
### PL/pgSQL
609
610 3496 aaronmk
class ReturnQuery(Code):
611
    def __init__(self, query):
612
        Code.__init__(self)
613
614
        query = as_Code(query)
615
616
        self.query = query
617
618
    def to_str(self, db):
619
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
620
621 3515 aaronmk
## Exceptions
622
623 3509 aaronmk
class BaseExcHandler(BasicObject):
624
    def to_str(self, db, body): raise NotImplementedError()
625 3510 aaronmk
626
    def __repr__(self): return self.to_str(mockDb, '<body>')
627 3509 aaronmk
628 3549 aaronmk
suppress_exc = 'NULL;\n';
629
630 3554 aaronmk
reraise_exc = 'RAISE USING ERRCODE = SQLSTATE, MESSAGE = SQLERRM;\n';
631
632 3509 aaronmk
class ExcHandler(BaseExcHandler):
633 3454 aaronmk
    def __init__(self, exc, handler=None):
634
        if handler != None: handler = as_Code(handler)
635
636
        self.exc = exc
637
        self.handler = handler
638
639
    def to_str(self, db, body):
640
        body = as_Code(body)
641
642 3467 aaronmk
        if self.handler != None:
643
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
644 3549 aaronmk
        else: handler_str = ' '+suppress_exc
645 3454 aaronmk
646
        str_ = '''\
647
BEGIN
648 3467 aaronmk
'''+strings.indent(body.to_str(db))+'''\
649 3454 aaronmk
EXCEPTION
650 3463 aaronmk
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
651 3454 aaronmk
END;\
652
'''
653
        return str_
654
655 3515 aaronmk
class NestedExcHandler(BaseExcHandler):
656
    def __init__(self, *handlers):
657
        '''
658
        @param handlers Sorted from outermost to innermost
659
        '''
660
        self.handlers = handlers
661
662
    def to_str(self, db, body):
663
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
664
        return body
665
666 3503 aaronmk
class ExcToWarning(Code):
667
    def __init__(self, return_):
668
        '''
669
        @param return_ Statement to return a default value in case of error
670
        '''
671
        Code.__init__(self)
672
673
        return_ = as_Code(return_)
674
675
        self.return_ = return_
676
677
    def to_str(self, db):
678
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
679
680 3454 aaronmk
unique_violation_handler = ExcHandler('unique_violation')
681
682 3555 aaronmk
# Note doubled "\"s because inside Python string
683 3468 aaronmk
plpythonu_error_handler = ExcHandler('internal_error', '''\
684 3591 aaronmk
-- Handle PL/Python exceptions
685 3555 aaronmk
DECLARE
686 3595 aaronmk
    matches text[] := regexp_matches(SQLERRM,
687
        E'^(?:PL/Python: )?(\\\\w+): (.*)$'); -- .* also matches \\n
688 3555 aaronmk
    exc_name text := matches[1];
689
    msg text := matches[2];
690
BEGIN
691 3596 aaronmk
    /* Re-raise PL/Python exceptions with the PL/Python prefix removed.
692
    This allows the exception to be parsed like a native exception.
693
    Always raise as data_exception so it goes in the errors table. */
694 3598 aaronmk
    IF exc_name IS NOT NULL THEN
695 4026 aaronmk
        RAISE data_exception USING MESSAGE = msg;
696 3595 aaronmk
    -- Re-raise non-PL/Python exceptions
697
    ELSE
698
        '''+reraise_exc+'''\
699 3555 aaronmk
    END IF;
700
END;
701 3468 aaronmk
''')
702
703 3505 aaronmk
def data_exception_handler(handler):
704
    return ExcHandler('data_exception', handler)
705
706 3527 aaronmk
row_var = Table('row')
707
708 3449 aaronmk
class RowExcIgnore(Code):
709
    def __init__(self, row_type, select_query, with_row, cols=None,
710 3527 aaronmk
        exc_handler=unique_violation_handler, row_var=row_var):
711 3529 aaronmk
        '''
712
        @param row_type Ignored if a custom row_var is used.
713
        @pre If a custom row_var is used, it must already be defined.
714
        '''
715 3449 aaronmk
        Code.__init__(self, lang='plpgsql')
716
717 3482 aaronmk
        row_type = as_Code(row_type)
718 3449 aaronmk
        select_query = as_Code(select_query)
719
        with_row = as_Code(with_row)
720 3452 aaronmk
        row_var = as_Table(row_var)
721 3449 aaronmk
722
        self.row_type = row_type
723
        self.select_query = select_query
724
        self.with_row = with_row
725
        self.cols = cols
726 3455 aaronmk
        self.exc_handler = exc_handler
727 3452 aaronmk
        self.row_var = row_var
728 3449 aaronmk
729
    def to_str(self, db):
730 3452 aaronmk
        if self.cols == None: row_vars = [self.row_var]
731
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
732 3449 aaronmk
733 3526 aaronmk
        # Need an EXCEPTION block for each individual row because "When an error
734
        # is caught by an EXCEPTION clause, [...] all changes to persistent
735
        # database state within the block are rolled back."
736
        # This is unfortunate because "A block containing an EXCEPTION clause is
737
        # significantly more expensive to enter and exit than a block without
738
        # one."
739
        # (http://www.postgresql.org/docs/8.3/static/\
740
        # plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING)
741 3449 aaronmk
        str_ = '''\
742 3529 aaronmk
FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
743
'''+strings.indent(self.select_query.to_str(db))+'''\
744
LOOP
745
'''+strings.indent(self.exc_handler.to_str(db, self.with_row))+'''\
746
END LOOP;
747
'''
748 3533 aaronmk
        if self.row_var == row_var:
749 3529 aaronmk
            str_ = '''\
750 3449 aaronmk
DECLARE
751 3482 aaronmk
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
752 3449 aaronmk
BEGIN
753 3529 aaronmk
'''+strings.indent(str_)+'''\
754
END;
755 3449 aaronmk
'''
756
        return str_
757
758 2986 aaronmk
##### Casts
759
760
class Cast(FunctionCall):
761
    def __init__(self, type_, value):
762 3539 aaronmk
        type_ = as_Code(type_)
763 2986 aaronmk
        value = as_Value(value)
764
765
        self.type_ = type_
766
        self.value = value
767
768
    def to_str(self, db):
769 3539 aaronmk
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_.to_str(db)+')'
770 2986 aaronmk
771 3354 aaronmk
def cast_literal(value):
772 3429 aaronmk
    if not is_literal(value): return value
773 3354 aaronmk
774
    if util.is_str(value.value): value = Cast('text', value)
775
    return value
776
777 2335 aaronmk
##### Conditions
778 2259 aaronmk
779 3350 aaronmk
class NotCond(Code):
780
    def __init__(self, cond):
781 3445 aaronmk
        Code.__init__(self)
782
783 3350 aaronmk
        self.cond = cond
784
785
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
786
787 2398 aaronmk
class ColValueCond(Code):
788
    def __init__(self, col, value):
789 3445 aaronmk
        Code.__init__(self)
790
791 2398 aaronmk
        value = as_ValueCond(value)
792
793
        self.col = col
794
        self.value = value
795
796
    def to_str(self, db): return self.value.to_str(db, self.col)
797
798 2577 aaronmk
def combine_conds(conds, keyword=None):
799
    '''
800
    @param keyword The keyword to add before the conditions, if any
801
    '''
802
    str_ = ''
803
    if keyword != None:
804
        if conds == []: whitespace = ''
805
        elif len(conds) == 1: whitespace = ' '
806
        else: whitespace = '\n'
807
        str_ += keyword+whitespace
808
809
    str_ += '\nAND '.join(conds)
810
    return str_
811
812 2398 aaronmk
##### Condition column comparisons
813
814 2514 aaronmk
class ValueCond(BasicObject):
815 2213 aaronmk
    def __init__(self, value):
816 2858 aaronmk
        value = remove_col_rename(as_Value(value))
817 2213 aaronmk
818
        self.value = value
819 2214 aaronmk
820 2216 aaronmk
    def to_str(self, db, left_value):
821 2214 aaronmk
        '''
822 2216 aaronmk
        @param left_value The Code object that the condition is being applied on
823 2214 aaronmk
        '''
824
        raise NotImplemented()
825 2228 aaronmk
826 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
827 2211 aaronmk
828
class CompareCond(ValueCond):
829
    def __init__(self, value, operator='='):
830 2222 aaronmk
        '''
831
        @param operator By default, compares NULL values literally. Use '~=' or
832
            '~!=' to pass NULLs through.
833
        '''
834 2211 aaronmk
        ValueCond.__init__(self, value)
835
        self.operator = operator
836
837 2216 aaronmk
    def to_str(self, db, left_value):
838 2858 aaronmk
        left_value = remove_col_rename(as_Col(left_value))
839 2216 aaronmk
840 2222 aaronmk
        right_value = self.value
841
842
        # Parse operator
843 2216 aaronmk
        operator = self.operator
844 2222 aaronmk
        passthru_null_ref = [False]
845
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
846
        neg_ref = [False]
847
        operator = strings.remove_prefix('!', operator, neg_ref)
848 2844 aaronmk
        equals = operator.endswith('=') # also includes <=, >=
849 2222 aaronmk
850 2825 aaronmk
        # Handle nullable columns
851
        check_null = False
852 2844 aaronmk
        if not passthru_null_ref[0]: # NULLs compare equal
853 2857 aaronmk
            try: left_value = ensure_not_null(db, left_value)
854 2844 aaronmk
            except ensure_not_null_excs: # fall back to alternate method
855
                check_null = equals and isinstance(right_value, Col)
856 2837 aaronmk
            else:
857 2857 aaronmk
                if isinstance(left_value, EnsureNotNull):
858
                    right_value = ensure_not_null(db, right_value,
859
                        left_value.type) # apply same function to both sides
860 2825 aaronmk
861 2844 aaronmk
        if equals and is_null(right_value): operator = 'IS'
862
863 2825 aaronmk
        left = left_value.to_str(db)
864
        right = right_value.to_str(db)
865
866 2222 aaronmk
        # Create str
867
        str_ = left+' '+operator+' '+right
868 2825 aaronmk
        if check_null:
869 2578 aaronmk
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
870
        if neg_ref[0]: str_ = 'NOT '+str_
871 2222 aaronmk
        return str_
872 2216 aaronmk
873 2260 aaronmk
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
874
assume_literal = object()
875
876
def as_ValueCond(value, default_table=assume_literal):
877
    if not isinstance(value, ValueCond):
878
        if default_table is not assume_literal:
879 2748 aaronmk
            value = with_default_table(value, default_table)
880 2260 aaronmk
        return CompareCond(value)
881 2216 aaronmk
    else: return value
882 2219 aaronmk
883 2335 aaronmk
##### Joins
884
885 2352 aaronmk
join_same = object() # tells Join the left and right columns have the same name
886 2260 aaronmk
887 2353 aaronmk
# Tells Join the left and right columns have the same name and are never NULL
888
join_same_not_null = object()
889
890 2260 aaronmk
filter_out = object() # tells Join to filter out rows that match the join
891
892 2514 aaronmk
class Join(BasicObject):
893 2746 aaronmk
    def __init__(self, table, mapping={}, type_=None):
894 2260 aaronmk
        '''
895
        @param mapping dict(right_table_col=left_table_col, ...)
896 2352 aaronmk
            * if left_table_col is join_same: left_table_col = right_table_col
897 2353 aaronmk
              * Note that right_table_col must be a string
898
            * if left_table_col is join_same_not_null:
899
              left_table_col = right_table_col and both have NOT NULL constraint
900
              * Note that right_table_col must be a string
901 2260 aaronmk
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
902
            * filter_out: equivalent to 'LEFT' with the query filtered by
903
              `table_pkey IS NULL` (indicating no match)
904
        '''
905
        if util.is_str(table): table = Table(table)
906
        assert type_ == None or util.is_str(type_) or type_ is filter_out
907
908
        self.table = table
909
        self.mapping = mapping
910
        self.type_ = type_
911
912 2749 aaronmk
    def to_str(self, db, left_table_):
913 2260 aaronmk
        def join(entry):
914
            '''Parses non-USING joins'''
915
            right_table_col, left_table_col = entry
916
917 2353 aaronmk
            # Switch order (right_table_col is on the left in the comparison)
918
            left = right_table_col
919
            right = left_table_col
920 2749 aaronmk
            left_table = self.table
921
            right_table = left_table_
922 2353 aaronmk
923 2747 aaronmk
            # Parse left side
924 2748 aaronmk
            left = with_default_table(left, left_table)
925 2747 aaronmk
926 2260 aaronmk
            # Parse special values
927 2747 aaronmk
            left_on_right = Col(left.name, right_table)
928
            if right is join_same: right = left_on_right
929 2353 aaronmk
            elif right is join_same_not_null:
930 2747 aaronmk
                right = CompareCond(left_on_right, '~=')
931 2260 aaronmk
932 2747 aaronmk
            # Parse right side
933 2353 aaronmk
            right = as_ValueCond(right, right_table)
934 2747 aaronmk
935
            return right.to_str(db, left)
936 2260 aaronmk
937 2265 aaronmk
        # Create join condition
938
        type_ = self.type_
939 2276 aaronmk
        joins = self.mapping
940 2746 aaronmk
        if joins == {}: join_cond = None
941
        elif type_ is not filter_out and reduce(operator.and_,
942 2460 aaronmk
            (v is join_same_not_null for v in joins.itervalues())):
943 2260 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
944 2747 aaronmk
            cols = map(to_name_only_col, joins.iterkeys())
945
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
946 2757 aaronmk
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
947 2260 aaronmk
948 2757 aaronmk
        if isinstance(self.table, NamedTable): whitespace = '\n'
949
        else: whitespace = ' '
950
951 2260 aaronmk
        # Create join
952
        if type_ is filter_out: type_ = 'LEFT'
953 2266 aaronmk
        str_ = ''
954
        if type_ != None: str_ += type_+' '
955 2757 aaronmk
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
956
        if join_cond != None: str_ += whitespace+join_cond
957 2266 aaronmk
        return str_
958 2349 aaronmk
959 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
960 2424 aaronmk
961
##### Value exprs
962
963 3089 aaronmk
all_cols = CustomCode('*')
964
965 2737 aaronmk
default = CustomCode('DEFAULT')
966
967 3090 aaronmk
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
968 2674 aaronmk
969 3061 aaronmk
class Coalesce(FunctionCall):
970
    def __init__(self, *args):
971
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
972 3060 aaronmk
973 3062 aaronmk
class Nullif(FunctionCall):
974
    def __init__(self, *args):
975
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
976
977 3706 aaronmk
null_as_str = Cast('text', 'NULL')
978
979
def to_text(value): return Coalesce(Cast('text', value), null_as_str)
980
981 2850 aaronmk
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
982 2958 aaronmk
null_sentinels = {
983
    'character varying': r'\N',
984
    'double precision': 'NaN',
985
    'integer': 2147483647,
986
    'text': r'\N',
987
    'timestamp with time zone': 'infinity'
988
}
989 2692 aaronmk
990 3061 aaronmk
class EnsureNotNull(Coalesce):
991 2850 aaronmk
    def __init__(self, value, type_):
992 4938 aaronmk
        if isinstance(type_, ArrayType): null = []
993
        else: null = null_sentinels[type_]
994
        Coalesce.__init__(self, as_Col(value), Cast(type_, null))
995 2850 aaronmk
996
        self.type = type_
997 3001 aaronmk
998
    def to_str(self, db):
999
        col = self.args[0]
1000
        index_col_ = index_col(col)
1001
        if index_col_ != None: return index_col_.to_str(db)
1002 3061 aaronmk
        return Coalesce.to_str(self, db)
1003 2850 aaronmk
1004 3523 aaronmk
#### Arrays
1005
1006 3535 aaronmk
class ArrayMerge(FunctionCall):
1007 3523 aaronmk
    def __init__(self, sep, array):
1008
        array = to_Array(array)
1009
        FunctionCall.__init__(self, InternalFunction('array_to_string'), array,
1010
            sep)
1011
1012 3537 aaronmk
def merge_not_null(db, sep, values):
1013 3707 aaronmk
    return ArrayMerge(sep, map(to_text, values))
1014 3537 aaronmk
1015 2737 aaronmk
##### Table exprs
1016
1017
class Values(Code):
1018
    def __init__(self, values):
1019 2739 aaronmk
        '''
1020
        @param values [...]|[[...], ...] Can be one or multiple rows.
1021
        '''
1022 3445 aaronmk
        Code.__init__(self)
1023
1024 2739 aaronmk
        rows = values
1025
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
1026
            rows = [values]
1027
        for i, row in enumerate(rows):
1028
            rows[i] = map(remove_col_rename, map(as_Value, row))
1029 2737 aaronmk
1030 2739 aaronmk
        self.rows = rows
1031 2737 aaronmk
1032
    def to_str(self, db):
1033 3520 aaronmk
        return 'VALUES '+(', '.join((Tuple(*r).to_str(db) for r in self.rows)))
1034 2737 aaronmk
1035 2740 aaronmk
def NamedValues(name, cols, values):
1036 2745 aaronmk
    '''
1037 3048 aaronmk
    @param cols None|[...]
1038 2745 aaronmk
    @post `cols` will be changed to Col objects with the table set to `name`.
1039
    '''
1040 2834 aaronmk
    table = NamedTable(name, Values(values), cols)
1041 3048 aaronmk
    if cols != None: set_cols_table(table, cols)
1042 2834 aaronmk
    return table
1043 2740 aaronmk
1044 2674 aaronmk
##### Database structure
1045
1046 5074 aaronmk
def is_nullable(db, value):
1047
    try: return is_null(value) or db.col_info(value).nullable
1048
    except NoUnderlyingTableException: return True # not a table column
1049
1050 4442 aaronmk
text_types = set(['character varying', 'text'])
1051 4406 aaronmk
1052 4442 aaronmk
def is_text_col(db, col): return db.col_info(col).type in text_types
1053
1054 2840 aaronmk
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
1055
1056 2851 aaronmk
def ensure_not_null(db, col, type_=None):
1057 2840 aaronmk
    '''
1058 2855 aaronmk
    @param col If type_ is not set, must have an underlying column.
1059 4488 aaronmk
    @param type_ If set, overrides the underlying column's type and casts the
1060
        column to it if needed.
1061 2840 aaronmk
    @return EnsureNotNull|Col
1062
    @throws ensure_not_null_excs
1063
    '''
1064 2855 aaronmk
    nullable = True
1065
    try: typed_col = db.col_info(underlying_col(col))
1066
    except NoUnderlyingTableException:
1067 3355 aaronmk
        col = remove_col_rename(col)
1068 3429 aaronmk
        if is_literal(col) and not is_null(col): nullable = False
1069 3355 aaronmk
        elif type_ == None: raise
1070 2855 aaronmk
    else:
1071
        if type_ == None: type_ = typed_col.type
1072 4488 aaronmk
        elif type_ != typed_col.type: col = Cast(type_, col)
1073 2855 aaronmk
        nullable = typed_col.nullable
1074
1075 2953 aaronmk
    if nullable:
1076
        try: col = EnsureNotNull(col, type_)
1077
        except KeyError, e:
1078
            # Warn of no null sentinel for type, even if caller catches error
1079
            warnings.warn(UserWarning(exc.str_(e)))
1080
            raise
1081
1082 2840 aaronmk
    return col
1083 3536 aaronmk
1084
def try_mk_not_null(db, value):
1085
    '''
1086
    Warning: This function does not guarantee that its result is NOT NULL.
1087
    '''
1088
    try: return ensure_not_null(db, value)
1089
    except ensure_not_null_excs: return value