Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 2127 aaronmk
import operator
5 11 aaronmk
import re
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 1889 aaronmk
from Proxy import Proxy
13 1872 aaronmk
import rand
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 135 aaronmk
def get_cur_query(cur):
20
    if hasattr(cur, 'query'): return cur.query
21
    elif hasattr(cur, '_last_executed'): return cur._last_executed
22
    else: return None
23 14 aaronmk
24 2164 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+str(get_cur_query(cur)))
25 135 aaronmk
26 300 aaronmk
class DbException(exc.ExceptionWithCause):
27 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
28 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
29 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
30
31 2143 aaronmk
class ExceptionWithName(DbException):
32
    def __init__(self, name, cause=None):
33 2145 aaronmk
        DbException.__init__(self, 'for name: '+str(name), cause)
34 2143 aaronmk
        self.name = name
35 360 aaronmk
36 468 aaronmk
class ExceptionWithColumns(DbException):
37
    def __init__(self, cols, cause=None):
38 2145 aaronmk
        DbException.__init__(self, 'for columns: '+(', '.join(cols)), cause)
39 468 aaronmk
        self.cols = cols
40 11 aaronmk
41 2143 aaronmk
class NameException(DbException): pass
42
43 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
44 13 aaronmk
45 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
46 13 aaronmk
47 2143 aaronmk
class DuplicateTableException(ExceptionWithName): pass
48
49 89 aaronmk
class EmptyRowException(DbException): pass
50
51 865 aaronmk
##### Warnings
52
53
class DbWarning(UserWarning): pass
54
55 1930 aaronmk
##### Result retrieval
56
57
def col_names(cur): return (col[0] for col in cur.description)
58
59
def rows(cur): return iter(lambda: cur.fetchone(), None)
60
61
def consume_rows(cur):
62
    '''Used to fetch all rows so result will be cached'''
63
    iters.consume_iter(rows(cur))
64
65
def next_row(cur): return rows(cur).next()
66
67
def row(cur):
68
    row_ = next_row(cur)
69
    consume_rows(cur)
70
    return row_
71
72
def next_value(cur): return next_row(cur)[0]
73
74
def value(cur): return row(cur)[0]
75
76
def values(cur): return iters.func_iter(lambda: next_value(cur))
77
78
def value_or_none(cur):
79
    try: return value(cur)
80
    except StopIteration: return None
81
82 2101 aaronmk
##### Input validation
83
84
def clean_name(name): return re.sub(r'\W', r'', name)
85
86
def check_name(name):
87
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
88
        +'" may contain only alphanumeric characters and _')
89
90
def esc_name_by_module(module, name, ignore_case=False):
91
    if module == 'psycopg2':
92
        if ignore_case:
93
            # Don't enclose in quotes because this disables case-insensitivity
94
            check_name(name)
95
            return name
96
        else: quote = '"'
97
    elif module == 'MySQLdb': quote = '`'
98
    else: raise NotImplementedError("Can't escape name for "+module+' database')
99
    return quote + name.replace(quote, '') + quote
100
101
def esc_name_by_engine(engine, name, **kw_args):
102
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
103
104
def esc_name(db, name, **kw_args):
105
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
106
107
def qual_name(db, schema, table):
108
    def esc_name_(name): return esc_name(db, name)
109
    table = esc_name_(table)
110
    if schema != None: return esc_name_(schema)+'.'+table
111
    else: return table
112
113 1869 aaronmk
##### Database connections
114 1849 aaronmk
115 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
116 1926 aaronmk
117 1869 aaronmk
db_engines = {
118
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
119
    'PostgreSQL': ('psycopg2', {}),
120
}
121
122
DatabaseErrors_set = set([DbException])
123
DatabaseErrors = tuple(DatabaseErrors_set)
124
125
def _add_module(module):
126
    DatabaseErrors_set.add(module.DatabaseError)
127
    global DatabaseErrors
128
    DatabaseErrors = tuple(DatabaseErrors_set)
129
130
def db_config_str(db_config):
131
    return db_config['engine']+' database '+db_config['database']
132
133 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
134 1894 aaronmk
135 1901 aaronmk
log_debug_none = lambda msg: None
136
137 1849 aaronmk
class DbConn:
138 2047 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none,
139 2050 aaronmk
        caching=True):
140 1869 aaronmk
        self.db_config = db_config
141
        self.serializable = serializable
142 1901 aaronmk
        self.log_debug = log_debug
143 2047 aaronmk
        self.caching = caching
144 1869 aaronmk
145
        self.__db = None
146 1889 aaronmk
        self.query_results = {}
147 2139 aaronmk
        self._savepoint = 0
148 1869 aaronmk
149
    def __getattr__(self, name):
150
        if name == '__dict__': raise Exception('getting __dict__')
151
        if name == 'db': return self._db()
152
        else: raise AttributeError()
153
154
    def __getstate__(self):
155
        state = copy.copy(self.__dict__) # shallow copy
156 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
157 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
158
        return state
159
160
    def _db(self):
161
        if self.__db == None:
162
            # Process db_config
163
            db_config = self.db_config.copy() # don't modify input!
164 2097 aaronmk
            schemas = db_config.pop('schemas', None)
165 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
166
            module = __import__(module_name)
167
            _add_module(module)
168
            for orig, new in mappings.iteritems():
169
                try: util.rename_key(db_config, orig, new)
170
                except KeyError: pass
171
172
            # Connect
173
            self.__db = module.connect(**db_config)
174
175
            # Configure connection
176
            if self.serializable: run_raw_query(self,
177
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
178 2101 aaronmk
            if schemas != None:
179
                schemas_ = ''.join((esc_name(self, s)+', '
180
                    for s in schemas.split(',')))
181
                run_raw_query(self, "SELECT set_config('search_path', \
182
%s || current_setting('search_path'), false)", [schemas_])
183 1869 aaronmk
184
        return self.__db
185 1889 aaronmk
186 1891 aaronmk
    class DbCursor(Proxy):
187 1927 aaronmk
        def __init__(self, outer):
188 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
189 1927 aaronmk
            self.query_results = outer.query_results
190 1894 aaronmk
            self.query_lookup = None
191 1891 aaronmk
            self.result = []
192 1889 aaronmk
193 1894 aaronmk
        def execute(self, query, params=None):
194 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
195 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
196 2148 aaronmk
            try:
197
                try: return_value = self.inner.execute(query, params)
198
                finally: self.query = get_cur_query(self.inner)
199 1904 aaronmk
            except Exception, e:
200 2148 aaronmk
                _add_cursor_info(e, self)
201 1904 aaronmk
                self.result = e # cache the exception as the result
202
                self._cache_result()
203
                raise
204 1930 aaronmk
            # Fetch all rows so result will be cached
205
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
206 1894 aaronmk
            return return_value
207
208 1891 aaronmk
        def fetchone(self):
209
            row = self.inner.fetchone()
210 1899 aaronmk
            if row != None: self.result.append(row)
211
            # otherwise, fetched all rows
212 1904 aaronmk
            else: self._cache_result()
213
            return row
214
215
        def _cache_result(self):
216 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
217
            # idempotent, but an invalid insert will always be invalid
218 1930 aaronmk
            if self.query_results != None and (not self._is_insert
219 1906 aaronmk
                or isinstance(self.result, Exception)):
220
221 1894 aaronmk
                assert self.query_lookup != None
222 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
223
                    util.dict_subset(dicts.AttrsDictView(self),
224
                    ['query', 'result', 'rowcount', 'description']))
225 1906 aaronmk
226 1916 aaronmk
        class CacheCursor:
227
            def __init__(self, cached_result): self.__dict__ = cached_result
228
229 1927 aaronmk
            def execute(self, *args, **kw_args):
230 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
231
                # otherwise, result is a rows list
232
                self.iter = iter(self.result)
233
234
            def fetchone(self):
235
                try: return self.iter.next()
236
                except StopIteration: return None
237 1891 aaronmk
238 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
239 2148 aaronmk
        '''Translates known DB errors to typed exceptions:
240
        See self.DbCursor.execute().'''
241 2047 aaronmk
        if not self.caching: cacheable = False
242 1903 aaronmk
        used_cache = False
243
        try:
244 1927 aaronmk
            # Get cursor
245
            if cacheable:
246
                query_lookup = _query_lookup(query, params)
247
                try:
248
                    cur = self.query_results[query_lookup]
249
                    used_cache = True
250
                except KeyError: cur = self.DbCursor(self)
251
            else: cur = self.db.cursor()
252
253
            # Run query
254 2148 aaronmk
            cur.execute(query, params)
255 1903 aaronmk
        finally:
256
            if self.log_debug != log_debug_none: # only compute msg if needed
257
                if used_cache: cache_status = 'Cache hit'
258
                elif cacheable: cache_status = 'Cache miss'
259
                else: cache_status = 'Non-cacheable'
260 1927 aaronmk
                self.log_debug(cache_status+': '
261 2164 aaronmk
                    +strings.one_line(str(get_cur_query(cur))))
262 1903 aaronmk
263
        return cur
264 1914 aaronmk
265
    def is_cached(self, query, params=None):
266
        return _query_lookup(query, params) in self.query_results
267 2139 aaronmk
268
    def with_savepoint(self, func):
269
        savepoint = 'savepoint_'+str(self._savepoint)
270
        self.run_query('SAVEPOINT '+savepoint)
271
        self._savepoint += 1
272
        try:
273
            try: return_val = func()
274
            finally:
275
                self._savepoint -= 1
276
                assert self._savepoint >= 0
277
        except:
278
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint)
279
            raise
280
        else:
281
            self.run_query('RELEASE SAVEPOINT '+savepoint)
282
            return return_val
283 1849 aaronmk
284 1869 aaronmk
connect = DbConn
285
286 832 aaronmk
##### Querying
287
288 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
289 2085 aaronmk
    '''For params, see DbConn.run_query()'''
290 1894 aaronmk
    return db.run_query(*args, **kw_args)
291 11 aaronmk
292 2068 aaronmk
def mogrify(db, query, params):
293
    module = util.root_module(db.db)
294
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
295
    else: raise NotImplementedError("Can't mogrify query for "+module+
296
        ' database')
297
298 832 aaronmk
##### Recoverable querying
299 15 aaronmk
300 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
301 11 aaronmk
302 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
303 830 aaronmk
    if recover == None: recover = False
304
305 2148 aaronmk
    try:
306
        def run(): return run_raw_query(db, query, params, cacheable)
307
        if recover and not db.is_cached(query, params):
308
            return with_savepoint(db, run)
309
        else: return run() # don't need savepoint if cached
310
    except Exception, e:
311
        if not recover: raise # need savepoint to run index_cols()
312
        msg = str(e)
313
        match = re.search(r'duplicate key value violates unique constraint '
314
            r'"((_?[^\W_]+)_[^"]+)"', msg)
315
        if match:
316
            constraint, table = match.groups()
317
            try: cols = index_cols(db, table, constraint)
318
            except NotImplementedError: raise e
319
            else: raise DuplicateKeyException(cols, e)
320
        match = re.search(r'null value in column "(\w+)" violates not-null '
321
            'constraint', msg)
322
        if match: raise NullValueException([match.group(1)], e)
323
        match = re.search(r'relation "(\w+)" already exists', msg)
324
        if match: raise DuplicateTableException(match.group(1), e)
325
        raise # no specific exception raised
326 830 aaronmk
327 832 aaronmk
##### Basic queries
328
329 2153 aaronmk
def next_version(name):
330
    '''Prepends the version # so it won't be removed if the name is truncated'''
331 2163 aaronmk
    version = 1 # first existing name was version 0
332 2153 aaronmk
    match = re.match(r'^v(\d+)_(.*)$', name)
333
    if match:
334
        version = int(match.group(1))+1
335
        name = match.group(2)
336
    return 'v'+str(version)+'_'+name
337
338 2151 aaronmk
def run_query_into(db, query, params, into_ref=None, *args, **kw_args):
339 2085 aaronmk
    '''Outputs a query to a temp table.
340
    For params, see run_query().
341
    '''
342 2151 aaronmk
    if into_ref == None: return run_query(db, query, params, *args, **kw_args)
343 2085 aaronmk
    else: # place rows in temp table
344 2151 aaronmk
        check_name(into_ref[0])
345 2153 aaronmk
        kw_args['recover'] = True
346
        while True:
347
            try:
348
                return run_query(db, 'CREATE TEMP TABLE '+into_ref[0]+' AS '
349
                    +query, params, *args, **kw_args)
350
                    # CREATE TABLE AS sets rowcount to # rows in query
351
            except DuplicateTableException, e:
352
                into_ref[0] = next_version(into_ref[0])
353
                # try again with next version of name
354 2085 aaronmk
355 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
356
357 2127 aaronmk
join_using = object() # tells mk_select() to join the column with USING
358
359 2121 aaronmk
def mk_select(db, tables, fields=None, conds=None, limit=None, start=None,
360 2120 aaronmk
    order_by=order_by_pkey, table_is_esc=False):
361 1981 aaronmk
    '''
362 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
363 2124 aaronmk
        together: [table0, (table1, dict(right_col=left_col, ...)), ...]
364 1981 aaronmk
    @param fields Use None to select all fields in the table
365
    @param table_is_esc Whether the table name has already been escaped
366 2054 aaronmk
    @return tuple(query, params)
367 1981 aaronmk
    '''
368 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
369 2058 aaronmk
370 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
371 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
372 2121 aaronmk
    table0 = tables.pop(0) # first table is separate
373
374 1135 aaronmk
    if conds == None: conds = {}
375 135 aaronmk
    assert limit == None or type(limit) == int
376 865 aaronmk
    assert start == None or type(start) == int
377 2120 aaronmk
    if order_by == order_by_pkey:
378 2121 aaronmk
        order_by = pkey(db, table0, recover=True, table_is_esc=table_is_esc)
379
    if not table_is_esc: table0 = esc_name_(table0)
380 865 aaronmk
381 2056 aaronmk
    params = []
382
383 2161 aaronmk
    def parse_col(field, default_table=None):
384 2056 aaronmk
        '''Parses fields'''
385 2157 aaronmk
        is_tuple = isinstance(field, tuple)
386
        if is_tuple and len(field) == 1: # field is literal value
387
            value, = field
388 2056 aaronmk
            sql_ = '%s'
389
            params.append(value)
390 2157 aaronmk
        elif is_tuple and len(field) == 2: # field is col with table
391
            table, col = field
392
            if not table_is_esc: table = esc_name_(table)
393
            sql_ = table+'.'+esc_name_(col)
394 2161 aaronmk
        else:
395
            sql_ = esc_name_(field) # field is col name
396
            if default_table != None: sql_ = default_table+'.'+sql_
397 2056 aaronmk
        return sql_
398 11 aaronmk
    def cond(entry):
399 2056 aaronmk
        '''Parses conditions'''
400 13 aaronmk
        col, value = entry
401 2058 aaronmk
        cond_ = esc_name_(col)+' '
402 11 aaronmk
        if value == None: cond_ += 'IS'
403
        else: cond_ += '='
404
        cond_ += ' %s'
405
        return cond_
406 2056 aaronmk
407 1135 aaronmk
    query = 'SELECT '
408
    if fields == None: query += '*'
409 2056 aaronmk
    else: query += ', '.join(map(parse_col, fields))
410 2121 aaronmk
    query += ' FROM '+table0
411 865 aaronmk
412 2122 aaronmk
    # Add joins
413
    left_table = table0
414
    for table, joins in tables:
415
        if not table_is_esc: table = esc_name_(table)
416 2127 aaronmk
        query += ' JOIN '+table
417 2122 aaronmk
418
        def join(entry):
419 2127 aaronmk
            '''Parses non-USING joins'''
420 2124 aaronmk
            right_col, left_col = entry
421
            right_col = table+'.'+esc_name_(right_col)
422 2161 aaronmk
            left_col = parse_col(left_col, left_table)
423 2123 aaronmk
            return (right_col+' = '+left_col
424
                +' OR ('+right_col+' IS NULL AND '+left_col+' IS NULL)')
425 2122 aaronmk
426 2127 aaronmk
        if reduce(operator.and_, (v == join_using for v in joins.itervalues())):
427
            # all cols w/ USING
428 2130 aaronmk
            query += ' USING ('+(', '.join(joins.iterkeys()))+')'
429 2127 aaronmk
        else: query += ' ON '+(' AND '.join(map(join, joins.iteritems())))
430
431 2122 aaronmk
        left_table = table
432
433 865 aaronmk
    missing = True
434 89 aaronmk
    if conds != {}:
435 2122 aaronmk
        query += ' WHERE '+(' AND '.join(map(cond, conds.iteritems())))
436 2056 aaronmk
        params += conds.values()
437 865 aaronmk
        missing = False
438 2120 aaronmk
    if order_by != None: query += ' ORDER BY '+esc_name_(order_by)
439 865 aaronmk
    if limit != None: query += ' LIMIT '+str(limit); missing = False
440
    if start != None:
441
        if start != 0: query += ' OFFSET '+str(start)
442
        missing = False
443
    if missing: warnings.warn(DbWarning(
444
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
445
446 2056 aaronmk
    return (query, params)
447 11 aaronmk
448 2054 aaronmk
def select(db, *args, **kw_args):
449
    '''For params, see mk_select() and run_query()'''
450
    recover = kw_args.pop('recover', None)
451
    cacheable = kw_args.pop('cacheable', True)
452
453
    query, params = mk_select(db, *args, **kw_args)
454
    return run_query(db, query, params, recover, cacheable)
455
456 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
457 2070 aaronmk
    returning=None, embeddable=False, table_is_esc=False):
458 1960 aaronmk
    '''
459
    @param returning str|None An inserted column (such as pkey) to return
460 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
461 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
462
        query will be fully cached, not just if it raises an exception.
463 1960 aaronmk
    @param table_is_esc Whether the table name has already been escaped
464
    '''
465 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
466
    if cols == []: cols = None # no cols (all defaults) = unknown col names
467 1960 aaronmk
    if not table_is_esc: check_name(table)
468 2063 aaronmk
469
    # Build query
470
    query = 'INSERT INTO '+table
471
    if cols != None:
472
        map(check_name, cols)
473
        query += ' ('+', '.join(cols)+')'
474
    query += ' '+select_query
475
476
    if returning != None:
477
        check_name(returning)
478
        query += ' RETURNING '+returning
479
480 2070 aaronmk
    if embeddable:
481
        # Create function
482 2083 aaronmk
        function = 'pg_temp.'+('_'.join(map(clean_name,
483
            ['insert', table] + cols)))
484 2070 aaronmk
        return_type = 'SETOF '+table+'.'+returning+'%TYPE'
485
        function_query = '''\
486 2083 aaronmk
CREATE OR REPLACE FUNCTION '''+function+'''() RETURNS '''+return_type+'''
487 2070 aaronmk
    LANGUAGE sql
488
    AS $$'''+mogrify(db, query, params)+''';$$;
489
'''
490
        run_query(db, function_query, cacheable=True)
491
492
        # Return query that uses function
493 2134 aaronmk
        return mk_select(db, function+'() AS f ('+returning+')', start=0,
494
            order_by=None, table_is_esc=True)# AS clause requires function alias
495 2070 aaronmk
496 2066 aaronmk
    return (query, params)
497
498
def insert_select(db, *args, **kw_args):
499 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
500 2152 aaronmk
    @param into_ref List with name of temp table to place RETURNING values in
501 2072 aaronmk
    '''
502 2151 aaronmk
    into_ref = kw_args.pop('into_ref', None)
503
    if into_ref != None: kw_args['embeddable'] = True
504 2066 aaronmk
    recover = kw_args.pop('recover', None)
505
    cacheable = kw_args.pop('cacheable', True)
506
507
    query, params = mk_insert_select(db, *args, **kw_args)
508 2153 aaronmk
    return run_query_into(db, query, params, into_ref, recover=recover,
509
        cacheable=cacheable)
510 2063 aaronmk
511 2066 aaronmk
default = object() # tells insert() to use the default value for a column
512
513 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
514 2085 aaronmk
    '''For params, see insert_select()'''
515 1960 aaronmk
    if lists.is_seq(row): cols = None
516
    else:
517
        cols = row.keys()
518
        row = row.values()
519
    row = list(row) # ensure that "!= []" works
520
521 1961 aaronmk
    # Check for special values
522
    labels = []
523
    values = []
524
    for value in row:
525
        if value == default: labels.append('DEFAULT')
526
        else:
527
            labels.append('%s')
528
            values.append(value)
529
530
    # Build query
531 2063 aaronmk
    if values != []: query = ' VALUES ('+(', '.join(labels))+')'
532
    else: query = None
533 1554 aaronmk
534 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
535 11 aaronmk
536 135 aaronmk
def last_insert_id(db):
537 1849 aaronmk
    module = util.root_module(db.db)
538 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
539
    elif module == 'MySQLdb': return db.insert_id()
540
    else: return None
541 13 aaronmk
542 1968 aaronmk
def truncate(db, table, schema='public'):
543
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
544 832 aaronmk
545
##### Database structure queries
546
547 2084 aaronmk
def pkey(db, table, recover=None, table_is_esc=False):
548 832 aaronmk
    '''Assumed to be first column in table'''
549 2120 aaronmk
    return col_names(select(db, table, limit=0, order_by=None, recover=recover,
550 2084 aaronmk
        table_is_esc=table_is_esc)).next()
551 832 aaronmk
552 853 aaronmk
def index_cols(db, table, index):
553
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
554
    automatically created. When you don't know whether something is a UNIQUE
555
    constraint or a UNIQUE index, use this function.'''
556
    check_name(table)
557
    check_name(index)
558 1909 aaronmk
    module = util.root_module(db.db)
559
    if module == 'psycopg2':
560
        return list(values(run_query(db, '''\
561 853 aaronmk
SELECT attname
562 866 aaronmk
FROM
563
(
564
        SELECT attnum, attname
565
        FROM pg_index
566
        JOIN pg_class index ON index.oid = indexrelid
567
        JOIN pg_class table_ ON table_.oid = indrelid
568
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
569
        WHERE
570
            table_.relname = %(table)s
571
            AND index.relname = %(index)s
572
    UNION
573
        SELECT attnum, attname
574
        FROM
575
        (
576
            SELECT
577
                indrelid
578
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
579
                    AS indkey
580
            FROM pg_index
581
            JOIN pg_class index ON index.oid = indexrelid
582
            JOIN pg_class table_ ON table_.oid = indrelid
583
            WHERE
584
                table_.relname = %(table)s
585
                AND index.relname = %(index)s
586
        ) s
587
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
588
) s
589 853 aaronmk
ORDER BY attnum
590
''',
591 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
592
    else: raise NotImplementedError("Can't list index columns for "+module+
593
        ' database')
594 853 aaronmk
595 464 aaronmk
def constraint_cols(db, table, constraint):
596
    check_name(table)
597
    check_name(constraint)
598 1849 aaronmk
    module = util.root_module(db.db)
599 464 aaronmk
    if module == 'psycopg2':
600
        return list(values(run_query(db, '''\
601
SELECT attname
602
FROM pg_constraint
603
JOIN pg_class ON pg_class.oid = conrelid
604
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
605
WHERE
606
    relname = %(table)s
607
    AND conname = %(constraint)s
608
ORDER BY attnum
609
''',
610
            {'table': table, 'constraint': constraint})))
611
    else: raise NotImplementedError("Can't list constraint columns for "+module+
612
        ' database')
613
614 2096 aaronmk
row_num_col = '_row_num'
615
616 2086 aaronmk
def add_row_num(db, table):
617 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
618
    be the primary key.'''
619 2086 aaronmk
    check_name(table)
620 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
621 2117 aaronmk
        +' serial NOT NULL PRIMARY KEY')
622 2086 aaronmk
623 1968 aaronmk
def tables(db, schema='public', table_like='%'):
624 1849 aaronmk
    module = util.root_module(db.db)
625 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
626 832 aaronmk
    if module == 'psycopg2':
627 1968 aaronmk
        return values(run_query(db, '''\
628
SELECT tablename
629
FROM pg_tables
630
WHERE
631
    schemaname = %(schema)s
632
    AND tablename LIKE %(table_like)s
633
ORDER BY tablename
634
''',
635
            params, cacheable=True))
636
    elif module == 'MySQLdb':
637
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
638
            cacheable=True))
639 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
640 830 aaronmk
641 833 aaronmk
##### Database management
642
643 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
644
    '''For kw_args, see tables()'''
645
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
646 833 aaronmk
647 832 aaronmk
##### Heuristic queries
648
649 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
650 1554 aaronmk
    '''Recovers from errors.
651 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
652
    '''
653 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
654
655 471 aaronmk
    try:
656 2149 aaronmk
        cur = insert(db, table, row, pkey_, recover=True)
657 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
658
            row_ct_ref[0] += cur.rowcount
659
        return value(cur)
660 471 aaronmk
    except DuplicateKeyException, e:
661 2104 aaronmk
        return value(select(db, table, [pkey_],
662 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
663 471 aaronmk
664 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
665 830 aaronmk
    '''Recovers from errors'''
666
    try: return value(select(db, table, [pkey], row, 1, recover=True))
667 14 aaronmk
    except StopIteration:
668 40 aaronmk
        if not create: raise
669 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
670 2078 aaronmk
671 2134 aaronmk
def put_table(db, out_table, in_tables, mapping, limit=None, start=0,
672
    row_ct_ref=None, table_is_esc=False):
673 2078 aaronmk
    '''Recovers from errors.
674
    Only works under PostgreSQL (uses INSERT RETURNING).
675 2131 aaronmk
    @param in_tables The main input table to select from, followed by a list of
676
        tables to join with it using the main input table's pkey
677 2133 aaronmk
    @return (table, col) Where the pkeys (from INSERT RETURNING) are made
678 2078 aaronmk
        available
679
    '''
680 2162 aaronmk
    temp_suffix = clean_name(out_table)
681 2158 aaronmk
        # suffix, not prefix, so main name won't be removed if name is truncated
682
    pkeys_ref = ['pkeys_'+temp_suffix]
683 2131 aaronmk
684 2132 aaronmk
    # Join together input tables
685 2131 aaronmk
    in_tables = in_tables[:] # don't modify input!
686
    in_tables0 = in_tables.pop(0) # first table is separate
687
    in_pkey = pkey(db, in_tables0, recover=True, table_is_esc=table_is_esc)
688 2142 aaronmk
    in_joins = [in_tables0] + [(t, {in_pkey: join_using}) for t in in_tables]
689 2131 aaronmk
690 2142 aaronmk
    out_pkey = pkey(db, out_table, recover=True, table_is_esc=table_is_esc)
691
    pkeys_cols = [in_pkey, out_pkey]
692
693 2132 aaronmk
    def mk_select_(cols):
694 2142 aaronmk
        return mk_select(db, in_joins, cols, limit=limit, start=start,
695 2134 aaronmk
            table_is_esc=table_is_esc)
696 2132 aaronmk
697 2158 aaronmk
    out_pkeys_ref = ['out_pkeys_'+temp_suffix]
698 2078 aaronmk
    def insert_():
699 2148 aaronmk
        '''Inserts and capture output pkeys.'''
700 2132 aaronmk
        cur = insert_select(db, out_table, mapping.keys(),
701 2133 aaronmk
            *mk_select_(mapping.values()), returning=out_pkey,
702 2155 aaronmk
            into_ref=out_pkeys_ref, recover=True, table_is_esc=table_is_esc)
703 2078 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
704
            row_ct_ref[0] += cur.rowcount
705 2155 aaronmk
            add_row_num(db, out_pkeys_ref[0]) # for joining it with input pkeys
706 2086 aaronmk
707 2132 aaronmk
        # Get input pkeys corresponding to rows in insert
708 2158 aaronmk
        in_pkeys_ref = ['in_pkeys_'+temp_suffix]
709 2155 aaronmk
        run_query_into(db, *mk_select_([in_pkey]), into_ref=in_pkeys_ref)
710
        add_row_num(db, in_pkeys_ref[0]) # for joining it with output pkeys
711 2086 aaronmk
712 2155 aaronmk
        # Join together output and input pkeys
713 2132 aaronmk
        run_query_into(db, *mk_select(db,
714 2155 aaronmk
            [in_pkeys_ref[0], (out_pkeys_ref[0], {row_num_col: join_using})],
715 2154 aaronmk
            pkeys_cols, start=0), into_ref=pkeys_ref)
716 2132 aaronmk
717 2142 aaronmk
    # Do inserts and selects
718 2148 aaronmk
    try: insert_()
719 2142 aaronmk
    except DuplicateKeyException, e:
720
        join_cols = util.dict_subset_right_join(mapping, e.cols)
721
        joins = in_joins + [(out_table, join_cols)]
722
        run_query_into(db, *mk_select(db, joins, pkeys_cols,
723 2154 aaronmk
            table_is_esc=table_is_esc), into_ref=pkeys_ref, recover=True)
724 2142 aaronmk
725 2154 aaronmk
    return (pkeys_ref[0], out_pkey)
726 2115 aaronmk
727
##### Data cleanup
728
729
def cleanup_table(db, table, cols, table_is_esc=False):
730
    def esc_name_(name): return esc_name(db, name)
731
732
    if not table_is_esc: check_name(table)
733
    cols = map(esc_name_, cols)
734
735
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
736
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
737
            for col in cols))),
738
        dict(null0='', null1=r'\N'))