Project

General

Profile

1 11 aaronmk
# Database access
2
3
import random
4
import re
5
import sys
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 862 aaronmk
import strings
10 131 aaronmk
import util
11 11 aaronmk
12 832 aaronmk
##### Exceptions
13
14 135 aaronmk
def get_cur_query(cur):
15
    if hasattr(cur, 'query'): return cur.query
16
    elif hasattr(cur, '_last_executed'): return cur._last_executed
17
    else: return None
18 14 aaronmk
19 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
20 135 aaronmk
21 300 aaronmk
class DbException(exc.ExceptionWithCause):
22 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
23 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
24 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
25
26 360 aaronmk
class NameException(DbException): pass
27
28 468 aaronmk
class ExceptionWithColumns(DbException):
29
    def __init__(self, cols, cause=None):
30
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
31
        self.cols = cols
32 11 aaronmk
33 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
34 13 aaronmk
35 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
36 13 aaronmk
37 89 aaronmk
class EmptyRowException(DbException): pass
38
39 865 aaronmk
##### Warnings
40
41
class DbWarning(UserWarning): pass
42
43 832 aaronmk
##### Input validation
44
45 11 aaronmk
def check_name(name):
46
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
47
        +'" may contain only alphanumeric characters and _')
48
49 643 aaronmk
def esc_name(db, name):
50
    module = util.root_module(db)
51 645 aaronmk
    if module == 'psycopg2': return name
52
        # Don't enclose in quotes because this disables case-insensitivity
53 643 aaronmk
    elif module == 'MySQLdb': quote = '`'
54 645 aaronmk
    else: raise NotImplementedError("Can't escape name for "+module+' database')
55 643 aaronmk
    return quote + name.replace(quote, '') + quote
56
57 832 aaronmk
##### Querying
58
59 830 aaronmk
def run_raw_query(db, query, params=None):
60 11 aaronmk
    cur = db.cursor()
61
    try: cur.execute(query, params)
62 46 aaronmk
    except Exception, e:
63
        _add_cursor_info(e, cur)
64 11 aaronmk
        raise
65 867 aaronmk
    if run_raw_query.debug:
66
        sys.stderr.write(strings.one_line(get_cur_query(cur))+'\n')
67 11 aaronmk
    return cur
68
69 832 aaronmk
##### Recoverable querying
70 15 aaronmk
71 11 aaronmk
def with_savepoint(db, func):
72
    savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique
73 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
74 11 aaronmk
    try: return_val = func()
75
    except:
76 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
77 11 aaronmk
        raise
78
    else:
79 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
80 11 aaronmk
        return return_val
81
82 830 aaronmk
def run_query(db, query, params=None, recover=None):
83
    if recover == None: recover = False
84
85
    def run(): return run_raw_query(db, query, params)
86
    if recover: return with_savepoint(db, run)
87
    else: return run()
88
89 832 aaronmk
##### Result retrieval
90
91
def col(cur, idx): return cur.description[idx][0]
92
93
def rows(cur): return iter(lambda: cur.fetchone(), None)
94
95
def row(cur): return rows(cur).next()
96
97
def value(cur): return row(cur)[0]
98
99
def values(cur): return iter(lambda: value(cur), None)
100
101
def value_or_none(cur):
102
    try: return value(cur)
103
    except StopIteration: return None
104
105
##### Basic queries
106
107 863 aaronmk
def select(db, table, fields, conds, limit=None, start=None, recover=None):
108 135 aaronmk
    assert limit == None or type(limit) == int
109 865 aaronmk
    assert start == None or type(start) == int
110 15 aaronmk
    check_name(table)
111
    map(check_name, fields)
112
    map(check_name, conds.keys())
113 865 aaronmk
114 11 aaronmk
    def cond(entry):
115 13 aaronmk
        col, value = entry
116 644 aaronmk
        cond_ = esc_name(db, col)+' '
117 11 aaronmk
        if value == None: cond_ += 'IS'
118
        else: cond_ += '='
119
        cond_ += ' %s'
120
        return cond_
121 644 aaronmk
    query = ('SELECT ' + ', '.join([esc_name(db, field) for field in fields])
122
        + ' FROM '+esc_name(db, table))
123 865 aaronmk
124
    missing = True
125 89 aaronmk
    if conds != {}:
126
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
127 865 aaronmk
        missing = False
128
    if limit != None: query += ' LIMIT '+str(limit); missing = False
129
    if start != None:
130
        if start != 0: query += ' OFFSET '+str(start)
131
        missing = False
132
    if missing: warnings.warn(DbWarning(
133
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
134
135 830 aaronmk
    return run_query(db, query, conds.values(), recover)
136 11 aaronmk
137 830 aaronmk
def insert(db, table, row, recover=None):
138 11 aaronmk
    check_name(table)
139 13 aaronmk
    cols = row.keys()
140 15 aaronmk
    map(check_name, cols)
141 89 aaronmk
    query = 'INSERT INTO '+table
142
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
143
        +', '.join(['%s']*len(cols))+')'
144
    else: query += ' DEFAULT VALUES'
145 830 aaronmk
    return run_query(db, query, row.values(), recover)
146 11 aaronmk
147 135 aaronmk
def last_insert_id(db):
148
    module = util.root_module(db)
149
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
150
    elif module == 'MySQLdb': return db.insert_id()
151
    else: return None
152 13 aaronmk
153 832 aaronmk
def truncate(db, table):
154
    check_name(table)
155 869 aaronmk
    return run_raw_query(db, 'TRUNCATE '+table+' CASCADE')
156 832 aaronmk
157
##### Database structure queries
158
159
def pkey(db, cache, table, recover=None):
160
    '''Assumed to be first column in table'''
161
    check_name(table)
162
    if table not in cache:
163
        cache[table] = col(run_query(db, 'SELECT * FROM '+table+' LIMIT 0',
164 834 aaronmk
            recover=recover), 0)
165 832 aaronmk
    return cache[table]
166
167 853 aaronmk
def index_cols(db, table, index):
168
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
169
    automatically created. When you don't know whether something is a UNIQUE
170
    constraint or a UNIQUE index, use this function.'''
171
    check_name(table)
172
    check_name(index)
173
    module = util.root_module(db)
174
    if module == 'psycopg2':
175
        return list(values(run_query(db, '''\
176
SELECT attname
177 866 aaronmk
FROM
178
(
179
        SELECT attnum, attname
180
        FROM pg_index
181
        JOIN pg_class index ON index.oid = indexrelid
182
        JOIN pg_class table_ ON table_.oid = indrelid
183
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
184
        WHERE
185
            table_.relname = %(table)s
186
            AND index.relname = %(index)s
187
    UNION
188
        SELECT attnum, attname
189
        FROM
190
        (
191
            SELECT
192
                indrelid
193
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
194
                    AS indkey
195
            FROM pg_index
196
            JOIN pg_class index ON index.oid = indexrelid
197
            JOIN pg_class table_ ON table_.oid = indrelid
198
            WHERE
199
                table_.relname = %(table)s
200
                AND index.relname = %(index)s
201
        ) s
202
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
203
) s
204 853 aaronmk
ORDER BY attnum
205
''',
206
            {'table': table, 'index': index})))
207
    else: raise NotImplementedError("Can't list index columns for "+module+
208
        ' database')
209
210 464 aaronmk
def constraint_cols(db, table, constraint):
211
    check_name(table)
212
    check_name(constraint)
213
    module = util.root_module(db)
214
    if module == 'psycopg2':
215
        return list(values(run_query(db, '''\
216
SELECT attname
217
FROM pg_constraint
218
JOIN pg_class ON pg_class.oid = conrelid
219
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
220
WHERE
221
    relname = %(table)s
222
    AND conname = %(constraint)s
223
ORDER BY attnum
224
''',
225
            {'table': table, 'constraint': constraint})))
226
    else: raise NotImplementedError("Can't list constraint columns for "+module+
227
        ' database')
228
229 832 aaronmk
def tables(db):
230
    module = util.root_module(db)
231
    if module == 'psycopg2':
232
        return values(run_query(db, "SELECT tablename from pg_tables "
233
            "WHERE schemaname = 'public' ORDER BY tablename"))
234
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
235
    else: raise NotImplementedError("Can't list tables for "+module+' database')
236 830 aaronmk
237 833 aaronmk
##### Database management
238
239
def empty_db(db):
240
    for table in tables(db): truncate(db, table)
241
242 832 aaronmk
##### Heuristic queries
243
244 14 aaronmk
def try_insert(db, table, row):
245 830 aaronmk
    '''Recovers from errors'''
246
    try: return insert(db, table, row, recover=True)
247 46 aaronmk
    except Exception, e:
248
        msg = str(e)
249 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
250
            r'"(([^\W_]+)_[^"]+)"', msg)
251
        if match:
252
            constraint, table = match.groups()
253 854 aaronmk
            try: cols = index_cols(db, table, constraint)
254 465 aaronmk
            except NotImplementedError: raise e
255 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
256 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
257
            'constraint', msg)
258 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
259 13 aaronmk
        raise # no specific exception raised
260 11 aaronmk
261 471 aaronmk
def put(db, table, row, pkey, row_ct_ref=None):
262 830 aaronmk
    '''Recovers from errors'''
263 471 aaronmk
    try:
264
        row_ct = try_insert(db, table, row).rowcount
265
        if row_ct_ref != None and row_ct >= 0: row_ct_ref[0] += row_ct
266
        return last_insert_id(db)
267
    except DuplicateKeyException, e:
268 830 aaronmk
        return value(select(db, table, [pkey], util.dict_subset(row, e.cols),
269
            recover=True))
270 471 aaronmk
271 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
272 830 aaronmk
    '''Recovers from errors'''
273
    try: return value(select(db, table, [pkey], row, 1, recover=True))
274 14 aaronmk
    except StopIteration:
275 40 aaronmk
        if not create: raise
276 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
277 131 aaronmk
278 832 aaronmk
##### Database connections
279
280 131 aaronmk
db_engines = {
281
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
282
    'PostgreSQL': ('psycopg2', {}),
283
}
284
285 360 aaronmk
DatabaseErrors_set = set([DbException])
286
DatabaseErrors = tuple(DatabaseErrors_set)
287 342 aaronmk
288
def _add_module(module):
289
    DatabaseErrors_set.add(module.DatabaseError)
290
    global DatabaseErrors
291
    DatabaseErrors = tuple(DatabaseErrors_set)
292
293 646 aaronmk
def connect(db_config, serializable=True):
294 131 aaronmk
    db_config = db_config.copy() # don't modify input!
295 342 aaronmk
    module_name, mappings = db_engines[db_config.pop('engine')]
296
    module = __import__(module_name)
297
    _add_module(module)
298 330 aaronmk
    for orig, new in mappings.iteritems():
299
        try: util.rename_key(db_config, orig, new)
300
        except KeyError: pass
301 646 aaronmk
    db = module.connect(**db_config)
302
    if serializable:
303 830 aaronmk
        run_raw_query(db, 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
304 646 aaronmk
    return db
305
306
def db_config_str(db_config):
307
    return db_config['engine']+' database '+db_config['database']