Project

General

Profile

1
# Database access
2

    
3
import random
4
import re
5
import sys
6

    
7
import exc
8
import util
9

    
10
def get_cur_query(cur):
11
    if hasattr(cur, 'query'): return cur.query
12
    elif hasattr(cur, '_last_executed'): return cur._last_executed
13
    else: return None
14

    
15
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
16

    
17
class DbException(exc.ExceptionWithCause):
18
    def __init__(self, msg, cause=None, cur=None):
19
        exc.ExceptionWithCause.__init__(self, msg, cause)
20
        if cur != None: _add_cursor_info(self, cur)
21

    
22
class NameException(DbException): pass
23

    
24
class ExceptionWithColumns(DbException):
25
    def __init__(self, cols, cause=None):
26
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
27
        self.cols = cols
28

    
29
class DuplicateKeyException(ExceptionWithColumns): pass
30

    
31
class NullValueException(ExceptionWithColumns): pass
32

    
33
class EmptyRowException(DbException): pass
34

    
35
def check_name(name):
36
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
37
        +'" may contain only alphanumeric characters and _')
38

    
39
def esc_name(db, name):
40
    module = util.root_module(db)
41
    if module == 'psycopg2': return name
42
        # Don't enclose in quotes because this disables case-insensitivity
43
    elif module == 'MySQLdb': quote = '`'
44
    else: raise NotImplementedError("Can't escape name for "+module+' database')
45
    return quote + name.replace(quote, '') + quote
46

    
47
def run_raw_query(db, query, params=None):
48
    cur = db.cursor()
49
    try: cur.execute(query, params)
50
    except Exception, e:
51
        _add_cursor_info(e, cur)
52
        raise
53
    return cur
54

    
55
def col(cur, idx): return cur.description[idx][0]
56

    
57
def rows(cur): return iter(lambda: cur.fetchone(), None)
58

    
59
def row(cur): return rows(cur).next()
60

    
61
def value(cur): return row(cur)[0]
62

    
63
def values(cur): return iter(lambda: value(cur), None)
64

    
65
def value_or_none(cur):
66
    try: return value(cur)
67
    except StopIteration: return None
68

    
69
def with_savepoint(db, func):
70
    savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique
71
    run_raw_query(db, 'SAVEPOINT '+savepoint)
72
    try: return_val = func()
73
    except:
74
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
75
        raise
76
    else:
77
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
78
        return return_val
79

    
80
def run_query(db, query, params=None, recover=None):
81
    if recover == None: recover = False
82
    
83
    def run(): return run_raw_query(db, query, params)
84
    if recover: return with_savepoint(db, run)
85
    else: return run()
86

    
87
def select(db, table, fields, conds, limit=None, recover=None):
88
    assert limit == None or type(limit) == int
89
    check_name(table)
90
    map(check_name, fields)
91
    map(check_name, conds.keys())
92
    def cond(entry):
93
        col, value = entry
94
        cond_ = esc_name(db, col)+' '
95
        if value == None: cond_ += 'IS'
96
        else: cond_ += '='
97
        cond_ += ' %s'
98
        return cond_
99
    query = ('SELECT ' + ', '.join([esc_name(db, field) for field in fields])
100
        + ' FROM '+esc_name(db, table))
101
    if conds != {}:
102
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
103
    if limit != None: query += ' LIMIT '+str(limit)
104
    return run_query(db, query, conds.values(), recover)
105

    
106
def insert(db, table, row, recover=None):
107
    check_name(table)
108
    cols = row.keys()
109
    map(check_name, cols)
110
    query = 'INSERT INTO '+table
111
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
112
        +', '.join(['%s']*len(cols))+')'
113
    else: query += ' DEFAULT VALUES'
114
    return run_query(db, query, row.values(), recover)
115

    
116
def last_insert_id(db):
117
    module = util.root_module(db)
118
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
119
    elif module == 'MySQLdb': return db.insert_id()
120
    else: return None
121

    
122
def constraint_cols(db, table, constraint):
123
    check_name(table)
124
    check_name(constraint)
125
    module = util.root_module(db)
126
    if module == 'psycopg2':
127
        return list(values(run_query(db, '''\
128
SELECT attname
129
FROM pg_constraint
130
JOIN pg_class ON pg_class.oid = conrelid
131
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
132
WHERE
133
    relname = %(table)s
134
    AND conname = %(constraint)s
135
ORDER BY attnum
136
''',
137
            {'table': table, 'constraint': constraint})))
138
    else: raise NotImplementedError("Can't list constraint columns for "+module+
139
        ' database')
140

    
141
def pkey(db, cache, table, recover=None):
142
    '''Assumed to be first column in table'''
143
    check_name(table)
144
    if table not in cache:
145
        cache[table] = col(run_query(db, 'SELECT * FROM '+table+' LIMIT 0',
146
            recover), 0)
147
    return cache[table]
148

    
149
def try_insert(db, table, row):
150
    '''Recovers from errors'''
151
    try: return insert(db, table, row, recover=True)
152
    except Exception, e:
153
        msg = str(e)
154
        match = re.search(r'duplicate key value violates unique constraint '
155
            r'"(([^\W_]+)_[^"]+)"', msg)
156
        if match:
157
            constraint, table = match.groups()
158
            try: cols = constraint_cols(db, table, constraint)
159
            except NotImplementedError: raise e
160
            else: raise DuplicateKeyException(cols[0], e)
161
        match = re.search(r'null value in column "(\w+)" violates not-null '
162
            'constraint', msg)
163
        if match: raise NullValueException([match.group(1)], e)
164
        raise # no specific exception raised
165

    
166
def put(db, table, row, pkey, row_ct_ref=None):
167
    '''Recovers from errors'''
168
    try:
169
        row_ct = try_insert(db, table, row).rowcount
170
        if row_ct_ref != None and row_ct >= 0: row_ct_ref[0] += row_ct
171
        return last_insert_id(db)
172
    except DuplicateKeyException, e:
173
        return value(select(db, table, [pkey], util.dict_subset(row, e.cols),
174
            recover=True))
175

    
176
def get(db, table, row, pkey, row_ct_ref=None, create=False):
177
    '''Recovers from errors'''
178
    try: return value(select(db, table, [pkey], row, 1, recover=True))
179
    except StopIteration:
180
        if not create: raise
181
        return put(db, table, row, pkey, row_ct_ref) # insert new row
182

    
183

    
184
def truncate(db, table):
185
    check_name(table)
186
    return run_query(db, 'TRUNCATE '+table+' CASCADE')
187

    
188
def tables(db):
189
    module = util.root_module(db)
190
    if module == 'psycopg2':
191
        return values(run_query(db, "SELECT tablename from pg_tables "
192
            "WHERE schemaname = 'public' ORDER BY tablename"))
193
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
194
    else: raise NotImplementedError("Can't list tables for "+module+' database')
195

    
196
def empty_db(db):
197
    for table in tables(db): truncate(db, table)
198

    
199
db_engines = {
200
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
201
    'PostgreSQL': ('psycopg2', {}),
202
}
203

    
204
DatabaseErrors_set = set([DbException])
205
DatabaseErrors = tuple(DatabaseErrors_set)
206

    
207
def _add_module(module):
208
    DatabaseErrors_set.add(module.DatabaseError)
209
    global DatabaseErrors
210
    DatabaseErrors = tuple(DatabaseErrors_set)
211

    
212
def connect(db_config, serializable=True):
213
    db_config = db_config.copy() # don't modify input!
214
    module_name, mappings = db_engines[db_config.pop('engine')]
215
    module = __import__(module_name)
216
    _add_module(module)
217
    for orig, new in mappings.iteritems():
218
        try: util.rename_key(db_config, orig, new)
219
        except KeyError: pass
220
    db = module.connect(**db_config)
221
    if serializable:
222
        run_raw_query(db, 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
223
    return db
224

    
225
def db_config_str(db_config):
226
    return db_config['engine']+' database '+db_config['database']
(8-8/14)