Project

General

Profile

1
# Database access
2

    
3
import copy
4
import re
5
import warnings
6

    
7
import exc
8
import dicts
9
import iters
10
from Proxy import Proxy
11
import rand
12
import strings
13
import util
14

    
15
##### Exceptions
16

    
17
def get_cur_query(cur):
18
    if hasattr(cur, 'query'): return cur.query
19
    elif hasattr(cur, '_last_executed'): return cur._last_executed
20
    else: return None
21

    
22
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
23

    
24
class DbException(exc.ExceptionWithCause):
25
    def __init__(self, msg, cause=None, cur=None):
26
        exc.ExceptionWithCause.__init__(self, msg, cause)
27
        if cur != None: _add_cursor_info(self, cur)
28

    
29
class NameException(DbException): pass
30

    
31
class ExceptionWithColumns(DbException):
32
    def __init__(self, cols, cause=None):
33
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
34
        self.cols = cols
35

    
36
class DuplicateKeyException(ExceptionWithColumns): pass
37

    
38
class NullValueException(ExceptionWithColumns): pass
39

    
40
class EmptyRowException(DbException): pass
41

    
42
##### Warnings
43

    
44
class DbWarning(UserWarning): pass
45

    
46
##### Database connections
47

    
48
db_config_names = ['engine', 'host', 'user', 'password', 'database']
49

    
50
db_engines = {
51
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
52
    'PostgreSQL': ('psycopg2', {}),
53
}
54

    
55
DatabaseErrors_set = set([DbException])
56
DatabaseErrors = tuple(DatabaseErrors_set)
57

    
58
def _add_module(module):
59
    DatabaseErrors_set.add(module.DatabaseError)
60
    global DatabaseErrors
61
    DatabaseErrors = tuple(DatabaseErrors_set)
62

    
63
def db_config_str(db_config):
64
    return db_config['engine']+' database '+db_config['database']
65

    
66
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
67

    
68
log_debug_none = lambda msg: None
69

    
70
class DbConn:
71
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none):
72
        self.db_config = db_config
73
        self.serializable = serializable
74
        self.log_debug = log_debug
75
        
76
        self.__db = None
77
        self.query_results = {}
78
    
79
    def __getattr__(self, name):
80
        if name == '__dict__': raise Exception('getting __dict__')
81
        if name == 'db': return self._db()
82
        else: raise AttributeError()
83
    
84
    def __getstate__(self):
85
        state = copy.copy(self.__dict__) # shallow copy
86
        state['log_debug'] = None # don't pickle the debug callback
87
        state['_DbConn__db'] = None # don't pickle the connection
88
        return state
89
    
90
    def _db(self):
91
        if self.__db == None:
92
            # Process db_config
93
            db_config = self.db_config.copy() # don't modify input!
94
            module_name, mappings = db_engines[db_config.pop('engine')]
95
            module = __import__(module_name)
96
            _add_module(module)
97
            for orig, new in mappings.iteritems():
98
                try: util.rename_key(db_config, orig, new)
99
                except KeyError: pass
100
            
101
            # Connect
102
            self.__db = module.connect(**db_config)
103
            
104
            # Configure connection
105
            if self.serializable: run_raw_query(self,
106
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
107
        
108
        return self.__db
109
    
110
    class DbCursor(Proxy):
111
        def __init__(self, outer):
112
            Proxy.__init__(self, outer.db.cursor())
113
            self.query_results = outer.query_results
114
            self.query_lookup = None
115
            self.result = []
116
        
117
        def execute(self, query, params=None):
118
            self.query_lookup = _query_lookup(query, params)
119
            try: return_value = self.inner.execute(query, params)
120
            except Exception, e:
121
                self.result = e # cache the exception as the result
122
                self._cache_result()
123
                raise
124
            finally: self.query = get_cur_query(self.inner)
125
            return return_value
126
        
127
        def fetchone(self):
128
            row = self.inner.fetchone()
129
            if row != None: self.result.append(row)
130
            # otherwise, fetched all rows
131
            else: self._cache_result()
132
            return row
133
        
134
        def _cache_result(self):
135
            is_insert = self._is_insert()
136
            # For inserts, only cache exceptions since inserts are not
137
            # idempotent, but an invalid insert will always be invalid
138
            if self.query_results != None and (not is_insert
139
                or isinstance(self.result, Exception)):
140
                
141
                assert self.query_lookup != None
142
                self.query_results[self.query_lookup] = self.CacheCursor(
143
                    util.dict_subset(dicts.AttrsDictView(self),
144
                    ['query', 'result', 'rowcount', 'description']))
145
        
146
        def _is_insert(self): return self.query.upper().find('INSERT') >= 0
147
        
148
        class CacheCursor:
149
            def __init__(self, cached_result): self.__dict__ = cached_result
150
            
151
            def execute(self, *args, **kw_args):
152
                if isinstance(self.result, Exception): raise self.result
153
                # otherwise, result is a rows list
154
                self.iter = iter(self.result)
155
            
156
            def fetchone(self):
157
                try: return self.iter.next()
158
                except StopIteration: return None
159
    
160
    def run_query(self, query, params=None, cacheable=False):
161
        used_cache = False
162
        try:
163
            # Get cursor
164
            if cacheable:
165
                query_lookup = _query_lookup(query, params)
166
                try:
167
                    cur = self.query_results[query_lookup]
168
                    used_cache = True
169
                except KeyError: cur = self.DbCursor(self)
170
            else: cur = self.db.cursor()
171
            
172
            # Run query
173
            try: cur.execute(query, params)
174
            except Exception, e:
175
                _add_cursor_info(e, cur)
176
                raise
177
        finally:
178
            if self.log_debug != log_debug_none: # only compute msg if needed
179
                if used_cache: cache_status = 'Cache hit'
180
                elif cacheable: cache_status = 'Cache miss'
181
                else: cache_status = 'Non-cacheable'
182
                self.log_debug(cache_status+': '
183
                    +strings.one_line(get_cur_query(cur)))
184
        
185
        return cur
186
    
187
    def is_cached(self, query, params=None):
188
        return _query_lookup(query, params) in self.query_results
189

    
190
connect = DbConn
191

    
192
##### Input validation
193

    
194
def check_name(name):
195
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
196
        +'" may contain only alphanumeric characters and _')
197

    
198
def esc_name_by_module(module, name, preserve_case=False):
199
    if module == 'psycopg2':
200
        if preserve_case: quote = '"'
201
        # Don't enclose in quotes because this disables case-insensitivity
202
        else: return name
203
    elif module == 'MySQLdb': quote = '`'
204
    else: raise NotImplementedError("Can't escape name for "+module+' database')
205
    return quote + name.replace(quote, '') + quote
206

    
207
def esc_name_by_engine(engine, name, **kw_args):
208
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
209

    
210
def esc_name(db, name, **kw_args):
211
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
212

    
213
##### Querying
214

    
215
def run_raw_query(db, *args, **kw_args):
216
    '''For args, see DbConn.run_query()'''
217
    return db.run_query(*args, **kw_args)
218

    
219
##### Recoverable querying
220

    
221
def with_savepoint(db, func):
222
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
223
    run_raw_query(db, 'SAVEPOINT '+savepoint)
224
    try: return_val = func()
225
    except:
226
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
227
        raise
228
    else:
229
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
230
        return return_val
231

    
232
def run_query(db, query, params=None, recover=None, cacheable=False):
233
    if recover == None: recover = False
234
    
235
    def run(): return run_raw_query(db, query, params, cacheable)
236
    if recover and not db.is_cached(query, params):
237
        return with_savepoint(db, run)
238
    else: return run() # don't need savepoint if cached
239

    
240
##### Result retrieval
241

    
242
def col_names(cur): return (col[0] for col in cur.description)
243

    
244
def rows(cur): return iter(lambda: cur.fetchone(), None)
245

    
246
def next_row(cur): return rows(cur).next()
247

    
248
def row(cur):
249
    row_iter = rows(cur)
250
    row_ = row_iter.next()
251
    iters.consume_iter(row_iter) # fetch all rows so result will be cached
252
    return row_
253

    
254
def next_value(cur): return next_row(cur)[0]
255

    
256
def value(cur): return row(cur)[0]
257

    
258
def values(cur): return iters.func_iter(lambda: next_value(cur))
259

    
260
def value_or_none(cur):
261
    try: return value(cur)
262
    except StopIteration: return None
263

    
264
##### Basic queries
265

    
266
def select(db, table, fields=None, conds=None, limit=None, start=None,
267
    recover=None, cacheable=True):
268
    '''@param fields Use None to select all fields in the table'''
269
    if conds == None: conds = {}
270
    assert limit == None or type(limit) == int
271
    assert start == None or type(start) == int
272
    check_name(table)
273
    if fields != None: map(check_name, fields)
274
    map(check_name, conds.keys())
275
    
276
    def cond(entry):
277
        col, value = entry
278
        cond_ = esc_name(db, col)+' '
279
        if value == None: cond_ += 'IS'
280
        else: cond_ += '='
281
        cond_ += ' %s'
282
        return cond_
283
    query = 'SELECT '
284
    if fields == None: query += '*'
285
    else: query += ', '.join([esc_name(db, field) for field in fields])
286
    query += ' FROM '+esc_name(db, table)
287
    
288
    missing = True
289
    if conds != {}:
290
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
291
        missing = False
292
    if limit != None: query += ' LIMIT '+str(limit); missing = False
293
    if start != None:
294
        if start != 0: query += ' OFFSET '+str(start)
295
        missing = False
296
    if missing: warnings.warn(DbWarning(
297
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
298
    
299
    return run_query(db, query, conds.values(), recover, cacheable)
300

    
301
def insert(db, table, row, returning=None, recover=None, cacheable=True):
302
    '''@param returning str|None An inserted column (such as pkey) to return'''
303
    check_name(table)
304
    cols = row.keys()
305
    map(check_name, cols)
306
    query = 'INSERT INTO '+table
307
    
308
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
309
        +', '.join(['%s']*len(cols))+')'
310
    else: query += ' DEFAULT VALUES'
311
    
312
    if returning != None:
313
        check_name(returning)
314
        query += ' RETURNING '+returning
315
    
316
    return run_query(db, query, row.values(), recover, cacheable)
317

    
318
def last_insert_id(db):
319
    module = util.root_module(db.db)
320
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
321
    elif module == 'MySQLdb': return db.insert_id()
322
    else: return None
323

    
324
def truncate(db, table):
325
    check_name(table)
326
    return run_raw_query(db, 'TRUNCATE '+table+' CASCADE')
327

    
328
##### Database structure queries
329

    
330
def pkey(db, table, recover=None):
331
    '''Assumed to be first column in table'''
332
    check_name(table)
333
    return col_names(select(db, table, limit=0, recover=recover)).next()
334

    
335
def index_cols(db, table, index):
336
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
337
    automatically created. When you don't know whether something is a UNIQUE
338
    constraint or a UNIQUE index, use this function.'''
339
    check_name(table)
340
    check_name(index)
341
    module = util.root_module(db.db)
342
    if module == 'psycopg2':
343
        return list(values(run_query(db, '''\
344
SELECT attname
345
FROM
346
(
347
        SELECT attnum, attname
348
        FROM pg_index
349
        JOIN pg_class index ON index.oid = indexrelid
350
        JOIN pg_class table_ ON table_.oid = indrelid
351
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
352
        WHERE
353
            table_.relname = %(table)s
354
            AND index.relname = %(index)s
355
    UNION
356
        SELECT attnum, attname
357
        FROM
358
        (
359
            SELECT
360
                indrelid
361
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
362
                    AS indkey
363
            FROM pg_index
364
            JOIN pg_class index ON index.oid = indexrelid
365
            JOIN pg_class table_ ON table_.oid = indrelid
366
            WHERE
367
                table_.relname = %(table)s
368
                AND index.relname = %(index)s
369
        ) s
370
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
371
) s
372
ORDER BY attnum
373
''',
374
            {'table': table, 'index': index}, cacheable=True)))
375
    else: raise NotImplementedError("Can't list index columns for "+module+
376
        ' database')
377

    
378
def constraint_cols(db, table, constraint):
379
    check_name(table)
380
    check_name(constraint)
381
    module = util.root_module(db.db)
382
    if module == 'psycopg2':
383
        return list(values(run_query(db, '''\
384
SELECT attname
385
FROM pg_constraint
386
JOIN pg_class ON pg_class.oid = conrelid
387
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
388
WHERE
389
    relname = %(table)s
390
    AND conname = %(constraint)s
391
ORDER BY attnum
392
''',
393
            {'table': table, 'constraint': constraint})))
394
    else: raise NotImplementedError("Can't list constraint columns for "+module+
395
        ' database')
396

    
397
def tables(db):
398
    module = util.root_module(db.db)
399
    if module == 'psycopg2':
400
        return values(run_query(db, "SELECT tablename from pg_tables "
401
            "WHERE schemaname = 'public' ORDER BY tablename"))
402
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
403
    else: raise NotImplementedError("Can't list tables for "+module+' database')
404

    
405
##### Database management
406

    
407
def empty_db(db):
408
    for table in tables(db): truncate(db, table)
409

    
410
##### Heuristic queries
411

    
412
def try_insert(db, table, row, returning=None):
413
    '''Recovers from errors'''
414
    try: return insert(db, table, row, returning, recover=True)
415
    except Exception, e:
416
        msg = str(e)
417
        match = re.search(r'duplicate key value violates unique constraint '
418
            r'"(([^\W_]+)_[^"]+)"', msg)
419
        if match:
420
            constraint, table = match.groups()
421
            try: cols = index_cols(db, table, constraint)
422
            except NotImplementedError: raise e
423
            else: raise DuplicateKeyException(cols, e)
424
        match = re.search(r'null value in column "(\w+)" violates not-null '
425
            'constraint', msg)
426
        if match: raise NullValueException([match.group(1)], e)
427
        raise # no specific exception raised
428

    
429
def put(db, table, row, pkey, row_ct_ref=None):
430
    '''Recovers from errors.
431
    Only works under PostgreSQL (uses `INSERT ... RETURNING`)'''
432
    try:
433
        cur = try_insert(db, table, row, pkey)
434
        if row_ct_ref != None and cur.rowcount >= 0:
435
            row_ct_ref[0] += cur.rowcount
436
        return value(cur)
437
    except DuplicateKeyException, e:
438
        return value(select(db, table, [pkey],
439
            util.dict_subset_right_join(row, e.cols), recover=True))
440

    
441
def get(db, table, row, pkey, row_ct_ref=None, create=False):
442
    '''Recovers from errors'''
443
    try: return value(select(db, table, [pkey], row, 1, recover=True))
444
    except StopIteration:
445
        if not create: raise
446
        return put(db, table, row, pkey, row_ct_ref) # insert new row
(22-22/33)