Project

General

Profile

1 2211 aaronmk
# SQL code generation
2
3 2748 aaronmk
import copy
4 3518 aaronmk
import itertools
5 2276 aaronmk
import operator
6 3158 aaronmk
from ordereddict import OrderedDict
7 2568 aaronmk
import re
8 2653 aaronmk
import UserDict
9 2953 aaronmk
import warnings
10 2276 aaronmk
11 2667 aaronmk
import dicts
12 2953 aaronmk
import exc
13 2701 aaronmk
import iters
14
import lists
15 2360 aaronmk
import objects
16 2222 aaronmk
import strings
17 2227 aaronmk
import util
18 2211 aaronmk
19 2587 aaronmk
##### Names
20 2499 aaronmk
21 2608 aaronmk
identifier_max_len = 63 # works for both PostgreSQL and MySQL
22 2587 aaronmk
23 2932 aaronmk
def concat(str_, suffix):
24 2609 aaronmk
    '''Preserves version so that it won't be truncated off the string, leading
25
    to collisions.'''
26 2613 aaronmk
    # Preserve version
27 2995 aaronmk
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
28 2985 aaronmk
    if match:
29
        str_, old_suffix = match.groups()
30
        suffix = old_suffix+suffix
31 2613 aaronmk
32 2932 aaronmk
    return strings.concat(str_, suffix, identifier_max_len)
33 2587 aaronmk
34 2932 aaronmk
def truncate(str_): return concat(str_, '')
35 2842 aaronmk
36 2575 aaronmk
def is_safe_name(name):
37 2583 aaronmk
    '''A name is safe *and unambiguous* if it:
38
    * contains only *lowercase* word (\w) characters
39
    * doesn't start with a digit
40
    * contains "_", so that it's not a keyword
41 2984 aaronmk
    '''
42
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
43 2568 aaronmk
44 2499 aaronmk
def esc_name(name, quote='"'):
45
    return quote + name.replace(quote, quote+quote) + quote
46
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
47
48 3320 aaronmk
def unesc_name(name, quote='"'):
49
    removed_ref = [False]
50
    name = strings.remove_prefix(quote, name, removed_ref)
51
    if removed_ref[0]:
52
        name = strings.remove_suffix(quote, name, removed_ref)
53
        assert removed_ref[0]
54
        name = name.replace(quote+quote, quote)
55
    return name
56
57 2513 aaronmk
def clean_name(name): return name.replace('"', '').replace('`', '')
58
59 3041 aaronmk
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
60
61 3182 aaronmk
def lstrip(str_):
62
    '''Also removes comments.'''
63
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
64
    return str_.lstrip()
65
66 2659 aaronmk
##### General SQL code objects
67 2219 aaronmk
68 2349 aaronmk
class MockDb:
69 2503 aaronmk
    def esc_value(self, value): return strings.repr_no_u(value)
70 2349 aaronmk
71 2499 aaronmk
    def esc_name(self, name): return esc_name(name)
72 2859 aaronmk
73
    def col_info(self, col):
74
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
75
76 2349 aaronmk
mockDb = MockDb()
77
78 2514 aaronmk
class BasicObject(objects.BasicObject):
79
    def __str__(self): return clean_name(strings.repr_no_u(self))
80
81 2659 aaronmk
##### Unparameterized code objects
82
83 2514 aaronmk
class Code(BasicObject):
84 3446 aaronmk
    def __init__(self, lang='sql'):
85
        self.lang = lang
86 3445 aaronmk
87 2658 aaronmk
    def to_str(self, db): raise NotImplementedError()
88 2349 aaronmk
89 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb)
90 2211 aaronmk
91 2269 aaronmk
class CustomCode(Code):
92 3445 aaronmk
    def __init__(self, str_):
93
        Code.__init__(self)
94
95
        self.str_ = str_
96 2256 aaronmk
97
    def to_str(self, db): return self.str_
98
99 2815 aaronmk
def as_Code(value, db=None):
100
    '''
101
    @param db If set, runs db.std_code() on the value.
102
    '''
103 3447 aaronmk
    if isinstance(value, Code): return value
104
105 2815 aaronmk
    if util.is_str(value):
106
        if db != None: value = db.std_code(value)
107
        return CustomCode(value)
108 2659 aaronmk
    else: return Literal(value)
109
110 2540 aaronmk
class Expr(Code):
111 3445 aaronmk
    def __init__(self, expr):
112
        Code.__init__(self)
113
114
        self.expr = expr
115 2540 aaronmk
116
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
117
118 3086 aaronmk
##### Names
119
120
class Name(Code):
121 3087 aaronmk
    def __init__(self, name):
122 3445 aaronmk
        Code.__init__(self)
123
124 3087 aaronmk
        name = truncate(name)
125
126
        self.name = name
127 3086 aaronmk
128
    def to_str(self, db): return db.esc_name(self.name)
129
130
def as_Name(value):
131
    if isinstance(value, Code): return value
132
    else: return Name(value)
133
134 2335 aaronmk
##### Literal values
135
136 2216 aaronmk
class Literal(Code):
137 3445 aaronmk
    def __init__(self, value):
138
        Code.__init__(self)
139
140
        self.value = value
141 2213 aaronmk
142
    def to_str(self, db): return db.esc_value(self.value)
143 2211 aaronmk
144 2400 aaronmk
def as_Value(value):
145
    if isinstance(value, Code): return value
146
    else: return Literal(value)
147
148 3429 aaronmk
def is_literal(value): return isinstance(value, Literal)
149 2216 aaronmk
150 3429 aaronmk
def is_null(value): return is_literal(value) and value.value == None
151
152 2711 aaronmk
##### Derived elements
153
154
src_self = object() # tells Col that it is its own source column
155
156
class Derived(Code):
157
    def __init__(self, srcs):
158 2712 aaronmk
        '''An element which was derived from some other element(s).
159 2711 aaronmk
        @param srcs See self.set_srcs()
160
        '''
161 3445 aaronmk
        Code.__init__(self)
162
163 2711 aaronmk
        self.set_srcs(srcs)
164
165 2713 aaronmk
    def set_srcs(self, srcs, overwrite=True):
166 2711 aaronmk
        '''
167
        @param srcs (self_type...)|src_self The element(s) this is derived from
168
        '''
169 2713 aaronmk
        if not overwrite and self.srcs != (): return # already set
170
171 2711 aaronmk
        if srcs == src_self: srcs = (self,)
172
        srcs = tuple(srcs) # make Col hashable
173
        self.srcs = srcs
174
175
    def _compare_on(self):
176
        compare_on = self.__dict__.copy()
177
        del compare_on['srcs'] # ignore
178
        return compare_on
179
180
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
181
182 2335 aaronmk
##### Tables
183
184 2712 aaronmk
class Table(Derived):
185 2991 aaronmk
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
186 2211 aaronmk
        '''
187
        @param schema str|None (for no schema)
188 2712 aaronmk
        @param srcs (Table...)|src_self See Derived.set_srcs()
189 2211 aaronmk
        '''
190 2712 aaronmk
        Derived.__init__(self, srcs)
191
192 3091 aaronmk
        if util.is_str(name): name = truncate(name)
193 2843 aaronmk
194 2211 aaronmk
        self.name = name
195
        self.schema = schema
196 2991 aaronmk
        self.is_temp = is_temp
197 3000 aaronmk
        self.index_cols = {}
198 2211 aaronmk
199 2348 aaronmk
    def to_str(self, db):
200
        str_ = ''
201 3088 aaronmk
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
202
        str_ += as_Name(self.name).to_str(db)
203 2348 aaronmk
        return str_
204 2336 aaronmk
205
    def to_Table(self): return self
206 3000 aaronmk
207
    def _compare_on(self):
208
        compare_on = Derived._compare_on(self)
209
        del compare_on['index_cols'] # ignore
210
        return compare_on
211 2211 aaronmk
212 2835 aaronmk
def is_underlying_table(table):
213
    return isinstance(table, Table) and table.to_Table() is table
214 2832 aaronmk
215 2902 aaronmk
class NoUnderlyingTableException(Exception): pass
216
217
def underlying_table(table):
218
    table = remove_table_rename(table)
219
    if not is_underlying_table(table): raise NoUnderlyingTableException
220
    return table
221
222 2776 aaronmk
def as_Table(table, schema=None):
223 2270 aaronmk
    if table == None or isinstance(table, Code): return table
224 2776 aaronmk
    else: return Table(table, schema)
225 2219 aaronmk
226 3101 aaronmk
def suffixed_table(table, suffix):
227 3128 aaronmk
    table = copy.copy(table) # don't modify input!
228
    table.name = concat(table.name, suffix)
229
    return table
230 2707 aaronmk
231 2336 aaronmk
class NamedTable(Table):
232
    def __init__(self, name, code, cols=None):
233
        Table.__init__(self, name)
234
235 3016 aaronmk
        code = as_Table(code)
236 2741 aaronmk
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
237 3020 aaronmk
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
238 2336 aaronmk
239
        self.code = code
240
        self.cols = cols
241
242
    def to_str(self, db):
243 3026 aaronmk
        str_ = self.code.to_str(db)
244
        if str_.find('\n') >= 0: whitespace = '\n'
245
        else: whitespace = ' '
246
        str_ += whitespace+'AS '+Table.to_str(self, db)
247 2742 aaronmk
        if self.cols != None:
248
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
249 2336 aaronmk
        return str_
250
251
    def to_Table(self): return Table(self.name)
252
253 2753 aaronmk
def remove_table_rename(table):
254
    if isinstance(table, NamedTable): table = table.code
255
    return table
256
257 2335 aaronmk
##### Columns
258
259 2711 aaronmk
class Col(Derived):
260 2701 aaronmk
    def __init__(self, name, table=None, srcs=()):
261 2211 aaronmk
        '''
262
        @param table Table|None (for no table)
263 2711 aaronmk
        @param srcs (Col...)|src_self See Derived.set_srcs()
264 2211 aaronmk
        '''
265 2711 aaronmk
        Derived.__init__(self, srcs)
266
267 3091 aaronmk
        if util.is_str(name): name = truncate(name)
268 2241 aaronmk
        if util.is_str(table): table = Table(table)
269 2211 aaronmk
        assert table == None or isinstance(table, Table)
270
271
        self.name = name
272
        self.table = table
273
274 2989 aaronmk
    def to_str(self, db, for_str=False):
275 3088 aaronmk
        str_ = as_Name(self.name).to_str(db)
276 2989 aaronmk
        if for_str: str_ = clean_name(str_)
277 2933 aaronmk
        if self.table != None:
278 2989 aaronmk
            table = self.table.to_Table()
279
            if for_str: str_ = concat(str(table), '.'+str_)
280
            else: str_ = table.to_str(db)+'.'+str_
281 2211 aaronmk
        return str_
282 2314 aaronmk
283 2989 aaronmk
    def __str__(self): return self.to_str(mockDb, for_str=True)
284 2933 aaronmk
285 2314 aaronmk
    def to_Col(self): return self
286 2211 aaronmk
287 3512 aaronmk
def is_col(col): return isinstance(col, Col)
288 2393 aaronmk
289 3512 aaronmk
def is_table_col(col): return is_col(col) and col.table != None
290
291 3000 aaronmk
def index_col(col):
292
    if not is_table_col(col): return None
293 3104 aaronmk
294
    table = col.table
295
    try: name = table.index_cols[col.name]
296
    except KeyError: return None
297
    else: return Col(name, table, col.srcs)
298 2999 aaronmk
299 3024 aaronmk
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
300 2996 aaronmk
301 2563 aaronmk
def as_Col(col, table=None, name=None):
302
    '''
303
    @param name If not None, any non-Col input will be renamed using NamedCol.
304
    '''
305
    if name != None:
306
        col = as_Value(col)
307
        if not isinstance(col, Col): col = NamedCol(name, col)
308 2333 aaronmk
309
    if isinstance(col, Code): return col
310 3513 aaronmk
    elif util.is_str(col): return Col(col, table)
311
    else: return Literal(col)
312 2260 aaronmk
313 3093 aaronmk
def with_table(col, table):
314 3105 aaronmk
    if isinstance(col, NamedCol): pass # doesn't take a table
315
    elif isinstance(col, FunctionCall):
316
        col = copy.deepcopy(col) # don't modify input!
317
        col.args[0].table = table
318 3493 aaronmk
    elif isinstance(col, Col):
319 3098 aaronmk
        col = copy.copy(col) # don't modify input!
320
        col.table = table
321 3093 aaronmk
    return col
322
323 3100 aaronmk
def with_default_table(col, table):
324 2747 aaronmk
    col = as_Col(col)
325 3100 aaronmk
    if col.table == None: col = with_table(col, table)
326 2747 aaronmk
    return col
327
328 2744 aaronmk
def set_cols_table(table, cols):
329
    table = as_Table(table)
330
331
    for i, col in enumerate(cols):
332
        col = cols[i] = as_Col(col)
333
        col.table = table
334
335 2401 aaronmk
def to_name_only_col(col, check_table=None):
336
    col = as_Col(col)
337 3020 aaronmk
    if not is_table_col(col): return col
338 2401 aaronmk
339
    if check_table != None:
340
        table = col.table
341
        assert table == None or table == check_table
342
    return Col(col.name)
343
344 2993 aaronmk
def suffixed_col(col, suffix):
345
    return Col(concat(col.name, suffix), col.table, col.srcs)
346
347 3516 aaronmk
def has_srcs(col): return is_col(col) and col.srcs
348
349 3518 aaronmk
def cross_join_srcs(cols):
350
    cols = filter(has_srcs, cols) # empty srcs will mess up the cross join
351
    srcs = [[s.name for s in c.srcs] for c in cols]
352
    return [Col(','.join(s)) for s in itertools.product(*srcs)]
353 3514 aaronmk
354 2323 aaronmk
class NamedCol(Col):
355 2229 aaronmk
    def __init__(self, name, code):
356 2310 aaronmk
        Col.__init__(self, name)
357
358 3016 aaronmk
        code = as_Value(code)
359 2229 aaronmk
360
        self.code = code
361
362
    def to_str(self, db):
363 2310 aaronmk
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
364 2314 aaronmk
365
    def to_Col(self): return Col(self.name)
366 2229 aaronmk
367 2462 aaronmk
def remove_col_rename(col):
368
    if isinstance(col, NamedCol): col = col.code
369
    return col
370
371 2830 aaronmk
def underlying_col(col):
372
    col = remove_col_rename(col)
373 2849 aaronmk
    if not isinstance(col, Col): raise NoUnderlyingTableException
374
375 2902 aaronmk
    return Col(col.name, underlying_table(col.table), col.srcs)
376 2830 aaronmk
377 2703 aaronmk
def wrap(wrap_func, value):
378
    '''Wraps a value, propagating any column renaming to the returned value.'''
379
    if isinstance(value, NamedCol):
380
        return NamedCol(value.name, wrap_func(value.code))
381
    else: return wrap_func(value)
382
383 2667 aaronmk
class ColDict(dicts.DictProxy):
384 2564 aaronmk
    '''A dict that automatically makes inserted entries Col objects'''
385
386 2645 aaronmk
    def __init__(self, db, keys_table, dict_={}):
387 3158 aaronmk
        dicts.DictProxy.__init__(self, OrderedDict())
388 2667 aaronmk
389 2645 aaronmk
        keys_table = as_Table(keys_table)
390
391 2642 aaronmk
        self.db = db
392 2641 aaronmk
        self.table = keys_table
393 2653 aaronmk
        self.update(dict_) # after setting vars because __setitem__() needs them
394 2641 aaronmk
395 2667 aaronmk
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
396 2655 aaronmk
397 2667 aaronmk
    def __getitem__(self, key):
398
        return dicts.DictProxy.__getitem__(self, self._key(key))
399 2653 aaronmk
400 2564 aaronmk
    def __setitem__(self, key, value):
401 2642 aaronmk
        key = self._key(key)
402 2819 aaronmk
        if value == None: value = self.db.col_info(key).default
403 2667 aaronmk
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
404 2564 aaronmk
405 2641 aaronmk
    def _key(self, key): return as_Col(key, self.table)
406 2564 aaronmk
407 3491 aaronmk
##### Composite types
408
409
class List(Code):
410
    def __init__(self, *values):
411
        Code.__init__(self)
412
413
        self.values = values
414
415
    def to_str(self, db):
416
        return '('+(', '.join((v.to_str(db) for v in self.values)))+')'
417
418 3492 aaronmk
class Tuple(List):
419
    def to_str(self, db): return 'ROW'+List.to_str(self, db)
420
421 3469 aaronmk
#### Definitions
422
423
class TypedCol(Col):
424
    def __init__(self, name, type_, default=None, nullable=True,
425
        constraints=None):
426
        assert default == None or isinstance(default, Code)
427
428
        Col.__init__(self, name)
429
430
        self.type = type_
431
        self.default = default
432
        self.nullable = nullable
433
        self.constraints = constraints
434
435
    def to_str(self, db):
436 3481 aaronmk
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
437 3469 aaronmk
        if not self.nullable: str_ += ' NOT NULL'
438
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
439
        if self.constraints != None: str_ += ' '+self.constraints
440
        return str_
441
442
    def to_Col(self): return Col(self.name)
443
444 3488 aaronmk
class SetOf(Code):
445
    def __init__(self, type_):
446
        Code.__init__(self)
447
448
        self.type = type_
449
450
    def to_str(self, db):
451
        return 'SETOF '+self.type.to_str(db)
452
453 3483 aaronmk
class RowType(Code):
454
    def __init__(self, table):
455
        Code.__init__(self)
456
457
        self.table = table
458
459
    def to_str(self, db):
460
        return self.table.to_str(db)+'%ROWTYPE'
461
462 3485 aaronmk
class ColType(Code):
463
    def __init__(self, col):
464
        Code.__init__(self)
465
466
        self.col = col
467
468
    def to_str(self, db):
469
        return self.col.to_str(db)+'%TYPE'
470
471 2524 aaronmk
##### Functions
472
473 2912 aaronmk
Function = Table
474 2911 aaronmk
as_Function = as_Table
475
476 2691 aaronmk
class InternalFunction(CustomCode): pass
477
478 3442 aaronmk
#### Calls
479
480 2941 aaronmk
class NamedArg(NamedCol):
481
    def __init__(self, name, value):
482
        NamedCol.__init__(self, name, value)
483
484
    def to_str(self, db):
485
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
486
487 2524 aaronmk
class FunctionCall(Code):
488 2941 aaronmk
    def __init__(self, function, *args, **kw_args):
489 2524 aaronmk
        '''
490 2690 aaronmk
        @param args [Code|literal-value...] The function's arguments
491 2524 aaronmk
        '''
492 3445 aaronmk
        Code.__init__(self)
493
494 3016 aaronmk
        function = as_Function(function)
495 2941 aaronmk
        def filter_(arg): return remove_col_rename(as_Value(arg))
496
        args = map(filter_, args)
497
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
498 2524 aaronmk
499
        self.function = function
500
        self.args = args
501
502
    def to_str(self, db):
503
        args_str = ', '.join((v.to_str(db) for v in self.args))
504
        return self.function.to_str(db)+'('+args_str+')'
505
506 2533 aaronmk
def wrap_in_func(function, value):
507
    '''Wraps a value inside a function call.
508
    Propagates any column renaming to the returned value.
509
    '''
510 2703 aaronmk
    return wrap(lambda v: FunctionCall(function, v), value)
511 2533 aaronmk
512 2561 aaronmk
def unwrap_func_call(func_call, check_name=None):
513
    '''Unwraps any function call to its first argument.
514
    Also removes any column renaming.
515
    '''
516
    func_call = remove_col_rename(func_call)
517
    if not isinstance(func_call, FunctionCall): return func_call
518
519
    if check_name != None:
520
        name = func_call.function.name
521
        assert name == None or name == check_name
522
    return func_call.args[0]
523
524 3442 aaronmk
#### Definitions
525
526
class FunctionDef(Code):
527 3471 aaronmk
    def __init__(self, function, return_type, body, params=[], modifiers=None):
528 3445 aaronmk
        Code.__init__(self)
529
530 3487 aaronmk
        return_type = as_Code(return_type)
531 3444 aaronmk
        body = as_Code(body)
532
533 3442 aaronmk
        self.function = function
534
        self.return_type = return_type
535
        self.body = body
536 3471 aaronmk
        self.params = params
537 3456 aaronmk
        self.modifiers = modifiers
538 3442 aaronmk
539
    def to_str(self, db):
540 3487 aaronmk
        params_str = (', '.join((p.to_str(db) for p in self.params)))
541 3442 aaronmk
        str_ = '''\
542 3471 aaronmk
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
543 3487 aaronmk
RETURNS '''+self.return_type.to_str(db)+'''
544 3448 aaronmk
LANGUAGE '''+self.body.lang+'''
545 3456 aaronmk
'''
546
        if self.modifiers != None: str_ += self.modifiers+'\n'
547
        str_ += '''\
548 3442 aaronmk
AS $$
549 3444 aaronmk
'''+self.body.to_str(db)+'''
550 3442 aaronmk
$$;
551
'''
552
        return str_
553
554 3469 aaronmk
class FunctionParam(TypedCol):
555
    def __init__(self, name, type_, default=None, out=False):
556
        TypedCol.__init__(self, name, type_, default)
557
558
        self.out = out
559
560
    def to_str(self, db):
561
        str_ = TypedCol.to_str(self, db)
562
        if self.out: str_ = 'OUT '+str_
563
        return str_
564
565
    def to_Col(self): return Col(self.name)
566
567 3454 aaronmk
### PL/pgSQL
568
569 3496 aaronmk
class ReturnQuery(Code):
570
    def __init__(self, query):
571
        Code.__init__(self)
572
573
        query = as_Code(query)
574
575
        self.query = query
576
577
    def to_str(self, db):
578
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
579
580 3515 aaronmk
## Exceptions
581
582 3509 aaronmk
class BaseExcHandler(BasicObject):
583
    def to_str(self, db, body): raise NotImplementedError()
584 3510 aaronmk
585
    def __repr__(self): return self.to_str(mockDb, '<body>')
586 3509 aaronmk
587
class ExcHandler(BaseExcHandler):
588 3454 aaronmk
    def __init__(self, exc, handler=None):
589
        if handler != None: handler = as_Code(handler)
590
591
        self.exc = exc
592
        self.handler = handler
593
594
    def to_str(self, db, body):
595
        body = as_Code(body)
596
597 3467 aaronmk
        if self.handler != None:
598
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
599 3463 aaronmk
        else: handler_str = ' NULL;\n'
600 3454 aaronmk
601
        str_ = '''\
602
BEGIN
603 3467 aaronmk
'''+strings.indent(body.to_str(db))+'''\
604 3454 aaronmk
EXCEPTION
605 3463 aaronmk
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
606 3454 aaronmk
END;\
607
'''
608
        return str_
609
610 3515 aaronmk
class NestedExcHandler(BaseExcHandler):
611
    def __init__(self, *handlers):
612
        '''
613
        @param handlers Sorted from outermost to innermost
614
        '''
615
        self.handlers = handlers
616
617
    def to_str(self, db, body):
618
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
619
        return body
620
621 3503 aaronmk
class ExcToWarning(Code):
622
    def __init__(self, return_):
623
        '''
624
        @param return_ Statement to return a default value in case of error
625
        '''
626
        Code.__init__(self)
627
628
        return_ = as_Code(return_)
629
630
        self.return_ = return_
631
632
    def to_str(self, db):
633
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
634
635 3454 aaronmk
unique_violation_handler = ExcHandler('unique_violation')
636
637 3468 aaronmk
plpythonu_error_handler = ExcHandler('internal_error', '''\
638
RAISE data_exception USING MESSAGE =
639
    regexp_replace(SQLERRM, E'^PL/Python: \\w+: ', '');
640
''')
641
642 3505 aaronmk
def data_exception_handler(handler):
643
    return ExcHandler('data_exception', handler)
644
645 3449 aaronmk
class RowExcIgnore(Code):
646
    def __init__(self, row_type, select_query, with_row, cols=None,
647 3455 aaronmk
        exc_handler=unique_violation_handler, row_var='row'):
648 3449 aaronmk
        Code.__init__(self, lang='plpgsql')
649
650 3482 aaronmk
        row_type = as_Code(row_type)
651 3449 aaronmk
        select_query = as_Code(select_query)
652
        with_row = as_Code(with_row)
653 3452 aaronmk
        row_var = as_Table(row_var)
654 3449 aaronmk
655
        self.row_type = row_type
656
        self.select_query = select_query
657
        self.with_row = with_row
658
        self.cols = cols
659 3455 aaronmk
        self.exc_handler = exc_handler
660 3452 aaronmk
        self.row_var = row_var
661 3449 aaronmk
662
    def to_str(self, db):
663 3452 aaronmk
        if self.cols == None: row_vars = [self.row_var]
664
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
665 3449 aaronmk
666
        str_ = '''\
667
DECLARE
668 3482 aaronmk
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
669 3449 aaronmk
BEGIN
670 3467 aaronmk
    /* Need an EXCEPTION block for each individual row because "When
671
    an error is caught by an EXCEPTION clause, [...] all changes to
672
    persistent database state within the block are rolled back."
673
    This is unfortunate because "A block containing an EXCEPTION
674
    clause is significantly more expensive to enter and exit than a
675
    block without one."
676 3449 aaronmk
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
677
#PLPGSQL-ERROR-TRAPPING)
678
    */
679
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
680 3467 aaronmk
'''+strings.indent(self.select_query.to_str(db), 2)+'''\
681 3449 aaronmk
    LOOP
682 3467 aaronmk
'''+strings.indent(self.exc_handler.to_str(db, self.with_row), 2)+'''\
683 3449 aaronmk
    END LOOP;
684
END;\
685
'''
686
        return str_
687
688 2986 aaronmk
##### Casts
689
690
class Cast(FunctionCall):
691
    def __init__(self, type_, value):
692
        value = as_Value(value)
693
694
        self.type_ = type_
695
        self.value = value
696
697
    def to_str(self, db):
698
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_+')'
699
700 3354 aaronmk
def cast_literal(value):
701 3429 aaronmk
    if not is_literal(value): return value
702 3354 aaronmk
703
    if util.is_str(value.value): value = Cast('text', value)
704
    return value
705
706 2335 aaronmk
##### Conditions
707 2259 aaronmk
708 3350 aaronmk
class NotCond(Code):
709
    def __init__(self, cond):
710 3445 aaronmk
        Code.__init__(self)
711
712 3350 aaronmk
        self.cond = cond
713
714
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
715
716 2398 aaronmk
class ColValueCond(Code):
717
    def __init__(self, col, value):
718 3445 aaronmk
        Code.__init__(self)
719
720 2398 aaronmk
        value = as_ValueCond(value)
721
722
        self.col = col
723
        self.value = value
724
725
    def to_str(self, db): return self.value.to_str(db, self.col)
726
727 2577 aaronmk
def combine_conds(conds, keyword=None):
728
    '''
729
    @param keyword The keyword to add before the conditions, if any
730
    '''
731
    str_ = ''
732
    if keyword != None:
733
        if conds == []: whitespace = ''
734
        elif len(conds) == 1: whitespace = ' '
735
        else: whitespace = '\n'
736
        str_ += keyword+whitespace
737
738
    str_ += '\nAND '.join(conds)
739
    return str_
740
741 2398 aaronmk
##### Condition column comparisons
742
743 2514 aaronmk
class ValueCond(BasicObject):
744 2213 aaronmk
    def __init__(self, value):
745 2858 aaronmk
        value = remove_col_rename(as_Value(value))
746 2213 aaronmk
747
        self.value = value
748 2214 aaronmk
749 2216 aaronmk
    def to_str(self, db, left_value):
750 2214 aaronmk
        '''
751 2216 aaronmk
        @param left_value The Code object that the condition is being applied on
752 2214 aaronmk
        '''
753
        raise NotImplemented()
754 2228 aaronmk
755 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
756 2211 aaronmk
757
class CompareCond(ValueCond):
758
    def __init__(self, value, operator='='):
759 2222 aaronmk
        '''
760
        @param operator By default, compares NULL values literally. Use '~=' or
761
            '~!=' to pass NULLs through.
762
        '''
763 2211 aaronmk
        ValueCond.__init__(self, value)
764
        self.operator = operator
765
766 2216 aaronmk
    def to_str(self, db, left_value):
767 2858 aaronmk
        left_value = remove_col_rename(as_Col(left_value))
768 2216 aaronmk
769 2222 aaronmk
        right_value = self.value
770
771
        # Parse operator
772 2216 aaronmk
        operator = self.operator
773 2222 aaronmk
        passthru_null_ref = [False]
774
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
775
        neg_ref = [False]
776
        operator = strings.remove_prefix('!', operator, neg_ref)
777 2844 aaronmk
        equals = operator.endswith('=') # also includes <=, >=
778 2222 aaronmk
779 2825 aaronmk
        # Handle nullable columns
780
        check_null = False
781 2844 aaronmk
        if not passthru_null_ref[0]: # NULLs compare equal
782 2857 aaronmk
            try: left_value = ensure_not_null(db, left_value)
783 2844 aaronmk
            except ensure_not_null_excs: # fall back to alternate method
784
                check_null = equals and isinstance(right_value, Col)
785 2837 aaronmk
            else:
786 2857 aaronmk
                if isinstance(left_value, EnsureNotNull):
787
                    right_value = ensure_not_null(db, right_value,
788
                        left_value.type) # apply same function to both sides
789 2825 aaronmk
790 2844 aaronmk
        if equals and is_null(right_value): operator = 'IS'
791
792 2825 aaronmk
        left = left_value.to_str(db)
793
        right = right_value.to_str(db)
794
795 2222 aaronmk
        # Create str
796
        str_ = left+' '+operator+' '+right
797 2825 aaronmk
        if check_null:
798 2578 aaronmk
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
799
        if neg_ref[0]: str_ = 'NOT '+str_
800 2222 aaronmk
        return str_
801 2216 aaronmk
802 2260 aaronmk
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
803
assume_literal = object()
804
805
def as_ValueCond(value, default_table=assume_literal):
806
    if not isinstance(value, ValueCond):
807
        if default_table is not assume_literal:
808 2748 aaronmk
            value = with_default_table(value, default_table)
809 2260 aaronmk
        return CompareCond(value)
810 2216 aaronmk
    else: return value
811 2219 aaronmk
812 2335 aaronmk
##### Joins
813
814 2352 aaronmk
join_same = object() # tells Join the left and right columns have the same name
815 2260 aaronmk
816 2353 aaronmk
# Tells Join the left and right columns have the same name and are never NULL
817
join_same_not_null = object()
818
819 2260 aaronmk
filter_out = object() # tells Join to filter out rows that match the join
820
821 2514 aaronmk
class Join(BasicObject):
822 2746 aaronmk
    def __init__(self, table, mapping={}, type_=None):
823 2260 aaronmk
        '''
824
        @param mapping dict(right_table_col=left_table_col, ...)
825 2352 aaronmk
            * if left_table_col is join_same: left_table_col = right_table_col
826 2353 aaronmk
              * Note that right_table_col must be a string
827
            * if left_table_col is join_same_not_null:
828
              left_table_col = right_table_col and both have NOT NULL constraint
829
              * Note that right_table_col must be a string
830 2260 aaronmk
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
831
            * filter_out: equivalent to 'LEFT' with the query filtered by
832
              `table_pkey IS NULL` (indicating no match)
833
        '''
834
        if util.is_str(table): table = Table(table)
835
        assert type_ == None or util.is_str(type_) or type_ is filter_out
836
837
        self.table = table
838
        self.mapping = mapping
839
        self.type_ = type_
840
841 2749 aaronmk
    def to_str(self, db, left_table_):
842 2260 aaronmk
        def join(entry):
843
            '''Parses non-USING joins'''
844
            right_table_col, left_table_col = entry
845
846 2353 aaronmk
            # Switch order (right_table_col is on the left in the comparison)
847
            left = right_table_col
848
            right = left_table_col
849 2749 aaronmk
            left_table = self.table
850
            right_table = left_table_
851 2353 aaronmk
852 2747 aaronmk
            # Parse left side
853 2748 aaronmk
            left = with_default_table(left, left_table)
854 2747 aaronmk
855 2260 aaronmk
            # Parse special values
856 2747 aaronmk
            left_on_right = Col(left.name, right_table)
857
            if right is join_same: right = left_on_right
858 2353 aaronmk
            elif right is join_same_not_null:
859 2747 aaronmk
                right = CompareCond(left_on_right, '~=')
860 2260 aaronmk
861 2747 aaronmk
            # Parse right side
862 2353 aaronmk
            right = as_ValueCond(right, right_table)
863 2747 aaronmk
864
            return right.to_str(db, left)
865 2260 aaronmk
866 2265 aaronmk
        # Create join condition
867
        type_ = self.type_
868 2276 aaronmk
        joins = self.mapping
869 2746 aaronmk
        if joins == {}: join_cond = None
870
        elif type_ is not filter_out and reduce(operator.and_,
871 2460 aaronmk
            (v is join_same_not_null for v in joins.itervalues())):
872 2260 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
873 2747 aaronmk
            cols = map(to_name_only_col, joins.iterkeys())
874
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
875 2757 aaronmk
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
876 2260 aaronmk
877 2757 aaronmk
        if isinstance(self.table, NamedTable): whitespace = '\n'
878
        else: whitespace = ' '
879
880 2260 aaronmk
        # Create join
881
        if type_ is filter_out: type_ = 'LEFT'
882 2266 aaronmk
        str_ = ''
883
        if type_ != None: str_ += type_+' '
884 2757 aaronmk
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
885
        if join_cond != None: str_ += whitespace+join_cond
886 2266 aaronmk
        return str_
887 2349 aaronmk
888 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
889 2424 aaronmk
890
##### Value exprs
891
892 3089 aaronmk
all_cols = CustomCode('*')
893
894 2737 aaronmk
default = CustomCode('DEFAULT')
895
896 3090 aaronmk
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
897 2674 aaronmk
898 3061 aaronmk
class Coalesce(FunctionCall):
899
    def __init__(self, *args):
900
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
901 3060 aaronmk
902 3062 aaronmk
class Nullif(FunctionCall):
903
    def __init__(self, *args):
904
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
905
906 2850 aaronmk
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
907 2958 aaronmk
null_sentinels = {
908
    'character varying': r'\N',
909
    'double precision': 'NaN',
910
    'integer': 2147483647,
911
    'text': r'\N',
912
    'timestamp with time zone': 'infinity'
913
}
914 2692 aaronmk
915 3061 aaronmk
class EnsureNotNull(Coalesce):
916 2850 aaronmk
    def __init__(self, value, type_):
917 3061 aaronmk
        Coalesce.__init__(self, as_Col(value),
918 2988 aaronmk
            Cast(type_, null_sentinels[type_]))
919 2850 aaronmk
920
        self.type = type_
921 3001 aaronmk
922
    def to_str(self, db):
923
        col = self.args[0]
924
        index_col_ = index_col(col)
925
        if index_col_ != None: return index_col_.to_str(db)
926 3061 aaronmk
        return Coalesce.to_str(self, db)
927 2850 aaronmk
928 2737 aaronmk
##### Table exprs
929
930
class Values(Code):
931
    def __init__(self, values):
932 2739 aaronmk
        '''
933
        @param values [...]|[[...], ...] Can be one or multiple rows.
934
        '''
935 3445 aaronmk
        Code.__init__(self)
936
937 2739 aaronmk
        rows = values
938
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
939
            rows = [values]
940
        for i, row in enumerate(rows):
941
            rows[i] = map(remove_col_rename, map(as_Value, row))
942 2737 aaronmk
943 2739 aaronmk
        self.rows = rows
944 2737 aaronmk
945
    def to_str(self, db):
946 3491 aaronmk
        return 'VALUES '+(', '.join((List(*r).to_str(db) for r in self.rows)))
947 2737 aaronmk
948 2740 aaronmk
def NamedValues(name, cols, values):
949 2745 aaronmk
    '''
950 3048 aaronmk
    @param cols None|[...]
951 2745 aaronmk
    @post `cols` will be changed to Col objects with the table set to `name`.
952
    '''
953 2834 aaronmk
    table = NamedTable(name, Values(values), cols)
954 3048 aaronmk
    if cols != None: set_cols_table(table, cols)
955 2834 aaronmk
    return table
956 2740 aaronmk
957 2674 aaronmk
##### Database structure
958
959 2840 aaronmk
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
960
961 2851 aaronmk
def ensure_not_null(db, col, type_=None):
962 2840 aaronmk
    '''
963 2855 aaronmk
    @param col If type_ is not set, must have an underlying column.
964 2851 aaronmk
    @param type_ If set, overrides the underlying column's type.
965 2840 aaronmk
    @return EnsureNotNull|Col
966
    @throws ensure_not_null_excs
967
    '''
968 2855 aaronmk
    nullable = True
969
    try: typed_col = db.col_info(underlying_col(col))
970
    except NoUnderlyingTableException:
971 3355 aaronmk
        col = remove_col_rename(col)
972 3429 aaronmk
        if is_literal(col) and not is_null(col): nullable = False
973 3355 aaronmk
        elif type_ == None: raise
974 2855 aaronmk
    else:
975
        if type_ == None: type_ = typed_col.type
976
        nullable = typed_col.nullable
977
978 2953 aaronmk
    if nullable:
979
        try: col = EnsureNotNull(col, type_)
980
        except KeyError, e:
981
            # Warn of no null sentinel for type, even if caller catches error
982
            warnings.warn(UserWarning(exc.str_(e)))
983
            raise
984
985 2840 aaronmk
    return col