Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1889 aaronmk
from Proxy import Proxy
11 1872 aaronmk
import rand
12 862 aaronmk
import strings
13 131 aaronmk
import util
14 11 aaronmk
15 832 aaronmk
##### Exceptions
16
17 135 aaronmk
def get_cur_query(cur):
18
    if hasattr(cur, 'query'): return cur.query
19
    elif hasattr(cur, '_last_executed'): return cur._last_executed
20
    else: return None
21 14 aaronmk
22 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
23 135 aaronmk
24 300 aaronmk
class DbException(exc.ExceptionWithCause):
25 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
26 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
27 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
28
29 360 aaronmk
class NameException(DbException): pass
30
31 468 aaronmk
class ExceptionWithColumns(DbException):
32
    def __init__(self, cols, cause=None):
33
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
34
        self.cols = cols
35 11 aaronmk
36 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
37 13 aaronmk
38 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
39 13 aaronmk
40 89 aaronmk
class EmptyRowException(DbException): pass
41
42 865 aaronmk
##### Warnings
43
44
class DbWarning(UserWarning): pass
45
46 1869 aaronmk
##### Database connections
47 1849 aaronmk
48 1926 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database']
49
50 1869 aaronmk
db_engines = {
51
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
52
    'PostgreSQL': ('psycopg2', {}),
53
}
54
55
DatabaseErrors_set = set([DbException])
56
DatabaseErrors = tuple(DatabaseErrors_set)
57
58
def _add_module(module):
59
    DatabaseErrors_set.add(module.DatabaseError)
60
    global DatabaseErrors
61
    DatabaseErrors = tuple(DatabaseErrors_set)
62
63
def db_config_str(db_config):
64
    return db_config['engine']+' database '+db_config['database']
65
66 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
67 1894 aaronmk
68 1901 aaronmk
log_debug_none = lambda msg: None
69
70 1849 aaronmk
class DbConn:
71 1901 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none):
72 1869 aaronmk
        self.db_config = db_config
73
        self.serializable = serializable
74 1901 aaronmk
        self.log_debug = log_debug
75 1869 aaronmk
76
        self.__db = None
77 1889 aaronmk
        self.query_results = {}
78 1869 aaronmk
79
    def __getattr__(self, name):
80
        if name == '__dict__': raise Exception('getting __dict__')
81
        if name == 'db': return self._db()
82
        else: raise AttributeError()
83
84
    def __getstate__(self):
85
        state = copy.copy(self.__dict__) # shallow copy
86 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
87 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
88
        return state
89
90
    def _db(self):
91
        if self.__db == None:
92
            # Process db_config
93
            db_config = self.db_config.copy() # don't modify input!
94
            module_name, mappings = db_engines[db_config.pop('engine')]
95
            module = __import__(module_name)
96
            _add_module(module)
97
            for orig, new in mappings.iteritems():
98
                try: util.rename_key(db_config, orig, new)
99
                except KeyError: pass
100
101
            # Connect
102
            self.__db = module.connect(**db_config)
103
104
            # Configure connection
105
            if self.serializable: run_raw_query(self,
106
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
107
108
        return self.__db
109 1889 aaronmk
110 1891 aaronmk
    class DbCursor(Proxy):
111 1927 aaronmk
        def __init__(self, outer):
112 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
113 1927 aaronmk
            self.query_results = outer.query_results
114 1894 aaronmk
            self.query_lookup = None
115 1891 aaronmk
            self.result = []
116 1889 aaronmk
117 1894 aaronmk
        def execute(self, query, params=None):
118
            self.query_lookup = _query_lookup(query, params)
119 1904 aaronmk
            try: return_value = self.inner.execute(query, params)
120
            except Exception, e:
121
                self.result = e # cache the exception as the result
122
                self._cache_result()
123
                raise
124
            finally: self.query = get_cur_query(self.inner)
125 1894 aaronmk
            return return_value
126
127 1891 aaronmk
        def fetchone(self):
128
            row = self.inner.fetchone()
129 1899 aaronmk
            if row != None: self.result.append(row)
130
            # otherwise, fetched all rows
131 1904 aaronmk
            else: self._cache_result()
132
            return row
133
134
        def _cache_result(self):
135 1906 aaronmk
            is_insert = self._is_insert()
136
            # For inserts, only cache exceptions since inserts are not
137
            # idempotent, but an invalid insert will always be invalid
138
            if self.query_results != None and (not is_insert
139
                or isinstance(self.result, Exception)):
140
141 1894 aaronmk
                assert self.query_lookup != None
142 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
143
                    util.dict_subset(dicts.AttrsDictView(self),
144
                    ['query', 'result', 'rowcount', 'description']))
145 1906 aaronmk
146
        def _is_insert(self): return self.query.upper().find('INSERT') >= 0
147
148 1916 aaronmk
        class CacheCursor:
149
            def __init__(self, cached_result): self.__dict__ = cached_result
150
151 1927 aaronmk
            def execute(self, *args, **kw_args):
152 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
153
                # otherwise, result is a rows list
154
                self.iter = iter(self.result)
155
156
            def fetchone(self):
157
                try: return self.iter.next()
158
                except StopIteration: return None
159 1891 aaronmk
160 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
161 1903 aaronmk
        used_cache = False
162
        try:
163 1927 aaronmk
            # Get cursor
164
            if cacheable:
165
                query_lookup = _query_lookup(query, params)
166
                try:
167
                    cur = self.query_results[query_lookup]
168
                    used_cache = True
169
                except KeyError: cur = self.DbCursor(self)
170
            else: cur = self.db.cursor()
171
172
            # Run query
173
            try: cur.execute(query, params)
174
            except Exception, e:
175
                _add_cursor_info(e, cur)
176
                raise
177 1903 aaronmk
        finally:
178
            if self.log_debug != log_debug_none: # only compute msg if needed
179
                if used_cache: cache_status = 'Cache hit'
180
                elif cacheable: cache_status = 'Cache miss'
181
                else: cache_status = 'Non-cacheable'
182 1927 aaronmk
                self.log_debug(cache_status+': '
183
                    +strings.one_line(get_cur_query(cur)))
184 1903 aaronmk
185
        return cur
186 1914 aaronmk
187
    def is_cached(self, query, params=None):
188
        return _query_lookup(query, params) in self.query_results
189 1849 aaronmk
190 1869 aaronmk
connect = DbConn
191
192 1919 aaronmk
##### Input validation
193
194
def check_name(name):
195
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
196
        +'" may contain only alphanumeric characters and _')
197
198
def esc_name_by_module(module, name, preserve_case=False):
199
    if module == 'psycopg2':
200
        if preserve_case: quote = '"'
201
        # Don't enclose in quotes because this disables case-insensitivity
202
        else: return name
203
    elif module == 'MySQLdb': quote = '`'
204
    else: raise NotImplementedError("Can't escape name for "+module+' database')
205
    return quote + name.replace(quote, '') + quote
206
207
def esc_name_by_engine(engine, name, **kw_args):
208
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
209
210
def esc_name(db, name, **kw_args):
211
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
212
213 832 aaronmk
##### Querying
214
215 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
216
    '''For args, see DbConn.run_query()'''
217
    return db.run_query(*args, **kw_args)
218 11 aaronmk
219 832 aaronmk
##### Recoverable querying
220 15 aaronmk
221 11 aaronmk
def with_savepoint(db, func):
222 1872 aaronmk
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
223 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
224 11 aaronmk
    try: return_val = func()
225
    except:
226 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
227 11 aaronmk
        raise
228
    else:
229 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
230 11 aaronmk
        return return_val
231
232 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
233 830 aaronmk
    if recover == None: recover = False
234
235 1894 aaronmk
    def run(): return run_raw_query(db, query, params, cacheable)
236 1914 aaronmk
    if recover and not db.is_cached(query, params):
237
        return with_savepoint(db, run)
238
    else: return run() # don't need savepoint if cached
239 830 aaronmk
240 832 aaronmk
##### Result retrieval
241
242 1135 aaronmk
def col_names(cur): return (col[0] for col in cur.description)
243 832 aaronmk
244
def rows(cur): return iter(lambda: cur.fetchone(), None)
245
246 1893 aaronmk
def next_row(cur): return rows(cur).next()
247 832 aaronmk
248 1893 aaronmk
def row(cur):
249
    row_iter = rows(cur)
250
    row_ = row_iter.next()
251
    iters.consume_iter(row_iter) # fetch all rows so result will be cached
252
    return row_
253
254
def next_value(cur): return next_row(cur)[0]
255
256 832 aaronmk
def value(cur): return row(cur)[0]
257
258 1893 aaronmk
def values(cur): return iters.func_iter(lambda: next_value(cur))
259 832 aaronmk
260
def value_or_none(cur):
261
    try: return value(cur)
262
    except StopIteration: return None
263
264
##### Basic queries
265
266 1135 aaronmk
def select(db, table, fields=None, conds=None, limit=None, start=None,
267 1894 aaronmk
    recover=None, cacheable=True):
268 1135 aaronmk
    '''@param fields Use None to select all fields in the table'''
269
    if conds == None: conds = {}
270 135 aaronmk
    assert limit == None or type(limit) == int
271 865 aaronmk
    assert start == None or type(start) == int
272 15 aaronmk
    check_name(table)
273 1135 aaronmk
    if fields != None: map(check_name, fields)
274 15 aaronmk
    map(check_name, conds.keys())
275 865 aaronmk
276 11 aaronmk
    def cond(entry):
277 13 aaronmk
        col, value = entry
278 644 aaronmk
        cond_ = esc_name(db, col)+' '
279 11 aaronmk
        if value == None: cond_ += 'IS'
280
        else: cond_ += '='
281
        cond_ += ' %s'
282
        return cond_
283 1135 aaronmk
    query = 'SELECT '
284
    if fields == None: query += '*'
285
    else: query += ', '.join([esc_name(db, field) for field in fields])
286
    query += ' FROM '+esc_name(db, table)
287 865 aaronmk
288
    missing = True
289 89 aaronmk
    if conds != {}:
290
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
291 865 aaronmk
        missing = False
292
    if limit != None: query += ' LIMIT '+str(limit); missing = False
293
    if start != None:
294
        if start != 0: query += ' OFFSET '+str(start)
295
        missing = False
296
    if missing: warnings.warn(DbWarning(
297
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
298
299 1905 aaronmk
    return run_query(db, query, conds.values(), recover, cacheable)
300 11 aaronmk
301 1905 aaronmk
def insert(db, table, row, returning=None, recover=None, cacheable=True):
302 1554 aaronmk
    '''@param returning str|None An inserted column (such as pkey) to return'''
303 11 aaronmk
    check_name(table)
304 13 aaronmk
    cols = row.keys()
305 15 aaronmk
    map(check_name, cols)
306 89 aaronmk
    query = 'INSERT INTO '+table
307 1554 aaronmk
308 89 aaronmk
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
309
        +', '.join(['%s']*len(cols))+')'
310
    else: query += ' DEFAULT VALUES'
311 1554 aaronmk
312
    if returning != None:
313
        check_name(returning)
314
        query += ' RETURNING '+returning
315
316 1905 aaronmk
    return run_query(db, query, row.values(), recover, cacheable)
317 11 aaronmk
318 135 aaronmk
def last_insert_id(db):
319 1849 aaronmk
    module = util.root_module(db.db)
320 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
321
    elif module == 'MySQLdb': return db.insert_id()
322
    else: return None
323 13 aaronmk
324 832 aaronmk
def truncate(db, table):
325
    check_name(table)
326 869 aaronmk
    return run_raw_query(db, 'TRUNCATE '+table+' CASCADE')
327 832 aaronmk
328
##### Database structure queries
329
330 1850 aaronmk
def pkey(db, table, recover=None):
331 832 aaronmk
    '''Assumed to be first column in table'''
332
    check_name(table)
333 1929 aaronmk
    return col_names(select(db, table, limit=0, recover=recover)).next()
334 832 aaronmk
335 853 aaronmk
def index_cols(db, table, index):
336
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
337
    automatically created. When you don't know whether something is a UNIQUE
338
    constraint or a UNIQUE index, use this function.'''
339
    check_name(table)
340
    check_name(index)
341 1909 aaronmk
    module = util.root_module(db.db)
342
    if module == 'psycopg2':
343
        return list(values(run_query(db, '''\
344 853 aaronmk
SELECT attname
345 866 aaronmk
FROM
346
(
347
        SELECT attnum, attname
348
        FROM pg_index
349
        JOIN pg_class index ON index.oid = indexrelid
350
        JOIN pg_class table_ ON table_.oid = indrelid
351
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
352
        WHERE
353
            table_.relname = %(table)s
354
            AND index.relname = %(index)s
355
    UNION
356
        SELECT attnum, attname
357
        FROM
358
        (
359
            SELECT
360
                indrelid
361
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
362
                    AS indkey
363
            FROM pg_index
364
            JOIN pg_class index ON index.oid = indexrelid
365
            JOIN pg_class table_ ON table_.oid = indrelid
366
            WHERE
367
                table_.relname = %(table)s
368
                AND index.relname = %(index)s
369
        ) s
370
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
371
) s
372 853 aaronmk
ORDER BY attnum
373
''',
374 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
375
    else: raise NotImplementedError("Can't list index columns for "+module+
376
        ' database')
377 853 aaronmk
378 464 aaronmk
def constraint_cols(db, table, constraint):
379
    check_name(table)
380
    check_name(constraint)
381 1849 aaronmk
    module = util.root_module(db.db)
382 464 aaronmk
    if module == 'psycopg2':
383
        return list(values(run_query(db, '''\
384
SELECT attname
385
FROM pg_constraint
386
JOIN pg_class ON pg_class.oid = conrelid
387
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
388
WHERE
389
    relname = %(table)s
390
    AND conname = %(constraint)s
391
ORDER BY attnum
392
''',
393
            {'table': table, 'constraint': constraint})))
394
    else: raise NotImplementedError("Can't list constraint columns for "+module+
395
        ' database')
396
397 832 aaronmk
def tables(db):
398 1849 aaronmk
    module = util.root_module(db.db)
399 832 aaronmk
    if module == 'psycopg2':
400
        return values(run_query(db, "SELECT tablename from pg_tables "
401
            "WHERE schemaname = 'public' ORDER BY tablename"))
402
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
403
    else: raise NotImplementedError("Can't list tables for "+module+' database')
404 830 aaronmk
405 833 aaronmk
##### Database management
406
407
def empty_db(db):
408
    for table in tables(db): truncate(db, table)
409
410 832 aaronmk
##### Heuristic queries
411
412 1554 aaronmk
def try_insert(db, table, row, returning=None):
413 830 aaronmk
    '''Recovers from errors'''
414 1554 aaronmk
    try: return insert(db, table, row, returning, recover=True)
415 46 aaronmk
    except Exception, e:
416
        msg = str(e)
417 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
418
            r'"(([^\W_]+)_[^"]+)"', msg)
419
        if match:
420
            constraint, table = match.groups()
421 854 aaronmk
            try: cols = index_cols(db, table, constraint)
422 465 aaronmk
            except NotImplementedError: raise e
423 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
424 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
425
            'constraint', msg)
426 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
427 13 aaronmk
        raise # no specific exception raised
428 11 aaronmk
429 471 aaronmk
def put(db, table, row, pkey, row_ct_ref=None):
430 1554 aaronmk
    '''Recovers from errors.
431
    Only works under PostgreSQL (uses `INSERT ... RETURNING`)'''
432 471 aaronmk
    try:
433 1554 aaronmk
        cur = try_insert(db, table, row, pkey)
434
        if row_ct_ref != None and cur.rowcount >= 0:
435
            row_ct_ref[0] += cur.rowcount
436
        return value(cur)
437 471 aaronmk
    except DuplicateKeyException, e:
438 1069 aaronmk
        return value(select(db, table, [pkey],
439
            util.dict_subset_right_join(row, e.cols), recover=True))
440 471 aaronmk
441 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
442 830 aaronmk
    '''Recovers from errors'''
443
    try: return value(select(db, table, [pkey], row, 1, recover=True))
444 14 aaronmk
    except StopIteration:
445 40 aaronmk
        if not create: raise
446 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row