Project

General

Profile

1 11 aaronmk
# Database access
2
3
import random
4
import re
5
import sys
6
7 300 aaronmk
import exc
8 131 aaronmk
import util
9 11 aaronmk
10 832 aaronmk
##### Exceptions
11
12 135 aaronmk
def get_cur_query(cur):
13
    if hasattr(cur, 'query'): return cur.query
14
    elif hasattr(cur, '_last_executed'): return cur._last_executed
15
    else: return None
16 14 aaronmk
17 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
18 135 aaronmk
19 300 aaronmk
class DbException(exc.ExceptionWithCause):
20 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
21 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
22 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
23
24 360 aaronmk
class NameException(DbException): pass
25
26 468 aaronmk
class ExceptionWithColumns(DbException):
27
    def __init__(self, cols, cause=None):
28
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
29
        self.cols = cols
30 11 aaronmk
31 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
32 13 aaronmk
33 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
34 13 aaronmk
35 89 aaronmk
class EmptyRowException(DbException): pass
36
37 832 aaronmk
##### Input validation
38
39 11 aaronmk
def check_name(name):
40
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
41
        +'" may contain only alphanumeric characters and _')
42
43 643 aaronmk
def esc_name(db, name):
44
    module = util.root_module(db)
45 645 aaronmk
    if module == 'psycopg2': return name
46
        # Don't enclose in quotes because this disables case-insensitivity
47 643 aaronmk
    elif module == 'MySQLdb': quote = '`'
48 645 aaronmk
    else: raise NotImplementedError("Can't escape name for "+module+' database')
49 643 aaronmk
    return quote + name.replace(quote, '') + quote
50
51 832 aaronmk
##### Querying
52
53 830 aaronmk
def run_raw_query(db, query, params=None):
54 11 aaronmk
    cur = db.cursor()
55
    try: cur.execute(query, params)
56 46 aaronmk
    except Exception, e:
57
        _add_cursor_info(e, cur)
58 11 aaronmk
        raise
59
    return cur
60
61 832 aaronmk
##### Recoverable querying
62 15 aaronmk
63 11 aaronmk
def with_savepoint(db, func):
64
    savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique
65 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
66 11 aaronmk
    try: return_val = func()
67
    except:
68 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
69 11 aaronmk
        raise
70
    else:
71 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
72 11 aaronmk
        return return_val
73
74 830 aaronmk
def run_query(db, query, params=None, recover=None):
75
    if recover == None: recover = False
76
77
    def run(): return run_raw_query(db, query, params)
78
    if recover: return with_savepoint(db, run)
79
    else: return run()
80
81 832 aaronmk
##### Result retrieval
82
83
def col(cur, idx): return cur.description[idx][0]
84
85
def rows(cur): return iter(lambda: cur.fetchone(), None)
86
87
def row(cur): return rows(cur).next()
88
89
def value(cur): return row(cur)[0]
90
91
def values(cur): return iter(lambda: value(cur), None)
92
93
def value_or_none(cur):
94
    try: return value(cur)
95
    except StopIteration: return None
96
97
##### Basic queries
98
99 830 aaronmk
def select(db, table, fields, conds, limit=None, recover=None):
100 135 aaronmk
    assert limit == None or type(limit) == int
101 15 aaronmk
    check_name(table)
102
    map(check_name, fields)
103
    map(check_name, conds.keys())
104 11 aaronmk
    def cond(entry):
105 13 aaronmk
        col, value = entry
106 644 aaronmk
        cond_ = esc_name(db, col)+' '
107 11 aaronmk
        if value == None: cond_ += 'IS'
108
        else: cond_ += '='
109
        cond_ += ' %s'
110
        return cond_
111 644 aaronmk
    query = ('SELECT ' + ', '.join([esc_name(db, field) for field in fields])
112
        + ' FROM '+esc_name(db, table))
113 89 aaronmk
    if conds != {}:
114
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
115 92 aaronmk
    if limit != None: query += ' LIMIT '+str(limit)
116 830 aaronmk
    return run_query(db, query, conds.values(), recover)
117 11 aaronmk
118 830 aaronmk
def insert(db, table, row, recover=None):
119 11 aaronmk
    check_name(table)
120 13 aaronmk
    cols = row.keys()
121 15 aaronmk
    map(check_name, cols)
122 89 aaronmk
    query = 'INSERT INTO '+table
123
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
124
        +', '.join(['%s']*len(cols))+')'
125
    else: query += ' DEFAULT VALUES'
126 830 aaronmk
    return run_query(db, query, row.values(), recover)
127 11 aaronmk
128 135 aaronmk
def last_insert_id(db):
129
    module = util.root_module(db)
130
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
131
    elif module == 'MySQLdb': return db.insert_id()
132
    else: return None
133 13 aaronmk
134 832 aaronmk
def truncate(db, table):
135
    check_name(table)
136
    return run_query(db, 'TRUNCATE '+table+' CASCADE')
137
138
##### Database structure queries
139
140
def pkey(db, cache, table, recover=None):
141
    '''Assumed to be first column in table'''
142
    check_name(table)
143
    if table not in cache:
144
        cache[table] = col(run_query(db, 'SELECT * FROM '+table+' LIMIT 0',
145 834 aaronmk
            recover=recover), 0)
146 832 aaronmk
    return cache[table]
147
148 464 aaronmk
def constraint_cols(db, table, constraint):
149
    check_name(table)
150
    check_name(constraint)
151
    module = util.root_module(db)
152
    if module == 'psycopg2':
153
        return list(values(run_query(db, '''\
154
SELECT attname
155
FROM pg_constraint
156
JOIN pg_class ON pg_class.oid = conrelid
157
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
158
WHERE
159
    relname = %(table)s
160
    AND conname = %(constraint)s
161
ORDER BY attnum
162
''',
163
            {'table': table, 'constraint': constraint})))
164
    else: raise NotImplementedError("Can't list constraint columns for "+module+
165
        ' database')
166
167 832 aaronmk
def tables(db):
168
    module = util.root_module(db)
169
    if module == 'psycopg2':
170
        return values(run_query(db, "SELECT tablename from pg_tables "
171
            "WHERE schemaname = 'public' ORDER BY tablename"))
172
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
173
    else: raise NotImplementedError("Can't list tables for "+module+' database')
174 830 aaronmk
175 833 aaronmk
##### Database management
176
177
def empty_db(db):
178
    for table in tables(db): truncate(db, table)
179
180 832 aaronmk
##### Heuristic queries
181
182 14 aaronmk
def try_insert(db, table, row):
183 830 aaronmk
    '''Recovers from errors'''
184
    try: return insert(db, table, row, recover=True)
185 46 aaronmk
    except Exception, e:
186
        msg = str(e)
187 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
188
            r'"(([^\W_]+)_[^"]+)"', msg)
189
        if match:
190
            constraint, table = match.groups()
191
            try: cols = constraint_cols(db, table, constraint)
192
            except NotImplementedError: raise e
193 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
194 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
195
            'constraint', msg)
196 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
197 13 aaronmk
        raise # no specific exception raised
198 11 aaronmk
199 471 aaronmk
def put(db, table, row, pkey, row_ct_ref=None):
200 830 aaronmk
    '''Recovers from errors'''
201 471 aaronmk
    try:
202
        row_ct = try_insert(db, table, row).rowcount
203
        if row_ct_ref != None and row_ct >= 0: row_ct_ref[0] += row_ct
204
        return last_insert_id(db)
205
    except DuplicateKeyException, e:
206 830 aaronmk
        return value(select(db, table, [pkey], util.dict_subset(row, e.cols),
207
            recover=True))
208 471 aaronmk
209 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
210 830 aaronmk
    '''Recovers from errors'''
211
    try: return value(select(db, table, [pkey], row, 1, recover=True))
212 14 aaronmk
    except StopIteration:
213 40 aaronmk
        if not create: raise
214 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
215 131 aaronmk
216 832 aaronmk
##### Database connections
217
218 131 aaronmk
db_engines = {
219
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
220
    'PostgreSQL': ('psycopg2', {}),
221
}
222
223 360 aaronmk
DatabaseErrors_set = set([DbException])
224
DatabaseErrors = tuple(DatabaseErrors_set)
225 342 aaronmk
226
def _add_module(module):
227
    DatabaseErrors_set.add(module.DatabaseError)
228
    global DatabaseErrors
229
    DatabaseErrors = tuple(DatabaseErrors_set)
230
231 646 aaronmk
def connect(db_config, serializable=True):
232 131 aaronmk
    db_config = db_config.copy() # don't modify input!
233 342 aaronmk
    module_name, mappings = db_engines[db_config.pop('engine')]
234
    module = __import__(module_name)
235
    _add_module(module)
236 330 aaronmk
    for orig, new in mappings.iteritems():
237
        try: util.rename_key(db_config, orig, new)
238
        except KeyError: pass
239 646 aaronmk
    db = module.connect(**db_config)
240
    if serializable:
241 830 aaronmk
        run_raw_query(db, 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
242 646 aaronmk
    return db
243
244
def db_config_str(db_config):
245
    return db_config['engine']+' database '+db_config['database']