Revision 830
Added by Aaron Marcuse-Kubitza almost 13 years ago
lib/sql.py | ||
---|---|---|
44 | 44 |
else: raise NotImplementedError("Can't escape name for "+module+' database') |
45 | 45 |
return quote + name.replace(quote, '') + quote |
46 | 46 |
|
47 |
def run_query(db, query, params=None): |
|
47 |
def run_raw_query(db, query, params=None):
|
|
48 | 48 |
cur = db.cursor() |
49 | 49 |
try: cur.execute(query, params) |
50 | 50 |
except Exception, e: |
... | ... | |
68 | 68 |
|
69 | 69 |
def with_savepoint(db, func): |
70 | 70 |
savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique |
71 |
run_query(db, 'SAVEPOINT '+savepoint) |
|
71 |
run_raw_query(db, 'SAVEPOINT '+savepoint)
|
|
72 | 72 |
try: return_val = func() |
73 | 73 |
except: |
74 |
run_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint) |
|
74 |
run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
|
|
75 | 75 |
raise |
76 | 76 |
else: |
77 |
run_query(db, 'RELEASE SAVEPOINT '+savepoint) |
|
77 |
run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
|
|
78 | 78 |
return return_val |
79 | 79 |
|
80 |
def select(db, table, fields, conds, limit=None): |
|
80 |
def run_query(db, query, params=None, recover=None): |
|
81 |
if recover == None: recover = False |
|
82 |
|
|
83 |
def run(): return run_raw_query(db, query, params) |
|
84 |
if recover: return with_savepoint(db, run) |
|
85 |
else: return run() |
|
86 |
|
|
87 |
def select(db, table, fields, conds, limit=None, recover=None): |
|
81 | 88 |
assert limit == None or type(limit) == int |
82 | 89 |
check_name(table) |
83 | 90 |
map(check_name, fields) |
... | ... | |
94 | 101 |
if conds != {}: |
95 | 102 |
query += ' WHERE '+' AND '.join(map(cond, conds.iteritems())) |
96 | 103 |
if limit != None: query += ' LIMIT '+str(limit) |
97 |
return run_query(db, query, conds.values()) |
|
104 |
return run_query(db, query, conds.values(), recover)
|
|
98 | 105 |
|
99 |
def insert(db, table, row): |
|
106 |
def insert(db, table, row, recover=None):
|
|
100 | 107 |
check_name(table) |
101 | 108 |
cols = row.keys() |
102 | 109 |
map(check_name, cols) |
... | ... | |
104 | 111 |
if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\ |
105 | 112 |
+', '.join(['%s']*len(cols))+')' |
106 | 113 |
else: query += ' DEFAULT VALUES' |
107 |
return run_query(db, query, row.values()) |
|
114 |
return run_query(db, query, row.values(), recover)
|
|
108 | 115 |
|
109 | 116 |
def last_insert_id(db): |
110 | 117 |
module = util.root_module(db) |
... | ... | |
131 | 138 |
else: raise NotImplementedError("Can't list constraint columns for "+module+ |
132 | 139 |
' database') |
133 | 140 |
|
141 |
def pkey(db, cache, table, recover=None): |
|
142 |
'''Assumed to be first column in table''' |
|
143 |
check_name(table) |
|
144 |
if table not in cache: |
|
145 |
cache[table] = col(run_query(db, 'SELECT * FROM '+table+' LIMIT 0', |
|
146 |
recover), 0) |
|
147 |
return cache[table] |
|
148 |
|
|
134 | 149 |
def try_insert(db, table, row): |
135 |
try: return with_savepoint(db, lambda: insert(db, table, row)) |
|
150 |
'''Recovers from errors''' |
|
151 |
try: return insert(db, table, row, recover=True) |
|
136 | 152 |
except Exception, e: |
137 | 153 |
msg = str(e) |
138 | 154 |
match = re.search(r'duplicate key value violates unique constraint ' |
... | ... | |
147 | 163 |
if match: raise NullValueException([match.group(1)], e) |
148 | 164 |
raise # no specific exception raised |
149 | 165 |
|
150 |
def pkey(db, cache, table): # Assumed to be first column in table |
|
151 |
check_name(table) |
|
152 |
if table not in cache: |
|
153 |
cache[table] = col(run_query(db, 'SELECT * FROM '+table+' LIMIT 0'), 0) |
|
154 |
return cache[table] |
|
155 |
|
|
156 | 166 |
def put(db, table, row, pkey, row_ct_ref=None): |
167 |
'''Recovers from errors''' |
|
157 | 168 |
try: |
158 | 169 |
row_ct = try_insert(db, table, row).rowcount |
159 | 170 |
if row_ct_ref != None and row_ct >= 0: row_ct_ref[0] += row_ct |
160 | 171 |
return last_insert_id(db) |
161 | 172 |
except DuplicateKeyException, e: |
162 |
return value(select(db, table, [pkey], util.dict_subset(row, e.cols))) |
|
173 |
return value(select(db, table, [pkey], util.dict_subset(row, e.cols), |
|
174 |
recover=True)) |
|
163 | 175 |
|
164 | 176 |
def get(db, table, row, pkey, row_ct_ref=None, create=False): |
165 |
try: return value(select(db, table, [pkey], row, 1)) |
|
177 |
'''Recovers from errors''' |
|
178 |
try: return value(select(db, table, [pkey], row, 1, recover=True)) |
|
166 | 179 |
except StopIteration: |
167 | 180 |
if not create: raise |
168 | 181 |
return put(db, table, row, pkey, row_ct_ref) # insert new row |
... | ... | |
206 | 219 |
except KeyError: pass |
207 | 220 |
db = module.connect(**db_config) |
208 | 221 |
if serializable: |
209 |
run_query(db, 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE') |
|
222 |
run_raw_query(db, 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
|
|
210 | 223 |
return db |
211 | 224 |
|
212 | 225 |
def db_config_str(db_config): |
Also available in: Unified diff
sql.py: Added ability to recover from database errors so you don't get the error "InternalError: current transaction is aborted, commands ignored until end of transaction block"