Project

General

Profile

1
# SQL code generation
2

    
3
import copy
4
import itertools
5
import operator
6
from ordereddict import OrderedDict
7
import re
8
import UserDict
9
import warnings
10

    
11
import dicts
12
import exc
13
import iters
14
import lists
15
import objects
16
import strings
17
import util
18

    
19
##### Names
20

    
21
identifier_max_len = 63 # works for both PostgreSQL and MySQL
22

    
23
def concat(str_, suffix):
24
    '''Preserves version so that it won't be truncated off the string, leading
25
    to collisions.'''
26
    # Preserve version
27
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
28
    if match:
29
        str_, old_suffix = match.groups()
30
        suffix = old_suffix+suffix
31
    
32
    return strings.concat(str_, suffix, identifier_max_len)
33

    
34
def truncate(str_): return concat(str_, '')
35

    
36
def is_safe_name(name):
37
    '''A name is safe *and unambiguous* if it:
38
    * contains only *lowercase* word (\w) characters
39
    * doesn't start with a digit
40
    * contains "_", so that it's not a keyword
41
    '''
42
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
43

    
44
def esc_name(name, quote='"'):
45
    return quote + name.replace(quote, quote+quote) + quote
46
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
47

    
48
def unesc_name(name, quote='"'):
49
    removed_ref = [False]
50
    name = strings.remove_prefix(quote, name, removed_ref)
51
    if removed_ref[0]:
52
        name = strings.remove_suffix(quote, name, removed_ref)
53
        assert removed_ref[0]
54
        name = name.replace(quote+quote, quote)
55
    return name
56

    
57
def clean_name(name): return name.replace('"', '').replace('`', '')
58

    
59
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
60

    
61
def lstrip(str_):
62
    '''Also removes comments.'''
63
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
64
    return str_.lstrip()
65

    
66
##### General SQL code objects
67

    
68
class MockDb:
69
    def esc_value(self, value): return strings.repr_no_u(value)
70
    
71
    def esc_name(self, name): return esc_name(name)
72
    
73
    def col_info(self, col):
74
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
75

    
76
mockDb = MockDb()
77

    
78
class BasicObject(objects.BasicObject):
79
    def __str__(self): return clean_name(strings.repr_no_u(self))
80

    
81
##### Unparameterized code objects
82

    
83
class Code(BasicObject):
84
    def __init__(self, lang='sql'):
85
        self.lang = lang
86
    
87
    def to_str(self, db): raise NotImplementedError()
88
    
89
    def __repr__(self): return self.to_str(mockDb)
90

    
91
class CustomCode(Code):
92
    def __init__(self, str_):
93
        Code.__init__(self)
94
        
95
        self.str_ = str_
96
    
97
    def to_str(self, db): return self.str_
98

    
99
def as_Code(value, db=None):
100
    '''
101
    @param db If set, runs db.std_code() on the value.
102
    '''
103
    if isinstance(value, Code): return value
104
    
105
    if util.is_str(value):
106
        if db != None: value = db.std_code(value)
107
        return CustomCode(value)
108
    else: return Literal(value)
109

    
110
class Expr(Code):
111
    def __init__(self, expr):
112
        Code.__init__(self)
113
        
114
        self.expr = expr
115
    
116
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
117

    
118
##### Names
119

    
120
class Name(Code):
121
    def __init__(self, name):
122
        Code.__init__(self)
123
        
124
        name = truncate(name)
125
        
126
        self.name = name
127
    
128
    def to_str(self, db): return db.esc_name(self.name)
129

    
130
def as_Name(value):
131
    if isinstance(value, Code): return value
132
    else: return Name(value)
133

    
134
##### Literal values
135

    
136
class Literal(Code):
137
    def __init__(self, value):
138
        Code.__init__(self)
139
        
140
        self.value = value
141
    
142
    def to_str(self, db): return db.esc_value(self.value)
143

    
144
def as_Value(value):
145
    if isinstance(value, Code): return value
146
    else: return Literal(value)
147

    
148
def is_literal(value): return isinstance(value, Literal)
149

    
150
def is_null(value): return is_literal(value) and value.value == None
151

    
152
##### Derived elements
153

    
154
src_self = object() # tells Col that it is its own source column
155

    
156
class Derived(Code):
157
    def __init__(self, srcs):
158
        '''An element which was derived from some other element(s).
159
        @param srcs See self.set_srcs()
160
        '''
161
        Code.__init__(self)
162
        
163
        self.set_srcs(srcs)
164
    
165
    def set_srcs(self, srcs, overwrite=True):
166
        '''
167
        @param srcs (self_type...)|src_self The element(s) this is derived from
168
        '''
169
        if not overwrite and self.srcs != (): return # already set
170
        
171
        if srcs == src_self: srcs = (self,)
172
        srcs = tuple(srcs) # make Col hashable
173
        self.srcs = srcs
174
    
175
    def _compare_on(self):
176
        compare_on = self.__dict__.copy()
177
        del compare_on['srcs'] # ignore
178
        return compare_on
179

    
180
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
181

    
182
##### Tables
183

    
184
class Table(Derived):
185
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
186
        '''
187
        @param schema str|None (for no schema)
188
        @param srcs (Table...)|src_self See Derived.set_srcs()
189
        '''
190
        Derived.__init__(self, srcs)
191
        
192
        if util.is_str(name): name = truncate(name)
193
        
194
        self.name = name
195
        self.schema = schema
196
        self.is_temp = is_temp
197
        self.index_cols = {}
198
    
199
    def to_str(self, db):
200
        str_ = ''
201
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
202
        str_ += as_Name(self.name).to_str(db)
203
        return str_
204
    
205
    def to_Table(self): return self
206
    
207
    def _compare_on(self):
208
        compare_on = Derived._compare_on(self)
209
        del compare_on['index_cols'] # ignore
210
        return compare_on
211

    
212
def is_underlying_table(table):
213
    return isinstance(table, Table) and table.to_Table() is table
214

    
215
class NoUnderlyingTableException(Exception): pass
216

    
217
def underlying_table(table):
218
    table = remove_table_rename(table)
219
    if not is_underlying_table(table): raise NoUnderlyingTableException
220
    return table
221

    
222
def as_Table(table, schema=None):
223
    if table == None or isinstance(table, Code): return table
224
    else: return Table(table, schema)
225

    
226
def suffixed_table(table, suffix):
227
    table = copy.copy(table) # don't modify input!
228
    table.name = concat(table.name, suffix)
229
    return table
230

    
231
class NamedTable(Table):
232
    def __init__(self, name, code, cols=None):
233
        Table.__init__(self, name)
234
        
235
        code = as_Table(code)
236
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
237
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
238
        
239
        self.code = code
240
        self.cols = cols
241
    
242
    def to_str(self, db):
243
        str_ = self.code.to_str(db)
244
        if str_.find('\n') >= 0: whitespace = '\n'
245
        else: whitespace = ' '
246
        str_ += whitespace+'AS '+Table.to_str(self, db)
247
        if self.cols != None:
248
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
249
        return str_
250
    
251
    def to_Table(self): return Table(self.name)
252

    
253
def remove_table_rename(table):
254
    if isinstance(table, NamedTable): table = table.code
255
    return table
256

    
257
##### Columns
258

    
259
class Col(Derived):
260
    def __init__(self, name, table=None, srcs=()):
261
        '''
262
        @param table Table|None (for no table)
263
        @param srcs (Col...)|src_self See Derived.set_srcs()
264
        '''
265
        Derived.__init__(self, srcs)
266
        
267
        if util.is_str(name): name = truncate(name)
268
        if util.is_str(table): table = Table(table)
269
        assert table == None or isinstance(table, Table)
270
        
271
        self.name = name
272
        self.table = table
273
    
274
    def to_str(self, db, for_str=False):
275
        str_ = as_Name(self.name).to_str(db)
276
        if for_str: str_ = clean_name(str_)
277
        if self.table != None:
278
            table = self.table.to_Table()
279
            if for_str: str_ = concat(str(table), '.'+str_)
280
            else: str_ = table.to_str(db)+'.'+str_
281
        return str_
282
    
283
    def __str__(self): return self.to_str(mockDb, for_str=True)
284
    
285
    def to_Col(self): return self
286

    
287
def is_col(col): return isinstance(col, Col)
288

    
289
def is_table_col(col): return is_col(col) and col.table != None
290

    
291
def index_col(col):
292
    if not is_table_col(col): return None
293
    
294
    table = col.table
295
    try: name = table.index_cols[col.name]
296
    except KeyError: return None
297
    else: return Col(name, table, col.srcs)
298

    
299
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
300

    
301
def as_Col(col, table=None, name=None):
302
    '''
303
    @param name If not None, any non-Col input will be renamed using NamedCol.
304
    '''
305
    if name != None:
306
        col = as_Value(col)
307
        if not isinstance(col, Col): col = NamedCol(name, col)
308
    
309
    if isinstance(col, Code): return col
310
    elif util.is_str(col): return Col(col, table)
311
    else: return Literal(col)
312

    
313
def with_table(col, table):
314
    if isinstance(col, NamedCol): pass # doesn't take a table
315
    elif isinstance(col, FunctionCall):
316
        col = copy.deepcopy(col) # don't modify input!
317
        col.args[0].table = table
318
    elif isinstance(col, Col):
319
        col = copy.copy(col) # don't modify input!
320
        col.table = table
321
    return col
322

    
323
def with_default_table(col, table):
324
    col = as_Col(col)
325
    if col.table == None: col = with_table(col, table)
326
    return col
327

    
328
def set_cols_table(table, cols):
329
    table = as_Table(table)
330
    
331
    for i, col in enumerate(cols):
332
        col = cols[i] = as_Col(col)
333
        col.table = table
334

    
335
def to_name_only_col(col, check_table=None):
336
    col = as_Col(col)
337
    if not is_table_col(col): return col
338
    
339
    if check_table != None:
340
        table = col.table
341
        assert table == None or table == check_table
342
    return Col(col.name)
343

    
344
def suffixed_col(col, suffix):
345
    return Col(concat(col.name, suffix), col.table, col.srcs)
346

    
347
def has_srcs(col): return is_col(col) and col.srcs
348

    
349
def cross_join_srcs(cols):
350
    cols = filter(has_srcs, cols) # empty srcs will mess up the cross join
351
    srcs = [[s.name for s in c.srcs] for c in cols]
352
    return [Col(','.join(s)) for s in itertools.product(*srcs)]
353

    
354
class NamedCol(Col):
355
    def __init__(self, name, code):
356
        Col.__init__(self, name)
357
        
358
        code = as_Value(code)
359
        
360
        self.code = code
361
    
362
    def to_str(self, db):
363
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
364
    
365
    def to_Col(self): return Col(self.name)
366

    
367
def remove_col_rename(col):
368
    if isinstance(col, NamedCol): col = col.code
369
    return col
370

    
371
def underlying_col(col):
372
    col = remove_col_rename(col)
373
    if not isinstance(col, Col): raise NoUnderlyingTableException
374
    
375
    return Col(col.name, underlying_table(col.table), col.srcs)
376

    
377
def wrap(wrap_func, value):
378
    '''Wraps a value, propagating any column renaming to the returned value.'''
379
    if isinstance(value, NamedCol):
380
        return NamedCol(value.name, wrap_func(value.code))
381
    else: return wrap_func(value)
382

    
383
class ColDict(dicts.DictProxy):
384
    '''A dict that automatically makes inserted entries Col objects'''
385
    
386
    def __init__(self, db, keys_table, dict_={}):
387
        dicts.DictProxy.__init__(self, OrderedDict())
388
        
389
        keys_table = as_Table(keys_table)
390
        
391
        self.db = db
392
        self.table = keys_table
393
        self.update(dict_) # after setting vars because __setitem__() needs them
394
    
395
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
396
    
397
    def __getitem__(self, key):
398
        return dicts.DictProxy.__getitem__(self, self._key(key))
399
    
400
    def __setitem__(self, key, value):
401
        key = self._key(key)
402
        if value == None: value = self.db.col_info(key).default
403
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
404
    
405
    def _key(self, key): return as_Col(key, self.table)
406

    
407
##### Composite types
408

    
409
class List(Code):
410
    def __init__(self, *values):
411
        Code.__init__(self)
412
        
413
        self.values = values
414
    
415
    def to_str(self, db):
416
        return '('+(', '.join((v.to_str(db) for v in self.values)))+')'
417

    
418
class Tuple(List):
419
    def to_str(self, db): return 'ROW'+List.to_str(self, db)
420

    
421
#### Definitions
422

    
423
class TypedCol(Col):
424
    def __init__(self, name, type_, default=None, nullable=True,
425
        constraints=None):
426
        assert default == None or isinstance(default, Code)
427
        
428
        Col.__init__(self, name)
429
        
430
        self.type = type_
431
        self.default = default
432
        self.nullable = nullable
433
        self.constraints = constraints
434
    
435
    def to_str(self, db):
436
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
437
        if not self.nullable: str_ += ' NOT NULL'
438
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
439
        if self.constraints != None: str_ += ' '+self.constraints
440
        return str_
441
    
442
    def to_Col(self): return Col(self.name)
443

    
444
class SetOf(Code):
445
    def __init__(self, type_):
446
        Code.__init__(self)
447
        
448
        self.type = type_
449
    
450
    def to_str(self, db):
451
        return 'SETOF '+self.type.to_str(db)
452

    
453
class RowType(Code):
454
    def __init__(self, table):
455
        Code.__init__(self)
456
        
457
        self.table = table
458
    
459
    def to_str(self, db):
460
        return self.table.to_str(db)+'%ROWTYPE'
461

    
462
class ColType(Code):
463
    def __init__(self, col):
464
        Code.__init__(self)
465
        
466
        self.col = col
467
    
468
    def to_str(self, db):
469
        return self.col.to_str(db)+'%TYPE'
470

    
471
##### Functions
472

    
473
Function = Table
474
as_Function = as_Table
475

    
476
class InternalFunction(CustomCode): pass
477

    
478
#### Calls
479

    
480
class NamedArg(NamedCol):
481
    def __init__(self, name, value):
482
        NamedCol.__init__(self, name, value)
483
    
484
    def to_str(self, db):
485
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
486

    
487
class FunctionCall(Code):
488
    def __init__(self, function, *args, **kw_args):
489
        '''
490
        @param args [Code|literal-value...] The function's arguments
491
        '''
492
        Code.__init__(self)
493
        
494
        function = as_Function(function)
495
        def filter_(arg): return remove_col_rename(as_Value(arg))
496
        args = map(filter_, args)
497
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
498
        
499
        self.function = function
500
        self.args = args
501
    
502
    def to_str(self, db):
503
        args_str = ', '.join((v.to_str(db) for v in self.args))
504
        return self.function.to_str(db)+'('+args_str+')'
505

    
506
def wrap_in_func(function, value):
507
    '''Wraps a value inside a function call.
508
    Propagates any column renaming to the returned value.
509
    '''
510
    return wrap(lambda v: FunctionCall(function, v), value)
511

    
512
def unwrap_func_call(func_call, check_name=None):
513
    '''Unwraps any function call to its first argument.
514
    Also removes any column renaming.
515
    '''
516
    func_call = remove_col_rename(func_call)
517
    if not isinstance(func_call, FunctionCall): return func_call
518
    
519
    if check_name != None:
520
        name = func_call.function.name
521
        assert name == None or name == check_name
522
    return func_call.args[0]
523

    
524
#### Definitions
525

    
526
class FunctionDef(Code):
527
    def __init__(self, function, return_type, body, params=[], modifiers=None):
528
        Code.__init__(self)
529
        
530
        return_type = as_Code(return_type)
531
        body = as_Code(body)
532
        
533
        self.function = function
534
        self.return_type = return_type
535
        self.body = body
536
        self.params = params
537
        self.modifiers = modifiers
538
    
539
    def to_str(self, db):
540
        params_str = (', '.join((p.to_str(db) for p in self.params)))
541
        str_ = '''\
542
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
543
RETURNS '''+self.return_type.to_str(db)+'''
544
LANGUAGE '''+self.body.lang+'''
545
'''
546
        if self.modifiers != None: str_ += self.modifiers+'\n'
547
        str_ += '''\
548
AS $$
549
'''+self.body.to_str(db)+'''
550
$$;
551
'''
552
        return str_
553

    
554
class FunctionParam(TypedCol):
555
    def __init__(self, name, type_, default=None, out=False):
556
        TypedCol.__init__(self, name, type_, default)
557
        
558
        self.out = out
559
    
560
    def to_str(self, db):
561
        str_ = TypedCol.to_str(self, db)
562
        if self.out: str_ = 'OUT '+str_
563
        return str_
564
    
565
    def to_Col(self): return Col(self.name)
566

    
567
### PL/pgSQL
568

    
569
class ReturnQuery(Code):
570
    def __init__(self, query):
571
        Code.__init__(self)
572
        
573
        query = as_Code(query)
574
        
575
        self.query = query
576
    
577
    def to_str(self, db):
578
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
579

    
580
## Exceptions
581

    
582
class BaseExcHandler(BasicObject):
583
    def to_str(self, db, body): raise NotImplementedError()
584
    
585
    def __repr__(self): return self.to_str(mockDb, '<body>')
586

    
587
class ExcHandler(BaseExcHandler):
588
    def __init__(self, exc, handler=None):
589
        if handler != None: handler = as_Code(handler)
590
        
591
        self.exc = exc
592
        self.handler = handler
593
    
594
    def to_str(self, db, body):
595
        body = as_Code(body)
596
        
597
        if self.handler != None:
598
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
599
        else: handler_str = ' NULL;\n'
600
        
601
        str_ = '''\
602
BEGIN
603
'''+strings.indent(body.to_str(db))+'''\
604
EXCEPTION
605
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
606
END;\
607
'''
608
        return str_
609

    
610
class NestedExcHandler(BaseExcHandler):
611
    def __init__(self, *handlers):
612
        '''
613
        @param handlers Sorted from outermost to innermost
614
        '''
615
        self.handlers = handlers
616
    
617
    def to_str(self, db, body):
618
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
619
        return body
620

    
621
class ExcToWarning(Code):
622
    def __init__(self, return_):
623
        '''
624
        @param return_ Statement to return a default value in case of error
625
        '''
626
        Code.__init__(self)
627
        
628
        return_ = as_Code(return_)
629
        
630
        self.return_ = return_
631
    
632
    def to_str(self, db):
633
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
634

    
635
unique_violation_handler = ExcHandler('unique_violation')
636

    
637
plpythonu_error_handler = ExcHandler('internal_error', '''\
638
RAISE data_exception USING MESSAGE =
639
    regexp_replace(SQLERRM, E'^PL/Python: \\w+: ', '');
640
''')
641

    
642
def data_exception_handler(handler):
643
    return ExcHandler('data_exception', handler)
644

    
645
class RowExcIgnore(Code):
646
    def __init__(self, row_type, select_query, with_row, cols=None,
647
        exc_handler=unique_violation_handler, row_var='row'):
648
        Code.__init__(self, lang='plpgsql')
649
        
650
        row_type = as_Code(row_type)
651
        select_query = as_Code(select_query)
652
        with_row = as_Code(with_row)
653
        row_var = as_Table(row_var)
654
        
655
        self.row_type = row_type
656
        self.select_query = select_query
657
        self.with_row = with_row
658
        self.cols = cols
659
        self.exc_handler = exc_handler
660
        self.row_var = row_var
661
    
662
    def to_str(self, db):
663
        if self.cols == None: row_vars = [self.row_var]
664
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
665
        
666
        str_ = '''\
667
DECLARE
668
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
669
BEGIN
670
    /* Need an EXCEPTION block for each individual row because "When
671
    an error is caught by an EXCEPTION clause, [...] all changes to
672
    persistent database state within the block are rolled back."
673
    This is unfortunate because "A block containing an EXCEPTION
674
    clause is significantly more expensive to enter and exit than a
675
    block without one."
676
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
677
#PLPGSQL-ERROR-TRAPPING)
678
    */
679
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
680
'''+strings.indent(self.select_query.to_str(db), 2)+'''\
681
    LOOP
682
'''+strings.indent(self.exc_handler.to_str(db, self.with_row), 2)+'''\
683
    END LOOP;
684
END;\
685
'''
686
        return str_
687

    
688
##### Casts
689

    
690
class Cast(FunctionCall):
691
    def __init__(self, type_, value):
692
        value = as_Value(value)
693
        
694
        self.type_ = type_
695
        self.value = value
696
    
697
    def to_str(self, db):
698
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_+')'
699

    
700
def cast_literal(value):
701
    if not is_literal(value): return value
702
    
703
    if util.is_str(value.value): value = Cast('text', value)
704
    return value
705

    
706
##### Conditions
707

    
708
class NotCond(Code):
709
    def __init__(self, cond):
710
        Code.__init__(self)
711
        
712
        self.cond = cond
713
    
714
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
715

    
716
class ColValueCond(Code):
717
    def __init__(self, col, value):
718
        Code.__init__(self)
719
        
720
        value = as_ValueCond(value)
721
        
722
        self.col = col
723
        self.value = value
724
    
725
    def to_str(self, db): return self.value.to_str(db, self.col)
726

    
727
def combine_conds(conds, keyword=None):
728
    '''
729
    @param keyword The keyword to add before the conditions, if any
730
    '''
731
    str_ = ''
732
    if keyword != None:
733
        if conds == []: whitespace = ''
734
        elif len(conds) == 1: whitespace = ' '
735
        else: whitespace = '\n'
736
        str_ += keyword+whitespace
737
    
738
    str_ += '\nAND '.join(conds)
739
    return str_
740

    
741
##### Condition column comparisons
742

    
743
class ValueCond(BasicObject):
744
    def __init__(self, value):
745
        value = remove_col_rename(as_Value(value))
746
        
747
        self.value = value
748
    
749
    def to_str(self, db, left_value):
750
        '''
751
        @param left_value The Code object that the condition is being applied on
752
        '''
753
        raise NotImplemented()
754
    
755
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
756

    
757
class CompareCond(ValueCond):
758
    def __init__(self, value, operator='='):
759
        '''
760
        @param operator By default, compares NULL values literally. Use '~=' or
761
            '~!=' to pass NULLs through.
762
        '''
763
        ValueCond.__init__(self, value)
764
        self.operator = operator
765
    
766
    def to_str(self, db, left_value):
767
        left_value = remove_col_rename(as_Col(left_value))
768
        
769
        right_value = self.value
770
        
771
        # Parse operator
772
        operator = self.operator
773
        passthru_null_ref = [False]
774
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
775
        neg_ref = [False]
776
        operator = strings.remove_prefix('!', operator, neg_ref)
777
        equals = operator.endswith('=') # also includes <=, >=
778
        
779
        # Handle nullable columns
780
        check_null = False
781
        if not passthru_null_ref[0]: # NULLs compare equal
782
            try: left_value = ensure_not_null(db, left_value)
783
            except ensure_not_null_excs: # fall back to alternate method
784
                check_null = equals and isinstance(right_value, Col)
785
            else:
786
                if isinstance(left_value, EnsureNotNull):
787
                    right_value = ensure_not_null(db, right_value,
788
                        left_value.type) # apply same function to both sides
789
        
790
        if equals and is_null(right_value): operator = 'IS'
791
        
792
        left = left_value.to_str(db)
793
        right = right_value.to_str(db)
794
        
795
        # Create str
796
        str_ = left+' '+operator+' '+right
797
        if check_null:
798
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
799
        if neg_ref[0]: str_ = 'NOT '+str_
800
        return str_
801

    
802
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
803
assume_literal = object()
804

    
805
def as_ValueCond(value, default_table=assume_literal):
806
    if not isinstance(value, ValueCond):
807
        if default_table is not assume_literal:
808
            value = with_default_table(value, default_table)
809
        return CompareCond(value)
810
    else: return value
811

    
812
##### Joins
813

    
814
join_same = object() # tells Join the left and right columns have the same name
815

    
816
# Tells Join the left and right columns have the same name and are never NULL
817
join_same_not_null = object()
818

    
819
filter_out = object() # tells Join to filter out rows that match the join
820

    
821
class Join(BasicObject):
822
    def __init__(self, table, mapping={}, type_=None):
823
        '''
824
        @param mapping dict(right_table_col=left_table_col, ...)
825
            * if left_table_col is join_same: left_table_col = right_table_col
826
              * Note that right_table_col must be a string
827
            * if left_table_col is join_same_not_null:
828
              left_table_col = right_table_col and both have NOT NULL constraint
829
              * Note that right_table_col must be a string
830
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
831
            * filter_out: equivalent to 'LEFT' with the query filtered by
832
              `table_pkey IS NULL` (indicating no match)
833
        '''
834
        if util.is_str(table): table = Table(table)
835
        assert type_ == None or util.is_str(type_) or type_ is filter_out
836
        
837
        self.table = table
838
        self.mapping = mapping
839
        self.type_ = type_
840
    
841
    def to_str(self, db, left_table_):
842
        def join(entry):
843
            '''Parses non-USING joins'''
844
            right_table_col, left_table_col = entry
845
            
846
            # Switch order (right_table_col is on the left in the comparison)
847
            left = right_table_col
848
            right = left_table_col
849
            left_table = self.table
850
            right_table = left_table_
851
            
852
            # Parse left side
853
            left = with_default_table(left, left_table)
854
            
855
            # Parse special values
856
            left_on_right = Col(left.name, right_table)
857
            if right is join_same: right = left_on_right
858
            elif right is join_same_not_null:
859
                right = CompareCond(left_on_right, '~=')
860
            
861
            # Parse right side
862
            right = as_ValueCond(right, right_table)
863
            
864
            return right.to_str(db, left)
865
        
866
        # Create join condition
867
        type_ = self.type_
868
        joins = self.mapping
869
        if joins == {}: join_cond = None
870
        elif type_ is not filter_out and reduce(operator.and_,
871
            (v is join_same_not_null for v in joins.itervalues())):
872
            # all cols w/ USING, so can use simpler USING syntax
873
            cols = map(to_name_only_col, joins.iterkeys())
874
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
875
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
876
        
877
        if isinstance(self.table, NamedTable): whitespace = '\n'
878
        else: whitespace = ' '
879
        
880
        # Create join
881
        if type_ is filter_out: type_ = 'LEFT'
882
        str_ = ''
883
        if type_ != None: str_ += type_+' '
884
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
885
        if join_cond != None: str_ += whitespace+join_cond
886
        return str_
887
    
888
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
889

    
890
##### Value exprs
891

    
892
all_cols = CustomCode('*')
893

    
894
default = CustomCode('DEFAULT')
895

    
896
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
897

    
898
class Coalesce(FunctionCall):
899
    def __init__(self, *args):
900
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
901

    
902
class Nullif(FunctionCall):
903
    def __init__(self, *args):
904
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
905

    
906
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
907
null_sentinels = {
908
    'character varying': r'\N',
909
    'double precision': 'NaN',
910
    'integer': 2147483647,
911
    'text': r'\N',
912
    'timestamp with time zone': 'infinity'
913
}
914

    
915
class EnsureNotNull(Coalesce):
916
    def __init__(self, value, type_):
917
        Coalesce.__init__(self, as_Col(value),
918
            Cast(type_, null_sentinels[type_]))
919
        
920
        self.type = type_
921
    
922
    def to_str(self, db):
923
        col = self.args[0]
924
        index_col_ = index_col(col)
925
        if index_col_ != None: return index_col_.to_str(db)
926
        return Coalesce.to_str(self, db)
927

    
928
##### Table exprs
929

    
930
class Values(Code):
931
    def __init__(self, values):
932
        '''
933
        @param values [...]|[[...], ...] Can be one or multiple rows.
934
        '''
935
        Code.__init__(self)
936
        
937
        rows = values
938
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
939
            rows = [values]
940
        for i, row in enumerate(rows):
941
            rows[i] = map(remove_col_rename, map(as_Value, row))
942
        
943
        self.rows = rows
944
    
945
    def to_str(self, db):
946
        return 'VALUES '+(', '.join((List(*r).to_str(db) for r in self.rows)))
947

    
948
def NamedValues(name, cols, values):
949
    '''
950
    @param cols None|[...]
951
    @post `cols` will be changed to Col objects with the table set to `name`.
952
    '''
953
    table = NamedTable(name, Values(values), cols)
954
    if cols != None: set_cols_table(table, cols)
955
    return table
956

    
957
##### Database structure
958

    
959
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
960

    
961
def ensure_not_null(db, col, type_=None):
962
    '''
963
    @param col If type_ is not set, must have an underlying column.
964
    @param type_ If set, overrides the underlying column's type.
965
    @return EnsureNotNull|Col
966
    @throws ensure_not_null_excs
967
    '''
968
    nullable = True
969
    try: typed_col = db.col_info(underlying_col(col))
970
    except NoUnderlyingTableException:
971
        col = remove_col_rename(col)
972
        if is_literal(col) and not is_null(col): nullable = False
973
        elif type_ == None: raise
974
    else:
975
        if type_ == None: type_ = typed_col.type
976
        nullable = typed_col.nullable
977
    
978
    if nullable:
979
        try: col = EnsureNotNull(col, type_)
980
        except KeyError, e:
981
            # Warn of no null sentinel for type, even if caller catches error
982
            warnings.warn(UserWarning(exc.str_(e)))
983
            raise
984
    
985
    return col
(25-25/37)