Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 2127 aaronmk
import operator
5 11 aaronmk
import re
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 1889 aaronmk
from Proxy import Proxy
13 1872 aaronmk
import rand
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 2170 aaronmk
def get_cur_query(cur, input_query=None, input_params=None):
20 2168 aaronmk
    raw_query = None
21
    if hasattr(cur, 'query'): raw_query = cur.query
22
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
23 2170 aaronmk
24
    if raw_query != None: return raw_query
25
    else: return repr(input_query)+' % '+repr(input_params)
26 14 aaronmk
27 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
28
    '''For params, see get_cur_query()'''
29
    exc.add_msg(e, 'query: '+str(get_cur_query(*args, **kw_args)))
30 135 aaronmk
31 300 aaronmk
class DbException(exc.ExceptionWithCause):
32 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
33 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
34 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
35
36 2143 aaronmk
class ExceptionWithName(DbException):
37
    def __init__(self, name, cause=None):
38 2145 aaronmk
        DbException.__init__(self, 'for name: '+str(name), cause)
39 2143 aaronmk
        self.name = name
40 360 aaronmk
41 468 aaronmk
class ExceptionWithColumns(DbException):
42
    def __init__(self, cols, cause=None):
43 2145 aaronmk
        DbException.__init__(self, 'for columns: '+(', '.join(cols)), cause)
44 468 aaronmk
        self.cols = cols
45 11 aaronmk
46 2143 aaronmk
class NameException(DbException): pass
47
48 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
49 13 aaronmk
50 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
51 13 aaronmk
52 2143 aaronmk
class DuplicateTableException(ExceptionWithName): pass
53
54 2188 aaronmk
class DuplicateFunctionException(ExceptionWithName): pass
55
56 89 aaronmk
class EmptyRowException(DbException): pass
57
58 865 aaronmk
##### Warnings
59
60
class DbWarning(UserWarning): pass
61
62 1930 aaronmk
##### Result retrieval
63
64
def col_names(cur): return (col[0] for col in cur.description)
65
66
def rows(cur): return iter(lambda: cur.fetchone(), None)
67
68
def consume_rows(cur):
69
    '''Used to fetch all rows so result will be cached'''
70
    iters.consume_iter(rows(cur))
71
72
def next_row(cur): return rows(cur).next()
73
74
def row(cur):
75
    row_ = next_row(cur)
76
    consume_rows(cur)
77
    return row_
78
79
def next_value(cur): return next_row(cur)[0]
80
81
def value(cur): return row(cur)[0]
82
83
def values(cur): return iters.func_iter(lambda: next_value(cur))
84
85
def value_or_none(cur):
86
    try: return value(cur)
87
    except StopIteration: return None
88
89 2101 aaronmk
##### Input validation
90
91
def clean_name(name): return re.sub(r'\W', r'', name)
92
93
def check_name(name):
94
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
95
        +'" may contain only alphanumeric characters and _')
96
97
def esc_name_by_module(module, name, ignore_case=False):
98
    if module == 'psycopg2':
99
        if ignore_case:
100
            # Don't enclose in quotes because this disables case-insensitivity
101
            check_name(name)
102
            return name
103
        else: quote = '"'
104
    elif module == 'MySQLdb': quote = '`'
105
    else: raise NotImplementedError("Can't escape name for "+module+' database')
106
    return quote + name.replace(quote, '') + quote
107
108
def esc_name_by_engine(engine, name, **kw_args):
109
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
110
111
def esc_name(db, name, **kw_args):
112
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
113
114
def qual_name(db, schema, table):
115
    def esc_name_(name): return esc_name(db, name)
116
    table = esc_name_(table)
117
    if schema != None: return esc_name_(schema)+'.'+table
118
    else: return table
119
120 1869 aaronmk
##### Database connections
121 1849 aaronmk
122 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
123 1926 aaronmk
124 1869 aaronmk
db_engines = {
125
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
126
    'PostgreSQL': ('psycopg2', {}),
127
}
128
129
DatabaseErrors_set = set([DbException])
130
DatabaseErrors = tuple(DatabaseErrors_set)
131
132
def _add_module(module):
133
    DatabaseErrors_set.add(module.DatabaseError)
134
    global DatabaseErrors
135
    DatabaseErrors = tuple(DatabaseErrors_set)
136
137
def db_config_str(db_config):
138
    return db_config['engine']+' database '+db_config['database']
139
140 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
141 1894 aaronmk
142 1901 aaronmk
log_debug_none = lambda msg: None
143
144 1849 aaronmk
class DbConn:
145 2190 aaronmk
    def __init__(self, db_config, serializable=True, autocommit=False,
146
        caching=True, log_debug=log_debug_none):
147 1869 aaronmk
        self.db_config = db_config
148
        self.serializable = serializable
149 2190 aaronmk
        self.autocommit = autocommit
150
        self.caching = caching
151 1901 aaronmk
        self.log_debug = log_debug
152 1869 aaronmk
153
        self.__db = None
154 1889 aaronmk
        self.query_results = {}
155 2139 aaronmk
        self._savepoint = 0
156 1869 aaronmk
157
    def __getattr__(self, name):
158
        if name == '__dict__': raise Exception('getting __dict__')
159
        if name == 'db': return self._db()
160
        else: raise AttributeError()
161
162
    def __getstate__(self):
163
        state = copy.copy(self.__dict__) # shallow copy
164 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
165 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
166
        return state
167
168 2165 aaronmk
    def connected(self): return self.__db != None
169
170 1869 aaronmk
    def _db(self):
171
        if self.__db == None:
172
            # Process db_config
173
            db_config = self.db_config.copy() # don't modify input!
174 2097 aaronmk
            schemas = db_config.pop('schemas', None)
175 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
176
            module = __import__(module_name)
177
            _add_module(module)
178
            for orig, new in mappings.iteritems():
179
                try: util.rename_key(db_config, orig, new)
180
                except KeyError: pass
181
182
            # Connect
183
            self.__db = module.connect(**db_config)
184
185
            # Configure connection
186 2190 aaronmk
            if self.serializable and not self.autocommit:
187
                self.db.set_session(isolation_level='SERIALIZABLE')
188 2101 aaronmk
            if schemas != None:
189
                schemas_ = ''.join((esc_name(self, s)+', '
190
                    for s in schemas.split(',')))
191
                run_raw_query(self, "SELECT set_config('search_path', \
192
%s || current_setting('search_path'), false)", [schemas_])
193 1869 aaronmk
194
        return self.__db
195 1889 aaronmk
196 1891 aaronmk
    class DbCursor(Proxy):
197 1927 aaronmk
        def __init__(self, outer):
198 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
199 2191 aaronmk
            self.outer = outer
200 1927 aaronmk
            self.query_results = outer.query_results
201 1894 aaronmk
            self.query_lookup = None
202 1891 aaronmk
            self.result = []
203 1889 aaronmk
204 1894 aaronmk
        def execute(self, query, params=None):
205 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
206 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
207 2148 aaronmk
            try:
208 2191 aaronmk
                try:
209
                    return_value = self.inner.execute(query, params)
210
                    self.outer.do_autocommit()
211 2148 aaronmk
                finally: self.query = get_cur_query(self.inner)
212 1904 aaronmk
            except Exception, e:
213 2170 aaronmk
                _add_cursor_info(e, self, query, params)
214 1904 aaronmk
                self.result = e # cache the exception as the result
215
                self._cache_result()
216
                raise
217 1930 aaronmk
            # Fetch all rows so result will be cached
218
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
219 1894 aaronmk
            return return_value
220
221 1891 aaronmk
        def fetchone(self):
222
            row = self.inner.fetchone()
223 1899 aaronmk
            if row != None: self.result.append(row)
224
            # otherwise, fetched all rows
225 1904 aaronmk
            else: self._cache_result()
226
            return row
227
228
        def _cache_result(self):
229 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
230
            # idempotent, but an invalid insert will always be invalid
231 1930 aaronmk
            if self.query_results != None and (not self._is_insert
232 1906 aaronmk
                or isinstance(self.result, Exception)):
233
234 1894 aaronmk
                assert self.query_lookup != None
235 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
236
                    util.dict_subset(dicts.AttrsDictView(self),
237
                    ['query', 'result', 'rowcount', 'description']))
238 1906 aaronmk
239 1916 aaronmk
        class CacheCursor:
240
            def __init__(self, cached_result): self.__dict__ = cached_result
241
242 1927 aaronmk
            def execute(self, *args, **kw_args):
243 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
244
                # otherwise, result is a rows list
245
                self.iter = iter(self.result)
246
247
            def fetchone(self):
248
                try: return self.iter.next()
249
                except StopIteration: return None
250 1891 aaronmk
251 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
252 2148 aaronmk
        '''Translates known DB errors to typed exceptions:
253
        See self.DbCursor.execute().'''
254 2167 aaronmk
        assert query != None
255
256 2047 aaronmk
        if not self.caching: cacheable = False
257 1903 aaronmk
        used_cache = False
258
        try:
259 1927 aaronmk
            # Get cursor
260
            if cacheable:
261
                query_lookup = _query_lookup(query, params)
262
                try:
263
                    cur = self.query_results[query_lookup]
264
                    used_cache = True
265
                except KeyError: cur = self.DbCursor(self)
266
            else: cur = self.db.cursor()
267
268
            # Run query
269 2148 aaronmk
            cur.execute(query, params)
270 1903 aaronmk
        finally:
271
            if self.log_debug != log_debug_none: # only compute msg if needed
272
                if used_cache: cache_status = 'Cache hit'
273
                elif cacheable: cache_status = 'Cache miss'
274
                else: cache_status = 'Non-cacheable'
275 1927 aaronmk
                self.log_debug(cache_status+': '
276 2170 aaronmk
                    +strings.one_line(str(get_cur_query(cur, query, params))))
277 1903 aaronmk
278
        return cur
279 1914 aaronmk
280
    def is_cached(self, query, params=None):
281
        return _query_lookup(query, params) in self.query_results
282 2139 aaronmk
283
    def with_savepoint(self, func):
284 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
285 2139 aaronmk
        self.run_query('SAVEPOINT '+savepoint)
286
        self._savepoint += 1
287
        try:
288
            try: return_val = func()
289
            finally:
290
                self._savepoint -= 1
291
                assert self._savepoint >= 0
292
        except:
293
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint)
294
            raise
295
        else:
296
            self.run_query('RELEASE SAVEPOINT '+savepoint)
297 2191 aaronmk
            self.do_autocommit()
298 2139 aaronmk
            return return_val
299 2191 aaronmk
300
    def do_autocommit(self):
301
        '''Autocommits if outside savepoint'''
302
        assert self._savepoint >= 0
303
        if self.autocommit and self._savepoint == 0:
304
            self.log_debug('Autocommiting')
305
            self.db.commit()
306 1849 aaronmk
307 1869 aaronmk
connect = DbConn
308
309 832 aaronmk
##### Querying
310
311 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
312 2085 aaronmk
    '''For params, see DbConn.run_query()'''
313 1894 aaronmk
    return db.run_query(*args, **kw_args)
314 11 aaronmk
315 2068 aaronmk
def mogrify(db, query, params):
316
    module = util.root_module(db.db)
317
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
318
    else: raise NotImplementedError("Can't mogrify query for "+module+
319
        ' database')
320
321 832 aaronmk
##### Recoverable querying
322 15 aaronmk
323 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
324 11 aaronmk
325 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
326 830 aaronmk
    if recover == None: recover = False
327
328 2148 aaronmk
    try:
329
        def run(): return run_raw_query(db, query, params, cacheable)
330
        if recover and not db.is_cached(query, params):
331
            return with_savepoint(db, run)
332
        else: return run() # don't need savepoint if cached
333
    except Exception, e:
334
        if not recover: raise # need savepoint to run index_cols()
335
        msg = str(e)
336
        match = re.search(r'duplicate key value violates unique constraint '
337
            r'"((_?[^\W_]+)_[^"]+)"', msg)
338
        if match:
339
            constraint, table = match.groups()
340
            try: cols = index_cols(db, table, constraint)
341
            except NotImplementedError: raise e
342
            else: raise DuplicateKeyException(cols, e)
343
        match = re.search(r'null value in column "(\w+)" violates not-null '
344
            'constraint', msg)
345
        if match: raise NullValueException([match.group(1)], e)
346
        match = re.search(r'relation "(\w+)" already exists', msg)
347
        if match: raise DuplicateTableException(match.group(1), e)
348 2188 aaronmk
        match = re.search(r'function "(\w+)" already exists', msg)
349
        if match: raise DuplicateFunctionException(match.group(1), e)
350 2148 aaronmk
        raise # no specific exception raised
351 830 aaronmk
352 832 aaronmk
##### Basic queries
353
354 2153 aaronmk
def next_version(name):
355
    '''Prepends the version # so it won't be removed if the name is truncated'''
356 2163 aaronmk
    version = 1 # first existing name was version 0
357 2153 aaronmk
    match = re.match(r'^v(\d+)_(.*)$', name)
358
    if match:
359
        version = int(match.group(1))+1
360
        name = match.group(2)
361
    return 'v'+str(version)+'_'+name
362
363 2151 aaronmk
def run_query_into(db, query, params, into_ref=None, *args, **kw_args):
364 2085 aaronmk
    '''Outputs a query to a temp table.
365
    For params, see run_query().
366
    '''
367 2151 aaronmk
    if into_ref == None: return run_query(db, query, params, *args, **kw_args)
368 2085 aaronmk
    else: # place rows in temp table
369 2151 aaronmk
        check_name(into_ref[0])
370 2153 aaronmk
        kw_args['recover'] = True
371
        while True:
372
            try:
373
                return run_query(db, 'CREATE TEMP TABLE '+into_ref[0]+' AS '
374
                    +query, params, *args, **kw_args)
375
                    # CREATE TABLE AS sets rowcount to # rows in query
376
            except DuplicateTableException, e:
377
                into_ref[0] = next_version(into_ref[0])
378
                # try again with next version of name
379 2085 aaronmk
380 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
381
382 2127 aaronmk
join_using = object() # tells mk_select() to join the column with USING
383
384 2187 aaronmk
filter_out = object() # tells mk_select() to filter out rows that match the join
385 2180 aaronmk
386 2121 aaronmk
def mk_select(db, tables, fields=None, conds=None, limit=None, start=None,
387 2120 aaronmk
    order_by=order_by_pkey, table_is_esc=False):
388 1981 aaronmk
    '''
389 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
390 2187 aaronmk
        together: [table0, (table1, joins), ...]
391
392
        joins has the format: dict(right_col=left_col, ...)
393
        * if left_col is join_using, left_col is set to right_col
394
        * if left_col is filter_out, the tables are LEFT JOINed together and the
395
          query is filtered by `right_col IS NULL` (indicating no match)
396 1981 aaronmk
    @param fields Use None to select all fields in the table
397
    @param table_is_esc Whether the table name has already been escaped
398 2054 aaronmk
    @return tuple(query, params)
399 1981 aaronmk
    '''
400 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
401 2058 aaronmk
402 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
403 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
404 2121 aaronmk
    table0 = tables.pop(0) # first table is separate
405
406 1135 aaronmk
    if conds == None: conds = {}
407 135 aaronmk
    assert limit == None or type(limit) == int
408 865 aaronmk
    assert start == None or type(start) == int
409 2120 aaronmk
    if order_by == order_by_pkey:
410 2121 aaronmk
        order_by = pkey(db, table0, recover=True, table_is_esc=table_is_esc)
411
    if not table_is_esc: table0 = esc_name_(table0)
412 865 aaronmk
413 2056 aaronmk
    params = []
414
415 2161 aaronmk
    def parse_col(field, default_table=None):
416 2056 aaronmk
        '''Parses fields'''
417 2176 aaronmk
        if field == None: field = (field,) # for None values, tuple is optional
418 2157 aaronmk
        is_tuple = isinstance(field, tuple)
419
        if is_tuple and len(field) == 1: # field is literal value
420
            value, = field
421 2056 aaronmk
            sql_ = '%s'
422
            params.append(value)
423 2157 aaronmk
        elif is_tuple and len(field) == 2: # field is col with table
424
            table, col = field
425
            if not table_is_esc: table = esc_name_(table)
426
            sql_ = table+'.'+esc_name_(col)
427 2161 aaronmk
        else:
428
            sql_ = esc_name_(field) # field is col name
429
            if default_table != None: sql_ = default_table+'.'+sql_
430 2056 aaronmk
        return sql_
431 11 aaronmk
    def cond(entry):
432 2056 aaronmk
        '''Parses conditions'''
433 13 aaronmk
        col, value = entry
434 2187 aaronmk
        cond_ = parse_col(col)+' '
435 11 aaronmk
        if value == None: cond_ += 'IS'
436
        else: cond_ += '='
437
        cond_ += ' %s'
438
        return cond_
439 2056 aaronmk
440 1135 aaronmk
    query = 'SELECT '
441
    if fields == None: query += '*'
442 2056 aaronmk
    else: query += ', '.join(map(parse_col, fields))
443 2121 aaronmk
    query += ' FROM '+table0
444 865 aaronmk
445 2122 aaronmk
    # Add joins
446
    left_table = table0
447
    for table, joins in tables:
448
        if not table_is_esc: table = esc_name_(table)
449
450 2187 aaronmk
        left_join_ref = [False]
451
452 2122 aaronmk
        def join(entry):
453 2127 aaronmk
            '''Parses non-USING joins'''
454 2124 aaronmk
            right_col, left_col = entry
455 2173 aaronmk
456 2179 aaronmk
            # Parse special values
457 2176 aaronmk
            if left_col == None: left_col = (left_col,)
458
                # for None values, tuple is optional
459 2179 aaronmk
            elif left_col == join_using: left_col = right_col
460 2187 aaronmk
            elif left_col == filter_out:
461 2180 aaronmk
                left_col = right_col
462 2187 aaronmk
                left_join_ref[0] = True
463
                conds[(table, right_col)] = None # filter query by no match
464 2179 aaronmk
465
            # Create SQL
466 2185 aaronmk
            right_col = table+'.'+esc_name_(right_col)
467
            sql_ = right_col+' '
468 2173 aaronmk
            if isinstance(left_col, tuple) and len(left_col) == 1:
469
                # col is literal value
470
                value, = left_col
471
                if value == None: sql_ += 'IS'
472
                else: sql_ += '='
473
                sql_ += ' %s'
474
                params.append(value)
475
            else: # col is name
476
                left_col = parse_col(left_col, left_table)
477
                sql_ += ('= '+left_col+' OR ('+right_col+' IS NULL AND '
478
                    +left_col+' IS NULL)')
479
480
            return sql_
481 2122 aaronmk
482 2187 aaronmk
        # Create join condition and determine join type
483 2127 aaronmk
        if reduce(operator.and_, (v == join_using for v in joins.itervalues())):
484 2179 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
485 2187 aaronmk
            join_cond = 'USING ('+(', '.join(joins.iterkeys()))+')'
486
        else: join_cond = 'ON '+(' AND '.join(map(join, joins.iteritems())))
487 2127 aaronmk
488 2187 aaronmk
        # Create join
489
        if left_join_ref[0]: query += ' LEFT'
490
        query += ' JOIN '+table+' '+join_cond
491
492 2122 aaronmk
        left_table = table
493
494 865 aaronmk
    missing = True
495 89 aaronmk
    if conds != {}:
496 2122 aaronmk
        query += ' WHERE '+(' AND '.join(map(cond, conds.iteritems())))
497 2056 aaronmk
        params += conds.values()
498 865 aaronmk
        missing = False
499 2186 aaronmk
    if order_by != None: query += ' ORDER BY '+parse_col(order_by, table0)
500 865 aaronmk
    if limit != None: query += ' LIMIT '+str(limit); missing = False
501
    if start != None:
502
        if start != 0: query += ' OFFSET '+str(start)
503
        missing = False
504
    if missing: warnings.warn(DbWarning(
505
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
506
507 2056 aaronmk
    return (query, params)
508 11 aaronmk
509 2054 aaronmk
def select(db, *args, **kw_args):
510
    '''For params, see mk_select() and run_query()'''
511
    recover = kw_args.pop('recover', None)
512
    cacheable = kw_args.pop('cacheable', True)
513
514
    query, params = mk_select(db, *args, **kw_args)
515
    return run_query(db, query, params, recover, cacheable)
516
517 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
518 2070 aaronmk
    returning=None, embeddable=False, table_is_esc=False):
519 1960 aaronmk
    '''
520
    @param returning str|None An inserted column (such as pkey) to return
521 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
522 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
523
        query will be fully cached, not just if it raises an exception.
524 1960 aaronmk
    @param table_is_esc Whether the table name has already been escaped
525
    '''
526 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
527
    if cols == []: cols = None # no cols (all defaults) = unknown col names
528 1960 aaronmk
    if not table_is_esc: check_name(table)
529 2063 aaronmk
530
    # Build query
531
    query = 'INSERT INTO '+table
532
    if cols != None:
533
        map(check_name, cols)
534
        query += ' ('+', '.join(cols)+')'
535
    query += ' '+select_query
536
537
    if returning != None:
538
        check_name(returning)
539
        query += ' RETURNING '+returning
540
541 2070 aaronmk
    if embeddable:
542
        # Create function
543 2189 aaronmk
        function_name = '_'.join(map(clean_name, ['insert', table] + cols))
544 2070 aaronmk
        return_type = 'SETOF '+table+'.'+returning+'%TYPE'
545 2189 aaronmk
        while True:
546
            try:
547
                function = 'pg_temp.'+function_name
548
                function_query = '''\
549
CREATE FUNCTION '''+function+'''() RETURNS '''+return_type+'''
550 2070 aaronmk
    LANGUAGE sql
551
    AS $$'''+mogrify(db, query, params)+''';$$;
552
'''
553 2189 aaronmk
                run_query(db, function_query, recover=True, cacheable=True)
554
                break # this version was successful
555
            except DuplicateFunctionException, e:
556
                function_name = next_version(function_name)
557
                # try again with next version of name
558 2070 aaronmk
559
        # Return query that uses function
560 2134 aaronmk
        return mk_select(db, function+'() AS f ('+returning+')', start=0,
561
            order_by=None, table_is_esc=True)# AS clause requires function alias
562 2070 aaronmk
563 2066 aaronmk
    return (query, params)
564
565
def insert_select(db, *args, **kw_args):
566 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
567 2152 aaronmk
    @param into_ref List with name of temp table to place RETURNING values in
568 2072 aaronmk
    '''
569 2151 aaronmk
    into_ref = kw_args.pop('into_ref', None)
570
    if into_ref != None: kw_args['embeddable'] = True
571 2066 aaronmk
    recover = kw_args.pop('recover', None)
572
    cacheable = kw_args.pop('cacheable', True)
573
574
    query, params = mk_insert_select(db, *args, **kw_args)
575 2153 aaronmk
    return run_query_into(db, query, params, into_ref, recover=recover,
576
        cacheable=cacheable)
577 2063 aaronmk
578 2066 aaronmk
default = object() # tells insert() to use the default value for a column
579
580 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
581 2085 aaronmk
    '''For params, see insert_select()'''
582 1960 aaronmk
    if lists.is_seq(row): cols = None
583
    else:
584
        cols = row.keys()
585
        row = row.values()
586
    row = list(row) # ensure that "!= []" works
587
588 1961 aaronmk
    # Check for special values
589
    labels = []
590
    values = []
591
    for value in row:
592
        if value == default: labels.append('DEFAULT')
593
        else:
594
            labels.append('%s')
595
            values.append(value)
596
597
    # Build query
598 2063 aaronmk
    if values != []: query = ' VALUES ('+(', '.join(labels))+')'
599
    else: query = None
600 1554 aaronmk
601 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
602 11 aaronmk
603 135 aaronmk
def last_insert_id(db):
604 1849 aaronmk
    module = util.root_module(db.db)
605 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
606
    elif module == 'MySQLdb': return db.insert_id()
607
    else: return None
608 13 aaronmk
609 1968 aaronmk
def truncate(db, table, schema='public'):
610
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
611 832 aaronmk
612
##### Database structure queries
613
614 2084 aaronmk
def pkey(db, table, recover=None, table_is_esc=False):
615 832 aaronmk
    '''Assumed to be first column in table'''
616 2120 aaronmk
    return col_names(select(db, table, limit=0, order_by=None, recover=recover,
617 2084 aaronmk
        table_is_esc=table_is_esc)).next()
618 832 aaronmk
619 853 aaronmk
def index_cols(db, table, index):
620
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
621
    automatically created. When you don't know whether something is a UNIQUE
622
    constraint or a UNIQUE index, use this function.'''
623
    check_name(table)
624
    check_name(index)
625 1909 aaronmk
    module = util.root_module(db.db)
626
    if module == 'psycopg2':
627
        return list(values(run_query(db, '''\
628 853 aaronmk
SELECT attname
629 866 aaronmk
FROM
630
(
631
        SELECT attnum, attname
632
        FROM pg_index
633
        JOIN pg_class index ON index.oid = indexrelid
634
        JOIN pg_class table_ ON table_.oid = indrelid
635
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
636
        WHERE
637
            table_.relname = %(table)s
638
            AND index.relname = %(index)s
639
    UNION
640
        SELECT attnum, attname
641
        FROM
642
        (
643
            SELECT
644
                indrelid
645
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
646
                    AS indkey
647
            FROM pg_index
648
            JOIN pg_class index ON index.oid = indexrelid
649
            JOIN pg_class table_ ON table_.oid = indrelid
650
            WHERE
651
                table_.relname = %(table)s
652
                AND index.relname = %(index)s
653
        ) s
654
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
655
) s
656 853 aaronmk
ORDER BY attnum
657
''',
658 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
659
    else: raise NotImplementedError("Can't list index columns for "+module+
660
        ' database')
661 853 aaronmk
662 464 aaronmk
def constraint_cols(db, table, constraint):
663
    check_name(table)
664
    check_name(constraint)
665 1849 aaronmk
    module = util.root_module(db.db)
666 464 aaronmk
    if module == 'psycopg2':
667
        return list(values(run_query(db, '''\
668
SELECT attname
669
FROM pg_constraint
670
JOIN pg_class ON pg_class.oid = conrelid
671
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
672
WHERE
673
    relname = %(table)s
674
    AND conname = %(constraint)s
675
ORDER BY attnum
676
''',
677
            {'table': table, 'constraint': constraint})))
678
    else: raise NotImplementedError("Can't list constraint columns for "+module+
679
        ' database')
680
681 2096 aaronmk
row_num_col = '_row_num'
682
683 2086 aaronmk
def add_row_num(db, table):
684 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
685
    be the primary key.'''
686 2086 aaronmk
    check_name(table)
687 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
688 2117 aaronmk
        +' serial NOT NULL PRIMARY KEY')
689 2086 aaronmk
690 1968 aaronmk
def tables(db, schema='public', table_like='%'):
691 1849 aaronmk
    module = util.root_module(db.db)
692 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
693 832 aaronmk
    if module == 'psycopg2':
694 1968 aaronmk
        return values(run_query(db, '''\
695
SELECT tablename
696
FROM pg_tables
697
WHERE
698
    schemaname = %(schema)s
699
    AND tablename LIKE %(table_like)s
700
ORDER BY tablename
701
''',
702
            params, cacheable=True))
703
    elif module == 'MySQLdb':
704
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
705
            cacheable=True))
706 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
707 830 aaronmk
708 833 aaronmk
##### Database management
709
710 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
711
    '''For kw_args, see tables()'''
712
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
713 833 aaronmk
714 832 aaronmk
##### Heuristic queries
715
716 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
717 1554 aaronmk
    '''Recovers from errors.
718 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
719
    '''
720 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
721
722 471 aaronmk
    try:
723 2149 aaronmk
        cur = insert(db, table, row, pkey_, recover=True)
724 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
725
            row_ct_ref[0] += cur.rowcount
726
        return value(cur)
727 471 aaronmk
    except DuplicateKeyException, e:
728 2104 aaronmk
        return value(select(db, table, [pkey_],
729 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
730 471 aaronmk
731 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
732 830 aaronmk
    '''Recovers from errors'''
733
    try: return value(select(db, table, [pkey], row, 1, recover=True))
734 14 aaronmk
    except StopIteration:
735 40 aaronmk
        if not create: raise
736 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
737 2078 aaronmk
738 2134 aaronmk
def put_table(db, out_table, in_tables, mapping, limit=None, start=0,
739
    row_ct_ref=None, table_is_esc=False):
740 2078 aaronmk
    '''Recovers from errors.
741
    Only works under PostgreSQL (uses INSERT RETURNING).
742 2131 aaronmk
    @param in_tables The main input table to select from, followed by a list of
743
        tables to join with it using the main input table's pkey
744 2133 aaronmk
    @return (table, col) Where the pkeys (from INSERT RETURNING) are made
745 2078 aaronmk
        available
746
    '''
747 2162 aaronmk
    temp_suffix = clean_name(out_table)
748 2158 aaronmk
        # suffix, not prefix, so main name won't be removed if name is truncated
749
    pkeys_ref = ['pkeys_'+temp_suffix]
750 2131 aaronmk
751 2132 aaronmk
    # Join together input tables
752 2131 aaronmk
    in_tables = in_tables[:] # don't modify input!
753
    in_tables0 = in_tables.pop(0) # first table is separate
754
    in_pkey = pkey(db, in_tables0, recover=True, table_is_esc=table_is_esc)
755 2178 aaronmk
    insert_joins = [in_tables0]+[(t, {in_pkey: join_using}) for t in in_tables]
756 2131 aaronmk
757 2142 aaronmk
    out_pkey = pkey(db, out_table, recover=True, table_is_esc=table_is_esc)
758
    pkeys_cols = [in_pkey, out_pkey]
759
760 2132 aaronmk
    def mk_select_(cols):
761 2178 aaronmk
        return mk_select(db, insert_joins, cols, limit=limit, start=start,
762 2134 aaronmk
            table_is_esc=table_is_esc)
763 2132 aaronmk
764 2158 aaronmk
    out_pkeys_ref = ['out_pkeys_'+temp_suffix]
765 2181 aaronmk
    def insert_(pkeys_table_exists=False):
766 2148 aaronmk
        '''Inserts and capture output pkeys.'''
767 2132 aaronmk
        cur = insert_select(db, out_table, mapping.keys(),
768 2133 aaronmk
            *mk_select_(mapping.values()), returning=out_pkey,
769 2155 aaronmk
            into_ref=out_pkeys_ref, recover=True, table_is_esc=table_is_esc)
770 2078 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
771
            row_ct_ref[0] += cur.rowcount
772 2155 aaronmk
            add_row_num(db, out_pkeys_ref[0]) # for joining it with input pkeys
773 2086 aaronmk
774 2132 aaronmk
        # Get input pkeys corresponding to rows in insert
775 2158 aaronmk
        in_pkeys_ref = ['in_pkeys_'+temp_suffix]
776 2155 aaronmk
        run_query_into(db, *mk_select_([in_pkey]), into_ref=in_pkeys_ref)
777
        add_row_num(db, in_pkeys_ref[0]) # for joining it with output pkeys
778 2086 aaronmk
779 2155 aaronmk
        # Join together output and input pkeys
780 2181 aaronmk
        select_query = mk_select(db, [in_pkeys_ref[0],
781
            (out_pkeys_ref[0], {row_num_col: join_using})], pkeys_cols, start=0)
782
        if pkeys_table_exists:
783
            insert_select(db, pkeys_ref[0], pkeys_cols, *select_query)
784
        else: run_query_into(db, *select_query, into_ref=pkeys_ref)
785 2132 aaronmk
786 2142 aaronmk
    # Do inserts and selects
787 2148 aaronmk
    try: insert_()
788 2142 aaronmk
    except DuplicateKeyException, e:
789
        join_cols = util.dict_subset_right_join(mapping, e.cols)
790 2178 aaronmk
        select_joins = insert_joins + [(out_table, join_cols)]
791
792
        # Get pkeys of already existing rows
793
        run_query_into(db, *mk_select(db, select_joins, pkeys_cols, start=0,
794 2154 aaronmk
            table_is_esc=table_is_esc), into_ref=pkeys_ref, recover=True)
795 2142 aaronmk
796 2154 aaronmk
    return (pkeys_ref[0], out_pkey)
797 2115 aaronmk
798
##### Data cleanup
799
800
def cleanup_table(db, table, cols, table_is_esc=False):
801
    def esc_name_(name): return esc_name(db, name)
802
803
    if not table_is_esc: check_name(table)
804
    cols = map(esc_name_, cols)
805
806
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
807
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
808
            for col in cols))),
809
        dict(null0='', null1=r'\N'))