Project

General

Profile

1 11 aaronmk
# Database access
2
3
import random
4
import re
5
import sys
6
7 46 aaronmk
import ex
8 131 aaronmk
import util
9 11 aaronmk
10 135 aaronmk
def get_cur_query(cur):
11
    if hasattr(cur, 'query'): return cur.query
12
    elif hasattr(cur, '_last_executed'): return cur._last_executed
13
    else: return None
14 14 aaronmk
15 135 aaronmk
def _add_cursor_info(e, cur): ex.add_msg(e, 'query: '+get_cur_query(cur))
16
17 11 aaronmk
class NameException(Exception): pass
18
19 46 aaronmk
class DbException(ex.ExceptionWithCause):
20 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
21 46 aaronmk
        ex.ExceptionWithCause.__init__(self, msg, cause)
22 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
23
24
class ExceptionWithColumn(DbException):
25 13 aaronmk
    def __init__(self, col, cause=None):
26 14 aaronmk
        DbException.__init__(self, 'column: '+col, cause)
27 13 aaronmk
        self.col = col
28 11 aaronmk
29 13 aaronmk
class DuplicateKeyException(ExceptionWithColumn): pass
30
31
class NullValueException(ExceptionWithColumn): pass
32
33 89 aaronmk
class EmptyRowException(DbException): pass
34
35 11 aaronmk
def check_name(name):
36
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
37
        +'" may contain only alphanumeric characters and _')
38
39
def run_query(db, query, params=None):
40
    cur = db.cursor()
41
    try: cur.execute(query, params)
42 46 aaronmk
    except Exception, e:
43
        _add_cursor_info(e, cur)
44 11 aaronmk
        raise
45
    return cur
46
47 15 aaronmk
def col(cur, idx): return cur.description[idx][0]
48
49 135 aaronmk
def rows(cur): return iter(lambda: cur.fetchone(), None)
50 11 aaronmk
51 135 aaronmk
def row(cur): return rows(cur).next()
52
53 11 aaronmk
def value(cur): return row(cur)[0]
54
55
def with_savepoint(db, func):
56
    savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique
57
    run_query(db, 'SAVEPOINT '+savepoint)
58
    try: return_val = func()
59
    except:
60
        run_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
61
        raise
62
    else:
63
        run_query(db, 'RELEASE SAVEPOINT '+savepoint)
64
        return return_val
65
66 92 aaronmk
def select(db, table, fields, conds, limit=None):
67 135 aaronmk
    assert limit == None or type(limit) == int
68 15 aaronmk
    check_name(table)
69
    map(check_name, fields)
70
    map(check_name, conds.keys())
71 11 aaronmk
    def cond(entry):
72 13 aaronmk
        col, value = entry
73
        cond_ = col+' '
74 11 aaronmk
        if value == None: cond_ += 'IS'
75
        else: cond_ += '='
76
        cond_ += ' %s'
77
        return cond_
78 89 aaronmk
    query = 'SELECT '+', '.join(fields)+' FROM '+table
79
    if conds != {}:
80
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
81 92 aaronmk
    if limit != None: query += ' LIMIT '+str(limit)
82 89 aaronmk
    return run_query(db, query, conds.values())
83 11 aaronmk
84 13 aaronmk
def insert(db, table, row):
85 11 aaronmk
    check_name(table)
86 13 aaronmk
    cols = row.keys()
87 15 aaronmk
    map(check_name, cols)
88 89 aaronmk
    query = 'INSERT INTO '+table
89
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
90
        +', '.join(['%s']*len(cols))+')'
91
    else: query += ' DEFAULT VALUES'
92
    return run_query(db, query, row.values())
93 11 aaronmk
94 135 aaronmk
def last_insert_id(db):
95
    module = util.root_module(db)
96
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
97
    elif module == 'MySQLdb': return db.insert_id()
98
    else: return None
99 13 aaronmk
100 14 aaronmk
def try_insert(db, table, row):
101 13 aaronmk
    try: return with_savepoint(db, lambda: insert(db, table, row))
102 46 aaronmk
    except Exception, e:
103
        msg = str(e)
104 11 aaronmk
        match = re.search(r'duplicate key value violates unique constraint "'
105 13 aaronmk
            +table+'_(\w+)_index"', msg)
106 46 aaronmk
        if match: raise DuplicateKeyException(match.group(1), e)
107 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
108
            'constraint', msg)
109 46 aaronmk
        if match: raise NullValueException(match.group(1), e)
110 13 aaronmk
        raise # no specific exception raised
111 11 aaronmk
112 126 aaronmk
def pkey(db, cache, table): # Assumed to be first column in table
113 15 aaronmk
    check_name(table)
114 126 aaronmk
    if table not in cache:
115
        cache[table] = col(run_query(db, 'SELECT * FROM '+table+' LIMIT 0'), 0)
116
    return cache[table]
117 15 aaronmk
118 40 aaronmk
def get(db, table, row, pkey, create=False, row_ct_ref=None):
119 92 aaronmk
    try: return value(select(db, table, [pkey], row, 1))
120 14 aaronmk
    except StopIteration:
121 40 aaronmk
        if not create: raise
122
        # Insert new row
123 14 aaronmk
        try:
124
            row_ct = try_insert(db, table, row).rowcount
125
            if row_ct_ref != None and row_ct >= 0: row_ct_ref[0] += row_ct
126
            return last_insert_id(db)
127 46 aaronmk
        except DuplicateKeyException, e:
128
            return value(select(db, table, [pkey], {e.col: row[e.col]}))
129 131 aaronmk
130
db_engines = {
131
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
132
    'PostgreSQL': ('psycopg2', {}),
133
}
134
135
def connect(db_config):
136
    db_config = db_config.copy() # don't modify input!
137
    module, mappings = db_engines[db_config.pop('engine')]
138
    for orig, new in mappings.iteritems(): util.rename_key(db_config, orig, new)
139
    return __import__(module).connect(**db_config)