Project

General

Profile

1
# Database access
2

    
3
import copy
4
import re
5
import warnings
6

    
7
import exc
8
import dicts
9
import iters
10
from Proxy import Proxy
11
import rand
12
import strings
13
import util
14

    
15
##### Exceptions
16

    
17
def get_cur_query(cur):
18
    if hasattr(cur, 'query'): return cur.query
19
    elif hasattr(cur, '_last_executed'): return cur._last_executed
20
    else: return None
21

    
22
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
23

    
24
class DbException(exc.ExceptionWithCause):
25
    def __init__(self, msg, cause=None, cur=None):
26
        exc.ExceptionWithCause.__init__(self, msg, cause)
27
        if cur != None: _add_cursor_info(self, cur)
28

    
29
class NameException(DbException): pass
30

    
31
class ExceptionWithColumns(DbException):
32
    def __init__(self, cols, cause=None):
33
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
34
        self.cols = cols
35

    
36
class DuplicateKeyException(ExceptionWithColumns): pass
37

    
38
class NullValueException(ExceptionWithColumns): pass
39

    
40
class EmptyRowException(DbException): pass
41

    
42
##### Warnings
43

    
44
class DbWarning(UserWarning): pass
45

    
46
##### Database connections
47

    
48
db_engines = {
49
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
50
    'PostgreSQL': ('psycopg2', {}),
51
}
52

    
53
DatabaseErrors_set = set([DbException])
54
DatabaseErrors = tuple(DatabaseErrors_set)
55

    
56
def _add_module(module):
57
    DatabaseErrors_set.add(module.DatabaseError)
58
    global DatabaseErrors
59
    DatabaseErrors = tuple(DatabaseErrors_set)
60

    
61
def db_config_str(db_config):
62
    return db_config['engine']+' database '+db_config['database']
63

    
64
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
65

    
66
log_debug_none = lambda msg: None
67

    
68
class DbConn:
69
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none):
70
        self.db_config = db_config
71
        self.serializable = serializable
72
        self.log_debug = log_debug
73
        
74
        self.__db = None
75
        self.query_results = {}
76
    
77
    def __getattr__(self, name):
78
        if name == '__dict__': raise Exception('getting __dict__')
79
        if name == 'db': return self._db()
80
        else: raise AttributeError()
81
    
82
    def __getstate__(self):
83
        state = copy.copy(self.__dict__) # shallow copy
84
        state['log_debug'] = None # don't pickle the debug callback
85
        state['_DbConn__db'] = None # don't pickle the connection
86
        return state
87
    
88
    def _db(self):
89
        if self.__db == None:
90
            # Process db_config
91
            db_config = self.db_config.copy() # don't modify input!
92
            module_name, mappings = db_engines[db_config.pop('engine')]
93
            module = __import__(module_name)
94
            _add_module(module)
95
            for orig, new in mappings.iteritems():
96
                try: util.rename_key(db_config, orig, new)
97
                except KeyError: pass
98
            
99
            # Connect
100
            self.__db = module.connect(**db_config)
101
            
102
            # Configure connection
103
            if self.serializable: run_raw_query(self,
104
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
105
        
106
        return self.__db
107
    
108
    class DbCursor(Proxy):
109
        def __init__(self, outer, cache_results):
110
            Proxy.__init__(self, outer.db.cursor())
111
            if cache_results: self.query_results = outer.query_results
112
            else: self.query_results = None
113
            self.query_lookup = None
114
            self.result = []
115
        
116
        def execute(self, query, params=None):
117
            self.query_lookup = _query_lookup(query, params)
118
            try: return_value = self.inner.execute(query, params)
119
            except Exception, e:
120
                self.result = e # cache the exception as the result
121
                self._cache_result()
122
                raise
123
            finally: self.query = get_cur_query(self.inner)
124
            return return_value
125
        
126
        def fetchone(self):
127
            row = self.inner.fetchone()
128
            if row != None: self.result.append(row)
129
            # otherwise, fetched all rows
130
            else: self._cache_result()
131
            return row
132
        
133
        def _cache_result(self):
134
            is_insert = self._is_insert()
135
            # For inserts, only cache exceptions since inserts are not
136
            # idempotent, but an invalid insert will always be invalid
137
            if self.query_results != None and (not is_insert
138
                or isinstance(self.result, Exception)):
139
                
140
                assert self.query_lookup != None
141
                self.query_results[self.query_lookup] = self.CacheCursor(
142
                    util.dict_subset(dicts.AttrsDictView(self),
143
                    ['query', 'result', 'rowcount', 'description']))
144
        
145
        def _is_insert(self): return self.query.upper().find('INSERT') >= 0
146
        
147
        class CacheCursor:
148
            def __init__(self, cached_result): self.__dict__ = cached_result
149
            
150
            def execute(self):
151
                if isinstance(self.result, Exception): raise self.result
152
                # otherwise, result is a rows list
153
                self.iter = iter(self.result)
154
            
155
            def fetchone(self):
156
                try: return self.iter.next()
157
                except StopIteration: return None
158
    
159
    def run_query(self, query, params=None, cacheable=False):
160
        query_lookup = _query_lookup(query, params)
161
        used_cache = False
162
        try:
163
            try:
164
                if not cacheable: raise KeyError
165
                cur = self.query_results[query_lookup]
166
                used_cache = True
167
            except KeyError:
168
                cur = self.DbCursor(self, cacheable)
169
                try: cur.execute(query, params)
170
                except Exception, e:
171
                    _add_cursor_info(e, cur)
172
                    raise
173
            else: cur.execute()
174
        finally:
175
            if self.log_debug != log_debug_none: # only compute msg if needed
176
                if used_cache: cache_status = 'Cache hit'
177
                elif cacheable: cache_status = 'Cache miss'
178
                else: cache_status = 'Non-cacheable'
179
                self.log_debug(cache_status+': '+strings.one_line(cur.query))
180
        
181
        return cur
182
    
183
    def is_cached(self, query, params=None):
184
        return _query_lookup(query, params) in self.query_results
185

    
186
connect = DbConn
187

    
188
##### Input validation
189

    
190
def check_name(name):
191
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
192
        +'" may contain only alphanumeric characters and _')
193

    
194
def esc_name_by_module(module, name, preserve_case=False):
195
    if module == 'psycopg2':
196
        if preserve_case: quote = '"'
197
        # Don't enclose in quotes because this disables case-insensitivity
198
        else: return name
199
    elif module == 'MySQLdb': quote = '`'
200
    else: raise NotImplementedError("Can't escape name for "+module+' database')
201
    return quote + name.replace(quote, '') + quote
202

    
203
def esc_name_by_engine(engine, name, **kw_args):
204
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
205

    
206
def esc_name(db, name, **kw_args):
207
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
208

    
209
##### Querying
210

    
211
def run_raw_query(db, *args, **kw_args):
212
    '''For args, see DbConn.run_query()'''
213
    return db.run_query(*args, **kw_args)
214

    
215
##### Recoverable querying
216

    
217
def with_savepoint(db, func):
218
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
219
    run_raw_query(db, 'SAVEPOINT '+savepoint)
220
    try: return_val = func()
221
    except:
222
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
223
        raise
224
    else:
225
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
226
        return return_val
227

    
228
def run_query(db, query, params=None, recover=None, cacheable=False):
229
    if recover == None: recover = False
230
    
231
    def run(): return run_raw_query(db, query, params, cacheable)
232
    if recover and not db.is_cached(query, params):
233
        return with_savepoint(db, run)
234
    else: return run() # don't need savepoint if cached
235

    
236
##### Result retrieval
237

    
238
def col_names(cur): return (col[0] for col in cur.description)
239

    
240
def rows(cur): return iter(lambda: cur.fetchone(), None)
241

    
242
def next_row(cur): return rows(cur).next()
243

    
244
def row(cur):
245
    row_iter = rows(cur)
246
    row_ = row_iter.next()
247
    iters.consume_iter(row_iter) # fetch all rows so result will be cached
248
    return row_
249

    
250
def next_value(cur): return next_row(cur)[0]
251

    
252
def value(cur): return row(cur)[0]
253

    
254
def values(cur): return iters.func_iter(lambda: next_value(cur))
255

    
256
def value_or_none(cur):
257
    try: return value(cur)
258
    except StopIteration: return None
259

    
260
##### Basic queries
261

    
262
def select(db, table, fields=None, conds=None, limit=None, start=None,
263
    recover=None, cacheable=True):
264
    '''@param fields Use None to select all fields in the table'''
265
    if conds == None: conds = {}
266
    assert limit == None or type(limit) == int
267
    assert start == None or type(start) == int
268
    check_name(table)
269
    if fields != None: map(check_name, fields)
270
    map(check_name, conds.keys())
271
    
272
    def cond(entry):
273
        col, value = entry
274
        cond_ = esc_name(db, col)+' '
275
        if value == None: cond_ += 'IS'
276
        else: cond_ += '='
277
        cond_ += ' %s'
278
        return cond_
279
    query = 'SELECT '
280
    if fields == None: query += '*'
281
    else: query += ', '.join([esc_name(db, field) for field in fields])
282
    query += ' FROM '+esc_name(db, table)
283
    
284
    missing = True
285
    if conds != {}:
286
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
287
        missing = False
288
    if limit != None: query += ' LIMIT '+str(limit); missing = False
289
    if start != None:
290
        if start != 0: query += ' OFFSET '+str(start)
291
        missing = False
292
    if missing: warnings.warn(DbWarning(
293
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
294
    
295
    return run_query(db, query, conds.values(), recover, cacheable)
296

    
297
def insert(db, table, row, returning=None, recover=None, cacheable=True):
298
    '''@param returning str|None An inserted column (such as pkey) to return'''
299
    check_name(table)
300
    cols = row.keys()
301
    map(check_name, cols)
302
    query = 'INSERT INTO '+table
303
    
304
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
305
        +', '.join(['%s']*len(cols))+')'
306
    else: query += ' DEFAULT VALUES'
307
    
308
    if returning != None:
309
        check_name(returning)
310
        query += ' RETURNING '+returning
311
    
312
    return run_query(db, query, row.values(), recover, cacheable)
313

    
314
def last_insert_id(db):
315
    module = util.root_module(db.db)
316
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
317
    elif module == 'MySQLdb': return db.insert_id()
318
    else: return None
319

    
320
def truncate(db, table):
321
    check_name(table)
322
    return run_raw_query(db, 'TRUNCATE '+table+' CASCADE')
323

    
324
##### Database structure queries
325

    
326
def pkey(db, table, recover=None):
327
    '''Assumed to be first column in table'''
328
    check_name(table)
329
    return col_names(run_query(db,
330
        'SELECT * FROM '+table+' LIMIT 0', recover=recover)).next()
331

    
332
def index_cols(db, table, index):
333
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
334
    automatically created. When you don't know whether something is a UNIQUE
335
    constraint or a UNIQUE index, use this function.'''
336
    check_name(table)
337
    check_name(index)
338
    module = util.root_module(db.db)
339
    if module == 'psycopg2':
340
        return list(values(run_query(db, '''\
341
SELECT attname
342
FROM
343
(
344
        SELECT attnum, attname
345
        FROM pg_index
346
        JOIN pg_class index ON index.oid = indexrelid
347
        JOIN pg_class table_ ON table_.oid = indrelid
348
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
349
        WHERE
350
            table_.relname = %(table)s
351
            AND index.relname = %(index)s
352
    UNION
353
        SELECT attnum, attname
354
        FROM
355
        (
356
            SELECT
357
                indrelid
358
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
359
                    AS indkey
360
            FROM pg_index
361
            JOIN pg_class index ON index.oid = indexrelid
362
            JOIN pg_class table_ ON table_.oid = indrelid
363
            WHERE
364
                table_.relname = %(table)s
365
                AND index.relname = %(index)s
366
        ) s
367
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
368
) s
369
ORDER BY attnum
370
''',
371
            {'table': table, 'index': index}, cacheable=True)))
372
    else: raise NotImplementedError("Can't list index columns for "+module+
373
        ' database')
374

    
375
def constraint_cols(db, table, constraint):
376
    check_name(table)
377
    check_name(constraint)
378
    module = util.root_module(db.db)
379
    if module == 'psycopg2':
380
        return list(values(run_query(db, '''\
381
SELECT attname
382
FROM pg_constraint
383
JOIN pg_class ON pg_class.oid = conrelid
384
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
385
WHERE
386
    relname = %(table)s
387
    AND conname = %(constraint)s
388
ORDER BY attnum
389
''',
390
            {'table': table, 'constraint': constraint})))
391
    else: raise NotImplementedError("Can't list constraint columns for "+module+
392
        ' database')
393

    
394
def tables(db):
395
    module = util.root_module(db.db)
396
    if module == 'psycopg2':
397
        return values(run_query(db, "SELECT tablename from pg_tables "
398
            "WHERE schemaname = 'public' ORDER BY tablename"))
399
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
400
    else: raise NotImplementedError("Can't list tables for "+module+' database')
401

    
402
##### Database management
403

    
404
def empty_db(db):
405
    for table in tables(db): truncate(db, table)
406

    
407
##### Heuristic queries
408

    
409
def try_insert(db, table, row, returning=None):
410
    '''Recovers from errors'''
411
    try: return insert(db, table, row, returning, recover=True)
412
    except Exception, e:
413
        msg = str(e)
414
        match = re.search(r'duplicate key value violates unique constraint '
415
            r'"(([^\W_]+)_[^"]+)"', msg)
416
        if match:
417
            constraint, table = match.groups()
418
            try: cols = index_cols(db, table, constraint)
419
            except NotImplementedError: raise e
420
            else: raise DuplicateKeyException(cols, e)
421
        match = re.search(r'null value in column "(\w+)" violates not-null '
422
            'constraint', msg)
423
        if match: raise NullValueException([match.group(1)], e)
424
        raise # no specific exception raised
425

    
426
def put(db, table, row, pkey, row_ct_ref=None):
427
    '''Recovers from errors.
428
    Only works under PostgreSQL (uses `INSERT ... RETURNING`)'''
429
    try:
430
        cur = try_insert(db, table, row, pkey)
431
        if row_ct_ref != None and cur.rowcount >= 0:
432
            row_ct_ref[0] += cur.rowcount
433
        return value(cur)
434
    except DuplicateKeyException, e:
435
        return value(select(db, table, [pkey],
436
            util.dict_subset_right_join(row, e.cols), recover=True))
437

    
438
def get(db, table, row, pkey, row_ct_ref=None, create=False):
439
    '''Recovers from errors'''
440
    try: return value(select(db, table, [pkey], row, 1, recover=True))
441
    except StopIteration:
442
        if not create: raise
443
        return put(db, table, row, pkey, row_ct_ref) # insert new row
(22-22/33)