Project

General

Profile

1
# Database access
2

    
3
import random
4
import re
5
import sys
6

    
7
import exc
8
import util
9

    
10
def get_cur_query(cur):
11
    if hasattr(cur, 'query'): return cur.query
12
    elif hasattr(cur, '_last_executed'): return cur._last_executed
13
    else: return None
14

    
15
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
16

    
17
class NameException(Exception): pass
18

    
19
class DbException(exc.ExceptionWithCause):
20
    def __init__(self, msg, cause=None, cur=None):
21
        exc.ExceptionWithCause.__init__(self, msg, cause)
22
        if cur != None: _add_cursor_info(self, cur)
23

    
24
class ExceptionWithColumn(DbException):
25
    def __init__(self, col, cause=None):
26
        DbException.__init__(self, 'column: '+col, cause)
27
        self.col = col
28

    
29
class DuplicateKeyException(ExceptionWithColumn): pass
30

    
31
class NullValueException(ExceptionWithColumn): pass
32

    
33
class EmptyRowException(DbException): pass
34

    
35
def check_name(name):
36
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
37
        +'" may contain only alphanumeric characters and _')
38

    
39
def run_query(db, query, params=None):
40
    cur = db.cursor()
41
    try: cur.execute(query, params)
42
    except Exception, e:
43
        _add_cursor_info(e, cur)
44
        raise
45
    return cur
46

    
47
def col(cur, idx): return cur.description[idx][0]
48

    
49
def rows(cur): return iter(lambda: cur.fetchone(), None)
50

    
51
def row(cur): return rows(cur).next()
52

    
53
def value(cur): return row(cur)[0]
54

    
55
def value_or_none(cur):
56
    try: return value(cur)
57
    except StopIteration: return None
58

    
59
def with_savepoint(db, func):
60
    savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique
61
    run_query(db, 'SAVEPOINT '+savepoint)
62
    try: return_val = func()
63
    except:
64
        run_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
65
        raise
66
    else:
67
        run_query(db, 'RELEASE SAVEPOINT '+savepoint)
68
        return return_val
69

    
70
def select(db, table, fields, conds, limit=None):
71
    assert limit == None or type(limit) == int
72
    check_name(table)
73
    map(check_name, fields)
74
    map(check_name, conds.keys())
75
    def cond(entry):
76
        col, value = entry
77
        cond_ = col+' '
78
        if value == None: cond_ += 'IS'
79
        else: cond_ += '='
80
        cond_ += ' %s'
81
        return cond_
82
    query = 'SELECT '+', '.join(fields)+' FROM '+table
83
    if conds != {}:
84
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
85
    if limit != None: query += ' LIMIT '+str(limit)
86
    return run_query(db, query, conds.values())
87

    
88
def insert(db, table, row):
89
    check_name(table)
90
    cols = row.keys()
91
    map(check_name, cols)
92
    query = 'INSERT INTO '+table
93
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
94
        +', '.join(['%s']*len(cols))+')'
95
    else: query += ' DEFAULT VALUES'
96
    return run_query(db, query, row.values())
97

    
98
def last_insert_id(db):
99
    module = util.root_module(db)
100
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
101
    elif module == 'MySQLdb': return db.insert_id()
102
    else: return None
103

    
104
def try_insert(db, table, row):
105
    try: return with_savepoint(db, lambda: insert(db, table, row))
106
    except Exception, e:
107
        msg = str(e)
108
        match = re.search(r'duplicate key value violates unique constraint "'
109
            +table+'_(\w+)_index"', msg)
110
        if match: raise DuplicateKeyException(match.group(1), e)
111
        match = re.search(r'null value in column "(\w+)" violates not-null '
112
            'constraint', msg)
113
        if match: raise NullValueException(match.group(1), e)
114
        raise # no specific exception raised
115

    
116
def pkey(db, cache, table): # Assumed to be first column in table
117
    check_name(table)
118
    if table not in cache:
119
        cache[table] = col(run_query(db, 'SELECT * FROM '+table+' LIMIT 0'), 0)
120
    return cache[table]
121

    
122
def get(db, table, row, pkey, create=False, row_ct_ref=None):
123
    try: return value(select(db, table, [pkey], row, 1))
124
    except StopIteration:
125
        if not create: raise
126
        # Insert new row
127
        try:
128
            row_ct = try_insert(db, table, row).rowcount
129
            if row_ct_ref != None and row_ct >= 0: row_ct_ref[0] += row_ct
130
            return last_insert_id(db)
131
        except DuplicateKeyException, e:
132
            return value(select(db, table, [pkey], {e.col: row[e.col]}))
133

    
134
db_engines = {
135
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
136
    'PostgreSQL': ('psycopg2', {}),
137
}
138

    
139
DatabaseErrors_set = set()
140
DatabaseErrors = ()
141

    
142
def _add_module(module):
143
    DatabaseErrors_set.add(module.DatabaseError)
144
    global DatabaseErrors
145
    DatabaseErrors = tuple(DatabaseErrors_set)
146

    
147
def connect(db_config):
148
    db_config = db_config.copy() # don't modify input!
149
    module_name, mappings = db_engines[db_config.pop('engine')]
150
    module = __import__(module_name)
151
    _add_module(module)
152
    for orig, new in mappings.iteritems():
153
        try: util.rename_key(db_config, orig, new)
154
        except KeyError: pass
155
    return module.connect(**db_config)
(5-5/11)