Revision 832
Added by Aaron Marcuse-Kubitza about 13 years ago
lib/sql.py | ||
---|---|---|
7 | 7 |
import exc |
8 | 8 |
import util |
9 | 9 |
|
10 |
##### Exceptions |
|
11 |
|
|
10 | 12 |
def get_cur_query(cur): |
11 | 13 |
if hasattr(cur, 'query'): return cur.query |
12 | 14 |
elif hasattr(cur, '_last_executed'): return cur._last_executed |
... | ... | |
32 | 34 |
|
33 | 35 |
class EmptyRowException(DbException): pass |
34 | 36 |
|
37 |
##### Input validation |
|
38 |
|
|
35 | 39 |
def check_name(name): |
36 | 40 |
if re.search(r'\W', name) != None: raise NameException('Name "'+name |
37 | 41 |
+'" may contain only alphanumeric characters and _') |
... | ... | |
44 | 48 |
else: raise NotImplementedError("Can't escape name for "+module+' database') |
45 | 49 |
return quote + name.replace(quote, '') + quote |
46 | 50 |
|
51 |
##### Querying |
|
52 |
|
|
47 | 53 |
def run_raw_query(db, query, params=None): |
48 | 54 |
cur = db.cursor() |
49 | 55 |
try: cur.execute(query, params) |
... | ... | |
52 | 58 |
raise |
53 | 59 |
return cur |
54 | 60 |
|
55 |
def col(cur, idx): return cur.description[idx][0]
|
|
61 |
##### Recoverable querying
|
|
56 | 62 |
|
57 |
def rows(cur): return iter(lambda: cur.fetchone(), None) |
|
58 |
|
|
59 |
def row(cur): return rows(cur).next() |
|
60 |
|
|
61 |
def value(cur): return row(cur)[0] |
|
62 |
|
|
63 |
def values(cur): return iter(lambda: value(cur), None) |
|
64 |
|
|
65 |
def value_or_none(cur): |
|
66 |
try: return value(cur) |
|
67 |
except StopIteration: return None |
|
68 |
|
|
69 | 63 |
def with_savepoint(db, func): |
70 | 64 |
savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique |
71 | 65 |
run_raw_query(db, 'SAVEPOINT '+savepoint) |
... | ... | |
84 | 78 |
if recover: return with_savepoint(db, run) |
85 | 79 |
else: return run() |
86 | 80 |
|
81 |
##### Result retrieval |
|
82 |
|
|
83 |
def col(cur, idx): return cur.description[idx][0] |
|
84 |
|
|
85 |
def rows(cur): return iter(lambda: cur.fetchone(), None) |
|
86 |
|
|
87 |
def row(cur): return rows(cur).next() |
|
88 |
|
|
89 |
def value(cur): return row(cur)[0] |
|
90 |
|
|
91 |
def values(cur): return iter(lambda: value(cur), None) |
|
92 |
|
|
93 |
def value_or_none(cur): |
|
94 |
try: return value(cur) |
|
95 |
except StopIteration: return None |
|
96 |
|
|
97 |
##### Basic queries |
|
98 |
|
|
87 | 99 |
def select(db, table, fields, conds, limit=None, recover=None): |
88 | 100 |
assert limit == None or type(limit) == int |
89 | 101 |
check_name(table) |
... | ... | |
119 | 131 |
elif module == 'MySQLdb': return db.insert_id() |
120 | 132 |
else: return None |
121 | 133 |
|
134 |
def truncate(db, table): |
|
135 |
check_name(table) |
|
136 |
return run_query(db, 'TRUNCATE '+table+' CASCADE') |
|
137 |
|
|
138 |
##### Database structure queries |
|
139 |
|
|
140 |
def pkey(db, cache, table, recover=None): |
|
141 |
'''Assumed to be first column in table''' |
|
142 |
check_name(table) |
|
143 |
if table not in cache: |
|
144 |
cache[table] = col(run_query(db, 'SELECT * FROM '+table+' LIMIT 0', |
|
145 |
recover), 0) |
|
146 |
return cache[table] |
|
147 |
|
|
122 | 148 |
def constraint_cols(db, table, constraint): |
123 | 149 |
check_name(table) |
124 | 150 |
check_name(constraint) |
... | ... | |
138 | 164 |
else: raise NotImplementedError("Can't list constraint columns for "+module+ |
139 | 165 |
' database') |
140 | 166 |
|
141 |
def pkey(db, cache, table, recover=None):
|
|
142 |
'''Assumed to be first column in table'''
|
|
143 |
check_name(table)
|
|
144 |
if table not in cache:
|
|
145 |
cache[table] = col(run_query(db, 'SELECT * FROM '+table+' LIMIT 0',
|
|
146 |
recover), 0)
|
|
147 |
return cache[table]
|
|
167 |
def tables(db):
|
|
168 |
module = util.root_module(db)
|
|
169 |
if module == 'psycopg2':
|
|
170 |
return values(run_query(db, "SELECT tablename from pg_tables "
|
|
171 |
"WHERE schemaname = 'public' ORDER BY tablename"))
|
|
172 |
elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES'))
|
|
173 |
else: raise NotImplementedError("Can't list tables for "+module+' database')
|
|
148 | 174 |
|
175 |
##### Heuristic queries |
|
176 |
|
|
149 | 177 |
def try_insert(db, table, row): |
150 | 178 |
'''Recovers from errors''' |
151 | 179 |
try: return insert(db, table, row, recover=True) |
... | ... | |
180 | 208 |
if not create: raise |
181 | 209 |
return put(db, table, row, pkey, row_ct_ref) # insert new row |
182 | 210 |
|
211 |
##### Database management |
|
183 | 212 |
|
184 |
def truncate(db, table): |
|
185 |
check_name(table) |
|
186 |
return run_query(db, 'TRUNCATE '+table+' CASCADE') |
|
187 |
|
|
188 |
def tables(db): |
|
189 |
module = util.root_module(db) |
|
190 |
if module == 'psycopg2': |
|
191 |
return values(run_query(db, "SELECT tablename from pg_tables " |
|
192 |
"WHERE schemaname = 'public' ORDER BY tablename")) |
|
193 |
elif module == 'MySQLdb': return values(run_query(db, 'SHOW TABLES')) |
|
194 |
else: raise NotImplementedError("Can't list tables for "+module+' database') |
|
195 |
|
|
196 | 213 |
def empty_db(db): |
197 | 214 |
for table in tables(db): truncate(db, table) |
198 | 215 |
|
216 |
##### Database connections |
|
217 |
|
|
199 | 218 |
db_engines = { |
200 | 219 |
'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}), |
201 | 220 |
'PostgreSQL': ('psycopg2', {}), |
Also available in: Unified diff
sql.py: Added documentation labels to each section