Project

General

Profile

1
# Database access
2

    
3
import random
4
import re
5
import sys
6

    
7
import ex_util
8

    
9
class NameException(Exception): pass
10

    
11
class ExceptionWithColumn(ex_util.ExceptionWithCause):
12
    def __init__(self, col, cause=None):
13
        ex_util.ExceptionWithCause.__init__(self, 'column: '+col, cause)
14
        self.col = col
15

    
16
class DuplicateKeyException(ExceptionWithColumn): pass
17

    
18
class NullValueException(ExceptionWithColumn): pass
19

    
20
def check_name(name):
21
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
22
        +'" may contain only alphanumeric characters and _')
23

    
24
def run_query(db, query, params=None):
25
    cur = db.cursor()
26
    try: cur.execute(query, params)
27
    except Exception, ex:
28
        ex_util.add_msg(ex, 'query: '+cur.query)
29
        raise
30
    return cur
31

    
32
def row(cur): return cur.fetchone()
33

    
34
def value(cur): return row(cur)[0]
35

    
36
def with_savepoint(db, func):
37
    savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique
38
    run_query(db, 'SAVEPOINT '+savepoint)
39
    try: return_val = func()
40
    except:
41
        run_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
42
        raise
43
    else:
44
        run_query(db, 'RELEASE SAVEPOINT '+savepoint)
45
        return return_val
46

    
47
def select(db, table, fields, conds):
48
    for field in fields: check_name(field)
49
    for col in conds.keys(): check_name(col)
50
    def cond(entry):
51
        col, value = entry
52
        cond_ = col+' '
53
        if value == None: cond_ += 'IS'
54
        else: cond_ += '='
55
        cond_ += ' %s'
56
        return cond_
57
    return run_query(db, 'SELECT '+', '.join(fields)+' FROM '+table+' WHERE '
58
        +' AND '.join(map(cond, conds.iteritems())), conds.values())
59

    
60
def insert(db, table, row):
61
    check_name(table)
62
    cols = row.keys()
63
    for col in cols: check_name(col)
64
    return run_query(db, 'INSERT INTO '+table+' ('+', '.join(cols)
65
        +') VALUES ('+', '.join(['%s']*len(cols))+')', row.values())
66

    
67
def last_insert_id(db): return value(run_query(db, 'SELECT lastval()'))
68

    
69
def insert_ignore(db, table, row):
70
    try: return with_savepoint(db, lambda: insert(db, table, row))
71
    except Exception, ex:
72
        msg = str(ex)
73
        match = re.search(r'duplicate key value violates unique constraint "'
74
            +table+'_(\w+)_index"', msg)
75
        if match: raise DuplicateKeyException(match.group(1), ex)
76
        match = re.search(r'null value in column "(\w+)" violates not-null '
77
            'constraint', msg)
78
        if match: raise NullValueException(match.group(1), ex)
79
        raise # no specific exception raised
80

    
81
def insert_or_get(db, table, row, pkey, row_ct_ref=None):
82
    try:
83
        row_ct = insert_ignore(db, table, row).rowcount
84
        if row_ct_ref != None and row_ct >= 0: row_ct_ref[0] += row_ct
85
        return last_insert_id(db)
86
    except DuplicateKeyException, ex:
87
        return value(select(db, table, [pkey], {ex.col: row[ex.col]}))
(1-1/8)