Project

General

Profile

1 2211 aaronmk
# SQL code generation
2
3
import sql
4 2222 aaronmk
import strings
5 2227 aaronmk
import util
6 2211 aaronmk
7 2219 aaronmk
##### SQL code objects
8
9 2211 aaronmk
class Code:
10
    def to_str(self, db): raise NotImplemented()
11 2228 aaronmk
12
    def __str__(self): return str(self.__dict__)
13 2211 aaronmk
14 2256 aaronmk
class CustomCode:
15
    def __init__(self, str_): self.str_ = str_
16
17
    def to_str(self, db): return self.str_
18
19 2216 aaronmk
class Literal(Code):
20 2211 aaronmk
    def __init__(self, value): self.value = value
21 2213 aaronmk
22
    def to_str(self, db): return db.esc_value(self.value)
23 2211 aaronmk
24 2216 aaronmk
def is_null(value): return isinstance(value, Literal) and value.value == None
25
26 2211 aaronmk
class Table(Code):
27
    def __init__(self, name, schema=None):
28
        '''
29
        @param schema str|None (for no schema)
30
        '''
31
        self.name = name
32
        self.schema = schema
33
34
    def to_str(self, db): return sql.qual_name(db, self.schema, self.name)
35
36 2219 aaronmk
def as_Table(table):
37
    if table == None or isinstance(table, Table): return table
38
    elif isinstance(table, tuple):
39
        schema, table = table
40
        return Table(table, schema)
41
    else: return Table(table)
42
43 2211 aaronmk
class Col(Code):
44
    def __init__(self, name, table=None):
45
        '''
46
        @param table Table|None (for no table)
47
        '''
48 2241 aaronmk
        if util.is_str(table): table = Table(table)
49 2211 aaronmk
        assert table == None or isinstance(table, Table)
50
51
        self.name = name
52
        self.table = table
53
54
    def to_str(self, db):
55
        str_ = ''
56
        if self.table != None: str_ += self.table.to_str(db)+'.'
57
        str_ += sql.esc_name(db, self.name)
58
        return str_
59
60 2260 aaronmk
def as_Col(col, table=None):
61
    if col == None or isinstance(col, Code): return col
62
    else: return Col(col, table)
63
64 2229 aaronmk
class NamedCode(Code):
65
    def __init__(self, name, code):
66
        if not isinstance(code, Code): code = Literal(code)
67
68
        self.name = name
69
        self.code = code
70
71
    def to_str(self, db):
72
        return self.code.to_str(db)+' AS '+sql.esc_name(db, self.name)
73
74 2259 aaronmk
##### Parameterized SQL code objects
75
76 2214 aaronmk
class ValueCond:
77 2213 aaronmk
    def __init__(self, value):
78 2225 aaronmk
        if not isinstance(value, Code): value = Literal(value)
79 2213 aaronmk
80
        self.value = value
81 2214 aaronmk
82 2216 aaronmk
    def to_str(self, db, left_value):
83 2214 aaronmk
        '''
84 2216 aaronmk
        @param left_value The Code object that the condition is being applied on
85 2214 aaronmk
        '''
86
        raise NotImplemented()
87 2228 aaronmk
88
    def __str__(self): return str(self.__dict__)
89 2211 aaronmk
90
class CompareCond(ValueCond):
91
    def __init__(self, value, operator='='):
92 2222 aaronmk
        '''
93
        @param operator By default, compares NULL values literally. Use '~=' or
94
            '~!=' to pass NULLs through.
95
        '''
96 2211 aaronmk
        ValueCond.__init__(self, value)
97
        self.operator = operator
98
99 2216 aaronmk
    def to_str(self, db, left_value):
100
        if not isinstance(left_value, Code): left_value = Col(left_value)
101
102 2222 aaronmk
        right_value = self.value
103
        left = left_value.to_str(db)
104
        right = right_value.to_str(db)
105
106
        # Parse operator
107 2216 aaronmk
        operator = self.operator
108 2222 aaronmk
        passthru_null_ref = [False]
109
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
110
        neg_ref = [False]
111
        operator = strings.remove_prefix('!', operator, neg_ref)
112
        equals = operator.endswith('=')
113
        if equals and is_null(self.value): operator = 'IS'
114
115
        # Create str
116
        str_ = left+' '+operator+' '+right
117
        if equals and not passthru_null_ref[0] and isinstance(right_value, Col):
118
            str_ += ' OR ('+left+' IS NULL AND '+right+' IS NULL)'
119
        if neg_ref[0]: str_ = 'NOT ('+str_+')'
120
        return str_
121 2216 aaronmk
122 2260 aaronmk
# Tells as_ValueCond() to assume a non-ValueCond is a literal value
123
assume_literal = object()
124
125
def as_ValueCond(value, default_table=assume_literal):
126
    if not isinstance(value, ValueCond):
127
        if default_table is not assume_literal:
128
            value = as_Col(value, default_table)
129
        return CompareCond(value)
130 2216 aaronmk
    else: return value
131 2219 aaronmk
132 2260 aaronmk
join_using = object() # tells Join to join the column with USING
133
134
filter_out = object() # tells Join to filter out rows that match the join
135
136
class Join(Code):
137
    def __init__(self, table, mapping, type_=None):
138
        '''
139
        @param mapping dict(right_table_col=left_table_col, ...)
140
            * if left_table_col is join_using: left_table_col = right_table_col
141
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
142
            * filter_out: equivalent to 'LEFT' with the query filtered by
143
              `table_pkey IS NULL` (indicating no match)
144
        '''
145
        if util.is_str(table): table = Table(table)
146
        assert type_ == None or util.is_str(type_) or type_ is filter_out
147
148
        self.table = table
149
        self.mapping = mapping
150
        self.type_ = type_
151
152
    def to_str(self, db, left_table):
153
        def join(entry):
154
            '''Parses non-USING joins'''
155
            right_table_col, left_table_col = entry
156
157
            # Parse special values
158
            if left_table_col is join_using: left_table_col = right_table_col
159
160
            cond = as_ValueCond(right_table_col, self.table)
161
            return cond.to_str(db, as_Col(left_table_col, left_table))
162
163
        # Create join condition and determine join type
164
        if reduce(operator.and_, (v is join_using for v in joins.itervalues())):
165
            # all cols w/ USING, so can use simpler USING syntax
166
            join_cond = 'USING ('+(', '.join(joins.iterkeys()))+')'
167
        else: join_cond = 'ON '+(' AND '.join(map(join, joins.iteritems())))
168
169
        # Create join
170
        type_ = self.type_
171
        if type_ is filter_out: type_ = 'LEFT'
172
        return type_+' JOIN '+table+' '+join_cond
173
174 2219 aaronmk
##### Old-style format support
175
176
def unescape_table(table):
177
    '''Currently only works with PostgreSQL.'''
178
    if table == None: return table
179
180
    assert table.count('.') <= 1
181
    parts = tuple((v.replace('"', '') for v in table.split('"."', 2)))
182
    if len(parts) == 1: parts, = parts
183
    return parts
184
185 2262 aaronmk
def table2sql_gen(table, table_is_esc=False):
186
    '''Converts old-style (tuple-based) tables to sql_gen-compatible values.
187
    @param table_is_esc If False, assumes any table name is not escaped or that
188
        re-escaping it will produce the same value.
189
    '''
190
    if util.is_str(table) and table_is_esc: table = unescape_table(table)
191
    return as_Table(table)
192
193 2223 aaronmk
def col2sql_gen(col, default_table=None, table_is_esc=False):
194 2262 aaronmk
    '''Converts old-style (tuple-based) columns to sql_gen-compatible values.
195 2223 aaronmk
    @param table_is_esc If False, assumes any table name is not escaped or that
196
        re-escaping it will produce the same value.
197
    '''
198
    if isinstance(col, Col): return col # already in sql_gen form
199
200
    table = default_table
201
    if isinstance(col, tuple): table, col = col
202
    if table_is_esc: table = unescape_table(table)
203
    return Col(col, as_Table(table))
204
205 2227 aaronmk
def value2sql_gen(value, default_table=None, table_is_esc=False,
206
    assume_col=False):
207 2219 aaronmk
    '''Converts old-style (tuple-based) values to sql_gen-compatible values.
208
    @param table_is_esc If False, assumes any table name is not escaped or that
209
        re-escaping it will produce the same value.
210
    '''
211
    if isinstance(value, Code): return value # already in sql_gen form
212
213 2227 aaronmk
    is_tuple = isinstance(value, tuple)
214
    if is_tuple and len(value) == 1: return Literal(value[0])
215
    if is_tuple or (assume_col and util.is_str(value)):
216
        return col2sql_gen(value, default_table, table_is_esc)
217
    else: return Literal(value)
218 2225 aaronmk
219 2237 aaronmk
def cond2sql_gen(value, default_table=None, table_is_esc=False,
220
    assume_col=False):
221 2225 aaronmk
    '''Converts old-style (tuple-based) conditions to sql_gen-compatible values.
222
    @param table_is_esc If False, assumes any table name is not escaped or that
223
        re-escaping it will produce the same value.
224
    '''
225
    if isinstance(value, ValueCond): return value # already in sql_gen form
226
227 2237 aaronmk
    return as_ValueCond(value2sql_gen(value, default_table, table_is_esc,
228
        assume_col))
229 2261 aaronmk
230
def join2sql_gen(value, table_is_esc=False):
231
    '''Converts old-style (tuple-based) joins to sql_gen-compatible values.
232
    @param table_is_esc If False, assumes any table name is not escaped or that
233
        re-escaping it will produce the same value.
234
    '''
235
    if isinstance(value, Join): return value # already in sql_gen form
236
237
    assert isinstance(value, tuple)
238
    table, joins = value
239
    if table_is_esc: table = unescape_table(table)
240
    return Join(table, joins)