Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1960 aaronmk
import lists
11 1889 aaronmk
from Proxy import Proxy
12 1872 aaronmk
import rand
13 862 aaronmk
import strings
14 131 aaronmk
import util
15 11 aaronmk
16 832 aaronmk
##### Exceptions
17
18 135 aaronmk
def get_cur_query(cur):
19
    if hasattr(cur, 'query'): return cur.query
20
    elif hasattr(cur, '_last_executed'): return cur._last_executed
21
    else: return None
22 14 aaronmk
23 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
24 135 aaronmk
25 300 aaronmk
class DbException(exc.ExceptionWithCause):
26 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
27 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
28 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
29
30 360 aaronmk
class NameException(DbException): pass
31
32 468 aaronmk
class ExceptionWithColumns(DbException):
33
    def __init__(self, cols, cause=None):
34
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
35
        self.cols = cols
36 11 aaronmk
37 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
38 13 aaronmk
39 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
40 13 aaronmk
41 89 aaronmk
class EmptyRowException(DbException): pass
42
43 865 aaronmk
##### Warnings
44
45
class DbWarning(UserWarning): pass
46
47 1930 aaronmk
##### Result retrieval
48
49
def col_names(cur): return (col[0] for col in cur.description)
50
51
def rows(cur): return iter(lambda: cur.fetchone(), None)
52
53
def consume_rows(cur):
54
    '''Used to fetch all rows so result will be cached'''
55
    iters.consume_iter(rows(cur))
56
57
def next_row(cur): return rows(cur).next()
58
59
def row(cur):
60
    row_ = next_row(cur)
61
    consume_rows(cur)
62
    return row_
63
64
def next_value(cur): return next_row(cur)[0]
65
66
def value(cur): return row(cur)[0]
67
68
def values(cur): return iters.func_iter(lambda: next_value(cur))
69
70
def value_or_none(cur):
71
    try: return value(cur)
72
    except StopIteration: return None
73
74 2101 aaronmk
##### Input validation
75
76
def clean_name(name): return re.sub(r'\W', r'', name)
77
78
def check_name(name):
79
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
80
        +'" may contain only alphanumeric characters and _')
81
82
def esc_name_by_module(module, name, ignore_case=False):
83
    if module == 'psycopg2':
84
        if ignore_case:
85
            # Don't enclose in quotes because this disables case-insensitivity
86
            check_name(name)
87
            return name
88
        else: quote = '"'
89
    elif module == 'MySQLdb': quote = '`'
90
    else: raise NotImplementedError("Can't escape name for "+module+' database')
91
    return quote + name.replace(quote, '') + quote
92
93
def esc_name_by_engine(engine, name, **kw_args):
94
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
95
96
def esc_name(db, name, **kw_args):
97
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
98
99
def qual_name(db, schema, table):
100
    def esc_name_(name): return esc_name(db, name)
101
    table = esc_name_(table)
102
    if schema != None: return esc_name_(schema)+'.'+table
103
    else: return table
104
105 1869 aaronmk
##### Database connections
106 1849 aaronmk
107 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
108 1926 aaronmk
109 1869 aaronmk
db_engines = {
110
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
111
    'PostgreSQL': ('psycopg2', {}),
112
}
113
114
DatabaseErrors_set = set([DbException])
115
DatabaseErrors = tuple(DatabaseErrors_set)
116
117
def _add_module(module):
118
    DatabaseErrors_set.add(module.DatabaseError)
119
    global DatabaseErrors
120
    DatabaseErrors = tuple(DatabaseErrors_set)
121
122
def db_config_str(db_config):
123
    return db_config['engine']+' database '+db_config['database']
124
125 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
126 1894 aaronmk
127 1901 aaronmk
log_debug_none = lambda msg: None
128
129 1849 aaronmk
class DbConn:
130 2047 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none,
131 2050 aaronmk
        caching=True):
132 1869 aaronmk
        self.db_config = db_config
133
        self.serializable = serializable
134 1901 aaronmk
        self.log_debug = log_debug
135 2047 aaronmk
        self.caching = caching
136 1869 aaronmk
137
        self.__db = None
138 1889 aaronmk
        self.query_results = {}
139 1869 aaronmk
140
    def __getattr__(self, name):
141
        if name == '__dict__': raise Exception('getting __dict__')
142
        if name == 'db': return self._db()
143
        else: raise AttributeError()
144
145
    def __getstate__(self):
146
        state = copy.copy(self.__dict__) # shallow copy
147 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
148 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
149
        return state
150
151
    def _db(self):
152
        if self.__db == None:
153
            # Process db_config
154
            db_config = self.db_config.copy() # don't modify input!
155 2097 aaronmk
            schemas = db_config.pop('schemas', None)
156 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
157
            module = __import__(module_name)
158
            _add_module(module)
159
            for orig, new in mappings.iteritems():
160
                try: util.rename_key(db_config, orig, new)
161
                except KeyError: pass
162
163
            # Connect
164
            self.__db = module.connect(**db_config)
165
166
            # Configure connection
167
            if self.serializable: run_raw_query(self,
168
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
169 2101 aaronmk
            if schemas != None:
170
                schemas_ = ''.join((esc_name(self, s)+', '
171
                    for s in schemas.split(',')))
172
                run_raw_query(self, "SELECT set_config('search_path', \
173
%s || current_setting('search_path'), false)", [schemas_])
174 1869 aaronmk
175
        return self.__db
176 1889 aaronmk
177 1891 aaronmk
    class DbCursor(Proxy):
178 1927 aaronmk
        def __init__(self, outer):
179 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
180 1927 aaronmk
            self.query_results = outer.query_results
181 1894 aaronmk
            self.query_lookup = None
182 1891 aaronmk
            self.result = []
183 1889 aaronmk
184 1894 aaronmk
        def execute(self, query, params=None):
185 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
186 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
187 1904 aaronmk
            try: return_value = self.inner.execute(query, params)
188
            except Exception, e:
189
                self.result = e # cache the exception as the result
190
                self._cache_result()
191
                raise
192
            finally: self.query = get_cur_query(self.inner)
193 1930 aaronmk
            # Fetch all rows so result will be cached
194
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
195 1894 aaronmk
            return return_value
196
197 1891 aaronmk
        def fetchone(self):
198
            row = self.inner.fetchone()
199 1899 aaronmk
            if row != None: self.result.append(row)
200
            # otherwise, fetched all rows
201 1904 aaronmk
            else: self._cache_result()
202
            return row
203
204
        def _cache_result(self):
205 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
206
            # idempotent, but an invalid insert will always be invalid
207 1930 aaronmk
            if self.query_results != None and (not self._is_insert
208 1906 aaronmk
                or isinstance(self.result, Exception)):
209
210 1894 aaronmk
                assert self.query_lookup != None
211 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
212
                    util.dict_subset(dicts.AttrsDictView(self),
213
                    ['query', 'result', 'rowcount', 'description']))
214 1906 aaronmk
215 1916 aaronmk
        class CacheCursor:
216
            def __init__(self, cached_result): self.__dict__ = cached_result
217
218 1927 aaronmk
            def execute(self, *args, **kw_args):
219 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
220
                # otherwise, result is a rows list
221
                self.iter = iter(self.result)
222
223
            def fetchone(self):
224
                try: return self.iter.next()
225
                except StopIteration: return None
226 1891 aaronmk
227 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
228 2047 aaronmk
        if not self.caching: cacheable = False
229 1903 aaronmk
        used_cache = False
230
        try:
231 1927 aaronmk
            # Get cursor
232
            if cacheable:
233
                query_lookup = _query_lookup(query, params)
234
                try:
235
                    cur = self.query_results[query_lookup]
236
                    used_cache = True
237
                except KeyError: cur = self.DbCursor(self)
238
            else: cur = self.db.cursor()
239
240
            # Run query
241
            try: cur.execute(query, params)
242
            except Exception, e:
243
                _add_cursor_info(e, cur)
244
                raise
245 1903 aaronmk
        finally:
246
            if self.log_debug != log_debug_none: # only compute msg if needed
247
                if used_cache: cache_status = 'Cache hit'
248
                elif cacheable: cache_status = 'Cache miss'
249
                else: cache_status = 'Non-cacheable'
250 1927 aaronmk
                self.log_debug(cache_status+': '
251
                    +strings.one_line(get_cur_query(cur)))
252 1903 aaronmk
253
        return cur
254 1914 aaronmk
255
    def is_cached(self, query, params=None):
256
        return _query_lookup(query, params) in self.query_results
257 1849 aaronmk
258 1869 aaronmk
connect = DbConn
259
260 832 aaronmk
##### Querying
261
262 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
263 2085 aaronmk
    '''For params, see DbConn.run_query()'''
264 1894 aaronmk
    return db.run_query(*args, **kw_args)
265 11 aaronmk
266 2068 aaronmk
def mogrify(db, query, params):
267
    module = util.root_module(db.db)
268
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
269
    else: raise NotImplementedError("Can't mogrify query for "+module+
270
        ' database')
271
272 832 aaronmk
##### Recoverable querying
273 15 aaronmk
274 11 aaronmk
def with_savepoint(db, func):
275 1872 aaronmk
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
276 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
277 11 aaronmk
    try: return_val = func()
278
    except:
279 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
280 11 aaronmk
        raise
281
    else:
282 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
283 11 aaronmk
        return return_val
284
285 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
286 830 aaronmk
    if recover == None: recover = False
287
288 1894 aaronmk
    def run(): return run_raw_query(db, query, params, cacheable)
289 1914 aaronmk
    if recover and not db.is_cached(query, params):
290
        return with_savepoint(db, run)
291
    else: return run() # don't need savepoint if cached
292 830 aaronmk
293 832 aaronmk
##### Basic queries
294
295 2085 aaronmk
def run_query_into(db, query, params, into=None, *args, **kw_args):
296
    '''Outputs a query to a temp table.
297
    For params, see run_query().
298
    '''
299
    if into == None: return run_query(db, query, params, *args, **kw_args)
300
    else: # place rows in temp table
301
        check_name(into)
302
303
        run_query(db, 'DROP TABLE IF EXISTS '+into+' CASCADE', *args, **kw_args)
304
        return run_query(db, 'CREATE TEMP TABLE '+into+' AS '+query, params,
305
            *args, **kw_args) # CREATE TABLE sets rowcount to # rows in query
306
307 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
308
309 2121 aaronmk
def mk_select(db, tables, fields=None, conds=None, limit=None, start=None,
310 2120 aaronmk
    order_by=order_by_pkey, table_is_esc=False):
311 1981 aaronmk
    '''
312 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
313
        together, in the form: [table0, (table1, joins_dict), ...]
314 1981 aaronmk
    @param fields Use None to select all fields in the table
315
    @param table_is_esc Whether the table name has already been escaped
316 2054 aaronmk
    @return tuple(query, params)
317 1981 aaronmk
    '''
318 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
319 2058 aaronmk
320 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
321
    tables = tables[:] # don't modify input!
322
    table0 = tables.pop(0) # first table is separate
323
324 1135 aaronmk
    if conds == None: conds = {}
325 135 aaronmk
    assert limit == None or type(limit) == int
326 865 aaronmk
    assert start == None or type(start) == int
327 2120 aaronmk
    if order_by == order_by_pkey:
328 2121 aaronmk
        order_by = pkey(db, table0, recover=True, table_is_esc=table_is_esc)
329
    if not table_is_esc: table0 = esc_name_(table0)
330 865 aaronmk
331 2056 aaronmk
    params = []
332
333
    def parse_col(field):
334
        '''Parses fields'''
335 2121 aaronmk
        if isinstance(field, tuple): # field is literal value
336 2056 aaronmk
            value, col = field
337
            sql_ = '%s'
338
            params.append(value)
339 2058 aaronmk
            if col != None: sql_ += ' AS '+esc_name_(col)
340
        else: sql_ = esc_name_(field) # field is col name
341 2056 aaronmk
        return sql_
342 11 aaronmk
    def cond(entry):
343 2056 aaronmk
        '''Parses conditions'''
344 13 aaronmk
        col, value = entry
345 2058 aaronmk
        cond_ = esc_name_(col)+' '
346 11 aaronmk
        if value == None: cond_ += 'IS'
347
        else: cond_ += '='
348
        cond_ += ' %s'
349
        return cond_
350 2056 aaronmk
351 1135 aaronmk
    query = 'SELECT '
352
    if fields == None: query += '*'
353 2056 aaronmk
    else: query += ', '.join(map(parse_col, fields))
354 2121 aaronmk
    query += ' FROM '+table0
355 865 aaronmk
356
    missing = True
357 89 aaronmk
    if conds != {}:
358
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
359 2056 aaronmk
        params += conds.values()
360 865 aaronmk
        missing = False
361 2120 aaronmk
    if order_by != None: query += ' ORDER BY '+esc_name_(order_by)
362 865 aaronmk
    if limit != None: query += ' LIMIT '+str(limit); missing = False
363
    if start != None:
364
        if start != 0: query += ' OFFSET '+str(start)
365
        missing = False
366
    if missing: warnings.warn(DbWarning(
367
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
368
369 2056 aaronmk
    return (query, params)
370 11 aaronmk
371 2054 aaronmk
def select(db, *args, **kw_args):
372
    '''For params, see mk_select() and run_query()'''
373
    recover = kw_args.pop('recover', None)
374
    cacheable = kw_args.pop('cacheable', True)
375
376
    query, params = mk_select(db, *args, **kw_args)
377
    return run_query(db, query, params, recover, cacheable)
378
379 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
380 2070 aaronmk
    returning=None, embeddable=False, table_is_esc=False):
381 1960 aaronmk
    '''
382
    @param returning str|None An inserted column (such as pkey) to return
383 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
384 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
385
        query will be fully cached, not just if it raises an exception.
386 1960 aaronmk
    @param table_is_esc Whether the table name has already been escaped
387
    '''
388 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
389
    if cols == []: cols = None # no cols (all defaults) = unknown col names
390 1960 aaronmk
    if not table_is_esc: check_name(table)
391 2063 aaronmk
392
    # Build query
393
    query = 'INSERT INTO '+table
394
    if cols != None:
395
        map(check_name, cols)
396
        query += ' ('+', '.join(cols)+')'
397
    query += ' '+select_query
398
399
    if returning != None:
400
        check_name(returning)
401
        query += ' RETURNING '+returning
402
403 2070 aaronmk
    if embeddable:
404
        # Create function
405 2083 aaronmk
        function = 'pg_temp.'+('_'.join(map(clean_name,
406
            ['insert', table] + cols)))
407 2070 aaronmk
        return_type = 'SETOF '+table+'.'+returning+'%TYPE'
408
        function_query = '''\
409 2083 aaronmk
CREATE OR REPLACE FUNCTION '''+function+'''() RETURNS '''+return_type+'''
410 2070 aaronmk
    LANGUAGE sql
411
    AS $$'''+mogrify(db, query, params)+''';$$;
412
'''
413
        run_query(db, function_query, cacheable=True)
414
415
        # Return query that uses function
416 2083 aaronmk
        return mk_select(db, function+'() AS f ('+returning+')',
417 2080 aaronmk
            table_is_esc=True) # function alias is required in AS clause
418 2070 aaronmk
419 2066 aaronmk
    return (query, params)
420
421
def insert_select(db, *args, **kw_args):
422 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
423 2072 aaronmk
    @param into Name of temp table to place RETURNING values in
424
    '''
425
    into = kw_args.pop('into', None)
426
    if into != None: kw_args['embeddable'] = True
427 2066 aaronmk
    recover = kw_args.pop('recover', None)
428
    cacheable = kw_args.pop('cacheable', True)
429
430
    query, params = mk_insert_select(db, *args, **kw_args)
431 2085 aaronmk
    return run_query_into(db, query, params, into, recover, cacheable)
432 2063 aaronmk
433 2066 aaronmk
default = object() # tells insert() to use the default value for a column
434
435 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
436 2085 aaronmk
    '''For params, see insert_select()'''
437 1960 aaronmk
    if lists.is_seq(row): cols = None
438
    else:
439
        cols = row.keys()
440
        row = row.values()
441
    row = list(row) # ensure that "!= []" works
442
443 1961 aaronmk
    # Check for special values
444
    labels = []
445
    values = []
446
    for value in row:
447
        if value == default: labels.append('DEFAULT')
448
        else:
449
            labels.append('%s')
450
            values.append(value)
451
452
    # Build query
453 2063 aaronmk
    if values != []: query = ' VALUES ('+(', '.join(labels))+')'
454
    else: query = None
455 1554 aaronmk
456 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
457 11 aaronmk
458 135 aaronmk
def last_insert_id(db):
459 1849 aaronmk
    module = util.root_module(db.db)
460 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
461
    elif module == 'MySQLdb': return db.insert_id()
462
    else: return None
463 13 aaronmk
464 1968 aaronmk
def truncate(db, table, schema='public'):
465
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
466 832 aaronmk
467
##### Database structure queries
468
469 2084 aaronmk
def pkey(db, table, recover=None, table_is_esc=False):
470 832 aaronmk
    '''Assumed to be first column in table'''
471 2120 aaronmk
    return col_names(select(db, table, limit=0, order_by=None, recover=recover,
472 2084 aaronmk
        table_is_esc=table_is_esc)).next()
473 832 aaronmk
474 853 aaronmk
def index_cols(db, table, index):
475
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
476
    automatically created. When you don't know whether something is a UNIQUE
477
    constraint or a UNIQUE index, use this function.'''
478
    check_name(table)
479
    check_name(index)
480 1909 aaronmk
    module = util.root_module(db.db)
481
    if module == 'psycopg2':
482
        return list(values(run_query(db, '''\
483 853 aaronmk
SELECT attname
484 866 aaronmk
FROM
485
(
486
        SELECT attnum, attname
487
        FROM pg_index
488
        JOIN pg_class index ON index.oid = indexrelid
489
        JOIN pg_class table_ ON table_.oid = indrelid
490
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
491
        WHERE
492
            table_.relname = %(table)s
493
            AND index.relname = %(index)s
494
    UNION
495
        SELECT attnum, attname
496
        FROM
497
        (
498
            SELECT
499
                indrelid
500
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
501
                    AS indkey
502
            FROM pg_index
503
            JOIN pg_class index ON index.oid = indexrelid
504
            JOIN pg_class table_ ON table_.oid = indrelid
505
            WHERE
506
                table_.relname = %(table)s
507
                AND index.relname = %(index)s
508
        ) s
509
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
510
) s
511 853 aaronmk
ORDER BY attnum
512
''',
513 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
514
    else: raise NotImplementedError("Can't list index columns for "+module+
515
        ' database')
516 853 aaronmk
517 464 aaronmk
def constraint_cols(db, table, constraint):
518
    check_name(table)
519
    check_name(constraint)
520 1849 aaronmk
    module = util.root_module(db.db)
521 464 aaronmk
    if module == 'psycopg2':
522
        return list(values(run_query(db, '''\
523
SELECT attname
524
FROM pg_constraint
525
JOIN pg_class ON pg_class.oid = conrelid
526
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
527
WHERE
528
    relname = %(table)s
529
    AND conname = %(constraint)s
530
ORDER BY attnum
531
''',
532
            {'table': table, 'constraint': constraint})))
533
    else: raise NotImplementedError("Can't list constraint columns for "+module+
534
        ' database')
535
536 2096 aaronmk
row_num_col = '_row_num'
537
538 2086 aaronmk
def add_row_num(db, table):
539 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
540
    be the primary key.'''
541 2086 aaronmk
    check_name(table)
542 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
543 2117 aaronmk
        +' serial NOT NULL PRIMARY KEY')
544 2086 aaronmk
545 1968 aaronmk
def tables(db, schema='public', table_like='%'):
546 1849 aaronmk
    module = util.root_module(db.db)
547 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
548 832 aaronmk
    if module == 'psycopg2':
549 1968 aaronmk
        return values(run_query(db, '''\
550
SELECT tablename
551
FROM pg_tables
552
WHERE
553
    schemaname = %(schema)s
554
    AND tablename LIKE %(table_like)s
555
ORDER BY tablename
556
''',
557
            params, cacheable=True))
558
    elif module == 'MySQLdb':
559
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
560
            cacheable=True))
561 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
562 830 aaronmk
563 833 aaronmk
##### Database management
564
565 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
566
    '''For kw_args, see tables()'''
567
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
568 833 aaronmk
569 832 aaronmk
##### Heuristic queries
570
571 2076 aaronmk
def with_parsed_errors(db, func):
572
    '''Translates known DB errors to typed exceptions'''
573
    try: return func()
574 46 aaronmk
    except Exception, e:
575
        msg = str(e)
576 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
577
            r'"(([^\W_]+)_[^"]+)"', msg)
578
        if match:
579
            constraint, table = match.groups()
580 854 aaronmk
            try: cols = index_cols(db, table, constraint)
581 465 aaronmk
            except NotImplementedError: raise e
582 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
583 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
584
            'constraint', msg)
585 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
586 13 aaronmk
        raise # no specific exception raised
587 11 aaronmk
588 2076 aaronmk
def try_insert(db, table, row, returning=None):
589
    '''Recovers from errors'''
590
    return with_parsed_errors(db, lambda: insert(db, table, row, returning,
591
        recover=True))
592
593 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
594 1554 aaronmk
    '''Recovers from errors.
595 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
596
    '''
597 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
598
599 471 aaronmk
    try:
600 2104 aaronmk
        cur = try_insert(db, table, row, pkey_)
601 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
602
            row_ct_ref[0] += cur.rowcount
603
        return value(cur)
604 471 aaronmk
    except DuplicateKeyException, e:
605 2104 aaronmk
        return value(select(db, table, [pkey_],
606 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
607 471 aaronmk
608 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
609 830 aaronmk
    '''Recovers from errors'''
610
    try: return value(select(db, table, [pkey], row, 1, recover=True))
611 14 aaronmk
    except StopIteration:
612 40 aaronmk
        if not create: raise
613 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
614 2078 aaronmk
615 2087 aaronmk
def put_table(db, out_table, out_cols, in_tables, in_cols, pkey,
616
    row_ct_ref=None, table_is_esc=False):
617 2078 aaronmk
    '''Recovers from errors.
618
    Only works under PostgreSQL (uses INSERT RETURNING).
619 2081 aaronmk
    @return Name of the table where the pkeys (from INSERT RETURNING) are made
620 2078 aaronmk
        available
621
    '''
622
    pkeys_table = clean_name(out_table)+'_pkeys'
623
    def insert_():
624
        return insert_select(db, out_table, out_cols,
625 2087 aaronmk
            *mk_select(db, in_tables[0], in_cols, table_is_esc=table_is_esc),
626 2078 aaronmk
            returning=pkey, into=pkeys_table, recover=True,
627
            table_is_esc=table_is_esc)
628
    try:
629
        cur = with_parsed_errors(db, insert_)
630
        if row_ct_ref != None and cur.rowcount >= 0:
631
            row_ct_ref[0] += cur.rowcount
632 2086 aaronmk
633
        # Add row_num to pkeys_table, so it can be joined with in_table's pkeys
634
        add_row_num(db, pkeys_table)
635
636 2081 aaronmk
        return pkeys_table
637 2078 aaronmk
    except DuplicateKeyException, e: raise
638 2115 aaronmk
639
##### Data cleanup
640
641
def cleanup_table(db, table, cols, table_is_esc=False):
642
    def esc_name_(name): return esc_name(db, name)
643
644
    if not table_is_esc: check_name(table)
645
    cols = map(esc_name_, cols)
646
647
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
648
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
649
            for col in cols))),
650
        dict(null0='', null1=r'\N'))