Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1889 aaronmk
from Proxy import Proxy
11 1872 aaronmk
import rand
12 862 aaronmk
import strings
13 131 aaronmk
import util
14 11 aaronmk
15 832 aaronmk
##### Exceptions
16
17 135 aaronmk
def get_cur_query(cur):
18
    if hasattr(cur, 'query'): return cur.query
19
    elif hasattr(cur, '_last_executed'): return cur._last_executed
20
    else: return None
21 14 aaronmk
22 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
23 135 aaronmk
24 300 aaronmk
class DbException(exc.ExceptionWithCause):
25 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
26 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
27 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
28
29 360 aaronmk
class NameException(DbException): pass
30
31 468 aaronmk
class ExceptionWithColumns(DbException):
32
    def __init__(self, cols, cause=None):
33
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
34
        self.cols = cols
35 11 aaronmk
36 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
37 13 aaronmk
38 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
39 13 aaronmk
40 89 aaronmk
class EmptyRowException(DbException): pass
41
42 865 aaronmk
##### Warnings
43
44
class DbWarning(UserWarning): pass
45
46 1930 aaronmk
##### Result retrieval
47
48
def col_names(cur): return (col[0] for col in cur.description)
49
50
def rows(cur): return iter(lambda: cur.fetchone(), None)
51
52
def consume_rows(cur):
53
    '''Used to fetch all rows so result will be cached'''
54
    iters.consume_iter(rows(cur))
55
56
def next_row(cur): return rows(cur).next()
57
58
def row(cur):
59
    row_ = next_row(cur)
60
    consume_rows(cur)
61
    return row_
62
63
def next_value(cur): return next_row(cur)[0]
64
65
def value(cur): return row(cur)[0]
66
67
def values(cur): return iters.func_iter(lambda: next_value(cur))
68
69
def value_or_none(cur):
70
    try: return value(cur)
71
    except StopIteration: return None
72
73 1869 aaronmk
##### Database connections
74 1849 aaronmk
75 1926 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database']
76
77 1869 aaronmk
db_engines = {
78
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
79
    'PostgreSQL': ('psycopg2', {}),
80
}
81
82
DatabaseErrors_set = set([DbException])
83
DatabaseErrors = tuple(DatabaseErrors_set)
84
85
def _add_module(module):
86
    DatabaseErrors_set.add(module.DatabaseError)
87
    global DatabaseErrors
88
    DatabaseErrors = tuple(DatabaseErrors_set)
89
90
def db_config_str(db_config):
91
    return db_config['engine']+' database '+db_config['database']
92
93 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
94 1894 aaronmk
95 1901 aaronmk
log_debug_none = lambda msg: None
96
97 1849 aaronmk
class DbConn:
98 1901 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none):
99 1869 aaronmk
        self.db_config = db_config
100
        self.serializable = serializable
101 1901 aaronmk
        self.log_debug = log_debug
102 1869 aaronmk
103
        self.__db = None
104 1889 aaronmk
        self.query_results = {}
105 1869 aaronmk
106
    def __getattr__(self, name):
107
        if name == '__dict__': raise Exception('getting __dict__')
108
        if name == 'db': return self._db()
109
        else: raise AttributeError()
110
111
    def __getstate__(self):
112
        state = copy.copy(self.__dict__) # shallow copy
113 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
114 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
115
        return state
116
117
    def _db(self):
118
        if self.__db == None:
119
            # Process db_config
120
            db_config = self.db_config.copy() # don't modify input!
121
            module_name, mappings = db_engines[db_config.pop('engine')]
122
            module = __import__(module_name)
123
            _add_module(module)
124
            for orig, new in mappings.iteritems():
125
                try: util.rename_key(db_config, orig, new)
126
                except KeyError: pass
127
128
            # Connect
129
            self.__db = module.connect(**db_config)
130
131
            # Configure connection
132
            if self.serializable: run_raw_query(self,
133
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
134
135
        return self.__db
136 1889 aaronmk
137 1891 aaronmk
    class DbCursor(Proxy):
138 1927 aaronmk
        def __init__(self, outer):
139 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
140 1927 aaronmk
            self.query_results = outer.query_results
141 1894 aaronmk
            self.query_lookup = None
142 1891 aaronmk
            self.result = []
143 1889 aaronmk
144 1894 aaronmk
        def execute(self, query, params=None):
145 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
146 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
147 1904 aaronmk
            try: return_value = self.inner.execute(query, params)
148
            except Exception, e:
149
                self.result = e # cache the exception as the result
150
                self._cache_result()
151
                raise
152
            finally: self.query = get_cur_query(self.inner)
153 1930 aaronmk
            # Fetch all rows so result will be cached
154
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
155 1894 aaronmk
            return return_value
156
157 1891 aaronmk
        def fetchone(self):
158
            row = self.inner.fetchone()
159 1899 aaronmk
            if row != None: self.result.append(row)
160
            # otherwise, fetched all rows
161 1904 aaronmk
            else: self._cache_result()
162
            return row
163
164
        def _cache_result(self):
165 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
166
            # idempotent, but an invalid insert will always be invalid
167 1930 aaronmk
            if self.query_results != None and (not self._is_insert
168 1906 aaronmk
                or isinstance(self.result, Exception)):
169
170 1894 aaronmk
                assert self.query_lookup != None
171 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
172
                    util.dict_subset(dicts.AttrsDictView(self),
173
                    ['query', 'result', 'rowcount', 'description']))
174 1906 aaronmk
175 1916 aaronmk
        class CacheCursor:
176
            def __init__(self, cached_result): self.__dict__ = cached_result
177
178 1927 aaronmk
            def execute(self, *args, **kw_args):
179 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
180
                # otherwise, result is a rows list
181
                self.iter = iter(self.result)
182
183
            def fetchone(self):
184
                try: return self.iter.next()
185
                except StopIteration: return None
186 1891 aaronmk
187 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
188 1903 aaronmk
        used_cache = False
189
        try:
190 1927 aaronmk
            # Get cursor
191
            if cacheable:
192
                query_lookup = _query_lookup(query, params)
193
                try:
194
                    cur = self.query_results[query_lookup]
195
                    used_cache = True
196
                except KeyError: cur = self.DbCursor(self)
197
            else: cur = self.db.cursor()
198
199
            # Run query
200
            try: cur.execute(query, params)
201
            except Exception, e:
202
                _add_cursor_info(e, cur)
203
                raise
204 1903 aaronmk
        finally:
205
            if self.log_debug != log_debug_none: # only compute msg if needed
206
                if used_cache: cache_status = 'Cache hit'
207
                elif cacheable: cache_status = 'Cache miss'
208
                else: cache_status = 'Non-cacheable'
209 1927 aaronmk
                self.log_debug(cache_status+': '
210
                    +strings.one_line(get_cur_query(cur)))
211 1903 aaronmk
212
        return cur
213 1914 aaronmk
214
    def is_cached(self, query, params=None):
215
        return _query_lookup(query, params) in self.query_results
216 1849 aaronmk
217 1869 aaronmk
connect = DbConn
218
219 1919 aaronmk
##### Input validation
220
221
def check_name(name):
222
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
223
        +'" may contain only alphanumeric characters and _')
224
225
def esc_name_by_module(module, name, preserve_case=False):
226
    if module == 'psycopg2':
227
        if preserve_case: quote = '"'
228
        # Don't enclose in quotes because this disables case-insensitivity
229
        else: return name
230
    elif module == 'MySQLdb': quote = '`'
231
    else: raise NotImplementedError("Can't escape name for "+module+' database')
232
    return quote + name.replace(quote, '') + quote
233
234
def esc_name_by_engine(engine, name, **kw_args):
235
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
236
237
def esc_name(db, name, **kw_args):
238
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
239
240 832 aaronmk
##### Querying
241
242 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
243
    '''For args, see DbConn.run_query()'''
244
    return db.run_query(*args, **kw_args)
245 11 aaronmk
246 832 aaronmk
##### Recoverable querying
247 15 aaronmk
248 11 aaronmk
def with_savepoint(db, func):
249 1872 aaronmk
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
250 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
251 11 aaronmk
    try: return_val = func()
252
    except:
253 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
254 11 aaronmk
        raise
255
    else:
256 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
257 11 aaronmk
        return return_val
258
259 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
260 830 aaronmk
    if recover == None: recover = False
261
262 1894 aaronmk
    def run(): return run_raw_query(db, query, params, cacheable)
263 1914 aaronmk
    if recover and not db.is_cached(query, params):
264
        return with_savepoint(db, run)
265
    else: return run() # don't need savepoint if cached
266 830 aaronmk
267 832 aaronmk
##### Basic queries
268
269 1135 aaronmk
def select(db, table, fields=None, conds=None, limit=None, start=None,
270 1894 aaronmk
    recover=None, cacheable=True):
271 1135 aaronmk
    '''@param fields Use None to select all fields in the table'''
272
    if conds == None: conds = {}
273 135 aaronmk
    assert limit == None or type(limit) == int
274 865 aaronmk
    assert start == None or type(start) == int
275 15 aaronmk
    check_name(table)
276 1135 aaronmk
    if fields != None: map(check_name, fields)
277 15 aaronmk
    map(check_name, conds.keys())
278 865 aaronmk
279 11 aaronmk
    def cond(entry):
280 13 aaronmk
        col, value = entry
281 644 aaronmk
        cond_ = esc_name(db, col)+' '
282 11 aaronmk
        if value == None: cond_ += 'IS'
283
        else: cond_ += '='
284
        cond_ += ' %s'
285
        return cond_
286 1135 aaronmk
    query = 'SELECT '
287
    if fields == None: query += '*'
288
    else: query += ', '.join([esc_name(db, field) for field in fields])
289
    query += ' FROM '+esc_name(db, table)
290 865 aaronmk
291
    missing = True
292 89 aaronmk
    if conds != {}:
293
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
294 865 aaronmk
        missing = False
295
    if limit != None: query += ' LIMIT '+str(limit); missing = False
296
    if start != None:
297
        if start != 0: query += ' OFFSET '+str(start)
298
        missing = False
299
    if missing: warnings.warn(DbWarning(
300
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
301
302 1905 aaronmk
    return run_query(db, query, conds.values(), recover, cacheable)
303 11 aaronmk
304 1905 aaronmk
def insert(db, table, row, returning=None, recover=None, cacheable=True):
305 1554 aaronmk
    '''@param returning str|None An inserted column (such as pkey) to return'''
306 11 aaronmk
    check_name(table)
307 13 aaronmk
    cols = row.keys()
308 15 aaronmk
    map(check_name, cols)
309 89 aaronmk
    query = 'INSERT INTO '+table
310 1554 aaronmk
311 89 aaronmk
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
312
        +', '.join(['%s']*len(cols))+')'
313
    else: query += ' DEFAULT VALUES'
314 1554 aaronmk
315
    if returning != None:
316
        check_name(returning)
317
        query += ' RETURNING '+returning
318
319 1905 aaronmk
    return run_query(db, query, row.values(), recover, cacheable)
320 11 aaronmk
321 135 aaronmk
def last_insert_id(db):
322 1849 aaronmk
    module = util.root_module(db.db)
323 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
324
    elif module == 'MySQLdb': return db.insert_id()
325
    else: return None
326 13 aaronmk
327 832 aaronmk
def truncate(db, table):
328
    check_name(table)
329 869 aaronmk
    return run_raw_query(db, 'TRUNCATE '+table+' CASCADE')
330 832 aaronmk
331
##### Database structure queries
332
333 1850 aaronmk
def pkey(db, table, recover=None):
334 832 aaronmk
    '''Assumed to be first column in table'''
335
    check_name(table)
336 1929 aaronmk
    return col_names(select(db, table, limit=0, recover=recover)).next()
337 832 aaronmk
338 853 aaronmk
def index_cols(db, table, index):
339
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
340
    automatically created. When you don't know whether something is a UNIQUE
341
    constraint or a UNIQUE index, use this function.'''
342
    check_name(table)
343
    check_name(index)
344 1909 aaronmk
    module = util.root_module(db.db)
345
    if module == 'psycopg2':
346
        return list(values(run_query(db, '''\
347 853 aaronmk
SELECT attname
348 866 aaronmk
FROM
349
(
350
        SELECT attnum, attname
351
        FROM pg_index
352
        JOIN pg_class index ON index.oid = indexrelid
353
        JOIN pg_class table_ ON table_.oid = indrelid
354
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
355
        WHERE
356
            table_.relname = %(table)s
357
            AND index.relname = %(index)s
358
    UNION
359
        SELECT attnum, attname
360
        FROM
361
        (
362
            SELECT
363
                indrelid
364
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
365
                    AS indkey
366
            FROM pg_index
367
            JOIN pg_class index ON index.oid = indexrelid
368
            JOIN pg_class table_ ON table_.oid = indrelid
369
            WHERE
370
                table_.relname = %(table)s
371
                AND index.relname = %(index)s
372
        ) s
373
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
374
) s
375 853 aaronmk
ORDER BY attnum
376
''',
377 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
378
    else: raise NotImplementedError("Can't list index columns for "+module+
379
        ' database')
380 853 aaronmk
381 464 aaronmk
def constraint_cols(db, table, constraint):
382
    check_name(table)
383
    check_name(constraint)
384 1849 aaronmk
    module = util.root_module(db.db)
385 464 aaronmk
    if module == 'psycopg2':
386
        return list(values(run_query(db, '''\
387
SELECT attname
388
FROM pg_constraint
389
JOIN pg_class ON pg_class.oid = conrelid
390
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
391
WHERE
392
    relname = %(table)s
393
    AND conname = %(constraint)s
394
ORDER BY attnum
395
''',
396
            {'table': table, 'constraint': constraint})))
397
    else: raise NotImplementedError("Can't list constraint columns for "+module+
398
        ' database')
399
400 832 aaronmk
def tables(db):
401 1849 aaronmk
    module = util.root_module(db.db)
402 832 aaronmk
    if module == 'psycopg2':
403
        return values(run_query(db, "SELECT tablename from pg_tables "
404
            "WHERE schemaname = 'public' ORDER BY tablename"))
405
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
406
    else: raise NotImplementedError("Can't list tables for "+module+' database')
407 830 aaronmk
408 833 aaronmk
##### Database management
409
410
def empty_db(db):
411
    for table in tables(db): truncate(db, table)
412
413 832 aaronmk
##### Heuristic queries
414
415 1554 aaronmk
def try_insert(db, table, row, returning=None):
416 830 aaronmk
    '''Recovers from errors'''
417 1554 aaronmk
    try: return insert(db, table, row, returning, recover=True)
418 46 aaronmk
    except Exception, e:
419
        msg = str(e)
420 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
421
            r'"(([^\W_]+)_[^"]+)"', msg)
422
        if match:
423
            constraint, table = match.groups()
424 854 aaronmk
            try: cols = index_cols(db, table, constraint)
425 465 aaronmk
            except NotImplementedError: raise e
426 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
427 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
428
            'constraint', msg)
429 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
430 13 aaronmk
        raise # no specific exception raised
431 11 aaronmk
432 471 aaronmk
def put(db, table, row, pkey, row_ct_ref=None):
433 1554 aaronmk
    '''Recovers from errors.
434
    Only works under PostgreSQL (uses `INSERT ... RETURNING`)'''
435 471 aaronmk
    try:
436 1554 aaronmk
        cur = try_insert(db, table, row, pkey)
437
        if row_ct_ref != None and cur.rowcount >= 0:
438
            row_ct_ref[0] += cur.rowcount
439
        return value(cur)
440 471 aaronmk
    except DuplicateKeyException, e:
441 1069 aaronmk
        return value(select(db, table, [pkey],
442
            util.dict_subset_right_join(row, e.cols), recover=True))
443 471 aaronmk
444 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
445 830 aaronmk
    '''Recovers from errors'''
446
    try: return value(select(db, table, [pkey], row, 1, recover=True))
447 14 aaronmk
    except StopIteration:
448 40 aaronmk
        if not create: raise
449 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row