Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 2127 aaronmk
import operator
5 11 aaronmk
import re
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 1889 aaronmk
from Proxy import Proxy
13 1872 aaronmk
import rand
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 135 aaronmk
def get_cur_query(cur):
20
    if hasattr(cur, 'query'): return cur.query
21
    elif hasattr(cur, '_last_executed'): return cur._last_executed
22
    else: return None
23 14 aaronmk
24 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
25 135 aaronmk
26 300 aaronmk
class DbException(exc.ExceptionWithCause):
27 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
28 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
29 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
30
31 2143 aaronmk
class ExceptionWithName(DbException):
32
    def __init__(self, name, cause=None):
33 2145 aaronmk
        DbException.__init__(self, 'for name: '+str(name), cause)
34 2143 aaronmk
        self.name = name
35 360 aaronmk
36 468 aaronmk
class ExceptionWithColumns(DbException):
37
    def __init__(self, cols, cause=None):
38 2145 aaronmk
        DbException.__init__(self, 'for columns: '+(', '.join(cols)), cause)
39 468 aaronmk
        self.cols = cols
40 11 aaronmk
41 2143 aaronmk
class NameException(DbException): pass
42
43 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
44 13 aaronmk
45 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
46 13 aaronmk
47 2143 aaronmk
class DuplicateTableException(ExceptionWithName): pass
48
49 89 aaronmk
class EmptyRowException(DbException): pass
50
51 865 aaronmk
##### Warnings
52
53
class DbWarning(UserWarning): pass
54
55 1930 aaronmk
##### Result retrieval
56
57
def col_names(cur): return (col[0] for col in cur.description)
58
59
def rows(cur): return iter(lambda: cur.fetchone(), None)
60
61
def consume_rows(cur):
62
    '''Used to fetch all rows so result will be cached'''
63
    iters.consume_iter(rows(cur))
64
65
def next_row(cur): return rows(cur).next()
66
67
def row(cur):
68
    row_ = next_row(cur)
69
    consume_rows(cur)
70
    return row_
71
72
def next_value(cur): return next_row(cur)[0]
73
74
def value(cur): return row(cur)[0]
75
76
def values(cur): return iters.func_iter(lambda: next_value(cur))
77
78
def value_or_none(cur):
79
    try: return value(cur)
80
    except StopIteration: return None
81
82 2101 aaronmk
##### Input validation
83
84
def clean_name(name): return re.sub(r'\W', r'', name)
85
86
def check_name(name):
87
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
88
        +'" may contain only alphanumeric characters and _')
89
90
def esc_name_by_module(module, name, ignore_case=False):
91
    if module == 'psycopg2':
92
        if ignore_case:
93
            # Don't enclose in quotes because this disables case-insensitivity
94
            check_name(name)
95
            return name
96
        else: quote = '"'
97
    elif module == 'MySQLdb': quote = '`'
98
    else: raise NotImplementedError("Can't escape name for "+module+' database')
99
    return quote + name.replace(quote, '') + quote
100
101
def esc_name_by_engine(engine, name, **kw_args):
102
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
103
104
def esc_name(db, name, **kw_args):
105
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
106
107
def qual_name(db, schema, table):
108
    def esc_name_(name): return esc_name(db, name)
109
    table = esc_name_(table)
110
    if schema != None: return esc_name_(schema)+'.'+table
111
    else: return table
112
113 1869 aaronmk
##### Database connections
114 1849 aaronmk
115 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
116 1926 aaronmk
117 1869 aaronmk
db_engines = {
118
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
119
    'PostgreSQL': ('psycopg2', {}),
120
}
121
122
DatabaseErrors_set = set([DbException])
123
DatabaseErrors = tuple(DatabaseErrors_set)
124
125
def _add_module(module):
126
    DatabaseErrors_set.add(module.DatabaseError)
127
    global DatabaseErrors
128
    DatabaseErrors = tuple(DatabaseErrors_set)
129
130
def db_config_str(db_config):
131
    return db_config['engine']+' database '+db_config['database']
132
133 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
134 1894 aaronmk
135 1901 aaronmk
log_debug_none = lambda msg: None
136
137 1849 aaronmk
class DbConn:
138 2047 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none,
139 2050 aaronmk
        caching=True):
140 1869 aaronmk
        self.db_config = db_config
141
        self.serializable = serializable
142 1901 aaronmk
        self.log_debug = log_debug
143 2047 aaronmk
        self.caching = caching
144 1869 aaronmk
145
        self.__db = None
146 1889 aaronmk
        self.query_results = {}
147 2139 aaronmk
        self._savepoint = 0
148 1869 aaronmk
149
    def __getattr__(self, name):
150
        if name == '__dict__': raise Exception('getting __dict__')
151
        if name == 'db': return self._db()
152
        else: raise AttributeError()
153
154
    def __getstate__(self):
155
        state = copy.copy(self.__dict__) # shallow copy
156 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
157 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
158
        return state
159
160
    def _db(self):
161
        if self.__db == None:
162
            # Process db_config
163
            db_config = self.db_config.copy() # don't modify input!
164 2097 aaronmk
            schemas = db_config.pop('schemas', None)
165 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
166
            module = __import__(module_name)
167
            _add_module(module)
168
            for orig, new in mappings.iteritems():
169
                try: util.rename_key(db_config, orig, new)
170
                except KeyError: pass
171
172
            # Connect
173
            self.__db = module.connect(**db_config)
174
175
            # Configure connection
176
            if self.serializable: run_raw_query(self,
177
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
178 2101 aaronmk
            if schemas != None:
179
                schemas_ = ''.join((esc_name(self, s)+', '
180
                    for s in schemas.split(',')))
181
                run_raw_query(self, "SELECT set_config('search_path', \
182
%s || current_setting('search_path'), false)", [schemas_])
183 1869 aaronmk
184
        return self.__db
185 1889 aaronmk
186 1891 aaronmk
    class DbCursor(Proxy):
187 1927 aaronmk
        def __init__(self, outer):
188 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
189 1927 aaronmk
            self.query_results = outer.query_results
190 1894 aaronmk
            self.query_lookup = None
191 1891 aaronmk
            self.result = []
192 1889 aaronmk
193 1894 aaronmk
        def execute(self, query, params=None):
194 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
195 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
196 1904 aaronmk
            try: return_value = self.inner.execute(query, params)
197
            except Exception, e:
198
                self.result = e # cache the exception as the result
199
                self._cache_result()
200
                raise
201
            finally: self.query = get_cur_query(self.inner)
202 1930 aaronmk
            # Fetch all rows so result will be cached
203
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
204 1894 aaronmk
            return return_value
205
206 1891 aaronmk
        def fetchone(self):
207
            row = self.inner.fetchone()
208 1899 aaronmk
            if row != None: self.result.append(row)
209
            # otherwise, fetched all rows
210 1904 aaronmk
            else: self._cache_result()
211
            return row
212
213
        def _cache_result(self):
214 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
215
            # idempotent, but an invalid insert will always be invalid
216 1930 aaronmk
            if self.query_results != None and (not self._is_insert
217 1906 aaronmk
                or isinstance(self.result, Exception)):
218
219 1894 aaronmk
                assert self.query_lookup != None
220 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
221
                    util.dict_subset(dicts.AttrsDictView(self),
222
                    ['query', 'result', 'rowcount', 'description']))
223 1906 aaronmk
224 1916 aaronmk
        class CacheCursor:
225
            def __init__(self, cached_result): self.__dict__ = cached_result
226
227 1927 aaronmk
            def execute(self, *args, **kw_args):
228 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
229
                # otherwise, result is a rows list
230
                self.iter = iter(self.result)
231
232
            def fetchone(self):
233
                try: return self.iter.next()
234
                except StopIteration: return None
235 1891 aaronmk
236 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
237 2047 aaronmk
        if not self.caching: cacheable = False
238 1903 aaronmk
        used_cache = False
239
        try:
240 1927 aaronmk
            # Get cursor
241
            if cacheable:
242
                query_lookup = _query_lookup(query, params)
243
                try:
244
                    cur = self.query_results[query_lookup]
245
                    used_cache = True
246
                except KeyError: cur = self.DbCursor(self)
247
            else: cur = self.db.cursor()
248
249
            # Run query
250
            try: cur.execute(query, params)
251
            except Exception, e:
252
                _add_cursor_info(e, cur)
253
                raise
254 1903 aaronmk
        finally:
255
            if self.log_debug != log_debug_none: # only compute msg if needed
256
                if used_cache: cache_status = 'Cache hit'
257
                elif cacheable: cache_status = 'Cache miss'
258
                else: cache_status = 'Non-cacheable'
259 1927 aaronmk
                self.log_debug(cache_status+': '
260
                    +strings.one_line(get_cur_query(cur)))
261 1903 aaronmk
262
        return cur
263 1914 aaronmk
264
    def is_cached(self, query, params=None):
265
        return _query_lookup(query, params) in self.query_results
266 2139 aaronmk
267
    def with_savepoint(self, func):
268
        savepoint = 'savepoint_'+str(self._savepoint)
269
        self.run_query('SAVEPOINT '+savepoint)
270
        self._savepoint += 1
271
        try:
272
            try: return_val = func()
273
            finally:
274
                self._savepoint -= 1
275
                assert self._savepoint >= 0
276
        except:
277
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint)
278
            raise
279
        else:
280
            self.run_query('RELEASE SAVEPOINT '+savepoint)
281
            return return_val
282 1849 aaronmk
283 1869 aaronmk
connect = DbConn
284
285 832 aaronmk
##### Querying
286
287 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
288 2085 aaronmk
    '''For params, see DbConn.run_query()'''
289 1894 aaronmk
    return db.run_query(*args, **kw_args)
290 11 aaronmk
291 2068 aaronmk
def mogrify(db, query, params):
292
    module = util.root_module(db.db)
293
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
294
    else: raise NotImplementedError("Can't mogrify query for "+module+
295
        ' database')
296
297 832 aaronmk
##### Recoverable querying
298 15 aaronmk
299 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
300 11 aaronmk
301 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
302 830 aaronmk
    if recover == None: recover = False
303
304 1894 aaronmk
    def run(): return run_raw_query(db, query, params, cacheable)
305 1914 aaronmk
    if recover and not db.is_cached(query, params):
306
        return with_savepoint(db, run)
307
    else: return run() # don't need savepoint if cached
308 830 aaronmk
309 832 aaronmk
##### Basic queries
310
311 2085 aaronmk
def run_query_into(db, query, params, into=None, *args, **kw_args):
312
    '''Outputs a query to a temp table.
313
    For params, see run_query().
314
    '''
315
    if into == None: return run_query(db, query, params, *args, **kw_args)
316
    else: # place rows in temp table
317
        check_name(into)
318
        return run_query(db, 'CREATE TEMP TABLE '+into+' AS '+query, params,
319
            *args, **kw_args) # CREATE TABLE sets rowcount to # rows in query
320
321 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
322
323 2127 aaronmk
join_using = object() # tells mk_select() to join the column with USING
324
325 2121 aaronmk
def mk_select(db, tables, fields=None, conds=None, limit=None, start=None,
326 2120 aaronmk
    order_by=order_by_pkey, table_is_esc=False):
327 1981 aaronmk
    '''
328 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
329 2124 aaronmk
        together: [table0, (table1, dict(right_col=left_col, ...)), ...]
330 1981 aaronmk
    @param fields Use None to select all fields in the table
331
    @param table_is_esc Whether the table name has already been escaped
332 2054 aaronmk
    @return tuple(query, params)
333 1981 aaronmk
    '''
334 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
335 2058 aaronmk
336 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
337 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
338 2121 aaronmk
    table0 = tables.pop(0) # first table is separate
339
340 1135 aaronmk
    if conds == None: conds = {}
341 135 aaronmk
    assert limit == None or type(limit) == int
342 865 aaronmk
    assert start == None or type(start) == int
343 2120 aaronmk
    if order_by == order_by_pkey:
344 2121 aaronmk
        order_by = pkey(db, table0, recover=True, table_is_esc=table_is_esc)
345
    if not table_is_esc: table0 = esc_name_(table0)
346 865 aaronmk
347 2056 aaronmk
    params = []
348
349
    def parse_col(field):
350
        '''Parses fields'''
351 2121 aaronmk
        if isinstance(field, tuple): # field is literal value
352 2056 aaronmk
            value, col = field
353
            sql_ = '%s'
354
            params.append(value)
355 2058 aaronmk
            if col != None: sql_ += ' AS '+esc_name_(col)
356
        else: sql_ = esc_name_(field) # field is col name
357 2056 aaronmk
        return sql_
358 11 aaronmk
    def cond(entry):
359 2056 aaronmk
        '''Parses conditions'''
360 13 aaronmk
        col, value = entry
361 2058 aaronmk
        cond_ = esc_name_(col)+' '
362 11 aaronmk
        if value == None: cond_ += 'IS'
363
        else: cond_ += '='
364
        cond_ += ' %s'
365
        return cond_
366 2056 aaronmk
367 1135 aaronmk
    query = 'SELECT '
368
    if fields == None: query += '*'
369 2056 aaronmk
    else: query += ', '.join(map(parse_col, fields))
370 2121 aaronmk
    query += ' FROM '+table0
371 865 aaronmk
372 2122 aaronmk
    # Add joins
373
    left_table = table0
374
    for table, joins in tables:
375
        if not table_is_esc: table = esc_name_(table)
376 2127 aaronmk
        query += ' JOIN '+table
377 2122 aaronmk
378
        def join(entry):
379 2127 aaronmk
            '''Parses non-USING joins'''
380 2124 aaronmk
            right_col, left_col = entry
381
            right_col = table+'.'+esc_name_(right_col)
382 2122 aaronmk
            left_col = left_table+'.'+esc_name_(left_col)
383 2123 aaronmk
            return (right_col+' = '+left_col
384
                +' OR ('+right_col+' IS NULL AND '+left_col+' IS NULL)')
385 2122 aaronmk
386 2127 aaronmk
        if reduce(operator.and_, (v == join_using for v in joins.itervalues())):
387
            # all cols w/ USING
388 2130 aaronmk
            query += ' USING ('+(', '.join(joins.iterkeys()))+')'
389 2127 aaronmk
        else: query += ' ON '+(' AND '.join(map(join, joins.iteritems())))
390
391 2122 aaronmk
        left_table = table
392
393 865 aaronmk
    missing = True
394 89 aaronmk
    if conds != {}:
395 2122 aaronmk
        query += ' WHERE '+(' AND '.join(map(cond, conds.iteritems())))
396 2056 aaronmk
        params += conds.values()
397 865 aaronmk
        missing = False
398 2120 aaronmk
    if order_by != None: query += ' ORDER BY '+esc_name_(order_by)
399 865 aaronmk
    if limit != None: query += ' LIMIT '+str(limit); missing = False
400
    if start != None:
401
        if start != 0: query += ' OFFSET '+str(start)
402
        missing = False
403
    if missing: warnings.warn(DbWarning(
404
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
405
406 2056 aaronmk
    return (query, params)
407 11 aaronmk
408 2054 aaronmk
def select(db, *args, **kw_args):
409
    '''For params, see mk_select() and run_query()'''
410
    recover = kw_args.pop('recover', None)
411
    cacheable = kw_args.pop('cacheable', True)
412
413
    query, params = mk_select(db, *args, **kw_args)
414
    return run_query(db, query, params, recover, cacheable)
415
416 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
417 2070 aaronmk
    returning=None, embeddable=False, table_is_esc=False):
418 1960 aaronmk
    '''
419
    @param returning str|None An inserted column (such as pkey) to return
420 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
421 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
422
        query will be fully cached, not just if it raises an exception.
423 1960 aaronmk
    @param table_is_esc Whether the table name has already been escaped
424
    '''
425 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
426
    if cols == []: cols = None # no cols (all defaults) = unknown col names
427 1960 aaronmk
    if not table_is_esc: check_name(table)
428 2063 aaronmk
429
    # Build query
430
    query = 'INSERT INTO '+table
431
    if cols != None:
432
        map(check_name, cols)
433
        query += ' ('+', '.join(cols)+')'
434
    query += ' '+select_query
435
436
    if returning != None:
437
        check_name(returning)
438
        query += ' RETURNING '+returning
439
440 2070 aaronmk
    if embeddable:
441
        # Create function
442 2083 aaronmk
        function = 'pg_temp.'+('_'.join(map(clean_name,
443
            ['insert', table] + cols)))
444 2070 aaronmk
        return_type = 'SETOF '+table+'.'+returning+'%TYPE'
445
        function_query = '''\
446 2083 aaronmk
CREATE OR REPLACE FUNCTION '''+function+'''() RETURNS '''+return_type+'''
447 2070 aaronmk
    LANGUAGE sql
448
    AS $$'''+mogrify(db, query, params)+''';$$;
449
'''
450
        run_query(db, function_query, cacheable=True)
451
452
        # Return query that uses function
453 2134 aaronmk
        return mk_select(db, function+'() AS f ('+returning+')', start=0,
454
            order_by=None, table_is_esc=True)# AS clause requires function alias
455 2070 aaronmk
456 2066 aaronmk
    return (query, params)
457
458
def insert_select(db, *args, **kw_args):
459 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
460 2072 aaronmk
    @param into Name of temp table to place RETURNING values in
461
    '''
462
    into = kw_args.pop('into', None)
463
    if into != None: kw_args['embeddable'] = True
464 2066 aaronmk
    recover = kw_args.pop('recover', None)
465
    cacheable = kw_args.pop('cacheable', True)
466
467
    query, params = mk_insert_select(db, *args, **kw_args)
468 2085 aaronmk
    return run_query_into(db, query, params, into, recover, cacheable)
469 2063 aaronmk
470 2066 aaronmk
default = object() # tells insert() to use the default value for a column
471
472 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
473 2085 aaronmk
    '''For params, see insert_select()'''
474 1960 aaronmk
    if lists.is_seq(row): cols = None
475
    else:
476
        cols = row.keys()
477
        row = row.values()
478
    row = list(row) # ensure that "!= []" works
479
480 1961 aaronmk
    # Check for special values
481
    labels = []
482
    values = []
483
    for value in row:
484
        if value == default: labels.append('DEFAULT')
485
        else:
486
            labels.append('%s')
487
            values.append(value)
488
489
    # Build query
490 2063 aaronmk
    if values != []: query = ' VALUES ('+(', '.join(labels))+')'
491
    else: query = None
492 1554 aaronmk
493 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
494 11 aaronmk
495 135 aaronmk
def last_insert_id(db):
496 1849 aaronmk
    module = util.root_module(db.db)
497 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
498
    elif module == 'MySQLdb': return db.insert_id()
499
    else: return None
500 13 aaronmk
501 1968 aaronmk
def truncate(db, table, schema='public'):
502
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
503 832 aaronmk
504
##### Database structure queries
505
506 2084 aaronmk
def pkey(db, table, recover=None, table_is_esc=False):
507 832 aaronmk
    '''Assumed to be first column in table'''
508 2120 aaronmk
    return col_names(select(db, table, limit=0, order_by=None, recover=recover,
509 2084 aaronmk
        table_is_esc=table_is_esc)).next()
510 832 aaronmk
511 853 aaronmk
def index_cols(db, table, index):
512
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
513
    automatically created. When you don't know whether something is a UNIQUE
514
    constraint or a UNIQUE index, use this function.'''
515
    check_name(table)
516
    check_name(index)
517 1909 aaronmk
    module = util.root_module(db.db)
518
    if module == 'psycopg2':
519
        return list(values(run_query(db, '''\
520 853 aaronmk
SELECT attname
521 866 aaronmk
FROM
522
(
523
        SELECT attnum, attname
524
        FROM pg_index
525
        JOIN pg_class index ON index.oid = indexrelid
526
        JOIN pg_class table_ ON table_.oid = indrelid
527
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
528
        WHERE
529
            table_.relname = %(table)s
530
            AND index.relname = %(index)s
531
    UNION
532
        SELECT attnum, attname
533
        FROM
534
        (
535
            SELECT
536
                indrelid
537
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
538
                    AS indkey
539
            FROM pg_index
540
            JOIN pg_class index ON index.oid = indexrelid
541
            JOIN pg_class table_ ON table_.oid = indrelid
542
            WHERE
543
                table_.relname = %(table)s
544
                AND index.relname = %(index)s
545
        ) s
546
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
547
) s
548 853 aaronmk
ORDER BY attnum
549
''',
550 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
551
    else: raise NotImplementedError("Can't list index columns for "+module+
552
        ' database')
553 853 aaronmk
554 464 aaronmk
def constraint_cols(db, table, constraint):
555
    check_name(table)
556
    check_name(constraint)
557 1849 aaronmk
    module = util.root_module(db.db)
558 464 aaronmk
    if module == 'psycopg2':
559
        return list(values(run_query(db, '''\
560
SELECT attname
561
FROM pg_constraint
562
JOIN pg_class ON pg_class.oid = conrelid
563
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
564
WHERE
565
    relname = %(table)s
566
    AND conname = %(constraint)s
567
ORDER BY attnum
568
''',
569
            {'table': table, 'constraint': constraint})))
570
    else: raise NotImplementedError("Can't list constraint columns for "+module+
571
        ' database')
572
573 2096 aaronmk
row_num_col = '_row_num'
574
575 2086 aaronmk
def add_row_num(db, table):
576 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
577
    be the primary key.'''
578 2086 aaronmk
    check_name(table)
579 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
580 2117 aaronmk
        +' serial NOT NULL PRIMARY KEY')
581 2086 aaronmk
582 1968 aaronmk
def tables(db, schema='public', table_like='%'):
583 1849 aaronmk
    module = util.root_module(db.db)
584 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
585 832 aaronmk
    if module == 'psycopg2':
586 1968 aaronmk
        return values(run_query(db, '''\
587
SELECT tablename
588
FROM pg_tables
589
WHERE
590
    schemaname = %(schema)s
591
    AND tablename LIKE %(table_like)s
592
ORDER BY tablename
593
''',
594
            params, cacheable=True))
595
    elif module == 'MySQLdb':
596
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
597
            cacheable=True))
598 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
599 830 aaronmk
600 833 aaronmk
##### Database management
601
602 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
603
    '''For kw_args, see tables()'''
604
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
605 833 aaronmk
606 832 aaronmk
##### Heuristic queries
607
608 2076 aaronmk
def with_parsed_errors(db, func):
609
    '''Translates known DB errors to typed exceptions'''
610
    try: return func()
611 46 aaronmk
    except Exception, e:
612
        msg = str(e)
613 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
614 2140 aaronmk
            r'"((_?[^\W_]+)_[^"]+)"', msg)
615 465 aaronmk
        if match:
616
            constraint, table = match.groups()
617 854 aaronmk
            try: cols = index_cols(db, table, constraint)
618 465 aaronmk
            except NotImplementedError: raise e
619 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
620 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
621
            'constraint', msg)
622 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
623 2147 aaronmk
        match = re.search(r'relation "(\w+)" already exists', msg)
624 2143 aaronmk
        if match: raise DuplicateTableException(match.group(1), e)
625 13 aaronmk
        raise # no specific exception raised
626 11 aaronmk
627 2076 aaronmk
def try_insert(db, table, row, returning=None):
628
    '''Recovers from errors'''
629
    return with_parsed_errors(db, lambda: insert(db, table, row, returning,
630
        recover=True))
631
632 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
633 1554 aaronmk
    '''Recovers from errors.
634 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
635
    '''
636 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
637
638 471 aaronmk
    try:
639 2104 aaronmk
        cur = try_insert(db, table, row, pkey_)
640 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
641
            row_ct_ref[0] += cur.rowcount
642
        return value(cur)
643 471 aaronmk
    except DuplicateKeyException, e:
644 2104 aaronmk
        return value(select(db, table, [pkey_],
645 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
646 471 aaronmk
647 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
648 830 aaronmk
    '''Recovers from errors'''
649
    try: return value(select(db, table, [pkey], row, 1, recover=True))
650 14 aaronmk
    except StopIteration:
651 40 aaronmk
        if not create: raise
652 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
653 2078 aaronmk
654 2134 aaronmk
def put_table(db, out_table, in_tables, mapping, limit=None, start=0,
655
    row_ct_ref=None, table_is_esc=False):
656 2078 aaronmk
    '''Recovers from errors.
657
    Only works under PostgreSQL (uses INSERT RETURNING).
658 2131 aaronmk
    @param in_tables The main input table to select from, followed by a list of
659
        tables to join with it using the main input table's pkey
660 2133 aaronmk
    @return (table, col) Where the pkeys (from INSERT RETURNING) are made
661 2078 aaronmk
        available
662
    '''
663 2142 aaronmk
    temp_prefix = '_'.join(map(clean_name,
664
        [out_table] + list(iters.flatten(mapping.items()))))
665
    pkeys = temp_prefix+'_pkeys'
666 2131 aaronmk
667 2132 aaronmk
    # Join together input tables
668 2131 aaronmk
    in_tables = in_tables[:] # don't modify input!
669
    in_tables0 = in_tables.pop(0) # first table is separate
670
    in_pkey = pkey(db, in_tables0, recover=True, table_is_esc=table_is_esc)
671 2142 aaronmk
    in_joins = [in_tables0] + [(t, {in_pkey: join_using}) for t in in_tables]
672 2131 aaronmk
673 2142 aaronmk
    out_pkey = pkey(db, out_table, recover=True, table_is_esc=table_is_esc)
674
    pkeys_cols = [in_pkey, out_pkey]
675
676 2132 aaronmk
    def mk_select_(cols):
677 2142 aaronmk
        return mk_select(db, in_joins, cols, limit=limit, start=start,
678 2134 aaronmk
            table_is_esc=table_is_esc)
679 2132 aaronmk
680 2142 aaronmk
    out_pkeys = temp_prefix+'_out_pkeys'
681 2078 aaronmk
    def insert_():
682 2132 aaronmk
        cur = insert_select(db, out_table, mapping.keys(),
683 2133 aaronmk
            *mk_select_(mapping.values()), returning=out_pkey,
684 2132 aaronmk
            into=out_pkeys, recover=True, table_is_esc=table_is_esc)
685 2078 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
686
            row_ct_ref[0] += cur.rowcount
687 2132 aaronmk
        add_row_num(db, out_pkeys) # for joining it with in_pkeys
688 2086 aaronmk
689 2132 aaronmk
        # Get input pkeys corresponding to rows in insert
690 2142 aaronmk
        in_pkeys = temp_prefix+'_in_pkeys'
691 2132 aaronmk
        run_query_into(db, *mk_select_([in_pkey]), into=in_pkeys)
692
        add_row_num(db, in_pkeys) # for joining it with out_pkeys
693 2086 aaronmk
694 2132 aaronmk
        # Join together out_pkeys and in_pkeys
695
        run_query_into(db, *mk_select(db,
696
            [in_pkeys, (out_pkeys, {row_num_col: join_using})],
697 2142 aaronmk
            pkeys_cols, start=0), into=pkeys)
698 2132 aaronmk
699 2142 aaronmk
    # Do inserts and selects
700 2132 aaronmk
    try:
701
        # Insert and capture output pkeys
702
        with_parsed_errors(db, insert_)
703 2142 aaronmk
    except DuplicateKeyException, e:
704
        join_cols = util.dict_subset_right_join(mapping, e.cols)
705
        joins = in_joins + [(out_table, join_cols)]
706
        run_query_into(db, *mk_select(db, joins, pkeys_cols,
707
            table_is_esc=table_is_esc), into=pkeys)
708
709
    return (pkeys, out_pkey)
710 2115 aaronmk
711
##### Data cleanup
712
713
def cleanup_table(db, table, cols, table_is_esc=False):
714
    def esc_name_(name): return esc_name(db, name)
715
716
    if not table_is_esc: check_name(table)
717
    cols = map(esc_name_, cols)
718
719
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
720
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
721
            for col in cols))),
722
        dict(null0='', null1=r'\N'))