Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 2127 aaronmk
import operator
5 11 aaronmk
import re
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 1889 aaronmk
from Proxy import Proxy
13 1872 aaronmk
import rand
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 135 aaronmk
def get_cur_query(cur):
20
    if hasattr(cur, 'query'): return cur.query
21
    elif hasattr(cur, '_last_executed'): return cur._last_executed
22
    else: return None
23 14 aaronmk
24 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
25 135 aaronmk
26 300 aaronmk
class DbException(exc.ExceptionWithCause):
27 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
28 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
29 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
30
31 360 aaronmk
class NameException(DbException): pass
32
33 468 aaronmk
class ExceptionWithColumns(DbException):
34
    def __init__(self, cols, cause=None):
35
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
36
        self.cols = cols
37 11 aaronmk
38 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
39 13 aaronmk
40 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
41 13 aaronmk
42 89 aaronmk
class EmptyRowException(DbException): pass
43
44 865 aaronmk
##### Warnings
45
46
class DbWarning(UserWarning): pass
47
48 1930 aaronmk
##### Result retrieval
49
50
def col_names(cur): return (col[0] for col in cur.description)
51
52
def rows(cur): return iter(lambda: cur.fetchone(), None)
53
54
def consume_rows(cur):
55
    '''Used to fetch all rows so result will be cached'''
56
    iters.consume_iter(rows(cur))
57
58
def next_row(cur): return rows(cur).next()
59
60
def row(cur):
61
    row_ = next_row(cur)
62
    consume_rows(cur)
63
    return row_
64
65
def next_value(cur): return next_row(cur)[0]
66
67
def value(cur): return row(cur)[0]
68
69
def values(cur): return iters.func_iter(lambda: next_value(cur))
70
71
def value_or_none(cur):
72
    try: return value(cur)
73
    except StopIteration: return None
74
75 2101 aaronmk
##### Input validation
76
77
def clean_name(name): return re.sub(r'\W', r'', name)
78
79
def check_name(name):
80
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
81
        +'" may contain only alphanumeric characters and _')
82
83
def esc_name_by_module(module, name, ignore_case=False):
84
    if module == 'psycopg2':
85
        if ignore_case:
86
            # Don't enclose in quotes because this disables case-insensitivity
87
            check_name(name)
88
            return name
89
        else: quote = '"'
90
    elif module == 'MySQLdb': quote = '`'
91
    else: raise NotImplementedError("Can't escape name for "+module+' database')
92
    return quote + name.replace(quote, '') + quote
93
94
def esc_name_by_engine(engine, name, **kw_args):
95
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
96
97
def esc_name(db, name, **kw_args):
98
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
99
100
def qual_name(db, schema, table):
101
    def esc_name_(name): return esc_name(db, name)
102
    table = esc_name_(table)
103
    if schema != None: return esc_name_(schema)+'.'+table
104
    else: return table
105
106 1869 aaronmk
##### Database connections
107 1849 aaronmk
108 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
109 1926 aaronmk
110 1869 aaronmk
db_engines = {
111
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
112
    'PostgreSQL': ('psycopg2', {}),
113
}
114
115
DatabaseErrors_set = set([DbException])
116
DatabaseErrors = tuple(DatabaseErrors_set)
117
118
def _add_module(module):
119
    DatabaseErrors_set.add(module.DatabaseError)
120
    global DatabaseErrors
121
    DatabaseErrors = tuple(DatabaseErrors_set)
122
123
def db_config_str(db_config):
124
    return db_config['engine']+' database '+db_config['database']
125
126 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
127 1894 aaronmk
128 1901 aaronmk
log_debug_none = lambda msg: None
129
130 1849 aaronmk
class DbConn:
131 2047 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none,
132 2050 aaronmk
        caching=True):
133 1869 aaronmk
        self.db_config = db_config
134
        self.serializable = serializable
135 1901 aaronmk
        self.log_debug = log_debug
136 2047 aaronmk
        self.caching = caching
137 1869 aaronmk
138
        self.__db = None
139 1889 aaronmk
        self.query_results = {}
140 1869 aaronmk
141
    def __getattr__(self, name):
142
        if name == '__dict__': raise Exception('getting __dict__')
143
        if name == 'db': return self._db()
144
        else: raise AttributeError()
145
146
    def __getstate__(self):
147
        state = copy.copy(self.__dict__) # shallow copy
148 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
149 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
150
        return state
151
152
    def _db(self):
153
        if self.__db == None:
154
            # Process db_config
155
            db_config = self.db_config.copy() # don't modify input!
156 2097 aaronmk
            schemas = db_config.pop('schemas', None)
157 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
158
            module = __import__(module_name)
159
            _add_module(module)
160
            for orig, new in mappings.iteritems():
161
                try: util.rename_key(db_config, orig, new)
162
                except KeyError: pass
163
164
            # Connect
165
            self.__db = module.connect(**db_config)
166
167
            # Configure connection
168
            if self.serializable: run_raw_query(self,
169
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
170 2101 aaronmk
            if schemas != None:
171
                schemas_ = ''.join((esc_name(self, s)+', '
172
                    for s in schemas.split(',')))
173
                run_raw_query(self, "SELECT set_config('search_path', \
174
%s || current_setting('search_path'), false)", [schemas_])
175 1869 aaronmk
176
        return self.__db
177 1889 aaronmk
178 1891 aaronmk
    class DbCursor(Proxy):
179 1927 aaronmk
        def __init__(self, outer):
180 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
181 1927 aaronmk
            self.query_results = outer.query_results
182 1894 aaronmk
            self.query_lookup = None
183 1891 aaronmk
            self.result = []
184 1889 aaronmk
185 1894 aaronmk
        def execute(self, query, params=None):
186 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
187 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
188 1904 aaronmk
            try: return_value = self.inner.execute(query, params)
189
            except Exception, e:
190
                self.result = e # cache the exception as the result
191
                self._cache_result()
192
                raise
193
            finally: self.query = get_cur_query(self.inner)
194 1930 aaronmk
            # Fetch all rows so result will be cached
195
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
196 1894 aaronmk
            return return_value
197
198 1891 aaronmk
        def fetchone(self):
199
            row = self.inner.fetchone()
200 1899 aaronmk
            if row != None: self.result.append(row)
201
            # otherwise, fetched all rows
202 1904 aaronmk
            else: self._cache_result()
203
            return row
204
205
        def _cache_result(self):
206 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
207
            # idempotent, but an invalid insert will always be invalid
208 1930 aaronmk
            if self.query_results != None and (not self._is_insert
209 1906 aaronmk
                or isinstance(self.result, Exception)):
210
211 1894 aaronmk
                assert self.query_lookup != None
212 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
213
                    util.dict_subset(dicts.AttrsDictView(self),
214
                    ['query', 'result', 'rowcount', 'description']))
215 1906 aaronmk
216 1916 aaronmk
        class CacheCursor:
217
            def __init__(self, cached_result): self.__dict__ = cached_result
218
219 1927 aaronmk
            def execute(self, *args, **kw_args):
220 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
221
                # otherwise, result is a rows list
222
                self.iter = iter(self.result)
223
224
            def fetchone(self):
225
                try: return self.iter.next()
226
                except StopIteration: return None
227 1891 aaronmk
228 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
229 2047 aaronmk
        if not self.caching: cacheable = False
230 1903 aaronmk
        used_cache = False
231
        try:
232 1927 aaronmk
            # Get cursor
233
            if cacheable:
234
                query_lookup = _query_lookup(query, params)
235
                try:
236
                    cur = self.query_results[query_lookup]
237
                    used_cache = True
238
                except KeyError: cur = self.DbCursor(self)
239
            else: cur = self.db.cursor()
240
241
            # Run query
242
            try: cur.execute(query, params)
243
            except Exception, e:
244
                _add_cursor_info(e, cur)
245
                raise
246 1903 aaronmk
        finally:
247
            if self.log_debug != log_debug_none: # only compute msg if needed
248
                if used_cache: cache_status = 'Cache hit'
249
                elif cacheable: cache_status = 'Cache miss'
250
                else: cache_status = 'Non-cacheable'
251 1927 aaronmk
                self.log_debug(cache_status+': '
252
                    +strings.one_line(get_cur_query(cur)))
253 1903 aaronmk
254
        return cur
255 1914 aaronmk
256
    def is_cached(self, query, params=None):
257
        return _query_lookup(query, params) in self.query_results
258 1849 aaronmk
259 1869 aaronmk
connect = DbConn
260
261 832 aaronmk
##### Querying
262
263 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
264 2085 aaronmk
    '''For params, see DbConn.run_query()'''
265 1894 aaronmk
    return db.run_query(*args, **kw_args)
266 11 aaronmk
267 2068 aaronmk
def mogrify(db, query, params):
268
    module = util.root_module(db.db)
269
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
270
    else: raise NotImplementedError("Can't mogrify query for "+module+
271
        ' database')
272
273 832 aaronmk
##### Recoverable querying
274 15 aaronmk
275 11 aaronmk
def with_savepoint(db, func):
276 1872 aaronmk
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
277 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
278 11 aaronmk
    try: return_val = func()
279
    except:
280 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
281 11 aaronmk
        raise
282
    else:
283 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
284 11 aaronmk
        return return_val
285
286 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
287 830 aaronmk
    if recover == None: recover = False
288
289 1894 aaronmk
    def run(): return run_raw_query(db, query, params, cacheable)
290 1914 aaronmk
    if recover and not db.is_cached(query, params):
291
        return with_savepoint(db, run)
292
    else: return run() # don't need savepoint if cached
293 830 aaronmk
294 832 aaronmk
##### Basic queries
295
296 2085 aaronmk
def run_query_into(db, query, params, into=None, *args, **kw_args):
297
    '''Outputs a query to a temp table.
298
    For params, see run_query().
299
    '''
300
    if into == None: return run_query(db, query, params, *args, **kw_args)
301
    else: # place rows in temp table
302
        check_name(into)
303
304
        run_query(db, 'DROP TABLE IF EXISTS '+into+' CASCADE', *args, **kw_args)
305
        return run_query(db, 'CREATE TEMP TABLE '+into+' AS '+query, params,
306
            *args, **kw_args) # CREATE TABLE sets rowcount to # rows in query
307
308 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
309
310 2127 aaronmk
join_using = object() # tells mk_select() to join the column with USING
311
312 2121 aaronmk
def mk_select(db, tables, fields=None, conds=None, limit=None, start=None,
313 2120 aaronmk
    order_by=order_by_pkey, table_is_esc=False):
314 1981 aaronmk
    '''
315 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
316 2124 aaronmk
        together: [table0, (table1, dict(right_col=left_col, ...)), ...]
317 1981 aaronmk
    @param fields Use None to select all fields in the table
318
    @param table_is_esc Whether the table name has already been escaped
319 2054 aaronmk
    @return tuple(query, params)
320 1981 aaronmk
    '''
321 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
322 2058 aaronmk
323 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
324
    tables = tables[:] # don't modify input!
325
    table0 = tables.pop(0) # first table is separate
326
327 1135 aaronmk
    if conds == None: conds = {}
328 135 aaronmk
    assert limit == None or type(limit) == int
329 865 aaronmk
    assert start == None or type(start) == int
330 2120 aaronmk
    if order_by == order_by_pkey:
331 2121 aaronmk
        order_by = pkey(db, table0, recover=True, table_is_esc=table_is_esc)
332
    if not table_is_esc: table0 = esc_name_(table0)
333 865 aaronmk
334 2056 aaronmk
    params = []
335
336
    def parse_col(field):
337
        '''Parses fields'''
338 2121 aaronmk
        if isinstance(field, tuple): # field is literal value
339 2056 aaronmk
            value, col = field
340
            sql_ = '%s'
341
            params.append(value)
342 2058 aaronmk
            if col != None: sql_ += ' AS '+esc_name_(col)
343
        else: sql_ = esc_name_(field) # field is col name
344 2056 aaronmk
        return sql_
345 11 aaronmk
    def cond(entry):
346 2056 aaronmk
        '''Parses conditions'''
347 13 aaronmk
        col, value = entry
348 2058 aaronmk
        cond_ = esc_name_(col)+' '
349 11 aaronmk
        if value == None: cond_ += 'IS'
350
        else: cond_ += '='
351
        cond_ += ' %s'
352
        return cond_
353 2056 aaronmk
354 1135 aaronmk
    query = 'SELECT '
355
    if fields == None: query += '*'
356 2056 aaronmk
    else: query += ', '.join(map(parse_col, fields))
357 2121 aaronmk
    query += ' FROM '+table0
358 865 aaronmk
359 2122 aaronmk
    # Add joins
360
    left_table = table0
361
    for table, joins in tables:
362
        if not table_is_esc: table = esc_name_(table)
363 2127 aaronmk
        query += ' JOIN '+table
364 2122 aaronmk
365
        def join(entry):
366 2127 aaronmk
            '''Parses non-USING joins'''
367 2124 aaronmk
            right_col, left_col = entry
368
            right_col = table+'.'+esc_name_(right_col)
369 2122 aaronmk
            left_col = left_table+'.'+esc_name_(left_col)
370 2123 aaronmk
            return (right_col+' = '+left_col
371
                +' OR ('+right_col+' IS NULL AND '+left_col+' IS NULL)')
372 2122 aaronmk
373 2127 aaronmk
        if reduce(operator.and_, (v == join_using for v in joins.itervalues())):
374
            # all cols w/ USING
375
            query += ' USING('+(', '.join(joins.iterkeys()))
376
        else: query += ' ON '+(' AND '.join(map(join, joins.iteritems())))
377
378 2122 aaronmk
        left_table = table
379
380 865 aaronmk
    missing = True
381 89 aaronmk
    if conds != {}:
382 2122 aaronmk
        query += ' WHERE '+(' AND '.join(map(cond, conds.iteritems())))
383 2056 aaronmk
        params += conds.values()
384 865 aaronmk
        missing = False
385 2120 aaronmk
    if order_by != None: query += ' ORDER BY '+esc_name_(order_by)
386 865 aaronmk
    if limit != None: query += ' LIMIT '+str(limit); missing = False
387
    if start != None:
388
        if start != 0: query += ' OFFSET '+str(start)
389
        missing = False
390
    if missing: warnings.warn(DbWarning(
391
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
392
393 2056 aaronmk
    return (query, params)
394 11 aaronmk
395 2054 aaronmk
def select(db, *args, **kw_args):
396
    '''For params, see mk_select() and run_query()'''
397
    recover = kw_args.pop('recover', None)
398
    cacheable = kw_args.pop('cacheable', True)
399
400
    query, params = mk_select(db, *args, **kw_args)
401
    return run_query(db, query, params, recover, cacheable)
402
403 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
404 2070 aaronmk
    returning=None, embeddable=False, table_is_esc=False):
405 1960 aaronmk
    '''
406
    @param returning str|None An inserted column (such as pkey) to return
407 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
408 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
409
        query will be fully cached, not just if it raises an exception.
410 1960 aaronmk
    @param table_is_esc Whether the table name has already been escaped
411
    '''
412 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
413
    if cols == []: cols = None # no cols (all defaults) = unknown col names
414 1960 aaronmk
    if not table_is_esc: check_name(table)
415 2063 aaronmk
416
    # Build query
417
    query = 'INSERT INTO '+table
418
    if cols != None:
419
        map(check_name, cols)
420
        query += ' ('+', '.join(cols)+')'
421
    query += ' '+select_query
422
423
    if returning != None:
424
        check_name(returning)
425
        query += ' RETURNING '+returning
426
427 2070 aaronmk
    if embeddable:
428
        # Create function
429 2083 aaronmk
        function = 'pg_temp.'+('_'.join(map(clean_name,
430
            ['insert', table] + cols)))
431 2070 aaronmk
        return_type = 'SETOF '+table+'.'+returning+'%TYPE'
432
        function_query = '''\
433 2083 aaronmk
CREATE OR REPLACE FUNCTION '''+function+'''() RETURNS '''+return_type+'''
434 2070 aaronmk
    LANGUAGE sql
435
    AS $$'''+mogrify(db, query, params)+''';$$;
436
'''
437
        run_query(db, function_query, cacheable=True)
438
439
        # Return query that uses function
440 2126 aaronmk
        return mk_select(db, function+'() AS f ('+returning+')', order_by=None,
441 2080 aaronmk
            table_is_esc=True) # function alias is required in AS clause
442 2070 aaronmk
443 2066 aaronmk
    return (query, params)
444
445
def insert_select(db, *args, **kw_args):
446 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
447 2072 aaronmk
    @param into Name of temp table to place RETURNING values in
448
    '''
449
    into = kw_args.pop('into', None)
450
    if into != None: kw_args['embeddable'] = True
451 2066 aaronmk
    recover = kw_args.pop('recover', None)
452
    cacheable = kw_args.pop('cacheable', True)
453
454
    query, params = mk_insert_select(db, *args, **kw_args)
455 2085 aaronmk
    return run_query_into(db, query, params, into, recover, cacheable)
456 2063 aaronmk
457 2066 aaronmk
default = object() # tells insert() to use the default value for a column
458
459 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
460 2085 aaronmk
    '''For params, see insert_select()'''
461 1960 aaronmk
    if lists.is_seq(row): cols = None
462
    else:
463
        cols = row.keys()
464
        row = row.values()
465
    row = list(row) # ensure that "!= []" works
466
467 1961 aaronmk
    # Check for special values
468
    labels = []
469
    values = []
470
    for value in row:
471
        if value == default: labels.append('DEFAULT')
472
        else:
473
            labels.append('%s')
474
            values.append(value)
475
476
    # Build query
477 2063 aaronmk
    if values != []: query = ' VALUES ('+(', '.join(labels))+')'
478
    else: query = None
479 1554 aaronmk
480 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
481 11 aaronmk
482 135 aaronmk
def last_insert_id(db):
483 1849 aaronmk
    module = util.root_module(db.db)
484 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
485
    elif module == 'MySQLdb': return db.insert_id()
486
    else: return None
487 13 aaronmk
488 1968 aaronmk
def truncate(db, table, schema='public'):
489
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
490 832 aaronmk
491
##### Database structure queries
492
493 2084 aaronmk
def pkey(db, table, recover=None, table_is_esc=False):
494 832 aaronmk
    '''Assumed to be first column in table'''
495 2120 aaronmk
    return col_names(select(db, table, limit=0, order_by=None, recover=recover,
496 2084 aaronmk
        table_is_esc=table_is_esc)).next()
497 832 aaronmk
498 853 aaronmk
def index_cols(db, table, index):
499
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
500
    automatically created. When you don't know whether something is a UNIQUE
501
    constraint or a UNIQUE index, use this function.'''
502
    check_name(table)
503
    check_name(index)
504 1909 aaronmk
    module = util.root_module(db.db)
505
    if module == 'psycopg2':
506
        return list(values(run_query(db, '''\
507 853 aaronmk
SELECT attname
508 866 aaronmk
FROM
509
(
510
        SELECT attnum, attname
511
        FROM pg_index
512
        JOIN pg_class index ON index.oid = indexrelid
513
        JOIN pg_class table_ ON table_.oid = indrelid
514
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
515
        WHERE
516
            table_.relname = %(table)s
517
            AND index.relname = %(index)s
518
    UNION
519
        SELECT attnum, attname
520
        FROM
521
        (
522
            SELECT
523
                indrelid
524
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
525
                    AS indkey
526
            FROM pg_index
527
            JOIN pg_class index ON index.oid = indexrelid
528
            JOIN pg_class table_ ON table_.oid = indrelid
529
            WHERE
530
                table_.relname = %(table)s
531
                AND index.relname = %(index)s
532
        ) s
533
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
534
) s
535 853 aaronmk
ORDER BY attnum
536
''',
537 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
538
    else: raise NotImplementedError("Can't list index columns for "+module+
539
        ' database')
540 853 aaronmk
541 464 aaronmk
def constraint_cols(db, table, constraint):
542
    check_name(table)
543
    check_name(constraint)
544 1849 aaronmk
    module = util.root_module(db.db)
545 464 aaronmk
    if module == 'psycopg2':
546
        return list(values(run_query(db, '''\
547
SELECT attname
548
FROM pg_constraint
549
JOIN pg_class ON pg_class.oid = conrelid
550
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
551
WHERE
552
    relname = %(table)s
553
    AND conname = %(constraint)s
554
ORDER BY attnum
555
''',
556
            {'table': table, 'constraint': constraint})))
557
    else: raise NotImplementedError("Can't list constraint columns for "+module+
558
        ' database')
559
560 2096 aaronmk
row_num_col = '_row_num'
561
562 2086 aaronmk
def add_row_num(db, table):
563 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
564
    be the primary key.'''
565 2086 aaronmk
    check_name(table)
566 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
567 2117 aaronmk
        +' serial NOT NULL PRIMARY KEY')
568 2086 aaronmk
569 1968 aaronmk
def tables(db, schema='public', table_like='%'):
570 1849 aaronmk
    module = util.root_module(db.db)
571 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
572 832 aaronmk
    if module == 'psycopg2':
573 1968 aaronmk
        return values(run_query(db, '''\
574
SELECT tablename
575
FROM pg_tables
576
WHERE
577
    schemaname = %(schema)s
578
    AND tablename LIKE %(table_like)s
579
ORDER BY tablename
580
''',
581
            params, cacheable=True))
582
    elif module == 'MySQLdb':
583
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
584
            cacheable=True))
585 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
586 830 aaronmk
587 833 aaronmk
##### Database management
588
589 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
590
    '''For kw_args, see tables()'''
591
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
592 833 aaronmk
593 832 aaronmk
##### Heuristic queries
594
595 2076 aaronmk
def with_parsed_errors(db, func):
596
    '''Translates known DB errors to typed exceptions'''
597
    try: return func()
598 46 aaronmk
    except Exception, e:
599
        msg = str(e)
600 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
601
            r'"(([^\W_]+)_[^"]+)"', msg)
602
        if match:
603
            constraint, table = match.groups()
604 854 aaronmk
            try: cols = index_cols(db, table, constraint)
605 465 aaronmk
            except NotImplementedError: raise e
606 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
607 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
608
            'constraint', msg)
609 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
610 13 aaronmk
        raise # no specific exception raised
611 11 aaronmk
612 2076 aaronmk
def try_insert(db, table, row, returning=None):
613
    '''Recovers from errors'''
614
    return with_parsed_errors(db, lambda: insert(db, table, row, returning,
615
        recover=True))
616
617 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
618 1554 aaronmk
    '''Recovers from errors.
619 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
620
    '''
621 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
622
623 471 aaronmk
    try:
624 2104 aaronmk
        cur = try_insert(db, table, row, pkey_)
625 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
626
            row_ct_ref[0] += cur.rowcount
627
        return value(cur)
628 471 aaronmk
    except DuplicateKeyException, e:
629 2104 aaronmk
        return value(select(db, table, [pkey_],
630 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
631 471 aaronmk
632 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
633 830 aaronmk
    '''Recovers from errors'''
634
    try: return value(select(db, table, [pkey], row, 1, recover=True))
635 14 aaronmk
    except StopIteration:
636 40 aaronmk
        if not create: raise
637 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
638 2078 aaronmk
639 2125 aaronmk
def put_table(db, out_table, in_tables, mapping, pkey, row_ct_ref=None,
640
    table_is_esc=False):
641 2078 aaronmk
    '''Recovers from errors.
642
    Only works under PostgreSQL (uses INSERT RETURNING).
643 2081 aaronmk
    @return Name of the table where the pkeys (from INSERT RETURNING) are made
644 2078 aaronmk
        available
645
    '''
646
    pkeys_table = clean_name(out_table)+'_pkeys'
647
    def insert_():
648 2125 aaronmk
        return insert_select(db, out_table, mapping.keys(), *mk_select(db,
649
            in_tables[0], mapping.values(), table_is_esc=table_is_esc),
650 2078 aaronmk
            returning=pkey, into=pkeys_table, recover=True,
651
            table_is_esc=table_is_esc)
652
    try:
653
        cur = with_parsed_errors(db, insert_)
654
        if row_ct_ref != None and cur.rowcount >= 0:
655
            row_ct_ref[0] += cur.rowcount
656 2086 aaronmk
657
        # Add row_num to pkeys_table, so it can be joined with in_table's pkeys
658
        add_row_num(db, pkeys_table)
659
660 2081 aaronmk
        return pkeys_table
661 2078 aaronmk
    except DuplicateKeyException, e: raise
662 2115 aaronmk
663
##### Data cleanup
664
665
def cleanup_table(db, table, cols, table_is_esc=False):
666
    def esc_name_(name): return esc_name(db, name)
667
668
    if not table_is_esc: check_name(table)
669
    cols = map(esc_name_, cols)
670
671
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
672
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
673
            for col in cols))),
674
        dict(null0='', null1=r'\N'))