Project

General

Profile

1 2211 aaronmk
# SQL code generation
2
3 2748 aaronmk
import copy
4 2276 aaronmk
import operator
5 3158 aaronmk
from ordereddict import OrderedDict
6 2568 aaronmk
import re
7 2653 aaronmk
import UserDict
8 2953 aaronmk
import warnings
9 2276 aaronmk
10 2667 aaronmk
import dicts
11 2953 aaronmk
import exc
12 2701 aaronmk
import iters
13
import lists
14 2360 aaronmk
import objects
15 2222 aaronmk
import strings
16 2227 aaronmk
import util
17 2211 aaronmk
18 2587 aaronmk
##### Names
19 2499 aaronmk
20 2608 aaronmk
identifier_max_len = 63 # works for both PostgreSQL and MySQL
21 2587 aaronmk
22 2932 aaronmk
def concat(str_, suffix):
23 2609 aaronmk
    '''Preserves version so that it won't be truncated off the string, leading
24
    to collisions.'''
25 2613 aaronmk
    # Preserve version
26 2995 aaronmk
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
27 2985 aaronmk
    if match:
28
        str_, old_suffix = match.groups()
29
        suffix = old_suffix+suffix
30 2613 aaronmk
31 2932 aaronmk
    return strings.concat(str_, suffix, identifier_max_len)
32 2587 aaronmk
33 2932 aaronmk
def truncate(str_): return concat(str_, '')
34 2842 aaronmk
35 2575 aaronmk
def is_safe_name(name):
36 2583 aaronmk
    '''A name is safe *and unambiguous* if it:
37
    * contains only *lowercase* word (\w) characters
38
    * doesn't start with a digit
39
    * contains "_", so that it's not a keyword
40 2984 aaronmk
    '''
41
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
42 2568 aaronmk
43 2499 aaronmk
def esc_name(name, quote='"'):
44
    return quote + name.replace(quote, quote+quote) + quote
45
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
46
47 3320 aaronmk
def unesc_name(name, quote='"'):
48
    removed_ref = [False]
49
    name = strings.remove_prefix(quote, name, removed_ref)
50
    if removed_ref[0]:
51
        name = strings.remove_suffix(quote, name, removed_ref)
52
        assert removed_ref[0]
53
        name = name.replace(quote+quote, quote)
54
    return name
55
56 2513 aaronmk
def clean_name(name): return name.replace('"', '').replace('`', '')
57
58 3041 aaronmk
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
59
60 3182 aaronmk
def lstrip(str_):
61
    '''Also removes comments.'''
62
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
63
    return str_.lstrip()
64
65 2659 aaronmk
##### General SQL code objects
66 2219 aaronmk
67 2349 aaronmk
class MockDb:
68 2503 aaronmk
    def esc_value(self, value): return strings.repr_no_u(value)
69 2349 aaronmk
70 2499 aaronmk
    def esc_name(self, name): return esc_name(name)
71 2859 aaronmk
72
    def col_info(self, col):
73
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
74
75 2349 aaronmk
mockDb = MockDb()
76
77 2514 aaronmk
class BasicObject(objects.BasicObject):
78
    def __init__(self, value): self.value = value
79
80
    def __str__(self): return clean_name(strings.repr_no_u(self))
81
82 2659 aaronmk
##### Unparameterized code objects
83
84 2514 aaronmk
class Code(BasicObject):
85 2658 aaronmk
    def to_str(self, db): raise NotImplementedError()
86 2349 aaronmk
87 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb)
88 2211 aaronmk
89 2269 aaronmk
class CustomCode(Code):
90 2256 aaronmk
    def __init__(self, str_): self.str_ = str_
91
92
    def to_str(self, db): return self.str_
93
94 2815 aaronmk
def as_Code(value, db=None):
95
    '''
96
    @param db If set, runs db.std_code() on the value.
97
    '''
98
    if util.is_str(value):
99
        if db != None: value = db.std_code(value)
100
        return CustomCode(value)
101 2659 aaronmk
    else: return Literal(value)
102
103 2540 aaronmk
class Expr(Code):
104
    def __init__(self, expr): self.expr = expr
105
106
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
107
108 3086 aaronmk
##### Names
109
110
class Name(Code):
111 3087 aaronmk
    def __init__(self, name):
112
        name = truncate(name)
113
114
        self.name = name
115 3086 aaronmk
116
    def to_str(self, db): return db.esc_name(self.name)
117
118
def as_Name(value):
119
    if isinstance(value, Code): return value
120
    else: return Name(value)
121
122 2335 aaronmk
##### Literal values
123
124 2216 aaronmk
class Literal(Code):
125 2211 aaronmk
    def __init__(self, value): self.value = value
126 2213 aaronmk
127
    def to_str(self, db): return db.esc_value(self.value)
128 2211 aaronmk
129 2400 aaronmk
def as_Value(value):
130
    if isinstance(value, Code): return value
131
    else: return Literal(value)
132
133 3429 aaronmk
def is_literal(value): return isinstance(value, Literal)
134 2216 aaronmk
135 3429 aaronmk
def is_null(value): return is_literal(value) and value.value == None
136
137 2711 aaronmk
##### Derived elements
138
139
src_self = object() # tells Col that it is its own source column
140
141
class Derived(Code):
142
    def __init__(self, srcs):
143 2712 aaronmk
        '''An element which was derived from some other element(s).
144 2711 aaronmk
        @param srcs See self.set_srcs()
145
        '''
146
        self.set_srcs(srcs)
147
148 2713 aaronmk
    def set_srcs(self, srcs, overwrite=True):
149 2711 aaronmk
        '''
150
        @param srcs (self_type...)|src_self The element(s) this is derived from
151
        '''
152 2713 aaronmk
        if not overwrite and self.srcs != (): return # already set
153
154 2711 aaronmk
        if srcs == src_self: srcs = (self,)
155
        srcs = tuple(srcs) # make Col hashable
156
        self.srcs = srcs
157
158
    def _compare_on(self):
159
        compare_on = self.__dict__.copy()
160
        del compare_on['srcs'] # ignore
161
        return compare_on
162
163
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
164
165 2335 aaronmk
##### Tables
166
167 2712 aaronmk
class Table(Derived):
168 2991 aaronmk
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
169 2211 aaronmk
        '''
170
        @param schema str|None (for no schema)
171 2712 aaronmk
        @param srcs (Table...)|src_self See Derived.set_srcs()
172 2211 aaronmk
        '''
173 2712 aaronmk
        Derived.__init__(self, srcs)
174
175 3091 aaronmk
        if util.is_str(name): name = truncate(name)
176 2843 aaronmk
177 2211 aaronmk
        self.name = name
178
        self.schema = schema
179 2991 aaronmk
        self.is_temp = is_temp
180 3000 aaronmk
        self.index_cols = {}
181 2211 aaronmk
182 2348 aaronmk
    def to_str(self, db):
183
        str_ = ''
184 3088 aaronmk
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
185
        str_ += as_Name(self.name).to_str(db)
186 2348 aaronmk
        return str_
187 2336 aaronmk
188
    def to_Table(self): return self
189 3000 aaronmk
190
    def _compare_on(self):
191
        compare_on = Derived._compare_on(self)
192
        del compare_on['index_cols'] # ignore
193
        return compare_on
194 2211 aaronmk
195 2835 aaronmk
def is_underlying_table(table):
196
    return isinstance(table, Table) and table.to_Table() is table
197 2832 aaronmk
198 2902 aaronmk
class NoUnderlyingTableException(Exception): pass
199
200
def underlying_table(table):
201
    table = remove_table_rename(table)
202
    if not is_underlying_table(table): raise NoUnderlyingTableException
203
    return table
204
205 2776 aaronmk
def as_Table(table, schema=None):
206 2270 aaronmk
    if table == None or isinstance(table, Code): return table
207 2776 aaronmk
    else: return Table(table, schema)
208 2219 aaronmk
209 3101 aaronmk
def suffixed_table(table, suffix):
210 3128 aaronmk
    table = copy.copy(table) # don't modify input!
211
    table.name = concat(table.name, suffix)
212
    return table
213 2707 aaronmk
214 2336 aaronmk
class NamedTable(Table):
215
    def __init__(self, name, code, cols=None):
216
        Table.__init__(self, name)
217
218 3016 aaronmk
        code = as_Table(code)
219 2741 aaronmk
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
220 3020 aaronmk
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
221 2336 aaronmk
222
        self.code = code
223
        self.cols = cols
224
225
    def to_str(self, db):
226 3026 aaronmk
        str_ = self.code.to_str(db)
227
        if str_.find('\n') >= 0: whitespace = '\n'
228
        else: whitespace = ' '
229
        str_ += whitespace+'AS '+Table.to_str(self, db)
230 2742 aaronmk
        if self.cols != None:
231
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
232 2336 aaronmk
        return str_
233
234
    def to_Table(self): return Table(self.name)
235
236 2753 aaronmk
def remove_table_rename(table):
237
    if isinstance(table, NamedTable): table = table.code
238
    return table
239
240 2335 aaronmk
##### Columns
241
242 2711 aaronmk
class Col(Derived):
243 2701 aaronmk
    def __init__(self, name, table=None, srcs=()):
244 2211 aaronmk
        '''
245
        @param table Table|None (for no table)
246 2711 aaronmk
        @param srcs (Col...)|src_self See Derived.set_srcs()
247 2211 aaronmk
        '''
248 2711 aaronmk
        Derived.__init__(self, srcs)
249
250 3091 aaronmk
        if util.is_str(name): name = truncate(name)
251 2241 aaronmk
        if util.is_str(table): table = Table(table)
252 2211 aaronmk
        assert table == None or isinstance(table, Table)
253
254
        self.name = name
255
        self.table = table
256
257 2989 aaronmk
    def to_str(self, db, for_str=False):
258 3088 aaronmk
        str_ = as_Name(self.name).to_str(db)
259 2989 aaronmk
        if for_str: str_ = clean_name(str_)
260 2933 aaronmk
        if self.table != None:
261 2989 aaronmk
            table = self.table.to_Table()
262
            if for_str: str_ = concat(str(table), '.'+str_)
263
            else: str_ = table.to_str(db)+'.'+str_
264 2211 aaronmk
        return str_
265 2314 aaronmk
266 2989 aaronmk
    def __str__(self): return self.to_str(mockDb, for_str=True)
267 2933 aaronmk
268 2314 aaronmk
    def to_Col(self): return self
269 2211 aaronmk
270 2767 aaronmk
def is_table_col(col): return isinstance(col, Col) and col.table != None
271 2393 aaronmk
272 3000 aaronmk
def index_col(col):
273
    if not is_table_col(col): return None
274 3104 aaronmk
275
    table = col.table
276
    try: name = table.index_cols[col.name]
277
    except KeyError: return None
278
    else: return Col(name, table, col.srcs)
279 2999 aaronmk
280 3024 aaronmk
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
281 2996 aaronmk
282 2563 aaronmk
def as_Col(col, table=None, name=None):
283
    '''
284
    @param name If not None, any non-Col input will be renamed using NamedCol.
285
    '''
286
    if name != None:
287
        col = as_Value(col)
288
        if not isinstance(col, Col): col = NamedCol(name, col)
289 2333 aaronmk
290
    if isinstance(col, Code): return col
291 2260 aaronmk
    else: return Col(col, table)
292
293 3093 aaronmk
def with_table(col, table):
294 3105 aaronmk
    if isinstance(col, NamedCol): pass # doesn't take a table
295
    elif isinstance(col, FunctionCall):
296
        col = copy.deepcopy(col) # don't modify input!
297
        col.args[0].table = table
298
    else:
299 3098 aaronmk
        col = copy.copy(col) # don't modify input!
300
        col.table = table
301 3093 aaronmk
    return col
302
303 3100 aaronmk
def with_default_table(col, table):
304 2747 aaronmk
    col = as_Col(col)
305 3100 aaronmk
    if col.table == None: col = with_table(col, table)
306 2747 aaronmk
    return col
307
308 2744 aaronmk
def set_cols_table(table, cols):
309
    table = as_Table(table)
310
311
    for i, col in enumerate(cols):
312
        col = cols[i] = as_Col(col)
313
        col.table = table
314
315 2401 aaronmk
def to_name_only_col(col, check_table=None):
316
    col = as_Col(col)
317 3020 aaronmk
    if not is_table_col(col): return col
318 2401 aaronmk
319
    if check_table != None:
320
        table = col.table
321
        assert table == None or table == check_table
322
    return Col(col.name)
323
324 2993 aaronmk
def suffixed_col(col, suffix):
325
    return Col(concat(col.name, suffix), col.table, col.srcs)
326
327 2323 aaronmk
class NamedCol(Col):
328 2229 aaronmk
    def __init__(self, name, code):
329 2310 aaronmk
        Col.__init__(self, name)
330
331 3016 aaronmk
        code = as_Value(code)
332 2229 aaronmk
333
        self.code = code
334
335
    def to_str(self, db):
336 2310 aaronmk
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
337 2314 aaronmk
338
    def to_Col(self): return Col(self.name)
339 2229 aaronmk
340 2462 aaronmk
def remove_col_rename(col):
341
    if isinstance(col, NamedCol): col = col.code
342
    return col
343
344 2830 aaronmk
def underlying_col(col):
345
    col = remove_col_rename(col)
346 2849 aaronmk
    if not isinstance(col, Col): raise NoUnderlyingTableException
347
348 2902 aaronmk
    return Col(col.name, underlying_table(col.table), col.srcs)
349 2830 aaronmk
350 2703 aaronmk
def wrap(wrap_func, value):
351
    '''Wraps a value, propagating any column renaming to the returned value.'''
352
    if isinstance(value, NamedCol):
353
        return NamedCol(value.name, wrap_func(value.code))
354
    else: return wrap_func(value)
355
356 2667 aaronmk
class ColDict(dicts.DictProxy):
357 2564 aaronmk
    '''A dict that automatically makes inserted entries Col objects'''
358
359 2645 aaronmk
    def __init__(self, db, keys_table, dict_={}):
360 3158 aaronmk
        dicts.DictProxy.__init__(self, OrderedDict())
361 2667 aaronmk
362 2645 aaronmk
        keys_table = as_Table(keys_table)
363
364 2642 aaronmk
        self.db = db
365 2641 aaronmk
        self.table = keys_table
366 2653 aaronmk
        self.update(dict_) # after setting vars because __setitem__() needs them
367 2641 aaronmk
368 2667 aaronmk
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
369 2655 aaronmk
370 2667 aaronmk
    def __getitem__(self, key):
371
        return dicts.DictProxy.__getitem__(self, self._key(key))
372 2653 aaronmk
373 2564 aaronmk
    def __setitem__(self, key, value):
374 2642 aaronmk
        key = self._key(key)
375 2819 aaronmk
        if value == None: value = self.db.col_info(key).default
376 2667 aaronmk
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
377 2564 aaronmk
378 2641 aaronmk
    def _key(self, key): return as_Col(key, self.table)
379 2564 aaronmk
380 2524 aaronmk
##### Functions
381
382 2912 aaronmk
Function = Table
383 2911 aaronmk
as_Function = as_Table
384
385 2691 aaronmk
class InternalFunction(CustomCode): pass
386
387 3442 aaronmk
#### Calls
388
389 2941 aaronmk
class NamedArg(NamedCol):
390
    def __init__(self, name, value):
391
        NamedCol.__init__(self, name, value)
392
393
    def to_str(self, db):
394
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
395
396 2524 aaronmk
class FunctionCall(Code):
397 2941 aaronmk
    def __init__(self, function, *args, **kw_args):
398 2524 aaronmk
        '''
399 2690 aaronmk
        @param args [Code|literal-value...] The function's arguments
400 2524 aaronmk
        '''
401 3016 aaronmk
        function = as_Function(function)
402 2941 aaronmk
        def filter_(arg): return remove_col_rename(as_Value(arg))
403
        args = map(filter_, args)
404
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
405 2524 aaronmk
406
        self.function = function
407
        self.args = args
408
409
    def to_str(self, db):
410
        args_str = ', '.join((v.to_str(db) for v in self.args))
411
        return self.function.to_str(db)+'('+args_str+')'
412
413 2533 aaronmk
def wrap_in_func(function, value):
414
    '''Wraps a value inside a function call.
415
    Propagates any column renaming to the returned value.
416
    '''
417 2703 aaronmk
    return wrap(lambda v: FunctionCall(function, v), value)
418 2533 aaronmk
419 2561 aaronmk
def unwrap_func_call(func_call, check_name=None):
420
    '''Unwraps any function call to its first argument.
421
    Also removes any column renaming.
422
    '''
423
    func_call = remove_col_rename(func_call)
424
    if not isinstance(func_call, FunctionCall): return func_call
425
426
    if check_name != None:
427
        name = func_call.function.name
428
        assert name == None or name == check_name
429
    return func_call.args[0]
430
431 3442 aaronmk
#### Definitions
432
433
class FunctionDef(Code):
434
    def __init__(self, function, return_type, body, lang='sql'):
435
        self.function = function
436
        self.return_type = return_type
437
        self.body = body
438
        self.lang = lang
439
440
    def to_str(self, db):
441
        str_ = '''\
442
CREATE FUNCTION '''+self.function.to_str(db)+'''()
443
RETURNS SETOF '''+self.return_type+'''
444
LANGUAGE '''+self.lang+'''
445
AS $$
446
'''+self.body+'''
447
$$;
448
'''
449
        return str_
450
451 2986 aaronmk
##### Casts
452
453
class Cast(FunctionCall):
454
    def __init__(self, type_, value):
455
        value = as_Value(value)
456
457
        self.type_ = type_
458
        self.value = value
459
460
    def to_str(self, db):
461
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_+')'
462
463 3354 aaronmk
def cast_literal(value):
464 3429 aaronmk
    if not is_literal(value): return value
465 3354 aaronmk
466
    if util.is_str(value.value): value = Cast('text', value)
467
    return value
468
469 2335 aaronmk
##### Conditions
470 2259 aaronmk
471 3350 aaronmk
class NotCond(Code):
472
    def __init__(self, cond):
473
        self.cond = cond
474
475
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
476
477 2398 aaronmk
class ColValueCond(Code):
478
    def __init__(self, col, value):
479
        value = as_ValueCond(value)
480
481
        self.col = col
482
        self.value = value
483
484
    def to_str(self, db): return self.value.to_str(db, self.col)
485
486 2577 aaronmk
def combine_conds(conds, keyword=None):
487
    '''
488
    @param keyword The keyword to add before the conditions, if any
489
    '''
490
    str_ = ''
491
    if keyword != None:
492
        if conds == []: whitespace = ''
493
        elif len(conds) == 1: whitespace = ' '
494
        else: whitespace = '\n'
495
        str_ += keyword+whitespace
496
497
    str_ += '\nAND '.join(conds)
498
    return str_
499
500 2398 aaronmk
##### Condition column comparisons
501
502 2514 aaronmk
class ValueCond(BasicObject):
503 2213 aaronmk
    def __init__(self, value):
504 2858 aaronmk
        value = remove_col_rename(as_Value(value))
505 2213 aaronmk
506
        self.value = value
507 2214 aaronmk
508 2216 aaronmk
    def to_str(self, db, left_value):
509 2214 aaronmk
        '''
510 2216 aaronmk
        @param left_value The Code object that the condition is being applied on
511 2214 aaronmk
        '''
512
        raise NotImplemented()
513 2228 aaronmk
514 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
515 2211 aaronmk
516
class CompareCond(ValueCond):
517
    def __init__(self, value, operator='='):
518 2222 aaronmk
        '''
519
        @param operator By default, compares NULL values literally. Use '~=' or
520
            '~!=' to pass NULLs through.
521
        '''
522 2211 aaronmk
        ValueCond.__init__(self, value)
523
        self.operator = operator
524
525 2216 aaronmk
    def to_str(self, db, left_value):
526 2858 aaronmk
        left_value = remove_col_rename(as_Col(left_value))
527 2216 aaronmk
528 2222 aaronmk
        right_value = self.value
529
530
        # Parse operator
531 2216 aaronmk
        operator = self.operator
532 2222 aaronmk
        passthru_null_ref = [False]
533
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
534
        neg_ref = [False]
535
        operator = strings.remove_prefix('!', operator, neg_ref)
536 2844 aaronmk
        equals = operator.endswith('=') # also includes <=, >=
537 2222 aaronmk
538 2825 aaronmk
        # Handle nullable columns
539
        check_null = False
540 2844 aaronmk
        if not passthru_null_ref[0]: # NULLs compare equal
541 2857 aaronmk
            try: left_value = ensure_not_null(db, left_value)
542 2844 aaronmk
            except ensure_not_null_excs: # fall back to alternate method
543
                check_null = equals and isinstance(right_value, Col)
544 2837 aaronmk
            else:
545 2857 aaronmk
                if isinstance(left_value, EnsureNotNull):
546
                    right_value = ensure_not_null(db, right_value,
547
                        left_value.type) # apply same function to both sides
548 2825 aaronmk
549 2844 aaronmk
        if equals and is_null(right_value): operator = 'IS'
550
551 2825 aaronmk
        left = left_value.to_str(db)
552
        right = right_value.to_str(db)
553
554 2222 aaronmk
        # Create str
555
        str_ = left+' '+operator+' '+right
556 2825 aaronmk
        if check_null:
557 2578 aaronmk
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
558
        if neg_ref[0]: str_ = 'NOT '+str_
559 2222 aaronmk
        return str_
560 2216 aaronmk
561 2260 aaronmk
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
562
assume_literal = object()
563
564
def as_ValueCond(value, default_table=assume_literal):
565
    if not isinstance(value, ValueCond):
566
        if default_table is not assume_literal:
567 2748 aaronmk
            value = with_default_table(value, default_table)
568 2260 aaronmk
        return CompareCond(value)
569 2216 aaronmk
    else: return value
570 2219 aaronmk
571 2335 aaronmk
##### Joins
572
573 2352 aaronmk
join_same = object() # tells Join the left and right columns have the same name
574 2260 aaronmk
575 2353 aaronmk
# Tells Join the left and right columns have the same name and are never NULL
576
join_same_not_null = object()
577
578 2260 aaronmk
filter_out = object() # tells Join to filter out rows that match the join
579
580 2514 aaronmk
class Join(BasicObject):
581 2746 aaronmk
    def __init__(self, table, mapping={}, type_=None):
582 2260 aaronmk
        '''
583
        @param mapping dict(right_table_col=left_table_col, ...)
584 2352 aaronmk
            * if left_table_col is join_same: left_table_col = right_table_col
585 2353 aaronmk
              * Note that right_table_col must be a string
586
            * if left_table_col is join_same_not_null:
587
              left_table_col = right_table_col and both have NOT NULL constraint
588
              * Note that right_table_col must be a string
589 2260 aaronmk
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
590
            * filter_out: equivalent to 'LEFT' with the query filtered by
591
              `table_pkey IS NULL` (indicating no match)
592
        '''
593
        if util.is_str(table): table = Table(table)
594
        assert type_ == None or util.is_str(type_) or type_ is filter_out
595
596
        self.table = table
597
        self.mapping = mapping
598
        self.type_ = type_
599
600 2749 aaronmk
    def to_str(self, db, left_table_):
601 2260 aaronmk
        def join(entry):
602
            '''Parses non-USING joins'''
603
            right_table_col, left_table_col = entry
604
605 2353 aaronmk
            # Switch order (right_table_col is on the left in the comparison)
606
            left = right_table_col
607
            right = left_table_col
608 2749 aaronmk
            left_table = self.table
609
            right_table = left_table_
610 2353 aaronmk
611 2747 aaronmk
            # Parse left side
612 2748 aaronmk
            left = with_default_table(left, left_table)
613 2747 aaronmk
614 2260 aaronmk
            # Parse special values
615 2747 aaronmk
            left_on_right = Col(left.name, right_table)
616
            if right is join_same: right = left_on_right
617 2353 aaronmk
            elif right is join_same_not_null:
618 2747 aaronmk
                right = CompareCond(left_on_right, '~=')
619 2260 aaronmk
620 2747 aaronmk
            # Parse right side
621 2353 aaronmk
            right = as_ValueCond(right, right_table)
622 2747 aaronmk
623
            return right.to_str(db, left)
624 2260 aaronmk
625 2265 aaronmk
        # Create join condition
626
        type_ = self.type_
627 2276 aaronmk
        joins = self.mapping
628 2746 aaronmk
        if joins == {}: join_cond = None
629
        elif type_ is not filter_out and reduce(operator.and_,
630 2460 aaronmk
            (v is join_same_not_null for v in joins.itervalues())):
631 2260 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
632 2747 aaronmk
            cols = map(to_name_only_col, joins.iterkeys())
633
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
634 2757 aaronmk
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
635 2260 aaronmk
636 2757 aaronmk
        if isinstance(self.table, NamedTable): whitespace = '\n'
637
        else: whitespace = ' '
638
639 2260 aaronmk
        # Create join
640
        if type_ is filter_out: type_ = 'LEFT'
641 2266 aaronmk
        str_ = ''
642
        if type_ != None: str_ += type_+' '
643 2757 aaronmk
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
644
        if join_cond != None: str_ += whitespace+join_cond
645 2266 aaronmk
        return str_
646 2349 aaronmk
647 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
648 2424 aaronmk
649
##### Value exprs
650
651 3089 aaronmk
all_cols = CustomCode('*')
652
653 2737 aaronmk
default = CustomCode('DEFAULT')
654
655 3090 aaronmk
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
656 2674 aaronmk
657 3061 aaronmk
class Coalesce(FunctionCall):
658
    def __init__(self, *args):
659
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
660 3060 aaronmk
661 3062 aaronmk
class Nullif(FunctionCall):
662
    def __init__(self, *args):
663
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
664
665 2850 aaronmk
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
666 2958 aaronmk
null_sentinels = {
667
    'character varying': r'\N',
668
    'double precision': 'NaN',
669
    'integer': 2147483647,
670
    'text': r'\N',
671
    'timestamp with time zone': 'infinity'
672
}
673 2692 aaronmk
674 3061 aaronmk
class EnsureNotNull(Coalesce):
675 2850 aaronmk
    def __init__(self, value, type_):
676 3061 aaronmk
        Coalesce.__init__(self, as_Col(value),
677 2988 aaronmk
            Cast(type_, null_sentinels[type_]))
678 2850 aaronmk
679
        self.type = type_
680 3001 aaronmk
681
    def to_str(self, db):
682
        col = self.args[0]
683
        index_col_ = index_col(col)
684
        if index_col_ != None: return index_col_.to_str(db)
685 3061 aaronmk
        return Coalesce.to_str(self, db)
686 2850 aaronmk
687 2737 aaronmk
##### Table exprs
688
689
class Values(Code):
690
    def __init__(self, values):
691 2739 aaronmk
        '''
692
        @param values [...]|[[...], ...] Can be one or multiple rows.
693
        '''
694
        rows = values
695
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
696
            rows = [values]
697
        for i, row in enumerate(rows):
698
            rows[i] = map(remove_col_rename, map(as_Value, row))
699 2737 aaronmk
700 2739 aaronmk
        self.rows = rows
701 2737 aaronmk
702
    def to_str(self, db):
703 2739 aaronmk
        def row_str(row):
704
            return '('+(', '.join((v.to_str(db) for v in row)))+')'
705
        return 'VALUES '+(', '.join(map(row_str, self.rows)))
706 2737 aaronmk
707 2740 aaronmk
def NamedValues(name, cols, values):
708 2745 aaronmk
    '''
709 3048 aaronmk
    @param cols None|[...]
710 2745 aaronmk
    @post `cols` will be changed to Col objects with the table set to `name`.
711
    '''
712 2834 aaronmk
    table = NamedTable(name, Values(values), cols)
713 3048 aaronmk
    if cols != None: set_cols_table(table, cols)
714 2834 aaronmk
    return table
715 2740 aaronmk
716 2674 aaronmk
##### Database structure
717
718
class TypedCol(Col):
719 2871 aaronmk
    def __init__(self, name, type_, default=None, nullable=True,
720
        constraints=None):
721 2818 aaronmk
        assert default == None or isinstance(default, Code)
722
723 2674 aaronmk
        Col.__init__(self, name)
724
725
        self.type = type_
726 2818 aaronmk
        self.default = default
727
        self.nullable = nullable
728 2871 aaronmk
        self.constraints = constraints
729 2674 aaronmk
730 2818 aaronmk
    def to_str(self, db):
731
        str_ = Col.to_str(self, db)+' '+self.type
732
        if not self.nullable: str_ += ' NOT NULL'
733
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
734 2871 aaronmk
        if self.constraints != None: str_ += ' '+self.constraints
735 2818 aaronmk
        return str_
736 2674 aaronmk
737
    def to_Col(self): return Col(self.name)
738 2822 aaronmk
739 2840 aaronmk
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
740
741 2851 aaronmk
def ensure_not_null(db, col, type_=None):
742 2840 aaronmk
    '''
743 2855 aaronmk
    @param col If type_ is not set, must have an underlying column.
744 2851 aaronmk
    @param type_ If set, overrides the underlying column's type.
745 2840 aaronmk
    @return EnsureNotNull|Col
746
    @throws ensure_not_null_excs
747
    '''
748 2855 aaronmk
    nullable = True
749
    try: typed_col = db.col_info(underlying_col(col))
750
    except NoUnderlyingTableException:
751 3355 aaronmk
        col = remove_col_rename(col)
752 3429 aaronmk
        if is_literal(col) and not is_null(col): nullable = False
753 3355 aaronmk
        elif type_ == None: raise
754 2855 aaronmk
    else:
755
        if type_ == None: type_ = typed_col.type
756
        nullable = typed_col.nullable
757
758 2953 aaronmk
    if nullable:
759
        try: col = EnsureNotNull(col, type_)
760
        except KeyError, e:
761
            # Warn of no null sentinel for type, even if caller catches error
762
            warnings.warn(UserWarning(exc.str_(e)))
763
            raise
764
765 2840 aaronmk
    return col