Project

General

Profile

1 2211 aaronmk
# SQL code generation
2
3 2748 aaronmk
import copy
4 2276 aaronmk
import operator
5 2568 aaronmk
import re
6 2653 aaronmk
import UserDict
7 2953 aaronmk
import warnings
8 2276 aaronmk
9 2667 aaronmk
import dicts
10 2953 aaronmk
import exc
11 2701 aaronmk
import iters
12
import lists
13 2360 aaronmk
import objects
14 2222 aaronmk
import strings
15 2227 aaronmk
import util
16 2211 aaronmk
17 2587 aaronmk
##### Names
18 2499 aaronmk
19 2608 aaronmk
identifier_max_len = 63 # works for both PostgreSQL and MySQL
20 2587 aaronmk
21 2932 aaronmk
def concat(str_, suffix):
22 2609 aaronmk
    '''Preserves version so that it won't be truncated off the string, leading
23
    to collisions.'''
24 2613 aaronmk
    # Preserve version
25 2995 aaronmk
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
26 2985 aaronmk
    if match:
27
        str_, old_suffix = match.groups()
28
        suffix = old_suffix+suffix
29 2613 aaronmk
30 2932 aaronmk
    return strings.concat(str_, suffix, identifier_max_len)
31 2587 aaronmk
32 2932 aaronmk
def truncate(str_): return concat(str_, '')
33 2842 aaronmk
34 2575 aaronmk
def is_safe_name(name):
35 2583 aaronmk
    '''A name is safe *and unambiguous* if it:
36
    * contains only *lowercase* word (\w) characters
37
    * doesn't start with a digit
38
    * contains "_", so that it's not a keyword
39 2984 aaronmk
    '''
40
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
41 2568 aaronmk
42 2499 aaronmk
def esc_name(name, quote='"'):
43
    return quote + name.replace(quote, quote+quote) + quote
44
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
45
46 2513 aaronmk
def clean_name(name): return name.replace('"', '').replace('`', '')
47
48 3041 aaronmk
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
49
50 2659 aaronmk
##### General SQL code objects
51 2219 aaronmk
52 2349 aaronmk
class MockDb:
53 2503 aaronmk
    def esc_value(self, value): return strings.repr_no_u(value)
54 2349 aaronmk
55 2499 aaronmk
    def esc_name(self, name): return esc_name(name)
56 2859 aaronmk
57
    def col_info(self, col):
58
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
59
60 2349 aaronmk
mockDb = MockDb()
61
62 2514 aaronmk
class BasicObject(objects.BasicObject):
63
    def __init__(self, value): self.value = value
64
65
    def __str__(self): return clean_name(strings.repr_no_u(self))
66
67 2659 aaronmk
##### Unparameterized code objects
68
69 2514 aaronmk
class Code(BasicObject):
70 2658 aaronmk
    def to_str(self, db): raise NotImplementedError()
71 2349 aaronmk
72 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb)
73 2211 aaronmk
74 2269 aaronmk
class CustomCode(Code):
75 2256 aaronmk
    def __init__(self, str_): self.str_ = str_
76
77
    def to_str(self, db): return self.str_
78
79 2815 aaronmk
def as_Code(value, db=None):
80
    '''
81
    @param db If set, runs db.std_code() on the value.
82
    '''
83
    if util.is_str(value):
84
        if db != None: value = db.std_code(value)
85
        return CustomCode(value)
86 2659 aaronmk
    else: return Literal(value)
87
88 2540 aaronmk
class Expr(Code):
89
    def __init__(self, expr): self.expr = expr
90
91
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
92
93 2335 aaronmk
##### Literal values
94
95 2216 aaronmk
class Literal(Code):
96 2211 aaronmk
    def __init__(self, value): self.value = value
97 2213 aaronmk
98
    def to_str(self, db): return db.esc_value(self.value)
99 2211 aaronmk
100 2400 aaronmk
def as_Value(value):
101
    if isinstance(value, Code): return value
102
    else: return Literal(value)
103
104 2216 aaronmk
def is_null(value): return isinstance(value, Literal) and value.value == None
105
106 2711 aaronmk
##### Derived elements
107
108
src_self = object() # tells Col that it is its own source column
109
110
class Derived(Code):
111
    def __init__(self, srcs):
112 2712 aaronmk
        '''An element which was derived from some other element(s).
113 2711 aaronmk
        @param srcs See self.set_srcs()
114
        '''
115
        self.set_srcs(srcs)
116
117 2713 aaronmk
    def set_srcs(self, srcs, overwrite=True):
118 2711 aaronmk
        '''
119
        @param srcs (self_type...)|src_self The element(s) this is derived from
120
        '''
121 2713 aaronmk
        if not overwrite and self.srcs != (): return # already set
122
123 2711 aaronmk
        if srcs == src_self: srcs = (self,)
124
        srcs = tuple(srcs) # make Col hashable
125
        self.srcs = srcs
126
127
    def _compare_on(self):
128
        compare_on = self.__dict__.copy()
129
        del compare_on['srcs'] # ignore
130
        return compare_on
131
132
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
133
134 2335 aaronmk
##### Tables
135
136 2712 aaronmk
class Table(Derived):
137 2991 aaronmk
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
138 2211 aaronmk
        '''
139
        @param schema str|None (for no schema)
140 2712 aaronmk
        @param srcs (Table...)|src_self See Derived.set_srcs()
141 2211 aaronmk
        '''
142 2712 aaronmk
        Derived.__init__(self, srcs)
143
144 2843 aaronmk
        name = truncate(name)
145
146 2211 aaronmk
        self.name = name
147
        self.schema = schema
148 2991 aaronmk
        self.is_temp = is_temp
149 3000 aaronmk
        self.index_cols = {}
150 2211 aaronmk
151 2348 aaronmk
    def to_str(self, db):
152
        str_ = ''
153
        if self.schema != None: str_ += db.esc_name(self.schema)+'.'
154
        str_ += db.esc_name(self.name)
155
        return str_
156 2336 aaronmk
157
    def to_Table(self): return self
158 3000 aaronmk
159
    def _compare_on(self):
160
        compare_on = Derived._compare_on(self)
161
        del compare_on['index_cols'] # ignore
162
        return compare_on
163 2211 aaronmk
164 2835 aaronmk
def is_underlying_table(table):
165
    return isinstance(table, Table) and table.to_Table() is table
166 2832 aaronmk
167 2902 aaronmk
class NoUnderlyingTableException(Exception): pass
168
169
def underlying_table(table):
170
    table = remove_table_rename(table)
171
    if not is_underlying_table(table): raise NoUnderlyingTableException
172
    return table
173
174 2776 aaronmk
def as_Table(table, schema=None):
175 2270 aaronmk
    if table == None or isinstance(table, Code): return table
176 2776 aaronmk
    else: return Table(table, schema)
177 2219 aaronmk
178 2707 aaronmk
def suffixed_table(table, suffix): return Table(table.name+suffix, table.schema)
179
180 2336 aaronmk
class NamedTable(Table):
181
    def __init__(self, name, code, cols=None):
182
        Table.__init__(self, name)
183
184 3016 aaronmk
        code = as_Table(code)
185 2741 aaronmk
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
186 3020 aaronmk
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
187 2336 aaronmk
188
        self.code = code
189
        self.cols = cols
190
191
    def to_str(self, db):
192 3026 aaronmk
        str_ = self.code.to_str(db)
193
        if str_.find('\n') >= 0: whitespace = '\n'
194
        else: whitespace = ' '
195
        str_ += whitespace+'AS '+Table.to_str(self, db)
196 2742 aaronmk
        if self.cols != None:
197
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
198 2336 aaronmk
        return str_
199
200
    def to_Table(self): return Table(self.name)
201
202 2753 aaronmk
def remove_table_rename(table):
203
    if isinstance(table, NamedTable): table = table.code
204
    return table
205
206 2335 aaronmk
##### Columns
207
208 2711 aaronmk
class Col(Derived):
209 2701 aaronmk
    def __init__(self, name, table=None, srcs=()):
210 2211 aaronmk
        '''
211
        @param table Table|None (for no table)
212 2711 aaronmk
        @param srcs (Col...)|src_self See Derived.set_srcs()
213 2211 aaronmk
        '''
214 2711 aaronmk
        Derived.__init__(self, srcs)
215
216 2843 aaronmk
        name = truncate(name)
217 2241 aaronmk
        if util.is_str(table): table = Table(table)
218 2211 aaronmk
        assert table == None or isinstance(table, Table)
219
220
        self.name = name
221
        self.table = table
222
223 2989 aaronmk
    def to_str(self, db, for_str=False):
224 2933 aaronmk
        str_ = db.esc_name(self.name)
225 2989 aaronmk
        if for_str: str_ = clean_name(str_)
226 2933 aaronmk
        if self.table != None:
227 2989 aaronmk
            table = self.table.to_Table()
228
            if for_str: str_ = concat(str(table), '.'+str_)
229
            else: str_ = table.to_str(db)+'.'+str_
230 2211 aaronmk
        return str_
231 2314 aaronmk
232 2989 aaronmk
    def __str__(self): return self.to_str(mockDb, for_str=True)
233 2933 aaronmk
234 2314 aaronmk
    def to_Col(self): return self
235 2211 aaronmk
236 2767 aaronmk
def is_table_col(col): return isinstance(col, Col) and col.table != None
237 2393 aaronmk
238 3000 aaronmk
def index_col(col):
239
    if not is_table_col(col): return None
240
    return col.table.index_cols.get(col.name, None)
241 2999 aaronmk
242 3024 aaronmk
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
243 2996 aaronmk
244 2563 aaronmk
def as_Col(col, table=None, name=None):
245
    '''
246
    @param name If not None, any non-Col input will be renamed using NamedCol.
247
    '''
248
    if name != None:
249
        col = as_Value(col)
250
        if not isinstance(col, Col): col = NamedCol(name, col)
251 2333 aaronmk
252
    if isinstance(col, Code): return col
253 2260 aaronmk
    else: return Col(col, table)
254
255 2750 aaronmk
def with_default_table(col, table, overwrite=False):
256 2747 aaronmk
    col = as_Col(col)
257 2750 aaronmk
    if not isinstance(col, NamedCol) and (overwrite or col.table == None):
258 2748 aaronmk
        col = copy.copy(col) # don't modify input!
259
        col.table = table
260 2747 aaronmk
    return col
261
262 2744 aaronmk
def set_cols_table(table, cols):
263
    table = as_Table(table)
264
265
    for i, col in enumerate(cols):
266
        col = cols[i] = as_Col(col)
267
        col.table = table
268
269 2401 aaronmk
def to_name_only_col(col, check_table=None):
270
    col = as_Col(col)
271 3020 aaronmk
    if not is_table_col(col): return col
272 2401 aaronmk
273
    if check_table != None:
274
        table = col.table
275
        assert table == None or table == check_table
276
    return Col(col.name)
277
278 2993 aaronmk
def suffixed_col(col, suffix):
279
    return Col(concat(col.name, suffix), col.table, col.srcs)
280
281 2323 aaronmk
class NamedCol(Col):
282 2229 aaronmk
    def __init__(self, name, code):
283 2310 aaronmk
        Col.__init__(self, name)
284
285 3016 aaronmk
        code = as_Value(code)
286 2229 aaronmk
287
        self.code = code
288
289
    def to_str(self, db):
290 2310 aaronmk
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
291 2314 aaronmk
292
    def to_Col(self): return Col(self.name)
293 2229 aaronmk
294 2462 aaronmk
def remove_col_rename(col):
295
    if isinstance(col, NamedCol): col = col.code
296
    return col
297
298 2830 aaronmk
def underlying_col(col):
299
    col = remove_col_rename(col)
300 2849 aaronmk
    if not isinstance(col, Col): raise NoUnderlyingTableException
301
302 2902 aaronmk
    return Col(col.name, underlying_table(col.table), col.srcs)
303 2830 aaronmk
304 2703 aaronmk
def wrap(wrap_func, value):
305
    '''Wraps a value, propagating any column renaming to the returned value.'''
306
    if isinstance(value, NamedCol):
307
        return NamedCol(value.name, wrap_func(value.code))
308
    else: return wrap_func(value)
309
310 2667 aaronmk
class ColDict(dicts.DictProxy):
311 2564 aaronmk
    '''A dict that automatically makes inserted entries Col objects'''
312
313 2645 aaronmk
    def __init__(self, db, keys_table, dict_={}):
314 2667 aaronmk
        dicts.DictProxy.__init__(self, {})
315
316 2645 aaronmk
        keys_table = as_Table(keys_table)
317
318 2642 aaronmk
        self.db = db
319 2641 aaronmk
        self.table = keys_table
320 2653 aaronmk
        self.update(dict_) # after setting vars because __setitem__() needs them
321 2641 aaronmk
322 2667 aaronmk
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
323 2655 aaronmk
324 2667 aaronmk
    def __getitem__(self, key):
325
        return dicts.DictProxy.__getitem__(self, self._key(key))
326 2653 aaronmk
327 2564 aaronmk
    def __setitem__(self, key, value):
328 2642 aaronmk
        key = self._key(key)
329 2819 aaronmk
        if value == None: value = self.db.col_info(key).default
330 2667 aaronmk
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
331 2564 aaronmk
332 2641 aaronmk
    def _key(self, key): return as_Col(key, self.table)
333 2564 aaronmk
334 2524 aaronmk
##### Functions
335
336 2912 aaronmk
Function = Table
337 2911 aaronmk
as_Function = as_Table
338
339 2691 aaronmk
class InternalFunction(CustomCode): pass
340
341 2941 aaronmk
class NamedArg(NamedCol):
342
    def __init__(self, name, value):
343
        NamedCol.__init__(self, name, value)
344
345
    def to_str(self, db):
346
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
347
348 2524 aaronmk
class FunctionCall(Code):
349 2941 aaronmk
    def __init__(self, function, *args, **kw_args):
350 2524 aaronmk
        '''
351 2690 aaronmk
        @param args [Code|literal-value...] The function's arguments
352 2524 aaronmk
        '''
353 3016 aaronmk
        function = as_Function(function)
354 2941 aaronmk
        def filter_(arg): return remove_col_rename(as_Value(arg))
355
        args = map(filter_, args)
356
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
357 2524 aaronmk
358
        self.function = function
359
        self.args = args
360
361
    def to_str(self, db):
362
        args_str = ', '.join((v.to_str(db) for v in self.args))
363
        return self.function.to_str(db)+'('+args_str+')'
364
365 2533 aaronmk
def wrap_in_func(function, value):
366
    '''Wraps a value inside a function call.
367
    Propagates any column renaming to the returned value.
368
    '''
369 2703 aaronmk
    return wrap(lambda v: FunctionCall(function, v), value)
370 2533 aaronmk
371 2561 aaronmk
def unwrap_func_call(func_call, check_name=None):
372
    '''Unwraps any function call to its first argument.
373
    Also removes any column renaming.
374
    '''
375
    func_call = remove_col_rename(func_call)
376
    if not isinstance(func_call, FunctionCall): return func_call
377
378
    if check_name != None:
379
        name = func_call.function.name
380
        assert name == None or name == check_name
381
    return func_call.args[0]
382
383 2986 aaronmk
##### Casts
384
385
class Cast(FunctionCall):
386
    def __init__(self, type_, value):
387
        value = as_Value(value)
388
389
        self.type_ = type_
390
        self.value = value
391
392
    def to_str(self, db):
393
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_+')'
394
395 2335 aaronmk
##### Conditions
396 2259 aaronmk
397 2398 aaronmk
class ColValueCond(Code):
398
    def __init__(self, col, value):
399
        value = as_ValueCond(value)
400
401
        self.col = col
402
        self.value = value
403
404
    def to_str(self, db): return self.value.to_str(db, self.col)
405
406 2577 aaronmk
def combine_conds(conds, keyword=None):
407
    '''
408
    @param keyword The keyword to add before the conditions, if any
409
    '''
410
    str_ = ''
411
    if keyword != None:
412
        if conds == []: whitespace = ''
413
        elif len(conds) == 1: whitespace = ' '
414
        else: whitespace = '\n'
415
        str_ += keyword+whitespace
416
417
    str_ += '\nAND '.join(conds)
418
    return str_
419
420 2398 aaronmk
##### Condition column comparisons
421
422 2514 aaronmk
class ValueCond(BasicObject):
423 2213 aaronmk
    def __init__(self, value):
424 2858 aaronmk
        value = remove_col_rename(as_Value(value))
425 2213 aaronmk
426
        self.value = value
427 2214 aaronmk
428 2216 aaronmk
    def to_str(self, db, left_value):
429 2214 aaronmk
        '''
430 2216 aaronmk
        @param left_value The Code object that the condition is being applied on
431 2214 aaronmk
        '''
432
        raise NotImplemented()
433 2228 aaronmk
434 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
435 2211 aaronmk
436
class CompareCond(ValueCond):
437
    def __init__(self, value, operator='='):
438 2222 aaronmk
        '''
439
        @param operator By default, compares NULL values literally. Use '~=' or
440
            '~!=' to pass NULLs through.
441
        '''
442 2211 aaronmk
        ValueCond.__init__(self, value)
443
        self.operator = operator
444
445 2216 aaronmk
    def to_str(self, db, left_value):
446 2858 aaronmk
        left_value = remove_col_rename(as_Col(left_value))
447 2216 aaronmk
448 2222 aaronmk
        right_value = self.value
449
450
        # Parse operator
451 2216 aaronmk
        operator = self.operator
452 2222 aaronmk
        passthru_null_ref = [False]
453
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
454
        neg_ref = [False]
455
        operator = strings.remove_prefix('!', operator, neg_ref)
456 2844 aaronmk
        equals = operator.endswith('=') # also includes <=, >=
457 2222 aaronmk
458 2825 aaronmk
        # Handle nullable columns
459
        check_null = False
460 2844 aaronmk
        if not passthru_null_ref[0]: # NULLs compare equal
461 2857 aaronmk
            try: left_value = ensure_not_null(db, left_value)
462 2844 aaronmk
            except ensure_not_null_excs: # fall back to alternate method
463
                check_null = equals and isinstance(right_value, Col)
464 2837 aaronmk
            else:
465 2857 aaronmk
                if isinstance(left_value, EnsureNotNull):
466
                    right_value = ensure_not_null(db, right_value,
467
                        left_value.type) # apply same function to both sides
468 2825 aaronmk
469 2844 aaronmk
        if equals and is_null(right_value): operator = 'IS'
470
471 2825 aaronmk
        left = left_value.to_str(db)
472
        right = right_value.to_str(db)
473
474 2222 aaronmk
        # Create str
475
        str_ = left+' '+operator+' '+right
476 2825 aaronmk
        if check_null:
477 2578 aaronmk
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
478
        if neg_ref[0]: str_ = 'NOT '+str_
479 2222 aaronmk
        return str_
480 2216 aaronmk
481 2260 aaronmk
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
482
assume_literal = object()
483
484
def as_ValueCond(value, default_table=assume_literal):
485
    if not isinstance(value, ValueCond):
486
        if default_table is not assume_literal:
487 2748 aaronmk
            value = with_default_table(value, default_table)
488 2260 aaronmk
        return CompareCond(value)
489 2216 aaronmk
    else: return value
490 2219 aaronmk
491 2335 aaronmk
##### Joins
492
493 2352 aaronmk
join_same = object() # tells Join the left and right columns have the same name
494 2260 aaronmk
495 2353 aaronmk
# Tells Join the left and right columns have the same name and are never NULL
496
join_same_not_null = object()
497
498 2260 aaronmk
filter_out = object() # tells Join to filter out rows that match the join
499
500 2514 aaronmk
class Join(BasicObject):
501 2746 aaronmk
    def __init__(self, table, mapping={}, type_=None):
502 2260 aaronmk
        '''
503
        @param mapping dict(right_table_col=left_table_col, ...)
504 2352 aaronmk
            * if left_table_col is join_same: left_table_col = right_table_col
505 2353 aaronmk
              * Note that right_table_col must be a string
506
            * if left_table_col is join_same_not_null:
507
              left_table_col = right_table_col and both have NOT NULL constraint
508
              * Note that right_table_col must be a string
509 2260 aaronmk
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
510
            * filter_out: equivalent to 'LEFT' with the query filtered by
511
              `table_pkey IS NULL` (indicating no match)
512
        '''
513
        if util.is_str(table): table = Table(table)
514
        assert type_ == None or util.is_str(type_) or type_ is filter_out
515
516
        self.table = table
517
        self.mapping = mapping
518
        self.type_ = type_
519
520 2749 aaronmk
    def to_str(self, db, left_table_):
521 2260 aaronmk
        def join(entry):
522
            '''Parses non-USING joins'''
523
            right_table_col, left_table_col = entry
524
525 2353 aaronmk
            # Switch order (right_table_col is on the left in the comparison)
526
            left = right_table_col
527
            right = left_table_col
528 2749 aaronmk
            left_table = self.table
529
            right_table = left_table_
530 2353 aaronmk
531 2747 aaronmk
            # Parse left side
532 2748 aaronmk
            left = with_default_table(left, left_table)
533 2747 aaronmk
534 2260 aaronmk
            # Parse special values
535 2747 aaronmk
            left_on_right = Col(left.name, right_table)
536
            if right is join_same: right = left_on_right
537 2353 aaronmk
            elif right is join_same_not_null:
538 2747 aaronmk
                right = CompareCond(left_on_right, '~=')
539 2260 aaronmk
540 2747 aaronmk
            # Parse right side
541 2353 aaronmk
            right = as_ValueCond(right, right_table)
542 2747 aaronmk
543
            return right.to_str(db, left)
544 2260 aaronmk
545 2265 aaronmk
        # Create join condition
546
        type_ = self.type_
547 2276 aaronmk
        joins = self.mapping
548 2746 aaronmk
        if joins == {}: join_cond = None
549
        elif type_ is not filter_out and reduce(operator.and_,
550 2460 aaronmk
            (v is join_same_not_null for v in joins.itervalues())):
551 2260 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
552 2747 aaronmk
            cols = map(to_name_only_col, joins.iterkeys())
553
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
554 2757 aaronmk
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
555 2260 aaronmk
556 2757 aaronmk
        if isinstance(self.table, NamedTable): whitespace = '\n'
557
        else: whitespace = ' '
558
559 2260 aaronmk
        # Create join
560
        if type_ is filter_out: type_ = 'LEFT'
561 2266 aaronmk
        str_ = ''
562
        if type_ != None: str_ += type_+' '
563 2757 aaronmk
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
564
        if join_cond != None: str_ += whitespace+join_cond
565 2266 aaronmk
        return str_
566 2349 aaronmk
567 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
568 2424 aaronmk
569
##### Value exprs
570
571 2737 aaronmk
default = CustomCode('DEFAULT')
572
573 2424 aaronmk
row_count = CustomCode('count(*)')
574 2674 aaronmk
575 2850 aaronmk
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
576 2958 aaronmk
null_sentinels = {
577
    'character varying': r'\N',
578
    'double precision': 'NaN',
579
    'integer': 2147483647,
580
    'text': r'\N',
581
    'timestamp with time zone': 'infinity'
582
}
583 2692 aaronmk
584 2850 aaronmk
class EnsureNotNull(FunctionCall):
585
    def __init__(self, value, type_):
586 2870 aaronmk
        FunctionCall.__init__(self, InternalFunction('COALESCE'), as_Col(value),
587 2988 aaronmk
            Cast(type_, null_sentinels[type_]))
588 2850 aaronmk
589
        self.type = type_
590 3001 aaronmk
591
    def to_str(self, db):
592
        col = self.args[0]
593
        index_col_ = index_col(col)
594
        if index_col_ != None: return index_col_.to_str(db)
595
        return FunctionCall.to_str(self, db)
596 2850 aaronmk
597 2737 aaronmk
##### Table exprs
598
599
class Values(Code):
600
    def __init__(self, values):
601 2739 aaronmk
        '''
602
        @param values [...]|[[...], ...] Can be one or multiple rows.
603
        '''
604
        rows = values
605
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
606
            rows = [values]
607
        for i, row in enumerate(rows):
608
            rows[i] = map(remove_col_rename, map(as_Value, row))
609 2737 aaronmk
610 2739 aaronmk
        self.rows = rows
611 2737 aaronmk
612
    def to_str(self, db):
613 2739 aaronmk
        def row_str(row):
614
            return '('+(', '.join((v.to_str(db) for v in row)))+')'
615
        return 'VALUES '+(', '.join(map(row_str, self.rows)))
616 2737 aaronmk
617 2740 aaronmk
def NamedValues(name, cols, values):
618 2745 aaronmk
    '''
619
    @post `cols` will be changed to Col objects with the table set to `name`.
620
    '''
621 2834 aaronmk
    table = NamedTable(name, Values(values), cols)
622
    set_cols_table(table, cols)
623
    return table
624 2740 aaronmk
625 2674 aaronmk
##### Database structure
626
627
class TypedCol(Col):
628 2871 aaronmk
    def __init__(self, name, type_, default=None, nullable=True,
629
        constraints=None):
630 2818 aaronmk
        assert default == None or isinstance(default, Code)
631
632 2674 aaronmk
        Col.__init__(self, name)
633
634
        self.type = type_
635 2818 aaronmk
        self.default = default
636
        self.nullable = nullable
637 2871 aaronmk
        self.constraints = constraints
638 2674 aaronmk
639 2818 aaronmk
    def to_str(self, db):
640
        str_ = Col.to_str(self, db)+' '+self.type
641
        if not self.nullable: str_ += ' NOT NULL'
642
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
643 2871 aaronmk
        if self.constraints != None: str_ += ' '+self.constraints
644 2818 aaronmk
        return str_
645 2674 aaronmk
646
    def to_Col(self): return Col(self.name)
647 2822 aaronmk
648 2840 aaronmk
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
649
650 2851 aaronmk
def ensure_not_null(db, col, type_=None):
651 2840 aaronmk
    '''
652 2855 aaronmk
    @param col If type_ is not set, must have an underlying column.
653 2851 aaronmk
    @param type_ If set, overrides the underlying column's type.
654 2840 aaronmk
    @return EnsureNotNull|Col
655
    @throws ensure_not_null_excs
656
    '''
657 2855 aaronmk
    nullable = True
658
    try: typed_col = db.col_info(underlying_col(col))
659
    except NoUnderlyingTableException:
660
        if type_ == None: raise
661
    else:
662
        if type_ == None: type_ = typed_col.type
663
        nullable = typed_col.nullable
664
665 2953 aaronmk
    if nullable:
666
        try: col = EnsureNotNull(col, type_)
667
        except KeyError, e:
668
            # Warn of no null sentinel for type, even if caller catches error
669
            warnings.warn(UserWarning(exc.str_(e)))
670
            raise
671
672 2840 aaronmk
    return col