Project

General

Profile

1
# Database access
2

    
3
import copy
4
import re
5
import sys
6
import warnings
7

    
8
import exc
9
import rand
10
import strings
11
import util
12

    
13
##### Exceptions
14

    
15
def get_cur_query(cur):
16
    if hasattr(cur, 'query'): return cur.query
17
    elif hasattr(cur, '_last_executed'): return cur._last_executed
18
    else: return None
19

    
20
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
21

    
22
class DbException(exc.ExceptionWithCause):
23
    def __init__(self, msg, cause=None, cur=None):
24
        exc.ExceptionWithCause.__init__(self, msg, cause)
25
        if cur != None: _add_cursor_info(self, cur)
26

    
27
class NameException(DbException): pass
28

    
29
class ExceptionWithColumns(DbException):
30
    def __init__(self, cols, cause=None):
31
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
32
        self.cols = cols
33

    
34
class DuplicateKeyException(ExceptionWithColumns): pass
35

    
36
class NullValueException(ExceptionWithColumns): pass
37

    
38
class EmptyRowException(DbException): pass
39

    
40
##### Warnings
41

    
42
class DbWarning(UserWarning): pass
43

    
44
##### Input validation
45

    
46
def check_name(name):
47
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
48
        +'" may contain only alphanumeric characters and _')
49

    
50
def esc_name(db, name):
51
    module = util.root_module(db.db)
52
    if module == 'psycopg2': return name
53
        # Don't enclose in quotes because this disables case-insensitivity
54
    elif module == 'MySQLdb': quote = '`'
55
    else: raise NotImplementedError("Can't escape name for "+module+' database')
56
    return quote + name.replace(quote, '') + quote
57

    
58
##### Database connections
59

    
60
db_engines = {
61
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
62
    'PostgreSQL': ('psycopg2', {}),
63
}
64

    
65
DatabaseErrors_set = set([DbException])
66
DatabaseErrors = tuple(DatabaseErrors_set)
67

    
68
def _add_module(module):
69
    DatabaseErrors_set.add(module.DatabaseError)
70
    global DatabaseErrors
71
    DatabaseErrors = tuple(DatabaseErrors_set)
72

    
73
def db_config_str(db_config):
74
    return db_config['engine']+' database '+db_config['database']
75

    
76
class DbConn:
77
    def __init__(self, db_config, serializable=True):
78
        self.db_config = db_config
79
        self.serializable = serializable
80
        
81
        self.__db = None
82
        self.pkeys = {}
83
        self.index_cols = {}
84
    
85
    def __getattr__(self, name):
86
        if name == '__dict__': raise Exception('getting __dict__')
87
        if name == 'db': return self._db()
88
        else: raise AttributeError()
89
    
90
    def __getstate__(self):
91
        state = copy.copy(self.__dict__) # shallow copy
92
        state['_DbConn__db'] = None # don't pickle the connection
93
        return state
94
    
95
    def _db(self):
96
        if self.__db == None:
97
            # Process db_config
98
            db_config = self.db_config.copy() # don't modify input!
99
            module_name, mappings = db_engines[db_config.pop('engine')]
100
            module = __import__(module_name)
101
            _add_module(module)
102
            for orig, new in mappings.iteritems():
103
                try: util.rename_key(db_config, orig, new)
104
                except KeyError: pass
105
            
106
            # Connect
107
            self.__db = module.connect(**db_config)
108
            
109
            # Configure connection
110
            if self.serializable: run_raw_query(self,
111
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
112
        
113
        return self.__db
114

    
115
connect = DbConn
116

    
117
##### Querying
118

    
119
def run_raw_query(db, query, params=None):
120
    cur = db.db.cursor()
121
    try: cur.execute(query, params)
122
    except Exception, e:
123
        _add_cursor_info(e, cur)
124
        raise
125
    if run_raw_query.debug:
126
        sys.stderr.write(strings.one_line(get_cur_query(cur))+'\n')
127
    return cur
128

    
129
##### Recoverable querying
130

    
131
def with_savepoint(db, func):
132
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
133
    run_raw_query(db, 'SAVEPOINT '+savepoint)
134
    try: return_val = func()
135
    except:
136
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
137
        raise
138
    else:
139
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
140
        return return_val
141

    
142
def run_query(db, query, params=None, recover=None):
143
    if recover == None: recover = False
144
    
145
    def run(): return run_raw_query(db, query, params)
146
    if recover: return with_savepoint(db, run)
147
    else: return run()
148

    
149
##### Result retrieval
150

    
151
def col_names(cur): return (col[0] for col in cur.description)
152

    
153
def rows(cur): return iter(lambda: cur.fetchone(), None)
154

    
155
def row(cur): return rows(cur).next()
156

    
157
def value(cur): return row(cur)[0]
158

    
159
def values(cur): return iter(lambda: value(cur), None)
160

    
161
def value_or_none(cur):
162
    try: return value(cur)
163
    except StopIteration: return None
164

    
165
##### Basic queries
166

    
167
def select(db, table, fields=None, conds=None, limit=None, start=None,
168
    recover=None):
169
    '''@param fields Use None to select all fields in the table'''
170
    if conds == None: conds = {}
171
    assert limit == None or type(limit) == int
172
    assert start == None or type(start) == int
173
    check_name(table)
174
    if fields != None: map(check_name, fields)
175
    map(check_name, conds.keys())
176
    
177
    def cond(entry):
178
        col, value = entry
179
        cond_ = esc_name(db, col)+' '
180
        if value == None: cond_ += 'IS'
181
        else: cond_ += '='
182
        cond_ += ' %s'
183
        return cond_
184
    query = 'SELECT '
185
    if fields == None: query += '*'
186
    else: query += ', '.join([esc_name(db, field) for field in fields])
187
    query += ' FROM '+esc_name(db, table)
188
    
189
    missing = True
190
    if conds != {}:
191
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
192
        missing = False
193
    if limit != None: query += ' LIMIT '+str(limit); missing = False
194
    if start != None:
195
        if start != 0: query += ' OFFSET '+str(start)
196
        missing = False
197
    if missing: warnings.warn(DbWarning(
198
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
199
    
200
    return run_query(db, query, conds.values(), recover)
201

    
202
def insert(db, table, row, returning=None, recover=None):
203
    '''@param returning str|None An inserted column (such as pkey) to return'''
204
    check_name(table)
205
    cols = row.keys()
206
    map(check_name, cols)
207
    query = 'INSERT INTO '+table
208
    
209
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
210
        +', '.join(['%s']*len(cols))+')'
211
    else: query += ' DEFAULT VALUES'
212
    
213
    if returning != None:
214
        check_name(returning)
215
        query += ' RETURNING '+returning
216
    
217
    return run_query(db, query, row.values(), recover)
218

    
219
def last_insert_id(db):
220
    module = util.root_module(db.db)
221
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
222
    elif module == 'MySQLdb': return db.insert_id()
223
    else: return None
224

    
225
def truncate(db, table):
226
    check_name(table)
227
    return run_raw_query(db, 'TRUNCATE '+table+' CASCADE')
228

    
229
##### Database structure queries
230

    
231
def pkey(db, table, recover=None):
232
    '''Assumed to be first column in table'''
233
    check_name(table)
234
    if table not in db.pkeys:
235
        db.pkeys[table] = col_names(run_query(db,
236
            'SELECT * FROM '+table+' LIMIT 0', recover=recover)).next()
237
    return db.pkeys[table]
238

    
239
def index_cols(db, table, index):
240
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
241
    automatically created. When you don't know whether something is a UNIQUE
242
    constraint or a UNIQUE index, use this function.'''
243
    check_name(table)
244
    check_name(index)
245
    lookup = (table, index)
246
    if lookup not in db.index_cols:
247
        module = util.root_module(db.db)
248
        if module == 'psycopg2':
249
            db.index_cols[lookup] = list(values(run_query(db, '''\
250
SELECT attname
251
FROM
252
(
253
        SELECT attnum, attname
254
        FROM pg_index
255
        JOIN pg_class index ON index.oid = indexrelid
256
        JOIN pg_class table_ ON table_.oid = indrelid
257
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
258
        WHERE
259
            table_.relname = %(table)s
260
            AND index.relname = %(index)s
261
    UNION
262
        SELECT attnum, attname
263
        FROM
264
        (
265
            SELECT
266
                indrelid
267
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
268
                    AS indkey
269
            FROM pg_index
270
            JOIN pg_class index ON index.oid = indexrelid
271
            JOIN pg_class table_ ON table_.oid = indrelid
272
            WHERE
273
                table_.relname = %(table)s
274
                AND index.relname = %(index)s
275
        ) s
276
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
277
) s
278
ORDER BY attnum
279
''',
280
                {'table': table, 'index': index})))
281
        else: raise NotImplementedError("Can't list index columns for "+module+
282
            ' database')
283
    return db.index_cols[lookup]
284

    
285
def constraint_cols(db, table, constraint):
286
    check_name(table)
287
    check_name(constraint)
288
    module = util.root_module(db.db)
289
    if module == 'psycopg2':
290
        return list(values(run_query(db, '''\
291
SELECT attname
292
FROM pg_constraint
293
JOIN pg_class ON pg_class.oid = conrelid
294
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
295
WHERE
296
    relname = %(table)s
297
    AND conname = %(constraint)s
298
ORDER BY attnum
299
''',
300
            {'table': table, 'constraint': constraint})))
301
    else: raise NotImplementedError("Can't list constraint columns for "+module+
302
        ' database')
303

    
304
def tables(db):
305
    module = util.root_module(db.db)
306
    if module == 'psycopg2':
307
        return values(run_query(db, "SELECT tablename from pg_tables "
308
            "WHERE schemaname = 'public' ORDER BY tablename"))
309
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
310
    else: raise NotImplementedError("Can't list tables for "+module+' database')
311

    
312
##### Database management
313

    
314
def empty_db(db):
315
    for table in tables(db): truncate(db, table)
316

    
317
##### Heuristic queries
318

    
319
def try_insert(db, table, row, returning=None):
320
    '''Recovers from errors'''
321
    try: return insert(db, table, row, returning, recover=True)
322
    except Exception, e:
323
        msg = str(e)
324
        match = re.search(r'duplicate key value violates unique constraint '
325
            r'"(([^\W_]+)_[^"]+)"', msg)
326
        if match:
327
            constraint, table = match.groups()
328
            try: cols = index_cols(db, table, constraint)
329
            except NotImplementedError: raise e
330
            else: raise DuplicateKeyException(cols, e)
331
        match = re.search(r'null value in column "(\w+)" violates not-null '
332
            'constraint', msg)
333
        if match: raise NullValueException([match.group(1)], e)
334
        raise # no specific exception raised
335

    
336
def put(db, table, row, pkey, row_ct_ref=None):
337
    '''Recovers from errors.
338
    Only works under PostgreSQL (uses `INSERT ... RETURNING`)'''
339
    try:
340
        cur = try_insert(db, table, row, pkey)
341
        if row_ct_ref != None and cur.rowcount >= 0:
342
            row_ct_ref[0] += cur.rowcount
343
        return value(cur)
344
    except DuplicateKeyException, e:
345
        return value(select(db, table, [pkey],
346
            util.dict_subset_right_join(row, e.cols), recover=True))
347

    
348
def get(db, table, row, pkey, row_ct_ref=None, create=False):
349
    '''Recovers from errors'''
350
    try: return value(select(db, table, [pkey], row, 1, recover=True))
351
    except StopIteration:
352
        if not create: raise
353
        return put(db, table, row, pkey, row_ct_ref) # insert new row
(22-22/33)