Project

General

Profile

1 11 aaronmk
# Database access
2
3
import random
4
import re
5
import sys
6
7 46 aaronmk
import ex
8 131 aaronmk
import util
9 11 aaronmk
10 46 aaronmk
def _add_cursor_info(e, cur): ex.add_msg(e, 'query: '+cur.query)
11 14 aaronmk
12 11 aaronmk
class NameException(Exception): pass
13
14 46 aaronmk
class DbException(ex.ExceptionWithCause):
15 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
16 46 aaronmk
        ex.ExceptionWithCause.__init__(self, msg, cause)
17 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
18
19
class ExceptionWithColumn(DbException):
20 13 aaronmk
    def __init__(self, col, cause=None):
21 14 aaronmk
        DbException.__init__(self, 'column: '+col, cause)
22 13 aaronmk
        self.col = col
23 11 aaronmk
24 13 aaronmk
class DuplicateKeyException(ExceptionWithColumn): pass
25
26
class NullValueException(ExceptionWithColumn): pass
27
28 89 aaronmk
class EmptyRowException(DbException): pass
29
30 11 aaronmk
def check_name(name):
31
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
32
        +'" may contain only alphanumeric characters and _')
33
34
def run_query(db, query, params=None):
35
    cur = db.cursor()
36
    try: cur.execute(query, params)
37 46 aaronmk
    except Exception, e:
38
        _add_cursor_info(e, cur)
39 11 aaronmk
        raise
40
    return cur
41
42 15 aaronmk
def col(cur, idx): return cur.description[idx][0]
43
44 14 aaronmk
def row(cur): return iter(lambda: cur.fetchone(), None).next()
45 11 aaronmk
46
def value(cur): return row(cur)[0]
47
48
def with_savepoint(db, func):
49
    savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique
50
    run_query(db, 'SAVEPOINT '+savepoint)
51
    try: return_val = func()
52
    except:
53
        run_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
54
        raise
55
    else:
56
        run_query(db, 'RELEASE SAVEPOINT '+savepoint)
57
        return return_val
58
59 92 aaronmk
def select(db, table, fields, conds, limit=None):
60
    assert type(limit) == int
61 15 aaronmk
    check_name(table)
62
    map(check_name, fields)
63
    map(check_name, conds.keys())
64 11 aaronmk
    def cond(entry):
65 13 aaronmk
        col, value = entry
66
        cond_ = col+' '
67 11 aaronmk
        if value == None: cond_ += 'IS'
68
        else: cond_ += '='
69
        cond_ += ' %s'
70
        return cond_
71 89 aaronmk
    query = 'SELECT '+', '.join(fields)+' FROM '+table
72
    if conds != {}:
73
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
74 92 aaronmk
    if limit != None: query += ' LIMIT '+str(limit)
75 89 aaronmk
    return run_query(db, query, conds.values())
76 11 aaronmk
77 13 aaronmk
def insert(db, table, row):
78 11 aaronmk
    check_name(table)
79 13 aaronmk
    cols = row.keys()
80 15 aaronmk
    map(check_name, cols)
81 89 aaronmk
    query = 'INSERT INTO '+table
82
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
83
        +', '.join(['%s']*len(cols))+')'
84
    else: query += ' DEFAULT VALUES'
85
    return run_query(db, query, row.values())
86 11 aaronmk
87 13 aaronmk
def last_insert_id(db): return value(run_query(db, 'SELECT lastval()'))
88
89 14 aaronmk
def try_insert(db, table, row):
90 13 aaronmk
    try: return with_savepoint(db, lambda: insert(db, table, row))
91 46 aaronmk
    except Exception, e:
92
        msg = str(e)
93 11 aaronmk
        match = re.search(r'duplicate key value violates unique constraint "'
94 13 aaronmk
            +table+'_(\w+)_index"', msg)
95 46 aaronmk
        if match: raise DuplicateKeyException(match.group(1), e)
96 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
97
            'constraint', msg)
98 46 aaronmk
        if match: raise NullValueException(match.group(1), e)
99 13 aaronmk
        raise # no specific exception raised
100 11 aaronmk
101 126 aaronmk
def pkey(db, cache, table): # Assumed to be first column in table
102 15 aaronmk
    check_name(table)
103 126 aaronmk
    if table not in cache:
104
        cache[table] = col(run_query(db, 'SELECT * FROM '+table+' LIMIT 0'), 0)
105
    return cache[table]
106 15 aaronmk
107 40 aaronmk
def get(db, table, row, pkey, create=False, row_ct_ref=None):
108 92 aaronmk
    try: return value(select(db, table, [pkey], row, 1))
109 14 aaronmk
    except StopIteration:
110 40 aaronmk
        if not create: raise
111
        # Insert new row
112 14 aaronmk
        try:
113
            row_ct = try_insert(db, table, row).rowcount
114
            if row_ct_ref != None and row_ct >= 0: row_ct_ref[0] += row_ct
115
            return last_insert_id(db)
116 46 aaronmk
        except DuplicateKeyException, e:
117
            return value(select(db, table, [pkey], {e.col: row[e.col]}))
118 131 aaronmk
119
db_engines = {
120
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
121
    'PostgreSQL': ('psycopg2', {}),
122
}
123
124
def connect(db_config):
125
    db_config = db_config.copy() # don't modify input!
126
    module, mappings = db_engines[db_config.pop('engine')]
127
    for orig, new in mappings.iteritems(): util.rename_key(db_config, orig, new)
128
    return __import__(module).connect(**db_config)