Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import random
5
import re
6
import sys
7 865 aaronmk
import warnings
8 11 aaronmk
9 300 aaronmk
import exc
10 862 aaronmk
import strings
11 131 aaronmk
import util
12 11 aaronmk
13 832 aaronmk
##### Exceptions
14
15 135 aaronmk
def get_cur_query(cur):
16
    if hasattr(cur, 'query'): return cur.query
17
    elif hasattr(cur, '_last_executed'): return cur._last_executed
18
    else: return None
19 14 aaronmk
20 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
21 135 aaronmk
22 300 aaronmk
class DbException(exc.ExceptionWithCause):
23 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
24 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
25 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
26
27 360 aaronmk
class NameException(DbException): pass
28
29 468 aaronmk
class ExceptionWithColumns(DbException):
30
    def __init__(self, cols, cause=None):
31
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
32
        self.cols = cols
33 11 aaronmk
34 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
35 13 aaronmk
36 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
37 13 aaronmk
38 89 aaronmk
class EmptyRowException(DbException): pass
39
40 865 aaronmk
##### Warnings
41
42
class DbWarning(UserWarning): pass
43
44 832 aaronmk
##### Input validation
45
46 11 aaronmk
def check_name(name):
47
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
48
        +'" may contain only alphanumeric characters and _')
49
50 643 aaronmk
def esc_name(db, name):
51 1849 aaronmk
    module = util.root_module(db.db)
52 645 aaronmk
    if module == 'psycopg2': return name
53
        # Don't enclose in quotes because this disables case-insensitivity
54 643 aaronmk
    elif module == 'MySQLdb': quote = '`'
55 645 aaronmk
    else: raise NotImplementedError("Can't escape name for "+module+' database')
56 643 aaronmk
    return quote + name.replace(quote, '') + quote
57
58 1869 aaronmk
##### Database connections
59 1849 aaronmk
60 1869 aaronmk
db_engines = {
61
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
62
    'PostgreSQL': ('psycopg2', {}),
63
}
64
65
DatabaseErrors_set = set([DbException])
66
DatabaseErrors = tuple(DatabaseErrors_set)
67
68
def _add_module(module):
69
    DatabaseErrors_set.add(module.DatabaseError)
70
    global DatabaseErrors
71
    DatabaseErrors = tuple(DatabaseErrors_set)
72
73
def db_config_str(db_config):
74
    return db_config['engine']+' database '+db_config['database']
75
76 1849 aaronmk
class DbConn:
77 1869 aaronmk
    def __init__(self, db_config, serializable=True):
78
        self.db_config = db_config
79
        self.serializable = serializable
80
81
        self.__db = None
82 1849 aaronmk
        self.pkeys = {}
83
        self.index_cols = {}
84 1869 aaronmk
85
    def __getattr__(self, name):
86
        if name == '__dict__': raise Exception('getting __dict__')
87
        if name == 'db': return self._db()
88
        else: raise AttributeError()
89
90
    def __getstate__(self):
91
        state = copy.copy(self.__dict__) # shallow copy
92
        state['_DbConn__db'] = None # don't pickle the connection
93
        return state
94
95
    def _db(self):
96
        if self.__db == None:
97
            # Process db_config
98
            db_config = self.db_config.copy() # don't modify input!
99
            module_name, mappings = db_engines[db_config.pop('engine')]
100
            module = __import__(module_name)
101
            _add_module(module)
102
            for orig, new in mappings.iteritems():
103
                try: util.rename_key(db_config, orig, new)
104
                except KeyError: pass
105
106
            # Connect
107
            self.__db = module.connect(**db_config)
108
109
            # Configure connection
110
            if self.serializable: run_raw_query(self,
111
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
112
113
        return self.__db
114 1849 aaronmk
115 1869 aaronmk
connect = DbConn
116
117 832 aaronmk
##### Querying
118
119 830 aaronmk
def run_raw_query(db, query, params=None):
120 1849 aaronmk
    cur = db.db.cursor()
121 11 aaronmk
    try: cur.execute(query, params)
122 46 aaronmk
    except Exception, e:
123
        _add_cursor_info(e, cur)
124 11 aaronmk
        raise
125 867 aaronmk
    if run_raw_query.debug:
126
        sys.stderr.write(strings.one_line(get_cur_query(cur))+'\n')
127 11 aaronmk
    return cur
128
129 832 aaronmk
##### Recoverable querying
130 15 aaronmk
131 11 aaronmk
def with_savepoint(db, func):
132
    savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique
133 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
134 11 aaronmk
    try: return_val = func()
135
    except:
136 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
137 11 aaronmk
        raise
138
    else:
139 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
140 11 aaronmk
        return return_val
141
142 830 aaronmk
def run_query(db, query, params=None, recover=None):
143
    if recover == None: recover = False
144
145
    def run(): return run_raw_query(db, query, params)
146
    if recover: return with_savepoint(db, run)
147
    else: return run()
148
149 832 aaronmk
##### Result retrieval
150
151 1135 aaronmk
def col_names(cur): return (col[0] for col in cur.description)
152 832 aaronmk
153
def rows(cur): return iter(lambda: cur.fetchone(), None)
154
155
def row(cur): return rows(cur).next()
156
157
def value(cur): return row(cur)[0]
158
159
def values(cur): return iter(lambda: value(cur), None)
160
161
def value_or_none(cur):
162
    try: return value(cur)
163
    except StopIteration: return None
164
165
##### Basic queries
166
167 1135 aaronmk
def select(db, table, fields=None, conds=None, limit=None, start=None,
168
    recover=None):
169
    '''@param fields Use None to select all fields in the table'''
170
    if conds == None: conds = {}
171 135 aaronmk
    assert limit == None or type(limit) == int
172 865 aaronmk
    assert start == None or type(start) == int
173 15 aaronmk
    check_name(table)
174 1135 aaronmk
    if fields != None: map(check_name, fields)
175 15 aaronmk
    map(check_name, conds.keys())
176 865 aaronmk
177 11 aaronmk
    def cond(entry):
178 13 aaronmk
        col, value = entry
179 644 aaronmk
        cond_ = esc_name(db, col)+' '
180 11 aaronmk
        if value == None: cond_ += 'IS'
181
        else: cond_ += '='
182
        cond_ += ' %s'
183
        return cond_
184 1135 aaronmk
    query = 'SELECT '
185
    if fields == None: query += '*'
186
    else: query += ', '.join([esc_name(db, field) for field in fields])
187
    query += ' FROM '+esc_name(db, table)
188 865 aaronmk
189
    missing = True
190 89 aaronmk
    if conds != {}:
191
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
192 865 aaronmk
        missing = False
193
    if limit != None: query += ' LIMIT '+str(limit); missing = False
194
    if start != None:
195
        if start != 0: query += ' OFFSET '+str(start)
196
        missing = False
197
    if missing: warnings.warn(DbWarning(
198
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
199
200 830 aaronmk
    return run_query(db, query, conds.values(), recover)
201 11 aaronmk
202 1554 aaronmk
def insert(db, table, row, returning=None, recover=None):
203
    '''@param returning str|None An inserted column (such as pkey) to return'''
204 11 aaronmk
    check_name(table)
205 13 aaronmk
    cols = row.keys()
206 15 aaronmk
    map(check_name, cols)
207 89 aaronmk
    query = 'INSERT INTO '+table
208 1554 aaronmk
209 89 aaronmk
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
210
        +', '.join(['%s']*len(cols))+')'
211
    else: query += ' DEFAULT VALUES'
212 1554 aaronmk
213
    if returning != None:
214
        check_name(returning)
215
        query += ' RETURNING '+returning
216
217 830 aaronmk
    return run_query(db, query, row.values(), recover)
218 11 aaronmk
219 135 aaronmk
def last_insert_id(db):
220 1849 aaronmk
    module = util.root_module(db.db)
221 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
222
    elif module == 'MySQLdb': return db.insert_id()
223
    else: return None
224 13 aaronmk
225 832 aaronmk
def truncate(db, table):
226
    check_name(table)
227 869 aaronmk
    return run_raw_query(db, 'TRUNCATE '+table+' CASCADE')
228 832 aaronmk
229
##### Database structure queries
230
231 1850 aaronmk
def pkey(db, table, recover=None):
232 832 aaronmk
    '''Assumed to be first column in table'''
233
    check_name(table)
234 1850 aaronmk
    if table not in db.pkeys:
235
        db.pkeys[table] = col_names(run_query(db,
236 1135 aaronmk
            'SELECT * FROM '+table+' LIMIT 0', recover=recover)).next()
237 1850 aaronmk
    return db.pkeys[table]
238 832 aaronmk
239 853 aaronmk
def index_cols(db, table, index):
240
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
241
    automatically created. When you don't know whether something is a UNIQUE
242
    constraint or a UNIQUE index, use this function.'''
243
    check_name(table)
244
    check_name(index)
245 1852 aaronmk
    lookup = (table, index)
246
    if lookup not in db.index_cols:
247
        module = util.root_module(db.db)
248
        if module == 'psycopg2':
249
            db.index_cols[lookup] = list(values(run_query(db, '''\
250 853 aaronmk
SELECT attname
251 866 aaronmk
FROM
252
(
253
        SELECT attnum, attname
254
        FROM pg_index
255
        JOIN pg_class index ON index.oid = indexrelid
256
        JOIN pg_class table_ ON table_.oid = indrelid
257
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
258
        WHERE
259
            table_.relname = %(table)s
260
            AND index.relname = %(index)s
261
    UNION
262
        SELECT attnum, attname
263
        FROM
264
        (
265
            SELECT
266
                indrelid
267
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
268
                    AS indkey
269
            FROM pg_index
270
            JOIN pg_class index ON index.oid = indexrelid
271
            JOIN pg_class table_ ON table_.oid = indrelid
272
            WHERE
273
                table_.relname = %(table)s
274
                AND index.relname = %(index)s
275
        ) s
276
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
277
) s
278 853 aaronmk
ORDER BY attnum
279
''',
280 1852 aaronmk
                {'table': table, 'index': index})))
281
        else: raise NotImplementedError("Can't list index columns for "+module+
282
            ' database')
283
    return db.index_cols[lookup]
284 853 aaronmk
285 464 aaronmk
def constraint_cols(db, table, constraint):
286
    check_name(table)
287
    check_name(constraint)
288 1849 aaronmk
    module = util.root_module(db.db)
289 464 aaronmk
    if module == 'psycopg2':
290
        return list(values(run_query(db, '''\
291
SELECT attname
292
FROM pg_constraint
293
JOIN pg_class ON pg_class.oid = conrelid
294
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
295
WHERE
296
    relname = %(table)s
297
    AND conname = %(constraint)s
298
ORDER BY attnum
299
''',
300
            {'table': table, 'constraint': constraint})))
301
    else: raise NotImplementedError("Can't list constraint columns for "+module+
302
        ' database')
303
304 832 aaronmk
def tables(db):
305 1849 aaronmk
    module = util.root_module(db.db)
306 832 aaronmk
    if module == 'psycopg2':
307
        return values(run_query(db, "SELECT tablename from pg_tables "
308
            "WHERE schemaname = 'public' ORDER BY tablename"))
309
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
310
    else: raise NotImplementedError("Can't list tables for "+module+' database')
311 830 aaronmk
312 833 aaronmk
##### Database management
313
314
def empty_db(db):
315
    for table in tables(db): truncate(db, table)
316
317 832 aaronmk
##### Heuristic queries
318
319 1554 aaronmk
def try_insert(db, table, row, returning=None):
320 830 aaronmk
    '''Recovers from errors'''
321 1554 aaronmk
    try: return insert(db, table, row, returning, recover=True)
322 46 aaronmk
    except Exception, e:
323
        msg = str(e)
324 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
325
            r'"(([^\W_]+)_[^"]+)"', msg)
326
        if match:
327
            constraint, table = match.groups()
328 854 aaronmk
            try: cols = index_cols(db, table, constraint)
329 465 aaronmk
            except NotImplementedError: raise e
330 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
331 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
332
            'constraint', msg)
333 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
334 13 aaronmk
        raise # no specific exception raised
335 11 aaronmk
336 471 aaronmk
def put(db, table, row, pkey, row_ct_ref=None):
337 1554 aaronmk
    '''Recovers from errors.
338
    Only works under PostgreSQL (uses `INSERT ... RETURNING`)'''
339 471 aaronmk
    try:
340 1554 aaronmk
        cur = try_insert(db, table, row, pkey)
341
        if row_ct_ref != None and cur.rowcount >= 0:
342
            row_ct_ref[0] += cur.rowcount
343
        return value(cur)
344 471 aaronmk
    except DuplicateKeyException, e:
345 1069 aaronmk
        return value(select(db, table, [pkey],
346
            util.dict_subset_right_join(row, e.cols), recover=True))
347 471 aaronmk
348 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
349 830 aaronmk
    '''Recovers from errors'''
350
    try: return value(select(db, table, [pkey], row, 1, recover=True))
351 14 aaronmk
    except StopIteration:
352 40 aaronmk
        if not create: raise
353 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row