Project

General

Profile

1
# SQL code generation
2

    
3
import sql
4
import strings
5
import util
6

    
7
##### SQL code objects
8

    
9
class Code:
10
    def to_str(self, db): raise NotImplemented()
11
    
12
    def __str__(self): return str(self.__dict__)
13

    
14
class CustomCode(Code):
15
    def __init__(self, str_): self.str_ = str_
16
    
17
    def to_str(self, db): return self.str_
18

    
19
class Literal(Code):
20
    def __init__(self, value): self.value = value
21
    
22
    def to_str(self, db): return db.esc_value(self.value)
23

    
24
def is_null(value): return isinstance(value, Literal) and value.value == None
25

    
26
class Table(Code):
27
    def __init__(self, name, schema=None):
28
        '''
29
        @param schema str|None (for no schema)
30
        '''
31
        self.name = name
32
        self.schema = schema
33
    
34
    def to_str(self, db): return sql.qual_name(db, self.schema, self.name)
35

    
36
def as_Table(table):
37
    if table == None or isinstance(table, Code): return table
38
    elif isinstance(table, tuple):
39
        schema, table = table
40
        return Table(table, schema)
41
    else: return Table(table)
42

    
43
class Col(Code):
44
    def __init__(self, name, table=None):
45
        '''
46
        @param table Table|None (for no table)
47
        '''
48
        if util.is_str(table): table = Table(table)
49
        assert table == None or isinstance(table, Table)
50
        
51
        self.name = name
52
        self.table = table
53
    
54
    def to_str(self, db):
55
        str_ = ''
56
        if self.table != None: str_ += self.table.to_str(db)+'.'
57
        str_ += sql.esc_name(db, self.name)
58
        return str_
59

    
60
def as_Col(col, table=None):
61
    if col == None or isinstance(col, Code): return col
62
    else: return Col(col, table)
63

    
64
class NamedCode(Code):
65
    def __init__(self, name, code):
66
        if not isinstance(code, Code): code = Literal(code)
67
        
68
        self.name = name
69
        self.code = code
70
    
71
    def to_str(self, db):
72
        return self.code.to_str(db)+' AS '+sql.esc_name(db, self.name)
73

    
74
##### Parameterized SQL code objects
75

    
76
class ValueCond:
77
    def __init__(self, value):
78
        if not isinstance(value, Code): value = Literal(value)
79
        
80
        self.value = value
81
    
82
    def to_str(self, db, left_value):
83
        '''
84
        @param left_value The Code object that the condition is being applied on
85
        '''
86
        raise NotImplemented()
87
    
88
    def __str__(self): return str(self.__dict__)
89

    
90
class CompareCond(ValueCond):
91
    def __init__(self, value, operator='='):
92
        '''
93
        @param operator By default, compares NULL values literally. Use '~=' or
94
            '~!=' to pass NULLs through.
95
        '''
96
        ValueCond.__init__(self, value)
97
        self.operator = operator
98
    
99
    def to_str(self, db, left_value):
100
        if not isinstance(left_value, Code): left_value = Col(left_value)
101
        
102
        right_value = self.value
103
        left = left_value.to_str(db)
104
        right = right_value.to_str(db)
105
        
106
        # Parse operator
107
        operator = self.operator
108
        passthru_null_ref = [False]
109
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
110
        neg_ref = [False]
111
        operator = strings.remove_prefix('!', operator, neg_ref)
112
        equals = operator.endswith('=')
113
        if equals and is_null(self.value): operator = 'IS'
114
        
115
        # Create str
116
        str_ = left+' '+operator+' '+right
117
        if equals and not passthru_null_ref[0] and isinstance(right_value, Col):
118
            str_ += ' OR ('+left+' IS NULL AND '+right+' IS NULL)'
119
        if neg_ref[0]: str_ = 'NOT ('+str_+')'
120
        return str_
121

    
122
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
123
assume_literal = object()
124

    
125
def as_ValueCond(value, default_table=assume_literal):
126
    if not isinstance(value, ValueCond):
127
        if default_table is not assume_literal:
128
            value = as_Col(value, default_table)
129
        return CompareCond(value)
130
    else: return value
131

    
132
join_using = object() # tells Join to join the column with USING
133

    
134
filter_out = object() # tells Join to filter out rows that match the join
135

    
136
class Join(Code):
137
    def __init__(self, table, mapping, type_=None):
138
        '''
139
        @param mapping dict(right_table_col=left_table_col, ...)
140
            * if left_table_col is join_using: left_table_col = right_table_col
141
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
142
            * filter_out: equivalent to 'LEFT' with the query filtered by
143
              `table_pkey IS NULL` (indicating no match)
144
        '''
145
        if util.is_str(table): table = Table(table)
146
        assert type_ == None or util.is_str(type_) or type_ is filter_out
147
        
148
        self.table = table
149
        self.mapping = mapping
150
        self.type_ = type_
151
    
152
    def to_str(self, db, left_table):
153
        def join(entry):
154
            '''Parses non-USING joins'''
155
            right_table_col, left_table_col = entry
156
            
157
            # Parse special values
158
            if left_table_col is join_using: left_table_col = right_table_col
159
            
160
            cond = as_ValueCond(right_table_col, self.table)
161
            return cond.to_str(db, as_Col(left_table_col, left_table))
162
        
163
        # Create join condition
164
        type_ = self.type_
165
        if type_ is not filter_out and reduce(operator.and_,
166
            (v is join_using for v in joins.itervalues())):
167
            # all cols w/ USING, so can use simpler USING syntax
168
            join_cond = 'USING ('+(', '.join(joins.iterkeys()))+')'
169
        else: join_cond = 'ON '+(' AND '.join(map(join, joins.iteritems())))
170
        
171
        # Create join
172
        if type_ is filter_out: type_ = 'LEFT'
173
        str_ = ''
174
        if type_ != None: str_ += type_+' '
175
        str_ += 'JOIN '+table+' '+join_cond
176
        return str_
177

    
178
##### Old-style format support
179

    
180
def unescape_table(table):
181
    '''Currently only works with PostgreSQL.'''
182
    if table == None: return table
183
    
184
    assert table.count('.') <= 1
185
    parts = tuple((v.replace('"', '') for v in table.split('"."', 2)))
186
    if len(parts) == 1: parts, = parts
187
    return parts
188

    
189
def table2sql_gen(table, table_is_esc=False):
190
    '''Converts old-style (tuple-based) tables to sql_gen-compatible values.
191
    @param table_is_esc If False, assumes any table name is not escaped or that
192
        re-escaping it will produce the same value.
193
    '''
194
    if util.is_str(table) and table_is_esc: table = unescape_table(table)
195
    return as_Table(table)
196

    
197
def col2sql_gen(col, default_table=None, table_is_esc=False):
198
    '''Converts old-style (tuple-based) columns to sql_gen-compatible values.
199
    @param table_is_esc If False, assumes any table name is not escaped or that
200
        re-escaping it will produce the same value.
201
    '''
202
    if isinstance(col, Col): return col # already in sql_gen form
203
    
204
    table = default_table
205
    if isinstance(col, tuple): table, col = col
206
    return Col(col, table2sql_gen(table, table_is_esc))
207

    
208
def value2sql_gen(value, default_table=None, table_is_esc=False,
209
    assume_col=False):
210
    '''Converts old-style (tuple-based) values to sql_gen-compatible values.
211
    @param table_is_esc If False, assumes any table name is not escaped or that
212
        re-escaping it will produce the same value.
213
    '''
214
    if isinstance(value, Code): return value # already in sql_gen form
215
    
216
    is_tuple = isinstance(value, tuple)
217
    if is_tuple and len(value) == 1: return Literal(value[0])
218
    if is_tuple or (assume_col and util.is_str(value)):
219
        return col2sql_gen(value, default_table, table_is_esc)
220
    else: return Literal(value)
221

    
222
def cond2sql_gen(value, default_table=None, table_is_esc=False,
223
    assume_col=False):
224
    '''Converts old-style (tuple-based) conditions to sql_gen-compatible values.
225
    @param table_is_esc If False, assumes any table name is not escaped or that
226
        re-escaping it will produce the same value.
227
    '''
228
    if isinstance(value, ValueCond): return value # already in sql_gen form
229
    
230
    return as_ValueCond(value2sql_gen(value, default_table, table_is_esc,
231
        assume_col))
232

    
233
def join2sql_gen(value, table_is_esc=False):
234
    '''Converts old-style (tuple-based) joins to sql_gen-compatible values.
235
    @param table_is_esc If False, assumes any table name is not escaped or that
236
        re-escaping it will produce the same value.
237
    '''
238
    if isinstance(value, Join): return value # already in sql_gen form
239
    
240
    assert isinstance(value, tuple)
241
    table, joins = value
242
    return Join(table2sql_gen(table, table_is_esc), joins)
(23-23/34)