Project

General

Profile

1 2211 aaronmk
# SQL code generation
2
3 2748 aaronmk
import copy
4 3518 aaronmk
import itertools
5 2276 aaronmk
import operator
6 9383 aaronmk
from collections import OrderedDict
7 2568 aaronmk
import re
8 2653 aaronmk
import UserDict
9 2953 aaronmk
import warnings
10 2276 aaronmk
11 2667 aaronmk
import dicts
12 2953 aaronmk
import exc
13 2701 aaronmk
import iters
14
import lists
15 2360 aaronmk
import objects
16 5367 aaronmk
import regexp
17 2222 aaronmk
import strings
18 2227 aaronmk
import util
19 2211 aaronmk
20 2587 aaronmk
##### Names
21 2499 aaronmk
22 2608 aaronmk
identifier_max_len = 63 # works for both PostgreSQL and MySQL
23 2587 aaronmk
24 2932 aaronmk
def concat(str_, suffix):
25 2609 aaronmk
    '''Preserves version so that it won't be truncated off the string, leading
26
    to collisions.'''
27 2613 aaronmk
    # Preserve version
28 5552 aaronmk
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)+(?:::[\w ]+)*)$', str_)
29 2985 aaronmk
    if match:
30
        str_, old_suffix = match.groups()
31
        suffix = old_suffix+suffix
32 2613 aaronmk
33 2932 aaronmk
    return strings.concat(str_, suffix, identifier_max_len)
34 2587 aaronmk
35 2932 aaronmk
def truncate(str_): return concat(str_, '')
36 2842 aaronmk
37 2575 aaronmk
def is_safe_name(name):
38 2583 aaronmk
    '''A name is safe *and unambiguous* if it:
39
    * contains only *lowercase* word (\w) characters
40
    * doesn't start with a digit
41
    * contains "_", so that it's not a keyword
42 2984 aaronmk
    '''
43
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
44 2568 aaronmk
45 2499 aaronmk
def esc_name(name, quote='"'):
46
    return quote + name.replace(quote, quote+quote) + quote
47
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
48
49 3320 aaronmk
def unesc_name(name, quote='"'):
50
    removed_ref = [False]
51
    name = strings.remove_prefix(quote, name, removed_ref)
52
    if removed_ref[0]:
53
        name = strings.remove_suffix(quote, name, removed_ref)
54
        assert removed_ref[0]
55
        name = name.replace(quote+quote, quote)
56
    return name
57
58 2513 aaronmk
def clean_name(name): return name.replace('"', '').replace('`', '')
59
60 3041 aaronmk
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
61
62 3182 aaronmk
def lstrip(str_):
63
    '''Also removes comments.'''
64
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
65
    return str_.lstrip()
66
67 2659 aaronmk
##### General SQL code objects
68 2219 aaronmk
69 2349 aaronmk
class MockDb:
70 2503 aaronmk
    def esc_value(self, value): return strings.repr_no_u(value)
71 2349 aaronmk
72 2499 aaronmk
    def esc_name(self, name): return esc_name(name)
73 2859 aaronmk
74
    def col_info(self, col):
75
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
76
77 2349 aaronmk
mockDb = MockDb()
78
79 2514 aaronmk
class BasicObject(objects.BasicObject):
80
    def __str__(self): return clean_name(strings.repr_no_u(self))
81
82 2659 aaronmk
##### Unparameterized code objects
83
84 2514 aaronmk
class Code(BasicObject):
85 3446 aaronmk
    def __init__(self, lang='sql'):
86
        self.lang = lang
87 3445 aaronmk
88 2658 aaronmk
    def to_str(self, db): raise NotImplementedError()
89 2349 aaronmk
90 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb)
91 2211 aaronmk
92 2269 aaronmk
class CustomCode(Code):
93 3445 aaronmk
    def __init__(self, str_):
94
        Code.__init__(self)
95
96
        self.str_ = str_
97 2256 aaronmk
98
    def to_str(self, db): return self.str_
99
100 2815 aaronmk
def as_Code(value, db=None):
101
    '''
102
    @param db If set, runs db.std_code() on the value.
103
    '''
104 3447 aaronmk
    if isinstance(value, Code): return value
105
106 2815 aaronmk
    if util.is_str(value):
107
        if db != None: value = db.std_code(value)
108
        return CustomCode(value)
109 2659 aaronmk
    else: return Literal(value)
110
111 2540 aaronmk
class Expr(Code):
112 3445 aaronmk
    def __init__(self, expr):
113
        Code.__init__(self)
114
115
        self.expr = expr
116 2540 aaronmk
117
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
118
119 3086 aaronmk
##### Names
120
121
class Name(Code):
122 3087 aaronmk
    def __init__(self, name):
123 3445 aaronmk
        Code.__init__(self)
124
125 3087 aaronmk
        name = truncate(name)
126
127
        self.name = name
128 3086 aaronmk
129
    def to_str(self, db): return db.esc_name(self.name)
130
131
def as_Name(value):
132
    if isinstance(value, Code): return value
133
    else: return Name(value)
134
135 2335 aaronmk
##### Literal values
136
137 3519 aaronmk
#### Primitives
138
139 2216 aaronmk
class Literal(Code):
140 3445 aaronmk
    def __init__(self, value):
141
        Code.__init__(self)
142
143
        self.value = value
144 2213 aaronmk
145
    def to_str(self, db): return db.esc_value(self.value)
146 2211 aaronmk
147 2400 aaronmk
def as_Value(value):
148
    if isinstance(value, Code): return value
149
    else: return Literal(value)
150
151 6225 aaronmk
def get_value(value):
152
    '''Unwraps a Literal's value'''
153
    value = remove_col_rename(value)
154
    if isinstance(value, Literal): return value.value
155
    else:
156
        assert not isinstance(value, Code)
157
        return value
158
159 3429 aaronmk
def is_literal(value): return isinstance(value, Literal)
160 2216 aaronmk
161 3429 aaronmk
def is_null(value): return is_literal(value) and value.value == None
162
163 3519 aaronmk
#### Composites
164
165 3521 aaronmk
class List(Code):
166
    def __init__(self, values):
167 3519 aaronmk
        Code.__init__(self)
168
169
        self.values = values
170
171 3521 aaronmk
    def to_str(self, db): return ', '.join((v.to_str(db) for v in self.values))
172 3519 aaronmk
173 3521 aaronmk
class Tuple(List):
174
    def __init__(self, *values):
175
        List.__init__(self, values)
176
177
    def to_str(self, db): return '('+List.to_str(self, db)+')'
178
179 3520 aaronmk
class Row(Tuple):
180
    def to_str(self, db): return 'ROW'+Tuple.to_str(self, db)
181 3519 aaronmk
182 3522 aaronmk
### Arrays
183
184
class Array(List):
185
    def __init__(self, values):
186
        values = map(remove_col_rename, values)
187
188
        List.__init__(self, values)
189
190
    def to_str(self, db): return 'ARRAY['+List.to_str(self, db)+']'
191
192
def to_Array(value):
193
    if isinstance(value, Array): return value
194
    return Array(lists.mk_seq(value))
195
196 2711 aaronmk
##### Derived elements
197
198
src_self = object() # tells Col that it is its own source column
199
200
class Derived(Code):
201
    def __init__(self, srcs):
202 2712 aaronmk
        '''An element which was derived from some other element(s).
203 2711 aaronmk
        @param srcs See self.set_srcs()
204
        '''
205 3445 aaronmk
        Code.__init__(self)
206
207 2711 aaronmk
        self.set_srcs(srcs)
208
209 2713 aaronmk
    def set_srcs(self, srcs, overwrite=True):
210 2711 aaronmk
        '''
211
        @param srcs (self_type...)|src_self The element(s) this is derived from
212
        '''
213 2713 aaronmk
        if not overwrite and self.srcs != (): return # already set
214
215 2711 aaronmk
        if srcs == src_self: srcs = (self,)
216
        srcs = tuple(srcs) # make Col hashable
217
        self.srcs = srcs
218
219
    def _compare_on(self):
220
        compare_on = self.__dict__.copy()
221
        del compare_on['srcs'] # ignore
222
        return compare_on
223
224
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
225
226 2335 aaronmk
##### Tables
227
228 2712 aaronmk
class Table(Derived):
229 2991 aaronmk
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
230 2211 aaronmk
        '''
231
        @param schema str|None (for no schema)
232 2712 aaronmk
        @param srcs (Table...)|src_self See Derived.set_srcs()
233 2211 aaronmk
        '''
234 2712 aaronmk
        Derived.__init__(self, srcs)
235
236 3091 aaronmk
        if util.is_str(name): name = truncate(name)
237 2843 aaronmk
238 2211 aaronmk
        self.name = name
239
        self.schema = schema
240 2991 aaronmk
        self.is_temp = is_temp
241 5524 aaronmk
        self.order_by = None
242 3000 aaronmk
        self.index_cols = {}
243 2211 aaronmk
244 2348 aaronmk
    def to_str(self, db):
245
        str_ = ''
246 3088 aaronmk
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
247
        str_ += as_Name(self.name).to_str(db)
248 2348 aaronmk
        return str_
249 2336 aaronmk
250
    def to_Table(self): return self
251 3000 aaronmk
252
    def _compare_on(self):
253
        compare_on = Derived._compare_on(self)
254 5524 aaronmk
        del compare_on['order_by'] # ignore
255 3000 aaronmk
        del compare_on['index_cols'] # ignore
256
        return compare_on
257 2211 aaronmk
258 2835 aaronmk
def is_underlying_table(table):
259
    return isinstance(table, Table) and table.to_Table() is table
260 2832 aaronmk
261 4114 aaronmk
class NoUnderlyingTableException(Exception):
262
    def __init__(self, ref):
263
        Exception.__init__(self, 'for: '+strings.as_tt(strings.urepr(ref)))
264
        self.ref = ref
265 2902 aaronmk
266
def underlying_table(table):
267
    table = remove_table_rename(table)
268 3532 aaronmk
    if table != None and table.srcs:
269
        table, = table.srcs # for derived tables or row vars
270 4114 aaronmk
    if not is_underlying_table(table): raise NoUnderlyingTableException(table)
271 2902 aaronmk
    return table
272
273 2776 aaronmk
def as_Table(table, schema=None):
274 2270 aaronmk
    if table == None or isinstance(table, Code): return table
275 2776 aaronmk
    else: return Table(table, schema)
276 2219 aaronmk
277 3101 aaronmk
def suffixed_table(table, suffix):
278 3128 aaronmk
    table = copy.copy(table) # don't modify input!
279
    table.name = concat(table.name, suffix)
280
    return table
281 2707 aaronmk
282 2336 aaronmk
class NamedTable(Table):
283
    def __init__(self, name, code, cols=None):
284
        Table.__init__(self, name)
285
286 3016 aaronmk
        code = as_Table(code)
287 2741 aaronmk
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
288 3020 aaronmk
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
289 2336 aaronmk
290
        self.code = code
291
        self.cols = cols
292
293
    def to_str(self, db):
294 3026 aaronmk
        str_ = self.code.to_str(db)
295
        if str_.find('\n') >= 0: whitespace = '\n'
296
        else: whitespace = ' '
297
        str_ += whitespace+'AS '+Table.to_str(self, db)
298 2742 aaronmk
        if self.cols != None:
299
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
300 2336 aaronmk
        return str_
301
302
    def to_Table(self): return Table(self.name)
303
304 2753 aaronmk
def remove_table_rename(table):
305
    if isinstance(table, NamedTable): table = table.code
306
    return table
307
308 2335 aaronmk
##### Columns
309
310 2711 aaronmk
class Col(Derived):
311 2701 aaronmk
    def __init__(self, name, table=None, srcs=()):
312 2211 aaronmk
        '''
313
        @param table Table|None (for no table)
314 2711 aaronmk
        @param srcs (Col...)|src_self See Derived.set_srcs()
315 2211 aaronmk
        '''
316 2711 aaronmk
        Derived.__init__(self, srcs)
317
318 3091 aaronmk
        if util.is_str(name): name = truncate(name)
319 2241 aaronmk
        if util.is_str(table): table = Table(table)
320 2211 aaronmk
        assert table == None or isinstance(table, Table)
321
322
        self.name = name
323
        self.table = table
324
325 2989 aaronmk
    def to_str(self, db, for_str=False):
326 3088 aaronmk
        str_ = as_Name(self.name).to_str(db)
327 2989 aaronmk
        if for_str: str_ = clean_name(str_)
328 2933 aaronmk
        if self.table != None:
329 2989 aaronmk
            table = self.table.to_Table()
330 3750 aaronmk
            if for_str: str_ = concat(strings.ustr(table), '.'+str_)
331 2989 aaronmk
            else: str_ = table.to_str(db)+'.'+str_
332 2211 aaronmk
        return str_
333 2314 aaronmk
334 2989 aaronmk
    def __str__(self): return self.to_str(mockDb, for_str=True)
335 2933 aaronmk
336 2314 aaronmk
    def to_Col(self): return self
337 2211 aaronmk
338 3512 aaronmk
def is_col(col): return isinstance(col, Col)
339 2393 aaronmk
340 3512 aaronmk
def is_table_col(col): return is_col(col) and col.table != None
341
342 10138 aaronmk
def col2col_ref(db, col):
343
    assert isinstance(col, Col)
344
    return db.esc_value((col.table.to_str(db), col.name))
345
346 3000 aaronmk
def index_col(col):
347
    if not is_table_col(col): return None
348 3104 aaronmk
349
    table = col.table
350
    try: name = table.index_cols[col.name]
351
    except KeyError: return None
352
    else: return Col(name, table, col.srcs)
353 2999 aaronmk
354 3024 aaronmk
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
355 2996 aaronmk
356 2563 aaronmk
def as_Col(col, table=None, name=None):
357
    '''
358
    @param name If not None, any non-Col input will be renamed using NamedCol.
359
    '''
360
    if name != None:
361
        col = as_Value(col)
362
        if not isinstance(col, Col): col = NamedCol(name, col)
363 2333 aaronmk
364
    if isinstance(col, Code): return col
365 3513 aaronmk
    elif util.is_str(col): return Col(col, table)
366
    else: return Literal(col)
367 2260 aaronmk
368 3093 aaronmk
def with_table(col, table):
369 3105 aaronmk
    if isinstance(col, NamedCol): pass # doesn't take a table
370
    elif isinstance(col, FunctionCall):
371
        col = copy.deepcopy(col) # don't modify input!
372
        col.args[0].table = table
373 3493 aaronmk
    elif isinstance(col, Col):
374 3098 aaronmk
        col = copy.copy(col) # don't modify input!
375
        col.table = table
376 3093 aaronmk
    return col
377
378 3100 aaronmk
def with_default_table(col, table):
379 2747 aaronmk
    col = as_Col(col)
380 3100 aaronmk
    if col.table == None: col = with_table(col, table)
381 2747 aaronmk
    return col
382
383 2744 aaronmk
def set_cols_table(table, cols):
384
    table = as_Table(table)
385
386
    for i, col in enumerate(cols):
387
        col = cols[i] = as_Col(col)
388
        col.table = table
389
390 2401 aaronmk
def to_name_only_col(col, check_table=None):
391
    col = as_Col(col)
392 3020 aaronmk
    if not is_table_col(col): return col
393 2401 aaronmk
394
    if check_table != None:
395
        table = col.table
396
        assert table == None or table == check_table
397
    return Col(col.name)
398
399 2993 aaronmk
def suffixed_col(col, suffix):
400
    return Col(concat(col.name, suffix), col.table, col.srcs)
401
402 3516 aaronmk
def has_srcs(col): return is_col(col) and col.srcs
403
404 3518 aaronmk
def cross_join_srcs(cols):
405
    cols = filter(has_srcs, cols) # empty srcs will mess up the cross join
406
    srcs = [[s.name for s in c.srcs] for c in cols]
407 5816 aaronmk
    if not srcs: return [] # itertools.product() returns [()] for empty input
408 3518 aaronmk
    return [Col(','.join(s)) for s in itertools.product(*srcs)]
409 3514 aaronmk
410 2323 aaronmk
class NamedCol(Col):
411 2229 aaronmk
    def __init__(self, name, code):
412 2310 aaronmk
        Col.__init__(self, name)
413
414 3016 aaronmk
        code = as_Value(code)
415 2229 aaronmk
416
        self.code = code
417
418
    def to_str(self, db):
419 2310 aaronmk
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
420 2314 aaronmk
421
    def to_Col(self): return Col(self.name)
422 2229 aaronmk
423 2462 aaronmk
def remove_col_rename(col):
424
    if isinstance(col, NamedCol): col = col.code
425
    return col
426
427 2830 aaronmk
def underlying_col(col):
428
    col = remove_col_rename(col)
429 4114 aaronmk
    if not isinstance(col, Col): raise NoUnderlyingTableException(col)
430 2849 aaronmk
431 2902 aaronmk
    return Col(col.name, underlying_table(col.table), col.srcs)
432 2830 aaronmk
433 2703 aaronmk
def wrap(wrap_func, value):
434
    '''Wraps a value, propagating any column renaming to the returned value.'''
435
    if isinstance(value, NamedCol):
436
        return NamedCol(value.name, wrap_func(value.code))
437
    else: return wrap_func(value)
438
439 2667 aaronmk
class ColDict(dicts.DictProxy):
440 3691 aaronmk
    '''A dict that automatically makes inserted entries Col objects.
441
    Anything that isn't a column is wrapped in a NamedCol with the key's column
442
    name by `as_Col(value, name=key.name)`.
443
    '''
444 2564 aaronmk
445 2645 aaronmk
    def __init__(self, db, keys_table, dict_={}):
446 3158 aaronmk
        dicts.DictProxy.__init__(self, OrderedDict())
447 2667 aaronmk
448 2645 aaronmk
        keys_table = as_Table(keys_table)
449
450 2642 aaronmk
        self.db = db
451 2641 aaronmk
        self.table = keys_table
452 2653 aaronmk
        self.update(dict_) # after setting vars because __setitem__() needs them
453 2641 aaronmk
454 2667 aaronmk
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
455 2655 aaronmk
456 2667 aaronmk
    def __getitem__(self, key):
457
        return dicts.DictProxy.__getitem__(self, self._key(key))
458 2653 aaronmk
459 2564 aaronmk
    def __setitem__(self, key, value):
460 2642 aaronmk
        key = self._key(key)
461 3639 aaronmk
        if value == None:
462
            try: value = self.db.col_info(key).default
463
            except NoUnderlyingTableException: pass # not a table column
464 2667 aaronmk
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
465 2564 aaronmk
466 2641 aaronmk
    def _key(self, key): return as_Col(key, self.table)
467 2564 aaronmk
468 3519 aaronmk
##### Definitions
469 3491 aaronmk
470 3469 aaronmk
class TypedCol(Col):
471
    def __init__(self, name, type_, default=None, nullable=True,
472
        constraints=None):
473
        assert default == None or isinstance(default, Code)
474
475
        Col.__init__(self, name)
476
477
        self.type = type_
478
        self.default = default
479
        self.nullable = nullable
480
        self.constraints = constraints
481
482
    def to_str(self, db):
483 3481 aaronmk
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
484 3469 aaronmk
        if not self.nullable: str_ += ' NOT NULL'
485
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
486
        if self.constraints != None: str_ += ' '+self.constraints
487
        return str_
488
489
    def to_Col(self): return Col(self.name)
490
491 3488 aaronmk
class SetOf(Code):
492
    def __init__(self, type_):
493
        Code.__init__(self)
494
495
        self.type = type_
496
497
    def to_str(self, db):
498
        return 'SETOF '+self.type.to_str(db)
499
500 3483 aaronmk
class RowType(Code):
501
    def __init__(self, table):
502
        Code.__init__(self)
503
504
        self.table = table
505
506
    def to_str(self, db):
507
        return self.table.to_str(db)+'%ROWTYPE'
508
509 3485 aaronmk
class ColType(Code):
510
    def __init__(self, col):
511
        Code.__init__(self)
512
513
        self.col = col
514
515
    def to_str(self, db):
516
        return self.col.to_str(db)+'%TYPE'
517
518 4935 aaronmk
class ArrayType(Code):
519
    def __init__(self, elem_type):
520
        Code.__init__(self)
521
        elem_type = as_Code(elem_type)
522
523
        self.elem_type = elem_type
524
525
    def to_str(self, db):
526
        return self.elem_type.to_str(db)+'[]'
527
528 2524 aaronmk
##### Functions
529
530 2912 aaronmk
Function = Table
531 2911 aaronmk
as_Function = as_Table
532
533 2691 aaronmk
class InternalFunction(CustomCode): pass
534
535 3442 aaronmk
#### Calls
536
537 2941 aaronmk
class NamedArg(NamedCol):
538
    def __init__(self, name, value):
539
        NamedCol.__init__(self, name, value)
540
541
    def to_str(self, db):
542
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
543
544 2524 aaronmk
class FunctionCall(Code):
545 2941 aaronmk
    def __init__(self, function, *args, **kw_args):
546 2524 aaronmk
        '''
547 2690 aaronmk
        @param args [Code|literal-value...] The function's arguments
548 2524 aaronmk
        '''
549 3445 aaronmk
        Code.__init__(self)
550
551 3016 aaronmk
        function = as_Function(function)
552 2941 aaronmk
        def filter_(arg): return remove_col_rename(as_Value(arg))
553
        args = map(filter_, args)
554
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
555 2524 aaronmk
556
        self.function = function
557
        self.args = args
558
559
    def to_str(self, db):
560
        args_str = ', '.join((v.to_str(db) for v in self.args))
561
        return self.function.to_str(db)+'('+args_str+')'
562
563 2533 aaronmk
def wrap_in_func(function, value):
564
    '''Wraps a value inside a function call.
565
    Propagates any column renaming to the returned value.
566
    '''
567 2703 aaronmk
    return wrap(lambda v: FunctionCall(function, v), value)
568 2533 aaronmk
569 2561 aaronmk
def unwrap_func_call(func_call, check_name=None):
570
    '''Unwraps any function call to its first argument.
571
    Also removes any column renaming.
572
    '''
573
    func_call = remove_col_rename(func_call)
574
    if not isinstance(func_call, FunctionCall): return func_call
575
576
    if check_name != None:
577
        name = func_call.function.name
578
        assert name == None or name == check_name
579
    return func_call.args[0]
580
581 3442 aaronmk
#### Definitions
582
583
class FunctionDef(Code):
584 3471 aaronmk
    def __init__(self, function, return_type, body, params=[], modifiers=None):
585 3445 aaronmk
        Code.__init__(self)
586
587 3487 aaronmk
        return_type = as_Code(return_type)
588 3444 aaronmk
        body = as_Code(body)
589
590 3442 aaronmk
        self.function = function
591
        self.return_type = return_type
592
        self.body = body
593 3471 aaronmk
        self.params = params
594 3456 aaronmk
        self.modifiers = modifiers
595 3442 aaronmk
596
    def to_str(self, db):
597 3487 aaronmk
        params_str = (', '.join((p.to_str(db) for p in self.params)))
598 3442 aaronmk
        str_ = '''\
599 3471 aaronmk
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
600 3487 aaronmk
RETURNS '''+self.return_type.to_str(db)+'''
601 3448 aaronmk
LANGUAGE '''+self.body.lang+'''
602 3456 aaronmk
'''
603
        if self.modifiers != None: str_ += self.modifiers+'\n'
604
        str_ += '''\
605 3442 aaronmk
AS $$
606 3444 aaronmk
'''+self.body.to_str(db)+'''
607 3442 aaronmk
$$;
608
'''
609
        return str_
610
611 3469 aaronmk
class FunctionParam(TypedCol):
612
    def __init__(self, name, type_, default=None, out=False):
613
        TypedCol.__init__(self, name, type_, default)
614
615
        self.out = out
616
617
    def to_str(self, db):
618
        str_ = TypedCol.to_str(self, db)
619
        if self.out: str_ = 'OUT '+str_
620
        return str_
621
622
    def to_Col(self): return Col(self.name)
623
624 3454 aaronmk
### PL/pgSQL
625
626 3496 aaronmk
class ReturnQuery(Code):
627
    def __init__(self, query):
628
        Code.__init__(self)
629
630
        query = as_Code(query)
631
632
        self.query = query
633
634
    def to_str(self, db):
635
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
636
637 3515 aaronmk
## Exceptions
638
639 3509 aaronmk
class BaseExcHandler(BasicObject):
640
    def to_str(self, db, body): raise NotImplementedError()
641 3510 aaronmk
642
    def __repr__(self): return self.to_str(mockDb, '<body>')
643 3509 aaronmk
644 3549 aaronmk
suppress_exc = 'NULL;\n';
645
646 3554 aaronmk
reraise_exc = 'RAISE USING ERRCODE = SQLSTATE, MESSAGE = SQLERRM;\n';
647
648 3509 aaronmk
class ExcHandler(BaseExcHandler):
649 3454 aaronmk
    def __init__(self, exc, handler=None):
650
        if handler != None: handler = as_Code(handler)
651
652
        self.exc = exc
653
        self.handler = handler
654
655
    def to_str(self, db, body):
656
        body = as_Code(body)
657
658 3467 aaronmk
        if self.handler != None:
659
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
660 3549 aaronmk
        else: handler_str = ' '+suppress_exc
661 3454 aaronmk
662
        str_ = '''\
663
BEGIN
664 3467 aaronmk
'''+strings.indent(body.to_str(db))+'''\
665 3454 aaronmk
EXCEPTION
666 3463 aaronmk
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
667 3454 aaronmk
END;\
668
'''
669
        return str_
670
671 3515 aaronmk
class NestedExcHandler(BaseExcHandler):
672
    def __init__(self, *handlers):
673
        '''
674
        @param handlers Sorted from outermost to innermost
675
        '''
676
        self.handlers = handlers
677
678
    def to_str(self, db, body):
679
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
680
        return body
681
682 3503 aaronmk
class ExcToWarning(Code):
683
    def __init__(self, return_):
684
        '''
685
        @param return_ Statement to return a default value in case of error
686
        '''
687
        Code.__init__(self)
688
689
        return_ = as_Code(return_)
690
691
        self.return_ = return_
692
693
    def to_str(self, db):
694
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
695
696 3454 aaronmk
unique_violation_handler = ExcHandler('unique_violation')
697
698 3555 aaronmk
# Note doubled "\"s because inside Python string
699 3468 aaronmk
plpythonu_error_handler = ExcHandler('internal_error', '''\
700 3591 aaronmk
-- Handle PL/Python exceptions
701 3555 aaronmk
DECLARE
702 3595 aaronmk
    matches text[] := regexp_matches(SQLERRM,
703
        E'^(?:PL/Python: )?(\\\\w+): (.*)$'); -- .* also matches \\n
704 3555 aaronmk
    exc_name text := matches[1];
705
    msg text := matches[2];
706
BEGIN
707 3596 aaronmk
    /* Re-raise PL/Python exceptions with the PL/Python prefix removed.
708
    This allows the exception to be parsed like a native exception.
709
    Always raise as data_exception so it goes in the errors table. */
710 3598 aaronmk
    IF exc_name IS NOT NULL THEN
711 4026 aaronmk
        RAISE data_exception USING MESSAGE = msg;
712 3595 aaronmk
    -- Re-raise non-PL/Python exceptions
713
    ELSE
714
        '''+reraise_exc+'''\
715 3555 aaronmk
    END IF;
716
END;
717 3468 aaronmk
''')
718
719 3505 aaronmk
def data_exception_handler(handler):
720
    return ExcHandler('data_exception', handler)
721
722 3527 aaronmk
row_var = Table('row')
723
724 3449 aaronmk
class RowExcIgnore(Code):
725
    def __init__(self, row_type, select_query, with_row, cols=None,
726 3527 aaronmk
        exc_handler=unique_violation_handler, row_var=row_var):
727 3529 aaronmk
        '''
728
        @param row_type Ignored if a custom row_var is used.
729
        @pre If a custom row_var is used, it must already be defined.
730
        '''
731 3449 aaronmk
        Code.__init__(self, lang='plpgsql')
732
733 3482 aaronmk
        row_type = as_Code(row_type)
734 3449 aaronmk
        select_query = as_Code(select_query)
735
        with_row = as_Code(with_row)
736 3452 aaronmk
        row_var = as_Table(row_var)
737 3449 aaronmk
738
        self.row_type = row_type
739
        self.select_query = select_query
740
        self.with_row = with_row
741
        self.cols = cols
742 3455 aaronmk
        self.exc_handler = exc_handler
743 3452 aaronmk
        self.row_var = row_var
744 3449 aaronmk
745
    def to_str(self, db):
746 3452 aaronmk
        if self.cols == None: row_vars = [self.row_var]
747
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
748 3449 aaronmk
749 3526 aaronmk
        # Need an EXCEPTION block for each individual row because "When an error
750
        # is caught by an EXCEPTION clause, [...] all changes to persistent
751
        # database state within the block are rolled back."
752
        # This is unfortunate because "A block containing an EXCEPTION clause is
753
        # significantly more expensive to enter and exit than a block without
754
        # one."
755
        # (http://www.postgresql.org/docs/8.3/static/\
756
        # plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING)
757 3449 aaronmk
        str_ = '''\
758 3529 aaronmk
FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
759
'''+strings.indent(self.select_query.to_str(db))+'''\
760
LOOP
761
'''+strings.indent(self.exc_handler.to_str(db, self.with_row))+'''\
762
END LOOP;
763
'''
764 3533 aaronmk
        if self.row_var == row_var:
765 3529 aaronmk
            str_ = '''\
766 3449 aaronmk
DECLARE
767 3482 aaronmk
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
768 3449 aaronmk
BEGIN
769 3529 aaronmk
'''+strings.indent(str_)+'''\
770
END;
771 3449 aaronmk
'''
772
        return str_
773
774 2986 aaronmk
##### Casts
775
776
class Cast(FunctionCall):
777
    def __init__(self, type_, value):
778 3539 aaronmk
        type_ = as_Code(type_)
779 2986 aaronmk
        value = as_Value(value)
780
781 6299 aaronmk
        # Most types cannot be cast directly to unknown
782
        if type_.to_str(mockDb) == 'unknown': value = Cast('text', value)
783
784 2986 aaronmk
        self.type_ = type_
785
        self.value = value
786
787
    def to_str(self, db):
788 3539 aaronmk
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_.to_str(db)+')'
789 2986 aaronmk
790 3354 aaronmk
def cast_literal(value):
791 3429 aaronmk
    if not is_literal(value): return value
792 3354 aaronmk
793
    if util.is_str(value.value): value = Cast('text', value)
794
    return value
795
796 2335 aaronmk
##### Conditions
797 2259 aaronmk
798 3350 aaronmk
class NotCond(Code):
799
    def __init__(self, cond):
800 3445 aaronmk
        Code.__init__(self)
801
802 5893 aaronmk
        if not isinstance(cond, Coalesce): cond = Coalesce(cond, False)
803
804 3350 aaronmk
        self.cond = cond
805
806
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
807
808 2398 aaronmk
class ColValueCond(Code):
809
    def __init__(self, col, value):
810 3445 aaronmk
        Code.__init__(self)
811
812 2398 aaronmk
        value = as_ValueCond(value)
813
814
        self.col = col
815
        self.value = value
816
817
    def to_str(self, db): return self.value.to_str(db, self.col)
818
819 2577 aaronmk
def combine_conds(conds, keyword=None):
820
    '''
821
    @param keyword The keyword to add before the conditions, if any
822
    '''
823
    str_ = ''
824
    if keyword != None:
825
        if conds == []: whitespace = ''
826
        elif len(conds) == 1: whitespace = ' '
827
        else: whitespace = '\n'
828
        str_ += keyword+whitespace
829
830
    str_ += '\nAND '.join(conds)
831
    return str_
832
833 2398 aaronmk
##### Condition column comparisons
834
835 2514 aaronmk
class ValueCond(BasicObject):
836 2213 aaronmk
    def __init__(self, value):
837 2858 aaronmk
        value = remove_col_rename(as_Value(value))
838 2213 aaronmk
839
        self.value = value
840 2214 aaronmk
841 2216 aaronmk
    def to_str(self, db, left_value):
842 2214 aaronmk
        '''
843 2216 aaronmk
        @param left_value The Code object that the condition is being applied on
844 2214 aaronmk
        '''
845
        raise NotImplemented()
846 2228 aaronmk
847 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
848 2211 aaronmk
849
class CompareCond(ValueCond):
850
    def __init__(self, value, operator='='):
851 2222 aaronmk
        '''
852
        @param operator By default, compares NULL values literally. Use '~=' or
853
            '~!=' to pass NULLs through.
854
        '''
855 2211 aaronmk
        ValueCond.__init__(self, value)
856
        self.operator = operator
857
858 2216 aaronmk
    def to_str(self, db, left_value):
859 2858 aaronmk
        left_value = remove_col_rename(as_Col(left_value))
860 2216 aaronmk
861 2222 aaronmk
        right_value = self.value
862
863
        # Parse operator
864 2216 aaronmk
        operator = self.operator
865 2222 aaronmk
        passthru_null_ref = [False]
866
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
867
        neg_ref = [False]
868
        operator = strings.remove_prefix('!', operator, neg_ref)
869 2844 aaronmk
        equals = operator.endswith('=') # also includes <=, >=
870 2222 aaronmk
871 2825 aaronmk
        # Handle nullable columns
872
        check_null = False
873 2844 aaronmk
        if not passthru_null_ref[0]: # NULLs compare equal
874 2857 aaronmk
            try: left_value = ensure_not_null(db, left_value)
875 2844 aaronmk
            except ensure_not_null_excs: # fall back to alternate method
876
                check_null = equals and isinstance(right_value, Col)
877 2837 aaronmk
            else:
878 2857 aaronmk
                if isinstance(left_value, EnsureNotNull):
879
                    right_value = ensure_not_null(db, right_value,
880
                        left_value.type) # apply same function to both sides
881 2825 aaronmk
882 2844 aaronmk
        if equals and is_null(right_value): operator = 'IS'
883
884 2825 aaronmk
        left = left_value.to_str(db)
885
        right = right_value.to_str(db)
886
887 2222 aaronmk
        # Create str
888
        str_ = left+' '+operator+' '+right
889 2825 aaronmk
        if check_null:
890 2578 aaronmk
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
891
        if neg_ref[0]: str_ = 'NOT '+str_
892 2222 aaronmk
        return str_
893 2216 aaronmk
894 2260 aaronmk
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
895
assume_literal = object()
896
897
def as_ValueCond(value, default_table=assume_literal):
898
    if not isinstance(value, ValueCond):
899
        if default_table is not assume_literal:
900 2748 aaronmk
            value = with_default_table(value, default_table)
901 2260 aaronmk
        return CompareCond(value)
902 2216 aaronmk
    else: return value
903 2219 aaronmk
904 2335 aaronmk
##### Joins
905
906 2352 aaronmk
join_same = object() # tells Join the left and right columns have the same name
907 2260 aaronmk
908 2353 aaronmk
# Tells Join the left and right columns have the same name and are never NULL
909
join_same_not_null = object()
910
911 2260 aaronmk
filter_out = object() # tells Join to filter out rows that match the join
912
913 2514 aaronmk
class Join(BasicObject):
914 2746 aaronmk
    def __init__(self, table, mapping={}, type_=None):
915 2260 aaronmk
        '''
916
        @param mapping dict(right_table_col=left_table_col, ...)
917 7176 aaronmk
            or [using_col...]
918 2352 aaronmk
            * if left_table_col is join_same: left_table_col = right_table_col
919 2353 aaronmk
              * Note that right_table_col must be a string
920
            * if left_table_col is join_same_not_null:
921
              left_table_col = right_table_col and both have NOT NULL constraint
922
              * Note that right_table_col must be a string
923 2260 aaronmk
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
924
            * filter_out: equivalent to 'LEFT' with the query filtered by
925
              `table_pkey IS NULL` (indicating no match)
926
        '''
927
        if util.is_str(table): table = Table(table)
928 7176 aaronmk
        if lists.is_seq(mapping):
929
            mapping = dict(((c, join_same_not_null) for c in mapping))
930 2260 aaronmk
        assert type_ == None or util.is_str(type_) or type_ is filter_out
931
932
        self.table = table
933
        self.mapping = mapping
934
        self.type_ = type_
935
936 2749 aaronmk
    def to_str(self, db, left_table_):
937 2260 aaronmk
        def join(entry):
938
            '''Parses non-USING joins'''
939
            right_table_col, left_table_col = entry
940
941 2353 aaronmk
            # Switch order (right_table_col is on the left in the comparison)
942
            left = right_table_col
943
            right = left_table_col
944 2749 aaronmk
            left_table = self.table
945
            right_table = left_table_
946 2353 aaronmk
947 2747 aaronmk
            # Parse left side
948 2748 aaronmk
            left = with_default_table(left, left_table)
949 2747 aaronmk
950 2260 aaronmk
            # Parse special values
951 2747 aaronmk
            left_on_right = Col(left.name, right_table)
952
            if right is join_same: right = left_on_right
953 2353 aaronmk
            elif right is join_same_not_null:
954 2747 aaronmk
                right = CompareCond(left_on_right, '~=')
955 2260 aaronmk
956 2747 aaronmk
            # Parse right side
957 2353 aaronmk
            right = as_ValueCond(right, right_table)
958 2747 aaronmk
959
            return right.to_str(db, left)
960 2260 aaronmk
961 2265 aaronmk
        # Create join condition
962
        type_ = self.type_
963 2276 aaronmk
        joins = self.mapping
964 2746 aaronmk
        if joins == {}: join_cond = None
965
        elif type_ is not filter_out and reduce(operator.and_,
966 2460 aaronmk
            (v is join_same_not_null for v in joins.itervalues())):
967 2260 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
968 2747 aaronmk
            cols = map(to_name_only_col, joins.iterkeys())
969
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
970 2757 aaronmk
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
971 2260 aaronmk
972 2757 aaronmk
        if isinstance(self.table, NamedTable): whitespace = '\n'
973
        else: whitespace = ' '
974
975 2260 aaronmk
        # Create join
976
        if type_ is filter_out: type_ = 'LEFT'
977 2266 aaronmk
        str_ = ''
978
        if type_ != None: str_ += type_+' '
979 2757 aaronmk
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
980
        if join_cond != None: str_ += whitespace+join_cond
981 2266 aaronmk
        return str_
982 2349 aaronmk
983 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
984 2424 aaronmk
985
##### Value exprs
986
987 3089 aaronmk
all_cols = CustomCode('*')
988
989 2737 aaronmk
default = CustomCode('DEFAULT')
990
991 3090 aaronmk
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
992 2674 aaronmk
993 3061 aaronmk
class Coalesce(FunctionCall):
994
    def __init__(self, *args):
995
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
996 3060 aaronmk
997 3062 aaronmk
class Nullif(FunctionCall):
998
    def __init__(self, *args):
999
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
1000
1001 5891 aaronmk
null = Literal(None)
1002 5892 aaronmk
null_as_str = Cast('text', null)
1003 3706 aaronmk
1004
def to_text(value): return Coalesce(Cast('text', value), null_as_str)
1005
1006 2850 aaronmk
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
1007 2958 aaronmk
null_sentinels = {
1008
    'character varying': r'\N',
1009
    'double precision': 'NaN',
1010
    'integer': 2147483647,
1011
    'text': r'\N',
1012 5498 aaronmk
    'date': 'infinity',
1013 5266 aaronmk
    'timestamp with time zone': 'infinity',
1014
    'taxonrank': 'unknown',
1015 2958 aaronmk
}
1016 2692 aaronmk
1017 3061 aaronmk
class EnsureNotNull(Coalesce):
1018 2850 aaronmk
    def __init__(self, value, type_):
1019 4938 aaronmk
        if isinstance(type_, ArrayType): null = []
1020
        else: null = null_sentinels[type_]
1021
        Coalesce.__init__(self, as_Col(value), Cast(type_, null))
1022 2850 aaronmk
1023
        self.type = type_
1024 3001 aaronmk
1025
    def to_str(self, db):
1026
        col = self.args[0]
1027
        index_col_ = index_col(col)
1028
        if index_col_ != None: return index_col_.to_str(db)
1029 3061 aaronmk
        return Coalesce.to_str(self, db)
1030 2850 aaronmk
1031 3523 aaronmk
#### Arrays
1032
1033 3535 aaronmk
class ArrayMerge(FunctionCall):
1034 3523 aaronmk
    def __init__(self, sep, array):
1035
        array = to_Array(array)
1036
        FunctionCall.__init__(self, InternalFunction('array_to_string'), array,
1037
            sep)
1038
1039 3537 aaronmk
def merge_not_null(db, sep, values):
1040 3707 aaronmk
    return ArrayMerge(sep, map(to_text, values))
1041 3537 aaronmk
1042 2737 aaronmk
##### Table exprs
1043
1044
class Values(Code):
1045
    def __init__(self, values):
1046 2739 aaronmk
        '''
1047
        @param values [...]|[[...], ...] Can be one or multiple rows.
1048
        '''
1049 3445 aaronmk
        Code.__init__(self)
1050
1051 2739 aaronmk
        rows = values
1052
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
1053
            rows = [values]
1054
        for i, row in enumerate(rows):
1055
            rows[i] = map(remove_col_rename, map(as_Value, row))
1056 2737 aaronmk
1057 2739 aaronmk
        self.rows = rows
1058 2737 aaronmk
1059
    def to_str(self, db):
1060 3520 aaronmk
        return 'VALUES '+(', '.join((Tuple(*r).to_str(db) for r in self.rows)))
1061 2737 aaronmk
1062 2740 aaronmk
def NamedValues(name, cols, values):
1063 2745 aaronmk
    '''
1064 3048 aaronmk
    @param cols None|[...]
1065 2745 aaronmk
    @post `cols` will be changed to Col objects with the table set to `name`.
1066
    '''
1067 2834 aaronmk
    table = NamedTable(name, Values(values), cols)
1068 3048 aaronmk
    if cols != None: set_cols_table(table, cols)
1069 2834 aaronmk
    return table
1070 2740 aaronmk
1071 2674 aaronmk
##### Database structure
1072
1073 5074 aaronmk
def is_nullable(db, value):
1074 5330 aaronmk
    if not is_table_col(value): return is_null(value)
1075
    try: return db.col_info(value).nullable
1076 5074 aaronmk
    except NoUnderlyingTableException: return True # not a table column
1077
1078 4442 aaronmk
text_types = set(['character varying', 'text'])
1079 4406 aaronmk
1080 5334 aaronmk
def is_text_type(type_): return type_ in text_types
1081
1082 5335 aaronmk
def is_text_col(db, col): return is_text_type(db.col_info(col).type)
1083 4442 aaronmk
1084 5397 aaronmk
def canon_type(type_):
1085
    if type_ in text_types: return 'text'
1086
    else: return type_
1087
1088 2840 aaronmk
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
1089
1090 2851 aaronmk
def ensure_not_null(db, col, type_=None):
1091 2840 aaronmk
    '''
1092 2855 aaronmk
    @param col If type_ is not set, must have an underlying column.
1093 4488 aaronmk
    @param type_ If set, overrides the underlying column's type and casts the
1094
        column to it if needed.
1095 2840 aaronmk
    @return EnsureNotNull|Col
1096
    @throws ensure_not_null_excs
1097
    '''
1098 5329 aaronmk
    col = remove_col_rename(col)
1099 5332 aaronmk
1100
    try: col_type = db.col_info(underlying_col(col)).type
1101 2855 aaronmk
    except NoUnderlyingTableException:
1102 5333 aaronmk
        if type_ == None and is_null(col): raise # NULL has no type
1103 2855 aaronmk
    else:
1104 5332 aaronmk
        if type_ == None: type_ = col_type
1105
        elif type_ != col_type: col = Cast(type_, col)
1106 2855 aaronmk
1107 5331 aaronmk
    if is_nullable(db, col):
1108 2953 aaronmk
        try: col = EnsureNotNull(col, type_)
1109
        except KeyError, e:
1110
            # Warn of no null sentinel for type, even if caller catches error
1111
            warnings.warn(UserWarning(exc.str_(e)))
1112
            raise
1113
1114 2840 aaronmk
    return col
1115 3536 aaronmk
1116
def try_mk_not_null(db, value):
1117
    '''
1118
    Warning: This function does not guarantee that its result is NOT NULL.
1119
    '''
1120
    try: return ensure_not_null(db, value)
1121
    except ensure_not_null_excs: return value
1122 5367 aaronmk
1123
##### Expression transforming
1124
1125
true_expr = 'true'
1126
false_expr = 'false'
1127
1128
true_re = true_expr
1129
false_re = false_expr
1130
bool_re = r'(?:'+true_re+r'|'+false_re+r')'
1131
atom_re = r'(?:'+bool_re+r'|\([^()]*\)'+r')'
1132
1133
def logic_op_re(op, value_re, expr_re=''):
1134
    op_re = ' '+op+' '
1135
    return '(?:'+expr_re+op_re+value_re+'|'+value_re+op_re+expr_re+')'
1136
1137 5370 aaronmk
not_re = r'\bNOT '
1138 5372 aaronmk
not_false_re = not_re+false_re+r'\b'
1139
not_true_re = not_re+true_re+r'\b'
1140 5367 aaronmk
and_false_re = logic_op_re('AND', false_re, atom_re)
1141 5370 aaronmk
and_false_not_true_re = '(?:'+not_true_re+'|'+and_false_re+')'
1142 5367 aaronmk
and_true_re = logic_op_re('AND', true_re)
1143
or_re = logic_op_re('OR', bool_re)
1144
or_and_true_re = '(?:'+and_true_re+'|'+or_re+')'
1145
1146
def simplify_parens(expr):
1147
    return regexp.sub_nested(r'\(('+atom_re+')\)', r'\1', expr)
1148
1149
def simplify_recursive(sub_func, expr):
1150
    '''
1151
    @param sub_func See regexp.sub_recursive() sub_func param
1152
    '''
1153 5378 aaronmk
    return regexp.sub_recursive(lambda s: sub_func(simplify_parens(s)), expr)
1154
        # simplify_parens() is also done at end in final iteration
1155 5367 aaronmk
1156
def simplify_expr(expr):
1157
    def simplify_logic_ops(expr):
1158
        total_n = 0
1159 5371 aaronmk
        expr, n = re.subn(not_false_re, true_re, expr)
1160
        total_n += n
1161 5370 aaronmk
        expr, n = re.subn(and_false_not_true_re, false_expr, expr)
1162 5367 aaronmk
        total_n += n
1163
        expr, n = re.subn(or_and_true_re, r'', expr)
1164
        total_n += n
1165
        return expr, total_n
1166
1167 5826 aaronmk
    expr = expr.replace('NULL IS NULL', true_expr)
1168
    expr = expr.replace('NULL IS NOT NULL', false_expr)
1169 5367 aaronmk
    expr = simplify_recursive(simplify_logic_ops, expr)
1170
    return expr
1171
1172
name_re = r'(?:\w+|(?:"[^"]*")+)'
1173
1174
def parse_expr_col(str_):
1175
    match = re.match(r'^\('+name_re+r'\(('+name_re+r').*\)\)$', str_)
1176
    if match: str_ = match.group(1)
1177
    return unesc_name(str_)
1178
1179
def map_expr(db, expr, mapping, in_cols_found=None):
1180
    '''Replaces output columns with input columns in an expression.
1181
    @param in_cols_found If set, will be filled in with the expr's (input) cols
1182
    '''
1183
    for out, in_ in mapping.iteritems():
1184
        orig_expr = expr
1185
        out = to_name_only_col(out)
1186
        in_str = to_name_only_col(remove_col_rename(in_)).to_str(db)
1187
1188
        # Replace out both with and without quotes
1189
        expr = expr.replace(out.to_str(db), in_str)
1190 5507 aaronmk
        expr = re.sub(r'(?<!["\'\.=\[])\b'+out.name+r'\b(?!["\',\.=\]])',
1191
            in_str, expr)
1192 5367 aaronmk
1193
        if in_cols_found != None and expr != orig_expr: # replaced something
1194
            in_cols_found.append(in_)
1195
1196
    return simplify_expr(expr)