Project

General

Profile

1 2211 aaronmk
# SQL code generation
2
3 2748 aaronmk
import copy
4 3518 aaronmk
import itertools
5 2276 aaronmk
import operator
6 3158 aaronmk
from ordereddict import OrderedDict
7 2568 aaronmk
import re
8 2653 aaronmk
import UserDict
9 2953 aaronmk
import warnings
10 2276 aaronmk
11 2667 aaronmk
import dicts
12 2953 aaronmk
import exc
13 2701 aaronmk
import iters
14
import lists
15 2360 aaronmk
import objects
16 5367 aaronmk
import regexp
17 2222 aaronmk
import strings
18 2227 aaronmk
import util
19 2211 aaronmk
20 2587 aaronmk
##### Names
21 2499 aaronmk
22 2608 aaronmk
identifier_max_len = 63 # works for both PostgreSQL and MySQL
23 2587 aaronmk
24 2932 aaronmk
def concat(str_, suffix):
25 2609 aaronmk
    '''Preserves version so that it won't be truncated off the string, leading
26
    to collisions.'''
27 2613 aaronmk
    # Preserve version
28 2995 aaronmk
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
29 2985 aaronmk
    if match:
30
        str_, old_suffix = match.groups()
31
        suffix = old_suffix+suffix
32 2613 aaronmk
33 2932 aaronmk
    return strings.concat(str_, suffix, identifier_max_len)
34 2587 aaronmk
35 2932 aaronmk
def truncate(str_): return concat(str_, '')
36 2842 aaronmk
37 2575 aaronmk
def is_safe_name(name):
38 2583 aaronmk
    '''A name is safe *and unambiguous* if it:
39
    * contains only *lowercase* word (\w) characters
40
    * doesn't start with a digit
41
    * contains "_", so that it's not a keyword
42 2984 aaronmk
    '''
43
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
44 2568 aaronmk
45 2499 aaronmk
def esc_name(name, quote='"'):
46
    return quote + name.replace(quote, quote+quote) + quote
47
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
48
49 3320 aaronmk
def unesc_name(name, quote='"'):
50
    removed_ref = [False]
51
    name = strings.remove_prefix(quote, name, removed_ref)
52
    if removed_ref[0]:
53
        name = strings.remove_suffix(quote, name, removed_ref)
54
        assert removed_ref[0]
55
        name = name.replace(quote+quote, quote)
56
    return name
57
58 2513 aaronmk
def clean_name(name): return name.replace('"', '').replace('`', '')
59
60 3041 aaronmk
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
61
62 3182 aaronmk
def lstrip(str_):
63
    '''Also removes comments.'''
64
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
65
    return str_.lstrip()
66
67 2659 aaronmk
##### General SQL code objects
68 2219 aaronmk
69 2349 aaronmk
class MockDb:
70 2503 aaronmk
    def esc_value(self, value): return strings.repr_no_u(value)
71 2349 aaronmk
72 2499 aaronmk
    def esc_name(self, name): return esc_name(name)
73 2859 aaronmk
74
    def col_info(self, col):
75
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
76
77 2349 aaronmk
mockDb = MockDb()
78
79 2514 aaronmk
class BasicObject(objects.BasicObject):
80
    def __str__(self): return clean_name(strings.repr_no_u(self))
81
82 2659 aaronmk
##### Unparameterized code objects
83
84 2514 aaronmk
class Code(BasicObject):
85 3446 aaronmk
    def __init__(self, lang='sql'):
86
        self.lang = lang
87 3445 aaronmk
88 2658 aaronmk
    def to_str(self, db): raise NotImplementedError()
89 2349 aaronmk
90 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb)
91 2211 aaronmk
92 2269 aaronmk
class CustomCode(Code):
93 3445 aaronmk
    def __init__(self, str_):
94
        Code.__init__(self)
95
96
        self.str_ = str_
97 2256 aaronmk
98
    def to_str(self, db): return self.str_
99
100 2815 aaronmk
def as_Code(value, db=None):
101
    '''
102
    @param db If set, runs db.std_code() on the value.
103
    '''
104 3447 aaronmk
    if isinstance(value, Code): return value
105
106 2815 aaronmk
    if util.is_str(value):
107
        if db != None: value = db.std_code(value)
108
        return CustomCode(value)
109 2659 aaronmk
    else: return Literal(value)
110
111 2540 aaronmk
class Expr(Code):
112 3445 aaronmk
    def __init__(self, expr):
113
        Code.__init__(self)
114
115
        self.expr = expr
116 2540 aaronmk
117
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
118
119 3086 aaronmk
##### Names
120
121
class Name(Code):
122 3087 aaronmk
    def __init__(self, name):
123 3445 aaronmk
        Code.__init__(self)
124
125 3087 aaronmk
        name = truncate(name)
126
127
        self.name = name
128 3086 aaronmk
129
    def to_str(self, db): return db.esc_name(self.name)
130
131
def as_Name(value):
132
    if isinstance(value, Code): return value
133
    else: return Name(value)
134
135 2335 aaronmk
##### Literal values
136
137 3519 aaronmk
#### Primitives
138
139 2216 aaronmk
class Literal(Code):
140 3445 aaronmk
    def __init__(self, value):
141
        Code.__init__(self)
142
143
        self.value = value
144 2213 aaronmk
145
    def to_str(self, db): return db.esc_value(self.value)
146 2211 aaronmk
147 2400 aaronmk
def as_Value(value):
148
    if isinstance(value, Code): return value
149
    else: return Literal(value)
150
151 3429 aaronmk
def is_literal(value): return isinstance(value, Literal)
152 2216 aaronmk
153 3429 aaronmk
def is_null(value): return is_literal(value) and value.value == None
154
155 3519 aaronmk
#### Composites
156
157 3521 aaronmk
class List(Code):
158
    def __init__(self, values):
159 3519 aaronmk
        Code.__init__(self)
160
161
        self.values = values
162
163 3521 aaronmk
    def to_str(self, db): return ', '.join((v.to_str(db) for v in self.values))
164 3519 aaronmk
165 3521 aaronmk
class Tuple(List):
166
    def __init__(self, *values):
167
        List.__init__(self, values)
168
169
    def to_str(self, db): return '('+List.to_str(self, db)+')'
170
171 3520 aaronmk
class Row(Tuple):
172
    def to_str(self, db): return 'ROW'+Tuple.to_str(self, db)
173 3519 aaronmk
174 3522 aaronmk
### Arrays
175
176
class Array(List):
177
    def __init__(self, values):
178
        values = map(remove_col_rename, values)
179
180
        List.__init__(self, values)
181
182
    def to_str(self, db): return 'ARRAY['+List.to_str(self, db)+']'
183
184
def to_Array(value):
185
    if isinstance(value, Array): return value
186
    return Array(lists.mk_seq(value))
187
188 2711 aaronmk
##### Derived elements
189
190
src_self = object() # tells Col that it is its own source column
191
192
class Derived(Code):
193
    def __init__(self, srcs):
194 2712 aaronmk
        '''An element which was derived from some other element(s).
195 2711 aaronmk
        @param srcs See self.set_srcs()
196
        '''
197 3445 aaronmk
        Code.__init__(self)
198
199 2711 aaronmk
        self.set_srcs(srcs)
200
201 2713 aaronmk
    def set_srcs(self, srcs, overwrite=True):
202 2711 aaronmk
        '''
203
        @param srcs (self_type...)|src_self The element(s) this is derived from
204
        '''
205 2713 aaronmk
        if not overwrite and self.srcs != (): return # already set
206
207 2711 aaronmk
        if srcs == src_self: srcs = (self,)
208
        srcs = tuple(srcs) # make Col hashable
209
        self.srcs = srcs
210
211
    def _compare_on(self):
212
        compare_on = self.__dict__.copy()
213
        del compare_on['srcs'] # ignore
214
        return compare_on
215
216
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
217
218 2335 aaronmk
##### Tables
219
220 2712 aaronmk
class Table(Derived):
221 2991 aaronmk
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
222 2211 aaronmk
        '''
223
        @param schema str|None (for no schema)
224 2712 aaronmk
        @param srcs (Table...)|src_self See Derived.set_srcs()
225 2211 aaronmk
        '''
226 2712 aaronmk
        Derived.__init__(self, srcs)
227
228 3091 aaronmk
        if util.is_str(name): name = truncate(name)
229 2843 aaronmk
230 2211 aaronmk
        self.name = name
231
        self.schema = schema
232 2991 aaronmk
        self.is_temp = is_temp
233 3000 aaronmk
        self.index_cols = {}
234 2211 aaronmk
235 2348 aaronmk
    def to_str(self, db):
236
        str_ = ''
237 3088 aaronmk
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
238
        str_ += as_Name(self.name).to_str(db)
239 2348 aaronmk
        return str_
240 2336 aaronmk
241
    def to_Table(self): return self
242 3000 aaronmk
243
    def _compare_on(self):
244
        compare_on = Derived._compare_on(self)
245
        del compare_on['index_cols'] # ignore
246
        return compare_on
247 2211 aaronmk
248 2835 aaronmk
def is_underlying_table(table):
249
    return isinstance(table, Table) and table.to_Table() is table
250 2832 aaronmk
251 4114 aaronmk
class NoUnderlyingTableException(Exception):
252
    def __init__(self, ref):
253
        Exception.__init__(self, 'for: '+strings.as_tt(strings.urepr(ref)))
254
        self.ref = ref
255 2902 aaronmk
256
def underlying_table(table):
257
    table = remove_table_rename(table)
258 3532 aaronmk
    if table != None and table.srcs:
259
        table, = table.srcs # for derived tables or row vars
260 4114 aaronmk
    if not is_underlying_table(table): raise NoUnderlyingTableException(table)
261 2902 aaronmk
    return table
262
263 2776 aaronmk
def as_Table(table, schema=None):
264 2270 aaronmk
    if table == None or isinstance(table, Code): return table
265 2776 aaronmk
    else: return Table(table, schema)
266 2219 aaronmk
267 3101 aaronmk
def suffixed_table(table, suffix):
268 3128 aaronmk
    table = copy.copy(table) # don't modify input!
269
    table.name = concat(table.name, suffix)
270
    return table
271 2707 aaronmk
272 2336 aaronmk
class NamedTable(Table):
273
    def __init__(self, name, code, cols=None):
274
        Table.__init__(self, name)
275
276 3016 aaronmk
        code = as_Table(code)
277 2741 aaronmk
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
278 3020 aaronmk
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
279 2336 aaronmk
280
        self.code = code
281
        self.cols = cols
282
283
    def to_str(self, db):
284 3026 aaronmk
        str_ = self.code.to_str(db)
285
        if str_.find('\n') >= 0: whitespace = '\n'
286
        else: whitespace = ' '
287
        str_ += whitespace+'AS '+Table.to_str(self, db)
288 2742 aaronmk
        if self.cols != None:
289
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
290 2336 aaronmk
        return str_
291
292
    def to_Table(self): return Table(self.name)
293
294 2753 aaronmk
def remove_table_rename(table):
295
    if isinstance(table, NamedTable): table = table.code
296
    return table
297
298 2335 aaronmk
##### Columns
299
300 2711 aaronmk
class Col(Derived):
301 2701 aaronmk
    def __init__(self, name, table=None, srcs=()):
302 2211 aaronmk
        '''
303
        @param table Table|None (for no table)
304 2711 aaronmk
        @param srcs (Col...)|src_self See Derived.set_srcs()
305 2211 aaronmk
        '''
306 2711 aaronmk
        Derived.__init__(self, srcs)
307
308 3091 aaronmk
        if util.is_str(name): name = truncate(name)
309 2241 aaronmk
        if util.is_str(table): table = Table(table)
310 2211 aaronmk
        assert table == None or isinstance(table, Table)
311
312
        self.name = name
313
        self.table = table
314
315 2989 aaronmk
    def to_str(self, db, for_str=False):
316 3088 aaronmk
        str_ = as_Name(self.name).to_str(db)
317 2989 aaronmk
        if for_str: str_ = clean_name(str_)
318 2933 aaronmk
        if self.table != None:
319 2989 aaronmk
            table = self.table.to_Table()
320 3750 aaronmk
            if for_str: str_ = concat(strings.ustr(table), '.'+str_)
321 2989 aaronmk
            else: str_ = table.to_str(db)+'.'+str_
322 2211 aaronmk
        return str_
323 2314 aaronmk
324 2989 aaronmk
    def __str__(self): return self.to_str(mockDb, for_str=True)
325 2933 aaronmk
326 2314 aaronmk
    def to_Col(self): return self
327 2211 aaronmk
328 3512 aaronmk
def is_col(col): return isinstance(col, Col)
329 2393 aaronmk
330 3512 aaronmk
def is_table_col(col): return is_col(col) and col.table != None
331
332 3000 aaronmk
def index_col(col):
333
    if not is_table_col(col): return None
334 3104 aaronmk
335
    table = col.table
336
    try: name = table.index_cols[col.name]
337
    except KeyError: return None
338
    else: return Col(name, table, col.srcs)
339 2999 aaronmk
340 3024 aaronmk
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
341 2996 aaronmk
342 2563 aaronmk
def as_Col(col, table=None, name=None):
343
    '''
344
    @param name If not None, any non-Col input will be renamed using NamedCol.
345
    '''
346
    if name != None:
347
        col = as_Value(col)
348
        if not isinstance(col, Col): col = NamedCol(name, col)
349 2333 aaronmk
350
    if isinstance(col, Code): return col
351 3513 aaronmk
    elif util.is_str(col): return Col(col, table)
352
    else: return Literal(col)
353 2260 aaronmk
354 3093 aaronmk
def with_table(col, table):
355 3105 aaronmk
    if isinstance(col, NamedCol): pass # doesn't take a table
356
    elif isinstance(col, FunctionCall):
357
        col = copy.deepcopy(col) # don't modify input!
358
        col.args[0].table = table
359 3493 aaronmk
    elif isinstance(col, Col):
360 3098 aaronmk
        col = copy.copy(col) # don't modify input!
361
        col.table = table
362 3093 aaronmk
    return col
363
364 3100 aaronmk
def with_default_table(col, table):
365 2747 aaronmk
    col = as_Col(col)
366 3100 aaronmk
    if col.table == None: col = with_table(col, table)
367 2747 aaronmk
    return col
368
369 2744 aaronmk
def set_cols_table(table, cols):
370
    table = as_Table(table)
371
372
    for i, col in enumerate(cols):
373
        col = cols[i] = as_Col(col)
374
        col.table = table
375
376 2401 aaronmk
def to_name_only_col(col, check_table=None):
377
    col = as_Col(col)
378 3020 aaronmk
    if not is_table_col(col): return col
379 2401 aaronmk
380
    if check_table != None:
381
        table = col.table
382
        assert table == None or table == check_table
383
    return Col(col.name)
384
385 2993 aaronmk
def suffixed_col(col, suffix):
386
    return Col(concat(col.name, suffix), col.table, col.srcs)
387
388 3516 aaronmk
def has_srcs(col): return is_col(col) and col.srcs
389
390 3518 aaronmk
def cross_join_srcs(cols):
391
    cols = filter(has_srcs, cols) # empty srcs will mess up the cross join
392
    srcs = [[s.name for s in c.srcs] for c in cols]
393
    return [Col(','.join(s)) for s in itertools.product(*srcs)]
394 3514 aaronmk
395 2323 aaronmk
class NamedCol(Col):
396 2229 aaronmk
    def __init__(self, name, code):
397 2310 aaronmk
        Col.__init__(self, name)
398
399 3016 aaronmk
        code = as_Value(code)
400 2229 aaronmk
401
        self.code = code
402
403
    def to_str(self, db):
404 2310 aaronmk
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
405 2314 aaronmk
406
    def to_Col(self): return Col(self.name)
407 2229 aaronmk
408 2462 aaronmk
def remove_col_rename(col):
409
    if isinstance(col, NamedCol): col = col.code
410
    return col
411
412 2830 aaronmk
def underlying_col(col):
413
    col = remove_col_rename(col)
414 4114 aaronmk
    if not isinstance(col, Col): raise NoUnderlyingTableException(col)
415 2849 aaronmk
416 2902 aaronmk
    return Col(col.name, underlying_table(col.table), col.srcs)
417 2830 aaronmk
418 2703 aaronmk
def wrap(wrap_func, value):
419
    '''Wraps a value, propagating any column renaming to the returned value.'''
420
    if isinstance(value, NamedCol):
421
        return NamedCol(value.name, wrap_func(value.code))
422
    else: return wrap_func(value)
423
424 2667 aaronmk
class ColDict(dicts.DictProxy):
425 3691 aaronmk
    '''A dict that automatically makes inserted entries Col objects.
426
    Anything that isn't a column is wrapped in a NamedCol with the key's column
427
    name by `as_Col(value, name=key.name)`.
428
    '''
429 2564 aaronmk
430 2645 aaronmk
    def __init__(self, db, keys_table, dict_={}):
431 3158 aaronmk
        dicts.DictProxy.__init__(self, OrderedDict())
432 2667 aaronmk
433 2645 aaronmk
        keys_table = as_Table(keys_table)
434
435 2642 aaronmk
        self.db = db
436 2641 aaronmk
        self.table = keys_table
437 2653 aaronmk
        self.update(dict_) # after setting vars because __setitem__() needs them
438 2641 aaronmk
439 2667 aaronmk
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
440 2655 aaronmk
441 2667 aaronmk
    def __getitem__(self, key):
442
        return dicts.DictProxy.__getitem__(self, self._key(key))
443 2653 aaronmk
444 2564 aaronmk
    def __setitem__(self, key, value):
445 2642 aaronmk
        key = self._key(key)
446 3639 aaronmk
        if value == None:
447
            try: value = self.db.col_info(key).default
448
            except NoUnderlyingTableException: pass # not a table column
449 2667 aaronmk
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
450 2564 aaronmk
451 2641 aaronmk
    def _key(self, key): return as_Col(key, self.table)
452 2564 aaronmk
453 3519 aaronmk
##### Definitions
454 3491 aaronmk
455 3469 aaronmk
class TypedCol(Col):
456
    def __init__(self, name, type_, default=None, nullable=True,
457
        constraints=None):
458
        assert default == None or isinstance(default, Code)
459
460
        Col.__init__(self, name)
461
462
        self.type = type_
463
        self.default = default
464
        self.nullable = nullable
465
        self.constraints = constraints
466
467
    def to_str(self, db):
468 3481 aaronmk
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
469 3469 aaronmk
        if not self.nullable: str_ += ' NOT NULL'
470
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
471
        if self.constraints != None: str_ += ' '+self.constraints
472
        return str_
473
474
    def to_Col(self): return Col(self.name)
475
476 3488 aaronmk
class SetOf(Code):
477
    def __init__(self, type_):
478
        Code.__init__(self)
479
480
        self.type = type_
481
482
    def to_str(self, db):
483
        return 'SETOF '+self.type.to_str(db)
484
485 3483 aaronmk
class RowType(Code):
486
    def __init__(self, table):
487
        Code.__init__(self)
488
489
        self.table = table
490
491
    def to_str(self, db):
492
        return self.table.to_str(db)+'%ROWTYPE'
493
494 3485 aaronmk
class ColType(Code):
495
    def __init__(self, col):
496
        Code.__init__(self)
497
498
        self.col = col
499
500
    def to_str(self, db):
501
        return self.col.to_str(db)+'%TYPE'
502
503 4935 aaronmk
class ArrayType(Code):
504
    def __init__(self, elem_type):
505
        Code.__init__(self)
506
        elem_type = as_Code(elem_type)
507
508
        self.elem_type = elem_type
509
510
    def to_str(self, db):
511
        return self.elem_type.to_str(db)+'[]'
512
513 2524 aaronmk
##### Functions
514
515 2912 aaronmk
Function = Table
516 2911 aaronmk
as_Function = as_Table
517
518 2691 aaronmk
class InternalFunction(CustomCode): pass
519
520 3442 aaronmk
#### Calls
521
522 2941 aaronmk
class NamedArg(NamedCol):
523
    def __init__(self, name, value):
524
        NamedCol.__init__(self, name, value)
525
526
    def to_str(self, db):
527
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
528
529 2524 aaronmk
class FunctionCall(Code):
530 2941 aaronmk
    def __init__(self, function, *args, **kw_args):
531 2524 aaronmk
        '''
532 2690 aaronmk
        @param args [Code|literal-value...] The function's arguments
533 2524 aaronmk
        '''
534 3445 aaronmk
        Code.__init__(self)
535
536 3016 aaronmk
        function = as_Function(function)
537 2941 aaronmk
        def filter_(arg): return remove_col_rename(as_Value(arg))
538
        args = map(filter_, args)
539
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
540 2524 aaronmk
541
        self.function = function
542
        self.args = args
543
544
    def to_str(self, db):
545
        args_str = ', '.join((v.to_str(db) for v in self.args))
546
        return self.function.to_str(db)+'('+args_str+')'
547
548 2533 aaronmk
def wrap_in_func(function, value):
549
    '''Wraps a value inside a function call.
550
    Propagates any column renaming to the returned value.
551
    '''
552 2703 aaronmk
    return wrap(lambda v: FunctionCall(function, v), value)
553 2533 aaronmk
554 2561 aaronmk
def unwrap_func_call(func_call, check_name=None):
555
    '''Unwraps any function call to its first argument.
556
    Also removes any column renaming.
557
    '''
558
    func_call = remove_col_rename(func_call)
559
    if not isinstance(func_call, FunctionCall): return func_call
560
561
    if check_name != None:
562
        name = func_call.function.name
563
        assert name == None or name == check_name
564
    return func_call.args[0]
565
566 3442 aaronmk
#### Definitions
567
568
class FunctionDef(Code):
569 3471 aaronmk
    def __init__(self, function, return_type, body, params=[], modifiers=None):
570 3445 aaronmk
        Code.__init__(self)
571
572 3487 aaronmk
        return_type = as_Code(return_type)
573 3444 aaronmk
        body = as_Code(body)
574
575 3442 aaronmk
        self.function = function
576
        self.return_type = return_type
577
        self.body = body
578 3471 aaronmk
        self.params = params
579 3456 aaronmk
        self.modifiers = modifiers
580 3442 aaronmk
581
    def to_str(self, db):
582 3487 aaronmk
        params_str = (', '.join((p.to_str(db) for p in self.params)))
583 3442 aaronmk
        str_ = '''\
584 3471 aaronmk
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
585 3487 aaronmk
RETURNS '''+self.return_type.to_str(db)+'''
586 3448 aaronmk
LANGUAGE '''+self.body.lang+'''
587 3456 aaronmk
'''
588
        if self.modifiers != None: str_ += self.modifiers+'\n'
589
        str_ += '''\
590 3442 aaronmk
AS $$
591 3444 aaronmk
'''+self.body.to_str(db)+'''
592 3442 aaronmk
$$;
593
'''
594
        return str_
595
596 3469 aaronmk
class FunctionParam(TypedCol):
597
    def __init__(self, name, type_, default=None, out=False):
598
        TypedCol.__init__(self, name, type_, default)
599
600
        self.out = out
601
602
    def to_str(self, db):
603
        str_ = TypedCol.to_str(self, db)
604
        if self.out: str_ = 'OUT '+str_
605
        return str_
606
607
    def to_Col(self): return Col(self.name)
608
609 3454 aaronmk
### PL/pgSQL
610
611 3496 aaronmk
class ReturnQuery(Code):
612
    def __init__(self, query):
613
        Code.__init__(self)
614
615
        query = as_Code(query)
616
617
        self.query = query
618
619
    def to_str(self, db):
620
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
621
622 3515 aaronmk
## Exceptions
623
624 3509 aaronmk
class BaseExcHandler(BasicObject):
625
    def to_str(self, db, body): raise NotImplementedError()
626 3510 aaronmk
627
    def __repr__(self): return self.to_str(mockDb, '<body>')
628 3509 aaronmk
629 3549 aaronmk
suppress_exc = 'NULL;\n';
630
631 3554 aaronmk
reraise_exc = 'RAISE USING ERRCODE = SQLSTATE, MESSAGE = SQLERRM;\n';
632
633 3509 aaronmk
class ExcHandler(BaseExcHandler):
634 3454 aaronmk
    def __init__(self, exc, handler=None):
635
        if handler != None: handler = as_Code(handler)
636
637
        self.exc = exc
638
        self.handler = handler
639
640
    def to_str(self, db, body):
641
        body = as_Code(body)
642
643 3467 aaronmk
        if self.handler != None:
644
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
645 3549 aaronmk
        else: handler_str = ' '+suppress_exc
646 3454 aaronmk
647
        str_ = '''\
648
BEGIN
649 3467 aaronmk
'''+strings.indent(body.to_str(db))+'''\
650 3454 aaronmk
EXCEPTION
651 3463 aaronmk
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
652 3454 aaronmk
END;\
653
'''
654
        return str_
655
656 3515 aaronmk
class NestedExcHandler(BaseExcHandler):
657
    def __init__(self, *handlers):
658
        '''
659
        @param handlers Sorted from outermost to innermost
660
        '''
661
        self.handlers = handlers
662
663
    def to_str(self, db, body):
664
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
665
        return body
666
667 3503 aaronmk
class ExcToWarning(Code):
668
    def __init__(self, return_):
669
        '''
670
        @param return_ Statement to return a default value in case of error
671
        '''
672
        Code.__init__(self)
673
674
        return_ = as_Code(return_)
675
676
        self.return_ = return_
677
678
    def to_str(self, db):
679
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
680
681 3454 aaronmk
unique_violation_handler = ExcHandler('unique_violation')
682
683 3555 aaronmk
# Note doubled "\"s because inside Python string
684 3468 aaronmk
plpythonu_error_handler = ExcHandler('internal_error', '''\
685 3591 aaronmk
-- Handle PL/Python exceptions
686 3555 aaronmk
DECLARE
687 3595 aaronmk
    matches text[] := regexp_matches(SQLERRM,
688
        E'^(?:PL/Python: )?(\\\\w+): (.*)$'); -- .* also matches \\n
689 3555 aaronmk
    exc_name text := matches[1];
690
    msg text := matches[2];
691
BEGIN
692 3596 aaronmk
    /* Re-raise PL/Python exceptions with the PL/Python prefix removed.
693
    This allows the exception to be parsed like a native exception.
694
    Always raise as data_exception so it goes in the errors table. */
695 3598 aaronmk
    IF exc_name IS NOT NULL THEN
696 4026 aaronmk
        RAISE data_exception USING MESSAGE = msg;
697 3595 aaronmk
    -- Re-raise non-PL/Python exceptions
698
    ELSE
699
        '''+reraise_exc+'''\
700 3555 aaronmk
    END IF;
701
END;
702 3468 aaronmk
''')
703
704 3505 aaronmk
def data_exception_handler(handler):
705
    return ExcHandler('data_exception', handler)
706
707 3527 aaronmk
row_var = Table('row')
708
709 3449 aaronmk
class RowExcIgnore(Code):
710
    def __init__(self, row_type, select_query, with_row, cols=None,
711 3527 aaronmk
        exc_handler=unique_violation_handler, row_var=row_var):
712 3529 aaronmk
        '''
713
        @param row_type Ignored if a custom row_var is used.
714
        @pre If a custom row_var is used, it must already be defined.
715
        '''
716 3449 aaronmk
        Code.__init__(self, lang='plpgsql')
717
718 3482 aaronmk
        row_type = as_Code(row_type)
719 3449 aaronmk
        select_query = as_Code(select_query)
720
        with_row = as_Code(with_row)
721 3452 aaronmk
        row_var = as_Table(row_var)
722 3449 aaronmk
723
        self.row_type = row_type
724
        self.select_query = select_query
725
        self.with_row = with_row
726
        self.cols = cols
727 3455 aaronmk
        self.exc_handler = exc_handler
728 3452 aaronmk
        self.row_var = row_var
729 3449 aaronmk
730
    def to_str(self, db):
731 3452 aaronmk
        if self.cols == None: row_vars = [self.row_var]
732
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
733 3449 aaronmk
734 3526 aaronmk
        # Need an EXCEPTION block for each individual row because "When an error
735
        # is caught by an EXCEPTION clause, [...] all changes to persistent
736
        # database state within the block are rolled back."
737
        # This is unfortunate because "A block containing an EXCEPTION clause is
738
        # significantly more expensive to enter and exit than a block without
739
        # one."
740
        # (http://www.postgresql.org/docs/8.3/static/\
741
        # plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING)
742 3449 aaronmk
        str_ = '''\
743 3529 aaronmk
FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
744
'''+strings.indent(self.select_query.to_str(db))+'''\
745
LOOP
746
'''+strings.indent(self.exc_handler.to_str(db, self.with_row))+'''\
747
END LOOP;
748
'''
749 3533 aaronmk
        if self.row_var == row_var:
750 3529 aaronmk
            str_ = '''\
751 3449 aaronmk
DECLARE
752 3482 aaronmk
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
753 3449 aaronmk
BEGIN
754 3529 aaronmk
'''+strings.indent(str_)+'''\
755
END;
756 3449 aaronmk
'''
757
        return str_
758
759 2986 aaronmk
##### Casts
760
761
class Cast(FunctionCall):
762
    def __init__(self, type_, value):
763 3539 aaronmk
        type_ = as_Code(type_)
764 2986 aaronmk
        value = as_Value(value)
765
766
        self.type_ = type_
767
        self.value = value
768
769
    def to_str(self, db):
770 3539 aaronmk
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_.to_str(db)+')'
771 2986 aaronmk
772 3354 aaronmk
def cast_literal(value):
773 3429 aaronmk
    if not is_literal(value): return value
774 3354 aaronmk
775
    if util.is_str(value.value): value = Cast('text', value)
776
    return value
777
778 2335 aaronmk
##### Conditions
779 2259 aaronmk
780 3350 aaronmk
class NotCond(Code):
781
    def __init__(self, cond):
782 3445 aaronmk
        Code.__init__(self)
783
784 3350 aaronmk
        self.cond = cond
785
786
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
787
788 2398 aaronmk
class ColValueCond(Code):
789
    def __init__(self, col, value):
790 3445 aaronmk
        Code.__init__(self)
791
792 2398 aaronmk
        value = as_ValueCond(value)
793
794
        self.col = col
795
        self.value = value
796
797
    def to_str(self, db): return self.value.to_str(db, self.col)
798
799 2577 aaronmk
def combine_conds(conds, keyword=None):
800
    '''
801
    @param keyword The keyword to add before the conditions, if any
802
    '''
803
    str_ = ''
804
    if keyword != None:
805
        if conds == []: whitespace = ''
806
        elif len(conds) == 1: whitespace = ' '
807
        else: whitespace = '\n'
808
        str_ += keyword+whitespace
809
810
    str_ += '\nAND '.join(conds)
811
    return str_
812
813 2398 aaronmk
##### Condition column comparisons
814
815 2514 aaronmk
class ValueCond(BasicObject):
816 2213 aaronmk
    def __init__(self, value):
817 2858 aaronmk
        value = remove_col_rename(as_Value(value))
818 2213 aaronmk
819
        self.value = value
820 2214 aaronmk
821 2216 aaronmk
    def to_str(self, db, left_value):
822 2214 aaronmk
        '''
823 2216 aaronmk
        @param left_value The Code object that the condition is being applied on
824 2214 aaronmk
        '''
825
        raise NotImplemented()
826 2228 aaronmk
827 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
828 2211 aaronmk
829
class CompareCond(ValueCond):
830
    def __init__(self, value, operator='='):
831 2222 aaronmk
        '''
832
        @param operator By default, compares NULL values literally. Use '~=' or
833
            '~!=' to pass NULLs through.
834
        '''
835 2211 aaronmk
        ValueCond.__init__(self, value)
836
        self.operator = operator
837
838 2216 aaronmk
    def to_str(self, db, left_value):
839 2858 aaronmk
        left_value = remove_col_rename(as_Col(left_value))
840 2216 aaronmk
841 2222 aaronmk
        right_value = self.value
842
843
        # Parse operator
844 2216 aaronmk
        operator = self.operator
845 2222 aaronmk
        passthru_null_ref = [False]
846
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
847
        neg_ref = [False]
848
        operator = strings.remove_prefix('!', operator, neg_ref)
849 2844 aaronmk
        equals = operator.endswith('=') # also includes <=, >=
850 2222 aaronmk
851 2825 aaronmk
        # Handle nullable columns
852
        check_null = False
853 2844 aaronmk
        if not passthru_null_ref[0]: # NULLs compare equal
854 2857 aaronmk
            try: left_value = ensure_not_null(db, left_value)
855 2844 aaronmk
            except ensure_not_null_excs: # fall back to alternate method
856
                check_null = equals and isinstance(right_value, Col)
857 2837 aaronmk
            else:
858 2857 aaronmk
                if isinstance(left_value, EnsureNotNull):
859
                    right_value = ensure_not_null(db, right_value,
860
                        left_value.type) # apply same function to both sides
861 2825 aaronmk
862 2844 aaronmk
        if equals and is_null(right_value): operator = 'IS'
863
864 2825 aaronmk
        left = left_value.to_str(db)
865
        right = right_value.to_str(db)
866
867 2222 aaronmk
        # Create str
868
        str_ = left+' '+operator+' '+right
869 2825 aaronmk
        if check_null:
870 2578 aaronmk
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
871
        if neg_ref[0]: str_ = 'NOT '+str_
872 2222 aaronmk
        return str_
873 2216 aaronmk
874 2260 aaronmk
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
875
assume_literal = object()
876
877
def as_ValueCond(value, default_table=assume_literal):
878
    if not isinstance(value, ValueCond):
879
        if default_table is not assume_literal:
880 2748 aaronmk
            value = with_default_table(value, default_table)
881 2260 aaronmk
        return CompareCond(value)
882 2216 aaronmk
    else: return value
883 2219 aaronmk
884 2335 aaronmk
##### Joins
885
886 2352 aaronmk
join_same = object() # tells Join the left and right columns have the same name
887 2260 aaronmk
888 2353 aaronmk
# Tells Join the left and right columns have the same name and are never NULL
889
join_same_not_null = object()
890
891 2260 aaronmk
filter_out = object() # tells Join to filter out rows that match the join
892
893 2514 aaronmk
class Join(BasicObject):
894 2746 aaronmk
    def __init__(self, table, mapping={}, type_=None):
895 2260 aaronmk
        '''
896
        @param mapping dict(right_table_col=left_table_col, ...)
897 2352 aaronmk
            * if left_table_col is join_same: left_table_col = right_table_col
898 2353 aaronmk
              * Note that right_table_col must be a string
899
            * if left_table_col is join_same_not_null:
900
              left_table_col = right_table_col and both have NOT NULL constraint
901
              * Note that right_table_col must be a string
902 2260 aaronmk
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
903
            * filter_out: equivalent to 'LEFT' with the query filtered by
904
              `table_pkey IS NULL` (indicating no match)
905
        '''
906
        if util.is_str(table): table = Table(table)
907
        assert type_ == None or util.is_str(type_) or type_ is filter_out
908
909
        self.table = table
910
        self.mapping = mapping
911
        self.type_ = type_
912
913 2749 aaronmk
    def to_str(self, db, left_table_):
914 2260 aaronmk
        def join(entry):
915
            '''Parses non-USING joins'''
916
            right_table_col, left_table_col = entry
917
918 2353 aaronmk
            # Switch order (right_table_col is on the left in the comparison)
919
            left = right_table_col
920
            right = left_table_col
921 2749 aaronmk
            left_table = self.table
922
            right_table = left_table_
923 2353 aaronmk
924 2747 aaronmk
            # Parse left side
925 2748 aaronmk
            left = with_default_table(left, left_table)
926 2747 aaronmk
927 2260 aaronmk
            # Parse special values
928 2747 aaronmk
            left_on_right = Col(left.name, right_table)
929
            if right is join_same: right = left_on_right
930 2353 aaronmk
            elif right is join_same_not_null:
931 2747 aaronmk
                right = CompareCond(left_on_right, '~=')
932 2260 aaronmk
933 2747 aaronmk
            # Parse right side
934 2353 aaronmk
            right = as_ValueCond(right, right_table)
935 2747 aaronmk
936
            return right.to_str(db, left)
937 2260 aaronmk
938 2265 aaronmk
        # Create join condition
939
        type_ = self.type_
940 2276 aaronmk
        joins = self.mapping
941 2746 aaronmk
        if joins == {}: join_cond = None
942
        elif type_ is not filter_out and reduce(operator.and_,
943 2460 aaronmk
            (v is join_same_not_null for v in joins.itervalues())):
944 2260 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
945 2747 aaronmk
            cols = map(to_name_only_col, joins.iterkeys())
946
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
947 2757 aaronmk
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
948 2260 aaronmk
949 2757 aaronmk
        if isinstance(self.table, NamedTable): whitespace = '\n'
950
        else: whitespace = ' '
951
952 2260 aaronmk
        # Create join
953
        if type_ is filter_out: type_ = 'LEFT'
954 2266 aaronmk
        str_ = ''
955
        if type_ != None: str_ += type_+' '
956 2757 aaronmk
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
957
        if join_cond != None: str_ += whitespace+join_cond
958 2266 aaronmk
        return str_
959 2349 aaronmk
960 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
961 2424 aaronmk
962
##### Value exprs
963
964 3089 aaronmk
all_cols = CustomCode('*')
965
966 2737 aaronmk
default = CustomCode('DEFAULT')
967
968 3090 aaronmk
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
969 2674 aaronmk
970 3061 aaronmk
class Coalesce(FunctionCall):
971
    def __init__(self, *args):
972
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
973 3060 aaronmk
974 3062 aaronmk
class Nullif(FunctionCall):
975
    def __init__(self, *args):
976
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
977
978 3706 aaronmk
null_as_str = Cast('text', 'NULL')
979
980
def to_text(value): return Coalesce(Cast('text', value), null_as_str)
981
982 2850 aaronmk
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
983 2958 aaronmk
null_sentinels = {
984
    'character varying': r'\N',
985
    'double precision': 'NaN',
986
    'integer': 2147483647,
987
    'text': r'\N',
988 5498 aaronmk
    'date': 'infinity',
989 5266 aaronmk
    'timestamp with time zone': 'infinity',
990
    'taxonrank': 'unknown',
991 2958 aaronmk
}
992 2692 aaronmk
993 3061 aaronmk
class EnsureNotNull(Coalesce):
994 2850 aaronmk
    def __init__(self, value, type_):
995 4938 aaronmk
        if isinstance(type_, ArrayType): null = []
996
        else: null = null_sentinels[type_]
997
        Coalesce.__init__(self, as_Col(value), Cast(type_, null))
998 2850 aaronmk
999
        self.type = type_
1000 3001 aaronmk
1001
    def to_str(self, db):
1002
        col = self.args[0]
1003
        index_col_ = index_col(col)
1004
        if index_col_ != None: return index_col_.to_str(db)
1005 3061 aaronmk
        return Coalesce.to_str(self, db)
1006 2850 aaronmk
1007 3523 aaronmk
#### Arrays
1008
1009 3535 aaronmk
class ArrayMerge(FunctionCall):
1010 3523 aaronmk
    def __init__(self, sep, array):
1011
        array = to_Array(array)
1012
        FunctionCall.__init__(self, InternalFunction('array_to_string'), array,
1013
            sep)
1014
1015 3537 aaronmk
def merge_not_null(db, sep, values):
1016 3707 aaronmk
    return ArrayMerge(sep, map(to_text, values))
1017 3537 aaronmk
1018 2737 aaronmk
##### Table exprs
1019
1020
class Values(Code):
1021
    def __init__(self, values):
1022 2739 aaronmk
        '''
1023
        @param values [...]|[[...], ...] Can be one or multiple rows.
1024
        '''
1025 3445 aaronmk
        Code.__init__(self)
1026
1027 2739 aaronmk
        rows = values
1028
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
1029
            rows = [values]
1030
        for i, row in enumerate(rows):
1031
            rows[i] = map(remove_col_rename, map(as_Value, row))
1032 2737 aaronmk
1033 2739 aaronmk
        self.rows = rows
1034 2737 aaronmk
1035
    def to_str(self, db):
1036 3520 aaronmk
        return 'VALUES '+(', '.join((Tuple(*r).to_str(db) for r in self.rows)))
1037 2737 aaronmk
1038 2740 aaronmk
def NamedValues(name, cols, values):
1039 2745 aaronmk
    '''
1040 3048 aaronmk
    @param cols None|[...]
1041 2745 aaronmk
    @post `cols` will be changed to Col objects with the table set to `name`.
1042
    '''
1043 2834 aaronmk
    table = NamedTable(name, Values(values), cols)
1044 3048 aaronmk
    if cols != None: set_cols_table(table, cols)
1045 2834 aaronmk
    return table
1046 2740 aaronmk
1047 2674 aaronmk
##### Database structure
1048
1049 5074 aaronmk
def is_nullable(db, value):
1050 5330 aaronmk
    if not is_table_col(value): return is_null(value)
1051
    try: return db.col_info(value).nullable
1052 5074 aaronmk
    except NoUnderlyingTableException: return True # not a table column
1053
1054 4442 aaronmk
text_types = set(['character varying', 'text'])
1055 4406 aaronmk
1056 5334 aaronmk
def is_text_type(type_): return type_ in text_types
1057
1058 5335 aaronmk
def is_text_col(db, col): return is_text_type(db.col_info(col).type)
1059 4442 aaronmk
1060 5397 aaronmk
def canon_type(type_):
1061
    if type_ in text_types: return 'text'
1062
    else: return type_
1063
1064 2840 aaronmk
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
1065
1066 2851 aaronmk
def ensure_not_null(db, col, type_=None):
1067 2840 aaronmk
    '''
1068 2855 aaronmk
    @param col If type_ is not set, must have an underlying column.
1069 4488 aaronmk
    @param type_ If set, overrides the underlying column's type and casts the
1070
        column to it if needed.
1071 2840 aaronmk
    @return EnsureNotNull|Col
1072
    @throws ensure_not_null_excs
1073
    '''
1074 5329 aaronmk
    col = remove_col_rename(col)
1075 5332 aaronmk
1076
    try: col_type = db.col_info(underlying_col(col)).type
1077 2855 aaronmk
    except NoUnderlyingTableException:
1078 5333 aaronmk
        if type_ == None and is_null(col): raise # NULL has no type
1079 2855 aaronmk
    else:
1080 5332 aaronmk
        if type_ == None: type_ = col_type
1081
        elif type_ != col_type: col = Cast(type_, col)
1082 2855 aaronmk
1083 5331 aaronmk
    if is_nullable(db, col):
1084 2953 aaronmk
        try: col = EnsureNotNull(col, type_)
1085
        except KeyError, e:
1086
            # Warn of no null sentinel for type, even if caller catches error
1087
            warnings.warn(UserWarning(exc.str_(e)))
1088
            raise
1089
1090 2840 aaronmk
    return col
1091 3536 aaronmk
1092
def try_mk_not_null(db, value):
1093
    '''
1094
    Warning: This function does not guarantee that its result is NOT NULL.
1095
    '''
1096
    try: return ensure_not_null(db, value)
1097
    except ensure_not_null_excs: return value
1098 5367 aaronmk
1099
##### Expression transforming
1100
1101
true_expr = 'true'
1102
false_expr = 'false'
1103
1104
true_re = true_expr
1105
false_re = false_expr
1106
bool_re = r'(?:'+true_re+r'|'+false_re+r')'
1107
atom_re = r'(?:'+bool_re+r'|\([^()]*\)'+r')'
1108
1109
def logic_op_re(op, value_re, expr_re=''):
1110
    op_re = ' '+op+' '
1111
    return '(?:'+expr_re+op_re+value_re+'|'+value_re+op_re+expr_re+')'
1112
1113 5370 aaronmk
not_re = r'\bNOT '
1114 5372 aaronmk
not_false_re = not_re+false_re+r'\b'
1115
not_true_re = not_re+true_re+r'\b'
1116 5367 aaronmk
and_false_re = logic_op_re('AND', false_re, atom_re)
1117 5370 aaronmk
and_false_not_true_re = '(?:'+not_true_re+'|'+and_false_re+')'
1118 5367 aaronmk
and_true_re = logic_op_re('AND', true_re)
1119
or_re = logic_op_re('OR', bool_re)
1120
or_and_true_re = '(?:'+and_true_re+'|'+or_re+')'
1121
1122
def simplify_parens(expr):
1123
    return regexp.sub_nested(r'\(('+atom_re+')\)', r'\1', expr)
1124
1125
def simplify_recursive(sub_func, expr):
1126
    '''
1127
    @param sub_func See regexp.sub_recursive() sub_func param
1128
    '''
1129 5378 aaronmk
    return regexp.sub_recursive(lambda s: sub_func(simplify_parens(s)), expr)
1130
        # simplify_parens() is also done at end in final iteration
1131 5367 aaronmk
1132
def simplify_expr(expr):
1133
    def simplify_logic_ops(expr):
1134
        total_n = 0
1135 5371 aaronmk
        expr, n = re.subn(not_false_re, true_re, expr)
1136
        total_n += n
1137 5370 aaronmk
        expr, n = re.subn(and_false_not_true_re, false_expr, expr)
1138 5367 aaronmk
        total_n += n
1139
        expr, n = re.subn(or_and_true_re, r'', expr)
1140
        total_n += n
1141
        return expr, total_n
1142
1143
    expr = expr.replace('(NULL IS NULL)', true_expr)
1144
    expr = expr.replace('(NULL IS NOT NULL)', false_expr)
1145
    expr = simplify_recursive(simplify_logic_ops, expr)
1146
    return expr
1147
1148
name_re = r'(?:\w+|(?:"[^"]*")+)'
1149
1150
def parse_expr_col(str_):
1151
    match = re.match(r'^\('+name_re+r'\(('+name_re+r').*\)\)$', str_)
1152
    if match: str_ = match.group(1)
1153
    return unesc_name(str_)
1154
1155
def map_expr(db, expr, mapping, in_cols_found=None):
1156
    '''Replaces output columns with input columns in an expression.
1157
    @param in_cols_found If set, will be filled in with the expr's (input) cols
1158
    '''
1159
    for out, in_ in mapping.iteritems():
1160
        orig_expr = expr
1161
        out = to_name_only_col(out)
1162
        in_str = to_name_only_col(remove_col_rename(in_)).to_str(db)
1163
1164
        # Replace out both with and without quotes
1165
        expr = expr.replace(out.to_str(db), in_str)
1166 5507 aaronmk
        expr = re.sub(r'(?<!["\'\.=\[])\b'+out.name+r'\b(?!["\',\.=\]])',
1167
            in_str, expr)
1168 5367 aaronmk
1169
        if in_cols_found != None and expr != orig_expr: # replaced something
1170
            in_cols_found.append(in_)
1171
1172
    return simplify_expr(expr)