Project

General

Profile

1 2211 aaronmk
# SQL code generation
2
3 2748 aaronmk
import copy
4 2276 aaronmk
import operator
5 3158 aaronmk
from ordereddict import OrderedDict
6 2568 aaronmk
import re
7 2653 aaronmk
import UserDict
8 2953 aaronmk
import warnings
9 2276 aaronmk
10 2667 aaronmk
import dicts
11 2953 aaronmk
import exc
12 2701 aaronmk
import iters
13
import lists
14 2360 aaronmk
import objects
15 2222 aaronmk
import strings
16 2227 aaronmk
import util
17 2211 aaronmk
18 2587 aaronmk
##### Names
19 2499 aaronmk
20 2608 aaronmk
identifier_max_len = 63 # works for both PostgreSQL and MySQL
21 2587 aaronmk
22 2932 aaronmk
def concat(str_, suffix):
23 2609 aaronmk
    '''Preserves version so that it won't be truncated off the string, leading
24
    to collisions.'''
25 2613 aaronmk
    # Preserve version
26 2995 aaronmk
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
27 2985 aaronmk
    if match:
28
        str_, old_suffix = match.groups()
29
        suffix = old_suffix+suffix
30 2613 aaronmk
31 2932 aaronmk
    return strings.concat(str_, suffix, identifier_max_len)
32 2587 aaronmk
33 2932 aaronmk
def truncate(str_): return concat(str_, '')
34 2842 aaronmk
35 2575 aaronmk
def is_safe_name(name):
36 2583 aaronmk
    '''A name is safe *and unambiguous* if it:
37
    * contains only *lowercase* word (\w) characters
38
    * doesn't start with a digit
39
    * contains "_", so that it's not a keyword
40 2984 aaronmk
    '''
41
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
42 2568 aaronmk
43 2499 aaronmk
def esc_name(name, quote='"'):
44
    return quote + name.replace(quote, quote+quote) + quote
45
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
46
47 3320 aaronmk
def unesc_name(name, quote='"'):
48
    removed_ref = [False]
49
    name = strings.remove_prefix(quote, name, removed_ref)
50
    if removed_ref[0]:
51
        name = strings.remove_suffix(quote, name, removed_ref)
52
        assert removed_ref[0]
53
        name = name.replace(quote+quote, quote)
54
    return name
55
56 2513 aaronmk
def clean_name(name): return name.replace('"', '').replace('`', '')
57
58 3041 aaronmk
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
59
60 3182 aaronmk
def lstrip(str_):
61
    '''Also removes comments.'''
62
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
63
    return str_.lstrip()
64
65 2659 aaronmk
##### General SQL code objects
66 2219 aaronmk
67 2349 aaronmk
class MockDb:
68 2503 aaronmk
    def esc_value(self, value): return strings.repr_no_u(value)
69 2349 aaronmk
70 2499 aaronmk
    def esc_name(self, name): return esc_name(name)
71 2859 aaronmk
72
    def col_info(self, col):
73
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
74
75 2349 aaronmk
mockDb = MockDb()
76
77 2514 aaronmk
class BasicObject(objects.BasicObject):
78
    def __str__(self): return clean_name(strings.repr_no_u(self))
79
80 2659 aaronmk
##### Unparameterized code objects
81
82 2514 aaronmk
class Code(BasicObject):
83 3446 aaronmk
    def __init__(self, lang='sql'):
84
        self.lang = lang
85 3445 aaronmk
86 2658 aaronmk
    def to_str(self, db): raise NotImplementedError()
87 2349 aaronmk
88 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb)
89 2211 aaronmk
90 2269 aaronmk
class CustomCode(Code):
91 3445 aaronmk
    def __init__(self, str_):
92
        Code.__init__(self)
93
94
        self.str_ = str_
95 2256 aaronmk
96
    def to_str(self, db): return self.str_
97
98 2815 aaronmk
def as_Code(value, db=None):
99
    '''
100
    @param db If set, runs db.std_code() on the value.
101
    '''
102 3447 aaronmk
    if isinstance(value, Code): return value
103
104 2815 aaronmk
    if util.is_str(value):
105
        if db != None: value = db.std_code(value)
106
        return CustomCode(value)
107 2659 aaronmk
    else: return Literal(value)
108
109 2540 aaronmk
class Expr(Code):
110 3445 aaronmk
    def __init__(self, expr):
111
        Code.__init__(self)
112
113
        self.expr = expr
114 2540 aaronmk
115
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
116
117 3086 aaronmk
##### Names
118
119
class Name(Code):
120 3087 aaronmk
    def __init__(self, name):
121 3445 aaronmk
        Code.__init__(self)
122
123 3087 aaronmk
        name = truncate(name)
124
125
        self.name = name
126 3086 aaronmk
127
    def to_str(self, db): return db.esc_name(self.name)
128
129
def as_Name(value):
130
    if isinstance(value, Code): return value
131
    else: return Name(value)
132
133 2335 aaronmk
##### Literal values
134
135 2216 aaronmk
class Literal(Code):
136 3445 aaronmk
    def __init__(self, value):
137
        Code.__init__(self)
138
139
        self.value = value
140 2213 aaronmk
141
    def to_str(self, db): return db.esc_value(self.value)
142 2211 aaronmk
143 2400 aaronmk
def as_Value(value):
144
    if isinstance(value, Code): return value
145
    else: return Literal(value)
146
147 3429 aaronmk
def is_literal(value): return isinstance(value, Literal)
148 2216 aaronmk
149 3429 aaronmk
def is_null(value): return is_literal(value) and value.value == None
150
151 2711 aaronmk
##### Derived elements
152
153
src_self = object() # tells Col that it is its own source column
154
155
class Derived(Code):
156
    def __init__(self, srcs):
157 2712 aaronmk
        '''An element which was derived from some other element(s).
158 2711 aaronmk
        @param srcs See self.set_srcs()
159
        '''
160 3445 aaronmk
        Code.__init__(self)
161
162 2711 aaronmk
        self.set_srcs(srcs)
163
164 2713 aaronmk
    def set_srcs(self, srcs, overwrite=True):
165 2711 aaronmk
        '''
166
        @param srcs (self_type...)|src_self The element(s) this is derived from
167
        '''
168 2713 aaronmk
        if not overwrite and self.srcs != (): return # already set
169
170 2711 aaronmk
        if srcs == src_self: srcs = (self,)
171
        srcs = tuple(srcs) # make Col hashable
172
        self.srcs = srcs
173
174
    def _compare_on(self):
175
        compare_on = self.__dict__.copy()
176
        del compare_on['srcs'] # ignore
177
        return compare_on
178
179
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
180
181 2335 aaronmk
##### Tables
182
183 2712 aaronmk
class Table(Derived):
184 2991 aaronmk
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
185 2211 aaronmk
        '''
186
        @param schema str|None (for no schema)
187 2712 aaronmk
        @param srcs (Table...)|src_self See Derived.set_srcs()
188 2211 aaronmk
        '''
189 2712 aaronmk
        Derived.__init__(self, srcs)
190
191 3091 aaronmk
        if util.is_str(name): name = truncate(name)
192 2843 aaronmk
193 2211 aaronmk
        self.name = name
194
        self.schema = schema
195 2991 aaronmk
        self.is_temp = is_temp
196 3000 aaronmk
        self.index_cols = {}
197 2211 aaronmk
198 2348 aaronmk
    def to_str(self, db):
199
        str_ = ''
200 3088 aaronmk
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
201
        str_ += as_Name(self.name).to_str(db)
202 2348 aaronmk
        return str_
203 2336 aaronmk
204
    def to_Table(self): return self
205 3000 aaronmk
206
    def _compare_on(self):
207
        compare_on = Derived._compare_on(self)
208
        del compare_on['index_cols'] # ignore
209
        return compare_on
210 2211 aaronmk
211 2835 aaronmk
def is_underlying_table(table):
212
    return isinstance(table, Table) and table.to_Table() is table
213 2832 aaronmk
214 2902 aaronmk
class NoUnderlyingTableException(Exception): pass
215
216
def underlying_table(table):
217
    table = remove_table_rename(table)
218
    if not is_underlying_table(table): raise NoUnderlyingTableException
219
    return table
220
221 2776 aaronmk
def as_Table(table, schema=None):
222 2270 aaronmk
    if table == None or isinstance(table, Code): return table
223 2776 aaronmk
    else: return Table(table, schema)
224 2219 aaronmk
225 3101 aaronmk
def suffixed_table(table, suffix):
226 3128 aaronmk
    table = copy.copy(table) # don't modify input!
227
    table.name = concat(table.name, suffix)
228
    return table
229 2707 aaronmk
230 2336 aaronmk
class NamedTable(Table):
231
    def __init__(self, name, code, cols=None):
232
        Table.__init__(self, name)
233
234 3016 aaronmk
        code = as_Table(code)
235 2741 aaronmk
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
236 3020 aaronmk
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
237 2336 aaronmk
238
        self.code = code
239
        self.cols = cols
240
241
    def to_str(self, db):
242 3026 aaronmk
        str_ = self.code.to_str(db)
243
        if str_.find('\n') >= 0: whitespace = '\n'
244
        else: whitespace = ' '
245
        str_ += whitespace+'AS '+Table.to_str(self, db)
246 2742 aaronmk
        if self.cols != None:
247
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
248 2336 aaronmk
        return str_
249
250
    def to_Table(self): return Table(self.name)
251
252 2753 aaronmk
def remove_table_rename(table):
253
    if isinstance(table, NamedTable): table = table.code
254
    return table
255
256 2335 aaronmk
##### Columns
257
258 2711 aaronmk
class Col(Derived):
259 2701 aaronmk
    def __init__(self, name, table=None, srcs=()):
260 2211 aaronmk
        '''
261
        @param table Table|None (for no table)
262 2711 aaronmk
        @param srcs (Col...)|src_self See Derived.set_srcs()
263 2211 aaronmk
        '''
264 2711 aaronmk
        Derived.__init__(self, srcs)
265
266 3091 aaronmk
        if util.is_str(name): name = truncate(name)
267 2241 aaronmk
        if util.is_str(table): table = Table(table)
268 2211 aaronmk
        assert table == None or isinstance(table, Table)
269
270
        self.name = name
271
        self.table = table
272
273 2989 aaronmk
    def to_str(self, db, for_str=False):
274 3088 aaronmk
        str_ = as_Name(self.name).to_str(db)
275 2989 aaronmk
        if for_str: str_ = clean_name(str_)
276 2933 aaronmk
        if self.table != None:
277 2989 aaronmk
            table = self.table.to_Table()
278
            if for_str: str_ = concat(str(table), '.'+str_)
279
            else: str_ = table.to_str(db)+'.'+str_
280 2211 aaronmk
        return str_
281 2314 aaronmk
282 2989 aaronmk
    def __str__(self): return self.to_str(mockDb, for_str=True)
283 2933 aaronmk
284 2314 aaronmk
    def to_Col(self): return self
285 2211 aaronmk
286 2767 aaronmk
def is_table_col(col): return isinstance(col, Col) and col.table != None
287 2393 aaronmk
288 3000 aaronmk
def index_col(col):
289
    if not is_table_col(col): return None
290 3104 aaronmk
291
    table = col.table
292
    try: name = table.index_cols[col.name]
293
    except KeyError: return None
294
    else: return Col(name, table, col.srcs)
295 2999 aaronmk
296 3024 aaronmk
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
297 2996 aaronmk
298 2563 aaronmk
def as_Col(col, table=None, name=None):
299
    '''
300
    @param name If not None, any non-Col input will be renamed using NamedCol.
301
    '''
302
    if name != None:
303
        col = as_Value(col)
304
        if not isinstance(col, Col): col = NamedCol(name, col)
305 2333 aaronmk
306
    if isinstance(col, Code): return col
307 2260 aaronmk
    else: return Col(col, table)
308
309 3093 aaronmk
def with_table(col, table):
310 3105 aaronmk
    if isinstance(col, NamedCol): pass # doesn't take a table
311
    elif isinstance(col, FunctionCall):
312
        col = copy.deepcopy(col) # don't modify input!
313
        col.args[0].table = table
314 3493 aaronmk
    elif isinstance(col, Col):
315 3098 aaronmk
        col = copy.copy(col) # don't modify input!
316
        col.table = table
317 3093 aaronmk
    return col
318
319 3100 aaronmk
def with_default_table(col, table):
320 2747 aaronmk
    col = as_Col(col)
321 3100 aaronmk
    if col.table == None: col = with_table(col, table)
322 2747 aaronmk
    return col
323
324 2744 aaronmk
def set_cols_table(table, cols):
325
    table = as_Table(table)
326
327
    for i, col in enumerate(cols):
328
        col = cols[i] = as_Col(col)
329
        col.table = table
330
331 2401 aaronmk
def to_name_only_col(col, check_table=None):
332
    col = as_Col(col)
333 3020 aaronmk
    if not is_table_col(col): return col
334 2401 aaronmk
335
    if check_table != None:
336
        table = col.table
337
        assert table == None or table == check_table
338
    return Col(col.name)
339
340 2993 aaronmk
def suffixed_col(col, suffix):
341
    return Col(concat(col.name, suffix), col.table, col.srcs)
342
343 2323 aaronmk
class NamedCol(Col):
344 2229 aaronmk
    def __init__(self, name, code):
345 2310 aaronmk
        Col.__init__(self, name)
346
347 3016 aaronmk
        code = as_Value(code)
348 2229 aaronmk
349
        self.code = code
350
351
    def to_str(self, db):
352 2310 aaronmk
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
353 2314 aaronmk
354
    def to_Col(self): return Col(self.name)
355 2229 aaronmk
356 2462 aaronmk
def remove_col_rename(col):
357
    if isinstance(col, NamedCol): col = col.code
358
    return col
359
360 2830 aaronmk
def underlying_col(col):
361
    col = remove_col_rename(col)
362 2849 aaronmk
    if not isinstance(col, Col): raise NoUnderlyingTableException
363
364 2902 aaronmk
    return Col(col.name, underlying_table(col.table), col.srcs)
365 2830 aaronmk
366 2703 aaronmk
def wrap(wrap_func, value):
367
    '''Wraps a value, propagating any column renaming to the returned value.'''
368
    if isinstance(value, NamedCol):
369
        return NamedCol(value.name, wrap_func(value.code))
370
    else: return wrap_func(value)
371
372 2667 aaronmk
class ColDict(dicts.DictProxy):
373 2564 aaronmk
    '''A dict that automatically makes inserted entries Col objects'''
374
375 2645 aaronmk
    def __init__(self, db, keys_table, dict_={}):
376 3158 aaronmk
        dicts.DictProxy.__init__(self, OrderedDict())
377 2667 aaronmk
378 2645 aaronmk
        keys_table = as_Table(keys_table)
379
380 2642 aaronmk
        self.db = db
381 2641 aaronmk
        self.table = keys_table
382 2653 aaronmk
        self.update(dict_) # after setting vars because __setitem__() needs them
383 2641 aaronmk
384 2667 aaronmk
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
385 2655 aaronmk
386 2667 aaronmk
    def __getitem__(self, key):
387
        return dicts.DictProxy.__getitem__(self, self._key(key))
388 2653 aaronmk
389 2564 aaronmk
    def __setitem__(self, key, value):
390 2642 aaronmk
        key = self._key(key)
391 2819 aaronmk
        if value == None: value = self.db.col_info(key).default
392 2667 aaronmk
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
393 2564 aaronmk
394 2641 aaronmk
    def _key(self, key): return as_Col(key, self.table)
395 2564 aaronmk
396 3491 aaronmk
##### Composite types
397
398
class List(Code):
399
    def __init__(self, *values):
400
        Code.__init__(self)
401
402
        self.values = values
403
404
    def to_str(self, db):
405
        return '('+(', '.join((v.to_str(db) for v in self.values)))+')'
406
407 3492 aaronmk
class Tuple(List):
408
    def to_str(self, db): return 'ROW'+List.to_str(self, db)
409
410 3469 aaronmk
#### Definitions
411
412
class TypedCol(Col):
413
    def __init__(self, name, type_, default=None, nullable=True,
414
        constraints=None):
415
        assert default == None or isinstance(default, Code)
416
417
        Col.__init__(self, name)
418
419
        self.type = type_
420
        self.default = default
421
        self.nullable = nullable
422
        self.constraints = constraints
423
424
    def to_str(self, db):
425 3481 aaronmk
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
426 3469 aaronmk
        if not self.nullable: str_ += ' NOT NULL'
427
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
428
        if self.constraints != None: str_ += ' '+self.constraints
429
        return str_
430
431
    def to_Col(self): return Col(self.name)
432
433 3488 aaronmk
class SetOf(Code):
434
    def __init__(self, type_):
435
        Code.__init__(self)
436
437
        self.type = type_
438
439
    def to_str(self, db):
440
        return 'SETOF '+self.type.to_str(db)
441
442 3483 aaronmk
class RowType(Code):
443
    def __init__(self, table):
444
        Code.__init__(self)
445
446
        self.table = table
447
448
    def to_str(self, db):
449
        return self.table.to_str(db)+'%ROWTYPE'
450
451 3485 aaronmk
class ColType(Code):
452
    def __init__(self, col):
453
        Code.__init__(self)
454
455
        self.col = col
456
457
    def to_str(self, db):
458
        return self.col.to_str(db)+'%TYPE'
459
460 2524 aaronmk
##### Functions
461
462 2912 aaronmk
Function = Table
463 2911 aaronmk
as_Function = as_Table
464
465 2691 aaronmk
class InternalFunction(CustomCode): pass
466
467 3442 aaronmk
#### Calls
468
469 2941 aaronmk
class NamedArg(NamedCol):
470
    def __init__(self, name, value):
471
        NamedCol.__init__(self, name, value)
472
473
    def to_str(self, db):
474
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
475
476 2524 aaronmk
class FunctionCall(Code):
477 2941 aaronmk
    def __init__(self, function, *args, **kw_args):
478 2524 aaronmk
        '''
479 2690 aaronmk
        @param args [Code|literal-value...] The function's arguments
480 2524 aaronmk
        '''
481 3445 aaronmk
        Code.__init__(self)
482
483 3016 aaronmk
        function = as_Function(function)
484 2941 aaronmk
        def filter_(arg): return remove_col_rename(as_Value(arg))
485
        args = map(filter_, args)
486
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
487 2524 aaronmk
488
        self.function = function
489
        self.args = args
490
491
    def to_str(self, db):
492
        args_str = ', '.join((v.to_str(db) for v in self.args))
493
        return self.function.to_str(db)+'('+args_str+')'
494
495 2533 aaronmk
def wrap_in_func(function, value):
496
    '''Wraps a value inside a function call.
497
    Propagates any column renaming to the returned value.
498
    '''
499 2703 aaronmk
    return wrap(lambda v: FunctionCall(function, v), value)
500 2533 aaronmk
501 2561 aaronmk
def unwrap_func_call(func_call, check_name=None):
502
    '''Unwraps any function call to its first argument.
503
    Also removes any column renaming.
504
    '''
505
    func_call = remove_col_rename(func_call)
506
    if not isinstance(func_call, FunctionCall): return func_call
507
508
    if check_name != None:
509
        name = func_call.function.name
510
        assert name == None or name == check_name
511
    return func_call.args[0]
512
513 3442 aaronmk
#### Definitions
514
515
class FunctionDef(Code):
516 3471 aaronmk
    def __init__(self, function, return_type, body, params=[], modifiers=None):
517 3445 aaronmk
        Code.__init__(self)
518
519 3487 aaronmk
        return_type = as_Code(return_type)
520 3444 aaronmk
        body = as_Code(body)
521
522 3442 aaronmk
        self.function = function
523
        self.return_type = return_type
524
        self.body = body
525 3471 aaronmk
        self.params = params
526 3456 aaronmk
        self.modifiers = modifiers
527 3442 aaronmk
528
    def to_str(self, db):
529 3487 aaronmk
        params_str = (', '.join((p.to_str(db) for p in self.params)))
530 3442 aaronmk
        str_ = '''\
531 3471 aaronmk
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
532 3487 aaronmk
RETURNS '''+self.return_type.to_str(db)+'''
533 3448 aaronmk
LANGUAGE '''+self.body.lang+'''
534 3456 aaronmk
'''
535
        if self.modifiers != None: str_ += self.modifiers+'\n'
536
        str_ += '''\
537 3442 aaronmk
AS $$
538 3444 aaronmk
'''+self.body.to_str(db)+'''
539 3442 aaronmk
$$;
540
'''
541
        return str_
542
543 3469 aaronmk
class FunctionParam(TypedCol):
544
    def __init__(self, name, type_, default=None, out=False):
545
        TypedCol.__init__(self, name, type_, default)
546
547
        self.out = out
548
549
    def to_str(self, db):
550
        str_ = TypedCol.to_str(self, db)
551
        if self.out: str_ = 'OUT '+str_
552
        return str_
553
554
    def to_Col(self): return Col(self.name)
555
556 3454 aaronmk
### PL/pgSQL
557
558 3496 aaronmk
class ReturnQuery(Code):
559
    def __init__(self, query):
560
        Code.__init__(self)
561
562
        query = as_Code(query)
563
564
        self.query = query
565
566
    def to_str(self, db):
567
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
568
569 3509 aaronmk
class BaseExcHandler(BasicObject):
570
    def to_str(self, db, body): raise NotImplementedError()
571
572
class ExcHandler(BaseExcHandler):
573 3454 aaronmk
    def __init__(self, exc, handler=None):
574
        if handler != None: handler = as_Code(handler)
575
576
        self.exc = exc
577
        self.handler = handler
578
579
    def to_str(self, db, body):
580
        body = as_Code(body)
581
582 3467 aaronmk
        if self.handler != None:
583
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
584 3463 aaronmk
        else: handler_str = ' NULL;\n'
585 3454 aaronmk
586
        str_ = '''\
587
BEGIN
588 3467 aaronmk
'''+strings.indent(body.to_str(db))+'''\
589 3454 aaronmk
EXCEPTION
590 3463 aaronmk
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
591 3454 aaronmk
END;\
592
'''
593
        return str_
594 3457 aaronmk
595
    def __repr__(self): return self.to_str(mockDb, '<body>')
596 3454 aaronmk
597 3503 aaronmk
class ExcToWarning(Code):
598
    def __init__(self, return_):
599
        '''
600
        @param return_ Statement to return a default value in case of error
601
        '''
602
        Code.__init__(self)
603
604
        return_ = as_Code(return_)
605
606
        self.return_ = return_
607
608
    def to_str(self, db):
609
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
610
611 3454 aaronmk
unique_violation_handler = ExcHandler('unique_violation')
612
613 3468 aaronmk
plpythonu_error_handler = ExcHandler('internal_error', '''\
614
RAISE data_exception USING MESSAGE =
615
    regexp_replace(SQLERRM, E'^PL/Python: \\w+: ', '');
616
''')
617
618 3505 aaronmk
def data_exception_handler(handler):
619
    return ExcHandler('data_exception', handler)
620
621 3449 aaronmk
class RowExcIgnore(Code):
622
    def __init__(self, row_type, select_query, with_row, cols=None,
623 3455 aaronmk
        exc_handler=unique_violation_handler, row_var='row'):
624 3449 aaronmk
        Code.__init__(self, lang='plpgsql')
625
626 3482 aaronmk
        row_type = as_Code(row_type)
627 3449 aaronmk
        select_query = as_Code(select_query)
628
        with_row = as_Code(with_row)
629 3452 aaronmk
        row_var = as_Table(row_var)
630 3449 aaronmk
631
        self.row_type = row_type
632
        self.select_query = select_query
633
        self.with_row = with_row
634
        self.cols = cols
635 3455 aaronmk
        self.exc_handler = exc_handler
636 3452 aaronmk
        self.row_var = row_var
637 3449 aaronmk
638
    def to_str(self, db):
639 3452 aaronmk
        if self.cols == None: row_vars = [self.row_var]
640
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
641 3449 aaronmk
642
        str_ = '''\
643
DECLARE
644 3482 aaronmk
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
645 3449 aaronmk
BEGIN
646 3467 aaronmk
    /* Need an EXCEPTION block for each individual row because "When
647
    an error is caught by an EXCEPTION clause, [...] all changes to
648
    persistent database state within the block are rolled back."
649
    This is unfortunate because "A block containing an EXCEPTION
650
    clause is significantly more expensive to enter and exit than a
651
    block without one."
652 3449 aaronmk
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
653
#PLPGSQL-ERROR-TRAPPING)
654
    */
655
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
656 3467 aaronmk
'''+strings.indent(self.select_query.to_str(db), 2)+'''\
657 3449 aaronmk
    LOOP
658 3467 aaronmk
'''+strings.indent(self.exc_handler.to_str(db, self.with_row), 2)+'''\
659 3449 aaronmk
    END LOOP;
660
END;\
661
'''
662
        return str_
663
664 2986 aaronmk
##### Casts
665
666
class Cast(FunctionCall):
667
    def __init__(self, type_, value):
668
        value = as_Value(value)
669
670
        self.type_ = type_
671
        self.value = value
672
673
    def to_str(self, db):
674
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_+')'
675
676 3354 aaronmk
def cast_literal(value):
677 3429 aaronmk
    if not is_literal(value): return value
678 3354 aaronmk
679
    if util.is_str(value.value): value = Cast('text', value)
680
    return value
681
682 2335 aaronmk
##### Conditions
683 2259 aaronmk
684 3350 aaronmk
class NotCond(Code):
685
    def __init__(self, cond):
686 3445 aaronmk
        Code.__init__(self)
687
688 3350 aaronmk
        self.cond = cond
689
690
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
691
692 2398 aaronmk
class ColValueCond(Code):
693
    def __init__(self, col, value):
694 3445 aaronmk
        Code.__init__(self)
695
696 2398 aaronmk
        value = as_ValueCond(value)
697
698
        self.col = col
699
        self.value = value
700
701
    def to_str(self, db): return self.value.to_str(db, self.col)
702
703 2577 aaronmk
def combine_conds(conds, keyword=None):
704
    '''
705
    @param keyword The keyword to add before the conditions, if any
706
    '''
707
    str_ = ''
708
    if keyword != None:
709
        if conds == []: whitespace = ''
710
        elif len(conds) == 1: whitespace = ' '
711
        else: whitespace = '\n'
712
        str_ += keyword+whitespace
713
714
    str_ += '\nAND '.join(conds)
715
    return str_
716
717 2398 aaronmk
##### Condition column comparisons
718
719 2514 aaronmk
class ValueCond(BasicObject):
720 2213 aaronmk
    def __init__(self, value):
721 2858 aaronmk
        value = remove_col_rename(as_Value(value))
722 2213 aaronmk
723
        self.value = value
724 2214 aaronmk
725 2216 aaronmk
    def to_str(self, db, left_value):
726 2214 aaronmk
        '''
727 2216 aaronmk
        @param left_value The Code object that the condition is being applied on
728 2214 aaronmk
        '''
729
        raise NotImplemented()
730 2228 aaronmk
731 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
732 2211 aaronmk
733
class CompareCond(ValueCond):
734
    def __init__(self, value, operator='='):
735 2222 aaronmk
        '''
736
        @param operator By default, compares NULL values literally. Use '~=' or
737
            '~!=' to pass NULLs through.
738
        '''
739 2211 aaronmk
        ValueCond.__init__(self, value)
740
        self.operator = operator
741
742 2216 aaronmk
    def to_str(self, db, left_value):
743 2858 aaronmk
        left_value = remove_col_rename(as_Col(left_value))
744 2216 aaronmk
745 2222 aaronmk
        right_value = self.value
746
747
        # Parse operator
748 2216 aaronmk
        operator = self.operator
749 2222 aaronmk
        passthru_null_ref = [False]
750
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
751
        neg_ref = [False]
752
        operator = strings.remove_prefix('!', operator, neg_ref)
753 2844 aaronmk
        equals = operator.endswith('=') # also includes <=, >=
754 2222 aaronmk
755 2825 aaronmk
        # Handle nullable columns
756
        check_null = False
757 2844 aaronmk
        if not passthru_null_ref[0]: # NULLs compare equal
758 2857 aaronmk
            try: left_value = ensure_not_null(db, left_value)
759 2844 aaronmk
            except ensure_not_null_excs: # fall back to alternate method
760
                check_null = equals and isinstance(right_value, Col)
761 2837 aaronmk
            else:
762 2857 aaronmk
                if isinstance(left_value, EnsureNotNull):
763
                    right_value = ensure_not_null(db, right_value,
764
                        left_value.type) # apply same function to both sides
765 2825 aaronmk
766 2844 aaronmk
        if equals and is_null(right_value): operator = 'IS'
767
768 2825 aaronmk
        left = left_value.to_str(db)
769
        right = right_value.to_str(db)
770
771 2222 aaronmk
        # Create str
772
        str_ = left+' '+operator+' '+right
773 2825 aaronmk
        if check_null:
774 2578 aaronmk
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
775
        if neg_ref[0]: str_ = 'NOT '+str_
776 2222 aaronmk
        return str_
777 2216 aaronmk
778 2260 aaronmk
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
779
assume_literal = object()
780
781
def as_ValueCond(value, default_table=assume_literal):
782
    if not isinstance(value, ValueCond):
783
        if default_table is not assume_literal:
784 2748 aaronmk
            value = with_default_table(value, default_table)
785 2260 aaronmk
        return CompareCond(value)
786 2216 aaronmk
    else: return value
787 2219 aaronmk
788 2335 aaronmk
##### Joins
789
790 2352 aaronmk
join_same = object() # tells Join the left and right columns have the same name
791 2260 aaronmk
792 2353 aaronmk
# Tells Join the left and right columns have the same name and are never NULL
793
join_same_not_null = object()
794
795 2260 aaronmk
filter_out = object() # tells Join to filter out rows that match the join
796
797 2514 aaronmk
class Join(BasicObject):
798 2746 aaronmk
    def __init__(self, table, mapping={}, type_=None):
799 2260 aaronmk
        '''
800
        @param mapping dict(right_table_col=left_table_col, ...)
801 2352 aaronmk
            * if left_table_col is join_same: left_table_col = right_table_col
802 2353 aaronmk
              * Note that right_table_col must be a string
803
            * if left_table_col is join_same_not_null:
804
              left_table_col = right_table_col and both have NOT NULL constraint
805
              * Note that right_table_col must be a string
806 2260 aaronmk
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
807
            * filter_out: equivalent to 'LEFT' with the query filtered by
808
              `table_pkey IS NULL` (indicating no match)
809
        '''
810
        if util.is_str(table): table = Table(table)
811
        assert type_ == None or util.is_str(type_) or type_ is filter_out
812
813
        self.table = table
814
        self.mapping = mapping
815
        self.type_ = type_
816
817 2749 aaronmk
    def to_str(self, db, left_table_):
818 2260 aaronmk
        def join(entry):
819
            '''Parses non-USING joins'''
820
            right_table_col, left_table_col = entry
821
822 2353 aaronmk
            # Switch order (right_table_col is on the left in the comparison)
823
            left = right_table_col
824
            right = left_table_col
825 2749 aaronmk
            left_table = self.table
826
            right_table = left_table_
827 2353 aaronmk
828 2747 aaronmk
            # Parse left side
829 2748 aaronmk
            left = with_default_table(left, left_table)
830 2747 aaronmk
831 2260 aaronmk
            # Parse special values
832 2747 aaronmk
            left_on_right = Col(left.name, right_table)
833
            if right is join_same: right = left_on_right
834 2353 aaronmk
            elif right is join_same_not_null:
835 2747 aaronmk
                right = CompareCond(left_on_right, '~=')
836 2260 aaronmk
837 2747 aaronmk
            # Parse right side
838 2353 aaronmk
            right = as_ValueCond(right, right_table)
839 2747 aaronmk
840
            return right.to_str(db, left)
841 2260 aaronmk
842 2265 aaronmk
        # Create join condition
843
        type_ = self.type_
844 2276 aaronmk
        joins = self.mapping
845 2746 aaronmk
        if joins == {}: join_cond = None
846
        elif type_ is not filter_out and reduce(operator.and_,
847 2460 aaronmk
            (v is join_same_not_null for v in joins.itervalues())):
848 2260 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
849 2747 aaronmk
            cols = map(to_name_only_col, joins.iterkeys())
850
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
851 2757 aaronmk
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
852 2260 aaronmk
853 2757 aaronmk
        if isinstance(self.table, NamedTable): whitespace = '\n'
854
        else: whitespace = ' '
855
856 2260 aaronmk
        # Create join
857
        if type_ is filter_out: type_ = 'LEFT'
858 2266 aaronmk
        str_ = ''
859
        if type_ != None: str_ += type_+' '
860 2757 aaronmk
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
861
        if join_cond != None: str_ += whitespace+join_cond
862 2266 aaronmk
        return str_
863 2349 aaronmk
864 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
865 2424 aaronmk
866
##### Value exprs
867
868 3089 aaronmk
all_cols = CustomCode('*')
869
870 2737 aaronmk
default = CustomCode('DEFAULT')
871
872 3090 aaronmk
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
873 2674 aaronmk
874 3061 aaronmk
class Coalesce(FunctionCall):
875
    def __init__(self, *args):
876
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
877 3060 aaronmk
878 3062 aaronmk
class Nullif(FunctionCall):
879
    def __init__(self, *args):
880
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
881
882 2850 aaronmk
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
883 2958 aaronmk
null_sentinels = {
884
    'character varying': r'\N',
885
    'double precision': 'NaN',
886
    'integer': 2147483647,
887
    'text': r'\N',
888
    'timestamp with time zone': 'infinity'
889
}
890 2692 aaronmk
891 3061 aaronmk
class EnsureNotNull(Coalesce):
892 2850 aaronmk
    def __init__(self, value, type_):
893 3061 aaronmk
        Coalesce.__init__(self, as_Col(value),
894 2988 aaronmk
            Cast(type_, null_sentinels[type_]))
895 2850 aaronmk
896
        self.type = type_
897 3001 aaronmk
898
    def to_str(self, db):
899
        col = self.args[0]
900
        index_col_ = index_col(col)
901
        if index_col_ != None: return index_col_.to_str(db)
902 3061 aaronmk
        return Coalesce.to_str(self, db)
903 2850 aaronmk
904 2737 aaronmk
##### Table exprs
905
906
class Values(Code):
907
    def __init__(self, values):
908 2739 aaronmk
        '''
909
        @param values [...]|[[...], ...] Can be one or multiple rows.
910
        '''
911 3445 aaronmk
        Code.__init__(self)
912
913 2739 aaronmk
        rows = values
914
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
915
            rows = [values]
916
        for i, row in enumerate(rows):
917
            rows[i] = map(remove_col_rename, map(as_Value, row))
918 2737 aaronmk
919 2739 aaronmk
        self.rows = rows
920 2737 aaronmk
921
    def to_str(self, db):
922 3491 aaronmk
        return 'VALUES '+(', '.join((List(*r).to_str(db) for r in self.rows)))
923 2737 aaronmk
924 2740 aaronmk
def NamedValues(name, cols, values):
925 2745 aaronmk
    '''
926 3048 aaronmk
    @param cols None|[...]
927 2745 aaronmk
    @post `cols` will be changed to Col objects with the table set to `name`.
928
    '''
929 2834 aaronmk
    table = NamedTable(name, Values(values), cols)
930 3048 aaronmk
    if cols != None: set_cols_table(table, cols)
931 2834 aaronmk
    return table
932 2740 aaronmk
933 2674 aaronmk
##### Database structure
934
935 2840 aaronmk
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
936
937 2851 aaronmk
def ensure_not_null(db, col, type_=None):
938 2840 aaronmk
    '''
939 2855 aaronmk
    @param col If type_ is not set, must have an underlying column.
940 2851 aaronmk
    @param type_ If set, overrides the underlying column's type.
941 2840 aaronmk
    @return EnsureNotNull|Col
942
    @throws ensure_not_null_excs
943
    '''
944 2855 aaronmk
    nullable = True
945
    try: typed_col = db.col_info(underlying_col(col))
946
    except NoUnderlyingTableException:
947 3355 aaronmk
        col = remove_col_rename(col)
948 3429 aaronmk
        if is_literal(col) and not is_null(col): nullable = False
949 3355 aaronmk
        elif type_ == None: raise
950 2855 aaronmk
    else:
951
        if type_ == None: type_ = typed_col.type
952
        nullable = typed_col.nullable
953
954 2953 aaronmk
    if nullable:
955
        try: col = EnsureNotNull(col, type_)
956
        except KeyError, e:
957
            # Warn of no null sentinel for type, even if caller catches error
958
            warnings.warn(UserWarning(exc.str_(e)))
959
            raise
960
961 2840 aaronmk
    return col