Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 2127 aaronmk
import operator
5 11 aaronmk
import re
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 1889 aaronmk
from Proxy import Proxy
13 1872 aaronmk
import rand
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 135 aaronmk
def get_cur_query(cur):
20
    if hasattr(cur, 'query'): return cur.query
21
    elif hasattr(cur, '_last_executed'): return cur._last_executed
22
    else: return None
23 14 aaronmk
24 2164 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+str(get_cur_query(cur)))
25 135 aaronmk
26 300 aaronmk
class DbException(exc.ExceptionWithCause):
27 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
28 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
29 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
30
31 2143 aaronmk
class ExceptionWithName(DbException):
32
    def __init__(self, name, cause=None):
33 2145 aaronmk
        DbException.__init__(self, 'for name: '+str(name), cause)
34 2143 aaronmk
        self.name = name
35 360 aaronmk
36 468 aaronmk
class ExceptionWithColumns(DbException):
37
    def __init__(self, cols, cause=None):
38 2145 aaronmk
        DbException.__init__(self, 'for columns: '+(', '.join(cols)), cause)
39 468 aaronmk
        self.cols = cols
40 11 aaronmk
41 2143 aaronmk
class NameException(DbException): pass
42
43 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
44 13 aaronmk
45 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
46 13 aaronmk
47 2143 aaronmk
class DuplicateTableException(ExceptionWithName): pass
48
49 89 aaronmk
class EmptyRowException(DbException): pass
50
51 865 aaronmk
##### Warnings
52
53
class DbWarning(UserWarning): pass
54
55 1930 aaronmk
##### Result retrieval
56
57
def col_names(cur): return (col[0] for col in cur.description)
58
59
def rows(cur): return iter(lambda: cur.fetchone(), None)
60
61
def consume_rows(cur):
62
    '''Used to fetch all rows so result will be cached'''
63
    iters.consume_iter(rows(cur))
64
65
def next_row(cur): return rows(cur).next()
66
67
def row(cur):
68
    row_ = next_row(cur)
69
    consume_rows(cur)
70
    return row_
71
72
def next_value(cur): return next_row(cur)[0]
73
74
def value(cur): return row(cur)[0]
75
76
def values(cur): return iters.func_iter(lambda: next_value(cur))
77
78
def value_or_none(cur):
79
    try: return value(cur)
80
    except StopIteration: return None
81
82 2101 aaronmk
##### Input validation
83
84
def clean_name(name): return re.sub(r'\W', r'', name)
85
86
def check_name(name):
87
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
88
        +'" may contain only alphanumeric characters and _')
89
90
def esc_name_by_module(module, name, ignore_case=False):
91
    if module == 'psycopg2':
92
        if ignore_case:
93
            # Don't enclose in quotes because this disables case-insensitivity
94
            check_name(name)
95
            return name
96
        else: quote = '"'
97
    elif module == 'MySQLdb': quote = '`'
98
    else: raise NotImplementedError("Can't escape name for "+module+' database')
99
    return quote + name.replace(quote, '') + quote
100
101
def esc_name_by_engine(engine, name, **kw_args):
102
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
103
104
def esc_name(db, name, **kw_args):
105
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
106
107
def qual_name(db, schema, table):
108
    def esc_name_(name): return esc_name(db, name)
109
    table = esc_name_(table)
110
    if schema != None: return esc_name_(schema)+'.'+table
111
    else: return table
112
113 1869 aaronmk
##### Database connections
114 1849 aaronmk
115 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
116 1926 aaronmk
117 1869 aaronmk
db_engines = {
118
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
119
    'PostgreSQL': ('psycopg2', {}),
120
}
121
122
DatabaseErrors_set = set([DbException])
123
DatabaseErrors = tuple(DatabaseErrors_set)
124
125
def _add_module(module):
126
    DatabaseErrors_set.add(module.DatabaseError)
127
    global DatabaseErrors
128
    DatabaseErrors = tuple(DatabaseErrors_set)
129
130
def db_config_str(db_config):
131
    return db_config['engine']+' database '+db_config['database']
132
133 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
134 1894 aaronmk
135 1901 aaronmk
log_debug_none = lambda msg: None
136
137 1849 aaronmk
class DbConn:
138 2047 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none,
139 2050 aaronmk
        caching=True):
140 1869 aaronmk
        self.db_config = db_config
141
        self.serializable = serializable
142 1901 aaronmk
        self.log_debug = log_debug
143 2047 aaronmk
        self.caching = caching
144 1869 aaronmk
145
        self.__db = None
146 1889 aaronmk
        self.query_results = {}
147 2139 aaronmk
        self._savepoint = 0
148 1869 aaronmk
149
    def __getattr__(self, name):
150
        if name == '__dict__': raise Exception('getting __dict__')
151
        if name == 'db': return self._db()
152
        else: raise AttributeError()
153
154
    def __getstate__(self):
155
        state = copy.copy(self.__dict__) # shallow copy
156 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
157 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
158
        return state
159
160 2165 aaronmk
    def connected(self): return self.__db != None
161
162 1869 aaronmk
    def _db(self):
163
        if self.__db == None:
164
            # Process db_config
165
            db_config = self.db_config.copy() # don't modify input!
166 2097 aaronmk
            schemas = db_config.pop('schemas', None)
167 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
168
            module = __import__(module_name)
169
            _add_module(module)
170
            for orig, new in mappings.iteritems():
171
                try: util.rename_key(db_config, orig, new)
172
                except KeyError: pass
173
174
            # Connect
175
            self.__db = module.connect(**db_config)
176
177
            # Configure connection
178
            if self.serializable: run_raw_query(self,
179
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
180 2101 aaronmk
            if schemas != None:
181
                schemas_ = ''.join((esc_name(self, s)+', '
182
                    for s in schemas.split(',')))
183
                run_raw_query(self, "SELECT set_config('search_path', \
184
%s || current_setting('search_path'), false)", [schemas_])
185 1869 aaronmk
186
        return self.__db
187 1889 aaronmk
188 1891 aaronmk
    class DbCursor(Proxy):
189 1927 aaronmk
        def __init__(self, outer):
190 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
191 1927 aaronmk
            self.query_results = outer.query_results
192 1894 aaronmk
            self.query_lookup = None
193 1891 aaronmk
            self.result = []
194 1889 aaronmk
195 1894 aaronmk
        def execute(self, query, params=None):
196 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
197 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
198 2148 aaronmk
            try:
199
                try: return_value = self.inner.execute(query, params)
200
                finally: self.query = get_cur_query(self.inner)
201 1904 aaronmk
            except Exception, e:
202 2148 aaronmk
                _add_cursor_info(e, self)
203 1904 aaronmk
                self.result = e # cache the exception as the result
204
                self._cache_result()
205
                raise
206 1930 aaronmk
            # Fetch all rows so result will be cached
207
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
208 1894 aaronmk
            return return_value
209
210 1891 aaronmk
        def fetchone(self):
211
            row = self.inner.fetchone()
212 1899 aaronmk
            if row != None: self.result.append(row)
213
            # otherwise, fetched all rows
214 1904 aaronmk
            else: self._cache_result()
215
            return row
216
217
        def _cache_result(self):
218 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
219
            # idempotent, but an invalid insert will always be invalid
220 1930 aaronmk
            if self.query_results != None and (not self._is_insert
221 1906 aaronmk
                or isinstance(self.result, Exception)):
222
223 1894 aaronmk
                assert self.query_lookup != None
224 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
225
                    util.dict_subset(dicts.AttrsDictView(self),
226
                    ['query', 'result', 'rowcount', 'description']))
227 1906 aaronmk
228 1916 aaronmk
        class CacheCursor:
229
            def __init__(self, cached_result): self.__dict__ = cached_result
230
231 1927 aaronmk
            def execute(self, *args, **kw_args):
232 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
233
                # otherwise, result is a rows list
234
                self.iter = iter(self.result)
235
236
            def fetchone(self):
237
                try: return self.iter.next()
238
                except StopIteration: return None
239 1891 aaronmk
240 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
241 2148 aaronmk
        '''Translates known DB errors to typed exceptions:
242
        See self.DbCursor.execute().'''
243 2167 aaronmk
        assert query != None
244
245 2047 aaronmk
        if not self.caching: cacheable = False
246 1903 aaronmk
        used_cache = False
247
        try:
248 1927 aaronmk
            # Get cursor
249
            if cacheable:
250
                query_lookup = _query_lookup(query, params)
251
                try:
252
                    cur = self.query_results[query_lookup]
253
                    used_cache = True
254
                except KeyError: cur = self.DbCursor(self)
255
            else: cur = self.db.cursor()
256
257
            # Run query
258 2148 aaronmk
            cur.execute(query, params)
259 1903 aaronmk
        finally:
260
            if self.log_debug != log_debug_none: # only compute msg if needed
261
                if used_cache: cache_status = 'Cache hit'
262
                elif cacheable: cache_status = 'Cache miss'
263
                else: cache_status = 'Non-cacheable'
264 1927 aaronmk
                self.log_debug(cache_status+': '
265 2164 aaronmk
                    +strings.one_line(str(get_cur_query(cur))))
266 1903 aaronmk
267
        return cur
268 1914 aaronmk
269
    def is_cached(self, query, params=None):
270
        return _query_lookup(query, params) in self.query_results
271 2139 aaronmk
272
    def with_savepoint(self, func):
273
        savepoint = 'savepoint_'+str(self._savepoint)
274
        self.run_query('SAVEPOINT '+savepoint)
275
        self._savepoint += 1
276
        try:
277
            try: return_val = func()
278
            finally:
279
                self._savepoint -= 1
280
                assert self._savepoint >= 0
281
        except:
282
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint)
283
            raise
284
        else:
285
            self.run_query('RELEASE SAVEPOINT '+savepoint)
286
            return return_val
287 1849 aaronmk
288 1869 aaronmk
connect = DbConn
289
290 832 aaronmk
##### Querying
291
292 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
293 2085 aaronmk
    '''For params, see DbConn.run_query()'''
294 1894 aaronmk
    return db.run_query(*args, **kw_args)
295 11 aaronmk
296 2068 aaronmk
def mogrify(db, query, params):
297
    module = util.root_module(db.db)
298
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
299
    else: raise NotImplementedError("Can't mogrify query for "+module+
300
        ' database')
301
302 832 aaronmk
##### Recoverable querying
303 15 aaronmk
304 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
305 11 aaronmk
306 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
307 830 aaronmk
    if recover == None: recover = False
308
309 2148 aaronmk
    try:
310
        def run(): return run_raw_query(db, query, params, cacheable)
311
        if recover and not db.is_cached(query, params):
312
            return with_savepoint(db, run)
313
        else: return run() # don't need savepoint if cached
314
    except Exception, e:
315
        if not recover: raise # need savepoint to run index_cols()
316
        msg = str(e)
317
        match = re.search(r'duplicate key value violates unique constraint '
318
            r'"((_?[^\W_]+)_[^"]+)"', msg)
319
        if match:
320
            constraint, table = match.groups()
321
            try: cols = index_cols(db, table, constraint)
322
            except NotImplementedError: raise e
323
            else: raise DuplicateKeyException(cols, e)
324
        match = re.search(r'null value in column "(\w+)" violates not-null '
325
            'constraint', msg)
326
        if match: raise NullValueException([match.group(1)], e)
327
        match = re.search(r'relation "(\w+)" already exists', msg)
328
        if match: raise DuplicateTableException(match.group(1), e)
329
        raise # no specific exception raised
330 830 aaronmk
331 832 aaronmk
##### Basic queries
332
333 2153 aaronmk
def next_version(name):
334
    '''Prepends the version # so it won't be removed if the name is truncated'''
335 2163 aaronmk
    version = 1 # first existing name was version 0
336 2153 aaronmk
    match = re.match(r'^v(\d+)_(.*)$', name)
337
    if match:
338
        version = int(match.group(1))+1
339
        name = match.group(2)
340
    return 'v'+str(version)+'_'+name
341
342 2151 aaronmk
def run_query_into(db, query, params, into_ref=None, *args, **kw_args):
343 2085 aaronmk
    '''Outputs a query to a temp table.
344
    For params, see run_query().
345
    '''
346 2151 aaronmk
    if into_ref == None: return run_query(db, query, params, *args, **kw_args)
347 2085 aaronmk
    else: # place rows in temp table
348 2151 aaronmk
        check_name(into_ref[0])
349 2153 aaronmk
        kw_args['recover'] = True
350
        while True:
351
            try:
352
                return run_query(db, 'CREATE TEMP TABLE '+into_ref[0]+' AS '
353
                    +query, params, *args, **kw_args)
354
                    # CREATE TABLE AS sets rowcount to # rows in query
355
            except DuplicateTableException, e:
356
                into_ref[0] = next_version(into_ref[0])
357
                # try again with next version of name
358 2085 aaronmk
359 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
360
361 2127 aaronmk
join_using = object() # tells mk_select() to join the column with USING
362
363 2121 aaronmk
def mk_select(db, tables, fields=None, conds=None, limit=None, start=None,
364 2120 aaronmk
    order_by=order_by_pkey, table_is_esc=False):
365 1981 aaronmk
    '''
366 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
367 2124 aaronmk
        together: [table0, (table1, dict(right_col=left_col, ...)), ...]
368 1981 aaronmk
    @param fields Use None to select all fields in the table
369
    @param table_is_esc Whether the table name has already been escaped
370 2054 aaronmk
    @return tuple(query, params)
371 1981 aaronmk
    '''
372 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
373 2058 aaronmk
374 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
375 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
376 2121 aaronmk
    table0 = tables.pop(0) # first table is separate
377
378 1135 aaronmk
    if conds == None: conds = {}
379 135 aaronmk
    assert limit == None or type(limit) == int
380 865 aaronmk
    assert start == None or type(start) == int
381 2120 aaronmk
    if order_by == order_by_pkey:
382 2121 aaronmk
        order_by = pkey(db, table0, recover=True, table_is_esc=table_is_esc)
383
    if not table_is_esc: table0 = esc_name_(table0)
384 865 aaronmk
385 2056 aaronmk
    params = []
386
387 2161 aaronmk
    def parse_col(field, default_table=None):
388 2056 aaronmk
        '''Parses fields'''
389 2157 aaronmk
        is_tuple = isinstance(field, tuple)
390
        if is_tuple and len(field) == 1: # field is literal value
391
            value, = field
392 2056 aaronmk
            sql_ = '%s'
393
            params.append(value)
394 2157 aaronmk
        elif is_tuple and len(field) == 2: # field is col with table
395
            table, col = field
396
            if not table_is_esc: table = esc_name_(table)
397
            sql_ = table+'.'+esc_name_(col)
398 2161 aaronmk
        else:
399
            sql_ = esc_name_(field) # field is col name
400
            if default_table != None: sql_ = default_table+'.'+sql_
401 2056 aaronmk
        return sql_
402 11 aaronmk
    def cond(entry):
403 2056 aaronmk
        '''Parses conditions'''
404 13 aaronmk
        col, value = entry
405 2058 aaronmk
        cond_ = esc_name_(col)+' '
406 11 aaronmk
        if value == None: cond_ += 'IS'
407
        else: cond_ += '='
408
        cond_ += ' %s'
409
        return cond_
410 2056 aaronmk
411 1135 aaronmk
    query = 'SELECT '
412
    if fields == None: query += '*'
413 2056 aaronmk
    else: query += ', '.join(map(parse_col, fields))
414 2121 aaronmk
    query += ' FROM '+table0
415 865 aaronmk
416 2122 aaronmk
    # Add joins
417
    left_table = table0
418
    for table, joins in tables:
419
        if not table_is_esc: table = esc_name_(table)
420 2127 aaronmk
        query += ' JOIN '+table
421 2122 aaronmk
422
        def join(entry):
423 2127 aaronmk
            '''Parses non-USING joins'''
424 2124 aaronmk
            right_col, left_col = entry
425
            right_col = table+'.'+esc_name_(right_col)
426 2161 aaronmk
            left_col = parse_col(left_col, left_table)
427 2123 aaronmk
            return (right_col+' = '+left_col
428
                +' OR ('+right_col+' IS NULL AND '+left_col+' IS NULL)')
429 2122 aaronmk
430 2127 aaronmk
        if reduce(operator.and_, (v == join_using for v in joins.itervalues())):
431
            # all cols w/ USING
432 2130 aaronmk
            query += ' USING ('+(', '.join(joins.iterkeys()))+')'
433 2127 aaronmk
        else: query += ' ON '+(' AND '.join(map(join, joins.iteritems())))
434
435 2122 aaronmk
        left_table = table
436
437 865 aaronmk
    missing = True
438 89 aaronmk
    if conds != {}:
439 2122 aaronmk
        query += ' WHERE '+(' AND '.join(map(cond, conds.iteritems())))
440 2056 aaronmk
        params += conds.values()
441 865 aaronmk
        missing = False
442 2120 aaronmk
    if order_by != None: query += ' ORDER BY '+esc_name_(order_by)
443 865 aaronmk
    if limit != None: query += ' LIMIT '+str(limit); missing = False
444
    if start != None:
445
        if start != 0: query += ' OFFSET '+str(start)
446
        missing = False
447
    if missing: warnings.warn(DbWarning(
448
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
449
450 2056 aaronmk
    return (query, params)
451 11 aaronmk
452 2054 aaronmk
def select(db, *args, **kw_args):
453
    '''For params, see mk_select() and run_query()'''
454
    recover = kw_args.pop('recover', None)
455
    cacheable = kw_args.pop('cacheable', True)
456
457
    query, params = mk_select(db, *args, **kw_args)
458
    return run_query(db, query, params, recover, cacheable)
459
460 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
461 2070 aaronmk
    returning=None, embeddable=False, table_is_esc=False):
462 1960 aaronmk
    '''
463
    @param returning str|None An inserted column (such as pkey) to return
464 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
465 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
466
        query will be fully cached, not just if it raises an exception.
467 1960 aaronmk
    @param table_is_esc Whether the table name has already been escaped
468
    '''
469 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
470
    if cols == []: cols = None # no cols (all defaults) = unknown col names
471 1960 aaronmk
    if not table_is_esc: check_name(table)
472 2063 aaronmk
473
    # Build query
474
    query = 'INSERT INTO '+table
475
    if cols != None:
476
        map(check_name, cols)
477
        query += ' ('+', '.join(cols)+')'
478
    query += ' '+select_query
479
480
    if returning != None:
481
        check_name(returning)
482
        query += ' RETURNING '+returning
483
484 2070 aaronmk
    if embeddable:
485
        # Create function
486 2083 aaronmk
        function = 'pg_temp.'+('_'.join(map(clean_name,
487
            ['insert', table] + cols)))
488 2070 aaronmk
        return_type = 'SETOF '+table+'.'+returning+'%TYPE'
489
        function_query = '''\
490 2083 aaronmk
CREATE OR REPLACE FUNCTION '''+function+'''() RETURNS '''+return_type+'''
491 2070 aaronmk
    LANGUAGE sql
492
    AS $$'''+mogrify(db, query, params)+''';$$;
493
'''
494
        run_query(db, function_query, cacheable=True)
495
496
        # Return query that uses function
497 2134 aaronmk
        return mk_select(db, function+'() AS f ('+returning+')', start=0,
498
            order_by=None, table_is_esc=True)# AS clause requires function alias
499 2070 aaronmk
500 2066 aaronmk
    return (query, params)
501
502
def insert_select(db, *args, **kw_args):
503 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
504 2152 aaronmk
    @param into_ref List with name of temp table to place RETURNING values in
505 2072 aaronmk
    '''
506 2151 aaronmk
    into_ref = kw_args.pop('into_ref', None)
507
    if into_ref != None: kw_args['embeddable'] = True
508 2066 aaronmk
    recover = kw_args.pop('recover', None)
509
    cacheable = kw_args.pop('cacheable', True)
510
511
    query, params = mk_insert_select(db, *args, **kw_args)
512 2153 aaronmk
    return run_query_into(db, query, params, into_ref, recover=recover,
513
        cacheable=cacheable)
514 2063 aaronmk
515 2066 aaronmk
default = object() # tells insert() to use the default value for a column
516
517 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
518 2085 aaronmk
    '''For params, see insert_select()'''
519 1960 aaronmk
    if lists.is_seq(row): cols = None
520
    else:
521
        cols = row.keys()
522
        row = row.values()
523
    row = list(row) # ensure that "!= []" works
524
525 1961 aaronmk
    # Check for special values
526
    labels = []
527
    values = []
528
    for value in row:
529
        if value == default: labels.append('DEFAULT')
530
        else:
531
            labels.append('%s')
532
            values.append(value)
533
534
    # Build query
535 2063 aaronmk
    if values != []: query = ' VALUES ('+(', '.join(labels))+')'
536
    else: query = None
537 1554 aaronmk
538 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
539 11 aaronmk
540 135 aaronmk
def last_insert_id(db):
541 1849 aaronmk
    module = util.root_module(db.db)
542 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
543
    elif module == 'MySQLdb': return db.insert_id()
544
    else: return None
545 13 aaronmk
546 1968 aaronmk
def truncate(db, table, schema='public'):
547
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
548 832 aaronmk
549
##### Database structure queries
550
551 2084 aaronmk
def pkey(db, table, recover=None, table_is_esc=False):
552 832 aaronmk
    '''Assumed to be first column in table'''
553 2120 aaronmk
    return col_names(select(db, table, limit=0, order_by=None, recover=recover,
554 2084 aaronmk
        table_is_esc=table_is_esc)).next()
555 832 aaronmk
556 853 aaronmk
def index_cols(db, table, index):
557
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
558
    automatically created. When you don't know whether something is a UNIQUE
559
    constraint or a UNIQUE index, use this function.'''
560
    check_name(table)
561
    check_name(index)
562 1909 aaronmk
    module = util.root_module(db.db)
563
    if module == 'psycopg2':
564
        return list(values(run_query(db, '''\
565 853 aaronmk
SELECT attname
566 866 aaronmk
FROM
567
(
568
        SELECT attnum, attname
569
        FROM pg_index
570
        JOIN pg_class index ON index.oid = indexrelid
571
        JOIN pg_class table_ ON table_.oid = indrelid
572
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
573
        WHERE
574
            table_.relname = %(table)s
575
            AND index.relname = %(index)s
576
    UNION
577
        SELECT attnum, attname
578
        FROM
579
        (
580
            SELECT
581
                indrelid
582
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
583
                    AS indkey
584
            FROM pg_index
585
            JOIN pg_class index ON index.oid = indexrelid
586
            JOIN pg_class table_ ON table_.oid = indrelid
587
            WHERE
588
                table_.relname = %(table)s
589
                AND index.relname = %(index)s
590
        ) s
591
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
592
) s
593 853 aaronmk
ORDER BY attnum
594
''',
595 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
596
    else: raise NotImplementedError("Can't list index columns for "+module+
597
        ' database')
598 853 aaronmk
599 464 aaronmk
def constraint_cols(db, table, constraint):
600
    check_name(table)
601
    check_name(constraint)
602 1849 aaronmk
    module = util.root_module(db.db)
603 464 aaronmk
    if module == 'psycopg2':
604
        return list(values(run_query(db, '''\
605
SELECT attname
606
FROM pg_constraint
607
JOIN pg_class ON pg_class.oid = conrelid
608
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
609
WHERE
610
    relname = %(table)s
611
    AND conname = %(constraint)s
612
ORDER BY attnum
613
''',
614
            {'table': table, 'constraint': constraint})))
615
    else: raise NotImplementedError("Can't list constraint columns for "+module+
616
        ' database')
617
618 2096 aaronmk
row_num_col = '_row_num'
619
620 2086 aaronmk
def add_row_num(db, table):
621 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
622
    be the primary key.'''
623 2086 aaronmk
    check_name(table)
624 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
625 2117 aaronmk
        +' serial NOT NULL PRIMARY KEY')
626 2086 aaronmk
627 1968 aaronmk
def tables(db, schema='public', table_like='%'):
628 1849 aaronmk
    module = util.root_module(db.db)
629 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
630 832 aaronmk
    if module == 'psycopg2':
631 1968 aaronmk
        return values(run_query(db, '''\
632
SELECT tablename
633
FROM pg_tables
634
WHERE
635
    schemaname = %(schema)s
636
    AND tablename LIKE %(table_like)s
637
ORDER BY tablename
638
''',
639
            params, cacheable=True))
640
    elif module == 'MySQLdb':
641
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
642
            cacheable=True))
643 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
644 830 aaronmk
645 833 aaronmk
##### Database management
646
647 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
648
    '''For kw_args, see tables()'''
649
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
650 833 aaronmk
651 832 aaronmk
##### Heuristic queries
652
653 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
654 1554 aaronmk
    '''Recovers from errors.
655 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
656
    '''
657 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
658
659 471 aaronmk
    try:
660 2149 aaronmk
        cur = insert(db, table, row, pkey_, recover=True)
661 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
662
            row_ct_ref[0] += cur.rowcount
663
        return value(cur)
664 471 aaronmk
    except DuplicateKeyException, e:
665 2104 aaronmk
        return value(select(db, table, [pkey_],
666 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
667 471 aaronmk
668 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
669 830 aaronmk
    '''Recovers from errors'''
670
    try: return value(select(db, table, [pkey], row, 1, recover=True))
671 14 aaronmk
    except StopIteration:
672 40 aaronmk
        if not create: raise
673 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
674 2078 aaronmk
675 2134 aaronmk
def put_table(db, out_table, in_tables, mapping, limit=None, start=0,
676
    row_ct_ref=None, table_is_esc=False):
677 2078 aaronmk
    '''Recovers from errors.
678
    Only works under PostgreSQL (uses INSERT RETURNING).
679 2131 aaronmk
    @param in_tables The main input table to select from, followed by a list of
680
        tables to join with it using the main input table's pkey
681 2133 aaronmk
    @return (table, col) Where the pkeys (from INSERT RETURNING) are made
682 2078 aaronmk
        available
683
    '''
684 2162 aaronmk
    temp_suffix = clean_name(out_table)
685 2158 aaronmk
        # suffix, not prefix, so main name won't be removed if name is truncated
686
    pkeys_ref = ['pkeys_'+temp_suffix]
687 2131 aaronmk
688 2132 aaronmk
    # Join together input tables
689 2131 aaronmk
    in_tables = in_tables[:] # don't modify input!
690
    in_tables0 = in_tables.pop(0) # first table is separate
691
    in_pkey = pkey(db, in_tables0, recover=True, table_is_esc=table_is_esc)
692 2142 aaronmk
    in_joins = [in_tables0] + [(t, {in_pkey: join_using}) for t in in_tables]
693 2131 aaronmk
694 2142 aaronmk
    out_pkey = pkey(db, out_table, recover=True, table_is_esc=table_is_esc)
695
    pkeys_cols = [in_pkey, out_pkey]
696
697 2132 aaronmk
    def mk_select_(cols):
698 2142 aaronmk
        return mk_select(db, in_joins, cols, limit=limit, start=start,
699 2134 aaronmk
            table_is_esc=table_is_esc)
700 2132 aaronmk
701 2158 aaronmk
    out_pkeys_ref = ['out_pkeys_'+temp_suffix]
702 2078 aaronmk
    def insert_():
703 2148 aaronmk
        '''Inserts and capture output pkeys.'''
704 2132 aaronmk
        cur = insert_select(db, out_table, mapping.keys(),
705 2133 aaronmk
            *mk_select_(mapping.values()), returning=out_pkey,
706 2155 aaronmk
            into_ref=out_pkeys_ref, recover=True, table_is_esc=table_is_esc)
707 2078 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
708
            row_ct_ref[0] += cur.rowcount
709 2155 aaronmk
            add_row_num(db, out_pkeys_ref[0]) # for joining it with input pkeys
710 2086 aaronmk
711 2132 aaronmk
        # Get input pkeys corresponding to rows in insert
712 2158 aaronmk
        in_pkeys_ref = ['in_pkeys_'+temp_suffix]
713 2155 aaronmk
        run_query_into(db, *mk_select_([in_pkey]), into_ref=in_pkeys_ref)
714
        add_row_num(db, in_pkeys_ref[0]) # for joining it with output pkeys
715 2086 aaronmk
716 2155 aaronmk
        # Join together output and input pkeys
717 2132 aaronmk
        run_query_into(db, *mk_select(db,
718 2155 aaronmk
            [in_pkeys_ref[0], (out_pkeys_ref[0], {row_num_col: join_using})],
719 2154 aaronmk
            pkeys_cols, start=0), into_ref=pkeys_ref)
720 2132 aaronmk
721 2142 aaronmk
    # Do inserts and selects
722 2148 aaronmk
    try: insert_()
723 2142 aaronmk
    except DuplicateKeyException, e:
724
        join_cols = util.dict_subset_right_join(mapping, e.cols)
725
        joins = in_joins + [(out_table, join_cols)]
726
        run_query_into(db, *mk_select(db, joins, pkeys_cols,
727 2154 aaronmk
            table_is_esc=table_is_esc), into_ref=pkeys_ref, recover=True)
728 2142 aaronmk
729 2154 aaronmk
    return (pkeys_ref[0], out_pkey)
730 2115 aaronmk
731
##### Data cleanup
732
733
def cleanup_table(db, table, cols, table_is_esc=False):
734
    def esc_name_(name): return esc_name(db, name)
735
736
    if not table_is_esc: check_name(table)
737
    cols = map(esc_name_, cols)
738
739
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
740
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
741
            for col in cols))),
742
        dict(null0='', null1=r'\N'))