Project

General

Profile

1 2211 aaronmk
# SQL code generation
2
3 2748 aaronmk
import copy
4 2276 aaronmk
import operator
5 2568 aaronmk
import re
6 2653 aaronmk
import UserDict
7 2953 aaronmk
import warnings
8 2276 aaronmk
9 2667 aaronmk
import dicts
10 2953 aaronmk
import exc
11 2701 aaronmk
import iters
12
import lists
13 2360 aaronmk
import objects
14 2222 aaronmk
import strings
15 2227 aaronmk
import util
16 2211 aaronmk
17 2587 aaronmk
##### Names
18 2499 aaronmk
19 2608 aaronmk
identifier_max_len = 63 # works for both PostgreSQL and MySQL
20 2587 aaronmk
21 2932 aaronmk
def concat(str_, suffix):
22 2609 aaronmk
    '''Preserves version so that it won't be truncated off the string, leading
23
    to collisions.'''
24 2613 aaronmk
    # Preserve version
25 2995 aaronmk
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
26 2985 aaronmk
    if match:
27
        str_, old_suffix = match.groups()
28
        suffix = old_suffix+suffix
29 2613 aaronmk
30 2932 aaronmk
    return strings.concat(str_, suffix, identifier_max_len)
31 2587 aaronmk
32 2932 aaronmk
def truncate(str_): return concat(str_, '')
33 2842 aaronmk
34 2575 aaronmk
def is_safe_name(name):
35 2583 aaronmk
    '''A name is safe *and unambiguous* if it:
36
    * contains only *lowercase* word (\w) characters
37
    * doesn't start with a digit
38
    * contains "_", so that it's not a keyword
39 2984 aaronmk
    '''
40
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
41 2568 aaronmk
42 2499 aaronmk
def esc_name(name, quote='"'):
43
    return quote + name.replace(quote, quote+quote) + quote
44
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
45
46 2513 aaronmk
def clean_name(name): return name.replace('"', '').replace('`', '')
47
48 2659 aaronmk
##### General SQL code objects
49 2219 aaronmk
50 2349 aaronmk
class MockDb:
51 2503 aaronmk
    def esc_value(self, value): return strings.repr_no_u(value)
52 2349 aaronmk
53 2499 aaronmk
    def esc_name(self, name): return esc_name(name)
54 2859 aaronmk
55
    def col_info(self, col):
56
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
57
58 2349 aaronmk
mockDb = MockDb()
59
60 2514 aaronmk
class BasicObject(objects.BasicObject):
61
    def __init__(self, value): self.value = value
62
63
    def __str__(self): return clean_name(strings.repr_no_u(self))
64
65 2659 aaronmk
##### Unparameterized code objects
66
67 2514 aaronmk
class Code(BasicObject):
68 2658 aaronmk
    def to_str(self, db): raise NotImplementedError()
69 2349 aaronmk
70 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb)
71 2211 aaronmk
72 2269 aaronmk
class CustomCode(Code):
73 2256 aaronmk
    def __init__(self, str_): self.str_ = str_
74
75
    def to_str(self, db): return self.str_
76
77 2815 aaronmk
def as_Code(value, db=None):
78
    '''
79
    @param db If set, runs db.std_code() on the value.
80
    '''
81
    if util.is_str(value):
82
        if db != None: value = db.std_code(value)
83
        return CustomCode(value)
84 2659 aaronmk
    else: return Literal(value)
85
86 2540 aaronmk
class Expr(Code):
87
    def __init__(self, expr): self.expr = expr
88
89
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
90
91 2335 aaronmk
##### Literal values
92
93 2216 aaronmk
class Literal(Code):
94 2211 aaronmk
    def __init__(self, value): self.value = value
95 2213 aaronmk
96
    def to_str(self, db): return db.esc_value(self.value)
97 2211 aaronmk
98 2400 aaronmk
def as_Value(value):
99
    if isinstance(value, Code): return value
100
    else: return Literal(value)
101
102 2216 aaronmk
def is_null(value): return isinstance(value, Literal) and value.value == None
103
104 2711 aaronmk
##### Derived elements
105
106
src_self = object() # tells Col that it is its own source column
107
108
class Derived(Code):
109
    def __init__(self, srcs):
110 2712 aaronmk
        '''An element which was derived from some other element(s).
111 2711 aaronmk
        @param srcs See self.set_srcs()
112
        '''
113
        self.set_srcs(srcs)
114
115 2713 aaronmk
    def set_srcs(self, srcs, overwrite=True):
116 2711 aaronmk
        '''
117
        @param srcs (self_type...)|src_self The element(s) this is derived from
118
        '''
119 2713 aaronmk
        if not overwrite and self.srcs != (): return # already set
120
121 2711 aaronmk
        if srcs == src_self: srcs = (self,)
122
        srcs = tuple(srcs) # make Col hashable
123
        self.srcs = srcs
124
125
    def _compare_on(self):
126
        compare_on = self.__dict__.copy()
127
        del compare_on['srcs'] # ignore
128
        return compare_on
129
130
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
131
132 2335 aaronmk
##### Tables
133
134 2712 aaronmk
class Table(Derived):
135 2991 aaronmk
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
136 2211 aaronmk
        '''
137
        @param schema str|None (for no schema)
138 2712 aaronmk
        @param srcs (Table...)|src_self See Derived.set_srcs()
139 2211 aaronmk
        '''
140 2712 aaronmk
        Derived.__init__(self, srcs)
141
142 2843 aaronmk
        name = truncate(name)
143
144 2211 aaronmk
        self.name = name
145
        self.schema = schema
146 2991 aaronmk
        self.is_temp = is_temp
147 3000 aaronmk
        self.index_cols = {}
148 2211 aaronmk
149 2348 aaronmk
    def to_str(self, db):
150
        str_ = ''
151
        if self.schema != None: str_ += db.esc_name(self.schema)+'.'
152
        str_ += db.esc_name(self.name)
153
        return str_
154 2336 aaronmk
155
    def to_Table(self): return self
156 3000 aaronmk
157
    def _compare_on(self):
158
        compare_on = Derived._compare_on(self)
159
        del compare_on['index_cols'] # ignore
160
        return compare_on
161 2211 aaronmk
162 2835 aaronmk
def is_underlying_table(table):
163
    return isinstance(table, Table) and table.to_Table() is table
164 2832 aaronmk
165 2902 aaronmk
class NoUnderlyingTableException(Exception): pass
166
167
def underlying_table(table):
168
    table = remove_table_rename(table)
169
    if not is_underlying_table(table): raise NoUnderlyingTableException
170
    return table
171
172 2776 aaronmk
def as_Table(table, schema=None):
173 2270 aaronmk
    if table == None or isinstance(table, Code): return table
174 2776 aaronmk
    else: return Table(table, schema)
175 2219 aaronmk
176 2707 aaronmk
def suffixed_table(table, suffix): return Table(table.name+suffix, table.schema)
177
178 2336 aaronmk
class NamedTable(Table):
179
    def __init__(self, name, code, cols=None):
180
        Table.__init__(self, name)
181
182 3016 aaronmk
        code = as_Table(code)
183 2741 aaronmk
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
184 3020 aaronmk
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
185 2336 aaronmk
186
        self.code = code
187
        self.cols = cols
188
189
    def to_str(self, db):
190 3026 aaronmk
        str_ = self.code.to_str(db)
191
        if str_.find('\n') >= 0: whitespace = '\n'
192
        else: whitespace = ' '
193
        str_ += whitespace+'AS '+Table.to_str(self, db)
194 2742 aaronmk
        if self.cols != None:
195
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
196 2336 aaronmk
        return str_
197
198
    def to_Table(self): return Table(self.name)
199
200 2753 aaronmk
def remove_table_rename(table):
201
    if isinstance(table, NamedTable): table = table.code
202
    return table
203
204 2335 aaronmk
##### Columns
205
206 2711 aaronmk
class Col(Derived):
207 2701 aaronmk
    def __init__(self, name, table=None, srcs=()):
208 2211 aaronmk
        '''
209
        @param table Table|None (for no table)
210 2711 aaronmk
        @param srcs (Col...)|src_self See Derived.set_srcs()
211 2211 aaronmk
        '''
212 2711 aaronmk
        Derived.__init__(self, srcs)
213
214 2843 aaronmk
        name = truncate(name)
215 2241 aaronmk
        if util.is_str(table): table = Table(table)
216 2211 aaronmk
        assert table == None or isinstance(table, Table)
217
218
        self.name = name
219
        self.table = table
220
221 2989 aaronmk
    def to_str(self, db, for_str=False):
222 2933 aaronmk
        str_ = db.esc_name(self.name)
223 2989 aaronmk
        if for_str: str_ = clean_name(str_)
224 2933 aaronmk
        if self.table != None:
225 2989 aaronmk
            table = self.table.to_Table()
226
            if for_str: str_ = concat(str(table), '.'+str_)
227
            else: str_ = table.to_str(db)+'.'+str_
228 2211 aaronmk
        return str_
229 2314 aaronmk
230 2989 aaronmk
    def __str__(self): return self.to_str(mockDb, for_str=True)
231 2933 aaronmk
232 2314 aaronmk
    def to_Col(self): return self
233 2211 aaronmk
234 2767 aaronmk
def is_table_col(col): return isinstance(col, Col) and col.table != None
235 2393 aaronmk
236 3000 aaronmk
def index_col(col):
237
    if not is_table_col(col): return None
238
    return col.table.index_cols.get(col.name, None)
239 2999 aaronmk
240 3024 aaronmk
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
241 2996 aaronmk
242 2563 aaronmk
def as_Col(col, table=None, name=None):
243
    '''
244
    @param name If not None, any non-Col input will be renamed using NamedCol.
245
    '''
246
    if name != None:
247
        col = as_Value(col)
248
        if not isinstance(col, Col): col = NamedCol(name, col)
249 2333 aaronmk
250
    if isinstance(col, Code): return col
251 2260 aaronmk
    else: return Col(col, table)
252
253 2750 aaronmk
def with_default_table(col, table, overwrite=False):
254 2747 aaronmk
    col = as_Col(col)
255 2750 aaronmk
    if not isinstance(col, NamedCol) and (overwrite or col.table == None):
256 2748 aaronmk
        col = copy.copy(col) # don't modify input!
257
        col.table = table
258 2747 aaronmk
    return col
259
260 2744 aaronmk
def set_cols_table(table, cols):
261
    table = as_Table(table)
262
263
    for i, col in enumerate(cols):
264
        col = cols[i] = as_Col(col)
265
        col.table = table
266
267 2401 aaronmk
def to_name_only_col(col, check_table=None):
268
    col = as_Col(col)
269 3020 aaronmk
    if not is_table_col(col): return col
270 2401 aaronmk
271
    if check_table != None:
272
        table = col.table
273
        assert table == None or table == check_table
274
    return Col(col.name)
275
276 2993 aaronmk
def suffixed_col(col, suffix):
277
    return Col(concat(col.name, suffix), col.table, col.srcs)
278
279 2323 aaronmk
class NamedCol(Col):
280 2229 aaronmk
    def __init__(self, name, code):
281 2310 aaronmk
        Col.__init__(self, name)
282
283 3016 aaronmk
        code = as_Value(code)
284 2229 aaronmk
285
        self.code = code
286
287
    def to_str(self, db):
288 2310 aaronmk
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
289 2314 aaronmk
290
    def to_Col(self): return Col(self.name)
291 2229 aaronmk
292 2462 aaronmk
def remove_col_rename(col):
293
    if isinstance(col, NamedCol): col = col.code
294
    return col
295
296 2830 aaronmk
def underlying_col(col):
297
    col = remove_col_rename(col)
298 2849 aaronmk
    if not isinstance(col, Col): raise NoUnderlyingTableException
299
300 2902 aaronmk
    return Col(col.name, underlying_table(col.table), col.srcs)
301 2830 aaronmk
302 2703 aaronmk
def wrap(wrap_func, value):
303
    '''Wraps a value, propagating any column renaming to the returned value.'''
304
    if isinstance(value, NamedCol):
305
        return NamedCol(value.name, wrap_func(value.code))
306
    else: return wrap_func(value)
307
308 2667 aaronmk
class ColDict(dicts.DictProxy):
309 2564 aaronmk
    '''A dict that automatically makes inserted entries Col objects'''
310
311 2645 aaronmk
    def __init__(self, db, keys_table, dict_={}):
312 2667 aaronmk
        dicts.DictProxy.__init__(self, {})
313
314 2645 aaronmk
        keys_table = as_Table(keys_table)
315
316 2642 aaronmk
        self.db = db
317 2641 aaronmk
        self.table = keys_table
318 2653 aaronmk
        self.update(dict_) # after setting vars because __setitem__() needs them
319 2641 aaronmk
320 2667 aaronmk
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
321 2655 aaronmk
322 2667 aaronmk
    def __getitem__(self, key):
323
        return dicts.DictProxy.__getitem__(self, self._key(key))
324 2653 aaronmk
325 2564 aaronmk
    def __setitem__(self, key, value):
326 2642 aaronmk
        key = self._key(key)
327 2819 aaronmk
        if value == None: value = self.db.col_info(key).default
328 2667 aaronmk
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
329 2564 aaronmk
330 2641 aaronmk
    def _key(self, key): return as_Col(key, self.table)
331 2564 aaronmk
332 2524 aaronmk
##### Functions
333
334 2912 aaronmk
Function = Table
335 2911 aaronmk
as_Function = as_Table
336
337 2691 aaronmk
class InternalFunction(CustomCode): pass
338
339 2941 aaronmk
class NamedArg(NamedCol):
340
    def __init__(self, name, value):
341
        NamedCol.__init__(self, name, value)
342
343
    def to_str(self, db):
344
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
345
346 2524 aaronmk
class FunctionCall(Code):
347 2941 aaronmk
    def __init__(self, function, *args, **kw_args):
348 2524 aaronmk
        '''
349 2690 aaronmk
        @param args [Code|literal-value...] The function's arguments
350 2524 aaronmk
        '''
351 3016 aaronmk
        function = as_Function(function)
352 2941 aaronmk
        def filter_(arg): return remove_col_rename(as_Value(arg))
353
        args = map(filter_, args)
354
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
355 2524 aaronmk
356
        self.function = function
357
        self.args = args
358
359
    def to_str(self, db):
360
        args_str = ', '.join((v.to_str(db) for v in self.args))
361
        return self.function.to_str(db)+'('+args_str+')'
362
363 2533 aaronmk
def wrap_in_func(function, value):
364
    '''Wraps a value inside a function call.
365
    Propagates any column renaming to the returned value.
366
    '''
367 2703 aaronmk
    return wrap(lambda v: FunctionCall(function, v), value)
368 2533 aaronmk
369 2561 aaronmk
def unwrap_func_call(func_call, check_name=None):
370
    '''Unwraps any function call to its first argument.
371
    Also removes any column renaming.
372
    '''
373
    func_call = remove_col_rename(func_call)
374
    if not isinstance(func_call, FunctionCall): return func_call
375
376
    if check_name != None:
377
        name = func_call.function.name
378
        assert name == None or name == check_name
379
    return func_call.args[0]
380
381 2986 aaronmk
##### Casts
382
383
class Cast(FunctionCall):
384
    def __init__(self, type_, value):
385
        value = as_Value(value)
386
387
        self.type_ = type_
388
        self.value = value
389
390
    def to_str(self, db):
391
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_+')'
392
393 2335 aaronmk
##### Conditions
394 2259 aaronmk
395 2398 aaronmk
class ColValueCond(Code):
396
    def __init__(self, col, value):
397
        value = as_ValueCond(value)
398
399
        self.col = col
400
        self.value = value
401
402
    def to_str(self, db): return self.value.to_str(db, self.col)
403
404 2577 aaronmk
def combine_conds(conds, keyword=None):
405
    '''
406
    @param keyword The keyword to add before the conditions, if any
407
    '''
408
    str_ = ''
409
    if keyword != None:
410
        if conds == []: whitespace = ''
411
        elif len(conds) == 1: whitespace = ' '
412
        else: whitespace = '\n'
413
        str_ += keyword+whitespace
414
415
    str_ += '\nAND '.join(conds)
416
    return str_
417
418 2398 aaronmk
##### Condition column comparisons
419
420 2514 aaronmk
class ValueCond(BasicObject):
421 2213 aaronmk
    def __init__(self, value):
422 2858 aaronmk
        value = remove_col_rename(as_Value(value))
423 2213 aaronmk
424
        self.value = value
425 2214 aaronmk
426 2216 aaronmk
    def to_str(self, db, left_value):
427 2214 aaronmk
        '''
428 2216 aaronmk
        @param left_value The Code object that the condition is being applied on
429 2214 aaronmk
        '''
430
        raise NotImplemented()
431 2228 aaronmk
432 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
433 2211 aaronmk
434
class CompareCond(ValueCond):
435
    def __init__(self, value, operator='='):
436 2222 aaronmk
        '''
437
        @param operator By default, compares NULL values literally. Use '~=' or
438
            '~!=' to pass NULLs through.
439
        '''
440 2211 aaronmk
        ValueCond.__init__(self, value)
441
        self.operator = operator
442
443 2216 aaronmk
    def to_str(self, db, left_value):
444 2858 aaronmk
        left_value = remove_col_rename(as_Col(left_value))
445 2216 aaronmk
446 2222 aaronmk
        right_value = self.value
447
448
        # Parse operator
449 2216 aaronmk
        operator = self.operator
450 2222 aaronmk
        passthru_null_ref = [False]
451
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
452
        neg_ref = [False]
453
        operator = strings.remove_prefix('!', operator, neg_ref)
454 2844 aaronmk
        equals = operator.endswith('=') # also includes <=, >=
455 2222 aaronmk
456 2825 aaronmk
        # Handle nullable columns
457
        check_null = False
458 2844 aaronmk
        if not passthru_null_ref[0]: # NULLs compare equal
459 2857 aaronmk
            try: left_value = ensure_not_null(db, left_value)
460 2844 aaronmk
            except ensure_not_null_excs: # fall back to alternate method
461
                check_null = equals and isinstance(right_value, Col)
462 2837 aaronmk
            else:
463 2857 aaronmk
                if isinstance(left_value, EnsureNotNull):
464
                    right_value = ensure_not_null(db, right_value,
465
                        left_value.type) # apply same function to both sides
466 2825 aaronmk
467 2844 aaronmk
        if equals and is_null(right_value): operator = 'IS'
468
469 2825 aaronmk
        left = left_value.to_str(db)
470
        right = right_value.to_str(db)
471
472 2222 aaronmk
        # Create str
473
        str_ = left+' '+operator+' '+right
474 2825 aaronmk
        if check_null:
475 2578 aaronmk
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
476
        if neg_ref[0]: str_ = 'NOT '+str_
477 2222 aaronmk
        return str_
478 2216 aaronmk
479 2260 aaronmk
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
480
assume_literal = object()
481
482
def as_ValueCond(value, default_table=assume_literal):
483
    if not isinstance(value, ValueCond):
484
        if default_table is not assume_literal:
485 2748 aaronmk
            value = with_default_table(value, default_table)
486 2260 aaronmk
        return CompareCond(value)
487 2216 aaronmk
    else: return value
488 2219 aaronmk
489 2335 aaronmk
##### Joins
490
491 2352 aaronmk
join_same = object() # tells Join the left and right columns have the same name
492 2260 aaronmk
493 2353 aaronmk
# Tells Join the left and right columns have the same name and are never NULL
494
join_same_not_null = object()
495
496 2260 aaronmk
filter_out = object() # tells Join to filter out rows that match the join
497
498 2514 aaronmk
class Join(BasicObject):
499 2746 aaronmk
    def __init__(self, table, mapping={}, type_=None):
500 2260 aaronmk
        '''
501
        @param mapping dict(right_table_col=left_table_col, ...)
502 2352 aaronmk
            * if left_table_col is join_same: left_table_col = right_table_col
503 2353 aaronmk
              * Note that right_table_col must be a string
504
            * if left_table_col is join_same_not_null:
505
              left_table_col = right_table_col and both have NOT NULL constraint
506
              * Note that right_table_col must be a string
507 2260 aaronmk
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
508
            * filter_out: equivalent to 'LEFT' with the query filtered by
509
              `table_pkey IS NULL` (indicating no match)
510
        '''
511
        if util.is_str(table): table = Table(table)
512
        assert type_ == None or util.is_str(type_) or type_ is filter_out
513
514
        self.table = table
515
        self.mapping = mapping
516
        self.type_ = type_
517
518 2749 aaronmk
    def to_str(self, db, left_table_):
519 2260 aaronmk
        def join(entry):
520
            '''Parses non-USING joins'''
521
            right_table_col, left_table_col = entry
522
523 2353 aaronmk
            # Switch order (right_table_col is on the left in the comparison)
524
            left = right_table_col
525
            right = left_table_col
526 2749 aaronmk
            left_table = self.table
527
            right_table = left_table_
528 2353 aaronmk
529 2747 aaronmk
            # Parse left side
530 2748 aaronmk
            left = with_default_table(left, left_table)
531 2747 aaronmk
532 2260 aaronmk
            # Parse special values
533 2747 aaronmk
            left_on_right = Col(left.name, right_table)
534
            if right is join_same: right = left_on_right
535 2353 aaronmk
            elif right is join_same_not_null:
536 2747 aaronmk
                right = CompareCond(left_on_right, '~=')
537 2260 aaronmk
538 2747 aaronmk
            # Parse right side
539 2353 aaronmk
            right = as_ValueCond(right, right_table)
540 2747 aaronmk
541
            return right.to_str(db, left)
542 2260 aaronmk
543 2265 aaronmk
        # Create join condition
544
        type_ = self.type_
545 2276 aaronmk
        joins = self.mapping
546 2746 aaronmk
        if joins == {}: join_cond = None
547
        elif type_ is not filter_out and reduce(operator.and_,
548 2460 aaronmk
            (v is join_same_not_null for v in joins.itervalues())):
549 2260 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
550 2747 aaronmk
            cols = map(to_name_only_col, joins.iterkeys())
551
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
552 2757 aaronmk
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
553 2260 aaronmk
554 2757 aaronmk
        if isinstance(self.table, NamedTable): whitespace = '\n'
555
        else: whitespace = ' '
556
557 2260 aaronmk
        # Create join
558
        if type_ is filter_out: type_ = 'LEFT'
559 2266 aaronmk
        str_ = ''
560
        if type_ != None: str_ += type_+' '
561 2757 aaronmk
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
562
        if join_cond != None: str_ += whitespace+join_cond
563 2266 aaronmk
        return str_
564 2349 aaronmk
565 2514 aaronmk
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
566 2424 aaronmk
567
##### Value exprs
568
569 2737 aaronmk
default = CustomCode('DEFAULT')
570
571 2424 aaronmk
row_count = CustomCode('count(*)')
572 2674 aaronmk
573 2850 aaronmk
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
574 2958 aaronmk
null_sentinels = {
575
    'character varying': r'\N',
576
    'double precision': 'NaN',
577
    'integer': 2147483647,
578
    'text': r'\N',
579
    'timestamp with time zone': 'infinity'
580
}
581 2692 aaronmk
582 2850 aaronmk
class EnsureNotNull(FunctionCall):
583
    def __init__(self, value, type_):
584 2870 aaronmk
        FunctionCall.__init__(self, InternalFunction('COALESCE'), as_Col(value),
585 2988 aaronmk
            Cast(type_, null_sentinels[type_]))
586 2850 aaronmk
587
        self.type = type_
588 3001 aaronmk
589
    def to_str(self, db):
590
        col = self.args[0]
591
        index_col_ = index_col(col)
592
        if index_col_ != None: return index_col_.to_str(db)
593
        return FunctionCall.to_str(self, db)
594 2850 aaronmk
595 2737 aaronmk
##### Table exprs
596
597
class Values(Code):
598
    def __init__(self, values):
599 2739 aaronmk
        '''
600
        @param values [...]|[[...], ...] Can be one or multiple rows.
601
        '''
602
        rows = values
603
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
604
            rows = [values]
605
        for i, row in enumerate(rows):
606
            rows[i] = map(remove_col_rename, map(as_Value, row))
607 2737 aaronmk
608 2739 aaronmk
        self.rows = rows
609 2737 aaronmk
610
    def to_str(self, db):
611 2739 aaronmk
        def row_str(row):
612
            return '('+(', '.join((v.to_str(db) for v in row)))+')'
613
        return 'VALUES '+(', '.join(map(row_str, self.rows)))
614 2737 aaronmk
615 2740 aaronmk
def NamedValues(name, cols, values):
616 2745 aaronmk
    '''
617
    @post `cols` will be changed to Col objects with the table set to `name`.
618
    '''
619 2834 aaronmk
    table = NamedTable(name, Values(values), cols)
620
    set_cols_table(table, cols)
621
    return table
622 2740 aaronmk
623 2674 aaronmk
##### Database structure
624
625
class TypedCol(Col):
626 2871 aaronmk
    def __init__(self, name, type_, default=None, nullable=True,
627
        constraints=None):
628 2818 aaronmk
        assert default == None or isinstance(default, Code)
629
630 2674 aaronmk
        Col.__init__(self, name)
631
632
        self.type = type_
633 2818 aaronmk
        self.default = default
634
        self.nullable = nullable
635 2871 aaronmk
        self.constraints = constraints
636 2674 aaronmk
637 2818 aaronmk
    def to_str(self, db):
638
        str_ = Col.to_str(self, db)+' '+self.type
639
        if not self.nullable: str_ += ' NOT NULL'
640
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
641 2871 aaronmk
        if self.constraints != None: str_ += ' '+self.constraints
642 2818 aaronmk
        return str_
643 2674 aaronmk
644
    def to_Col(self): return Col(self.name)
645 2822 aaronmk
646 2840 aaronmk
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
647
648 2851 aaronmk
def ensure_not_null(db, col, type_=None):
649 2840 aaronmk
    '''
650 2855 aaronmk
    @param col If type_ is not set, must have an underlying column.
651 2851 aaronmk
    @param type_ If set, overrides the underlying column's type.
652 2840 aaronmk
    @return EnsureNotNull|Col
653
    @throws ensure_not_null_excs
654
    '''
655 2855 aaronmk
    nullable = True
656
    try: typed_col = db.col_info(underlying_col(col))
657
    except NoUnderlyingTableException:
658
        if type_ == None: raise
659
    else:
660
        if type_ == None: type_ = typed_col.type
661
        nullable = typed_col.nullable
662
663 2953 aaronmk
    if nullable:
664
        try: col = EnsureNotNull(col, type_)
665
        except KeyError, e:
666
            # Warn of no null sentinel for type, even if caller catches error
667
            warnings.warn(UserWarning(exc.str_(e)))
668
            raise
669
670 2840 aaronmk
    return col