Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1960 aaronmk
import lists
11 1889 aaronmk
from Proxy import Proxy
12 1872 aaronmk
import rand
13 862 aaronmk
import strings
14 131 aaronmk
import util
15 11 aaronmk
16 832 aaronmk
##### Exceptions
17
18 135 aaronmk
def get_cur_query(cur):
19
    if hasattr(cur, 'query'): return cur.query
20
    elif hasattr(cur, '_last_executed'): return cur._last_executed
21
    else: return None
22 14 aaronmk
23 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
24 135 aaronmk
25 300 aaronmk
class DbException(exc.ExceptionWithCause):
26 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
27 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
28 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
29
30 360 aaronmk
class NameException(DbException): pass
31
32 468 aaronmk
class ExceptionWithColumns(DbException):
33
    def __init__(self, cols, cause=None):
34
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
35
        self.cols = cols
36 11 aaronmk
37 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
38 13 aaronmk
39 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
40 13 aaronmk
41 89 aaronmk
class EmptyRowException(DbException): pass
42
43 865 aaronmk
##### Warnings
44
45
class DbWarning(UserWarning): pass
46
47 1930 aaronmk
##### Result retrieval
48
49
def col_names(cur): return (col[0] for col in cur.description)
50
51
def rows(cur): return iter(lambda: cur.fetchone(), None)
52
53
def consume_rows(cur):
54
    '''Used to fetch all rows so result will be cached'''
55
    iters.consume_iter(rows(cur))
56
57
def next_row(cur): return rows(cur).next()
58
59
def row(cur):
60
    row_ = next_row(cur)
61
    consume_rows(cur)
62
    return row_
63
64
def next_value(cur): return next_row(cur)[0]
65
66
def value(cur): return row(cur)[0]
67
68
def values(cur): return iters.func_iter(lambda: next_value(cur))
69
70
def value_or_none(cur):
71
    try: return value(cur)
72
    except StopIteration: return None
73
74 1869 aaronmk
##### Database connections
75 1849 aaronmk
76 1926 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database']
77
78 1869 aaronmk
db_engines = {
79
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
80
    'PostgreSQL': ('psycopg2', {}),
81
}
82
83
DatabaseErrors_set = set([DbException])
84
DatabaseErrors = tuple(DatabaseErrors_set)
85
86
def _add_module(module):
87
    DatabaseErrors_set.add(module.DatabaseError)
88
    global DatabaseErrors
89
    DatabaseErrors = tuple(DatabaseErrors_set)
90
91
def db_config_str(db_config):
92
    return db_config['engine']+' database '+db_config['database']
93
94 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
95 1894 aaronmk
96 1901 aaronmk
log_debug_none = lambda msg: None
97
98 1849 aaronmk
class DbConn:
99 2047 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none,
100
        caching=False):
101 1869 aaronmk
        self.db_config = db_config
102
        self.serializable = serializable
103 1901 aaronmk
        self.log_debug = log_debug
104 2047 aaronmk
        self.caching = caching
105 1869 aaronmk
106
        self.__db = None
107 1889 aaronmk
        self.query_results = {}
108 1869 aaronmk
109
    def __getattr__(self, name):
110
        if name == '__dict__': raise Exception('getting __dict__')
111
        if name == 'db': return self._db()
112
        else: raise AttributeError()
113
114
    def __getstate__(self):
115
        state = copy.copy(self.__dict__) # shallow copy
116 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
117 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
118
        return state
119
120
    def _db(self):
121
        if self.__db == None:
122
            # Process db_config
123
            db_config = self.db_config.copy() # don't modify input!
124
            module_name, mappings = db_engines[db_config.pop('engine')]
125
            module = __import__(module_name)
126
            _add_module(module)
127
            for orig, new in mappings.iteritems():
128
                try: util.rename_key(db_config, orig, new)
129
                except KeyError: pass
130
131
            # Connect
132
            self.__db = module.connect(**db_config)
133
134
            # Configure connection
135
            if self.serializable: run_raw_query(self,
136
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
137
138
        return self.__db
139 1889 aaronmk
140 1891 aaronmk
    class DbCursor(Proxy):
141 1927 aaronmk
        def __init__(self, outer):
142 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
143 1927 aaronmk
            self.query_results = outer.query_results
144 1894 aaronmk
            self.query_lookup = None
145 1891 aaronmk
            self.result = []
146 1889 aaronmk
147 1894 aaronmk
        def execute(self, query, params=None):
148 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
149 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
150 1904 aaronmk
            try: return_value = self.inner.execute(query, params)
151
            except Exception, e:
152
                self.result = e # cache the exception as the result
153
                self._cache_result()
154
                raise
155
            finally: self.query = get_cur_query(self.inner)
156 1930 aaronmk
            # Fetch all rows so result will be cached
157
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
158 1894 aaronmk
            return return_value
159
160 1891 aaronmk
        def fetchone(self):
161
            row = self.inner.fetchone()
162 1899 aaronmk
            if row != None: self.result.append(row)
163
            # otherwise, fetched all rows
164 1904 aaronmk
            else: self._cache_result()
165
            return row
166
167
        def _cache_result(self):
168 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
169
            # idempotent, but an invalid insert will always be invalid
170 1930 aaronmk
            if self.query_results != None and (not self._is_insert
171 1906 aaronmk
                or isinstance(self.result, Exception)):
172
173 1894 aaronmk
                assert self.query_lookup != None
174 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
175
                    util.dict_subset(dicts.AttrsDictView(self),
176
                    ['query', 'result', 'rowcount', 'description']))
177 1906 aaronmk
178 1916 aaronmk
        class CacheCursor:
179
            def __init__(self, cached_result): self.__dict__ = cached_result
180
181 1927 aaronmk
            def execute(self, *args, **kw_args):
182 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
183
                # otherwise, result is a rows list
184
                self.iter = iter(self.result)
185
186
            def fetchone(self):
187
                try: return self.iter.next()
188
                except StopIteration: return None
189 1891 aaronmk
190 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
191 2047 aaronmk
        if not self.caching: cacheable = False
192 1903 aaronmk
        used_cache = False
193
        try:
194 1927 aaronmk
            # Get cursor
195
            if cacheable:
196
                query_lookup = _query_lookup(query, params)
197
                try:
198
                    cur = self.query_results[query_lookup]
199
                    used_cache = True
200
                except KeyError: cur = self.DbCursor(self)
201
            else: cur = self.db.cursor()
202
203
            # Run query
204
            try: cur.execute(query, params)
205
            except Exception, e:
206
                _add_cursor_info(e, cur)
207
                raise
208 1903 aaronmk
        finally:
209
            if self.log_debug != log_debug_none: # only compute msg if needed
210
                if used_cache: cache_status = 'Cache hit'
211
                elif cacheable: cache_status = 'Cache miss'
212
                else: cache_status = 'Non-cacheable'
213 1927 aaronmk
                self.log_debug(cache_status+': '
214
                    +strings.one_line(get_cur_query(cur)))
215 1903 aaronmk
216
        return cur
217 1914 aaronmk
218
    def is_cached(self, query, params=None):
219
        return _query_lookup(query, params) in self.query_results
220 1849 aaronmk
221 1869 aaronmk
connect = DbConn
222
223 1919 aaronmk
##### Input validation
224
225
def check_name(name):
226
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
227
        +'" may contain only alphanumeric characters and _')
228
229
def esc_name_by_module(module, name, preserve_case=False):
230
    if module == 'psycopg2':
231
        if preserve_case: quote = '"'
232
        # Don't enclose in quotes because this disables case-insensitivity
233
        else: return name
234
    elif module == 'MySQLdb': quote = '`'
235
    else: raise NotImplementedError("Can't escape name for "+module+' database')
236
    return quote + name.replace(quote, '') + quote
237
238
def esc_name_by_engine(engine, name, **kw_args):
239
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
240
241
def esc_name(db, name, **kw_args):
242
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
243
244 1968 aaronmk
def qual_name(db, schema, table):
245 1981 aaronmk
    def esc_name_(name): return esc_name(db, name, preserve_case=True)
246
    return esc_name_(schema)+'.'+esc_name_(table)
247 1968 aaronmk
248 832 aaronmk
##### Querying
249
250 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
251
    '''For args, see DbConn.run_query()'''
252
    return db.run_query(*args, **kw_args)
253 11 aaronmk
254 832 aaronmk
##### Recoverable querying
255 15 aaronmk
256 11 aaronmk
def with_savepoint(db, func):
257 1872 aaronmk
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
258 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
259 11 aaronmk
    try: return_val = func()
260
    except:
261 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
262 11 aaronmk
        raise
263
    else:
264 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
265 11 aaronmk
        return return_val
266
267 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
268 830 aaronmk
    if recover == None: recover = False
269
270 1894 aaronmk
    def run(): return run_raw_query(db, query, params, cacheable)
271 1914 aaronmk
    if recover and not db.is_cached(query, params):
272
        return with_savepoint(db, run)
273
    else: return run() # don't need savepoint if cached
274 830 aaronmk
275 832 aaronmk
##### Basic queries
276
277 1135 aaronmk
def select(db, table, fields=None, conds=None, limit=None, start=None,
278 1981 aaronmk
    recover=None, cacheable=True, table_is_esc=False):
279
    '''
280
    @param fields Use None to select all fields in the table
281
    @param table_is_esc Whether the table name has already been escaped
282
    '''
283 1135 aaronmk
    if conds == None: conds = {}
284 135 aaronmk
    assert limit == None or type(limit) == int
285 865 aaronmk
    assert start == None or type(start) == int
286 1981 aaronmk
    if not table_is_esc: check_name(table)
287 1135 aaronmk
    if fields != None: map(check_name, fields)
288 15 aaronmk
    map(check_name, conds.keys())
289 865 aaronmk
290 11 aaronmk
    def cond(entry):
291 13 aaronmk
        col, value = entry
292 644 aaronmk
        cond_ = esc_name(db, col)+' '
293 11 aaronmk
        if value == None: cond_ += 'IS'
294
        else: cond_ += '='
295
        cond_ += ' %s'
296
        return cond_
297 1135 aaronmk
    query = 'SELECT '
298
    if fields == None: query += '*'
299
    else: query += ', '.join([esc_name(db, field) for field in fields])
300
    query += ' FROM '+esc_name(db, table)
301 865 aaronmk
302
    missing = True
303 89 aaronmk
    if conds != {}:
304
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
305 865 aaronmk
        missing = False
306
    if limit != None: query += ' LIMIT '+str(limit); missing = False
307
    if start != None:
308
        if start != 0: query += ' OFFSET '+str(start)
309
        missing = False
310
    if missing: warnings.warn(DbWarning(
311
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
312
313 1905 aaronmk
    return run_query(db, query, conds.values(), recover, cacheable)
314 11 aaronmk
315 1961 aaronmk
default = object() # tells insert() to use the default value for a column
316
317 1960 aaronmk
def insert(db, table, row, returning=None, recover=None, cacheable=True,
318
    table_is_esc=False):
319
    '''
320
    @param returning str|None An inserted column (such as pkey) to return
321
    @param table_is_esc Whether the table name has already been escaped
322
    '''
323
    if not table_is_esc: check_name(table)
324
    if lists.is_seq(row): cols = None
325
    else:
326
        cols = row.keys()
327
        row = row.values()
328
        map(check_name, cols)
329
    row = list(row) # ensure that "!= []" works
330
331 1961 aaronmk
    # Check for special values
332
    labels = []
333
    values = []
334
    for value in row:
335
        if value == default: labels.append('DEFAULT')
336
        else:
337
            labels.append('%s')
338
            values.append(value)
339
340
    # Build query
341 89 aaronmk
    query = 'INSERT INTO '+table
342 1961 aaronmk
    if values != []:
343 1960 aaronmk
        if cols != None: query += ' ('+', '.join(cols)+')'
344 1961 aaronmk
        query += ' VALUES ('+(', '.join(labels))+')'
345 89 aaronmk
    else: query += ' DEFAULT VALUES'
346 1554 aaronmk
347
    if returning != None:
348
        check_name(returning)
349
        query += ' RETURNING '+returning
350
351 1961 aaronmk
    return run_query(db, query, values, recover, cacheable)
352 11 aaronmk
353 135 aaronmk
def last_insert_id(db):
354 1849 aaronmk
    module = util.root_module(db.db)
355 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
356
    elif module == 'MySQLdb': return db.insert_id()
357
    else: return None
358 13 aaronmk
359 1968 aaronmk
def truncate(db, table, schema='public'):
360
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
361 832 aaronmk
362
##### Database structure queries
363
364 1850 aaronmk
def pkey(db, table, recover=None):
365 832 aaronmk
    '''Assumed to be first column in table'''
366
    check_name(table)
367 1929 aaronmk
    return col_names(select(db, table, limit=0, recover=recover)).next()
368 832 aaronmk
369 853 aaronmk
def index_cols(db, table, index):
370
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
371
    automatically created. When you don't know whether something is a UNIQUE
372
    constraint or a UNIQUE index, use this function.'''
373
    check_name(table)
374
    check_name(index)
375 1909 aaronmk
    module = util.root_module(db.db)
376
    if module == 'psycopg2':
377
        return list(values(run_query(db, '''\
378 853 aaronmk
SELECT attname
379 866 aaronmk
FROM
380
(
381
        SELECT attnum, attname
382
        FROM pg_index
383
        JOIN pg_class index ON index.oid = indexrelid
384
        JOIN pg_class table_ ON table_.oid = indrelid
385
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
386
        WHERE
387
            table_.relname = %(table)s
388
            AND index.relname = %(index)s
389
    UNION
390
        SELECT attnum, attname
391
        FROM
392
        (
393
            SELECT
394
                indrelid
395
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
396
                    AS indkey
397
            FROM pg_index
398
            JOIN pg_class index ON index.oid = indexrelid
399
            JOIN pg_class table_ ON table_.oid = indrelid
400
            WHERE
401
                table_.relname = %(table)s
402
                AND index.relname = %(index)s
403
        ) s
404
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
405
) s
406 853 aaronmk
ORDER BY attnum
407
''',
408 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
409
    else: raise NotImplementedError("Can't list index columns for "+module+
410
        ' database')
411 853 aaronmk
412 464 aaronmk
def constraint_cols(db, table, constraint):
413
    check_name(table)
414
    check_name(constraint)
415 1849 aaronmk
    module = util.root_module(db.db)
416 464 aaronmk
    if module == 'psycopg2':
417
        return list(values(run_query(db, '''\
418
SELECT attname
419
FROM pg_constraint
420
JOIN pg_class ON pg_class.oid = conrelid
421
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
422
WHERE
423
    relname = %(table)s
424
    AND conname = %(constraint)s
425
ORDER BY attnum
426
''',
427
            {'table': table, 'constraint': constraint})))
428
    else: raise NotImplementedError("Can't list constraint columns for "+module+
429
        ' database')
430
431 1968 aaronmk
def tables(db, schema='public', table_like='%'):
432 1849 aaronmk
    module = util.root_module(db.db)
433 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
434 832 aaronmk
    if module == 'psycopg2':
435 1968 aaronmk
        return values(run_query(db, '''\
436
SELECT tablename
437
FROM pg_tables
438
WHERE
439
    schemaname = %(schema)s
440
    AND tablename LIKE %(table_like)s
441
ORDER BY tablename
442
''',
443
            params, cacheable=True))
444
    elif module == 'MySQLdb':
445
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
446
            cacheable=True))
447 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
448 830 aaronmk
449 833 aaronmk
##### Database management
450
451 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
452
    '''For kw_args, see tables()'''
453
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
454 833 aaronmk
455 832 aaronmk
##### Heuristic queries
456
457 1554 aaronmk
def try_insert(db, table, row, returning=None):
458 830 aaronmk
    '''Recovers from errors'''
459 1554 aaronmk
    try: return insert(db, table, row, returning, recover=True)
460 46 aaronmk
    except Exception, e:
461
        msg = str(e)
462 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
463
            r'"(([^\W_]+)_[^"]+)"', msg)
464
        if match:
465
            constraint, table = match.groups()
466 854 aaronmk
            try: cols = index_cols(db, table, constraint)
467 465 aaronmk
            except NotImplementedError: raise e
468 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
469 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
470
            'constraint', msg)
471 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
472 13 aaronmk
        raise # no specific exception raised
473 11 aaronmk
474 471 aaronmk
def put(db, table, row, pkey, row_ct_ref=None):
475 1554 aaronmk
    '''Recovers from errors.
476
    Only works under PostgreSQL (uses `INSERT ... RETURNING`)'''
477 471 aaronmk
    try:
478 1554 aaronmk
        cur = try_insert(db, table, row, pkey)
479
        if row_ct_ref != None and cur.rowcount >= 0:
480
            row_ct_ref[0] += cur.rowcount
481
        return value(cur)
482 471 aaronmk
    except DuplicateKeyException, e:
483 1069 aaronmk
        return value(select(db, table, [pkey],
484
            util.dict_subset_right_join(row, e.cols), recover=True))
485 471 aaronmk
486 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
487 830 aaronmk
    '''Recovers from errors'''
488
    try: return value(select(db, table, [pkey], row, 1, recover=True))
489 14 aaronmk
    except StopIteration:
490 40 aaronmk
        if not create: raise
491 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row