Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1889 aaronmk
from Proxy import Proxy
11 1872 aaronmk
import rand
12 862 aaronmk
import strings
13 131 aaronmk
import util
14 11 aaronmk
15 832 aaronmk
##### Exceptions
16
17 135 aaronmk
def get_cur_query(cur):
18
    if hasattr(cur, 'query'): return cur.query
19
    elif hasattr(cur, '_last_executed'): return cur._last_executed
20
    else: return None
21 14 aaronmk
22 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
23 135 aaronmk
24 300 aaronmk
class DbException(exc.ExceptionWithCause):
25 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
26 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
27 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
28
29 360 aaronmk
class NameException(DbException): pass
30
31 468 aaronmk
class ExceptionWithColumns(DbException):
32
    def __init__(self, cols, cause=None):
33
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
34
        self.cols = cols
35 11 aaronmk
36 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
37 13 aaronmk
38 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
39 13 aaronmk
40 89 aaronmk
class EmptyRowException(DbException): pass
41
42 865 aaronmk
##### Warnings
43
44
class DbWarning(UserWarning): pass
45
46 1869 aaronmk
##### Database connections
47 1849 aaronmk
48 1869 aaronmk
db_engines = {
49
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
50
    'PostgreSQL': ('psycopg2', {}),
51
}
52
53
DatabaseErrors_set = set([DbException])
54
DatabaseErrors = tuple(DatabaseErrors_set)
55
56
def _add_module(module):
57
    DatabaseErrors_set.add(module.DatabaseError)
58
    global DatabaseErrors
59
    DatabaseErrors = tuple(DatabaseErrors_set)
60
61
def db_config_str(db_config):
62
    return db_config['engine']+' database '+db_config['database']
63
64 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
65 1894 aaronmk
66 1901 aaronmk
log_debug_none = lambda msg: None
67
68 1849 aaronmk
class DbConn:
69 1901 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none):
70 1869 aaronmk
        self.db_config = db_config
71
        self.serializable = serializable
72 1901 aaronmk
        self.log_debug = log_debug
73 1869 aaronmk
74
        self.__db = None
75 1889 aaronmk
        self.query_results = {}
76 1869 aaronmk
77
    def __getattr__(self, name):
78
        if name == '__dict__': raise Exception('getting __dict__')
79
        if name == 'db': return self._db()
80
        else: raise AttributeError()
81
82
    def __getstate__(self):
83
        state = copy.copy(self.__dict__) # shallow copy
84 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
85 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
86
        return state
87
88
    def _db(self):
89
        if self.__db == None:
90
            # Process db_config
91
            db_config = self.db_config.copy() # don't modify input!
92
            module_name, mappings = db_engines[db_config.pop('engine')]
93
            module = __import__(module_name)
94
            _add_module(module)
95
            for orig, new in mappings.iteritems():
96
                try: util.rename_key(db_config, orig, new)
97
                except KeyError: pass
98
99
            # Connect
100
            self.__db = module.connect(**db_config)
101
102
            # Configure connection
103
            if self.serializable: run_raw_query(self,
104
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
105
106
        return self.__db
107 1889 aaronmk
108 1891 aaronmk
    class DbCursor(Proxy):
109 1899 aaronmk
        def __init__(self, outer, cache_results):
110 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
111 1899 aaronmk
            if cache_results: self.query_results = outer.query_results
112
            else: self.query_results = None
113 1894 aaronmk
            self.query_lookup = None
114 1891 aaronmk
            self.result = []
115 1889 aaronmk
116 1894 aaronmk
        def execute(self, query, params=None):
117
            self.query_lookup = _query_lookup(query, params)
118 1904 aaronmk
            try: return_value = self.inner.execute(query, params)
119
            except Exception, e:
120
                self.result = e # cache the exception as the result
121
                self._cache_result()
122
                raise
123
            finally: self.query = get_cur_query(self.inner)
124 1894 aaronmk
            return return_value
125
126 1891 aaronmk
        def fetchone(self):
127
            row = self.inner.fetchone()
128 1899 aaronmk
            if row != None: self.result.append(row)
129
            # otherwise, fetched all rows
130 1904 aaronmk
            else: self._cache_result()
131
            return row
132
133
        def _cache_result(self):
134 1906 aaronmk
            is_insert = self._is_insert()
135
            # For inserts, only cache exceptions since inserts are not
136
            # idempotent, but an invalid insert will always be invalid
137
            if self.query_results != None and (not is_insert
138
                or isinstance(self.result, Exception)):
139
140 1894 aaronmk
                assert self.query_lookup != None
141 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
142
                    util.dict_subset(dicts.AttrsDictView(self),
143
                    ['query', 'result', 'rowcount', 'description']))
144 1906 aaronmk
145
        def _is_insert(self): return self.query.upper().find('INSERT') >= 0
146
147 1916 aaronmk
        class CacheCursor:
148
            def __init__(self, cached_result): self.__dict__ = cached_result
149
150
            def execute(self):
151
                if isinstance(self.result, Exception): raise self.result
152
                # otherwise, result is a rows list
153
                self.iter = iter(self.result)
154
155
            def fetchone(self):
156
                try: return self.iter.next()
157
                except StopIteration: return None
158 1891 aaronmk
159 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
160
        query_lookup = _query_lookup(query, params)
161 1903 aaronmk
        used_cache = False
162
        try:
163
            try:
164
                if not cacheable: raise KeyError
165 1916 aaronmk
                cur = self.query_results[query_lookup]
166 1903 aaronmk
                used_cache = True
167
            except KeyError:
168
                cur = self.DbCursor(self, cacheable)
169
                try: cur.execute(query, params)
170
                except Exception, e:
171
                    _add_cursor_info(e, cur)
172
                    raise
173 1916 aaronmk
            else: cur.execute()
174 1903 aaronmk
        finally:
175
            if self.log_debug != log_debug_none: # only compute msg if needed
176
                if used_cache: cache_status = 'Cache hit'
177
                elif cacheable: cache_status = 'Cache miss'
178
                else: cache_status = 'Non-cacheable'
179
                self.log_debug(cache_status+': '+strings.one_line(cur.query))
180
181
        return cur
182 1914 aaronmk
183
    def is_cached(self, query, params=None):
184
        return _query_lookup(query, params) in self.query_results
185 1849 aaronmk
186 1869 aaronmk
connect = DbConn
187
188 1919 aaronmk
##### Input validation
189
190
def check_name(name):
191
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
192
        +'" may contain only alphanumeric characters and _')
193
194
def esc_name_by_module(module, name, preserve_case=False):
195
    if module == 'psycopg2':
196
        if preserve_case: quote = '"'
197
        # Don't enclose in quotes because this disables case-insensitivity
198
        else: return name
199
    elif module == 'MySQLdb': quote = '`'
200
    else: raise NotImplementedError("Can't escape name for "+module+' database')
201
    return quote + name.replace(quote, '') + quote
202
203
def esc_name_by_engine(engine, name, **kw_args):
204
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
205
206
def esc_name(db, name, **kw_args):
207
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
208
209 832 aaronmk
##### Querying
210
211 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
212
    '''For args, see DbConn.run_query()'''
213
    return db.run_query(*args, **kw_args)
214 11 aaronmk
215 832 aaronmk
##### Recoverable querying
216 15 aaronmk
217 11 aaronmk
def with_savepoint(db, func):
218 1872 aaronmk
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
219 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
220 11 aaronmk
    try: return_val = func()
221
    except:
222 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
223 11 aaronmk
        raise
224
    else:
225 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
226 11 aaronmk
        return return_val
227
228 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
229 830 aaronmk
    if recover == None: recover = False
230
231 1894 aaronmk
    def run(): return run_raw_query(db, query, params, cacheable)
232 1914 aaronmk
    if recover and not db.is_cached(query, params):
233
        return with_savepoint(db, run)
234
    else: return run() # don't need savepoint if cached
235 830 aaronmk
236 832 aaronmk
##### Result retrieval
237
238 1135 aaronmk
def col_names(cur): return (col[0] for col in cur.description)
239 832 aaronmk
240
def rows(cur): return iter(lambda: cur.fetchone(), None)
241
242 1893 aaronmk
def next_row(cur): return rows(cur).next()
243 832 aaronmk
244 1893 aaronmk
def row(cur):
245
    row_iter = rows(cur)
246
    row_ = row_iter.next()
247
    iters.consume_iter(row_iter) # fetch all rows so result will be cached
248
    return row_
249
250
def next_value(cur): return next_row(cur)[0]
251
252 832 aaronmk
def value(cur): return row(cur)[0]
253
254 1893 aaronmk
def values(cur): return iters.func_iter(lambda: next_value(cur))
255 832 aaronmk
256
def value_or_none(cur):
257
    try: return value(cur)
258
    except StopIteration: return None
259
260
##### Basic queries
261
262 1135 aaronmk
def select(db, table, fields=None, conds=None, limit=None, start=None,
263 1894 aaronmk
    recover=None, cacheable=True):
264 1135 aaronmk
    '''@param fields Use None to select all fields in the table'''
265
    if conds == None: conds = {}
266 135 aaronmk
    assert limit == None or type(limit) == int
267 865 aaronmk
    assert start == None or type(start) == int
268 15 aaronmk
    check_name(table)
269 1135 aaronmk
    if fields != None: map(check_name, fields)
270 15 aaronmk
    map(check_name, conds.keys())
271 865 aaronmk
272 11 aaronmk
    def cond(entry):
273 13 aaronmk
        col, value = entry
274 644 aaronmk
        cond_ = esc_name(db, col)+' '
275 11 aaronmk
        if value == None: cond_ += 'IS'
276
        else: cond_ += '='
277
        cond_ += ' %s'
278
        return cond_
279 1135 aaronmk
    query = 'SELECT '
280
    if fields == None: query += '*'
281
    else: query += ', '.join([esc_name(db, field) for field in fields])
282
    query += ' FROM '+esc_name(db, table)
283 865 aaronmk
284
    missing = True
285 89 aaronmk
    if conds != {}:
286
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
287 865 aaronmk
        missing = False
288
    if limit != None: query += ' LIMIT '+str(limit); missing = False
289
    if start != None:
290
        if start != 0: query += ' OFFSET '+str(start)
291
        missing = False
292
    if missing: warnings.warn(DbWarning(
293
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
294
295 1905 aaronmk
    return run_query(db, query, conds.values(), recover, cacheable)
296 11 aaronmk
297 1905 aaronmk
def insert(db, table, row, returning=None, recover=None, cacheable=True):
298 1554 aaronmk
    '''@param returning str|None An inserted column (such as pkey) to return'''
299 11 aaronmk
    check_name(table)
300 13 aaronmk
    cols = row.keys()
301 15 aaronmk
    map(check_name, cols)
302 89 aaronmk
    query = 'INSERT INTO '+table
303 1554 aaronmk
304 89 aaronmk
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
305
        +', '.join(['%s']*len(cols))+')'
306
    else: query += ' DEFAULT VALUES'
307 1554 aaronmk
308
    if returning != None:
309
        check_name(returning)
310
        query += ' RETURNING '+returning
311
312 1905 aaronmk
    return run_query(db, query, row.values(), recover, cacheable)
313 11 aaronmk
314 135 aaronmk
def last_insert_id(db):
315 1849 aaronmk
    module = util.root_module(db.db)
316 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
317
    elif module == 'MySQLdb': return db.insert_id()
318
    else: return None
319 13 aaronmk
320 832 aaronmk
def truncate(db, table):
321
    check_name(table)
322 869 aaronmk
    return run_raw_query(db, 'TRUNCATE '+table+' CASCADE')
323 832 aaronmk
324
##### Database structure queries
325
326 1850 aaronmk
def pkey(db, table, recover=None):
327 832 aaronmk
    '''Assumed to be first column in table'''
328
    check_name(table)
329 1915 aaronmk
    return col_names(run_query(db,
330
        'SELECT * FROM '+table+' LIMIT 0', recover=recover)).next()
331 832 aaronmk
332 853 aaronmk
def index_cols(db, table, index):
333
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
334
    automatically created. When you don't know whether something is a UNIQUE
335
    constraint or a UNIQUE index, use this function.'''
336
    check_name(table)
337
    check_name(index)
338 1909 aaronmk
    module = util.root_module(db.db)
339
    if module == 'psycopg2':
340
        return list(values(run_query(db, '''\
341 853 aaronmk
SELECT attname
342 866 aaronmk
FROM
343
(
344
        SELECT attnum, attname
345
        FROM pg_index
346
        JOIN pg_class index ON index.oid = indexrelid
347
        JOIN pg_class table_ ON table_.oid = indrelid
348
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
349
        WHERE
350
            table_.relname = %(table)s
351
            AND index.relname = %(index)s
352
    UNION
353
        SELECT attnum, attname
354
        FROM
355
        (
356
            SELECT
357
                indrelid
358
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
359
                    AS indkey
360
            FROM pg_index
361
            JOIN pg_class index ON index.oid = indexrelid
362
            JOIN pg_class table_ ON table_.oid = indrelid
363
            WHERE
364
                table_.relname = %(table)s
365
                AND index.relname = %(index)s
366
        ) s
367
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
368
) s
369 853 aaronmk
ORDER BY attnum
370
''',
371 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
372
    else: raise NotImplementedError("Can't list index columns for "+module+
373
        ' database')
374 853 aaronmk
375 464 aaronmk
def constraint_cols(db, table, constraint):
376
    check_name(table)
377
    check_name(constraint)
378 1849 aaronmk
    module = util.root_module(db.db)
379 464 aaronmk
    if module == 'psycopg2':
380
        return list(values(run_query(db, '''\
381
SELECT attname
382
FROM pg_constraint
383
JOIN pg_class ON pg_class.oid = conrelid
384
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
385
WHERE
386
    relname = %(table)s
387
    AND conname = %(constraint)s
388
ORDER BY attnum
389
''',
390
            {'table': table, 'constraint': constraint})))
391
    else: raise NotImplementedError("Can't list constraint columns for "+module+
392
        ' database')
393
394 832 aaronmk
def tables(db):
395 1849 aaronmk
    module = util.root_module(db.db)
396 832 aaronmk
    if module == 'psycopg2':
397
        return values(run_query(db, "SELECT tablename from pg_tables "
398
            "WHERE schemaname = 'public' ORDER BY tablename"))
399
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
400
    else: raise NotImplementedError("Can't list tables for "+module+' database')
401 830 aaronmk
402 833 aaronmk
##### Database management
403
404
def empty_db(db):
405
    for table in tables(db): truncate(db, table)
406
407 832 aaronmk
##### Heuristic queries
408
409 1554 aaronmk
def try_insert(db, table, row, returning=None):
410 830 aaronmk
    '''Recovers from errors'''
411 1554 aaronmk
    try: return insert(db, table, row, returning, recover=True)
412 46 aaronmk
    except Exception, e:
413
        msg = str(e)
414 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
415
            r'"(([^\W_]+)_[^"]+)"', msg)
416
        if match:
417
            constraint, table = match.groups()
418 854 aaronmk
            try: cols = index_cols(db, table, constraint)
419 465 aaronmk
            except NotImplementedError: raise e
420 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
421 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
422
            'constraint', msg)
423 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
424 13 aaronmk
        raise # no specific exception raised
425 11 aaronmk
426 471 aaronmk
def put(db, table, row, pkey, row_ct_ref=None):
427 1554 aaronmk
    '''Recovers from errors.
428
    Only works under PostgreSQL (uses `INSERT ... RETURNING`)'''
429 471 aaronmk
    try:
430 1554 aaronmk
        cur = try_insert(db, table, row, pkey)
431
        if row_ct_ref != None and cur.rowcount >= 0:
432
            row_ct_ref[0] += cur.rowcount
433
        return value(cur)
434 471 aaronmk
    except DuplicateKeyException, e:
435 1069 aaronmk
        return value(select(db, table, [pkey],
436
            util.dict_subset_right_join(row, e.cols), recover=True))
437 471 aaronmk
438 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
439 830 aaronmk
    '''Recovers from errors'''
440
    try: return value(select(db, table, [pkey], row, 1, recover=True))
441 14 aaronmk
    except StopIteration:
442 40 aaronmk
        if not create: raise
443 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row