Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 2127 aaronmk
import operator
5 11 aaronmk
import re
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 1889 aaronmk
from Proxy import Proxy
13 1872 aaronmk
import rand
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 2170 aaronmk
def get_cur_query(cur, input_query=None, input_params=None):
20 2168 aaronmk
    raw_query = None
21
    if hasattr(cur, 'query'): raw_query = cur.query
22
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
23 2170 aaronmk
24
    if raw_query != None: return raw_query
25
    else: return repr(input_query)+' % '+repr(input_params)
26 14 aaronmk
27 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
28
    '''For params, see get_cur_query()'''
29
    exc.add_msg(e, 'query: '+str(get_cur_query(*args, **kw_args)))
30 135 aaronmk
31 300 aaronmk
class DbException(exc.ExceptionWithCause):
32 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
33 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
34 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
35
36 2143 aaronmk
class ExceptionWithName(DbException):
37
    def __init__(self, name, cause=None):
38 2145 aaronmk
        DbException.__init__(self, 'for name: '+str(name), cause)
39 2143 aaronmk
        self.name = name
40 360 aaronmk
41 468 aaronmk
class ExceptionWithColumns(DbException):
42
    def __init__(self, cols, cause=None):
43 2145 aaronmk
        DbException.__init__(self, 'for columns: '+(', '.join(cols)), cause)
44 468 aaronmk
        self.cols = cols
45 11 aaronmk
46 2143 aaronmk
class NameException(DbException): pass
47
48 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
49 13 aaronmk
50 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
51 13 aaronmk
52 2143 aaronmk
class DuplicateTableException(ExceptionWithName): pass
53
54 2188 aaronmk
class DuplicateFunctionException(ExceptionWithName): pass
55
56 89 aaronmk
class EmptyRowException(DbException): pass
57
58 865 aaronmk
##### Warnings
59
60
class DbWarning(UserWarning): pass
61
62 1930 aaronmk
##### Result retrieval
63
64
def col_names(cur): return (col[0] for col in cur.description)
65
66
def rows(cur): return iter(lambda: cur.fetchone(), None)
67
68
def consume_rows(cur):
69
    '''Used to fetch all rows so result will be cached'''
70
    iters.consume_iter(rows(cur))
71
72
def next_row(cur): return rows(cur).next()
73
74
def row(cur):
75
    row_ = next_row(cur)
76
    consume_rows(cur)
77
    return row_
78
79
def next_value(cur): return next_row(cur)[0]
80
81
def value(cur): return row(cur)[0]
82
83
def values(cur): return iters.func_iter(lambda: next_value(cur))
84
85
def value_or_none(cur):
86
    try: return value(cur)
87
    except StopIteration: return None
88
89 2101 aaronmk
##### Input validation
90
91
def clean_name(name): return re.sub(r'\W', r'', name)
92
93
def check_name(name):
94
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
95
        +'" may contain only alphanumeric characters and _')
96
97
def esc_name_by_module(module, name, ignore_case=False):
98
    if module == 'psycopg2':
99
        if ignore_case:
100
            # Don't enclose in quotes because this disables case-insensitivity
101
            check_name(name)
102
            return name
103
        else: quote = '"'
104
    elif module == 'MySQLdb': quote = '`'
105
    else: raise NotImplementedError("Can't escape name for "+module+' database')
106
    return quote + name.replace(quote, '') + quote
107
108
def esc_name_by_engine(engine, name, **kw_args):
109
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
110
111
def esc_name(db, name, **kw_args):
112
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
113
114
def qual_name(db, schema, table):
115
    def esc_name_(name): return esc_name(db, name)
116
    table = esc_name_(table)
117
    if schema != None: return esc_name_(schema)+'.'+table
118
    else: return table
119
120 1869 aaronmk
##### Database connections
121 1849 aaronmk
122 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
123 1926 aaronmk
124 1869 aaronmk
db_engines = {
125
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
126
    'PostgreSQL': ('psycopg2', {}),
127
}
128
129
DatabaseErrors_set = set([DbException])
130
DatabaseErrors = tuple(DatabaseErrors_set)
131
132
def _add_module(module):
133
    DatabaseErrors_set.add(module.DatabaseError)
134
    global DatabaseErrors
135
    DatabaseErrors = tuple(DatabaseErrors_set)
136
137
def db_config_str(db_config):
138
    return db_config['engine']+' database '+db_config['database']
139
140 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
141 1894 aaronmk
142 1901 aaronmk
log_debug_none = lambda msg: None
143
144 1849 aaronmk
class DbConn:
145 2190 aaronmk
    def __init__(self, db_config, serializable=True, autocommit=False,
146
        caching=True, log_debug=log_debug_none):
147 1869 aaronmk
        self.db_config = db_config
148
        self.serializable = serializable
149 2190 aaronmk
        self.autocommit = autocommit
150
        self.caching = caching
151 1901 aaronmk
        self.log_debug = log_debug
152 2193 aaronmk
        self.debug = log_debug != log_debug_none
153 1869 aaronmk
154
        self.__db = None
155 1889 aaronmk
        self.query_results = {}
156 2139 aaronmk
        self._savepoint = 0
157 1869 aaronmk
158
    def __getattr__(self, name):
159
        if name == '__dict__': raise Exception('getting __dict__')
160
        if name == 'db': return self._db()
161
        else: raise AttributeError()
162
163
    def __getstate__(self):
164
        state = copy.copy(self.__dict__) # shallow copy
165 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
166 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
167
        return state
168
169 2165 aaronmk
    def connected(self): return self.__db != None
170
171 1869 aaronmk
    def _db(self):
172
        if self.__db == None:
173
            # Process db_config
174
            db_config = self.db_config.copy() # don't modify input!
175 2097 aaronmk
            schemas = db_config.pop('schemas', None)
176 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
177
            module = __import__(module_name)
178
            _add_module(module)
179
            for orig, new in mappings.iteritems():
180
                try: util.rename_key(db_config, orig, new)
181
                except KeyError: pass
182
183
            # Connect
184
            self.__db = module.connect(**db_config)
185
186
            # Configure connection
187 2190 aaronmk
            if self.serializable and not self.autocommit:
188
                self.db.set_session(isolation_level='SERIALIZABLE')
189 2101 aaronmk
            if schemas != None:
190
                schemas_ = ''.join((esc_name(self, s)+', '
191
                    for s in schemas.split(',')))
192
                run_raw_query(self, "SELECT set_config('search_path', \
193
%s || current_setting('search_path'), false)", [schemas_])
194 1869 aaronmk
195
        return self.__db
196 1889 aaronmk
197 1891 aaronmk
    class DbCursor(Proxy):
198 1927 aaronmk
        def __init__(self, outer):
199 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
200 2191 aaronmk
            self.outer = outer
201 1927 aaronmk
            self.query_results = outer.query_results
202 1894 aaronmk
            self.query_lookup = None
203 1891 aaronmk
            self.result = []
204 1889 aaronmk
205 1894 aaronmk
        def execute(self, query, params=None):
206 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
207 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
208 2148 aaronmk
            try:
209 2191 aaronmk
                try:
210
                    return_value = self.inner.execute(query, params)
211
                    self.outer.do_autocommit()
212 2148 aaronmk
                finally: self.query = get_cur_query(self.inner)
213 1904 aaronmk
            except Exception, e:
214 2170 aaronmk
                _add_cursor_info(e, self, query, params)
215 1904 aaronmk
                self.result = e # cache the exception as the result
216
                self._cache_result()
217
                raise
218 1930 aaronmk
            # Fetch all rows so result will be cached
219
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
220 1894 aaronmk
            return return_value
221
222 1891 aaronmk
        def fetchone(self):
223
            row = self.inner.fetchone()
224 1899 aaronmk
            if row != None: self.result.append(row)
225
            # otherwise, fetched all rows
226 1904 aaronmk
            else: self._cache_result()
227
            return row
228
229
        def _cache_result(self):
230 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
231
            # idempotent, but an invalid insert will always be invalid
232 1930 aaronmk
            if self.query_results != None and (not self._is_insert
233 1906 aaronmk
                or isinstance(self.result, Exception)):
234
235 1894 aaronmk
                assert self.query_lookup != None
236 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
237
                    util.dict_subset(dicts.AttrsDictView(self),
238
                    ['query', 'result', 'rowcount', 'description']))
239 1906 aaronmk
240 1916 aaronmk
        class CacheCursor:
241
            def __init__(self, cached_result): self.__dict__ = cached_result
242
243 1927 aaronmk
            def execute(self, *args, **kw_args):
244 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
245
                # otherwise, result is a rows list
246
                self.iter = iter(self.result)
247
248
            def fetchone(self):
249
                try: return self.iter.next()
250
                except StopIteration: return None
251 1891 aaronmk
252 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
253 2148 aaronmk
        '''Translates known DB errors to typed exceptions:
254
        See self.DbCursor.execute().'''
255 2167 aaronmk
        assert query != None
256
257 2047 aaronmk
        if not self.caching: cacheable = False
258 1903 aaronmk
        used_cache = False
259
        try:
260 1927 aaronmk
            # Get cursor
261
            if cacheable:
262
                query_lookup = _query_lookup(query, params)
263
                try:
264
                    cur = self.query_results[query_lookup]
265
                    used_cache = True
266
                except KeyError: cur = self.DbCursor(self)
267
            else: cur = self.db.cursor()
268
269
            # Run query
270 2148 aaronmk
            cur.execute(query, params)
271 1903 aaronmk
        finally:
272 2193 aaronmk
            if self.debug: # only compute msg if needed
273 1903 aaronmk
                if used_cache: cache_status = 'Cache hit'
274
                elif cacheable: cache_status = 'Cache miss'
275
                else: cache_status = 'Non-cacheable'
276 1927 aaronmk
                self.log_debug(cache_status+': '
277 2170 aaronmk
                    +strings.one_line(str(get_cur_query(cur, query, params))))
278 1903 aaronmk
279
        return cur
280 1914 aaronmk
281
    def is_cached(self, query, params=None):
282
        return _query_lookup(query, params) in self.query_results
283 2139 aaronmk
284
    def with_savepoint(self, func):
285 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
286 2139 aaronmk
        self.run_query('SAVEPOINT '+savepoint)
287
        self._savepoint += 1
288
        try:
289
            try: return_val = func()
290
            finally:
291
                self._savepoint -= 1
292
                assert self._savepoint >= 0
293
        except:
294
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint)
295
            raise
296
        else:
297
            self.run_query('RELEASE SAVEPOINT '+savepoint)
298 2191 aaronmk
            self.do_autocommit()
299 2139 aaronmk
            return return_val
300 2191 aaronmk
301
    def do_autocommit(self):
302
        '''Autocommits if outside savepoint'''
303
        assert self._savepoint >= 0
304
        if self.autocommit and self._savepoint == 0:
305
            self.log_debug('Autocommiting')
306
            self.db.commit()
307 1849 aaronmk
308 1869 aaronmk
connect = DbConn
309
310 832 aaronmk
##### Querying
311
312 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
313 2085 aaronmk
    '''For params, see DbConn.run_query()'''
314 1894 aaronmk
    return db.run_query(*args, **kw_args)
315 11 aaronmk
316 2068 aaronmk
def mogrify(db, query, params):
317
    module = util.root_module(db.db)
318
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
319
    else: raise NotImplementedError("Can't mogrify query for "+module+
320
        ' database')
321
322 832 aaronmk
##### Recoverable querying
323 15 aaronmk
324 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
325 11 aaronmk
326 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
327 830 aaronmk
    if recover == None: recover = False
328
329 2148 aaronmk
    try:
330
        def run(): return run_raw_query(db, query, params, cacheable)
331
        if recover and not db.is_cached(query, params):
332
            return with_savepoint(db, run)
333
        else: return run() # don't need savepoint if cached
334
    except Exception, e:
335
        if not recover: raise # need savepoint to run index_cols()
336
        msg = str(e)
337
        match = re.search(r'duplicate key value violates unique constraint '
338
            r'"((_?[^\W_]+)_[^"]+)"', msg)
339
        if match:
340
            constraint, table = match.groups()
341
            try: cols = index_cols(db, table, constraint)
342
            except NotImplementedError: raise e
343
            else: raise DuplicateKeyException(cols, e)
344
        match = re.search(r'null value in column "(\w+)" violates not-null '
345
            'constraint', msg)
346
        if match: raise NullValueException([match.group(1)], e)
347
        match = re.search(r'relation "(\w+)" already exists', msg)
348
        if match: raise DuplicateTableException(match.group(1), e)
349 2188 aaronmk
        match = re.search(r'function "(\w+)" already exists', msg)
350
        if match: raise DuplicateFunctionException(match.group(1), e)
351 2148 aaronmk
        raise # no specific exception raised
352 830 aaronmk
353 832 aaronmk
##### Basic queries
354
355 2153 aaronmk
def next_version(name):
356
    '''Prepends the version # so it won't be removed if the name is truncated'''
357 2163 aaronmk
    version = 1 # first existing name was version 0
358 2153 aaronmk
    match = re.match(r'^v(\d+)_(.*)$', name)
359
    if match:
360
        version = int(match.group(1))+1
361
        name = match.group(2)
362
    return 'v'+str(version)+'_'+name
363
364 2151 aaronmk
def run_query_into(db, query, params, into_ref=None, *args, **kw_args):
365 2085 aaronmk
    '''Outputs a query to a temp table.
366
    For params, see run_query().
367
    '''
368 2151 aaronmk
    if into_ref == None: return run_query(db, query, params, *args, **kw_args)
369 2085 aaronmk
    else: # place rows in temp table
370 2151 aaronmk
        check_name(into_ref[0])
371 2153 aaronmk
        kw_args['recover'] = True
372
        while True:
373
            try:
374 2194 aaronmk
                create_query = 'CREATE'
375
                if not db.debug: create_query += ' TEMP'
376
                create_query += ' TABLE '+into_ref[0]+' AS '+query
377
378
                return run_query(db, create_query, params, *args, **kw_args)
379 2153 aaronmk
                    # CREATE TABLE AS sets rowcount to # rows in query
380
            except DuplicateTableException, e:
381
                into_ref[0] = next_version(into_ref[0])
382
                # try again with next version of name
383 2085 aaronmk
384 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
385
386 2127 aaronmk
join_using = object() # tells mk_select() to join the column with USING
387
388 2187 aaronmk
filter_out = object() # tells mk_select() to filter out rows that match the join
389 2180 aaronmk
390 2121 aaronmk
def mk_select(db, tables, fields=None, conds=None, limit=None, start=None,
391 2120 aaronmk
    order_by=order_by_pkey, table_is_esc=False):
392 1981 aaronmk
    '''
393 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
394 2187 aaronmk
        together: [table0, (table1, joins), ...]
395
396
        joins has the format: dict(right_col=left_col, ...)
397
        * if left_col is join_using, left_col is set to right_col
398
        * if left_col is filter_out, the tables are LEFT JOINed together and the
399
          query is filtered by `right_col IS NULL` (indicating no match)
400 1981 aaronmk
    @param fields Use None to select all fields in the table
401
    @param table_is_esc Whether the table name has already been escaped
402 2054 aaronmk
    @return tuple(query, params)
403 1981 aaronmk
    '''
404 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
405 2058 aaronmk
406 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
407 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
408 2121 aaronmk
    table0 = tables.pop(0) # first table is separate
409
410 1135 aaronmk
    if conds == None: conds = {}
411 135 aaronmk
    assert limit == None or type(limit) == int
412 865 aaronmk
    assert start == None or type(start) == int
413 2120 aaronmk
    if order_by == order_by_pkey:
414 2121 aaronmk
        order_by = pkey(db, table0, recover=True, table_is_esc=table_is_esc)
415
    if not table_is_esc: table0 = esc_name_(table0)
416 865 aaronmk
417 2056 aaronmk
    params = []
418
419 2161 aaronmk
    def parse_col(field, default_table=None):
420 2056 aaronmk
        '''Parses fields'''
421 2176 aaronmk
        if field == None: field = (field,) # for None values, tuple is optional
422 2157 aaronmk
        is_tuple = isinstance(field, tuple)
423
        if is_tuple and len(field) == 1: # field is literal value
424
            value, = field
425 2056 aaronmk
            sql_ = '%s'
426
            params.append(value)
427 2157 aaronmk
        elif is_tuple and len(field) == 2: # field is col with table
428
            table, col = field
429
            if not table_is_esc: table = esc_name_(table)
430
            sql_ = table+'.'+esc_name_(col)
431 2161 aaronmk
        else:
432
            sql_ = esc_name_(field) # field is col name
433
            if default_table != None: sql_ = default_table+'.'+sql_
434 2056 aaronmk
        return sql_
435 11 aaronmk
    def cond(entry):
436 2056 aaronmk
        '''Parses conditions'''
437 13 aaronmk
        col, value = entry
438 2187 aaronmk
        cond_ = parse_col(col)+' '
439 11 aaronmk
        if value == None: cond_ += 'IS'
440
        else: cond_ += '='
441
        cond_ += ' %s'
442
        return cond_
443 2056 aaronmk
444 1135 aaronmk
    query = 'SELECT '
445
    if fields == None: query += '*'
446 2056 aaronmk
    else: query += ', '.join(map(parse_col, fields))
447 2121 aaronmk
    query += ' FROM '+table0
448 865 aaronmk
449 2122 aaronmk
    # Add joins
450
    left_table = table0
451
    for table, joins in tables:
452
        if not table_is_esc: table = esc_name_(table)
453
454 2187 aaronmk
        left_join_ref = [False]
455
456 2122 aaronmk
        def join(entry):
457 2127 aaronmk
            '''Parses non-USING joins'''
458 2124 aaronmk
            right_col, left_col = entry
459 2173 aaronmk
460 2179 aaronmk
            # Parse special values
461 2176 aaronmk
            if left_col == None: left_col = (left_col,)
462
                # for None values, tuple is optional
463 2179 aaronmk
            elif left_col == join_using: left_col = right_col
464 2187 aaronmk
            elif left_col == filter_out:
465 2180 aaronmk
                left_col = right_col
466 2187 aaronmk
                left_join_ref[0] = True
467
                conds[(table, right_col)] = None # filter query by no match
468 2179 aaronmk
469
            # Create SQL
470 2185 aaronmk
            right_col = table+'.'+esc_name_(right_col)
471
            sql_ = right_col+' '
472 2173 aaronmk
            if isinstance(left_col, tuple) and len(left_col) == 1:
473
                # col is literal value
474
                value, = left_col
475
                if value == None: sql_ += 'IS'
476
                else: sql_ += '='
477
                sql_ += ' %s'
478
                params.append(value)
479
            else: # col is name
480
                left_col = parse_col(left_col, left_table)
481
                sql_ += ('= '+left_col+' OR ('+right_col+' IS NULL AND '
482
                    +left_col+' IS NULL)')
483
484
            return sql_
485 2122 aaronmk
486 2187 aaronmk
        # Create join condition and determine join type
487 2127 aaronmk
        if reduce(operator.and_, (v == join_using for v in joins.itervalues())):
488 2179 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
489 2187 aaronmk
            join_cond = 'USING ('+(', '.join(joins.iterkeys()))+')'
490
        else: join_cond = 'ON '+(' AND '.join(map(join, joins.iteritems())))
491 2127 aaronmk
492 2187 aaronmk
        # Create join
493
        if left_join_ref[0]: query += ' LEFT'
494
        query += ' JOIN '+table+' '+join_cond
495
496 2122 aaronmk
        left_table = table
497
498 865 aaronmk
    missing = True
499 89 aaronmk
    if conds != {}:
500 2122 aaronmk
        query += ' WHERE '+(' AND '.join(map(cond, conds.iteritems())))
501 2056 aaronmk
        params += conds.values()
502 865 aaronmk
        missing = False
503 2186 aaronmk
    if order_by != None: query += ' ORDER BY '+parse_col(order_by, table0)
504 865 aaronmk
    if limit != None: query += ' LIMIT '+str(limit); missing = False
505
    if start != None:
506
        if start != 0: query += ' OFFSET '+str(start)
507
        missing = False
508
    if missing: warnings.warn(DbWarning(
509
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
510
511 2056 aaronmk
    return (query, params)
512 11 aaronmk
513 2054 aaronmk
def select(db, *args, **kw_args):
514
    '''For params, see mk_select() and run_query()'''
515
    recover = kw_args.pop('recover', None)
516
    cacheable = kw_args.pop('cacheable', True)
517
518
    query, params = mk_select(db, *args, **kw_args)
519
    return run_query(db, query, params, recover, cacheable)
520
521 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
522 2070 aaronmk
    returning=None, embeddable=False, table_is_esc=False):
523 1960 aaronmk
    '''
524
    @param returning str|None An inserted column (such as pkey) to return
525 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
526 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
527
        query will be fully cached, not just if it raises an exception.
528 1960 aaronmk
    @param table_is_esc Whether the table name has already been escaped
529
    '''
530 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
531
    if cols == []: cols = None # no cols (all defaults) = unknown col names
532 1960 aaronmk
    if not table_is_esc: check_name(table)
533 2063 aaronmk
534
    # Build query
535
    query = 'INSERT INTO '+table
536
    if cols != None:
537
        map(check_name, cols)
538
        query += ' ('+', '.join(cols)+')'
539
    query += ' '+select_query
540
541
    if returning != None:
542
        check_name(returning)
543
        query += ' RETURNING '+returning
544
545 2070 aaronmk
    if embeddable:
546
        # Create function
547 2189 aaronmk
        function_name = '_'.join(map(clean_name, ['insert', table] + cols))
548 2070 aaronmk
        return_type = 'SETOF '+table+'.'+returning+'%TYPE'
549 2189 aaronmk
        while True:
550
            try:
551 2194 aaronmk
                function = function_name
552
                if not db.debug: function = 'pg_temp.'+function
553
554 2189 aaronmk
                function_query = '''\
555
CREATE FUNCTION '''+function+'''() RETURNS '''+return_type+'''
556 2070 aaronmk
    LANGUAGE sql
557
    AS $$'''+mogrify(db, query, params)+''';$$;
558
'''
559 2189 aaronmk
                run_query(db, function_query, recover=True, cacheable=True)
560
                break # this version was successful
561
            except DuplicateFunctionException, e:
562
                function_name = next_version(function_name)
563
                # try again with next version of name
564 2070 aaronmk
565
        # Return query that uses function
566 2134 aaronmk
        return mk_select(db, function+'() AS f ('+returning+')', start=0,
567
            order_by=None, table_is_esc=True)# AS clause requires function alias
568 2070 aaronmk
569 2066 aaronmk
    return (query, params)
570
571
def insert_select(db, *args, **kw_args):
572 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
573 2152 aaronmk
    @param into_ref List with name of temp table to place RETURNING values in
574 2072 aaronmk
    '''
575 2151 aaronmk
    into_ref = kw_args.pop('into_ref', None)
576
    if into_ref != None: kw_args['embeddable'] = True
577 2066 aaronmk
    recover = kw_args.pop('recover', None)
578
    cacheable = kw_args.pop('cacheable', True)
579
580
    query, params = mk_insert_select(db, *args, **kw_args)
581 2153 aaronmk
    return run_query_into(db, query, params, into_ref, recover=recover,
582
        cacheable=cacheable)
583 2063 aaronmk
584 2066 aaronmk
default = object() # tells insert() to use the default value for a column
585
586 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
587 2085 aaronmk
    '''For params, see insert_select()'''
588 1960 aaronmk
    if lists.is_seq(row): cols = None
589
    else:
590
        cols = row.keys()
591
        row = row.values()
592
    row = list(row) # ensure that "!= []" works
593
594 1961 aaronmk
    # Check for special values
595
    labels = []
596
    values = []
597
    for value in row:
598
        if value == default: labels.append('DEFAULT')
599
        else:
600
            labels.append('%s')
601
            values.append(value)
602
603
    # Build query
604 2063 aaronmk
    if values != []: query = ' VALUES ('+(', '.join(labels))+')'
605
    else: query = None
606 1554 aaronmk
607 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
608 11 aaronmk
609 135 aaronmk
def last_insert_id(db):
610 1849 aaronmk
    module = util.root_module(db.db)
611 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
612
    elif module == 'MySQLdb': return db.insert_id()
613
    else: return None
614 13 aaronmk
615 1968 aaronmk
def truncate(db, table, schema='public'):
616
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
617 832 aaronmk
618
##### Database structure queries
619
620 2084 aaronmk
def pkey(db, table, recover=None, table_is_esc=False):
621 832 aaronmk
    '''Assumed to be first column in table'''
622 2120 aaronmk
    return col_names(select(db, table, limit=0, order_by=None, recover=recover,
623 2084 aaronmk
        table_is_esc=table_is_esc)).next()
624 832 aaronmk
625 853 aaronmk
def index_cols(db, table, index):
626
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
627
    automatically created. When you don't know whether something is a UNIQUE
628
    constraint or a UNIQUE index, use this function.'''
629
    check_name(table)
630
    check_name(index)
631 1909 aaronmk
    module = util.root_module(db.db)
632
    if module == 'psycopg2':
633
        return list(values(run_query(db, '''\
634 853 aaronmk
SELECT attname
635 866 aaronmk
FROM
636
(
637
        SELECT attnum, attname
638
        FROM pg_index
639
        JOIN pg_class index ON index.oid = indexrelid
640
        JOIN pg_class table_ ON table_.oid = indrelid
641
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
642
        WHERE
643
            table_.relname = %(table)s
644
            AND index.relname = %(index)s
645
    UNION
646
        SELECT attnum, attname
647
        FROM
648
        (
649
            SELECT
650
                indrelid
651
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
652
                    AS indkey
653
            FROM pg_index
654
            JOIN pg_class index ON index.oid = indexrelid
655
            JOIN pg_class table_ ON table_.oid = indrelid
656
            WHERE
657
                table_.relname = %(table)s
658
                AND index.relname = %(index)s
659
        ) s
660
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
661
) s
662 853 aaronmk
ORDER BY attnum
663
''',
664 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
665
    else: raise NotImplementedError("Can't list index columns for "+module+
666
        ' database')
667 853 aaronmk
668 464 aaronmk
def constraint_cols(db, table, constraint):
669
    check_name(table)
670
    check_name(constraint)
671 1849 aaronmk
    module = util.root_module(db.db)
672 464 aaronmk
    if module == 'psycopg2':
673
        return list(values(run_query(db, '''\
674
SELECT attname
675
FROM pg_constraint
676
JOIN pg_class ON pg_class.oid = conrelid
677
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
678
WHERE
679
    relname = %(table)s
680
    AND conname = %(constraint)s
681
ORDER BY attnum
682
''',
683
            {'table': table, 'constraint': constraint})))
684
    else: raise NotImplementedError("Can't list constraint columns for "+module+
685
        ' database')
686
687 2096 aaronmk
row_num_col = '_row_num'
688
689 2086 aaronmk
def add_row_num(db, table):
690 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
691
    be the primary key.'''
692 2086 aaronmk
    check_name(table)
693 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
694 2117 aaronmk
        +' serial NOT NULL PRIMARY KEY')
695 2086 aaronmk
696 1968 aaronmk
def tables(db, schema='public', table_like='%'):
697 1849 aaronmk
    module = util.root_module(db.db)
698 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
699 832 aaronmk
    if module == 'psycopg2':
700 1968 aaronmk
        return values(run_query(db, '''\
701
SELECT tablename
702
FROM pg_tables
703
WHERE
704
    schemaname = %(schema)s
705
    AND tablename LIKE %(table_like)s
706
ORDER BY tablename
707
''',
708
            params, cacheable=True))
709
    elif module == 'MySQLdb':
710
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
711
            cacheable=True))
712 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
713 830 aaronmk
714 833 aaronmk
##### Database management
715
716 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
717
    '''For kw_args, see tables()'''
718
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
719 833 aaronmk
720 832 aaronmk
##### Heuristic queries
721
722 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
723 1554 aaronmk
    '''Recovers from errors.
724 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
725
    '''
726 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
727
728 471 aaronmk
    try:
729 2149 aaronmk
        cur = insert(db, table, row, pkey_, recover=True)
730 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
731
            row_ct_ref[0] += cur.rowcount
732
        return value(cur)
733 471 aaronmk
    except DuplicateKeyException, e:
734 2104 aaronmk
        return value(select(db, table, [pkey_],
735 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
736 471 aaronmk
737 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
738 830 aaronmk
    '''Recovers from errors'''
739
    try: return value(select(db, table, [pkey], row, 1, recover=True))
740 14 aaronmk
    except StopIteration:
741 40 aaronmk
        if not create: raise
742 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
743 2078 aaronmk
744 2134 aaronmk
def put_table(db, out_table, in_tables, mapping, limit=None, start=0,
745
    row_ct_ref=None, table_is_esc=False):
746 2078 aaronmk
    '''Recovers from errors.
747
    Only works under PostgreSQL (uses INSERT RETURNING).
748 2131 aaronmk
    @param in_tables The main input table to select from, followed by a list of
749
        tables to join with it using the main input table's pkey
750 2133 aaronmk
    @return (table, col) Where the pkeys (from INSERT RETURNING) are made
751 2078 aaronmk
        available
752
    '''
753 2162 aaronmk
    temp_suffix = clean_name(out_table)
754 2158 aaronmk
        # suffix, not prefix, so main name won't be removed if name is truncated
755
    pkeys_ref = ['pkeys_'+temp_suffix]
756 2131 aaronmk
757 2132 aaronmk
    # Join together input tables
758 2131 aaronmk
    in_tables = in_tables[:] # don't modify input!
759
    in_tables0 = in_tables.pop(0) # first table is separate
760
    in_pkey = pkey(db, in_tables0, recover=True, table_is_esc=table_is_esc)
761 2178 aaronmk
    insert_joins = [in_tables0]+[(t, {in_pkey: join_using}) for t in in_tables]
762 2131 aaronmk
763 2142 aaronmk
    out_pkey = pkey(db, out_table, recover=True, table_is_esc=table_is_esc)
764
    pkeys_cols = [in_pkey, out_pkey]
765
766 2132 aaronmk
    def mk_select_(cols):
767 2178 aaronmk
        return mk_select(db, insert_joins, cols, limit=limit, start=start,
768 2134 aaronmk
            table_is_esc=table_is_esc)
769 2132 aaronmk
770 2158 aaronmk
    out_pkeys_ref = ['out_pkeys_'+temp_suffix]
771 2181 aaronmk
    def insert_(pkeys_table_exists=False):
772 2148 aaronmk
        '''Inserts and capture output pkeys.'''
773 2132 aaronmk
        cur = insert_select(db, out_table, mapping.keys(),
774 2133 aaronmk
            *mk_select_(mapping.values()), returning=out_pkey,
775 2155 aaronmk
            into_ref=out_pkeys_ref, recover=True, table_is_esc=table_is_esc)
776 2078 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
777
            row_ct_ref[0] += cur.rowcount
778 2155 aaronmk
            add_row_num(db, out_pkeys_ref[0]) # for joining it with input pkeys
779 2086 aaronmk
780 2132 aaronmk
        # Get input pkeys corresponding to rows in insert
781 2158 aaronmk
        in_pkeys_ref = ['in_pkeys_'+temp_suffix]
782 2155 aaronmk
        run_query_into(db, *mk_select_([in_pkey]), into_ref=in_pkeys_ref)
783
        add_row_num(db, in_pkeys_ref[0]) # for joining it with output pkeys
784 2086 aaronmk
785 2155 aaronmk
        # Join together output and input pkeys
786 2181 aaronmk
        select_query = mk_select(db, [in_pkeys_ref[0],
787
            (out_pkeys_ref[0], {row_num_col: join_using})], pkeys_cols, start=0)
788
        if pkeys_table_exists:
789
            insert_select(db, pkeys_ref[0], pkeys_cols, *select_query)
790
        else: run_query_into(db, *select_query, into_ref=pkeys_ref)
791 2132 aaronmk
792 2142 aaronmk
    # Do inserts and selects
793 2148 aaronmk
    try: insert_()
794 2142 aaronmk
    except DuplicateKeyException, e:
795
        join_cols = util.dict_subset_right_join(mapping, e.cols)
796 2178 aaronmk
        select_joins = insert_joins + [(out_table, join_cols)]
797
798
        # Get pkeys of already existing rows
799
        run_query_into(db, *mk_select(db, select_joins, pkeys_cols, start=0,
800 2154 aaronmk
            table_is_esc=table_is_esc), into_ref=pkeys_ref, recover=True)
801 2142 aaronmk
802 2154 aaronmk
    return (pkeys_ref[0], out_pkey)
803 2115 aaronmk
804
##### Data cleanup
805
806
def cleanup_table(db, table, cols, table_is_esc=False):
807
    def esc_name_(name): return esc_name(db, name)
808
809
    if not table_is_esc: check_name(table)
810
    cols = map(esc_name_, cols)
811
812
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
813
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
814
            for col in cols))),
815
        dict(null0='', null1=r'\N'))