Project

General

Profile

1
# Database access
2

    
3
import copy
4
import re
5
import warnings
6

    
7
import exc
8
import dicts
9
import iters
10
from Proxy import Proxy
11
import rand
12
import strings
13
import util
14

    
15
##### Exceptions
16

    
17
def get_cur_query(cur):
18
    if hasattr(cur, 'query'): return cur.query
19
    elif hasattr(cur, '_last_executed'): return cur._last_executed
20
    else: return None
21

    
22
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
23

    
24
class DbException(exc.ExceptionWithCause):
25
    def __init__(self, msg, cause=None, cur=None):
26
        exc.ExceptionWithCause.__init__(self, msg, cause)
27
        if cur != None: _add_cursor_info(self, cur)
28

    
29
class NameException(DbException): pass
30

    
31
class ExceptionWithColumns(DbException):
32
    def __init__(self, cols, cause=None):
33
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
34
        self.cols = cols
35

    
36
class DuplicateKeyException(ExceptionWithColumns): pass
37

    
38
class NullValueException(ExceptionWithColumns): pass
39

    
40
class EmptyRowException(DbException): pass
41

    
42
##### Warnings
43

    
44
class DbWarning(UserWarning): pass
45

    
46
##### Result retrieval
47

    
48
def col_names(cur): return (col[0] for col in cur.description)
49

    
50
def rows(cur): return iter(lambda: cur.fetchone(), None)
51

    
52
def consume_rows(cur):
53
    '''Used to fetch all rows so result will be cached'''
54
    iters.consume_iter(rows(cur))
55

    
56
def next_row(cur): return rows(cur).next()
57

    
58
def row(cur):
59
    row_ = next_row(cur)
60
    consume_rows(cur)
61
    return row_
62

    
63
def next_value(cur): return next_row(cur)[0]
64

    
65
def value(cur): return row(cur)[0]
66

    
67
def values(cur): return iters.func_iter(lambda: next_value(cur))
68

    
69
def value_or_none(cur):
70
    try: return value(cur)
71
    except StopIteration: return None
72

    
73
##### Database connections
74

    
75
db_config_names = ['engine', 'host', 'user', 'password', 'database']
76

    
77
db_engines = {
78
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
79
    'PostgreSQL': ('psycopg2', {}),
80
}
81

    
82
DatabaseErrors_set = set([DbException])
83
DatabaseErrors = tuple(DatabaseErrors_set)
84

    
85
def _add_module(module):
86
    DatabaseErrors_set.add(module.DatabaseError)
87
    global DatabaseErrors
88
    DatabaseErrors = tuple(DatabaseErrors_set)
89

    
90
def db_config_str(db_config):
91
    return db_config['engine']+' database '+db_config['database']
92

    
93
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
94

    
95
log_debug_none = lambda msg: None
96

    
97
class DbConn:
98
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none):
99
        self.db_config = db_config
100
        self.serializable = serializable
101
        self.log_debug = log_debug
102
        
103
        self.__db = None
104
        self.query_results = {}
105
    
106
    def __getattr__(self, name):
107
        if name == '__dict__': raise Exception('getting __dict__')
108
        if name == 'db': return self._db()
109
        else: raise AttributeError()
110
    
111
    def __getstate__(self):
112
        state = copy.copy(self.__dict__) # shallow copy
113
        state['log_debug'] = None # don't pickle the debug callback
114
        state['_DbConn__db'] = None # don't pickle the connection
115
        return state
116
    
117
    def _db(self):
118
        if self.__db == None:
119
            # Process db_config
120
            db_config = self.db_config.copy() # don't modify input!
121
            module_name, mappings = db_engines[db_config.pop('engine')]
122
            module = __import__(module_name)
123
            _add_module(module)
124
            for orig, new in mappings.iteritems():
125
                try: util.rename_key(db_config, orig, new)
126
                except KeyError: pass
127
            
128
            # Connect
129
            self.__db = module.connect(**db_config)
130
            
131
            # Configure connection
132
            if self.serializable: run_raw_query(self,
133
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
134
        
135
        return self.__db
136
    
137
    class DbCursor(Proxy):
138
        def __init__(self, outer):
139
            Proxy.__init__(self, outer.db.cursor())
140
            self.query_results = outer.query_results
141
            self.query_lookup = None
142
            self.result = []
143
        
144
        def execute(self, query, params=None):
145
            self._is_insert = query.upper().find('INSERT') >= 0
146
            self.query_lookup = _query_lookup(query, params)
147
            try: return_value = self.inner.execute(query, params)
148
            except Exception, e:
149
                self.result = e # cache the exception as the result
150
                self._cache_result()
151
                raise
152
            finally: self.query = get_cur_query(self.inner)
153
            # Fetch all rows so result will be cached
154
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
155
            return return_value
156
        
157
        def fetchone(self):
158
            row = self.inner.fetchone()
159
            if row != None: self.result.append(row)
160
            # otherwise, fetched all rows
161
            else: self._cache_result()
162
            return row
163
        
164
        def _cache_result(self):
165
            # For inserts, only cache exceptions since inserts are not
166
            # idempotent, but an invalid insert will always be invalid
167
            if self.query_results != None and (not self._is_insert
168
                or isinstance(self.result, Exception)):
169
                
170
                assert self.query_lookup != None
171
                self.query_results[self.query_lookup] = self.CacheCursor(
172
                    util.dict_subset(dicts.AttrsDictView(self),
173
                    ['query', 'result', 'rowcount', 'description']))
174
        
175
        class CacheCursor:
176
            def __init__(self, cached_result): self.__dict__ = cached_result
177
            
178
            def execute(self, *args, **kw_args):
179
                if isinstance(self.result, Exception): raise self.result
180
                # otherwise, result is a rows list
181
                self.iter = iter(self.result)
182
            
183
            def fetchone(self):
184
                try: return self.iter.next()
185
                except StopIteration: return None
186
    
187
    def run_query(self, query, params=None, cacheable=False):
188
        used_cache = False
189
        try:
190
            # Get cursor
191
            if cacheable:
192
                query_lookup = _query_lookup(query, params)
193
                try:
194
                    cur = self.query_results[query_lookup]
195
                    used_cache = True
196
                except KeyError: cur = self.DbCursor(self)
197
            else: cur = self.db.cursor()
198
            
199
            # Run query
200
            try: cur.execute(query, params)
201
            except Exception, e:
202
                _add_cursor_info(e, cur)
203
                raise
204
        finally:
205
            if self.log_debug != log_debug_none: # only compute msg if needed
206
                if used_cache: cache_status = 'Cache hit'
207
                elif cacheable: cache_status = 'Cache miss'
208
                else: cache_status = 'Non-cacheable'
209
                self.log_debug(cache_status+': '
210
                    +strings.one_line(get_cur_query(cur)))
211
        
212
        return cur
213
    
214
    def is_cached(self, query, params=None):
215
        return _query_lookup(query, params) in self.query_results
216

    
217
connect = DbConn
218

    
219
##### Input validation
220

    
221
def check_name(name):
222
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
223
        +'" may contain only alphanumeric characters and _')
224

    
225
def esc_name_by_module(module, name, preserve_case=False):
226
    if module == 'psycopg2':
227
        if preserve_case: quote = '"'
228
        # Don't enclose in quotes because this disables case-insensitivity
229
        else: return name
230
    elif module == 'MySQLdb': quote = '`'
231
    else: raise NotImplementedError("Can't escape name for "+module+' database')
232
    return quote + name.replace(quote, '') + quote
233

    
234
def esc_name_by_engine(engine, name, **kw_args):
235
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
236

    
237
def esc_name(db, name, **kw_args):
238
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
239

    
240
##### Querying
241

    
242
def run_raw_query(db, *args, **kw_args):
243
    '''For args, see DbConn.run_query()'''
244
    return db.run_query(*args, **kw_args)
245

    
246
##### Recoverable querying
247

    
248
def with_savepoint(db, func):
249
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
250
    run_raw_query(db, 'SAVEPOINT '+savepoint)
251
    try: return_val = func()
252
    except:
253
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
254
        raise
255
    else:
256
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
257
        return return_val
258

    
259
def run_query(db, query, params=None, recover=None, cacheable=False):
260
    if recover == None: recover = False
261
    
262
    def run(): return run_raw_query(db, query, params, cacheable)
263
    if recover and not db.is_cached(query, params):
264
        return with_savepoint(db, run)
265
    else: return run() # don't need savepoint if cached
266

    
267
##### Basic queries
268

    
269
def select(db, table, fields=None, conds=None, limit=None, start=None,
270
    recover=None, cacheable=True):
271
    '''@param fields Use None to select all fields in the table'''
272
    if conds == None: conds = {}
273
    assert limit == None or type(limit) == int
274
    assert start == None or type(start) == int
275
    check_name(table)
276
    if fields != None: map(check_name, fields)
277
    map(check_name, conds.keys())
278
    
279
    def cond(entry):
280
        col, value = entry
281
        cond_ = esc_name(db, col)+' '
282
        if value == None: cond_ += 'IS'
283
        else: cond_ += '='
284
        cond_ += ' %s'
285
        return cond_
286
    query = 'SELECT '
287
    if fields == None: query += '*'
288
    else: query += ', '.join([esc_name(db, field) for field in fields])
289
    query += ' FROM '+esc_name(db, table)
290
    
291
    missing = True
292
    if conds != {}:
293
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
294
        missing = False
295
    if limit != None: query += ' LIMIT '+str(limit); missing = False
296
    if start != None:
297
        if start != 0: query += ' OFFSET '+str(start)
298
        missing = False
299
    if missing: warnings.warn(DbWarning(
300
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
301
    
302
    return run_query(db, query, conds.values(), recover, cacheable)
303

    
304
def insert(db, table, row, returning=None, recover=None, cacheable=True):
305
    '''@param returning str|None An inserted column (such as pkey) to return'''
306
    check_name(table)
307
    cols = row.keys()
308
    map(check_name, cols)
309
    query = 'INSERT INTO '+table
310
    
311
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
312
        +', '.join(['%s']*len(cols))+')'
313
    else: query += ' DEFAULT VALUES'
314
    
315
    if returning != None:
316
        check_name(returning)
317
        query += ' RETURNING '+returning
318
    
319
    return run_query(db, query, row.values(), recover, cacheable)
320

    
321
def last_insert_id(db):
322
    module = util.root_module(db.db)
323
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
324
    elif module == 'MySQLdb': return db.insert_id()
325
    else: return None
326

    
327
def truncate(db, table):
328
    check_name(table)
329
    return run_raw_query(db, 'TRUNCATE '+table+' CASCADE')
330

    
331
##### Database structure queries
332

    
333
def pkey(db, table, recover=None):
334
    '''Assumed to be first column in table'''
335
    check_name(table)
336
    return col_names(select(db, table, limit=0, recover=recover)).next()
337

    
338
def index_cols(db, table, index):
339
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
340
    automatically created. When you don't know whether something is a UNIQUE
341
    constraint or a UNIQUE index, use this function.'''
342
    check_name(table)
343
    check_name(index)
344
    module = util.root_module(db.db)
345
    if module == 'psycopg2':
346
        return list(values(run_query(db, '''\
347
SELECT attname
348
FROM
349
(
350
        SELECT attnum, attname
351
        FROM pg_index
352
        JOIN pg_class index ON index.oid = indexrelid
353
        JOIN pg_class table_ ON table_.oid = indrelid
354
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
355
        WHERE
356
            table_.relname = %(table)s
357
            AND index.relname = %(index)s
358
    UNION
359
        SELECT attnum, attname
360
        FROM
361
        (
362
            SELECT
363
                indrelid
364
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
365
                    AS indkey
366
            FROM pg_index
367
            JOIN pg_class index ON index.oid = indexrelid
368
            JOIN pg_class table_ ON table_.oid = indrelid
369
            WHERE
370
                table_.relname = %(table)s
371
                AND index.relname = %(index)s
372
        ) s
373
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
374
) s
375
ORDER BY attnum
376
''',
377
            {'table': table, 'index': index}, cacheable=True)))
378
    else: raise NotImplementedError("Can't list index columns for "+module+
379
        ' database')
380

    
381
def constraint_cols(db, table, constraint):
382
    check_name(table)
383
    check_name(constraint)
384
    module = util.root_module(db.db)
385
    if module == 'psycopg2':
386
        return list(values(run_query(db, '''\
387
SELECT attname
388
FROM pg_constraint
389
JOIN pg_class ON pg_class.oid = conrelid
390
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
391
WHERE
392
    relname = %(table)s
393
    AND conname = %(constraint)s
394
ORDER BY attnum
395
''',
396
            {'table': table, 'constraint': constraint})))
397
    else: raise NotImplementedError("Can't list constraint columns for "+module+
398
        ' database')
399

    
400
def tables(db):
401
    module = util.root_module(db.db)
402
    if module == 'psycopg2':
403
        return values(run_query(db, "SELECT tablename from pg_tables "
404
            "WHERE schemaname = 'public' ORDER BY tablename"))
405
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
406
    else: raise NotImplementedError("Can't list tables for "+module+' database')
407

    
408
##### Database management
409

    
410
def empty_db(db):
411
    for table in tables(db): truncate(db, table)
412

    
413
##### Heuristic queries
414

    
415
def try_insert(db, table, row, returning=None):
416
    '''Recovers from errors'''
417
    try: return insert(db, table, row, returning, recover=True)
418
    except Exception, e:
419
        msg = str(e)
420
        match = re.search(r'duplicate key value violates unique constraint '
421
            r'"(([^\W_]+)_[^"]+)"', msg)
422
        if match:
423
            constraint, table = match.groups()
424
            try: cols = index_cols(db, table, constraint)
425
            except NotImplementedError: raise e
426
            else: raise DuplicateKeyException(cols, e)
427
        match = re.search(r'null value in column "(\w+)" violates not-null '
428
            'constraint', msg)
429
        if match: raise NullValueException([match.group(1)], e)
430
        raise # no specific exception raised
431

    
432
def put(db, table, row, pkey, row_ct_ref=None):
433
    '''Recovers from errors.
434
    Only works under PostgreSQL (uses `INSERT ... RETURNING`)'''
435
    try:
436
        cur = try_insert(db, table, row, pkey)
437
        if row_ct_ref != None and cur.rowcount >= 0:
438
            row_ct_ref[0] += cur.rowcount
439
        return value(cur)
440
    except DuplicateKeyException, e:
441
        return value(select(db, table, [pkey],
442
            util.dict_subset_right_join(row, e.cols), recover=True))
443

    
444
def get(db, table, row, pkey, row_ct_ref=None, create=False):
445
    '''Recovers from errors'''
446
    try: return value(select(db, table, [pkey], row, 1, recover=True))
447
    except StopIteration:
448
        if not create: raise
449
        return put(db, table, row, pkey, row_ct_ref) # insert new row
(22-22/33)