Project

General

Profile

1
# SQL code generation
2

    
3
import copy
4
import operator
5
from ordereddict import OrderedDict
6
import re
7
import UserDict
8
import warnings
9

    
10
import dicts
11
import exc
12
import iters
13
import lists
14
import objects
15
import strings
16
import util
17

    
18
##### Names
19

    
20
identifier_max_len = 63 # works for both PostgreSQL and MySQL
21

    
22
def concat(str_, suffix):
23
    '''Preserves version so that it won't be truncated off the string, leading
24
    to collisions.'''
25
    # Preserve version
26
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
27
    if match:
28
        str_, old_suffix = match.groups()
29
        suffix = old_suffix+suffix
30
    
31
    return strings.concat(str_, suffix, identifier_max_len)
32

    
33
def truncate(str_): return concat(str_, '')
34

    
35
def is_safe_name(name):
36
    '''A name is safe *and unambiguous* if it:
37
    * contains only *lowercase* word (\w) characters
38
    * doesn't start with a digit
39
    * contains "_", so that it's not a keyword
40
    '''
41
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
42

    
43
def esc_name(name, quote='"'):
44
    return quote + name.replace(quote, quote+quote) + quote
45
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
46

    
47
def unesc_name(name, quote='"'):
48
    removed_ref = [False]
49
    name = strings.remove_prefix(quote, name, removed_ref)
50
    if removed_ref[0]:
51
        name = strings.remove_suffix(quote, name, removed_ref)
52
        assert removed_ref[0]
53
        name = name.replace(quote+quote, quote)
54
    return name
55

    
56
def clean_name(name): return name.replace('"', '').replace('`', '')
57

    
58
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
59

    
60
def lstrip(str_):
61
    '''Also removes comments.'''
62
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
63
    return str_.lstrip()
64

    
65
##### General SQL code objects
66

    
67
class MockDb:
68
    def esc_value(self, value): return strings.repr_no_u(value)
69
    
70
    def esc_name(self, name): return esc_name(name)
71
    
72
    def col_info(self, col):
73
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
74

    
75
mockDb = MockDb()
76

    
77
class BasicObject(objects.BasicObject):
78
    def __str__(self): return clean_name(strings.repr_no_u(self))
79

    
80
##### Unparameterized code objects
81

    
82
class Code(BasicObject):
83
    def __init__(self, lang='sql'):
84
        self.lang = lang
85
    
86
    def to_str(self, db): raise NotImplementedError()
87
    
88
    def __repr__(self): return self.to_str(mockDb)
89

    
90
class CustomCode(Code):
91
    def __init__(self, str_):
92
        Code.__init__(self)
93
        
94
        self.str_ = str_
95
    
96
    def to_str(self, db): return self.str_
97

    
98
def as_Code(value, db=None):
99
    '''
100
    @param db If set, runs db.std_code() on the value.
101
    '''
102
    if isinstance(value, Code): return value
103
    
104
    if util.is_str(value):
105
        if db != None: value = db.std_code(value)
106
        return CustomCode(value)
107
    else: return Literal(value)
108

    
109
class Expr(Code):
110
    def __init__(self, expr):
111
        Code.__init__(self)
112
        
113
        self.expr = expr
114
    
115
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
116

    
117
##### Names
118

    
119
class Name(Code):
120
    def __init__(self, name):
121
        Code.__init__(self)
122
        
123
        name = truncate(name)
124
        
125
        self.name = name
126
    
127
    def to_str(self, db): return db.esc_name(self.name)
128

    
129
def as_Name(value):
130
    if isinstance(value, Code): return value
131
    else: return Name(value)
132

    
133
##### Literal values
134

    
135
class Literal(Code):
136
    def __init__(self, value):
137
        Code.__init__(self)
138
        
139
        self.value = value
140
    
141
    def to_str(self, db): return db.esc_value(self.value)
142

    
143
def as_Value(value):
144
    if isinstance(value, Code): return value
145
    else: return Literal(value)
146

    
147
def is_literal(value): return isinstance(value, Literal)
148

    
149
def is_null(value): return is_literal(value) and value.value == None
150

    
151
##### Derived elements
152

    
153
src_self = object() # tells Col that it is its own source column
154

    
155
class Derived(Code):
156
    def __init__(self, srcs):
157
        '''An element which was derived from some other element(s).
158
        @param srcs See self.set_srcs()
159
        '''
160
        Code.__init__(self)
161
        
162
        self.set_srcs(srcs)
163
    
164
    def set_srcs(self, srcs, overwrite=True):
165
        '''
166
        @param srcs (self_type...)|src_self The element(s) this is derived from
167
        '''
168
        if not overwrite and self.srcs != (): return # already set
169
        
170
        if srcs == src_self: srcs = (self,)
171
        srcs = tuple(srcs) # make Col hashable
172
        self.srcs = srcs
173
    
174
    def _compare_on(self):
175
        compare_on = self.__dict__.copy()
176
        del compare_on['srcs'] # ignore
177
        return compare_on
178

    
179
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
180

    
181
##### Tables
182

    
183
class Table(Derived):
184
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
185
        '''
186
        @param schema str|None (for no schema)
187
        @param srcs (Table...)|src_self See Derived.set_srcs()
188
        '''
189
        Derived.__init__(self, srcs)
190
        
191
        if util.is_str(name): name = truncate(name)
192
        
193
        self.name = name
194
        self.schema = schema
195
        self.is_temp = is_temp
196
        self.index_cols = {}
197
    
198
    def to_str(self, db):
199
        str_ = ''
200
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
201
        str_ += as_Name(self.name).to_str(db)
202
        return str_
203
    
204
    def to_Table(self): return self
205
    
206
    def _compare_on(self):
207
        compare_on = Derived._compare_on(self)
208
        del compare_on['index_cols'] # ignore
209
        return compare_on
210

    
211
def is_underlying_table(table):
212
    return isinstance(table, Table) and table.to_Table() is table
213

    
214
class NoUnderlyingTableException(Exception): pass
215

    
216
def underlying_table(table):
217
    table = remove_table_rename(table)
218
    if not is_underlying_table(table): raise NoUnderlyingTableException
219
    return table
220

    
221
def as_Table(table, schema=None):
222
    if table == None or isinstance(table, Code): return table
223
    else: return Table(table, schema)
224

    
225
def suffixed_table(table, suffix):
226
    table = copy.copy(table) # don't modify input!
227
    table.name = concat(table.name, suffix)
228
    return table
229

    
230
class NamedTable(Table):
231
    def __init__(self, name, code, cols=None):
232
        Table.__init__(self, name)
233
        
234
        code = as_Table(code)
235
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
236
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
237
        
238
        self.code = code
239
        self.cols = cols
240
    
241
    def to_str(self, db):
242
        str_ = self.code.to_str(db)
243
        if str_.find('\n') >= 0: whitespace = '\n'
244
        else: whitespace = ' '
245
        str_ += whitespace+'AS '+Table.to_str(self, db)
246
        if self.cols != None:
247
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
248
        return str_
249
    
250
    def to_Table(self): return Table(self.name)
251

    
252
def remove_table_rename(table):
253
    if isinstance(table, NamedTable): table = table.code
254
    return table
255

    
256
##### Columns
257

    
258
class Col(Derived):
259
    def __init__(self, name, table=None, srcs=()):
260
        '''
261
        @param table Table|None (for no table)
262
        @param srcs (Col...)|src_self See Derived.set_srcs()
263
        '''
264
        Derived.__init__(self, srcs)
265
        
266
        if util.is_str(name): name = truncate(name)
267
        if util.is_str(table): table = Table(table)
268
        assert table == None or isinstance(table, Table)
269
        
270
        self.name = name
271
        self.table = table
272
    
273
    def to_str(self, db, for_str=False):
274
        str_ = as_Name(self.name).to_str(db)
275
        if for_str: str_ = clean_name(str_)
276
        if self.table != None:
277
            table = self.table.to_Table()
278
            if for_str: str_ = concat(str(table), '.'+str_)
279
            else: str_ = table.to_str(db)+'.'+str_
280
        return str_
281
    
282
    def __str__(self): return self.to_str(mockDb, for_str=True)
283
    
284
    def to_Col(self): return self
285

    
286
def is_col(col): return isinstance(col, Col)
287

    
288
def is_table_col(col): return is_col(col) and col.table != None
289

    
290
def index_col(col):
291
    if not is_table_col(col): return None
292
    
293
    table = col.table
294
    try: name = table.index_cols[col.name]
295
    except KeyError: return None
296
    else: return Col(name, table, col.srcs)
297

    
298
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
299

    
300
def as_Col(col, table=None, name=None):
301
    '''
302
    @param name If not None, any non-Col input will be renamed using NamedCol.
303
    '''
304
    if name != None:
305
        col = as_Value(col)
306
        if not isinstance(col, Col): col = NamedCol(name, col)
307
    
308
    if isinstance(col, Code): return col
309
    elif util.is_str(col): return Col(col, table)
310
    else: return Literal(col)
311

    
312
def with_table(col, table):
313
    if isinstance(col, NamedCol): pass # doesn't take a table
314
    elif isinstance(col, FunctionCall):
315
        col = copy.deepcopy(col) # don't modify input!
316
        col.args[0].table = table
317
    elif isinstance(col, Col):
318
        col = copy.copy(col) # don't modify input!
319
        col.table = table
320
    return col
321

    
322
def with_default_table(col, table):
323
    col = as_Col(col)
324
    if col.table == None: col = with_table(col, table)
325
    return col
326

    
327
def set_cols_table(table, cols):
328
    table = as_Table(table)
329
    
330
    for i, col in enumerate(cols):
331
        col = cols[i] = as_Col(col)
332
        col.table = table
333

    
334
def to_name_only_col(col, check_table=None):
335
    col = as_Col(col)
336
    if not is_table_col(col): return col
337
    
338
    if check_table != None:
339
        table = col.table
340
        assert table == None or table == check_table
341
    return Col(col.name)
342

    
343
def suffixed_col(col, suffix):
344
    return Col(concat(col.name, suffix), col.table, col.srcs)
345

    
346
def has_srcs(col): return is_col(col) and col.srcs
347

    
348
def srcs_str(cols):
349
    cols = filter(is_col, cols)
350
    return ','.join(('+'.join((s.name for s in c.srcs)) for c in cols))
351

    
352
class NamedCol(Col):
353
    def __init__(self, name, code):
354
        Col.__init__(self, name)
355
        
356
        code = as_Value(code)
357
        
358
        self.code = code
359
    
360
    def to_str(self, db):
361
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
362
    
363
    def to_Col(self): return Col(self.name)
364

    
365
def remove_col_rename(col):
366
    if isinstance(col, NamedCol): col = col.code
367
    return col
368

    
369
def underlying_col(col):
370
    col = remove_col_rename(col)
371
    if not isinstance(col, Col): raise NoUnderlyingTableException
372
    
373
    return Col(col.name, underlying_table(col.table), col.srcs)
374

    
375
def wrap(wrap_func, value):
376
    '''Wraps a value, propagating any column renaming to the returned value.'''
377
    if isinstance(value, NamedCol):
378
        return NamedCol(value.name, wrap_func(value.code))
379
    else: return wrap_func(value)
380

    
381
class ColDict(dicts.DictProxy):
382
    '''A dict that automatically makes inserted entries Col objects'''
383
    
384
    def __init__(self, db, keys_table, dict_={}):
385
        dicts.DictProxy.__init__(self, OrderedDict())
386
        
387
        keys_table = as_Table(keys_table)
388
        
389
        self.db = db
390
        self.table = keys_table
391
        self.update(dict_) # after setting vars because __setitem__() needs them
392
    
393
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
394
    
395
    def __getitem__(self, key):
396
        return dicts.DictProxy.__getitem__(self, self._key(key))
397
    
398
    def __setitem__(self, key, value):
399
        key = self._key(key)
400
        if value == None: value = self.db.col_info(key).default
401
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
402
    
403
    def _key(self, key): return as_Col(key, self.table)
404

    
405
##### Composite types
406

    
407
class List(Code):
408
    def __init__(self, *values):
409
        Code.__init__(self)
410
        
411
        self.values = values
412
    
413
    def to_str(self, db):
414
        return '('+(', '.join((v.to_str(db) for v in self.values)))+')'
415

    
416
class Tuple(List):
417
    def to_str(self, db): return 'ROW'+List.to_str(self, db)
418

    
419
#### Definitions
420

    
421
class TypedCol(Col):
422
    def __init__(self, name, type_, default=None, nullable=True,
423
        constraints=None):
424
        assert default == None or isinstance(default, Code)
425
        
426
        Col.__init__(self, name)
427
        
428
        self.type = type_
429
        self.default = default
430
        self.nullable = nullable
431
        self.constraints = constraints
432
    
433
    def to_str(self, db):
434
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
435
        if not self.nullable: str_ += ' NOT NULL'
436
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
437
        if self.constraints != None: str_ += ' '+self.constraints
438
        return str_
439
    
440
    def to_Col(self): return Col(self.name)
441

    
442
class SetOf(Code):
443
    def __init__(self, type_):
444
        Code.__init__(self)
445
        
446
        self.type = type_
447
    
448
    def to_str(self, db):
449
        return 'SETOF '+self.type.to_str(db)
450

    
451
class RowType(Code):
452
    def __init__(self, table):
453
        Code.__init__(self)
454
        
455
        self.table = table
456
    
457
    def to_str(self, db):
458
        return self.table.to_str(db)+'%ROWTYPE'
459

    
460
class ColType(Code):
461
    def __init__(self, col):
462
        Code.__init__(self)
463
        
464
        self.col = col
465
    
466
    def to_str(self, db):
467
        return self.col.to_str(db)+'%TYPE'
468

    
469
##### Functions
470

    
471
Function = Table
472
as_Function = as_Table
473

    
474
class InternalFunction(CustomCode): pass
475

    
476
#### Calls
477

    
478
class NamedArg(NamedCol):
479
    def __init__(self, name, value):
480
        NamedCol.__init__(self, name, value)
481
    
482
    def to_str(self, db):
483
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
484

    
485
class FunctionCall(Code):
486
    def __init__(self, function, *args, **kw_args):
487
        '''
488
        @param args [Code|literal-value...] The function's arguments
489
        '''
490
        Code.__init__(self)
491
        
492
        function = as_Function(function)
493
        def filter_(arg): return remove_col_rename(as_Value(arg))
494
        args = map(filter_, args)
495
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
496
        
497
        self.function = function
498
        self.args = args
499
    
500
    def to_str(self, db):
501
        args_str = ', '.join((v.to_str(db) for v in self.args))
502
        return self.function.to_str(db)+'('+args_str+')'
503

    
504
def wrap_in_func(function, value):
505
    '''Wraps a value inside a function call.
506
    Propagates any column renaming to the returned value.
507
    '''
508
    return wrap(lambda v: FunctionCall(function, v), value)
509

    
510
def unwrap_func_call(func_call, check_name=None):
511
    '''Unwraps any function call to its first argument.
512
    Also removes any column renaming.
513
    '''
514
    func_call = remove_col_rename(func_call)
515
    if not isinstance(func_call, FunctionCall): return func_call
516
    
517
    if check_name != None:
518
        name = func_call.function.name
519
        assert name == None or name == check_name
520
    return func_call.args[0]
521

    
522
#### Definitions
523

    
524
class FunctionDef(Code):
525
    def __init__(self, function, return_type, body, params=[], modifiers=None):
526
        Code.__init__(self)
527
        
528
        return_type = as_Code(return_type)
529
        body = as_Code(body)
530
        
531
        self.function = function
532
        self.return_type = return_type
533
        self.body = body
534
        self.params = params
535
        self.modifiers = modifiers
536
    
537
    def to_str(self, db):
538
        params_str = (', '.join((p.to_str(db) for p in self.params)))
539
        str_ = '''\
540
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
541
RETURNS '''+self.return_type.to_str(db)+'''
542
LANGUAGE '''+self.body.lang+'''
543
'''
544
        if self.modifiers != None: str_ += self.modifiers+'\n'
545
        str_ += '''\
546
AS $$
547
'''+self.body.to_str(db)+'''
548
$$;
549
'''
550
        return str_
551

    
552
class FunctionParam(TypedCol):
553
    def __init__(self, name, type_, default=None, out=False):
554
        TypedCol.__init__(self, name, type_, default)
555
        
556
        self.out = out
557
    
558
    def to_str(self, db):
559
        str_ = TypedCol.to_str(self, db)
560
        if self.out: str_ = 'OUT '+str_
561
        return str_
562
    
563
    def to_Col(self): return Col(self.name)
564

    
565
### PL/pgSQL
566

    
567
class ReturnQuery(Code):
568
    def __init__(self, query):
569
        Code.__init__(self)
570
        
571
        query = as_Code(query)
572
        
573
        self.query = query
574
    
575
    def to_str(self, db):
576
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
577

    
578
## Exceptions
579

    
580
class BaseExcHandler(BasicObject):
581
    def to_str(self, db, body): raise NotImplementedError()
582
    
583
    def __repr__(self): return self.to_str(mockDb, '<body>')
584

    
585
class ExcHandler(BaseExcHandler):
586
    def __init__(self, exc, handler=None):
587
        if handler != None: handler = as_Code(handler)
588
        
589
        self.exc = exc
590
        self.handler = handler
591
    
592
    def to_str(self, db, body):
593
        body = as_Code(body)
594
        
595
        if self.handler != None:
596
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
597
        else: handler_str = ' NULL;\n'
598
        
599
        str_ = '''\
600
BEGIN
601
'''+strings.indent(body.to_str(db))+'''\
602
EXCEPTION
603
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
604
END;\
605
'''
606
        return str_
607

    
608
class NestedExcHandler(BaseExcHandler):
609
    def __init__(self, *handlers):
610
        '''
611
        @param handlers Sorted from outermost to innermost
612
        '''
613
        self.handlers = handlers
614
    
615
    def to_str(self, db, body):
616
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
617
        return body
618

    
619
class ExcToWarning(Code):
620
    def __init__(self, return_):
621
        '''
622
        @param return_ Statement to return a default value in case of error
623
        '''
624
        Code.__init__(self)
625
        
626
        return_ = as_Code(return_)
627
        
628
        self.return_ = return_
629
    
630
    def to_str(self, db):
631
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
632

    
633
unique_violation_handler = ExcHandler('unique_violation')
634

    
635
plpythonu_error_handler = ExcHandler('internal_error', '''\
636
RAISE data_exception USING MESSAGE =
637
    regexp_replace(SQLERRM, E'^PL/Python: \\w+: ', '');
638
''')
639

    
640
def data_exception_handler(handler):
641
    return ExcHandler('data_exception', handler)
642

    
643
class RowExcIgnore(Code):
644
    def __init__(self, row_type, select_query, with_row, cols=None,
645
        exc_handler=unique_violation_handler, row_var='row'):
646
        Code.__init__(self, lang='plpgsql')
647
        
648
        row_type = as_Code(row_type)
649
        select_query = as_Code(select_query)
650
        with_row = as_Code(with_row)
651
        row_var = as_Table(row_var)
652
        
653
        self.row_type = row_type
654
        self.select_query = select_query
655
        self.with_row = with_row
656
        self.cols = cols
657
        self.exc_handler = exc_handler
658
        self.row_var = row_var
659
    
660
    def to_str(self, db):
661
        if self.cols == None: row_vars = [self.row_var]
662
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
663
        
664
        str_ = '''\
665
DECLARE
666
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
667
BEGIN
668
    /* Need an EXCEPTION block for each individual row because "When
669
    an error is caught by an EXCEPTION clause, [...] all changes to
670
    persistent database state within the block are rolled back."
671
    This is unfortunate because "A block containing an EXCEPTION
672
    clause is significantly more expensive to enter and exit than a
673
    block without one."
674
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
675
#PLPGSQL-ERROR-TRAPPING)
676
    */
677
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
678
'''+strings.indent(self.select_query.to_str(db), 2)+'''\
679
    LOOP
680
'''+strings.indent(self.exc_handler.to_str(db, self.with_row), 2)+'''\
681
    END LOOP;
682
END;\
683
'''
684
        return str_
685

    
686
##### Casts
687

    
688
class Cast(FunctionCall):
689
    def __init__(self, type_, value):
690
        value = as_Value(value)
691
        
692
        self.type_ = type_
693
        self.value = value
694
    
695
    def to_str(self, db):
696
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_+')'
697

    
698
def cast_literal(value):
699
    if not is_literal(value): return value
700
    
701
    if util.is_str(value.value): value = Cast('text', value)
702
    return value
703

    
704
##### Conditions
705

    
706
class NotCond(Code):
707
    def __init__(self, cond):
708
        Code.__init__(self)
709
        
710
        self.cond = cond
711
    
712
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
713

    
714
class ColValueCond(Code):
715
    def __init__(self, col, value):
716
        Code.__init__(self)
717
        
718
        value = as_ValueCond(value)
719
        
720
        self.col = col
721
        self.value = value
722
    
723
    def to_str(self, db): return self.value.to_str(db, self.col)
724

    
725
def combine_conds(conds, keyword=None):
726
    '''
727
    @param keyword The keyword to add before the conditions, if any
728
    '''
729
    str_ = ''
730
    if keyword != None:
731
        if conds == []: whitespace = ''
732
        elif len(conds) == 1: whitespace = ' '
733
        else: whitespace = '\n'
734
        str_ += keyword+whitespace
735
    
736
    str_ += '\nAND '.join(conds)
737
    return str_
738

    
739
##### Condition column comparisons
740

    
741
class ValueCond(BasicObject):
742
    def __init__(self, value):
743
        value = remove_col_rename(as_Value(value))
744
        
745
        self.value = value
746
    
747
    def to_str(self, db, left_value):
748
        '''
749
        @param left_value The Code object that the condition is being applied on
750
        '''
751
        raise NotImplemented()
752
    
753
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
754

    
755
class CompareCond(ValueCond):
756
    def __init__(self, value, operator='='):
757
        '''
758
        @param operator By default, compares NULL values literally. Use '~=' or
759
            '~!=' to pass NULLs through.
760
        '''
761
        ValueCond.__init__(self, value)
762
        self.operator = operator
763
    
764
    def to_str(self, db, left_value):
765
        left_value = remove_col_rename(as_Col(left_value))
766
        
767
        right_value = self.value
768
        
769
        # Parse operator
770
        operator = self.operator
771
        passthru_null_ref = [False]
772
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
773
        neg_ref = [False]
774
        operator = strings.remove_prefix('!', operator, neg_ref)
775
        equals = operator.endswith('=') # also includes <=, >=
776
        
777
        # Handle nullable columns
778
        check_null = False
779
        if not passthru_null_ref[0]: # NULLs compare equal
780
            try: left_value = ensure_not_null(db, left_value)
781
            except ensure_not_null_excs: # fall back to alternate method
782
                check_null = equals and isinstance(right_value, Col)
783
            else:
784
                if isinstance(left_value, EnsureNotNull):
785
                    right_value = ensure_not_null(db, right_value,
786
                        left_value.type) # apply same function to both sides
787
        
788
        if equals and is_null(right_value): operator = 'IS'
789
        
790
        left = left_value.to_str(db)
791
        right = right_value.to_str(db)
792
        
793
        # Create str
794
        str_ = left+' '+operator+' '+right
795
        if check_null:
796
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
797
        if neg_ref[0]: str_ = 'NOT '+str_
798
        return str_
799

    
800
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
801
assume_literal = object()
802

    
803
def as_ValueCond(value, default_table=assume_literal):
804
    if not isinstance(value, ValueCond):
805
        if default_table is not assume_literal:
806
            value = with_default_table(value, default_table)
807
        return CompareCond(value)
808
    else: return value
809

    
810
##### Joins
811

    
812
join_same = object() # tells Join the left and right columns have the same name
813

    
814
# Tells Join the left and right columns have the same name and are never NULL
815
join_same_not_null = object()
816

    
817
filter_out = object() # tells Join to filter out rows that match the join
818

    
819
class Join(BasicObject):
820
    def __init__(self, table, mapping={}, type_=None):
821
        '''
822
        @param mapping dict(right_table_col=left_table_col, ...)
823
            * if left_table_col is join_same: left_table_col = right_table_col
824
              * Note that right_table_col must be a string
825
            * if left_table_col is join_same_not_null:
826
              left_table_col = right_table_col and both have NOT NULL constraint
827
              * Note that right_table_col must be a string
828
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
829
            * filter_out: equivalent to 'LEFT' with the query filtered by
830
              `table_pkey IS NULL` (indicating no match)
831
        '''
832
        if util.is_str(table): table = Table(table)
833
        assert type_ == None or util.is_str(type_) or type_ is filter_out
834
        
835
        self.table = table
836
        self.mapping = mapping
837
        self.type_ = type_
838
    
839
    def to_str(self, db, left_table_):
840
        def join(entry):
841
            '''Parses non-USING joins'''
842
            right_table_col, left_table_col = entry
843
            
844
            # Switch order (right_table_col is on the left in the comparison)
845
            left = right_table_col
846
            right = left_table_col
847
            left_table = self.table
848
            right_table = left_table_
849
            
850
            # Parse left side
851
            left = with_default_table(left, left_table)
852
            
853
            # Parse special values
854
            left_on_right = Col(left.name, right_table)
855
            if right is join_same: right = left_on_right
856
            elif right is join_same_not_null:
857
                right = CompareCond(left_on_right, '~=')
858
            
859
            # Parse right side
860
            right = as_ValueCond(right, right_table)
861
            
862
            return right.to_str(db, left)
863
        
864
        # Create join condition
865
        type_ = self.type_
866
        joins = self.mapping
867
        if joins == {}: join_cond = None
868
        elif type_ is not filter_out and reduce(operator.and_,
869
            (v is join_same_not_null for v in joins.itervalues())):
870
            # all cols w/ USING, so can use simpler USING syntax
871
            cols = map(to_name_only_col, joins.iterkeys())
872
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
873
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
874
        
875
        if isinstance(self.table, NamedTable): whitespace = '\n'
876
        else: whitespace = ' '
877
        
878
        # Create join
879
        if type_ is filter_out: type_ = 'LEFT'
880
        str_ = ''
881
        if type_ != None: str_ += type_+' '
882
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
883
        if join_cond != None: str_ += whitespace+join_cond
884
        return str_
885
    
886
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
887

    
888
##### Value exprs
889

    
890
all_cols = CustomCode('*')
891

    
892
default = CustomCode('DEFAULT')
893

    
894
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
895

    
896
class Coalesce(FunctionCall):
897
    def __init__(self, *args):
898
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
899

    
900
class Nullif(FunctionCall):
901
    def __init__(self, *args):
902
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
903

    
904
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
905
null_sentinels = {
906
    'character varying': r'\N',
907
    'double precision': 'NaN',
908
    'integer': 2147483647,
909
    'text': r'\N',
910
    'timestamp with time zone': 'infinity'
911
}
912

    
913
class EnsureNotNull(Coalesce):
914
    def __init__(self, value, type_):
915
        Coalesce.__init__(self, as_Col(value),
916
            Cast(type_, null_sentinels[type_]))
917
        
918
        self.type = type_
919
    
920
    def to_str(self, db):
921
        col = self.args[0]
922
        index_col_ = index_col(col)
923
        if index_col_ != None: return index_col_.to_str(db)
924
        return Coalesce.to_str(self, db)
925

    
926
##### Table exprs
927

    
928
class Values(Code):
929
    def __init__(self, values):
930
        '''
931
        @param values [...]|[[...], ...] Can be one or multiple rows.
932
        '''
933
        Code.__init__(self)
934
        
935
        rows = values
936
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
937
            rows = [values]
938
        for i, row in enumerate(rows):
939
            rows[i] = map(remove_col_rename, map(as_Value, row))
940
        
941
        self.rows = rows
942
    
943
    def to_str(self, db):
944
        return 'VALUES '+(', '.join((List(*r).to_str(db) for r in self.rows)))
945

    
946
def NamedValues(name, cols, values):
947
    '''
948
    @param cols None|[...]
949
    @post `cols` will be changed to Col objects with the table set to `name`.
950
    '''
951
    table = NamedTable(name, Values(values), cols)
952
    if cols != None: set_cols_table(table, cols)
953
    return table
954

    
955
##### Database structure
956

    
957
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
958

    
959
def ensure_not_null(db, col, type_=None):
960
    '''
961
    @param col If type_ is not set, must have an underlying column.
962
    @param type_ If set, overrides the underlying column's type.
963
    @return EnsureNotNull|Col
964
    @throws ensure_not_null_excs
965
    '''
966
    nullable = True
967
    try: typed_col = db.col_info(underlying_col(col))
968
    except NoUnderlyingTableException:
969
        col = remove_col_rename(col)
970
        if is_literal(col) and not is_null(col): nullable = False
971
        elif type_ == None: raise
972
    else:
973
        if type_ == None: type_ = typed_col.type
974
        nullable = typed_col.nullable
975
    
976
    if nullable:
977
        try: col = EnsureNotNull(col, type_)
978
        except KeyError, e:
979
            # Warn of no null sentinel for type, even if caller catches error
980
            warnings.warn(UserWarning(exc.str_(e)))
981
            raise
982
    
983
    return col
(25-25/37)