Project

General

Profile

1 11 aaronmk
# Database access
2
3
import random
4
import re
5
import sys
6
7
import ex_util
8
9 14 aaronmk
def _add_cursor_info(ex, cur): ex_util.add_msg(ex, 'query: '+cur.query)
10
11 11 aaronmk
class NameException(Exception): pass
12
13 14 aaronmk
class DbException(ex_util.ExceptionWithCause):
14
    def __init__(self, msg, cause=None, cur=None):
15
        ex_util.ExceptionWithCause.__init__(self, msg, cause)
16
        if cur != None: _add_cursor_info(self, cur)
17
18
class ExceptionWithColumn(DbException):
19 13 aaronmk
    def __init__(self, col, cause=None):
20 14 aaronmk
        DbException.__init__(self, 'column: '+col, cause)
21 13 aaronmk
        self.col = col
22 11 aaronmk
23 13 aaronmk
class DuplicateKeyException(ExceptionWithColumn): pass
24
25
class NullValueException(ExceptionWithColumn): pass
26
27 11 aaronmk
def check_name(name):
28
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
29
        +'" may contain only alphanumeric characters and _')
30
31
def run_query(db, query, params=None):
32
    cur = db.cursor()
33
    try: cur.execute(query, params)
34
    except Exception, ex:
35 14 aaronmk
        _add_cursor_info(ex, cur)
36 11 aaronmk
        raise
37
    return cur
38
39 15 aaronmk
def col(cur, idx): return cur.description[idx][0]
40
41 14 aaronmk
def row(cur): return iter(lambda: cur.fetchone(), None).next()
42 11 aaronmk
43
def value(cur): return row(cur)[0]
44
45
def with_savepoint(db, func):
46
    savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique
47
    run_query(db, 'SAVEPOINT '+savepoint)
48
    try: return_val = func()
49
    except:
50
        run_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
51
        raise
52
    else:
53
        run_query(db, 'RELEASE SAVEPOINT '+savepoint)
54
        return return_val
55
56
def select(db, table, fields, conds):
57 15 aaronmk
    check_name(table)
58
    map(check_name, fields)
59
    map(check_name, conds.keys())
60 11 aaronmk
    def cond(entry):
61 13 aaronmk
        col, value = entry
62
        cond_ = col+' '
63 11 aaronmk
        if value == None: cond_ += 'IS'
64
        else: cond_ += '='
65
        cond_ += ' %s'
66
        return cond_
67
    return run_query(db, 'SELECT '+', '.join(fields)+' FROM '+table+' WHERE '
68
        +' AND '.join(map(cond, conds.iteritems())), conds.values())
69
70 13 aaronmk
def insert(db, table, row):
71 11 aaronmk
    check_name(table)
72 13 aaronmk
    cols = row.keys()
73 15 aaronmk
    map(check_name, cols)
74 13 aaronmk
    return run_query(db, 'INSERT INTO '+table+' ('+', '.join(cols)
75
        +') VALUES ('+', '.join(['%s']*len(cols))+')', row.values())
76 11 aaronmk
77 13 aaronmk
def last_insert_id(db): return value(run_query(db, 'SELECT lastval()'))
78
79 14 aaronmk
def try_insert(db, table, row):
80 13 aaronmk
    try: return with_savepoint(db, lambda: insert(db, table, row))
81 11 aaronmk
    except Exception, ex:
82 13 aaronmk
        msg = str(ex)
83 11 aaronmk
        match = re.search(r'duplicate key value violates unique constraint "'
84 13 aaronmk
            +table+'_(\w+)_index"', msg)
85
        if match: raise DuplicateKeyException(match.group(1), ex)
86
        match = re.search(r'null value in column "(\w+)" violates not-null '
87
            'constraint', msg)
88
        if match: raise NullValueException(match.group(1), ex)
89
        raise # no specific exception raised
90 11 aaronmk
91 15 aaronmk
def pkey(db, table): # Assumed to be first column in table
92
    check_name(table)
93
    return col(run_query(db, 'SELECT * FROM '+table+' LIMIT 0'), 0)
94
95 13 aaronmk
def insert_or_get(db, table, row, pkey, row_ct_ref=None):
96 14 aaronmk
    try: return value(select(db, table, [pkey], row))
97
    except StopIteration:
98
        try:
99
            row_ct = try_insert(db, table, row).rowcount
100
            if row_ct_ref != None and row_ct >= 0: row_ct_ref[0] += row_ct
101
            return last_insert_id(db)
102
        except DuplicateKeyException, ex:
103
            return value(select(db, table, [pkey], {ex.col: row[ex.col]}))