Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1889 aaronmk
from Proxy import Proxy
11 1872 aaronmk
import rand
12 862 aaronmk
import strings
13 131 aaronmk
import util
14 11 aaronmk
15 832 aaronmk
##### Exceptions
16
17 135 aaronmk
def get_cur_query(cur):
18
    if hasattr(cur, 'query'): return cur.query
19
    elif hasattr(cur, '_last_executed'): return cur._last_executed
20
    else: return None
21 14 aaronmk
22 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
23 135 aaronmk
24 300 aaronmk
class DbException(exc.ExceptionWithCause):
25 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
26 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
27 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
28
29 360 aaronmk
class NameException(DbException): pass
30
31 468 aaronmk
class ExceptionWithColumns(DbException):
32
    def __init__(self, cols, cause=None):
33
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
34
        self.cols = cols
35 11 aaronmk
36 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
37 13 aaronmk
38 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
39 13 aaronmk
40 89 aaronmk
class EmptyRowException(DbException): pass
41
42 865 aaronmk
##### Warnings
43
44
class DbWarning(UserWarning): pass
45
46 832 aaronmk
##### Input validation
47
48 11 aaronmk
def check_name(name):
49
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
50
        +'" may contain only alphanumeric characters and _')
51
52 643 aaronmk
def esc_name(db, name):
53 1849 aaronmk
    module = util.root_module(db.db)
54 645 aaronmk
    if module == 'psycopg2': return name
55
        # Don't enclose in quotes because this disables case-insensitivity
56 643 aaronmk
    elif module == 'MySQLdb': quote = '`'
57 645 aaronmk
    else: raise NotImplementedError("Can't escape name for "+module+' database')
58 643 aaronmk
    return quote + name.replace(quote, '') + quote
59
60 1869 aaronmk
##### Database connections
61 1849 aaronmk
62 1869 aaronmk
db_engines = {
63
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
64
    'PostgreSQL': ('psycopg2', {}),
65
}
66
67
DatabaseErrors_set = set([DbException])
68
DatabaseErrors = tuple(DatabaseErrors_set)
69
70
def _add_module(module):
71
    DatabaseErrors_set.add(module.DatabaseError)
72
    global DatabaseErrors
73
    DatabaseErrors = tuple(DatabaseErrors_set)
74
75
def db_config_str(db_config):
76
    return db_config['engine']+' database '+db_config['database']
77
78 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
79 1894 aaronmk
80 1901 aaronmk
log_debug_none = lambda msg: None
81
82 1849 aaronmk
class DbConn:
83 1901 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none):
84 1869 aaronmk
        self.db_config = db_config
85
        self.serializable = serializable
86 1901 aaronmk
        self.log_debug = log_debug
87 1869 aaronmk
88
        self.__db = None
89 1889 aaronmk
        self.query_results = {}
90 1869 aaronmk
91
    def __getattr__(self, name):
92
        if name == '__dict__': raise Exception('getting __dict__')
93
        if name == 'db': return self._db()
94
        else: raise AttributeError()
95
96
    def __getstate__(self):
97
        state = copy.copy(self.__dict__) # shallow copy
98 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
99 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
100
        return state
101
102
    def _db(self):
103
        if self.__db == None:
104
            # Process db_config
105
            db_config = self.db_config.copy() # don't modify input!
106
            module_name, mappings = db_engines[db_config.pop('engine')]
107
            module = __import__(module_name)
108
            _add_module(module)
109
            for orig, new in mappings.iteritems():
110
                try: util.rename_key(db_config, orig, new)
111
                except KeyError: pass
112
113
            # Connect
114
            self.__db = module.connect(**db_config)
115
116
            # Configure connection
117
            if self.serializable: run_raw_query(self,
118
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
119
120
        return self.__db
121 1889 aaronmk
122 1891 aaronmk
    class DbCursor(Proxy):
123 1899 aaronmk
        def __init__(self, outer, cache_results):
124 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
125 1899 aaronmk
            if cache_results: self.query_results = outer.query_results
126
            else: self.query_results = None
127 1894 aaronmk
            self.query_lookup = None
128 1891 aaronmk
            self.result = []
129 1889 aaronmk
130 1894 aaronmk
        def execute(self, query, params=None):
131
            self.query_lookup = _query_lookup(query, params)
132 1904 aaronmk
            try: return_value = self.inner.execute(query, params)
133
            except Exception, e:
134
                self.result = e # cache the exception as the result
135
                self._cache_result()
136
                raise
137
            finally: self.query = get_cur_query(self.inner)
138 1894 aaronmk
            return return_value
139
140 1891 aaronmk
        def fetchone(self):
141
            row = self.inner.fetchone()
142 1899 aaronmk
            if row != None: self.result.append(row)
143
            # otherwise, fetched all rows
144 1904 aaronmk
            else: self._cache_result()
145
            return row
146
147
        def _cache_result(self):
148 1906 aaronmk
            is_insert = self._is_insert()
149
            # For inserts, only cache exceptions since inserts are not
150
            # idempotent, but an invalid insert will always be invalid
151
            if self.query_results != None and (not is_insert
152
                or isinstance(self.result, Exception)):
153
154 1894 aaronmk
                assert self.query_lookup != None
155 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
156
                    util.dict_subset(dicts.AttrsDictView(self),
157
                    ['query', 'result', 'rowcount', 'description']))
158 1906 aaronmk
159
        def _is_insert(self): return self.query.upper().find('INSERT') >= 0
160
161 1916 aaronmk
        class CacheCursor:
162
            def __init__(self, cached_result): self.__dict__ = cached_result
163
164
            def execute(self):
165
                if isinstance(self.result, Exception): raise self.result
166
                # otherwise, result is a rows list
167
                self.iter = iter(self.result)
168
169
            def fetchone(self):
170
                try: return self.iter.next()
171
                except StopIteration: return None
172 1891 aaronmk
173 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
174
        query_lookup = _query_lookup(query, params)
175 1903 aaronmk
        used_cache = False
176
        try:
177
            try:
178
                if not cacheable: raise KeyError
179 1916 aaronmk
                cur = self.query_results[query_lookup]
180 1903 aaronmk
                used_cache = True
181
            except KeyError:
182
                cur = self.DbCursor(self, cacheable)
183
                try: cur.execute(query, params)
184
                except Exception, e:
185
                    _add_cursor_info(e, cur)
186
                    raise
187 1916 aaronmk
            else: cur.execute()
188 1903 aaronmk
        finally:
189
            if self.log_debug != log_debug_none: # only compute msg if needed
190
                if used_cache: cache_status = 'Cache hit'
191
                elif cacheable: cache_status = 'Cache miss'
192
                else: cache_status = 'Non-cacheable'
193
                self.log_debug(cache_status+': '+strings.one_line(cur.query))
194
195
        return cur
196 1914 aaronmk
197
    def is_cached(self, query, params=None):
198
        return _query_lookup(query, params) in self.query_results
199 1849 aaronmk
200 1869 aaronmk
connect = DbConn
201
202 832 aaronmk
##### Querying
203
204 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
205
    '''For args, see DbConn.run_query()'''
206
    return db.run_query(*args, **kw_args)
207 11 aaronmk
208 832 aaronmk
##### Recoverable querying
209 15 aaronmk
210 11 aaronmk
def with_savepoint(db, func):
211 1872 aaronmk
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
212 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
213 11 aaronmk
    try: return_val = func()
214
    except:
215 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
216 11 aaronmk
        raise
217
    else:
218 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
219 11 aaronmk
        return return_val
220
221 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
222 830 aaronmk
    if recover == None: recover = False
223
224 1894 aaronmk
    def run(): return run_raw_query(db, query, params, cacheable)
225 1914 aaronmk
    if recover and not db.is_cached(query, params):
226
        return with_savepoint(db, run)
227
    else: return run() # don't need savepoint if cached
228 830 aaronmk
229 832 aaronmk
##### Result retrieval
230
231 1135 aaronmk
def col_names(cur): return (col[0] for col in cur.description)
232 832 aaronmk
233
def rows(cur): return iter(lambda: cur.fetchone(), None)
234
235 1893 aaronmk
def next_row(cur): return rows(cur).next()
236 832 aaronmk
237 1893 aaronmk
def row(cur):
238
    row_iter = rows(cur)
239
    row_ = row_iter.next()
240
    iters.consume_iter(row_iter) # fetch all rows so result will be cached
241
    return row_
242
243
def next_value(cur): return next_row(cur)[0]
244
245 832 aaronmk
def value(cur): return row(cur)[0]
246
247 1893 aaronmk
def values(cur): return iters.func_iter(lambda: next_value(cur))
248 832 aaronmk
249
def value_or_none(cur):
250
    try: return value(cur)
251
    except StopIteration: return None
252
253
##### Basic queries
254
255 1135 aaronmk
def select(db, table, fields=None, conds=None, limit=None, start=None,
256 1894 aaronmk
    recover=None, cacheable=True):
257 1135 aaronmk
    '''@param fields Use None to select all fields in the table'''
258
    if conds == None: conds = {}
259 135 aaronmk
    assert limit == None or type(limit) == int
260 865 aaronmk
    assert start == None or type(start) == int
261 15 aaronmk
    check_name(table)
262 1135 aaronmk
    if fields != None: map(check_name, fields)
263 15 aaronmk
    map(check_name, conds.keys())
264 865 aaronmk
265 11 aaronmk
    def cond(entry):
266 13 aaronmk
        col, value = entry
267 644 aaronmk
        cond_ = esc_name(db, col)+' '
268 11 aaronmk
        if value == None: cond_ += 'IS'
269
        else: cond_ += '='
270
        cond_ += ' %s'
271
        return cond_
272 1135 aaronmk
    query = 'SELECT '
273
    if fields == None: query += '*'
274
    else: query += ', '.join([esc_name(db, field) for field in fields])
275
    query += ' FROM '+esc_name(db, table)
276 865 aaronmk
277
    missing = True
278 89 aaronmk
    if conds != {}:
279
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
280 865 aaronmk
        missing = False
281
    if limit != None: query += ' LIMIT '+str(limit); missing = False
282
    if start != None:
283
        if start != 0: query += ' OFFSET '+str(start)
284
        missing = False
285
    if missing: warnings.warn(DbWarning(
286
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
287
288 1905 aaronmk
    return run_query(db, query, conds.values(), recover, cacheable)
289 11 aaronmk
290 1905 aaronmk
def insert(db, table, row, returning=None, recover=None, cacheable=True):
291 1554 aaronmk
    '''@param returning str|None An inserted column (such as pkey) to return'''
292 11 aaronmk
    check_name(table)
293 13 aaronmk
    cols = row.keys()
294 15 aaronmk
    map(check_name, cols)
295 89 aaronmk
    query = 'INSERT INTO '+table
296 1554 aaronmk
297 89 aaronmk
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
298
        +', '.join(['%s']*len(cols))+')'
299
    else: query += ' DEFAULT VALUES'
300 1554 aaronmk
301
    if returning != None:
302
        check_name(returning)
303
        query += ' RETURNING '+returning
304
305 1905 aaronmk
    return run_query(db, query, row.values(), recover, cacheable)
306 11 aaronmk
307 135 aaronmk
def last_insert_id(db):
308 1849 aaronmk
    module = util.root_module(db.db)
309 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
310
    elif module == 'MySQLdb': return db.insert_id()
311
    else: return None
312 13 aaronmk
313 832 aaronmk
def truncate(db, table):
314
    check_name(table)
315 869 aaronmk
    return run_raw_query(db, 'TRUNCATE '+table+' CASCADE')
316 832 aaronmk
317
##### Database structure queries
318
319 1850 aaronmk
def pkey(db, table, recover=None):
320 832 aaronmk
    '''Assumed to be first column in table'''
321
    check_name(table)
322 1915 aaronmk
    return col_names(run_query(db,
323
        'SELECT * FROM '+table+' LIMIT 0', recover=recover)).next()
324 832 aaronmk
325 853 aaronmk
def index_cols(db, table, index):
326
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
327
    automatically created. When you don't know whether something is a UNIQUE
328
    constraint or a UNIQUE index, use this function.'''
329
    check_name(table)
330
    check_name(index)
331 1909 aaronmk
    module = util.root_module(db.db)
332
    if module == 'psycopg2':
333
        return list(values(run_query(db, '''\
334 853 aaronmk
SELECT attname
335 866 aaronmk
FROM
336
(
337
        SELECT attnum, attname
338
        FROM pg_index
339
        JOIN pg_class index ON index.oid = indexrelid
340
        JOIN pg_class table_ ON table_.oid = indrelid
341
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
342
        WHERE
343
            table_.relname = %(table)s
344
            AND index.relname = %(index)s
345
    UNION
346
        SELECT attnum, attname
347
        FROM
348
        (
349
            SELECT
350
                indrelid
351
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
352
                    AS indkey
353
            FROM pg_index
354
            JOIN pg_class index ON index.oid = indexrelid
355
            JOIN pg_class table_ ON table_.oid = indrelid
356
            WHERE
357
                table_.relname = %(table)s
358
                AND index.relname = %(index)s
359
        ) s
360
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
361
) s
362 853 aaronmk
ORDER BY attnum
363
''',
364 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
365
    else: raise NotImplementedError("Can't list index columns for "+module+
366
        ' database')
367 853 aaronmk
368 464 aaronmk
def constraint_cols(db, table, constraint):
369
    check_name(table)
370
    check_name(constraint)
371 1849 aaronmk
    module = util.root_module(db.db)
372 464 aaronmk
    if module == 'psycopg2':
373
        return list(values(run_query(db, '''\
374
SELECT attname
375
FROM pg_constraint
376
JOIN pg_class ON pg_class.oid = conrelid
377
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
378
WHERE
379
    relname = %(table)s
380
    AND conname = %(constraint)s
381
ORDER BY attnum
382
''',
383
            {'table': table, 'constraint': constraint})))
384
    else: raise NotImplementedError("Can't list constraint columns for "+module+
385
        ' database')
386
387 832 aaronmk
def tables(db):
388 1849 aaronmk
    module = util.root_module(db.db)
389 832 aaronmk
    if module == 'psycopg2':
390
        return values(run_query(db, "SELECT tablename from pg_tables "
391
            "WHERE schemaname = 'public' ORDER BY tablename"))
392
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
393
    else: raise NotImplementedError("Can't list tables for "+module+' database')
394 830 aaronmk
395 833 aaronmk
##### Database management
396
397
def empty_db(db):
398
    for table in tables(db): truncate(db, table)
399
400 832 aaronmk
##### Heuristic queries
401
402 1554 aaronmk
def try_insert(db, table, row, returning=None):
403 830 aaronmk
    '''Recovers from errors'''
404 1554 aaronmk
    try: return insert(db, table, row, returning, recover=True)
405 46 aaronmk
    except Exception, e:
406
        msg = str(e)
407 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
408
            r'"(([^\W_]+)_[^"]+)"', msg)
409
        if match:
410
            constraint, table = match.groups()
411 854 aaronmk
            try: cols = index_cols(db, table, constraint)
412 465 aaronmk
            except NotImplementedError: raise e
413 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
414 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
415
            'constraint', msg)
416 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
417 13 aaronmk
        raise # no specific exception raised
418 11 aaronmk
419 471 aaronmk
def put(db, table, row, pkey, row_ct_ref=None):
420 1554 aaronmk
    '''Recovers from errors.
421
    Only works under PostgreSQL (uses `INSERT ... RETURNING`)'''
422 471 aaronmk
    try:
423 1554 aaronmk
        cur = try_insert(db, table, row, pkey)
424
        if row_ct_ref != None and cur.rowcount >= 0:
425
            row_ct_ref[0] += cur.rowcount
426
        return value(cur)
427 471 aaronmk
    except DuplicateKeyException, e:
428 1069 aaronmk
        return value(select(db, table, [pkey],
429
            util.dict_subset_right_join(row, e.cols), recover=True))
430 471 aaronmk
431 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
432 830 aaronmk
    '''Recovers from errors'''
433
    try: return value(select(db, table, [pkey], row, 1, recover=True))
434 14 aaronmk
    except StopIteration:
435 40 aaronmk
        if not create: raise
436 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row