Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5
import sys
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1893 aaronmk
import iters
10 1889 aaronmk
from Proxy import Proxy
11 1872 aaronmk
import rand
12 862 aaronmk
import strings
13 131 aaronmk
import util
14 11 aaronmk
15 832 aaronmk
##### Exceptions
16
17 135 aaronmk
def get_cur_query(cur):
18
    if hasattr(cur, 'query'): return cur.query
19
    elif hasattr(cur, '_last_executed'): return cur._last_executed
20
    else: return None
21 14 aaronmk
22 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
23 135 aaronmk
24 300 aaronmk
class DbException(exc.ExceptionWithCause):
25 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
26 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
27 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
28
29 360 aaronmk
class NameException(DbException): pass
30
31 468 aaronmk
class ExceptionWithColumns(DbException):
32
    def __init__(self, cols, cause=None):
33
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
34
        self.cols = cols
35 11 aaronmk
36 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
37 13 aaronmk
38 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
39 13 aaronmk
40 89 aaronmk
class EmptyRowException(DbException): pass
41
42 865 aaronmk
##### Warnings
43
44
class DbWarning(UserWarning): pass
45
46 832 aaronmk
##### Input validation
47
48 11 aaronmk
def check_name(name):
49
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
50
        +'" may contain only alphanumeric characters and _')
51
52 643 aaronmk
def esc_name(db, name):
53 1849 aaronmk
    module = util.root_module(db.db)
54 645 aaronmk
    if module == 'psycopg2': return name
55
        # Don't enclose in quotes because this disables case-insensitivity
56 643 aaronmk
    elif module == 'MySQLdb': quote = '`'
57 645 aaronmk
    else: raise NotImplementedError("Can't escape name for "+module+' database')
58 643 aaronmk
    return quote + name.replace(quote, '') + quote
59
60 1869 aaronmk
##### Database connections
61 1849 aaronmk
62 1869 aaronmk
db_engines = {
63
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
64
    'PostgreSQL': ('psycopg2', {}),
65
}
66
67
DatabaseErrors_set = set([DbException])
68
DatabaseErrors = tuple(DatabaseErrors_set)
69
70
def _add_module(module):
71
    DatabaseErrors_set.add(module.DatabaseError)
72
    global DatabaseErrors
73
    DatabaseErrors = tuple(DatabaseErrors_set)
74
75
def db_config_str(db_config):
76
    return db_config['engine']+' database '+db_config['database']
77
78 1894 aaronmk
def _query_lookup(query, params): return (query, util.cast(tuple, params))
79
80 1849 aaronmk
class DbConn:
81 1889 aaronmk
    def __init__(self, db_config, serializable=True, debug=False):
82 1869 aaronmk
        self.db_config = db_config
83
        self.serializable = serializable
84 1889 aaronmk
        self.debug = debug
85 1869 aaronmk
86
        self.__db = None
87 1849 aaronmk
        self.pkeys = {}
88
        self.index_cols = {}
89 1889 aaronmk
        self.query_results = {}
90 1869 aaronmk
91
    def __getattr__(self, name):
92
        if name == '__dict__': raise Exception('getting __dict__')
93
        if name == 'db': return self._db()
94
        else: raise AttributeError()
95
96
    def __getstate__(self):
97
        state = copy.copy(self.__dict__) # shallow copy
98
        state['_DbConn__db'] = None # don't pickle the connection
99
        return state
100
101
    def _db(self):
102
        if self.__db == None:
103
            # Process db_config
104
            db_config = self.db_config.copy() # don't modify input!
105
            module_name, mappings = db_engines[db_config.pop('engine')]
106
            module = __import__(module_name)
107
            _add_module(module)
108
            for orig, new in mappings.iteritems():
109
                try: util.rename_key(db_config, orig, new)
110
                except KeyError: pass
111
112
            # Connect
113
            self.__db = module.connect(**db_config)
114
115
            # Configure connection
116
            if self.serializable: run_raw_query(self,
117
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
118
119
        return self.__db
120 1889 aaronmk
121 1891 aaronmk
    class DbCursor(Proxy):
122
        def __init__(self, outer):
123
            Proxy.__init__(self, outer.db.cursor())
124
            self.outer = outer
125 1894 aaronmk
            self.query_lookup = None
126 1891 aaronmk
            self.result = []
127 1889 aaronmk
128 1894 aaronmk
        def execute(self, query, params=None):
129
            self.query_lookup = _query_lookup(query, params)
130
            return_value = self.inner.execute(query, params)
131
            self.query = get_cur_query(self.inner)
132
            return return_value
133
134 1891 aaronmk
        def fetchone(self):
135
            row = self.inner.fetchone()
136
            if row == None: # fetched all rows
137 1894 aaronmk
                assert self.query_lookup != None
138
                pass #self.outer.query_results[self.query_lookup] = (self.query,
139
                    #self.result)
140 1891 aaronmk
            else: self.result.append(row)
141
            return row
142 1889 aaronmk
143 1891 aaronmk
    class CacheCursor:
144 1894 aaronmk
        def __init__(self, query, result):
145
            self.query = query
146
            self.rowcount = len(result)
147
            self.iter = iter(result)
148 1891 aaronmk
149
        def fetchone(self):
150
            try: return self.iter.next()
151
            except StopIteration: return None
152
153 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
154
        query_lookup = _query_lookup(query, params)
155
        try: actual_query, result = self.query_results[query_lookup]
156 1891 aaronmk
        except KeyError:
157 1894 aaronmk
            cur = self.DbCursor(self)
158
            try: cur.execute(query, params)
159 1891 aaronmk
            except Exception, e:
160
                _add_cursor_info(e, cur)
161
                raise
162
            if self.debug:
163
                sys.stderr.write(strings.one_line(get_cur_query(cur))+'\n')
164
            return cur
165 1894 aaronmk
        else: return self.CacheCursor(actual_query, result)
166 1849 aaronmk
167 1869 aaronmk
connect = DbConn
168
169 832 aaronmk
##### Querying
170
171 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
172
    '''For args, see DbConn.run_query()'''
173
    return db.run_query(*args, **kw_args)
174 11 aaronmk
175 832 aaronmk
##### Recoverable querying
176 15 aaronmk
177 11 aaronmk
def with_savepoint(db, func):
178 1872 aaronmk
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
179 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
180 11 aaronmk
    try: return_val = func()
181
    except:
182 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
183 11 aaronmk
        raise
184
    else:
185 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
186 11 aaronmk
        return return_val
187
188 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
189 830 aaronmk
    if recover == None: recover = False
190
191 1894 aaronmk
    def run(): return run_raw_query(db, query, params, cacheable)
192 830 aaronmk
    if recover: return with_savepoint(db, run)
193
    else: return run()
194
195 832 aaronmk
##### Result retrieval
196
197 1135 aaronmk
def col_names(cur): return (col[0] for col in cur.description)
198 832 aaronmk
199
def rows(cur): return iter(lambda: cur.fetchone(), None)
200
201 1893 aaronmk
def next_row(cur): return rows(cur).next()
202 832 aaronmk
203 1893 aaronmk
def row(cur):
204
    row_iter = rows(cur)
205
    row_ = row_iter.next()
206
    iters.consume_iter(row_iter) # fetch all rows so result will be cached
207
    return row_
208
209
def next_value(cur): return next_row(cur)[0]
210
211 832 aaronmk
def value(cur): return row(cur)[0]
212
213 1893 aaronmk
def values(cur): return iters.func_iter(lambda: next_value(cur))
214 832 aaronmk
215
def value_or_none(cur):
216
    try: return value(cur)
217
    except StopIteration: return None
218
219
##### Basic queries
220
221 1135 aaronmk
def select(db, table, fields=None, conds=None, limit=None, start=None,
222 1894 aaronmk
    recover=None, cacheable=True):
223 1135 aaronmk
    '''@param fields Use None to select all fields in the table'''
224
    if conds == None: conds = {}
225 135 aaronmk
    assert limit == None or type(limit) == int
226 865 aaronmk
    assert start == None or type(start) == int
227 15 aaronmk
    check_name(table)
228 1135 aaronmk
    if fields != None: map(check_name, fields)
229 15 aaronmk
    map(check_name, conds.keys())
230 865 aaronmk
231 11 aaronmk
    def cond(entry):
232 13 aaronmk
        col, value = entry
233 644 aaronmk
        cond_ = esc_name(db, col)+' '
234 11 aaronmk
        if value == None: cond_ += 'IS'
235
        else: cond_ += '='
236
        cond_ += ' %s'
237
        return cond_
238 1135 aaronmk
    query = 'SELECT '
239
    if fields == None: query += '*'
240
    else: query += ', '.join([esc_name(db, field) for field in fields])
241
    query += ' FROM '+esc_name(db, table)
242 865 aaronmk
243
    missing = True
244 89 aaronmk
    if conds != {}:
245
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
246 865 aaronmk
        missing = False
247
    if limit != None: query += ' LIMIT '+str(limit); missing = False
248
    if start != None:
249
        if start != 0: query += ' OFFSET '+str(start)
250
        missing = False
251
    if missing: warnings.warn(DbWarning(
252
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
253
254 1894 aaronmk
    return run_query(db, query, conds.values(), cacheable, recover)
255 11 aaronmk
256 1554 aaronmk
def insert(db, table, row, returning=None, recover=None):
257
    '''@param returning str|None An inserted column (such as pkey) to return'''
258 11 aaronmk
    check_name(table)
259 13 aaronmk
    cols = row.keys()
260 15 aaronmk
    map(check_name, cols)
261 89 aaronmk
    query = 'INSERT INTO '+table
262 1554 aaronmk
263 89 aaronmk
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
264
        +', '.join(['%s']*len(cols))+')'
265
    else: query += ' DEFAULT VALUES'
266 1554 aaronmk
267
    if returning != None:
268
        check_name(returning)
269
        query += ' RETURNING '+returning
270
271 830 aaronmk
    return run_query(db, query, row.values(), recover)
272 11 aaronmk
273 135 aaronmk
def last_insert_id(db):
274 1849 aaronmk
    module = util.root_module(db.db)
275 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
276
    elif module == 'MySQLdb': return db.insert_id()
277
    else: return None
278 13 aaronmk
279 832 aaronmk
def truncate(db, table):
280
    check_name(table)
281 869 aaronmk
    return run_raw_query(db, 'TRUNCATE '+table+' CASCADE')
282 832 aaronmk
283
##### Database structure queries
284
285 1850 aaronmk
def pkey(db, table, recover=None):
286 832 aaronmk
    '''Assumed to be first column in table'''
287
    check_name(table)
288 1850 aaronmk
    if table not in db.pkeys:
289
        db.pkeys[table] = col_names(run_query(db,
290 1135 aaronmk
            'SELECT * FROM '+table+' LIMIT 0', recover=recover)).next()
291 1850 aaronmk
    return db.pkeys[table]
292 832 aaronmk
293 853 aaronmk
def index_cols(db, table, index):
294
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
295
    automatically created. When you don't know whether something is a UNIQUE
296
    constraint or a UNIQUE index, use this function.'''
297
    check_name(table)
298
    check_name(index)
299 1852 aaronmk
    lookup = (table, index)
300
    if lookup not in db.index_cols:
301
        module = util.root_module(db.db)
302
        if module == 'psycopg2':
303
            db.index_cols[lookup] = list(values(run_query(db, '''\
304 853 aaronmk
SELECT attname
305 866 aaronmk
FROM
306
(
307
        SELECT attnum, attname
308
        FROM pg_index
309
        JOIN pg_class index ON index.oid = indexrelid
310
        JOIN pg_class table_ ON table_.oid = indrelid
311
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
312
        WHERE
313
            table_.relname = %(table)s
314
            AND index.relname = %(index)s
315
    UNION
316
        SELECT attnum, attname
317
        FROM
318
        (
319
            SELECT
320
                indrelid
321
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
322
                    AS indkey
323
            FROM pg_index
324
            JOIN pg_class index ON index.oid = indexrelid
325
            JOIN pg_class table_ ON table_.oid = indrelid
326
            WHERE
327
                table_.relname = %(table)s
328
                AND index.relname = %(index)s
329
        ) s
330
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
331
) s
332 853 aaronmk
ORDER BY attnum
333
''',
334 1852 aaronmk
                {'table': table, 'index': index})))
335
        else: raise NotImplementedError("Can't list index columns for "+module+
336
            ' database')
337
    return db.index_cols[lookup]
338 853 aaronmk
339 464 aaronmk
def constraint_cols(db, table, constraint):
340
    check_name(table)
341
    check_name(constraint)
342 1849 aaronmk
    module = util.root_module(db.db)
343 464 aaronmk
    if module == 'psycopg2':
344
        return list(values(run_query(db, '''\
345
SELECT attname
346
FROM pg_constraint
347
JOIN pg_class ON pg_class.oid = conrelid
348
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
349
WHERE
350
    relname = %(table)s
351
    AND conname = %(constraint)s
352
ORDER BY attnum
353
''',
354
            {'table': table, 'constraint': constraint})))
355
    else: raise NotImplementedError("Can't list constraint columns for "+module+
356
        ' database')
357
358 832 aaronmk
def tables(db):
359 1849 aaronmk
    module = util.root_module(db.db)
360 832 aaronmk
    if module == 'psycopg2':
361
        return values(run_query(db, "SELECT tablename from pg_tables "
362
            "WHERE schemaname = 'public' ORDER BY tablename"))
363
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
364
    else: raise NotImplementedError("Can't list tables for "+module+' database')
365 830 aaronmk
366 833 aaronmk
##### Database management
367
368
def empty_db(db):
369
    for table in tables(db): truncate(db, table)
370
371 832 aaronmk
##### Heuristic queries
372
373 1554 aaronmk
def try_insert(db, table, row, returning=None):
374 830 aaronmk
    '''Recovers from errors'''
375 1554 aaronmk
    try: return insert(db, table, row, returning, recover=True)
376 46 aaronmk
    except Exception, e:
377
        msg = str(e)
378 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
379
            r'"(([^\W_]+)_[^"]+)"', msg)
380
        if match:
381
            constraint, table = match.groups()
382 854 aaronmk
            try: cols = index_cols(db, table, constraint)
383 465 aaronmk
            except NotImplementedError: raise e
384 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
385 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
386
            'constraint', msg)
387 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
388 13 aaronmk
        raise # no specific exception raised
389 11 aaronmk
390 471 aaronmk
def put(db, table, row, pkey, row_ct_ref=None):
391 1554 aaronmk
    '''Recovers from errors.
392
    Only works under PostgreSQL (uses `INSERT ... RETURNING`)'''
393 471 aaronmk
    try:
394 1554 aaronmk
        cur = try_insert(db, table, row, pkey)
395
        if row_ct_ref != None and cur.rowcount >= 0:
396
            row_ct_ref[0] += cur.rowcount
397
        return value(cur)
398 471 aaronmk
    except DuplicateKeyException, e:
399 1069 aaronmk
        return value(select(db, table, [pkey],
400
            util.dict_subset_right_join(row, e.cols), recover=True))
401 471 aaronmk
402 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
403 830 aaronmk
    '''Recovers from errors'''
404
    try: return value(select(db, table, [pkey], row, 1, recover=True))
405 14 aaronmk
    except StopIteration:
406 40 aaronmk
        if not create: raise
407 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row