Project

General

Profile

1
# Database access
2

    
3
import random
4
import re
5
import sys
6

    
7
import ex
8
import util
9

    
10
def _add_cursor_info(e, cur): ex.add_msg(e, 'query: '+cur.query)
11

    
12
class NameException(Exception): pass
13

    
14
class DbException(ex.ExceptionWithCause):
15
    def __init__(self, msg, cause=None, cur=None):
16
        ex.ExceptionWithCause.__init__(self, msg, cause)
17
        if cur != None: _add_cursor_info(self, cur)
18

    
19
class ExceptionWithColumn(DbException):
20
    def __init__(self, col, cause=None):
21
        DbException.__init__(self, 'column: '+col, cause)
22
        self.col = col
23

    
24
class DuplicateKeyException(ExceptionWithColumn): pass
25

    
26
class NullValueException(ExceptionWithColumn): pass
27

    
28
class EmptyRowException(DbException): pass
29

    
30
def check_name(name):
31
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
32
        +'" may contain only alphanumeric characters and _')
33

    
34
def run_query(db, query, params=None):
35
    cur = db.cursor()
36
    try: cur.execute(query, params)
37
    except Exception, e:
38
        _add_cursor_info(e, cur)
39
        raise
40
    return cur
41

    
42
def col(cur, idx): return cur.description[idx][0]
43

    
44
def row(cur): return iter(lambda: cur.fetchone(), None).next()
45

    
46
def value(cur): return row(cur)[0]
47

    
48
def with_savepoint(db, func):
49
    savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique
50
    run_query(db, 'SAVEPOINT '+savepoint)
51
    try: return_val = func()
52
    except:
53
        run_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
54
        raise
55
    else:
56
        run_query(db, 'RELEASE SAVEPOINT '+savepoint)
57
        return return_val
58

    
59
def select(db, table, fields, conds, limit=None):
60
    assert type(limit) == int
61
    check_name(table)
62
    map(check_name, fields)
63
    map(check_name, conds.keys())
64
    def cond(entry):
65
        col, value = entry
66
        cond_ = col+' '
67
        if value == None: cond_ += 'IS'
68
        else: cond_ += '='
69
        cond_ += ' %s'
70
        return cond_
71
    query = 'SELECT '+', '.join(fields)+' FROM '+table
72
    if conds != {}:
73
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
74
    if limit != None: query += ' LIMIT '+str(limit)
75
    return run_query(db, query, conds.values())
76

    
77
def insert(db, table, row):
78
    check_name(table)
79
    cols = row.keys()
80
    map(check_name, cols)
81
    query = 'INSERT INTO '+table
82
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
83
        +', '.join(['%s']*len(cols))+')'
84
    else: query += ' DEFAULT VALUES'
85
    return run_query(db, query, row.values())
86

    
87
def last_insert_id(db): return value(run_query(db, 'SELECT lastval()'))
88

    
89
def try_insert(db, table, row):
90
    try: return with_savepoint(db, lambda: insert(db, table, row))
91
    except Exception, e:
92
        msg = str(e)
93
        match = re.search(r'duplicate key value violates unique constraint "'
94
            +table+'_(\w+)_index"', msg)
95
        if match: raise DuplicateKeyException(match.group(1), e)
96
        match = re.search(r'null value in column "(\w+)" violates not-null '
97
            'constraint', msg)
98
        if match: raise NullValueException(match.group(1), e)
99
        raise # no specific exception raised
100

    
101
def pkey(db, cache, table): # Assumed to be first column in table
102
    check_name(table)
103
    if table not in cache:
104
        cache[table] = col(run_query(db, 'SELECT * FROM '+table+' LIMIT 0'), 0)
105
    return cache[table]
106

    
107
def get(db, table, row, pkey, create=False, row_ct_ref=None):
108
    try: return value(select(db, table, [pkey], row, 1))
109
    except StopIteration:
110
        if not create: raise
111
        # Insert new row
112
        try:
113
            row_ct = try_insert(db, table, row).rowcount
114
            if row_ct_ref != None and row_ct >= 0: row_ct_ref[0] += row_ct
115
            return last_insert_id(db)
116
        except DuplicateKeyException, e:
117
            return value(select(db, table, [pkey], {e.col: row[e.col]}))
118

    
119
db_engines = {
120
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
121
    'PostgreSQL': ('psycopg2', {}),
122
}
123

    
124
def connect(db_config):
125
    db_config = db_config.copy() # don't modify input!
126
    module, mappings = db_engines[db_config.pop('engine')]
127
    for orig, new in mappings.iteritems(): util.rename_key(db_config, orig, new)
128
    return __import__(module).connect(**db_config)
(5-5/10)