Project

General

Profile

1 2211 aaronmk
# SQL code generation
2
3 2748 aaronmk
import copy
4 3518 aaronmk
import itertools
5 2276 aaronmk
import operator
6 9383 aaronmk
from collections import OrderedDict
7 2568 aaronmk
import re
8 2653 aaronmk
import UserDict
9 2953 aaronmk
import warnings
10 2276 aaronmk
11 2667 aaronmk
import dicts
12 2953 aaronmk
import exc
13 2701 aaronmk
import iters
14
import lists
15 2360 aaronmk
import objects
16 5367 aaronmk
import regexp
17 2222 aaronmk
import strings
18 2227 aaronmk
import util
19 2211 aaronmk
20 2587 aaronmk
##### Names
21 2499 aaronmk
22 2608 aaronmk
identifier_max_len = 63 # works for both PostgreSQL and MySQL
23 2587 aaronmk
24 2932 aaronmk
def concat(str_, suffix):
25 2609 aaronmk
    '''Preserves version so that it won't be truncated off the string, leading
26
    to collisions.'''
27 2613 aaronmk
    # Preserve version
28 5552 aaronmk
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)+(?:::[\w ]+)*)$', str_)
29 2985 aaronmk
    if match:
30
        str_, old_suffix = match.groups()
31
        suffix = old_suffix+suffix
32 2613 aaronmk
33 2932 aaronmk
    return strings.concat(str_, suffix, identifier_max_len)
34 2587 aaronmk
35 2932 aaronmk
def truncate(str_): return concat(str_, '')
36 2842 aaronmk
37 2575 aaronmk
def is_safe_name(name):
38 2583 aaronmk
    '''A name is safe *and unambiguous* if it:
39
    * contains only *lowercase* word (\w) characters
40
    * doesn't start with a digit
41
    * contains "_", so that it's not a keyword
42 2984 aaronmk
    '''
43
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
44 2568 aaronmk
45 2499 aaronmk
def esc_name(name, quote='"'):
46
    return quote + name.replace(quote, quote+quote) + quote
47
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
48
49 3320 aaronmk
def unesc_name(name, quote='"'):
50
    removed_ref = [False]
51
    name = strings.remove_prefix(quote, name, removed_ref)
52
    if removed_ref[0]:
53
        name = strings.remove_suffix(quote, name, removed_ref)
54
        assert removed_ref[0]
55
        name = name.replace(quote+quote, quote)
56
    return name
57
58 2513 aaronmk
def clean_name(name): return name.replace('"', '').replace('`', '')
59
60 3041 aaronmk
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
61
62 3182 aaronmk
def lstrip(str_):
63
    '''Also removes comments.'''
64
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
65
    return str_.lstrip()
66
67 2659 aaronmk
##### General SQL code objects
68 2219 aaronmk
69 2349 aaronmk
class MockDb:
70 2503 aaronmk
    def esc_value(self, value): return strings.repr_no_u(value)
71 2349 aaronmk
72 2499 aaronmk
    def esc_name(self, name): return esc_name(name)
73 2859 aaronmk
74
    def col_info(self, col):
75
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
76
77 2349 aaronmk
mockDb = MockDb()
78
79 2514 aaronmk
class BasicObject(objects.BasicObject):
80
    def __str__(self): return clean_name(strings.repr_no_u(self))
81
82 2659 aaronmk
##### Unparameterized code objects
83
84 2514 aaronmk
class Code(BasicObject):
85 3446 aaronmk
    def __init__(self, lang='sql'):
86
        self.lang = lang
87 3445 aaronmk
88 2658 aaronmk
    def to_str(self, db): raise NotImplementedError()
89 2349 aaronmk
90 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb)
91 2211 aaronmk
92 2269 aaronmk
class CustomCode(Code):
93 3445 aaronmk
    def __init__(self, str_):
94
        Code.__init__(self)
95
96
        self.str_ = str_
97 2256 aaronmk
98
    def to_str(self, db): return self.str_
99
100 2815 aaronmk
def as_Code(value, db=None):
101
    '''
102
    @param db If set, runs db.std_code() on the value.
103
    '''
104 3447 aaronmk
    if isinstance(value, Code): return value
105
106 2815 aaronmk
    if util.is_str(value):
107
        if db != None: value = db.std_code(value)
108
        return CustomCode(value)
109 2659 aaronmk
    else: return Literal(value)
110
111 2540 aaronmk
class Expr(Code):
112 3445 aaronmk
    def __init__(self, expr):
113
        Code.__init__(self)
114
115
        self.expr = expr
116 2540 aaronmk
117
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
118
119 3086 aaronmk
##### Names
120
121
class Name(Code):
122 3087 aaronmk
    def __init__(self, name):
123 3445 aaronmk
        Code.__init__(self)
124
125 3087 aaronmk
        name = truncate(name)
126
127
        self.name = name
128 3086 aaronmk
129
    def to_str(self, db): return db.esc_name(self.name)
130
131
def as_Name(value):
132
    if isinstance(value, Code): return value
133
    else: return Name(value)
134
135 2335 aaronmk
##### Literal values
136
137 3519 aaronmk
#### Primitives
138
139 2216 aaronmk
class Literal(Code):
140 3445 aaronmk
    def __init__(self, value):
141
        Code.__init__(self)
142
143
        self.value = value
144 2213 aaronmk
145
    def to_str(self, db): return db.esc_value(self.value)
146 2211 aaronmk
147 2400 aaronmk
def as_Value(value):
148
    if isinstance(value, Code): return value
149
    else: return Literal(value)
150
151 6225 aaronmk
def get_value(value):
152
    '''Unwraps a Literal's value'''
153
    value = remove_col_rename(value)
154
    if isinstance(value, Literal): return value.value
155
    else:
156
        assert not isinstance(value, Code)
157
        return value
158
159 3429 aaronmk
def is_literal(value): return isinstance(value, Literal)
160 2216 aaronmk
161 3429 aaronmk
def is_null(value): return is_literal(value) and value.value == None
162
163 3519 aaronmk
#### Composites
164
165 3521 aaronmk
class List(Code):
166
    def __init__(self, values):
167 3519 aaronmk
        Code.__init__(self)
168
169
        self.values = values
170
171 3521 aaronmk
    def to_str(self, db): return ', '.join((v.to_str(db) for v in self.values))
172 3519 aaronmk
173 3521 aaronmk
class Tuple(List):
174
    def __init__(self, *values):
175
        List.__init__(self, values)
176
177
    def to_str(self, db): return '('+List.to_str(self, db)+')'
178
179 3520 aaronmk
class Row(Tuple):
180
    def to_str(self, db): return 'ROW'+Tuple.to_str(self, db)
181 3519 aaronmk
182 3522 aaronmk
### Arrays
183
184
class Array(List):
185
    def __init__(self, values):
186
        values = map(remove_col_rename, values)
187
188
        List.__init__(self, values)
189
190
    def to_str(self, db): return 'ARRAY['+List.to_str(self, db)+']'
191
192
def to_Array(value):
193
    if isinstance(value, Array): return value
194
    return Array(lists.mk_seq(value))
195
196 2711 aaronmk
##### Derived elements
197
198
src_self = object() # tells Col that it is its own source column
199
200
class Derived(Code):
201
    def __init__(self, srcs):
202 2712 aaronmk
        '''An element which was derived from some other element(s).
203 2711 aaronmk
        @param srcs See self.set_srcs()
204
        '''
205 3445 aaronmk
        Code.__init__(self)
206
207 2711 aaronmk
        self.set_srcs(srcs)
208
209 2713 aaronmk
    def set_srcs(self, srcs, overwrite=True):
210 2711 aaronmk
        '''
211
        @param srcs (self_type...)|src_self The element(s) this is derived from
212
        '''
213 2713 aaronmk
        if not overwrite and self.srcs != (): return # already set
214
215 2711 aaronmk
        if srcs == src_self: srcs = (self,)
216
        srcs = tuple(srcs) # make Col hashable
217
        self.srcs = srcs
218
219
    def _compare_on(self):
220
        compare_on = self.__dict__.copy()
221
        del compare_on['srcs'] # ignore
222
        return compare_on
223
224
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
225
226 2335 aaronmk
##### Tables
227
228 2712 aaronmk
class Table(Derived):
229 2991 aaronmk
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
230 2211 aaronmk
        '''
231
        @param schema str|None (for no schema)
232 2712 aaronmk
        @param srcs (Table...)|src_self See Derived.set_srcs()
233 2211 aaronmk
        '''
234 2712 aaronmk
        Derived.__init__(self, srcs)
235
236 3091 aaronmk
        if util.is_str(name): name = truncate(name)
237 2843 aaronmk
238 2211 aaronmk
        self.name = name
239
        self.schema = schema
240 2991 aaronmk
        self.is_temp = is_temp
241 5524 aaronmk
        self.order_by = None
242 3000 aaronmk
        self.index_cols = {}
243 2211 aaronmk
244 2348 aaronmk
    def to_str(self, db):
245
        str_ = ''
246 3088 aaronmk
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
247
        str_ += as_Name(self.name).to_str(db)
248 2348 aaronmk
        return str_
249 2336 aaronmk
250
    def to_Table(self): return self
251 3000 aaronmk
252
    def _compare_on(self):
253
        compare_on = Derived._compare_on(self)
254 5524 aaronmk
        del compare_on['order_by'] # ignore
255 3000 aaronmk
        del compare_on['index_cols'] # ignore
256
        return compare_on
257 2211 aaronmk
258 2835 aaronmk
def is_underlying_table(table):
259
    return isinstance(table, Table) and table.to_Table() is table
260 2832 aaronmk
261 10186 aaronmk
def table2regclass_text(db, table):
262
    assert isinstance(table, Table)
263
    return db.esc_value(table.to_str(db))
264
265 4114 aaronmk
class NoUnderlyingTableException(Exception):
266
    def __init__(self, ref):
267
        Exception.__init__(self, 'for: '+strings.as_tt(strings.urepr(ref)))
268
        self.ref = ref
269 2902 aaronmk
270
def underlying_table(table):
271
    table = remove_table_rename(table)
272 3532 aaronmk
    if table != None and table.srcs:
273
        table, = table.srcs # for derived tables or row vars
274 4114 aaronmk
    if not is_underlying_table(table): raise NoUnderlyingTableException(table)
275 2902 aaronmk
    return table
276
277 2776 aaronmk
def as_Table(table, schema=None):
278 2270 aaronmk
    if table == None or isinstance(table, Code): return table
279 2776 aaronmk
    else: return Table(table, schema)
280 2219 aaronmk
281 3101 aaronmk
def suffixed_table(table, suffix):
282 3128 aaronmk
    table = copy.copy(table) # don't modify input!
283
    table.name = concat(table.name, suffix)
284
    return table
285 2707 aaronmk
286 2336 aaronmk
class NamedTable(Table):
287
    def __init__(self, name, code, cols=None):
288
        Table.__init__(self, name)
289
290 3016 aaronmk
        code = as_Table(code)
291 2741 aaronmk
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
292 3020 aaronmk
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
293 2336 aaronmk
294
        self.code = code
295
        self.cols = cols
296
297
    def to_str(self, db):
298 3026 aaronmk
        str_ = self.code.to_str(db)
299
        if str_.find('\n') >= 0: whitespace = '\n'
300
        else: whitespace = ' '
301
        str_ += whitespace+'AS '+Table.to_str(self, db)
302 2742 aaronmk
        if self.cols != None:
303
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
304 2336 aaronmk
        return str_
305
306
    def to_Table(self): return Table(self.name)
307
308 2753 aaronmk
def remove_table_rename(table):
309
    if isinstance(table, NamedTable): table = table.code
310
    return table
311
312 2335 aaronmk
##### Columns
313
314 2711 aaronmk
class Col(Derived):
315 2701 aaronmk
    def __init__(self, name, table=None, srcs=()):
316 2211 aaronmk
        '''
317
        @param table Table|None (for no table)
318 2711 aaronmk
        @param srcs (Col...)|src_self See Derived.set_srcs()
319 2211 aaronmk
        '''
320 2711 aaronmk
        Derived.__init__(self, srcs)
321
322 3091 aaronmk
        if util.is_str(name): name = truncate(name)
323 2241 aaronmk
        if util.is_str(table): table = Table(table)
324 2211 aaronmk
        assert table == None or isinstance(table, Table)
325
326
        self.name = name
327
        self.table = table
328
329 2989 aaronmk
    def to_str(self, db, for_str=False):
330 3088 aaronmk
        str_ = as_Name(self.name).to_str(db)
331 2989 aaronmk
        if for_str: str_ = clean_name(str_)
332 2933 aaronmk
        if self.table != None:
333 2989 aaronmk
            table = self.table.to_Table()
334 3750 aaronmk
            if for_str: str_ = concat(strings.ustr(table), '.'+str_)
335 2989 aaronmk
            else: str_ = table.to_str(db)+'.'+str_
336 2211 aaronmk
        return str_
337 2314 aaronmk
338 2989 aaronmk
    def __str__(self): return self.to_str(mockDb, for_str=True)
339 2933 aaronmk
340 2314 aaronmk
    def to_Col(self): return self
341 2211 aaronmk
342 3512 aaronmk
def is_col(col): return isinstance(col, Col)
343 2393 aaronmk
344 3512 aaronmk
def is_table_col(col): return is_col(col) and col.table != None
345
346 10138 aaronmk
def col2col_ref(db, col):
347
    assert isinstance(col, Col)
348
    return db.esc_value((col.table.to_str(db), col.name))
349
350 3000 aaronmk
def index_col(col):
351
    if not is_table_col(col): return None
352 3104 aaronmk
353
    table = col.table
354
    try: name = table.index_cols[col.name]
355
    except KeyError: return None
356
    else: return Col(name, table, col.srcs)
357 2999 aaronmk
358 3024 aaronmk
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
359 2996 aaronmk
360 2563 aaronmk
def as_Col(col, table=None, name=None):
361
    '''
362
    @param name If not None, any non-Col input will be renamed using NamedCol.
363
    '''
364
    if name != None:
365
        col = as_Value(col)
366
        if not isinstance(col, Col): col = NamedCol(name, col)
367 2333 aaronmk
368
    if isinstance(col, Code): return col
369 3513 aaronmk
    elif util.is_str(col): return Col(col, table)
370
    else: return Literal(col)
371 2260 aaronmk
372 3093 aaronmk
def with_table(col, table):
373 3105 aaronmk
    if isinstance(col, NamedCol): pass # doesn't take a table
374
    elif isinstance(col, FunctionCall):
375
        col = copy.deepcopy(col) # don't modify input!
376
        col.args[0].table = table
377 3493 aaronmk
    elif isinstance(col, Col):
378 3098 aaronmk
        col = copy.copy(col) # don't modify input!
379
        col.table = table
380 3093 aaronmk
    return col
381
382 3100 aaronmk
def with_default_table(col, table):
383 2747 aaronmk
    col = as_Col(col)
384 3100 aaronmk
    if col.table == None: col = with_table(col, table)
385 2747 aaronmk
    return col
386
387 2744 aaronmk
def set_cols_table(table, cols):
388
    table = as_Table(table)
389
390
    for i, col in enumerate(cols):
391
        col = cols[i] = as_Col(col)
392
        col.table = table
393
394 2401 aaronmk
def to_name_only_col(col, check_table=None):
395
    col = as_Col(col)
396 3020 aaronmk
    if not is_table_col(col): return col
397 2401 aaronmk
398
    if check_table != None:
399
        table = col.table
400
        assert table == None or table == check_table
401
    return Col(col.name)
402
403 2993 aaronmk
def suffixed_col(col, suffix):
404
    return Col(concat(col.name, suffix), col.table, col.srcs)
405
406 3516 aaronmk
def has_srcs(col): return is_col(col) and col.srcs
407
408 3518 aaronmk
def cross_join_srcs(cols):
409
    cols = filter(has_srcs, cols) # empty srcs will mess up the cross join
410
    srcs = [[s.name for s in c.srcs] for c in cols]
411 5816 aaronmk
    if not srcs: return [] # itertools.product() returns [()] for empty input
412 3518 aaronmk
    return [Col(','.join(s)) for s in itertools.product(*srcs)]
413 3514 aaronmk
414 2323 aaronmk
class NamedCol(Col):
415 2229 aaronmk
    def __init__(self, name, code):
416 2310 aaronmk
        Col.__init__(self, name)
417
418 3016 aaronmk
        code = as_Value(code)
419 2229 aaronmk
420
        self.code = code
421
422
    def to_str(self, db):
423 2310 aaronmk
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
424 2314 aaronmk
425
    def to_Col(self): return Col(self.name)
426 2229 aaronmk
427 2462 aaronmk
def remove_col_rename(col):
428
    if isinstance(col, NamedCol): col = col.code
429
    return col
430
431 2830 aaronmk
def underlying_col(col):
432
    col = remove_col_rename(col)
433 4114 aaronmk
    if not isinstance(col, Col): raise NoUnderlyingTableException(col)
434 2849 aaronmk
435 2902 aaronmk
    return Col(col.name, underlying_table(col.table), col.srcs)
436 2830 aaronmk
437 2703 aaronmk
def wrap(wrap_func, value):
438
    '''Wraps a value, propagating any column renaming to the returned value.'''
439
    if isinstance(value, NamedCol):
440
        return NamedCol(value.name, wrap_func(value.code))
441
    else: return wrap_func(value)
442
443 2667 aaronmk
class ColDict(dicts.DictProxy):
444 3691 aaronmk
    '''A dict that automatically makes inserted entries Col objects.
445
    Anything that isn't a column is wrapped in a NamedCol with the key's column
446
    name by `as_Col(value, name=key.name)`.
447
    '''
448 2564 aaronmk
449 2645 aaronmk
    def __init__(self, db, keys_table, dict_={}):
450 3158 aaronmk
        dicts.DictProxy.__init__(self, OrderedDict())
451 2667 aaronmk
452 2645 aaronmk
        keys_table = as_Table(keys_table)
453
454 2642 aaronmk
        self.db = db
455 2641 aaronmk
        self.table = keys_table
456 2653 aaronmk
        self.update(dict_) # after setting vars because __setitem__() needs them
457 2641 aaronmk
458 2667 aaronmk
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
459 2655 aaronmk
460 2667 aaronmk
    def __getitem__(self, key):
461
        return dicts.DictProxy.__getitem__(self, self._key(key))
462 2653 aaronmk
463 2564 aaronmk
    def __setitem__(self, key, value):
464 2642 aaronmk
        key = self._key(key)
465 3639 aaronmk
        if value == None:
466
            try: value = self.db.col_info(key).default
467
            except NoUnderlyingTableException: pass # not a table column
468 2667 aaronmk
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
469 2564 aaronmk
470 2641 aaronmk
    def _key(self, key): return as_Col(key, self.table)
471 2564 aaronmk
472 3519 aaronmk
##### Definitions
473 3491 aaronmk
474 3469 aaronmk
class TypedCol(Col):
475
    def __init__(self, name, type_, default=None, nullable=True,
476
        constraints=None):
477
        assert default == None or isinstance(default, Code)
478
479
        Col.__init__(self, name)
480
481
        self.type = type_
482
        self.default = default
483
        self.nullable = nullable
484
        self.constraints = constraints
485
486
    def to_str(self, db):
487 3481 aaronmk
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
488 3469 aaronmk
        if not self.nullable: str_ += ' NOT NULL'
489
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
490
        if self.constraints != None: str_ += ' '+self.constraints
491
        return str_
492
493
    def to_Col(self): return Col(self.name)
494
495 3488 aaronmk
class SetOf(Code):
496
    def __init__(self, type_):
497
        Code.__init__(self)
498
499
        self.type = type_
500
501
    def to_str(self, db):
502
        return 'SETOF '+self.type.to_str(db)
503
504 3483 aaronmk
class RowType(Code):
505
    def __init__(self, table):
506
        Code.__init__(self)
507
508
        self.table = table
509
510
    def to_str(self, db):
511
        return self.table.to_str(db)+'%ROWTYPE'
512
513 3485 aaronmk
class ColType(Code):
514
    def __init__(self, col):
515
        Code.__init__(self)
516
517
        self.col = col
518
519
    def to_str(self, db):
520
        return self.col.to_str(db)+'%TYPE'
521
522 4935 aaronmk
class ArrayType(Code):
523
    def __init__(self, elem_type):
524
        Code.__init__(self)
525
        elem_type = as_Code(elem_type)
526
527
        self.elem_type = elem_type
528
529
    def to_str(self, db):
530
        return self.elem_type.to_str(db)+'[]'
531
532 2524 aaronmk
##### Functions
533
534 2912 aaronmk
Function = Table
535 2911 aaronmk
as_Function = as_Table
536
537 2691 aaronmk
class InternalFunction(CustomCode): pass
538
539 3442 aaronmk
#### Calls
540
541 2941 aaronmk
class NamedArg(NamedCol):
542
    def __init__(self, name, value):
543
        NamedCol.__init__(self, name, value)
544
545
    def to_str(self, db):
546
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
547
548 2524 aaronmk
class FunctionCall(Code):
549 2941 aaronmk
    def __init__(self, function, *args, **kw_args):
550 2524 aaronmk
        '''
551 2690 aaronmk
        @param args [Code|literal-value...] The function's arguments
552 2524 aaronmk
        '''
553 3445 aaronmk
        Code.__init__(self)
554
555 3016 aaronmk
        function = as_Function(function)
556 2941 aaronmk
        def filter_(arg): return remove_col_rename(as_Value(arg))
557
        args = map(filter_, args)
558
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
559 2524 aaronmk
560
        self.function = function
561
        self.args = args
562
563
    def to_str(self, db):
564
        args_str = ', '.join((v.to_str(db) for v in self.args))
565
        return self.function.to_str(db)+'('+args_str+')'
566
567 2533 aaronmk
def wrap_in_func(function, value):
568
    '''Wraps a value inside a function call.
569
    Propagates any column renaming to the returned value.
570
    '''
571 2703 aaronmk
    return wrap(lambda v: FunctionCall(function, v), value)
572 2533 aaronmk
573 2561 aaronmk
def unwrap_func_call(func_call, check_name=None):
574
    '''Unwraps any function call to its first argument.
575
    Also removes any column renaming.
576
    '''
577
    func_call = remove_col_rename(func_call)
578
    if not isinstance(func_call, FunctionCall): return func_call
579
580
    if check_name != None:
581
        name = func_call.function.name
582
        assert name == None or name == check_name
583
    return func_call.args[0]
584
585 3442 aaronmk
#### Definitions
586
587
class FunctionDef(Code):
588 3471 aaronmk
    def __init__(self, function, return_type, body, params=[], modifiers=None):
589 3445 aaronmk
        Code.__init__(self)
590
591 3487 aaronmk
        return_type = as_Code(return_type)
592 3444 aaronmk
        body = as_Code(body)
593
594 3442 aaronmk
        self.function = function
595
        self.return_type = return_type
596
        self.body = body
597 3471 aaronmk
        self.params = params
598 3456 aaronmk
        self.modifiers = modifiers
599 3442 aaronmk
600
    def to_str(self, db):
601 3487 aaronmk
        params_str = (', '.join((p.to_str(db) for p in self.params)))
602 3442 aaronmk
        str_ = '''\
603 3471 aaronmk
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
604 3487 aaronmk
RETURNS '''+self.return_type.to_str(db)+'''
605 3448 aaronmk
LANGUAGE '''+self.body.lang+'''
606 3456 aaronmk
'''
607
        if self.modifiers != None: str_ += self.modifiers+'\n'
608
        str_ += '''\
609 3442 aaronmk
AS $$
610 3444 aaronmk
'''+self.body.to_str(db)+'''
611 3442 aaronmk
$$;
612
'''
613
        return str_
614
615 3469 aaronmk
class FunctionParam(TypedCol):
616
    def __init__(self, name, type_, default=None, out=False):
617
        TypedCol.__init__(self, name, type_, default)
618
619
        self.out = out
620
621
    def to_str(self, db):
622
        str_ = TypedCol.to_str(self, db)
623
        if self.out: str_ = 'OUT '+str_
624
        return str_
625
626
    def to_Col(self): return Col(self.name)
627
628 3454 aaronmk
### PL/pgSQL
629
630 3496 aaronmk
class ReturnQuery(Code):
631
    def __init__(self, query):
632
        Code.__init__(self)
633
634
        query = as_Code(query)
635
636
        self.query = query
637
638
    def to_str(self, db):
639
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
640
641 3515 aaronmk
## Exceptions
642
643 3509 aaronmk
class BaseExcHandler(BasicObject):
644
    def to_str(self, db, body): raise NotImplementedError()
645 3510 aaronmk
646
    def __repr__(self): return self.to_str(mockDb, '<body>')
647 3509 aaronmk
648 3549 aaronmk
suppress_exc = 'NULL;\n';
649
650 3554 aaronmk
reraise_exc = 'RAISE USING ERRCODE = SQLSTATE, MESSAGE = SQLERRM;\n';
651
652 3509 aaronmk
class ExcHandler(BaseExcHandler):
653 3454 aaronmk
    def __init__(self, exc, handler=None):
654
        if handler != None: handler = as_Code(handler)
655
656
        self.exc = exc
657
        self.handler = handler
658
659
    def to_str(self, db, body):
660
        body = as_Code(body)
661
662 3467 aaronmk
        if self.handler != None:
663
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
664 3549 aaronmk
        else: handler_str = ' '+suppress_exc
665 3454 aaronmk
666
        str_ = '''\
667
BEGIN
668 3467 aaronmk
'''+strings.indent(body.to_str(db))+'''\
669 3454 aaronmk
EXCEPTION
670 3463 aaronmk
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
671 3454 aaronmk
END;\
672
'''
673
        return str_
674
675 3515 aaronmk
class NestedExcHandler(BaseExcHandler):
676
    def __init__(self, *handlers):
677
        '''
678
        @param handlers Sorted from outermost to innermost
679
        '''
680
        self.handlers = handlers
681
682
    def to_str(self, db, body):
683
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
684
        return body
685
686 3503 aaronmk
class ExcToWarning(Code):
687
    def __init__(self, return_):
688
        '''
689
        @param return_ Statement to return a default value in case of error
690
        '''
691
        Code.__init__(self)
692
693
        return_ = as_Code(return_)
694
695
        self.return_ = return_
696
697
    def to_str(self, db):
698
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
699
700 3454 aaronmk
unique_violation_handler = ExcHandler('unique_violation')
701
702 3555 aaronmk
# Note doubled "\"s because inside Python string
703 3468 aaronmk
plpythonu_error_handler = ExcHandler('internal_error', '''\
704 3591 aaronmk
-- Handle PL/Python exceptions
705 3555 aaronmk
DECLARE
706 3595 aaronmk
    matches text[] := regexp_matches(SQLERRM,
707
        E'^(?:PL/Python: )?(\\\\w+): (.*)$'); -- .* also matches \\n
708 3555 aaronmk
    exc_name text := matches[1];
709
    msg text := matches[2];
710
BEGIN
711 3596 aaronmk
    /* Re-raise PL/Python exceptions with the PL/Python prefix removed.
712
    This allows the exception to be parsed like a native exception.
713
    Always raise as data_exception so it goes in the errors table. */
714 3598 aaronmk
    IF exc_name IS NOT NULL THEN
715 4026 aaronmk
        RAISE data_exception USING MESSAGE = msg;
716 3595 aaronmk
    -- Re-raise non-PL/Python exceptions
717
    ELSE
718
        '''+reraise_exc+'''\
719 3555 aaronmk
    END IF;
720
END;
721 3468 aaronmk
''')
722
723 3505 aaronmk
def data_exception_handler(handler):
724
    return ExcHandler('data_exception', handler)
725
726 3527 aaronmk
row_var = Table('row')
727
728 3449 aaronmk
class RowExcIgnore(Code):
729
    def __init__(self, row_type, select_query, with_row, cols=None,
730 3527 aaronmk
        exc_handler=unique_violation_handler, row_var=row_var):
731 3529 aaronmk
        '''
732
        @param row_type Ignored if a custom row_var is used.
733
        @pre If a custom row_var is used, it must already be defined.
734
        '''
735 3449 aaronmk
        Code.__init__(self, lang='plpgsql')
736
737 3482 aaronmk
        row_type = as_Code(row_type)
738 3449 aaronmk
        select_query = as_Code(select_query)
739
        with_row = as_Code(with_row)
740 3452 aaronmk
        row_var = as_Table(row_var)
741 3449 aaronmk
742
        self.row_type = row_type
743
        self.select_query = select_query
744
        self.with_row = with_row
745
        self.cols = cols
746 3455 aaronmk
        self.exc_handler = exc_handler
747 3452 aaronmk
        self.row_var = row_var
748 3449 aaronmk
749
    def to_str(self, db):
750 3452 aaronmk
        if self.cols == None: row_vars = [self.row_var]
751
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
752 3449 aaronmk
753 3526 aaronmk
        # Need an EXCEPTION block for each individual row because "When an error
754
        # is caught by an EXCEPTION clause, [...] all changes to persistent
755
        # database state within the block are rolled back."
756
        # This is unfortunate because "A block containing an EXCEPTION clause is
757
        # significantly more expensive to enter and exit than a block without
758
        # one."
759
        # (http://www.postgresql.org/docs/8.3/static/\
760
        # plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING)
761 3449 aaronmk
        str_ = '''\
762 3529 aaronmk
FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
763
'''+strings.indent(self.select_query.to_str(db))+'''\
764
LOOP
765
'''+strings.indent(self.exc_handler.to_str(db, self.with_row))+'''\
766
END LOOP;
767
'''
768 3533 aaronmk
        if self.row_var == row_var:
769 3529 aaronmk
            str_ = '''\
770 3449 aaronmk
DECLARE
771 3482 aaronmk
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
772 3449 aaronmk
BEGIN
773 3529 aaronmk
'''+strings.indent(str_)+'''\
774
END;
775 3449 aaronmk
'''
776
        return str_
777
778 2986 aaronmk
##### Casts
779
780
class Cast(FunctionCall):
781
    def __init__(self, type_, value):
782 3539 aaronmk
        type_ = as_Code(type_)
783 2986 aaronmk
        value = as_Value(value)
784
785 6299 aaronmk
        # Most types cannot be cast directly to unknown
786
        if type_.to_str(mockDb) == 'unknown': value = Cast('text', value)
787
788 2986 aaronmk
        self.type_ = type_
789
        self.value = value
790
791
    def to_str(self, db):
792 3539 aaronmk
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_.to_str(db)+')'
793 2986 aaronmk
794 3354 aaronmk
def cast_literal(value):
795 3429 aaronmk
    if not is_literal(value): return value
796 3354 aaronmk
797
    if util.is_str(value.value): value = Cast('text', value)
798
    return value
799
800 2335 aaronmk
##### Conditions
801 2259 aaronmk
802 3350 aaronmk
class NotCond(Code):
803
    def __init__(self, cond):
804 3445 aaronmk
        Code.__init__(self)
805
806 5893 aaronmk
        if not isinstance(cond, Coalesce): cond = Coalesce(cond, False)
807
808 3350 aaronmk
        self.cond = cond
809
810
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
811
812 10840 aaronmk
custom_cond = object() # tells ColValueCond that value is a plain SQL cond
813
814 2398 aaronmk
class ColValueCond(Code):
815
    def __init__(self, col, value):
816 3445 aaronmk
        Code.__init__(self)
817
818 10840 aaronmk
        if col is not custom_cond: value = as_ValueCond(value)
819 2398 aaronmk
820
        self.col = col
821
        self.value = value
822
823 10840 aaronmk
    def to_str(self, db):
824
        if self.col is custom_cond: return self.value.to_str(db)
825
        else: return self.value.to_str(db, self.col)
826 2398 aaronmk
827 2577 aaronmk
def combine_conds(conds, keyword=None):
828
    '''
829
    @param keyword The keyword to add before the conditions, if any
830
    '''
831
    str_ = ''
832
    if keyword != None:
833
        if conds == []: whitespace = ''
834
        elif len(conds) == 1: whitespace = ' '
835
        else: whitespace = '\n'
836
        str_ += keyword+whitespace
837
838
    str_ += '\nAND '.join(conds)
839
    return str_
840
841 2398 aaronmk
##### Condition column comparisons
842
843 2514 aaronmk
class ValueCond(BasicObject):
844 2213 aaronmk
    def __init__(self, value):
845 2858 aaronmk
        value = remove_col_rename(as_Value(value))
846 2213 aaronmk
847
        self.value = value
848 2214 aaronmk
849 2216 aaronmk
    def to_str(self, db, left_value):
850 2214 aaronmk
        '''
851 2216 aaronmk
        @param left_value The Code object that the condition is being applied on
852 2214 aaronmk
        '''
853
        raise NotImplemented()
854 2228 aaronmk
855 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
856 2211 aaronmk
857
class CompareCond(ValueCond):
858
    def __init__(self, value, operator='='):
859 2222 aaronmk
        '''
860
        @param operator By default, compares NULL values literally. Use '~=' or
861
            '~!=' to pass NULLs through.
862
        '''
863 2211 aaronmk
        ValueCond.__init__(self, value)
864
        self.operator = operator
865
866 2216 aaronmk
    def to_str(self, db, left_value):
867 2858 aaronmk
        left_value = remove_col_rename(as_Col(left_value))
868 2216 aaronmk
869 2222 aaronmk
        right_value = self.value
870
871
        # Parse operator
872 2216 aaronmk
        operator = self.operator
873 2222 aaronmk
        passthru_null_ref = [False]
874
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
875
        neg_ref = [False]
876
        operator = strings.remove_prefix('!', operator, neg_ref)
877 2844 aaronmk
        equals = operator.endswith('=') # also includes <=, >=
878 2222 aaronmk
879 2825 aaronmk
        # Handle nullable columns
880
        check_null = False
881 2844 aaronmk
        if not passthru_null_ref[0]: # NULLs compare equal
882 2857 aaronmk
            try: left_value = ensure_not_null(db, left_value)
883 2844 aaronmk
            except ensure_not_null_excs: # fall back to alternate method
884
                check_null = equals and isinstance(right_value, Col)
885 2837 aaronmk
            else:
886 2857 aaronmk
                if isinstance(left_value, EnsureNotNull):
887
                    right_value = ensure_not_null(db, right_value,
888
                        left_value.type) # apply same function to both sides
889 2825 aaronmk
890 2844 aaronmk
        if equals and is_null(right_value): operator = 'IS'
891
892 2825 aaronmk
        left = left_value.to_str(db)
893
        right = right_value.to_str(db)
894
895 2222 aaronmk
        # Create str
896
        str_ = left+' '+operator+' '+right
897 2825 aaronmk
        if check_null:
898 2578 aaronmk
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
899
        if neg_ref[0]: str_ = 'NOT '+str_
900 2222 aaronmk
        return str_
901 2216 aaronmk
902 2260 aaronmk
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
903
assume_literal = object()
904
905
def as_ValueCond(value, default_table=assume_literal):
906
    if not isinstance(value, ValueCond):
907
        if default_table is not assume_literal:
908 2748 aaronmk
            value = with_default_table(value, default_table)
909 2260 aaronmk
        return CompareCond(value)
910 2216 aaronmk
    else: return value
911 2219 aaronmk
912 2335 aaronmk
##### Joins
913
914 2352 aaronmk
join_same = object() # tells Join the left and right columns have the same name
915 2260 aaronmk
916 2353 aaronmk
# Tells Join the left and right columns have the same name and are never NULL
917
join_same_not_null = object()
918
919 2260 aaronmk
filter_out = object() # tells Join to filter out rows that match the join
920
921 2514 aaronmk
class Join(BasicObject):
922 10842 aaronmk
    def __init__(self, table, mapping={}, type_=None, custom_cond=None):
923 2260 aaronmk
        '''
924
        @param mapping dict(right_table_col=left_table_col, ...)
925 7176 aaronmk
            or [using_col...]
926 2352 aaronmk
            * if left_table_col is join_same: left_table_col = right_table_col
927 2353 aaronmk
              * Note that right_table_col must be a string
928
            * if left_table_col is join_same_not_null:
929
              left_table_col = right_table_col and both have NOT NULL constraint
930
              * Note that right_table_col must be a string
931 2260 aaronmk
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
932
            * filter_out: equivalent to 'LEFT' with the query filtered by
933
              `table_pkey IS NULL` (indicating no match)
934
        '''
935
        if util.is_str(table): table = Table(table)
936 7176 aaronmk
        if lists.is_seq(mapping):
937
            mapping = dict(((c, join_same_not_null) for c in mapping))
938 2260 aaronmk
        assert type_ == None or util.is_str(type_) or type_ is filter_out
939
940
        self.table = table
941
        self.mapping = mapping
942
        self.type_ = type_
943 10842 aaronmk
        self.custom_cond = custom_cond
944 2260 aaronmk
945 2749 aaronmk
    def to_str(self, db, left_table_):
946 2260 aaronmk
        def join(entry):
947
            '''Parses non-USING joins'''
948
            right_table_col, left_table_col = entry
949
950 2353 aaronmk
            # Switch order (right_table_col is on the left in the comparison)
951
            left = right_table_col
952
            right = left_table_col
953 2749 aaronmk
            left_table = self.table
954
            right_table = left_table_
955 2353 aaronmk
956 2747 aaronmk
            # Parse left side
957 2748 aaronmk
            left = with_default_table(left, left_table)
958 2747 aaronmk
959 2260 aaronmk
            # Parse special values
960 2747 aaronmk
            left_on_right = Col(left.name, right_table)
961
            if right is join_same: right = left_on_right
962 2353 aaronmk
            elif right is join_same_not_null:
963 2747 aaronmk
                right = CompareCond(left_on_right, '~=')
964 2260 aaronmk
965 2747 aaronmk
            # Parse right side
966 2353 aaronmk
            right = as_ValueCond(right, right_table)
967 2747 aaronmk
968
            return right.to_str(db, left)
969 2260 aaronmk
970 2265 aaronmk
        # Create join condition
971
        type_ = self.type_
972 2276 aaronmk
        joins = self.mapping
973 2746 aaronmk
        if joins == {}: join_cond = None
974
        elif type_ is not filter_out and reduce(operator.and_,
975 2460 aaronmk
            (v is join_same_not_null for v in joins.itervalues())):
976 2260 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
977 2747 aaronmk
            cols = map(to_name_only_col, joins.iterkeys())
978
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
979 2757 aaronmk
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
980 2260 aaronmk
981 2757 aaronmk
        if isinstance(self.table, NamedTable): whitespace = '\n'
982
        else: whitespace = ' '
983
984 2260 aaronmk
        # Create join
985
        if type_ is filter_out: type_ = 'LEFT'
986 2266 aaronmk
        str_ = ''
987
        if type_ != None: str_ += type_+' '
988 2757 aaronmk
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
989
        if join_cond != None: str_ += whitespace+join_cond
990 10842 aaronmk
        if self.custom_cond != None: str_ += '\nAND '+self.custom_cond
991 2266 aaronmk
        return str_
992 2349 aaronmk
993 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
994 2424 aaronmk
995
##### Value exprs
996
997 3089 aaronmk
all_cols = CustomCode('*')
998
999 2737 aaronmk
default = CustomCode('DEFAULT')
1000
1001 3090 aaronmk
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
1002 2674 aaronmk
1003 3061 aaronmk
class Coalesce(FunctionCall):
1004
    def __init__(self, *args):
1005
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
1006 3060 aaronmk
1007 3062 aaronmk
class Nullif(FunctionCall):
1008
    def __init__(self, *args):
1009
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
1010
1011 5891 aaronmk
null = Literal(None)
1012 5892 aaronmk
null_as_str = Cast('text', null)
1013 3706 aaronmk
1014
def to_text(value): return Coalesce(Cast('text', value), null_as_str)
1015
1016 2850 aaronmk
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
1017 2958 aaronmk
null_sentinels = {
1018
    'character varying': r'\N',
1019
    'double precision': 'NaN',
1020
    'integer': 2147483647,
1021
    'text': r'\N',
1022 5498 aaronmk
    'date': 'infinity',
1023 5266 aaronmk
    'timestamp with time zone': 'infinity',
1024
    'taxonrank': 'unknown',
1025 2958 aaronmk
}
1026 2692 aaronmk
1027 3061 aaronmk
class EnsureNotNull(Coalesce):
1028 2850 aaronmk
    def __init__(self, value, type_):
1029 4938 aaronmk
        if isinstance(type_, ArrayType): null = []
1030
        else: null = null_sentinels[type_]
1031
        Coalesce.__init__(self, as_Col(value), Cast(type_, null))
1032 2850 aaronmk
1033
        self.type = type_
1034 3001 aaronmk
1035
    def to_str(self, db):
1036
        col = self.args[0]
1037
        index_col_ = index_col(col)
1038
        if index_col_ != None: return index_col_.to_str(db)
1039 3061 aaronmk
        return Coalesce.to_str(self, db)
1040 2850 aaronmk
1041 3523 aaronmk
#### Arrays
1042
1043 3535 aaronmk
class ArrayMerge(FunctionCall):
1044 3523 aaronmk
    def __init__(self, sep, array):
1045
        array = to_Array(array)
1046
        FunctionCall.__init__(self, InternalFunction('array_to_string'), array,
1047
            sep)
1048
1049 3537 aaronmk
def merge_not_null(db, sep, values):
1050 3707 aaronmk
    return ArrayMerge(sep, map(to_text, values))
1051 3537 aaronmk
1052 2737 aaronmk
##### Table exprs
1053
1054
class Values(Code):
1055
    def __init__(self, values):
1056 2739 aaronmk
        '''
1057
        @param values [...]|[[...], ...] Can be one or multiple rows.
1058
        '''
1059 3445 aaronmk
        Code.__init__(self)
1060
1061 2739 aaronmk
        rows = values
1062
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
1063
            rows = [values]
1064
        for i, row in enumerate(rows):
1065
            rows[i] = map(remove_col_rename, map(as_Value, row))
1066 2737 aaronmk
1067 2739 aaronmk
        self.rows = rows
1068 2737 aaronmk
1069
    def to_str(self, db):
1070 3520 aaronmk
        return 'VALUES '+(', '.join((Tuple(*r).to_str(db) for r in self.rows)))
1071 2737 aaronmk
1072 2740 aaronmk
def NamedValues(name, cols, values):
1073 2745 aaronmk
    '''
1074 3048 aaronmk
    @param cols None|[...]
1075 2745 aaronmk
    @post `cols` will be changed to Col objects with the table set to `name`.
1076
    '''
1077 2834 aaronmk
    table = NamedTable(name, Values(values), cols)
1078 3048 aaronmk
    if cols != None: set_cols_table(table, cols)
1079 2834 aaronmk
    return table
1080 2740 aaronmk
1081 2674 aaronmk
##### Database structure
1082
1083 5074 aaronmk
def is_nullable(db, value):
1084 5330 aaronmk
    if not is_table_col(value): return is_null(value)
1085
    try: return db.col_info(value).nullable
1086 5074 aaronmk
    except NoUnderlyingTableException: return True # not a table column
1087
1088 4442 aaronmk
text_types = set(['character varying', 'text'])
1089 4406 aaronmk
1090 5334 aaronmk
def is_text_type(type_): return type_ in text_types
1091
1092 5335 aaronmk
def is_text_col(db, col): return is_text_type(db.col_info(col).type)
1093 4442 aaronmk
1094 5397 aaronmk
def canon_type(type_):
1095
    if type_ in text_types: return 'text'
1096
    else: return type_
1097
1098 2840 aaronmk
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
1099
1100 2851 aaronmk
def ensure_not_null(db, col, type_=None):
1101 2840 aaronmk
    '''
1102 2855 aaronmk
    @param col If type_ is not set, must have an underlying column.
1103 4488 aaronmk
    @param type_ If set, overrides the underlying column's type and casts the
1104
        column to it if needed.
1105 2840 aaronmk
    @return EnsureNotNull|Col
1106
    @throws ensure_not_null_excs
1107
    '''
1108 5329 aaronmk
    col = remove_col_rename(col)
1109 5332 aaronmk
1110
    try: col_type = db.col_info(underlying_col(col)).type
1111 2855 aaronmk
    except NoUnderlyingTableException:
1112 5333 aaronmk
        if type_ == None and is_null(col): raise # NULL has no type
1113 2855 aaronmk
    else:
1114 5332 aaronmk
        if type_ == None: type_ = col_type
1115
        elif type_ != col_type: col = Cast(type_, col)
1116 2855 aaronmk
1117 5331 aaronmk
    if is_nullable(db, col):
1118 2953 aaronmk
        try: col = EnsureNotNull(col, type_)
1119
        except KeyError, e:
1120
            # Warn of no null sentinel for type, even if caller catches error
1121
            warnings.warn(UserWarning(exc.str_(e)))
1122
            raise
1123
1124 2840 aaronmk
    return col
1125 3536 aaronmk
1126
def try_mk_not_null(db, value):
1127
    '''
1128
    Warning: This function does not guarantee that its result is NOT NULL.
1129
    '''
1130
    try: return ensure_not_null(db, value)
1131
    except ensure_not_null_excs: return value
1132 5367 aaronmk
1133
##### Expression transforming
1134
1135
true_expr = 'true'
1136
false_expr = 'false'
1137
1138
true_re = true_expr
1139
false_re = false_expr
1140
bool_re = r'(?:'+true_re+r'|'+false_re+r')'
1141
atom_re = r'(?:'+bool_re+r'|\([^()]*\)'+r')'
1142
1143
def logic_op_re(op, value_re, expr_re=''):
1144
    op_re = ' '+op+' '
1145
    return '(?:'+expr_re+op_re+value_re+'|'+value_re+op_re+expr_re+')'
1146
1147 5370 aaronmk
not_re = r'\bNOT '
1148 5372 aaronmk
not_false_re = not_re+false_re+r'\b'
1149
not_true_re = not_re+true_re+r'\b'
1150 5367 aaronmk
and_false_re = logic_op_re('AND', false_re, atom_re)
1151 5370 aaronmk
and_false_not_true_re = '(?:'+not_true_re+'|'+and_false_re+')'
1152 5367 aaronmk
and_true_re = logic_op_re('AND', true_re)
1153
or_re = logic_op_re('OR', bool_re)
1154
or_and_true_re = '(?:'+and_true_re+'|'+or_re+')'
1155
1156
def simplify_parens(expr):
1157
    return regexp.sub_nested(r'\(('+atom_re+')\)', r'\1', expr)
1158
1159
def simplify_recursive(sub_func, expr):
1160
    '''
1161
    @param sub_func See regexp.sub_recursive() sub_func param
1162
    '''
1163 5378 aaronmk
    return regexp.sub_recursive(lambda s: sub_func(simplify_parens(s)), expr)
1164
        # simplify_parens() is also done at end in final iteration
1165 5367 aaronmk
1166
def simplify_expr(expr):
1167 13076 aaronmk
    '''
1168
    this can also be done in Postgres with expression substitution
1169
    (wiki.vegpath.org/Postgres_queries#expression-substitution)
1170
    '''
1171 5367 aaronmk
    def simplify_logic_ops(expr):
1172
        total_n = 0
1173 5371 aaronmk
        expr, n = re.subn(not_false_re, true_re, expr)
1174
        total_n += n
1175 5370 aaronmk
        expr, n = re.subn(and_false_not_true_re, false_expr, expr)
1176 5367 aaronmk
        total_n += n
1177
        expr, n = re.subn(or_and_true_re, r'', expr)
1178
        total_n += n
1179
        return expr, total_n
1180
1181 5826 aaronmk
    expr = expr.replace('NULL IS NULL', true_expr)
1182
    expr = expr.replace('NULL IS NOT NULL', false_expr)
1183 5367 aaronmk
    expr = simplify_recursive(simplify_logic_ops, expr)
1184
    return expr
1185
1186
name_re = r'(?:\w+|(?:"[^"]*")+)'
1187
1188
def parse_expr_col(str_):
1189
    match = re.match(r'^\('+name_re+r'\(('+name_re+r').*\)\)$', str_)
1190
    if match: str_ = match.group(1)
1191
    return unesc_name(str_)
1192
1193
def map_expr(db, expr, mapping, in_cols_found=None):
1194
    '''Replaces output columns with input columns in an expression.
1195
    @param in_cols_found If set, will be filled in with the expr's (input) cols
1196 13076 aaronmk
1197
    this can also be done in Postgres with expression substitution
1198
    (wiki.vegpath.org/Postgres_queries#expression-substitution)
1199 13078 aaronmk
1200
    this is a special case of bin/repl SQL identifier handling which does not
1201 13079 aaronmk
    handle entire source files, but which does simplify the resulting expression
1202 5367 aaronmk
    '''
1203
    for out, in_ in mapping.iteritems():
1204
        orig_expr = expr
1205
        out = to_name_only_col(out)
1206
        in_str = to_name_only_col(remove_col_rename(in_)).to_str(db)
1207
1208
        # Replace out both with and without quotes
1209
        expr = expr.replace(out.to_str(db), in_str)
1210 5507 aaronmk
        expr = re.sub(r'(?<!["\'\.=\[])\b'+out.name+r'\b(?!["\',\.=\]])',
1211
            in_str, expr)
1212 5367 aaronmk
1213
        if in_cols_found != None and expr != orig_expr: # replaced something
1214
            in_cols_found.append(in_)
1215
1216
    return simplify_expr(expr)