Project

General

Profile

1
# SQL code generation
2

    
3
import copy
4
import operator
5
from ordereddict import OrderedDict
6
import re
7
import UserDict
8
import warnings
9

    
10
import dicts
11
import exc
12
import iters
13
import lists
14
import objects
15
import strings
16
import util
17

    
18
##### Names
19

    
20
identifier_max_len = 63 # works for both PostgreSQL and MySQL
21

    
22
def concat(str_, suffix):
23
    '''Preserves version so that it won't be truncated off the string, leading
24
    to collisions.'''
25
    # Preserve version
26
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
27
    if match:
28
        str_, old_suffix = match.groups()
29
        suffix = old_suffix+suffix
30
    
31
    return strings.concat(str_, suffix, identifier_max_len)
32

    
33
def truncate(str_): return concat(str_, '')
34

    
35
def is_safe_name(name):
36
    '''A name is safe *and unambiguous* if it:
37
    * contains only *lowercase* word (\w) characters
38
    * doesn't start with a digit
39
    * contains "_", so that it's not a keyword
40
    '''
41
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
42

    
43
def esc_name(name, quote='"'):
44
    return quote + name.replace(quote, quote+quote) + quote
45
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
46

    
47
def unesc_name(name, quote='"'):
48
    removed_ref = [False]
49
    name = strings.remove_prefix(quote, name, removed_ref)
50
    if removed_ref[0]:
51
        name = strings.remove_suffix(quote, name, removed_ref)
52
        assert removed_ref[0]
53
        name = name.replace(quote+quote, quote)
54
    return name
55

    
56
def clean_name(name): return name.replace('"', '').replace('`', '')
57

    
58
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
59

    
60
def lstrip(str_):
61
    '''Also removes comments.'''
62
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
63
    return str_.lstrip()
64

    
65
##### General SQL code objects
66

    
67
class MockDb:
68
    def esc_value(self, value): return strings.repr_no_u(value)
69
    
70
    def esc_name(self, name): return esc_name(name)
71
    
72
    def col_info(self, col):
73
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
74

    
75
mockDb = MockDb()
76

    
77
class BasicObject(objects.BasicObject):
78
    def __str__(self): return clean_name(strings.repr_no_u(self))
79

    
80
##### Unparameterized code objects
81

    
82
class Code(BasicObject):
83
    def __init__(self, lang='sql'):
84
        self.lang = lang
85
    
86
    def to_str(self, db): raise NotImplementedError()
87
    
88
    def __repr__(self): return self.to_str(mockDb)
89

    
90
class CustomCode(Code):
91
    def __init__(self, str_):
92
        Code.__init__(self)
93
        
94
        self.str_ = str_
95
    
96
    def to_str(self, db): return self.str_
97

    
98
def as_Code(value, db=None):
99
    '''
100
    @param db If set, runs db.std_code() on the value.
101
    '''
102
    if isinstance(value, Code): return value
103
    
104
    if util.is_str(value):
105
        if db != None: value = db.std_code(value)
106
        return CustomCode(value)
107
    else: return Literal(value)
108

    
109
class Expr(Code):
110
    def __init__(self, expr):
111
        Code.__init__(self)
112
        
113
        self.expr = expr
114
    
115
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
116

    
117
##### Names
118

    
119
class Name(Code):
120
    def __init__(self, name):
121
        Code.__init__(self)
122
        
123
        name = truncate(name)
124
        
125
        self.name = name
126
    
127
    def to_str(self, db): return db.esc_name(self.name)
128

    
129
def as_Name(value):
130
    if isinstance(value, Code): return value
131
    else: return Name(value)
132

    
133
##### Literal values
134

    
135
class Literal(Code):
136
    def __init__(self, value):
137
        Code.__init__(self)
138
        
139
        self.value = value
140
    
141
    def to_str(self, db): return db.esc_value(self.value)
142

    
143
def as_Value(value):
144
    if isinstance(value, Code): return value
145
    else: return Literal(value)
146

    
147
def is_literal(value): return isinstance(value, Literal)
148

    
149
def is_null(value): return is_literal(value) and value.value == None
150

    
151
##### Derived elements
152

    
153
src_self = object() # tells Col that it is its own source column
154

    
155
class Derived(Code):
156
    def __init__(self, srcs):
157
        '''An element which was derived from some other element(s).
158
        @param srcs See self.set_srcs()
159
        '''
160
        Code.__init__(self)
161
        
162
        self.set_srcs(srcs)
163
    
164
    def set_srcs(self, srcs, overwrite=True):
165
        '''
166
        @param srcs (self_type...)|src_self The element(s) this is derived from
167
        '''
168
        if not overwrite and self.srcs != (): return # already set
169
        
170
        if srcs == src_self: srcs = (self,)
171
        srcs = tuple(srcs) # make Col hashable
172
        self.srcs = srcs
173
    
174
    def _compare_on(self):
175
        compare_on = self.__dict__.copy()
176
        del compare_on['srcs'] # ignore
177
        return compare_on
178

    
179
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
180

    
181
##### Tables
182

    
183
class Table(Derived):
184
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
185
        '''
186
        @param schema str|None (for no schema)
187
        @param srcs (Table...)|src_self See Derived.set_srcs()
188
        '''
189
        Derived.__init__(self, srcs)
190
        
191
        if util.is_str(name): name = truncate(name)
192
        
193
        self.name = name
194
        self.schema = schema
195
        self.is_temp = is_temp
196
        self.index_cols = {}
197
    
198
    def to_str(self, db):
199
        str_ = ''
200
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
201
        str_ += as_Name(self.name).to_str(db)
202
        return str_
203
    
204
    def to_Table(self): return self
205
    
206
    def _compare_on(self):
207
        compare_on = Derived._compare_on(self)
208
        del compare_on['index_cols'] # ignore
209
        return compare_on
210

    
211
def is_underlying_table(table):
212
    return isinstance(table, Table) and table.to_Table() is table
213

    
214
class NoUnderlyingTableException(Exception): pass
215

    
216
def underlying_table(table):
217
    table = remove_table_rename(table)
218
    if not is_underlying_table(table): raise NoUnderlyingTableException
219
    return table
220

    
221
def as_Table(table, schema=None):
222
    if table == None or isinstance(table, Code): return table
223
    else: return Table(table, schema)
224

    
225
def suffixed_table(table, suffix):
226
    table = copy.copy(table) # don't modify input!
227
    table.name = concat(table.name, suffix)
228
    return table
229

    
230
class NamedTable(Table):
231
    def __init__(self, name, code, cols=None):
232
        Table.__init__(self, name)
233
        
234
        code = as_Table(code)
235
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
236
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
237
        
238
        self.code = code
239
        self.cols = cols
240
    
241
    def to_str(self, db):
242
        str_ = self.code.to_str(db)
243
        if str_.find('\n') >= 0: whitespace = '\n'
244
        else: whitespace = ' '
245
        str_ += whitespace+'AS '+Table.to_str(self, db)
246
        if self.cols != None:
247
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
248
        return str_
249
    
250
    def to_Table(self): return Table(self.name)
251

    
252
def remove_table_rename(table):
253
    if isinstance(table, NamedTable): table = table.code
254
    return table
255

    
256
##### Columns
257

    
258
class Col(Derived):
259
    def __init__(self, name, table=None, srcs=()):
260
        '''
261
        @param table Table|None (for no table)
262
        @param srcs (Col...)|src_self See Derived.set_srcs()
263
        '''
264
        Derived.__init__(self, srcs)
265
        
266
        if util.is_str(name): name = truncate(name)
267
        if util.is_str(table): table = Table(table)
268
        assert table == None or isinstance(table, Table)
269
        
270
        self.name = name
271
        self.table = table
272
    
273
    def to_str(self, db, for_str=False):
274
        str_ = as_Name(self.name).to_str(db)
275
        if for_str: str_ = clean_name(str_)
276
        if self.table != None:
277
            table = self.table.to_Table()
278
            if for_str: str_ = concat(str(table), '.'+str_)
279
            else: str_ = table.to_str(db)+'.'+str_
280
        return str_
281
    
282
    def __str__(self): return self.to_str(mockDb, for_str=True)
283
    
284
    def to_Col(self): return self
285

    
286
def is_table_col(col): return isinstance(col, Col) and col.table != None
287

    
288
def index_col(col):
289
    if not is_table_col(col): return None
290
    
291
    table = col.table
292
    try: name = table.index_cols[col.name]
293
    except KeyError: return None
294
    else: return Col(name, table, col.srcs)
295

    
296
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
297

    
298
def as_Col(col, table=None, name=None):
299
    '''
300
    @param name If not None, any non-Col input will be renamed using NamedCol.
301
    '''
302
    if name != None:
303
        col = as_Value(col)
304
        if not isinstance(col, Col): col = NamedCol(name, col)
305
    
306
    if isinstance(col, Code): return col
307
    else: return Col(col, table)
308

    
309
def with_table(col, table):
310
    if isinstance(col, NamedCol): pass # doesn't take a table
311
    elif isinstance(col, FunctionCall):
312
        col = copy.deepcopy(col) # don't modify input!
313
        col.args[0].table = table
314
    else:
315
        col = copy.copy(col) # don't modify input!
316
        col.table = table
317
    return col
318

    
319
def with_default_table(col, table):
320
    col = as_Col(col)
321
    if col.table == None: col = with_table(col, table)
322
    return col
323

    
324
def set_cols_table(table, cols):
325
    table = as_Table(table)
326
    
327
    for i, col in enumerate(cols):
328
        col = cols[i] = as_Col(col)
329
        col.table = table
330

    
331
def to_name_only_col(col, check_table=None):
332
    col = as_Col(col)
333
    if not is_table_col(col): return col
334
    
335
    if check_table != None:
336
        table = col.table
337
        assert table == None or table == check_table
338
    return Col(col.name)
339

    
340
def suffixed_col(col, suffix):
341
    return Col(concat(col.name, suffix), col.table, col.srcs)
342

    
343
class NamedCol(Col):
344
    def __init__(self, name, code):
345
        Col.__init__(self, name)
346
        
347
        code = as_Value(code)
348
        
349
        self.code = code
350
    
351
    def to_str(self, db):
352
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
353
    
354
    def to_Col(self): return Col(self.name)
355

    
356
def remove_col_rename(col):
357
    if isinstance(col, NamedCol): col = col.code
358
    return col
359

    
360
def underlying_col(col):
361
    col = remove_col_rename(col)
362
    if not isinstance(col, Col): raise NoUnderlyingTableException
363
    
364
    return Col(col.name, underlying_table(col.table), col.srcs)
365

    
366
def wrap(wrap_func, value):
367
    '''Wraps a value, propagating any column renaming to the returned value.'''
368
    if isinstance(value, NamedCol):
369
        return NamedCol(value.name, wrap_func(value.code))
370
    else: return wrap_func(value)
371

    
372
class ColDict(dicts.DictProxy):
373
    '''A dict that automatically makes inserted entries Col objects'''
374
    
375
    def __init__(self, db, keys_table, dict_={}):
376
        dicts.DictProxy.__init__(self, OrderedDict())
377
        
378
        keys_table = as_Table(keys_table)
379
        
380
        self.db = db
381
        self.table = keys_table
382
        self.update(dict_) # after setting vars because __setitem__() needs them
383
    
384
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
385
    
386
    def __getitem__(self, key):
387
        return dicts.DictProxy.__getitem__(self, self._key(key))
388
    
389
    def __setitem__(self, key, value):
390
        key = self._key(key)
391
        if value == None: value = self.db.col_info(key).default
392
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
393
    
394
    def _key(self, key): return as_Col(key, self.table)
395

    
396
#### Definitions
397

    
398
class TypedCol(Col):
399
    def __init__(self, name, type_, default=None, nullable=True,
400
        constraints=None):
401
        assert default == None or isinstance(default, Code)
402
        
403
        Col.__init__(self, name)
404
        
405
        self.type = type_
406
        self.default = default
407
        self.nullable = nullable
408
        self.constraints = constraints
409
    
410
    def to_str(self, db):
411
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
412
        if not self.nullable: str_ += ' NOT NULL'
413
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
414
        if self.constraints != None: str_ += ' '+self.constraints
415
        return str_
416
    
417
    def to_Col(self): return Col(self.name)
418

    
419
class SetOf(Code):
420
    def __init__(self, type_):
421
        Code.__init__(self)
422
        
423
        self.type = type_
424
    
425
    def to_str(self, db):
426
        return 'SETOF '+self.type.to_str(db)
427

    
428
class RowType(Code):
429
    def __init__(self, table):
430
        Code.__init__(self)
431
        
432
        self.table = table
433
    
434
    def to_str(self, db):
435
        return self.table.to_str(db)+'%ROWTYPE'
436

    
437
class ColType(Code):
438
    def __init__(self, col):
439
        Code.__init__(self)
440
        
441
        self.col = col
442
    
443
    def to_str(self, db):
444
        return self.col.to_str(db)+'%TYPE'
445

    
446
##### Functions
447

    
448
Function = Table
449
as_Function = as_Table
450

    
451
class InternalFunction(CustomCode): pass
452

    
453
#### Calls
454

    
455
class NamedArg(NamedCol):
456
    def __init__(self, name, value):
457
        NamedCol.__init__(self, name, value)
458
    
459
    def to_str(self, db):
460
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
461

    
462
class FunctionCall(Code):
463
    def __init__(self, function, *args, **kw_args):
464
        '''
465
        @param args [Code|literal-value...] The function's arguments
466
        '''
467
        Code.__init__(self)
468
        
469
        function = as_Function(function)
470
        def filter_(arg): return remove_col_rename(as_Value(arg))
471
        args = map(filter_, args)
472
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
473
        
474
        self.function = function
475
        self.args = args
476
    
477
    def to_str(self, db):
478
        args_str = ', '.join((v.to_str(db) for v in self.args))
479
        return self.function.to_str(db)+'('+args_str+')'
480

    
481
def wrap_in_func(function, value):
482
    '''Wraps a value inside a function call.
483
    Propagates any column renaming to the returned value.
484
    '''
485
    return wrap(lambda v: FunctionCall(function, v), value)
486

    
487
def unwrap_func_call(func_call, check_name=None):
488
    '''Unwraps any function call to its first argument.
489
    Also removes any column renaming.
490
    '''
491
    func_call = remove_col_rename(func_call)
492
    if not isinstance(func_call, FunctionCall): return func_call
493
    
494
    if check_name != None:
495
        name = func_call.function.name
496
        assert name == None or name == check_name
497
    return func_call.args[0]
498

    
499
#### Definitions
500

    
501
class FunctionDef(Code):
502
    def __init__(self, function, return_type, body, params=[], modifiers=None):
503
        Code.__init__(self)
504
        
505
        return_type = as_Code(return_type)
506
        body = as_Code(body)
507
        
508
        self.function = function
509
        self.return_type = return_type
510
        self.body = body
511
        self.params = params
512
        self.modifiers = modifiers
513
    
514
    def to_str(self, db):
515
        params_str = (', '.join((p.to_str(db) for p in self.params)))
516
        str_ = '''\
517
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
518
RETURNS '''+self.return_type.to_str(db)+'''
519
LANGUAGE '''+self.body.lang+'''
520
'''
521
        if self.modifiers != None: str_ += self.modifiers+'\n'
522
        str_ += '''\
523
AS $$
524
'''+self.body.to_str(db)+'''
525
$$;
526
'''
527
        return str_
528

    
529
class FunctionParam(TypedCol):
530
    def __init__(self, name, type_, default=None, out=False):
531
        TypedCol.__init__(self, name, type_, default)
532
        
533
        self.out = out
534
    
535
    def to_str(self, db):
536
        str_ = TypedCol.to_str(self, db)
537
        if self.out: str_ = 'OUT '+str_
538
        return str_
539
    
540
    def to_Col(self): return Col(self.name)
541

    
542
### PL/pgSQL
543

    
544
class ExcHandler(BasicObject):
545
    def __init__(self, exc, handler=None):
546
        if handler != None: handler = as_Code(handler)
547
        
548
        self.exc = exc
549
        self.handler = handler
550
    
551
    def to_str(self, db, body):
552
        body = as_Code(body)
553
        
554
        if self.handler != None:
555
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
556
        else: handler_str = ' NULL;\n'
557
        
558
        str_ = '''\
559
BEGIN
560
'''+strings.indent(body.to_str(db))+'''\
561
EXCEPTION
562
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
563
END;\
564
'''
565
        return str_
566
    
567
    def __repr__(self): return self.to_str(mockDb, '<body>')
568

    
569
unique_violation_handler = ExcHandler('unique_violation')
570

    
571
plpythonu_error_handler = ExcHandler('internal_error', '''\
572
RAISE data_exception USING MESSAGE =
573
    regexp_replace(SQLERRM, E'^PL/Python: \\w+: ', '');
574
''')
575

    
576
class RowExcIgnore(Code):
577
    def __init__(self, row_type, select_query, with_row, cols=None,
578
        exc_handler=unique_violation_handler, row_var='row'):
579
        Code.__init__(self, lang='plpgsql')
580
        
581
        row_type = as_Code(row_type)
582
        select_query = as_Code(select_query)
583
        with_row = as_Code(with_row)
584
        row_var = as_Table(row_var)
585
        
586
        self.row_type = row_type
587
        self.select_query = select_query
588
        self.with_row = with_row
589
        self.cols = cols
590
        self.exc_handler = exc_handler
591
        self.row_var = row_var
592
    
593
    def to_str(self, db):
594
        if self.cols == None: row_vars = [self.row_var]
595
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
596
        
597
        str_ = '''\
598
DECLARE
599
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
600
BEGIN
601
    /* Need an EXCEPTION block for each individual row because "When
602
    an error is caught by an EXCEPTION clause, [...] all changes to
603
    persistent database state within the block are rolled back."
604
    This is unfortunate because "A block containing an EXCEPTION
605
    clause is significantly more expensive to enter and exit than a
606
    block without one."
607
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
608
#PLPGSQL-ERROR-TRAPPING)
609
    */
610
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
611
'''+strings.indent(self.select_query.to_str(db), 2)+'''\
612
    LOOP
613
'''+strings.indent(self.exc_handler.to_str(db, self.with_row), 2)+'''\
614
    END LOOP;
615
END;\
616
'''
617
        return str_
618

    
619
##### Casts
620

    
621
class Cast(FunctionCall):
622
    def __init__(self, type_, value):
623
        value = as_Value(value)
624
        
625
        self.type_ = type_
626
        self.value = value
627
    
628
    def to_str(self, db):
629
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_+')'
630

    
631
def cast_literal(value):
632
    if not is_literal(value): return value
633
    
634
    if util.is_str(value.value): value = Cast('text', value)
635
    return value
636

    
637
##### Conditions
638

    
639
class NotCond(Code):
640
    def __init__(self, cond):
641
        Code.__init__(self)
642
        
643
        self.cond = cond
644
    
645
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
646

    
647
class ColValueCond(Code):
648
    def __init__(self, col, value):
649
        Code.__init__(self)
650
        
651
        value = as_ValueCond(value)
652
        
653
        self.col = col
654
        self.value = value
655
    
656
    def to_str(self, db): return self.value.to_str(db, self.col)
657

    
658
def combine_conds(conds, keyword=None):
659
    '''
660
    @param keyword The keyword to add before the conditions, if any
661
    '''
662
    str_ = ''
663
    if keyword != None:
664
        if conds == []: whitespace = ''
665
        elif len(conds) == 1: whitespace = ' '
666
        else: whitespace = '\n'
667
        str_ += keyword+whitespace
668
    
669
    str_ += '\nAND '.join(conds)
670
    return str_
671

    
672
##### Condition column comparisons
673

    
674
class ValueCond(BasicObject):
675
    def __init__(self, value):
676
        value = remove_col_rename(as_Value(value))
677
        
678
        self.value = value
679
    
680
    def to_str(self, db, left_value):
681
        '''
682
        @param left_value The Code object that the condition is being applied on
683
        '''
684
        raise NotImplemented()
685
    
686
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
687

    
688
class CompareCond(ValueCond):
689
    def __init__(self, value, operator='='):
690
        '''
691
        @param operator By default, compares NULL values literally. Use '~=' or
692
            '~!=' to pass NULLs through.
693
        '''
694
        ValueCond.__init__(self, value)
695
        self.operator = operator
696
    
697
    def to_str(self, db, left_value):
698
        left_value = remove_col_rename(as_Col(left_value))
699
        
700
        right_value = self.value
701
        
702
        # Parse operator
703
        operator = self.operator
704
        passthru_null_ref = [False]
705
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
706
        neg_ref = [False]
707
        operator = strings.remove_prefix('!', operator, neg_ref)
708
        equals = operator.endswith('=') # also includes <=, >=
709
        
710
        # Handle nullable columns
711
        check_null = False
712
        if not passthru_null_ref[0]: # NULLs compare equal
713
            try: left_value = ensure_not_null(db, left_value)
714
            except ensure_not_null_excs: # fall back to alternate method
715
                check_null = equals and isinstance(right_value, Col)
716
            else:
717
                if isinstance(left_value, EnsureNotNull):
718
                    right_value = ensure_not_null(db, right_value,
719
                        left_value.type) # apply same function to both sides
720
        
721
        if equals and is_null(right_value): operator = 'IS'
722
        
723
        left = left_value.to_str(db)
724
        right = right_value.to_str(db)
725
        
726
        # Create str
727
        str_ = left+' '+operator+' '+right
728
        if check_null:
729
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
730
        if neg_ref[0]: str_ = 'NOT '+str_
731
        return str_
732

    
733
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
734
assume_literal = object()
735

    
736
def as_ValueCond(value, default_table=assume_literal):
737
    if not isinstance(value, ValueCond):
738
        if default_table is not assume_literal:
739
            value = with_default_table(value, default_table)
740
        return CompareCond(value)
741
    else: return value
742

    
743
##### Joins
744

    
745
join_same = object() # tells Join the left and right columns have the same name
746

    
747
# Tells Join the left and right columns have the same name and are never NULL
748
join_same_not_null = object()
749

    
750
filter_out = object() # tells Join to filter out rows that match the join
751

    
752
class Join(BasicObject):
753
    def __init__(self, table, mapping={}, type_=None):
754
        '''
755
        @param mapping dict(right_table_col=left_table_col, ...)
756
            * if left_table_col is join_same: left_table_col = right_table_col
757
              * Note that right_table_col must be a string
758
            * if left_table_col is join_same_not_null:
759
              left_table_col = right_table_col and both have NOT NULL constraint
760
              * Note that right_table_col must be a string
761
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
762
            * filter_out: equivalent to 'LEFT' with the query filtered by
763
              `table_pkey IS NULL` (indicating no match)
764
        '''
765
        if util.is_str(table): table = Table(table)
766
        assert type_ == None or util.is_str(type_) or type_ is filter_out
767
        
768
        self.table = table
769
        self.mapping = mapping
770
        self.type_ = type_
771
    
772
    def to_str(self, db, left_table_):
773
        def join(entry):
774
            '''Parses non-USING joins'''
775
            right_table_col, left_table_col = entry
776
            
777
            # Switch order (right_table_col is on the left in the comparison)
778
            left = right_table_col
779
            right = left_table_col
780
            left_table = self.table
781
            right_table = left_table_
782
            
783
            # Parse left side
784
            left = with_default_table(left, left_table)
785
            
786
            # Parse special values
787
            left_on_right = Col(left.name, right_table)
788
            if right is join_same: right = left_on_right
789
            elif right is join_same_not_null:
790
                right = CompareCond(left_on_right, '~=')
791
            
792
            # Parse right side
793
            right = as_ValueCond(right, right_table)
794
            
795
            return right.to_str(db, left)
796
        
797
        # Create join condition
798
        type_ = self.type_
799
        joins = self.mapping
800
        if joins == {}: join_cond = None
801
        elif type_ is not filter_out and reduce(operator.and_,
802
            (v is join_same_not_null for v in joins.itervalues())):
803
            # all cols w/ USING, so can use simpler USING syntax
804
            cols = map(to_name_only_col, joins.iterkeys())
805
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
806
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
807
        
808
        if isinstance(self.table, NamedTable): whitespace = '\n'
809
        else: whitespace = ' '
810
        
811
        # Create join
812
        if type_ is filter_out: type_ = 'LEFT'
813
        str_ = ''
814
        if type_ != None: str_ += type_+' '
815
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
816
        if join_cond != None: str_ += whitespace+join_cond
817
        return str_
818
    
819
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
820

    
821
##### Value exprs
822

    
823
all_cols = CustomCode('*')
824

    
825
default = CustomCode('DEFAULT')
826

    
827
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
828

    
829
class Coalesce(FunctionCall):
830
    def __init__(self, *args):
831
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
832

    
833
class Nullif(FunctionCall):
834
    def __init__(self, *args):
835
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
836

    
837
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
838
null_sentinels = {
839
    'character varying': r'\N',
840
    'double precision': 'NaN',
841
    'integer': 2147483647,
842
    'text': r'\N',
843
    'timestamp with time zone': 'infinity'
844
}
845

    
846
class EnsureNotNull(Coalesce):
847
    def __init__(self, value, type_):
848
        Coalesce.__init__(self, as_Col(value),
849
            Cast(type_, null_sentinels[type_]))
850
        
851
        self.type = type_
852
    
853
    def to_str(self, db):
854
        col = self.args[0]
855
        index_col_ = index_col(col)
856
        if index_col_ != None: return index_col_.to_str(db)
857
        return Coalesce.to_str(self, db)
858

    
859
##### Table exprs
860

    
861
class Values(Code):
862
    def __init__(self, values):
863
        '''
864
        @param values [...]|[[...], ...] Can be one or multiple rows.
865
        '''
866
        Code.__init__(self)
867
        
868
        rows = values
869
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
870
            rows = [values]
871
        for i, row in enumerate(rows):
872
            rows[i] = map(remove_col_rename, map(as_Value, row))
873
        
874
        self.rows = rows
875
    
876
    def to_str(self, db):
877
        def row_str(row):
878
            return '('+(', '.join((v.to_str(db) for v in row)))+')'
879
        return 'VALUES '+(', '.join(map(row_str, self.rows)))
880

    
881
def NamedValues(name, cols, values):
882
    '''
883
    @param cols None|[...]
884
    @post `cols` will be changed to Col objects with the table set to `name`.
885
    '''
886
    table = NamedTable(name, Values(values), cols)
887
    if cols != None: set_cols_table(table, cols)
888
    return table
889

    
890
##### Database structure
891

    
892
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
893

    
894
def ensure_not_null(db, col, type_=None):
895
    '''
896
    @param col If type_ is not set, must have an underlying column.
897
    @param type_ If set, overrides the underlying column's type.
898
    @return EnsureNotNull|Col
899
    @throws ensure_not_null_excs
900
    '''
901
    nullable = True
902
    try: typed_col = db.col_info(underlying_col(col))
903
    except NoUnderlyingTableException:
904
        col = remove_col_rename(col)
905
        if is_literal(col) and not is_null(col): nullable = False
906
        elif type_ == None: raise
907
    else:
908
        if type_ == None: type_ = typed_col.type
909
        nullable = typed_col.nullable
910
    
911
    if nullable:
912
        try: col = EnsureNotNull(col, type_)
913
        except KeyError, e:
914
            # Warn of no null sentinel for type, even if caller catches error
915
            warnings.warn(UserWarning(exc.str_(e)))
916
            raise
917
    
918
    return col
(25-25/37)