Project

General

Profile

1 11 aaronmk
# Database access
2
3
import random
4
import re
5
import sys
6
7
import ex_util
8
9
class NameException(Exception): pass
10
11
class DuplicateKeyException(Exception): pass
12
13
def check_name(name):
14
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
15
        +'" may contain only alphanumeric characters and _')
16
17
def run_query(db, query, params=None):
18
    cur = db.cursor()
19
    try: cur.execute(query, params)
20
    except Exception, ex:
21
        ex_util.add_msg(ex, 'query: '+cur.query)
22
        raise
23
    return cur
24
25
def row(cur): return cur.fetchone()
26
27
def value(cur): return row(cur)[0]
28
29
def with_savepoint(db, func):
30
    savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique
31
    run_query(db, 'SAVEPOINT '+savepoint)
32
    try: return_val = func()
33
    except:
34
        run_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
35
        raise
36
    else:
37
        run_query(db, 'RELEASE SAVEPOINT '+savepoint)
38
        return return_val
39
40
def select(db, table, fields, conds):
41
    for field in fields: check_name(field)
42
    for key in conds.keys(): check_name(key)
43
    def cond(entry):
44
        key, value = entry
45
        cond_ = key+' '
46
        if value == None: cond_ += 'IS'
47
        else: cond_ += '='
48
        cond_ += ' %s'
49
        return cond_
50
    return run_query(db, 'SELECT '+', '.join(fields)+' FROM '+table+' WHERE '
51
        +' AND '.join(map(cond, conds.iteritems())), conds.values())
52
53
def insert(db, table, row, get_id=False):
54
    check_name(table)
55
    keys = row.keys()
56
    for key in keys: check_name(key)
57
    query = 'INSERT INTO '+table+' ('+', '.join(keys)+') VALUES ('\
58
        +', '.join(['%s']*len(keys))+')'
59
    if get_id: query += ' RETURNING lastval()'
60
    cur = run_query(db, query, row.values())
61
    if get_id: return value(cur)
62
63
def insert_ignore(db, table, row, get_id=False):
64
    try: return with_savepoint(db, lambda: insert(db, table, row, get_id))
65
    except Exception, ex:
66
        match = re.search(r'duplicate key value violates unique constraint "'
67
            +table+'_(\w+)_index"', str(ex))
68
        if match: raise DuplicateKeyException(match.group(1))
69
        else: raise
70
71
def insert_or_get(db, table, row, pkey):
72
    try: return insert_ignore(db, table, row, True)
73
    except DuplicateKeyException, ex:
74
        dup_key = str(ex)
75
        return value(select(db, table, [pkey], {dup_key: row[dup_key]}))