Project

General

Profile

1 2211 aaronmk
# SQL code generation
2
3 2748 aaronmk
import copy
4 3518 aaronmk
import itertools
5 2276 aaronmk
import operator
6 3158 aaronmk
from ordereddict import OrderedDict
7 2568 aaronmk
import re
8 2653 aaronmk
import UserDict
9 2953 aaronmk
import warnings
10 2276 aaronmk
11 2667 aaronmk
import dicts
12 2953 aaronmk
import exc
13 2701 aaronmk
import iters
14
import lists
15 2360 aaronmk
import objects
16 5367 aaronmk
import regexp
17 2222 aaronmk
import strings
18 2227 aaronmk
import util
19 2211 aaronmk
20 2587 aaronmk
##### Names
21 2499 aaronmk
22 2608 aaronmk
identifier_max_len = 63 # works for both PostgreSQL and MySQL
23 2587 aaronmk
24 2932 aaronmk
def concat(str_, suffix):
25 2609 aaronmk
    '''Preserves version so that it won't be truncated off the string, leading
26
    to collisions.'''
27 2613 aaronmk
    # Preserve version
28 5552 aaronmk
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)+(?:::[\w ]+)*)$', str_)
29 2985 aaronmk
    if match:
30
        str_, old_suffix = match.groups()
31
        suffix = old_suffix+suffix
32 2613 aaronmk
33 2932 aaronmk
    return strings.concat(str_, suffix, identifier_max_len)
34 2587 aaronmk
35 2932 aaronmk
def truncate(str_): return concat(str_, '')
36 2842 aaronmk
37 2575 aaronmk
def is_safe_name(name):
38 2583 aaronmk
    '''A name is safe *and unambiguous* if it:
39
    * contains only *lowercase* word (\w) characters
40
    * doesn't start with a digit
41
    * contains "_", so that it's not a keyword
42 2984 aaronmk
    '''
43
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
44 2568 aaronmk
45 2499 aaronmk
def esc_name(name, quote='"'):
46
    return quote + name.replace(quote, quote+quote) + quote
47
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
48
49 3320 aaronmk
def unesc_name(name, quote='"'):
50
    removed_ref = [False]
51
    name = strings.remove_prefix(quote, name, removed_ref)
52
    if removed_ref[0]:
53
        name = strings.remove_suffix(quote, name, removed_ref)
54
        assert removed_ref[0]
55
        name = name.replace(quote+quote, quote)
56
    return name
57
58 2513 aaronmk
def clean_name(name): return name.replace('"', '').replace('`', '')
59
60 3041 aaronmk
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
61
62 3182 aaronmk
def lstrip(str_):
63
    '''Also removes comments.'''
64
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
65
    return str_.lstrip()
66
67 2659 aaronmk
##### General SQL code objects
68 2219 aaronmk
69 2349 aaronmk
class MockDb:
70 2503 aaronmk
    def esc_value(self, value): return strings.repr_no_u(value)
71 2349 aaronmk
72 2499 aaronmk
    def esc_name(self, name): return esc_name(name)
73 2859 aaronmk
74
    def col_info(self, col):
75
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
76
77 2349 aaronmk
mockDb = MockDb()
78
79 2514 aaronmk
class BasicObject(objects.BasicObject):
80
    def __str__(self): return clean_name(strings.repr_no_u(self))
81
82 2659 aaronmk
##### Unparameterized code objects
83
84 2514 aaronmk
class Code(BasicObject):
85 3446 aaronmk
    def __init__(self, lang='sql'):
86
        self.lang = lang
87 3445 aaronmk
88 2658 aaronmk
    def to_str(self, db): raise NotImplementedError()
89 2349 aaronmk
90 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb)
91 2211 aaronmk
92 2269 aaronmk
class CustomCode(Code):
93 3445 aaronmk
    def __init__(self, str_):
94
        Code.__init__(self)
95
96
        self.str_ = str_
97 2256 aaronmk
98
    def to_str(self, db): return self.str_
99
100 2815 aaronmk
def as_Code(value, db=None):
101
    '''
102
    @param db If set, runs db.std_code() on the value.
103
    '''
104 3447 aaronmk
    if isinstance(value, Code): return value
105
106 2815 aaronmk
    if util.is_str(value):
107
        if db != None: value = db.std_code(value)
108
        return CustomCode(value)
109 2659 aaronmk
    else: return Literal(value)
110
111 2540 aaronmk
class Expr(Code):
112 3445 aaronmk
    def __init__(self, expr):
113
        Code.__init__(self)
114
115
        self.expr = expr
116 2540 aaronmk
117
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
118
119 3086 aaronmk
##### Names
120
121
class Name(Code):
122 3087 aaronmk
    def __init__(self, name):
123 3445 aaronmk
        Code.__init__(self)
124
125 3087 aaronmk
        name = truncate(name)
126
127
        self.name = name
128 3086 aaronmk
129
    def to_str(self, db): return db.esc_name(self.name)
130
131
def as_Name(value):
132
    if isinstance(value, Code): return value
133
    else: return Name(value)
134
135 2335 aaronmk
##### Literal values
136
137 3519 aaronmk
#### Primitives
138
139 2216 aaronmk
class Literal(Code):
140 3445 aaronmk
    def __init__(self, value):
141
        Code.__init__(self)
142
143
        self.value = value
144 2213 aaronmk
145
    def to_str(self, db): return db.esc_value(self.value)
146 2211 aaronmk
147 2400 aaronmk
def as_Value(value):
148
    if isinstance(value, Code): return value
149
    else: return Literal(value)
150
151 6225 aaronmk
def get_value(value):
152
    '''Unwraps a Literal's value'''
153
    value = remove_col_rename(value)
154
    if isinstance(value, Literal): return value.value
155
    else:
156
        assert not isinstance(value, Code)
157
        return value
158
159 3429 aaronmk
def is_literal(value): return isinstance(value, Literal)
160 2216 aaronmk
161 3429 aaronmk
def is_null(value): return is_literal(value) and value.value == None
162
163 3519 aaronmk
#### Composites
164
165 3521 aaronmk
class List(Code):
166
    def __init__(self, values):
167 3519 aaronmk
        Code.__init__(self)
168
169
        self.values = values
170
171 3521 aaronmk
    def to_str(self, db): return ', '.join((v.to_str(db) for v in self.values))
172 3519 aaronmk
173 3521 aaronmk
class Tuple(List):
174
    def __init__(self, *values):
175
        List.__init__(self, values)
176
177
    def to_str(self, db): return '('+List.to_str(self, db)+')'
178
179 3520 aaronmk
class Row(Tuple):
180
    def to_str(self, db): return 'ROW'+Tuple.to_str(self, db)
181 3519 aaronmk
182 3522 aaronmk
### Arrays
183
184
class Array(List):
185
    def __init__(self, values):
186
        values = map(remove_col_rename, values)
187
188
        List.__init__(self, values)
189
190
    def to_str(self, db): return 'ARRAY['+List.to_str(self, db)+']'
191
192
def to_Array(value):
193
    if isinstance(value, Array): return value
194
    return Array(lists.mk_seq(value))
195
196 2711 aaronmk
##### Derived elements
197
198
src_self = object() # tells Col that it is its own source column
199
200
class Derived(Code):
201
    def __init__(self, srcs):
202 2712 aaronmk
        '''An element which was derived from some other element(s).
203 2711 aaronmk
        @param srcs See self.set_srcs()
204
        '''
205 3445 aaronmk
        Code.__init__(self)
206
207 2711 aaronmk
        self.set_srcs(srcs)
208
209 2713 aaronmk
    def set_srcs(self, srcs, overwrite=True):
210 2711 aaronmk
        '''
211
        @param srcs (self_type...)|src_self The element(s) this is derived from
212
        '''
213 2713 aaronmk
        if not overwrite and self.srcs != (): return # already set
214
215 2711 aaronmk
        if srcs == src_self: srcs = (self,)
216
        srcs = tuple(srcs) # make Col hashable
217
        self.srcs = srcs
218
219
    def _compare_on(self):
220
        compare_on = self.__dict__.copy()
221
        del compare_on['srcs'] # ignore
222
        return compare_on
223
224
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
225
226 2335 aaronmk
##### Tables
227
228 2712 aaronmk
class Table(Derived):
229 2991 aaronmk
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
230 2211 aaronmk
        '''
231
        @param schema str|None (for no schema)
232 2712 aaronmk
        @param srcs (Table...)|src_self See Derived.set_srcs()
233 2211 aaronmk
        '''
234 2712 aaronmk
        Derived.__init__(self, srcs)
235
236 3091 aaronmk
        if util.is_str(name): name = truncate(name)
237 2843 aaronmk
238 2211 aaronmk
        self.name = name
239
        self.schema = schema
240 2991 aaronmk
        self.is_temp = is_temp
241 5524 aaronmk
        self.order_by = None
242 3000 aaronmk
        self.index_cols = {}
243 2211 aaronmk
244 2348 aaronmk
    def to_str(self, db):
245
        str_ = ''
246 3088 aaronmk
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
247
        str_ += as_Name(self.name).to_str(db)
248 2348 aaronmk
        return str_
249 2336 aaronmk
250
    def to_Table(self): return self
251 3000 aaronmk
252
    def _compare_on(self):
253
        compare_on = Derived._compare_on(self)
254 5524 aaronmk
        del compare_on['order_by'] # ignore
255 3000 aaronmk
        del compare_on['index_cols'] # ignore
256
        return compare_on
257 2211 aaronmk
258 2835 aaronmk
def is_underlying_table(table):
259
    return isinstance(table, Table) and table.to_Table() is table
260 2832 aaronmk
261 4114 aaronmk
class NoUnderlyingTableException(Exception):
262
    def __init__(self, ref):
263
        Exception.__init__(self, 'for: '+strings.as_tt(strings.urepr(ref)))
264
        self.ref = ref
265 2902 aaronmk
266
def underlying_table(table):
267
    table = remove_table_rename(table)
268 3532 aaronmk
    if table != None and table.srcs:
269
        table, = table.srcs # for derived tables or row vars
270 4114 aaronmk
    if not is_underlying_table(table): raise NoUnderlyingTableException(table)
271 2902 aaronmk
    return table
272
273 2776 aaronmk
def as_Table(table, schema=None):
274 2270 aaronmk
    if table == None or isinstance(table, Code): return table
275 2776 aaronmk
    else: return Table(table, schema)
276 2219 aaronmk
277 3101 aaronmk
def suffixed_table(table, suffix):
278 3128 aaronmk
    table = copy.copy(table) # don't modify input!
279
    table.name = concat(table.name, suffix)
280
    return table
281 2707 aaronmk
282 2336 aaronmk
class NamedTable(Table):
283
    def __init__(self, name, code, cols=None):
284
        Table.__init__(self, name)
285
286 3016 aaronmk
        code = as_Table(code)
287 2741 aaronmk
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
288 3020 aaronmk
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
289 2336 aaronmk
290
        self.code = code
291
        self.cols = cols
292
293
    def to_str(self, db):
294 3026 aaronmk
        str_ = self.code.to_str(db)
295
        if str_.find('\n') >= 0: whitespace = '\n'
296
        else: whitespace = ' '
297
        str_ += whitespace+'AS '+Table.to_str(self, db)
298 2742 aaronmk
        if self.cols != None:
299
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
300 2336 aaronmk
        return str_
301
302
    def to_Table(self): return Table(self.name)
303
304 2753 aaronmk
def remove_table_rename(table):
305
    if isinstance(table, NamedTable): table = table.code
306
    return table
307
308 2335 aaronmk
##### Columns
309
310 2711 aaronmk
class Col(Derived):
311 2701 aaronmk
    def __init__(self, name, table=None, srcs=()):
312 2211 aaronmk
        '''
313
        @param table Table|None (for no table)
314 2711 aaronmk
        @param srcs (Col...)|src_self See Derived.set_srcs()
315 2211 aaronmk
        '''
316 2711 aaronmk
        Derived.__init__(self, srcs)
317
318 3091 aaronmk
        if util.is_str(name): name = truncate(name)
319 2241 aaronmk
        if util.is_str(table): table = Table(table)
320 2211 aaronmk
        assert table == None or isinstance(table, Table)
321
322
        self.name = name
323
        self.table = table
324
325 2989 aaronmk
    def to_str(self, db, for_str=False):
326 3088 aaronmk
        str_ = as_Name(self.name).to_str(db)
327 2989 aaronmk
        if for_str: str_ = clean_name(str_)
328 2933 aaronmk
        if self.table != None:
329 2989 aaronmk
            table = self.table.to_Table()
330 3750 aaronmk
            if for_str: str_ = concat(strings.ustr(table), '.'+str_)
331 2989 aaronmk
            else: str_ = table.to_str(db)+'.'+str_
332 2211 aaronmk
        return str_
333 2314 aaronmk
334 2989 aaronmk
    def __str__(self): return self.to_str(mockDb, for_str=True)
335 2933 aaronmk
336 2314 aaronmk
    def to_Col(self): return self
337 2211 aaronmk
338 3512 aaronmk
def is_col(col): return isinstance(col, Col)
339 2393 aaronmk
340 3512 aaronmk
def is_table_col(col): return is_col(col) and col.table != None
341
342 3000 aaronmk
def index_col(col):
343
    if not is_table_col(col): return None
344 3104 aaronmk
345
    table = col.table
346
    try: name = table.index_cols[col.name]
347
    except KeyError: return None
348
    else: return Col(name, table, col.srcs)
349 2999 aaronmk
350 3024 aaronmk
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
351 2996 aaronmk
352 2563 aaronmk
def as_Col(col, table=None, name=None):
353
    '''
354
    @param name If not None, any non-Col input will be renamed using NamedCol.
355
    '''
356
    if name != None:
357
        col = as_Value(col)
358
        if not isinstance(col, Col): col = NamedCol(name, col)
359 2333 aaronmk
360
    if isinstance(col, Code): return col
361 3513 aaronmk
    elif util.is_str(col): return Col(col, table)
362
    else: return Literal(col)
363 2260 aaronmk
364 3093 aaronmk
def with_table(col, table):
365 3105 aaronmk
    if isinstance(col, NamedCol): pass # doesn't take a table
366
    elif isinstance(col, FunctionCall):
367
        col = copy.deepcopy(col) # don't modify input!
368
        col.args[0].table = table
369 3493 aaronmk
    elif isinstance(col, Col):
370 3098 aaronmk
        col = copy.copy(col) # don't modify input!
371
        col.table = table
372 3093 aaronmk
    return col
373
374 3100 aaronmk
def with_default_table(col, table):
375 2747 aaronmk
    col = as_Col(col)
376 3100 aaronmk
    if col.table == None: col = with_table(col, table)
377 2747 aaronmk
    return col
378
379 2744 aaronmk
def set_cols_table(table, cols):
380
    table = as_Table(table)
381
382
    for i, col in enumerate(cols):
383
        col = cols[i] = as_Col(col)
384
        col.table = table
385
386 2401 aaronmk
def to_name_only_col(col, check_table=None):
387
    col = as_Col(col)
388 3020 aaronmk
    if not is_table_col(col): return col
389 2401 aaronmk
390
    if check_table != None:
391
        table = col.table
392
        assert table == None or table == check_table
393
    return Col(col.name)
394
395 2993 aaronmk
def suffixed_col(col, suffix):
396
    return Col(concat(col.name, suffix), col.table, col.srcs)
397
398 3516 aaronmk
def has_srcs(col): return is_col(col) and col.srcs
399
400 3518 aaronmk
def cross_join_srcs(cols):
401
    cols = filter(has_srcs, cols) # empty srcs will mess up the cross join
402
    srcs = [[s.name for s in c.srcs] for c in cols]
403 5816 aaronmk
    if not srcs: return [] # itertools.product() returns [()] for empty input
404 3518 aaronmk
    return [Col(','.join(s)) for s in itertools.product(*srcs)]
405 3514 aaronmk
406 2323 aaronmk
class NamedCol(Col):
407 2229 aaronmk
    def __init__(self, name, code):
408 2310 aaronmk
        Col.__init__(self, name)
409
410 3016 aaronmk
        code = as_Value(code)
411 2229 aaronmk
412
        self.code = code
413
414
    def to_str(self, db):
415 2310 aaronmk
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
416 2314 aaronmk
417
    def to_Col(self): return Col(self.name)
418 2229 aaronmk
419 2462 aaronmk
def remove_col_rename(col):
420
    if isinstance(col, NamedCol): col = col.code
421
    return col
422
423 2830 aaronmk
def underlying_col(col):
424
    col = remove_col_rename(col)
425 4114 aaronmk
    if not isinstance(col, Col): raise NoUnderlyingTableException(col)
426 2849 aaronmk
427 2902 aaronmk
    return Col(col.name, underlying_table(col.table), col.srcs)
428 2830 aaronmk
429 2703 aaronmk
def wrap(wrap_func, value):
430
    '''Wraps a value, propagating any column renaming to the returned value.'''
431
    if isinstance(value, NamedCol):
432
        return NamedCol(value.name, wrap_func(value.code))
433
    else: return wrap_func(value)
434
435 2667 aaronmk
class ColDict(dicts.DictProxy):
436 3691 aaronmk
    '''A dict that automatically makes inserted entries Col objects.
437
    Anything that isn't a column is wrapped in a NamedCol with the key's column
438
    name by `as_Col(value, name=key.name)`.
439
    '''
440 2564 aaronmk
441 2645 aaronmk
    def __init__(self, db, keys_table, dict_={}):
442 3158 aaronmk
        dicts.DictProxy.__init__(self, OrderedDict())
443 2667 aaronmk
444 2645 aaronmk
        keys_table = as_Table(keys_table)
445
446 2642 aaronmk
        self.db = db
447 2641 aaronmk
        self.table = keys_table
448 2653 aaronmk
        self.update(dict_) # after setting vars because __setitem__() needs them
449 2641 aaronmk
450 2667 aaronmk
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
451 2655 aaronmk
452 2667 aaronmk
    def __getitem__(self, key):
453
        return dicts.DictProxy.__getitem__(self, self._key(key))
454 2653 aaronmk
455 2564 aaronmk
    def __setitem__(self, key, value):
456 2642 aaronmk
        key = self._key(key)
457 3639 aaronmk
        if value == None:
458
            try: value = self.db.col_info(key).default
459
            except NoUnderlyingTableException: pass # not a table column
460 2667 aaronmk
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
461 2564 aaronmk
462 2641 aaronmk
    def _key(self, key): return as_Col(key, self.table)
463 2564 aaronmk
464 3519 aaronmk
##### Definitions
465 3491 aaronmk
466 3469 aaronmk
class TypedCol(Col):
467
    def __init__(self, name, type_, default=None, nullable=True,
468
        constraints=None):
469
        assert default == None or isinstance(default, Code)
470
471
        Col.__init__(self, name)
472
473
        self.type = type_
474
        self.default = default
475
        self.nullable = nullable
476
        self.constraints = constraints
477
478
    def to_str(self, db):
479 3481 aaronmk
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
480 3469 aaronmk
        if not self.nullable: str_ += ' NOT NULL'
481
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
482
        if self.constraints != None: str_ += ' '+self.constraints
483
        return str_
484
485
    def to_Col(self): return Col(self.name)
486
487 3488 aaronmk
class SetOf(Code):
488
    def __init__(self, type_):
489
        Code.__init__(self)
490
491
        self.type = type_
492
493
    def to_str(self, db):
494
        return 'SETOF '+self.type.to_str(db)
495
496 3483 aaronmk
class RowType(Code):
497
    def __init__(self, table):
498
        Code.__init__(self)
499
500
        self.table = table
501
502
    def to_str(self, db):
503
        return self.table.to_str(db)+'%ROWTYPE'
504
505 3485 aaronmk
class ColType(Code):
506
    def __init__(self, col):
507
        Code.__init__(self)
508
509
        self.col = col
510
511
    def to_str(self, db):
512
        return self.col.to_str(db)+'%TYPE'
513
514 4935 aaronmk
class ArrayType(Code):
515
    def __init__(self, elem_type):
516
        Code.__init__(self)
517
        elem_type = as_Code(elem_type)
518
519
        self.elem_type = elem_type
520
521
    def to_str(self, db):
522
        return self.elem_type.to_str(db)+'[]'
523
524 2524 aaronmk
##### Functions
525
526 2912 aaronmk
Function = Table
527 2911 aaronmk
as_Function = as_Table
528
529 2691 aaronmk
class InternalFunction(CustomCode): pass
530
531 3442 aaronmk
#### Calls
532
533 2941 aaronmk
class NamedArg(NamedCol):
534
    def __init__(self, name, value):
535
        NamedCol.__init__(self, name, value)
536
537
    def to_str(self, db):
538
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
539
540 2524 aaronmk
class FunctionCall(Code):
541 2941 aaronmk
    def __init__(self, function, *args, **kw_args):
542 2524 aaronmk
        '''
543 2690 aaronmk
        @param args [Code|literal-value...] The function's arguments
544 2524 aaronmk
        '''
545 3445 aaronmk
        Code.__init__(self)
546
547 3016 aaronmk
        function = as_Function(function)
548 2941 aaronmk
        def filter_(arg): return remove_col_rename(as_Value(arg))
549
        args = map(filter_, args)
550
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
551 2524 aaronmk
552
        self.function = function
553
        self.args = args
554
555
    def to_str(self, db):
556
        args_str = ', '.join((v.to_str(db) for v in self.args))
557
        return self.function.to_str(db)+'('+args_str+')'
558
559 2533 aaronmk
def wrap_in_func(function, value):
560
    '''Wraps a value inside a function call.
561
    Propagates any column renaming to the returned value.
562
    '''
563 2703 aaronmk
    return wrap(lambda v: FunctionCall(function, v), value)
564 2533 aaronmk
565 2561 aaronmk
def unwrap_func_call(func_call, check_name=None):
566
    '''Unwraps any function call to its first argument.
567
    Also removes any column renaming.
568
    '''
569
    func_call = remove_col_rename(func_call)
570
    if not isinstance(func_call, FunctionCall): return func_call
571
572
    if check_name != None:
573
        name = func_call.function.name
574
        assert name == None or name == check_name
575
    return func_call.args[0]
576
577 3442 aaronmk
#### Definitions
578
579
class FunctionDef(Code):
580 3471 aaronmk
    def __init__(self, function, return_type, body, params=[], modifiers=None):
581 3445 aaronmk
        Code.__init__(self)
582
583 3487 aaronmk
        return_type = as_Code(return_type)
584 3444 aaronmk
        body = as_Code(body)
585
586 3442 aaronmk
        self.function = function
587
        self.return_type = return_type
588
        self.body = body
589 3471 aaronmk
        self.params = params
590 3456 aaronmk
        self.modifiers = modifiers
591 3442 aaronmk
592
    def to_str(self, db):
593 3487 aaronmk
        params_str = (', '.join((p.to_str(db) for p in self.params)))
594 3442 aaronmk
        str_ = '''\
595 3471 aaronmk
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
596 3487 aaronmk
RETURNS '''+self.return_type.to_str(db)+'''
597 3448 aaronmk
LANGUAGE '''+self.body.lang+'''
598 3456 aaronmk
'''
599
        if self.modifiers != None: str_ += self.modifiers+'\n'
600
        str_ += '''\
601 3442 aaronmk
AS $$
602 3444 aaronmk
'''+self.body.to_str(db)+'''
603 3442 aaronmk
$$;
604
'''
605
        return str_
606
607 3469 aaronmk
class FunctionParam(TypedCol):
608
    def __init__(self, name, type_, default=None, out=False):
609
        TypedCol.__init__(self, name, type_, default)
610
611
        self.out = out
612
613
    def to_str(self, db):
614
        str_ = TypedCol.to_str(self, db)
615
        if self.out: str_ = 'OUT '+str_
616
        return str_
617
618
    def to_Col(self): return Col(self.name)
619
620 3454 aaronmk
### PL/pgSQL
621
622 3496 aaronmk
class ReturnQuery(Code):
623
    def __init__(self, query):
624
        Code.__init__(self)
625
626
        query = as_Code(query)
627
628
        self.query = query
629
630
    def to_str(self, db):
631
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
632
633 3515 aaronmk
## Exceptions
634
635 3509 aaronmk
class BaseExcHandler(BasicObject):
636
    def to_str(self, db, body): raise NotImplementedError()
637 3510 aaronmk
638
    def __repr__(self): return self.to_str(mockDb, '<body>')
639 3509 aaronmk
640 3549 aaronmk
suppress_exc = 'NULL;\n';
641
642 3554 aaronmk
reraise_exc = 'RAISE USING ERRCODE = SQLSTATE, MESSAGE = SQLERRM;\n';
643
644 3509 aaronmk
class ExcHandler(BaseExcHandler):
645 3454 aaronmk
    def __init__(self, exc, handler=None):
646
        if handler != None: handler = as_Code(handler)
647
648
        self.exc = exc
649
        self.handler = handler
650
651
    def to_str(self, db, body):
652
        body = as_Code(body)
653
654 3467 aaronmk
        if self.handler != None:
655
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
656 3549 aaronmk
        else: handler_str = ' '+suppress_exc
657 3454 aaronmk
658
        str_ = '''\
659
BEGIN
660 3467 aaronmk
'''+strings.indent(body.to_str(db))+'''\
661 3454 aaronmk
EXCEPTION
662 3463 aaronmk
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
663 3454 aaronmk
END;\
664
'''
665
        return str_
666
667 3515 aaronmk
class NestedExcHandler(BaseExcHandler):
668
    def __init__(self, *handlers):
669
        '''
670
        @param handlers Sorted from outermost to innermost
671
        '''
672
        self.handlers = handlers
673
674
    def to_str(self, db, body):
675
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
676
        return body
677
678 3503 aaronmk
class ExcToWarning(Code):
679
    def __init__(self, return_):
680
        '''
681
        @param return_ Statement to return a default value in case of error
682
        '''
683
        Code.__init__(self)
684
685
        return_ = as_Code(return_)
686
687
        self.return_ = return_
688
689
    def to_str(self, db):
690
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
691
692 3454 aaronmk
unique_violation_handler = ExcHandler('unique_violation')
693
694 3555 aaronmk
# Note doubled "\"s because inside Python string
695 3468 aaronmk
plpythonu_error_handler = ExcHandler('internal_error', '''\
696 3591 aaronmk
-- Handle PL/Python exceptions
697 3555 aaronmk
DECLARE
698 3595 aaronmk
    matches text[] := regexp_matches(SQLERRM,
699
        E'^(?:PL/Python: )?(\\\\w+): (.*)$'); -- .* also matches \\n
700 3555 aaronmk
    exc_name text := matches[1];
701
    msg text := matches[2];
702
BEGIN
703 3596 aaronmk
    /* Re-raise PL/Python exceptions with the PL/Python prefix removed.
704
    This allows the exception to be parsed like a native exception.
705
    Always raise as data_exception so it goes in the errors table. */
706 3598 aaronmk
    IF exc_name IS NOT NULL THEN
707 4026 aaronmk
        RAISE data_exception USING MESSAGE = msg;
708 3595 aaronmk
    -- Re-raise non-PL/Python exceptions
709
    ELSE
710
        '''+reraise_exc+'''\
711 3555 aaronmk
    END IF;
712
END;
713 3468 aaronmk
''')
714
715 3505 aaronmk
def data_exception_handler(handler):
716
    return ExcHandler('data_exception', handler)
717
718 3527 aaronmk
row_var = Table('row')
719
720 3449 aaronmk
class RowExcIgnore(Code):
721
    def __init__(self, row_type, select_query, with_row, cols=None,
722 3527 aaronmk
        exc_handler=unique_violation_handler, row_var=row_var):
723 3529 aaronmk
        '''
724
        @param row_type Ignored if a custom row_var is used.
725
        @pre If a custom row_var is used, it must already be defined.
726
        '''
727 3449 aaronmk
        Code.__init__(self, lang='plpgsql')
728
729 3482 aaronmk
        row_type = as_Code(row_type)
730 3449 aaronmk
        select_query = as_Code(select_query)
731
        with_row = as_Code(with_row)
732 3452 aaronmk
        row_var = as_Table(row_var)
733 3449 aaronmk
734
        self.row_type = row_type
735
        self.select_query = select_query
736
        self.with_row = with_row
737
        self.cols = cols
738 3455 aaronmk
        self.exc_handler = exc_handler
739 3452 aaronmk
        self.row_var = row_var
740 3449 aaronmk
741
    def to_str(self, db):
742 3452 aaronmk
        if self.cols == None: row_vars = [self.row_var]
743
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
744 3449 aaronmk
745 3526 aaronmk
        # Need an EXCEPTION block for each individual row because "When an error
746
        # is caught by an EXCEPTION clause, [...] all changes to persistent
747
        # database state within the block are rolled back."
748
        # This is unfortunate because "A block containing an EXCEPTION clause is
749
        # significantly more expensive to enter and exit than a block without
750
        # one."
751
        # (http://www.postgresql.org/docs/8.3/static/\
752
        # plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING)
753 3449 aaronmk
        str_ = '''\
754 3529 aaronmk
FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
755
'''+strings.indent(self.select_query.to_str(db))+'''\
756
LOOP
757
'''+strings.indent(self.exc_handler.to_str(db, self.with_row))+'''\
758
END LOOP;
759
'''
760 3533 aaronmk
        if self.row_var == row_var:
761 3529 aaronmk
            str_ = '''\
762 3449 aaronmk
DECLARE
763 3482 aaronmk
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
764 3449 aaronmk
BEGIN
765 3529 aaronmk
'''+strings.indent(str_)+'''\
766
END;
767 3449 aaronmk
'''
768
        return str_
769
770 2986 aaronmk
##### Casts
771
772
class Cast(FunctionCall):
773
    def __init__(self, type_, value):
774 3539 aaronmk
        type_ = as_Code(type_)
775 2986 aaronmk
        value = as_Value(value)
776
777 6299 aaronmk
        # Most types cannot be cast directly to unknown
778
        if type_.to_str(mockDb) == 'unknown': value = Cast('text', value)
779
780 2986 aaronmk
        self.type_ = type_
781
        self.value = value
782
783
    def to_str(self, db):
784 3539 aaronmk
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_.to_str(db)+')'
785 2986 aaronmk
786 3354 aaronmk
def cast_literal(value):
787 3429 aaronmk
    if not is_literal(value): return value
788 3354 aaronmk
789
    if util.is_str(value.value): value = Cast('text', value)
790
    return value
791
792 2335 aaronmk
##### Conditions
793 2259 aaronmk
794 3350 aaronmk
class NotCond(Code):
795
    def __init__(self, cond):
796 3445 aaronmk
        Code.__init__(self)
797
798 5893 aaronmk
        if not isinstance(cond, Coalesce): cond = Coalesce(cond, False)
799
800 3350 aaronmk
        self.cond = cond
801
802
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
803
804 2398 aaronmk
class ColValueCond(Code):
805
    def __init__(self, col, value):
806 3445 aaronmk
        Code.__init__(self)
807
808 2398 aaronmk
        value = as_ValueCond(value)
809
810
        self.col = col
811
        self.value = value
812
813
    def to_str(self, db): return self.value.to_str(db, self.col)
814
815 2577 aaronmk
def combine_conds(conds, keyword=None):
816
    '''
817
    @param keyword The keyword to add before the conditions, if any
818
    '''
819
    str_ = ''
820
    if keyword != None:
821
        if conds == []: whitespace = ''
822
        elif len(conds) == 1: whitespace = ' '
823
        else: whitespace = '\n'
824
        str_ += keyword+whitespace
825
826
    str_ += '\nAND '.join(conds)
827
    return str_
828
829 2398 aaronmk
##### Condition column comparisons
830
831 2514 aaronmk
class ValueCond(BasicObject):
832 2213 aaronmk
    def __init__(self, value):
833 2858 aaronmk
        value = remove_col_rename(as_Value(value))
834 2213 aaronmk
835
        self.value = value
836 2214 aaronmk
837 2216 aaronmk
    def to_str(self, db, left_value):
838 2214 aaronmk
        '''
839 2216 aaronmk
        @param left_value The Code object that the condition is being applied on
840 2214 aaronmk
        '''
841
        raise NotImplemented()
842 2228 aaronmk
843 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
844 2211 aaronmk
845
class CompareCond(ValueCond):
846
    def __init__(self, value, operator='='):
847 2222 aaronmk
        '''
848
        @param operator By default, compares NULL values literally. Use '~=' or
849
            '~!=' to pass NULLs through.
850
        '''
851 2211 aaronmk
        ValueCond.__init__(self, value)
852
        self.operator = operator
853
854 2216 aaronmk
    def to_str(self, db, left_value):
855 2858 aaronmk
        left_value = remove_col_rename(as_Col(left_value))
856 2216 aaronmk
857 2222 aaronmk
        right_value = self.value
858
859
        # Parse operator
860 2216 aaronmk
        operator = self.operator
861 2222 aaronmk
        passthru_null_ref = [False]
862
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
863
        neg_ref = [False]
864
        operator = strings.remove_prefix('!', operator, neg_ref)
865 2844 aaronmk
        equals = operator.endswith('=') # also includes <=, >=
866 2222 aaronmk
867 2825 aaronmk
        # Handle nullable columns
868
        check_null = False
869 2844 aaronmk
        if not passthru_null_ref[0]: # NULLs compare equal
870 2857 aaronmk
            try: left_value = ensure_not_null(db, left_value)
871 2844 aaronmk
            except ensure_not_null_excs: # fall back to alternate method
872
                check_null = equals and isinstance(right_value, Col)
873 2837 aaronmk
            else:
874 2857 aaronmk
                if isinstance(left_value, EnsureNotNull):
875
                    right_value = ensure_not_null(db, right_value,
876
                        left_value.type) # apply same function to both sides
877 2825 aaronmk
878 2844 aaronmk
        if equals and is_null(right_value): operator = 'IS'
879
880 2825 aaronmk
        left = left_value.to_str(db)
881
        right = right_value.to_str(db)
882
883 2222 aaronmk
        # Create str
884
        str_ = left+' '+operator+' '+right
885 2825 aaronmk
        if check_null:
886 2578 aaronmk
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
887
        if neg_ref[0]: str_ = 'NOT '+str_
888 2222 aaronmk
        return str_
889 2216 aaronmk
890 2260 aaronmk
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
891
assume_literal = object()
892
893
def as_ValueCond(value, default_table=assume_literal):
894
    if not isinstance(value, ValueCond):
895
        if default_table is not assume_literal:
896 2748 aaronmk
            value = with_default_table(value, default_table)
897 2260 aaronmk
        return CompareCond(value)
898 2216 aaronmk
    else: return value
899 2219 aaronmk
900 2335 aaronmk
##### Joins
901
902 2352 aaronmk
join_same = object() # tells Join the left and right columns have the same name
903 2260 aaronmk
904 2353 aaronmk
# Tells Join the left and right columns have the same name and are never NULL
905
join_same_not_null = object()
906
907 2260 aaronmk
filter_out = object() # tells Join to filter out rows that match the join
908
909 2514 aaronmk
class Join(BasicObject):
910 2746 aaronmk
    def __init__(self, table, mapping={}, type_=None):
911 2260 aaronmk
        '''
912
        @param mapping dict(right_table_col=left_table_col, ...)
913 7176 aaronmk
            or [using_col...]
914 2352 aaronmk
            * if left_table_col is join_same: left_table_col = right_table_col
915 2353 aaronmk
              * Note that right_table_col must be a string
916
            * if left_table_col is join_same_not_null:
917
              left_table_col = right_table_col and both have NOT NULL constraint
918
              * Note that right_table_col must be a string
919 2260 aaronmk
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
920
            * filter_out: equivalent to 'LEFT' with the query filtered by
921
              `table_pkey IS NULL` (indicating no match)
922
        '''
923
        if util.is_str(table): table = Table(table)
924 7176 aaronmk
        if lists.is_seq(mapping):
925
            mapping = dict(((c, join_same_not_null) for c in mapping))
926 2260 aaronmk
        assert type_ == None or util.is_str(type_) or type_ is filter_out
927
928
        self.table = table
929
        self.mapping = mapping
930
        self.type_ = type_
931
932 2749 aaronmk
    def to_str(self, db, left_table_):
933 2260 aaronmk
        def join(entry):
934
            '''Parses non-USING joins'''
935
            right_table_col, left_table_col = entry
936
937 2353 aaronmk
            # Switch order (right_table_col is on the left in the comparison)
938
            left = right_table_col
939
            right = left_table_col
940 2749 aaronmk
            left_table = self.table
941
            right_table = left_table_
942 2353 aaronmk
943 2747 aaronmk
            # Parse left side
944 2748 aaronmk
            left = with_default_table(left, left_table)
945 2747 aaronmk
946 2260 aaronmk
            # Parse special values
947 2747 aaronmk
            left_on_right = Col(left.name, right_table)
948
            if right is join_same: right = left_on_right
949 2353 aaronmk
            elif right is join_same_not_null:
950 2747 aaronmk
                right = CompareCond(left_on_right, '~=')
951 2260 aaronmk
952 2747 aaronmk
            # Parse right side
953 2353 aaronmk
            right = as_ValueCond(right, right_table)
954 2747 aaronmk
955
            return right.to_str(db, left)
956 2260 aaronmk
957 2265 aaronmk
        # Create join condition
958
        type_ = self.type_
959 2276 aaronmk
        joins = self.mapping
960 2746 aaronmk
        if joins == {}: join_cond = None
961
        elif type_ is not filter_out and reduce(operator.and_,
962 2460 aaronmk
            (v is join_same_not_null for v in joins.itervalues())):
963 2260 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
964 2747 aaronmk
            cols = map(to_name_only_col, joins.iterkeys())
965
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
966 2757 aaronmk
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
967 2260 aaronmk
968 2757 aaronmk
        if isinstance(self.table, NamedTable): whitespace = '\n'
969
        else: whitespace = ' '
970
971 2260 aaronmk
        # Create join
972
        if type_ is filter_out: type_ = 'LEFT'
973 2266 aaronmk
        str_ = ''
974
        if type_ != None: str_ += type_+' '
975 2757 aaronmk
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
976
        if join_cond != None: str_ += whitespace+join_cond
977 2266 aaronmk
        return str_
978 2349 aaronmk
979 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
980 2424 aaronmk
981
##### Value exprs
982
983 3089 aaronmk
all_cols = CustomCode('*')
984
985 2737 aaronmk
default = CustomCode('DEFAULT')
986
987 3090 aaronmk
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
988 2674 aaronmk
989 3061 aaronmk
class Coalesce(FunctionCall):
990
    def __init__(self, *args):
991
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
992 3060 aaronmk
993 3062 aaronmk
class Nullif(FunctionCall):
994
    def __init__(self, *args):
995
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
996
997 5891 aaronmk
null = Literal(None)
998 5892 aaronmk
null_as_str = Cast('text', null)
999 3706 aaronmk
1000
def to_text(value): return Coalesce(Cast('text', value), null_as_str)
1001
1002 2850 aaronmk
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
1003 2958 aaronmk
null_sentinels = {
1004
    'character varying': r'\N',
1005
    'double precision': 'NaN',
1006
    'integer': 2147483647,
1007
    'text': r'\N',
1008 5498 aaronmk
    'date': 'infinity',
1009 5266 aaronmk
    'timestamp with time zone': 'infinity',
1010
    'taxonrank': 'unknown',
1011 2958 aaronmk
}
1012 2692 aaronmk
1013 3061 aaronmk
class EnsureNotNull(Coalesce):
1014 2850 aaronmk
    def __init__(self, value, type_):
1015 4938 aaronmk
        if isinstance(type_, ArrayType): null = []
1016
        else: null = null_sentinels[type_]
1017
        Coalesce.__init__(self, as_Col(value), Cast(type_, null))
1018 2850 aaronmk
1019
        self.type = type_
1020 3001 aaronmk
1021
    def to_str(self, db):
1022
        col = self.args[0]
1023
        index_col_ = index_col(col)
1024
        if index_col_ != None: return index_col_.to_str(db)
1025 3061 aaronmk
        return Coalesce.to_str(self, db)
1026 2850 aaronmk
1027 3523 aaronmk
#### Arrays
1028
1029 3535 aaronmk
class ArrayMerge(FunctionCall):
1030 3523 aaronmk
    def __init__(self, sep, array):
1031
        array = to_Array(array)
1032
        FunctionCall.__init__(self, InternalFunction('array_to_string'), array,
1033
            sep)
1034
1035 3537 aaronmk
def merge_not_null(db, sep, values):
1036 3707 aaronmk
    return ArrayMerge(sep, map(to_text, values))
1037 3537 aaronmk
1038 2737 aaronmk
##### Table exprs
1039
1040
class Values(Code):
1041
    def __init__(self, values):
1042 2739 aaronmk
        '''
1043
        @param values [...]|[[...], ...] Can be one or multiple rows.
1044
        '''
1045 3445 aaronmk
        Code.__init__(self)
1046
1047 2739 aaronmk
        rows = values
1048
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
1049
            rows = [values]
1050
        for i, row in enumerate(rows):
1051
            rows[i] = map(remove_col_rename, map(as_Value, row))
1052 2737 aaronmk
1053 2739 aaronmk
        self.rows = rows
1054 2737 aaronmk
1055
    def to_str(self, db):
1056 3520 aaronmk
        return 'VALUES '+(', '.join((Tuple(*r).to_str(db) for r in self.rows)))
1057 2737 aaronmk
1058 2740 aaronmk
def NamedValues(name, cols, values):
1059 2745 aaronmk
    '''
1060 3048 aaronmk
    @param cols None|[...]
1061 2745 aaronmk
    @post `cols` will be changed to Col objects with the table set to `name`.
1062
    '''
1063 2834 aaronmk
    table = NamedTable(name, Values(values), cols)
1064 3048 aaronmk
    if cols != None: set_cols_table(table, cols)
1065 2834 aaronmk
    return table
1066 2740 aaronmk
1067 2674 aaronmk
##### Database structure
1068
1069 5074 aaronmk
def is_nullable(db, value):
1070 5330 aaronmk
    if not is_table_col(value): return is_null(value)
1071
    try: return db.col_info(value).nullable
1072 5074 aaronmk
    except NoUnderlyingTableException: return True # not a table column
1073
1074 4442 aaronmk
text_types = set(['character varying', 'text'])
1075 4406 aaronmk
1076 5334 aaronmk
def is_text_type(type_): return type_ in text_types
1077
1078 5335 aaronmk
def is_text_col(db, col): return is_text_type(db.col_info(col).type)
1079 4442 aaronmk
1080 5397 aaronmk
def canon_type(type_):
1081
    if type_ in text_types: return 'text'
1082
    else: return type_
1083
1084 2840 aaronmk
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
1085
1086 2851 aaronmk
def ensure_not_null(db, col, type_=None):
1087 2840 aaronmk
    '''
1088 2855 aaronmk
    @param col If type_ is not set, must have an underlying column.
1089 4488 aaronmk
    @param type_ If set, overrides the underlying column's type and casts the
1090
        column to it if needed.
1091 2840 aaronmk
    @return EnsureNotNull|Col
1092
    @throws ensure_not_null_excs
1093
    '''
1094 5329 aaronmk
    col = remove_col_rename(col)
1095 5332 aaronmk
1096
    try: col_type = db.col_info(underlying_col(col)).type
1097 2855 aaronmk
    except NoUnderlyingTableException:
1098 5333 aaronmk
        if type_ == None and is_null(col): raise # NULL has no type
1099 2855 aaronmk
    else:
1100 5332 aaronmk
        if type_ == None: type_ = col_type
1101
        elif type_ != col_type: col = Cast(type_, col)
1102 2855 aaronmk
1103 5331 aaronmk
    if is_nullable(db, col):
1104 2953 aaronmk
        try: col = EnsureNotNull(col, type_)
1105
        except KeyError, e:
1106
            # Warn of no null sentinel for type, even if caller catches error
1107
            warnings.warn(UserWarning(exc.str_(e)))
1108
            raise
1109
1110 2840 aaronmk
    return col
1111 3536 aaronmk
1112
def try_mk_not_null(db, value):
1113
    '''
1114
    Warning: This function does not guarantee that its result is NOT NULL.
1115
    '''
1116
    try: return ensure_not_null(db, value)
1117
    except ensure_not_null_excs: return value
1118 5367 aaronmk
1119
##### Expression transforming
1120
1121
true_expr = 'true'
1122
false_expr = 'false'
1123
1124
true_re = true_expr
1125
false_re = false_expr
1126
bool_re = r'(?:'+true_re+r'|'+false_re+r')'
1127
atom_re = r'(?:'+bool_re+r'|\([^()]*\)'+r')'
1128
1129
def logic_op_re(op, value_re, expr_re=''):
1130
    op_re = ' '+op+' '
1131
    return '(?:'+expr_re+op_re+value_re+'|'+value_re+op_re+expr_re+')'
1132
1133 5370 aaronmk
not_re = r'\bNOT '
1134 5372 aaronmk
not_false_re = not_re+false_re+r'\b'
1135
not_true_re = not_re+true_re+r'\b'
1136 5367 aaronmk
and_false_re = logic_op_re('AND', false_re, atom_re)
1137 5370 aaronmk
and_false_not_true_re = '(?:'+not_true_re+'|'+and_false_re+')'
1138 5367 aaronmk
and_true_re = logic_op_re('AND', true_re)
1139
or_re = logic_op_re('OR', bool_re)
1140
or_and_true_re = '(?:'+and_true_re+'|'+or_re+')'
1141
1142
def simplify_parens(expr):
1143
    return regexp.sub_nested(r'\(('+atom_re+')\)', r'\1', expr)
1144
1145
def simplify_recursive(sub_func, expr):
1146
    '''
1147
    @param sub_func See regexp.sub_recursive() sub_func param
1148
    '''
1149 5378 aaronmk
    return regexp.sub_recursive(lambda s: sub_func(simplify_parens(s)), expr)
1150
        # simplify_parens() is also done at end in final iteration
1151 5367 aaronmk
1152
def simplify_expr(expr):
1153
    def simplify_logic_ops(expr):
1154
        total_n = 0
1155 5371 aaronmk
        expr, n = re.subn(not_false_re, true_re, expr)
1156
        total_n += n
1157 5370 aaronmk
        expr, n = re.subn(and_false_not_true_re, false_expr, expr)
1158 5367 aaronmk
        total_n += n
1159
        expr, n = re.subn(or_and_true_re, r'', expr)
1160
        total_n += n
1161
        return expr, total_n
1162
1163 5826 aaronmk
    expr = expr.replace('NULL IS NULL', true_expr)
1164
    expr = expr.replace('NULL IS NOT NULL', false_expr)
1165 5367 aaronmk
    expr = simplify_recursive(simplify_logic_ops, expr)
1166
    return expr
1167
1168
name_re = r'(?:\w+|(?:"[^"]*")+)'
1169
1170
def parse_expr_col(str_):
1171
    match = re.match(r'^\('+name_re+r'\(('+name_re+r').*\)\)$', str_)
1172
    if match: str_ = match.group(1)
1173
    return unesc_name(str_)
1174
1175
def map_expr(db, expr, mapping, in_cols_found=None):
1176
    '''Replaces output columns with input columns in an expression.
1177
    @param in_cols_found If set, will be filled in with the expr's (input) cols
1178
    '''
1179
    for out, in_ in mapping.iteritems():
1180
        orig_expr = expr
1181
        out = to_name_only_col(out)
1182
        in_str = to_name_only_col(remove_col_rename(in_)).to_str(db)
1183
1184
        # Replace out both with and without quotes
1185
        expr = expr.replace(out.to_str(db), in_str)
1186 5507 aaronmk
        expr = re.sub(r'(?<!["\'\.=\[])\b'+out.name+r'\b(?!["\',\.=\]])',
1187
            in_str, expr)
1188 5367 aaronmk
1189
        if in_cols_found != None and expr != orig_expr: # replaced something
1190
            in_cols_found.append(in_)
1191
1192
    return simplify_expr(expr)