Project

General

Profile

1 11 aaronmk
# Database access
2
3
import random
4
import re
5
import sys
6
7 300 aaronmk
import exc
8 862 aaronmk
import strings
9 131 aaronmk
import util
10 11 aaronmk
11 832 aaronmk
##### Exceptions
12
13 135 aaronmk
def get_cur_query(cur):
14
    if hasattr(cur, 'query'): return cur.query
15
    elif hasattr(cur, '_last_executed'): return cur._last_executed
16
    else: return None
17 14 aaronmk
18 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
19 135 aaronmk
20 300 aaronmk
class DbException(exc.ExceptionWithCause):
21 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
22 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
23 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
24
25 360 aaronmk
class NameException(DbException): pass
26
27 468 aaronmk
class ExceptionWithColumns(DbException):
28
    def __init__(self, cols, cause=None):
29
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
30
        self.cols = cols
31 11 aaronmk
32 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
33 13 aaronmk
34 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
35 13 aaronmk
36 89 aaronmk
class EmptyRowException(DbException): pass
37
38 832 aaronmk
##### Input validation
39
40 11 aaronmk
def check_name(name):
41
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
42
        +'" may contain only alphanumeric characters and _')
43
44 643 aaronmk
def esc_name(db, name):
45
    module = util.root_module(db)
46 645 aaronmk
    if module == 'psycopg2': return name
47
        # Don't enclose in quotes because this disables case-insensitivity
48 643 aaronmk
    elif module == 'MySQLdb': quote = '`'
49 645 aaronmk
    else: raise NotImplementedError("Can't escape name for "+module+' database')
50 643 aaronmk
    return quote + name.replace(quote, '') + quote
51
52 832 aaronmk
##### Querying
53
54 830 aaronmk
def run_raw_query(db, query, params=None):
55 862 aaronmk
    if run_raw_query.debug: sys.stderr.write(strings.one_line(query)+'\n')
56 11 aaronmk
    cur = db.cursor()
57
    try: cur.execute(query, params)
58 46 aaronmk
    except Exception, e:
59
        _add_cursor_info(e, cur)
60 11 aaronmk
        raise
61
    return cur
62
63 832 aaronmk
##### Recoverable querying
64 15 aaronmk
65 11 aaronmk
def with_savepoint(db, func):
66
    savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique
67 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
68 11 aaronmk
    try: return_val = func()
69
    except:
70 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
71 11 aaronmk
        raise
72
    else:
73 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
74 11 aaronmk
        return return_val
75
76 830 aaronmk
def run_query(db, query, params=None, recover=None):
77
    if recover == None: recover = False
78
79
    def run(): return run_raw_query(db, query, params)
80
    if recover: return with_savepoint(db, run)
81
    else: return run()
82
83 832 aaronmk
##### Result retrieval
84
85
def col(cur, idx): return cur.description[idx][0]
86
87
def rows(cur): return iter(lambda: cur.fetchone(), None)
88
89
def row(cur): return rows(cur).next()
90
91
def value(cur): return row(cur)[0]
92
93
def values(cur): return iter(lambda: value(cur), None)
94
95
def value_or_none(cur):
96
    try: return value(cur)
97
    except StopIteration: return None
98
99
##### Basic queries
100
101 830 aaronmk
def select(db, table, fields, conds, limit=None, recover=None):
102 135 aaronmk
    assert limit == None or type(limit) == int
103 15 aaronmk
    check_name(table)
104
    map(check_name, fields)
105
    map(check_name, conds.keys())
106 11 aaronmk
    def cond(entry):
107 13 aaronmk
        col, value = entry
108 644 aaronmk
        cond_ = esc_name(db, col)+' '
109 11 aaronmk
        if value == None: cond_ += 'IS'
110
        else: cond_ += '='
111
        cond_ += ' %s'
112
        return cond_
113 644 aaronmk
    query = ('SELECT ' + ', '.join([esc_name(db, field) for field in fields])
114
        + ' FROM '+esc_name(db, table))
115 89 aaronmk
    if conds != {}:
116
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
117 92 aaronmk
    if limit != None: query += ' LIMIT '+str(limit)
118 830 aaronmk
    return run_query(db, query, conds.values(), recover)
119 11 aaronmk
120 830 aaronmk
def insert(db, table, row, recover=None):
121 11 aaronmk
    check_name(table)
122 13 aaronmk
    cols = row.keys()
123 15 aaronmk
    map(check_name, cols)
124 89 aaronmk
    query = 'INSERT INTO '+table
125
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
126
        +', '.join(['%s']*len(cols))+')'
127
    else: query += ' DEFAULT VALUES'
128 830 aaronmk
    return run_query(db, query, row.values(), recover)
129 11 aaronmk
130 135 aaronmk
def last_insert_id(db):
131
    module = util.root_module(db)
132
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
133
    elif module == 'MySQLdb': return db.insert_id()
134
    else: return None
135 13 aaronmk
136 832 aaronmk
def truncate(db, table):
137
    check_name(table)
138
    return run_query(db, 'TRUNCATE '+table+' CASCADE')
139
140
##### Database structure queries
141
142
def pkey(db, cache, table, recover=None):
143
    '''Assumed to be first column in table'''
144
    check_name(table)
145
    if table not in cache:
146
        cache[table] = col(run_query(db, 'SELECT * FROM '+table+' LIMIT 0',
147 834 aaronmk
            recover=recover), 0)
148 832 aaronmk
    return cache[table]
149
150 853 aaronmk
def index_cols(db, table, index):
151
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
152
    automatically created. When you don't know whether something is a UNIQUE
153
    constraint or a UNIQUE index, use this function.'''
154
    check_name(table)
155
    check_name(index)
156
    module = util.root_module(db)
157
    if module == 'psycopg2':
158
        return list(values(run_query(db, '''\
159
SELECT attname
160
FROM pg_index
161
JOIN pg_class index ON index.oid = indexrelid
162
JOIN pg_class table_ ON table_.oid = indrelid
163
JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
164
WHERE
165
    table_.relname = %(table)s
166
    AND index.relname = %(index)s
167
ORDER BY attnum
168
''',
169
            {'table': table, 'index': index})))
170
    else: raise NotImplementedError("Can't list index columns for "+module+
171
        ' database')
172
173 464 aaronmk
def constraint_cols(db, table, constraint):
174
    check_name(table)
175
    check_name(constraint)
176
    module = util.root_module(db)
177
    if module == 'psycopg2':
178
        return list(values(run_query(db, '''\
179
SELECT attname
180
FROM pg_constraint
181
JOIN pg_class ON pg_class.oid = conrelid
182
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
183
WHERE
184
    relname = %(table)s
185
    AND conname = %(constraint)s
186
ORDER BY attnum
187
''',
188
            {'table': table, 'constraint': constraint})))
189
    else: raise NotImplementedError("Can't list constraint columns for "+module+
190
        ' database')
191
192 832 aaronmk
def tables(db):
193
    module = util.root_module(db)
194
    if module == 'psycopg2':
195
        return values(run_query(db, "SELECT tablename from pg_tables "
196
            "WHERE schemaname = 'public' ORDER BY tablename"))
197
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
198
    else: raise NotImplementedError("Can't list tables for "+module+' database')
199 830 aaronmk
200 833 aaronmk
##### Database management
201
202
def empty_db(db):
203
    for table in tables(db): truncate(db, table)
204
205 832 aaronmk
##### Heuristic queries
206
207 14 aaronmk
def try_insert(db, table, row):
208 830 aaronmk
    '''Recovers from errors'''
209
    try: return insert(db, table, row, recover=True)
210 46 aaronmk
    except Exception, e:
211
        msg = str(e)
212 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
213
            r'"(([^\W_]+)_[^"]+)"', msg)
214
        if match:
215
            constraint, table = match.groups()
216 854 aaronmk
            try: cols = index_cols(db, table, constraint)
217 465 aaronmk
            except NotImplementedError: raise e
218 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
219 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
220
            'constraint', msg)
221 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
222 13 aaronmk
        raise # no specific exception raised
223 11 aaronmk
224 471 aaronmk
def put(db, table, row, pkey, row_ct_ref=None):
225 830 aaronmk
    '''Recovers from errors'''
226 471 aaronmk
    try:
227
        row_ct = try_insert(db, table, row).rowcount
228
        if row_ct_ref != None and row_ct >= 0: row_ct_ref[0] += row_ct
229
        return last_insert_id(db)
230
    except DuplicateKeyException, e:
231 830 aaronmk
        return value(select(db, table, [pkey], util.dict_subset(row, e.cols),
232
            recover=True))
233 471 aaronmk
234 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
235 830 aaronmk
    '''Recovers from errors'''
236
    try: return value(select(db, table, [pkey], row, 1, recover=True))
237 14 aaronmk
    except StopIteration:
238 40 aaronmk
        if not create: raise
239 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
240 131 aaronmk
241 832 aaronmk
##### Database connections
242
243 131 aaronmk
db_engines = {
244
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
245
    'PostgreSQL': ('psycopg2', {}),
246
}
247
248 360 aaronmk
DatabaseErrors_set = set([DbException])
249
DatabaseErrors = tuple(DatabaseErrors_set)
250 342 aaronmk
251
def _add_module(module):
252
    DatabaseErrors_set.add(module.DatabaseError)
253
    global DatabaseErrors
254
    DatabaseErrors = tuple(DatabaseErrors_set)
255
256 646 aaronmk
def connect(db_config, serializable=True):
257 131 aaronmk
    db_config = db_config.copy() # don't modify input!
258 342 aaronmk
    module_name, mappings = db_engines[db_config.pop('engine')]
259
    module = __import__(module_name)
260
    _add_module(module)
261 330 aaronmk
    for orig, new in mappings.iteritems():
262
        try: util.rename_key(db_config, orig, new)
263
        except KeyError: pass
264 646 aaronmk
    db = module.connect(**db_config)
265
    if serializable:
266 830 aaronmk
        run_raw_query(db, 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
267 646 aaronmk
    return db
268
269
def db_config_str(db_config):
270
    return db_config['engine']+' database '+db_config['database']