Project

General

Profile

1 2211 aaronmk
# SQL code generation
2
3 2748 aaronmk
import copy
4 3518 aaronmk
import itertools
5 2276 aaronmk
import operator
6 9383 aaronmk
from collections import OrderedDict
7 2568 aaronmk
import re
8 2653 aaronmk
import UserDict
9 2953 aaronmk
import warnings
10 2276 aaronmk
11 2667 aaronmk
import dicts
12 2953 aaronmk
import exc
13 2701 aaronmk
import iters
14
import lists
15 2360 aaronmk
import objects
16 5367 aaronmk
import regexp
17 2222 aaronmk
import strings
18 2227 aaronmk
import util
19 2211 aaronmk
20 2587 aaronmk
##### Names
21 2499 aaronmk
22 2608 aaronmk
identifier_max_len = 63 # works for both PostgreSQL and MySQL
23 2587 aaronmk
24 2932 aaronmk
def concat(str_, suffix):
25 2609 aaronmk
    '''Preserves version so that it won't be truncated off the string, leading
26
    to collisions.'''
27 2613 aaronmk
    # Preserve version
28 5552 aaronmk
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)+(?:::[\w ]+)*)$', str_)
29 2985 aaronmk
    if match:
30
        str_, old_suffix = match.groups()
31
        suffix = old_suffix+suffix
32 2613 aaronmk
33 2932 aaronmk
    return strings.concat(str_, suffix, identifier_max_len)
34 2587 aaronmk
35 2932 aaronmk
def truncate(str_): return concat(str_, '')
36 2842 aaronmk
37 2575 aaronmk
def is_safe_name(name):
38 2583 aaronmk
    '''A name is safe *and unambiguous* if it:
39
    * contains only *lowercase* word (\w) characters
40
    * doesn't start with a digit
41
    * contains "_", so that it's not a keyword
42 2984 aaronmk
    '''
43
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
44 2568 aaronmk
45 2499 aaronmk
def esc_name(name, quote='"'):
46
    return quote + name.replace(quote, quote+quote) + quote
47
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
48
49 3320 aaronmk
def unesc_name(name, quote='"'):
50
    removed_ref = [False]
51
    name = strings.remove_prefix(quote, name, removed_ref)
52
    if removed_ref[0]:
53
        name = strings.remove_suffix(quote, name, removed_ref)
54
        assert removed_ref[0]
55
        name = name.replace(quote+quote, quote)
56
    return name
57
58 2513 aaronmk
def clean_name(name): return name.replace('"', '').replace('`', '')
59
60 3041 aaronmk
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
61
62 3182 aaronmk
def lstrip(str_):
63
    '''Also removes comments.'''
64
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
65
    return str_.lstrip()
66
67 2659 aaronmk
##### General SQL code objects
68 2219 aaronmk
69 2349 aaronmk
class MockDb:
70 2503 aaronmk
    def esc_value(self, value): return strings.repr_no_u(value)
71 2349 aaronmk
72 2499 aaronmk
    def esc_name(self, name): return esc_name(name)
73 2859 aaronmk
74
    def col_info(self, col):
75
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
76
77 2349 aaronmk
mockDb = MockDb()
78
79 2514 aaronmk
class BasicObject(objects.BasicObject):
80
    def __str__(self): return clean_name(strings.repr_no_u(self))
81
82 2659 aaronmk
##### Unparameterized code objects
83
84 2514 aaronmk
class Code(BasicObject):
85 3446 aaronmk
    def __init__(self, lang='sql'):
86
        self.lang = lang
87 3445 aaronmk
88 2658 aaronmk
    def to_str(self, db): raise NotImplementedError()
89 2349 aaronmk
90 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb)
91 2211 aaronmk
92 2269 aaronmk
class CustomCode(Code):
93 3445 aaronmk
    def __init__(self, str_):
94
        Code.__init__(self)
95
96
        self.str_ = str_
97 2256 aaronmk
98
    def to_str(self, db): return self.str_
99
100 2815 aaronmk
def as_Code(value, db=None):
101
    '''
102
    @param db If set, runs db.std_code() on the value.
103
    '''
104 3447 aaronmk
    if isinstance(value, Code): return value
105
106 2815 aaronmk
    if util.is_str(value):
107
        if db != None: value = db.std_code(value)
108
        return CustomCode(value)
109 2659 aaronmk
    else: return Literal(value)
110
111 2540 aaronmk
class Expr(Code):
112 3445 aaronmk
    def __init__(self, expr):
113
        Code.__init__(self)
114
115
        self.expr = expr
116 2540 aaronmk
117
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
118
119 3086 aaronmk
##### Names
120
121
class Name(Code):
122 3087 aaronmk
    def __init__(self, name):
123 3445 aaronmk
        Code.__init__(self)
124
125 3087 aaronmk
        name = truncate(name)
126
127
        self.name = name
128 3086 aaronmk
129
    def to_str(self, db): return db.esc_name(self.name)
130
131
def as_Name(value):
132
    if isinstance(value, Code): return value
133
    else: return Name(value)
134
135 2335 aaronmk
##### Literal values
136
137 3519 aaronmk
#### Primitives
138
139 2216 aaronmk
class Literal(Code):
140 3445 aaronmk
    def __init__(self, value):
141
        Code.__init__(self)
142
143
        self.value = value
144 2213 aaronmk
145
    def to_str(self, db): return db.esc_value(self.value)
146 2211 aaronmk
147 2400 aaronmk
def as_Value(value):
148
    if isinstance(value, Code): return value
149
    else: return Literal(value)
150
151 6225 aaronmk
def get_value(value):
152
    '''Unwraps a Literal's value'''
153
    value = remove_col_rename(value)
154
    if isinstance(value, Literal): return value.value
155
    else:
156
        assert not isinstance(value, Code)
157
        return value
158
159 3429 aaronmk
def is_literal(value): return isinstance(value, Literal)
160 2216 aaronmk
161 3429 aaronmk
def is_null(value): return is_literal(value) and value.value == None
162
163 3519 aaronmk
#### Composites
164
165 3521 aaronmk
class List(Code):
166
    def __init__(self, values):
167 3519 aaronmk
        Code.__init__(self)
168
169
        self.values = values
170
171 3521 aaronmk
    def to_str(self, db): return ', '.join((v.to_str(db) for v in self.values))
172 3519 aaronmk
173 3521 aaronmk
class Tuple(List):
174
    def __init__(self, *values):
175
        List.__init__(self, values)
176
177
    def to_str(self, db): return '('+List.to_str(self, db)+')'
178
179 3520 aaronmk
class Row(Tuple):
180
    def to_str(self, db): return 'ROW'+Tuple.to_str(self, db)
181 3519 aaronmk
182 3522 aaronmk
### Arrays
183
184
class Array(List):
185
    def __init__(self, values):
186
        values = map(remove_col_rename, values)
187
188
        List.__init__(self, values)
189
190
    def to_str(self, db): return 'ARRAY['+List.to_str(self, db)+']'
191
192
def to_Array(value):
193
    if isinstance(value, Array): return value
194
    return Array(lists.mk_seq(value))
195
196 2711 aaronmk
##### Derived elements
197
198
src_self = object() # tells Col that it is its own source column
199
200
class Derived(Code):
201
    def __init__(self, srcs):
202 2712 aaronmk
        '''An element which was derived from some other element(s).
203 2711 aaronmk
        @param srcs See self.set_srcs()
204
        '''
205 3445 aaronmk
        Code.__init__(self)
206
207 2711 aaronmk
        self.set_srcs(srcs)
208
209 2713 aaronmk
    def set_srcs(self, srcs, overwrite=True):
210 2711 aaronmk
        '''
211
        @param srcs (self_type...)|src_self The element(s) this is derived from
212
        '''
213 2713 aaronmk
        if not overwrite and self.srcs != (): return # already set
214
215 2711 aaronmk
        if srcs == src_self: srcs = (self,)
216
        srcs = tuple(srcs) # make Col hashable
217
        self.srcs = srcs
218
219
    def _compare_on(self):
220
        compare_on = self.__dict__.copy()
221
        del compare_on['srcs'] # ignore
222
        return compare_on
223
224
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
225
226 2335 aaronmk
##### Tables
227
228 2712 aaronmk
class Table(Derived):
229 2991 aaronmk
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
230 2211 aaronmk
        '''
231
        @param schema str|None (for no schema)
232 2712 aaronmk
        @param srcs (Table...)|src_self See Derived.set_srcs()
233 2211 aaronmk
        '''
234 2712 aaronmk
        Derived.__init__(self, srcs)
235
236 3091 aaronmk
        if util.is_str(name): name = truncate(name)
237 2843 aaronmk
238 2211 aaronmk
        self.name = name
239
        self.schema = schema
240 2991 aaronmk
        self.is_temp = is_temp
241 5524 aaronmk
        self.order_by = None
242 3000 aaronmk
        self.index_cols = {}
243 2211 aaronmk
244 2348 aaronmk
    def to_str(self, db):
245
        str_ = ''
246 3088 aaronmk
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
247
        str_ += as_Name(self.name).to_str(db)
248 2348 aaronmk
        return str_
249 2336 aaronmk
250
    def to_Table(self): return self
251 3000 aaronmk
252
    def _compare_on(self):
253
        compare_on = Derived._compare_on(self)
254 5524 aaronmk
        del compare_on['order_by'] # ignore
255 3000 aaronmk
        del compare_on['index_cols'] # ignore
256
        return compare_on
257 2211 aaronmk
258 2835 aaronmk
def is_underlying_table(table):
259
    return isinstance(table, Table) and table.to_Table() is table
260 2832 aaronmk
261 10186 aaronmk
def table2regclass_text(db, table):
262
    assert isinstance(table, Table)
263
    return db.esc_value(table.to_str(db))
264
265 4114 aaronmk
class NoUnderlyingTableException(Exception):
266
    def __init__(self, ref):
267
        Exception.__init__(self, 'for: '+strings.as_tt(strings.urepr(ref)))
268
        self.ref = ref
269 2902 aaronmk
270
def underlying_table(table):
271
    table = remove_table_rename(table)
272 3532 aaronmk
    if table != None and table.srcs:
273
        table, = table.srcs # for derived tables or row vars
274 4114 aaronmk
    if not is_underlying_table(table): raise NoUnderlyingTableException(table)
275 2902 aaronmk
    return table
276
277 2776 aaronmk
def as_Table(table, schema=None):
278 2270 aaronmk
    if table == None or isinstance(table, Code): return table
279 2776 aaronmk
    else: return Table(table, schema)
280 2219 aaronmk
281 3101 aaronmk
def suffixed_table(table, suffix):
282 3128 aaronmk
    table = copy.copy(table) # don't modify input!
283
    table.name = concat(table.name, suffix)
284
    return table
285 2707 aaronmk
286 2336 aaronmk
class NamedTable(Table):
287
    def __init__(self, name, code, cols=None):
288
        Table.__init__(self, name)
289
290 3016 aaronmk
        code = as_Table(code)
291 2741 aaronmk
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
292 3020 aaronmk
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
293 2336 aaronmk
294
        self.code = code
295
        self.cols = cols
296
297
    def to_str(self, db):
298 3026 aaronmk
        str_ = self.code.to_str(db)
299
        if str_.find('\n') >= 0: whitespace = '\n'
300
        else: whitespace = ' '
301
        str_ += whitespace+'AS '+Table.to_str(self, db)
302 2742 aaronmk
        if self.cols != None:
303
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
304 2336 aaronmk
        return str_
305
306
    def to_Table(self): return Table(self.name)
307
308 2753 aaronmk
def remove_table_rename(table):
309
    if isinstance(table, NamedTable): table = table.code
310
    return table
311
312 2335 aaronmk
##### Columns
313
314 2711 aaronmk
class Col(Derived):
315 2701 aaronmk
    def __init__(self, name, table=None, srcs=()):
316 2211 aaronmk
        '''
317
        @param table Table|None (for no table)
318 2711 aaronmk
        @param srcs (Col...)|src_self See Derived.set_srcs()
319 2211 aaronmk
        '''
320 2711 aaronmk
        Derived.__init__(self, srcs)
321
322 3091 aaronmk
        if util.is_str(name): name = truncate(name)
323 2241 aaronmk
        if util.is_str(table): table = Table(table)
324 2211 aaronmk
        assert table == None or isinstance(table, Table)
325
326
        self.name = name
327
        self.table = table
328
329 2989 aaronmk
    def to_str(self, db, for_str=False):
330 3088 aaronmk
        str_ = as_Name(self.name).to_str(db)
331 2989 aaronmk
        if for_str: str_ = clean_name(str_)
332 2933 aaronmk
        if self.table != None:
333 2989 aaronmk
            table = self.table.to_Table()
334 3750 aaronmk
            if for_str: str_ = concat(strings.ustr(table), '.'+str_)
335 2989 aaronmk
            else: str_ = table.to_str(db)+'.'+str_
336 2211 aaronmk
        return str_
337 2314 aaronmk
338 2989 aaronmk
    def __str__(self): return self.to_str(mockDb, for_str=True)
339 2933 aaronmk
340 2314 aaronmk
    def to_Col(self): return self
341 2211 aaronmk
342 3512 aaronmk
def is_col(col): return isinstance(col, Col)
343 2393 aaronmk
344 3512 aaronmk
def is_table_col(col): return is_col(col) and col.table != None
345
346 10138 aaronmk
def col2col_ref(db, col):
347
    assert isinstance(col, Col)
348
    return db.esc_value((col.table.to_str(db), col.name))
349
350 3000 aaronmk
def index_col(col):
351
    if not is_table_col(col): return None
352 3104 aaronmk
353
    table = col.table
354
    try: name = table.index_cols[col.name]
355
    except KeyError: return None
356
    else: return Col(name, table, col.srcs)
357 2999 aaronmk
358 3024 aaronmk
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
359 2996 aaronmk
360 2563 aaronmk
def as_Col(col, table=None, name=None):
361
    '''
362
    @param name If not None, any non-Col input will be renamed using NamedCol.
363
    '''
364
    if name != None:
365
        col = as_Value(col)
366
        if not isinstance(col, Col): col = NamedCol(name, col)
367 2333 aaronmk
368
    if isinstance(col, Code): return col
369 3513 aaronmk
    elif util.is_str(col): return Col(col, table)
370
    else: return Literal(col)
371 2260 aaronmk
372 3093 aaronmk
def with_table(col, table):
373 3105 aaronmk
    if isinstance(col, NamedCol): pass # doesn't take a table
374
    elif isinstance(col, FunctionCall):
375
        col = copy.deepcopy(col) # don't modify input!
376
        col.args[0].table = table
377 3493 aaronmk
    elif isinstance(col, Col):
378 3098 aaronmk
        col = copy.copy(col) # don't modify input!
379
        col.table = table
380 3093 aaronmk
    return col
381
382 3100 aaronmk
def with_default_table(col, table):
383 2747 aaronmk
    col = as_Col(col)
384 3100 aaronmk
    if col.table == None: col = with_table(col, table)
385 2747 aaronmk
    return col
386
387 2744 aaronmk
def set_cols_table(table, cols):
388
    table = as_Table(table)
389
390
    for i, col in enumerate(cols):
391
        col = cols[i] = as_Col(col)
392
        col.table = table
393
394 2401 aaronmk
def to_name_only_col(col, check_table=None):
395
    col = as_Col(col)
396 3020 aaronmk
    if not is_table_col(col): return col
397 2401 aaronmk
398
    if check_table != None:
399
        table = col.table
400
        assert table == None or table == check_table
401
    return Col(col.name)
402
403 2993 aaronmk
def suffixed_col(col, suffix):
404
    return Col(concat(col.name, suffix), col.table, col.srcs)
405
406 3516 aaronmk
def has_srcs(col): return is_col(col) and col.srcs
407
408 3518 aaronmk
def cross_join_srcs(cols):
409
    cols = filter(has_srcs, cols) # empty srcs will mess up the cross join
410
    srcs = [[s.name for s in c.srcs] for c in cols]
411 5816 aaronmk
    if not srcs: return [] # itertools.product() returns [()] for empty input
412 3518 aaronmk
    return [Col(','.join(s)) for s in itertools.product(*srcs)]
413 3514 aaronmk
414 2323 aaronmk
class NamedCol(Col):
415 2229 aaronmk
    def __init__(self, name, code):
416 2310 aaronmk
        Col.__init__(self, name)
417
418 3016 aaronmk
        code = as_Value(code)
419 2229 aaronmk
420
        self.code = code
421
422
    def to_str(self, db):
423 2310 aaronmk
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
424 2314 aaronmk
425
    def to_Col(self): return Col(self.name)
426 2229 aaronmk
427 2462 aaronmk
def remove_col_rename(col):
428
    if isinstance(col, NamedCol): col = col.code
429
    return col
430
431 2830 aaronmk
def underlying_col(col):
432
    col = remove_col_rename(col)
433 4114 aaronmk
    if not isinstance(col, Col): raise NoUnderlyingTableException(col)
434 2849 aaronmk
435 2902 aaronmk
    return Col(col.name, underlying_table(col.table), col.srcs)
436 2830 aaronmk
437 2703 aaronmk
def wrap(wrap_func, value):
438
    '''Wraps a value, propagating any column renaming to the returned value.'''
439
    if isinstance(value, NamedCol):
440
        return NamedCol(value.name, wrap_func(value.code))
441
    else: return wrap_func(value)
442
443 2667 aaronmk
class ColDict(dicts.DictProxy):
444 3691 aaronmk
    '''A dict that automatically makes inserted entries Col objects.
445
    Anything that isn't a column is wrapped in a NamedCol with the key's column
446
    name by `as_Col(value, name=key.name)`.
447
    '''
448 2564 aaronmk
449 2645 aaronmk
    def __init__(self, db, keys_table, dict_={}):
450 3158 aaronmk
        dicts.DictProxy.__init__(self, OrderedDict())
451 2667 aaronmk
452 2645 aaronmk
        keys_table = as_Table(keys_table)
453
454 2642 aaronmk
        self.db = db
455 2641 aaronmk
        self.table = keys_table
456 2653 aaronmk
        self.update(dict_) # after setting vars because __setitem__() needs them
457 2641 aaronmk
458 2667 aaronmk
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
459 2655 aaronmk
460 2667 aaronmk
    def __getitem__(self, key):
461
        return dicts.DictProxy.__getitem__(self, self._key(key))
462 2653 aaronmk
463 2564 aaronmk
    def __setitem__(self, key, value):
464 2642 aaronmk
        key = self._key(key)
465 3639 aaronmk
        if value == None:
466
            try: value = self.db.col_info(key).default
467
            except NoUnderlyingTableException: pass # not a table column
468 2667 aaronmk
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
469 2564 aaronmk
470 2641 aaronmk
    def _key(self, key): return as_Col(key, self.table)
471 2564 aaronmk
472 3519 aaronmk
##### Definitions
473 3491 aaronmk
474 3469 aaronmk
class TypedCol(Col):
475
    def __init__(self, name, type_, default=None, nullable=True,
476
        constraints=None):
477
        assert default == None or isinstance(default, Code)
478
479
        Col.__init__(self, name)
480
481
        self.type = type_
482
        self.default = default
483
        self.nullable = nullable
484
        self.constraints = constraints
485
486
    def to_str(self, db):
487 3481 aaronmk
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
488 3469 aaronmk
        if not self.nullable: str_ += ' NOT NULL'
489
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
490
        if self.constraints != None: str_ += ' '+self.constraints
491
        return str_
492
493
    def to_Col(self): return Col(self.name)
494
495 3488 aaronmk
class SetOf(Code):
496
    def __init__(self, type_):
497
        Code.__init__(self)
498
499
        self.type = type_
500
501
    def to_str(self, db):
502
        return 'SETOF '+self.type.to_str(db)
503
504 3483 aaronmk
class RowType(Code):
505
    def __init__(self, table):
506
        Code.__init__(self)
507
508
        self.table = table
509
510
    def to_str(self, db):
511
        return self.table.to_str(db)+'%ROWTYPE'
512
513 3485 aaronmk
class ColType(Code):
514
    def __init__(self, col):
515
        Code.__init__(self)
516
517
        self.col = col
518
519
    def to_str(self, db):
520
        return self.col.to_str(db)+'%TYPE'
521
522 4935 aaronmk
class ArrayType(Code):
523
    def __init__(self, elem_type):
524
        Code.__init__(self)
525
        elem_type = as_Code(elem_type)
526
527
        self.elem_type = elem_type
528
529
    def to_str(self, db):
530
        return self.elem_type.to_str(db)+'[]'
531
532 2524 aaronmk
##### Functions
533
534 2912 aaronmk
Function = Table
535 2911 aaronmk
as_Function = as_Table
536
537 2691 aaronmk
class InternalFunction(CustomCode): pass
538
539 3442 aaronmk
#### Calls
540
541 2941 aaronmk
class NamedArg(NamedCol):
542
    def __init__(self, name, value):
543
        NamedCol.__init__(self, name, value)
544
545
    def to_str(self, db):
546
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
547
548 2524 aaronmk
class FunctionCall(Code):
549 2941 aaronmk
    def __init__(self, function, *args, **kw_args):
550 2524 aaronmk
        '''
551 2690 aaronmk
        @param args [Code|literal-value...] The function's arguments
552 2524 aaronmk
        '''
553 3445 aaronmk
        Code.__init__(self)
554
555 3016 aaronmk
        function = as_Function(function)
556 2941 aaronmk
        def filter_(arg): return remove_col_rename(as_Value(arg))
557
        args = map(filter_, args)
558
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
559 2524 aaronmk
560
        self.function = function
561
        self.args = args
562
563
    def to_str(self, db):
564
        args_str = ', '.join((v.to_str(db) for v in self.args))
565
        return self.function.to_str(db)+'('+args_str+')'
566
567 2533 aaronmk
def wrap_in_func(function, value):
568
    '''Wraps a value inside a function call.
569
    Propagates any column renaming to the returned value.
570
    '''
571 2703 aaronmk
    return wrap(lambda v: FunctionCall(function, v), value)
572 2533 aaronmk
573 2561 aaronmk
def unwrap_func_call(func_call, check_name=None):
574
    '''Unwraps any function call to its first argument.
575
    Also removes any column renaming.
576
    '''
577
    func_call = remove_col_rename(func_call)
578
    if not isinstance(func_call, FunctionCall): return func_call
579
580
    if check_name != None:
581
        name = func_call.function.name
582
        assert name == None or name == check_name
583
    return func_call.args[0]
584
585 3442 aaronmk
#### Definitions
586
587
class FunctionDef(Code):
588 3471 aaronmk
    def __init__(self, function, return_type, body, params=[], modifiers=None):
589 3445 aaronmk
        Code.__init__(self)
590
591 3487 aaronmk
        return_type = as_Code(return_type)
592 3444 aaronmk
        body = as_Code(body)
593
594 3442 aaronmk
        self.function = function
595
        self.return_type = return_type
596
        self.body = body
597 3471 aaronmk
        self.params = params
598 3456 aaronmk
        self.modifiers = modifiers
599 3442 aaronmk
600
    def to_str(self, db):
601 3487 aaronmk
        params_str = (', '.join((p.to_str(db) for p in self.params)))
602 3442 aaronmk
        str_ = '''\
603 3471 aaronmk
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
604 3487 aaronmk
RETURNS '''+self.return_type.to_str(db)+'''
605 3448 aaronmk
LANGUAGE '''+self.body.lang+'''
606 3456 aaronmk
'''
607
        if self.modifiers != None: str_ += self.modifiers+'\n'
608
        str_ += '''\
609 3442 aaronmk
AS $$
610 3444 aaronmk
'''+self.body.to_str(db)+'''
611 3442 aaronmk
$$;
612
'''
613
        return str_
614
615 3469 aaronmk
class FunctionParam(TypedCol):
616
    def __init__(self, name, type_, default=None, out=False):
617
        TypedCol.__init__(self, name, type_, default)
618
619
        self.out = out
620
621
    def to_str(self, db):
622
        str_ = TypedCol.to_str(self, db)
623
        if self.out: str_ = 'OUT '+str_
624
        return str_
625
626
    def to_Col(self): return Col(self.name)
627
628 3454 aaronmk
### PL/pgSQL
629
630 3496 aaronmk
class ReturnQuery(Code):
631
    def __init__(self, query):
632
        Code.__init__(self)
633
634
        query = as_Code(query)
635
636
        self.query = query
637
638
    def to_str(self, db):
639
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
640
641 3515 aaronmk
## Exceptions
642
643 3509 aaronmk
class BaseExcHandler(BasicObject):
644
    def to_str(self, db, body): raise NotImplementedError()
645 3510 aaronmk
646
    def __repr__(self): return self.to_str(mockDb, '<body>')
647 3509 aaronmk
648 3549 aaronmk
suppress_exc = 'NULL;\n';
649
650 3554 aaronmk
reraise_exc = 'RAISE USING ERRCODE = SQLSTATE, MESSAGE = SQLERRM;\n';
651
652 3509 aaronmk
class ExcHandler(BaseExcHandler):
653 3454 aaronmk
    def __init__(self, exc, handler=None):
654
        if handler != None: handler = as_Code(handler)
655
656
        self.exc = exc
657
        self.handler = handler
658
659
    def to_str(self, db, body):
660
        body = as_Code(body)
661
662 3467 aaronmk
        if self.handler != None:
663
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
664 3549 aaronmk
        else: handler_str = ' '+suppress_exc
665 3454 aaronmk
666
        str_ = '''\
667
BEGIN
668 3467 aaronmk
'''+strings.indent(body.to_str(db))+'''\
669 3454 aaronmk
EXCEPTION
670 3463 aaronmk
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
671 3454 aaronmk
END;\
672
'''
673
        return str_
674
675 3515 aaronmk
class NestedExcHandler(BaseExcHandler):
676
    def __init__(self, *handlers):
677
        '''
678
        @param handlers Sorted from outermost to innermost
679
        '''
680
        self.handlers = handlers
681
682
    def to_str(self, db, body):
683
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
684
        return body
685
686 3503 aaronmk
class ExcToWarning(Code):
687
    def __init__(self, return_):
688
        '''
689
        @param return_ Statement to return a default value in case of error
690
        '''
691
        Code.__init__(self)
692
693
        return_ = as_Code(return_)
694
695
        self.return_ = return_
696
697
    def to_str(self, db):
698
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
699
700 3454 aaronmk
unique_violation_handler = ExcHandler('unique_violation')
701
702 3555 aaronmk
# Note doubled "\"s because inside Python string
703 3468 aaronmk
plpythonu_error_handler = ExcHandler('internal_error', '''\
704 3591 aaronmk
-- Handle PL/Python exceptions
705 3555 aaronmk
DECLARE
706 3595 aaronmk
    matches text[] := regexp_matches(SQLERRM,
707
        E'^(?:PL/Python: )?(\\\\w+): (.*)$'); -- .* also matches \\n
708 3555 aaronmk
    exc_name text := matches[1];
709
    msg text := matches[2];
710
BEGIN
711 3596 aaronmk
    /* Re-raise PL/Python exceptions with the PL/Python prefix removed.
712
    This allows the exception to be parsed like a native exception.
713
    Always raise as data_exception so it goes in the errors table. */
714 3598 aaronmk
    IF exc_name IS NOT NULL THEN
715 4026 aaronmk
        RAISE data_exception USING MESSAGE = msg;
716 3595 aaronmk
    -- Re-raise non-PL/Python exceptions
717
    ELSE
718
        '''+reraise_exc+'''\
719 3555 aaronmk
    END IF;
720
END;
721 3468 aaronmk
''')
722
723 3505 aaronmk
def data_exception_handler(handler):
724
    return ExcHandler('data_exception', handler)
725
726 3527 aaronmk
row_var = Table('row')
727
728 3449 aaronmk
class RowExcIgnore(Code):
729
    def __init__(self, row_type, select_query, with_row, cols=None,
730 3527 aaronmk
        exc_handler=unique_violation_handler, row_var=row_var):
731 3529 aaronmk
        '''
732
        @param row_type Ignored if a custom row_var is used.
733
        @pre If a custom row_var is used, it must already be defined.
734
        '''
735 3449 aaronmk
        Code.__init__(self, lang='plpgsql')
736
737 3482 aaronmk
        row_type = as_Code(row_type)
738 3449 aaronmk
        select_query = as_Code(select_query)
739
        with_row = as_Code(with_row)
740 3452 aaronmk
        row_var = as_Table(row_var)
741 3449 aaronmk
742
        self.row_type = row_type
743
        self.select_query = select_query
744
        self.with_row = with_row
745
        self.cols = cols
746 3455 aaronmk
        self.exc_handler = exc_handler
747 3452 aaronmk
        self.row_var = row_var
748 3449 aaronmk
749
    def to_str(self, db):
750 3452 aaronmk
        if self.cols == None: row_vars = [self.row_var]
751
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
752 3449 aaronmk
753 3526 aaronmk
        # Need an EXCEPTION block for each individual row because "When an error
754
        # is caught by an EXCEPTION clause, [...] all changes to persistent
755
        # database state within the block are rolled back."
756
        # This is unfortunate because "A block containing an EXCEPTION clause is
757
        # significantly more expensive to enter and exit than a block without
758
        # one."
759
        # (http://www.postgresql.org/docs/8.3/static/\
760
        # plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING)
761 3449 aaronmk
        str_ = '''\
762 3529 aaronmk
FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
763
'''+strings.indent(self.select_query.to_str(db))+'''\
764
LOOP
765
'''+strings.indent(self.exc_handler.to_str(db, self.with_row))+'''\
766
END LOOP;
767
'''
768 3533 aaronmk
        if self.row_var == row_var:
769 3529 aaronmk
            str_ = '''\
770 3449 aaronmk
DECLARE
771 3482 aaronmk
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
772 3449 aaronmk
BEGIN
773 3529 aaronmk
'''+strings.indent(str_)+'''\
774
END;
775 3449 aaronmk
'''
776
        return str_
777
778 2986 aaronmk
##### Casts
779
780
class Cast(FunctionCall):
781
    def __init__(self, type_, value):
782 3539 aaronmk
        type_ = as_Code(type_)
783 2986 aaronmk
        value = as_Value(value)
784
785 6299 aaronmk
        # Most types cannot be cast directly to unknown
786
        if type_.to_str(mockDb) == 'unknown': value = Cast('text', value)
787
788 2986 aaronmk
        self.type_ = type_
789
        self.value = value
790
791
    def to_str(self, db):
792 3539 aaronmk
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_.to_str(db)+')'
793 2986 aaronmk
794 3354 aaronmk
def cast_literal(value):
795 3429 aaronmk
    if not is_literal(value): return value
796 3354 aaronmk
797
    if util.is_str(value.value): value = Cast('text', value)
798
    return value
799
800 2335 aaronmk
##### Conditions
801 2259 aaronmk
802 3350 aaronmk
class NotCond(Code):
803
    def __init__(self, cond):
804 3445 aaronmk
        Code.__init__(self)
805
806 5893 aaronmk
        if not isinstance(cond, Coalesce): cond = Coalesce(cond, False)
807
808 3350 aaronmk
        self.cond = cond
809
810
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
811
812 2398 aaronmk
class ColValueCond(Code):
813
    def __init__(self, col, value):
814 3445 aaronmk
        Code.__init__(self)
815
816 2398 aaronmk
        value = as_ValueCond(value)
817
818
        self.col = col
819
        self.value = value
820
821
    def to_str(self, db): return self.value.to_str(db, self.col)
822
823 2577 aaronmk
def combine_conds(conds, keyword=None):
824
    '''
825
    @param keyword The keyword to add before the conditions, if any
826
    '''
827
    str_ = ''
828
    if keyword != None:
829
        if conds == []: whitespace = ''
830
        elif len(conds) == 1: whitespace = ' '
831
        else: whitespace = '\n'
832
        str_ += keyword+whitespace
833
834
    str_ += '\nAND '.join(conds)
835
    return str_
836
837 2398 aaronmk
##### Condition column comparisons
838
839 2514 aaronmk
class ValueCond(BasicObject):
840 2213 aaronmk
    def __init__(self, value):
841 2858 aaronmk
        value = remove_col_rename(as_Value(value))
842 2213 aaronmk
843
        self.value = value
844 2214 aaronmk
845 2216 aaronmk
    def to_str(self, db, left_value):
846 2214 aaronmk
        '''
847 2216 aaronmk
        @param left_value The Code object that the condition is being applied on
848 2214 aaronmk
        '''
849
        raise NotImplemented()
850 2228 aaronmk
851 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
852 2211 aaronmk
853
class CompareCond(ValueCond):
854
    def __init__(self, value, operator='='):
855 2222 aaronmk
        '''
856
        @param operator By default, compares NULL values literally. Use '~=' or
857
            '~!=' to pass NULLs through.
858
        '''
859 2211 aaronmk
        ValueCond.__init__(self, value)
860
        self.operator = operator
861
862 2216 aaronmk
    def to_str(self, db, left_value):
863 2858 aaronmk
        left_value = remove_col_rename(as_Col(left_value))
864 2216 aaronmk
865 2222 aaronmk
        right_value = self.value
866
867
        # Parse operator
868 2216 aaronmk
        operator = self.operator
869 2222 aaronmk
        passthru_null_ref = [False]
870
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
871
        neg_ref = [False]
872
        operator = strings.remove_prefix('!', operator, neg_ref)
873 2844 aaronmk
        equals = operator.endswith('=') # also includes <=, >=
874 2222 aaronmk
875 2825 aaronmk
        # Handle nullable columns
876
        check_null = False
877 2844 aaronmk
        if not passthru_null_ref[0]: # NULLs compare equal
878 2857 aaronmk
            try: left_value = ensure_not_null(db, left_value)
879 2844 aaronmk
            except ensure_not_null_excs: # fall back to alternate method
880
                check_null = equals and isinstance(right_value, Col)
881 2837 aaronmk
            else:
882 2857 aaronmk
                if isinstance(left_value, EnsureNotNull):
883
                    right_value = ensure_not_null(db, right_value,
884
                        left_value.type) # apply same function to both sides
885 2825 aaronmk
886 2844 aaronmk
        if equals and is_null(right_value): operator = 'IS'
887
888 2825 aaronmk
        left = left_value.to_str(db)
889
        right = right_value.to_str(db)
890
891 2222 aaronmk
        # Create str
892
        str_ = left+' '+operator+' '+right
893 2825 aaronmk
        if check_null:
894 2578 aaronmk
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
895
        if neg_ref[0]: str_ = 'NOT '+str_
896 2222 aaronmk
        return str_
897 2216 aaronmk
898 2260 aaronmk
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
899
assume_literal = object()
900
901
def as_ValueCond(value, default_table=assume_literal):
902
    if not isinstance(value, ValueCond):
903
        if default_table is not assume_literal:
904 2748 aaronmk
            value = with_default_table(value, default_table)
905 2260 aaronmk
        return CompareCond(value)
906 2216 aaronmk
    else: return value
907 2219 aaronmk
908 2335 aaronmk
##### Joins
909
910 2352 aaronmk
join_same = object() # tells Join the left and right columns have the same name
911 2260 aaronmk
912 2353 aaronmk
# Tells Join the left and right columns have the same name and are never NULL
913
join_same_not_null = object()
914
915 2260 aaronmk
filter_out = object() # tells Join to filter out rows that match the join
916
917 2514 aaronmk
class Join(BasicObject):
918 2746 aaronmk
    def __init__(self, table, mapping={}, type_=None):
919 2260 aaronmk
        '''
920
        @param mapping dict(right_table_col=left_table_col, ...)
921 7176 aaronmk
            or [using_col...]
922 2352 aaronmk
            * if left_table_col is join_same: left_table_col = right_table_col
923 2353 aaronmk
              * Note that right_table_col must be a string
924
            * if left_table_col is join_same_not_null:
925
              left_table_col = right_table_col and both have NOT NULL constraint
926
              * Note that right_table_col must be a string
927 2260 aaronmk
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
928
            * filter_out: equivalent to 'LEFT' with the query filtered by
929
              `table_pkey IS NULL` (indicating no match)
930
        '''
931
        if util.is_str(table): table = Table(table)
932 7176 aaronmk
        if lists.is_seq(mapping):
933
            mapping = dict(((c, join_same_not_null) for c in mapping))
934 2260 aaronmk
        assert type_ == None or util.is_str(type_) or type_ is filter_out
935
936
        self.table = table
937
        self.mapping = mapping
938
        self.type_ = type_
939
940 2749 aaronmk
    def to_str(self, db, left_table_):
941 2260 aaronmk
        def join(entry):
942
            '''Parses non-USING joins'''
943
            right_table_col, left_table_col = entry
944
945 2353 aaronmk
            # Switch order (right_table_col is on the left in the comparison)
946
            left = right_table_col
947
            right = left_table_col
948 2749 aaronmk
            left_table = self.table
949
            right_table = left_table_
950 2353 aaronmk
951 2747 aaronmk
            # Parse left side
952 2748 aaronmk
            left = with_default_table(left, left_table)
953 2747 aaronmk
954 2260 aaronmk
            # Parse special values
955 2747 aaronmk
            left_on_right = Col(left.name, right_table)
956
            if right is join_same: right = left_on_right
957 2353 aaronmk
            elif right is join_same_not_null:
958 2747 aaronmk
                right = CompareCond(left_on_right, '~=')
959 2260 aaronmk
960 2747 aaronmk
            # Parse right side
961 2353 aaronmk
            right = as_ValueCond(right, right_table)
962 2747 aaronmk
963
            return right.to_str(db, left)
964 2260 aaronmk
965 2265 aaronmk
        # Create join condition
966
        type_ = self.type_
967 2276 aaronmk
        joins = self.mapping
968 2746 aaronmk
        if joins == {}: join_cond = None
969
        elif type_ is not filter_out and reduce(operator.and_,
970 2460 aaronmk
            (v is join_same_not_null for v in joins.itervalues())):
971 2260 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
972 2747 aaronmk
            cols = map(to_name_only_col, joins.iterkeys())
973
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
974 2757 aaronmk
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
975 2260 aaronmk
976 2757 aaronmk
        if isinstance(self.table, NamedTable): whitespace = '\n'
977
        else: whitespace = ' '
978
979 2260 aaronmk
        # Create join
980
        if type_ is filter_out: type_ = 'LEFT'
981 2266 aaronmk
        str_ = ''
982
        if type_ != None: str_ += type_+' '
983 2757 aaronmk
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
984
        if join_cond != None: str_ += whitespace+join_cond
985 2266 aaronmk
        return str_
986 2349 aaronmk
987 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
988 2424 aaronmk
989
##### Value exprs
990
991 3089 aaronmk
all_cols = CustomCode('*')
992
993 2737 aaronmk
default = CustomCode('DEFAULT')
994
995 3090 aaronmk
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
996 2674 aaronmk
997 3061 aaronmk
class Coalesce(FunctionCall):
998
    def __init__(self, *args):
999
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
1000 3060 aaronmk
1001 3062 aaronmk
class Nullif(FunctionCall):
1002
    def __init__(self, *args):
1003
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
1004
1005 5891 aaronmk
null = Literal(None)
1006 5892 aaronmk
null_as_str = Cast('text', null)
1007 3706 aaronmk
1008
def to_text(value): return Coalesce(Cast('text', value), null_as_str)
1009
1010 2850 aaronmk
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
1011 2958 aaronmk
null_sentinels = {
1012
    'character varying': r'\N',
1013
    'double precision': 'NaN',
1014
    'integer': 2147483647,
1015
    'text': r'\N',
1016 5498 aaronmk
    'date': 'infinity',
1017 5266 aaronmk
    'timestamp with time zone': 'infinity',
1018
    'taxonrank': 'unknown',
1019 2958 aaronmk
}
1020 2692 aaronmk
1021 3061 aaronmk
class EnsureNotNull(Coalesce):
1022 2850 aaronmk
    def __init__(self, value, type_):
1023 4938 aaronmk
        if isinstance(type_, ArrayType): null = []
1024
        else: null = null_sentinels[type_]
1025
        Coalesce.__init__(self, as_Col(value), Cast(type_, null))
1026 2850 aaronmk
1027
        self.type = type_
1028 3001 aaronmk
1029
    def to_str(self, db):
1030
        col = self.args[0]
1031
        index_col_ = index_col(col)
1032
        if index_col_ != None: return index_col_.to_str(db)
1033 3061 aaronmk
        return Coalesce.to_str(self, db)
1034 2850 aaronmk
1035 3523 aaronmk
#### Arrays
1036
1037 3535 aaronmk
class ArrayMerge(FunctionCall):
1038 3523 aaronmk
    def __init__(self, sep, array):
1039
        array = to_Array(array)
1040
        FunctionCall.__init__(self, InternalFunction('array_to_string'), array,
1041
            sep)
1042
1043 3537 aaronmk
def merge_not_null(db, sep, values):
1044 3707 aaronmk
    return ArrayMerge(sep, map(to_text, values))
1045 3537 aaronmk
1046 2737 aaronmk
##### Table exprs
1047
1048
class Values(Code):
1049
    def __init__(self, values):
1050 2739 aaronmk
        '''
1051
        @param values [...]|[[...], ...] Can be one or multiple rows.
1052
        '''
1053 3445 aaronmk
        Code.__init__(self)
1054
1055 2739 aaronmk
        rows = values
1056
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
1057
            rows = [values]
1058
        for i, row in enumerate(rows):
1059
            rows[i] = map(remove_col_rename, map(as_Value, row))
1060 2737 aaronmk
1061 2739 aaronmk
        self.rows = rows
1062 2737 aaronmk
1063
    def to_str(self, db):
1064 3520 aaronmk
        return 'VALUES '+(', '.join((Tuple(*r).to_str(db) for r in self.rows)))
1065 2737 aaronmk
1066 2740 aaronmk
def NamedValues(name, cols, values):
1067 2745 aaronmk
    '''
1068 3048 aaronmk
    @param cols None|[...]
1069 2745 aaronmk
    @post `cols` will be changed to Col objects with the table set to `name`.
1070
    '''
1071 2834 aaronmk
    table = NamedTable(name, Values(values), cols)
1072 3048 aaronmk
    if cols != None: set_cols_table(table, cols)
1073 2834 aaronmk
    return table
1074 2740 aaronmk
1075 2674 aaronmk
##### Database structure
1076
1077 5074 aaronmk
def is_nullable(db, value):
1078 5330 aaronmk
    if not is_table_col(value): return is_null(value)
1079
    try: return db.col_info(value).nullable
1080 5074 aaronmk
    except NoUnderlyingTableException: return True # not a table column
1081
1082 4442 aaronmk
text_types = set(['character varying', 'text'])
1083 4406 aaronmk
1084 5334 aaronmk
def is_text_type(type_): return type_ in text_types
1085
1086 5335 aaronmk
def is_text_col(db, col): return is_text_type(db.col_info(col).type)
1087 4442 aaronmk
1088 5397 aaronmk
def canon_type(type_):
1089
    if type_ in text_types: return 'text'
1090
    else: return type_
1091
1092 2840 aaronmk
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
1093
1094 2851 aaronmk
def ensure_not_null(db, col, type_=None):
1095 2840 aaronmk
    '''
1096 2855 aaronmk
    @param col If type_ is not set, must have an underlying column.
1097 4488 aaronmk
    @param type_ If set, overrides the underlying column's type and casts the
1098
        column to it if needed.
1099 2840 aaronmk
    @return EnsureNotNull|Col
1100
    @throws ensure_not_null_excs
1101
    '''
1102 5329 aaronmk
    col = remove_col_rename(col)
1103 5332 aaronmk
1104
    try: col_type = db.col_info(underlying_col(col)).type
1105 2855 aaronmk
    except NoUnderlyingTableException:
1106 5333 aaronmk
        if type_ == None and is_null(col): raise # NULL has no type
1107 2855 aaronmk
    else:
1108 5332 aaronmk
        if type_ == None: type_ = col_type
1109
        elif type_ != col_type: col = Cast(type_, col)
1110 2855 aaronmk
1111 5331 aaronmk
    if is_nullable(db, col):
1112 2953 aaronmk
        try: col = EnsureNotNull(col, type_)
1113
        except KeyError, e:
1114
            # Warn of no null sentinel for type, even if caller catches error
1115
            warnings.warn(UserWarning(exc.str_(e)))
1116
            raise
1117
1118 2840 aaronmk
    return col
1119 3536 aaronmk
1120
def try_mk_not_null(db, value):
1121
    '''
1122
    Warning: This function does not guarantee that its result is NOT NULL.
1123
    '''
1124
    try: return ensure_not_null(db, value)
1125
    except ensure_not_null_excs: return value
1126 5367 aaronmk
1127
##### Expression transforming
1128
1129
true_expr = 'true'
1130
false_expr = 'false'
1131
1132
true_re = true_expr
1133
false_re = false_expr
1134
bool_re = r'(?:'+true_re+r'|'+false_re+r')'
1135
atom_re = r'(?:'+bool_re+r'|\([^()]*\)'+r')'
1136
1137
def logic_op_re(op, value_re, expr_re=''):
1138
    op_re = ' '+op+' '
1139
    return '(?:'+expr_re+op_re+value_re+'|'+value_re+op_re+expr_re+')'
1140
1141 5370 aaronmk
not_re = r'\bNOT '
1142 5372 aaronmk
not_false_re = not_re+false_re+r'\b'
1143
not_true_re = not_re+true_re+r'\b'
1144 5367 aaronmk
and_false_re = logic_op_re('AND', false_re, atom_re)
1145 5370 aaronmk
and_false_not_true_re = '(?:'+not_true_re+'|'+and_false_re+')'
1146 5367 aaronmk
and_true_re = logic_op_re('AND', true_re)
1147
or_re = logic_op_re('OR', bool_re)
1148
or_and_true_re = '(?:'+and_true_re+'|'+or_re+')'
1149
1150
def simplify_parens(expr):
1151
    return regexp.sub_nested(r'\(('+atom_re+')\)', r'\1', expr)
1152
1153
def simplify_recursive(sub_func, expr):
1154
    '''
1155
    @param sub_func See regexp.sub_recursive() sub_func param
1156
    '''
1157 5378 aaronmk
    return regexp.sub_recursive(lambda s: sub_func(simplify_parens(s)), expr)
1158
        # simplify_parens() is also done at end in final iteration
1159 5367 aaronmk
1160
def simplify_expr(expr):
1161
    def simplify_logic_ops(expr):
1162
        total_n = 0
1163 5371 aaronmk
        expr, n = re.subn(not_false_re, true_re, expr)
1164
        total_n += n
1165 5370 aaronmk
        expr, n = re.subn(and_false_not_true_re, false_expr, expr)
1166 5367 aaronmk
        total_n += n
1167
        expr, n = re.subn(or_and_true_re, r'', expr)
1168
        total_n += n
1169
        return expr, total_n
1170
1171 5826 aaronmk
    expr = expr.replace('NULL IS NULL', true_expr)
1172
    expr = expr.replace('NULL IS NOT NULL', false_expr)
1173 5367 aaronmk
    expr = simplify_recursive(simplify_logic_ops, expr)
1174
    return expr
1175
1176
name_re = r'(?:\w+|(?:"[^"]*")+)'
1177
1178
def parse_expr_col(str_):
1179
    match = re.match(r'^\('+name_re+r'\(('+name_re+r').*\)\)$', str_)
1180
    if match: str_ = match.group(1)
1181
    return unesc_name(str_)
1182
1183
def map_expr(db, expr, mapping, in_cols_found=None):
1184
    '''Replaces output columns with input columns in an expression.
1185
    @param in_cols_found If set, will be filled in with the expr's (input) cols
1186
    '''
1187
    for out, in_ in mapping.iteritems():
1188
        orig_expr = expr
1189
        out = to_name_only_col(out)
1190
        in_str = to_name_only_col(remove_col_rename(in_)).to_str(db)
1191
1192
        # Replace out both with and without quotes
1193
        expr = expr.replace(out.to_str(db), in_str)
1194 5507 aaronmk
        expr = re.sub(r'(?<!["\'\.=\[])\b'+out.name+r'\b(?!["\',\.=\]])',
1195
            in_str, expr)
1196 5367 aaronmk
1197
        if in_cols_found != None and expr != orig_expr: # replaced something
1198
            in_cols_found.append(in_)
1199
1200
    return simplify_expr(expr)