Project

General

Profile

1
# SQL code generation
2

    
3
import sql
4
import strings
5

    
6
##### SQL code objects
7

    
8
class Code:
9
    def to_str(self, db): raise NotImplemented()
10

    
11
class Literal(Code):
12
    def __init__(self, value): self.value = value
13
    
14
    def to_str(self, db): return db.esc_value(self.value)
15

    
16
def is_null(value): return isinstance(value, Literal) and value.value == None
17

    
18
class Table(Code):
19
    def __init__(self, name, schema=None):
20
        '''
21
        @param schema str|None (for no schema)
22
        '''
23
        self.name = name
24
        self.schema = schema
25
    
26
    def to_str(self, db): return sql.qual_name(db, self.schema, self.name)
27

    
28
def as_Table(table):
29
    if table == None or isinstance(table, Table): return table
30
    elif isinstance(table, tuple):
31
        schema, table = table
32
        return Table(table, schema)
33
    else: return Table(table)
34

    
35
class Col(Code):
36
    def __init__(self, name, table=None):
37
        '''
38
        @param table Table|None (for no table)
39
        '''
40
        assert table == None or isinstance(table, Table)
41
        
42
        self.name = name
43
        self.table = table
44
    
45
    def to_str(self, db):
46
        str_ = ''
47
        if self.table != None: str_ += self.table.to_str(db)+'.'
48
        str_ += sql.esc_name(db, self.name)
49
        return str_
50

    
51
class ValueCond:
52
    def __init__(self, value):
53
        if not isinstance(value, Code): value = Literal(value)
54
        
55
        self.value = value
56
    
57
    def to_str(self, db, left_value):
58
        '''
59
        @param left_value The Code object that the condition is being applied on
60
        '''
61
        raise NotImplemented()
62

    
63
class CompareCond(ValueCond):
64
    def __init__(self, value, operator='='):
65
        '''
66
        @param operator By default, compares NULL values literally. Use '~=' or
67
            '~!=' to pass NULLs through.
68
        '''
69
        ValueCond.__init__(self, value)
70
        self.operator = operator
71
    
72
    def to_str(self, db, left_value):
73
        if not isinstance(left_value, Code): left_value = Col(left_value)
74
        
75
        right_value = self.value
76
        left = left_value.to_str(db)
77
        right = right_value.to_str(db)
78
        
79
        # Parse operator
80
        operator = self.operator
81
        passthru_null_ref = [False]
82
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
83
        neg_ref = [False]
84
        operator = strings.remove_prefix('!', operator, neg_ref)
85
        equals = operator.endswith('=')
86
        if equals and is_null(self.value): operator = 'IS'
87
        
88
        # Create str
89
        str_ = left+' '+operator+' '+right
90
        if equals and not passthru_null_ref[0] and isinstance(right_value, Col):
91
            str_ += ' OR ('+left+' IS NULL AND '+right+' IS NULL)'
92
        if neg_ref[0]: str_ = 'NOT ('+str_+')'
93
        return str_
94

    
95
def as_ValueCond(value):
96
    if not isinstance(value, ValueCond): return CompareCond(value)
97
    else: return value
98

    
99
##### Old-style format support
100

    
101
def unescape_table(table):
102
    '''Currently only works with PostgreSQL.'''
103
    if table == None: return table
104
    
105
    assert table.count('.') <= 1
106
    parts = tuple((v.replace('"', '') for v in table.split('"."', 2)))
107
    if len(parts) == 1: parts, = parts
108
    return parts
109

    
110
def col2sql_gen(col, default_table=None, table_is_esc=False):
111
    '''Converts old-style (tuple-based) columns to sql_gen-compatible columns.
112
    @param table_is_esc If False, assumes any table name is not escaped or that
113
        re-escaping it will produce the same value.
114
    '''
115
    if isinstance(col, Col): return col # already in sql_gen form
116
    
117
    table = default_table
118
    if isinstance(col, tuple): table, col = col
119
    if table_is_esc: table = unescape_table(table)
120
    return Col(col, as_Table(table))
121

    
122
def value2sql_gen(value, default_table=None, table_is_esc=False):
123
    '''Converts old-style (tuple-based) values to sql_gen-compatible values.
124
    @param table_is_esc If False, assumes any table name is not escaped or that
125
        re-escaping it will produce the same value.
126
    '''
127
    if isinstance(value, Code): return value # already in sql_gen form
128
    
129
    if isinstance(value, tuple) and len(value) == 1: return Literal(value[0])
130
    else: return col2sql_gen(value, default_table, table_is_esc)
131

    
132
def cond2sql_gen(value, default_table=None, table_is_esc=False):
133
    '''Converts old-style (tuple-based) conditions to sql_gen-compatible values.
134
    @param table_is_esc If False, assumes any table name is not escaped or that
135
        re-escaping it will produce the same value.
136
    '''
137
    if isinstance(value, ValueCond): return value # already in sql_gen form
138
    
139
    return as_ValueCond(value2sql_gen(value, default_table, table_is_esc))
(23-23/34)