Project

General

Profile

1
# Database access
2

    
3
import copy
4
import re
5
import warnings
6

    
7
import exc
8
import dicts
9
import iters
10
from Proxy import Proxy
11
import rand
12
import strings
13
import util
14

    
15
##### Exceptions
16

    
17
def get_cur_query(cur):
18
    if hasattr(cur, 'query'): return cur.query
19
    elif hasattr(cur, '_last_executed'): return cur._last_executed
20
    else: return None
21

    
22
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
23

    
24
class DbException(exc.ExceptionWithCause):
25
    def __init__(self, msg, cause=None, cur=None):
26
        exc.ExceptionWithCause.__init__(self, msg, cause)
27
        if cur != None: _add_cursor_info(self, cur)
28

    
29
class NameException(DbException): pass
30

    
31
class ExceptionWithColumns(DbException):
32
    def __init__(self, cols, cause=None):
33
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
34
        self.cols = cols
35

    
36
class DuplicateKeyException(ExceptionWithColumns): pass
37

    
38
class NullValueException(ExceptionWithColumns): pass
39

    
40
class EmptyRowException(DbException): pass
41

    
42
##### Warnings
43

    
44
class DbWarning(UserWarning): pass
45

    
46
##### Database connections
47

    
48
db_config_names = ['engine', 'host', 'user', 'password', 'database']
49

    
50
db_engines = {
51
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
52
    'PostgreSQL': ('psycopg2', {}),
53
}
54

    
55
DatabaseErrors_set = set([DbException])
56
DatabaseErrors = tuple(DatabaseErrors_set)
57

    
58
def _add_module(module):
59
    DatabaseErrors_set.add(module.DatabaseError)
60
    global DatabaseErrors
61
    DatabaseErrors = tuple(DatabaseErrors_set)
62

    
63
def db_config_str(db_config):
64
    return db_config['engine']+' database '+db_config['database']
65

    
66
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
67

    
68
log_debug_none = lambda msg: None
69

    
70
class DbConn:
71
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none):
72
        self.db_config = db_config
73
        self.serializable = serializable
74
        self.log_debug = log_debug
75
        
76
        self.__db = None
77
        self.query_results = {}
78
    
79
    def __getattr__(self, name):
80
        if name == '__dict__': raise Exception('getting __dict__')
81
        if name == 'db': return self._db()
82
        else: raise AttributeError()
83
    
84
    def __getstate__(self):
85
        state = copy.copy(self.__dict__) # shallow copy
86
        state['log_debug'] = None # don't pickle the debug callback
87
        state['_DbConn__db'] = None # don't pickle the connection
88
        return state
89
    
90
    def _db(self):
91
        if self.__db == None:
92
            # Process db_config
93
            db_config = self.db_config.copy() # don't modify input!
94
            module_name, mappings = db_engines[db_config.pop('engine')]
95
            module = __import__(module_name)
96
            _add_module(module)
97
            for orig, new in mappings.iteritems():
98
                try: util.rename_key(db_config, orig, new)
99
                except KeyError: pass
100
            
101
            # Connect
102
            self.__db = module.connect(**db_config)
103
            
104
            # Configure connection
105
            if self.serializable: run_raw_query(self,
106
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
107
        
108
        return self.__db
109
    
110
    class DbCursor(Proxy):
111
        def __init__(self, outer, cache_results):
112
            Proxy.__init__(self, outer.db.cursor())
113
            if cache_results: self.query_results = outer.query_results
114
            else: self.query_results = None
115
            self.query_lookup = None
116
            self.result = []
117
        
118
        def execute(self, query, params=None):
119
            self.query_lookup = _query_lookup(query, params)
120
            try: return_value = self.inner.execute(query, params)
121
            except Exception, e:
122
                self.result = e # cache the exception as the result
123
                self._cache_result()
124
                raise
125
            finally: self.query = get_cur_query(self.inner)
126
            return return_value
127
        
128
        def fetchone(self):
129
            row = self.inner.fetchone()
130
            if row != None: self.result.append(row)
131
            # otherwise, fetched all rows
132
            else: self._cache_result()
133
            return row
134
        
135
        def _cache_result(self):
136
            is_insert = self._is_insert()
137
            # For inserts, only cache exceptions since inserts are not
138
            # idempotent, but an invalid insert will always be invalid
139
            if self.query_results != None and (not is_insert
140
                or isinstance(self.result, Exception)):
141
                
142
                assert self.query_lookup != None
143
                self.query_results[self.query_lookup] = self.CacheCursor(
144
                    util.dict_subset(dicts.AttrsDictView(self),
145
                    ['query', 'result', 'rowcount', 'description']))
146
        
147
        def _is_insert(self): return self.query.upper().find('INSERT') >= 0
148
        
149
        class CacheCursor:
150
            def __init__(self, cached_result): self.__dict__ = cached_result
151
            
152
            def execute(self):
153
                if isinstance(self.result, Exception): raise self.result
154
                # otherwise, result is a rows list
155
                self.iter = iter(self.result)
156
            
157
            def fetchone(self):
158
                try: return self.iter.next()
159
                except StopIteration: return None
160
    
161
    def run_query(self, query, params=None, cacheable=False):
162
        query_lookup = _query_lookup(query, params)
163
        used_cache = False
164
        try:
165
            try:
166
                if not cacheable: raise KeyError
167
                cur = self.query_results[query_lookup]
168
                used_cache = True
169
            except KeyError:
170
                cur = self.DbCursor(self, cacheable)
171
                try: cur.execute(query, params)
172
                except Exception, e:
173
                    _add_cursor_info(e, cur)
174
                    raise
175
            else: cur.execute()
176
        finally:
177
            if self.log_debug != log_debug_none: # only compute msg if needed
178
                if used_cache: cache_status = 'Cache hit'
179
                elif cacheable: cache_status = 'Cache miss'
180
                else: cache_status = 'Non-cacheable'
181
                self.log_debug(cache_status+': '+strings.one_line(cur.query))
182
        
183
        return cur
184
    
185
    def is_cached(self, query, params=None):
186
        return _query_lookup(query, params) in self.query_results
187

    
188
connect = DbConn
189

    
190
##### Input validation
191

    
192
def check_name(name):
193
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
194
        +'" may contain only alphanumeric characters and _')
195

    
196
def esc_name_by_module(module, name, preserve_case=False):
197
    if module == 'psycopg2':
198
        if preserve_case: quote = '"'
199
        # Don't enclose in quotes because this disables case-insensitivity
200
        else: return name
201
    elif module == 'MySQLdb': quote = '`'
202
    else: raise NotImplementedError("Can't escape name for "+module+' database')
203
    return quote + name.replace(quote, '') + quote
204

    
205
def esc_name_by_engine(engine, name, **kw_args):
206
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
207

    
208
def esc_name(db, name, **kw_args):
209
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
210

    
211
##### Querying
212

    
213
def run_raw_query(db, *args, **kw_args):
214
    '''For args, see DbConn.run_query()'''
215
    return db.run_query(*args, **kw_args)
216

    
217
##### Recoverable querying
218

    
219
def with_savepoint(db, func):
220
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
221
    run_raw_query(db, 'SAVEPOINT '+savepoint)
222
    try: return_val = func()
223
    except:
224
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
225
        raise
226
    else:
227
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
228
        return return_val
229

    
230
def run_query(db, query, params=None, recover=None, cacheable=False):
231
    if recover == None: recover = False
232
    
233
    def run(): return run_raw_query(db, query, params, cacheable)
234
    if recover and not db.is_cached(query, params):
235
        return with_savepoint(db, run)
236
    else: return run() # don't need savepoint if cached
237

    
238
##### Result retrieval
239

    
240
def col_names(cur): return (col[0] for col in cur.description)
241

    
242
def rows(cur): return iter(lambda: cur.fetchone(), None)
243

    
244
def next_row(cur): return rows(cur).next()
245

    
246
def row(cur):
247
    row_iter = rows(cur)
248
    row_ = row_iter.next()
249
    iters.consume_iter(row_iter) # fetch all rows so result will be cached
250
    return row_
251

    
252
def next_value(cur): return next_row(cur)[0]
253

    
254
def value(cur): return row(cur)[0]
255

    
256
def values(cur): return iters.func_iter(lambda: next_value(cur))
257

    
258
def value_or_none(cur):
259
    try: return value(cur)
260
    except StopIteration: return None
261

    
262
##### Basic queries
263

    
264
def select(db, table, fields=None, conds=None, limit=None, start=None,
265
    recover=None, cacheable=True):
266
    '''@param fields Use None to select all fields in the table'''
267
    if conds == None: conds = {}
268
    assert limit == None or type(limit) == int
269
    assert start == None or type(start) == int
270
    check_name(table)
271
    if fields != None: map(check_name, fields)
272
    map(check_name, conds.keys())
273
    
274
    def cond(entry):
275
        col, value = entry
276
        cond_ = esc_name(db, col)+' '
277
        if value == None: cond_ += 'IS'
278
        else: cond_ += '='
279
        cond_ += ' %s'
280
        return cond_
281
    query = 'SELECT '
282
    if fields == None: query += '*'
283
    else: query += ', '.join([esc_name(db, field) for field in fields])
284
    query += ' FROM '+esc_name(db, table)
285
    
286
    missing = True
287
    if conds != {}:
288
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
289
        missing = False
290
    if limit != None: query += ' LIMIT '+str(limit); missing = False
291
    if start != None:
292
        if start != 0: query += ' OFFSET '+str(start)
293
        missing = False
294
    if missing: warnings.warn(DbWarning(
295
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
296
    
297
    return run_query(db, query, conds.values(), recover, cacheable)
298

    
299
def insert(db, table, row, returning=None, recover=None, cacheable=True):
300
    '''@param returning str|None An inserted column (such as pkey) to return'''
301
    check_name(table)
302
    cols = row.keys()
303
    map(check_name, cols)
304
    query = 'INSERT INTO '+table
305
    
306
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
307
        +', '.join(['%s']*len(cols))+')'
308
    else: query += ' DEFAULT VALUES'
309
    
310
    if returning != None:
311
        check_name(returning)
312
        query += ' RETURNING '+returning
313
    
314
    return run_query(db, query, row.values(), recover, cacheable)
315

    
316
def last_insert_id(db):
317
    module = util.root_module(db.db)
318
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
319
    elif module == 'MySQLdb': return db.insert_id()
320
    else: return None
321

    
322
def truncate(db, table):
323
    check_name(table)
324
    return run_raw_query(db, 'TRUNCATE '+table+' CASCADE')
325

    
326
##### Database structure queries
327

    
328
def pkey(db, table, recover=None):
329
    '''Assumed to be first column in table'''
330
    check_name(table)
331
    return col_names(run_query(db,
332
        'SELECT * FROM '+table+' LIMIT 0', recover=recover)).next()
333

    
334
def index_cols(db, table, index):
335
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
336
    automatically created. When you don't know whether something is a UNIQUE
337
    constraint or a UNIQUE index, use this function.'''
338
    check_name(table)
339
    check_name(index)
340
    module = util.root_module(db.db)
341
    if module == 'psycopg2':
342
        return list(values(run_query(db, '''\
343
SELECT attname
344
FROM
345
(
346
        SELECT attnum, attname
347
        FROM pg_index
348
        JOIN pg_class index ON index.oid = indexrelid
349
        JOIN pg_class table_ ON table_.oid = indrelid
350
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
351
        WHERE
352
            table_.relname = %(table)s
353
            AND index.relname = %(index)s
354
    UNION
355
        SELECT attnum, attname
356
        FROM
357
        (
358
            SELECT
359
                indrelid
360
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
361
                    AS indkey
362
            FROM pg_index
363
            JOIN pg_class index ON index.oid = indexrelid
364
            JOIN pg_class table_ ON table_.oid = indrelid
365
            WHERE
366
                table_.relname = %(table)s
367
                AND index.relname = %(index)s
368
        ) s
369
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
370
) s
371
ORDER BY attnum
372
''',
373
            {'table': table, 'index': index}, cacheable=True)))
374
    else: raise NotImplementedError("Can't list index columns for "+module+
375
        ' database')
376

    
377
def constraint_cols(db, table, constraint):
378
    check_name(table)
379
    check_name(constraint)
380
    module = util.root_module(db.db)
381
    if module == 'psycopg2':
382
        return list(values(run_query(db, '''\
383
SELECT attname
384
FROM pg_constraint
385
JOIN pg_class ON pg_class.oid = conrelid
386
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
387
WHERE
388
    relname = %(table)s
389
    AND conname = %(constraint)s
390
ORDER BY attnum
391
''',
392
            {'table': table, 'constraint': constraint})))
393
    else: raise NotImplementedError("Can't list constraint columns for "+module+
394
        ' database')
395

    
396
def tables(db):
397
    module = util.root_module(db.db)
398
    if module == 'psycopg2':
399
        return values(run_query(db, "SELECT tablename from pg_tables "
400
            "WHERE schemaname = 'public' ORDER BY tablename"))
401
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
402
    else: raise NotImplementedError("Can't list tables for "+module+' database')
403

    
404
##### Database management
405

    
406
def empty_db(db):
407
    for table in tables(db): truncate(db, table)
408

    
409
##### Heuristic queries
410

    
411
def try_insert(db, table, row, returning=None):
412
    '''Recovers from errors'''
413
    try: return insert(db, table, row, returning, recover=True)
414
    except Exception, e:
415
        msg = str(e)
416
        match = re.search(r'duplicate key value violates unique constraint '
417
            r'"(([^\W_]+)_[^"]+)"', msg)
418
        if match:
419
            constraint, table = match.groups()
420
            try: cols = index_cols(db, table, constraint)
421
            except NotImplementedError: raise e
422
            else: raise DuplicateKeyException(cols, e)
423
        match = re.search(r'null value in column "(\w+)" violates not-null '
424
            'constraint', msg)
425
        if match: raise NullValueException([match.group(1)], e)
426
        raise # no specific exception raised
427

    
428
def put(db, table, row, pkey, row_ct_ref=None):
429
    '''Recovers from errors.
430
    Only works under PostgreSQL (uses `INSERT ... RETURNING`)'''
431
    try:
432
        cur = try_insert(db, table, row, pkey)
433
        if row_ct_ref != None and cur.rowcount >= 0:
434
            row_ct_ref[0] += cur.rowcount
435
        return value(cur)
436
    except DuplicateKeyException, e:
437
        return value(select(db, table, [pkey],
438
            util.dict_subset_right_join(row, e.cols), recover=True))
439

    
440
def get(db, table, row, pkey, row_ct_ref=None, create=False):
441
    '''Recovers from errors'''
442
    try: return value(select(db, table, [pkey], row, 1, recover=True))
443
    except StopIteration:
444
        if not create: raise
445
        return put(db, table, row, pkey, row_ct_ref) # insert new row
(22-22/33)