Project

General

Profile

1 2211 aaronmk
# SQL code generation
2
3 2748 aaronmk
import copy
4 3518 aaronmk
import itertools
5 2276 aaronmk
import operator
6 3158 aaronmk
from ordereddict import OrderedDict
7 2568 aaronmk
import re
8 2653 aaronmk
import UserDict
9 2953 aaronmk
import warnings
10 2276 aaronmk
11 2667 aaronmk
import dicts
12 2953 aaronmk
import exc
13 2701 aaronmk
import iters
14
import lists
15 2360 aaronmk
import objects
16 2222 aaronmk
import strings
17 2227 aaronmk
import util
18 2211 aaronmk
19 2587 aaronmk
##### Names
20 2499 aaronmk
21 2608 aaronmk
identifier_max_len = 63 # works for both PostgreSQL and MySQL
22 2587 aaronmk
23 2932 aaronmk
def concat(str_, suffix):
24 2609 aaronmk
    '''Preserves version so that it won't be truncated off the string, leading
25
    to collisions.'''
26 2613 aaronmk
    # Preserve version
27 2995 aaronmk
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
28 2985 aaronmk
    if match:
29
        str_, old_suffix = match.groups()
30
        suffix = old_suffix+suffix
31 2613 aaronmk
32 2932 aaronmk
    return strings.concat(str_, suffix, identifier_max_len)
33 2587 aaronmk
34 2932 aaronmk
def truncate(str_): return concat(str_, '')
35 2842 aaronmk
36 2575 aaronmk
def is_safe_name(name):
37 2583 aaronmk
    '''A name is safe *and unambiguous* if it:
38
    * contains only *lowercase* word (\w) characters
39
    * doesn't start with a digit
40
    * contains "_", so that it's not a keyword
41 2984 aaronmk
    '''
42
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
43 2568 aaronmk
44 2499 aaronmk
def esc_name(name, quote='"'):
45
    return quote + name.replace(quote, quote+quote) + quote
46
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
47
48 3320 aaronmk
def unesc_name(name, quote='"'):
49
    removed_ref = [False]
50
    name = strings.remove_prefix(quote, name, removed_ref)
51
    if removed_ref[0]:
52
        name = strings.remove_suffix(quote, name, removed_ref)
53
        assert removed_ref[0]
54
        name = name.replace(quote+quote, quote)
55
    return name
56
57 2513 aaronmk
def clean_name(name): return name.replace('"', '').replace('`', '')
58
59 3041 aaronmk
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
60
61 3182 aaronmk
def lstrip(str_):
62
    '''Also removes comments.'''
63
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
64
    return str_.lstrip()
65
66 2659 aaronmk
##### General SQL code objects
67 2219 aaronmk
68 2349 aaronmk
class MockDb:
69 2503 aaronmk
    def esc_value(self, value): return strings.repr_no_u(value)
70 2349 aaronmk
71 2499 aaronmk
    def esc_name(self, name): return esc_name(name)
72 2859 aaronmk
73
    def col_info(self, col):
74
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
75
76 2349 aaronmk
mockDb = MockDb()
77
78 2514 aaronmk
class BasicObject(objects.BasicObject):
79
    def __str__(self): return clean_name(strings.repr_no_u(self))
80
81 2659 aaronmk
##### Unparameterized code objects
82
83 2514 aaronmk
class Code(BasicObject):
84 3446 aaronmk
    def __init__(self, lang='sql'):
85
        self.lang = lang
86 3445 aaronmk
87 2658 aaronmk
    def to_str(self, db): raise NotImplementedError()
88 2349 aaronmk
89 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb)
90 2211 aaronmk
91 2269 aaronmk
class CustomCode(Code):
92 3445 aaronmk
    def __init__(self, str_):
93
        Code.__init__(self)
94
95
        self.str_ = str_
96 2256 aaronmk
97
    def to_str(self, db): return self.str_
98
99 2815 aaronmk
def as_Code(value, db=None):
100
    '''
101
    @param db If set, runs db.std_code() on the value.
102
    '''
103 3447 aaronmk
    if isinstance(value, Code): return value
104
105 2815 aaronmk
    if util.is_str(value):
106
        if db != None: value = db.std_code(value)
107
        return CustomCode(value)
108 2659 aaronmk
    else: return Literal(value)
109
110 2540 aaronmk
class Expr(Code):
111 3445 aaronmk
    def __init__(self, expr):
112
        Code.__init__(self)
113
114
        self.expr = expr
115 2540 aaronmk
116
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
117
118 3086 aaronmk
##### Names
119
120
class Name(Code):
121 3087 aaronmk
    def __init__(self, name):
122 3445 aaronmk
        Code.__init__(self)
123
124 3087 aaronmk
        name = truncate(name)
125
126
        self.name = name
127 3086 aaronmk
128
    def to_str(self, db): return db.esc_name(self.name)
129
130
def as_Name(value):
131
    if isinstance(value, Code): return value
132
    else: return Name(value)
133
134 2335 aaronmk
##### Literal values
135
136 3519 aaronmk
#### Primitives
137
138 2216 aaronmk
class Literal(Code):
139 3445 aaronmk
    def __init__(self, value):
140
        Code.__init__(self)
141
142
        self.value = value
143 2213 aaronmk
144
    def to_str(self, db): return db.esc_value(self.value)
145 2211 aaronmk
146 2400 aaronmk
def as_Value(value):
147
    if isinstance(value, Code): return value
148
    else: return Literal(value)
149
150 3429 aaronmk
def is_literal(value): return isinstance(value, Literal)
151 2216 aaronmk
152 3429 aaronmk
def is_null(value): return is_literal(value) and value.value == None
153
154 3519 aaronmk
#### Composites
155
156 3521 aaronmk
class List(Code):
157
    def __init__(self, values):
158 3519 aaronmk
        Code.__init__(self)
159
160
        self.values = values
161
162 3521 aaronmk
    def to_str(self, db): return ', '.join((v.to_str(db) for v in self.values))
163 3519 aaronmk
164 3521 aaronmk
class Tuple(List):
165
    def __init__(self, *values):
166
        List.__init__(self, values)
167
168
    def to_str(self, db): return '('+List.to_str(self, db)+')'
169
170 3520 aaronmk
class Row(Tuple):
171
    def to_str(self, db): return 'ROW'+Tuple.to_str(self, db)
172 3519 aaronmk
173 3522 aaronmk
### Arrays
174
175
class Array(List):
176
    def __init__(self, values):
177
        values = map(remove_col_rename, values)
178
179
        List.__init__(self, values)
180
181
    def to_str(self, db): return 'ARRAY['+List.to_str(self, db)+']'
182
183
def to_Array(value):
184
    if isinstance(value, Array): return value
185
    return Array(lists.mk_seq(value))
186
187 2711 aaronmk
##### Derived elements
188
189
src_self = object() # tells Col that it is its own source column
190
191
class Derived(Code):
192
    def __init__(self, srcs):
193 2712 aaronmk
        '''An element which was derived from some other element(s).
194 2711 aaronmk
        @param srcs See self.set_srcs()
195
        '''
196 3445 aaronmk
        Code.__init__(self)
197
198 2711 aaronmk
        self.set_srcs(srcs)
199
200 2713 aaronmk
    def set_srcs(self, srcs, overwrite=True):
201 2711 aaronmk
        '''
202
        @param srcs (self_type...)|src_self The element(s) this is derived from
203
        '''
204 2713 aaronmk
        if not overwrite and self.srcs != (): return # already set
205
206 2711 aaronmk
        if srcs == src_self: srcs = (self,)
207
        srcs = tuple(srcs) # make Col hashable
208
        self.srcs = srcs
209
210
    def _compare_on(self):
211
        compare_on = self.__dict__.copy()
212
        del compare_on['srcs'] # ignore
213
        return compare_on
214
215
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
216
217 2335 aaronmk
##### Tables
218
219 2712 aaronmk
class Table(Derived):
220 2991 aaronmk
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
221 2211 aaronmk
        '''
222
        @param schema str|None (for no schema)
223 2712 aaronmk
        @param srcs (Table...)|src_self See Derived.set_srcs()
224 2211 aaronmk
        '''
225 2712 aaronmk
        Derived.__init__(self, srcs)
226
227 3091 aaronmk
        if util.is_str(name): name = truncate(name)
228 2843 aaronmk
229 2211 aaronmk
        self.name = name
230
        self.schema = schema
231 2991 aaronmk
        self.is_temp = is_temp
232 3000 aaronmk
        self.index_cols = {}
233 2211 aaronmk
234 2348 aaronmk
    def to_str(self, db):
235
        str_ = ''
236 3088 aaronmk
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
237
        str_ += as_Name(self.name).to_str(db)
238 2348 aaronmk
        return str_
239 2336 aaronmk
240
    def to_Table(self): return self
241 3000 aaronmk
242
    def _compare_on(self):
243
        compare_on = Derived._compare_on(self)
244
        del compare_on['index_cols'] # ignore
245
        return compare_on
246 2211 aaronmk
247 2835 aaronmk
def is_underlying_table(table):
248
    return isinstance(table, Table) and table.to_Table() is table
249 2832 aaronmk
250 2902 aaronmk
class NoUnderlyingTableException(Exception): pass
251
252
def underlying_table(table):
253
    table = remove_table_rename(table)
254 3532 aaronmk
    if table != None and table.srcs:
255
        table, = table.srcs # for derived tables or row vars
256 2902 aaronmk
    if not is_underlying_table(table): raise NoUnderlyingTableException
257
    return table
258
259 2776 aaronmk
def as_Table(table, schema=None):
260 2270 aaronmk
    if table == None or isinstance(table, Code): return table
261 2776 aaronmk
    else: return Table(table, schema)
262 2219 aaronmk
263 3101 aaronmk
def suffixed_table(table, suffix):
264 3128 aaronmk
    table = copy.copy(table) # don't modify input!
265
    table.name = concat(table.name, suffix)
266
    return table
267 2707 aaronmk
268 2336 aaronmk
class NamedTable(Table):
269
    def __init__(self, name, code, cols=None):
270
        Table.__init__(self, name)
271
272 3016 aaronmk
        code = as_Table(code)
273 2741 aaronmk
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
274 3020 aaronmk
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
275 2336 aaronmk
276
        self.code = code
277
        self.cols = cols
278
279
    def to_str(self, db):
280 3026 aaronmk
        str_ = self.code.to_str(db)
281
        if str_.find('\n') >= 0: whitespace = '\n'
282
        else: whitespace = ' '
283
        str_ += whitespace+'AS '+Table.to_str(self, db)
284 2742 aaronmk
        if self.cols != None:
285
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
286 2336 aaronmk
        return str_
287
288
    def to_Table(self): return Table(self.name)
289
290 2753 aaronmk
def remove_table_rename(table):
291
    if isinstance(table, NamedTable): table = table.code
292
    return table
293
294 2335 aaronmk
##### Columns
295
296 2711 aaronmk
class Col(Derived):
297 2701 aaronmk
    def __init__(self, name, table=None, srcs=()):
298 2211 aaronmk
        '''
299
        @param table Table|None (for no table)
300 2711 aaronmk
        @param srcs (Col...)|src_self See Derived.set_srcs()
301 2211 aaronmk
        '''
302 2711 aaronmk
        Derived.__init__(self, srcs)
303
304 3091 aaronmk
        if util.is_str(name): name = truncate(name)
305 2241 aaronmk
        if util.is_str(table): table = Table(table)
306 2211 aaronmk
        assert table == None or isinstance(table, Table)
307
308
        self.name = name
309
        self.table = table
310
311 2989 aaronmk
    def to_str(self, db, for_str=False):
312 3088 aaronmk
        str_ = as_Name(self.name).to_str(db)
313 2989 aaronmk
        if for_str: str_ = clean_name(str_)
314 2933 aaronmk
        if self.table != None:
315 2989 aaronmk
            table = self.table.to_Table()
316 3750 aaronmk
            if for_str: str_ = concat(strings.ustr(table), '.'+str_)
317 2989 aaronmk
            else: str_ = table.to_str(db)+'.'+str_
318 2211 aaronmk
        return str_
319 2314 aaronmk
320 2989 aaronmk
    def __str__(self): return self.to_str(mockDb, for_str=True)
321 2933 aaronmk
322 2314 aaronmk
    def to_Col(self): return self
323 2211 aaronmk
324 3512 aaronmk
def is_col(col): return isinstance(col, Col)
325 2393 aaronmk
326 3512 aaronmk
def is_table_col(col): return is_col(col) and col.table != None
327
328 3000 aaronmk
def index_col(col):
329
    if not is_table_col(col): return None
330 3104 aaronmk
331
    table = col.table
332
    try: name = table.index_cols[col.name]
333
    except KeyError: return None
334
    else: return Col(name, table, col.srcs)
335 2999 aaronmk
336 3024 aaronmk
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
337 2996 aaronmk
338 2563 aaronmk
def as_Col(col, table=None, name=None):
339
    '''
340
    @param name If not None, any non-Col input will be renamed using NamedCol.
341
    '''
342
    if name != None:
343
        col = as_Value(col)
344
        if not isinstance(col, Col): col = NamedCol(name, col)
345 2333 aaronmk
346
    if isinstance(col, Code): return col
347 3513 aaronmk
    elif util.is_str(col): return Col(col, table)
348
    else: return Literal(col)
349 2260 aaronmk
350 3093 aaronmk
def with_table(col, table):
351 3105 aaronmk
    if isinstance(col, NamedCol): pass # doesn't take a table
352
    elif isinstance(col, FunctionCall):
353
        col = copy.deepcopy(col) # don't modify input!
354
        col.args[0].table = table
355 3493 aaronmk
    elif isinstance(col, Col):
356 3098 aaronmk
        col = copy.copy(col) # don't modify input!
357
        col.table = table
358 3093 aaronmk
    return col
359
360 3100 aaronmk
def with_default_table(col, table):
361 2747 aaronmk
    col = as_Col(col)
362 3100 aaronmk
    if col.table == None: col = with_table(col, table)
363 2747 aaronmk
    return col
364
365 2744 aaronmk
def set_cols_table(table, cols):
366
    table = as_Table(table)
367
368
    for i, col in enumerate(cols):
369
        col = cols[i] = as_Col(col)
370
        col.table = table
371
372 2401 aaronmk
def to_name_only_col(col, check_table=None):
373
    col = as_Col(col)
374 3020 aaronmk
    if not is_table_col(col): return col
375 2401 aaronmk
376
    if check_table != None:
377
        table = col.table
378
        assert table == None or table == check_table
379
    return Col(col.name)
380
381 2993 aaronmk
def suffixed_col(col, suffix):
382
    return Col(concat(col.name, suffix), col.table, col.srcs)
383
384 3516 aaronmk
def has_srcs(col): return is_col(col) and col.srcs
385
386 3518 aaronmk
def cross_join_srcs(cols):
387
    cols = filter(has_srcs, cols) # empty srcs will mess up the cross join
388
    srcs = [[s.name for s in c.srcs] for c in cols]
389
    return [Col(','.join(s)) for s in itertools.product(*srcs)]
390 3514 aaronmk
391 2323 aaronmk
class NamedCol(Col):
392 2229 aaronmk
    def __init__(self, name, code):
393 2310 aaronmk
        Col.__init__(self, name)
394
395 3016 aaronmk
        code = as_Value(code)
396 2229 aaronmk
397
        self.code = code
398
399
    def to_str(self, db):
400 2310 aaronmk
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
401 2314 aaronmk
402
    def to_Col(self): return Col(self.name)
403 2229 aaronmk
404 2462 aaronmk
def remove_col_rename(col):
405
    if isinstance(col, NamedCol): col = col.code
406
    return col
407
408 2830 aaronmk
def underlying_col(col):
409
    col = remove_col_rename(col)
410 2849 aaronmk
    if not isinstance(col, Col): raise NoUnderlyingTableException
411
412 2902 aaronmk
    return Col(col.name, underlying_table(col.table), col.srcs)
413 2830 aaronmk
414 2703 aaronmk
def wrap(wrap_func, value):
415
    '''Wraps a value, propagating any column renaming to the returned value.'''
416
    if isinstance(value, NamedCol):
417
        return NamedCol(value.name, wrap_func(value.code))
418
    else: return wrap_func(value)
419
420 2667 aaronmk
class ColDict(dicts.DictProxy):
421 3691 aaronmk
    '''A dict that automatically makes inserted entries Col objects.
422
    Anything that isn't a column is wrapped in a NamedCol with the key's column
423
    name by `as_Col(value, name=key.name)`.
424
    '''
425 2564 aaronmk
426 2645 aaronmk
    def __init__(self, db, keys_table, dict_={}):
427 3158 aaronmk
        dicts.DictProxy.__init__(self, OrderedDict())
428 2667 aaronmk
429 2645 aaronmk
        keys_table = as_Table(keys_table)
430
431 2642 aaronmk
        self.db = db
432 2641 aaronmk
        self.table = keys_table
433 2653 aaronmk
        self.update(dict_) # after setting vars because __setitem__() needs them
434 2641 aaronmk
435 2667 aaronmk
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
436 2655 aaronmk
437 2667 aaronmk
    def __getitem__(self, key):
438
        return dicts.DictProxy.__getitem__(self, self._key(key))
439 2653 aaronmk
440 2564 aaronmk
    def __setitem__(self, key, value):
441 2642 aaronmk
        key = self._key(key)
442 3639 aaronmk
        if value == None:
443
            try: value = self.db.col_info(key).default
444
            except NoUnderlyingTableException: pass # not a table column
445 2667 aaronmk
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
446 2564 aaronmk
447 2641 aaronmk
    def _key(self, key): return as_Col(key, self.table)
448 2564 aaronmk
449 3519 aaronmk
##### Definitions
450 3491 aaronmk
451 3469 aaronmk
class TypedCol(Col):
452
    def __init__(self, name, type_, default=None, nullable=True,
453
        constraints=None):
454
        assert default == None or isinstance(default, Code)
455
456
        Col.__init__(self, name)
457
458
        self.type = type_
459
        self.default = default
460
        self.nullable = nullable
461
        self.constraints = constraints
462
463
    def to_str(self, db):
464 3481 aaronmk
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
465 3469 aaronmk
        if not self.nullable: str_ += ' NOT NULL'
466
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
467
        if self.constraints != None: str_ += ' '+self.constraints
468
        return str_
469
470
    def to_Col(self): return Col(self.name)
471
472 3488 aaronmk
class SetOf(Code):
473
    def __init__(self, type_):
474
        Code.__init__(self)
475
476
        self.type = type_
477
478
    def to_str(self, db):
479
        return 'SETOF '+self.type.to_str(db)
480
481 3483 aaronmk
class RowType(Code):
482
    def __init__(self, table):
483
        Code.__init__(self)
484
485
        self.table = table
486
487
    def to_str(self, db):
488
        return self.table.to_str(db)+'%ROWTYPE'
489
490 3485 aaronmk
class ColType(Code):
491
    def __init__(self, col):
492
        Code.__init__(self)
493
494
        self.col = col
495
496
    def to_str(self, db):
497
        return self.col.to_str(db)+'%TYPE'
498
499 2524 aaronmk
##### Functions
500
501 2912 aaronmk
Function = Table
502 2911 aaronmk
as_Function = as_Table
503
504 2691 aaronmk
class InternalFunction(CustomCode): pass
505
506 3442 aaronmk
#### Calls
507
508 2941 aaronmk
class NamedArg(NamedCol):
509
    def __init__(self, name, value):
510
        NamedCol.__init__(self, name, value)
511
512
    def to_str(self, db):
513
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
514
515 2524 aaronmk
class FunctionCall(Code):
516 2941 aaronmk
    def __init__(self, function, *args, **kw_args):
517 2524 aaronmk
        '''
518 2690 aaronmk
        @param args [Code|literal-value...] The function's arguments
519 2524 aaronmk
        '''
520 3445 aaronmk
        Code.__init__(self)
521
522 3016 aaronmk
        function = as_Function(function)
523 2941 aaronmk
        def filter_(arg): return remove_col_rename(as_Value(arg))
524
        args = map(filter_, args)
525
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
526 2524 aaronmk
527
        self.function = function
528
        self.args = args
529
530
    def to_str(self, db):
531
        args_str = ', '.join((v.to_str(db) for v in self.args))
532
        return self.function.to_str(db)+'('+args_str+')'
533
534 2533 aaronmk
def wrap_in_func(function, value):
535
    '''Wraps a value inside a function call.
536
    Propagates any column renaming to the returned value.
537
    '''
538 2703 aaronmk
    return wrap(lambda v: FunctionCall(function, v), value)
539 2533 aaronmk
540 2561 aaronmk
def unwrap_func_call(func_call, check_name=None):
541
    '''Unwraps any function call to its first argument.
542
    Also removes any column renaming.
543
    '''
544
    func_call = remove_col_rename(func_call)
545
    if not isinstance(func_call, FunctionCall): return func_call
546
547
    if check_name != None:
548
        name = func_call.function.name
549
        assert name == None or name == check_name
550
    return func_call.args[0]
551
552 3442 aaronmk
#### Definitions
553
554
class FunctionDef(Code):
555 3471 aaronmk
    def __init__(self, function, return_type, body, params=[], modifiers=None):
556 3445 aaronmk
        Code.__init__(self)
557
558 3487 aaronmk
        return_type = as_Code(return_type)
559 3444 aaronmk
        body = as_Code(body)
560
561 3442 aaronmk
        self.function = function
562
        self.return_type = return_type
563
        self.body = body
564 3471 aaronmk
        self.params = params
565 3456 aaronmk
        self.modifiers = modifiers
566 3442 aaronmk
567
    def to_str(self, db):
568 3487 aaronmk
        params_str = (', '.join((p.to_str(db) for p in self.params)))
569 3442 aaronmk
        str_ = '''\
570 3471 aaronmk
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
571 3487 aaronmk
RETURNS '''+self.return_type.to_str(db)+'''
572 3448 aaronmk
LANGUAGE '''+self.body.lang+'''
573 3456 aaronmk
'''
574
        if self.modifiers != None: str_ += self.modifiers+'\n'
575
        str_ += '''\
576 3442 aaronmk
AS $$
577 3444 aaronmk
'''+self.body.to_str(db)+'''
578 3442 aaronmk
$$;
579
'''
580
        return str_
581
582 3469 aaronmk
class FunctionParam(TypedCol):
583
    def __init__(self, name, type_, default=None, out=False):
584
        TypedCol.__init__(self, name, type_, default)
585
586
        self.out = out
587
588
    def to_str(self, db):
589
        str_ = TypedCol.to_str(self, db)
590
        if self.out: str_ = 'OUT '+str_
591
        return str_
592
593
    def to_Col(self): return Col(self.name)
594
595 3454 aaronmk
### PL/pgSQL
596
597 3496 aaronmk
class ReturnQuery(Code):
598
    def __init__(self, query):
599
        Code.__init__(self)
600
601
        query = as_Code(query)
602
603
        self.query = query
604
605
    def to_str(self, db):
606
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
607
608 3515 aaronmk
## Exceptions
609
610 3509 aaronmk
class BaseExcHandler(BasicObject):
611
    def to_str(self, db, body): raise NotImplementedError()
612 3510 aaronmk
613
    def __repr__(self): return self.to_str(mockDb, '<body>')
614 3509 aaronmk
615 3549 aaronmk
suppress_exc = 'NULL;\n';
616
617 3554 aaronmk
reraise_exc = 'RAISE USING ERRCODE = SQLSTATE, MESSAGE = SQLERRM;\n';
618
619 3509 aaronmk
class ExcHandler(BaseExcHandler):
620 3454 aaronmk
    def __init__(self, exc, handler=None):
621
        if handler != None: handler = as_Code(handler)
622
623
        self.exc = exc
624
        self.handler = handler
625
626
    def to_str(self, db, body):
627
        body = as_Code(body)
628
629 3467 aaronmk
        if self.handler != None:
630
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
631 3549 aaronmk
        else: handler_str = ' '+suppress_exc
632 3454 aaronmk
633
        str_ = '''\
634
BEGIN
635 3467 aaronmk
'''+strings.indent(body.to_str(db))+'''\
636 3454 aaronmk
EXCEPTION
637 3463 aaronmk
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
638 3454 aaronmk
END;\
639
'''
640
        return str_
641
642 3515 aaronmk
class NestedExcHandler(BaseExcHandler):
643
    def __init__(self, *handlers):
644
        '''
645
        @param handlers Sorted from outermost to innermost
646
        '''
647
        self.handlers = handlers
648
649
    def to_str(self, db, body):
650
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
651
        return body
652
653 3503 aaronmk
class ExcToWarning(Code):
654
    def __init__(self, return_):
655
        '''
656
        @param return_ Statement to return a default value in case of error
657
        '''
658
        Code.__init__(self)
659
660
        return_ = as_Code(return_)
661
662
        self.return_ = return_
663
664
    def to_str(self, db):
665
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
666
667 3454 aaronmk
unique_violation_handler = ExcHandler('unique_violation')
668
669 3555 aaronmk
# Note doubled "\"s because inside Python string
670 3468 aaronmk
plpythonu_error_handler = ExcHandler('internal_error', '''\
671 3591 aaronmk
-- Handle PL/Python exceptions
672 3555 aaronmk
DECLARE
673 3595 aaronmk
    matches text[] := regexp_matches(SQLERRM,
674
        E'^(?:PL/Python: )?(\\\\w+): (.*)$'); -- .* also matches \\n
675 3555 aaronmk
    exc_name text := matches[1];
676
    msg text := matches[2];
677
BEGIN
678 3596 aaronmk
    /* Re-raise PL/Python exceptions with the PL/Python prefix removed.
679
    This allows the exception to be parsed like a native exception.
680
    Always raise as data_exception so it goes in the errors table. */
681 3598 aaronmk
    IF exc_name IS NOT NULL THEN
682
        RAISE data_exception USING MESSAGE = msg;
683 3595 aaronmk
    -- Re-raise non-PL/Python exceptions
684
    ELSE
685
        '''+reraise_exc+'''\
686 3555 aaronmk
    END IF;
687
END;
688 3468 aaronmk
''')
689
690 3505 aaronmk
def data_exception_handler(handler):
691
    return ExcHandler('data_exception', handler)
692
693 3527 aaronmk
row_var = Table('row')
694
695 3449 aaronmk
class RowExcIgnore(Code):
696
    def __init__(self, row_type, select_query, with_row, cols=None,
697 3527 aaronmk
        exc_handler=unique_violation_handler, row_var=row_var):
698 3529 aaronmk
        '''
699
        @param row_type Ignored if a custom row_var is used.
700
        @pre If a custom row_var is used, it must already be defined.
701
        '''
702 3449 aaronmk
        Code.__init__(self, lang='plpgsql')
703
704 3482 aaronmk
        row_type = as_Code(row_type)
705 3449 aaronmk
        select_query = as_Code(select_query)
706
        with_row = as_Code(with_row)
707 3452 aaronmk
        row_var = as_Table(row_var)
708 3449 aaronmk
709
        self.row_type = row_type
710
        self.select_query = select_query
711
        self.with_row = with_row
712
        self.cols = cols
713 3455 aaronmk
        self.exc_handler = exc_handler
714 3452 aaronmk
        self.row_var = row_var
715 3449 aaronmk
716
    def to_str(self, db):
717 3452 aaronmk
        if self.cols == None: row_vars = [self.row_var]
718
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
719 3449 aaronmk
720 3526 aaronmk
        # Need an EXCEPTION block for each individual row because "When an error
721
        # is caught by an EXCEPTION clause, [...] all changes to persistent
722
        # database state within the block are rolled back."
723
        # This is unfortunate because "A block containing an EXCEPTION clause is
724
        # significantly more expensive to enter and exit than a block without
725
        # one."
726
        # (http://www.postgresql.org/docs/8.3/static/\
727
        # plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING)
728 3449 aaronmk
        str_ = '''\
729 3529 aaronmk
FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
730
'''+strings.indent(self.select_query.to_str(db))+'''\
731
LOOP
732
'''+strings.indent(self.exc_handler.to_str(db, self.with_row))+'''\
733
END LOOP;
734
'''
735 3533 aaronmk
        if self.row_var == row_var:
736 3529 aaronmk
            str_ = '''\
737 3449 aaronmk
DECLARE
738 3482 aaronmk
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
739 3449 aaronmk
BEGIN
740 3529 aaronmk
'''+strings.indent(str_)+'''\
741
END;
742 3449 aaronmk
'''
743
        return str_
744
745 2986 aaronmk
##### Casts
746
747
class Cast(FunctionCall):
748
    def __init__(self, type_, value):
749 3539 aaronmk
        type_ = as_Code(type_)
750 2986 aaronmk
        value = as_Value(value)
751
752
        self.type_ = type_
753
        self.value = value
754
755
    def to_str(self, db):
756 3539 aaronmk
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_.to_str(db)+')'
757 2986 aaronmk
758 3354 aaronmk
def cast_literal(value):
759 3429 aaronmk
    if not is_literal(value): return value
760 3354 aaronmk
761
    if util.is_str(value.value): value = Cast('text', value)
762
    return value
763
764 2335 aaronmk
##### Conditions
765 2259 aaronmk
766 3350 aaronmk
class NotCond(Code):
767
    def __init__(self, cond):
768 3445 aaronmk
        Code.__init__(self)
769
770 3350 aaronmk
        self.cond = cond
771
772
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
773
774 2398 aaronmk
class ColValueCond(Code):
775
    def __init__(self, col, value):
776 3445 aaronmk
        Code.__init__(self)
777
778 2398 aaronmk
        value = as_ValueCond(value)
779
780
        self.col = col
781
        self.value = value
782
783
    def to_str(self, db): return self.value.to_str(db, self.col)
784
785 2577 aaronmk
def combine_conds(conds, keyword=None):
786
    '''
787
    @param keyword The keyword to add before the conditions, if any
788
    '''
789
    str_ = ''
790
    if keyword != None:
791
        if conds == []: whitespace = ''
792
        elif len(conds) == 1: whitespace = ' '
793
        else: whitespace = '\n'
794
        str_ += keyword+whitespace
795
796
    str_ += '\nAND '.join(conds)
797
    return str_
798
799 2398 aaronmk
##### Condition column comparisons
800
801 2514 aaronmk
class ValueCond(BasicObject):
802 2213 aaronmk
    def __init__(self, value):
803 2858 aaronmk
        value = remove_col_rename(as_Value(value))
804 2213 aaronmk
805
        self.value = value
806 2214 aaronmk
807 2216 aaronmk
    def to_str(self, db, left_value):
808 2214 aaronmk
        '''
809 2216 aaronmk
        @param left_value The Code object that the condition is being applied on
810 2214 aaronmk
        '''
811
        raise NotImplemented()
812 2228 aaronmk
813 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
814 2211 aaronmk
815
class CompareCond(ValueCond):
816
    def __init__(self, value, operator='='):
817 2222 aaronmk
        '''
818
        @param operator By default, compares NULL values literally. Use '~=' or
819
            '~!=' to pass NULLs through.
820
        '''
821 2211 aaronmk
        ValueCond.__init__(self, value)
822
        self.operator = operator
823
824 2216 aaronmk
    def to_str(self, db, left_value):
825 2858 aaronmk
        left_value = remove_col_rename(as_Col(left_value))
826 2216 aaronmk
827 2222 aaronmk
        right_value = self.value
828
829
        # Parse operator
830 2216 aaronmk
        operator = self.operator
831 2222 aaronmk
        passthru_null_ref = [False]
832
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
833
        neg_ref = [False]
834
        operator = strings.remove_prefix('!', operator, neg_ref)
835 2844 aaronmk
        equals = operator.endswith('=') # also includes <=, >=
836 2222 aaronmk
837 2825 aaronmk
        # Handle nullable columns
838
        check_null = False
839 2844 aaronmk
        if not passthru_null_ref[0]: # NULLs compare equal
840 2857 aaronmk
            try: left_value = ensure_not_null(db, left_value)
841 2844 aaronmk
            except ensure_not_null_excs: # fall back to alternate method
842
                check_null = equals and isinstance(right_value, Col)
843 2837 aaronmk
            else:
844 2857 aaronmk
                if isinstance(left_value, EnsureNotNull):
845
                    right_value = ensure_not_null(db, right_value,
846
                        left_value.type) # apply same function to both sides
847 2825 aaronmk
848 2844 aaronmk
        if equals and is_null(right_value): operator = 'IS'
849
850 2825 aaronmk
        left = left_value.to_str(db)
851
        right = right_value.to_str(db)
852
853 2222 aaronmk
        # Create str
854
        str_ = left+' '+operator+' '+right
855 2825 aaronmk
        if check_null:
856 2578 aaronmk
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
857
        if neg_ref[0]: str_ = 'NOT '+str_
858 2222 aaronmk
        return str_
859 2216 aaronmk
860 2260 aaronmk
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
861
assume_literal = object()
862
863
def as_ValueCond(value, default_table=assume_literal):
864
    if not isinstance(value, ValueCond):
865
        if default_table is not assume_literal:
866 2748 aaronmk
            value = with_default_table(value, default_table)
867 2260 aaronmk
        return CompareCond(value)
868 2216 aaronmk
    else: return value
869 2219 aaronmk
870 2335 aaronmk
##### Joins
871
872 2352 aaronmk
join_same = object() # tells Join the left and right columns have the same name
873 2260 aaronmk
874 2353 aaronmk
# Tells Join the left and right columns have the same name and are never NULL
875
join_same_not_null = object()
876
877 2260 aaronmk
filter_out = object() # tells Join to filter out rows that match the join
878
879 2514 aaronmk
class Join(BasicObject):
880 2746 aaronmk
    def __init__(self, table, mapping={}, type_=None):
881 2260 aaronmk
        '''
882
        @param mapping dict(right_table_col=left_table_col, ...)
883 2352 aaronmk
            * if left_table_col is join_same: left_table_col = right_table_col
884 2353 aaronmk
              * Note that right_table_col must be a string
885
            * if left_table_col is join_same_not_null:
886
              left_table_col = right_table_col and both have NOT NULL constraint
887
              * Note that right_table_col must be a string
888 2260 aaronmk
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
889
            * filter_out: equivalent to 'LEFT' with the query filtered by
890
              `table_pkey IS NULL` (indicating no match)
891
        '''
892
        if util.is_str(table): table = Table(table)
893
        assert type_ == None or util.is_str(type_) or type_ is filter_out
894
895
        self.table = table
896
        self.mapping = mapping
897
        self.type_ = type_
898
899 2749 aaronmk
    def to_str(self, db, left_table_):
900 2260 aaronmk
        def join(entry):
901
            '''Parses non-USING joins'''
902
            right_table_col, left_table_col = entry
903
904 2353 aaronmk
            # Switch order (right_table_col is on the left in the comparison)
905
            left = right_table_col
906
            right = left_table_col
907 2749 aaronmk
            left_table = self.table
908
            right_table = left_table_
909 2353 aaronmk
910 2747 aaronmk
            # Parse left side
911 2748 aaronmk
            left = with_default_table(left, left_table)
912 2747 aaronmk
913 2260 aaronmk
            # Parse special values
914 2747 aaronmk
            left_on_right = Col(left.name, right_table)
915
            if right is join_same: right = left_on_right
916 2353 aaronmk
            elif right is join_same_not_null:
917 2747 aaronmk
                right = CompareCond(left_on_right, '~=')
918 2260 aaronmk
919 2747 aaronmk
            # Parse right side
920 2353 aaronmk
            right = as_ValueCond(right, right_table)
921 2747 aaronmk
922
            return right.to_str(db, left)
923 2260 aaronmk
924 2265 aaronmk
        # Create join condition
925
        type_ = self.type_
926 2276 aaronmk
        joins = self.mapping
927 2746 aaronmk
        if joins == {}: join_cond = None
928
        elif type_ is not filter_out and reduce(operator.and_,
929 2460 aaronmk
            (v is join_same_not_null for v in joins.itervalues())):
930 2260 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
931 2747 aaronmk
            cols = map(to_name_only_col, joins.iterkeys())
932
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
933 2757 aaronmk
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
934 2260 aaronmk
935 2757 aaronmk
        if isinstance(self.table, NamedTable): whitespace = '\n'
936
        else: whitespace = ' '
937
938 2260 aaronmk
        # Create join
939
        if type_ is filter_out: type_ = 'LEFT'
940 2266 aaronmk
        str_ = ''
941
        if type_ != None: str_ += type_+' '
942 2757 aaronmk
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
943
        if join_cond != None: str_ += whitespace+join_cond
944 2266 aaronmk
        return str_
945 2349 aaronmk
946 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
947 2424 aaronmk
948
##### Value exprs
949
950 3089 aaronmk
all_cols = CustomCode('*')
951
952 2737 aaronmk
default = CustomCode('DEFAULT')
953
954 3090 aaronmk
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
955 2674 aaronmk
956 3061 aaronmk
class Coalesce(FunctionCall):
957
    def __init__(self, *args):
958
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
959 3060 aaronmk
960 3062 aaronmk
class Nullif(FunctionCall):
961
    def __init__(self, *args):
962
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
963
964 3706 aaronmk
null_as_str = Cast('text', 'NULL')
965
966
def to_text(value): return Coalesce(Cast('text', value), null_as_str)
967
968 2850 aaronmk
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
969 2958 aaronmk
null_sentinels = {
970
    'character varying': r'\N',
971
    'double precision': 'NaN',
972
    'integer': 2147483647,
973
    'text': r'\N',
974
    'timestamp with time zone': 'infinity'
975
}
976 2692 aaronmk
977 3061 aaronmk
class EnsureNotNull(Coalesce):
978 2850 aaronmk
    def __init__(self, value, type_):
979 3061 aaronmk
        Coalesce.__init__(self, as_Col(value),
980 2988 aaronmk
            Cast(type_, null_sentinels[type_]))
981 2850 aaronmk
982
        self.type = type_
983 3001 aaronmk
984
    def to_str(self, db):
985
        col = self.args[0]
986
        index_col_ = index_col(col)
987
        if index_col_ != None: return index_col_.to_str(db)
988 3061 aaronmk
        return Coalesce.to_str(self, db)
989 2850 aaronmk
990 3523 aaronmk
#### Arrays
991
992 3535 aaronmk
class ArrayMerge(FunctionCall):
993 3523 aaronmk
    def __init__(self, sep, array):
994
        array = to_Array(array)
995
        FunctionCall.__init__(self, InternalFunction('array_to_string'), array,
996
            sep)
997
998 3537 aaronmk
def merge_not_null(db, sep, values):
999 3707 aaronmk
    return ArrayMerge(sep, map(to_text, values))
1000 3537 aaronmk
1001 2737 aaronmk
##### Table exprs
1002
1003
class Values(Code):
1004
    def __init__(self, values):
1005 2739 aaronmk
        '''
1006
        @param values [...]|[[...], ...] Can be one or multiple rows.
1007
        '''
1008 3445 aaronmk
        Code.__init__(self)
1009
1010 2739 aaronmk
        rows = values
1011
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
1012
            rows = [values]
1013
        for i, row in enumerate(rows):
1014
            rows[i] = map(remove_col_rename, map(as_Value, row))
1015 2737 aaronmk
1016 2739 aaronmk
        self.rows = rows
1017 2737 aaronmk
1018
    def to_str(self, db):
1019 3520 aaronmk
        return 'VALUES '+(', '.join((Tuple(*r).to_str(db) for r in self.rows)))
1020 2737 aaronmk
1021 2740 aaronmk
def NamedValues(name, cols, values):
1022 2745 aaronmk
    '''
1023 3048 aaronmk
    @param cols None|[...]
1024 2745 aaronmk
    @post `cols` will be changed to Col objects with the table set to `name`.
1025
    '''
1026 2834 aaronmk
    table = NamedTable(name, Values(values), cols)
1027 3048 aaronmk
    if cols != None: set_cols_table(table, cols)
1028 2834 aaronmk
    return table
1029 2740 aaronmk
1030 2674 aaronmk
##### Database structure
1031
1032 2840 aaronmk
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
1033
1034 2851 aaronmk
def ensure_not_null(db, col, type_=None):
1035 2840 aaronmk
    '''
1036 2855 aaronmk
    @param col If type_ is not set, must have an underlying column.
1037 2851 aaronmk
    @param type_ If set, overrides the underlying column's type.
1038 2840 aaronmk
    @return EnsureNotNull|Col
1039
    @throws ensure_not_null_excs
1040
    '''
1041 2855 aaronmk
    nullable = True
1042
    try: typed_col = db.col_info(underlying_col(col))
1043
    except NoUnderlyingTableException:
1044 3355 aaronmk
        col = remove_col_rename(col)
1045 3429 aaronmk
        if is_literal(col) and not is_null(col): nullable = False
1046 3355 aaronmk
        elif type_ == None: raise
1047 2855 aaronmk
    else:
1048
        if type_ == None: type_ = typed_col.type
1049
        nullable = typed_col.nullable
1050
1051 2953 aaronmk
    if nullable:
1052
        try: col = EnsureNotNull(col, type_)
1053
        except KeyError, e:
1054
            # Warn of no null sentinel for type, even if caller catches error
1055
            warnings.warn(UserWarning(exc.str_(e)))
1056
            raise
1057
1058 2840 aaronmk
    return col
1059 3536 aaronmk
1060
def try_mk_not_null(db, value):
1061
    '''
1062
    Warning: This function does not guarantee that its result is NOT NULL.
1063
    '''
1064
    try: return ensure_not_null(db, value)
1065
    except ensure_not_null_excs: return value