Project

General

Profile

1
# SQL code generation
2

    
3
import copy
4
import operator
5
from ordereddict import OrderedDict
6
import re
7
import UserDict
8
import warnings
9

    
10
import dicts
11
import exc
12
import iters
13
import lists
14
import objects
15
import strings
16
import util
17

    
18
##### Names
19

    
20
identifier_max_len = 63 # works for both PostgreSQL and MySQL
21

    
22
def concat(str_, suffix):
23
    '''Preserves version so that it won't be truncated off the string, leading
24
    to collisions.'''
25
    # Preserve version
26
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
27
    if match:
28
        str_, old_suffix = match.groups()
29
        suffix = old_suffix+suffix
30
    
31
    return strings.concat(str_, suffix, identifier_max_len)
32

    
33
def truncate(str_): return concat(str_, '')
34

    
35
def is_safe_name(name):
36
    '''A name is safe *and unambiguous* if it:
37
    * contains only *lowercase* word (\w) characters
38
    * doesn't start with a digit
39
    * contains "_", so that it's not a keyword
40
    '''
41
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
42

    
43
def esc_name(name, quote='"'):
44
    return quote + name.replace(quote, quote+quote) + quote
45
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
46

    
47
def unesc_name(name, quote='"'):
48
    removed_ref = [False]
49
    name = strings.remove_prefix(quote, name, removed_ref)
50
    if removed_ref[0]:
51
        name = strings.remove_suffix(quote, name, removed_ref)
52
        assert removed_ref[0]
53
        name = name.replace(quote+quote, quote)
54
    return name
55

    
56
def clean_name(name): return name.replace('"', '').replace('`', '')
57

    
58
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
59

    
60
def lstrip(str_):
61
    '''Also removes comments.'''
62
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
63
    return str_.lstrip()
64

    
65
##### General SQL code objects
66

    
67
class MockDb:
68
    def esc_value(self, value): return strings.repr_no_u(value)
69
    
70
    def esc_name(self, name): return esc_name(name)
71
    
72
    def col_info(self, col):
73
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
74

    
75
mockDb = MockDb()
76

    
77
class BasicObject(objects.BasicObject):
78
    def __str__(self): return clean_name(strings.repr_no_u(self))
79

    
80
##### Unparameterized code objects
81

    
82
class Code(BasicObject):
83
    def __init__(self, lang='sql'):
84
        self.lang = lang
85
    
86
    def to_str(self, db): raise NotImplementedError()
87
    
88
    def __repr__(self): return self.to_str(mockDb)
89

    
90
class CustomCode(Code):
91
    def __init__(self, str_):
92
        Code.__init__(self)
93
        
94
        self.str_ = str_
95
    
96
    def to_str(self, db): return self.str_
97

    
98
def as_Code(value, db=None):
99
    '''
100
    @param db If set, runs db.std_code() on the value.
101
    '''
102
    if isinstance(value, Code): return value
103
    
104
    if util.is_str(value):
105
        if db != None: value = db.std_code(value)
106
        return CustomCode(value)
107
    else: return Literal(value)
108

    
109
class Expr(Code):
110
    def __init__(self, expr):
111
        Code.__init__(self)
112
        
113
        self.expr = expr
114
    
115
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
116

    
117
##### Names
118

    
119
class Name(Code):
120
    def __init__(self, name):
121
        Code.__init__(self)
122
        
123
        name = truncate(name)
124
        
125
        self.name = name
126
    
127
    def to_str(self, db): return db.esc_name(self.name)
128

    
129
def as_Name(value):
130
    if isinstance(value, Code): return value
131
    else: return Name(value)
132

    
133
##### Literal values
134

    
135
class Literal(Code):
136
    def __init__(self, value):
137
        Code.__init__(self)
138
        
139
        self.value = value
140
    
141
    def to_str(self, db): return db.esc_value(self.value)
142

    
143
def as_Value(value):
144
    if isinstance(value, Code): return value
145
    else: return Literal(value)
146

    
147
def is_literal(value): return isinstance(value, Literal)
148

    
149
def is_null(value): return is_literal(value) and value.value == None
150

    
151
##### Derived elements
152

    
153
src_self = object() # tells Col that it is its own source column
154

    
155
class Derived(Code):
156
    def __init__(self, srcs):
157
        '''An element which was derived from some other element(s).
158
        @param srcs See self.set_srcs()
159
        '''
160
        Code.__init__(self)
161
        
162
        self.set_srcs(srcs)
163
    
164
    def set_srcs(self, srcs, overwrite=True):
165
        '''
166
        @param srcs (self_type...)|src_self The element(s) this is derived from
167
        '''
168
        if not overwrite and self.srcs != (): return # already set
169
        
170
        if srcs == src_self: srcs = (self,)
171
        srcs = tuple(srcs) # make Col hashable
172
        self.srcs = srcs
173
    
174
    def _compare_on(self):
175
        compare_on = self.__dict__.copy()
176
        del compare_on['srcs'] # ignore
177
        return compare_on
178

    
179
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
180

    
181
##### Tables
182

    
183
class Table(Derived):
184
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
185
        '''
186
        @param schema str|None (for no schema)
187
        @param srcs (Table...)|src_self See Derived.set_srcs()
188
        '''
189
        Derived.__init__(self, srcs)
190
        
191
        if util.is_str(name): name = truncate(name)
192
        
193
        self.name = name
194
        self.schema = schema
195
        self.is_temp = is_temp
196
        self.index_cols = {}
197
    
198
    def to_str(self, db):
199
        str_ = ''
200
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
201
        str_ += as_Name(self.name).to_str(db)
202
        return str_
203
    
204
    def to_Table(self): return self
205
    
206
    def _compare_on(self):
207
        compare_on = Derived._compare_on(self)
208
        del compare_on['index_cols'] # ignore
209
        return compare_on
210

    
211
def is_underlying_table(table):
212
    return isinstance(table, Table) and table.to_Table() is table
213

    
214
class NoUnderlyingTableException(Exception): pass
215

    
216
def underlying_table(table):
217
    table = remove_table_rename(table)
218
    if not is_underlying_table(table): raise NoUnderlyingTableException
219
    return table
220

    
221
def as_Table(table, schema=None):
222
    if table == None or isinstance(table, Code): return table
223
    else: return Table(table, schema)
224

    
225
def suffixed_table(table, suffix):
226
    table = copy.copy(table) # don't modify input!
227
    table.name = concat(table.name, suffix)
228
    return table
229

    
230
class NamedTable(Table):
231
    def __init__(self, name, code, cols=None):
232
        Table.__init__(self, name)
233
        
234
        code = as_Table(code)
235
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
236
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
237
        
238
        self.code = code
239
        self.cols = cols
240
    
241
    def to_str(self, db):
242
        str_ = self.code.to_str(db)
243
        if str_.find('\n') >= 0: whitespace = '\n'
244
        else: whitespace = ' '
245
        str_ += whitespace+'AS '+Table.to_str(self, db)
246
        if self.cols != None:
247
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
248
        return str_
249
    
250
    def to_Table(self): return Table(self.name)
251

    
252
def remove_table_rename(table):
253
    if isinstance(table, NamedTable): table = table.code
254
    return table
255

    
256
##### Columns
257

    
258
class Col(Derived):
259
    def __init__(self, name, table=None, srcs=()):
260
        '''
261
        @param table Table|None (for no table)
262
        @param srcs (Col...)|src_self See Derived.set_srcs()
263
        '''
264
        Derived.__init__(self, srcs)
265
        
266
        if util.is_str(name): name = truncate(name)
267
        if util.is_str(table): table = Table(table)
268
        assert table == None or isinstance(table, Table)
269
        
270
        self.name = name
271
        self.table = table
272
    
273
    def to_str(self, db, for_str=False):
274
        str_ = as_Name(self.name).to_str(db)
275
        if for_str: str_ = clean_name(str_)
276
        if self.table != None:
277
            table = self.table.to_Table()
278
            if for_str: str_ = concat(str(table), '.'+str_)
279
            else: str_ = table.to_str(db)+'.'+str_
280
        return str_
281
    
282
    def __str__(self): return self.to_str(mockDb, for_str=True)
283
    
284
    def to_Col(self): return self
285

    
286
def is_table_col(col): return isinstance(col, Col) and col.table != None
287

    
288
def index_col(col):
289
    if not is_table_col(col): return None
290
    
291
    table = col.table
292
    try: name = table.index_cols[col.name]
293
    except KeyError: return None
294
    else: return Col(name, table, col.srcs)
295

    
296
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
297

    
298
def as_Col(col, table=None, name=None):
299
    '''
300
    @param name If not None, any non-Col input will be renamed using NamedCol.
301
    '''
302
    if name != None:
303
        col = as_Value(col)
304
        if not isinstance(col, Col): col = NamedCol(name, col)
305
    
306
    if isinstance(col, Code): return col
307
    else: return Col(col, table)
308

    
309
def with_table(col, table):
310
    if isinstance(col, NamedCol): pass # doesn't take a table
311
    elif isinstance(col, FunctionCall):
312
        col = copy.deepcopy(col) # don't modify input!
313
        col.args[0].table = table
314
    elif isinstance(col, Col):
315
        col = copy.copy(col) # don't modify input!
316
        col.table = table
317
    return col
318

    
319
def with_default_table(col, table):
320
    col = as_Col(col)
321
    if col.table == None: col = with_table(col, table)
322
    return col
323

    
324
def set_cols_table(table, cols):
325
    table = as_Table(table)
326
    
327
    for i, col in enumerate(cols):
328
        col = cols[i] = as_Col(col)
329
        col.table = table
330

    
331
def to_name_only_col(col, check_table=None):
332
    col = as_Col(col)
333
    if not is_table_col(col): return col
334
    
335
    if check_table != None:
336
        table = col.table
337
        assert table == None or table == check_table
338
    return Col(col.name)
339

    
340
def suffixed_col(col, suffix):
341
    return Col(concat(col.name, suffix), col.table, col.srcs)
342

    
343
class NamedCol(Col):
344
    def __init__(self, name, code):
345
        Col.__init__(self, name)
346
        
347
        code = as_Value(code)
348
        
349
        self.code = code
350
    
351
    def to_str(self, db):
352
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
353
    
354
    def to_Col(self): return Col(self.name)
355

    
356
def remove_col_rename(col):
357
    if isinstance(col, NamedCol): col = col.code
358
    return col
359

    
360
def underlying_col(col):
361
    col = remove_col_rename(col)
362
    if not isinstance(col, Col): raise NoUnderlyingTableException
363
    
364
    return Col(col.name, underlying_table(col.table), col.srcs)
365

    
366
def wrap(wrap_func, value):
367
    '''Wraps a value, propagating any column renaming to the returned value.'''
368
    if isinstance(value, NamedCol):
369
        return NamedCol(value.name, wrap_func(value.code))
370
    else: return wrap_func(value)
371

    
372
class ColDict(dicts.DictProxy):
373
    '''A dict that automatically makes inserted entries Col objects'''
374
    
375
    def __init__(self, db, keys_table, dict_={}):
376
        dicts.DictProxy.__init__(self, OrderedDict())
377
        
378
        keys_table = as_Table(keys_table)
379
        
380
        self.db = db
381
        self.table = keys_table
382
        self.update(dict_) # after setting vars because __setitem__() needs them
383
    
384
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
385
    
386
    def __getitem__(self, key):
387
        return dicts.DictProxy.__getitem__(self, self._key(key))
388
    
389
    def __setitem__(self, key, value):
390
        key = self._key(key)
391
        if value == None: value = self.db.col_info(key).default
392
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
393
    
394
    def _key(self, key): return as_Col(key, self.table)
395

    
396
##### Composite types
397

    
398
class List(Code):
399
    def __init__(self, *values):
400
        Code.__init__(self)
401
        
402
        self.values = values
403
    
404
    def to_str(self, db):
405
        return '('+(', '.join((v.to_str(db) for v in self.values)))+')'
406

    
407
class Tuple(List):
408
    def to_str(self, db): return 'ROW'+List.to_str(self, db)
409

    
410
#### Definitions
411

    
412
class TypedCol(Col):
413
    def __init__(self, name, type_, default=None, nullable=True,
414
        constraints=None):
415
        assert default == None or isinstance(default, Code)
416
        
417
        Col.__init__(self, name)
418
        
419
        self.type = type_
420
        self.default = default
421
        self.nullable = nullable
422
        self.constraints = constraints
423
    
424
    def to_str(self, db):
425
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
426
        if not self.nullable: str_ += ' NOT NULL'
427
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
428
        if self.constraints != None: str_ += ' '+self.constraints
429
        return str_
430
    
431
    def to_Col(self): return Col(self.name)
432

    
433
class SetOf(Code):
434
    def __init__(self, type_):
435
        Code.__init__(self)
436
        
437
        self.type = type_
438
    
439
    def to_str(self, db):
440
        return 'SETOF '+self.type.to_str(db)
441

    
442
class RowType(Code):
443
    def __init__(self, table):
444
        Code.__init__(self)
445
        
446
        self.table = table
447
    
448
    def to_str(self, db):
449
        return self.table.to_str(db)+'%ROWTYPE'
450

    
451
class ColType(Code):
452
    def __init__(self, col):
453
        Code.__init__(self)
454
        
455
        self.col = col
456
    
457
    def to_str(self, db):
458
        return self.col.to_str(db)+'%TYPE'
459

    
460
##### Functions
461

    
462
Function = Table
463
as_Function = as_Table
464

    
465
class InternalFunction(CustomCode): pass
466

    
467
#### Calls
468

    
469
class NamedArg(NamedCol):
470
    def __init__(self, name, value):
471
        NamedCol.__init__(self, name, value)
472
    
473
    def to_str(self, db):
474
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
475

    
476
class FunctionCall(Code):
477
    def __init__(self, function, *args, **kw_args):
478
        '''
479
        @param args [Code|literal-value...] The function's arguments
480
        '''
481
        Code.__init__(self)
482
        
483
        function = as_Function(function)
484
        def filter_(arg): return remove_col_rename(as_Value(arg))
485
        args = map(filter_, args)
486
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
487
        
488
        self.function = function
489
        self.args = args
490
    
491
    def to_str(self, db):
492
        args_str = ', '.join((v.to_str(db) for v in self.args))
493
        return self.function.to_str(db)+'('+args_str+')'
494

    
495
def wrap_in_func(function, value):
496
    '''Wraps a value inside a function call.
497
    Propagates any column renaming to the returned value.
498
    '''
499
    return wrap(lambda v: FunctionCall(function, v), value)
500

    
501
def unwrap_func_call(func_call, check_name=None):
502
    '''Unwraps any function call to its first argument.
503
    Also removes any column renaming.
504
    '''
505
    func_call = remove_col_rename(func_call)
506
    if not isinstance(func_call, FunctionCall): return func_call
507
    
508
    if check_name != None:
509
        name = func_call.function.name
510
        assert name == None or name == check_name
511
    return func_call.args[0]
512

    
513
#### Definitions
514

    
515
class FunctionDef(Code):
516
    def __init__(self, function, return_type, body, params=[], modifiers=None):
517
        Code.__init__(self)
518
        
519
        return_type = as_Code(return_type)
520
        body = as_Code(body)
521
        
522
        self.function = function
523
        self.return_type = return_type
524
        self.body = body
525
        self.params = params
526
        self.modifiers = modifiers
527
    
528
    def to_str(self, db):
529
        params_str = (', '.join((p.to_str(db) for p in self.params)))
530
        str_ = '''\
531
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
532
RETURNS '''+self.return_type.to_str(db)+'''
533
LANGUAGE '''+self.body.lang+'''
534
'''
535
        if self.modifiers != None: str_ += self.modifiers+'\n'
536
        str_ += '''\
537
AS $$
538
'''+self.body.to_str(db)+'''
539
$$;
540
'''
541
        return str_
542

    
543
class FunctionParam(TypedCol):
544
    def __init__(self, name, type_, default=None, out=False):
545
        TypedCol.__init__(self, name, type_, default)
546
        
547
        self.out = out
548
    
549
    def to_str(self, db):
550
        str_ = TypedCol.to_str(self, db)
551
        if self.out: str_ = 'OUT '+str_
552
        return str_
553
    
554
    def to_Col(self): return Col(self.name)
555

    
556
### PL/pgSQL
557

    
558
class ReturnQuery(Code):
559
    def __init__(self, query):
560
        Code.__init__(self)
561
        
562
        query = as_Code(query)
563
        
564
        self.query = query
565
    
566
    def to_str(self, db):
567
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
568

    
569
class BaseExcHandler(BasicObject):
570
    def to_str(self, db, body): raise NotImplementedError()
571

    
572
class ExcHandler(BaseExcHandler):
573
    def __init__(self, exc, handler=None):
574
        if handler != None: handler = as_Code(handler)
575
        
576
        self.exc = exc
577
        self.handler = handler
578
    
579
    def to_str(self, db, body):
580
        body = as_Code(body)
581
        
582
        if self.handler != None:
583
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
584
        else: handler_str = ' NULL;\n'
585
        
586
        str_ = '''\
587
BEGIN
588
'''+strings.indent(body.to_str(db))+'''\
589
EXCEPTION
590
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
591
END;\
592
'''
593
        return str_
594
    
595
    def __repr__(self): return self.to_str(mockDb, '<body>')
596

    
597
class ExcToWarning(Code):
598
    def __init__(self, return_):
599
        '''
600
        @param return_ Statement to return a default value in case of error
601
        '''
602
        Code.__init__(self)
603
        
604
        return_ = as_Code(return_)
605
        
606
        self.return_ = return_
607
    
608
    def to_str(self, db):
609
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
610

    
611
unique_violation_handler = ExcHandler('unique_violation')
612

    
613
plpythonu_error_handler = ExcHandler('internal_error', '''\
614
RAISE data_exception USING MESSAGE =
615
    regexp_replace(SQLERRM, E'^PL/Python: \\w+: ', '');
616
''')
617

    
618
def data_exception_handler(handler):
619
    return ExcHandler('data_exception', handler)
620

    
621
class RowExcIgnore(Code):
622
    def __init__(self, row_type, select_query, with_row, cols=None,
623
        exc_handler=unique_violation_handler, row_var='row'):
624
        Code.__init__(self, lang='plpgsql')
625
        
626
        row_type = as_Code(row_type)
627
        select_query = as_Code(select_query)
628
        with_row = as_Code(with_row)
629
        row_var = as_Table(row_var)
630
        
631
        self.row_type = row_type
632
        self.select_query = select_query
633
        self.with_row = with_row
634
        self.cols = cols
635
        self.exc_handler = exc_handler
636
        self.row_var = row_var
637
    
638
    def to_str(self, db):
639
        if self.cols == None: row_vars = [self.row_var]
640
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
641
        
642
        str_ = '''\
643
DECLARE
644
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
645
BEGIN
646
    /* Need an EXCEPTION block for each individual row because "When
647
    an error is caught by an EXCEPTION clause, [...] all changes to
648
    persistent database state within the block are rolled back."
649
    This is unfortunate because "A block containing an EXCEPTION
650
    clause is significantly more expensive to enter and exit than a
651
    block without one."
652
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
653
#PLPGSQL-ERROR-TRAPPING)
654
    */
655
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
656
'''+strings.indent(self.select_query.to_str(db), 2)+'''\
657
    LOOP
658
'''+strings.indent(self.exc_handler.to_str(db, self.with_row), 2)+'''\
659
    END LOOP;
660
END;\
661
'''
662
        return str_
663

    
664
##### Casts
665

    
666
class Cast(FunctionCall):
667
    def __init__(self, type_, value):
668
        value = as_Value(value)
669
        
670
        self.type_ = type_
671
        self.value = value
672
    
673
    def to_str(self, db):
674
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_+')'
675

    
676
def cast_literal(value):
677
    if not is_literal(value): return value
678
    
679
    if util.is_str(value.value): value = Cast('text', value)
680
    return value
681

    
682
##### Conditions
683

    
684
class NotCond(Code):
685
    def __init__(self, cond):
686
        Code.__init__(self)
687
        
688
        self.cond = cond
689
    
690
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
691

    
692
class ColValueCond(Code):
693
    def __init__(self, col, value):
694
        Code.__init__(self)
695
        
696
        value = as_ValueCond(value)
697
        
698
        self.col = col
699
        self.value = value
700
    
701
    def to_str(self, db): return self.value.to_str(db, self.col)
702

    
703
def combine_conds(conds, keyword=None):
704
    '''
705
    @param keyword The keyword to add before the conditions, if any
706
    '''
707
    str_ = ''
708
    if keyword != None:
709
        if conds == []: whitespace = ''
710
        elif len(conds) == 1: whitespace = ' '
711
        else: whitespace = '\n'
712
        str_ += keyword+whitespace
713
    
714
    str_ += '\nAND '.join(conds)
715
    return str_
716

    
717
##### Condition column comparisons
718

    
719
class ValueCond(BasicObject):
720
    def __init__(self, value):
721
        value = remove_col_rename(as_Value(value))
722
        
723
        self.value = value
724
    
725
    def to_str(self, db, left_value):
726
        '''
727
        @param left_value The Code object that the condition is being applied on
728
        '''
729
        raise NotImplemented()
730
    
731
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
732

    
733
class CompareCond(ValueCond):
734
    def __init__(self, value, operator='='):
735
        '''
736
        @param operator By default, compares NULL values literally. Use '~=' or
737
            '~!=' to pass NULLs through.
738
        '''
739
        ValueCond.__init__(self, value)
740
        self.operator = operator
741
    
742
    def to_str(self, db, left_value):
743
        left_value = remove_col_rename(as_Col(left_value))
744
        
745
        right_value = self.value
746
        
747
        # Parse operator
748
        operator = self.operator
749
        passthru_null_ref = [False]
750
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
751
        neg_ref = [False]
752
        operator = strings.remove_prefix('!', operator, neg_ref)
753
        equals = operator.endswith('=') # also includes <=, >=
754
        
755
        # Handle nullable columns
756
        check_null = False
757
        if not passthru_null_ref[0]: # NULLs compare equal
758
            try: left_value = ensure_not_null(db, left_value)
759
            except ensure_not_null_excs: # fall back to alternate method
760
                check_null = equals and isinstance(right_value, Col)
761
            else:
762
                if isinstance(left_value, EnsureNotNull):
763
                    right_value = ensure_not_null(db, right_value,
764
                        left_value.type) # apply same function to both sides
765
        
766
        if equals and is_null(right_value): operator = 'IS'
767
        
768
        left = left_value.to_str(db)
769
        right = right_value.to_str(db)
770
        
771
        # Create str
772
        str_ = left+' '+operator+' '+right
773
        if check_null:
774
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
775
        if neg_ref[0]: str_ = 'NOT '+str_
776
        return str_
777

    
778
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
779
assume_literal = object()
780

    
781
def as_ValueCond(value, default_table=assume_literal):
782
    if not isinstance(value, ValueCond):
783
        if default_table is not assume_literal:
784
            value = with_default_table(value, default_table)
785
        return CompareCond(value)
786
    else: return value
787

    
788
##### Joins
789

    
790
join_same = object() # tells Join the left and right columns have the same name
791

    
792
# Tells Join the left and right columns have the same name and are never NULL
793
join_same_not_null = object()
794

    
795
filter_out = object() # tells Join to filter out rows that match the join
796

    
797
class Join(BasicObject):
798
    def __init__(self, table, mapping={}, type_=None):
799
        '''
800
        @param mapping dict(right_table_col=left_table_col, ...)
801
            * if left_table_col is join_same: left_table_col = right_table_col
802
              * Note that right_table_col must be a string
803
            * if left_table_col is join_same_not_null:
804
              left_table_col = right_table_col and both have NOT NULL constraint
805
              * Note that right_table_col must be a string
806
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
807
            * filter_out: equivalent to 'LEFT' with the query filtered by
808
              `table_pkey IS NULL` (indicating no match)
809
        '''
810
        if util.is_str(table): table = Table(table)
811
        assert type_ == None or util.is_str(type_) or type_ is filter_out
812
        
813
        self.table = table
814
        self.mapping = mapping
815
        self.type_ = type_
816
    
817
    def to_str(self, db, left_table_):
818
        def join(entry):
819
            '''Parses non-USING joins'''
820
            right_table_col, left_table_col = entry
821
            
822
            # Switch order (right_table_col is on the left in the comparison)
823
            left = right_table_col
824
            right = left_table_col
825
            left_table = self.table
826
            right_table = left_table_
827
            
828
            # Parse left side
829
            left = with_default_table(left, left_table)
830
            
831
            # Parse special values
832
            left_on_right = Col(left.name, right_table)
833
            if right is join_same: right = left_on_right
834
            elif right is join_same_not_null:
835
                right = CompareCond(left_on_right, '~=')
836
            
837
            # Parse right side
838
            right = as_ValueCond(right, right_table)
839
            
840
            return right.to_str(db, left)
841
        
842
        # Create join condition
843
        type_ = self.type_
844
        joins = self.mapping
845
        if joins == {}: join_cond = None
846
        elif type_ is not filter_out and reduce(operator.and_,
847
            (v is join_same_not_null for v in joins.itervalues())):
848
            # all cols w/ USING, so can use simpler USING syntax
849
            cols = map(to_name_only_col, joins.iterkeys())
850
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
851
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
852
        
853
        if isinstance(self.table, NamedTable): whitespace = '\n'
854
        else: whitespace = ' '
855
        
856
        # Create join
857
        if type_ is filter_out: type_ = 'LEFT'
858
        str_ = ''
859
        if type_ != None: str_ += type_+' '
860
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
861
        if join_cond != None: str_ += whitespace+join_cond
862
        return str_
863
    
864
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
865

    
866
##### Value exprs
867

    
868
all_cols = CustomCode('*')
869

    
870
default = CustomCode('DEFAULT')
871

    
872
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
873

    
874
class Coalesce(FunctionCall):
875
    def __init__(self, *args):
876
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
877

    
878
class Nullif(FunctionCall):
879
    def __init__(self, *args):
880
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
881

    
882
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
883
null_sentinels = {
884
    'character varying': r'\N',
885
    'double precision': 'NaN',
886
    'integer': 2147483647,
887
    'text': r'\N',
888
    'timestamp with time zone': 'infinity'
889
}
890

    
891
class EnsureNotNull(Coalesce):
892
    def __init__(self, value, type_):
893
        Coalesce.__init__(self, as_Col(value),
894
            Cast(type_, null_sentinels[type_]))
895
        
896
        self.type = type_
897
    
898
    def to_str(self, db):
899
        col = self.args[0]
900
        index_col_ = index_col(col)
901
        if index_col_ != None: return index_col_.to_str(db)
902
        return Coalesce.to_str(self, db)
903

    
904
##### Table exprs
905

    
906
class Values(Code):
907
    def __init__(self, values):
908
        '''
909
        @param values [...]|[[...], ...] Can be one or multiple rows.
910
        '''
911
        Code.__init__(self)
912
        
913
        rows = values
914
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
915
            rows = [values]
916
        for i, row in enumerate(rows):
917
            rows[i] = map(remove_col_rename, map(as_Value, row))
918
        
919
        self.rows = rows
920
    
921
    def to_str(self, db):
922
        return 'VALUES '+(', '.join((List(*r).to_str(db) for r in self.rows)))
923

    
924
def NamedValues(name, cols, values):
925
    '''
926
    @param cols None|[...]
927
    @post `cols` will be changed to Col objects with the table set to `name`.
928
    '''
929
    table = NamedTable(name, Values(values), cols)
930
    if cols != None: set_cols_table(table, cols)
931
    return table
932

    
933
##### Database structure
934

    
935
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
936

    
937
def ensure_not_null(db, col, type_=None):
938
    '''
939
    @param col If type_ is not set, must have an underlying column.
940
    @param type_ If set, overrides the underlying column's type.
941
    @return EnsureNotNull|Col
942
    @throws ensure_not_null_excs
943
    '''
944
    nullable = True
945
    try: typed_col = db.col_info(underlying_col(col))
946
    except NoUnderlyingTableException:
947
        col = remove_col_rename(col)
948
        if is_literal(col) and not is_null(col): nullable = False
949
        elif type_ == None: raise
950
    else:
951
        if type_ == None: type_ = typed_col.type
952
        nullable = typed_col.nullable
953
    
954
    if nullable:
955
        try: col = EnsureNotNull(col, type_)
956
        except KeyError, e:
957
            # Warn of no null sentinel for type, even if caller catches error
958
            warnings.warn(UserWarning(exc.str_(e)))
959
            raise
960
    
961
    return col
(25-25/37)