Project

General

Profile

1
# Database access
2

    
3
import random
4
import re
5
import sys
6
import warnings
7

    
8
import exc
9
import strings
10
import util
11

    
12
##### Exceptions
13

    
14
def get_cur_query(cur):
15
    if hasattr(cur, 'query'): return cur.query
16
    elif hasattr(cur, '_last_executed'): return cur._last_executed
17
    else: return None
18

    
19
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
20

    
21
class DbException(exc.ExceptionWithCause):
22
    def __init__(self, msg, cause=None, cur=None):
23
        exc.ExceptionWithCause.__init__(self, msg, cause)
24
        if cur != None: _add_cursor_info(self, cur)
25

    
26
class NameException(DbException): pass
27

    
28
class ExceptionWithColumns(DbException):
29
    def __init__(self, cols, cause=None):
30
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
31
        self.cols = cols
32

    
33
class DuplicateKeyException(ExceptionWithColumns): pass
34

    
35
class NullValueException(ExceptionWithColumns): pass
36

    
37
class EmptyRowException(DbException): pass
38

    
39
##### Warnings
40

    
41
class DbWarning(UserWarning): pass
42

    
43
##### Input validation
44

    
45
def check_name(name):
46
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
47
        +'" may contain only alphanumeric characters and _')
48

    
49
def esc_name(db, name):
50
    module = util.root_module(db.db)
51
    if module == 'psycopg2': return name
52
        # Don't enclose in quotes because this disables case-insensitivity
53
    elif module == 'MySQLdb': quote = '`'
54
    else: raise NotImplementedError("Can't escape name for "+module+' database')
55
    return quote + name.replace(quote, '') + quote
56

    
57
##### Connection object
58

    
59
class DbConn:
60
    def __init__(self, db):
61
        self.db = db
62
        self.pkeys = {}
63
        self.index_cols = {}
64

    
65
##### Querying
66

    
67
def run_raw_query(db, query, params=None):
68
    cur = db.db.cursor()
69
    try: cur.execute(query, params)
70
    except Exception, e:
71
        _add_cursor_info(e, cur)
72
        raise
73
    if run_raw_query.debug:
74
        sys.stderr.write(strings.one_line(get_cur_query(cur))+'\n')
75
    return cur
76

    
77
##### Recoverable querying
78

    
79
def with_savepoint(db, func):
80
    savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique
81
    run_raw_query(db, 'SAVEPOINT '+savepoint)
82
    try: return_val = func()
83
    except:
84
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
85
        raise
86
    else:
87
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
88
        return return_val
89

    
90
def run_query(db, query, params=None, recover=None):
91
    if recover == None: recover = False
92
    
93
    def run(): return run_raw_query(db, query, params)
94
    if recover: return with_savepoint(db, run)
95
    else: return run()
96

    
97
##### Result retrieval
98

    
99
def col_names(cur): return (col[0] for col in cur.description)
100

    
101
def rows(cur): return iter(lambda: cur.fetchone(), None)
102

    
103
def row(cur): return rows(cur).next()
104

    
105
def value(cur): return row(cur)[0]
106

    
107
def values(cur): return iter(lambda: value(cur), None)
108

    
109
def value_or_none(cur):
110
    try: return value(cur)
111
    except StopIteration: return None
112

    
113
##### Basic queries
114

    
115
def select(db, table, fields=None, conds=None, limit=None, start=None,
116
    recover=None):
117
    '''@param fields Use None to select all fields in the table'''
118
    if conds == None: conds = {}
119
    assert limit == None or type(limit) == int
120
    assert start == None or type(start) == int
121
    check_name(table)
122
    if fields != None: map(check_name, fields)
123
    map(check_name, conds.keys())
124
    
125
    def cond(entry):
126
        col, value = entry
127
        cond_ = esc_name(db, col)+' '
128
        if value == None: cond_ += 'IS'
129
        else: cond_ += '='
130
        cond_ += ' %s'
131
        return cond_
132
    query = 'SELECT '
133
    if fields == None: query += '*'
134
    else: query += ', '.join([esc_name(db, field) for field in fields])
135
    query += ' FROM '+esc_name(db, table)
136
    
137
    missing = True
138
    if conds != {}:
139
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
140
        missing = False
141
    if limit != None: query += ' LIMIT '+str(limit); missing = False
142
    if start != None:
143
        if start != 0: query += ' OFFSET '+str(start)
144
        missing = False
145
    if missing: warnings.warn(DbWarning(
146
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
147
    
148
    return run_query(db, query, conds.values(), recover)
149

    
150
def insert(db, table, row, returning=None, recover=None):
151
    '''@param returning str|None An inserted column (such as pkey) to return'''
152
    check_name(table)
153
    cols = row.keys()
154
    map(check_name, cols)
155
    query = 'INSERT INTO '+table
156
    
157
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
158
        +', '.join(['%s']*len(cols))+')'
159
    else: query += ' DEFAULT VALUES'
160
    
161
    if returning != None:
162
        check_name(returning)
163
        query += ' RETURNING '+returning
164
    
165
    return run_query(db, query, row.values(), recover)
166

    
167
def last_insert_id(db):
168
    module = util.root_module(db.db)
169
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
170
    elif module == 'MySQLdb': return db.insert_id()
171
    else: return None
172

    
173
def truncate(db, table):
174
    check_name(table)
175
    return run_raw_query(db, 'TRUNCATE '+table+' CASCADE')
176

    
177
##### Database structure queries
178

    
179
def pkey(db, table, recover=None):
180
    '''Assumed to be first column in table'''
181
    check_name(table)
182
    if table not in db.pkeys:
183
        db.pkeys[table] = col_names(run_query(db,
184
            'SELECT * FROM '+table+' LIMIT 0', recover=recover)).next()
185
    return db.pkeys[table]
186

    
187
def index_cols(db, table, index):
188
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
189
    automatically created. When you don't know whether something is a UNIQUE
190
    constraint or a UNIQUE index, use this function.'''
191
    check_name(table)
192
    check_name(index)
193
    lookup = (table, index)
194
    if lookup not in db.index_cols:
195
        module = util.root_module(db.db)
196
        if module == 'psycopg2':
197
            db.index_cols[lookup] = list(values(run_query(db, '''\
198
SELECT attname
199
FROM
200
(
201
        SELECT attnum, attname
202
        FROM pg_index
203
        JOIN pg_class index ON index.oid = indexrelid
204
        JOIN pg_class table_ ON table_.oid = indrelid
205
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
206
        WHERE
207
            table_.relname = %(table)s
208
            AND index.relname = %(index)s
209
    UNION
210
        SELECT attnum, attname
211
        FROM
212
        (
213
            SELECT
214
                indrelid
215
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
216
                    AS indkey
217
            FROM pg_index
218
            JOIN pg_class index ON index.oid = indexrelid
219
            JOIN pg_class table_ ON table_.oid = indrelid
220
            WHERE
221
                table_.relname = %(table)s
222
                AND index.relname = %(index)s
223
        ) s
224
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
225
) s
226
ORDER BY attnum
227
''',
228
                {'table': table, 'index': index})))
229
        else: raise NotImplementedError("Can't list index columns for "+module+
230
            ' database')
231
    return db.index_cols[lookup]
232

    
233
def constraint_cols(db, table, constraint):
234
    check_name(table)
235
    check_name(constraint)
236
    module = util.root_module(db.db)
237
    if module == 'psycopg2':
238
        return list(values(run_query(db, '''\
239
SELECT attname
240
FROM pg_constraint
241
JOIN pg_class ON pg_class.oid = conrelid
242
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
243
WHERE
244
    relname = %(table)s
245
    AND conname = %(constraint)s
246
ORDER BY attnum
247
''',
248
            {'table': table, 'constraint': constraint})))
249
    else: raise NotImplementedError("Can't list constraint columns for "+module+
250
        ' database')
251

    
252
def tables(db):
253
    module = util.root_module(db.db)
254
    if module == 'psycopg2':
255
        return values(run_query(db, "SELECT tablename from pg_tables "
256
            "WHERE schemaname = 'public' ORDER BY tablename"))
257
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
258
    else: raise NotImplementedError("Can't list tables for "+module+' database')
259

    
260
##### Database management
261

    
262
def empty_db(db):
263
    for table in tables(db): truncate(db, table)
264

    
265
##### Heuristic queries
266

    
267
def try_insert(db, table, row, returning=None):
268
    '''Recovers from errors'''
269
    try: return insert(db, table, row, returning, recover=True)
270
    except Exception, e:
271
        msg = str(e)
272
        match = re.search(r'duplicate key value violates unique constraint '
273
            r'"(([^\W_]+)_[^"]+)"', msg)
274
        if match:
275
            constraint, table = match.groups()
276
            try: cols = index_cols(db, table, constraint)
277
            except NotImplementedError: raise e
278
            else: raise DuplicateKeyException(cols, e)
279
        match = re.search(r'null value in column "(\w+)" violates not-null '
280
            'constraint', msg)
281
        if match: raise NullValueException([match.group(1)], e)
282
        raise # no specific exception raised
283

    
284
def put(db, table, row, pkey, row_ct_ref=None):
285
    '''Recovers from errors.
286
    Only works under PostgreSQL (uses `INSERT ... RETURNING`)'''
287
    try:
288
        cur = try_insert(db, table, row, pkey)
289
        if row_ct_ref != None and cur.rowcount >= 0:
290
            row_ct_ref[0] += cur.rowcount
291
        return value(cur)
292
    except DuplicateKeyException, e:
293
        return value(select(db, table, [pkey],
294
            util.dict_subset_right_join(row, e.cols), recover=True))
295

    
296
def get(db, table, row, pkey, row_ct_ref=None, create=False):
297
    '''Recovers from errors'''
298
    try: return value(select(db, table, [pkey], row, 1, recover=True))
299
    except StopIteration:
300
        if not create: raise
301
        return put(db, table, row, pkey, row_ct_ref) # insert new row
302

    
303
##### Database connections
304

    
305
db_engines = {
306
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
307
    'PostgreSQL': ('psycopg2', {}),
308
}
309

    
310
DatabaseErrors_set = set([DbException])
311
DatabaseErrors = tuple(DatabaseErrors_set)
312

    
313
def _add_module(module):
314
    DatabaseErrors_set.add(module.DatabaseError)
315
    global DatabaseErrors
316
    DatabaseErrors = tuple(DatabaseErrors_set)
317

    
318
def connect(db_config, serializable=True):
319
    db_config = db_config.copy() # don't modify input!
320
    module_name, mappings = db_engines[db_config.pop('engine')]
321
    module = __import__(module_name)
322
    _add_module(module)
323
    for orig, new in mappings.iteritems():
324
        try: util.rename_key(db_config, orig, new)
325
        except KeyError: pass
326
    db = DbConn(module.connect(**db_config))
327
    if serializable:
328
        run_raw_query(db, 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
329
    return db
330

    
331
def db_config_str(db_config):
332
    return db_config['engine']+' database '+db_config['database']
(15-15/26)