Project

General

Profile

1
# SQL code generation
2

    
3
import copy
4
import itertools
5
import operator
6
from ordereddict import OrderedDict
7
import re
8
import UserDict
9
import warnings
10

    
11
import dicts
12
import exc
13
import iters
14
import lists
15
import objects
16
import strings
17
import util
18

    
19
##### Names
20

    
21
identifier_max_len = 63 # works for both PostgreSQL and MySQL
22

    
23
def concat(str_, suffix):
24
    '''Preserves version so that it won't be truncated off the string, leading
25
    to collisions.'''
26
    # Preserve version
27
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
28
    if match:
29
        str_, old_suffix = match.groups()
30
        suffix = old_suffix+suffix
31
    
32
    return strings.concat(str_, suffix, identifier_max_len)
33

    
34
def truncate(str_): return concat(str_, '')
35

    
36
def is_safe_name(name):
37
    '''A name is safe *and unambiguous* if it:
38
    * contains only *lowercase* word (\w) characters
39
    * doesn't start with a digit
40
    * contains "_", so that it's not a keyword
41
    '''
42
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
43

    
44
def esc_name(name, quote='"'):
45
    return quote + name.replace(quote, quote+quote) + quote
46
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
47

    
48
def unesc_name(name, quote='"'):
49
    removed_ref = [False]
50
    name = strings.remove_prefix(quote, name, removed_ref)
51
    if removed_ref[0]:
52
        name = strings.remove_suffix(quote, name, removed_ref)
53
        assert removed_ref[0]
54
        name = name.replace(quote+quote, quote)
55
    return name
56

    
57
def clean_name(name): return name.replace('"', '').replace('`', '')
58

    
59
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
60

    
61
def lstrip(str_):
62
    '''Also removes comments.'''
63
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
64
    return str_.lstrip()
65

    
66
##### General SQL code objects
67

    
68
class MockDb:
69
    def esc_value(self, value): return strings.repr_no_u(value)
70
    
71
    def esc_name(self, name): return esc_name(name)
72
    
73
    def col_info(self, col):
74
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
75

    
76
mockDb = MockDb()
77

    
78
class BasicObject(objects.BasicObject):
79
    def __str__(self): return clean_name(strings.repr_no_u(self))
80

    
81
##### Unparameterized code objects
82

    
83
class Code(BasicObject):
84
    def __init__(self, lang='sql'):
85
        self.lang = lang
86
    
87
    def to_str(self, db): raise NotImplementedError()
88
    
89
    def __repr__(self): return self.to_str(mockDb)
90

    
91
class CustomCode(Code):
92
    def __init__(self, str_):
93
        Code.__init__(self)
94
        
95
        self.str_ = str_
96
    
97
    def to_str(self, db): return self.str_
98

    
99
def as_Code(value, db=None):
100
    '''
101
    @param db If set, runs db.std_code() on the value.
102
    '''
103
    if isinstance(value, Code): return value
104
    
105
    if util.is_str(value):
106
        if db != None: value = db.std_code(value)
107
        return CustomCode(value)
108
    else: return Literal(value)
109

    
110
class Expr(Code):
111
    def __init__(self, expr):
112
        Code.__init__(self)
113
        
114
        self.expr = expr
115
    
116
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
117

    
118
##### Names
119

    
120
class Name(Code):
121
    def __init__(self, name):
122
        Code.__init__(self)
123
        
124
        name = truncate(name)
125
        
126
        self.name = name
127
    
128
    def to_str(self, db): return db.esc_name(self.name)
129

    
130
def as_Name(value):
131
    if isinstance(value, Code): return value
132
    else: return Name(value)
133

    
134
##### Literal values
135

    
136
#### Primitives
137

    
138
class Literal(Code):
139
    def __init__(self, value):
140
        Code.__init__(self)
141
        
142
        self.value = value
143
    
144
    def to_str(self, db): return db.esc_value(self.value)
145

    
146
def as_Value(value):
147
    if isinstance(value, Code): return value
148
    else: return Literal(value)
149

    
150
def is_literal(value): return isinstance(value, Literal)
151

    
152
def is_null(value): return is_literal(value) and value.value == None
153

    
154
#### Composites
155

    
156
class List(Code):
157
    def __init__(self, values):
158
        Code.__init__(self)
159
        
160
        self.values = values
161
    
162
    def to_str(self, db): return ', '.join((v.to_str(db) for v in self.values))
163

    
164
class Tuple(List):
165
    def __init__(self, *values):
166
        List.__init__(self, values)
167
    
168
    def to_str(self, db): return '('+List.to_str(self, db)+')'
169

    
170
class Row(Tuple):
171
    def to_str(self, db): return 'ROW'+Tuple.to_str(self, db)
172

    
173
### Arrays
174

    
175
class Array(List):
176
    def __init__(self, values):
177
        values = map(remove_col_rename, values)
178
        
179
        List.__init__(self, values)
180
    
181
    def to_str(self, db): return 'ARRAY['+List.to_str(self, db)+']'
182

    
183
def to_Array(value):
184
    if isinstance(value, Array): return value
185
    return Array(lists.mk_seq(value))
186

    
187
##### Derived elements
188

    
189
src_self = object() # tells Col that it is its own source column
190

    
191
class Derived(Code):
192
    def __init__(self, srcs):
193
        '''An element which was derived from some other element(s).
194
        @param srcs See self.set_srcs()
195
        '''
196
        Code.__init__(self)
197
        
198
        self.set_srcs(srcs)
199
    
200
    def set_srcs(self, srcs, overwrite=True):
201
        '''
202
        @param srcs (self_type...)|src_self The element(s) this is derived from
203
        '''
204
        if not overwrite and self.srcs != (): return # already set
205
        
206
        if srcs == src_self: srcs = (self,)
207
        srcs = tuple(srcs) # make Col hashable
208
        self.srcs = srcs
209
    
210
    def _compare_on(self):
211
        compare_on = self.__dict__.copy()
212
        del compare_on['srcs'] # ignore
213
        return compare_on
214

    
215
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
216

    
217
##### Tables
218

    
219
class Table(Derived):
220
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
221
        '''
222
        @param schema str|None (for no schema)
223
        @param srcs (Table...)|src_self See Derived.set_srcs()
224
        '''
225
        Derived.__init__(self, srcs)
226
        
227
        if util.is_str(name): name = truncate(name)
228
        
229
        self.name = name
230
        self.schema = schema
231
        self.is_temp = is_temp
232
        self.index_cols = {}
233
    
234
    def to_str(self, db):
235
        str_ = ''
236
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
237
        str_ += as_Name(self.name).to_str(db)
238
        return str_
239
    
240
    def to_Table(self): return self
241
    
242
    def _compare_on(self):
243
        compare_on = Derived._compare_on(self)
244
        del compare_on['index_cols'] # ignore
245
        return compare_on
246

    
247
def is_underlying_table(table):
248
    return isinstance(table, Table) and table.to_Table() is table
249

    
250
class NoUnderlyingTableException(Exception): pass
251

    
252
def underlying_table(table):
253
    table = remove_table_rename(table)
254
    if not is_underlying_table(table): raise NoUnderlyingTableException
255
    return table
256

    
257
def as_Table(table, schema=None):
258
    if table == None or isinstance(table, Code): return table
259
    else: return Table(table, schema)
260

    
261
def suffixed_table(table, suffix):
262
    table = copy.copy(table) # don't modify input!
263
    table.name = concat(table.name, suffix)
264
    return table
265

    
266
class NamedTable(Table):
267
    def __init__(self, name, code, cols=None):
268
        Table.__init__(self, name)
269
        
270
        code = as_Table(code)
271
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
272
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
273
        
274
        self.code = code
275
        self.cols = cols
276
    
277
    def to_str(self, db):
278
        str_ = self.code.to_str(db)
279
        if str_.find('\n') >= 0: whitespace = '\n'
280
        else: whitespace = ' '
281
        str_ += whitespace+'AS '+Table.to_str(self, db)
282
        if self.cols != None:
283
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
284
        return str_
285
    
286
    def to_Table(self): return Table(self.name)
287

    
288
def remove_table_rename(table):
289
    if isinstance(table, NamedTable): table = table.code
290
    return table
291

    
292
##### Columns
293

    
294
class Col(Derived):
295
    def __init__(self, name, table=None, srcs=()):
296
        '''
297
        @param table Table|None (for no table)
298
        @param srcs (Col...)|src_self See Derived.set_srcs()
299
        '''
300
        Derived.__init__(self, srcs)
301
        
302
        if util.is_str(name): name = truncate(name)
303
        if util.is_str(table): table = Table(table)
304
        assert table == None or isinstance(table, Table)
305
        
306
        self.name = name
307
        self.table = table
308
    
309
    def to_str(self, db, for_str=False):
310
        str_ = as_Name(self.name).to_str(db)
311
        if for_str: str_ = clean_name(str_)
312
        if self.table != None:
313
            table = self.table.to_Table()
314
            if for_str: str_ = concat(str(table), '.'+str_)
315
            else: str_ = table.to_str(db)+'.'+str_
316
        return str_
317
    
318
    def __str__(self): return self.to_str(mockDb, for_str=True)
319
    
320
    def to_Col(self): return self
321

    
322
def is_col(col): return isinstance(col, Col)
323

    
324
def is_table_col(col): return is_col(col) and col.table != None
325

    
326
def index_col(col):
327
    if not is_table_col(col): return None
328
    
329
    table = col.table
330
    try: name = table.index_cols[col.name]
331
    except KeyError: return None
332
    else: return Col(name, table, col.srcs)
333

    
334
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
335

    
336
def as_Col(col, table=None, name=None):
337
    '''
338
    @param name If not None, any non-Col input will be renamed using NamedCol.
339
    '''
340
    if name != None:
341
        col = as_Value(col)
342
        if not isinstance(col, Col): col = NamedCol(name, col)
343
    
344
    if isinstance(col, Code): return col
345
    elif util.is_str(col): return Col(col, table)
346
    else: return Literal(col)
347

    
348
def with_table(col, table):
349
    if isinstance(col, NamedCol): pass # doesn't take a table
350
    elif isinstance(col, FunctionCall):
351
        col = copy.deepcopy(col) # don't modify input!
352
        col.args[0].table = table
353
    elif isinstance(col, Col):
354
        col = copy.copy(col) # don't modify input!
355
        col.table = table
356
    return col
357

    
358
def with_default_table(col, table):
359
    col = as_Col(col)
360
    if col.table == None: col = with_table(col, table)
361
    return col
362

    
363
def set_cols_table(table, cols):
364
    table = as_Table(table)
365
    
366
    for i, col in enumerate(cols):
367
        col = cols[i] = as_Col(col)
368
        col.table = table
369

    
370
def to_name_only_col(col, check_table=None):
371
    col = as_Col(col)
372
    if not is_table_col(col): return col
373
    
374
    if check_table != None:
375
        table = col.table
376
        assert table == None or table == check_table
377
    return Col(col.name)
378

    
379
def suffixed_col(col, suffix):
380
    return Col(concat(col.name, suffix), col.table, col.srcs)
381

    
382
def has_srcs(col): return is_col(col) and col.srcs
383

    
384
def cross_join_srcs(cols):
385
    cols = filter(has_srcs, cols) # empty srcs will mess up the cross join
386
    srcs = [[s.name for s in c.srcs] for c in cols]
387
    return [Col(','.join(s)) for s in itertools.product(*srcs)]
388

    
389
class NamedCol(Col):
390
    def __init__(self, name, code):
391
        Col.__init__(self, name)
392
        
393
        code = as_Value(code)
394
        
395
        self.code = code
396
    
397
    def to_str(self, db):
398
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
399
    
400
    def to_Col(self): return Col(self.name)
401

    
402
def remove_col_rename(col):
403
    if isinstance(col, NamedCol): col = col.code
404
    return col
405

    
406
def underlying_col(col):
407
    col = remove_col_rename(col)
408
    if not isinstance(col, Col): raise NoUnderlyingTableException
409
    
410
    return Col(col.name, underlying_table(col.table), col.srcs)
411

    
412
def wrap(wrap_func, value):
413
    '''Wraps a value, propagating any column renaming to the returned value.'''
414
    if isinstance(value, NamedCol):
415
        return NamedCol(value.name, wrap_func(value.code))
416
    else: return wrap_func(value)
417

    
418
class ColDict(dicts.DictProxy):
419
    '''A dict that automatically makes inserted entries Col objects'''
420
    
421
    def __init__(self, db, keys_table, dict_={}):
422
        dicts.DictProxy.__init__(self, OrderedDict())
423
        
424
        keys_table = as_Table(keys_table)
425
        
426
        self.db = db
427
        self.table = keys_table
428
        self.update(dict_) # after setting vars because __setitem__() needs them
429
    
430
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
431
    
432
    def __getitem__(self, key):
433
        return dicts.DictProxy.__getitem__(self, self._key(key))
434
    
435
    def __setitem__(self, key, value):
436
        key = self._key(key)
437
        if value == None: value = self.db.col_info(key).default
438
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
439
    
440
    def _key(self, key): return as_Col(key, self.table)
441

    
442
##### Definitions
443

    
444
class TypedCol(Col):
445
    def __init__(self, name, type_, default=None, nullable=True,
446
        constraints=None):
447
        assert default == None or isinstance(default, Code)
448
        
449
        Col.__init__(self, name)
450
        
451
        self.type = type_
452
        self.default = default
453
        self.nullable = nullable
454
        self.constraints = constraints
455
    
456
    def to_str(self, db):
457
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
458
        if not self.nullable: str_ += ' NOT NULL'
459
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
460
        if self.constraints != None: str_ += ' '+self.constraints
461
        return str_
462
    
463
    def to_Col(self): return Col(self.name)
464

    
465
class SetOf(Code):
466
    def __init__(self, type_):
467
        Code.__init__(self)
468
        
469
        self.type = type_
470
    
471
    def to_str(self, db):
472
        return 'SETOF '+self.type.to_str(db)
473

    
474
class RowType(Code):
475
    def __init__(self, table):
476
        Code.__init__(self)
477
        
478
        self.table = table
479
    
480
    def to_str(self, db):
481
        return self.table.to_str(db)+'%ROWTYPE'
482

    
483
class ColType(Code):
484
    def __init__(self, col):
485
        Code.__init__(self)
486
        
487
        self.col = col
488
    
489
    def to_str(self, db):
490
        return self.col.to_str(db)+'%TYPE'
491

    
492
##### Functions
493

    
494
Function = Table
495
as_Function = as_Table
496

    
497
class InternalFunction(CustomCode): pass
498

    
499
#### Calls
500

    
501
class NamedArg(NamedCol):
502
    def __init__(self, name, value):
503
        NamedCol.__init__(self, name, value)
504
    
505
    def to_str(self, db):
506
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
507

    
508
class FunctionCall(Code):
509
    def __init__(self, function, *args, **kw_args):
510
        '''
511
        @param args [Code|literal-value...] The function's arguments
512
        '''
513
        Code.__init__(self)
514
        
515
        function = as_Function(function)
516
        def filter_(arg): return remove_col_rename(as_Value(arg))
517
        args = map(filter_, args)
518
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
519
        
520
        self.function = function
521
        self.args = args
522
    
523
    def to_str(self, db):
524
        args_str = ', '.join((v.to_str(db) for v in self.args))
525
        return self.function.to_str(db)+'('+args_str+')'
526

    
527
def wrap_in_func(function, value):
528
    '''Wraps a value inside a function call.
529
    Propagates any column renaming to the returned value.
530
    '''
531
    return wrap(lambda v: FunctionCall(function, v), value)
532

    
533
def unwrap_func_call(func_call, check_name=None):
534
    '''Unwraps any function call to its first argument.
535
    Also removes any column renaming.
536
    '''
537
    func_call = remove_col_rename(func_call)
538
    if not isinstance(func_call, FunctionCall): return func_call
539
    
540
    if check_name != None:
541
        name = func_call.function.name
542
        assert name == None or name == check_name
543
    return func_call.args[0]
544

    
545
#### Definitions
546

    
547
class FunctionDef(Code):
548
    def __init__(self, function, return_type, body, params=[], modifiers=None):
549
        Code.__init__(self)
550
        
551
        return_type = as_Code(return_type)
552
        body = as_Code(body)
553
        
554
        self.function = function
555
        self.return_type = return_type
556
        self.body = body
557
        self.params = params
558
        self.modifiers = modifiers
559
    
560
    def to_str(self, db):
561
        params_str = (', '.join((p.to_str(db) for p in self.params)))
562
        str_ = '''\
563
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
564
RETURNS '''+self.return_type.to_str(db)+'''
565
LANGUAGE '''+self.body.lang+'''
566
'''
567
        if self.modifiers != None: str_ += self.modifiers+'\n'
568
        str_ += '''\
569
AS $$
570
'''+self.body.to_str(db)+'''
571
$$;
572
'''
573
        return str_
574

    
575
class FunctionParam(TypedCol):
576
    def __init__(self, name, type_, default=None, out=False):
577
        TypedCol.__init__(self, name, type_, default)
578
        
579
        self.out = out
580
    
581
    def to_str(self, db):
582
        str_ = TypedCol.to_str(self, db)
583
        if self.out: str_ = 'OUT '+str_
584
        return str_
585
    
586
    def to_Col(self): return Col(self.name)
587

    
588
### PL/pgSQL
589

    
590
class ReturnQuery(Code):
591
    def __init__(self, query):
592
        Code.__init__(self)
593
        
594
        query = as_Code(query)
595
        
596
        self.query = query
597
    
598
    def to_str(self, db):
599
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
600

    
601
## Exceptions
602

    
603
class BaseExcHandler(BasicObject):
604
    def to_str(self, db, body): raise NotImplementedError()
605
    
606
    def __repr__(self): return self.to_str(mockDb, '<body>')
607

    
608
class ExcHandler(BaseExcHandler):
609
    def __init__(self, exc, handler=None):
610
        if handler != None: handler = as_Code(handler)
611
        
612
        self.exc = exc
613
        self.handler = handler
614
    
615
    def to_str(self, db, body):
616
        body = as_Code(body)
617
        
618
        if self.handler != None:
619
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
620
        else: handler_str = ' NULL;\n'
621
        
622
        str_ = '''\
623
BEGIN
624
'''+strings.indent(body.to_str(db))+'''\
625
EXCEPTION
626
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
627
END;\
628
'''
629
        return str_
630

    
631
class NestedExcHandler(BaseExcHandler):
632
    def __init__(self, *handlers):
633
        '''
634
        @param handlers Sorted from outermost to innermost
635
        '''
636
        self.handlers = handlers
637
    
638
    def to_str(self, db, body):
639
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
640
        return body
641

    
642
class ExcToWarning(Code):
643
    def __init__(self, return_):
644
        '''
645
        @param return_ Statement to return a default value in case of error
646
        '''
647
        Code.__init__(self)
648
        
649
        return_ = as_Code(return_)
650
        
651
        self.return_ = return_
652
    
653
    def to_str(self, db):
654
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
655

    
656
unique_violation_handler = ExcHandler('unique_violation')
657

    
658
plpythonu_error_handler = ExcHandler('internal_error', '''\
659
RAISE data_exception USING MESSAGE =
660
    regexp_replace(SQLERRM, E'^PL/Python: \\w+: ', '');
661
''')
662

    
663
def data_exception_handler(handler):
664
    return ExcHandler('data_exception', handler)
665

    
666
class RowExcIgnore(Code):
667
    def __init__(self, row_type, select_query, with_row, cols=None,
668
        exc_handler=unique_violation_handler, row_var='row'):
669
        Code.__init__(self, lang='plpgsql')
670
        
671
        row_type = as_Code(row_type)
672
        select_query = as_Code(select_query)
673
        with_row = as_Code(with_row)
674
        row_var = as_Table(row_var)
675
        
676
        self.row_type = row_type
677
        self.select_query = select_query
678
        self.with_row = with_row
679
        self.cols = cols
680
        self.exc_handler = exc_handler
681
        self.row_var = row_var
682
    
683
    def to_str(self, db):
684
        if self.cols == None: row_vars = [self.row_var]
685
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
686
        
687
        str_ = '''\
688
DECLARE
689
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
690
BEGIN
691
    /* Need an EXCEPTION block for each individual row because "When
692
    an error is caught by an EXCEPTION clause, [...] all changes to
693
    persistent database state within the block are rolled back."
694
    This is unfortunate because "A block containing an EXCEPTION
695
    clause is significantly more expensive to enter and exit than a
696
    block without one."
697
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
698
#PLPGSQL-ERROR-TRAPPING)
699
    */
700
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
701
'''+strings.indent(self.select_query.to_str(db), 2)+'''\
702
    LOOP
703
'''+strings.indent(self.exc_handler.to_str(db, self.with_row), 2)+'''\
704
    END LOOP;
705
END;\
706
'''
707
        return str_
708

    
709
##### Casts
710

    
711
class Cast(FunctionCall):
712
    def __init__(self, type_, value):
713
        value = as_Value(value)
714
        
715
        self.type_ = type_
716
        self.value = value
717
    
718
    def to_str(self, db):
719
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_+')'
720

    
721
def cast_literal(value):
722
    if not is_literal(value): return value
723
    
724
    if util.is_str(value.value): value = Cast('text', value)
725
    return value
726

    
727
##### Conditions
728

    
729
class NotCond(Code):
730
    def __init__(self, cond):
731
        Code.__init__(self)
732
        
733
        self.cond = cond
734
    
735
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
736

    
737
class ColValueCond(Code):
738
    def __init__(self, col, value):
739
        Code.__init__(self)
740
        
741
        value = as_ValueCond(value)
742
        
743
        self.col = col
744
        self.value = value
745
    
746
    def to_str(self, db): return self.value.to_str(db, self.col)
747

    
748
def combine_conds(conds, keyword=None):
749
    '''
750
    @param keyword The keyword to add before the conditions, if any
751
    '''
752
    str_ = ''
753
    if keyword != None:
754
        if conds == []: whitespace = ''
755
        elif len(conds) == 1: whitespace = ' '
756
        else: whitespace = '\n'
757
        str_ += keyword+whitespace
758
    
759
    str_ += '\nAND '.join(conds)
760
    return str_
761

    
762
##### Condition column comparisons
763

    
764
class ValueCond(BasicObject):
765
    def __init__(self, value):
766
        value = remove_col_rename(as_Value(value))
767
        
768
        self.value = value
769
    
770
    def to_str(self, db, left_value):
771
        '''
772
        @param left_value The Code object that the condition is being applied on
773
        '''
774
        raise NotImplemented()
775
    
776
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
777

    
778
class CompareCond(ValueCond):
779
    def __init__(self, value, operator='='):
780
        '''
781
        @param operator By default, compares NULL values literally. Use '~=' or
782
            '~!=' to pass NULLs through.
783
        '''
784
        ValueCond.__init__(self, value)
785
        self.operator = operator
786
    
787
    def to_str(self, db, left_value):
788
        left_value = remove_col_rename(as_Col(left_value))
789
        
790
        right_value = self.value
791
        
792
        # Parse operator
793
        operator = self.operator
794
        passthru_null_ref = [False]
795
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
796
        neg_ref = [False]
797
        operator = strings.remove_prefix('!', operator, neg_ref)
798
        equals = operator.endswith('=') # also includes <=, >=
799
        
800
        # Handle nullable columns
801
        check_null = False
802
        if not passthru_null_ref[0]: # NULLs compare equal
803
            try: left_value = ensure_not_null(db, left_value)
804
            except ensure_not_null_excs: # fall back to alternate method
805
                check_null = equals and isinstance(right_value, Col)
806
            else:
807
                if isinstance(left_value, EnsureNotNull):
808
                    right_value = ensure_not_null(db, right_value,
809
                        left_value.type) # apply same function to both sides
810
        
811
        if equals and is_null(right_value): operator = 'IS'
812
        
813
        left = left_value.to_str(db)
814
        right = right_value.to_str(db)
815
        
816
        # Create str
817
        str_ = left+' '+operator+' '+right
818
        if check_null:
819
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
820
        if neg_ref[0]: str_ = 'NOT '+str_
821
        return str_
822

    
823
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
824
assume_literal = object()
825

    
826
def as_ValueCond(value, default_table=assume_literal):
827
    if not isinstance(value, ValueCond):
828
        if default_table is not assume_literal:
829
            value = with_default_table(value, default_table)
830
        return CompareCond(value)
831
    else: return value
832

    
833
##### Joins
834

    
835
join_same = object() # tells Join the left and right columns have the same name
836

    
837
# Tells Join the left and right columns have the same name and are never NULL
838
join_same_not_null = object()
839

    
840
filter_out = object() # tells Join to filter out rows that match the join
841

    
842
class Join(BasicObject):
843
    def __init__(self, table, mapping={}, type_=None):
844
        '''
845
        @param mapping dict(right_table_col=left_table_col, ...)
846
            * if left_table_col is join_same: left_table_col = right_table_col
847
              * Note that right_table_col must be a string
848
            * if left_table_col is join_same_not_null:
849
              left_table_col = right_table_col and both have NOT NULL constraint
850
              * Note that right_table_col must be a string
851
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
852
            * filter_out: equivalent to 'LEFT' with the query filtered by
853
              `table_pkey IS NULL` (indicating no match)
854
        '''
855
        if util.is_str(table): table = Table(table)
856
        assert type_ == None or util.is_str(type_) or type_ is filter_out
857
        
858
        self.table = table
859
        self.mapping = mapping
860
        self.type_ = type_
861
    
862
    def to_str(self, db, left_table_):
863
        def join(entry):
864
            '''Parses non-USING joins'''
865
            right_table_col, left_table_col = entry
866
            
867
            # Switch order (right_table_col is on the left in the comparison)
868
            left = right_table_col
869
            right = left_table_col
870
            left_table = self.table
871
            right_table = left_table_
872
            
873
            # Parse left side
874
            left = with_default_table(left, left_table)
875
            
876
            # Parse special values
877
            left_on_right = Col(left.name, right_table)
878
            if right is join_same: right = left_on_right
879
            elif right is join_same_not_null:
880
                right = CompareCond(left_on_right, '~=')
881
            
882
            # Parse right side
883
            right = as_ValueCond(right, right_table)
884
            
885
            return right.to_str(db, left)
886
        
887
        # Create join condition
888
        type_ = self.type_
889
        joins = self.mapping
890
        if joins == {}: join_cond = None
891
        elif type_ is not filter_out and reduce(operator.and_,
892
            (v is join_same_not_null for v in joins.itervalues())):
893
            # all cols w/ USING, so can use simpler USING syntax
894
            cols = map(to_name_only_col, joins.iterkeys())
895
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
896
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
897
        
898
        if isinstance(self.table, NamedTable): whitespace = '\n'
899
        else: whitespace = ' '
900
        
901
        # Create join
902
        if type_ is filter_out: type_ = 'LEFT'
903
        str_ = ''
904
        if type_ != None: str_ += type_+' '
905
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
906
        if join_cond != None: str_ += whitespace+join_cond
907
        return str_
908
    
909
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
910

    
911
##### Value exprs
912

    
913
all_cols = CustomCode('*')
914

    
915
default = CustomCode('DEFAULT')
916

    
917
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
918

    
919
class Coalesce(FunctionCall):
920
    def __init__(self, *args):
921
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
922

    
923
class Nullif(FunctionCall):
924
    def __init__(self, *args):
925
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
926

    
927
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
928
null_sentinels = {
929
    'character varying': r'\N',
930
    'double precision': 'NaN',
931
    'integer': 2147483647,
932
    'text': r'\N',
933
    'timestamp with time zone': 'infinity'
934
}
935

    
936
class EnsureNotNull(Coalesce):
937
    def __init__(self, value, type_):
938
        Coalesce.__init__(self, as_Col(value),
939
            Cast(type_, null_sentinels[type_]))
940
        
941
        self.type = type_
942
    
943
    def to_str(self, db):
944
        col = self.args[0]
945
        index_col_ = index_col(col)
946
        if index_col_ != None: return index_col_.to_str(db)
947
        return Coalesce.to_str(self, db)
948

    
949
##### Table exprs
950

    
951
class Values(Code):
952
    def __init__(self, values):
953
        '''
954
        @param values [...]|[[...], ...] Can be one or multiple rows.
955
        '''
956
        Code.__init__(self)
957
        
958
        rows = values
959
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
960
            rows = [values]
961
        for i, row in enumerate(rows):
962
            rows[i] = map(remove_col_rename, map(as_Value, row))
963
        
964
        self.rows = rows
965
    
966
    def to_str(self, db):
967
        return 'VALUES '+(', '.join((Tuple(*r).to_str(db) for r in self.rows)))
968

    
969
def NamedValues(name, cols, values):
970
    '''
971
    @param cols None|[...]
972
    @post `cols` will be changed to Col objects with the table set to `name`.
973
    '''
974
    table = NamedTable(name, Values(values), cols)
975
    if cols != None: set_cols_table(table, cols)
976
    return table
977

    
978
##### Database structure
979

    
980
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
981

    
982
def ensure_not_null(db, col, type_=None):
983
    '''
984
    @param col If type_ is not set, must have an underlying column.
985
    @param type_ If set, overrides the underlying column's type.
986
    @return EnsureNotNull|Col
987
    @throws ensure_not_null_excs
988
    '''
989
    nullable = True
990
    try: typed_col = db.col_info(underlying_col(col))
991
    except NoUnderlyingTableException:
992
        col = remove_col_rename(col)
993
        if is_literal(col) and not is_null(col): nullable = False
994
        elif type_ == None: raise
995
    else:
996
        if type_ == None: type_ = typed_col.type
997
        nullable = typed_col.nullable
998
    
999
    if nullable:
1000
        try: col = EnsureNotNull(col, type_)
1001
        except KeyError, e:
1002
            # Warn of no null sentinel for type, even if caller catches error
1003
            warnings.warn(UserWarning(exc.str_(e)))
1004
            raise
1005
    
1006
    return col
(25-25/37)