Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1889 aaronmk
from Proxy import Proxy
11 1872 aaronmk
import rand
12 862 aaronmk
import strings
13 131 aaronmk
import util
14 11 aaronmk
15 832 aaronmk
##### Exceptions
16
17 135 aaronmk
def get_cur_query(cur):
18
    if hasattr(cur, 'query'): return cur.query
19
    elif hasattr(cur, '_last_executed'): return cur._last_executed
20
    else: return None
21 14 aaronmk
22 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
23 135 aaronmk
24 300 aaronmk
class DbException(exc.ExceptionWithCause):
25 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
26 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
27 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
28
29 360 aaronmk
class NameException(DbException): pass
30
31 468 aaronmk
class ExceptionWithColumns(DbException):
32
    def __init__(self, cols, cause=None):
33
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
34
        self.cols = cols
35 11 aaronmk
36 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
37 13 aaronmk
38 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
39 13 aaronmk
40 89 aaronmk
class EmptyRowException(DbException): pass
41
42 865 aaronmk
##### Warnings
43
44
class DbWarning(UserWarning): pass
45
46 832 aaronmk
##### Input validation
47
48 11 aaronmk
def check_name(name):
49
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
50
        +'" may contain only alphanumeric characters and _')
51
52 643 aaronmk
def esc_name(db, name):
53 1849 aaronmk
    module = util.root_module(db.db)
54 645 aaronmk
    if module == 'psycopg2': return name
55
        # Don't enclose in quotes because this disables case-insensitivity
56 643 aaronmk
    elif module == 'MySQLdb': quote = '`'
57 645 aaronmk
    else: raise NotImplementedError("Can't escape name for "+module+' database')
58 643 aaronmk
    return quote + name.replace(quote, '') + quote
59
60 1869 aaronmk
##### Database connections
61 1849 aaronmk
62 1869 aaronmk
db_engines = {
63
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
64
    'PostgreSQL': ('psycopg2', {}),
65
}
66
67
DatabaseErrors_set = set([DbException])
68
DatabaseErrors = tuple(DatabaseErrors_set)
69
70
def _add_module(module):
71
    DatabaseErrors_set.add(module.DatabaseError)
72
    global DatabaseErrors
73
    DatabaseErrors = tuple(DatabaseErrors_set)
74
75
def db_config_str(db_config):
76
    return db_config['engine']+' database '+db_config['database']
77
78 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
79 1894 aaronmk
80 1901 aaronmk
log_debug_none = lambda msg: None
81
82 1849 aaronmk
class DbConn:
83 1901 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none):
84 1869 aaronmk
        self.db_config = db_config
85
        self.serializable = serializable
86 1901 aaronmk
        self.log_debug = log_debug
87 1869 aaronmk
88
        self.__db = None
89 1849 aaronmk
        self.pkeys = {}
90 1889 aaronmk
        self.query_results = {}
91 1869 aaronmk
92
    def __getattr__(self, name):
93
        if name == '__dict__': raise Exception('getting __dict__')
94
        if name == 'db': return self._db()
95
        else: raise AttributeError()
96
97
    def __getstate__(self):
98
        state = copy.copy(self.__dict__) # shallow copy
99
        state['_DbConn__db'] = None # don't pickle the connection
100
        return state
101
102
    def _db(self):
103
        if self.__db == None:
104
            # Process db_config
105
            db_config = self.db_config.copy() # don't modify input!
106
            module_name, mappings = db_engines[db_config.pop('engine')]
107
            module = __import__(module_name)
108
            _add_module(module)
109
            for orig, new in mappings.iteritems():
110
                try: util.rename_key(db_config, orig, new)
111
                except KeyError: pass
112
113
            # Connect
114
            self.__db = module.connect(**db_config)
115
116
            # Configure connection
117
            if self.serializable: run_raw_query(self,
118
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
119
120
        return self.__db
121 1889 aaronmk
122 1891 aaronmk
    class DbCursor(Proxy):
123 1899 aaronmk
        def __init__(self, outer, cache_results):
124 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
125 1899 aaronmk
            if cache_results: self.query_results = outer.query_results
126
            else: self.query_results = None
127 1894 aaronmk
            self.query_lookup = None
128 1891 aaronmk
            self.result = []
129 1889 aaronmk
130 1894 aaronmk
        def execute(self, query, params=None):
131
            self.query_lookup = _query_lookup(query, params)
132 1904 aaronmk
            try: return_value = self.inner.execute(query, params)
133
            except Exception, e:
134
                self.result = e # cache the exception as the result
135
                self._cache_result()
136
                raise
137
            finally: self.query = get_cur_query(self.inner)
138 1894 aaronmk
            return return_value
139
140 1891 aaronmk
        def fetchone(self):
141
            row = self.inner.fetchone()
142 1899 aaronmk
            if row != None: self.result.append(row)
143
            # otherwise, fetched all rows
144 1904 aaronmk
            else: self._cache_result()
145
            return row
146
147
        def _cache_result(self):
148 1906 aaronmk
            is_insert = self._is_insert()
149
            # For inserts, only cache exceptions since inserts are not
150
            # idempotent, but an invalid insert will always be invalid
151
            if self.query_results != None and (not is_insert
152
                or isinstance(self.result, Exception)):
153
154 1894 aaronmk
                assert self.query_lookup != None
155 1899 aaronmk
                self.query_results[self.query_lookup] = (self.query,
156 1906 aaronmk
                    self.result, self.rowcount)
157
158
        def _is_insert(self): return self.query.upper().find('INSERT') >= 0
159 1889 aaronmk
160 1891 aaronmk
    class CacheCursor:
161 1906 aaronmk
        def __init__(self, query, result, rowcount):
162 1894 aaronmk
            self.query = query
163 1906 aaronmk
            self.result = result
164
            self.rowcount = rowcount
165
166
        def execute(self):
167
            if isinstance(self.result, Exception): raise self.result
168 1904 aaronmk
            # otherwise, result is a rows list
169 1906 aaronmk
            self.iter = iter(self.result)
170 1891 aaronmk
171
        def fetchone(self):
172
            try: return self.iter.next()
173
            except StopIteration: return None
174
175 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
176
        query_lookup = _query_lookup(query, params)
177 1903 aaronmk
        used_cache = False
178
        try:
179
            try:
180
                if not cacheable: raise KeyError
181 1906 aaronmk
                cached_result = self.query_results[query_lookup]
182 1903 aaronmk
                used_cache = True
183
            except KeyError:
184
                cur = self.DbCursor(self, cacheable)
185
                try: cur.execute(query, params)
186
                except Exception, e:
187
                    _add_cursor_info(e, cur)
188
                    raise
189 1906 aaronmk
            else:
190
                cur = self.CacheCursor(*cached_result)
191
                cur.execute()
192 1903 aaronmk
        finally:
193
            if self.log_debug != log_debug_none: # only compute msg if needed
194
                if used_cache: cache_status = 'Cache hit'
195
                elif cacheable: cache_status = 'Cache miss'
196
                else: cache_status = 'Non-cacheable'
197
                self.log_debug(cache_status+': '+strings.one_line(cur.query))
198
199
        return cur
200 1849 aaronmk
201 1869 aaronmk
connect = DbConn
202
203 832 aaronmk
##### Querying
204
205 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
206
    '''For args, see DbConn.run_query()'''
207
    return db.run_query(*args, **kw_args)
208 11 aaronmk
209 832 aaronmk
##### Recoverable querying
210 15 aaronmk
211 11 aaronmk
def with_savepoint(db, func):
212 1872 aaronmk
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
213 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
214 11 aaronmk
    try: return_val = func()
215
    except:
216 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
217 11 aaronmk
        raise
218
    else:
219 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
220 11 aaronmk
        return return_val
221
222 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
223 830 aaronmk
    if recover == None: recover = False
224
225 1894 aaronmk
    def run(): return run_raw_query(db, query, params, cacheable)
226 830 aaronmk
    if recover: return with_savepoint(db, run)
227
    else: return run()
228
229 832 aaronmk
##### Result retrieval
230
231 1135 aaronmk
def col_names(cur): return (col[0] for col in cur.description)
232 832 aaronmk
233
def rows(cur): return iter(lambda: cur.fetchone(), None)
234
235 1893 aaronmk
def next_row(cur): return rows(cur).next()
236 832 aaronmk
237 1893 aaronmk
def row(cur):
238
    row_iter = rows(cur)
239
    row_ = row_iter.next()
240
    iters.consume_iter(row_iter) # fetch all rows so result will be cached
241
    return row_
242
243
def next_value(cur): return next_row(cur)[0]
244
245 832 aaronmk
def value(cur): return row(cur)[0]
246
247 1893 aaronmk
def values(cur): return iters.func_iter(lambda: next_value(cur))
248 832 aaronmk
249
def value_or_none(cur):
250
    try: return value(cur)
251
    except StopIteration: return None
252
253
##### Basic queries
254
255 1135 aaronmk
def select(db, table, fields=None, conds=None, limit=None, start=None,
256 1894 aaronmk
    recover=None, cacheable=True):
257 1135 aaronmk
    '''@param fields Use None to select all fields in the table'''
258
    if conds == None: conds = {}
259 135 aaronmk
    assert limit == None or type(limit) == int
260 865 aaronmk
    assert start == None or type(start) == int
261 15 aaronmk
    check_name(table)
262 1135 aaronmk
    if fields != None: map(check_name, fields)
263 15 aaronmk
    map(check_name, conds.keys())
264 865 aaronmk
265 11 aaronmk
    def cond(entry):
266 13 aaronmk
        col, value = entry
267 644 aaronmk
        cond_ = esc_name(db, col)+' '
268 11 aaronmk
        if value == None: cond_ += 'IS'
269
        else: cond_ += '='
270
        cond_ += ' %s'
271
        return cond_
272 1135 aaronmk
    query = 'SELECT '
273
    if fields == None: query += '*'
274
    else: query += ', '.join([esc_name(db, field) for field in fields])
275
    query += ' FROM '+esc_name(db, table)
276 865 aaronmk
277
    missing = True
278 89 aaronmk
    if conds != {}:
279
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
280 865 aaronmk
        missing = False
281
    if limit != None: query += ' LIMIT '+str(limit); missing = False
282
    if start != None:
283
        if start != 0: query += ' OFFSET '+str(start)
284
        missing = False
285
    if missing: warnings.warn(DbWarning(
286
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
287
288 1905 aaronmk
    return run_query(db, query, conds.values(), recover, cacheable)
289 11 aaronmk
290 1905 aaronmk
def insert(db, table, row, returning=None, recover=None, cacheable=True):
291 1554 aaronmk
    '''@param returning str|None An inserted column (such as pkey) to return'''
292 11 aaronmk
    check_name(table)
293 13 aaronmk
    cols = row.keys()
294 15 aaronmk
    map(check_name, cols)
295 89 aaronmk
    query = 'INSERT INTO '+table
296 1554 aaronmk
297 89 aaronmk
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
298
        +', '.join(['%s']*len(cols))+')'
299
    else: query += ' DEFAULT VALUES'
300 1554 aaronmk
301
    if returning != None:
302
        check_name(returning)
303
        query += ' RETURNING '+returning
304
305 1905 aaronmk
    return run_query(db, query, row.values(), recover, cacheable)
306 11 aaronmk
307 135 aaronmk
def last_insert_id(db):
308 1849 aaronmk
    module = util.root_module(db.db)
309 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
310
    elif module == 'MySQLdb': return db.insert_id()
311
    else: return None
312 13 aaronmk
313 832 aaronmk
def truncate(db, table):
314
    check_name(table)
315 869 aaronmk
    return run_raw_query(db, 'TRUNCATE '+table+' CASCADE')
316 832 aaronmk
317
##### Database structure queries
318
319 1850 aaronmk
def pkey(db, table, recover=None):
320 832 aaronmk
    '''Assumed to be first column in table'''
321
    check_name(table)
322 1850 aaronmk
    if table not in db.pkeys:
323
        db.pkeys[table] = col_names(run_query(db,
324 1135 aaronmk
            'SELECT * FROM '+table+' LIMIT 0', recover=recover)).next()
325 1850 aaronmk
    return db.pkeys[table]
326 832 aaronmk
327 853 aaronmk
def index_cols(db, table, index):
328
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
329
    automatically created. When you don't know whether something is a UNIQUE
330
    constraint or a UNIQUE index, use this function.'''
331
    check_name(table)
332
    check_name(index)
333 1909 aaronmk
    module = util.root_module(db.db)
334
    if module == 'psycopg2':
335
        return list(values(run_query(db, '''\
336 853 aaronmk
SELECT attname
337 866 aaronmk
FROM
338
(
339
        SELECT attnum, attname
340
        FROM pg_index
341
        JOIN pg_class index ON index.oid = indexrelid
342
        JOIN pg_class table_ ON table_.oid = indrelid
343
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
344
        WHERE
345
            table_.relname = %(table)s
346
            AND index.relname = %(index)s
347
    UNION
348
        SELECT attnum, attname
349
        FROM
350
        (
351
            SELECT
352
                indrelid
353
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
354
                    AS indkey
355
            FROM pg_index
356
            JOIN pg_class index ON index.oid = indexrelid
357
            JOIN pg_class table_ ON table_.oid = indrelid
358
            WHERE
359
                table_.relname = %(table)s
360
                AND index.relname = %(index)s
361
        ) s
362
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
363
) s
364 853 aaronmk
ORDER BY attnum
365
''',
366 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
367
    else: raise NotImplementedError("Can't list index columns for "+module+
368
        ' database')
369 853 aaronmk
370 464 aaronmk
def constraint_cols(db, table, constraint):
371
    check_name(table)
372
    check_name(constraint)
373 1849 aaronmk
    module = util.root_module(db.db)
374 464 aaronmk
    if module == 'psycopg2':
375
        return list(values(run_query(db, '''\
376
SELECT attname
377
FROM pg_constraint
378
JOIN pg_class ON pg_class.oid = conrelid
379
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
380
WHERE
381
    relname = %(table)s
382
    AND conname = %(constraint)s
383
ORDER BY attnum
384
''',
385
            {'table': table, 'constraint': constraint})))
386
    else: raise NotImplementedError("Can't list constraint columns for "+module+
387
        ' database')
388
389 832 aaronmk
def tables(db):
390 1849 aaronmk
    module = util.root_module(db.db)
391 832 aaronmk
    if module == 'psycopg2':
392
        return values(run_query(db, "SELECT tablename from pg_tables "
393
            "WHERE schemaname = 'public' ORDER BY tablename"))
394
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
395
    else: raise NotImplementedError("Can't list tables for "+module+' database')
396 830 aaronmk
397 833 aaronmk
##### Database management
398
399
def empty_db(db):
400
    for table in tables(db): truncate(db, table)
401
402 832 aaronmk
##### Heuristic queries
403
404 1554 aaronmk
def try_insert(db, table, row, returning=None):
405 830 aaronmk
    '''Recovers from errors'''
406 1554 aaronmk
    try: return insert(db, table, row, returning, recover=True)
407 46 aaronmk
    except Exception, e:
408
        msg = str(e)
409 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
410
            r'"(([^\W_]+)_[^"]+)"', msg)
411
        if match:
412
            constraint, table = match.groups()
413 854 aaronmk
            try: cols = index_cols(db, table, constraint)
414 465 aaronmk
            except NotImplementedError: raise e
415 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
416 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
417
            'constraint', msg)
418 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
419 13 aaronmk
        raise # no specific exception raised
420 11 aaronmk
421 471 aaronmk
def put(db, table, row, pkey, row_ct_ref=None):
422 1554 aaronmk
    '''Recovers from errors.
423
    Only works under PostgreSQL (uses `INSERT ... RETURNING`)'''
424 471 aaronmk
    try:
425 1554 aaronmk
        cur = try_insert(db, table, row, pkey)
426
        if row_ct_ref != None and cur.rowcount >= 0:
427
            row_ct_ref[0] += cur.rowcount
428
        return value(cur)
429 471 aaronmk
    except DuplicateKeyException, e:
430 1069 aaronmk
        return value(select(db, table, [pkey],
431
            util.dict_subset_right_join(row, e.cols), recover=True))
432 471 aaronmk
433 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
434 830 aaronmk
    '''Recovers from errors'''
435
    try: return value(select(db, table, [pkey], row, 1, recover=True))
436 14 aaronmk
    except StopIteration:
437 40 aaronmk
        if not create: raise
438 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row