Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 2127 aaronmk
import operator
5 11 aaronmk
import re
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 1889 aaronmk
from Proxy import Proxy
13 1872 aaronmk
import rand
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 2170 aaronmk
def get_cur_query(cur, input_query=None, input_params=None):
20 2168 aaronmk
    raw_query = None
21
    if hasattr(cur, 'query'): raw_query = cur.query
22
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
23 2170 aaronmk
24
    if raw_query != None: return raw_query
25
    else: return repr(input_query)+' % '+repr(input_params)
26 14 aaronmk
27 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
28
    '''For params, see get_cur_query()'''
29
    exc.add_msg(e, 'query: '+str(get_cur_query(*args, **kw_args)))
30 135 aaronmk
31 300 aaronmk
class DbException(exc.ExceptionWithCause):
32 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
33 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
34 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
35
36 2143 aaronmk
class ExceptionWithName(DbException):
37
    def __init__(self, name, cause=None):
38 2145 aaronmk
        DbException.__init__(self, 'for name: '+str(name), cause)
39 2143 aaronmk
        self.name = name
40 360 aaronmk
41 468 aaronmk
class ExceptionWithColumns(DbException):
42
    def __init__(self, cols, cause=None):
43 2145 aaronmk
        DbException.__init__(self, 'for columns: '+(', '.join(cols)), cause)
44 468 aaronmk
        self.cols = cols
45 11 aaronmk
46 2143 aaronmk
class NameException(DbException): pass
47
48 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
49 13 aaronmk
50 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
51 13 aaronmk
52 2143 aaronmk
class DuplicateTableException(ExceptionWithName): pass
53
54 89 aaronmk
class EmptyRowException(DbException): pass
55
56 865 aaronmk
##### Warnings
57
58
class DbWarning(UserWarning): pass
59
60 1930 aaronmk
##### Result retrieval
61
62
def col_names(cur): return (col[0] for col in cur.description)
63
64
def rows(cur): return iter(lambda: cur.fetchone(), None)
65
66
def consume_rows(cur):
67
    '''Used to fetch all rows so result will be cached'''
68
    iters.consume_iter(rows(cur))
69
70
def next_row(cur): return rows(cur).next()
71
72
def row(cur):
73
    row_ = next_row(cur)
74
    consume_rows(cur)
75
    return row_
76
77
def next_value(cur): return next_row(cur)[0]
78
79
def value(cur): return row(cur)[0]
80
81
def values(cur): return iters.func_iter(lambda: next_value(cur))
82
83
def value_or_none(cur):
84
    try: return value(cur)
85
    except StopIteration: return None
86
87 2101 aaronmk
##### Input validation
88
89
def clean_name(name): return re.sub(r'\W', r'', name)
90
91
def check_name(name):
92
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
93
        +'" may contain only alphanumeric characters and _')
94
95
def esc_name_by_module(module, name, ignore_case=False):
96
    if module == 'psycopg2':
97
        if ignore_case:
98
            # Don't enclose in quotes because this disables case-insensitivity
99
            check_name(name)
100
            return name
101
        else: quote = '"'
102
    elif module == 'MySQLdb': quote = '`'
103
    else: raise NotImplementedError("Can't escape name for "+module+' database')
104
    return quote + name.replace(quote, '') + quote
105
106
def esc_name_by_engine(engine, name, **kw_args):
107
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
108
109
def esc_name(db, name, **kw_args):
110
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
111
112
def qual_name(db, schema, table):
113
    def esc_name_(name): return esc_name(db, name)
114
    table = esc_name_(table)
115
    if schema != None: return esc_name_(schema)+'.'+table
116
    else: return table
117
118 1869 aaronmk
##### Database connections
119 1849 aaronmk
120 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
121 1926 aaronmk
122 1869 aaronmk
db_engines = {
123
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
124
    'PostgreSQL': ('psycopg2', {}),
125
}
126
127
DatabaseErrors_set = set([DbException])
128
DatabaseErrors = tuple(DatabaseErrors_set)
129
130
def _add_module(module):
131
    DatabaseErrors_set.add(module.DatabaseError)
132
    global DatabaseErrors
133
    DatabaseErrors = tuple(DatabaseErrors_set)
134
135
def db_config_str(db_config):
136
    return db_config['engine']+' database '+db_config['database']
137
138 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
139 1894 aaronmk
140 1901 aaronmk
log_debug_none = lambda msg: None
141
142 1849 aaronmk
class DbConn:
143 2047 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none,
144 2050 aaronmk
        caching=True):
145 1869 aaronmk
        self.db_config = db_config
146
        self.serializable = serializable
147 1901 aaronmk
        self.log_debug = log_debug
148 2047 aaronmk
        self.caching = caching
149 1869 aaronmk
150
        self.__db = None
151 1889 aaronmk
        self.query_results = {}
152 2139 aaronmk
        self._savepoint = 0
153 1869 aaronmk
154
    def __getattr__(self, name):
155
        if name == '__dict__': raise Exception('getting __dict__')
156
        if name == 'db': return self._db()
157
        else: raise AttributeError()
158
159
    def __getstate__(self):
160
        state = copy.copy(self.__dict__) # shallow copy
161 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
162 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
163
        return state
164
165 2165 aaronmk
    def connected(self): return self.__db != None
166
167 1869 aaronmk
    def _db(self):
168
        if self.__db == None:
169
            # Process db_config
170
            db_config = self.db_config.copy() # don't modify input!
171 2097 aaronmk
            schemas = db_config.pop('schemas', None)
172 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
173
            module = __import__(module_name)
174
            _add_module(module)
175
            for orig, new in mappings.iteritems():
176
                try: util.rename_key(db_config, orig, new)
177
                except KeyError: pass
178
179
            # Connect
180
            self.__db = module.connect(**db_config)
181
182
            # Configure connection
183
            if self.serializable: run_raw_query(self,
184
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
185 2101 aaronmk
            if schemas != None:
186
                schemas_ = ''.join((esc_name(self, s)+', '
187
                    for s in schemas.split(',')))
188
                run_raw_query(self, "SELECT set_config('search_path', \
189
%s || current_setting('search_path'), false)", [schemas_])
190 1869 aaronmk
191
        return self.__db
192 1889 aaronmk
193 1891 aaronmk
    class DbCursor(Proxy):
194 1927 aaronmk
        def __init__(self, outer):
195 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
196 1927 aaronmk
            self.query_results = outer.query_results
197 1894 aaronmk
            self.query_lookup = None
198 1891 aaronmk
            self.result = []
199 1889 aaronmk
200 1894 aaronmk
        def execute(self, query, params=None):
201 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
202 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
203 2148 aaronmk
            try:
204
                try: return_value = self.inner.execute(query, params)
205
                finally: self.query = get_cur_query(self.inner)
206 1904 aaronmk
            except Exception, e:
207 2170 aaronmk
                _add_cursor_info(e, self, query, params)
208 1904 aaronmk
                self.result = e # cache the exception as the result
209
                self._cache_result()
210
                raise
211 1930 aaronmk
            # Fetch all rows so result will be cached
212
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
213 1894 aaronmk
            return return_value
214
215 1891 aaronmk
        def fetchone(self):
216
            row = self.inner.fetchone()
217 1899 aaronmk
            if row != None: self.result.append(row)
218
            # otherwise, fetched all rows
219 1904 aaronmk
            else: self._cache_result()
220
            return row
221
222
        def _cache_result(self):
223 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
224
            # idempotent, but an invalid insert will always be invalid
225 1930 aaronmk
            if self.query_results != None and (not self._is_insert
226 1906 aaronmk
                or isinstance(self.result, Exception)):
227
228 1894 aaronmk
                assert self.query_lookup != None
229 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
230
                    util.dict_subset(dicts.AttrsDictView(self),
231
                    ['query', 'result', 'rowcount', 'description']))
232 1906 aaronmk
233 1916 aaronmk
        class CacheCursor:
234
            def __init__(self, cached_result): self.__dict__ = cached_result
235
236 1927 aaronmk
            def execute(self, *args, **kw_args):
237 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
238
                # otherwise, result is a rows list
239
                self.iter = iter(self.result)
240
241
            def fetchone(self):
242
                try: return self.iter.next()
243
                except StopIteration: return None
244 1891 aaronmk
245 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
246 2148 aaronmk
        '''Translates known DB errors to typed exceptions:
247
        See self.DbCursor.execute().'''
248 2167 aaronmk
        assert query != None
249
250 2047 aaronmk
        if not self.caching: cacheable = False
251 1903 aaronmk
        used_cache = False
252
        try:
253 1927 aaronmk
            # Get cursor
254
            if cacheable:
255
                query_lookup = _query_lookup(query, params)
256
                try:
257
                    cur = self.query_results[query_lookup]
258
                    used_cache = True
259
                except KeyError: cur = self.DbCursor(self)
260
            else: cur = self.db.cursor()
261
262
            # Run query
263 2148 aaronmk
            cur.execute(query, params)
264 1903 aaronmk
        finally:
265
            if self.log_debug != log_debug_none: # only compute msg if needed
266
                if used_cache: cache_status = 'Cache hit'
267
                elif cacheable: cache_status = 'Cache miss'
268
                else: cache_status = 'Non-cacheable'
269 1927 aaronmk
                self.log_debug(cache_status+': '
270 2170 aaronmk
                    +strings.one_line(str(get_cur_query(cur, query, params))))
271 1903 aaronmk
272
        return cur
273 1914 aaronmk
274
    def is_cached(self, query, params=None):
275
        return _query_lookup(query, params) in self.query_results
276 2139 aaronmk
277
    def with_savepoint(self, func):
278 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
279 2139 aaronmk
        self.run_query('SAVEPOINT '+savepoint)
280
        self._savepoint += 1
281
        try:
282
            try: return_val = func()
283
            finally:
284
                self._savepoint -= 1
285
                assert self._savepoint >= 0
286
        except:
287
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint)
288
            raise
289
        else:
290
            self.run_query('RELEASE SAVEPOINT '+savepoint)
291
            return return_val
292 1849 aaronmk
293 1869 aaronmk
connect = DbConn
294
295 832 aaronmk
##### Querying
296
297 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
298 2085 aaronmk
    '''For params, see DbConn.run_query()'''
299 1894 aaronmk
    return db.run_query(*args, **kw_args)
300 11 aaronmk
301 2068 aaronmk
def mogrify(db, query, params):
302
    module = util.root_module(db.db)
303
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
304
    else: raise NotImplementedError("Can't mogrify query for "+module+
305
        ' database')
306
307 832 aaronmk
##### Recoverable querying
308 15 aaronmk
309 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
310 11 aaronmk
311 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
312 830 aaronmk
    if recover == None: recover = False
313
314 2148 aaronmk
    try:
315
        def run(): return run_raw_query(db, query, params, cacheable)
316
        if recover and not db.is_cached(query, params):
317
            return with_savepoint(db, run)
318
        else: return run() # don't need savepoint if cached
319
    except Exception, e:
320
        if not recover: raise # need savepoint to run index_cols()
321
        msg = str(e)
322
        match = re.search(r'duplicate key value violates unique constraint '
323
            r'"((_?[^\W_]+)_[^"]+)"', msg)
324
        if match:
325
            constraint, table = match.groups()
326
            try: cols = index_cols(db, table, constraint)
327
            except NotImplementedError: raise e
328
            else: raise DuplicateKeyException(cols, e)
329
        match = re.search(r'null value in column "(\w+)" violates not-null '
330
            'constraint', msg)
331
        if match: raise NullValueException([match.group(1)], e)
332
        match = re.search(r'relation "(\w+)" already exists', msg)
333
        if match: raise DuplicateTableException(match.group(1), e)
334
        raise # no specific exception raised
335 830 aaronmk
336 832 aaronmk
##### Basic queries
337
338 2153 aaronmk
def next_version(name):
339
    '''Prepends the version # so it won't be removed if the name is truncated'''
340 2163 aaronmk
    version = 1 # first existing name was version 0
341 2153 aaronmk
    match = re.match(r'^v(\d+)_(.*)$', name)
342
    if match:
343
        version = int(match.group(1))+1
344
        name = match.group(2)
345
    return 'v'+str(version)+'_'+name
346
347 2151 aaronmk
def run_query_into(db, query, params, into_ref=None, *args, **kw_args):
348 2085 aaronmk
    '''Outputs a query to a temp table.
349
    For params, see run_query().
350
    '''
351 2151 aaronmk
    if into_ref == None: return run_query(db, query, params, *args, **kw_args)
352 2085 aaronmk
    else: # place rows in temp table
353 2151 aaronmk
        check_name(into_ref[0])
354 2153 aaronmk
        kw_args['recover'] = True
355
        while True:
356
            try:
357
                return run_query(db, 'CREATE TEMP TABLE '+into_ref[0]+' AS '
358
                    +query, params, *args, **kw_args)
359
                    # CREATE TABLE AS sets rowcount to # rows in query
360
            except DuplicateTableException, e:
361
                into_ref[0] = next_version(into_ref[0])
362
                # try again with next version of name
363 2085 aaronmk
364 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
365
366 2127 aaronmk
join_using = object() # tells mk_select() to join the column with USING
367
368 2187 aaronmk
filter_out = object() # tells mk_select() to filter out rows that match the join
369 2180 aaronmk
370 2121 aaronmk
def mk_select(db, tables, fields=None, conds=None, limit=None, start=None,
371 2120 aaronmk
    order_by=order_by_pkey, table_is_esc=False):
372 1981 aaronmk
    '''
373 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
374 2187 aaronmk
        together: [table0, (table1, joins), ...]
375
376
        joins has the format: dict(right_col=left_col, ...)
377
        * if left_col is join_using, left_col is set to right_col
378
        * if left_col is filter_out, the tables are LEFT JOINed together and the
379
          query is filtered by `right_col IS NULL` (indicating no match)
380 1981 aaronmk
    @param fields Use None to select all fields in the table
381
    @param table_is_esc Whether the table name has already been escaped
382 2054 aaronmk
    @return tuple(query, params)
383 1981 aaronmk
    '''
384 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
385 2058 aaronmk
386 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
387 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
388 2121 aaronmk
    table0 = tables.pop(0) # first table is separate
389
390 1135 aaronmk
    if conds == None: conds = {}
391 135 aaronmk
    assert limit == None or type(limit) == int
392 865 aaronmk
    assert start == None or type(start) == int
393 2120 aaronmk
    if order_by == order_by_pkey:
394 2121 aaronmk
        order_by = pkey(db, table0, recover=True, table_is_esc=table_is_esc)
395
    if not table_is_esc: table0 = esc_name_(table0)
396 865 aaronmk
397 2056 aaronmk
    params = []
398
399 2161 aaronmk
    def parse_col(field, default_table=None):
400 2056 aaronmk
        '''Parses fields'''
401 2176 aaronmk
        if field == None: field = (field,) # for None values, tuple is optional
402 2157 aaronmk
        is_tuple = isinstance(field, tuple)
403
        if is_tuple and len(field) == 1: # field is literal value
404
            value, = field
405 2056 aaronmk
            sql_ = '%s'
406
            params.append(value)
407 2157 aaronmk
        elif is_tuple and len(field) == 2: # field is col with table
408
            table, col = field
409
            if not table_is_esc: table = esc_name_(table)
410
            sql_ = table+'.'+esc_name_(col)
411 2161 aaronmk
        else:
412
            sql_ = esc_name_(field) # field is col name
413
            if default_table != None: sql_ = default_table+'.'+sql_
414 2056 aaronmk
        return sql_
415 11 aaronmk
    def cond(entry):
416 2056 aaronmk
        '''Parses conditions'''
417 13 aaronmk
        col, value = entry
418 2187 aaronmk
        cond_ = parse_col(col)+' '
419 11 aaronmk
        if value == None: cond_ += 'IS'
420
        else: cond_ += '='
421
        cond_ += ' %s'
422
        return cond_
423 2056 aaronmk
424 1135 aaronmk
    query = 'SELECT '
425
    if fields == None: query += '*'
426 2056 aaronmk
    else: query += ', '.join(map(parse_col, fields))
427 2121 aaronmk
    query += ' FROM '+table0
428 865 aaronmk
429 2122 aaronmk
    # Add joins
430
    left_table = table0
431
    for table, joins in tables:
432
        if not table_is_esc: table = esc_name_(table)
433
434 2187 aaronmk
        left_join_ref = [False]
435
436 2122 aaronmk
        def join(entry):
437 2127 aaronmk
            '''Parses non-USING joins'''
438 2124 aaronmk
            right_col, left_col = entry
439 2173 aaronmk
440 2179 aaronmk
            # Parse special values
441 2176 aaronmk
            if left_col == None: left_col = (left_col,)
442
                # for None values, tuple is optional
443 2179 aaronmk
            elif left_col == join_using: left_col = right_col
444 2187 aaronmk
            elif left_col == filter_out:
445 2180 aaronmk
                left_col = right_col
446 2187 aaronmk
                left_join_ref[0] = True
447
                conds[(table, right_col)] = None # filter query by no match
448 2179 aaronmk
449
            # Create SQL
450 2185 aaronmk
            right_col = table+'.'+esc_name_(right_col)
451
            sql_ = right_col+' '
452 2173 aaronmk
            if isinstance(left_col, tuple) and len(left_col) == 1:
453
                # col is literal value
454
                value, = left_col
455
                if value == None: sql_ += 'IS'
456
                else: sql_ += '='
457
                sql_ += ' %s'
458
                params.append(value)
459
            else: # col is name
460
                left_col = parse_col(left_col, left_table)
461
                sql_ += ('= '+left_col+' OR ('+right_col+' IS NULL AND '
462
                    +left_col+' IS NULL)')
463
464
            return sql_
465 2122 aaronmk
466 2187 aaronmk
        # Create join condition and determine join type
467 2127 aaronmk
        if reduce(operator.and_, (v == join_using for v in joins.itervalues())):
468 2179 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
469 2187 aaronmk
            join_cond = 'USING ('+(', '.join(joins.iterkeys()))+')'
470
        else: join_cond = 'ON '+(' AND '.join(map(join, joins.iteritems())))
471 2127 aaronmk
472 2187 aaronmk
        # Create join
473
        if left_join_ref[0]: query += ' LEFT'
474
        query += ' JOIN '+table+' '+join_cond
475
476 2122 aaronmk
        left_table = table
477
478 865 aaronmk
    missing = True
479 89 aaronmk
    if conds != {}:
480 2122 aaronmk
        query += ' WHERE '+(' AND '.join(map(cond, conds.iteritems())))
481 2056 aaronmk
        params += conds.values()
482 865 aaronmk
        missing = False
483 2186 aaronmk
    if order_by != None: query += ' ORDER BY '+parse_col(order_by, table0)
484 865 aaronmk
    if limit != None: query += ' LIMIT '+str(limit); missing = False
485
    if start != None:
486
        if start != 0: query += ' OFFSET '+str(start)
487
        missing = False
488
    if missing: warnings.warn(DbWarning(
489
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
490
491 2056 aaronmk
    return (query, params)
492 11 aaronmk
493 2054 aaronmk
def select(db, *args, **kw_args):
494
    '''For params, see mk_select() and run_query()'''
495
    recover = kw_args.pop('recover', None)
496
    cacheable = kw_args.pop('cacheable', True)
497
498
    query, params = mk_select(db, *args, **kw_args)
499
    return run_query(db, query, params, recover, cacheable)
500
501 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
502 2070 aaronmk
    returning=None, embeddable=False, table_is_esc=False):
503 1960 aaronmk
    '''
504
    @param returning str|None An inserted column (such as pkey) to return
505 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
506 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
507
        query will be fully cached, not just if it raises an exception.
508 1960 aaronmk
    @param table_is_esc Whether the table name has already been escaped
509
    '''
510 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
511
    if cols == []: cols = None # no cols (all defaults) = unknown col names
512 1960 aaronmk
    if not table_is_esc: check_name(table)
513 2063 aaronmk
514
    # Build query
515
    query = 'INSERT INTO '+table
516
    if cols != None:
517
        map(check_name, cols)
518
        query += ' ('+', '.join(cols)+')'
519
    query += ' '+select_query
520
521
    if returning != None:
522
        check_name(returning)
523
        query += ' RETURNING '+returning
524
525 2070 aaronmk
    if embeddable:
526
        # Create function
527 2083 aaronmk
        function = 'pg_temp.'+('_'.join(map(clean_name,
528
            ['insert', table] + cols)))
529 2070 aaronmk
        return_type = 'SETOF '+table+'.'+returning+'%TYPE'
530
        function_query = '''\
531 2083 aaronmk
CREATE OR REPLACE FUNCTION '''+function+'''() RETURNS '''+return_type+'''
532 2070 aaronmk
    LANGUAGE sql
533
    AS $$'''+mogrify(db, query, params)+''';$$;
534
'''
535 2172 aaronmk
        run_query(db, function_query)
536
            # don't cache because function def could change and then change back
537 2070 aaronmk
538
        # Return query that uses function
539 2134 aaronmk
        return mk_select(db, function+'() AS f ('+returning+')', start=0,
540
            order_by=None, table_is_esc=True)# AS clause requires function alias
541 2070 aaronmk
542 2066 aaronmk
    return (query, params)
543
544
def insert_select(db, *args, **kw_args):
545 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
546 2152 aaronmk
    @param into_ref List with name of temp table to place RETURNING values in
547 2072 aaronmk
    '''
548 2151 aaronmk
    into_ref = kw_args.pop('into_ref', None)
549
    if into_ref != None: kw_args['embeddable'] = True
550 2066 aaronmk
    recover = kw_args.pop('recover', None)
551
    cacheable = kw_args.pop('cacheable', True)
552
553
    query, params = mk_insert_select(db, *args, **kw_args)
554 2153 aaronmk
    return run_query_into(db, query, params, into_ref, recover=recover,
555
        cacheable=cacheable)
556 2063 aaronmk
557 2066 aaronmk
default = object() # tells insert() to use the default value for a column
558
559 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
560 2085 aaronmk
    '''For params, see insert_select()'''
561 1960 aaronmk
    if lists.is_seq(row): cols = None
562
    else:
563
        cols = row.keys()
564
        row = row.values()
565
    row = list(row) # ensure that "!= []" works
566
567 1961 aaronmk
    # Check for special values
568
    labels = []
569
    values = []
570
    for value in row:
571
        if value == default: labels.append('DEFAULT')
572
        else:
573
            labels.append('%s')
574
            values.append(value)
575
576
    # Build query
577 2063 aaronmk
    if values != []: query = ' VALUES ('+(', '.join(labels))+')'
578
    else: query = None
579 1554 aaronmk
580 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
581 11 aaronmk
582 135 aaronmk
def last_insert_id(db):
583 1849 aaronmk
    module = util.root_module(db.db)
584 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
585
    elif module == 'MySQLdb': return db.insert_id()
586
    else: return None
587 13 aaronmk
588 1968 aaronmk
def truncate(db, table, schema='public'):
589
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
590 832 aaronmk
591
##### Database structure queries
592
593 2084 aaronmk
def pkey(db, table, recover=None, table_is_esc=False):
594 832 aaronmk
    '''Assumed to be first column in table'''
595 2120 aaronmk
    return col_names(select(db, table, limit=0, order_by=None, recover=recover,
596 2084 aaronmk
        table_is_esc=table_is_esc)).next()
597 832 aaronmk
598 853 aaronmk
def index_cols(db, table, index):
599
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
600
    automatically created. When you don't know whether something is a UNIQUE
601
    constraint or a UNIQUE index, use this function.'''
602
    check_name(table)
603
    check_name(index)
604 1909 aaronmk
    module = util.root_module(db.db)
605
    if module == 'psycopg2':
606
        return list(values(run_query(db, '''\
607 853 aaronmk
SELECT attname
608 866 aaronmk
FROM
609
(
610
        SELECT attnum, attname
611
        FROM pg_index
612
        JOIN pg_class index ON index.oid = indexrelid
613
        JOIN pg_class table_ ON table_.oid = indrelid
614
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
615
        WHERE
616
            table_.relname = %(table)s
617
            AND index.relname = %(index)s
618
    UNION
619
        SELECT attnum, attname
620
        FROM
621
        (
622
            SELECT
623
                indrelid
624
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
625
                    AS indkey
626
            FROM pg_index
627
            JOIN pg_class index ON index.oid = indexrelid
628
            JOIN pg_class table_ ON table_.oid = indrelid
629
            WHERE
630
                table_.relname = %(table)s
631
                AND index.relname = %(index)s
632
        ) s
633
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
634
) s
635 853 aaronmk
ORDER BY attnum
636
''',
637 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
638
    else: raise NotImplementedError("Can't list index columns for "+module+
639
        ' database')
640 853 aaronmk
641 464 aaronmk
def constraint_cols(db, table, constraint):
642
    check_name(table)
643
    check_name(constraint)
644 1849 aaronmk
    module = util.root_module(db.db)
645 464 aaronmk
    if module == 'psycopg2':
646
        return list(values(run_query(db, '''\
647
SELECT attname
648
FROM pg_constraint
649
JOIN pg_class ON pg_class.oid = conrelid
650
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
651
WHERE
652
    relname = %(table)s
653
    AND conname = %(constraint)s
654
ORDER BY attnum
655
''',
656
            {'table': table, 'constraint': constraint})))
657
    else: raise NotImplementedError("Can't list constraint columns for "+module+
658
        ' database')
659
660 2096 aaronmk
row_num_col = '_row_num'
661
662 2086 aaronmk
def add_row_num(db, table):
663 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
664
    be the primary key.'''
665 2086 aaronmk
    check_name(table)
666 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
667 2117 aaronmk
        +' serial NOT NULL PRIMARY KEY')
668 2086 aaronmk
669 1968 aaronmk
def tables(db, schema='public', table_like='%'):
670 1849 aaronmk
    module = util.root_module(db.db)
671 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
672 832 aaronmk
    if module == 'psycopg2':
673 1968 aaronmk
        return values(run_query(db, '''\
674
SELECT tablename
675
FROM pg_tables
676
WHERE
677
    schemaname = %(schema)s
678
    AND tablename LIKE %(table_like)s
679
ORDER BY tablename
680
''',
681
            params, cacheable=True))
682
    elif module == 'MySQLdb':
683
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
684
            cacheable=True))
685 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
686 830 aaronmk
687 833 aaronmk
##### Database management
688
689 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
690
    '''For kw_args, see tables()'''
691
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
692 833 aaronmk
693 832 aaronmk
##### Heuristic queries
694
695 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
696 1554 aaronmk
    '''Recovers from errors.
697 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
698
    '''
699 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
700
701 471 aaronmk
    try:
702 2149 aaronmk
        cur = insert(db, table, row, pkey_, recover=True)
703 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
704
            row_ct_ref[0] += cur.rowcount
705
        return value(cur)
706 471 aaronmk
    except DuplicateKeyException, e:
707 2104 aaronmk
        return value(select(db, table, [pkey_],
708 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
709 471 aaronmk
710 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
711 830 aaronmk
    '''Recovers from errors'''
712
    try: return value(select(db, table, [pkey], row, 1, recover=True))
713 14 aaronmk
    except StopIteration:
714 40 aaronmk
        if not create: raise
715 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
716 2078 aaronmk
717 2134 aaronmk
def put_table(db, out_table, in_tables, mapping, limit=None, start=0,
718
    row_ct_ref=None, table_is_esc=False):
719 2078 aaronmk
    '''Recovers from errors.
720
    Only works under PostgreSQL (uses INSERT RETURNING).
721 2131 aaronmk
    @param in_tables The main input table to select from, followed by a list of
722
        tables to join with it using the main input table's pkey
723 2133 aaronmk
    @return (table, col) Where the pkeys (from INSERT RETURNING) are made
724 2078 aaronmk
        available
725
    '''
726 2162 aaronmk
    temp_suffix = clean_name(out_table)
727 2158 aaronmk
        # suffix, not prefix, so main name won't be removed if name is truncated
728
    pkeys_ref = ['pkeys_'+temp_suffix]
729 2131 aaronmk
730 2132 aaronmk
    # Join together input tables
731 2131 aaronmk
    in_tables = in_tables[:] # don't modify input!
732
    in_tables0 = in_tables.pop(0) # first table is separate
733
    in_pkey = pkey(db, in_tables0, recover=True, table_is_esc=table_is_esc)
734 2178 aaronmk
    insert_joins = [in_tables0]+[(t, {in_pkey: join_using}) for t in in_tables]
735 2131 aaronmk
736 2142 aaronmk
    out_pkey = pkey(db, out_table, recover=True, table_is_esc=table_is_esc)
737
    pkeys_cols = [in_pkey, out_pkey]
738
739 2132 aaronmk
    def mk_select_(cols):
740 2178 aaronmk
        return mk_select(db, insert_joins, cols, limit=limit, start=start,
741 2134 aaronmk
            table_is_esc=table_is_esc)
742 2132 aaronmk
743 2158 aaronmk
    out_pkeys_ref = ['out_pkeys_'+temp_suffix]
744 2181 aaronmk
    def insert_(pkeys_table_exists=False):
745 2148 aaronmk
        '''Inserts and capture output pkeys.'''
746 2132 aaronmk
        cur = insert_select(db, out_table, mapping.keys(),
747 2133 aaronmk
            *mk_select_(mapping.values()), returning=out_pkey,
748 2155 aaronmk
            into_ref=out_pkeys_ref, recover=True, table_is_esc=table_is_esc)
749 2078 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
750
            row_ct_ref[0] += cur.rowcount
751 2155 aaronmk
            add_row_num(db, out_pkeys_ref[0]) # for joining it with input pkeys
752 2086 aaronmk
753 2132 aaronmk
        # Get input pkeys corresponding to rows in insert
754 2158 aaronmk
        in_pkeys_ref = ['in_pkeys_'+temp_suffix]
755 2155 aaronmk
        run_query_into(db, *mk_select_([in_pkey]), into_ref=in_pkeys_ref)
756
        add_row_num(db, in_pkeys_ref[0]) # for joining it with output pkeys
757 2086 aaronmk
758 2155 aaronmk
        # Join together output and input pkeys
759 2181 aaronmk
        select_query = mk_select(db, [in_pkeys_ref[0],
760
            (out_pkeys_ref[0], {row_num_col: join_using})], pkeys_cols, start=0)
761
        if pkeys_table_exists:
762
            insert_select(db, pkeys_ref[0], pkeys_cols, *select_query)
763
        else: run_query_into(db, *select_query, into_ref=pkeys_ref)
764 2132 aaronmk
765 2142 aaronmk
    # Do inserts and selects
766 2148 aaronmk
    try: insert_()
767 2142 aaronmk
    except DuplicateKeyException, e:
768
        join_cols = util.dict_subset_right_join(mapping, e.cols)
769 2178 aaronmk
        select_joins = insert_joins + [(out_table, join_cols)]
770
771
        # Get pkeys of already existing rows
772
        run_query_into(db, *mk_select(db, select_joins, pkeys_cols, start=0,
773 2154 aaronmk
            table_is_esc=table_is_esc), into_ref=pkeys_ref, recover=True)
774 2142 aaronmk
775 2154 aaronmk
    return (pkeys_ref[0], out_pkey)
776 2115 aaronmk
777
##### Data cleanup
778
779
def cleanup_table(db, table, cols, table_is_esc=False):
780
    def esc_name_(name): return esc_name(db, name)
781
782
    if not table_is_esc: check_name(table)
783
    cols = map(esc_name_, cols)
784
785
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
786
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
787
            for col in cols))),
788
        dict(null0='', null1=r'\N'))