Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 2127 aaronmk
import operator
5 11 aaronmk
import re
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 1889 aaronmk
from Proxy import Proxy
13 1872 aaronmk
import rand
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 135 aaronmk
def get_cur_query(cur):
20
    if hasattr(cur, 'query'): return cur.query
21
    elif hasattr(cur, '_last_executed'): return cur._last_executed
22
    else: return None
23 14 aaronmk
24 2164 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+str(get_cur_query(cur)))
25 135 aaronmk
26 300 aaronmk
class DbException(exc.ExceptionWithCause):
27 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
28 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
29 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
30
31 2143 aaronmk
class ExceptionWithName(DbException):
32
    def __init__(self, name, cause=None):
33 2145 aaronmk
        DbException.__init__(self, 'for name: '+str(name), cause)
34 2143 aaronmk
        self.name = name
35 360 aaronmk
36 468 aaronmk
class ExceptionWithColumns(DbException):
37
    def __init__(self, cols, cause=None):
38 2145 aaronmk
        DbException.__init__(self, 'for columns: '+(', '.join(cols)), cause)
39 468 aaronmk
        self.cols = cols
40 11 aaronmk
41 2143 aaronmk
class NameException(DbException): pass
42
43 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
44 13 aaronmk
45 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
46 13 aaronmk
47 2143 aaronmk
class DuplicateTableException(ExceptionWithName): pass
48
49 89 aaronmk
class EmptyRowException(DbException): pass
50
51 865 aaronmk
##### Warnings
52
53
class DbWarning(UserWarning): pass
54
55 1930 aaronmk
##### Result retrieval
56
57
def col_names(cur): return (col[0] for col in cur.description)
58
59
def rows(cur): return iter(lambda: cur.fetchone(), None)
60
61
def consume_rows(cur):
62
    '''Used to fetch all rows so result will be cached'''
63
    iters.consume_iter(rows(cur))
64
65
def next_row(cur): return rows(cur).next()
66
67
def row(cur):
68
    row_ = next_row(cur)
69
    consume_rows(cur)
70
    return row_
71
72
def next_value(cur): return next_row(cur)[0]
73
74
def value(cur): return row(cur)[0]
75
76
def values(cur): return iters.func_iter(lambda: next_value(cur))
77
78
def value_or_none(cur):
79
    try: return value(cur)
80
    except StopIteration: return None
81
82 2101 aaronmk
##### Input validation
83
84
def clean_name(name): return re.sub(r'\W', r'', name)
85
86
def check_name(name):
87
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
88
        +'" may contain only alphanumeric characters and _')
89
90
def esc_name_by_module(module, name, ignore_case=False):
91
    if module == 'psycopg2':
92
        if ignore_case:
93
            # Don't enclose in quotes because this disables case-insensitivity
94
            check_name(name)
95
            return name
96
        else: quote = '"'
97
    elif module == 'MySQLdb': quote = '`'
98
    else: raise NotImplementedError("Can't escape name for "+module+' database')
99
    return quote + name.replace(quote, '') + quote
100
101
def esc_name_by_engine(engine, name, **kw_args):
102
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
103
104
def esc_name(db, name, **kw_args):
105
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
106
107
def qual_name(db, schema, table):
108
    def esc_name_(name): return esc_name(db, name)
109
    table = esc_name_(table)
110
    if schema != None: return esc_name_(schema)+'.'+table
111
    else: return table
112
113 1869 aaronmk
##### Database connections
114 1849 aaronmk
115 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
116 1926 aaronmk
117 1869 aaronmk
db_engines = {
118
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
119
    'PostgreSQL': ('psycopg2', {}),
120
}
121
122
DatabaseErrors_set = set([DbException])
123
DatabaseErrors = tuple(DatabaseErrors_set)
124
125
def _add_module(module):
126
    DatabaseErrors_set.add(module.DatabaseError)
127
    global DatabaseErrors
128
    DatabaseErrors = tuple(DatabaseErrors_set)
129
130
def db_config_str(db_config):
131
    return db_config['engine']+' database '+db_config['database']
132
133 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
134 1894 aaronmk
135 1901 aaronmk
log_debug_none = lambda msg: None
136
137 1849 aaronmk
class DbConn:
138 2047 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none,
139 2050 aaronmk
        caching=True):
140 1869 aaronmk
        self.db_config = db_config
141
        self.serializable = serializable
142 1901 aaronmk
        self.log_debug = log_debug
143 2047 aaronmk
        self.caching = caching
144 1869 aaronmk
145
        self.__db = None
146 1889 aaronmk
        self.query_results = {}
147 2139 aaronmk
        self._savepoint = 0
148 1869 aaronmk
149
    def __getattr__(self, name):
150
        if name == '__dict__': raise Exception('getting __dict__')
151
        if name == 'db': return self._db()
152
        else: raise AttributeError()
153
154
    def __getstate__(self):
155
        state = copy.copy(self.__dict__) # shallow copy
156 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
157 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
158
        return state
159
160 2165 aaronmk
    def connected(self): return self.__db != None
161
162 1869 aaronmk
    def _db(self):
163
        if self.__db == None:
164
            # Process db_config
165
            db_config = self.db_config.copy() # don't modify input!
166 2097 aaronmk
            schemas = db_config.pop('schemas', None)
167 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
168
            module = __import__(module_name)
169
            _add_module(module)
170
            for orig, new in mappings.iteritems():
171
                try: util.rename_key(db_config, orig, new)
172
                except KeyError: pass
173
174
            # Connect
175
            self.__db = module.connect(**db_config)
176
177
            # Configure connection
178
            if self.serializable: run_raw_query(self,
179
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
180 2101 aaronmk
            if schemas != None:
181
                schemas_ = ''.join((esc_name(self, s)+', '
182
                    for s in schemas.split(',')))
183
                run_raw_query(self, "SELECT set_config('search_path', \
184
%s || current_setting('search_path'), false)", [schemas_])
185 1869 aaronmk
186
        return self.__db
187 1889 aaronmk
188 1891 aaronmk
    class DbCursor(Proxy):
189 1927 aaronmk
        def __init__(self, outer):
190 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
191 1927 aaronmk
            self.query_results = outer.query_results
192 1894 aaronmk
            self.query_lookup = None
193 1891 aaronmk
            self.result = []
194 1889 aaronmk
195 1894 aaronmk
        def execute(self, query, params=None):
196 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
197 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
198 2148 aaronmk
            try:
199
                try: return_value = self.inner.execute(query, params)
200
                finally: self.query = get_cur_query(self.inner)
201 1904 aaronmk
            except Exception, e:
202 2148 aaronmk
                _add_cursor_info(e, self)
203 1904 aaronmk
                self.result = e # cache the exception as the result
204
                self._cache_result()
205
                raise
206 1930 aaronmk
            # Fetch all rows so result will be cached
207
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
208 1894 aaronmk
            return return_value
209
210 1891 aaronmk
        def fetchone(self):
211
            row = self.inner.fetchone()
212 1899 aaronmk
            if row != None: self.result.append(row)
213
            # otherwise, fetched all rows
214 1904 aaronmk
            else: self._cache_result()
215
            return row
216
217
        def _cache_result(self):
218 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
219
            # idempotent, but an invalid insert will always be invalid
220 1930 aaronmk
            if self.query_results != None and (not self._is_insert
221 1906 aaronmk
                or isinstance(self.result, Exception)):
222
223 1894 aaronmk
                assert self.query_lookup != None
224 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
225
                    util.dict_subset(dicts.AttrsDictView(self),
226
                    ['query', 'result', 'rowcount', 'description']))
227 1906 aaronmk
228 1916 aaronmk
        class CacheCursor:
229
            def __init__(self, cached_result): self.__dict__ = cached_result
230
231 1927 aaronmk
            def execute(self, *args, **kw_args):
232 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
233
                # otherwise, result is a rows list
234
                self.iter = iter(self.result)
235
236
            def fetchone(self):
237
                try: return self.iter.next()
238
                except StopIteration: return None
239 1891 aaronmk
240 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
241 2148 aaronmk
        '''Translates known DB errors to typed exceptions:
242
        See self.DbCursor.execute().'''
243 2047 aaronmk
        if not self.caching: cacheable = False
244 1903 aaronmk
        used_cache = False
245
        try:
246 1927 aaronmk
            # Get cursor
247
            if cacheable:
248
                query_lookup = _query_lookup(query, params)
249
                try:
250
                    cur = self.query_results[query_lookup]
251
                    used_cache = True
252
                except KeyError: cur = self.DbCursor(self)
253
            else: cur = self.db.cursor()
254
255
            # Run query
256 2148 aaronmk
            cur.execute(query, params)
257 1903 aaronmk
        finally:
258
            if self.log_debug != log_debug_none: # only compute msg if needed
259
                if used_cache: cache_status = 'Cache hit'
260
                elif cacheable: cache_status = 'Cache miss'
261
                else: cache_status = 'Non-cacheable'
262 1927 aaronmk
                self.log_debug(cache_status+': '
263 2164 aaronmk
                    +strings.one_line(str(get_cur_query(cur))))
264 1903 aaronmk
265
        return cur
266 1914 aaronmk
267
    def is_cached(self, query, params=None):
268
        return _query_lookup(query, params) in self.query_results
269 2139 aaronmk
270
    def with_savepoint(self, func):
271
        savepoint = 'savepoint_'+str(self._savepoint)
272
        self.run_query('SAVEPOINT '+savepoint)
273
        self._savepoint += 1
274
        try:
275
            try: return_val = func()
276
            finally:
277
                self._savepoint -= 1
278
                assert self._savepoint >= 0
279
        except:
280
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint)
281
            raise
282
        else:
283
            self.run_query('RELEASE SAVEPOINT '+savepoint)
284
            return return_val
285 1849 aaronmk
286 1869 aaronmk
connect = DbConn
287
288 832 aaronmk
##### Querying
289
290 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
291 2085 aaronmk
    '''For params, see DbConn.run_query()'''
292 1894 aaronmk
    return db.run_query(*args, **kw_args)
293 11 aaronmk
294 2068 aaronmk
def mogrify(db, query, params):
295
    module = util.root_module(db.db)
296
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
297
    else: raise NotImplementedError("Can't mogrify query for "+module+
298
        ' database')
299
300 832 aaronmk
##### Recoverable querying
301 15 aaronmk
302 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
303 11 aaronmk
304 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
305 830 aaronmk
    if recover == None: recover = False
306
307 2148 aaronmk
    try:
308
        def run(): return run_raw_query(db, query, params, cacheable)
309
        if recover and not db.is_cached(query, params):
310
            return with_savepoint(db, run)
311
        else: return run() # don't need savepoint if cached
312
    except Exception, e:
313
        if not recover: raise # need savepoint to run index_cols()
314
        msg = str(e)
315
        match = re.search(r'duplicate key value violates unique constraint '
316
            r'"((_?[^\W_]+)_[^"]+)"', msg)
317
        if match:
318
            constraint, table = match.groups()
319
            try: cols = index_cols(db, table, constraint)
320
            except NotImplementedError: raise e
321
            else: raise DuplicateKeyException(cols, e)
322
        match = re.search(r'null value in column "(\w+)" violates not-null '
323
            'constraint', msg)
324
        if match: raise NullValueException([match.group(1)], e)
325
        match = re.search(r'relation "(\w+)" already exists', msg)
326
        if match: raise DuplicateTableException(match.group(1), e)
327
        raise # no specific exception raised
328 830 aaronmk
329 832 aaronmk
##### Basic queries
330
331 2153 aaronmk
def next_version(name):
332
    '''Prepends the version # so it won't be removed if the name is truncated'''
333 2163 aaronmk
    version = 1 # first existing name was version 0
334 2153 aaronmk
    match = re.match(r'^v(\d+)_(.*)$', name)
335
    if match:
336
        version = int(match.group(1))+1
337
        name = match.group(2)
338
    return 'v'+str(version)+'_'+name
339
340 2151 aaronmk
def run_query_into(db, query, params, into_ref=None, *args, **kw_args):
341 2085 aaronmk
    '''Outputs a query to a temp table.
342
    For params, see run_query().
343
    '''
344 2151 aaronmk
    if into_ref == None: return run_query(db, query, params, *args, **kw_args)
345 2085 aaronmk
    else: # place rows in temp table
346 2151 aaronmk
        check_name(into_ref[0])
347 2153 aaronmk
        kw_args['recover'] = True
348
        while True:
349
            try:
350
                return run_query(db, 'CREATE TEMP TABLE '+into_ref[0]+' AS '
351
                    +query, params, *args, **kw_args)
352
                    # CREATE TABLE AS sets rowcount to # rows in query
353
            except DuplicateTableException, e:
354
                into_ref[0] = next_version(into_ref[0])
355
                # try again with next version of name
356 2085 aaronmk
357 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
358
359 2127 aaronmk
join_using = object() # tells mk_select() to join the column with USING
360
361 2121 aaronmk
def mk_select(db, tables, fields=None, conds=None, limit=None, start=None,
362 2120 aaronmk
    order_by=order_by_pkey, table_is_esc=False):
363 1981 aaronmk
    '''
364 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
365 2124 aaronmk
        together: [table0, (table1, dict(right_col=left_col, ...)), ...]
366 1981 aaronmk
    @param fields Use None to select all fields in the table
367
    @param table_is_esc Whether the table name has already been escaped
368 2054 aaronmk
    @return tuple(query, params)
369 1981 aaronmk
    '''
370 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
371 2058 aaronmk
372 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
373 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
374 2121 aaronmk
    table0 = tables.pop(0) # first table is separate
375
376 1135 aaronmk
    if conds == None: conds = {}
377 135 aaronmk
    assert limit == None or type(limit) == int
378 865 aaronmk
    assert start == None or type(start) == int
379 2120 aaronmk
    if order_by == order_by_pkey:
380 2121 aaronmk
        order_by = pkey(db, table0, recover=True, table_is_esc=table_is_esc)
381
    if not table_is_esc: table0 = esc_name_(table0)
382 865 aaronmk
383 2056 aaronmk
    params = []
384
385 2161 aaronmk
    def parse_col(field, default_table=None):
386 2056 aaronmk
        '''Parses fields'''
387 2157 aaronmk
        is_tuple = isinstance(field, tuple)
388
        if is_tuple and len(field) == 1: # field is literal value
389
            value, = field
390 2056 aaronmk
            sql_ = '%s'
391
            params.append(value)
392 2157 aaronmk
        elif is_tuple and len(field) == 2: # field is col with table
393
            table, col = field
394
            if not table_is_esc: table = esc_name_(table)
395
            sql_ = table+'.'+esc_name_(col)
396 2161 aaronmk
        else:
397
            sql_ = esc_name_(field) # field is col name
398
            if default_table != None: sql_ = default_table+'.'+sql_
399 2056 aaronmk
        return sql_
400 11 aaronmk
    def cond(entry):
401 2056 aaronmk
        '''Parses conditions'''
402 13 aaronmk
        col, value = entry
403 2058 aaronmk
        cond_ = esc_name_(col)+' '
404 11 aaronmk
        if value == None: cond_ += 'IS'
405
        else: cond_ += '='
406
        cond_ += ' %s'
407
        return cond_
408 2056 aaronmk
409 1135 aaronmk
    query = 'SELECT '
410
    if fields == None: query += '*'
411 2056 aaronmk
    else: query += ', '.join(map(parse_col, fields))
412 2121 aaronmk
    query += ' FROM '+table0
413 865 aaronmk
414 2122 aaronmk
    # Add joins
415
    left_table = table0
416
    for table, joins in tables:
417
        if not table_is_esc: table = esc_name_(table)
418 2127 aaronmk
        query += ' JOIN '+table
419 2122 aaronmk
420
        def join(entry):
421 2127 aaronmk
            '''Parses non-USING joins'''
422 2124 aaronmk
            right_col, left_col = entry
423
            right_col = table+'.'+esc_name_(right_col)
424 2161 aaronmk
            left_col = parse_col(left_col, left_table)
425 2123 aaronmk
            return (right_col+' = '+left_col
426
                +' OR ('+right_col+' IS NULL AND '+left_col+' IS NULL)')
427 2122 aaronmk
428 2127 aaronmk
        if reduce(operator.and_, (v == join_using for v in joins.itervalues())):
429
            # all cols w/ USING
430 2130 aaronmk
            query += ' USING ('+(', '.join(joins.iterkeys()))+')'
431 2127 aaronmk
        else: query += ' ON '+(' AND '.join(map(join, joins.iteritems())))
432
433 2122 aaronmk
        left_table = table
434
435 865 aaronmk
    missing = True
436 89 aaronmk
    if conds != {}:
437 2122 aaronmk
        query += ' WHERE '+(' AND '.join(map(cond, conds.iteritems())))
438 2056 aaronmk
        params += conds.values()
439 865 aaronmk
        missing = False
440 2120 aaronmk
    if order_by != None: query += ' ORDER BY '+esc_name_(order_by)
441 865 aaronmk
    if limit != None: query += ' LIMIT '+str(limit); missing = False
442
    if start != None:
443
        if start != 0: query += ' OFFSET '+str(start)
444
        missing = False
445
    if missing: warnings.warn(DbWarning(
446
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
447
448 2056 aaronmk
    return (query, params)
449 11 aaronmk
450 2054 aaronmk
def select(db, *args, **kw_args):
451
    '''For params, see mk_select() and run_query()'''
452
    recover = kw_args.pop('recover', None)
453
    cacheable = kw_args.pop('cacheable', True)
454
455
    query, params = mk_select(db, *args, **kw_args)
456
    return run_query(db, query, params, recover, cacheable)
457
458 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
459 2070 aaronmk
    returning=None, embeddable=False, table_is_esc=False):
460 1960 aaronmk
    '''
461
    @param returning str|None An inserted column (such as pkey) to return
462 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
463 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
464
        query will be fully cached, not just if it raises an exception.
465 1960 aaronmk
    @param table_is_esc Whether the table name has already been escaped
466
    '''
467 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
468
    if cols == []: cols = None # no cols (all defaults) = unknown col names
469 1960 aaronmk
    if not table_is_esc: check_name(table)
470 2063 aaronmk
471
    # Build query
472
    query = 'INSERT INTO '+table
473
    if cols != None:
474
        map(check_name, cols)
475
        query += ' ('+', '.join(cols)+')'
476
    query += ' '+select_query
477
478
    if returning != None:
479
        check_name(returning)
480
        query += ' RETURNING '+returning
481
482 2070 aaronmk
    if embeddable:
483
        # Create function
484 2083 aaronmk
        function = 'pg_temp.'+('_'.join(map(clean_name,
485
            ['insert', table] + cols)))
486 2070 aaronmk
        return_type = 'SETOF '+table+'.'+returning+'%TYPE'
487
        function_query = '''\
488 2083 aaronmk
CREATE OR REPLACE FUNCTION '''+function+'''() RETURNS '''+return_type+'''
489 2070 aaronmk
    LANGUAGE sql
490
    AS $$'''+mogrify(db, query, params)+''';$$;
491
'''
492
        run_query(db, function_query, cacheable=True)
493
494
        # Return query that uses function
495 2134 aaronmk
        return mk_select(db, function+'() AS f ('+returning+')', start=0,
496
            order_by=None, table_is_esc=True)# AS clause requires function alias
497 2070 aaronmk
498 2066 aaronmk
    return (query, params)
499
500
def insert_select(db, *args, **kw_args):
501 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
502 2152 aaronmk
    @param into_ref List with name of temp table to place RETURNING values in
503 2072 aaronmk
    '''
504 2151 aaronmk
    into_ref = kw_args.pop('into_ref', None)
505
    if into_ref != None: kw_args['embeddable'] = True
506 2066 aaronmk
    recover = kw_args.pop('recover', None)
507
    cacheable = kw_args.pop('cacheable', True)
508
509
    query, params = mk_insert_select(db, *args, **kw_args)
510 2153 aaronmk
    return run_query_into(db, query, params, into_ref, recover=recover,
511
        cacheable=cacheable)
512 2063 aaronmk
513 2066 aaronmk
default = object() # tells insert() to use the default value for a column
514
515 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
516 2085 aaronmk
    '''For params, see insert_select()'''
517 1960 aaronmk
    if lists.is_seq(row): cols = None
518
    else:
519
        cols = row.keys()
520
        row = row.values()
521
    row = list(row) # ensure that "!= []" works
522
523 1961 aaronmk
    # Check for special values
524
    labels = []
525
    values = []
526
    for value in row:
527
        if value == default: labels.append('DEFAULT')
528
        else:
529
            labels.append('%s')
530
            values.append(value)
531
532
    # Build query
533 2063 aaronmk
    if values != []: query = ' VALUES ('+(', '.join(labels))+')'
534
    else: query = None
535 1554 aaronmk
536 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
537 11 aaronmk
538 135 aaronmk
def last_insert_id(db):
539 1849 aaronmk
    module = util.root_module(db.db)
540 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
541
    elif module == 'MySQLdb': return db.insert_id()
542
    else: return None
543 13 aaronmk
544 1968 aaronmk
def truncate(db, table, schema='public'):
545
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
546 832 aaronmk
547
##### Database structure queries
548
549 2084 aaronmk
def pkey(db, table, recover=None, table_is_esc=False):
550 832 aaronmk
    '''Assumed to be first column in table'''
551 2120 aaronmk
    return col_names(select(db, table, limit=0, order_by=None, recover=recover,
552 2084 aaronmk
        table_is_esc=table_is_esc)).next()
553 832 aaronmk
554 853 aaronmk
def index_cols(db, table, index):
555
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
556
    automatically created. When you don't know whether something is a UNIQUE
557
    constraint or a UNIQUE index, use this function.'''
558
    check_name(table)
559
    check_name(index)
560 1909 aaronmk
    module = util.root_module(db.db)
561
    if module == 'psycopg2':
562
        return list(values(run_query(db, '''\
563 853 aaronmk
SELECT attname
564 866 aaronmk
FROM
565
(
566
        SELECT attnum, attname
567
        FROM pg_index
568
        JOIN pg_class index ON index.oid = indexrelid
569
        JOIN pg_class table_ ON table_.oid = indrelid
570
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
571
        WHERE
572
            table_.relname = %(table)s
573
            AND index.relname = %(index)s
574
    UNION
575
        SELECT attnum, attname
576
        FROM
577
        (
578
            SELECT
579
                indrelid
580
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
581
                    AS indkey
582
            FROM pg_index
583
            JOIN pg_class index ON index.oid = indexrelid
584
            JOIN pg_class table_ ON table_.oid = indrelid
585
            WHERE
586
                table_.relname = %(table)s
587
                AND index.relname = %(index)s
588
        ) s
589
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
590
) s
591 853 aaronmk
ORDER BY attnum
592
''',
593 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
594
    else: raise NotImplementedError("Can't list index columns for "+module+
595
        ' database')
596 853 aaronmk
597 464 aaronmk
def constraint_cols(db, table, constraint):
598
    check_name(table)
599
    check_name(constraint)
600 1849 aaronmk
    module = util.root_module(db.db)
601 464 aaronmk
    if module == 'psycopg2':
602
        return list(values(run_query(db, '''\
603
SELECT attname
604
FROM pg_constraint
605
JOIN pg_class ON pg_class.oid = conrelid
606
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
607
WHERE
608
    relname = %(table)s
609
    AND conname = %(constraint)s
610
ORDER BY attnum
611
''',
612
            {'table': table, 'constraint': constraint})))
613
    else: raise NotImplementedError("Can't list constraint columns for "+module+
614
        ' database')
615
616 2096 aaronmk
row_num_col = '_row_num'
617
618 2086 aaronmk
def add_row_num(db, table):
619 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
620
    be the primary key.'''
621 2086 aaronmk
    check_name(table)
622 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
623 2117 aaronmk
        +' serial NOT NULL PRIMARY KEY')
624 2086 aaronmk
625 1968 aaronmk
def tables(db, schema='public', table_like='%'):
626 1849 aaronmk
    module = util.root_module(db.db)
627 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
628 832 aaronmk
    if module == 'psycopg2':
629 1968 aaronmk
        return values(run_query(db, '''\
630
SELECT tablename
631
FROM pg_tables
632
WHERE
633
    schemaname = %(schema)s
634
    AND tablename LIKE %(table_like)s
635
ORDER BY tablename
636
''',
637
            params, cacheable=True))
638
    elif module == 'MySQLdb':
639
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
640
            cacheable=True))
641 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
642 830 aaronmk
643 833 aaronmk
##### Database management
644
645 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
646
    '''For kw_args, see tables()'''
647
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
648 833 aaronmk
649 832 aaronmk
##### Heuristic queries
650
651 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
652 1554 aaronmk
    '''Recovers from errors.
653 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
654
    '''
655 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
656
657 471 aaronmk
    try:
658 2149 aaronmk
        cur = insert(db, table, row, pkey_, recover=True)
659 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
660
            row_ct_ref[0] += cur.rowcount
661
        return value(cur)
662 471 aaronmk
    except DuplicateKeyException, e:
663 2104 aaronmk
        return value(select(db, table, [pkey_],
664 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
665 471 aaronmk
666 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
667 830 aaronmk
    '''Recovers from errors'''
668
    try: return value(select(db, table, [pkey], row, 1, recover=True))
669 14 aaronmk
    except StopIteration:
670 40 aaronmk
        if not create: raise
671 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
672 2078 aaronmk
673 2134 aaronmk
def put_table(db, out_table, in_tables, mapping, limit=None, start=0,
674
    row_ct_ref=None, table_is_esc=False):
675 2078 aaronmk
    '''Recovers from errors.
676
    Only works under PostgreSQL (uses INSERT RETURNING).
677 2131 aaronmk
    @param in_tables The main input table to select from, followed by a list of
678
        tables to join with it using the main input table's pkey
679 2133 aaronmk
    @return (table, col) Where the pkeys (from INSERT RETURNING) are made
680 2078 aaronmk
        available
681
    '''
682 2162 aaronmk
    temp_suffix = clean_name(out_table)
683 2158 aaronmk
        # suffix, not prefix, so main name won't be removed if name is truncated
684
    pkeys_ref = ['pkeys_'+temp_suffix]
685 2131 aaronmk
686 2132 aaronmk
    # Join together input tables
687 2131 aaronmk
    in_tables = in_tables[:] # don't modify input!
688
    in_tables0 = in_tables.pop(0) # first table is separate
689
    in_pkey = pkey(db, in_tables0, recover=True, table_is_esc=table_is_esc)
690 2142 aaronmk
    in_joins = [in_tables0] + [(t, {in_pkey: join_using}) for t in in_tables]
691 2131 aaronmk
692 2142 aaronmk
    out_pkey = pkey(db, out_table, recover=True, table_is_esc=table_is_esc)
693
    pkeys_cols = [in_pkey, out_pkey]
694
695 2132 aaronmk
    def mk_select_(cols):
696 2142 aaronmk
        return mk_select(db, in_joins, cols, limit=limit, start=start,
697 2134 aaronmk
            table_is_esc=table_is_esc)
698 2132 aaronmk
699 2158 aaronmk
    out_pkeys_ref = ['out_pkeys_'+temp_suffix]
700 2078 aaronmk
    def insert_():
701 2148 aaronmk
        '''Inserts and capture output pkeys.'''
702 2132 aaronmk
        cur = insert_select(db, out_table, mapping.keys(),
703 2133 aaronmk
            *mk_select_(mapping.values()), returning=out_pkey,
704 2155 aaronmk
            into_ref=out_pkeys_ref, recover=True, table_is_esc=table_is_esc)
705 2078 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
706
            row_ct_ref[0] += cur.rowcount
707 2155 aaronmk
            add_row_num(db, out_pkeys_ref[0]) # for joining it with input pkeys
708 2086 aaronmk
709 2132 aaronmk
        # Get input pkeys corresponding to rows in insert
710 2158 aaronmk
        in_pkeys_ref = ['in_pkeys_'+temp_suffix]
711 2155 aaronmk
        run_query_into(db, *mk_select_([in_pkey]), into_ref=in_pkeys_ref)
712
        add_row_num(db, in_pkeys_ref[0]) # for joining it with output pkeys
713 2086 aaronmk
714 2155 aaronmk
        # Join together output and input pkeys
715 2132 aaronmk
        run_query_into(db, *mk_select(db,
716 2155 aaronmk
            [in_pkeys_ref[0], (out_pkeys_ref[0], {row_num_col: join_using})],
717 2154 aaronmk
            pkeys_cols, start=0), into_ref=pkeys_ref)
718 2132 aaronmk
719 2142 aaronmk
    # Do inserts and selects
720 2148 aaronmk
    try: insert_()
721 2142 aaronmk
    except DuplicateKeyException, e:
722
        join_cols = util.dict_subset_right_join(mapping, e.cols)
723
        joins = in_joins + [(out_table, join_cols)]
724
        run_query_into(db, *mk_select(db, joins, pkeys_cols,
725 2154 aaronmk
            table_is_esc=table_is_esc), into_ref=pkeys_ref, recover=True)
726 2142 aaronmk
727 2154 aaronmk
    return (pkeys_ref[0], out_pkey)
728 2115 aaronmk
729
##### Data cleanup
730
731
def cleanup_table(db, table, cols, table_is_esc=False):
732
    def esc_name_(name): return esc_name(db, name)
733
734
    if not table_is_esc: check_name(table)
735
    cols = map(esc_name_, cols)
736
737
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
738
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
739
            for col in cols))),
740
        dict(null0='', null1=r'\N'))