Project

General

Profile

1 2211 aaronmk
# SQL code generation
2
3 2748 aaronmk
import copy
4 2276 aaronmk
import operator
5 3158 aaronmk
from ordereddict import OrderedDict
6 2568 aaronmk
import re
7 2653 aaronmk
import UserDict
8 2953 aaronmk
import warnings
9 2276 aaronmk
10 2667 aaronmk
import dicts
11 2953 aaronmk
import exc
12 2701 aaronmk
import iters
13
import lists
14 2360 aaronmk
import objects
15 2222 aaronmk
import strings
16 2227 aaronmk
import util
17 2211 aaronmk
18 2587 aaronmk
##### Names
19 2499 aaronmk
20 2608 aaronmk
identifier_max_len = 63 # works for both PostgreSQL and MySQL
21 2587 aaronmk
22 2932 aaronmk
def concat(str_, suffix):
23 2609 aaronmk
    '''Preserves version so that it won't be truncated off the string, leading
24
    to collisions.'''
25 2613 aaronmk
    # Preserve version
26 2995 aaronmk
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
27 2985 aaronmk
    if match:
28
        str_, old_suffix = match.groups()
29
        suffix = old_suffix+suffix
30 2613 aaronmk
31 2932 aaronmk
    return strings.concat(str_, suffix, identifier_max_len)
32 2587 aaronmk
33 2932 aaronmk
def truncate(str_): return concat(str_, '')
34 2842 aaronmk
35 2575 aaronmk
def is_safe_name(name):
36 2583 aaronmk
    '''A name is safe *and unambiguous* if it:
37
    * contains only *lowercase* word (\w) characters
38
    * doesn't start with a digit
39
    * contains "_", so that it's not a keyword
40 2984 aaronmk
    '''
41
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
42 2568 aaronmk
43 2499 aaronmk
def esc_name(name, quote='"'):
44
    return quote + name.replace(quote, quote+quote) + quote
45
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
46
47 3320 aaronmk
def unesc_name(name, quote='"'):
48
    removed_ref = [False]
49
    name = strings.remove_prefix(quote, name, removed_ref)
50
    if removed_ref[0]:
51
        name = strings.remove_suffix(quote, name, removed_ref)
52
        assert removed_ref[0]
53
        name = name.replace(quote+quote, quote)
54
    return name
55
56 2513 aaronmk
def clean_name(name): return name.replace('"', '').replace('`', '')
57
58 3041 aaronmk
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
59
60 3182 aaronmk
def lstrip(str_):
61
    '''Also removes comments.'''
62
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
63
    return str_.lstrip()
64
65 2659 aaronmk
##### General SQL code objects
66 2219 aaronmk
67 2349 aaronmk
class MockDb:
68 2503 aaronmk
    def esc_value(self, value): return strings.repr_no_u(value)
69 2349 aaronmk
70 2499 aaronmk
    def esc_name(self, name): return esc_name(name)
71 2859 aaronmk
72
    def col_info(self, col):
73
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
74
75 2349 aaronmk
mockDb = MockDb()
76
77 2514 aaronmk
class BasicObject(objects.BasicObject):
78
    def __str__(self): return clean_name(strings.repr_no_u(self))
79
80 2659 aaronmk
##### Unparameterized code objects
81
82 2514 aaronmk
class Code(BasicObject):
83 3446 aaronmk
    def __init__(self, lang='sql'):
84
        self.lang = lang
85 3445 aaronmk
86 2658 aaronmk
    def to_str(self, db): raise NotImplementedError()
87 2349 aaronmk
88 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb)
89 2211 aaronmk
90 2269 aaronmk
class CustomCode(Code):
91 3445 aaronmk
    def __init__(self, str_):
92
        Code.__init__(self)
93
94
        self.str_ = str_
95 2256 aaronmk
96
    def to_str(self, db): return self.str_
97
98 2815 aaronmk
def as_Code(value, db=None):
99
    '''
100
    @param db If set, runs db.std_code() on the value.
101
    '''
102 3447 aaronmk
    if isinstance(value, Code): return value
103
104 2815 aaronmk
    if util.is_str(value):
105
        if db != None: value = db.std_code(value)
106
        return CustomCode(value)
107 2659 aaronmk
    else: return Literal(value)
108
109 2540 aaronmk
class Expr(Code):
110 3445 aaronmk
    def __init__(self, expr):
111
        Code.__init__(self)
112
113
        self.expr = expr
114 2540 aaronmk
115
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
116
117 3086 aaronmk
##### Names
118
119
class Name(Code):
120 3087 aaronmk
    def __init__(self, name):
121 3445 aaronmk
        Code.__init__(self)
122
123 3087 aaronmk
        name = truncate(name)
124
125
        self.name = name
126 3086 aaronmk
127
    def to_str(self, db): return db.esc_name(self.name)
128
129
def as_Name(value):
130
    if isinstance(value, Code): return value
131
    else: return Name(value)
132
133 2335 aaronmk
##### Literal values
134
135 2216 aaronmk
class Literal(Code):
136 3445 aaronmk
    def __init__(self, value):
137
        Code.__init__(self)
138
139
        self.value = value
140 2213 aaronmk
141
    def to_str(self, db): return db.esc_value(self.value)
142 2211 aaronmk
143 2400 aaronmk
def as_Value(value):
144
    if isinstance(value, Code): return value
145
    else: return Literal(value)
146
147 3429 aaronmk
def is_literal(value): return isinstance(value, Literal)
148 2216 aaronmk
149 3429 aaronmk
def is_null(value): return is_literal(value) and value.value == None
150
151 2711 aaronmk
##### Derived elements
152
153
src_self = object() # tells Col that it is its own source column
154
155
class Derived(Code):
156
    def __init__(self, srcs):
157 2712 aaronmk
        '''An element which was derived from some other element(s).
158 2711 aaronmk
        @param srcs See self.set_srcs()
159
        '''
160 3445 aaronmk
        Code.__init__(self)
161
162 2711 aaronmk
        self.set_srcs(srcs)
163
164 2713 aaronmk
    def set_srcs(self, srcs, overwrite=True):
165 2711 aaronmk
        '''
166
        @param srcs (self_type...)|src_self The element(s) this is derived from
167
        '''
168 2713 aaronmk
        if not overwrite and self.srcs != (): return # already set
169
170 2711 aaronmk
        if srcs == src_self: srcs = (self,)
171
        srcs = tuple(srcs) # make Col hashable
172
        self.srcs = srcs
173
174
    def _compare_on(self):
175
        compare_on = self.__dict__.copy()
176
        del compare_on['srcs'] # ignore
177
        return compare_on
178
179
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
180
181 2335 aaronmk
##### Tables
182
183 2712 aaronmk
class Table(Derived):
184 2991 aaronmk
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
185 2211 aaronmk
        '''
186
        @param schema str|None (for no schema)
187 2712 aaronmk
        @param srcs (Table...)|src_self See Derived.set_srcs()
188 2211 aaronmk
        '''
189 2712 aaronmk
        Derived.__init__(self, srcs)
190
191 3091 aaronmk
        if util.is_str(name): name = truncate(name)
192 2843 aaronmk
193 2211 aaronmk
        self.name = name
194
        self.schema = schema
195 2991 aaronmk
        self.is_temp = is_temp
196 3000 aaronmk
        self.index_cols = {}
197 2211 aaronmk
198 2348 aaronmk
    def to_str(self, db):
199
        str_ = ''
200 3088 aaronmk
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
201
        str_ += as_Name(self.name).to_str(db)
202 2348 aaronmk
        return str_
203 2336 aaronmk
204
    def to_Table(self): return self
205 3000 aaronmk
206
    def _compare_on(self):
207
        compare_on = Derived._compare_on(self)
208
        del compare_on['index_cols'] # ignore
209
        return compare_on
210 2211 aaronmk
211 2835 aaronmk
def is_underlying_table(table):
212
    return isinstance(table, Table) and table.to_Table() is table
213 2832 aaronmk
214 2902 aaronmk
class NoUnderlyingTableException(Exception): pass
215
216
def underlying_table(table):
217
    table = remove_table_rename(table)
218
    if not is_underlying_table(table): raise NoUnderlyingTableException
219
    return table
220
221 2776 aaronmk
def as_Table(table, schema=None):
222 2270 aaronmk
    if table == None or isinstance(table, Code): return table
223 2776 aaronmk
    else: return Table(table, schema)
224 2219 aaronmk
225 3101 aaronmk
def suffixed_table(table, suffix):
226 3128 aaronmk
    table = copy.copy(table) # don't modify input!
227
    table.name = concat(table.name, suffix)
228
    return table
229 2707 aaronmk
230 2336 aaronmk
class NamedTable(Table):
231
    def __init__(self, name, code, cols=None):
232
        Table.__init__(self, name)
233
234 3016 aaronmk
        code = as_Table(code)
235 2741 aaronmk
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
236 3020 aaronmk
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
237 2336 aaronmk
238
        self.code = code
239
        self.cols = cols
240
241
    def to_str(self, db):
242 3026 aaronmk
        str_ = self.code.to_str(db)
243
        if str_.find('\n') >= 0: whitespace = '\n'
244
        else: whitespace = ' '
245
        str_ += whitespace+'AS '+Table.to_str(self, db)
246 2742 aaronmk
        if self.cols != None:
247
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
248 2336 aaronmk
        return str_
249
250
    def to_Table(self): return Table(self.name)
251
252 2753 aaronmk
def remove_table_rename(table):
253
    if isinstance(table, NamedTable): table = table.code
254
    return table
255
256 2335 aaronmk
##### Columns
257
258 2711 aaronmk
class Col(Derived):
259 2701 aaronmk
    def __init__(self, name, table=None, srcs=()):
260 2211 aaronmk
        '''
261
        @param table Table|None (for no table)
262 2711 aaronmk
        @param srcs (Col...)|src_self See Derived.set_srcs()
263 2211 aaronmk
        '''
264 2711 aaronmk
        Derived.__init__(self, srcs)
265
266 3091 aaronmk
        if util.is_str(name): name = truncate(name)
267 2241 aaronmk
        if util.is_str(table): table = Table(table)
268 2211 aaronmk
        assert table == None or isinstance(table, Table)
269
270
        self.name = name
271
        self.table = table
272
273 2989 aaronmk
    def to_str(self, db, for_str=False):
274 3088 aaronmk
        str_ = as_Name(self.name).to_str(db)
275 2989 aaronmk
        if for_str: str_ = clean_name(str_)
276 2933 aaronmk
        if self.table != None:
277 2989 aaronmk
            table = self.table.to_Table()
278
            if for_str: str_ = concat(str(table), '.'+str_)
279
            else: str_ = table.to_str(db)+'.'+str_
280 2211 aaronmk
        return str_
281 2314 aaronmk
282 2989 aaronmk
    def __str__(self): return self.to_str(mockDb, for_str=True)
283 2933 aaronmk
284 2314 aaronmk
    def to_Col(self): return self
285 2211 aaronmk
286 3512 aaronmk
def is_col(col): return isinstance(col, Col)
287 2393 aaronmk
288 3512 aaronmk
def is_table_col(col): return is_col(col) and col.table != None
289
290 3000 aaronmk
def index_col(col):
291
    if not is_table_col(col): return None
292 3104 aaronmk
293
    table = col.table
294
    try: name = table.index_cols[col.name]
295
    except KeyError: return None
296
    else: return Col(name, table, col.srcs)
297 2999 aaronmk
298 3024 aaronmk
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
299 2996 aaronmk
300 2563 aaronmk
def as_Col(col, table=None, name=None):
301
    '''
302
    @param name If not None, any non-Col input will be renamed using NamedCol.
303
    '''
304
    if name != None:
305
        col = as_Value(col)
306
        if not isinstance(col, Col): col = NamedCol(name, col)
307 2333 aaronmk
308
    if isinstance(col, Code): return col
309 3513 aaronmk
    elif util.is_str(col): return Col(col, table)
310
    else: return Literal(col)
311 2260 aaronmk
312 3093 aaronmk
def with_table(col, table):
313 3105 aaronmk
    if isinstance(col, NamedCol): pass # doesn't take a table
314
    elif isinstance(col, FunctionCall):
315
        col = copy.deepcopy(col) # don't modify input!
316
        col.args[0].table = table
317 3493 aaronmk
    elif isinstance(col, Col):
318 3098 aaronmk
        col = copy.copy(col) # don't modify input!
319
        col.table = table
320 3093 aaronmk
    return col
321
322 3100 aaronmk
def with_default_table(col, table):
323 2747 aaronmk
    col = as_Col(col)
324 3100 aaronmk
    if col.table == None: col = with_table(col, table)
325 2747 aaronmk
    return col
326
327 2744 aaronmk
def set_cols_table(table, cols):
328
    table = as_Table(table)
329
330
    for i, col in enumerate(cols):
331
        col = cols[i] = as_Col(col)
332
        col.table = table
333
334 2401 aaronmk
def to_name_only_col(col, check_table=None):
335
    col = as_Col(col)
336 3020 aaronmk
    if not is_table_col(col): return col
337 2401 aaronmk
338
    if check_table != None:
339
        table = col.table
340
        assert table == None or table == check_table
341
    return Col(col.name)
342
343 2993 aaronmk
def suffixed_col(col, suffix):
344
    return Col(concat(col.name, suffix), col.table, col.srcs)
345
346 3516 aaronmk
def has_srcs(col): return is_col(col) and col.srcs
347
348 3514 aaronmk
def srcs_str(cols):
349
    cols = filter(is_col, cols)
350
    return ','.join(('+'.join((s.name for s in c.srcs)) for c in cols))
351
352 2323 aaronmk
class NamedCol(Col):
353 2229 aaronmk
    def __init__(self, name, code):
354 2310 aaronmk
        Col.__init__(self, name)
355
356 3016 aaronmk
        code = as_Value(code)
357 2229 aaronmk
358
        self.code = code
359
360
    def to_str(self, db):
361 2310 aaronmk
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
362 2314 aaronmk
363
    def to_Col(self): return Col(self.name)
364 2229 aaronmk
365 2462 aaronmk
def remove_col_rename(col):
366
    if isinstance(col, NamedCol): col = col.code
367
    return col
368
369 2830 aaronmk
def underlying_col(col):
370
    col = remove_col_rename(col)
371 2849 aaronmk
    if not isinstance(col, Col): raise NoUnderlyingTableException
372
373 2902 aaronmk
    return Col(col.name, underlying_table(col.table), col.srcs)
374 2830 aaronmk
375 2703 aaronmk
def wrap(wrap_func, value):
376
    '''Wraps a value, propagating any column renaming to the returned value.'''
377
    if isinstance(value, NamedCol):
378
        return NamedCol(value.name, wrap_func(value.code))
379
    else: return wrap_func(value)
380
381 2667 aaronmk
class ColDict(dicts.DictProxy):
382 2564 aaronmk
    '''A dict that automatically makes inserted entries Col objects'''
383
384 2645 aaronmk
    def __init__(self, db, keys_table, dict_={}):
385 3158 aaronmk
        dicts.DictProxy.__init__(self, OrderedDict())
386 2667 aaronmk
387 2645 aaronmk
        keys_table = as_Table(keys_table)
388
389 2642 aaronmk
        self.db = db
390 2641 aaronmk
        self.table = keys_table
391 2653 aaronmk
        self.update(dict_) # after setting vars because __setitem__() needs them
392 2641 aaronmk
393 2667 aaronmk
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
394 2655 aaronmk
395 2667 aaronmk
    def __getitem__(self, key):
396
        return dicts.DictProxy.__getitem__(self, self._key(key))
397 2653 aaronmk
398 2564 aaronmk
    def __setitem__(self, key, value):
399 2642 aaronmk
        key = self._key(key)
400 2819 aaronmk
        if value == None: value = self.db.col_info(key).default
401 2667 aaronmk
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
402 2564 aaronmk
403 2641 aaronmk
    def _key(self, key): return as_Col(key, self.table)
404 2564 aaronmk
405 3491 aaronmk
##### Composite types
406
407
class List(Code):
408
    def __init__(self, *values):
409
        Code.__init__(self)
410
411
        self.values = values
412
413
    def to_str(self, db):
414
        return '('+(', '.join((v.to_str(db) for v in self.values)))+')'
415
416 3492 aaronmk
class Tuple(List):
417
    def to_str(self, db): return 'ROW'+List.to_str(self, db)
418
419 3469 aaronmk
#### Definitions
420
421
class TypedCol(Col):
422
    def __init__(self, name, type_, default=None, nullable=True,
423
        constraints=None):
424
        assert default == None or isinstance(default, Code)
425
426
        Col.__init__(self, name)
427
428
        self.type = type_
429
        self.default = default
430
        self.nullable = nullable
431
        self.constraints = constraints
432
433
    def to_str(self, db):
434 3481 aaronmk
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
435 3469 aaronmk
        if not self.nullable: str_ += ' NOT NULL'
436
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
437
        if self.constraints != None: str_ += ' '+self.constraints
438
        return str_
439
440
    def to_Col(self): return Col(self.name)
441
442 3488 aaronmk
class SetOf(Code):
443
    def __init__(self, type_):
444
        Code.__init__(self)
445
446
        self.type = type_
447
448
    def to_str(self, db):
449
        return 'SETOF '+self.type.to_str(db)
450
451 3483 aaronmk
class RowType(Code):
452
    def __init__(self, table):
453
        Code.__init__(self)
454
455
        self.table = table
456
457
    def to_str(self, db):
458
        return self.table.to_str(db)+'%ROWTYPE'
459
460 3485 aaronmk
class ColType(Code):
461
    def __init__(self, col):
462
        Code.__init__(self)
463
464
        self.col = col
465
466
    def to_str(self, db):
467
        return self.col.to_str(db)+'%TYPE'
468
469 2524 aaronmk
##### Functions
470
471 2912 aaronmk
Function = Table
472 2911 aaronmk
as_Function = as_Table
473
474 2691 aaronmk
class InternalFunction(CustomCode): pass
475
476 3442 aaronmk
#### Calls
477
478 2941 aaronmk
class NamedArg(NamedCol):
479
    def __init__(self, name, value):
480
        NamedCol.__init__(self, name, value)
481
482
    def to_str(self, db):
483
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
484
485 2524 aaronmk
class FunctionCall(Code):
486 2941 aaronmk
    def __init__(self, function, *args, **kw_args):
487 2524 aaronmk
        '''
488 2690 aaronmk
        @param args [Code|literal-value...] The function's arguments
489 2524 aaronmk
        '''
490 3445 aaronmk
        Code.__init__(self)
491
492 3016 aaronmk
        function = as_Function(function)
493 2941 aaronmk
        def filter_(arg): return remove_col_rename(as_Value(arg))
494
        args = map(filter_, args)
495
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
496 2524 aaronmk
497
        self.function = function
498
        self.args = args
499
500
    def to_str(self, db):
501
        args_str = ', '.join((v.to_str(db) for v in self.args))
502
        return self.function.to_str(db)+'('+args_str+')'
503
504 2533 aaronmk
def wrap_in_func(function, value):
505
    '''Wraps a value inside a function call.
506
    Propagates any column renaming to the returned value.
507
    '''
508 2703 aaronmk
    return wrap(lambda v: FunctionCall(function, v), value)
509 2533 aaronmk
510 2561 aaronmk
def unwrap_func_call(func_call, check_name=None):
511
    '''Unwraps any function call to its first argument.
512
    Also removes any column renaming.
513
    '''
514
    func_call = remove_col_rename(func_call)
515
    if not isinstance(func_call, FunctionCall): return func_call
516
517
    if check_name != None:
518
        name = func_call.function.name
519
        assert name == None or name == check_name
520
    return func_call.args[0]
521
522 3442 aaronmk
#### Definitions
523
524
class FunctionDef(Code):
525 3471 aaronmk
    def __init__(self, function, return_type, body, params=[], modifiers=None):
526 3445 aaronmk
        Code.__init__(self)
527
528 3487 aaronmk
        return_type = as_Code(return_type)
529 3444 aaronmk
        body = as_Code(body)
530
531 3442 aaronmk
        self.function = function
532
        self.return_type = return_type
533
        self.body = body
534 3471 aaronmk
        self.params = params
535 3456 aaronmk
        self.modifiers = modifiers
536 3442 aaronmk
537
    def to_str(self, db):
538 3487 aaronmk
        params_str = (', '.join((p.to_str(db) for p in self.params)))
539 3442 aaronmk
        str_ = '''\
540 3471 aaronmk
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
541 3487 aaronmk
RETURNS '''+self.return_type.to_str(db)+'''
542 3448 aaronmk
LANGUAGE '''+self.body.lang+'''
543 3456 aaronmk
'''
544
        if self.modifiers != None: str_ += self.modifiers+'\n'
545
        str_ += '''\
546 3442 aaronmk
AS $$
547 3444 aaronmk
'''+self.body.to_str(db)+'''
548 3442 aaronmk
$$;
549
'''
550
        return str_
551
552 3469 aaronmk
class FunctionParam(TypedCol):
553
    def __init__(self, name, type_, default=None, out=False):
554
        TypedCol.__init__(self, name, type_, default)
555
556
        self.out = out
557
558
    def to_str(self, db):
559
        str_ = TypedCol.to_str(self, db)
560
        if self.out: str_ = 'OUT '+str_
561
        return str_
562
563
    def to_Col(self): return Col(self.name)
564
565 3454 aaronmk
### PL/pgSQL
566
567 3496 aaronmk
class ReturnQuery(Code):
568
    def __init__(self, query):
569
        Code.__init__(self)
570
571
        query = as_Code(query)
572
573
        self.query = query
574
575
    def to_str(self, db):
576
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
577
578 3515 aaronmk
## Exceptions
579
580 3509 aaronmk
class BaseExcHandler(BasicObject):
581
    def to_str(self, db, body): raise NotImplementedError()
582 3510 aaronmk
583
    def __repr__(self): return self.to_str(mockDb, '<body>')
584 3509 aaronmk
585
class ExcHandler(BaseExcHandler):
586 3454 aaronmk
    def __init__(self, exc, handler=None):
587
        if handler != None: handler = as_Code(handler)
588
589
        self.exc = exc
590
        self.handler = handler
591
592
    def to_str(self, db, body):
593
        body = as_Code(body)
594
595 3467 aaronmk
        if self.handler != None:
596
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
597 3463 aaronmk
        else: handler_str = ' NULL;\n'
598 3454 aaronmk
599
        str_ = '''\
600
BEGIN
601 3467 aaronmk
'''+strings.indent(body.to_str(db))+'''\
602 3454 aaronmk
EXCEPTION
603 3463 aaronmk
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
604 3454 aaronmk
END;\
605
'''
606
        return str_
607
608 3515 aaronmk
class NestedExcHandler(BaseExcHandler):
609
    def __init__(self, *handlers):
610
        '''
611
        @param handlers Sorted from outermost to innermost
612
        '''
613
        self.handlers = handlers
614
615
    def to_str(self, db, body):
616
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
617
        return body
618
619 3503 aaronmk
class ExcToWarning(Code):
620
    def __init__(self, return_):
621
        '''
622
        @param return_ Statement to return a default value in case of error
623
        '''
624
        Code.__init__(self)
625
626
        return_ = as_Code(return_)
627
628
        self.return_ = return_
629
630
    def to_str(self, db):
631
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
632
633 3454 aaronmk
unique_violation_handler = ExcHandler('unique_violation')
634
635 3468 aaronmk
plpythonu_error_handler = ExcHandler('internal_error', '''\
636
RAISE data_exception USING MESSAGE =
637
    regexp_replace(SQLERRM, E'^PL/Python: \\w+: ', '');
638
''')
639
640 3505 aaronmk
def data_exception_handler(handler):
641
    return ExcHandler('data_exception', handler)
642
643 3449 aaronmk
class RowExcIgnore(Code):
644
    def __init__(self, row_type, select_query, with_row, cols=None,
645 3455 aaronmk
        exc_handler=unique_violation_handler, row_var='row'):
646 3449 aaronmk
        Code.__init__(self, lang='plpgsql')
647
648 3482 aaronmk
        row_type = as_Code(row_type)
649 3449 aaronmk
        select_query = as_Code(select_query)
650
        with_row = as_Code(with_row)
651 3452 aaronmk
        row_var = as_Table(row_var)
652 3449 aaronmk
653
        self.row_type = row_type
654
        self.select_query = select_query
655
        self.with_row = with_row
656
        self.cols = cols
657 3455 aaronmk
        self.exc_handler = exc_handler
658 3452 aaronmk
        self.row_var = row_var
659 3449 aaronmk
660
    def to_str(self, db):
661 3452 aaronmk
        if self.cols == None: row_vars = [self.row_var]
662
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
663 3449 aaronmk
664
        str_ = '''\
665
DECLARE
666 3482 aaronmk
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
667 3449 aaronmk
BEGIN
668 3467 aaronmk
    /* Need an EXCEPTION block for each individual row because "When
669
    an error is caught by an EXCEPTION clause, [...] all changes to
670
    persistent database state within the block are rolled back."
671
    This is unfortunate because "A block containing an EXCEPTION
672
    clause is significantly more expensive to enter and exit than a
673
    block without one."
674 3449 aaronmk
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
675
#PLPGSQL-ERROR-TRAPPING)
676
    */
677
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
678 3467 aaronmk
'''+strings.indent(self.select_query.to_str(db), 2)+'''\
679 3449 aaronmk
    LOOP
680 3467 aaronmk
'''+strings.indent(self.exc_handler.to_str(db, self.with_row), 2)+'''\
681 3449 aaronmk
    END LOOP;
682
END;\
683
'''
684
        return str_
685
686 2986 aaronmk
##### Casts
687
688
class Cast(FunctionCall):
689
    def __init__(self, type_, value):
690
        value = as_Value(value)
691
692
        self.type_ = type_
693
        self.value = value
694
695
    def to_str(self, db):
696
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_+')'
697
698 3354 aaronmk
def cast_literal(value):
699 3429 aaronmk
    if not is_literal(value): return value
700 3354 aaronmk
701
    if util.is_str(value.value): value = Cast('text', value)
702
    return value
703
704 2335 aaronmk
##### Conditions
705 2259 aaronmk
706 3350 aaronmk
class NotCond(Code):
707
    def __init__(self, cond):
708 3445 aaronmk
        Code.__init__(self)
709
710 3350 aaronmk
        self.cond = cond
711
712
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
713
714 2398 aaronmk
class ColValueCond(Code):
715
    def __init__(self, col, value):
716 3445 aaronmk
        Code.__init__(self)
717
718 2398 aaronmk
        value = as_ValueCond(value)
719
720
        self.col = col
721
        self.value = value
722
723
    def to_str(self, db): return self.value.to_str(db, self.col)
724
725 2577 aaronmk
def combine_conds(conds, keyword=None):
726
    '''
727
    @param keyword The keyword to add before the conditions, if any
728
    '''
729
    str_ = ''
730
    if keyword != None:
731
        if conds == []: whitespace = ''
732
        elif len(conds) == 1: whitespace = ' '
733
        else: whitespace = '\n'
734
        str_ += keyword+whitespace
735
736
    str_ += '\nAND '.join(conds)
737
    return str_
738
739 2398 aaronmk
##### Condition column comparisons
740
741 2514 aaronmk
class ValueCond(BasicObject):
742 2213 aaronmk
    def __init__(self, value):
743 2858 aaronmk
        value = remove_col_rename(as_Value(value))
744 2213 aaronmk
745
        self.value = value
746 2214 aaronmk
747 2216 aaronmk
    def to_str(self, db, left_value):
748 2214 aaronmk
        '''
749 2216 aaronmk
        @param left_value The Code object that the condition is being applied on
750 2214 aaronmk
        '''
751
        raise NotImplemented()
752 2228 aaronmk
753 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
754 2211 aaronmk
755
class CompareCond(ValueCond):
756
    def __init__(self, value, operator='='):
757 2222 aaronmk
        '''
758
        @param operator By default, compares NULL values literally. Use '~=' or
759
            '~!=' to pass NULLs through.
760
        '''
761 2211 aaronmk
        ValueCond.__init__(self, value)
762
        self.operator = operator
763
764 2216 aaronmk
    def to_str(self, db, left_value):
765 2858 aaronmk
        left_value = remove_col_rename(as_Col(left_value))
766 2216 aaronmk
767 2222 aaronmk
        right_value = self.value
768
769
        # Parse operator
770 2216 aaronmk
        operator = self.operator
771 2222 aaronmk
        passthru_null_ref = [False]
772
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
773
        neg_ref = [False]
774
        operator = strings.remove_prefix('!', operator, neg_ref)
775 2844 aaronmk
        equals = operator.endswith('=') # also includes <=, >=
776 2222 aaronmk
777 2825 aaronmk
        # Handle nullable columns
778
        check_null = False
779 2844 aaronmk
        if not passthru_null_ref[0]: # NULLs compare equal
780 2857 aaronmk
            try: left_value = ensure_not_null(db, left_value)
781 2844 aaronmk
            except ensure_not_null_excs: # fall back to alternate method
782
                check_null = equals and isinstance(right_value, Col)
783 2837 aaronmk
            else:
784 2857 aaronmk
                if isinstance(left_value, EnsureNotNull):
785
                    right_value = ensure_not_null(db, right_value,
786
                        left_value.type) # apply same function to both sides
787 2825 aaronmk
788 2844 aaronmk
        if equals and is_null(right_value): operator = 'IS'
789
790 2825 aaronmk
        left = left_value.to_str(db)
791
        right = right_value.to_str(db)
792
793 2222 aaronmk
        # Create str
794
        str_ = left+' '+operator+' '+right
795 2825 aaronmk
        if check_null:
796 2578 aaronmk
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
797
        if neg_ref[0]: str_ = 'NOT '+str_
798 2222 aaronmk
        return str_
799 2216 aaronmk
800 2260 aaronmk
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
801
assume_literal = object()
802
803
def as_ValueCond(value, default_table=assume_literal):
804
    if not isinstance(value, ValueCond):
805
        if default_table is not assume_literal:
806 2748 aaronmk
            value = with_default_table(value, default_table)
807 2260 aaronmk
        return CompareCond(value)
808 2216 aaronmk
    else: return value
809 2219 aaronmk
810 2335 aaronmk
##### Joins
811
812 2352 aaronmk
join_same = object() # tells Join the left and right columns have the same name
813 2260 aaronmk
814 2353 aaronmk
# Tells Join the left and right columns have the same name and are never NULL
815
join_same_not_null = object()
816
817 2260 aaronmk
filter_out = object() # tells Join to filter out rows that match the join
818
819 2514 aaronmk
class Join(BasicObject):
820 2746 aaronmk
    def __init__(self, table, mapping={}, type_=None):
821 2260 aaronmk
        '''
822
        @param mapping dict(right_table_col=left_table_col, ...)
823 2352 aaronmk
            * if left_table_col is join_same: left_table_col = right_table_col
824 2353 aaronmk
              * Note that right_table_col must be a string
825
            * if left_table_col is join_same_not_null:
826
              left_table_col = right_table_col and both have NOT NULL constraint
827
              * Note that right_table_col must be a string
828 2260 aaronmk
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
829
            * filter_out: equivalent to 'LEFT' with the query filtered by
830
              `table_pkey IS NULL` (indicating no match)
831
        '''
832
        if util.is_str(table): table = Table(table)
833
        assert type_ == None or util.is_str(type_) or type_ is filter_out
834
835
        self.table = table
836
        self.mapping = mapping
837
        self.type_ = type_
838
839 2749 aaronmk
    def to_str(self, db, left_table_):
840 2260 aaronmk
        def join(entry):
841
            '''Parses non-USING joins'''
842
            right_table_col, left_table_col = entry
843
844 2353 aaronmk
            # Switch order (right_table_col is on the left in the comparison)
845
            left = right_table_col
846
            right = left_table_col
847 2749 aaronmk
            left_table = self.table
848
            right_table = left_table_
849 2353 aaronmk
850 2747 aaronmk
            # Parse left side
851 2748 aaronmk
            left = with_default_table(left, left_table)
852 2747 aaronmk
853 2260 aaronmk
            # Parse special values
854 2747 aaronmk
            left_on_right = Col(left.name, right_table)
855
            if right is join_same: right = left_on_right
856 2353 aaronmk
            elif right is join_same_not_null:
857 2747 aaronmk
                right = CompareCond(left_on_right, '~=')
858 2260 aaronmk
859 2747 aaronmk
            # Parse right side
860 2353 aaronmk
            right = as_ValueCond(right, right_table)
861 2747 aaronmk
862
            return right.to_str(db, left)
863 2260 aaronmk
864 2265 aaronmk
        # Create join condition
865
        type_ = self.type_
866 2276 aaronmk
        joins = self.mapping
867 2746 aaronmk
        if joins == {}: join_cond = None
868
        elif type_ is not filter_out and reduce(operator.and_,
869 2460 aaronmk
            (v is join_same_not_null for v in joins.itervalues())):
870 2260 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
871 2747 aaronmk
            cols = map(to_name_only_col, joins.iterkeys())
872
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
873 2757 aaronmk
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
874 2260 aaronmk
875 2757 aaronmk
        if isinstance(self.table, NamedTable): whitespace = '\n'
876
        else: whitespace = ' '
877
878 2260 aaronmk
        # Create join
879
        if type_ is filter_out: type_ = 'LEFT'
880 2266 aaronmk
        str_ = ''
881
        if type_ != None: str_ += type_+' '
882 2757 aaronmk
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
883
        if join_cond != None: str_ += whitespace+join_cond
884 2266 aaronmk
        return str_
885 2349 aaronmk
886 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
887 2424 aaronmk
888
##### Value exprs
889
890 3089 aaronmk
all_cols = CustomCode('*')
891
892 2737 aaronmk
default = CustomCode('DEFAULT')
893
894 3090 aaronmk
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
895 2674 aaronmk
896 3061 aaronmk
class Coalesce(FunctionCall):
897
    def __init__(self, *args):
898
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
899 3060 aaronmk
900 3062 aaronmk
class Nullif(FunctionCall):
901
    def __init__(self, *args):
902
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
903
904 2850 aaronmk
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
905 2958 aaronmk
null_sentinels = {
906
    'character varying': r'\N',
907
    'double precision': 'NaN',
908
    'integer': 2147483647,
909
    'text': r'\N',
910
    'timestamp with time zone': 'infinity'
911
}
912 2692 aaronmk
913 3061 aaronmk
class EnsureNotNull(Coalesce):
914 2850 aaronmk
    def __init__(self, value, type_):
915 3061 aaronmk
        Coalesce.__init__(self, as_Col(value),
916 2988 aaronmk
            Cast(type_, null_sentinels[type_]))
917 2850 aaronmk
918
        self.type = type_
919 3001 aaronmk
920
    def to_str(self, db):
921
        col = self.args[0]
922
        index_col_ = index_col(col)
923
        if index_col_ != None: return index_col_.to_str(db)
924 3061 aaronmk
        return Coalesce.to_str(self, db)
925 2850 aaronmk
926 2737 aaronmk
##### Table exprs
927
928
class Values(Code):
929
    def __init__(self, values):
930 2739 aaronmk
        '''
931
        @param values [...]|[[...], ...] Can be one or multiple rows.
932
        '''
933 3445 aaronmk
        Code.__init__(self)
934
935 2739 aaronmk
        rows = values
936
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
937
            rows = [values]
938
        for i, row in enumerate(rows):
939
            rows[i] = map(remove_col_rename, map(as_Value, row))
940 2737 aaronmk
941 2739 aaronmk
        self.rows = rows
942 2737 aaronmk
943
    def to_str(self, db):
944 3491 aaronmk
        return 'VALUES '+(', '.join((List(*r).to_str(db) for r in self.rows)))
945 2737 aaronmk
946 2740 aaronmk
def NamedValues(name, cols, values):
947 2745 aaronmk
    '''
948 3048 aaronmk
    @param cols None|[...]
949 2745 aaronmk
    @post `cols` will be changed to Col objects with the table set to `name`.
950
    '''
951 2834 aaronmk
    table = NamedTable(name, Values(values), cols)
952 3048 aaronmk
    if cols != None: set_cols_table(table, cols)
953 2834 aaronmk
    return table
954 2740 aaronmk
955 2674 aaronmk
##### Database structure
956
957 2840 aaronmk
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
958
959 2851 aaronmk
def ensure_not_null(db, col, type_=None):
960 2840 aaronmk
    '''
961 2855 aaronmk
    @param col If type_ is not set, must have an underlying column.
962 2851 aaronmk
    @param type_ If set, overrides the underlying column's type.
963 2840 aaronmk
    @return EnsureNotNull|Col
964
    @throws ensure_not_null_excs
965
    '''
966 2855 aaronmk
    nullable = True
967
    try: typed_col = db.col_info(underlying_col(col))
968
    except NoUnderlyingTableException:
969 3355 aaronmk
        col = remove_col_rename(col)
970 3429 aaronmk
        if is_literal(col) and not is_null(col): nullable = False
971 3355 aaronmk
        elif type_ == None: raise
972 2855 aaronmk
    else:
973
        if type_ == None: type_ = typed_col.type
974
        nullable = typed_col.nullable
975
976 2953 aaronmk
    if nullable:
977
        try: col = EnsureNotNull(col, type_)
978
        except KeyError, e:
979
            # Warn of no null sentinel for type, even if caller catches error
980
            warnings.warn(UserWarning(exc.str_(e)))
981
            raise
982
983 2840 aaronmk
    return col