Project

General

Profile

1 11 aaronmk
# Database access
2
3
import random
4
import re
5
import sys
6
7 300 aaronmk
import exc
8 131 aaronmk
import util
9 11 aaronmk
10 135 aaronmk
def get_cur_query(cur):
11
    if hasattr(cur, 'query'): return cur.query
12
    elif hasattr(cur, '_last_executed'): return cur._last_executed
13
    else: return None
14 14 aaronmk
15 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
16 135 aaronmk
17 300 aaronmk
class DbException(exc.ExceptionWithCause):
18 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
19 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
20 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
21
22 360 aaronmk
class NameException(DbException): pass
23
24 468 aaronmk
class ExceptionWithColumns(DbException):
25
    def __init__(self, cols, cause=None):
26
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
27
        self.cols = cols
28 11 aaronmk
29 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
30 13 aaronmk
31 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
32 13 aaronmk
33 89 aaronmk
class EmptyRowException(DbException): pass
34
35 11 aaronmk
def check_name(name):
36
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
37
        +'" may contain only alphanumeric characters and _')
38
39
def run_query(db, query, params=None):
40
    cur = db.cursor()
41
    try: cur.execute(query, params)
42 46 aaronmk
    except Exception, e:
43
        _add_cursor_info(e, cur)
44 11 aaronmk
        raise
45
    return cur
46
47 15 aaronmk
def col(cur, idx): return cur.description[idx][0]
48
49 135 aaronmk
def rows(cur): return iter(lambda: cur.fetchone(), None)
50 11 aaronmk
51 135 aaronmk
def row(cur): return rows(cur).next()
52
53 11 aaronmk
def value(cur): return row(cur)[0]
54
55 399 aaronmk
def values(cur): return iter(lambda: value(cur), None)
56
57 140 aaronmk
def value_or_none(cur):
58
    try: return value(cur)
59
    except StopIteration: return None
60
61 11 aaronmk
def with_savepoint(db, func):
62
    savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique
63
    run_query(db, 'SAVEPOINT '+savepoint)
64
    try: return_val = func()
65
    except:
66
        run_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
67
        raise
68
    else:
69
        run_query(db, 'RELEASE SAVEPOINT '+savepoint)
70
        return return_val
71
72 92 aaronmk
def select(db, table, fields, conds, limit=None):
73 135 aaronmk
    assert limit == None or type(limit) == int
74 15 aaronmk
    check_name(table)
75
    map(check_name, fields)
76
    map(check_name, conds.keys())
77 11 aaronmk
    def cond(entry):
78 13 aaronmk
        col, value = entry
79
        cond_ = col+' '
80 11 aaronmk
        if value == None: cond_ += 'IS'
81
        else: cond_ += '='
82
        cond_ += ' %s'
83
        return cond_
84 89 aaronmk
    query = 'SELECT '+', '.join(fields)+' FROM '+table
85
    if conds != {}:
86
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
87 92 aaronmk
    if limit != None: query += ' LIMIT '+str(limit)
88 89 aaronmk
    return run_query(db, query, conds.values())
89 11 aaronmk
90 13 aaronmk
def insert(db, table, row):
91 11 aaronmk
    check_name(table)
92 13 aaronmk
    cols = row.keys()
93 15 aaronmk
    map(check_name, cols)
94 89 aaronmk
    query = 'INSERT INTO '+table
95
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
96
        +', '.join(['%s']*len(cols))+')'
97
    else: query += ' DEFAULT VALUES'
98
    return run_query(db, query, row.values())
99 11 aaronmk
100 135 aaronmk
def last_insert_id(db):
101
    module = util.root_module(db)
102
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
103
    elif module == 'MySQLdb': return db.insert_id()
104
    else: return None
105 13 aaronmk
106 464 aaronmk
def constraint_cols(db, table, constraint):
107
    check_name(table)
108
    check_name(constraint)
109
    module = util.root_module(db)
110
    if module == 'psycopg2':
111
        return list(values(run_query(db, '''\
112
SELECT attname
113
FROM pg_constraint
114
JOIN pg_class ON pg_class.oid = conrelid
115
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
116
WHERE
117
    relname = %(table)s
118
    AND conname = %(constraint)s
119
ORDER BY attnum
120
''',
121
            {'table': table, 'constraint': constraint})))
122
    else: raise NotImplementedError("Can't list constraint columns for "+module+
123
        ' database')
124
125 14 aaronmk
def try_insert(db, table, row):
126 13 aaronmk
    try: return with_savepoint(db, lambda: insert(db, table, row))
127 46 aaronmk
    except Exception, e:
128
        msg = str(e)
129 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
130
            r'"(([^\W_]+)_[^"]+)"', msg)
131
        if match:
132
            constraint, table = match.groups()
133
            try: cols = constraint_cols(db, table, constraint)
134
            except NotImplementedError: raise e
135
            else: raise DuplicateKeyException(cols[0], e)
136 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
137
            'constraint', msg)
138 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
139 13 aaronmk
        raise # no specific exception raised
140 11 aaronmk
141 126 aaronmk
def pkey(db, cache, table): # Assumed to be first column in table
142 15 aaronmk
    check_name(table)
143 126 aaronmk
    if table not in cache:
144
        cache[table] = col(run_query(db, 'SELECT * FROM '+table+' LIMIT 0'), 0)
145
    return cache[table]
146 15 aaronmk
147 471 aaronmk
def put(db, table, row, pkey, row_ct_ref=None):
148
    try:
149
        row_ct = try_insert(db, table, row).rowcount
150
        if row_ct_ref != None and row_ct >= 0: row_ct_ref[0] += row_ct
151
        return last_insert_id(db)
152
    except DuplicateKeyException, e:
153
        return value(select(db, table, [pkey], util.dict_subset(row, e.cols)))
154
155 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
156 92 aaronmk
    try: return value(select(db, table, [pkey], row, 1))
157 14 aaronmk
    except StopIteration:
158 40 aaronmk
        if not create: raise
159 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
160 131 aaronmk
161 471 aaronmk
162 399 aaronmk
def truncate(db, table):
163
    check_name(table)
164
    return run_query(db, 'TRUNCATE '+table+' CASCADE')
165
166
def tables(db):
167
    module = util.root_module(db)
168
    if module == 'psycopg2':
169
        return values(run_query(db, "SELECT tablename from pg_tables "
170
            "WHERE schemaname = 'public' ORDER BY tablename"))
171
    elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
172
    else: raise NotImplementedError("Can't list tables for "+module+' database')
173
174
def empty_db(db):
175
    for table in tables(db): truncate(db, table)
176
177 131 aaronmk
db_engines = {
178
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
179
    'PostgreSQL': ('psycopg2', {}),
180
}
181
182 360 aaronmk
DatabaseErrors_set = set([DbException])
183
DatabaseErrors = tuple(DatabaseErrors_set)
184 342 aaronmk
185
def _add_module(module):
186
    DatabaseErrors_set.add(module.DatabaseError)
187
    global DatabaseErrors
188
    DatabaseErrors = tuple(DatabaseErrors_set)
189
190 131 aaronmk
def connect(db_config):
191
    db_config = db_config.copy() # don't modify input!
192 342 aaronmk
    module_name, mappings = db_engines[db_config.pop('engine')]
193
    module = __import__(module_name)
194
    _add_module(module)
195 330 aaronmk
    for orig, new in mappings.iteritems():
196
        try: util.rename_key(db_config, orig, new)
197
        except KeyError: pass
198 342 aaronmk
    return module.connect(**db_config)