Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 2127 aaronmk
import operator
5 11 aaronmk
import re
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 1889 aaronmk
from Proxy import Proxy
13 1872 aaronmk
import rand
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 2168 aaronmk
def get_cur_query(cur, input_query=None):
20
    raw_query = None
21
    if hasattr(cur, 'query'): raw_query = cur.query
22
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
23
    return util.coalesce(raw_query, input_query)
24 14 aaronmk
25 2168 aaronmk
def _add_cursor_info(e, cur, input_query=None):
26
    exc.add_msg(e, 'query: '+str(get_cur_query(cur, input_query)))
27 135 aaronmk
28 300 aaronmk
class DbException(exc.ExceptionWithCause):
29 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
30 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
31 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
32
33 2143 aaronmk
class ExceptionWithName(DbException):
34
    def __init__(self, name, cause=None):
35 2145 aaronmk
        DbException.__init__(self, 'for name: '+str(name), cause)
36 2143 aaronmk
        self.name = name
37 360 aaronmk
38 468 aaronmk
class ExceptionWithColumns(DbException):
39
    def __init__(self, cols, cause=None):
40 2145 aaronmk
        DbException.__init__(self, 'for columns: '+(', '.join(cols)), cause)
41 468 aaronmk
        self.cols = cols
42 11 aaronmk
43 2143 aaronmk
class NameException(DbException): pass
44
45 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
46 13 aaronmk
47 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
48 13 aaronmk
49 2143 aaronmk
class DuplicateTableException(ExceptionWithName): pass
50
51 89 aaronmk
class EmptyRowException(DbException): pass
52
53 865 aaronmk
##### Warnings
54
55
class DbWarning(UserWarning): pass
56
57 1930 aaronmk
##### Result retrieval
58
59
def col_names(cur): return (col[0] for col in cur.description)
60
61
def rows(cur): return iter(lambda: cur.fetchone(), None)
62
63
def consume_rows(cur):
64
    '''Used to fetch all rows so result will be cached'''
65
    iters.consume_iter(rows(cur))
66
67
def next_row(cur): return rows(cur).next()
68
69
def row(cur):
70
    row_ = next_row(cur)
71
    consume_rows(cur)
72
    return row_
73
74
def next_value(cur): return next_row(cur)[0]
75
76
def value(cur): return row(cur)[0]
77
78
def values(cur): return iters.func_iter(lambda: next_value(cur))
79
80
def value_or_none(cur):
81
    try: return value(cur)
82
    except StopIteration: return None
83
84 2101 aaronmk
##### Input validation
85
86
def clean_name(name): return re.sub(r'\W', r'', name)
87
88
def check_name(name):
89
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
90
        +'" may contain only alphanumeric characters and _')
91
92
def esc_name_by_module(module, name, ignore_case=False):
93
    if module == 'psycopg2':
94
        if ignore_case:
95
            # Don't enclose in quotes because this disables case-insensitivity
96
            check_name(name)
97
            return name
98
        else: quote = '"'
99
    elif module == 'MySQLdb': quote = '`'
100
    else: raise NotImplementedError("Can't escape name for "+module+' database')
101
    return quote + name.replace(quote, '') + quote
102
103
def esc_name_by_engine(engine, name, **kw_args):
104
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
105
106
def esc_name(db, name, **kw_args):
107
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
108
109
def qual_name(db, schema, table):
110
    def esc_name_(name): return esc_name(db, name)
111
    table = esc_name_(table)
112
    if schema != None: return esc_name_(schema)+'.'+table
113
    else: return table
114
115 1869 aaronmk
##### Database connections
116 1849 aaronmk
117 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
118 1926 aaronmk
119 1869 aaronmk
db_engines = {
120
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
121
    'PostgreSQL': ('psycopg2', {}),
122
}
123
124
DatabaseErrors_set = set([DbException])
125
DatabaseErrors = tuple(DatabaseErrors_set)
126
127
def _add_module(module):
128
    DatabaseErrors_set.add(module.DatabaseError)
129
    global DatabaseErrors
130
    DatabaseErrors = tuple(DatabaseErrors_set)
131
132
def db_config_str(db_config):
133
    return db_config['engine']+' database '+db_config['database']
134
135 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
136 1894 aaronmk
137 1901 aaronmk
log_debug_none = lambda msg: None
138
139 1849 aaronmk
class DbConn:
140 2047 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none,
141 2050 aaronmk
        caching=True):
142 1869 aaronmk
        self.db_config = db_config
143
        self.serializable = serializable
144 1901 aaronmk
        self.log_debug = log_debug
145 2047 aaronmk
        self.caching = caching
146 1869 aaronmk
147
        self.__db = None
148 1889 aaronmk
        self.query_results = {}
149 2139 aaronmk
        self._savepoint = 0
150 1869 aaronmk
151
    def __getattr__(self, name):
152
        if name == '__dict__': raise Exception('getting __dict__')
153
        if name == 'db': return self._db()
154
        else: raise AttributeError()
155
156
    def __getstate__(self):
157
        state = copy.copy(self.__dict__) # shallow copy
158 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
159 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
160
        return state
161
162 2165 aaronmk
    def connected(self): return self.__db != None
163
164 1869 aaronmk
    def _db(self):
165
        if self.__db == None:
166
            # Process db_config
167
            db_config = self.db_config.copy() # don't modify input!
168 2097 aaronmk
            schemas = db_config.pop('schemas', None)
169 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
170
            module = __import__(module_name)
171
            _add_module(module)
172
            for orig, new in mappings.iteritems():
173
                try: util.rename_key(db_config, orig, new)
174
                except KeyError: pass
175
176
            # Connect
177
            self.__db = module.connect(**db_config)
178
179
            # Configure connection
180
            if self.serializable: run_raw_query(self,
181
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
182 2101 aaronmk
            if schemas != None:
183
                schemas_ = ''.join((esc_name(self, s)+', '
184
                    for s in schemas.split(',')))
185
                run_raw_query(self, "SELECT set_config('search_path', \
186
%s || current_setting('search_path'), false)", [schemas_])
187 1869 aaronmk
188
        return self.__db
189 1889 aaronmk
190 1891 aaronmk
    class DbCursor(Proxy):
191 1927 aaronmk
        def __init__(self, outer):
192 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
193 1927 aaronmk
            self.query_results = outer.query_results
194 1894 aaronmk
            self.query_lookup = None
195 1891 aaronmk
            self.result = []
196 1889 aaronmk
197 1894 aaronmk
        def execute(self, query, params=None):
198 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
199 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
200 2148 aaronmk
            try:
201
                try: return_value = self.inner.execute(query, params)
202
                finally: self.query = get_cur_query(self.inner)
203 1904 aaronmk
            except Exception, e:
204 2168 aaronmk
                _add_cursor_info(e, self, query)
205 1904 aaronmk
                self.result = e # cache the exception as the result
206
                self._cache_result()
207
                raise
208 1930 aaronmk
            # Fetch all rows so result will be cached
209
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
210 1894 aaronmk
            return return_value
211
212 1891 aaronmk
        def fetchone(self):
213
            row = self.inner.fetchone()
214 1899 aaronmk
            if row != None: self.result.append(row)
215
            # otherwise, fetched all rows
216 1904 aaronmk
            else: self._cache_result()
217
            return row
218
219
        def _cache_result(self):
220 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
221
            # idempotent, but an invalid insert will always be invalid
222 1930 aaronmk
            if self.query_results != None and (not self._is_insert
223 1906 aaronmk
                or isinstance(self.result, Exception)):
224
225 1894 aaronmk
                assert self.query_lookup != None
226 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
227
                    util.dict_subset(dicts.AttrsDictView(self),
228
                    ['query', 'result', 'rowcount', 'description']))
229 1906 aaronmk
230 1916 aaronmk
        class CacheCursor:
231
            def __init__(self, cached_result): self.__dict__ = cached_result
232
233 1927 aaronmk
            def execute(self, *args, **kw_args):
234 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
235
                # otherwise, result is a rows list
236
                self.iter = iter(self.result)
237
238
            def fetchone(self):
239
                try: return self.iter.next()
240
                except StopIteration: return None
241 1891 aaronmk
242 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
243 2148 aaronmk
        '''Translates known DB errors to typed exceptions:
244
        See self.DbCursor.execute().'''
245 2167 aaronmk
        assert query != None
246
247 2047 aaronmk
        if not self.caching: cacheable = False
248 1903 aaronmk
        used_cache = False
249
        try:
250 1927 aaronmk
            # Get cursor
251
            if cacheable:
252
                query_lookup = _query_lookup(query, params)
253
                try:
254
                    cur = self.query_results[query_lookup]
255
                    used_cache = True
256
                except KeyError: cur = self.DbCursor(self)
257
            else: cur = self.db.cursor()
258
259
            # Run query
260 2148 aaronmk
            cur.execute(query, params)
261 1903 aaronmk
        finally:
262
            if self.log_debug != log_debug_none: # only compute msg if needed
263
                if used_cache: cache_status = 'Cache hit'
264
                elif cacheable: cache_status = 'Cache miss'
265
                else: cache_status = 'Non-cacheable'
266 1927 aaronmk
                self.log_debug(cache_status+': '
267 2164 aaronmk
                    +strings.one_line(str(get_cur_query(cur))))
268 1903 aaronmk
269
        return cur
270 1914 aaronmk
271
    def is_cached(self, query, params=None):
272
        return _query_lookup(query, params) in self.query_results
273 2139 aaronmk
274
    def with_savepoint(self, func):
275
        savepoint = 'savepoint_'+str(self._savepoint)
276
        self.run_query('SAVEPOINT '+savepoint)
277
        self._savepoint += 1
278
        try:
279
            try: return_val = func()
280
            finally:
281
                self._savepoint -= 1
282
                assert self._savepoint >= 0
283
        except:
284
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint)
285
            raise
286
        else:
287
            self.run_query('RELEASE SAVEPOINT '+savepoint)
288
            return return_val
289 1849 aaronmk
290 1869 aaronmk
connect = DbConn
291
292 832 aaronmk
##### Querying
293
294 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
295 2085 aaronmk
    '''For params, see DbConn.run_query()'''
296 1894 aaronmk
    return db.run_query(*args, **kw_args)
297 11 aaronmk
298 2068 aaronmk
def mogrify(db, query, params):
299
    module = util.root_module(db.db)
300
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
301
    else: raise NotImplementedError("Can't mogrify query for "+module+
302
        ' database')
303
304 832 aaronmk
##### Recoverable querying
305 15 aaronmk
306 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
307 11 aaronmk
308 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
309 830 aaronmk
    if recover == None: recover = False
310
311 2148 aaronmk
    try:
312
        def run(): return run_raw_query(db, query, params, cacheable)
313
        if recover and not db.is_cached(query, params):
314
            return with_savepoint(db, run)
315
        else: return run() # don't need savepoint if cached
316
    except Exception, e:
317
        if not recover: raise # need savepoint to run index_cols()
318
        msg = str(e)
319
        match = re.search(r'duplicate key value violates unique constraint '
320
            r'"((_?[^\W_]+)_[^"]+)"', msg)
321
        if match:
322
            constraint, table = match.groups()
323
            try: cols = index_cols(db, table, constraint)
324
            except NotImplementedError: raise e
325
            else: raise DuplicateKeyException(cols, e)
326
        match = re.search(r'null value in column "(\w+)" violates not-null '
327
            'constraint', msg)
328
        if match: raise NullValueException([match.group(1)], e)
329
        match = re.search(r'relation "(\w+)" already exists', msg)
330
        if match: raise DuplicateTableException(match.group(1), e)
331
        raise # no specific exception raised
332 830 aaronmk
333 832 aaronmk
##### Basic queries
334
335 2153 aaronmk
def next_version(name):
336
    '''Prepends the version # so it won't be removed if the name is truncated'''
337 2163 aaronmk
    version = 1 # first existing name was version 0
338 2153 aaronmk
    match = re.match(r'^v(\d+)_(.*)$', name)
339
    if match:
340
        version = int(match.group(1))+1
341
        name = match.group(2)
342
    return 'v'+str(version)+'_'+name
343
344 2151 aaronmk
def run_query_into(db, query, params, into_ref=None, *args, **kw_args):
345 2085 aaronmk
    '''Outputs a query to a temp table.
346
    For params, see run_query().
347
    '''
348 2151 aaronmk
    if into_ref == None: return run_query(db, query, params, *args, **kw_args)
349 2085 aaronmk
    else: # place rows in temp table
350 2151 aaronmk
        check_name(into_ref[0])
351 2153 aaronmk
        kw_args['recover'] = True
352
        while True:
353
            try:
354
                return run_query(db, 'CREATE TEMP TABLE '+into_ref[0]+' AS '
355
                    +query, params, *args, **kw_args)
356
                    # CREATE TABLE AS sets rowcount to # rows in query
357
            except DuplicateTableException, e:
358
                into_ref[0] = next_version(into_ref[0])
359
                # try again with next version of name
360 2085 aaronmk
361 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
362
363 2127 aaronmk
join_using = object() # tells mk_select() to join the column with USING
364
365 2121 aaronmk
def mk_select(db, tables, fields=None, conds=None, limit=None, start=None,
366 2120 aaronmk
    order_by=order_by_pkey, table_is_esc=False):
367 1981 aaronmk
    '''
368 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
369 2124 aaronmk
        together: [table0, (table1, dict(right_col=left_col, ...)), ...]
370 1981 aaronmk
    @param fields Use None to select all fields in the table
371
    @param table_is_esc Whether the table name has already been escaped
372 2054 aaronmk
    @return tuple(query, params)
373 1981 aaronmk
    '''
374 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
375 2058 aaronmk
376 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
377 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
378 2121 aaronmk
    table0 = tables.pop(0) # first table is separate
379
380 1135 aaronmk
    if conds == None: conds = {}
381 135 aaronmk
    assert limit == None or type(limit) == int
382 865 aaronmk
    assert start == None or type(start) == int
383 2120 aaronmk
    if order_by == order_by_pkey:
384 2121 aaronmk
        order_by = pkey(db, table0, recover=True, table_is_esc=table_is_esc)
385
    if not table_is_esc: table0 = esc_name_(table0)
386 865 aaronmk
387 2056 aaronmk
    params = []
388
389 2161 aaronmk
    def parse_col(field, default_table=None):
390 2056 aaronmk
        '''Parses fields'''
391 2157 aaronmk
        is_tuple = isinstance(field, tuple)
392
        if is_tuple and len(field) == 1: # field is literal value
393
            value, = field
394 2056 aaronmk
            sql_ = '%s'
395
            params.append(value)
396 2157 aaronmk
        elif is_tuple and len(field) == 2: # field is col with table
397
            table, col = field
398
            if not table_is_esc: table = esc_name_(table)
399
            sql_ = table+'.'+esc_name_(col)
400 2161 aaronmk
        else:
401
            sql_ = esc_name_(field) # field is col name
402
            if default_table != None: sql_ = default_table+'.'+sql_
403 2056 aaronmk
        return sql_
404 11 aaronmk
    def cond(entry):
405 2056 aaronmk
        '''Parses conditions'''
406 13 aaronmk
        col, value = entry
407 2058 aaronmk
        cond_ = esc_name_(col)+' '
408 11 aaronmk
        if value == None: cond_ += 'IS'
409
        else: cond_ += '='
410
        cond_ += ' %s'
411
        return cond_
412 2056 aaronmk
413 1135 aaronmk
    query = 'SELECT '
414
    if fields == None: query += '*'
415 2056 aaronmk
    else: query += ', '.join(map(parse_col, fields))
416 2121 aaronmk
    query += ' FROM '+table0
417 865 aaronmk
418 2122 aaronmk
    # Add joins
419
    left_table = table0
420
    for table, joins in tables:
421
        if not table_is_esc: table = esc_name_(table)
422 2127 aaronmk
        query += ' JOIN '+table
423 2122 aaronmk
424
        def join(entry):
425 2127 aaronmk
            '''Parses non-USING joins'''
426 2124 aaronmk
            right_col, left_col = entry
427
            right_col = table+'.'+esc_name_(right_col)
428 2161 aaronmk
            left_col = parse_col(left_col, left_table)
429 2123 aaronmk
            return (right_col+' = '+left_col
430
                +' OR ('+right_col+' IS NULL AND '+left_col+' IS NULL)')
431 2122 aaronmk
432 2127 aaronmk
        if reduce(operator.and_, (v == join_using for v in joins.itervalues())):
433
            # all cols w/ USING
434 2130 aaronmk
            query += ' USING ('+(', '.join(joins.iterkeys()))+')'
435 2127 aaronmk
        else: query += ' ON '+(' AND '.join(map(join, joins.iteritems())))
436
437 2122 aaronmk
        left_table = table
438
439 865 aaronmk
    missing = True
440 89 aaronmk
    if conds != {}:
441 2122 aaronmk
        query += ' WHERE '+(' AND '.join(map(cond, conds.iteritems())))
442 2056 aaronmk
        params += conds.values()
443 865 aaronmk
        missing = False
444 2120 aaronmk
    if order_by != None: query += ' ORDER BY '+esc_name_(order_by)
445 865 aaronmk
    if limit != None: query += ' LIMIT '+str(limit); missing = False
446
    if start != None:
447
        if start != 0: query += ' OFFSET '+str(start)
448
        missing = False
449
    if missing: warnings.warn(DbWarning(
450
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
451
452 2056 aaronmk
    return (query, params)
453 11 aaronmk
454 2054 aaronmk
def select(db, *args, **kw_args):
455
    '''For params, see mk_select() and run_query()'''
456
    recover = kw_args.pop('recover', None)
457
    cacheable = kw_args.pop('cacheable', True)
458
459
    query, params = mk_select(db, *args, **kw_args)
460
    return run_query(db, query, params, recover, cacheable)
461
462 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
463 2070 aaronmk
    returning=None, embeddable=False, table_is_esc=False):
464 1960 aaronmk
    '''
465
    @param returning str|None An inserted column (such as pkey) to return
466 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
467 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
468
        query will be fully cached, not just if it raises an exception.
469 1960 aaronmk
    @param table_is_esc Whether the table name has already been escaped
470
    '''
471 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
472
    if cols == []: cols = None # no cols (all defaults) = unknown col names
473 1960 aaronmk
    if not table_is_esc: check_name(table)
474 2063 aaronmk
475
    # Build query
476
    query = 'INSERT INTO '+table
477
    if cols != None:
478
        map(check_name, cols)
479
        query += ' ('+', '.join(cols)+')'
480
    query += ' '+select_query
481
482
    if returning != None:
483
        check_name(returning)
484
        query += ' RETURNING '+returning
485
486 2070 aaronmk
    if embeddable:
487
        # Create function
488 2083 aaronmk
        function = 'pg_temp.'+('_'.join(map(clean_name,
489
            ['insert', table] + cols)))
490 2070 aaronmk
        return_type = 'SETOF '+table+'.'+returning+'%TYPE'
491
        function_query = '''\
492 2083 aaronmk
CREATE OR REPLACE FUNCTION '''+function+'''() RETURNS '''+return_type+'''
493 2070 aaronmk
    LANGUAGE sql
494
    AS $$'''+mogrify(db, query, params)+''';$$;
495
'''
496
        run_query(db, function_query, cacheable=True)
497
498
        # Return query that uses function
499 2134 aaronmk
        return mk_select(db, function+'() AS f ('+returning+')', start=0,
500
            order_by=None, table_is_esc=True)# AS clause requires function alias
501 2070 aaronmk
502 2066 aaronmk
    return (query, params)
503
504
def insert_select(db, *args, **kw_args):
505 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
506 2152 aaronmk
    @param into_ref List with name of temp table to place RETURNING values in
507 2072 aaronmk
    '''
508 2151 aaronmk
    into_ref = kw_args.pop('into_ref', None)
509
    if into_ref != None: kw_args['embeddable'] = True
510 2066 aaronmk
    recover = kw_args.pop('recover', None)
511
    cacheable = kw_args.pop('cacheable', True)
512
513
    query, params = mk_insert_select(db, *args, **kw_args)
514 2153 aaronmk
    return run_query_into(db, query, params, into_ref, recover=recover,
515
        cacheable=cacheable)
516 2063 aaronmk
517 2066 aaronmk
default = object() # tells insert() to use the default value for a column
518
519 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
520 2085 aaronmk
    '''For params, see insert_select()'''
521 1960 aaronmk
    if lists.is_seq(row): cols = None
522
    else:
523
        cols = row.keys()
524
        row = row.values()
525
    row = list(row) # ensure that "!= []" works
526
527 1961 aaronmk
    # Check for special values
528
    labels = []
529
    values = []
530
    for value in row:
531
        if value == default: labels.append('DEFAULT')
532
        else:
533
            labels.append('%s')
534
            values.append(value)
535
536
    # Build query
537 2063 aaronmk
    if values != []: query = ' VALUES ('+(', '.join(labels))+')'
538
    else: query = None
539 1554 aaronmk
540 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
541 11 aaronmk
542 135 aaronmk
def last_insert_id(db):
543 1849 aaronmk
    module = util.root_module(db.db)
544 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
545
    elif module == 'MySQLdb': return db.insert_id()
546
    else: return None
547 13 aaronmk
548 1968 aaronmk
def truncate(db, table, schema='public'):
549
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
550 832 aaronmk
551
##### Database structure queries
552
553 2084 aaronmk
def pkey(db, table, recover=None, table_is_esc=False):
554 832 aaronmk
    '''Assumed to be first column in table'''
555 2120 aaronmk
    return col_names(select(db, table, limit=0, order_by=None, recover=recover,
556 2084 aaronmk
        table_is_esc=table_is_esc)).next()
557 832 aaronmk
558 853 aaronmk
def index_cols(db, table, index):
559
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
560
    automatically created. When you don't know whether something is a UNIQUE
561
    constraint or a UNIQUE index, use this function.'''
562
    check_name(table)
563
    check_name(index)
564 1909 aaronmk
    module = util.root_module(db.db)
565
    if module == 'psycopg2':
566
        return list(values(run_query(db, '''\
567 853 aaronmk
SELECT attname
568 866 aaronmk
FROM
569
(
570
        SELECT attnum, attname
571
        FROM pg_index
572
        JOIN pg_class index ON index.oid = indexrelid
573
        JOIN pg_class table_ ON table_.oid = indrelid
574
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
575
        WHERE
576
            table_.relname = %(table)s
577
            AND index.relname = %(index)s
578
    UNION
579
        SELECT attnum, attname
580
        FROM
581
        (
582
            SELECT
583
                indrelid
584
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
585
                    AS indkey
586
            FROM pg_index
587
            JOIN pg_class index ON index.oid = indexrelid
588
            JOIN pg_class table_ ON table_.oid = indrelid
589
            WHERE
590
                table_.relname = %(table)s
591
                AND index.relname = %(index)s
592
        ) s
593
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
594
) s
595 853 aaronmk
ORDER BY attnum
596
''',
597 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
598
    else: raise NotImplementedError("Can't list index columns for "+module+
599
        ' database')
600 853 aaronmk
601 464 aaronmk
def constraint_cols(db, table, constraint):
602
    check_name(table)
603
    check_name(constraint)
604 1849 aaronmk
    module = util.root_module(db.db)
605 464 aaronmk
    if module == 'psycopg2':
606
        return list(values(run_query(db, '''\
607
SELECT attname
608
FROM pg_constraint
609
JOIN pg_class ON pg_class.oid = conrelid
610
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
611
WHERE
612
    relname = %(table)s
613
    AND conname = %(constraint)s
614
ORDER BY attnum
615
''',
616
            {'table': table, 'constraint': constraint})))
617
    else: raise NotImplementedError("Can't list constraint columns for "+module+
618
        ' database')
619
620 2096 aaronmk
row_num_col = '_row_num'
621
622 2086 aaronmk
def add_row_num(db, table):
623 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
624
    be the primary key.'''
625 2086 aaronmk
    check_name(table)
626 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
627 2117 aaronmk
        +' serial NOT NULL PRIMARY KEY')
628 2086 aaronmk
629 1968 aaronmk
def tables(db, schema='public', table_like='%'):
630 1849 aaronmk
    module = util.root_module(db.db)
631 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
632 832 aaronmk
    if module == 'psycopg2':
633 1968 aaronmk
        return values(run_query(db, '''\
634
SELECT tablename
635
FROM pg_tables
636
WHERE
637
    schemaname = %(schema)s
638
    AND tablename LIKE %(table_like)s
639
ORDER BY tablename
640
''',
641
            params, cacheable=True))
642
    elif module == 'MySQLdb':
643
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
644
            cacheable=True))
645 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
646 830 aaronmk
647 833 aaronmk
##### Database management
648
649 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
650
    '''For kw_args, see tables()'''
651
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
652 833 aaronmk
653 832 aaronmk
##### Heuristic queries
654
655 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
656 1554 aaronmk
    '''Recovers from errors.
657 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
658
    '''
659 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
660
661 471 aaronmk
    try:
662 2149 aaronmk
        cur = insert(db, table, row, pkey_, recover=True)
663 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
664
            row_ct_ref[0] += cur.rowcount
665
        return value(cur)
666 471 aaronmk
    except DuplicateKeyException, e:
667 2104 aaronmk
        return value(select(db, table, [pkey_],
668 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
669 471 aaronmk
670 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
671 830 aaronmk
    '''Recovers from errors'''
672
    try: return value(select(db, table, [pkey], row, 1, recover=True))
673 14 aaronmk
    except StopIteration:
674 40 aaronmk
        if not create: raise
675 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
676 2078 aaronmk
677 2134 aaronmk
def put_table(db, out_table, in_tables, mapping, limit=None, start=0,
678
    row_ct_ref=None, table_is_esc=False):
679 2078 aaronmk
    '''Recovers from errors.
680
    Only works under PostgreSQL (uses INSERT RETURNING).
681 2131 aaronmk
    @param in_tables The main input table to select from, followed by a list of
682
        tables to join with it using the main input table's pkey
683 2133 aaronmk
    @return (table, col) Where the pkeys (from INSERT RETURNING) are made
684 2078 aaronmk
        available
685
    '''
686 2162 aaronmk
    temp_suffix = clean_name(out_table)
687 2158 aaronmk
        # suffix, not prefix, so main name won't be removed if name is truncated
688
    pkeys_ref = ['pkeys_'+temp_suffix]
689 2131 aaronmk
690 2132 aaronmk
    # Join together input tables
691 2131 aaronmk
    in_tables = in_tables[:] # don't modify input!
692
    in_tables0 = in_tables.pop(0) # first table is separate
693
    in_pkey = pkey(db, in_tables0, recover=True, table_is_esc=table_is_esc)
694 2142 aaronmk
    in_joins = [in_tables0] + [(t, {in_pkey: join_using}) for t in in_tables]
695 2131 aaronmk
696 2142 aaronmk
    out_pkey = pkey(db, out_table, recover=True, table_is_esc=table_is_esc)
697
    pkeys_cols = [in_pkey, out_pkey]
698
699 2132 aaronmk
    def mk_select_(cols):
700 2142 aaronmk
        return mk_select(db, in_joins, cols, limit=limit, start=start,
701 2134 aaronmk
            table_is_esc=table_is_esc)
702 2132 aaronmk
703 2158 aaronmk
    out_pkeys_ref = ['out_pkeys_'+temp_suffix]
704 2078 aaronmk
    def insert_():
705 2148 aaronmk
        '''Inserts and capture output pkeys.'''
706 2132 aaronmk
        cur = insert_select(db, out_table, mapping.keys(),
707 2133 aaronmk
            *mk_select_(mapping.values()), returning=out_pkey,
708 2155 aaronmk
            into_ref=out_pkeys_ref, recover=True, table_is_esc=table_is_esc)
709 2078 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
710
            row_ct_ref[0] += cur.rowcount
711 2155 aaronmk
            add_row_num(db, out_pkeys_ref[0]) # for joining it with input pkeys
712 2086 aaronmk
713 2132 aaronmk
        # Get input pkeys corresponding to rows in insert
714 2158 aaronmk
        in_pkeys_ref = ['in_pkeys_'+temp_suffix]
715 2155 aaronmk
        run_query_into(db, *mk_select_([in_pkey]), into_ref=in_pkeys_ref)
716
        add_row_num(db, in_pkeys_ref[0]) # for joining it with output pkeys
717 2086 aaronmk
718 2155 aaronmk
        # Join together output and input pkeys
719 2132 aaronmk
        run_query_into(db, *mk_select(db,
720 2155 aaronmk
            [in_pkeys_ref[0], (out_pkeys_ref[0], {row_num_col: join_using})],
721 2154 aaronmk
            pkeys_cols, start=0), into_ref=pkeys_ref)
722 2132 aaronmk
723 2142 aaronmk
    # Do inserts and selects
724 2148 aaronmk
    try: insert_()
725 2142 aaronmk
    except DuplicateKeyException, e:
726
        join_cols = util.dict_subset_right_join(mapping, e.cols)
727
        joins = in_joins + [(out_table, join_cols)]
728
        run_query_into(db, *mk_select(db, joins, pkeys_cols,
729 2154 aaronmk
            table_is_esc=table_is_esc), into_ref=pkeys_ref, recover=True)
730 2142 aaronmk
731 2154 aaronmk
    return (pkeys_ref[0], out_pkey)
732 2115 aaronmk
733
##### Data cleanup
734
735
def cleanup_table(db, table, cols, table_is_esc=False):
736
    def esc_name_(name): return esc_name(db, name)
737
738
    if not table_is_esc: check_name(table)
739
    cols = map(esc_name_, cols)
740
741
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
742
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
743
            for col in cols))),
744
        dict(null0='', null1=r'\N'))