Project

General

Profile

1
# SQL code generation
2

    
3
import copy
4
import itertools
5
import operator
6
from ordereddict import OrderedDict
7
import re
8
import UserDict
9
import warnings
10

    
11
import dicts
12
import exc
13
import iters
14
import lists
15
import objects
16
import regexp
17
import strings
18
import util
19

    
20
##### Names
21

    
22
identifier_max_len = 63 # works for both PostgreSQL and MySQL
23

    
24
def concat(str_, suffix):
25
    '''Preserves version so that it won't be truncated off the string, leading
26
    to collisions.'''
27
    # Preserve version
28
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)?(?:::[\w ]+)*)$', str_)
29
    if match:
30
        str_, old_suffix = match.groups()
31
        suffix = old_suffix+suffix
32
    
33
    return strings.concat(str_, suffix, identifier_max_len)
34

    
35
def truncate(str_): return concat(str_, '')
36

    
37
def is_safe_name(name):
38
    '''A name is safe *and unambiguous* if it:
39
    * contains only *lowercase* word (\w) characters
40
    * doesn't start with a digit
41
    * contains "_", so that it's not a keyword
42
    '''
43
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)
44

    
45
def esc_name(name, quote='"'):
46
    return quote + name.replace(quote, quote+quote) + quote
47
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL
48

    
49
def unesc_name(name, quote='"'):
50
    removed_ref = [False]
51
    name = strings.remove_prefix(quote, name, removed_ref)
52
    if removed_ref[0]:
53
        name = strings.remove_suffix(quote, name, removed_ref)
54
        assert removed_ref[0]
55
        name = name.replace(quote+quote, quote)
56
    return name
57

    
58
def clean_name(name): return name.replace('"', '').replace('`', '')
59

    
60
def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'
61

    
62
def lstrip(str_):
63
    '''Also removes comments.'''
64
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
65
    return str_.lstrip()
66

    
67
##### General SQL code objects
68

    
69
class MockDb:
70
    def esc_value(self, value): return strings.repr_no_u(value)
71
    
72
    def esc_name(self, name): return esc_name(name)
73
    
74
    def col_info(self, col):
75
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)
76

    
77
mockDb = MockDb()
78

    
79
class BasicObject(objects.BasicObject):
80
    def __str__(self): return clean_name(strings.repr_no_u(self))
81

    
82
##### Unparameterized code objects
83

    
84
class Code(BasicObject):
85
    def __init__(self, lang='sql'):
86
        self.lang = lang
87
    
88
    def to_str(self, db): raise NotImplementedError()
89
    
90
    def __repr__(self): return self.to_str(mockDb)
91

    
92
class CustomCode(Code):
93
    def __init__(self, str_):
94
        Code.__init__(self)
95
        
96
        self.str_ = str_
97
    
98
    def to_str(self, db): return self.str_
99

    
100
def as_Code(value, db=None):
101
    '''
102
    @param db If set, runs db.std_code() on the value.
103
    '''
104
    if isinstance(value, Code): return value
105
    
106
    if util.is_str(value):
107
        if db != None: value = db.std_code(value)
108
        return CustomCode(value)
109
    else: return Literal(value)
110

    
111
class Expr(Code):
112
    def __init__(self, expr):
113
        Code.__init__(self)
114
        
115
        self.expr = expr
116
    
117
    def to_str(self, db): return '('+self.expr.to_str(db)+')'
118

    
119
##### Names
120

    
121
class Name(Code):
122
    def __init__(self, name):
123
        Code.__init__(self)
124
        
125
        name = truncate(name)
126
        
127
        self.name = name
128
    
129
    def to_str(self, db): return db.esc_name(self.name)
130

    
131
def as_Name(value):
132
    if isinstance(value, Code): return value
133
    else: return Name(value)
134

    
135
##### Literal values
136

    
137
#### Primitives
138

    
139
class Literal(Code):
140
    def __init__(self, value):
141
        Code.__init__(self)
142
        
143
        self.value = value
144
    
145
    def to_str(self, db): return db.esc_value(self.value)
146

    
147
def as_Value(value):
148
    if isinstance(value, Code): return value
149
    else: return Literal(value)
150

    
151
def is_literal(value): return isinstance(value, Literal)
152

    
153
def is_null(value): return is_literal(value) and value.value == None
154

    
155
#### Composites
156

    
157
class List(Code):
158
    def __init__(self, values):
159
        Code.__init__(self)
160
        
161
        self.values = values
162
    
163
    def to_str(self, db): return ', '.join((v.to_str(db) for v in self.values))
164

    
165
class Tuple(List):
166
    def __init__(self, *values):
167
        List.__init__(self, values)
168
    
169
    def to_str(self, db): return '('+List.to_str(self, db)+')'
170

    
171
class Row(Tuple):
172
    def to_str(self, db): return 'ROW'+Tuple.to_str(self, db)
173

    
174
### Arrays
175

    
176
class Array(List):
177
    def __init__(self, values):
178
        values = map(remove_col_rename, values)
179
        
180
        List.__init__(self, values)
181
    
182
    def to_str(self, db): return 'ARRAY['+List.to_str(self, db)+']'
183

    
184
def to_Array(value):
185
    if isinstance(value, Array): return value
186
    return Array(lists.mk_seq(value))
187

    
188
##### Derived elements
189

    
190
src_self = object() # tells Col that it is its own source column
191

    
192
class Derived(Code):
193
    def __init__(self, srcs):
194
        '''An element which was derived from some other element(s).
195
        @param srcs See self.set_srcs()
196
        '''
197
        Code.__init__(self)
198
        
199
        self.set_srcs(srcs)
200
    
201
    def set_srcs(self, srcs, overwrite=True):
202
        '''
203
        @param srcs (self_type...)|src_self The element(s) this is derived from
204
        '''
205
        if not overwrite and self.srcs != (): return # already set
206
        
207
        if srcs == src_self: srcs = (self,)
208
        srcs = tuple(srcs) # make Col hashable
209
        self.srcs = srcs
210
    
211
    def _compare_on(self):
212
        compare_on = self.__dict__.copy()
213
        del compare_on['srcs'] # ignore
214
        return compare_on
215

    
216
def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))
217

    
218
##### Tables
219

    
220
class Table(Derived):
221
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
222
        '''
223
        @param schema str|None (for no schema)
224
        @param srcs (Table...)|src_self See Derived.set_srcs()
225
        '''
226
        Derived.__init__(self, srcs)
227
        
228
        if util.is_str(name): name = truncate(name)
229
        
230
        self.name = name
231
        self.schema = schema
232
        self.is_temp = is_temp
233
        self.index_cols = {}
234
    
235
    def to_str(self, db):
236
        str_ = ''
237
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
238
        str_ += as_Name(self.name).to_str(db)
239
        return str_
240
    
241
    def to_Table(self): return self
242
    
243
    def _compare_on(self):
244
        compare_on = Derived._compare_on(self)
245
        del compare_on['index_cols'] # ignore
246
        return compare_on
247

    
248
def is_underlying_table(table):
249
    return isinstance(table, Table) and table.to_Table() is table
250

    
251
class NoUnderlyingTableException(Exception):
252
    def __init__(self, ref):
253
        Exception.__init__(self, 'for: '+strings.as_tt(strings.urepr(ref)))
254
        self.ref = ref
255

    
256
def underlying_table(table):
257
    table = remove_table_rename(table)
258
    if table != None and table.srcs:
259
        table, = table.srcs # for derived tables or row vars
260
    if not is_underlying_table(table): raise NoUnderlyingTableException(table)
261
    return table
262

    
263
def as_Table(table, schema=None):
264
    if table == None or isinstance(table, Code): return table
265
    else: return Table(table, schema)
266

    
267
def suffixed_table(table, suffix):
268
    table = copy.copy(table) # don't modify input!
269
    table.name = concat(table.name, suffix)
270
    return table
271

    
272
class NamedTable(Table):
273
    def __init__(self, name, code, cols=None):
274
        Table.__init__(self, name)
275
        
276
        code = as_Table(code)
277
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
278
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
279
        
280
        self.code = code
281
        self.cols = cols
282
    
283
    def to_str(self, db):
284
        str_ = self.code.to_str(db)
285
        if str_.find('\n') >= 0: whitespace = '\n'
286
        else: whitespace = ' '
287
        str_ += whitespace+'AS '+Table.to_str(self, db)
288
        if self.cols != None:
289
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
290
        return str_
291
    
292
    def to_Table(self): return Table(self.name)
293

    
294
def remove_table_rename(table):
295
    if isinstance(table, NamedTable): table = table.code
296
    return table
297

    
298
##### Columns
299

    
300
class Col(Derived):
301
    def __init__(self, name, table=None, srcs=()):
302
        '''
303
        @param table Table|None (for no table)
304
        @param srcs (Col...)|src_self See Derived.set_srcs()
305
        '''
306
        Derived.__init__(self, srcs)
307
        
308
        if util.is_str(name): name = truncate(name)
309
        if util.is_str(table): table = Table(table)
310
        assert table == None or isinstance(table, Table)
311
        
312
        self.name = name
313
        self.table = table
314
    
315
    def to_str(self, db, for_str=False):
316
        str_ = as_Name(self.name).to_str(db)
317
        if for_str: str_ = clean_name(str_)
318
        if self.table != None:
319
            table = self.table.to_Table()
320
            if for_str: str_ = concat(strings.ustr(table), '.'+str_)
321
            else: str_ = table.to_str(db)+'.'+str_
322
        return str_
323
    
324
    def __str__(self): return self.to_str(mockDb, for_str=True)
325
    
326
    def to_Col(self): return self
327

    
328
def is_col(col): return isinstance(col, Col)
329

    
330
def is_table_col(col): return is_col(col) and col.table != None
331

    
332
def index_col(col):
333
    if not is_table_col(col): return None
334
    
335
    table = col.table
336
    try: name = table.index_cols[col.name]
337
    except KeyError: return None
338
    else: return Col(name, table, col.srcs)
339

    
340
def is_temp_col(col): return is_table_col(col) and col.table.is_temp
341

    
342
def as_Col(col, table=None, name=None):
343
    '''
344
    @param name If not None, any non-Col input will be renamed using NamedCol.
345
    '''
346
    if name != None:
347
        col = as_Value(col)
348
        if not isinstance(col, Col): col = NamedCol(name, col)
349
    
350
    if isinstance(col, Code): return col
351
    elif util.is_str(col): return Col(col, table)
352
    else: return Literal(col)
353

    
354
def with_table(col, table):
355
    if isinstance(col, NamedCol): pass # doesn't take a table
356
    elif isinstance(col, FunctionCall):
357
        col = copy.deepcopy(col) # don't modify input!
358
        col.args[0].table = table
359
    elif isinstance(col, Col):
360
        col = copy.copy(col) # don't modify input!
361
        col.table = table
362
    return col
363

    
364
def with_default_table(col, table):
365
    col = as_Col(col)
366
    if col.table == None: col = with_table(col, table)
367
    return col
368

    
369
def set_cols_table(table, cols):
370
    table = as_Table(table)
371
    
372
    for i, col in enumerate(cols):
373
        col = cols[i] = as_Col(col)
374
        col.table = table
375

    
376
def to_name_only_col(col, check_table=None):
377
    col = as_Col(col)
378
    if not is_table_col(col): return col
379
    
380
    if check_table != None:
381
        table = col.table
382
        assert table == None or table == check_table
383
    return Col(col.name)
384

    
385
def suffixed_col(col, suffix):
386
    return Col(concat(col.name, suffix), col.table, col.srcs)
387

    
388
def has_srcs(col): return is_col(col) and col.srcs
389

    
390
def cross_join_srcs(cols):
391
    cols = filter(has_srcs, cols) # empty srcs will mess up the cross join
392
    srcs = [[s.name for s in c.srcs] for c in cols]
393
    return [Col(','.join(s)) for s in itertools.product(*srcs)]
394

    
395
class NamedCol(Col):
396
    def __init__(self, name, code):
397
        Col.__init__(self, name)
398
        
399
        code = as_Value(code)
400
        
401
        self.code = code
402
    
403
    def to_str(self, db):
404
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
405
    
406
    def to_Col(self): return Col(self.name)
407

    
408
def remove_col_rename(col):
409
    if isinstance(col, NamedCol): col = col.code
410
    return col
411

    
412
def underlying_col(col):
413
    col = remove_col_rename(col)
414
    if not isinstance(col, Col): raise NoUnderlyingTableException(col)
415
    
416
    return Col(col.name, underlying_table(col.table), col.srcs)
417

    
418
def wrap(wrap_func, value):
419
    '''Wraps a value, propagating any column renaming to the returned value.'''
420
    if isinstance(value, NamedCol):
421
        return NamedCol(value.name, wrap_func(value.code))
422
    else: return wrap_func(value)
423

    
424
class ColDict(dicts.DictProxy):
425
    '''A dict that automatically makes inserted entries Col objects.
426
    Anything that isn't a column is wrapped in a NamedCol with the key's column
427
    name by `as_Col(value, name=key.name)`.
428
    '''
429
    
430
    def __init__(self, db, keys_table, dict_={}):
431
        dicts.DictProxy.__init__(self, OrderedDict())
432
        
433
        keys_table = as_Table(keys_table)
434
        
435
        self.db = db
436
        self.table = keys_table
437
        self.update(dict_) # after setting vars because __setitem__() needs them
438
    
439
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
440
    
441
    def __getitem__(self, key):
442
        return dicts.DictProxy.__getitem__(self, self._key(key))
443
    
444
    def __setitem__(self, key, value):
445
        key = self._key(key)
446
        if value == None:
447
            try: value = self.db.col_info(key).default
448
            except NoUnderlyingTableException: pass # not a table column
449
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
450
    
451
    def _key(self, key): return as_Col(key, self.table)
452

    
453
##### Definitions
454

    
455
class TypedCol(Col):
456
    def __init__(self, name, type_, default=None, nullable=True,
457
        constraints=None):
458
        assert default == None or isinstance(default, Code)
459
        
460
        Col.__init__(self, name)
461
        
462
        self.type = type_
463
        self.default = default
464
        self.nullable = nullable
465
        self.constraints = constraints
466
    
467
    def to_str(self, db):
468
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
469
        if not self.nullable: str_ += ' NOT NULL'
470
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
471
        if self.constraints != None: str_ += ' '+self.constraints
472
        return str_
473
    
474
    def to_Col(self): return Col(self.name)
475

    
476
class SetOf(Code):
477
    def __init__(self, type_):
478
        Code.__init__(self)
479
        
480
        self.type = type_
481
    
482
    def to_str(self, db):
483
        return 'SETOF '+self.type.to_str(db)
484

    
485
class RowType(Code):
486
    def __init__(self, table):
487
        Code.__init__(self)
488
        
489
        self.table = table
490
    
491
    def to_str(self, db):
492
        return self.table.to_str(db)+'%ROWTYPE'
493

    
494
class ColType(Code):
495
    def __init__(self, col):
496
        Code.__init__(self)
497
        
498
        self.col = col
499
    
500
    def to_str(self, db):
501
        return self.col.to_str(db)+'%TYPE'
502

    
503
class ArrayType(Code):
504
    def __init__(self, elem_type):
505
        Code.__init__(self)
506
        elem_type = as_Code(elem_type)
507
        
508
        self.elem_type = elem_type
509
    
510
    def to_str(self, db):
511
        return self.elem_type.to_str(db)+'[]'
512

    
513
##### Functions
514

    
515
Function = Table
516
as_Function = as_Table
517

    
518
class InternalFunction(CustomCode): pass
519

    
520
#### Calls
521

    
522
class NamedArg(NamedCol):
523
    def __init__(self, name, value):
524
        NamedCol.__init__(self, name, value)
525
    
526
    def to_str(self, db):
527
        return Col.to_str(self, db)+' := '+self.code.to_str(db)
528

    
529
class FunctionCall(Code):
530
    def __init__(self, function, *args, **kw_args):
531
        '''
532
        @param args [Code|literal-value...] The function's arguments
533
        '''
534
        Code.__init__(self)
535
        
536
        function = as_Function(function)
537
        def filter_(arg): return remove_col_rename(as_Value(arg))
538
        args = map(filter_, args)
539
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
540
        
541
        self.function = function
542
        self.args = args
543
    
544
    def to_str(self, db):
545
        args_str = ', '.join((v.to_str(db) for v in self.args))
546
        return self.function.to_str(db)+'('+args_str+')'
547

    
548
def wrap_in_func(function, value):
549
    '''Wraps a value inside a function call.
550
    Propagates any column renaming to the returned value.
551
    '''
552
    return wrap(lambda v: FunctionCall(function, v), value)
553

    
554
def unwrap_func_call(func_call, check_name=None):
555
    '''Unwraps any function call to its first argument.
556
    Also removes any column renaming.
557
    '''
558
    func_call = remove_col_rename(func_call)
559
    if not isinstance(func_call, FunctionCall): return func_call
560
    
561
    if check_name != None:
562
        name = func_call.function.name
563
        assert name == None or name == check_name
564
    return func_call.args[0]
565

    
566
#### Definitions
567

    
568
class FunctionDef(Code):
569
    def __init__(self, function, return_type, body, params=[], modifiers=None):
570
        Code.__init__(self)
571
        
572
        return_type = as_Code(return_type)
573
        body = as_Code(body)
574
        
575
        self.function = function
576
        self.return_type = return_type
577
        self.body = body
578
        self.params = params
579
        self.modifiers = modifiers
580
    
581
    def to_str(self, db):
582
        params_str = (', '.join((p.to_str(db) for p in self.params)))
583
        str_ = '''\
584
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
585
RETURNS '''+self.return_type.to_str(db)+'''
586
LANGUAGE '''+self.body.lang+'''
587
'''
588
        if self.modifiers != None: str_ += self.modifiers+'\n'
589
        str_ += '''\
590
AS $$
591
'''+self.body.to_str(db)+'''
592
$$;
593
'''
594
        return str_
595

    
596
class FunctionParam(TypedCol):
597
    def __init__(self, name, type_, default=None, out=False):
598
        TypedCol.__init__(self, name, type_, default)
599
        
600
        self.out = out
601
    
602
    def to_str(self, db):
603
        str_ = TypedCol.to_str(self, db)
604
        if self.out: str_ = 'OUT '+str_
605
        return str_
606
    
607
    def to_Col(self): return Col(self.name)
608

    
609
### PL/pgSQL
610

    
611
class ReturnQuery(Code):
612
    def __init__(self, query):
613
        Code.__init__(self)
614
        
615
        query = as_Code(query)
616
        
617
        self.query = query
618
    
619
    def to_str(self, db):
620
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'
621

    
622
## Exceptions
623

    
624
class BaseExcHandler(BasicObject):
625
    def to_str(self, db, body): raise NotImplementedError()
626
    
627
    def __repr__(self): return self.to_str(mockDb, '<body>')
628

    
629
suppress_exc = 'NULL;\n';
630

    
631
reraise_exc = 'RAISE USING ERRCODE = SQLSTATE, MESSAGE = SQLERRM;\n';
632

    
633
class ExcHandler(BaseExcHandler):
634
    def __init__(self, exc, handler=None):
635
        if handler != None: handler = as_Code(handler)
636
        
637
        self.exc = exc
638
        self.handler = handler
639
    
640
    def to_str(self, db, body):
641
        body = as_Code(body)
642
        
643
        if self.handler != None:
644
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
645
        else: handler_str = ' '+suppress_exc
646
        
647
        str_ = '''\
648
BEGIN
649
'''+strings.indent(body.to_str(db))+'''\
650
EXCEPTION
651
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
652
END;\
653
'''
654
        return str_
655

    
656
class NestedExcHandler(BaseExcHandler):
657
    def __init__(self, *handlers):
658
        '''
659
        @param handlers Sorted from outermost to innermost
660
        '''
661
        self.handlers = handlers
662
    
663
    def to_str(self, db, body):
664
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
665
        return body
666

    
667
class ExcToWarning(Code):
668
    def __init__(self, return_):
669
        '''
670
        @param return_ Statement to return a default value in case of error
671
        '''
672
        Code.__init__(self)
673
        
674
        return_ = as_Code(return_)
675
        
676
        self.return_ = return_
677
    
678
    def to_str(self, db):
679
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)
680

    
681
unique_violation_handler = ExcHandler('unique_violation')
682

    
683
# Note doubled "\"s because inside Python string
684
plpythonu_error_handler = ExcHandler('internal_error', '''\
685
-- Handle PL/Python exceptions
686
DECLARE
687
    matches text[] := regexp_matches(SQLERRM,
688
        E'^(?:PL/Python: )?(\\\\w+): (.*)$'); -- .* also matches \\n
689
    exc_name text := matches[1];
690
    msg text := matches[2];
691
BEGIN
692
    /* Re-raise PL/Python exceptions with the PL/Python prefix removed.
693
    This allows the exception to be parsed like a native exception.
694
    Always raise as data_exception so it goes in the errors table. */
695
    IF exc_name IS NOT NULL THEN
696
        RAISE data_exception USING MESSAGE = msg;
697
    -- Re-raise non-PL/Python exceptions
698
    ELSE
699
        '''+reraise_exc+'''\
700
    END IF;
701
END;
702
''')
703

    
704
def data_exception_handler(handler):
705
    return ExcHandler('data_exception', handler)
706

    
707
row_var = Table('row')
708

    
709
class RowExcIgnore(Code):
710
    def __init__(self, row_type, select_query, with_row, cols=None,
711
        exc_handler=unique_violation_handler, row_var=row_var):
712
        '''
713
        @param row_type Ignored if a custom row_var is used.
714
        @pre If a custom row_var is used, it must already be defined.
715
        '''
716
        Code.__init__(self, lang='plpgsql')
717
        
718
        row_type = as_Code(row_type)
719
        select_query = as_Code(select_query)
720
        with_row = as_Code(with_row)
721
        row_var = as_Table(row_var)
722
        
723
        self.row_type = row_type
724
        self.select_query = select_query
725
        self.with_row = with_row
726
        self.cols = cols
727
        self.exc_handler = exc_handler
728
        self.row_var = row_var
729
    
730
    def to_str(self, db):
731
        if self.cols == None: row_vars = [self.row_var]
732
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
733
        
734
        # Need an EXCEPTION block for each individual row because "When an error
735
        # is caught by an EXCEPTION clause, [...] all changes to persistent
736
        # database state within the block are rolled back."
737
        # This is unfortunate because "A block containing an EXCEPTION clause is
738
        # significantly more expensive to enter and exit than a block without
739
        # one."
740
        # (http://www.postgresql.org/docs/8.3/static/\
741
        # plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING)
742
        str_ = '''\
743
FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
744
'''+strings.indent(self.select_query.to_str(db))+'''\
745
LOOP
746
'''+strings.indent(self.exc_handler.to_str(db, self.with_row))+'''\
747
END LOOP;
748
'''
749
        if self.row_var == row_var:
750
            str_ = '''\
751
DECLARE
752
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
753
BEGIN
754
'''+strings.indent(str_)+'''\
755
END;
756
'''
757
        return str_
758

    
759
##### Casts
760

    
761
class Cast(FunctionCall):
762
    def __init__(self, type_, value):
763
        type_ = as_Code(type_)
764
        value = as_Value(value)
765
        
766
        self.type_ = type_
767
        self.value = value
768
    
769
    def to_str(self, db):
770
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_.to_str(db)+')'
771

    
772
def cast_literal(value):
773
    if not is_literal(value): return value
774
    
775
    if util.is_str(value.value): value = Cast('text', value)
776
    return value
777

    
778
##### Conditions
779

    
780
class NotCond(Code):
781
    def __init__(self, cond):
782
        Code.__init__(self)
783
        
784
        self.cond = cond
785
    
786
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)
787

    
788
class ColValueCond(Code):
789
    def __init__(self, col, value):
790
        Code.__init__(self)
791
        
792
        value = as_ValueCond(value)
793
        
794
        self.col = col
795
        self.value = value
796
    
797
    def to_str(self, db): return self.value.to_str(db, self.col)
798

    
799
def combine_conds(conds, keyword=None):
800
    '''
801
    @param keyword The keyword to add before the conditions, if any
802
    '''
803
    str_ = ''
804
    if keyword != None:
805
        if conds == []: whitespace = ''
806
        elif len(conds) == 1: whitespace = ' '
807
        else: whitespace = '\n'
808
        str_ += keyword+whitespace
809
    
810
    str_ += '\nAND '.join(conds)
811
    return str_
812

    
813
##### Condition column comparisons
814

    
815
class ValueCond(BasicObject):
816
    def __init__(self, value):
817
        value = remove_col_rename(as_Value(value))
818
        
819
        self.value = value
820
    
821
    def to_str(self, db, left_value):
822
        '''
823
        @param left_value The Code object that the condition is being applied on
824
        '''
825
        raise NotImplemented()
826
    
827
    def __repr__(self): return self.to_str(mockDb, '<left_value>')
828

    
829
class CompareCond(ValueCond):
830
    def __init__(self, value, operator='='):
831
        '''
832
        @param operator By default, compares NULL values literally. Use '~=' or
833
            '~!=' to pass NULLs through.
834
        '''
835
        ValueCond.__init__(self, value)
836
        self.operator = operator
837
    
838
    def to_str(self, db, left_value):
839
        left_value = remove_col_rename(as_Col(left_value))
840
        
841
        right_value = self.value
842
        
843
        # Parse operator
844
        operator = self.operator
845
        passthru_null_ref = [False]
846
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
847
        neg_ref = [False]
848
        operator = strings.remove_prefix('!', operator, neg_ref)
849
        equals = operator.endswith('=') # also includes <=, >=
850
        
851
        # Handle nullable columns
852
        check_null = False
853
        if not passthru_null_ref[0]: # NULLs compare equal
854
            try: left_value = ensure_not_null(db, left_value)
855
            except ensure_not_null_excs: # fall back to alternate method
856
                check_null = equals and isinstance(right_value, Col)
857
            else:
858
                if isinstance(left_value, EnsureNotNull):
859
                    right_value = ensure_not_null(db, right_value,
860
                        left_value.type) # apply same function to both sides
861
        
862
        if equals and is_null(right_value): operator = 'IS'
863
        
864
        left = left_value.to_str(db)
865
        right = right_value.to_str(db)
866
        
867
        # Create str
868
        str_ = left+' '+operator+' '+right
869
        if check_null:
870
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
871
        if neg_ref[0]: str_ = 'NOT '+str_
872
        return str_
873

    
874
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
875
assume_literal = object()
876

    
877
def as_ValueCond(value, default_table=assume_literal):
878
    if not isinstance(value, ValueCond):
879
        if default_table is not assume_literal:
880
            value = with_default_table(value, default_table)
881
        return CompareCond(value)
882
    else: return value
883

    
884
##### Joins
885

    
886
join_same = object() # tells Join the left and right columns have the same name
887

    
888
# Tells Join the left and right columns have the same name and are never NULL
889
join_same_not_null = object()
890

    
891
filter_out = object() # tells Join to filter out rows that match the join
892

    
893
class Join(BasicObject):
894
    def __init__(self, table, mapping={}, type_=None):
895
        '''
896
        @param mapping dict(right_table_col=left_table_col, ...)
897
            * if left_table_col is join_same: left_table_col = right_table_col
898
              * Note that right_table_col must be a string
899
            * if left_table_col is join_same_not_null:
900
              left_table_col = right_table_col and both have NOT NULL constraint
901
              * Note that right_table_col must be a string
902
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
903
            * filter_out: equivalent to 'LEFT' with the query filtered by
904
              `table_pkey IS NULL` (indicating no match)
905
        '''
906
        if util.is_str(table): table = Table(table)
907
        assert type_ == None or util.is_str(type_) or type_ is filter_out
908
        
909
        self.table = table
910
        self.mapping = mapping
911
        self.type_ = type_
912
    
913
    def to_str(self, db, left_table_):
914
        def join(entry):
915
            '''Parses non-USING joins'''
916
            right_table_col, left_table_col = entry
917
            
918
            # Switch order (right_table_col is on the left in the comparison)
919
            left = right_table_col
920
            right = left_table_col
921
            left_table = self.table
922
            right_table = left_table_
923
            
924
            # Parse left side
925
            left = with_default_table(left, left_table)
926
            
927
            # Parse special values
928
            left_on_right = Col(left.name, right_table)
929
            if right is join_same: right = left_on_right
930
            elif right is join_same_not_null:
931
                right = CompareCond(left_on_right, '~=')
932
            
933
            # Parse right side
934
            right = as_ValueCond(right, right_table)
935
            
936
            return right.to_str(db, left)
937
        
938
        # Create join condition
939
        type_ = self.type_
940
        joins = self.mapping
941
        if joins == {}: join_cond = None
942
        elif type_ is not filter_out and reduce(operator.and_,
943
            (v is join_same_not_null for v in joins.itervalues())):
944
            # all cols w/ USING, so can use simpler USING syntax
945
            cols = map(to_name_only_col, joins.iterkeys())
946
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
947
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
948
        
949
        if isinstance(self.table, NamedTable): whitespace = '\n'
950
        else: whitespace = ' '
951
        
952
        # Create join
953
        if type_ is filter_out: type_ = 'LEFT'
954
        str_ = ''
955
        if type_ != None: str_ += type_+' '
956
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
957
        if join_cond != None: str_ += whitespace+join_cond
958
        return str_
959
    
960
    def __repr__(self): return self.to_str(mockDb, '<left_table>')
961

    
962
##### Value exprs
963

    
964
all_cols = CustomCode('*')
965

    
966
default = CustomCode('DEFAULT')
967

    
968
row_count = FunctionCall(InternalFunction('COUNT'), all_cols)
969

    
970
class Coalesce(FunctionCall):
971
    def __init__(self, *args):
972
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)
973

    
974
class Nullif(FunctionCall):
975
    def __init__(self, *args):
976
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)
977

    
978
null_as_str = Cast('text', 'NULL')
979

    
980
def to_text(value): return Coalesce(Cast('text', value), null_as_str)
981

    
982
# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
983
null_sentinels = {
984
    'character varying': r'\N',
985
    'double precision': 'NaN',
986
    'integer': 2147483647,
987
    'text': r'\N',
988
    'timestamp with time zone': 'infinity',
989
    'taxonrank': 'unknown',
990
}
991

    
992
class EnsureNotNull(Coalesce):
993
    def __init__(self, value, type_):
994
        if isinstance(type_, ArrayType): null = []
995
        else: null = null_sentinels[type_]
996
        Coalesce.__init__(self, as_Col(value), Cast(type_, null))
997
        
998
        self.type = type_
999
    
1000
    def to_str(self, db):
1001
        col = self.args[0]
1002
        index_col_ = index_col(col)
1003
        if index_col_ != None: return index_col_.to_str(db)
1004
        return Coalesce.to_str(self, db)
1005

    
1006
#### Arrays
1007

    
1008
class ArrayMerge(FunctionCall):
1009
    def __init__(self, sep, array):
1010
        array = to_Array(array)
1011
        FunctionCall.__init__(self, InternalFunction('array_to_string'), array,
1012
            sep)
1013

    
1014
def merge_not_null(db, sep, values):
1015
    return ArrayMerge(sep, map(to_text, values))
1016

    
1017
##### Table exprs
1018

    
1019
class Values(Code):
1020
    def __init__(self, values):
1021
        '''
1022
        @param values [...]|[[...], ...] Can be one or multiple rows.
1023
        '''
1024
        Code.__init__(self)
1025
        
1026
        rows = values
1027
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
1028
            rows = [values]
1029
        for i, row in enumerate(rows):
1030
            rows[i] = map(remove_col_rename, map(as_Value, row))
1031
        
1032
        self.rows = rows
1033
    
1034
    def to_str(self, db):
1035
        return 'VALUES '+(', '.join((Tuple(*r).to_str(db) for r in self.rows)))
1036

    
1037
def NamedValues(name, cols, values):
1038
    '''
1039
    @param cols None|[...]
1040
    @post `cols` will be changed to Col objects with the table set to `name`.
1041
    '''
1042
    table = NamedTable(name, Values(values), cols)
1043
    if cols != None: set_cols_table(table, cols)
1044
    return table
1045

    
1046
##### Database structure
1047

    
1048
def is_nullable(db, value):
1049
    if not is_table_col(value): return is_null(value)
1050
    try: return db.col_info(value).nullable
1051
    except NoUnderlyingTableException: return True # not a table column
1052

    
1053
text_types = set(['character varying', 'text'])
1054

    
1055
def is_text_type(type_): return type_ in text_types
1056

    
1057
def is_text_col(db, col): return is_text_type(db.col_info(col).type)
1058

    
1059
ensure_not_null_excs = (NoUnderlyingTableException, KeyError)
1060

    
1061
def ensure_not_null(db, col, type_=None):
1062
    '''
1063
    @param col If type_ is not set, must have an underlying column.
1064
    @param type_ If set, overrides the underlying column's type and casts the
1065
        column to it if needed.
1066
    @return EnsureNotNull|Col
1067
    @throws ensure_not_null_excs
1068
    '''
1069
    col = remove_col_rename(col)
1070
    
1071
    try: col_type = db.col_info(underlying_col(col)).type
1072
    except NoUnderlyingTableException:
1073
        if type_ == None and is_null(col): raise # NULL has no type
1074
    else:
1075
        if type_ == None: type_ = col_type
1076
        elif type_ != col_type: col = Cast(type_, col)
1077
    
1078
    if is_nullable(db, col):
1079
        try: col = EnsureNotNull(col, type_)
1080
        except KeyError, e:
1081
            # Warn of no null sentinel for type, even if caller catches error
1082
            warnings.warn(UserWarning(exc.str_(e)))
1083
            raise
1084
    
1085
    return col
1086

    
1087
def try_mk_not_null(db, value):
1088
    '''
1089
    Warning: This function does not guarantee that its result is NOT NULL.
1090
    '''
1091
    try: return ensure_not_null(db, value)
1092
    except ensure_not_null_excs: return value
1093

    
1094
##### Expression transforming
1095

    
1096
true_expr = 'true'
1097
false_expr = 'false'
1098

    
1099
true_re = true_expr
1100
false_re = false_expr
1101
bool_re = r'(?:'+true_re+r'|'+false_re+r')'
1102
atom_re = r'(?:'+bool_re+r'|\([^()]*\)'+r')'
1103

    
1104
def logic_op_re(op, value_re, expr_re=''):
1105
    op_re = ' '+op+' '
1106
    return '(?:'+expr_re+op_re+value_re+'|'+value_re+op_re+expr_re+')'
1107

    
1108
and_false_re = logic_op_re('AND', false_re, atom_re)
1109
and_true_re = logic_op_re('AND', true_re)
1110
or_re = logic_op_re('OR', bool_re)
1111
or_and_true_re = '(?:'+and_true_re+'|'+or_re+')'
1112

    
1113
def simplify_parens(expr):
1114
    return regexp.sub_nested(r'\(('+atom_re+')\)', r'\1', expr)
1115

    
1116
def simplify_recursive(sub_func, expr):
1117
    '''
1118
    @param sub_func See regexp.sub_recursive() sub_func param
1119
    '''
1120
    return simplify_parens(regexp.sub_recursive(
1121
        lambda s: sub_func(simplify_parens(s)), expr))
1122

    
1123
def simplify_expr(expr):
1124
    def simplify_logic_ops(expr):
1125
        total_n = 0
1126
        expr, n = re.subn(and_false_re, false_expr, expr)
1127
        total_n += n
1128
        expr, n = re.subn(or_and_true_re, r'', expr)
1129
        total_n += n
1130
        return expr, total_n
1131
    
1132
    expr = expr.replace('(NULL IS NULL)', true_expr)
1133
    expr = expr.replace('(NULL IS NOT NULL)', false_expr)
1134
    expr = simplify_recursive(simplify_logic_ops, expr)
1135
    return expr
1136

    
1137
name_re = r'(?:\w+|(?:"[^"]*")+)'
1138

    
1139
def parse_expr_col(str_):
1140
    match = re.match(r'^\('+name_re+r'\(('+name_re+r').*\)\)$', str_)
1141
    if match: str_ = match.group(1)
1142
    return unesc_name(str_)
1143

    
1144
def map_expr(db, expr, mapping, in_cols_found=None):
1145
    '''Replaces output columns with input columns in an expression.
1146
    @param in_cols_found If set, will be filled in with the expr's (input) cols
1147
    '''
1148
    for out, in_ in mapping.iteritems():
1149
        orig_expr = expr
1150
        out = to_name_only_col(out)
1151
        in_str = to_name_only_col(remove_col_rename(in_)).to_str(db)
1152
        
1153
        # Replace out both with and without quotes
1154
        expr = expr.replace(out.to_str(db), in_str)
1155
        expr = re.sub(r'(?<!["\'\.\[])\b'+out.name+r'\b(?!["\'\.=\]])', in_str,
1156
            expr)
1157
        
1158
        if in_cols_found != None and expr != orig_expr: # replaced something
1159
            in_cols_found.append(in_)
1160
    
1161
    return simplify_expr(expr)
(29-29/42)