Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 2127 aaronmk
import operator
5 11 aaronmk
import re
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 1889 aaronmk
from Proxy import Proxy
13 1872 aaronmk
import rand
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 135 aaronmk
def get_cur_query(cur):
20
    if hasattr(cur, 'query'): return cur.query
21
    elif hasattr(cur, '_last_executed'): return cur._last_executed
22
    else: return None
23 14 aaronmk
24 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
25 135 aaronmk
26 300 aaronmk
class DbException(exc.ExceptionWithCause):
27 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
28 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
29 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
30
31 360 aaronmk
class NameException(DbException): pass
32
33 468 aaronmk
class ExceptionWithColumns(DbException):
34
    def __init__(self, cols, cause=None):
35
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
36
        self.cols = cols
37 11 aaronmk
38 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
39 13 aaronmk
40 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
41 13 aaronmk
42 89 aaronmk
class EmptyRowException(DbException): pass
43
44 865 aaronmk
##### Warnings
45
46
class DbWarning(UserWarning): pass
47
48 1930 aaronmk
##### Result retrieval
49
50
def col_names(cur): return (col[0] for col in cur.description)
51
52
def rows(cur): return iter(lambda: cur.fetchone(), None)
53
54
def consume_rows(cur):
55
    '''Used to fetch all rows so result will be cached'''
56
    iters.consume_iter(rows(cur))
57
58
def next_row(cur): return rows(cur).next()
59
60
def row(cur):
61
    row_ = next_row(cur)
62
    consume_rows(cur)
63
    return row_
64
65
def next_value(cur): return next_row(cur)[0]
66
67
def value(cur): return row(cur)[0]
68
69
def values(cur): return iters.func_iter(lambda: next_value(cur))
70
71
def value_or_none(cur):
72
    try: return value(cur)
73
    except StopIteration: return None
74
75 2101 aaronmk
##### Input validation
76
77
def clean_name(name): return re.sub(r'\W', r'', name)
78
79
def check_name(name):
80
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
81
        +'" may contain only alphanumeric characters and _')
82
83
def esc_name_by_module(module, name, ignore_case=False):
84
    if module == 'psycopg2':
85
        if ignore_case:
86
            # Don't enclose in quotes because this disables case-insensitivity
87
            check_name(name)
88
            return name
89
        else: quote = '"'
90
    elif module == 'MySQLdb': quote = '`'
91
    else: raise NotImplementedError("Can't escape name for "+module+' database')
92
    return quote + name.replace(quote, '') + quote
93
94
def esc_name_by_engine(engine, name, **kw_args):
95
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
96
97
def esc_name(db, name, **kw_args):
98
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
99
100
def qual_name(db, schema, table):
101
    def esc_name_(name): return esc_name(db, name)
102
    table = esc_name_(table)
103
    if schema != None: return esc_name_(schema)+'.'+table
104
    else: return table
105
106 1869 aaronmk
##### Database connections
107 1849 aaronmk
108 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
109 1926 aaronmk
110 1869 aaronmk
db_engines = {
111
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
112
    'PostgreSQL': ('psycopg2', {}),
113
}
114
115
DatabaseErrors_set = set([DbException])
116
DatabaseErrors = tuple(DatabaseErrors_set)
117
118
def _add_module(module):
119
    DatabaseErrors_set.add(module.DatabaseError)
120
    global DatabaseErrors
121
    DatabaseErrors = tuple(DatabaseErrors_set)
122
123
def db_config_str(db_config):
124
    return db_config['engine']+' database '+db_config['database']
125
126 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
127 1894 aaronmk
128 1901 aaronmk
log_debug_none = lambda msg: None
129
130 1849 aaronmk
class DbConn:
131 2047 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none,
132 2050 aaronmk
        caching=True):
133 1869 aaronmk
        self.db_config = db_config
134
        self.serializable = serializable
135 1901 aaronmk
        self.log_debug = log_debug
136 2047 aaronmk
        self.caching = caching
137 1869 aaronmk
138
        self.__db = None
139 1889 aaronmk
        self.query_results = {}
140 2139 aaronmk
        self._savepoint = 0
141 1869 aaronmk
142
    def __getattr__(self, name):
143
        if name == '__dict__': raise Exception('getting __dict__')
144
        if name == 'db': return self._db()
145
        else: raise AttributeError()
146
147
    def __getstate__(self):
148
        state = copy.copy(self.__dict__) # shallow copy
149 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
150 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
151
        return state
152
153
    def _db(self):
154
        if self.__db == None:
155
            # Process db_config
156
            db_config = self.db_config.copy() # don't modify input!
157 2097 aaronmk
            schemas = db_config.pop('schemas', None)
158 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
159
            module = __import__(module_name)
160
            _add_module(module)
161
            for orig, new in mappings.iteritems():
162
                try: util.rename_key(db_config, orig, new)
163
                except KeyError: pass
164
165
            # Connect
166
            self.__db = module.connect(**db_config)
167
168
            # Configure connection
169
            if self.serializable: run_raw_query(self,
170
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
171 2101 aaronmk
            if schemas != None:
172
                schemas_ = ''.join((esc_name(self, s)+', '
173
                    for s in schemas.split(',')))
174
                run_raw_query(self, "SELECT set_config('search_path', \
175
%s || current_setting('search_path'), false)", [schemas_])
176 1869 aaronmk
177
        return self.__db
178 1889 aaronmk
179 1891 aaronmk
    class DbCursor(Proxy):
180 1927 aaronmk
        def __init__(self, outer):
181 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
182 1927 aaronmk
            self.query_results = outer.query_results
183 1894 aaronmk
            self.query_lookup = None
184 1891 aaronmk
            self.result = []
185 1889 aaronmk
186 1894 aaronmk
        def execute(self, query, params=None):
187 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
188 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
189 1904 aaronmk
            try: return_value = self.inner.execute(query, params)
190
            except Exception, e:
191
                self.result = e # cache the exception as the result
192
                self._cache_result()
193
                raise
194
            finally: self.query = get_cur_query(self.inner)
195 1930 aaronmk
            # Fetch all rows so result will be cached
196
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
197 1894 aaronmk
            return return_value
198
199 1891 aaronmk
        def fetchone(self):
200
            row = self.inner.fetchone()
201 1899 aaronmk
            if row != None: self.result.append(row)
202
            # otherwise, fetched all rows
203 1904 aaronmk
            else: self._cache_result()
204
            return row
205
206
        def _cache_result(self):
207 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
208
            # idempotent, but an invalid insert will always be invalid
209 1930 aaronmk
            if self.query_results != None and (not self._is_insert
210 1906 aaronmk
                or isinstance(self.result, Exception)):
211
212 1894 aaronmk
                assert self.query_lookup != None
213 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
214
                    util.dict_subset(dicts.AttrsDictView(self),
215
                    ['query', 'result', 'rowcount', 'description']))
216 1906 aaronmk
217 1916 aaronmk
        class CacheCursor:
218
            def __init__(self, cached_result): self.__dict__ = cached_result
219
220 1927 aaronmk
            def execute(self, *args, **kw_args):
221 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
222
                # otherwise, result is a rows list
223
                self.iter = iter(self.result)
224
225
            def fetchone(self):
226
                try: return self.iter.next()
227
                except StopIteration: return None
228 1891 aaronmk
229 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
230 2047 aaronmk
        if not self.caching: cacheable = False
231 1903 aaronmk
        used_cache = False
232
        try:
233 1927 aaronmk
            # Get cursor
234
            if cacheable:
235
                query_lookup = _query_lookup(query, params)
236
                try:
237
                    cur = self.query_results[query_lookup]
238
                    used_cache = True
239
                except KeyError: cur = self.DbCursor(self)
240
            else: cur = self.db.cursor()
241
242
            # Run query
243
            try: cur.execute(query, params)
244
            except Exception, e:
245
                _add_cursor_info(e, cur)
246
                raise
247 1903 aaronmk
        finally:
248
            if self.log_debug != log_debug_none: # only compute msg if needed
249
                if used_cache: cache_status = 'Cache hit'
250
                elif cacheable: cache_status = 'Cache miss'
251
                else: cache_status = 'Non-cacheable'
252 1927 aaronmk
                self.log_debug(cache_status+': '
253
                    +strings.one_line(get_cur_query(cur)))
254 1903 aaronmk
255
        return cur
256 1914 aaronmk
257
    def is_cached(self, query, params=None):
258
        return _query_lookup(query, params) in self.query_results
259 2139 aaronmk
260
    def with_savepoint(self, func):
261
        savepoint = 'savepoint_'+str(self._savepoint)
262
        self.run_query('SAVEPOINT '+savepoint)
263
        self._savepoint += 1
264
        try:
265
            try: return_val = func()
266
            finally:
267
                self._savepoint -= 1
268
                assert self._savepoint >= 0
269
        except:
270
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint)
271
            raise
272
        else:
273
            self.run_query('RELEASE SAVEPOINT '+savepoint)
274
            return return_val
275 1849 aaronmk
276 1869 aaronmk
connect = DbConn
277
278 832 aaronmk
##### Querying
279
280 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
281 2085 aaronmk
    '''For params, see DbConn.run_query()'''
282 1894 aaronmk
    return db.run_query(*args, **kw_args)
283 11 aaronmk
284 2068 aaronmk
def mogrify(db, query, params):
285
    module = util.root_module(db.db)
286
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
287
    else: raise NotImplementedError("Can't mogrify query for "+module+
288
        ' database')
289
290 832 aaronmk
##### Recoverable querying
291 15 aaronmk
292 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
293 11 aaronmk
294 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
295 830 aaronmk
    if recover == None: recover = False
296
297 1894 aaronmk
    def run(): return run_raw_query(db, query, params, cacheable)
298 1914 aaronmk
    if recover and not db.is_cached(query, params):
299
        return with_savepoint(db, run)
300
    else: return run() # don't need savepoint if cached
301 830 aaronmk
302 832 aaronmk
##### Basic queries
303
304 2085 aaronmk
def run_query_into(db, query, params, into=None, *args, **kw_args):
305
    '''Outputs a query to a temp table.
306
    For params, see run_query().
307
    '''
308
    if into == None: return run_query(db, query, params, *args, **kw_args)
309
    else: # place rows in temp table
310
        check_name(into)
311
312
        run_query(db, 'DROP TABLE IF EXISTS '+into+' CASCADE', *args, **kw_args)
313
        return run_query(db, 'CREATE TEMP TABLE '+into+' AS '+query, params,
314
            *args, **kw_args) # CREATE TABLE sets rowcount to # rows in query
315
316 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
317
318 2127 aaronmk
join_using = object() # tells mk_select() to join the column with USING
319
320 2121 aaronmk
def mk_select(db, tables, fields=None, conds=None, limit=None, start=None,
321 2120 aaronmk
    order_by=order_by_pkey, table_is_esc=False):
322 1981 aaronmk
    '''
323 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
324 2124 aaronmk
        together: [table0, (table1, dict(right_col=left_col, ...)), ...]
325 1981 aaronmk
    @param fields Use None to select all fields in the table
326
    @param table_is_esc Whether the table name has already been escaped
327 2054 aaronmk
    @return tuple(query, params)
328 1981 aaronmk
    '''
329 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
330 2058 aaronmk
331 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
332
    tables = tables[:] # don't modify input!
333
    table0 = tables.pop(0) # first table is separate
334
335 1135 aaronmk
    if conds == None: conds = {}
336 135 aaronmk
    assert limit == None or type(limit) == int
337 865 aaronmk
    assert start == None or type(start) == int
338 2120 aaronmk
    if order_by == order_by_pkey:
339 2121 aaronmk
        order_by = pkey(db, table0, recover=True, table_is_esc=table_is_esc)
340
    if not table_is_esc: table0 = esc_name_(table0)
341 865 aaronmk
342 2056 aaronmk
    params = []
343
344
    def parse_col(field):
345
        '''Parses fields'''
346 2121 aaronmk
        if isinstance(field, tuple): # field is literal value
347 2056 aaronmk
            value, col = field
348
            sql_ = '%s'
349
            params.append(value)
350 2058 aaronmk
            if col != None: sql_ += ' AS '+esc_name_(col)
351
        else: sql_ = esc_name_(field) # field is col name
352 2056 aaronmk
        return sql_
353 11 aaronmk
    def cond(entry):
354 2056 aaronmk
        '''Parses conditions'''
355 13 aaronmk
        col, value = entry
356 2058 aaronmk
        cond_ = esc_name_(col)+' '
357 11 aaronmk
        if value == None: cond_ += 'IS'
358
        else: cond_ += '='
359
        cond_ += ' %s'
360
        return cond_
361 2056 aaronmk
362 1135 aaronmk
    query = 'SELECT '
363
    if fields == None: query += '*'
364 2056 aaronmk
    else: query += ', '.join(map(parse_col, fields))
365 2121 aaronmk
    query += ' FROM '+table0
366 865 aaronmk
367 2122 aaronmk
    # Add joins
368
    left_table = table0
369
    for table, joins in tables:
370
        if not table_is_esc: table = esc_name_(table)
371 2127 aaronmk
        query += ' JOIN '+table
372 2122 aaronmk
373
        def join(entry):
374 2127 aaronmk
            '''Parses non-USING joins'''
375 2124 aaronmk
            right_col, left_col = entry
376
            right_col = table+'.'+esc_name_(right_col)
377 2122 aaronmk
            left_col = left_table+'.'+esc_name_(left_col)
378 2123 aaronmk
            return (right_col+' = '+left_col
379
                +' OR ('+right_col+' IS NULL AND '+left_col+' IS NULL)')
380 2122 aaronmk
381 2127 aaronmk
        if reduce(operator.and_, (v == join_using for v in joins.itervalues())):
382
            # all cols w/ USING
383 2130 aaronmk
            query += ' USING ('+(', '.join(joins.iterkeys()))+')'
384 2127 aaronmk
        else: query += ' ON '+(' AND '.join(map(join, joins.iteritems())))
385
386 2122 aaronmk
        left_table = table
387
388 865 aaronmk
    missing = True
389 89 aaronmk
    if conds != {}:
390 2122 aaronmk
        query += ' WHERE '+(' AND '.join(map(cond, conds.iteritems())))
391 2056 aaronmk
        params += conds.values()
392 865 aaronmk
        missing = False
393 2120 aaronmk
    if order_by != None: query += ' ORDER BY '+esc_name_(order_by)
394 865 aaronmk
    if limit != None: query += ' LIMIT '+str(limit); missing = False
395
    if start != None:
396
        if start != 0: query += ' OFFSET '+str(start)
397
        missing = False
398
    if missing: warnings.warn(DbWarning(
399
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
400
401 2056 aaronmk
    return (query, params)
402 11 aaronmk
403 2054 aaronmk
def select(db, *args, **kw_args):
404
    '''For params, see mk_select() and run_query()'''
405
    recover = kw_args.pop('recover', None)
406
    cacheable = kw_args.pop('cacheable', True)
407
408
    query, params = mk_select(db, *args, **kw_args)
409
    return run_query(db, query, params, recover, cacheable)
410
411 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
412 2070 aaronmk
    returning=None, embeddable=False, table_is_esc=False):
413 1960 aaronmk
    '''
414
    @param returning str|None An inserted column (such as pkey) to return
415 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
416 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
417
        query will be fully cached, not just if it raises an exception.
418 1960 aaronmk
    @param table_is_esc Whether the table name has already been escaped
419
    '''
420 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
421
    if cols == []: cols = None # no cols (all defaults) = unknown col names
422 1960 aaronmk
    if not table_is_esc: check_name(table)
423 2063 aaronmk
424
    # Build query
425
    query = 'INSERT INTO '+table
426
    if cols != None:
427
        map(check_name, cols)
428
        query += ' ('+', '.join(cols)+')'
429
    query += ' '+select_query
430
431
    if returning != None:
432
        check_name(returning)
433
        query += ' RETURNING '+returning
434
435 2070 aaronmk
    if embeddable:
436
        # Create function
437 2083 aaronmk
        function = 'pg_temp.'+('_'.join(map(clean_name,
438
            ['insert', table] + cols)))
439 2070 aaronmk
        return_type = 'SETOF '+table+'.'+returning+'%TYPE'
440
        function_query = '''\
441 2083 aaronmk
CREATE OR REPLACE FUNCTION '''+function+'''() RETURNS '''+return_type+'''
442 2070 aaronmk
    LANGUAGE sql
443
    AS $$'''+mogrify(db, query, params)+''';$$;
444
'''
445
        run_query(db, function_query, cacheable=True)
446
447
        # Return query that uses function
448 2134 aaronmk
        return mk_select(db, function+'() AS f ('+returning+')', start=0,
449
            order_by=None, table_is_esc=True)# AS clause requires function alias
450 2070 aaronmk
451 2066 aaronmk
    return (query, params)
452
453
def insert_select(db, *args, **kw_args):
454 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
455 2072 aaronmk
    @param into Name of temp table to place RETURNING values in
456
    '''
457
    into = kw_args.pop('into', None)
458
    if into != None: kw_args['embeddable'] = True
459 2066 aaronmk
    recover = kw_args.pop('recover', None)
460
    cacheable = kw_args.pop('cacheable', True)
461
462
    query, params = mk_insert_select(db, *args, **kw_args)
463 2085 aaronmk
    return run_query_into(db, query, params, into, recover, cacheable)
464 2063 aaronmk
465 2066 aaronmk
default = object() # tells insert() to use the default value for a column
466
467 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
468 2085 aaronmk
    '''For params, see insert_select()'''
469 1960 aaronmk
    if lists.is_seq(row): cols = None
470
    else:
471
        cols = row.keys()
472
        row = row.values()
473
    row = list(row) # ensure that "!= []" works
474
475 1961 aaronmk
    # Check for special values
476
    labels = []
477
    values = []
478
    for value in row:
479
        if value == default: labels.append('DEFAULT')
480
        else:
481
            labels.append('%s')
482
            values.append(value)
483
484
    # Build query
485 2063 aaronmk
    if values != []: query = ' VALUES ('+(', '.join(labels))+')'
486
    else: query = None
487 1554 aaronmk
488 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
489 11 aaronmk
490 135 aaronmk
def last_insert_id(db):
491 1849 aaronmk
    module = util.root_module(db.db)
492 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
493
    elif module == 'MySQLdb': return db.insert_id()
494
    else: return None
495 13 aaronmk
496 1968 aaronmk
def truncate(db, table, schema='public'):
497
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
498 832 aaronmk
499
##### Database structure queries
500
501 2084 aaronmk
def pkey(db, table, recover=None, table_is_esc=False):
502 832 aaronmk
    '''Assumed to be first column in table'''
503 2120 aaronmk
    return col_names(select(db, table, limit=0, order_by=None, recover=recover,
504 2084 aaronmk
        table_is_esc=table_is_esc)).next()
505 832 aaronmk
506 853 aaronmk
def index_cols(db, table, index):
507
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
508
    automatically created. When you don't know whether something is a UNIQUE
509
    constraint or a UNIQUE index, use this function.'''
510
    check_name(table)
511
    check_name(index)
512 1909 aaronmk
    module = util.root_module(db.db)
513
    if module == 'psycopg2':
514
        return list(values(run_query(db, '''\
515 853 aaronmk
SELECT attname
516 866 aaronmk
FROM
517
(
518
        SELECT attnum, attname
519
        FROM pg_index
520
        JOIN pg_class index ON index.oid = indexrelid
521
        JOIN pg_class table_ ON table_.oid = indrelid
522
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
523
        WHERE
524
            table_.relname = %(table)s
525
            AND index.relname = %(index)s
526
    UNION
527
        SELECT attnum, attname
528
        FROM
529
        (
530
            SELECT
531
                indrelid
532
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
533
                    AS indkey
534
            FROM pg_index
535
            JOIN pg_class index ON index.oid = indexrelid
536
            JOIN pg_class table_ ON table_.oid = indrelid
537
            WHERE
538
                table_.relname = %(table)s
539
                AND index.relname = %(index)s
540
        ) s
541
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
542
) s
543 853 aaronmk
ORDER BY attnum
544
''',
545 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
546
    else: raise NotImplementedError("Can't list index columns for "+module+
547
        ' database')
548 853 aaronmk
549 464 aaronmk
def constraint_cols(db, table, constraint):
550
    check_name(table)
551
    check_name(constraint)
552 1849 aaronmk
    module = util.root_module(db.db)
553 464 aaronmk
    if module == 'psycopg2':
554
        return list(values(run_query(db, '''\
555
SELECT attname
556
FROM pg_constraint
557
JOIN pg_class ON pg_class.oid = conrelid
558
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
559
WHERE
560
    relname = %(table)s
561
    AND conname = %(constraint)s
562
ORDER BY attnum
563
''',
564
            {'table': table, 'constraint': constraint})))
565
    else: raise NotImplementedError("Can't list constraint columns for "+module+
566
        ' database')
567
568 2096 aaronmk
row_num_col = '_row_num'
569
570 2086 aaronmk
def add_row_num(db, table):
571 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
572
    be the primary key.'''
573 2086 aaronmk
    check_name(table)
574 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
575 2117 aaronmk
        +' serial NOT NULL PRIMARY KEY')
576 2086 aaronmk
577 1968 aaronmk
def tables(db, schema='public', table_like='%'):
578 1849 aaronmk
    module = util.root_module(db.db)
579 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
580 832 aaronmk
    if module == 'psycopg2':
581 1968 aaronmk
        return values(run_query(db, '''\
582
SELECT tablename
583
FROM pg_tables
584
WHERE
585
    schemaname = %(schema)s
586
    AND tablename LIKE %(table_like)s
587
ORDER BY tablename
588
''',
589
            params, cacheable=True))
590
    elif module == 'MySQLdb':
591
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
592
            cacheable=True))
593 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
594 830 aaronmk
595 833 aaronmk
##### Database management
596
597 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
598
    '''For kw_args, see tables()'''
599
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
600 833 aaronmk
601 832 aaronmk
##### Heuristic queries
602
603 2076 aaronmk
def with_parsed_errors(db, func):
604
    '''Translates known DB errors to typed exceptions'''
605
    try: return func()
606 46 aaronmk
    except Exception, e:
607
        msg = str(e)
608 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
609
            r'"(([^\W_]+)_[^"]+)"', msg)
610
        if match:
611
            constraint, table = match.groups()
612 854 aaronmk
            try: cols = index_cols(db, table, constraint)
613 465 aaronmk
            except NotImplementedError: raise e
614 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
615 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
616
            'constraint', msg)
617 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
618 13 aaronmk
        raise # no specific exception raised
619 11 aaronmk
620 2076 aaronmk
def try_insert(db, table, row, returning=None):
621
    '''Recovers from errors'''
622
    return with_parsed_errors(db, lambda: insert(db, table, row, returning,
623
        recover=True))
624
625 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
626 1554 aaronmk
    '''Recovers from errors.
627 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
628
    '''
629 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
630
631 471 aaronmk
    try:
632 2104 aaronmk
        cur = try_insert(db, table, row, pkey_)
633 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
634
            row_ct_ref[0] += cur.rowcount
635
        return value(cur)
636 471 aaronmk
    except DuplicateKeyException, e:
637 2104 aaronmk
        return value(select(db, table, [pkey_],
638 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
639 471 aaronmk
640 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
641 830 aaronmk
    '''Recovers from errors'''
642
    try: return value(select(db, table, [pkey], row, 1, recover=True))
643 14 aaronmk
    except StopIteration:
644 40 aaronmk
        if not create: raise
645 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
646 2078 aaronmk
647 2134 aaronmk
def put_table(db, out_table, in_tables, mapping, limit=None, start=0,
648
    row_ct_ref=None, table_is_esc=False):
649 2078 aaronmk
    '''Recovers from errors.
650
    Only works under PostgreSQL (uses INSERT RETURNING).
651 2131 aaronmk
    @param in_tables The main input table to select from, followed by a list of
652
        tables to join with it using the main input table's pkey
653 2133 aaronmk
    @return (table, col) Where the pkeys (from INSERT RETURNING) are made
654 2078 aaronmk
        available
655
    '''
656 2132 aaronmk
    out_table_clean = clean_name(out_table)
657
    pkeys = out_table_clean+'_pkeys'
658 2131 aaronmk
659 2133 aaronmk
    out_pkey = pkey(db, out_table, recover=True, table_is_esc=table_is_esc)
660
661 2132 aaronmk
    # Join together input tables
662 2131 aaronmk
    in_tables = in_tables[:] # don't modify input!
663
    in_tables0 = in_tables.pop(0) # first table is separate
664
    in_pkey = pkey(db, in_tables0, recover=True, table_is_esc=table_is_esc)
665
    joins = [in_tables0] + [(t, {in_pkey: join_using}) for t in in_tables]
666
667 2132 aaronmk
    def mk_select_(cols):
668 2134 aaronmk
        return mk_select(db, joins, cols, limit=limit, start=start,
669
            table_is_esc=table_is_esc)
670 2132 aaronmk
671
    out_pkeys = out_table_clean+'_out_pkeys'
672 2078 aaronmk
    def insert_():
673 2132 aaronmk
        cur = insert_select(db, out_table, mapping.keys(),
674 2133 aaronmk
            *mk_select_(mapping.values()), returning=out_pkey,
675 2132 aaronmk
            into=out_pkeys, recover=True, table_is_esc=table_is_esc)
676 2078 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
677
            row_ct_ref[0] += cur.rowcount
678 2132 aaronmk
        add_row_num(db, out_pkeys) # for joining it with in_pkeys
679 2086 aaronmk
680 2132 aaronmk
        # Get input pkeys corresponding to rows in insert
681
        in_pkeys = out_table_clean+'_in_pkeys'
682
        run_query_into(db, *mk_select_([in_pkey]), into=in_pkeys)
683
        add_row_num(db, in_pkeys) # for joining it with out_pkeys
684 2086 aaronmk
685 2132 aaronmk
        # Join together out_pkeys and in_pkeys
686
        run_query_into(db, *mk_select(db,
687
            [in_pkeys, (out_pkeys, {row_num_col: join_using})],
688 2134 aaronmk
            [in_pkey, out_pkey], start=0), into=pkeys)
689 2132 aaronmk
690
    try:
691
        # Insert and capture output pkeys
692
        with_parsed_errors(db, insert_)
693
694 2133 aaronmk
        return (pkeys, out_pkey)
695 2078 aaronmk
    except DuplicateKeyException, e: raise
696 2115 aaronmk
697
##### Data cleanup
698
699
def cleanup_table(db, table, cols, table_is_esc=False):
700
    def esc_name_(name): return esc_name(db, name)
701
702
    if not table_is_esc: check_name(table)
703
    cols = map(esc_name_, cols)
704
705
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
706
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
707
            for col in cols))),
708
        dict(null0='', null1=r'\N'))