Project

General

Profile

1 2211 aaronmk
# SQL code generation
2
3 2748 aaronmk
import copy
4 2276 aaronmk
import operator
5 3158 aaronmk
from ordereddict import OrderedDict
6 2568 aaronmk
import re
7 2653 aaronmk
import UserDict
8 2953 aaronmk
import warnings
9 2276 aaronmk
10 2667 aaronmk
import dicts
11 2953 aaronmk
import exc
12 2701 aaronmk
import iters
13
import lists
14 2360 aaronmk
import objects
15 2222 aaronmk
import strings
16 2227 aaronmk
import util
17 2211 aaronmk
18 2587 aaronmk
##### Names
19 2499 aaronmk
20 2608 aaronmk
identifier_max_len = 63 # works for both PostgreSQL and MySQL
21 2587 aaronmk
22 2932 aaronmk
def concat(str_, suffix):
23 2609 aaronmk
    '''Preserves version so that it won't be truncated off the string, leading
24
    to collisions.'''
25 2613 aaronmk
    # Preserve version
26 2995 aaronmk
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
27 2985 aaronmk
    if match:
28
        str_, old_suffix = match.groups()
29
        suffix = old_suffix+suffix
30 2613 aaronmk
31 2932 aaronmk
    return strings.concat(str_, suffix, identifier_max_len)
32 2587 aaronmk
33 2932 aaronmk
def truncate(str_): return concat(str_, '')
34 2842 aaronmk
35 2575 aaronmk
def is_safe_name(name):
36 2583 aaronmk
    '''A name is safe *and unambiguous* if it:
37
    * contains only *lowercase* word (\w) characters
38
    * doesn't start with a digit
39
    * contains "_", so that it's not a keyword
40 2984 aaronmk
    '''
41
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
42 2568 aaronmk
43 2499 aaronmk
def esc_name(name, quote='"'):
44
    return quote + name.replace(quote, quote+quote) + quote
45
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
46
47 3320 aaronmk
def unesc_name(name, quote='"'):
48
    removed_ref = [False]
49
    name = strings.remove_prefix(quote, name, removed_ref)
50
    if removed_ref[0]:
51
        name = strings.remove_suffix(quote, name, removed_ref)
52
        assert removed_ref[0]
53
        name = name.replace(quote+quote, quote)
54
    return name
55
56 2513 aaronmk
def clean_name(name): return name.replace('"', '').replace('`', '')
57
58 3041 aaronmk
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
59
60 3182 aaronmk
def lstrip(str_):
61
    '''Also removes comments.'''
62
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
63
    return str_.lstrip()
64
65 2659 aaronmk
##### General SQL code objects
66 2219 aaronmk
67 2349 aaronmk
class MockDb:
68 2503 aaronmk
    def esc_value(self, value): return strings.repr_no_u(value)
69 2349 aaronmk
70 2499 aaronmk
    def esc_name(self, name): return esc_name(name)
71 2859 aaronmk
72
    def col_info(self, col):
73
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
74
75 2349 aaronmk
mockDb = MockDb()
76
77 2514 aaronmk
class BasicObject(objects.BasicObject):
78
    def __init__(self, value): self.value = value
79
80
    def __str__(self): return clean_name(strings.repr_no_u(self))
81
82 2659 aaronmk
##### Unparameterized code objects
83
84 2514 aaronmk
class Code(BasicObject):
85 2658 aaronmk
    def to_str(self, db): raise NotImplementedError()
86 2349 aaronmk
87 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb)
88 2211 aaronmk
89 2269 aaronmk
class CustomCode(Code):
90 2256 aaronmk
    def __init__(self, str_): self.str_ = str_
91
92
    def to_str(self, db): return self.str_
93
94 2815 aaronmk
def as_Code(value, db=None):
95
    '''
96
    @param db If set, runs db.std_code() on the value.
97
    '''
98
    if util.is_str(value):
99
        if db != None: value = db.std_code(value)
100
        return CustomCode(value)
101 2659 aaronmk
    else: return Literal(value)
102
103 2540 aaronmk
class Expr(Code):
104
    def __init__(self, expr): self.expr = expr
105
106
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
107
108 3086 aaronmk
##### Names
109
110
class Name(Code):
111 3087 aaronmk
    def __init__(self, name):
112
        name = truncate(name)
113
114
        self.name = name
115 3086 aaronmk
116
    def to_str(self, db): return db.esc_name(self.name)
117
118
def as_Name(value):
119
    if isinstance(value, Code): return value
120
    else: return Name(value)
121
122 2335 aaronmk
##### Literal values
123
124 2216 aaronmk
class Literal(Code):
125 2211 aaronmk
    def __init__(self, value): self.value = value
126 2213 aaronmk
127
    def to_str(self, db): return db.esc_value(self.value)
128 2211 aaronmk
129 2400 aaronmk
def as_Value(value):
130
    if isinstance(value, Code): return value
131
    else: return Literal(value)
132
133 2216 aaronmk
def is_null(value): return isinstance(value, Literal) and value.value == None
134
135 2711 aaronmk
##### Derived elements
136
137
src_self = object() # tells Col that it is its own source column
138
139
class Derived(Code):
140
    def __init__(self, srcs):
141 2712 aaronmk
        '''An element which was derived from some other element(s).
142 2711 aaronmk
        @param srcs See self.set_srcs()
143
        '''
144
        self.set_srcs(srcs)
145
146 2713 aaronmk
    def set_srcs(self, srcs, overwrite=True):
147 2711 aaronmk
        '''
148
        @param srcs (self_type...)|src_self The element(s) this is derived from
149
        '''
150 2713 aaronmk
        if not overwrite and self.srcs != (): return # already set
151
152 2711 aaronmk
        if srcs == src_self: srcs = (self,)
153
        srcs = tuple(srcs) # make Col hashable
154
        self.srcs = srcs
155
156
    def _compare_on(self):
157
        compare_on = self.__dict__.copy()
158
        del compare_on['srcs'] # ignore
159
        return compare_on
160
161
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
162
163 2335 aaronmk
##### Tables
164
165 2712 aaronmk
class Table(Derived):
166 2991 aaronmk
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
167 2211 aaronmk
        '''
168
        @param schema str|None (for no schema)
169 2712 aaronmk
        @param srcs (Table...)|src_self See Derived.set_srcs()
170 2211 aaronmk
        '''
171 2712 aaronmk
        Derived.__init__(self, srcs)
172
173 3091 aaronmk
        if util.is_str(name): name = truncate(name)
174 2843 aaronmk
175 2211 aaronmk
        self.name = name
176
        self.schema = schema
177 2991 aaronmk
        self.is_temp = is_temp
178 3000 aaronmk
        self.index_cols = {}
179 2211 aaronmk
180 2348 aaronmk
    def to_str(self, db):
181
        str_ = ''
182 3088 aaronmk
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
183
        str_ += as_Name(self.name).to_str(db)
184 2348 aaronmk
        return str_
185 2336 aaronmk
186
    def to_Table(self): return self
187 3000 aaronmk
188
    def _compare_on(self):
189
        compare_on = Derived._compare_on(self)
190
        del compare_on['index_cols'] # ignore
191
        return compare_on
192 2211 aaronmk
193 2835 aaronmk
def is_underlying_table(table):
194
    return isinstance(table, Table) and table.to_Table() is table
195 2832 aaronmk
196 2902 aaronmk
class NoUnderlyingTableException(Exception): pass
197
198
def underlying_table(table):
199
    table = remove_table_rename(table)
200
    if not is_underlying_table(table): raise NoUnderlyingTableException
201
    return table
202
203 2776 aaronmk
def as_Table(table, schema=None):
204 2270 aaronmk
    if table == None or isinstance(table, Code): return table
205 2776 aaronmk
    else: return Table(table, schema)
206 2219 aaronmk
207 3101 aaronmk
def suffixed_table(table, suffix):
208 3128 aaronmk
    table = copy.copy(table) # don't modify input!
209
    table.name = concat(table.name, suffix)
210
    return table
211 2707 aaronmk
212 2336 aaronmk
class NamedTable(Table):
213
    def __init__(self, name, code, cols=None):
214
        Table.__init__(self, name)
215
216 3016 aaronmk
        code = as_Table(code)
217 2741 aaronmk
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
218 3020 aaronmk
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
219 2336 aaronmk
220
        self.code = code
221
        self.cols = cols
222
223
    def to_str(self, db):
224 3026 aaronmk
        str_ = self.code.to_str(db)
225
        if str_.find('\n') >= 0: whitespace = '\n'
226
        else: whitespace = ' '
227
        str_ += whitespace+'AS '+Table.to_str(self, db)
228 2742 aaronmk
        if self.cols != None:
229
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
230 2336 aaronmk
        return str_
231
232
    def to_Table(self): return Table(self.name)
233
234 2753 aaronmk
def remove_table_rename(table):
235
    if isinstance(table, NamedTable): table = table.code
236
    return table
237
238 2335 aaronmk
##### Columns
239
240 2711 aaronmk
class Col(Derived):
241 2701 aaronmk
    def __init__(self, name, table=None, srcs=()):
242 2211 aaronmk
        '''
243
        @param table Table|None (for no table)
244 2711 aaronmk
        @param srcs (Col...)|src_self See Derived.set_srcs()
245 2211 aaronmk
        '''
246 2711 aaronmk
        Derived.__init__(self, srcs)
247
248 3091 aaronmk
        if util.is_str(name): name = truncate(name)
249 2241 aaronmk
        if util.is_str(table): table = Table(table)
250 2211 aaronmk
        assert table == None or isinstance(table, Table)
251
252
        self.name = name
253
        self.table = table
254
255 2989 aaronmk
    def to_str(self, db, for_str=False):
256 3088 aaronmk
        str_ = as_Name(self.name).to_str(db)
257 2989 aaronmk
        if for_str: str_ = clean_name(str_)
258 2933 aaronmk
        if self.table != None:
259 2989 aaronmk
            table = self.table.to_Table()
260
            if for_str: str_ = concat(str(table), '.'+str_)
261
            else: str_ = table.to_str(db)+'.'+str_
262 2211 aaronmk
        return str_
263 2314 aaronmk
264 2989 aaronmk
    def __str__(self): return self.to_str(mockDb, for_str=True)
265 2933 aaronmk
266 2314 aaronmk
    def to_Col(self): return self
267 2211 aaronmk
268 2767 aaronmk
def is_table_col(col): return isinstance(col, Col) and col.table != None
269 2393 aaronmk
270 3000 aaronmk
def index_col(col):
271
    if not is_table_col(col): return None
272 3104 aaronmk
273
    table = col.table
274
    try: name = table.index_cols[col.name]
275
    except KeyError: return None
276
    else: return Col(name, table, col.srcs)
277 2999 aaronmk
278 3024 aaronmk
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
279 2996 aaronmk
280 2563 aaronmk
def as_Col(col, table=None, name=None):
281
    '''
282
    @param name If not None, any non-Col input will be renamed using NamedCol.
283
    '''
284
    if name != None:
285
        col = as_Value(col)
286
        if not isinstance(col, Col): col = NamedCol(name, col)
287 2333 aaronmk
288
    if isinstance(col, Code): return col
289 2260 aaronmk
    else: return Col(col, table)
290
291 3093 aaronmk
def with_table(col, table):
292 3105 aaronmk
    if isinstance(col, NamedCol): pass # doesn't take a table
293
    elif isinstance(col, FunctionCall):
294
        col = copy.deepcopy(col) # don't modify input!
295
        col.args[0].table = table
296
    else:
297 3098 aaronmk
        col = copy.copy(col) # don't modify input!
298
        col.table = table
299 3093 aaronmk
    return col
300
301 3100 aaronmk
def with_default_table(col, table):
302 2747 aaronmk
    col = as_Col(col)
303 3100 aaronmk
    if col.table == None: col = with_table(col, table)
304 2747 aaronmk
    return col
305
306 2744 aaronmk
def set_cols_table(table, cols):
307
    table = as_Table(table)
308
309
    for i, col in enumerate(cols):
310
        col = cols[i] = as_Col(col)
311
        col.table = table
312
313 2401 aaronmk
def to_name_only_col(col, check_table=None):
314
    col = as_Col(col)
315 3020 aaronmk
    if not is_table_col(col): return col
316 2401 aaronmk
317
    if check_table != None:
318
        table = col.table
319
        assert table == None or table == check_table
320
    return Col(col.name)
321
322 2993 aaronmk
def suffixed_col(col, suffix):
323
    return Col(concat(col.name, suffix), col.table, col.srcs)
324
325 2323 aaronmk
class NamedCol(Col):
326 2229 aaronmk
    def __init__(self, name, code):
327 2310 aaronmk
        Col.__init__(self, name)
328
329 3016 aaronmk
        code = as_Value(code)
330 2229 aaronmk
331
        self.code = code
332
333
    def to_str(self, db):
334 2310 aaronmk
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
335 2314 aaronmk
336
    def to_Col(self): return Col(self.name)
337 2229 aaronmk
338 2462 aaronmk
def remove_col_rename(col):
339
    if isinstance(col, NamedCol): col = col.code
340
    return col
341
342 2830 aaronmk
def underlying_col(col):
343
    col = remove_col_rename(col)
344 2849 aaronmk
    if not isinstance(col, Col): raise NoUnderlyingTableException
345
346 2902 aaronmk
    return Col(col.name, underlying_table(col.table), col.srcs)
347 2830 aaronmk
348 2703 aaronmk
def wrap(wrap_func, value):
349
    '''Wraps a value, propagating any column renaming to the returned value.'''
350
    if isinstance(value, NamedCol):
351
        return NamedCol(value.name, wrap_func(value.code))
352
    else: return wrap_func(value)
353
354 2667 aaronmk
class ColDict(dicts.DictProxy):
355 2564 aaronmk
    '''A dict that automatically makes inserted entries Col objects'''
356
357 2645 aaronmk
    def __init__(self, db, keys_table, dict_={}):
358 3158 aaronmk
        dicts.DictProxy.__init__(self, OrderedDict())
359 2667 aaronmk
360 2645 aaronmk
        keys_table = as_Table(keys_table)
361
362 2642 aaronmk
        self.db = db
363 2641 aaronmk
        self.table = keys_table
364 2653 aaronmk
        self.update(dict_) # after setting vars because __setitem__() needs them
365 2641 aaronmk
366 2667 aaronmk
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
367 2655 aaronmk
368 2667 aaronmk
    def __getitem__(self, key):
369
        return dicts.DictProxy.__getitem__(self, self._key(key))
370 2653 aaronmk
371 2564 aaronmk
    def __setitem__(self, key, value):
372 2642 aaronmk
        key = self._key(key)
373 2819 aaronmk
        if value == None: value = self.db.col_info(key).default
374 2667 aaronmk
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
375 2564 aaronmk
376 2641 aaronmk
    def _key(self, key): return as_Col(key, self.table)
377 2564 aaronmk
378 2524 aaronmk
##### Functions
379
380 2912 aaronmk
Function = Table
381 2911 aaronmk
as_Function = as_Table
382
383 2691 aaronmk
class InternalFunction(CustomCode): pass
384
385 2941 aaronmk
class NamedArg(NamedCol):
386
    def __init__(self, name, value):
387
        NamedCol.__init__(self, name, value)
388
389
    def to_str(self, db):
390
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
391
392 2524 aaronmk
class FunctionCall(Code):
393 2941 aaronmk
    def __init__(self, function, *args, **kw_args):
394 2524 aaronmk
        '''
395 2690 aaronmk
        @param args [Code|literal-value...] The function's arguments
396 2524 aaronmk
        '''
397 3016 aaronmk
        function = as_Function(function)
398 2941 aaronmk
        def filter_(arg): return remove_col_rename(as_Value(arg))
399
        args = map(filter_, args)
400
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
401 2524 aaronmk
402
        self.function = function
403
        self.args = args
404
405
    def to_str(self, db):
406
        args_str = ', '.join((v.to_str(db) for v in self.args))
407
        return self.function.to_str(db)+'('+args_str+')'
408
409 2533 aaronmk
def wrap_in_func(function, value):
410
    '''Wraps a value inside a function call.
411
    Propagates any column renaming to the returned value.
412
    '''
413 2703 aaronmk
    return wrap(lambda v: FunctionCall(function, v), value)
414 2533 aaronmk
415 2561 aaronmk
def unwrap_func_call(func_call, check_name=None):
416
    '''Unwraps any function call to its first argument.
417
    Also removes any column renaming.
418
    '''
419
    func_call = remove_col_rename(func_call)
420
    if not isinstance(func_call, FunctionCall): return func_call
421
422
    if check_name != None:
423
        name = func_call.function.name
424
        assert name == None or name == check_name
425
    return func_call.args[0]
426
427 2986 aaronmk
##### Casts
428
429
class Cast(FunctionCall):
430
    def __init__(self, type_, value):
431
        value = as_Value(value)
432
433
        self.type_ = type_
434
        self.value = value
435
436
    def to_str(self, db):
437
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_+')'
438
439 3354 aaronmk
def cast_literal(value):
440
    if not isinstance(value, Literal): return value
441
442
    if util.is_str(value.value): value = Cast('text', value)
443
    return value
444
445 2335 aaronmk
##### Conditions
446 2259 aaronmk
447 3350 aaronmk
class NotCond(Code):
448
    def __init__(self, cond):
449
        self.cond = cond
450
451
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
452
453 2398 aaronmk
class ColValueCond(Code):
454
    def __init__(self, col, value):
455
        value = as_ValueCond(value)
456
457
        self.col = col
458
        self.value = value
459
460
    def to_str(self, db): return self.value.to_str(db, self.col)
461
462 2577 aaronmk
def combine_conds(conds, keyword=None):
463
    '''
464
    @param keyword The keyword to add before the conditions, if any
465
    '''
466
    str_ = ''
467
    if keyword != None:
468
        if conds == []: whitespace = ''
469
        elif len(conds) == 1: whitespace = ' '
470
        else: whitespace = '\n'
471
        str_ += keyword+whitespace
472
473
    str_ += '\nAND '.join(conds)
474
    return str_
475
476 2398 aaronmk
##### Condition column comparisons
477
478 2514 aaronmk
class ValueCond(BasicObject):
479 2213 aaronmk
    def __init__(self, value):
480 2858 aaronmk
        value = remove_col_rename(as_Value(value))
481 2213 aaronmk
482
        self.value = value
483 2214 aaronmk
484 2216 aaronmk
    def to_str(self, db, left_value):
485 2214 aaronmk
        '''
486 2216 aaronmk
        @param left_value The Code object that the condition is being applied on
487 2214 aaronmk
        '''
488
        raise NotImplemented()
489 2228 aaronmk
490 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
491 2211 aaronmk
492
class CompareCond(ValueCond):
493
    def __init__(self, value, operator='='):
494 2222 aaronmk
        '''
495
        @param operator By default, compares NULL values literally. Use '~=' or
496
            '~!=' to pass NULLs through.
497
        '''
498 2211 aaronmk
        ValueCond.__init__(self, value)
499
        self.operator = operator
500
501 2216 aaronmk
    def to_str(self, db, left_value):
502 2858 aaronmk
        left_value = remove_col_rename(as_Col(left_value))
503 2216 aaronmk
504 2222 aaronmk
        right_value = self.value
505
506
        # Parse operator
507 2216 aaronmk
        operator = self.operator
508 2222 aaronmk
        passthru_null_ref = [False]
509
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
510
        neg_ref = [False]
511
        operator = strings.remove_prefix('!', operator, neg_ref)
512 2844 aaronmk
        equals = operator.endswith('=') # also includes <=, >=
513 2222 aaronmk
514 2825 aaronmk
        # Handle nullable columns
515
        check_null = False
516 2844 aaronmk
        if not passthru_null_ref[0]: # NULLs compare equal
517 2857 aaronmk
            try: left_value = ensure_not_null(db, left_value)
518 2844 aaronmk
            except ensure_not_null_excs: # fall back to alternate method
519
                check_null = equals and isinstance(right_value, Col)
520 2837 aaronmk
            else:
521 2857 aaronmk
                if isinstance(left_value, EnsureNotNull):
522
                    right_value = ensure_not_null(db, right_value,
523
                        left_value.type) # apply same function to both sides
524 2825 aaronmk
525 2844 aaronmk
        if equals and is_null(right_value): operator = 'IS'
526
527 2825 aaronmk
        left = left_value.to_str(db)
528
        right = right_value.to_str(db)
529
530 2222 aaronmk
        # Create str
531
        str_ = left+' '+operator+' '+right
532 2825 aaronmk
        if check_null:
533 2578 aaronmk
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
534
        if neg_ref[0]: str_ = 'NOT '+str_
535 2222 aaronmk
        return str_
536 2216 aaronmk
537 2260 aaronmk
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
538
assume_literal = object()
539
540
def as_ValueCond(value, default_table=assume_literal):
541
    if not isinstance(value, ValueCond):
542
        if default_table is not assume_literal:
543 2748 aaronmk
            value = with_default_table(value, default_table)
544 2260 aaronmk
        return CompareCond(value)
545 2216 aaronmk
    else: return value
546 2219 aaronmk
547 2335 aaronmk
##### Joins
548
549 2352 aaronmk
join_same = object() # tells Join the left and right columns have the same name
550 2260 aaronmk
551 2353 aaronmk
# Tells Join the left and right columns have the same name and are never NULL
552
join_same_not_null = object()
553
554 2260 aaronmk
filter_out = object() # tells Join to filter out rows that match the join
555
556 2514 aaronmk
class Join(BasicObject):
557 2746 aaronmk
    def __init__(self, table, mapping={}, type_=None):
558 2260 aaronmk
        '''
559
        @param mapping dict(right_table_col=left_table_col, ...)
560 2352 aaronmk
            * if left_table_col is join_same: left_table_col = right_table_col
561 2353 aaronmk
              * Note that right_table_col must be a string
562
            * if left_table_col is join_same_not_null:
563
              left_table_col = right_table_col and both have NOT NULL constraint
564
              * Note that right_table_col must be a string
565 2260 aaronmk
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
566
            * filter_out: equivalent to 'LEFT' with the query filtered by
567
              `table_pkey IS NULL` (indicating no match)
568
        '''
569
        if util.is_str(table): table = Table(table)
570
        assert type_ == None or util.is_str(type_) or type_ is filter_out
571
572
        self.table = table
573
        self.mapping = mapping
574
        self.type_ = type_
575
576 2749 aaronmk
    def to_str(self, db, left_table_):
577 2260 aaronmk
        def join(entry):
578
            '''Parses non-USING joins'''
579
            right_table_col, left_table_col = entry
580
581 2353 aaronmk
            # Switch order (right_table_col is on the left in the comparison)
582
            left = right_table_col
583
            right = left_table_col
584 2749 aaronmk
            left_table = self.table
585
            right_table = left_table_
586 2353 aaronmk
587 2747 aaronmk
            # Parse left side
588 2748 aaronmk
            left = with_default_table(left, left_table)
589 2747 aaronmk
590 2260 aaronmk
            # Parse special values
591 2747 aaronmk
            left_on_right = Col(left.name, right_table)
592
            if right is join_same: right = left_on_right
593 2353 aaronmk
            elif right is join_same_not_null:
594 2747 aaronmk
                right = CompareCond(left_on_right, '~=')
595 2260 aaronmk
596 2747 aaronmk
            # Parse right side
597 2353 aaronmk
            right = as_ValueCond(right, right_table)
598 2747 aaronmk
599
            return right.to_str(db, left)
600 2260 aaronmk
601 2265 aaronmk
        # Create join condition
602
        type_ = self.type_
603 2276 aaronmk
        joins = self.mapping
604 2746 aaronmk
        if joins == {}: join_cond = None
605
        elif type_ is not filter_out and reduce(operator.and_,
606 2460 aaronmk
            (v is join_same_not_null for v in joins.itervalues())):
607 2260 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
608 2747 aaronmk
            cols = map(to_name_only_col, joins.iterkeys())
609
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
610 2757 aaronmk
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
611 2260 aaronmk
612 2757 aaronmk
        if isinstance(self.table, NamedTable): whitespace = '\n'
613
        else: whitespace = ' '
614
615 2260 aaronmk
        # Create join
616
        if type_ is filter_out: type_ = 'LEFT'
617 2266 aaronmk
        str_ = ''
618
        if type_ != None: str_ += type_+' '
619 2757 aaronmk
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
620
        if join_cond != None: str_ += whitespace+join_cond
621 2266 aaronmk
        return str_
622 2349 aaronmk
623 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
624 2424 aaronmk
625
##### Value exprs
626
627 3089 aaronmk
all_cols = CustomCode('*')
628
629 2737 aaronmk
default = CustomCode('DEFAULT')
630
631 3090 aaronmk
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
632 2674 aaronmk
633 3061 aaronmk
class Coalesce(FunctionCall):
634
    def __init__(self, *args):
635
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
636 3060 aaronmk
637 3062 aaronmk
class Nullif(FunctionCall):
638
    def __init__(self, *args):
639
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
640
641 2850 aaronmk
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
642 2958 aaronmk
null_sentinels = {
643
    'character varying': r'\N',
644
    'double precision': 'NaN',
645
    'integer': 2147483647,
646
    'text': r'\N',
647
    'timestamp with time zone': 'infinity'
648
}
649 2692 aaronmk
650 3061 aaronmk
class EnsureNotNull(Coalesce):
651 2850 aaronmk
    def __init__(self, value, type_):
652 3061 aaronmk
        Coalesce.__init__(self, as_Col(value),
653 2988 aaronmk
            Cast(type_, null_sentinels[type_]))
654 2850 aaronmk
655
        self.type = type_
656 3001 aaronmk
657
    def to_str(self, db):
658
        col = self.args[0]
659
        index_col_ = index_col(col)
660
        if index_col_ != None: return index_col_.to_str(db)
661 3061 aaronmk
        return Coalesce.to_str(self, db)
662 2850 aaronmk
663 2737 aaronmk
##### Table exprs
664
665
class Values(Code):
666
    def __init__(self, values):
667 2739 aaronmk
        '''
668
        @param values [...]|[[...], ...] Can be one or multiple rows.
669
        '''
670
        rows = values
671
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
672
            rows = [values]
673
        for i, row in enumerate(rows):
674
            rows[i] = map(remove_col_rename, map(as_Value, row))
675 2737 aaronmk
676 2739 aaronmk
        self.rows = rows
677 2737 aaronmk
678
    def to_str(self, db):
679 2739 aaronmk
        def row_str(row):
680
            return '('+(', '.join((v.to_str(db) for v in row)))+')'
681
        return 'VALUES '+(', '.join(map(row_str, self.rows)))
682 2737 aaronmk
683 2740 aaronmk
def NamedValues(name, cols, values):
684 2745 aaronmk
    '''
685 3048 aaronmk
    @param cols None|[...]
686 2745 aaronmk
    @post `cols` will be changed to Col objects with the table set to `name`.
687
    '''
688 2834 aaronmk
    table = NamedTable(name, Values(values), cols)
689 3048 aaronmk
    if cols != None: set_cols_table(table, cols)
690 2834 aaronmk
    return table
691 2740 aaronmk
692 2674 aaronmk
##### Database structure
693
694
class TypedCol(Col):
695 2871 aaronmk
    def __init__(self, name, type_, default=None, nullable=True,
696
        constraints=None):
697 2818 aaronmk
        assert default == None or isinstance(default, Code)
698
699 2674 aaronmk
        Col.__init__(self, name)
700
701
        self.type = type_
702 2818 aaronmk
        self.default = default
703
        self.nullable = nullable
704 2871 aaronmk
        self.constraints = constraints
705 2674 aaronmk
706 2818 aaronmk
    def to_str(self, db):
707
        str_ = Col.to_str(self, db)+' '+self.type
708
        if not self.nullable: str_ += ' NOT NULL'
709
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
710 2871 aaronmk
        if self.constraints != None: str_ += ' '+self.constraints
711 2818 aaronmk
        return str_
712 2674 aaronmk
713
    def to_Col(self): return Col(self.name)
714 2822 aaronmk
715 2840 aaronmk
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
716
717 2851 aaronmk
def ensure_not_null(db, col, type_=None):
718 2840 aaronmk
    '''
719 2855 aaronmk
    @param col If type_ is not set, must have an underlying column.
720 2851 aaronmk
    @param type_ If set, overrides the underlying column's type.
721 2840 aaronmk
    @return EnsureNotNull|Col
722
    @throws ensure_not_null_excs
723
    '''
724 2855 aaronmk
    nullable = True
725
    try: typed_col = db.col_info(underlying_col(col))
726
    except NoUnderlyingTableException:
727 3355 aaronmk
        col = remove_col_rename(col)
728
        if isinstance(col, Literal) and not is_null(col): nullable = False
729
        elif type_ == None: raise
730 2855 aaronmk
    else:
731
        if type_ == None: type_ = typed_col.type
732
        nullable = typed_col.nullable
733
734 2953 aaronmk
    if nullable:
735
        try: col = EnsureNotNull(col, type_)
736
        except KeyError, e:
737
            # Warn of no null sentinel for type, even if caller catches error
738
            warnings.warn(UserWarning(exc.str_(e)))
739
            raise
740
741 2840 aaronmk
    return col