Project

General

Profile

1
# SQL code generation
2

    
3
import copy
4
import operator
5
from ordereddict import OrderedDict
6
import re
7
import UserDict
8
import warnings
9

    
10
import dicts
11
import exc
12
import iters
13
import lists
14
import objects
15
import strings
16
import util
17

    
18
##### Names
19

    
20
identifier_max_len = 63 # works for both PostgreSQL and MySQL
21

    
22
def concat(str_, suffix):
23
    '''Preserves version so that it won't be truncated off the string, leading
24
    to collisions.'''
25
    # Preserve version
26
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
27
    if match:
28
        str_, old_suffix = match.groups()
29
        suffix = old_suffix+suffix
30
    
31
    return strings.concat(str_, suffix, identifier_max_len)
32

    
33
def truncate(str_): return concat(str_, '')
34

    
35
def is_safe_name(name):
36
    '''A name is safe *and unambiguous* if it:
37
    * contains only *lowercase* word (\w) characters
38
    * doesn't start with a digit
39
    * contains "_", so that it's not a keyword
40
    '''
41
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
42

    
43
def esc_name(name, quote='"'):
44
    return quote + name.replace(quote, quote+quote) + quote
45
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
46

    
47
def unesc_name(name, quote='"'):
48
    removed_ref = [False]
49
    name = strings.remove_prefix(quote, name, removed_ref)
50
    if removed_ref[0]:
51
        name = strings.remove_suffix(quote, name, removed_ref)
52
        assert removed_ref[0]
53
        name = name.replace(quote+quote, quote)
54
    return name
55

    
56
def clean_name(name): return name.replace('"', '').replace('`', '')
57

    
58
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
59

    
60
def lstrip(str_):
61
    '''Also removes comments.'''
62
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
63
    return str_.lstrip()
64

    
65
##### General SQL code objects
66

    
67
class MockDb:
68
    def esc_value(self, value): return strings.repr_no_u(value)
69
    
70
    def esc_name(self, name): return esc_name(name)
71
    
72
    def col_info(self, col):
73
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
74

    
75
mockDb = MockDb()
76

    
77
class BasicObject(objects.BasicObject):
78
    def __str__(self): return clean_name(strings.repr_no_u(self))
79

    
80
##### Unparameterized code objects
81

    
82
class Code(BasicObject):
83
    def __init__(self, lang='sql'):
84
        self.lang = lang
85
    
86
    def to_str(self, db): raise NotImplementedError()
87
    
88
    def __repr__(self): return self.to_str(mockDb)
89

    
90
class CustomCode(Code):
91
    def __init__(self, str_):
92
        Code.__init__(self)
93
        
94
        self.str_ = str_
95
    
96
    def to_str(self, db): return self.str_
97

    
98
def as_Code(value, db=None):
99
    '''
100
    @param db If set, runs db.std_code() on the value.
101
    '''
102
    if isinstance(value, Code): return value
103
    
104
    if util.is_str(value):
105
        if db != None: value = db.std_code(value)
106
        return CustomCode(value)
107
    else: return Literal(value)
108

    
109
class Expr(Code):
110
    def __init__(self, expr):
111
        Code.__init__(self)
112
        
113
        self.expr = expr
114
    
115
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
116

    
117
##### Names
118

    
119
class Name(Code):
120
    def __init__(self, name):
121
        Code.__init__(self)
122
        
123
        name = truncate(name)
124
        
125
        self.name = name
126
    
127
    def to_str(self, db): return db.esc_name(self.name)
128

    
129
def as_Name(value):
130
    if isinstance(value, Code): return value
131
    else: return Name(value)
132

    
133
##### Literal values
134

    
135
class Literal(Code):
136
    def __init__(self, value):
137
        Code.__init__(self)
138
        
139
        self.value = value
140
    
141
    def to_str(self, db): return db.esc_value(self.value)
142

    
143
def as_Value(value):
144
    if isinstance(value, Code): return value
145
    else: return Literal(value)
146

    
147
def is_literal(value): return isinstance(value, Literal)
148

    
149
def is_null(value): return is_literal(value) and value.value == None
150

    
151
##### Derived elements
152

    
153
src_self = object() # tells Col that it is its own source column
154

    
155
class Derived(Code):
156
    def __init__(self, srcs):
157
        '''An element which was derived from some other element(s).
158
        @param srcs See self.set_srcs()
159
        '''
160
        Code.__init__(self)
161
        
162
        self.set_srcs(srcs)
163
    
164
    def set_srcs(self, srcs, overwrite=True):
165
        '''
166
        @param srcs (self_type...)|src_self The element(s) this is derived from
167
        '''
168
        if not overwrite and self.srcs != (): return # already set
169
        
170
        if srcs == src_self: srcs = (self,)
171
        srcs = tuple(srcs) # make Col hashable
172
        self.srcs = srcs
173
    
174
    def _compare_on(self):
175
        compare_on = self.__dict__.copy()
176
        del compare_on['srcs'] # ignore
177
        return compare_on
178

    
179
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
180

    
181
##### Tables
182

    
183
class Table(Derived):
184
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
185
        '''
186
        @param schema str|None (for no schema)
187
        @param srcs (Table...)|src_self See Derived.set_srcs()
188
        '''
189
        Derived.__init__(self, srcs)
190
        
191
        if util.is_str(name): name = truncate(name)
192
        
193
        self.name = name
194
        self.schema = schema
195
        self.is_temp = is_temp
196
        self.index_cols = {}
197
    
198
    def to_str(self, db):
199
        str_ = ''
200
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
201
        str_ += as_Name(self.name).to_str(db)
202
        return str_
203
    
204
    def to_Table(self): return self
205
    
206
    def _compare_on(self):
207
        compare_on = Derived._compare_on(self)
208
        del compare_on['index_cols'] # ignore
209
        return compare_on
210

    
211
def is_underlying_table(table):
212
    return isinstance(table, Table) and table.to_Table() is table
213

    
214
class NoUnderlyingTableException(Exception): pass
215

    
216
def underlying_table(table):
217
    table = remove_table_rename(table)
218
    if not is_underlying_table(table): raise NoUnderlyingTableException
219
    return table
220

    
221
def as_Table(table, schema=None):
222
    if table == None or isinstance(table, Code): return table
223
    else: return Table(table, schema)
224

    
225
def suffixed_table(table, suffix):
226
    table = copy.copy(table) # don't modify input!
227
    table.name = concat(table.name, suffix)
228
    return table
229

    
230
class NamedTable(Table):
231
    def __init__(self, name, code, cols=None):
232
        Table.__init__(self, name)
233
        
234
        code = as_Table(code)
235
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
236
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
237
        
238
        self.code = code
239
        self.cols = cols
240
    
241
    def to_str(self, db):
242
        str_ = self.code.to_str(db)
243
        if str_.find('\n') >= 0: whitespace = '\n'
244
        else: whitespace = ' '
245
        str_ += whitespace+'AS '+Table.to_str(self, db)
246
        if self.cols != None:
247
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
248
        return str_
249
    
250
    def to_Table(self): return Table(self.name)
251

    
252
def remove_table_rename(table):
253
    if isinstance(table, NamedTable): table = table.code
254
    return table
255

    
256
##### Columns
257

    
258
class Col(Derived):
259
    def __init__(self, name, table=None, srcs=()):
260
        '''
261
        @param table Table|None (for no table)
262
        @param srcs (Col...)|src_self See Derived.set_srcs()
263
        '''
264
        Derived.__init__(self, srcs)
265
        
266
        if util.is_str(name): name = truncate(name)
267
        if util.is_str(table): table = Table(table)
268
        assert table == None or isinstance(table, Table)
269
        
270
        self.name = name
271
        self.table = table
272
    
273
    def to_str(self, db, for_str=False):
274
        str_ = as_Name(self.name).to_str(db)
275
        if for_str: str_ = clean_name(str_)
276
        if self.table != None:
277
            table = self.table.to_Table()
278
            if for_str: str_ = concat(str(table), '.'+str_)
279
            else: str_ = table.to_str(db)+'.'+str_
280
        return str_
281
    
282
    def __str__(self): return self.to_str(mockDb, for_str=True)
283
    
284
    def to_Col(self): return self
285

    
286
def is_table_col(col): return isinstance(col, Col) and col.table != None
287

    
288
def index_col(col):
289
    if not is_table_col(col): return None
290
    
291
    table = col.table
292
    try: name = table.index_cols[col.name]
293
    except KeyError: return None
294
    else: return Col(name, table, col.srcs)
295

    
296
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
297

    
298
def as_Col(col, table=None, name=None):
299
    '''
300
    @param name If not None, any non-Col input will be renamed using NamedCol.
301
    '''
302
    if name != None:
303
        col = as_Value(col)
304
        if not isinstance(col, Col): col = NamedCol(name, col)
305
    
306
    if isinstance(col, Code): return col
307
    else: return Col(col, table)
308

    
309
def with_table(col, table):
310
    if isinstance(col, NamedCol): pass # doesn't take a table
311
    elif isinstance(col, FunctionCall):
312
        col = copy.deepcopy(col) # don't modify input!
313
        col.args[0].table = table
314
    else:
315
        col = copy.copy(col) # don't modify input!
316
        col.table = table
317
    return col
318

    
319
def with_default_table(col, table):
320
    col = as_Col(col)
321
    if col.table == None: col = with_table(col, table)
322
    return col
323

    
324
def set_cols_table(table, cols):
325
    table = as_Table(table)
326
    
327
    for i, col in enumerate(cols):
328
        col = cols[i] = as_Col(col)
329
        col.table = table
330

    
331
def to_name_only_col(col, check_table=None):
332
    col = as_Col(col)
333
    if not is_table_col(col): return col
334
    
335
    if check_table != None:
336
        table = col.table
337
        assert table == None or table == check_table
338
    return Col(col.name)
339

    
340
def suffixed_col(col, suffix):
341
    return Col(concat(col.name, suffix), col.table, col.srcs)
342

    
343
class NamedCol(Col):
344
    def __init__(self, name, code):
345
        Col.__init__(self, name)
346
        
347
        code = as_Value(code)
348
        
349
        self.code = code
350
    
351
    def to_str(self, db):
352
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
353
    
354
    def to_Col(self): return Col(self.name)
355

    
356
def remove_col_rename(col):
357
    if isinstance(col, NamedCol): col = col.code
358
    return col
359

    
360
def underlying_col(col):
361
    col = remove_col_rename(col)
362
    if not isinstance(col, Col): raise NoUnderlyingTableException
363
    
364
    return Col(col.name, underlying_table(col.table), col.srcs)
365

    
366
def wrap(wrap_func, value):
367
    '''Wraps a value, propagating any column renaming to the returned value.'''
368
    if isinstance(value, NamedCol):
369
        return NamedCol(value.name, wrap_func(value.code))
370
    else: return wrap_func(value)
371

    
372
class ColDict(dicts.DictProxy):
373
    '''A dict that automatically makes inserted entries Col objects'''
374
    
375
    def __init__(self, db, keys_table, dict_={}):
376
        dicts.DictProxy.__init__(self, OrderedDict())
377
        
378
        keys_table = as_Table(keys_table)
379
        
380
        self.db = db
381
        self.table = keys_table
382
        self.update(dict_) # after setting vars because __setitem__() needs them
383
    
384
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
385
    
386
    def __getitem__(self, key):
387
        return dicts.DictProxy.__getitem__(self, self._key(key))
388
    
389
    def __setitem__(self, key, value):
390
        key = self._key(key)
391
        if value == None: value = self.db.col_info(key).default
392
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
393
    
394
    def _key(self, key): return as_Col(key, self.table)
395

    
396
##### Functions
397

    
398
Function = Table
399
as_Function = as_Table
400

    
401
class InternalFunction(CustomCode): pass
402

    
403
#### Calls
404

    
405
class NamedArg(NamedCol):
406
    def __init__(self, name, value):
407
        NamedCol.__init__(self, name, value)
408
    
409
    def to_str(self, db):
410
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
411

    
412
class FunctionCall(Code):
413
    def __init__(self, function, *args, **kw_args):
414
        '''
415
        @param args [Code|literal-value...] The function's arguments
416
        '''
417
        Code.__init__(self)
418
        
419
        function = as_Function(function)
420
        def filter_(arg): return remove_col_rename(as_Value(arg))
421
        args = map(filter_, args)
422
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
423
        
424
        self.function = function
425
        self.args = args
426
    
427
    def to_str(self, db):
428
        args_str = ', '.join((v.to_str(db) for v in self.args))
429
        return self.function.to_str(db)+'('+args_str+')'
430

    
431
def wrap_in_func(function, value):
432
    '''Wraps a value inside a function call.
433
    Propagates any column renaming to the returned value.
434
    '''
435
    return wrap(lambda v: FunctionCall(function, v), value)
436

    
437
def unwrap_func_call(func_call, check_name=None):
438
    '''Unwraps any function call to its first argument.
439
    Also removes any column renaming.
440
    '''
441
    func_call = remove_col_rename(func_call)
442
    if not isinstance(func_call, FunctionCall): return func_call
443
    
444
    if check_name != None:
445
        name = func_call.function.name
446
        assert name == None or name == check_name
447
    return func_call.args[0]
448

    
449
#### Definitions
450

    
451
class FunctionDef(Code):
452
    def __init__(self, function, return_type, body, args=[], modifiers=None):
453
        Code.__init__(self)
454
        
455
        body = as_Code(body)
456
        
457
        self.function = function
458
        self.return_type = return_type
459
        self.body = body
460
        self.args = args
461
        self.modifiers = modifiers
462
    
463
    def to_str(self, db):
464
        str_ = '''\
465
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+(', '.join(self.args))+''')
466
RETURNS '''+self.return_type+'''
467
LANGUAGE '''+self.body.lang+'''
468
'''
469
        if self.modifiers != None: str_ += self.modifiers+'\n'
470
        str_ += '''\
471
AS $$
472
'''+self.body.to_str(db)+'''
473
$$;
474
'''
475
        return str_
476

    
477
### PL/pgSQL
478

    
479
class ExcHandler(BasicObject):
480
    def __init__(self, exc, handler=None):
481
        if handler != None: handler = as_Code(handler)
482
        
483
        self.exc = exc
484
        self.handler = handler
485
    
486
    def to_str(self, db, body):
487
        body = as_Code(body)
488
        
489
        if self.handler != None: handler_str = '\n'+self.handler.to_str(db)
490
        else: handler_str = ' NULL;\n'
491
        
492
        str_ = '''\
493
BEGIN
494
'''+body.to_str(db)+'''\
495
EXCEPTION
496
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
497
END;\
498
'''
499
        return str_
500
    
501
    def __repr__(self): return self.to_str(mockDb, '<body>')
502

    
503
unique_violation_handler = ExcHandler('unique_violation')
504

    
505
class RowExcIgnore(Code):
506
    def __init__(self, row_type, select_query, with_row, cols=None,
507
        exc_handler=unique_violation_handler, row_var='row'):
508
        Code.__init__(self, lang='plpgsql')
509
        
510
        select_query = as_Code(select_query)
511
        with_row = as_Code(with_row)
512
        row_var = as_Table(row_var)
513
        
514
        self.row_type = row_type
515
        self.select_query = select_query
516
        self.with_row = with_row
517
        self.cols = cols
518
        self.exc_handler = exc_handler
519
        self.row_var = row_var
520
    
521
    def to_str(self, db):
522
        if self.cols == None: row_vars = [self.row_var]
523
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
524
        
525
        str_ = '''\
526
DECLARE
527
    '''+self.row_var.to_str(db)+''' '''+self.row_type+''';
528
BEGIN
529
    /* Need an EXCEPTION block for each individual row because "When an error is
530
    caught by an EXCEPTION clause, [...] all changes to persistent database
531
    state within the block are rolled back."
532
    This is unfortunate because "A block containing an EXCEPTION clause is
533
    significantly more expensive to enter and exit than a block without one."
534
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
535
#PLPGSQL-ERROR-TRAPPING)
536
    */
537
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
538
'''+self.select_query.to_str(db)+'''
539
    LOOP
540
'''+self.exc_handler.to_str(db, self.with_row)+'''
541
    END LOOP;
542
END;\
543
'''
544
        return str_
545

    
546
##### Casts
547

    
548
class Cast(FunctionCall):
549
    def __init__(self, type_, value):
550
        value = as_Value(value)
551
        
552
        self.type_ = type_
553
        self.value = value
554
    
555
    def to_str(self, db):
556
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_+')'
557

    
558
def cast_literal(value):
559
    if not is_literal(value): return value
560
    
561
    if util.is_str(value.value): value = Cast('text', value)
562
    return value
563

    
564
##### Conditions
565

    
566
class NotCond(Code):
567
    def __init__(self, cond):
568
        Code.__init__(self)
569
        
570
        self.cond = cond
571
    
572
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
573

    
574
class ColValueCond(Code):
575
    def __init__(self, col, value):
576
        Code.__init__(self)
577
        
578
        value = as_ValueCond(value)
579
        
580
        self.col = col
581
        self.value = value
582
    
583
    def to_str(self, db): return self.value.to_str(db, self.col)
584

    
585
def combine_conds(conds, keyword=None):
586
    '''
587
    @param keyword The keyword to add before the conditions, if any
588
    '''
589
    str_ = ''
590
    if keyword != None:
591
        if conds == []: whitespace = ''
592
        elif len(conds) == 1: whitespace = ' '
593
        else: whitespace = '\n'
594
        str_ += keyword+whitespace
595
    
596
    str_ += '\nAND '.join(conds)
597
    return str_
598

    
599
##### Condition column comparisons
600

    
601
class ValueCond(BasicObject):
602
    def __init__(self, value):
603
        value = remove_col_rename(as_Value(value))
604
        
605
        self.value = value
606
    
607
    def to_str(self, db, left_value):
608
        '''
609
        @param left_value The Code object that the condition is being applied on
610
        '''
611
        raise NotImplemented()
612
    
613
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
614

    
615
class CompareCond(ValueCond):
616
    def __init__(self, value, operator='='):
617
        '''
618
        @param operator By default, compares NULL values literally. Use '~=' or
619
            '~!=' to pass NULLs through.
620
        '''
621
        ValueCond.__init__(self, value)
622
        self.operator = operator
623
    
624
    def to_str(self, db, left_value):
625
        left_value = remove_col_rename(as_Col(left_value))
626
        
627
        right_value = self.value
628
        
629
        # Parse operator
630
        operator = self.operator
631
        passthru_null_ref = [False]
632
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
633
        neg_ref = [False]
634
        operator = strings.remove_prefix('!', operator, neg_ref)
635
        equals = operator.endswith('=') # also includes <=, >=
636
        
637
        # Handle nullable columns
638
        check_null = False
639
        if not passthru_null_ref[0]: # NULLs compare equal
640
            try: left_value = ensure_not_null(db, left_value)
641
            except ensure_not_null_excs: # fall back to alternate method
642
                check_null = equals and isinstance(right_value, Col)
643
            else:
644
                if isinstance(left_value, EnsureNotNull):
645
                    right_value = ensure_not_null(db, right_value,
646
                        left_value.type) # apply same function to both sides
647
        
648
        if equals and is_null(right_value): operator = 'IS'
649
        
650
        left = left_value.to_str(db)
651
        right = right_value.to_str(db)
652
        
653
        # Create str
654
        str_ = left+' '+operator+' '+right
655
        if check_null:
656
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
657
        if neg_ref[0]: str_ = 'NOT '+str_
658
        return str_
659

    
660
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
661
assume_literal = object()
662

    
663
def as_ValueCond(value, default_table=assume_literal):
664
    if not isinstance(value, ValueCond):
665
        if default_table is not assume_literal:
666
            value = with_default_table(value, default_table)
667
        return CompareCond(value)
668
    else: return value
669

    
670
##### Joins
671

    
672
join_same = object() # tells Join the left and right columns have the same name
673

    
674
# Tells Join the left and right columns have the same name and are never NULL
675
join_same_not_null = object()
676

    
677
filter_out = object() # tells Join to filter out rows that match the join
678

    
679
class Join(BasicObject):
680
    def __init__(self, table, mapping={}, type_=None):
681
        '''
682
        @param mapping dict(right_table_col=left_table_col, ...)
683
            * if left_table_col is join_same: left_table_col = right_table_col
684
              * Note that right_table_col must be a string
685
            * if left_table_col is join_same_not_null:
686
              left_table_col = right_table_col and both have NOT NULL constraint
687
              * Note that right_table_col must be a string
688
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
689
            * filter_out: equivalent to 'LEFT' with the query filtered by
690
              `table_pkey IS NULL` (indicating no match)
691
        '''
692
        if util.is_str(table): table = Table(table)
693
        assert type_ == None or util.is_str(type_) or type_ is filter_out
694
        
695
        self.table = table
696
        self.mapping = mapping
697
        self.type_ = type_
698
    
699
    def to_str(self, db, left_table_):
700
        def join(entry):
701
            '''Parses non-USING joins'''
702
            right_table_col, left_table_col = entry
703
            
704
            # Switch order (right_table_col is on the left in the comparison)
705
            left = right_table_col
706
            right = left_table_col
707
            left_table = self.table
708
            right_table = left_table_
709
            
710
            # Parse left side
711
            left = with_default_table(left, left_table)
712
            
713
            # Parse special values
714
            left_on_right = Col(left.name, right_table)
715
            if right is join_same: right = left_on_right
716
            elif right is join_same_not_null:
717
                right = CompareCond(left_on_right, '~=')
718
            
719
            # Parse right side
720
            right = as_ValueCond(right, right_table)
721
            
722
            return right.to_str(db, left)
723
        
724
        # Create join condition
725
        type_ = self.type_
726
        joins = self.mapping
727
        if joins == {}: join_cond = None
728
        elif type_ is not filter_out and reduce(operator.and_,
729
            (v is join_same_not_null for v in joins.itervalues())):
730
            # all cols w/ USING, so can use simpler USING syntax
731
            cols = map(to_name_only_col, joins.iterkeys())
732
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
733
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
734
        
735
        if isinstance(self.table, NamedTable): whitespace = '\n'
736
        else: whitespace = ' '
737
        
738
        # Create join
739
        if type_ is filter_out: type_ = 'LEFT'
740
        str_ = ''
741
        if type_ != None: str_ += type_+' '
742
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
743
        if join_cond != None: str_ += whitespace+join_cond
744
        return str_
745
    
746
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
747

    
748
##### Value exprs
749

    
750
all_cols = CustomCode('*')
751

    
752
default = CustomCode('DEFAULT')
753

    
754
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
755

    
756
class Coalesce(FunctionCall):
757
    def __init__(self, *args):
758
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
759

    
760
class Nullif(FunctionCall):
761
    def __init__(self, *args):
762
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
763

    
764
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
765
null_sentinels = {
766
    'character varying': r'\N',
767
    'double precision': 'NaN',
768
    'integer': 2147483647,
769
    'text': r'\N',
770
    'timestamp with time zone': 'infinity'
771
}
772

    
773
class EnsureNotNull(Coalesce):
774
    def __init__(self, value, type_):
775
        Coalesce.__init__(self, as_Col(value),
776
            Cast(type_, null_sentinels[type_]))
777
        
778
        self.type = type_
779
    
780
    def to_str(self, db):
781
        col = self.args[0]
782
        index_col_ = index_col(col)
783
        if index_col_ != None: return index_col_.to_str(db)
784
        return Coalesce.to_str(self, db)
785

    
786
##### Table exprs
787

    
788
class Values(Code):
789
    def __init__(self, values):
790
        '''
791
        @param values [...]|[[...], ...] Can be one or multiple rows.
792
        '''
793
        Code.__init__(self)
794
        
795
        rows = values
796
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
797
            rows = [values]
798
        for i, row in enumerate(rows):
799
            rows[i] = map(remove_col_rename, map(as_Value, row))
800
        
801
        self.rows = rows
802
    
803
    def to_str(self, db):
804
        def row_str(row):
805
            return '('+(', '.join((v.to_str(db) for v in row)))+')'
806
        return 'VALUES '+(', '.join(map(row_str, self.rows)))
807

    
808
def NamedValues(name, cols, values):
809
    '''
810
    @param cols None|[...]
811
    @post `cols` will be changed to Col objects with the table set to `name`.
812
    '''
813
    table = NamedTable(name, Values(values), cols)
814
    if cols != None: set_cols_table(table, cols)
815
    return table
816

    
817
##### Database structure
818

    
819
class TypedCol(Col):
820
    def __init__(self, name, type_, default=None, nullable=True,
821
        constraints=None):
822
        assert default == None or isinstance(default, Code)
823
        
824
        Col.__init__(self, name)
825
        
826
        self.type = type_
827
        self.default = default
828
        self.nullable = nullable
829
        self.constraints = constraints
830
    
831
    def to_str(self, db):
832
        str_ = Col.to_str(self, db)+' '+self.type
833
        if not self.nullable: str_ += ' NOT NULL'
834
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
835
        if self.constraints != None: str_ += ' '+self.constraints
836
        return str_
837
    
838
    def to_Col(self): return Col(self.name)
839

    
840
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
841

    
842
def ensure_not_null(db, col, type_=None):
843
    '''
844
    @param col If type_ is not set, must have an underlying column.
845
    @param type_ If set, overrides the underlying column's type.
846
    @return EnsureNotNull|Col
847
    @throws ensure_not_null_excs
848
    '''
849
    nullable = True
850
    try: typed_col = db.col_info(underlying_col(col))
851
    except NoUnderlyingTableException:
852
        col = remove_col_rename(col)
853
        if is_literal(col) and not is_null(col): nullable = False
854
        elif type_ == None: raise
855
    else:
856
        if type_ == None: type_ = typed_col.type
857
        nullable = typed_col.nullable
858
    
859
    if nullable:
860
        try: col = EnsureNotNull(col, type_)
861
        except KeyError, e:
862
            # Warn of no null sentinel for type, even if caller catches error
863
            warnings.warn(UserWarning(exc.str_(e)))
864
            raise
865
    
866
    return col
(25-25/37)