Project

General

Profile

1 2211 aaronmk
# SQL code generation
2
3 2748 aaronmk
import copy
4 3518 aaronmk
import itertools
5 2276 aaronmk
import operator
6 3158 aaronmk
from ordereddict import OrderedDict
7 2568 aaronmk
import re
8 2653 aaronmk
import UserDict
9 2953 aaronmk
import warnings
10 2276 aaronmk
11 2667 aaronmk
import dicts
12 2953 aaronmk
import exc
13 2701 aaronmk
import iters
14
import lists
15 2360 aaronmk
import objects
16 2222 aaronmk
import strings
17 2227 aaronmk
import util
18 2211 aaronmk
19 2587 aaronmk
##### Names
20 2499 aaronmk
21 2608 aaronmk
identifier_max_len = 63 # works for both PostgreSQL and MySQL
22 2587 aaronmk
23 2932 aaronmk
def concat(str_, suffix):
24 2609 aaronmk
    '''Preserves version so that it won't be truncated off the string, leading
25
    to collisions.'''
26 2613 aaronmk
    # Preserve version
27 2995 aaronmk
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
28 2985 aaronmk
    if match:
29
        str_, old_suffix = match.groups()
30
        suffix = old_suffix+suffix
31 2613 aaronmk
32 2932 aaronmk
    return strings.concat(str_, suffix, identifier_max_len)
33 2587 aaronmk
34 2932 aaronmk
def truncate(str_): return concat(str_, '')
35 2842 aaronmk
36 2575 aaronmk
def is_safe_name(name):
37 2583 aaronmk
    '''A name is safe *and unambiguous* if it:
38
    * contains only *lowercase* word (\w) characters
39
    * doesn't start with a digit
40
    * contains "_", so that it's not a keyword
41 2984 aaronmk
    '''
42
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
43 2568 aaronmk
44 2499 aaronmk
def esc_name(name, quote='"'):
45
    return quote + name.replace(quote, quote+quote) + quote
46
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
47
48 3320 aaronmk
def unesc_name(name, quote='"'):
49
    removed_ref = [False]
50
    name = strings.remove_prefix(quote, name, removed_ref)
51
    if removed_ref[0]:
52
        name = strings.remove_suffix(quote, name, removed_ref)
53
        assert removed_ref[0]
54
        name = name.replace(quote+quote, quote)
55
    return name
56
57 2513 aaronmk
def clean_name(name): return name.replace('"', '').replace('`', '')
58
59 3041 aaronmk
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
60
61 3182 aaronmk
def lstrip(str_):
62
    '''Also removes comments.'''
63
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
64
    return str_.lstrip()
65
66 2659 aaronmk
##### General SQL code objects
67 2219 aaronmk
68 2349 aaronmk
class MockDb:
69 2503 aaronmk
    def esc_value(self, value): return strings.repr_no_u(value)
70 2349 aaronmk
71 2499 aaronmk
    def esc_name(self, name): return esc_name(name)
72 2859 aaronmk
73
    def col_info(self, col):
74
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
75
76 2349 aaronmk
mockDb = MockDb()
77
78 2514 aaronmk
class BasicObject(objects.BasicObject):
79
    def __str__(self): return clean_name(strings.repr_no_u(self))
80
81 2659 aaronmk
##### Unparameterized code objects
82
83 2514 aaronmk
class Code(BasicObject):
84 3446 aaronmk
    def __init__(self, lang='sql'):
85
        self.lang = lang
86 3445 aaronmk
87 2658 aaronmk
    def to_str(self, db): raise NotImplementedError()
88 2349 aaronmk
89 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb)
90 2211 aaronmk
91 2269 aaronmk
class CustomCode(Code):
92 3445 aaronmk
    def __init__(self, str_):
93
        Code.__init__(self)
94
95
        self.str_ = str_
96 2256 aaronmk
97
    def to_str(self, db): return self.str_
98
99 2815 aaronmk
def as_Code(value, db=None):
100
    '''
101
    @param db If set, runs db.std_code() on the value.
102
    '''
103 3447 aaronmk
    if isinstance(value, Code): return value
104
105 2815 aaronmk
    if util.is_str(value):
106
        if db != None: value = db.std_code(value)
107
        return CustomCode(value)
108 2659 aaronmk
    else: return Literal(value)
109
110 2540 aaronmk
class Expr(Code):
111 3445 aaronmk
    def __init__(self, expr):
112
        Code.__init__(self)
113
114
        self.expr = expr
115 2540 aaronmk
116
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
117
118 3086 aaronmk
##### Names
119
120
class Name(Code):
121 3087 aaronmk
    def __init__(self, name):
122 3445 aaronmk
        Code.__init__(self)
123
124 3087 aaronmk
        name = truncate(name)
125
126
        self.name = name
127 3086 aaronmk
128
    def to_str(self, db): return db.esc_name(self.name)
129
130
def as_Name(value):
131
    if isinstance(value, Code): return value
132
    else: return Name(value)
133
134 2335 aaronmk
##### Literal values
135
136 3519 aaronmk
#### Primitives
137
138 2216 aaronmk
class Literal(Code):
139 3445 aaronmk
    def __init__(self, value):
140
        Code.__init__(self)
141
142
        self.value = value
143 2213 aaronmk
144
    def to_str(self, db): return db.esc_value(self.value)
145 2211 aaronmk
146 2400 aaronmk
def as_Value(value):
147
    if isinstance(value, Code): return value
148
    else: return Literal(value)
149
150 3429 aaronmk
def is_literal(value): return isinstance(value, Literal)
151 2216 aaronmk
152 3429 aaronmk
def is_null(value): return is_literal(value) and value.value == None
153
154 3519 aaronmk
#### Composites
155
156 3521 aaronmk
class List(Code):
157
    def __init__(self, values):
158 3519 aaronmk
        Code.__init__(self)
159
160
        self.values = values
161
162 3521 aaronmk
    def to_str(self, db): return ', '.join((v.to_str(db) for v in self.values))
163 3519 aaronmk
164 3521 aaronmk
class Tuple(List):
165
    def __init__(self, *values):
166
        List.__init__(self, values)
167
168
    def to_str(self, db): return '('+List.to_str(self, db)+')'
169
170 3520 aaronmk
class Row(Tuple):
171
    def to_str(self, db): return 'ROW'+Tuple.to_str(self, db)
172 3519 aaronmk
173 3522 aaronmk
### Arrays
174
175
class Array(List):
176
    def __init__(self, values):
177
        values = map(remove_col_rename, values)
178
179
        List.__init__(self, values)
180
181
    def to_str(self, db): return 'ARRAY['+List.to_str(self, db)+']'
182
183
def to_Array(value):
184
    if isinstance(value, Array): return value
185
    return Array(lists.mk_seq(value))
186
187 2711 aaronmk
##### Derived elements
188
189
src_self = object() # tells Col that it is its own source column
190
191
class Derived(Code):
192
    def __init__(self, srcs):
193 2712 aaronmk
        '''An element which was derived from some other element(s).
194 2711 aaronmk
        @param srcs See self.set_srcs()
195
        '''
196 3445 aaronmk
        Code.__init__(self)
197
198 2711 aaronmk
        self.set_srcs(srcs)
199
200 2713 aaronmk
    def set_srcs(self, srcs, overwrite=True):
201 2711 aaronmk
        '''
202
        @param srcs (self_type...)|src_self The element(s) this is derived from
203
        '''
204 2713 aaronmk
        if not overwrite and self.srcs != (): return # already set
205
206 2711 aaronmk
        if srcs == src_self: srcs = (self,)
207
        srcs = tuple(srcs) # make Col hashable
208
        self.srcs = srcs
209
210
    def _compare_on(self):
211
        compare_on = self.__dict__.copy()
212
        del compare_on['srcs'] # ignore
213
        return compare_on
214
215
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
216
217 2335 aaronmk
##### Tables
218
219 2712 aaronmk
class Table(Derived):
220 2991 aaronmk
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
221 2211 aaronmk
        '''
222
        @param schema str|None (for no schema)
223 2712 aaronmk
        @param srcs (Table...)|src_self See Derived.set_srcs()
224 2211 aaronmk
        '''
225 2712 aaronmk
        Derived.__init__(self, srcs)
226
227 3091 aaronmk
        if util.is_str(name): name = truncate(name)
228 2843 aaronmk
229 2211 aaronmk
        self.name = name
230
        self.schema = schema
231 2991 aaronmk
        self.is_temp = is_temp
232 3000 aaronmk
        self.index_cols = {}
233 2211 aaronmk
234 2348 aaronmk
    def to_str(self, db):
235
        str_ = ''
236 3088 aaronmk
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
237
        str_ += as_Name(self.name).to_str(db)
238 2348 aaronmk
        return str_
239 2336 aaronmk
240
    def to_Table(self): return self
241 3000 aaronmk
242
    def _compare_on(self):
243
        compare_on = Derived._compare_on(self)
244
        del compare_on['index_cols'] # ignore
245
        return compare_on
246 2211 aaronmk
247 2835 aaronmk
def is_underlying_table(table):
248
    return isinstance(table, Table) and table.to_Table() is table
249 2832 aaronmk
250 2902 aaronmk
class NoUnderlyingTableException(Exception): pass
251
252
def underlying_table(table):
253
    table = remove_table_rename(table)
254
    if not is_underlying_table(table): raise NoUnderlyingTableException
255
    return table
256
257 2776 aaronmk
def as_Table(table, schema=None):
258 2270 aaronmk
    if table == None or isinstance(table, Code): return table
259 2776 aaronmk
    else: return Table(table, schema)
260 2219 aaronmk
261 3101 aaronmk
def suffixed_table(table, suffix):
262 3128 aaronmk
    table = copy.copy(table) # don't modify input!
263
    table.name = concat(table.name, suffix)
264
    return table
265 2707 aaronmk
266 2336 aaronmk
class NamedTable(Table):
267
    def __init__(self, name, code, cols=None):
268
        Table.__init__(self, name)
269
270 3016 aaronmk
        code = as_Table(code)
271 2741 aaronmk
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
272 3020 aaronmk
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
273 2336 aaronmk
274
        self.code = code
275
        self.cols = cols
276
277
    def to_str(self, db):
278 3026 aaronmk
        str_ = self.code.to_str(db)
279
        if str_.find('\n') >= 0: whitespace = '\n'
280
        else: whitespace = ' '
281
        str_ += whitespace+'AS '+Table.to_str(self, db)
282 2742 aaronmk
        if self.cols != None:
283
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
284 2336 aaronmk
        return str_
285
286
    def to_Table(self): return Table(self.name)
287
288 2753 aaronmk
def remove_table_rename(table):
289
    if isinstance(table, NamedTable): table = table.code
290
    return table
291
292 2335 aaronmk
##### Columns
293
294 2711 aaronmk
class Col(Derived):
295 2701 aaronmk
    def __init__(self, name, table=None, srcs=()):
296 2211 aaronmk
        '''
297
        @param table Table|None (for no table)
298 2711 aaronmk
        @param srcs (Col...)|src_self See Derived.set_srcs()
299 2211 aaronmk
        '''
300 2711 aaronmk
        Derived.__init__(self, srcs)
301
302 3091 aaronmk
        if util.is_str(name): name = truncate(name)
303 2241 aaronmk
        if util.is_str(table): table = Table(table)
304 2211 aaronmk
        assert table == None or isinstance(table, Table)
305
306
        self.name = name
307
        self.table = table
308
309 2989 aaronmk
    def to_str(self, db, for_str=False):
310 3088 aaronmk
        str_ = as_Name(self.name).to_str(db)
311 2989 aaronmk
        if for_str: str_ = clean_name(str_)
312 2933 aaronmk
        if self.table != None:
313 2989 aaronmk
            table = self.table.to_Table()
314
            if for_str: str_ = concat(str(table), '.'+str_)
315
            else: str_ = table.to_str(db)+'.'+str_
316 2211 aaronmk
        return str_
317 2314 aaronmk
318 2989 aaronmk
    def __str__(self): return self.to_str(mockDb, for_str=True)
319 2933 aaronmk
320 2314 aaronmk
    def to_Col(self): return self
321 2211 aaronmk
322 3512 aaronmk
def is_col(col): return isinstance(col, Col)
323 2393 aaronmk
324 3512 aaronmk
def is_table_col(col): return is_col(col) and col.table != None
325
326 3000 aaronmk
def index_col(col):
327
    if not is_table_col(col): return None
328 3104 aaronmk
329
    table = col.table
330
    try: name = table.index_cols[col.name]
331
    except KeyError: return None
332
    else: return Col(name, table, col.srcs)
333 2999 aaronmk
334 3024 aaronmk
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
335 2996 aaronmk
336 2563 aaronmk
def as_Col(col, table=None, name=None):
337
    '''
338
    @param name If not None, any non-Col input will be renamed using NamedCol.
339
    '''
340
    if name != None:
341
        col = as_Value(col)
342
        if not isinstance(col, Col): col = NamedCol(name, col)
343 2333 aaronmk
344
    if isinstance(col, Code): return col
345 3513 aaronmk
    elif util.is_str(col): return Col(col, table)
346
    else: return Literal(col)
347 2260 aaronmk
348 3093 aaronmk
def with_table(col, table):
349 3105 aaronmk
    if isinstance(col, NamedCol): pass # doesn't take a table
350
    elif isinstance(col, FunctionCall):
351
        col = copy.deepcopy(col) # don't modify input!
352
        col.args[0].table = table
353 3493 aaronmk
    elif isinstance(col, Col):
354 3098 aaronmk
        col = copy.copy(col) # don't modify input!
355
        col.table = table
356 3093 aaronmk
    return col
357
358 3100 aaronmk
def with_default_table(col, table):
359 2747 aaronmk
    col = as_Col(col)
360 3100 aaronmk
    if col.table == None: col = with_table(col, table)
361 2747 aaronmk
    return col
362
363 2744 aaronmk
def set_cols_table(table, cols):
364
    table = as_Table(table)
365
366
    for i, col in enumerate(cols):
367
        col = cols[i] = as_Col(col)
368
        col.table = table
369
370 2401 aaronmk
def to_name_only_col(col, check_table=None):
371
    col = as_Col(col)
372 3020 aaronmk
    if not is_table_col(col): return col
373 2401 aaronmk
374
    if check_table != None:
375
        table = col.table
376
        assert table == None or table == check_table
377
    return Col(col.name)
378
379 2993 aaronmk
def suffixed_col(col, suffix):
380
    return Col(concat(col.name, suffix), col.table, col.srcs)
381
382 3516 aaronmk
def has_srcs(col): return is_col(col) and col.srcs
383
384 3518 aaronmk
def cross_join_srcs(cols):
385
    cols = filter(has_srcs, cols) # empty srcs will mess up the cross join
386
    srcs = [[s.name for s in c.srcs] for c in cols]
387
    return [Col(','.join(s)) for s in itertools.product(*srcs)]
388 3514 aaronmk
389 2323 aaronmk
class NamedCol(Col):
390 2229 aaronmk
    def __init__(self, name, code):
391 2310 aaronmk
        Col.__init__(self, name)
392
393 3016 aaronmk
        code = as_Value(code)
394 2229 aaronmk
395
        self.code = code
396
397
    def to_str(self, db):
398 2310 aaronmk
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
399 2314 aaronmk
400
    def to_Col(self): return Col(self.name)
401 2229 aaronmk
402 2462 aaronmk
def remove_col_rename(col):
403
    if isinstance(col, NamedCol): col = col.code
404
    return col
405
406 2830 aaronmk
def underlying_col(col):
407
    col = remove_col_rename(col)
408 2849 aaronmk
    if not isinstance(col, Col): raise NoUnderlyingTableException
409
410 2902 aaronmk
    return Col(col.name, underlying_table(col.table), col.srcs)
411 2830 aaronmk
412 2703 aaronmk
def wrap(wrap_func, value):
413
    '''Wraps a value, propagating any column renaming to the returned value.'''
414
    if isinstance(value, NamedCol):
415
        return NamedCol(value.name, wrap_func(value.code))
416
    else: return wrap_func(value)
417
418 2667 aaronmk
class ColDict(dicts.DictProxy):
419 2564 aaronmk
    '''A dict that automatically makes inserted entries Col objects'''
420
421 2645 aaronmk
    def __init__(self, db, keys_table, dict_={}):
422 3158 aaronmk
        dicts.DictProxy.__init__(self, OrderedDict())
423 2667 aaronmk
424 2645 aaronmk
        keys_table = as_Table(keys_table)
425
426 2642 aaronmk
        self.db = db
427 2641 aaronmk
        self.table = keys_table
428 2653 aaronmk
        self.update(dict_) # after setting vars because __setitem__() needs them
429 2641 aaronmk
430 2667 aaronmk
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
431 2655 aaronmk
432 2667 aaronmk
    def __getitem__(self, key):
433
        return dicts.DictProxy.__getitem__(self, self._key(key))
434 2653 aaronmk
435 2564 aaronmk
    def __setitem__(self, key, value):
436 2642 aaronmk
        key = self._key(key)
437 2819 aaronmk
        if value == None: value = self.db.col_info(key).default
438 2667 aaronmk
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
439 2564 aaronmk
440 2641 aaronmk
    def _key(self, key): return as_Col(key, self.table)
441 2564 aaronmk
442 3519 aaronmk
##### Definitions
443 3491 aaronmk
444 3469 aaronmk
class TypedCol(Col):
445
    def __init__(self, name, type_, default=None, nullable=True,
446
        constraints=None):
447
        assert default == None or isinstance(default, Code)
448
449
        Col.__init__(self, name)
450
451
        self.type = type_
452
        self.default = default
453
        self.nullable = nullable
454
        self.constraints = constraints
455
456
    def to_str(self, db):
457 3481 aaronmk
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
458 3469 aaronmk
        if not self.nullable: str_ += ' NOT NULL'
459
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
460
        if self.constraints != None: str_ += ' '+self.constraints
461
        return str_
462
463
    def to_Col(self): return Col(self.name)
464
465 3488 aaronmk
class SetOf(Code):
466
    def __init__(self, type_):
467
        Code.__init__(self)
468
469
        self.type = type_
470
471
    def to_str(self, db):
472
        return 'SETOF '+self.type.to_str(db)
473
474 3483 aaronmk
class RowType(Code):
475
    def __init__(self, table):
476
        Code.__init__(self)
477
478
        self.table = table
479
480
    def to_str(self, db):
481
        return self.table.to_str(db)+'%ROWTYPE'
482
483 3485 aaronmk
class ColType(Code):
484
    def __init__(self, col):
485
        Code.__init__(self)
486
487
        self.col = col
488
489
    def to_str(self, db):
490
        return self.col.to_str(db)+'%TYPE'
491
492 2524 aaronmk
##### Functions
493
494 2912 aaronmk
Function = Table
495 2911 aaronmk
as_Function = as_Table
496
497 2691 aaronmk
class InternalFunction(CustomCode): pass
498
499 3442 aaronmk
#### Calls
500
501 2941 aaronmk
class NamedArg(NamedCol):
502
    def __init__(self, name, value):
503
        NamedCol.__init__(self, name, value)
504
505
    def to_str(self, db):
506
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
507
508 2524 aaronmk
class FunctionCall(Code):
509 2941 aaronmk
    def __init__(self, function, *args, **kw_args):
510 2524 aaronmk
        '''
511 2690 aaronmk
        @param args [Code|literal-value...] The function's arguments
512 2524 aaronmk
        '''
513 3445 aaronmk
        Code.__init__(self)
514
515 3016 aaronmk
        function = as_Function(function)
516 2941 aaronmk
        def filter_(arg): return remove_col_rename(as_Value(arg))
517
        args = map(filter_, args)
518
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
519 2524 aaronmk
520
        self.function = function
521
        self.args = args
522
523
    def to_str(self, db):
524
        args_str = ', '.join((v.to_str(db) for v in self.args))
525
        return self.function.to_str(db)+'('+args_str+')'
526
527 2533 aaronmk
def wrap_in_func(function, value):
528
    '''Wraps a value inside a function call.
529
    Propagates any column renaming to the returned value.
530
    '''
531 2703 aaronmk
    return wrap(lambda v: FunctionCall(function, v), value)
532 2533 aaronmk
533 2561 aaronmk
def unwrap_func_call(func_call, check_name=None):
534
    '''Unwraps any function call to its first argument.
535
    Also removes any column renaming.
536
    '''
537
    func_call = remove_col_rename(func_call)
538
    if not isinstance(func_call, FunctionCall): return func_call
539
540
    if check_name != None:
541
        name = func_call.function.name
542
        assert name == None or name == check_name
543
    return func_call.args[0]
544
545 3442 aaronmk
#### Definitions
546
547
class FunctionDef(Code):
548 3471 aaronmk
    def __init__(self, function, return_type, body, params=[], modifiers=None):
549 3445 aaronmk
        Code.__init__(self)
550
551 3487 aaronmk
        return_type = as_Code(return_type)
552 3444 aaronmk
        body = as_Code(body)
553
554 3442 aaronmk
        self.function = function
555
        self.return_type = return_type
556
        self.body = body
557 3471 aaronmk
        self.params = params
558 3456 aaronmk
        self.modifiers = modifiers
559 3442 aaronmk
560
    def to_str(self, db):
561 3487 aaronmk
        params_str = (', '.join((p.to_str(db) for p in self.params)))
562 3442 aaronmk
        str_ = '''\
563 3471 aaronmk
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
564 3487 aaronmk
RETURNS '''+self.return_type.to_str(db)+'''
565 3448 aaronmk
LANGUAGE '''+self.body.lang+'''
566 3456 aaronmk
'''
567
        if self.modifiers != None: str_ += self.modifiers+'\n'
568
        str_ += '''\
569 3442 aaronmk
AS $$
570 3444 aaronmk
'''+self.body.to_str(db)+'''
571 3442 aaronmk
$$;
572
'''
573
        return str_
574
575 3469 aaronmk
class FunctionParam(TypedCol):
576
    def __init__(self, name, type_, default=None, out=False):
577
        TypedCol.__init__(self, name, type_, default)
578
579
        self.out = out
580
581
    def to_str(self, db):
582
        str_ = TypedCol.to_str(self, db)
583
        if self.out: str_ = 'OUT '+str_
584
        return str_
585
586
    def to_Col(self): return Col(self.name)
587
588 3454 aaronmk
### PL/pgSQL
589
590 3496 aaronmk
class ReturnQuery(Code):
591
    def __init__(self, query):
592
        Code.__init__(self)
593
594
        query = as_Code(query)
595
596
        self.query = query
597
598
    def to_str(self, db):
599
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
600
601 3515 aaronmk
## Exceptions
602
603 3509 aaronmk
class BaseExcHandler(BasicObject):
604
    def to_str(self, db, body): raise NotImplementedError()
605 3510 aaronmk
606
    def __repr__(self): return self.to_str(mockDb, '<body>')
607 3509 aaronmk
608
class ExcHandler(BaseExcHandler):
609 3454 aaronmk
    def __init__(self, exc, handler=None):
610
        if handler != None: handler = as_Code(handler)
611
612
        self.exc = exc
613
        self.handler = handler
614
615
    def to_str(self, db, body):
616
        body = as_Code(body)
617
618 3467 aaronmk
        if self.handler != None:
619
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
620 3463 aaronmk
        else: handler_str = ' NULL;\n'
621 3454 aaronmk
622
        str_ = '''\
623
BEGIN
624 3467 aaronmk
'''+strings.indent(body.to_str(db))+'''\
625 3454 aaronmk
EXCEPTION
626 3463 aaronmk
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
627 3454 aaronmk
END;\
628
'''
629
        return str_
630
631 3515 aaronmk
class NestedExcHandler(BaseExcHandler):
632
    def __init__(self, *handlers):
633
        '''
634
        @param handlers Sorted from outermost to innermost
635
        '''
636
        self.handlers = handlers
637
638
    def to_str(self, db, body):
639
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
640
        return body
641
642 3503 aaronmk
class ExcToWarning(Code):
643
    def __init__(self, return_):
644
        '''
645
        @param return_ Statement to return a default value in case of error
646
        '''
647
        Code.__init__(self)
648
649
        return_ = as_Code(return_)
650
651
        self.return_ = return_
652
653
    def to_str(self, db):
654
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
655
656 3454 aaronmk
unique_violation_handler = ExcHandler('unique_violation')
657
658 3468 aaronmk
plpythonu_error_handler = ExcHandler('internal_error', '''\
659
RAISE data_exception USING MESSAGE =
660
    regexp_replace(SQLERRM, E'^PL/Python: \\w+: ', '');
661
''')
662
663 3505 aaronmk
def data_exception_handler(handler):
664
    return ExcHandler('data_exception', handler)
665
666 3527 aaronmk
row_var = Table('row')
667
668 3449 aaronmk
class RowExcIgnore(Code):
669
    def __init__(self, row_type, select_query, with_row, cols=None,
670 3527 aaronmk
        exc_handler=unique_violation_handler, row_var=row_var):
671 3529 aaronmk
        '''
672
        @param row_type Ignored if a custom row_var is used.
673
        @pre If a custom row_var is used, it must already be defined.
674
        '''
675 3449 aaronmk
        Code.__init__(self, lang='plpgsql')
676
677 3482 aaronmk
        row_type = as_Code(row_type)
678 3449 aaronmk
        select_query = as_Code(select_query)
679
        with_row = as_Code(with_row)
680 3452 aaronmk
        row_var = as_Table(row_var)
681 3449 aaronmk
682
        self.row_type = row_type
683
        self.select_query = select_query
684
        self.with_row = with_row
685
        self.cols = cols
686 3455 aaronmk
        self.exc_handler = exc_handler
687 3452 aaronmk
        self.row_var = row_var
688 3449 aaronmk
689
    def to_str(self, db):
690 3452 aaronmk
        if self.cols == None: row_vars = [self.row_var]
691
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
692 3449 aaronmk
693 3526 aaronmk
        # Need an EXCEPTION block for each individual row because "When an error
694
        # is caught by an EXCEPTION clause, [...] all changes to persistent
695
        # database state within the block are rolled back."
696
        # This is unfortunate because "A block containing an EXCEPTION clause is
697
        # significantly more expensive to enter and exit than a block without
698
        # one."
699
        # (http://www.postgresql.org/docs/8.3/static/\
700
        # plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING)
701 3449 aaronmk
        str_ = '''\
702 3529 aaronmk
FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
703
'''+strings.indent(self.select_query.to_str(db))+'''\
704
LOOP
705
'''+strings.indent(self.exc_handler.to_str(db, self.with_row))+'''\
706
END LOOP;
707
'''
708
        if self.row_var is row_var:
709
            str_ = '''\
710 3449 aaronmk
DECLARE
711 3482 aaronmk
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
712 3449 aaronmk
BEGIN
713 3529 aaronmk
'''+strings.indent(str_)+'''\
714
END;
715 3449 aaronmk
'''
716
        return str_
717
718 2986 aaronmk
##### Casts
719
720
class Cast(FunctionCall):
721
    def __init__(self, type_, value):
722
        value = as_Value(value)
723
724
        self.type_ = type_
725
        self.value = value
726
727
    def to_str(self, db):
728
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_+')'
729
730 3354 aaronmk
def cast_literal(value):
731 3429 aaronmk
    if not is_literal(value): return value
732 3354 aaronmk
733
    if util.is_str(value.value): value = Cast('text', value)
734
    return value
735
736 2335 aaronmk
##### Conditions
737 2259 aaronmk
738 3350 aaronmk
class NotCond(Code):
739
    def __init__(self, cond):
740 3445 aaronmk
        Code.__init__(self)
741
742 3350 aaronmk
        self.cond = cond
743
744
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
745
746 2398 aaronmk
class ColValueCond(Code):
747
    def __init__(self, col, value):
748 3445 aaronmk
        Code.__init__(self)
749
750 2398 aaronmk
        value = as_ValueCond(value)
751
752
        self.col = col
753
        self.value = value
754
755
    def to_str(self, db): return self.value.to_str(db, self.col)
756
757 2577 aaronmk
def combine_conds(conds, keyword=None):
758
    '''
759
    @param keyword The keyword to add before the conditions, if any
760
    '''
761
    str_ = ''
762
    if keyword != None:
763
        if conds == []: whitespace = ''
764
        elif len(conds) == 1: whitespace = ' '
765
        else: whitespace = '\n'
766
        str_ += keyword+whitespace
767
768
    str_ += '\nAND '.join(conds)
769
    return str_
770
771 2398 aaronmk
##### Condition column comparisons
772
773 2514 aaronmk
class ValueCond(BasicObject):
774 2213 aaronmk
    def __init__(self, value):
775 2858 aaronmk
        value = remove_col_rename(as_Value(value))
776 2213 aaronmk
777
        self.value = value
778 2214 aaronmk
779 2216 aaronmk
    def to_str(self, db, left_value):
780 2214 aaronmk
        '''
781 2216 aaronmk
        @param left_value The Code object that the condition is being applied on
782 2214 aaronmk
        '''
783
        raise NotImplemented()
784 2228 aaronmk
785 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
786 2211 aaronmk
787
class CompareCond(ValueCond):
788
    def __init__(self, value, operator='='):
789 2222 aaronmk
        '''
790
        @param operator By default, compares NULL values literally. Use '~=' or
791
            '~!=' to pass NULLs through.
792
        '''
793 2211 aaronmk
        ValueCond.__init__(self, value)
794
        self.operator = operator
795
796 2216 aaronmk
    def to_str(self, db, left_value):
797 2858 aaronmk
        left_value = remove_col_rename(as_Col(left_value))
798 2216 aaronmk
799 2222 aaronmk
        right_value = self.value
800
801
        # Parse operator
802 2216 aaronmk
        operator = self.operator
803 2222 aaronmk
        passthru_null_ref = [False]
804
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
805
        neg_ref = [False]
806
        operator = strings.remove_prefix('!', operator, neg_ref)
807 2844 aaronmk
        equals = operator.endswith('=') # also includes <=, >=
808 2222 aaronmk
809 2825 aaronmk
        # Handle nullable columns
810
        check_null = False
811 2844 aaronmk
        if not passthru_null_ref[0]: # NULLs compare equal
812 2857 aaronmk
            try: left_value = ensure_not_null(db, left_value)
813 2844 aaronmk
            except ensure_not_null_excs: # fall back to alternate method
814
                check_null = equals and isinstance(right_value, Col)
815 2837 aaronmk
            else:
816 2857 aaronmk
                if isinstance(left_value, EnsureNotNull):
817
                    right_value = ensure_not_null(db, right_value,
818
                        left_value.type) # apply same function to both sides
819 2825 aaronmk
820 2844 aaronmk
        if equals and is_null(right_value): operator = 'IS'
821
822 2825 aaronmk
        left = left_value.to_str(db)
823
        right = right_value.to_str(db)
824
825 2222 aaronmk
        # Create str
826
        str_ = left+' '+operator+' '+right
827 2825 aaronmk
        if check_null:
828 2578 aaronmk
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
829
        if neg_ref[0]: str_ = 'NOT '+str_
830 2222 aaronmk
        return str_
831 2216 aaronmk
832 2260 aaronmk
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
833
assume_literal = object()
834
835
def as_ValueCond(value, default_table=assume_literal):
836
    if not isinstance(value, ValueCond):
837
        if default_table is not assume_literal:
838 2748 aaronmk
            value = with_default_table(value, default_table)
839 2260 aaronmk
        return CompareCond(value)
840 2216 aaronmk
    else: return value
841 2219 aaronmk
842 2335 aaronmk
##### Joins
843
844 2352 aaronmk
join_same = object() # tells Join the left and right columns have the same name
845 2260 aaronmk
846 2353 aaronmk
# Tells Join the left and right columns have the same name and are never NULL
847
join_same_not_null = object()
848
849 2260 aaronmk
filter_out = object() # tells Join to filter out rows that match the join
850
851 2514 aaronmk
class Join(BasicObject):
852 2746 aaronmk
    def __init__(self, table, mapping={}, type_=None):
853 2260 aaronmk
        '''
854
        @param mapping dict(right_table_col=left_table_col, ...)
855 2352 aaronmk
            * if left_table_col is join_same: left_table_col = right_table_col
856 2353 aaronmk
              * Note that right_table_col must be a string
857
            * if left_table_col is join_same_not_null:
858
              left_table_col = right_table_col and both have NOT NULL constraint
859
              * Note that right_table_col must be a string
860 2260 aaronmk
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
861
            * filter_out: equivalent to 'LEFT' with the query filtered by
862
              `table_pkey IS NULL` (indicating no match)
863
        '''
864
        if util.is_str(table): table = Table(table)
865
        assert type_ == None or util.is_str(type_) or type_ is filter_out
866
867
        self.table = table
868
        self.mapping = mapping
869
        self.type_ = type_
870
871 2749 aaronmk
    def to_str(self, db, left_table_):
872 2260 aaronmk
        def join(entry):
873
            '''Parses non-USING joins'''
874
            right_table_col, left_table_col = entry
875
876 2353 aaronmk
            # Switch order (right_table_col is on the left in the comparison)
877
            left = right_table_col
878
            right = left_table_col
879 2749 aaronmk
            left_table = self.table
880
            right_table = left_table_
881 2353 aaronmk
882 2747 aaronmk
            # Parse left side
883 2748 aaronmk
            left = with_default_table(left, left_table)
884 2747 aaronmk
885 2260 aaronmk
            # Parse special values
886 2747 aaronmk
            left_on_right = Col(left.name, right_table)
887
            if right is join_same: right = left_on_right
888 2353 aaronmk
            elif right is join_same_not_null:
889 2747 aaronmk
                right = CompareCond(left_on_right, '~=')
890 2260 aaronmk
891 2747 aaronmk
            # Parse right side
892 2353 aaronmk
            right = as_ValueCond(right, right_table)
893 2747 aaronmk
894
            return right.to_str(db, left)
895 2260 aaronmk
896 2265 aaronmk
        # Create join condition
897
        type_ = self.type_
898 2276 aaronmk
        joins = self.mapping
899 2746 aaronmk
        if joins == {}: join_cond = None
900
        elif type_ is not filter_out and reduce(operator.and_,
901 2460 aaronmk
            (v is join_same_not_null for v in joins.itervalues())):
902 2260 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
903 2747 aaronmk
            cols = map(to_name_only_col, joins.iterkeys())
904
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
905 2757 aaronmk
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
906 2260 aaronmk
907 2757 aaronmk
        if isinstance(self.table, NamedTable): whitespace = '\n'
908
        else: whitespace = ' '
909
910 2260 aaronmk
        # Create join
911
        if type_ is filter_out: type_ = 'LEFT'
912 2266 aaronmk
        str_ = ''
913
        if type_ != None: str_ += type_+' '
914 2757 aaronmk
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
915
        if join_cond != None: str_ += whitespace+join_cond
916 2266 aaronmk
        return str_
917 2349 aaronmk
918 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
919 2424 aaronmk
920
##### Value exprs
921
922 3089 aaronmk
all_cols = CustomCode('*')
923
924 2737 aaronmk
default = CustomCode('DEFAULT')
925
926 3090 aaronmk
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
927 2674 aaronmk
928 3061 aaronmk
class Coalesce(FunctionCall):
929
    def __init__(self, *args):
930
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
931 3060 aaronmk
932 3062 aaronmk
class Nullif(FunctionCall):
933
    def __init__(self, *args):
934
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
935
936 2850 aaronmk
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
937 2958 aaronmk
null_sentinels = {
938
    'character varying': r'\N',
939
    'double precision': 'NaN',
940
    'integer': 2147483647,
941
    'text': r'\N',
942
    'timestamp with time zone': 'infinity'
943
}
944 2692 aaronmk
945 3061 aaronmk
class EnsureNotNull(Coalesce):
946 2850 aaronmk
    def __init__(self, value, type_):
947 3061 aaronmk
        Coalesce.__init__(self, as_Col(value),
948 2988 aaronmk
            Cast(type_, null_sentinels[type_]))
949 2850 aaronmk
950
        self.type = type_
951 3001 aaronmk
952
    def to_str(self, db):
953
        col = self.args[0]
954
        index_col_ = index_col(col)
955
        if index_col_ != None: return index_col_.to_str(db)
956 3061 aaronmk
        return Coalesce.to_str(self, db)
957 2850 aaronmk
958 3523 aaronmk
#### Arrays
959
960
class ArrayJoin(FunctionCall):
961
    def __init__(self, sep, array):
962
        array = to_Array(array)
963
        FunctionCall.__init__(self, InternalFunction('array_to_string'), array,
964
            sep)
965
966 2737 aaronmk
##### Table exprs
967
968
class Values(Code):
969
    def __init__(self, values):
970 2739 aaronmk
        '''
971
        @param values [...]|[[...], ...] Can be one or multiple rows.
972
        '''
973 3445 aaronmk
        Code.__init__(self)
974
975 2739 aaronmk
        rows = values
976
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
977
            rows = [values]
978
        for i, row in enumerate(rows):
979
            rows[i] = map(remove_col_rename, map(as_Value, row))
980 2737 aaronmk
981 2739 aaronmk
        self.rows = rows
982 2737 aaronmk
983
    def to_str(self, db):
984 3520 aaronmk
        return 'VALUES '+(', '.join((Tuple(*r).to_str(db) for r in self.rows)))
985 2737 aaronmk
986 2740 aaronmk
def NamedValues(name, cols, values):
987 2745 aaronmk
    '''
988 3048 aaronmk
    @param cols None|[...]
989 2745 aaronmk
    @post `cols` will be changed to Col objects with the table set to `name`.
990
    '''
991 2834 aaronmk
    table = NamedTable(name, Values(values), cols)
992 3048 aaronmk
    if cols != None: set_cols_table(table, cols)
993 2834 aaronmk
    return table
994 2740 aaronmk
995 2674 aaronmk
##### Database structure
996
997 2840 aaronmk
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
998
999 2851 aaronmk
def ensure_not_null(db, col, type_=None):
1000 2840 aaronmk
    '''
1001 2855 aaronmk
    @param col If type_ is not set, must have an underlying column.
1002 2851 aaronmk
    @param type_ If set, overrides the underlying column's type.
1003 2840 aaronmk
    @return EnsureNotNull|Col
1004
    @throws ensure_not_null_excs
1005
    '''
1006 2855 aaronmk
    nullable = True
1007
    try: typed_col = db.col_info(underlying_col(col))
1008
    except NoUnderlyingTableException:
1009 3355 aaronmk
        col = remove_col_rename(col)
1010 3429 aaronmk
        if is_literal(col) and not is_null(col): nullable = False
1011 3355 aaronmk
        elif type_ == None: raise
1012 2855 aaronmk
    else:
1013
        if type_ == None: type_ = typed_col.type
1014
        nullable = typed_col.nullable
1015
1016 2953 aaronmk
    if nullable:
1017
        try: col = EnsureNotNull(col, type_)
1018
        except KeyError, e:
1019
            # Warn of no null sentinel for type, even if caller catches error
1020
            warnings.warn(UserWarning(exc.str_(e)))
1021
            raise
1022
1023 2840 aaronmk
    return col