Project

General

Profile

1 11 aaronmk
# Database access
2
3
import random
4
import re
5
import sys
6
7 46 aaronmk
import ex
8 11 aaronmk
9 46 aaronmk
def _add_cursor_info(e, cur): ex.add_msg(e, 'query: '+cur.query)
10 14 aaronmk
11 11 aaronmk
class NameException(Exception): pass
12
13 46 aaronmk
class DbException(ex.ExceptionWithCause):
14 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
15 46 aaronmk
        ex.ExceptionWithCause.__init__(self, msg, cause)
16 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
17
18
class ExceptionWithColumn(DbException):
19 13 aaronmk
    def __init__(self, col, cause=None):
20 14 aaronmk
        DbException.__init__(self, 'column: '+col, cause)
21 13 aaronmk
        self.col = col
22 11 aaronmk
23 13 aaronmk
class DuplicateKeyException(ExceptionWithColumn): pass
24
25
class NullValueException(ExceptionWithColumn): pass
26
27 89 aaronmk
class EmptyRowException(DbException): pass
28
29 11 aaronmk
def check_name(name):
30
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
31
        +'" may contain only alphanumeric characters and _')
32
33
def run_query(db, query, params=None):
34
    cur = db.cursor()
35
    try: cur.execute(query, params)
36 46 aaronmk
    except Exception, e:
37
        _add_cursor_info(e, cur)
38 11 aaronmk
        raise
39
    return cur
40
41 15 aaronmk
def col(cur, idx): return cur.description[idx][0]
42
43 14 aaronmk
def row(cur): return iter(lambda: cur.fetchone(), None).next()
44 11 aaronmk
45
def value(cur): return row(cur)[0]
46
47
def with_savepoint(db, func):
48
    savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique
49
    run_query(db, 'SAVEPOINT '+savepoint)
50
    try: return_val = func()
51
    except:
52
        run_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
53
        raise
54
    else:
55
        run_query(db, 'RELEASE SAVEPOINT '+savepoint)
56
        return return_val
57
58 92 aaronmk
def select(db, table, fields, conds, limit=None):
59
    assert type(limit) == int
60 15 aaronmk
    check_name(table)
61
    map(check_name, fields)
62
    map(check_name, conds.keys())
63 11 aaronmk
    def cond(entry):
64 13 aaronmk
        col, value = entry
65
        cond_ = col+' '
66 11 aaronmk
        if value == None: cond_ += 'IS'
67
        else: cond_ += '='
68
        cond_ += ' %s'
69
        return cond_
70 89 aaronmk
    query = 'SELECT '+', '.join(fields)+' FROM '+table
71
    if conds != {}:
72
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
73 92 aaronmk
    if limit != None: query += ' LIMIT '+str(limit)
74 89 aaronmk
    return run_query(db, query, conds.values())
75 11 aaronmk
76 13 aaronmk
def insert(db, table, row):
77 11 aaronmk
    check_name(table)
78 13 aaronmk
    cols = row.keys()
79 15 aaronmk
    map(check_name, cols)
80 89 aaronmk
    query = 'INSERT INTO '+table
81
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
82
        +', '.join(['%s']*len(cols))+')'
83
    else: query += ' DEFAULT VALUES'
84
    return run_query(db, query, row.values())
85 11 aaronmk
86 13 aaronmk
def last_insert_id(db): return value(run_query(db, 'SELECT lastval()'))
87
88 14 aaronmk
def try_insert(db, table, row):
89 13 aaronmk
    try: return with_savepoint(db, lambda: insert(db, table, row))
90 46 aaronmk
    except Exception, e:
91
        msg = str(e)
92 11 aaronmk
        match = re.search(r'duplicate key value violates unique constraint "'
93 13 aaronmk
            +table+'_(\w+)_index"', msg)
94 46 aaronmk
        if match: raise DuplicateKeyException(match.group(1), e)
95 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
96
            'constraint', msg)
97 46 aaronmk
        if match: raise NullValueException(match.group(1), e)
98 13 aaronmk
        raise # no specific exception raised
99 11 aaronmk
100 126 aaronmk
def pkey(db, cache, table): # Assumed to be first column in table
101 15 aaronmk
    check_name(table)
102 126 aaronmk
    if table not in cache:
103
        cache[table] = col(run_query(db, 'SELECT * FROM '+table+' LIMIT 0'), 0)
104
    return cache[table]
105 15 aaronmk
106 40 aaronmk
def get(db, table, row, pkey, create=False, row_ct_ref=None):
107 92 aaronmk
    try: return value(select(db, table, [pkey], row, 1))
108 14 aaronmk
    except StopIteration:
109 40 aaronmk
        if not create: raise
110
        # Insert new row
111 14 aaronmk
        try:
112
            row_ct = try_insert(db, table, row).rowcount
113
            if row_ct_ref != None and row_ct >= 0: row_ct_ref[0] += row_ct
114
            return last_insert_id(db)
115 46 aaronmk
        except DuplicateKeyException, e:
116
            return value(select(db, table, [pkey], {e.col: row[e.col]}))