Project

General

Profile

1
# Database access
2

    
3
import copy
4
import re
5
import sys
6
import warnings
7

    
8
import exc
9
import iters
10
from Proxy import Proxy
11
import rand
12
import strings
13
import util
14

    
15
##### Exceptions
16

    
17
def get_cur_query(cur):
18
    if hasattr(cur, 'query'): return cur.query
19
    elif hasattr(cur, '_last_executed'): return cur._last_executed
20
    else: return None
21

    
22
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
23

    
24
class DbException(exc.ExceptionWithCause):
25
    def __init__(self, msg, cause=None, cur=None):
26
        exc.ExceptionWithCause.__init__(self, msg, cause)
27
        if cur != None: _add_cursor_info(self, cur)
28

    
29
class NameException(DbException): pass
30

    
31
class ExceptionWithColumns(DbException):
32
    def __init__(self, cols, cause=None):
33
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
34
        self.cols = cols
35

    
36
class DuplicateKeyException(ExceptionWithColumns): pass
37

    
38
class NullValueException(ExceptionWithColumns): pass
39

    
40
class EmptyRowException(DbException): pass
41

    
42
##### Warnings
43

    
44
class DbWarning(UserWarning): pass
45

    
46
##### Input validation
47

    
48
def check_name(name):
49
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
50
        +'" may contain only alphanumeric characters and _')
51

    
52
def esc_name(db, name):
53
    module = util.root_module(db.db)
54
    if module == 'psycopg2': return name
55
        # Don't enclose in quotes because this disables case-insensitivity
56
    elif module == 'MySQLdb': quote = '`'
57
    else: raise NotImplementedError("Can't escape name for "+module+' database')
58
    return quote + name.replace(quote, '') + quote
59

    
60
##### Database connections
61

    
62
db_engines = {
63
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
64
    'PostgreSQL': ('psycopg2', {}),
65
}
66

    
67
DatabaseErrors_set = set([DbException])
68
DatabaseErrors = tuple(DatabaseErrors_set)
69

    
70
def _add_module(module):
71
    DatabaseErrors_set.add(module.DatabaseError)
72
    global DatabaseErrors
73
    DatabaseErrors = tuple(DatabaseErrors_set)
74

    
75
def db_config_str(db_config):
76
    return db_config['engine']+' database '+db_config['database']
77

    
78
def _query_lookup(query, params): return (query, util.cast(tuple, params))
79

    
80
class DbConn:
81
    def __init__(self, db_config, serializable=True, debug=False):
82
        self.db_config = db_config
83
        self.serializable = serializable
84
        self.debug = debug
85
        
86
        self.__db = None
87
        self.pkeys = {}
88
        self.index_cols = {}
89
        self.query_results = {}
90
    
91
    def __getattr__(self, name):
92
        if name == '__dict__': raise Exception('getting __dict__')
93
        if name == 'db': return self._db()
94
        else: raise AttributeError()
95
    
96
    def __getstate__(self):
97
        state = copy.copy(self.__dict__) # shallow copy
98
        state['_DbConn__db'] = None # don't pickle the connection
99
        return state
100
    
101
    def _db(self):
102
        if self.__db == None:
103
            # Process db_config
104
            db_config = self.db_config.copy() # don't modify input!
105
            module_name, mappings = db_engines[db_config.pop('engine')]
106
            module = __import__(module_name)
107
            _add_module(module)
108
            for orig, new in mappings.iteritems():
109
                try: util.rename_key(db_config, orig, new)
110
                except KeyError: pass
111
            
112
            # Connect
113
            self.__db = module.connect(**db_config)
114
            
115
            # Configure connection
116
            if self.serializable: run_raw_query(self,
117
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
118
        
119
        return self.__db
120
    
121
    class DbCursor(Proxy):
122
        def __init__(self, outer):
123
            Proxy.__init__(self, outer.db.cursor())
124
            self.outer = outer
125
            self.query_lookup = None
126
            self.result = []
127
        
128
        def execute(self, query, params=None):
129
            self.query_lookup = _query_lookup(query, params)
130
            return_value = self.inner.execute(query, params)
131
            self.query = get_cur_query(self.inner)
132
            return return_value
133
        
134
        def fetchone(self):
135
            row = self.inner.fetchone()
136
            if row == None: # fetched all rows
137
                assert self.query_lookup != None
138
                pass #self.outer.query_results[self.query_lookup] = (self.query,
139
                    #self.result)
140
            else: self.result.append(row)
141
            return row
142
    
143
    class CacheCursor:
144
        def __init__(self, query, result):
145
            self.query = query
146
            self.rowcount = len(result)
147
            self.iter = iter(result)
148
        
149
        def fetchone(self):
150
            try: return self.iter.next()
151
            except StopIteration: return None
152
    
153
    def run_query(self, query, params=None, cacheable=False):
154
        query_lookup = _query_lookup(query, params)
155
        try: actual_query, result = self.query_results[query_lookup]
156
        except KeyError:
157
            cur = self.DbCursor(self)
158
            try: cur.execute(query, params)
159
            except Exception, e:
160
                _add_cursor_info(e, cur)
161
                raise
162
            if self.debug:
163
                sys.stderr.write(strings.one_line(get_cur_query(cur))+'\n')
164
            return cur
165
        else: return self.CacheCursor(actual_query, result)
166

    
167
connect = DbConn
168

    
169
##### Querying
170

    
171
def run_raw_query(db, *args, **kw_args):
172
    '''For args, see DbConn.run_query()'''
173
    return db.run_query(*args, **kw_args)
174

    
175
##### Recoverable querying
176

    
177
def with_savepoint(db, func):
178
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
179
    run_raw_query(db, 'SAVEPOINT '+savepoint)
180
    try: return_val = func()
181
    except:
182
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
183
        raise
184
    else:
185
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
186
        return return_val
187

    
188
def run_query(db, query, params=None, recover=None, cacheable=False):
189
    if recover == None: recover = False
190
    
191
    def run(): return run_raw_query(db, query, params, cacheable)
192
    if recover: return with_savepoint(db, run)
193
    else: return run()
194

    
195
##### Result retrieval
196

    
197
def col_names(cur): return (col[0] for col in cur.description)
198

    
199
def rows(cur): return iter(lambda: cur.fetchone(), None)
200

    
201
def next_row(cur): return rows(cur).next()
202

    
203
def row(cur):
204
    row_iter = rows(cur)
205
    row_ = row_iter.next()
206
    iters.consume_iter(row_iter) # fetch all rows so result will be cached
207
    return row_
208

    
209
def next_value(cur): return next_row(cur)[0]
210

    
211
def value(cur): return row(cur)[0]
212

    
213
def values(cur): return iters.func_iter(lambda: next_value(cur))
214

    
215
def value_or_none(cur):
216
    try: return value(cur)
217
    except StopIteration: return None
218

    
219
##### Basic queries
220

    
221
def select(db, table, fields=None, conds=None, limit=None, start=None,
222
    recover=None, cacheable=True):
223
    '''@param fields Use None to select all fields in the table'''
224
    if conds == None: conds = {}
225
    assert limit == None or type(limit) == int
226
    assert start == None or type(start) == int
227
    check_name(table)
228
    if fields != None: map(check_name, fields)
229
    map(check_name, conds.keys())
230
    
231
    def cond(entry):
232
        col, value = entry
233
        cond_ = esc_name(db, col)+' '
234
        if value == None: cond_ += 'IS'
235
        else: cond_ += '='
236
        cond_ += ' %s'
237
        return cond_
238
    query = 'SELECT '
239
    if fields == None: query += '*'
240
    else: query += ', '.join([esc_name(db, field) for field in fields])
241
    query += ' FROM '+esc_name(db, table)
242
    
243
    missing = True
244
    if conds != {}:
245
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
246
        missing = False
247
    if limit != None: query += ' LIMIT '+str(limit); missing = False
248
    if start != None:
249
        if start != 0: query += ' OFFSET '+str(start)
250
        missing = False
251
    if missing: warnings.warn(DbWarning(
252
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
253
    
254
    return run_query(db, query, conds.values(), cacheable, recover)
255

    
256
def insert(db, table, row, returning=None, recover=None):
257
    '''@param returning str|None An inserted column (such as pkey) to return'''
258
    check_name(table)
259
    cols = row.keys()
260
    map(check_name, cols)
261
    query = 'INSERT INTO '+table
262
    
263
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
264
        +', '.join(['%s']*len(cols))+')'
265
    else: query += ' DEFAULT VALUES'
266
    
267
    if returning != None:
268
        check_name(returning)
269
        query += ' RETURNING '+returning
270
    
271
    return run_query(db, query, row.values(), recover)
272

    
273
def last_insert_id(db):
274
    module = util.root_module(db.db)
275
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
276
    elif module == 'MySQLdb': return db.insert_id()
277
    else: return None
278

    
279
def truncate(db, table):
280
    check_name(table)
281
    return run_raw_query(db, 'TRUNCATE '+table+' CASCADE')
282

    
283
##### Database structure queries
284

    
285
def pkey(db, table, recover=None):
286
    '''Assumed to be first column in table'''
287
    check_name(table)
288
    if table not in db.pkeys:
289
        db.pkeys[table] = col_names(run_query(db,
290
            'SELECT * FROM '+table+' LIMIT 0', recover=recover)).next()
291
    return db.pkeys[table]
292

    
293
def index_cols(db, table, index):
294
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
295
    automatically created. When you don't know whether something is a UNIQUE
296
    constraint or a UNIQUE index, use this function.'''
297
    check_name(table)
298
    check_name(index)
299
    lookup = (table, index)
300
    if lookup not in db.index_cols:
301
        module = util.root_module(db.db)
302
        if module == 'psycopg2':
303
            db.index_cols[lookup] = list(values(run_query(db, '''\
304
SELECT attname
305
FROM
306
(
307
        SELECT attnum, attname
308
        FROM pg_index
309
        JOIN pg_class index ON index.oid = indexrelid
310
        JOIN pg_class table_ ON table_.oid = indrelid
311
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
312
        WHERE
313
            table_.relname = %(table)s
314
            AND index.relname = %(index)s
315
    UNION
316
        SELECT attnum, attname
317
        FROM
318
        (
319
            SELECT
320
                indrelid
321
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
322
                    AS indkey
323
            FROM pg_index
324
            JOIN pg_class index ON index.oid = indexrelid
325
            JOIN pg_class table_ ON table_.oid = indrelid
326
            WHERE
327
                table_.relname = %(table)s
328
                AND index.relname = %(index)s
329
        ) s
330
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
331
) s
332
ORDER BY attnum
333
''',
334
                {'table': table, 'index': index})))
335
        else: raise NotImplementedError("Can't list index columns for "+module+
336
            ' database')
337
    return db.index_cols[lookup]
338

    
339
def constraint_cols(db, table, constraint):
340
    check_name(table)
341
    check_name(constraint)
342
    module = util.root_module(db.db)
343
    if module == 'psycopg2':
344
        return list(values(run_query(db, '''\
345
SELECT attname
346
FROM pg_constraint
347
JOIN pg_class ON pg_class.oid = conrelid
348
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
349
WHERE
350
    relname = %(table)s
351
    AND conname = %(constraint)s
352
ORDER BY attnum
353
''',
354
            {'table': table, 'constraint': constraint})))
355
    else: raise NotImplementedError("Can't list constraint columns for "+module+
356
        ' database')
357

    
358
def tables(db):
359
    module = util.root_module(db.db)
360
    if module == 'psycopg2':
361
        return values(run_query(db, "SELECT tablename from pg_tables "
362
            "WHERE schemaname = 'public' ORDER BY tablename"))
363
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
364
    else: raise NotImplementedError("Can't list tables for "+module+' database')
365

    
366
##### Database management
367

    
368
def empty_db(db):
369
    for table in tables(db): truncate(db, table)
370

    
371
##### Heuristic queries
372

    
373
def try_insert(db, table, row, returning=None):
374
    '''Recovers from errors'''
375
    try: return insert(db, table, row, returning, recover=True)
376
    except Exception, e:
377
        msg = str(e)
378
        match = re.search(r'duplicate key value violates unique constraint '
379
            r'"(([^\W_]+)_[^"]+)"', msg)
380
        if match:
381
            constraint, table = match.groups()
382
            try: cols = index_cols(db, table, constraint)
383
            except NotImplementedError: raise e
384
            else: raise DuplicateKeyException(cols, e)
385
        match = re.search(r'null value in column "(\w+)" violates not-null '
386
            'constraint', msg)
387
        if match: raise NullValueException([match.group(1)], e)
388
        raise # no specific exception raised
389

    
390
def put(db, table, row, pkey, row_ct_ref=None):
391
    '''Recovers from errors.
392
    Only works under PostgreSQL (uses `INSERT ... RETURNING`)'''
393
    try:
394
        cur = try_insert(db, table, row, pkey)
395
        if row_ct_ref != None and cur.rowcount >= 0:
396
            row_ct_ref[0] += cur.rowcount
397
        return value(cur)
398
    except DuplicateKeyException, e:
399
        return value(select(db, table, [pkey],
400
            util.dict_subset_right_join(row, e.cols), recover=True))
401

    
402
def get(db, table, row, pkey, row_ct_ref=None, create=False):
403
    '''Recovers from errors'''
404
    try: return value(select(db, table, [pkey], row, 1, recover=True))
405
    except StopIteration:
406
        if not create: raise
407
        return put(db, table, row, pkey, row_ct_ref) # insert new row
(22-22/33)