Project

General

Profile

1
# SQL code generation
2

    
3
import copy
4
import itertools
5
import operator
6
from collections import OrderedDict
7
import re
8
import UserDict
9
import warnings
10

    
11
import dicts
12
import exc
13
import iters
14
import lists
15
import objects
16
import regexp
17
import strings
18
import util
19

    
20
##### Names
21

    
22
identifier_max_len = 63 # works for both PostgreSQL and MySQL
23

    
24
def concat(str_, suffix):
25
    '''Preserves version so that it won't be truncated off the string, leading
26
    to collisions.'''
27
    # Preserve version
28
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)+(?:::[\w ]+)*)$', str_)
29
    if match:
30
        str_, old_suffix = match.groups()
31
        suffix = old_suffix+suffix
32
    
33
    return strings.concat(str_, suffix, identifier_max_len)
34

    
35
def truncate(str_): return concat(str_, '')
36

    
37
def is_safe_name(name):
38
    '''A name is safe *and unambiguous* if it:
39
    * contains only *lowercase* word (\w) characters
40
    * doesn't start with a digit
41
    * contains "_", so that it's not a keyword
42
    '''
43
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
44

    
45
def esc_name(name, quote='"'):
46
    return quote + name.replace(quote, quote+quote) + quote
47
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
48

    
49
def unesc_name(name, quote='"'):
50
    removed_ref = [False]
51
    name = strings.remove_prefix(quote, name, removed_ref)
52
    if removed_ref[0]:
53
        name = strings.remove_suffix(quote, name, removed_ref)
54
        assert removed_ref[0]
55
        name = name.replace(quote+quote, quote)
56
    return name
57

    
58
def clean_name(name): return name.replace('"', '').replace('`', '')
59

    
60
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
61

    
62
def lstrip(str_):
63
    '''Also removes comments.'''
64
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
65
    return str_.lstrip()
66

    
67
##### General SQL code objects
68

    
69
class MockDb:
70
    def esc_value(self, value): return strings.repr_no_u(value)
71
    
72
    def esc_name(self, name): return esc_name(name)
73
    
74
    def col_info(self, col):
75
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
76

    
77
mockDb = MockDb()
78

    
79
class BasicObject(objects.BasicObject):
80
    def __str__(self): return clean_name(strings.repr_no_u(self))
81

    
82
##### Unparameterized code objects
83

    
84
class Code(BasicObject):
85
    def __init__(self, lang='sql'):
86
        self.lang = lang
87
    
88
    def to_str(self, db): raise NotImplementedError()
89
    
90
    def __repr__(self): return self.to_str(mockDb)
91

    
92
class CustomCode(Code):
93
    def __init__(self, str_):
94
        Code.__init__(self)
95
        
96
        self.str_ = str_
97
    
98
    def to_str(self, db): return self.str_
99

    
100
def as_Code(value, db=None):
101
    '''
102
    @param db If set, runs db.std_code() on the value.
103
    '''
104
    if isinstance(value, Code): return value
105
    
106
    if util.is_str(value):
107
        if db != None: value = db.std_code(value)
108
        return CustomCode(value)
109
    else: return Literal(value)
110

    
111
class Expr(Code):
112
    def __init__(self, expr):
113
        Code.__init__(self)
114
        
115
        self.expr = expr
116
    
117
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
118

    
119
##### Names
120

    
121
class Name(Code):
122
    def __init__(self, name):
123
        Code.__init__(self)
124
        
125
        name = truncate(name)
126
        
127
        self.name = name
128
    
129
    def to_str(self, db): return db.esc_name(self.name)
130

    
131
def as_Name(value):
132
    if isinstance(value, Code): return value
133
    else: return Name(value)
134

    
135
##### Literal values
136

    
137
#### Primitives
138

    
139
class Literal(Code):
140
    def __init__(self, value):
141
        Code.__init__(self)
142
        
143
        self.value = value
144
    
145
    def to_str(self, db): return db.esc_value(self.value)
146

    
147
def as_Value(value):
148
    if isinstance(value, Code): return value
149
    else: return Literal(value)
150

    
151
def get_value(value):
152
    '''Unwraps a Literal's value'''
153
    value = remove_col_rename(value)
154
    if isinstance(value, Literal): return value.value
155
    else:
156
        assert not isinstance(value, Code)
157
        return value
158

    
159
def is_literal(value): return isinstance(value, Literal)
160

    
161
def is_null(value): return is_literal(value) and value.value == None
162

    
163
#### Composites
164

    
165
class List(Code):
166
    def __init__(self, values):
167
        Code.__init__(self)
168
        
169
        self.values = values
170
    
171
    def to_str(self, db): return ', '.join((v.to_str(db) for v in self.values))
172

    
173
class Tuple(List):
174
    def __init__(self, *values):
175
        List.__init__(self, values)
176
    
177
    def to_str(self, db): return '('+List.to_str(self, db)+')'
178

    
179
class Row(Tuple):
180
    def to_str(self, db): return 'ROW'+Tuple.to_str(self, db)
181

    
182
### Arrays
183

    
184
class Array(List):
185
    def __init__(self, values):
186
        values = map(remove_col_rename, values)
187
        
188
        List.__init__(self, values)
189
    
190
    def to_str(self, db): return 'ARRAY['+List.to_str(self, db)+']'
191

    
192
def to_Array(value):
193
    if isinstance(value, Array): return value
194
    return Array(lists.mk_seq(value))
195

    
196
##### Derived elements
197

    
198
src_self = object() # tells Col that it is its own source column
199

    
200
class Derived(Code):
201
    def __init__(self, srcs):
202
        '''An element which was derived from some other element(s).
203
        @param srcs See self.set_srcs()
204
        '''
205
        Code.__init__(self)
206
        
207
        self.set_srcs(srcs)
208
    
209
    def set_srcs(self, srcs, overwrite=True):
210
        '''
211
        @param srcs (self_type...)|src_self The element(s) this is derived from
212
        '''
213
        if not overwrite and self.srcs != (): return # already set
214
        
215
        if srcs == src_self: srcs = (self,)
216
        srcs = tuple(srcs) # make Col hashable
217
        self.srcs = srcs
218
    
219
    def _compare_on(self):
220
        compare_on = self.__dict__.copy()
221
        del compare_on['srcs'] # ignore
222
        return compare_on
223

    
224
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
225

    
226
##### Tables
227

    
228
class Table(Derived):
229
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
230
        '''
231
        @param schema str|None (for no schema)
232
        @param srcs (Table...)|src_self See Derived.set_srcs()
233
        '''
234
        Derived.__init__(self, srcs)
235
        
236
        if util.is_str(name): name = truncate(name)
237
        
238
        self.name = name
239
        self.schema = schema
240
        self.is_temp = is_temp
241
        self.order_by = None
242
        self.index_cols = {}
243
    
244
    def to_str(self, db):
245
        str_ = ''
246
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
247
        str_ += as_Name(self.name).to_str(db)
248
        return str_
249
    
250
    def to_Table(self): return self
251
    
252
    def _compare_on(self):
253
        compare_on = Derived._compare_on(self)
254
        del compare_on['order_by'] # ignore
255
        del compare_on['index_cols'] # ignore
256
        return compare_on
257

    
258
def is_underlying_table(table):
259
    return isinstance(table, Table) and table.to_Table() is table
260

    
261
def table2regclass_text(db, table):
262
    assert isinstance(table, Table)
263
    return db.esc_value(table.to_str(db))
264

    
265
class NoUnderlyingTableException(Exception):
266
    def __init__(self, ref):
267
        Exception.__init__(self, 'for: '+strings.as_tt(strings.urepr(ref)))
268
        self.ref = ref
269

    
270
def underlying_table(table):
271
    table = remove_table_rename(table)
272
    if table != None and table.srcs:
273
        table, = table.srcs # for derived tables or row vars
274
    if not is_underlying_table(table): raise NoUnderlyingTableException(table)
275
    return table
276

    
277
def as_Table(table, schema=None):
278
    if table == None or isinstance(table, Code): return table
279
    else: return Table(table, schema)
280

    
281
def suffixed_table(table, suffix):
282
    table = copy.copy(table) # don't modify input!
283
    table.name = concat(table.name, suffix)
284
    return table
285

    
286
class NamedTable(Table):
287
    def __init__(self, name, code, cols=None):
288
        Table.__init__(self, name)
289
        
290
        code = as_Table(code)
291
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
292
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
293
        
294
        self.code = code
295
        self.cols = cols
296
    
297
    def to_str(self, db):
298
        str_ = self.code.to_str(db)
299
        if str_.find('\n') >= 0: whitespace = '\n'
300
        else: whitespace = ' '
301
        str_ += whitespace+'AS '+Table.to_str(self, db)
302
        if self.cols != None:
303
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
304
        return str_
305
    
306
    def to_Table(self): return Table(self.name)
307

    
308
def remove_table_rename(table):
309
    if isinstance(table, NamedTable): table = table.code
310
    return table
311

    
312
##### Columns
313

    
314
class Col(Derived):
315
    def __init__(self, name, table=None, srcs=()):
316
        '''
317
        @param table Table|None (for no table)
318
        @param srcs (Col...)|src_self See Derived.set_srcs()
319
        '''
320
        Derived.__init__(self, srcs)
321
        
322
        if util.is_str(name): name = truncate(name)
323
        if util.is_str(table): table = Table(table)
324
        assert table == None or isinstance(table, Table)
325
        
326
        self.name = name
327
        self.table = table
328
    
329
    def to_str(self, db, for_str=False):
330
        str_ = as_Name(self.name).to_str(db)
331
        if for_str: str_ = clean_name(str_)
332
        if self.table != None:
333
            table = self.table.to_Table()
334
            if for_str: str_ = concat(strings.ustr(table), '.'+str_)
335
            else: str_ = table.to_str(db)+'.'+str_
336
        return str_
337
    
338
    def __str__(self): return self.to_str(mockDb, for_str=True)
339
    
340
    def to_Col(self): return self
341

    
342
def is_col(col): return isinstance(col, Col)
343

    
344
def is_table_col(col): return is_col(col) and col.table != None
345

    
346
def col2col_ref(db, col):
347
    assert isinstance(col, Col)
348
    return db.esc_value((col.table.to_str(db), col.name))
349

    
350
def index_col(col):
351
    if not is_table_col(col): return None
352
    
353
    table = col.table
354
    try: name = table.index_cols[col.name]
355
    except KeyError: return None
356
    else: return Col(name, table, col.srcs)
357

    
358
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
359

    
360
def as_Col(col, table=None, name=None):
361
    '''
362
    @param name If not None, any non-Col input will be renamed using NamedCol.
363
    '''
364
    if name != None:
365
        col = as_Value(col)
366
        if not isinstance(col, Col): col = NamedCol(name, col)
367
    
368
    if isinstance(col, Code): return col
369
    elif util.is_str(col): return Col(col, table)
370
    else: return Literal(col)
371

    
372
def with_table(col, table):
373
    if isinstance(col, NamedCol): pass # doesn't take a table
374
    elif isinstance(col, FunctionCall):
375
        col = copy.deepcopy(col) # don't modify input!
376
        col.args[0].table = table
377
    elif isinstance(col, Col):
378
        col = copy.copy(col) # don't modify input!
379
        col.table = table
380
    return col
381

    
382
def with_default_table(col, table):
383
    col = as_Col(col)
384
    if col.table == None: col = with_table(col, table)
385
    return col
386

    
387
def set_cols_table(table, cols):
388
    table = as_Table(table)
389
    
390
    for i, col in enumerate(cols):
391
        col = cols[i] = as_Col(col)
392
        col.table = table
393

    
394
def to_name_only_col(col, check_table=None):
395
    col = as_Col(col)
396
    if not is_table_col(col): return col
397
    
398
    if check_table != None:
399
        table = col.table
400
        assert table == None or table == check_table
401
    return Col(col.name)
402

    
403
def suffixed_col(col, suffix):
404
    return Col(concat(col.name, suffix), col.table, col.srcs)
405

    
406
def has_srcs(col): return is_col(col) and col.srcs
407

    
408
def cross_join_srcs(cols):
409
    cols = filter(has_srcs, cols) # empty srcs will mess up the cross join
410
    srcs = [[s.name for s in c.srcs] for c in cols]
411
    if not srcs: return [] # itertools.product() returns [()] for empty input
412
    return [Col(','.join(s)) for s in itertools.product(*srcs)]
413

    
414
class NamedCol(Col):
415
    def __init__(self, name, code):
416
        Col.__init__(self, name)
417
        
418
        code = as_Value(code)
419
        
420
        self.code = code
421
    
422
    def to_str(self, db):
423
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
424
    
425
    def to_Col(self): return Col(self.name)
426

    
427
def remove_col_rename(col):
428
    if isinstance(col, NamedCol): col = col.code
429
    return col
430

    
431
def underlying_col(col):
432
    col = remove_col_rename(col)
433
    if not isinstance(col, Col): raise NoUnderlyingTableException(col)
434
    
435
    return Col(col.name, underlying_table(col.table), col.srcs)
436

    
437
def wrap(wrap_func, value):
438
    '''Wraps a value, propagating any column renaming to the returned value.'''
439
    if isinstance(value, NamedCol):
440
        return NamedCol(value.name, wrap_func(value.code))
441
    else: return wrap_func(value)
442

    
443
class ColDict(dicts.DictProxy):
444
    '''A dict that automatically makes inserted entries Col objects.
445
    Anything that isn't a column is wrapped in a NamedCol with the key's column
446
    name by `as_Col(value, name=key.name)`.
447
    '''
448
    
449
    def __init__(self, db, keys_table, dict_={}):
450
        dicts.DictProxy.__init__(self, OrderedDict())
451
        
452
        keys_table = as_Table(keys_table)
453
        
454
        self.db = db
455
        self.table = keys_table
456
        self.update(dict_) # after setting vars because __setitem__() needs them
457
    
458
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
459
    
460
    def __getitem__(self, key):
461
        return dicts.DictProxy.__getitem__(self, self._key(key))
462
    
463
    def __setitem__(self, key, value):
464
        key = self._key(key)
465
        if value == None:
466
            try: value = self.db.col_info(key).default
467
            except NoUnderlyingTableException: pass # not a table column
468
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
469
    
470
    def _key(self, key): return as_Col(key, self.table)
471

    
472
##### Definitions
473

    
474
class TypedCol(Col):
475
    def __init__(self, name, type_, default=None, nullable=True,
476
        constraints=None):
477
        assert default == None or isinstance(default, Code)
478
        
479
        Col.__init__(self, name)
480
        
481
        self.type = type_
482
        self.default = default
483
        self.nullable = nullable
484
        self.constraints = constraints
485
    
486
    def to_str(self, db):
487
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
488
        if not self.nullable: str_ += ' NOT NULL'
489
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
490
        if self.constraints != None: str_ += ' '+self.constraints
491
        return str_
492
    
493
    def to_Col(self): return Col(self.name)
494

    
495
class SetOf(Code):
496
    def __init__(self, type_):
497
        Code.__init__(self)
498
        
499
        self.type = type_
500
    
501
    def to_str(self, db):
502
        return 'SETOF '+self.type.to_str(db)
503

    
504
class RowType(Code):
505
    def __init__(self, table):
506
        Code.__init__(self)
507
        
508
        self.table = table
509
    
510
    def to_str(self, db):
511
        return self.table.to_str(db)+'%ROWTYPE'
512

    
513
class ColType(Code):
514
    def __init__(self, col):
515
        Code.__init__(self)
516
        
517
        self.col = col
518
    
519
    def to_str(self, db):
520
        return self.col.to_str(db)+'%TYPE'
521

    
522
class ArrayType(Code):
523
    def __init__(self, elem_type):
524
        Code.__init__(self)
525
        elem_type = as_Code(elem_type)
526
        
527
        self.elem_type = elem_type
528
    
529
    def to_str(self, db):
530
        return self.elem_type.to_str(db)+'[]'
531

    
532
##### Functions
533

    
534
Function = Table
535
as_Function = as_Table
536

    
537
class InternalFunction(CustomCode): pass
538

    
539
#### Calls
540

    
541
class NamedArg(NamedCol):
542
    def __init__(self, name, value):
543
        NamedCol.__init__(self, name, value)
544
    
545
    def to_str(self, db):
546
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
547

    
548
class FunctionCall(Code):
549
    def __init__(self, function, *args, **kw_args):
550
        '''
551
        @param args [Code|literal-value...] The function's arguments
552
        '''
553
        Code.__init__(self)
554
        
555
        function = as_Function(function)
556
        def filter_(arg): return remove_col_rename(as_Value(arg))
557
        args = map(filter_, args)
558
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
559
        
560
        self.function = function
561
        self.args = args
562
    
563
    def to_str(self, db):
564
        args_str = ', '.join((v.to_str(db) for v in self.args))
565
        return self.function.to_str(db)+'('+args_str+')'
566

    
567
def wrap_in_func(function, value):
568
    '''Wraps a value inside a function call.
569
    Propagates any column renaming to the returned value.
570
    '''
571
    return wrap(lambda v: FunctionCall(function, v), value)
572

    
573
def unwrap_func_call(func_call, check_name=None):
574
    '''Unwraps any function call to its first argument.
575
    Also removes any column renaming.
576
    '''
577
    func_call = remove_col_rename(func_call)
578
    if not isinstance(func_call, FunctionCall): return func_call
579
    
580
    if check_name != None:
581
        name = func_call.function.name
582
        assert name == None or name == check_name
583
    return func_call.args[0]
584

    
585
#### Definitions
586

    
587
class FunctionDef(Code):
588
    def __init__(self, function, return_type, body, params=[], modifiers=None):
589
        Code.__init__(self)
590
        
591
        return_type = as_Code(return_type)
592
        body = as_Code(body)
593
        
594
        self.function = function
595
        self.return_type = return_type
596
        self.body = body
597
        self.params = params
598
        self.modifiers = modifiers
599
    
600
    def to_str(self, db):
601
        params_str = (', '.join((p.to_str(db) for p in self.params)))
602
        str_ = '''\
603
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
604
RETURNS '''+self.return_type.to_str(db)+'''
605
LANGUAGE '''+self.body.lang+'''
606
'''
607
        if self.modifiers != None: str_ += self.modifiers+'\n'
608
        str_ += '''\
609
AS $$
610
'''+self.body.to_str(db)+'''
611
$$;
612
'''
613
        return str_
614

    
615
class FunctionParam(TypedCol):
616
    def __init__(self, name, type_, default=None, out=False):
617
        TypedCol.__init__(self, name, type_, default)
618
        
619
        self.out = out
620
    
621
    def to_str(self, db):
622
        str_ = TypedCol.to_str(self, db)
623
        if self.out: str_ = 'OUT '+str_
624
        return str_
625
    
626
    def to_Col(self): return Col(self.name)
627

    
628
### PL/pgSQL
629

    
630
class ReturnQuery(Code):
631
    def __init__(self, query):
632
        Code.__init__(self)
633
        
634
        query = as_Code(query)
635
        
636
        self.query = query
637
    
638
    def to_str(self, db):
639
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
640

    
641
## Exceptions
642

    
643
class BaseExcHandler(BasicObject):
644
    def to_str(self, db, body): raise NotImplementedError()
645
    
646
    def __repr__(self): return self.to_str(mockDb, '<body>')
647

    
648
suppress_exc = 'NULL;\n';
649

    
650
reraise_exc = 'RAISE USING ERRCODE = SQLSTATE, MESSAGE = SQLERRM;\n';
651

    
652
class ExcHandler(BaseExcHandler):
653
    def __init__(self, exc, handler=None):
654
        if handler != None: handler = as_Code(handler)
655
        
656
        self.exc = exc
657
        self.handler = handler
658
    
659
    def to_str(self, db, body):
660
        body = as_Code(body)
661
        
662
        if self.handler != None:
663
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
664
        else: handler_str = ' '+suppress_exc
665
        
666
        str_ = '''\
667
BEGIN
668
'''+strings.indent(body.to_str(db))+'''\
669
EXCEPTION
670
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
671
END;\
672
'''
673
        return str_
674

    
675
class NestedExcHandler(BaseExcHandler):
676
    def __init__(self, *handlers):
677
        '''
678
        @param handlers Sorted from outermost to innermost
679
        '''
680
        self.handlers = handlers
681
    
682
    def to_str(self, db, body):
683
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
684
        return body
685

    
686
class ExcToWarning(Code):
687
    def __init__(self, return_):
688
        '''
689
        @param return_ Statement to return a default value in case of error
690
        '''
691
        Code.__init__(self)
692
        
693
        return_ = as_Code(return_)
694
        
695
        self.return_ = return_
696
    
697
    def to_str(self, db):
698
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
699

    
700
unique_violation_handler = ExcHandler('unique_violation')
701

    
702
# Note doubled "\"s because inside Python string
703
plpythonu_error_handler = ExcHandler('internal_error', '''\
704
-- Handle PL/Python exceptions
705
DECLARE
706
    matches text[] := regexp_matches(SQLERRM,
707
        E'^(?:PL/Python: )?(\\\\w+): (.*)$'); -- .* also matches \\n
708
    exc_name text := matches[1];
709
    msg text := matches[2];
710
BEGIN
711
    /* Re-raise PL/Python exceptions with the PL/Python prefix removed.
712
    This allows the exception to be parsed like a native exception.
713
    Always raise as data_exception so it goes in the errors table. */
714
    IF exc_name IS NOT NULL THEN
715
        RAISE data_exception USING MESSAGE = msg;
716
    -- Re-raise non-PL/Python exceptions
717
    ELSE
718
        '''+reraise_exc+'''\
719
    END IF;
720
END;
721
''')
722

    
723
def data_exception_handler(handler):
724
    return ExcHandler('data_exception', handler)
725

    
726
row_var = Table('row')
727

    
728
class RowExcIgnore(Code):
729
    def __init__(self, row_type, select_query, with_row, cols=None,
730
        exc_handler=unique_violation_handler, row_var=row_var):
731
        '''
732
        @param row_type Ignored if a custom row_var is used.
733
        @pre If a custom row_var is used, it must already be defined.
734
        '''
735
        Code.__init__(self, lang='plpgsql')
736
        
737
        row_type = as_Code(row_type)
738
        select_query = as_Code(select_query)
739
        with_row = as_Code(with_row)
740
        row_var = as_Table(row_var)
741
        
742
        self.row_type = row_type
743
        self.select_query = select_query
744
        self.with_row = with_row
745
        self.cols = cols
746
        self.exc_handler = exc_handler
747
        self.row_var = row_var
748
    
749
    def to_str(self, db):
750
        if self.cols == None: row_vars = [self.row_var]
751
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
752
        
753
        # Need an EXCEPTION block for each individual row because "When an error
754
        # is caught by an EXCEPTION clause, [...] all changes to persistent
755
        # database state within the block are rolled back."
756
        # This is unfortunate because "A block containing an EXCEPTION clause is
757
        # significantly more expensive to enter and exit than a block without
758
        # one."
759
        # (http://www.postgresql.org/docs/8.3/static/\
760
        # plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING)
761
        str_ = '''\
762
FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
763
'''+strings.indent(self.select_query.to_str(db))+'''\
764
LOOP
765
'''+strings.indent(self.exc_handler.to_str(db, self.with_row))+'''\
766
END LOOP;
767
'''
768
        if self.row_var == row_var:
769
            str_ = '''\
770
DECLARE
771
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
772
BEGIN
773
'''+strings.indent(str_)+'''\
774
END;
775
'''
776
        return str_
777

    
778
##### Casts
779

    
780
class Cast(FunctionCall):
781
    def __init__(self, type_, value):
782
        type_ = as_Code(type_)
783
        value = as_Value(value)
784
        
785
        # Most types cannot be cast directly to unknown
786
        if type_.to_str(mockDb) == 'unknown': value = Cast('text', value)
787
        
788
        self.type_ = type_
789
        self.value = value
790
    
791
    def to_str(self, db):
792
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_.to_str(db)+')'
793

    
794
def cast_literal(value):
795
    if not is_literal(value): return value
796
    
797
    if util.is_str(value.value): value = Cast('text', value)
798
    return value
799

    
800
##### Conditions
801

    
802
class NotCond(Code):
803
    def __init__(self, cond):
804
        Code.__init__(self)
805
        
806
        if not isinstance(cond, Coalesce): cond = Coalesce(cond, False)
807
        
808
        self.cond = cond
809
    
810
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
811

    
812
custom_cond = object() # tells ColValueCond that value is a plain SQL cond
813

    
814
class ColValueCond(Code):
815
    def __init__(self, col, value):
816
        Code.__init__(self)
817
        
818
        if col is not custom_cond: value = as_ValueCond(value)
819
        
820
        self.col = col
821
        self.value = value
822
    
823
    def to_str(self, db):
824
        if self.col is custom_cond: return self.value.to_str(db)
825
        else: return self.value.to_str(db, self.col)
826

    
827
def combine_conds(conds, keyword=None):
828
    '''
829
    @param keyword The keyword to add before the conditions, if any
830
    '''
831
    str_ = ''
832
    if keyword != None:
833
        if conds == []: whitespace = ''
834
        elif len(conds) == 1: whitespace = ' '
835
        else: whitespace = '\n'
836
        str_ += keyword+whitespace
837
    
838
    str_ += '\nAND '.join(conds)
839
    return str_
840

    
841
##### Condition column comparisons
842

    
843
class ValueCond(BasicObject):
844
    def __init__(self, value):
845
        value = remove_col_rename(as_Value(value))
846
        
847
        self.value = value
848
    
849
    def to_str(self, db, left_value):
850
        '''
851
        @param left_value The Code object that the condition is being applied on
852
        '''
853
        raise NotImplemented()
854
    
855
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
856

    
857
class CompareCond(ValueCond):
858
    def __init__(self, value, operator='='):
859
        '''
860
        @param operator By default, compares NULL values literally. Use '~=' or
861
            '~!=' to pass NULLs through.
862
        '''
863
        ValueCond.__init__(self, value)
864
        self.operator = operator
865
    
866
    def to_str(self, db, left_value):
867
        left_value = remove_col_rename(as_Col(left_value))
868
        
869
        right_value = self.value
870
        
871
        # Parse operator
872
        operator = self.operator
873
        passthru_null_ref = [False]
874
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
875
        neg_ref = [False]
876
        operator = strings.remove_prefix('!', operator, neg_ref)
877
        equals = operator.endswith('=') # also includes <=, >=
878
        
879
        # Handle nullable columns
880
        check_null = False
881
        if not passthru_null_ref[0]: # NULLs compare equal
882
            try: left_value = ensure_not_null(db, left_value)
883
            except ensure_not_null_excs: # fall back to alternate method
884
                check_null = equals and isinstance(right_value, Col)
885
            else:
886
                if isinstance(left_value, EnsureNotNull):
887
                    right_value = ensure_not_null(db, right_value,
888
                        left_value.type) # apply same function to both sides
889
        
890
        if equals and is_null(right_value): operator = 'IS'
891
        
892
        left = left_value.to_str(db)
893
        right = right_value.to_str(db)
894
        
895
        # Create str
896
        str_ = left+' '+operator+' '+right
897
        if check_null:
898
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
899
        if neg_ref[0]: str_ = 'NOT '+str_
900
        return str_
901

    
902
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
903
assume_literal = object()
904

    
905
def as_ValueCond(value, default_table=assume_literal):
906
    if not isinstance(value, ValueCond):
907
        if default_table is not assume_literal:
908
            value = with_default_table(value, default_table)
909
        return CompareCond(value)
910
    else: return value
911

    
912
##### Joins
913

    
914
join_same = object() # tells Join the left and right columns have the same name
915

    
916
# Tells Join the left and right columns have the same name and are never NULL
917
join_same_not_null = object()
918

    
919
filter_out = object() # tells Join to filter out rows that match the join
920

    
921
class Join(BasicObject):
922
    def __init__(self, table, mapping={}, type_=None, custom_cond=None):
923
        '''
924
        @param mapping dict(right_table_col=left_table_col, ...)
925
            or [using_col...]
926
            * if left_table_col is join_same: left_table_col = right_table_col
927
              * Note that right_table_col must be a string
928
            * if left_table_col is join_same_not_null:
929
              left_table_col = right_table_col and both have NOT NULL constraint
930
              * Note that right_table_col must be a string
931
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
932
            * filter_out: equivalent to 'LEFT' with the query filtered by
933
              `table_pkey IS NULL` (indicating no match)
934
        '''
935
        if util.is_str(table): table = Table(table)
936
        if lists.is_seq(mapping):
937
            mapping = dict(((c, join_same_not_null) for c in mapping))
938
        assert type_ == None or util.is_str(type_) or type_ is filter_out
939
        
940
        self.table = table
941
        self.mapping = mapping
942
        self.type_ = type_
943
        self.custom_cond = custom_cond
944
    
945
    def to_str(self, db, left_table_):
946
        def join(entry):
947
            '''Parses non-USING joins'''
948
            right_table_col, left_table_col = entry
949
            
950
            # Switch order (right_table_col is on the left in the comparison)
951
            left = right_table_col
952
            right = left_table_col
953
            left_table = self.table
954
            right_table = left_table_
955
            
956
            # Parse left side
957
            left = with_default_table(left, left_table)
958
            
959
            # Parse special values
960
            left_on_right = Col(left.name, right_table)
961
            if right is join_same: right = left_on_right
962
            elif right is join_same_not_null:
963
                right = CompareCond(left_on_right, '~=')
964
            
965
            # Parse right side
966
            right = as_ValueCond(right, right_table)
967
            
968
            return right.to_str(db, left)
969
        
970
        # Create join condition
971
        type_ = self.type_
972
        joins = self.mapping
973
        if joins == {}: join_cond = None
974
        elif type_ is not filter_out and reduce(operator.and_,
975
            (v is join_same_not_null for v in joins.itervalues())):
976
            # all cols w/ USING, so can use simpler USING syntax
977
            cols = map(to_name_only_col, joins.iterkeys())
978
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
979
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
980
        
981
        if isinstance(self.table, NamedTable): whitespace = '\n'
982
        else: whitespace = ' '
983
        
984
        # Create join
985
        if type_ is filter_out: type_ = 'LEFT'
986
        str_ = ''
987
        if type_ != None: str_ += type_+' '
988
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
989
        if join_cond != None: str_ += whitespace+join_cond
990
        if self.custom_cond != None: str_ += '\nAND '+self.custom_cond
991
        return str_
992
    
993
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
994

    
995
##### Value exprs
996

    
997
all_cols = CustomCode('*')
998

    
999
default = CustomCode('DEFAULT')
1000

    
1001
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
1002

    
1003
class Coalesce(FunctionCall):
1004
    def __init__(self, *args):
1005
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
1006

    
1007
class Nullif(FunctionCall):
1008
    def __init__(self, *args):
1009
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
1010

    
1011
null = Literal(None)
1012
null_as_str = Cast('text', null)
1013

    
1014
def to_text(value): return Coalesce(Cast('text', value), null_as_str)
1015

    
1016
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
1017
null_sentinels = {
1018
    'character varying': r'\N',
1019
    'double precision': 'NaN',
1020
    'integer': 2147483647,
1021
    'text': r'\N',
1022
    'date': 'infinity',
1023
    'timestamp with time zone': 'infinity',
1024
    'taxonrank': 'unknown',
1025
}
1026

    
1027
class EnsureNotNull(Coalesce):
1028
    def __init__(self, value, type_):
1029
        if isinstance(type_, ArrayType): null = []
1030
        else: null = null_sentinels[type_]
1031
        Coalesce.__init__(self, as_Col(value), Cast(type_, null))
1032
        
1033
        self.type = type_
1034
    
1035
    def to_str(self, db):
1036
        col = self.args[0]
1037
        index_col_ = index_col(col)
1038
        if index_col_ != None: return index_col_.to_str(db)
1039
        return Coalesce.to_str(self, db)
1040

    
1041
#### Arrays
1042

    
1043
class ArrayMerge(FunctionCall):
1044
    def __init__(self, sep, array):
1045
        array = to_Array(array)
1046
        FunctionCall.__init__(self, InternalFunction('array_to_string'), array,
1047
            sep)
1048

    
1049
def merge_not_null(db, sep, values):
1050
    return ArrayMerge(sep, map(to_text, values))
1051

    
1052
##### Table exprs
1053

    
1054
class Values(Code):
1055
    def __init__(self, values):
1056
        '''
1057
        @param values [...]|[[...], ...] Can be one or multiple rows.
1058
        '''
1059
        Code.__init__(self)
1060
        
1061
        rows = values
1062
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
1063
            rows = [values]
1064
        for i, row in enumerate(rows):
1065
            rows[i] = map(remove_col_rename, map(as_Value, row))
1066
        
1067
        self.rows = rows
1068
    
1069
    def to_str(self, db):
1070
        return 'VALUES '+(', '.join((Tuple(*r).to_str(db) for r in self.rows)))
1071

    
1072
def NamedValues(name, cols, values):
1073
    '''
1074
    @param cols None|[...]
1075
    @post `cols` will be changed to Col objects with the table set to `name`.
1076
    '''
1077
    table = NamedTable(name, Values(values), cols)
1078
    if cols != None: set_cols_table(table, cols)
1079
    return table
1080

    
1081
##### Database structure
1082

    
1083
def is_nullable(db, value):
1084
    if not is_table_col(value): return is_null(value)
1085
    try: return db.col_info(value).nullable
1086
    except NoUnderlyingTableException: return True # not a table column
1087

    
1088
text_types = set(['character varying', 'text'])
1089

    
1090
def is_text_type(type_): return type_ in text_types
1091

    
1092
def is_text_col(db, col): return is_text_type(db.col_info(col).type)
1093

    
1094
def canon_type(type_):
1095
    if type_ in text_types: return 'text'
1096
    else: return type_
1097

    
1098
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
1099

    
1100
def ensure_not_null(db, col, type_=None):
1101
    '''
1102
    @param col If type_ is not set, must have an underlying column.
1103
    @param type_ If set, overrides the underlying column's type and casts the
1104
        column to it if needed.
1105
    @return EnsureNotNull|Col
1106
    @throws ensure_not_null_excs
1107
    '''
1108
    col = remove_col_rename(col)
1109
    
1110
    try: col_type = db.col_info(underlying_col(col)).type
1111
    except NoUnderlyingTableException:
1112
        if type_ == None and is_null(col): raise # NULL has no type
1113
    else:
1114
        if type_ == None: type_ = col_type
1115
        elif type_ != col_type: col = Cast(type_, col)
1116
    
1117
    if is_nullable(db, col):
1118
        try: col = EnsureNotNull(col, type_)
1119
        except KeyError, e:
1120
            # Warn of no null sentinel for type, even if caller catches error
1121
            warnings.warn(UserWarning(exc.str_(e)))
1122
            raise
1123
    
1124
    return col
1125

    
1126
def try_mk_not_null(db, value):
1127
    '''
1128
    Warning: This function does not guarantee that its result is NOT NULL.
1129
    '''
1130
    try: return ensure_not_null(db, value)
1131
    except ensure_not_null_excs: return value
1132

    
1133
##### Expression transforming
1134

    
1135
true_expr = 'true'
1136
false_expr = 'false'
1137

    
1138
true_re = true_expr
1139
false_re = false_expr
1140
bool_re = r'(?:'+true_re+r'|'+false_re+r')'
1141
atom_re = r'(?:'+bool_re+r'|\([^()]*\)'+r')'
1142

    
1143
def logic_op_re(op, value_re, expr_re=''):
1144
    op_re = ' '+op+' '
1145
    return '(?:'+expr_re+op_re+value_re+'|'+value_re+op_re+expr_re+')'
1146

    
1147
not_re = r'\bNOT '
1148
not_false_re = not_re+false_re+r'\b'
1149
not_true_re = not_re+true_re+r'\b'
1150
and_false_re = logic_op_re('AND', false_re, atom_re)
1151
and_false_not_true_re = '(?:'+not_true_re+'|'+and_false_re+')'
1152
and_true_re = logic_op_re('AND', true_re)
1153
or_re = logic_op_re('OR', bool_re)
1154
or_and_true_re = '(?:'+and_true_re+'|'+or_re+')'
1155

    
1156
def simplify_parens(expr):
1157
    return regexp.sub_nested(r'\(('+atom_re+')\)', r'\1', expr)
1158

    
1159
def simplify_recursive(sub_func, expr):
1160
    '''
1161
    @param sub_func See regexp.sub_recursive() sub_func param
1162
    '''
1163
    return regexp.sub_recursive(lambda s: sub_func(simplify_parens(s)), expr)
1164
        # simplify_parens() is also done at end in final iteration
1165

    
1166
def simplify_expr(expr):
1167
    '''
1168
    this can also be done in Postgres with expression substitution
1169
    (wiki.vegpath.org/Postgres_queries#expression-substitution)
1170
    '''
1171
    def simplify_logic_ops(expr):
1172
        total_n = 0
1173
        expr, n = re.subn(not_false_re, true_re, expr)
1174
        total_n += n
1175
        expr, n = re.subn(and_false_not_true_re, false_expr, expr)
1176
        total_n += n
1177
        expr, n = re.subn(or_and_true_re, r'', expr)
1178
        total_n += n
1179
        return expr, total_n
1180
    
1181
    expr = expr.replace('NULL IS NULL', true_expr)
1182
    expr = expr.replace('NULL IS NOT NULL', false_expr)
1183
    expr = simplify_recursive(simplify_logic_ops, expr)
1184
    return expr
1185

    
1186
name_re = r'(?:\w+|(?:"[^"]*")+)'
1187

    
1188
def parse_expr_col(str_):
1189
    match = re.match(r'^\('+name_re+r'\(('+name_re+r').*\)\)$', str_)
1190
    if match: str_ = match.group(1)
1191
    return unesc_name(str_)
1192

    
1193
def map_expr(db, expr, mapping, in_cols_found=None):
1194
    '''Replaces output columns with input columns in an expression.
1195
    @param in_cols_found If set, will be filled in with the expr's (input) cols
1196
    
1197
    this can also be done in Postgres with expression substitution
1198
    (wiki.vegpath.org/Postgres_queries#expression-substitution)
1199
    
1200
    this is a special case of bin/repl SQL identifier handling which does not
1201
    handle entire source files, but which does simplify the resulting expression
1202
    '''
1203
    for out, in_ in mapping.iteritems():
1204
        orig_expr = expr
1205
        out = to_name_only_col(out)
1206
        in_str = to_name_only_col(remove_col_rename(in_)).to_str(db)
1207
        
1208
        # Replace out both with and without quotes
1209
        expr = expr.replace(out.to_str(db), in_str)
1210
        expr = re.sub(r'(?<!["\'\.=\[])\b'+out.name+r'\b(?!["\',\.=\]])',
1211
            in_str, expr)
1212
        
1213
        if in_cols_found != None and expr != orig_expr: # replaced something
1214
            in_cols_found.append(in_)
1215
    
1216
    return simplify_expr(expr)
(36-36/49)