Project

General

Profile

1 2211 aaronmk
# SQL code generation
2
3 2748 aaronmk
import copy
4 3518 aaronmk
import itertools
5 2276 aaronmk
import operator
6 3158 aaronmk
from ordereddict import OrderedDict
7 2568 aaronmk
import re
8 2653 aaronmk
import UserDict
9 2953 aaronmk
import warnings
10 2276 aaronmk
11 2667 aaronmk
import dicts
12 2953 aaronmk
import exc
13 2701 aaronmk
import iters
14
import lists
15 2360 aaronmk
import objects
16 5367 aaronmk
import regexp
17 2222 aaronmk
import strings
18 2227 aaronmk
import util
19 2211 aaronmk
20 2587 aaronmk
##### Names
21 2499 aaronmk
22 2608 aaronmk
identifier_max_len = 63 # works for both PostgreSQL and MySQL
23 2587 aaronmk
24 2932 aaronmk
def concat(str_, suffix):
25 2609 aaronmk
    '''Preserves version so that it won't be truncated off the string, leading
26
    to collisions.'''
27 2613 aaronmk
    # Preserve version
28 5552 aaronmk
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)+(?:::[\w ]+)*)$', str_)
29 2985 aaronmk
    if match:
30
        str_, old_suffix = match.groups()
31
        suffix = old_suffix+suffix
32 2613 aaronmk
33 2932 aaronmk
    return strings.concat(str_, suffix, identifier_max_len)
34 2587 aaronmk
35 2932 aaronmk
def truncate(str_): return concat(str_, '')
36 2842 aaronmk
37 2575 aaronmk
def is_safe_name(name):
38 2583 aaronmk
    '''A name is safe *and unambiguous* if it:
39
    * contains only *lowercase* word (\w) characters
40
    * doesn't start with a digit
41
    * contains "_", so that it's not a keyword
42 2984 aaronmk
    '''
43
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
44 2568 aaronmk
45 2499 aaronmk
def esc_name(name, quote='"'):
46
    return quote + name.replace(quote, quote+quote) + quote
47
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
48
49 3320 aaronmk
def unesc_name(name, quote='"'):
50
    removed_ref = [False]
51
    name = strings.remove_prefix(quote, name, removed_ref)
52
    if removed_ref[0]:
53
        name = strings.remove_suffix(quote, name, removed_ref)
54
        assert removed_ref[0]
55
        name = name.replace(quote+quote, quote)
56
    return name
57
58 2513 aaronmk
def clean_name(name): return name.replace('"', '').replace('`', '')
59
60 3041 aaronmk
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
61
62 3182 aaronmk
def lstrip(str_):
63
    '''Also removes comments.'''
64
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
65
    return str_.lstrip()
66
67 2659 aaronmk
##### General SQL code objects
68 2219 aaronmk
69 2349 aaronmk
class MockDb:
70 2503 aaronmk
    def esc_value(self, value): return strings.repr_no_u(value)
71 2349 aaronmk
72 2499 aaronmk
    def esc_name(self, name): return esc_name(name)
73 2859 aaronmk
74
    def col_info(self, col):
75
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
76
77 2349 aaronmk
mockDb = MockDb()
78
79 2514 aaronmk
class BasicObject(objects.BasicObject):
80
    def __str__(self): return clean_name(strings.repr_no_u(self))
81
82 2659 aaronmk
##### Unparameterized code objects
83
84 2514 aaronmk
class Code(BasicObject):
85 3446 aaronmk
    def __init__(self, lang='sql'):
86
        self.lang = lang
87 3445 aaronmk
88 2658 aaronmk
    def to_str(self, db): raise NotImplementedError()
89 2349 aaronmk
90 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb)
91 2211 aaronmk
92 2269 aaronmk
class CustomCode(Code):
93 3445 aaronmk
    def __init__(self, str_):
94
        Code.__init__(self)
95
96
        self.str_ = str_
97 2256 aaronmk
98
    def to_str(self, db): return self.str_
99
100 2815 aaronmk
def as_Code(value, db=None):
101
    '''
102
    @param db If set, runs db.std_code() on the value.
103
    '''
104 3447 aaronmk
    if isinstance(value, Code): return value
105
106 2815 aaronmk
    if util.is_str(value):
107
        if db != None: value = db.std_code(value)
108
        return CustomCode(value)
109 2659 aaronmk
    else: return Literal(value)
110
111 2540 aaronmk
class Expr(Code):
112 3445 aaronmk
    def __init__(self, expr):
113
        Code.__init__(self)
114
115
        self.expr = expr
116 2540 aaronmk
117
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
118
119 3086 aaronmk
##### Names
120
121
class Name(Code):
122 3087 aaronmk
    def __init__(self, name):
123 3445 aaronmk
        Code.__init__(self)
124
125 3087 aaronmk
        name = truncate(name)
126
127
        self.name = name
128 3086 aaronmk
129
    def to_str(self, db): return db.esc_name(self.name)
130
131
def as_Name(value):
132
    if isinstance(value, Code): return value
133
    else: return Name(value)
134
135 2335 aaronmk
##### Literal values
136
137 3519 aaronmk
#### Primitives
138
139 2216 aaronmk
class Literal(Code):
140 3445 aaronmk
    def __init__(self, value):
141
        Code.__init__(self)
142
143
        self.value = value
144 2213 aaronmk
145
    def to_str(self, db): return db.esc_value(self.value)
146 2211 aaronmk
147 2400 aaronmk
def as_Value(value):
148
    if isinstance(value, Code): return value
149
    else: return Literal(value)
150
151 3429 aaronmk
def is_literal(value): return isinstance(value, Literal)
152 2216 aaronmk
153 3429 aaronmk
def is_null(value): return is_literal(value) and value.value == None
154
155 3519 aaronmk
#### Composites
156
157 3521 aaronmk
class List(Code):
158
    def __init__(self, values):
159 3519 aaronmk
        Code.__init__(self)
160
161
        self.values = values
162
163 3521 aaronmk
    def to_str(self, db): return ', '.join((v.to_str(db) for v in self.values))
164 3519 aaronmk
165 3521 aaronmk
class Tuple(List):
166
    def __init__(self, *values):
167
        List.__init__(self, values)
168
169
    def to_str(self, db): return '('+List.to_str(self, db)+')'
170
171 3520 aaronmk
class Row(Tuple):
172
    def to_str(self, db): return 'ROW'+Tuple.to_str(self, db)
173 3519 aaronmk
174 3522 aaronmk
### Arrays
175
176
class Array(List):
177
    def __init__(self, values):
178
        values = map(remove_col_rename, values)
179
180
        List.__init__(self, values)
181
182
    def to_str(self, db): return 'ARRAY['+List.to_str(self, db)+']'
183
184
def to_Array(value):
185
    if isinstance(value, Array): return value
186
    return Array(lists.mk_seq(value))
187
188 2711 aaronmk
##### Derived elements
189
190
src_self = object() # tells Col that it is its own source column
191
192
class Derived(Code):
193
    def __init__(self, srcs):
194 2712 aaronmk
        '''An element which was derived from some other element(s).
195 2711 aaronmk
        @param srcs See self.set_srcs()
196
        '''
197 3445 aaronmk
        Code.__init__(self)
198
199 2711 aaronmk
        self.set_srcs(srcs)
200
201 2713 aaronmk
    def set_srcs(self, srcs, overwrite=True):
202 2711 aaronmk
        '''
203
        @param srcs (self_type...)|src_self The element(s) this is derived from
204
        '''
205 2713 aaronmk
        if not overwrite and self.srcs != (): return # already set
206
207 2711 aaronmk
        if srcs == src_self: srcs = (self,)
208
        srcs = tuple(srcs) # make Col hashable
209
        self.srcs = srcs
210
211
    def _compare_on(self):
212
        compare_on = self.__dict__.copy()
213
        del compare_on['srcs'] # ignore
214
        return compare_on
215
216
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
217
218 2335 aaronmk
##### Tables
219
220 2712 aaronmk
class Table(Derived):
221 2991 aaronmk
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
222 2211 aaronmk
        '''
223
        @param schema str|None (for no schema)
224 2712 aaronmk
        @param srcs (Table...)|src_self See Derived.set_srcs()
225 2211 aaronmk
        '''
226 2712 aaronmk
        Derived.__init__(self, srcs)
227
228 3091 aaronmk
        if util.is_str(name): name = truncate(name)
229 2843 aaronmk
230 2211 aaronmk
        self.name = name
231
        self.schema = schema
232 2991 aaronmk
        self.is_temp = is_temp
233 5524 aaronmk
        self.order_by = None
234 3000 aaronmk
        self.index_cols = {}
235 2211 aaronmk
236 2348 aaronmk
    def to_str(self, db):
237
        str_ = ''
238 3088 aaronmk
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
239
        str_ += as_Name(self.name).to_str(db)
240 2348 aaronmk
        return str_
241 2336 aaronmk
242
    def to_Table(self): return self
243 3000 aaronmk
244
    def _compare_on(self):
245
        compare_on = Derived._compare_on(self)
246 5524 aaronmk
        del compare_on['order_by'] # ignore
247 3000 aaronmk
        del compare_on['index_cols'] # ignore
248
        return compare_on
249 2211 aaronmk
250 2835 aaronmk
def is_underlying_table(table):
251
    return isinstance(table, Table) and table.to_Table() is table
252 2832 aaronmk
253 4114 aaronmk
class NoUnderlyingTableException(Exception):
254
    def __init__(self, ref):
255
        Exception.__init__(self, 'for: '+strings.as_tt(strings.urepr(ref)))
256
        self.ref = ref
257 2902 aaronmk
258
def underlying_table(table):
259
    table = remove_table_rename(table)
260 3532 aaronmk
    if table != None and table.srcs:
261
        table, = table.srcs # for derived tables or row vars
262 4114 aaronmk
    if not is_underlying_table(table): raise NoUnderlyingTableException(table)
263 2902 aaronmk
    return table
264
265 2776 aaronmk
def as_Table(table, schema=None):
266 2270 aaronmk
    if table == None or isinstance(table, Code): return table
267 2776 aaronmk
    else: return Table(table, schema)
268 2219 aaronmk
269 3101 aaronmk
def suffixed_table(table, suffix):
270 3128 aaronmk
    table = copy.copy(table) # don't modify input!
271
    table.name = concat(table.name, suffix)
272
    return table
273 2707 aaronmk
274 2336 aaronmk
class NamedTable(Table):
275
    def __init__(self, name, code, cols=None):
276
        Table.__init__(self, name)
277
278 3016 aaronmk
        code = as_Table(code)
279 2741 aaronmk
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
280 3020 aaronmk
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
281 2336 aaronmk
282
        self.code = code
283
        self.cols = cols
284
285
    def to_str(self, db):
286 3026 aaronmk
        str_ = self.code.to_str(db)
287
        if str_.find('\n') >= 0: whitespace = '\n'
288
        else: whitespace = ' '
289
        str_ += whitespace+'AS '+Table.to_str(self, db)
290 2742 aaronmk
        if self.cols != None:
291
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
292 2336 aaronmk
        return str_
293
294
    def to_Table(self): return Table(self.name)
295
296 2753 aaronmk
def remove_table_rename(table):
297
    if isinstance(table, NamedTable): table = table.code
298
    return table
299
300 2335 aaronmk
##### Columns
301
302 2711 aaronmk
class Col(Derived):
303 2701 aaronmk
    def __init__(self, name, table=None, srcs=()):
304 2211 aaronmk
        '''
305
        @param table Table|None (for no table)
306 2711 aaronmk
        @param srcs (Col...)|src_self See Derived.set_srcs()
307 2211 aaronmk
        '''
308 2711 aaronmk
        Derived.__init__(self, srcs)
309
310 3091 aaronmk
        if util.is_str(name): name = truncate(name)
311 2241 aaronmk
        if util.is_str(table): table = Table(table)
312 2211 aaronmk
        assert table == None or isinstance(table, Table)
313
314
        self.name = name
315
        self.table = table
316
317 2989 aaronmk
    def to_str(self, db, for_str=False):
318 3088 aaronmk
        str_ = as_Name(self.name).to_str(db)
319 2989 aaronmk
        if for_str: str_ = clean_name(str_)
320 2933 aaronmk
        if self.table != None:
321 2989 aaronmk
            table = self.table.to_Table()
322 3750 aaronmk
            if for_str: str_ = concat(strings.ustr(table), '.'+str_)
323 2989 aaronmk
            else: str_ = table.to_str(db)+'.'+str_
324 2211 aaronmk
        return str_
325 2314 aaronmk
326 2989 aaronmk
    def __str__(self): return self.to_str(mockDb, for_str=True)
327 2933 aaronmk
328 2314 aaronmk
    def to_Col(self): return self
329 2211 aaronmk
330 3512 aaronmk
def is_col(col): return isinstance(col, Col)
331 2393 aaronmk
332 3512 aaronmk
def is_table_col(col): return is_col(col) and col.table != None
333
334 3000 aaronmk
def index_col(col):
335
    if not is_table_col(col): return None
336 3104 aaronmk
337
    table = col.table
338
    try: name = table.index_cols[col.name]
339
    except KeyError: return None
340
    else: return Col(name, table, col.srcs)
341 2999 aaronmk
342 3024 aaronmk
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
343 2996 aaronmk
344 2563 aaronmk
def as_Col(col, table=None, name=None):
345
    '''
346
    @param name If not None, any non-Col input will be renamed using NamedCol.
347
    '''
348
    if name != None:
349
        col = as_Value(col)
350
        if not isinstance(col, Col): col = NamedCol(name, col)
351 2333 aaronmk
352
    if isinstance(col, Code): return col
353 3513 aaronmk
    elif util.is_str(col): return Col(col, table)
354
    else: return Literal(col)
355 2260 aaronmk
356 3093 aaronmk
def with_table(col, table):
357 3105 aaronmk
    if isinstance(col, NamedCol): pass # doesn't take a table
358
    elif isinstance(col, FunctionCall):
359
        col = copy.deepcopy(col) # don't modify input!
360
        col.args[0].table = table
361 3493 aaronmk
    elif isinstance(col, Col):
362 3098 aaronmk
        col = copy.copy(col) # don't modify input!
363
        col.table = table
364 3093 aaronmk
    return col
365
366 3100 aaronmk
def with_default_table(col, table):
367 2747 aaronmk
    col = as_Col(col)
368 3100 aaronmk
    if col.table == None: col = with_table(col, table)
369 2747 aaronmk
    return col
370
371 2744 aaronmk
def set_cols_table(table, cols):
372
    table = as_Table(table)
373
374
    for i, col in enumerate(cols):
375
        col = cols[i] = as_Col(col)
376
        col.table = table
377
378 2401 aaronmk
def to_name_only_col(col, check_table=None):
379
    col = as_Col(col)
380 3020 aaronmk
    if not is_table_col(col): return col
381 2401 aaronmk
382
    if check_table != None:
383
        table = col.table
384
        assert table == None or table == check_table
385
    return Col(col.name)
386
387 2993 aaronmk
def suffixed_col(col, suffix):
388
    return Col(concat(col.name, suffix), col.table, col.srcs)
389
390 3516 aaronmk
def has_srcs(col): return is_col(col) and col.srcs
391
392 3518 aaronmk
def cross_join_srcs(cols):
393
    cols = filter(has_srcs, cols) # empty srcs will mess up the cross join
394
    srcs = [[s.name for s in c.srcs] for c in cols]
395 5816 aaronmk
    if not srcs: return [] # itertools.product() returns [()] for empty input
396 3518 aaronmk
    return [Col(','.join(s)) for s in itertools.product(*srcs)]
397 3514 aaronmk
398 2323 aaronmk
class NamedCol(Col):
399 2229 aaronmk
    def __init__(self, name, code):
400 2310 aaronmk
        Col.__init__(self, name)
401
402 3016 aaronmk
        code = as_Value(code)
403 2229 aaronmk
404
        self.code = code
405
406
    def to_str(self, db):
407 2310 aaronmk
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
408 2314 aaronmk
409
    def to_Col(self): return Col(self.name)
410 2229 aaronmk
411 2462 aaronmk
def remove_col_rename(col):
412
    if isinstance(col, NamedCol): col = col.code
413
    return col
414
415 2830 aaronmk
def underlying_col(col):
416
    col = remove_col_rename(col)
417 4114 aaronmk
    if not isinstance(col, Col): raise NoUnderlyingTableException(col)
418 2849 aaronmk
419 2902 aaronmk
    return Col(col.name, underlying_table(col.table), col.srcs)
420 2830 aaronmk
421 2703 aaronmk
def wrap(wrap_func, value):
422
    '''Wraps a value, propagating any column renaming to the returned value.'''
423
    if isinstance(value, NamedCol):
424
        return NamedCol(value.name, wrap_func(value.code))
425
    else: return wrap_func(value)
426
427 2667 aaronmk
class ColDict(dicts.DictProxy):
428 3691 aaronmk
    '''A dict that automatically makes inserted entries Col objects.
429
    Anything that isn't a column is wrapped in a NamedCol with the key's column
430
    name by `as_Col(value, name=key.name)`.
431
    '''
432 2564 aaronmk
433 2645 aaronmk
    def __init__(self, db, keys_table, dict_={}):
434 3158 aaronmk
        dicts.DictProxy.__init__(self, OrderedDict())
435 2667 aaronmk
436 2645 aaronmk
        keys_table = as_Table(keys_table)
437
438 2642 aaronmk
        self.db = db
439 2641 aaronmk
        self.table = keys_table
440 2653 aaronmk
        self.update(dict_) # after setting vars because __setitem__() needs them
441 2641 aaronmk
442 2667 aaronmk
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
443 2655 aaronmk
444 2667 aaronmk
    def __getitem__(self, key):
445
        return dicts.DictProxy.__getitem__(self, self._key(key))
446 2653 aaronmk
447 2564 aaronmk
    def __setitem__(self, key, value):
448 2642 aaronmk
        key = self._key(key)
449 3639 aaronmk
        if value == None:
450
            try: value = self.db.col_info(key).default
451
            except NoUnderlyingTableException: pass # not a table column
452 2667 aaronmk
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
453 2564 aaronmk
454 2641 aaronmk
    def _key(self, key): return as_Col(key, self.table)
455 2564 aaronmk
456 3519 aaronmk
##### Definitions
457 3491 aaronmk
458 3469 aaronmk
class TypedCol(Col):
459
    def __init__(self, name, type_, default=None, nullable=True,
460
        constraints=None):
461
        assert default == None or isinstance(default, Code)
462
463
        Col.__init__(self, name)
464
465
        self.type = type_
466
        self.default = default
467
        self.nullable = nullable
468
        self.constraints = constraints
469
470
    def to_str(self, db):
471 3481 aaronmk
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
472 3469 aaronmk
        if not self.nullable: str_ += ' NOT NULL'
473
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
474
        if self.constraints != None: str_ += ' '+self.constraints
475
        return str_
476
477
    def to_Col(self): return Col(self.name)
478
479 3488 aaronmk
class SetOf(Code):
480
    def __init__(self, type_):
481
        Code.__init__(self)
482
483
        self.type = type_
484
485
    def to_str(self, db):
486
        return 'SETOF '+self.type.to_str(db)
487
488 3483 aaronmk
class RowType(Code):
489
    def __init__(self, table):
490
        Code.__init__(self)
491
492
        self.table = table
493
494
    def to_str(self, db):
495
        return self.table.to_str(db)+'%ROWTYPE'
496
497 3485 aaronmk
class ColType(Code):
498
    def __init__(self, col):
499
        Code.__init__(self)
500
501
        self.col = col
502
503
    def to_str(self, db):
504
        return self.col.to_str(db)+'%TYPE'
505
506 4935 aaronmk
class ArrayType(Code):
507
    def __init__(self, elem_type):
508
        Code.__init__(self)
509
        elem_type = as_Code(elem_type)
510
511
        self.elem_type = elem_type
512
513
    def to_str(self, db):
514
        return self.elem_type.to_str(db)+'[]'
515
516 2524 aaronmk
##### Functions
517
518 2912 aaronmk
Function = Table
519 2911 aaronmk
as_Function = as_Table
520
521 2691 aaronmk
class InternalFunction(CustomCode): pass
522
523 3442 aaronmk
#### Calls
524
525 2941 aaronmk
class NamedArg(NamedCol):
526
    def __init__(self, name, value):
527
        NamedCol.__init__(self, name, value)
528
529
    def to_str(self, db):
530
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
531
532 2524 aaronmk
class FunctionCall(Code):
533 2941 aaronmk
    def __init__(self, function, *args, **kw_args):
534 2524 aaronmk
        '''
535 2690 aaronmk
        @param args [Code|literal-value...] The function's arguments
536 2524 aaronmk
        '''
537 3445 aaronmk
        Code.__init__(self)
538
539 3016 aaronmk
        function = as_Function(function)
540 2941 aaronmk
        def filter_(arg): return remove_col_rename(as_Value(arg))
541
        args = map(filter_, args)
542
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
543 2524 aaronmk
544
        self.function = function
545
        self.args = args
546
547
    def to_str(self, db):
548
        args_str = ', '.join((v.to_str(db) for v in self.args))
549
        return self.function.to_str(db)+'('+args_str+')'
550
551 2533 aaronmk
def wrap_in_func(function, value):
552
    '''Wraps a value inside a function call.
553
    Propagates any column renaming to the returned value.
554
    '''
555 2703 aaronmk
    return wrap(lambda v: FunctionCall(function, v), value)
556 2533 aaronmk
557 2561 aaronmk
def unwrap_func_call(func_call, check_name=None):
558
    '''Unwraps any function call to its first argument.
559
    Also removes any column renaming.
560
    '''
561
    func_call = remove_col_rename(func_call)
562
    if not isinstance(func_call, FunctionCall): return func_call
563
564
    if check_name != None:
565
        name = func_call.function.name
566
        assert name == None or name == check_name
567
    return func_call.args[0]
568
569 3442 aaronmk
#### Definitions
570
571
class FunctionDef(Code):
572 3471 aaronmk
    def __init__(self, function, return_type, body, params=[], modifiers=None):
573 3445 aaronmk
        Code.__init__(self)
574
575 3487 aaronmk
        return_type = as_Code(return_type)
576 3444 aaronmk
        body = as_Code(body)
577
578 3442 aaronmk
        self.function = function
579
        self.return_type = return_type
580
        self.body = body
581 3471 aaronmk
        self.params = params
582 3456 aaronmk
        self.modifiers = modifiers
583 3442 aaronmk
584
    def to_str(self, db):
585 3487 aaronmk
        params_str = (', '.join((p.to_str(db) for p in self.params)))
586 3442 aaronmk
        str_ = '''\
587 3471 aaronmk
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
588 3487 aaronmk
RETURNS '''+self.return_type.to_str(db)+'''
589 3448 aaronmk
LANGUAGE '''+self.body.lang+'''
590 3456 aaronmk
'''
591
        if self.modifiers != None: str_ += self.modifiers+'\n'
592
        str_ += '''\
593 3442 aaronmk
AS $$
594 3444 aaronmk
'''+self.body.to_str(db)+'''
595 3442 aaronmk
$$;
596
'''
597
        return str_
598
599 3469 aaronmk
class FunctionParam(TypedCol):
600
    def __init__(self, name, type_, default=None, out=False):
601
        TypedCol.__init__(self, name, type_, default)
602
603
        self.out = out
604
605
    def to_str(self, db):
606
        str_ = TypedCol.to_str(self, db)
607
        if self.out: str_ = 'OUT '+str_
608
        return str_
609
610
    def to_Col(self): return Col(self.name)
611
612 3454 aaronmk
### PL/pgSQL
613
614 3496 aaronmk
class ReturnQuery(Code):
615
    def __init__(self, query):
616
        Code.__init__(self)
617
618
        query = as_Code(query)
619
620
        self.query = query
621
622
    def to_str(self, db):
623
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
624
625 3515 aaronmk
## Exceptions
626
627 3509 aaronmk
class BaseExcHandler(BasicObject):
628
    def to_str(self, db, body): raise NotImplementedError()
629 3510 aaronmk
630
    def __repr__(self): return self.to_str(mockDb, '<body>')
631 3509 aaronmk
632 3549 aaronmk
suppress_exc = 'NULL;\n';
633
634 3554 aaronmk
reraise_exc = 'RAISE USING ERRCODE = SQLSTATE, MESSAGE = SQLERRM;\n';
635
636 3509 aaronmk
class ExcHandler(BaseExcHandler):
637 3454 aaronmk
    def __init__(self, exc, handler=None):
638
        if handler != None: handler = as_Code(handler)
639
640
        self.exc = exc
641
        self.handler = handler
642
643
    def to_str(self, db, body):
644
        body = as_Code(body)
645
646 3467 aaronmk
        if self.handler != None:
647
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
648 3549 aaronmk
        else: handler_str = ' '+suppress_exc
649 3454 aaronmk
650
        str_ = '''\
651
BEGIN
652 3467 aaronmk
'''+strings.indent(body.to_str(db))+'''\
653 3454 aaronmk
EXCEPTION
654 3463 aaronmk
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
655 3454 aaronmk
END;\
656
'''
657
        return str_
658
659 3515 aaronmk
class NestedExcHandler(BaseExcHandler):
660
    def __init__(self, *handlers):
661
        '''
662
        @param handlers Sorted from outermost to innermost
663
        '''
664
        self.handlers = handlers
665
666
    def to_str(self, db, body):
667
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
668
        return body
669
670 3503 aaronmk
class ExcToWarning(Code):
671
    def __init__(self, return_):
672
        '''
673
        @param return_ Statement to return a default value in case of error
674
        '''
675
        Code.__init__(self)
676
677
        return_ = as_Code(return_)
678
679
        self.return_ = return_
680
681
    def to_str(self, db):
682
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
683
684 3454 aaronmk
unique_violation_handler = ExcHandler('unique_violation')
685
686 3555 aaronmk
# Note doubled "\"s because inside Python string
687 3468 aaronmk
plpythonu_error_handler = ExcHandler('internal_error', '''\
688 3591 aaronmk
-- Handle PL/Python exceptions
689 3555 aaronmk
DECLARE
690 3595 aaronmk
    matches text[] := regexp_matches(SQLERRM,
691
        E'^(?:PL/Python: )?(\\\\w+): (.*)$'); -- .* also matches \\n
692 3555 aaronmk
    exc_name text := matches[1];
693
    msg text := matches[2];
694
BEGIN
695 3596 aaronmk
    /* Re-raise PL/Python exceptions with the PL/Python prefix removed.
696
    This allows the exception to be parsed like a native exception.
697
    Always raise as data_exception so it goes in the errors table. */
698 3598 aaronmk
    IF exc_name IS NOT NULL THEN
699 4026 aaronmk
        RAISE data_exception USING MESSAGE = msg;
700 3595 aaronmk
    -- Re-raise non-PL/Python exceptions
701
    ELSE
702
        '''+reraise_exc+'''\
703 3555 aaronmk
    END IF;
704
END;
705 3468 aaronmk
''')
706
707 3505 aaronmk
def data_exception_handler(handler):
708
    return ExcHandler('data_exception', handler)
709
710 3527 aaronmk
row_var = Table('row')
711
712 3449 aaronmk
class RowExcIgnore(Code):
713
    def __init__(self, row_type, select_query, with_row, cols=None,
714 3527 aaronmk
        exc_handler=unique_violation_handler, row_var=row_var):
715 3529 aaronmk
        '''
716
        @param row_type Ignored if a custom row_var is used.
717
        @pre If a custom row_var is used, it must already be defined.
718
        '''
719 3449 aaronmk
        Code.__init__(self, lang='plpgsql')
720
721 3482 aaronmk
        row_type = as_Code(row_type)
722 3449 aaronmk
        select_query = as_Code(select_query)
723
        with_row = as_Code(with_row)
724 3452 aaronmk
        row_var = as_Table(row_var)
725 3449 aaronmk
726
        self.row_type = row_type
727
        self.select_query = select_query
728
        self.with_row = with_row
729
        self.cols = cols
730 3455 aaronmk
        self.exc_handler = exc_handler
731 3452 aaronmk
        self.row_var = row_var
732 3449 aaronmk
733
    def to_str(self, db):
734 3452 aaronmk
        if self.cols == None: row_vars = [self.row_var]
735
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
736 3449 aaronmk
737 3526 aaronmk
        # Need an EXCEPTION block for each individual row because "When an error
738
        # is caught by an EXCEPTION clause, [...] all changes to persistent
739
        # database state within the block are rolled back."
740
        # This is unfortunate because "A block containing an EXCEPTION clause is
741
        # significantly more expensive to enter and exit than a block without
742
        # one."
743
        # (http://www.postgresql.org/docs/8.3/static/\
744
        # plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING)
745 3449 aaronmk
        str_ = '''\
746 3529 aaronmk
FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
747
'''+strings.indent(self.select_query.to_str(db))+'''\
748
LOOP
749
'''+strings.indent(self.exc_handler.to_str(db, self.with_row))+'''\
750
END LOOP;
751
'''
752 3533 aaronmk
        if self.row_var == row_var:
753 3529 aaronmk
            str_ = '''\
754 3449 aaronmk
DECLARE
755 3482 aaronmk
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
756 3449 aaronmk
BEGIN
757 3529 aaronmk
'''+strings.indent(str_)+'''\
758
END;
759 3449 aaronmk
'''
760
        return str_
761
762 2986 aaronmk
##### Casts
763
764
class Cast(FunctionCall):
765
    def __init__(self, type_, value):
766 3539 aaronmk
        type_ = as_Code(type_)
767 2986 aaronmk
        value = as_Value(value)
768
769
        self.type_ = type_
770
        self.value = value
771
772
    def to_str(self, db):
773 3539 aaronmk
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_.to_str(db)+')'
774 2986 aaronmk
775 3354 aaronmk
def cast_literal(value):
776 3429 aaronmk
    if not is_literal(value): return value
777 3354 aaronmk
778
    if util.is_str(value.value): value = Cast('text', value)
779
    return value
780
781 2335 aaronmk
##### Conditions
782 2259 aaronmk
783 3350 aaronmk
class NotCond(Code):
784
    def __init__(self, cond):
785 3445 aaronmk
        Code.__init__(self)
786
787 5893 aaronmk
        if not isinstance(cond, Coalesce): cond = Coalesce(cond, False)
788
789 3350 aaronmk
        self.cond = cond
790
791
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
792
793 2398 aaronmk
class ColValueCond(Code):
794
    def __init__(self, col, value):
795 3445 aaronmk
        Code.__init__(self)
796
797 2398 aaronmk
        value = as_ValueCond(value)
798
799
        self.col = col
800
        self.value = value
801
802
    def to_str(self, db): return self.value.to_str(db, self.col)
803
804 2577 aaronmk
def combine_conds(conds, keyword=None):
805
    '''
806
    @param keyword The keyword to add before the conditions, if any
807
    '''
808
    str_ = ''
809
    if keyword != None:
810
        if conds == []: whitespace = ''
811
        elif len(conds) == 1: whitespace = ' '
812
        else: whitespace = '\n'
813
        str_ += keyword+whitespace
814
815
    str_ += '\nAND '.join(conds)
816
    return str_
817
818 2398 aaronmk
##### Condition column comparisons
819
820 2514 aaronmk
class ValueCond(BasicObject):
821 2213 aaronmk
    def __init__(self, value):
822 2858 aaronmk
        value = remove_col_rename(as_Value(value))
823 2213 aaronmk
824
        self.value = value
825 2214 aaronmk
826 2216 aaronmk
    def to_str(self, db, left_value):
827 2214 aaronmk
        '''
828 2216 aaronmk
        @param left_value The Code object that the condition is being applied on
829 2214 aaronmk
        '''
830
        raise NotImplemented()
831 2228 aaronmk
832 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
833 2211 aaronmk
834
class CompareCond(ValueCond):
835
    def __init__(self, value, operator='='):
836 2222 aaronmk
        '''
837
        @param operator By default, compares NULL values literally. Use '~=' or
838
            '~!=' to pass NULLs through.
839
        '''
840 2211 aaronmk
        ValueCond.__init__(self, value)
841
        self.operator = operator
842
843 2216 aaronmk
    def to_str(self, db, left_value):
844 2858 aaronmk
        left_value = remove_col_rename(as_Col(left_value))
845 2216 aaronmk
846 2222 aaronmk
        right_value = self.value
847
848
        # Parse operator
849 2216 aaronmk
        operator = self.operator
850 2222 aaronmk
        passthru_null_ref = [False]
851
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
852
        neg_ref = [False]
853
        operator = strings.remove_prefix('!', operator, neg_ref)
854 2844 aaronmk
        equals = operator.endswith('=') # also includes <=, >=
855 2222 aaronmk
856 2825 aaronmk
        # Handle nullable columns
857
        check_null = False
858 2844 aaronmk
        if not passthru_null_ref[0]: # NULLs compare equal
859 2857 aaronmk
            try: left_value = ensure_not_null(db, left_value)
860 2844 aaronmk
            except ensure_not_null_excs: # fall back to alternate method
861
                check_null = equals and isinstance(right_value, Col)
862 2837 aaronmk
            else:
863 2857 aaronmk
                if isinstance(left_value, EnsureNotNull):
864
                    right_value = ensure_not_null(db, right_value,
865
                        left_value.type) # apply same function to both sides
866 2825 aaronmk
867 2844 aaronmk
        if equals and is_null(right_value): operator = 'IS'
868
869 2825 aaronmk
        left = left_value.to_str(db)
870
        right = right_value.to_str(db)
871
872 2222 aaronmk
        # Create str
873
        str_ = left+' '+operator+' '+right
874 2825 aaronmk
        if check_null:
875 2578 aaronmk
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
876
        if neg_ref[0]: str_ = 'NOT '+str_
877 2222 aaronmk
        return str_
878 2216 aaronmk
879 2260 aaronmk
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
880
assume_literal = object()
881
882
def as_ValueCond(value, default_table=assume_literal):
883
    if not isinstance(value, ValueCond):
884
        if default_table is not assume_literal:
885 2748 aaronmk
            value = with_default_table(value, default_table)
886 2260 aaronmk
        return CompareCond(value)
887 2216 aaronmk
    else: return value
888 2219 aaronmk
889 2335 aaronmk
##### Joins
890
891 2352 aaronmk
join_same = object() # tells Join the left and right columns have the same name
892 2260 aaronmk
893 2353 aaronmk
# Tells Join the left and right columns have the same name and are never NULL
894
join_same_not_null = object()
895
896 2260 aaronmk
filter_out = object() # tells Join to filter out rows that match the join
897
898 2514 aaronmk
class Join(BasicObject):
899 2746 aaronmk
    def __init__(self, table, mapping={}, type_=None):
900 2260 aaronmk
        '''
901
        @param mapping dict(right_table_col=left_table_col, ...)
902 2352 aaronmk
            * if left_table_col is join_same: left_table_col = right_table_col
903 2353 aaronmk
              * Note that right_table_col must be a string
904
            * if left_table_col is join_same_not_null:
905
              left_table_col = right_table_col and both have NOT NULL constraint
906
              * Note that right_table_col must be a string
907 2260 aaronmk
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
908
            * filter_out: equivalent to 'LEFT' with the query filtered by
909
              `table_pkey IS NULL` (indicating no match)
910
        '''
911
        if util.is_str(table): table = Table(table)
912
        assert type_ == None or util.is_str(type_) or type_ is filter_out
913
914
        self.table = table
915
        self.mapping = mapping
916
        self.type_ = type_
917
918 2749 aaronmk
    def to_str(self, db, left_table_):
919 2260 aaronmk
        def join(entry):
920
            '''Parses non-USING joins'''
921
            right_table_col, left_table_col = entry
922
923 2353 aaronmk
            # Switch order (right_table_col is on the left in the comparison)
924
            left = right_table_col
925
            right = left_table_col
926 2749 aaronmk
            left_table = self.table
927
            right_table = left_table_
928 2353 aaronmk
929 2747 aaronmk
            # Parse left side
930 2748 aaronmk
            left = with_default_table(left, left_table)
931 2747 aaronmk
932 2260 aaronmk
            # Parse special values
933 2747 aaronmk
            left_on_right = Col(left.name, right_table)
934
            if right is join_same: right = left_on_right
935 2353 aaronmk
            elif right is join_same_not_null:
936 2747 aaronmk
                right = CompareCond(left_on_right, '~=')
937 2260 aaronmk
938 2747 aaronmk
            # Parse right side
939 2353 aaronmk
            right = as_ValueCond(right, right_table)
940 2747 aaronmk
941
            return right.to_str(db, left)
942 2260 aaronmk
943 2265 aaronmk
        # Create join condition
944
        type_ = self.type_
945 2276 aaronmk
        joins = self.mapping
946 2746 aaronmk
        if joins == {}: join_cond = None
947
        elif type_ is not filter_out and reduce(operator.and_,
948 2460 aaronmk
            (v is join_same_not_null for v in joins.itervalues())):
949 2260 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
950 2747 aaronmk
            cols = map(to_name_only_col, joins.iterkeys())
951
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
952 2757 aaronmk
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
953 2260 aaronmk
954 2757 aaronmk
        if isinstance(self.table, NamedTable): whitespace = '\n'
955
        else: whitespace = ' '
956
957 2260 aaronmk
        # Create join
958
        if type_ is filter_out: type_ = 'LEFT'
959 2266 aaronmk
        str_ = ''
960
        if type_ != None: str_ += type_+' '
961 2757 aaronmk
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
962
        if join_cond != None: str_ += whitespace+join_cond
963 2266 aaronmk
        return str_
964 2349 aaronmk
965 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
966 2424 aaronmk
967
##### Value exprs
968
969 3089 aaronmk
all_cols = CustomCode('*')
970
971 2737 aaronmk
default = CustomCode('DEFAULT')
972
973 3090 aaronmk
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
974 2674 aaronmk
975 3061 aaronmk
class Coalesce(FunctionCall):
976
    def __init__(self, *args):
977
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
978 3060 aaronmk
979 3062 aaronmk
class Nullif(FunctionCall):
980
    def __init__(self, *args):
981
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
982
983 5891 aaronmk
null = Literal(None)
984 5892 aaronmk
null_as_str = Cast('text', null)
985 3706 aaronmk
986
def to_text(value): return Coalesce(Cast('text', value), null_as_str)
987
988 2850 aaronmk
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
989 2958 aaronmk
null_sentinels = {
990
    'character varying': r'\N',
991
    'double precision': 'NaN',
992
    'integer': 2147483647,
993
    'text': r'\N',
994 5498 aaronmk
    'date': 'infinity',
995 5266 aaronmk
    'timestamp with time zone': 'infinity',
996
    'taxonrank': 'unknown',
997 2958 aaronmk
}
998 2692 aaronmk
999 3061 aaronmk
class EnsureNotNull(Coalesce):
1000 2850 aaronmk
    def __init__(self, value, type_):
1001 4938 aaronmk
        if isinstance(type_, ArrayType): null = []
1002
        else: null = null_sentinels[type_]
1003
        Coalesce.__init__(self, as_Col(value), Cast(type_, null))
1004 2850 aaronmk
1005
        self.type = type_
1006 3001 aaronmk
1007
    def to_str(self, db):
1008
        col = self.args[0]
1009
        index_col_ = index_col(col)
1010
        if index_col_ != None: return index_col_.to_str(db)
1011 3061 aaronmk
        return Coalesce.to_str(self, db)
1012 2850 aaronmk
1013 3523 aaronmk
#### Arrays
1014
1015 3535 aaronmk
class ArrayMerge(FunctionCall):
1016 3523 aaronmk
    def __init__(self, sep, array):
1017
        array = to_Array(array)
1018
        FunctionCall.__init__(self, InternalFunction('array_to_string'), array,
1019
            sep)
1020
1021 3537 aaronmk
def merge_not_null(db, sep, values):
1022 3707 aaronmk
    return ArrayMerge(sep, map(to_text, values))
1023 3537 aaronmk
1024 2737 aaronmk
##### Table exprs
1025
1026
class Values(Code):
1027
    def __init__(self, values):
1028 2739 aaronmk
        '''
1029
        @param values [...]|[[...], ...] Can be one or multiple rows.
1030
        '''
1031 3445 aaronmk
        Code.__init__(self)
1032
1033 2739 aaronmk
        rows = values
1034
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
1035
            rows = [values]
1036
        for i, row in enumerate(rows):
1037
            rows[i] = map(remove_col_rename, map(as_Value, row))
1038 2737 aaronmk
1039 2739 aaronmk
        self.rows = rows
1040 2737 aaronmk
1041
    def to_str(self, db):
1042 3520 aaronmk
        return 'VALUES '+(', '.join((Tuple(*r).to_str(db) for r in self.rows)))
1043 2737 aaronmk
1044 2740 aaronmk
def NamedValues(name, cols, values):
1045 2745 aaronmk
    '''
1046 3048 aaronmk
    @param cols None|[...]
1047 2745 aaronmk
    @post `cols` will be changed to Col objects with the table set to `name`.
1048
    '''
1049 2834 aaronmk
    table = NamedTable(name, Values(values), cols)
1050 3048 aaronmk
    if cols != None: set_cols_table(table, cols)
1051 2834 aaronmk
    return table
1052 2740 aaronmk
1053 2674 aaronmk
##### Database structure
1054
1055 5074 aaronmk
def is_nullable(db, value):
1056 5330 aaronmk
    if not is_table_col(value): return is_null(value)
1057
    try: return db.col_info(value).nullable
1058 5074 aaronmk
    except NoUnderlyingTableException: return True # not a table column
1059
1060 4442 aaronmk
text_types = set(['character varying', 'text'])
1061 4406 aaronmk
1062 5334 aaronmk
def is_text_type(type_): return type_ in text_types
1063
1064 5335 aaronmk
def is_text_col(db, col): return is_text_type(db.col_info(col).type)
1065 4442 aaronmk
1066 5397 aaronmk
def canon_type(type_):
1067
    if type_ in text_types: return 'text'
1068
    else: return type_
1069
1070 2840 aaronmk
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
1071
1072 2851 aaronmk
def ensure_not_null(db, col, type_=None):
1073 2840 aaronmk
    '''
1074 2855 aaronmk
    @param col If type_ is not set, must have an underlying column.
1075 4488 aaronmk
    @param type_ If set, overrides the underlying column's type and casts the
1076
        column to it if needed.
1077 2840 aaronmk
    @return EnsureNotNull|Col
1078
    @throws ensure_not_null_excs
1079
    '''
1080 5329 aaronmk
    col = remove_col_rename(col)
1081 5332 aaronmk
1082
    try: col_type = db.col_info(underlying_col(col)).type
1083 2855 aaronmk
    except NoUnderlyingTableException:
1084 5333 aaronmk
        if type_ == None and is_null(col): raise # NULL has no type
1085 2855 aaronmk
    else:
1086 5332 aaronmk
        if type_ == None: type_ = col_type
1087
        elif type_ != col_type: col = Cast(type_, col)
1088 2855 aaronmk
1089 5331 aaronmk
    if is_nullable(db, col):
1090 2953 aaronmk
        try: col = EnsureNotNull(col, type_)
1091
        except KeyError, e:
1092
            # Warn of no null sentinel for type, even if caller catches error
1093
            warnings.warn(UserWarning(exc.str_(e)))
1094
            raise
1095
1096 2840 aaronmk
    return col
1097 3536 aaronmk
1098
def try_mk_not_null(db, value):
1099
    '''
1100
    Warning: This function does not guarantee that its result is NOT NULL.
1101
    '''
1102
    try: return ensure_not_null(db, value)
1103
    except ensure_not_null_excs: return value
1104 5367 aaronmk
1105
##### Expression transforming
1106
1107
true_expr = 'true'
1108
false_expr = 'false'
1109
1110
true_re = true_expr
1111
false_re = false_expr
1112
bool_re = r'(?:'+true_re+r'|'+false_re+r')'
1113
atom_re = r'(?:'+bool_re+r'|\([^()]*\)'+r')'
1114
1115
def logic_op_re(op, value_re, expr_re=''):
1116
    op_re = ' '+op+' '
1117
    return '(?:'+expr_re+op_re+value_re+'|'+value_re+op_re+expr_re+')'
1118
1119 5370 aaronmk
not_re = r'\bNOT '
1120 5372 aaronmk
not_false_re = not_re+false_re+r'\b'
1121
not_true_re = not_re+true_re+r'\b'
1122 5367 aaronmk
and_false_re = logic_op_re('AND', false_re, atom_re)
1123 5370 aaronmk
and_false_not_true_re = '(?:'+not_true_re+'|'+and_false_re+')'
1124 5367 aaronmk
and_true_re = logic_op_re('AND', true_re)
1125
or_re = logic_op_re('OR', bool_re)
1126
or_and_true_re = '(?:'+and_true_re+'|'+or_re+')'
1127
1128
def simplify_parens(expr):
1129
    return regexp.sub_nested(r'\(('+atom_re+')\)', r'\1', expr)
1130
1131
def simplify_recursive(sub_func, expr):
1132
    '''
1133
    @param sub_func See regexp.sub_recursive() sub_func param
1134
    '''
1135 5378 aaronmk
    return regexp.sub_recursive(lambda s: sub_func(simplify_parens(s)), expr)
1136
        # simplify_parens() is also done at end in final iteration
1137 5367 aaronmk
1138
def simplify_expr(expr):
1139
    def simplify_logic_ops(expr):
1140
        total_n = 0
1141 5371 aaronmk
        expr, n = re.subn(not_false_re, true_re, expr)
1142
        total_n += n
1143 5370 aaronmk
        expr, n = re.subn(and_false_not_true_re, false_expr, expr)
1144 5367 aaronmk
        total_n += n
1145
        expr, n = re.subn(or_and_true_re, r'', expr)
1146
        total_n += n
1147
        return expr, total_n
1148
1149 5826 aaronmk
    expr = expr.replace('NULL IS NULL', true_expr)
1150
    expr = expr.replace('NULL IS NOT NULL', false_expr)
1151 5367 aaronmk
    expr = simplify_recursive(simplify_logic_ops, expr)
1152
    return expr
1153
1154
name_re = r'(?:\w+|(?:"[^"]*")+)'
1155
1156
def parse_expr_col(str_):
1157
    match = re.match(r'^\('+name_re+r'\(('+name_re+r').*\)\)$', str_)
1158
    if match: str_ = match.group(1)
1159
    return unesc_name(str_)
1160
1161
def map_expr(db, expr, mapping, in_cols_found=None):
1162
    '''Replaces output columns with input columns in an expression.
1163
    @param in_cols_found If set, will be filled in with the expr's (input) cols
1164
    '''
1165
    for out, in_ in mapping.iteritems():
1166
        orig_expr = expr
1167
        out = to_name_only_col(out)
1168
        in_str = to_name_only_col(remove_col_rename(in_)).to_str(db)
1169
1170
        # Replace out both with and without quotes
1171
        expr = expr.replace(out.to_str(db), in_str)
1172 5507 aaronmk
        expr = re.sub(r'(?<!["\'\.=\[])\b'+out.name+r'\b(?!["\',\.=\]])',
1173
            in_str, expr)
1174 5367 aaronmk
1175
        if in_cols_found != None and expr != orig_expr: # replaced something
1176
            in_cols_found.append(in_)
1177
1178
    return simplify_expr(expr)