Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1960 aaronmk
import lists
11 1889 aaronmk
from Proxy import Proxy
12 1872 aaronmk
import rand
13 862 aaronmk
import strings
14 131 aaronmk
import util
15 11 aaronmk
16 832 aaronmk
##### Exceptions
17
18 135 aaronmk
def get_cur_query(cur):
19
    if hasattr(cur, 'query'): return cur.query
20
    elif hasattr(cur, '_last_executed'): return cur._last_executed
21
    else: return None
22 14 aaronmk
23 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
24 135 aaronmk
25 300 aaronmk
class DbException(exc.ExceptionWithCause):
26 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
27 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
28 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
29
30 360 aaronmk
class NameException(DbException): pass
31
32 468 aaronmk
class ExceptionWithColumns(DbException):
33
    def __init__(self, cols, cause=None):
34
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
35
        self.cols = cols
36 11 aaronmk
37 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
38 13 aaronmk
39 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
40 13 aaronmk
41 89 aaronmk
class EmptyRowException(DbException): pass
42
43 865 aaronmk
##### Warnings
44
45
class DbWarning(UserWarning): pass
46
47 1930 aaronmk
##### Result retrieval
48
49
def col_names(cur): return (col[0] for col in cur.description)
50
51
def rows(cur): return iter(lambda: cur.fetchone(), None)
52
53
def consume_rows(cur):
54
    '''Used to fetch all rows so result will be cached'''
55
    iters.consume_iter(rows(cur))
56
57
def next_row(cur): return rows(cur).next()
58
59
def row(cur):
60
    row_ = next_row(cur)
61
    consume_rows(cur)
62
    return row_
63
64
def next_value(cur): return next_row(cur)[0]
65
66
def value(cur): return row(cur)[0]
67
68
def values(cur): return iters.func_iter(lambda: next_value(cur))
69
70
def value_or_none(cur):
71
    try: return value(cur)
72
    except StopIteration: return None
73
74 2101 aaronmk
##### Input validation
75
76
def clean_name(name): return re.sub(r'\W', r'', name)
77
78
def check_name(name):
79
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
80
        +'" may contain only alphanumeric characters and _')
81
82
def esc_name_by_module(module, name, ignore_case=False):
83
    if module == 'psycopg2':
84
        if ignore_case:
85
            # Don't enclose in quotes because this disables case-insensitivity
86
            check_name(name)
87
            return name
88
        else: quote = '"'
89
    elif module == 'MySQLdb': quote = '`'
90
    else: raise NotImplementedError("Can't escape name for "+module+' database')
91
    return quote + name.replace(quote, '') + quote
92
93
def esc_name_by_engine(engine, name, **kw_args):
94
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
95
96
def esc_name(db, name, **kw_args):
97
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
98
99
def qual_name(db, schema, table):
100
    def esc_name_(name): return esc_name(db, name)
101
    table = esc_name_(table)
102
    if schema != None: return esc_name_(schema)+'.'+table
103
    else: return table
104
105 1869 aaronmk
##### Database connections
106 1849 aaronmk
107 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
108 1926 aaronmk
109 1869 aaronmk
db_engines = {
110
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
111
    'PostgreSQL': ('psycopg2', {}),
112
}
113
114
DatabaseErrors_set = set([DbException])
115
DatabaseErrors = tuple(DatabaseErrors_set)
116
117
def _add_module(module):
118
    DatabaseErrors_set.add(module.DatabaseError)
119
    global DatabaseErrors
120
    DatabaseErrors = tuple(DatabaseErrors_set)
121
122
def db_config_str(db_config):
123
    return db_config['engine']+' database '+db_config['database']
124
125 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
126 1894 aaronmk
127 1901 aaronmk
log_debug_none = lambda msg: None
128
129 1849 aaronmk
class DbConn:
130 2047 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none,
131 2050 aaronmk
        caching=True):
132 1869 aaronmk
        self.db_config = db_config
133
        self.serializable = serializable
134 1901 aaronmk
        self.log_debug = log_debug
135 2047 aaronmk
        self.caching = caching
136 1869 aaronmk
137
        self.__db = None
138 1889 aaronmk
        self.query_results = {}
139 1869 aaronmk
140
    def __getattr__(self, name):
141
        if name == '__dict__': raise Exception('getting __dict__')
142
        if name == 'db': return self._db()
143
        else: raise AttributeError()
144
145
    def __getstate__(self):
146
        state = copy.copy(self.__dict__) # shallow copy
147 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
148 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
149
        return state
150
151
    def _db(self):
152
        if self.__db == None:
153
            # Process db_config
154
            db_config = self.db_config.copy() # don't modify input!
155 2097 aaronmk
            schemas = db_config.pop('schemas', None)
156 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
157
            module = __import__(module_name)
158
            _add_module(module)
159
            for orig, new in mappings.iteritems():
160
                try: util.rename_key(db_config, orig, new)
161
                except KeyError: pass
162
163
            # Connect
164
            self.__db = module.connect(**db_config)
165
166
            # Configure connection
167
            if self.serializable: run_raw_query(self,
168
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
169 2101 aaronmk
            if schemas != None:
170
                schemas_ = ''.join((esc_name(self, s)+', '
171
                    for s in schemas.split(',')))
172
                run_raw_query(self, "SELECT set_config('search_path', \
173
%s || current_setting('search_path'), false)", [schemas_])
174 1869 aaronmk
175
        return self.__db
176 1889 aaronmk
177 1891 aaronmk
    class DbCursor(Proxy):
178 1927 aaronmk
        def __init__(self, outer):
179 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
180 1927 aaronmk
            self.query_results = outer.query_results
181 1894 aaronmk
            self.query_lookup = None
182 1891 aaronmk
            self.result = []
183 1889 aaronmk
184 1894 aaronmk
        def execute(self, query, params=None):
185 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
186 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
187 1904 aaronmk
            try: return_value = self.inner.execute(query, params)
188
            except Exception, e:
189
                self.result = e # cache the exception as the result
190
                self._cache_result()
191
                raise
192
            finally: self.query = get_cur_query(self.inner)
193 1930 aaronmk
            # Fetch all rows so result will be cached
194
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
195 1894 aaronmk
            return return_value
196
197 1891 aaronmk
        def fetchone(self):
198
            row = self.inner.fetchone()
199 1899 aaronmk
            if row != None: self.result.append(row)
200
            # otherwise, fetched all rows
201 1904 aaronmk
            else: self._cache_result()
202
            return row
203
204
        def _cache_result(self):
205 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
206
            # idempotent, but an invalid insert will always be invalid
207 1930 aaronmk
            if self.query_results != None and (not self._is_insert
208 1906 aaronmk
                or isinstance(self.result, Exception)):
209
210 1894 aaronmk
                assert self.query_lookup != None
211 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
212
                    util.dict_subset(dicts.AttrsDictView(self),
213
                    ['query', 'result', 'rowcount', 'description']))
214 1906 aaronmk
215 1916 aaronmk
        class CacheCursor:
216
            def __init__(self, cached_result): self.__dict__ = cached_result
217
218 1927 aaronmk
            def execute(self, *args, **kw_args):
219 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
220
                # otherwise, result is a rows list
221
                self.iter = iter(self.result)
222
223
            def fetchone(self):
224
                try: return self.iter.next()
225
                except StopIteration: return None
226 1891 aaronmk
227 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
228 2047 aaronmk
        if not self.caching: cacheable = False
229 1903 aaronmk
        used_cache = False
230
        try:
231 1927 aaronmk
            # Get cursor
232
            if cacheable:
233
                query_lookup = _query_lookup(query, params)
234
                try:
235
                    cur = self.query_results[query_lookup]
236
                    used_cache = True
237
                except KeyError: cur = self.DbCursor(self)
238
            else: cur = self.db.cursor()
239
240
            # Run query
241
            try: cur.execute(query, params)
242
            except Exception, e:
243
                _add_cursor_info(e, cur)
244
                raise
245 1903 aaronmk
        finally:
246
            if self.log_debug != log_debug_none: # only compute msg if needed
247
                if used_cache: cache_status = 'Cache hit'
248
                elif cacheable: cache_status = 'Cache miss'
249
                else: cache_status = 'Non-cacheable'
250 1927 aaronmk
                self.log_debug(cache_status+': '
251
                    +strings.one_line(get_cur_query(cur)))
252 1903 aaronmk
253
        return cur
254 1914 aaronmk
255
    def is_cached(self, query, params=None):
256
        return _query_lookup(query, params) in self.query_results
257 1849 aaronmk
258 1869 aaronmk
connect = DbConn
259
260 832 aaronmk
##### Querying
261
262 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
263 2085 aaronmk
    '''For params, see DbConn.run_query()'''
264 1894 aaronmk
    return db.run_query(*args, **kw_args)
265 11 aaronmk
266 2068 aaronmk
def mogrify(db, query, params):
267
    module = util.root_module(db.db)
268
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
269
    else: raise NotImplementedError("Can't mogrify query for "+module+
270
        ' database')
271
272 832 aaronmk
##### Recoverable querying
273 15 aaronmk
274 11 aaronmk
def with_savepoint(db, func):
275 1872 aaronmk
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
276 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
277 11 aaronmk
    try: return_val = func()
278
    except:
279 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
280 11 aaronmk
        raise
281
    else:
282 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
283 11 aaronmk
        return return_val
284
285 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
286 830 aaronmk
    if recover == None: recover = False
287
288 1894 aaronmk
    def run(): return run_raw_query(db, query, params, cacheable)
289 1914 aaronmk
    if recover and not db.is_cached(query, params):
290
        return with_savepoint(db, run)
291
    else: return run() # don't need savepoint if cached
292 830 aaronmk
293 832 aaronmk
##### Basic queries
294
295 2085 aaronmk
def run_query_into(db, query, params, into=None, *args, **kw_args):
296
    '''Outputs a query to a temp table.
297
    For params, see run_query().
298
    '''
299
    if into == None: return run_query(db, query, params, *args, **kw_args)
300
    else: # place rows in temp table
301
        check_name(into)
302
303
        run_query(db, 'DROP TABLE IF EXISTS '+into+' CASCADE', *args, **kw_args)
304
        return run_query(db, 'CREATE TEMP TABLE '+into+' AS '+query, params,
305
            *args, **kw_args) # CREATE TABLE sets rowcount to # rows in query
306
307 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
308
309 2121 aaronmk
def mk_select(db, tables, fields=None, conds=None, limit=None, start=None,
310 2120 aaronmk
    order_by=order_by_pkey, table_is_esc=False):
311 1981 aaronmk
    '''
312 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
313
        together, in the form: [table0, (table1, joins_dict), ...]
314 1981 aaronmk
    @param fields Use None to select all fields in the table
315
    @param table_is_esc Whether the table name has already been escaped
316 2054 aaronmk
    @return tuple(query, params)
317 1981 aaronmk
    '''
318 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
319 2058 aaronmk
320 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
321
    tables = tables[:] # don't modify input!
322
    table0 = tables.pop(0) # first table is separate
323
324 1135 aaronmk
    if conds == None: conds = {}
325 135 aaronmk
    assert limit == None or type(limit) == int
326 865 aaronmk
    assert start == None or type(start) == int
327 2120 aaronmk
    if order_by == order_by_pkey:
328 2121 aaronmk
        order_by = pkey(db, table0, recover=True, table_is_esc=table_is_esc)
329
    if not table_is_esc: table0 = esc_name_(table0)
330 865 aaronmk
331 2056 aaronmk
    params = []
332
333
    def parse_col(field):
334
        '''Parses fields'''
335 2121 aaronmk
        if isinstance(field, tuple): # field is literal value
336 2056 aaronmk
            value, col = field
337
            sql_ = '%s'
338
            params.append(value)
339 2058 aaronmk
            if col != None: sql_ += ' AS '+esc_name_(col)
340
        else: sql_ = esc_name_(field) # field is col name
341 2056 aaronmk
        return sql_
342 11 aaronmk
    def cond(entry):
343 2056 aaronmk
        '''Parses conditions'''
344 13 aaronmk
        col, value = entry
345 2058 aaronmk
        cond_ = esc_name_(col)+' '
346 11 aaronmk
        if value == None: cond_ += 'IS'
347
        else: cond_ += '='
348
        cond_ += ' %s'
349
        return cond_
350 2056 aaronmk
351 1135 aaronmk
    query = 'SELECT '
352
    if fields == None: query += '*'
353 2056 aaronmk
    else: query += ', '.join(map(parse_col, fields))
354 2121 aaronmk
    query += ' FROM '+table0
355 865 aaronmk
356 2122 aaronmk
    # Add joins
357
    left_table = table0
358
    for table, joins in tables:
359
        if not table_is_esc: table = esc_name_(table)
360
361
        def join(entry):
362
            '''Parses joins'''
363
            left_col, right_col = entry
364
            left_col = left_table+'.'+esc_name_(left_col)
365
            right_col = table+'.'+esc_name_(right_col)
366 2123 aaronmk
            return (right_col+' = '+left_col
367
                +' OR ('+right_col+' IS NULL AND '+left_col+' IS NULL)')
368 2122 aaronmk
369
        query += ' JOIN '+table+' ON '+(
370
            ' AND '.join(map(join, joins.iteritems())))
371
        left_table = table
372
373 865 aaronmk
    missing = True
374 89 aaronmk
    if conds != {}:
375 2122 aaronmk
        query += ' WHERE '+(' AND '.join(map(cond, conds.iteritems())))
376 2056 aaronmk
        params += conds.values()
377 865 aaronmk
        missing = False
378 2120 aaronmk
    if order_by != None: query += ' ORDER BY '+esc_name_(order_by)
379 865 aaronmk
    if limit != None: query += ' LIMIT '+str(limit); missing = False
380
    if start != None:
381
        if start != 0: query += ' OFFSET '+str(start)
382
        missing = False
383
    if missing: warnings.warn(DbWarning(
384
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
385
386 2056 aaronmk
    return (query, params)
387 11 aaronmk
388 2054 aaronmk
def select(db, *args, **kw_args):
389
    '''For params, see mk_select() and run_query()'''
390
    recover = kw_args.pop('recover', None)
391
    cacheable = kw_args.pop('cacheable', True)
392
393
    query, params = mk_select(db, *args, **kw_args)
394
    return run_query(db, query, params, recover, cacheable)
395
396 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
397 2070 aaronmk
    returning=None, embeddable=False, table_is_esc=False):
398 1960 aaronmk
    '''
399
    @param returning str|None An inserted column (such as pkey) to return
400 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
401 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
402
        query will be fully cached, not just if it raises an exception.
403 1960 aaronmk
    @param table_is_esc Whether the table name has already been escaped
404
    '''
405 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
406
    if cols == []: cols = None # no cols (all defaults) = unknown col names
407 1960 aaronmk
    if not table_is_esc: check_name(table)
408 2063 aaronmk
409
    # Build query
410
    query = 'INSERT INTO '+table
411
    if cols != None:
412
        map(check_name, cols)
413
        query += ' ('+', '.join(cols)+')'
414
    query += ' '+select_query
415
416
    if returning != None:
417
        check_name(returning)
418
        query += ' RETURNING '+returning
419
420 2070 aaronmk
    if embeddable:
421
        # Create function
422 2083 aaronmk
        function = 'pg_temp.'+('_'.join(map(clean_name,
423
            ['insert', table] + cols)))
424 2070 aaronmk
        return_type = 'SETOF '+table+'.'+returning+'%TYPE'
425
        function_query = '''\
426 2083 aaronmk
CREATE OR REPLACE FUNCTION '''+function+'''() RETURNS '''+return_type+'''
427 2070 aaronmk
    LANGUAGE sql
428
    AS $$'''+mogrify(db, query, params)+''';$$;
429
'''
430
        run_query(db, function_query, cacheable=True)
431
432
        # Return query that uses function
433 2083 aaronmk
        return mk_select(db, function+'() AS f ('+returning+')',
434 2080 aaronmk
            table_is_esc=True) # function alias is required in AS clause
435 2070 aaronmk
436 2066 aaronmk
    return (query, params)
437
438
def insert_select(db, *args, **kw_args):
439 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
440 2072 aaronmk
    @param into Name of temp table to place RETURNING values in
441
    '''
442
    into = kw_args.pop('into', None)
443
    if into != None: kw_args['embeddable'] = True
444 2066 aaronmk
    recover = kw_args.pop('recover', None)
445
    cacheable = kw_args.pop('cacheable', True)
446
447
    query, params = mk_insert_select(db, *args, **kw_args)
448 2085 aaronmk
    return run_query_into(db, query, params, into, recover, cacheable)
449 2063 aaronmk
450 2066 aaronmk
default = object() # tells insert() to use the default value for a column
451
452 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
453 2085 aaronmk
    '''For params, see insert_select()'''
454 1960 aaronmk
    if lists.is_seq(row): cols = None
455
    else:
456
        cols = row.keys()
457
        row = row.values()
458
    row = list(row) # ensure that "!= []" works
459
460 1961 aaronmk
    # Check for special values
461
    labels = []
462
    values = []
463
    for value in row:
464
        if value == default: labels.append('DEFAULT')
465
        else:
466
            labels.append('%s')
467
            values.append(value)
468
469
    # Build query
470 2063 aaronmk
    if values != []: query = ' VALUES ('+(', '.join(labels))+')'
471
    else: query = None
472 1554 aaronmk
473 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
474 11 aaronmk
475 135 aaronmk
def last_insert_id(db):
476 1849 aaronmk
    module = util.root_module(db.db)
477 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
478
    elif module == 'MySQLdb': return db.insert_id()
479
    else: return None
480 13 aaronmk
481 1968 aaronmk
def truncate(db, table, schema='public'):
482
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
483 832 aaronmk
484
##### Database structure queries
485
486 2084 aaronmk
def pkey(db, table, recover=None, table_is_esc=False):
487 832 aaronmk
    '''Assumed to be first column in table'''
488 2120 aaronmk
    return col_names(select(db, table, limit=0, order_by=None, recover=recover,
489 2084 aaronmk
        table_is_esc=table_is_esc)).next()
490 832 aaronmk
491 853 aaronmk
def index_cols(db, table, index):
492
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
493
    automatically created. When you don't know whether something is a UNIQUE
494
    constraint or a UNIQUE index, use this function.'''
495
    check_name(table)
496
    check_name(index)
497 1909 aaronmk
    module = util.root_module(db.db)
498
    if module == 'psycopg2':
499
        return list(values(run_query(db, '''\
500 853 aaronmk
SELECT attname
501 866 aaronmk
FROM
502
(
503
        SELECT attnum, attname
504
        FROM pg_index
505
        JOIN pg_class index ON index.oid = indexrelid
506
        JOIN pg_class table_ ON table_.oid = indrelid
507
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
508
        WHERE
509
            table_.relname = %(table)s
510
            AND index.relname = %(index)s
511
    UNION
512
        SELECT attnum, attname
513
        FROM
514
        (
515
            SELECT
516
                indrelid
517
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
518
                    AS indkey
519
            FROM pg_index
520
            JOIN pg_class index ON index.oid = indexrelid
521
            JOIN pg_class table_ ON table_.oid = indrelid
522
            WHERE
523
                table_.relname = %(table)s
524
                AND index.relname = %(index)s
525
        ) s
526
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
527
) s
528 853 aaronmk
ORDER BY attnum
529
''',
530 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
531
    else: raise NotImplementedError("Can't list index columns for "+module+
532
        ' database')
533 853 aaronmk
534 464 aaronmk
def constraint_cols(db, table, constraint):
535
    check_name(table)
536
    check_name(constraint)
537 1849 aaronmk
    module = util.root_module(db.db)
538 464 aaronmk
    if module == 'psycopg2':
539
        return list(values(run_query(db, '''\
540
SELECT attname
541
FROM pg_constraint
542
JOIN pg_class ON pg_class.oid = conrelid
543
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
544
WHERE
545
    relname = %(table)s
546
    AND conname = %(constraint)s
547
ORDER BY attnum
548
''',
549
            {'table': table, 'constraint': constraint})))
550
    else: raise NotImplementedError("Can't list constraint columns for "+module+
551
        ' database')
552
553 2096 aaronmk
row_num_col = '_row_num'
554
555 2086 aaronmk
def add_row_num(db, table):
556 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
557
    be the primary key.'''
558 2086 aaronmk
    check_name(table)
559 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
560 2117 aaronmk
        +' serial NOT NULL PRIMARY KEY')
561 2086 aaronmk
562 1968 aaronmk
def tables(db, schema='public', table_like='%'):
563 1849 aaronmk
    module = util.root_module(db.db)
564 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
565 832 aaronmk
    if module == 'psycopg2':
566 1968 aaronmk
        return values(run_query(db, '''\
567
SELECT tablename
568
FROM pg_tables
569
WHERE
570
    schemaname = %(schema)s
571
    AND tablename LIKE %(table_like)s
572
ORDER BY tablename
573
''',
574
            params, cacheable=True))
575
    elif module == 'MySQLdb':
576
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
577
            cacheable=True))
578 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
579 830 aaronmk
580 833 aaronmk
##### Database management
581
582 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
583
    '''For kw_args, see tables()'''
584
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
585 833 aaronmk
586 832 aaronmk
##### Heuristic queries
587
588 2076 aaronmk
def with_parsed_errors(db, func):
589
    '''Translates known DB errors to typed exceptions'''
590
    try: return func()
591 46 aaronmk
    except Exception, e:
592
        msg = str(e)
593 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
594
            r'"(([^\W_]+)_[^"]+)"', msg)
595
        if match:
596
            constraint, table = match.groups()
597 854 aaronmk
            try: cols = index_cols(db, table, constraint)
598 465 aaronmk
            except NotImplementedError: raise e
599 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
600 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
601
            'constraint', msg)
602 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
603 13 aaronmk
        raise # no specific exception raised
604 11 aaronmk
605 2076 aaronmk
def try_insert(db, table, row, returning=None):
606
    '''Recovers from errors'''
607
    return with_parsed_errors(db, lambda: insert(db, table, row, returning,
608
        recover=True))
609
610 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
611 1554 aaronmk
    '''Recovers from errors.
612 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
613
    '''
614 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
615
616 471 aaronmk
    try:
617 2104 aaronmk
        cur = try_insert(db, table, row, pkey_)
618 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
619
            row_ct_ref[0] += cur.rowcount
620
        return value(cur)
621 471 aaronmk
    except DuplicateKeyException, e:
622 2104 aaronmk
        return value(select(db, table, [pkey_],
623 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
624 471 aaronmk
625 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
626 830 aaronmk
    '''Recovers from errors'''
627
    try: return value(select(db, table, [pkey], row, 1, recover=True))
628 14 aaronmk
    except StopIteration:
629 40 aaronmk
        if not create: raise
630 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
631 2078 aaronmk
632 2087 aaronmk
def put_table(db, out_table, out_cols, in_tables, in_cols, pkey,
633
    row_ct_ref=None, table_is_esc=False):
634 2078 aaronmk
    '''Recovers from errors.
635
    Only works under PostgreSQL (uses INSERT RETURNING).
636 2081 aaronmk
    @return Name of the table where the pkeys (from INSERT RETURNING) are made
637 2078 aaronmk
        available
638
    '''
639
    pkeys_table = clean_name(out_table)+'_pkeys'
640
    def insert_():
641
        return insert_select(db, out_table, out_cols,
642 2087 aaronmk
            *mk_select(db, in_tables[0], in_cols, table_is_esc=table_is_esc),
643 2078 aaronmk
            returning=pkey, into=pkeys_table, recover=True,
644
            table_is_esc=table_is_esc)
645
    try:
646
        cur = with_parsed_errors(db, insert_)
647
        if row_ct_ref != None and cur.rowcount >= 0:
648
            row_ct_ref[0] += cur.rowcount
649 2086 aaronmk
650
        # Add row_num to pkeys_table, so it can be joined with in_table's pkeys
651
        add_row_num(db, pkeys_table)
652
653 2081 aaronmk
        return pkeys_table
654 2078 aaronmk
    except DuplicateKeyException, e: raise
655 2115 aaronmk
656
##### Data cleanup
657
658
def cleanup_table(db, table, cols, table_is_esc=False):
659
    def esc_name_(name): return esc_name(db, name)
660
661
    if not table_is_esc: check_name(table)
662
    cols = map(esc_name_, cols)
663
664
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
665
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
666
            for col in cols))),
667
        dict(null0='', null1=r'\N'))