Project

General

Profile

1
# SQL code generation
2

    
3
import copy
4
import itertools
5
import operator
6
from ordereddict import OrderedDict
7
import re
8
import UserDict
9
import warnings
10

    
11
import dicts
12
import exc
13
import iters
14
import lists
15
import objects
16
import strings
17
import util
18

    
19
##### Names
20

    
21
identifier_max_len = 63 # works for both PostgreSQL and MySQL
22

    
23
def concat(str_, suffix):
24
    '''Preserves version so that it won't be truncated off the string, leading
25
    to collisions.'''
26
    # Preserve version
27
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
28
    if match:
29
        str_, old_suffix = match.groups()
30
        suffix = old_suffix+suffix
31
    
32
    return strings.concat(str_, suffix, identifier_max_len)
33

    
34
def truncate(str_): return concat(str_, '')
35

    
36
def is_safe_name(name):
37
    '''A name is safe *and unambiguous* if it:
38
    * contains only *lowercase* word (\w) characters
39
    * doesn't start with a digit
40
    * contains "_", so that it's not a keyword
41
    '''
42
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
43

    
44
def esc_name(name, quote='"'):
45
    return quote + name.replace(quote, quote+quote) + quote
46
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
47

    
48
def unesc_name(name, quote='"'):
49
    removed_ref = [False]
50
    name = strings.remove_prefix(quote, name, removed_ref)
51
    if removed_ref[0]:
52
        name = strings.remove_suffix(quote, name, removed_ref)
53
        assert removed_ref[0]
54
        name = name.replace(quote+quote, quote)
55
    return name
56

    
57
def clean_name(name): return name.replace('"', '').replace('`', '')
58

    
59
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
60

    
61
def lstrip(str_):
62
    '''Also removes comments.'''
63
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
64
    return str_.lstrip()
65

    
66
##### General SQL code objects
67

    
68
class MockDb:
69
    def esc_value(self, value): return strings.repr_no_u(value)
70
    
71
    def esc_name(self, name): return esc_name(name)
72
    
73
    def col_info(self, col):
74
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
75

    
76
mockDb = MockDb()
77

    
78
class BasicObject(objects.BasicObject):
79
    def __str__(self): return clean_name(strings.repr_no_u(self))
80

    
81
##### Unparameterized code objects
82

    
83
class Code(BasicObject):
84
    def __init__(self, lang='sql'):
85
        self.lang = lang
86
    
87
    def to_str(self, db): raise NotImplementedError()
88
    
89
    def __repr__(self): return self.to_str(mockDb)
90

    
91
class CustomCode(Code):
92
    def __init__(self, str_):
93
        Code.__init__(self)
94
        
95
        self.str_ = str_
96
    
97
    def to_str(self, db): return self.str_
98

    
99
def as_Code(value, db=None):
100
    '''
101
    @param db If set, runs db.std_code() on the value.
102
    '''
103
    if isinstance(value, Code): return value
104
    
105
    if util.is_str(value):
106
        if db != None: value = db.std_code(value)
107
        return CustomCode(value)
108
    else: return Literal(value)
109

    
110
class Expr(Code):
111
    def __init__(self, expr):
112
        Code.__init__(self)
113
        
114
        self.expr = expr
115
    
116
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
117

    
118
##### Names
119

    
120
class Name(Code):
121
    def __init__(self, name):
122
        Code.__init__(self)
123
        
124
        name = truncate(name)
125
        
126
        self.name = name
127
    
128
    def to_str(self, db): return db.esc_name(self.name)
129

    
130
def as_Name(value):
131
    if isinstance(value, Code): return value
132
    else: return Name(value)
133

    
134
##### Literal values
135

    
136
#### Primitives
137

    
138
class Literal(Code):
139
    def __init__(self, value):
140
        Code.__init__(self)
141
        
142
        self.value = value
143
    
144
    def to_str(self, db): return db.esc_value(self.value)
145

    
146
def as_Value(value):
147
    if isinstance(value, Code): return value
148
    else: return Literal(value)
149

    
150
def is_literal(value): return isinstance(value, Literal)
151

    
152
def is_null(value): return is_literal(value) and value.value == None
153

    
154
#### Composites
155

    
156
class List(Code):
157
    def __init__(self, values):
158
        Code.__init__(self)
159
        
160
        self.values = values
161
    
162
    def to_str(self, db): return ', '.join((v.to_str(db) for v in self.values))
163

    
164
class Tuple(List):
165
    def __init__(self, *values):
166
        List.__init__(self, values)
167
    
168
    def to_str(self, db): return '('+List.to_str(self, db)+')'
169

    
170
class Row(Tuple):
171
    def to_str(self, db): return 'ROW'+Tuple.to_str(self, db)
172

    
173
### Arrays
174

    
175
class Array(List):
176
    def __init__(self, values):
177
        values = map(remove_col_rename, values)
178
        
179
        List.__init__(self, values)
180
    
181
    def to_str(self, db): return 'ARRAY['+List.to_str(self, db)+']'
182

    
183
def to_Array(value):
184
    if isinstance(value, Array): return value
185
    return Array(lists.mk_seq(value))
186

    
187
##### Derived elements
188

    
189
src_self = object() # tells Col that it is its own source column
190

    
191
class Derived(Code):
192
    def __init__(self, srcs):
193
        '''An element which was derived from some other element(s).
194
        @param srcs See self.set_srcs()
195
        '''
196
        Code.__init__(self)
197
        
198
        self.set_srcs(srcs)
199
    
200
    def set_srcs(self, srcs, overwrite=True):
201
        '''
202
        @param srcs (self_type...)|src_self The element(s) this is derived from
203
        '''
204
        if not overwrite and self.srcs != (): return # already set
205
        
206
        if srcs == src_self: srcs = (self,)
207
        srcs = tuple(srcs) # make Col hashable
208
        self.srcs = srcs
209
    
210
    def _compare_on(self):
211
        compare_on = self.__dict__.copy()
212
        del compare_on['srcs'] # ignore
213
        return compare_on
214

    
215
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
216

    
217
##### Tables
218

    
219
class Table(Derived):
220
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
221
        '''
222
        @param schema str|None (for no schema)
223
        @param srcs (Table...)|src_self See Derived.set_srcs()
224
        '''
225
        Derived.__init__(self, srcs)
226
        
227
        if util.is_str(name): name = truncate(name)
228
        
229
        self.name = name
230
        self.schema = schema
231
        self.is_temp = is_temp
232
        self.index_cols = {}
233
    
234
    def to_str(self, db):
235
        str_ = ''
236
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
237
        str_ += as_Name(self.name).to_str(db)
238
        return str_
239
    
240
    def to_Table(self): return self
241
    
242
    def _compare_on(self):
243
        compare_on = Derived._compare_on(self)
244
        del compare_on['index_cols'] # ignore
245
        return compare_on
246

    
247
def is_underlying_table(table):
248
    return isinstance(table, Table) and table.to_Table() is table
249

    
250
class NoUnderlyingTableException(Exception):
251
    def __init__(self, ref):
252
        Exception.__init__(self, 'for: '+strings.as_tt(strings.urepr(ref)))
253
        self.ref = ref
254

    
255
def underlying_table(table):
256
    table = remove_table_rename(table)
257
    if table != None and table.srcs:
258
        table, = table.srcs # for derived tables or row vars
259
    if not is_underlying_table(table): raise NoUnderlyingTableException(table)
260
    return table
261

    
262
def as_Table(table, schema=None):
263
    if table == None or isinstance(table, Code): return table
264
    else: return Table(table, schema)
265

    
266
def suffixed_table(table, suffix):
267
    table = copy.copy(table) # don't modify input!
268
    table.name = concat(table.name, suffix)
269
    return table
270

    
271
class NamedTable(Table):
272
    def __init__(self, name, code, cols=None):
273
        Table.__init__(self, name)
274
        
275
        code = as_Table(code)
276
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
277
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
278
        
279
        self.code = code
280
        self.cols = cols
281
    
282
    def to_str(self, db):
283
        str_ = self.code.to_str(db)
284
        if str_.find('\n') >= 0: whitespace = '\n'
285
        else: whitespace = ' '
286
        str_ += whitespace+'AS '+Table.to_str(self, db)
287
        if self.cols != None:
288
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
289
        return str_
290
    
291
    def to_Table(self): return Table(self.name)
292

    
293
def remove_table_rename(table):
294
    if isinstance(table, NamedTable): table = table.code
295
    return table
296

    
297
##### Columns
298

    
299
class Col(Derived):
300
    def __init__(self, name, table=None, srcs=()):
301
        '''
302
        @param table Table|None (for no table)
303
        @param srcs (Col...)|src_self See Derived.set_srcs()
304
        '''
305
        Derived.__init__(self, srcs)
306
        
307
        if util.is_str(name): name = truncate(name)
308
        if util.is_str(table): table = Table(table)
309
        assert table == None or isinstance(table, Table)
310
        
311
        self.name = name
312
        self.table = table
313
    
314
    def to_str(self, db, for_str=False):
315
        str_ = as_Name(self.name).to_str(db)
316
        if for_str: str_ = clean_name(str_)
317
        if self.table != None:
318
            table = self.table.to_Table()
319
            if for_str: str_ = concat(strings.ustr(table), '.'+str_)
320
            else: str_ = table.to_str(db)+'.'+str_
321
        return str_
322
    
323
    def __str__(self): return self.to_str(mockDb, for_str=True)
324
    
325
    def to_Col(self): return self
326

    
327
def is_col(col): return isinstance(col, Col)
328

    
329
def is_table_col(col): return is_col(col) and col.table != None
330

    
331
def index_col(col):
332
    if not is_table_col(col): return None
333
    
334
    table = col.table
335
    try: name = table.index_cols[col.name]
336
    except KeyError: return None
337
    else: return Col(name, table, col.srcs)
338

    
339
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
340

    
341
def as_Col(col, table=None, name=None):
342
    '''
343
    @param name If not None, any non-Col input will be renamed using NamedCol.
344
    '''
345
    if name != None:
346
        col = as_Value(col)
347
        if not isinstance(col, Col): col = NamedCol(name, col)
348
    
349
    if isinstance(col, Code): return col
350
    elif util.is_str(col): return Col(col, table)
351
    else: return Literal(col)
352

    
353
def with_table(col, table):
354
    if isinstance(col, NamedCol): pass # doesn't take a table
355
    elif isinstance(col, FunctionCall):
356
        col = copy.deepcopy(col) # don't modify input!
357
        col.args[0].table = table
358
    elif isinstance(col, Col):
359
        col = copy.copy(col) # don't modify input!
360
        col.table = table
361
    return col
362

    
363
def with_default_table(col, table):
364
    col = as_Col(col)
365
    if col.table == None: col = with_table(col, table)
366
    return col
367

    
368
def set_cols_table(table, cols):
369
    table = as_Table(table)
370
    
371
    for i, col in enumerate(cols):
372
        col = cols[i] = as_Col(col)
373
        col.table = table
374

    
375
def to_name_only_col(col, check_table=None):
376
    col = as_Col(col)
377
    if not is_table_col(col): return col
378
    
379
    if check_table != None:
380
        table = col.table
381
        assert table == None or table == check_table
382
    return Col(col.name)
383

    
384
def suffixed_col(col, suffix):
385
    return Col(concat(col.name, suffix), col.table, col.srcs)
386

    
387
def has_srcs(col): return is_col(col) and col.srcs
388

    
389
def cross_join_srcs(cols):
390
    cols = filter(has_srcs, cols) # empty srcs will mess up the cross join
391
    srcs = [[s.name for s in c.srcs] for c in cols]
392
    return [Col(','.join(s)) for s in itertools.product(*srcs)]
393

    
394
class NamedCol(Col):
395
    def __init__(self, name, code):
396
        Col.__init__(self, name)
397
        
398
        code = as_Value(code)
399
        
400
        self.code = code
401
    
402
    def to_str(self, db):
403
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
404
    
405
    def to_Col(self): return Col(self.name)
406

    
407
def remove_col_rename(col):
408
    if isinstance(col, NamedCol): col = col.code
409
    return col
410

    
411
def underlying_col(col):
412
    col = remove_col_rename(col)
413
    if not isinstance(col, Col): raise NoUnderlyingTableException(col)
414
    
415
    return Col(col.name, underlying_table(col.table), col.srcs)
416

    
417
def wrap(wrap_func, value):
418
    '''Wraps a value, propagating any column renaming to the returned value.'''
419
    if isinstance(value, NamedCol):
420
        return NamedCol(value.name, wrap_func(value.code))
421
    else: return wrap_func(value)
422

    
423
class ColDict(dicts.DictProxy):
424
    '''A dict that automatically makes inserted entries Col objects.
425
    Anything that isn't a column is wrapped in a NamedCol with the key's column
426
    name by `as_Col(value, name=key.name)`.
427
    '''
428
    
429
    def __init__(self, db, keys_table, dict_={}):
430
        dicts.DictProxy.__init__(self, OrderedDict())
431
        
432
        keys_table = as_Table(keys_table)
433
        
434
        self.db = db
435
        self.table = keys_table
436
        self.update(dict_) # after setting vars because __setitem__() needs them
437
    
438
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
439
    
440
    def __getitem__(self, key):
441
        return dicts.DictProxy.__getitem__(self, self._key(key))
442
    
443
    def __setitem__(self, key, value):
444
        key = self._key(key)
445
        if value == None:
446
            try: value = self.db.col_info(key).default
447
            except NoUnderlyingTableException: pass # not a table column
448
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
449
    
450
    def _key(self, key): return as_Col(key, self.table)
451

    
452
##### Definitions
453

    
454
class TypedCol(Col):
455
    def __init__(self, name, type_, default=None, nullable=True,
456
        constraints=None):
457
        assert default == None or isinstance(default, Code)
458
        
459
        Col.__init__(self, name)
460
        
461
        self.type = type_
462
        self.default = default
463
        self.nullable = nullable
464
        self.constraints = constraints
465
    
466
    def to_str(self, db):
467
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
468
        if not self.nullable: str_ += ' NOT NULL'
469
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
470
        if self.constraints != None: str_ += ' '+self.constraints
471
        return str_
472
    
473
    def to_Col(self): return Col(self.name)
474

    
475
class SetOf(Code):
476
    def __init__(self, type_):
477
        Code.__init__(self)
478
        
479
        self.type = type_
480
    
481
    def to_str(self, db):
482
        return 'SETOF '+self.type.to_str(db)
483

    
484
class RowType(Code):
485
    def __init__(self, table):
486
        Code.__init__(self)
487
        
488
        self.table = table
489
    
490
    def to_str(self, db):
491
        return self.table.to_str(db)+'%ROWTYPE'
492

    
493
class ColType(Code):
494
    def __init__(self, col):
495
        Code.__init__(self)
496
        
497
        self.col = col
498
    
499
    def to_str(self, db):
500
        return self.col.to_str(db)+'%TYPE'
501

    
502
class ArrayType(Code):
503
    def __init__(self, elem_type):
504
        Code.__init__(self)
505
        elem_type = as_Code(elem_type)
506
        
507
        self.elem_type = elem_type
508
    
509
    def to_str(self, db):
510
        return self.elem_type.to_str(db)+'[]'
511

    
512
##### Functions
513

    
514
Function = Table
515
as_Function = as_Table
516

    
517
class InternalFunction(CustomCode): pass
518

    
519
#### Calls
520

    
521
class NamedArg(NamedCol):
522
    def __init__(self, name, value):
523
        NamedCol.__init__(self, name, value)
524
    
525
    def to_str(self, db):
526
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
527

    
528
class FunctionCall(Code):
529
    def __init__(self, function, *args, **kw_args):
530
        '''
531
        @param args [Code|literal-value...] The function's arguments
532
        '''
533
        Code.__init__(self)
534
        
535
        function = as_Function(function)
536
        def filter_(arg): return remove_col_rename(as_Value(arg))
537
        args = map(filter_, args)
538
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
539
        
540
        self.function = function
541
        self.args = args
542
    
543
    def to_str(self, db):
544
        args_str = ', '.join((v.to_str(db) for v in self.args))
545
        return self.function.to_str(db)+'('+args_str+')'
546

    
547
def wrap_in_func(function, value):
548
    '''Wraps a value inside a function call.
549
    Propagates any column renaming to the returned value.
550
    '''
551
    return wrap(lambda v: FunctionCall(function, v), value)
552

    
553
def unwrap_func_call(func_call, check_name=None):
554
    '''Unwraps any function call to its first argument.
555
    Also removes any column renaming.
556
    '''
557
    func_call = remove_col_rename(func_call)
558
    if not isinstance(func_call, FunctionCall): return func_call
559
    
560
    if check_name != None:
561
        name = func_call.function.name
562
        assert name == None or name == check_name
563
    return func_call.args[0]
564

    
565
#### Definitions
566

    
567
class FunctionDef(Code):
568
    def __init__(self, function, return_type, body, params=[], modifiers=None):
569
        Code.__init__(self)
570
        
571
        return_type = as_Code(return_type)
572
        body = as_Code(body)
573
        
574
        self.function = function
575
        self.return_type = return_type
576
        self.body = body
577
        self.params = params
578
        self.modifiers = modifiers
579
    
580
    def to_str(self, db):
581
        params_str = (', '.join((p.to_str(db) for p in self.params)))
582
        str_ = '''\
583
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
584
RETURNS '''+self.return_type.to_str(db)+'''
585
LANGUAGE '''+self.body.lang+'''
586
'''
587
        if self.modifiers != None: str_ += self.modifiers+'\n'
588
        str_ += '''\
589
AS $$
590
'''+self.body.to_str(db)+'''
591
$$;
592
'''
593
        return str_
594

    
595
class FunctionParam(TypedCol):
596
    def __init__(self, name, type_, default=None, out=False):
597
        TypedCol.__init__(self, name, type_, default)
598
        
599
        self.out = out
600
    
601
    def to_str(self, db):
602
        str_ = TypedCol.to_str(self, db)
603
        if self.out: str_ = 'OUT '+str_
604
        return str_
605
    
606
    def to_Col(self): return Col(self.name)
607

    
608
### PL/pgSQL
609

    
610
class ReturnQuery(Code):
611
    def __init__(self, query):
612
        Code.__init__(self)
613
        
614
        query = as_Code(query)
615
        
616
        self.query = query
617
    
618
    def to_str(self, db):
619
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
620

    
621
## Exceptions
622

    
623
class BaseExcHandler(BasicObject):
624
    def to_str(self, db, body): raise NotImplementedError()
625
    
626
    def __repr__(self): return self.to_str(mockDb, '<body>')
627

    
628
suppress_exc = 'NULL;\n';
629

    
630
reraise_exc = 'RAISE USING ERRCODE = SQLSTATE, MESSAGE = SQLERRM;\n';
631

    
632
class ExcHandler(BaseExcHandler):
633
    def __init__(self, exc, handler=None):
634
        if handler != None: handler = as_Code(handler)
635
        
636
        self.exc = exc
637
        self.handler = handler
638
    
639
    def to_str(self, db, body):
640
        body = as_Code(body)
641
        
642
        if self.handler != None:
643
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
644
        else: handler_str = ' '+suppress_exc
645
        
646
        str_ = '''\
647
BEGIN
648
'''+strings.indent(body.to_str(db))+'''\
649
EXCEPTION
650
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
651
END;\
652
'''
653
        return str_
654

    
655
class NestedExcHandler(BaseExcHandler):
656
    def __init__(self, *handlers):
657
        '''
658
        @param handlers Sorted from outermost to innermost
659
        '''
660
        self.handlers = handlers
661
    
662
    def to_str(self, db, body):
663
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
664
        return body
665

    
666
class ExcToWarning(Code):
667
    def __init__(self, return_):
668
        '''
669
        @param return_ Statement to return a default value in case of error
670
        '''
671
        Code.__init__(self)
672
        
673
        return_ = as_Code(return_)
674
        
675
        self.return_ = return_
676
    
677
    def to_str(self, db):
678
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
679

    
680
unique_violation_handler = ExcHandler('unique_violation')
681

    
682
# Note doubled "\"s because inside Python string
683
plpythonu_error_handler = ExcHandler('internal_error', '''\
684
-- Handle PL/Python exceptions
685
DECLARE
686
    matches text[] := regexp_matches(SQLERRM,
687
        E'^(?:PL/Python: )?(\\\\w+): (.*)$'); -- .* also matches \\n
688
    exc_name text := matches[1];
689
    msg text := matches[2];
690
BEGIN
691
    /* Re-raise PL/Python exceptions with the PL/Python prefix removed.
692
    This allows the exception to be parsed like a native exception.
693
    Always raise as data_exception so it goes in the errors table. */
694
    IF exc_name IS NOT NULL THEN
695
        RAISE data_exception USING MESSAGE = msg;
696
    -- Re-raise non-PL/Python exceptions
697
    ELSE
698
        '''+reraise_exc+'''\
699
    END IF;
700
END;
701
''')
702

    
703
def data_exception_handler(handler):
704
    return ExcHandler('data_exception', handler)
705

    
706
row_var = Table('row')
707

    
708
class RowExcIgnore(Code):
709
    def __init__(self, row_type, select_query, with_row, cols=None,
710
        exc_handler=unique_violation_handler, row_var=row_var):
711
        '''
712
        @param row_type Ignored if a custom row_var is used.
713
        @pre If a custom row_var is used, it must already be defined.
714
        '''
715
        Code.__init__(self, lang='plpgsql')
716
        
717
        row_type = as_Code(row_type)
718
        select_query = as_Code(select_query)
719
        with_row = as_Code(with_row)
720
        row_var = as_Table(row_var)
721
        
722
        self.row_type = row_type
723
        self.select_query = select_query
724
        self.with_row = with_row
725
        self.cols = cols
726
        self.exc_handler = exc_handler
727
        self.row_var = row_var
728
    
729
    def to_str(self, db):
730
        if self.cols == None: row_vars = [self.row_var]
731
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
732
        
733
        # Need an EXCEPTION block for each individual row because "When an error
734
        # is caught by an EXCEPTION clause, [...] all changes to persistent
735
        # database state within the block are rolled back."
736
        # This is unfortunate because "A block containing an EXCEPTION clause is
737
        # significantly more expensive to enter and exit than a block without
738
        # one."
739
        # (http://www.postgresql.org/docs/8.3/static/\
740
        # plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING)
741
        str_ = '''\
742
FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
743
'''+strings.indent(self.select_query.to_str(db))+'''\
744
LOOP
745
'''+strings.indent(self.exc_handler.to_str(db, self.with_row))+'''\
746
END LOOP;
747
'''
748
        if self.row_var == row_var:
749
            str_ = '''\
750
DECLARE
751
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
752
BEGIN
753
'''+strings.indent(str_)+'''\
754
END;
755
'''
756
        return str_
757

    
758
##### Casts
759

    
760
class Cast(FunctionCall):
761
    def __init__(self, type_, value):
762
        type_ = as_Code(type_)
763
        value = as_Value(value)
764
        
765
        self.type_ = type_
766
        self.value = value
767
    
768
    def to_str(self, db):
769
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_.to_str(db)+')'
770

    
771
def cast_literal(value):
772
    if not is_literal(value): return value
773
    
774
    if util.is_str(value.value): value = Cast('text', value)
775
    return value
776

    
777
##### Conditions
778

    
779
class NotCond(Code):
780
    def __init__(self, cond):
781
        Code.__init__(self)
782
        
783
        self.cond = cond
784
    
785
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
786

    
787
class ColValueCond(Code):
788
    def __init__(self, col, value):
789
        Code.__init__(self)
790
        
791
        value = as_ValueCond(value)
792
        
793
        self.col = col
794
        self.value = value
795
    
796
    def to_str(self, db): return self.value.to_str(db, self.col)
797

    
798
def combine_conds(conds, keyword=None):
799
    '''
800
    @param keyword The keyword to add before the conditions, if any
801
    '''
802
    str_ = ''
803
    if keyword != None:
804
        if conds == []: whitespace = ''
805
        elif len(conds) == 1: whitespace = ' '
806
        else: whitespace = '\n'
807
        str_ += keyword+whitespace
808
    
809
    str_ += '\nAND '.join(conds)
810
    return str_
811

    
812
##### Condition column comparisons
813

    
814
class ValueCond(BasicObject):
815
    def __init__(self, value):
816
        value = remove_col_rename(as_Value(value))
817
        
818
        self.value = value
819
    
820
    def to_str(self, db, left_value):
821
        '''
822
        @param left_value The Code object that the condition is being applied on
823
        '''
824
        raise NotImplemented()
825
    
826
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
827

    
828
class CompareCond(ValueCond):
829
    def __init__(self, value, operator='='):
830
        '''
831
        @param operator By default, compares NULL values literally. Use '~=' or
832
            '~!=' to pass NULLs through.
833
        '''
834
        ValueCond.__init__(self, value)
835
        self.operator = operator
836
    
837
    def to_str(self, db, left_value):
838
        left_value = remove_col_rename(as_Col(left_value))
839
        
840
        right_value = self.value
841
        
842
        # Parse operator
843
        operator = self.operator
844
        passthru_null_ref = [False]
845
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
846
        neg_ref = [False]
847
        operator = strings.remove_prefix('!', operator, neg_ref)
848
        equals = operator.endswith('=') # also includes <=, >=
849
        
850
        # Handle nullable columns
851
        check_null = False
852
        if not passthru_null_ref[0]: # NULLs compare equal
853
            try: left_value = ensure_not_null(db, left_value)
854
            except ensure_not_null_excs: # fall back to alternate method
855
                check_null = equals and isinstance(right_value, Col)
856
            else:
857
                if isinstance(left_value, EnsureNotNull):
858
                    right_value = ensure_not_null(db, right_value,
859
                        left_value.type) # apply same function to both sides
860
        
861
        if equals and is_null(right_value): operator = 'IS'
862
        
863
        left = left_value.to_str(db)
864
        right = right_value.to_str(db)
865
        
866
        # Create str
867
        str_ = left+' '+operator+' '+right
868
        if check_null:
869
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
870
        if neg_ref[0]: str_ = 'NOT '+str_
871
        return str_
872

    
873
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
874
assume_literal = object()
875

    
876
def as_ValueCond(value, default_table=assume_literal):
877
    if not isinstance(value, ValueCond):
878
        if default_table is not assume_literal:
879
            value = with_default_table(value, default_table)
880
        return CompareCond(value)
881
    else: return value
882

    
883
##### Joins
884

    
885
join_same = object() # tells Join the left and right columns have the same name
886

    
887
# Tells Join the left and right columns have the same name and are never NULL
888
join_same_not_null = object()
889

    
890
filter_out = object() # tells Join to filter out rows that match the join
891

    
892
class Join(BasicObject):
893
    def __init__(self, table, mapping={}, type_=None):
894
        '''
895
        @param mapping dict(right_table_col=left_table_col, ...)
896
            * if left_table_col is join_same: left_table_col = right_table_col
897
              * Note that right_table_col must be a string
898
            * if left_table_col is join_same_not_null:
899
              left_table_col = right_table_col and both have NOT NULL constraint
900
              * Note that right_table_col must be a string
901
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
902
            * filter_out: equivalent to 'LEFT' with the query filtered by
903
              `table_pkey IS NULL` (indicating no match)
904
        '''
905
        if util.is_str(table): table = Table(table)
906
        assert type_ == None or util.is_str(type_) or type_ is filter_out
907
        
908
        self.table = table
909
        self.mapping = mapping
910
        self.type_ = type_
911
    
912
    def to_str(self, db, left_table_):
913
        def join(entry):
914
            '''Parses non-USING joins'''
915
            right_table_col, left_table_col = entry
916
            
917
            # Switch order (right_table_col is on the left in the comparison)
918
            left = right_table_col
919
            right = left_table_col
920
            left_table = self.table
921
            right_table = left_table_
922
            
923
            # Parse left side
924
            left = with_default_table(left, left_table)
925
            
926
            # Parse special values
927
            left_on_right = Col(left.name, right_table)
928
            if right is join_same: right = left_on_right
929
            elif right is join_same_not_null:
930
                right = CompareCond(left_on_right, '~=')
931
            
932
            # Parse right side
933
            right = as_ValueCond(right, right_table)
934
            
935
            return right.to_str(db, left)
936
        
937
        # Create join condition
938
        type_ = self.type_
939
        joins = self.mapping
940
        if joins == {}: join_cond = None
941
        elif type_ is not filter_out and reduce(operator.and_,
942
            (v is join_same_not_null for v in joins.itervalues())):
943
            # all cols w/ USING, so can use simpler USING syntax
944
            cols = map(to_name_only_col, joins.iterkeys())
945
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
946
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
947
        
948
        if isinstance(self.table, NamedTable): whitespace = '\n'
949
        else: whitespace = ' '
950
        
951
        # Create join
952
        if type_ is filter_out: type_ = 'LEFT'
953
        str_ = ''
954
        if type_ != None: str_ += type_+' '
955
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
956
        if join_cond != None: str_ += whitespace+join_cond
957
        return str_
958
    
959
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
960

    
961
##### Value exprs
962

    
963
all_cols = CustomCode('*')
964

    
965
default = CustomCode('DEFAULT')
966

    
967
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
968

    
969
class Coalesce(FunctionCall):
970
    def __init__(self, *args):
971
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
972

    
973
class Nullif(FunctionCall):
974
    def __init__(self, *args):
975
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
976

    
977
null_as_str = Cast('text', 'NULL')
978

    
979
def to_text(value): return Coalesce(Cast('text', value), null_as_str)
980

    
981
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
982
null_sentinels = {
983
    'character varying': r'\N',
984
    'double precision': 'NaN',
985
    'integer': 2147483647,
986
    'text': r'\N',
987
    'timestamp with time zone': 'infinity'
988
}
989

    
990
class EnsureNotNull(Coalesce):
991
    def __init__(self, value, type_):
992
        if isinstance(type_, ArrayType): null = []
993
        else: null = null_sentinels[type_]
994
        Coalesce.__init__(self, as_Col(value), Cast(type_, null))
995
        
996
        self.type = type_
997
    
998
    def to_str(self, db):
999
        col = self.args[0]
1000
        index_col_ = index_col(col)
1001
        if index_col_ != None: return index_col_.to_str(db)
1002
        return Coalesce.to_str(self, db)
1003

    
1004
#### Arrays
1005

    
1006
class ArrayMerge(FunctionCall):
1007
    def __init__(self, sep, array):
1008
        array = to_Array(array)
1009
        FunctionCall.__init__(self, InternalFunction('array_to_string'), array,
1010
            sep)
1011

    
1012
def merge_not_null(db, sep, values):
1013
    return ArrayMerge(sep, map(to_text, values))
1014

    
1015
##### Table exprs
1016

    
1017
class Values(Code):
1018
    def __init__(self, values):
1019
        '''
1020
        @param values [...]|[[...], ...] Can be one or multiple rows.
1021
        '''
1022
        Code.__init__(self)
1023
        
1024
        rows = values
1025
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
1026
            rows = [values]
1027
        for i, row in enumerate(rows):
1028
            rows[i] = map(remove_col_rename, map(as_Value, row))
1029
        
1030
        self.rows = rows
1031
    
1032
    def to_str(self, db):
1033
        return 'VALUES '+(', '.join((Tuple(*r).to_str(db) for r in self.rows)))
1034

    
1035
def NamedValues(name, cols, values):
1036
    '''
1037
    @param cols None|[...]
1038
    @post `cols` will be changed to Col objects with the table set to `name`.
1039
    '''
1040
    table = NamedTable(name, Values(values), cols)
1041
    if cols != None: set_cols_table(table, cols)
1042
    return table
1043

    
1044
##### Database structure
1045

    
1046
def is_nullable(db, value):
1047
    try: return is_null(value) or db.col_info(value).nullable
1048
    except NoUnderlyingTableException: return True # not a table column
1049

    
1050
text_types = set(['character varying', 'text'])
1051

    
1052
def is_text_col(db, col): return db.col_info(col).type in text_types
1053

    
1054
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
1055

    
1056
def ensure_not_null(db, col, type_=None):
1057
    '''
1058
    @param col If type_ is not set, must have an underlying column.
1059
    @param type_ If set, overrides the underlying column's type and casts the
1060
        column to it if needed.
1061
    @return EnsureNotNull|Col
1062
    @throws ensure_not_null_excs
1063
    '''
1064
    nullable = True
1065
    try: typed_col = db.col_info(underlying_col(col))
1066
    except NoUnderlyingTableException:
1067
        col = remove_col_rename(col)
1068
        if is_literal(col) and not is_null(col): nullable = False
1069
        elif type_ == None: raise
1070
    else:
1071
        if type_ == None: type_ = typed_col.type
1072
        elif type_ != typed_col.type: col = Cast(type_, col)
1073
        nullable = typed_col.nullable
1074
    
1075
    if nullable:
1076
        try: col = EnsureNotNull(col, type_)
1077
        except KeyError, e:
1078
            # Warn of no null sentinel for type, even if caller catches error
1079
            warnings.warn(UserWarning(exc.str_(e)))
1080
            raise
1081
    
1082
    return col
1083

    
1084
def try_mk_not_null(db, value):
1085
    '''
1086
    Warning: This function does not guarantee that its result is NOT NULL.
1087
    '''
1088
    try: return ensure_not_null(db, value)
1089
    except ensure_not_null_excs: return value
(28-28/41)