Project

General

Profile

1
# SQL code generation
2

    
3
import sql
4

    
5
##### SQL code objects
6

    
7
class Code:
8
    def to_str(self, db): raise NotImplemented()
9

    
10
class Literal(Code):
11
    def __init__(self, value): self.value = value
12
    
13
    def to_str(self, db): return db.esc_value(self.value)
14

    
15
def is_null(value): return isinstance(value, Literal) and value.value == None
16

    
17
class Table(Code):
18
    def __init__(self, name, schema=None):
19
        '''
20
        @param schema str|None (for no schema)
21
        '''
22
        self.name = name
23
        self.schema = schema
24
    
25
    def to_str(self, db): return sql.qual_name(db, self.schema, self.name)
26

    
27
def as_Table(table):
28
    if table == None or isinstance(table, Table): return table
29
    elif isinstance(table, tuple):
30
        schema, table = table
31
        return Table(table, schema)
32
    else: return Table(table)
33

    
34
class Col(Code):
35
    def __init__(self, name, table=None):
36
        '''
37
        @param table Table|None (for no table)
38
        '''
39
        assert table == None or isinstance(table, Table)
40
        
41
        self.name = name
42
        self.table = table
43
    
44
    def to_str(self, db):
45
        str_ = ''
46
        if self.table != None: str_ += self.table.to_str(db)+'.'
47
        str_ += sql.esc_name(db, self.name)
48
        return str_
49

    
50
class ValueCond:
51
    def __init__(self, value):
52
        if not isinstance(value, Literal): value = Literal(value)
53
        
54
        self.value = value
55
    
56
    def to_str(self, db, left_value):
57
        '''
58
        @param left_value The Code object that the condition is being applied on
59
        '''
60
        raise NotImplemented()
61

    
62
class CompareCond(ValueCond):
63
    def __init__(self, value, operator='='):
64
        ValueCond.__init__(self, value)
65
        self.operator = operator
66
    
67
    def to_str(self, db, left_value):
68
        if not isinstance(left_value, Code): left_value = Col(left_value)
69
        
70
        operator = self.operator
71
        if is_null(self.value): operator = 'IS'
72
        return left_value.to_str(db)+' '+operator+' '+self.value.to_str(db)
73

    
74
def as_ValueCond(value):
75
    if not isinstance(value, ValueCond): return CompareCond(value)
76
    else: return value
77

    
78
##### Old-style format support
79

    
80
def unescape_table(table):
81
    '''Currently only works with PostgreSQL.'''
82
    if table == None: return table
83
    
84
    assert table.count('.') <= 1
85
    parts = tuple((v.replace('"', '') for v in table.split('"."', 2)))
86
    if len(parts) == 1: parts, = parts
87
    return parts
88

    
89
def value2sql_gen(value, default_table=None, table_is_esc=False):
90
    '''Converts old-style (tuple-based) values to sql_gen-compatible values.
91
    @param table_is_esc If False, assumes any table name is not escaped or that
92
        re-escaping it will produce the same value.
93
    '''
94
    if isinstance(value, Code): return value # already in sql_gen form
95
    
96
    if table_is_esc: default_table = unescape_table(default_table)
97
    is_tuple = isinstance(value, tuple)
98
    if is_tuple and len(value) == 1: return Literal(value[0]) # value is literal
99
    elif is_tuple and len(value) == 2: # value is col with table
100
        table, col = value
101
        if table_is_esc: table = unescape_table(table)
102
        return Col(col, as_Table(table))
103
    else: return Col(value, as_Table(default_table)) # value is col name
(23-23/34)