Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1960 aaronmk
import lists
11 1889 aaronmk
from Proxy import Proxy
12 1872 aaronmk
import rand
13 862 aaronmk
import strings
14 131 aaronmk
import util
15 11 aaronmk
16 832 aaronmk
##### Exceptions
17
18 135 aaronmk
def get_cur_query(cur):
19
    if hasattr(cur, 'query'): return cur.query
20
    elif hasattr(cur, '_last_executed'): return cur._last_executed
21
    else: return None
22 14 aaronmk
23 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
24 135 aaronmk
25 300 aaronmk
class DbException(exc.ExceptionWithCause):
26 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
27 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
28 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
29
30 360 aaronmk
class NameException(DbException): pass
31
32 468 aaronmk
class ExceptionWithColumns(DbException):
33
    def __init__(self, cols, cause=None):
34
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
35
        self.cols = cols
36 11 aaronmk
37 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
38 13 aaronmk
39 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
40 13 aaronmk
41 89 aaronmk
class EmptyRowException(DbException): pass
42
43 865 aaronmk
##### Warnings
44
45
class DbWarning(UserWarning): pass
46
47 1930 aaronmk
##### Result retrieval
48
49
def col_names(cur): return (col[0] for col in cur.description)
50
51
def rows(cur): return iter(lambda: cur.fetchone(), None)
52
53
def consume_rows(cur):
54
    '''Used to fetch all rows so result will be cached'''
55
    iters.consume_iter(rows(cur))
56
57
def next_row(cur): return rows(cur).next()
58
59
def row(cur):
60
    row_ = next_row(cur)
61
    consume_rows(cur)
62
    return row_
63
64
def next_value(cur): return next_row(cur)[0]
65
66
def value(cur): return row(cur)[0]
67
68
def values(cur): return iters.func_iter(lambda: next_value(cur))
69
70
def value_or_none(cur):
71
    try: return value(cur)
72
    except StopIteration: return None
73
74 1869 aaronmk
##### Database connections
75 1849 aaronmk
76 1926 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database']
77
78 1869 aaronmk
db_engines = {
79
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
80
    'PostgreSQL': ('psycopg2', {}),
81
}
82
83
DatabaseErrors_set = set([DbException])
84
DatabaseErrors = tuple(DatabaseErrors_set)
85
86
def _add_module(module):
87
    DatabaseErrors_set.add(module.DatabaseError)
88
    global DatabaseErrors
89
    DatabaseErrors = tuple(DatabaseErrors_set)
90
91
def db_config_str(db_config):
92
    return db_config['engine']+' database '+db_config['database']
93
94 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
95 1894 aaronmk
96 1901 aaronmk
log_debug_none = lambda msg: None
97
98 1849 aaronmk
class DbConn:
99 1901 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none):
100 1869 aaronmk
        self.db_config = db_config
101
        self.serializable = serializable
102 1901 aaronmk
        self.log_debug = log_debug
103 1869 aaronmk
104
        self.__db = None
105 1889 aaronmk
        self.query_results = {}
106 1869 aaronmk
107
    def __getattr__(self, name):
108
        if name == '__dict__': raise Exception('getting __dict__')
109
        if name == 'db': return self._db()
110
        else: raise AttributeError()
111
112
    def __getstate__(self):
113
        state = copy.copy(self.__dict__) # shallow copy
114 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
115 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
116
        return state
117
118
    def _db(self):
119
        if self.__db == None:
120
            # Process db_config
121
            db_config = self.db_config.copy() # don't modify input!
122
            module_name, mappings = db_engines[db_config.pop('engine')]
123
            module = __import__(module_name)
124
            _add_module(module)
125
            for orig, new in mappings.iteritems():
126
                try: util.rename_key(db_config, orig, new)
127
                except KeyError: pass
128
129
            # Connect
130
            self.__db = module.connect(**db_config)
131
132
            # Configure connection
133
            if self.serializable: run_raw_query(self,
134
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
135
136
        return self.__db
137 1889 aaronmk
138 1891 aaronmk
    class DbCursor(Proxy):
139 1927 aaronmk
        def __init__(self, outer):
140 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
141 1927 aaronmk
            self.query_results = outer.query_results
142 1894 aaronmk
            self.query_lookup = None
143 1891 aaronmk
            self.result = []
144 1889 aaronmk
145 1894 aaronmk
        def execute(self, query, params=None):
146 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
147 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
148 1904 aaronmk
            try: return_value = self.inner.execute(query, params)
149
            except Exception, e:
150
                self.result = e # cache the exception as the result
151
                self._cache_result()
152
                raise
153
            finally: self.query = get_cur_query(self.inner)
154 1930 aaronmk
            # Fetch all rows so result will be cached
155
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
156 1894 aaronmk
            return return_value
157
158 1891 aaronmk
        def fetchone(self):
159
            row = self.inner.fetchone()
160 1899 aaronmk
            if row != None: self.result.append(row)
161
            # otherwise, fetched all rows
162 1904 aaronmk
            else: self._cache_result()
163
            return row
164
165
        def _cache_result(self):
166 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
167
            # idempotent, but an invalid insert will always be invalid
168 1930 aaronmk
            if self.query_results != None and (not self._is_insert
169 1906 aaronmk
                or isinstance(self.result, Exception)):
170
171 1894 aaronmk
                assert self.query_lookup != None
172 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
173
                    util.dict_subset(dicts.AttrsDictView(self),
174
                    ['query', 'result', 'rowcount', 'description']))
175 1906 aaronmk
176 1916 aaronmk
        class CacheCursor:
177
            def __init__(self, cached_result): self.__dict__ = cached_result
178
179 1927 aaronmk
            def execute(self, *args, **kw_args):
180 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
181
                # otherwise, result is a rows list
182
                self.iter = iter(self.result)
183
184
            def fetchone(self):
185
                try: return self.iter.next()
186
                except StopIteration: return None
187 1891 aaronmk
188 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
189 1903 aaronmk
        used_cache = False
190
        try:
191 1927 aaronmk
            # Get cursor
192
            if cacheable:
193
                query_lookup = _query_lookup(query, params)
194
                try:
195
                    cur = self.query_results[query_lookup]
196
                    used_cache = True
197
                except KeyError: cur = self.DbCursor(self)
198
            else: cur = self.db.cursor()
199
200
            # Run query
201
            try: cur.execute(query, params)
202
            except Exception, e:
203
                _add_cursor_info(e, cur)
204
                raise
205 1903 aaronmk
        finally:
206
            if self.log_debug != log_debug_none: # only compute msg if needed
207
                if used_cache: cache_status = 'Cache hit'
208
                elif cacheable: cache_status = 'Cache miss'
209
                else: cache_status = 'Non-cacheable'
210 1927 aaronmk
                self.log_debug(cache_status+': '
211
                    +strings.one_line(get_cur_query(cur)))
212 1903 aaronmk
213
        return cur
214 1914 aaronmk
215
    def is_cached(self, query, params=None):
216
        return _query_lookup(query, params) in self.query_results
217 1849 aaronmk
218 1869 aaronmk
connect = DbConn
219
220 1919 aaronmk
##### Input validation
221
222
def check_name(name):
223
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
224
        +'" may contain only alphanumeric characters and _')
225
226
def esc_name_by_module(module, name, preserve_case=False):
227
    if module == 'psycopg2':
228
        if preserve_case: quote = '"'
229
        # Don't enclose in quotes because this disables case-insensitivity
230
        else: return name
231
    elif module == 'MySQLdb': quote = '`'
232
    else: raise NotImplementedError("Can't escape name for "+module+' database')
233
    return quote + name.replace(quote, '') + quote
234
235
def esc_name_by_engine(engine, name, **kw_args):
236
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
237
238
def esc_name(db, name, **kw_args):
239
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
240
241 1968 aaronmk
def qual_name(db, schema, table):
242 1981 aaronmk
    def esc_name_(name): return esc_name(db, name, preserve_case=True)
243
    return esc_name_(schema)+'.'+esc_name_(table)
244 1968 aaronmk
245 832 aaronmk
##### Querying
246
247 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
248
    '''For args, see DbConn.run_query()'''
249
    return db.run_query(*args, **kw_args)
250 11 aaronmk
251 832 aaronmk
##### Recoverable querying
252 15 aaronmk
253 11 aaronmk
def with_savepoint(db, func):
254 1872 aaronmk
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
255 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
256 11 aaronmk
    try: return_val = func()
257
    except:
258 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
259 11 aaronmk
        raise
260
    else:
261 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
262 11 aaronmk
        return return_val
263
264 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
265 830 aaronmk
    if recover == None: recover = False
266
267 1894 aaronmk
    def run(): return run_raw_query(db, query, params, cacheable)
268 1914 aaronmk
    if recover and not db.is_cached(query, params):
269
        return with_savepoint(db, run)
270
    else: return run() # don't need savepoint if cached
271 830 aaronmk
272 832 aaronmk
##### Basic queries
273
274 1135 aaronmk
def select(db, table, fields=None, conds=None, limit=None, start=None,
275 1981 aaronmk
    recover=None, cacheable=True, table_is_esc=False):
276
    '''
277
    @param fields Use None to select all fields in the table
278
    @param table_is_esc Whether the table name has already been escaped
279
    '''
280 1135 aaronmk
    if conds == None: conds = {}
281 135 aaronmk
    assert limit == None or type(limit) == int
282 865 aaronmk
    assert start == None or type(start) == int
283 1981 aaronmk
    if not table_is_esc: check_name(table)
284 1135 aaronmk
    if fields != None: map(check_name, fields)
285 15 aaronmk
    map(check_name, conds.keys())
286 865 aaronmk
287 11 aaronmk
    def cond(entry):
288 13 aaronmk
        col, value = entry
289 644 aaronmk
        cond_ = esc_name(db, col)+' '
290 11 aaronmk
        if value == None: cond_ += 'IS'
291
        else: cond_ += '='
292
        cond_ += ' %s'
293
        return cond_
294 1135 aaronmk
    query = 'SELECT '
295
    if fields == None: query += '*'
296
    else: query += ', '.join([esc_name(db, field) for field in fields])
297
    query += ' FROM '+esc_name(db, table)
298 865 aaronmk
299
    missing = True
300 89 aaronmk
    if conds != {}:
301
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
302 865 aaronmk
        missing = False
303
    if limit != None: query += ' LIMIT '+str(limit); missing = False
304
    if start != None:
305
        if start != 0: query += ' OFFSET '+str(start)
306
        missing = False
307
    if missing: warnings.warn(DbWarning(
308
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
309
310 1905 aaronmk
    return run_query(db, query, conds.values(), recover, cacheable)
311 11 aaronmk
312 1961 aaronmk
default = object() # tells insert() to use the default value for a column
313
314 1960 aaronmk
def insert(db, table, row, returning=None, recover=None, cacheable=True,
315
    table_is_esc=False):
316
    '''
317
    @param returning str|None An inserted column (such as pkey) to return
318
    @param table_is_esc Whether the table name has already been escaped
319
    '''
320
    if not table_is_esc: check_name(table)
321
    if lists.is_seq(row): cols = None
322
    else:
323
        cols = row.keys()
324
        row = row.values()
325
        map(check_name, cols)
326
    row = list(row) # ensure that "!= []" works
327
328 1961 aaronmk
    # Check for special values
329
    labels = []
330
    values = []
331
    for value in row:
332
        if value == default: labels.append('DEFAULT')
333
        else:
334
            labels.append('%s')
335
            values.append(value)
336
337
    # Build query
338 89 aaronmk
    query = 'INSERT INTO '+table
339 1961 aaronmk
    if values != []:
340 1960 aaronmk
        if cols != None: query += ' ('+', '.join(cols)+')'
341 1961 aaronmk
        query += ' VALUES ('+(', '.join(labels))+')'
342 89 aaronmk
    else: query += ' DEFAULT VALUES'
343 1554 aaronmk
344
    if returning != None:
345
        check_name(returning)
346
        query += ' RETURNING '+returning
347
348 1961 aaronmk
    return run_query(db, query, values, recover, cacheable)
349 11 aaronmk
350 135 aaronmk
def last_insert_id(db):
351 1849 aaronmk
    module = util.root_module(db.db)
352 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
353
    elif module == 'MySQLdb': return db.insert_id()
354
    else: return None
355 13 aaronmk
356 1968 aaronmk
def truncate(db, table, schema='public'):
357
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
358 832 aaronmk
359
##### Database structure queries
360
361 1850 aaronmk
def pkey(db, table, recover=None):
362 832 aaronmk
    '''Assumed to be first column in table'''
363
    check_name(table)
364 1929 aaronmk
    return col_names(select(db, table, limit=0, recover=recover)).next()
365 832 aaronmk
366 853 aaronmk
def index_cols(db, table, index):
367
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
368
    automatically created. When you don't know whether something is a UNIQUE
369
    constraint or a UNIQUE index, use this function.'''
370
    check_name(table)
371
    check_name(index)
372 1909 aaronmk
    module = util.root_module(db.db)
373
    if module == 'psycopg2':
374
        return list(values(run_query(db, '''\
375 853 aaronmk
SELECT attname
376 866 aaronmk
FROM
377
(
378
        SELECT attnum, attname
379
        FROM pg_index
380
        JOIN pg_class index ON index.oid = indexrelid
381
        JOIN pg_class table_ ON table_.oid = indrelid
382
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
383
        WHERE
384
            table_.relname = %(table)s
385
            AND index.relname = %(index)s
386
    UNION
387
        SELECT attnum, attname
388
        FROM
389
        (
390
            SELECT
391
                indrelid
392
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
393
                    AS indkey
394
            FROM pg_index
395
            JOIN pg_class index ON index.oid = indexrelid
396
            JOIN pg_class table_ ON table_.oid = indrelid
397
            WHERE
398
                table_.relname = %(table)s
399
                AND index.relname = %(index)s
400
        ) s
401
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
402
) s
403 853 aaronmk
ORDER BY attnum
404
''',
405 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
406
    else: raise NotImplementedError("Can't list index columns for "+module+
407
        ' database')
408 853 aaronmk
409 464 aaronmk
def constraint_cols(db, table, constraint):
410
    check_name(table)
411
    check_name(constraint)
412 1849 aaronmk
    module = util.root_module(db.db)
413 464 aaronmk
    if module == 'psycopg2':
414
        return list(values(run_query(db, '''\
415
SELECT attname
416
FROM pg_constraint
417
JOIN pg_class ON pg_class.oid = conrelid
418
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
419
WHERE
420
    relname = %(table)s
421
    AND conname = %(constraint)s
422
ORDER BY attnum
423
''',
424
            {'table': table, 'constraint': constraint})))
425
    else: raise NotImplementedError("Can't list constraint columns for "+module+
426
        ' database')
427
428 1968 aaronmk
def tables(db, schema='public', table_like='%'):
429 1849 aaronmk
    module = util.root_module(db.db)
430 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
431 832 aaronmk
    if module == 'psycopg2':
432 1968 aaronmk
        return values(run_query(db, '''\
433
SELECT tablename
434
FROM pg_tables
435
WHERE
436
    schemaname = %(schema)s
437
    AND tablename LIKE %(table_like)s
438
ORDER BY tablename
439
''',
440
            params, cacheable=True))
441
    elif module == 'MySQLdb':
442
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
443
            cacheable=True))
444 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
445 830 aaronmk
446 833 aaronmk
##### Database management
447
448 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
449
    '''For kw_args, see tables()'''
450
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
451 833 aaronmk
452 832 aaronmk
##### Heuristic queries
453
454 1554 aaronmk
def try_insert(db, table, row, returning=None):
455 830 aaronmk
    '''Recovers from errors'''
456 1554 aaronmk
    try: return insert(db, table, row, returning, recover=True)
457 46 aaronmk
    except Exception, e:
458
        msg = str(e)
459 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
460
            r'"(([^\W_]+)_[^"]+)"', msg)
461
        if match:
462
            constraint, table = match.groups()
463 854 aaronmk
            try: cols = index_cols(db, table, constraint)
464 465 aaronmk
            except NotImplementedError: raise e
465 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
466 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
467
            'constraint', msg)
468 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
469 13 aaronmk
        raise # no specific exception raised
470 11 aaronmk
471 471 aaronmk
def put(db, table, row, pkey, row_ct_ref=None):
472 1554 aaronmk
    '''Recovers from errors.
473
    Only works under PostgreSQL (uses `INSERT ... RETURNING`)'''
474 471 aaronmk
    try:
475 1554 aaronmk
        cur = try_insert(db, table, row, pkey)
476
        if row_ct_ref != None and cur.rowcount >= 0:
477
            row_ct_ref[0] += cur.rowcount
478
        return value(cur)
479 471 aaronmk
    except DuplicateKeyException, e:
480 1069 aaronmk
        return value(select(db, table, [pkey],
481
            util.dict_subset_right_join(row, e.cols), recover=True))
482 471 aaronmk
483 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
484 830 aaronmk
    '''Recovers from errors'''
485
    try: return value(select(db, table, [pkey], row, 1, recover=True))
486 14 aaronmk
    except StopIteration:
487 40 aaronmk
        if not create: raise
488 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row