Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 2127 aaronmk
import operator
5 11 aaronmk
import re
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 1889 aaronmk
from Proxy import Proxy
13 1872 aaronmk
import rand
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 2170 aaronmk
def get_cur_query(cur, input_query=None, input_params=None):
20 2168 aaronmk
    raw_query = None
21
    if hasattr(cur, 'query'): raw_query = cur.query
22
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
23 2170 aaronmk
24
    if raw_query != None: return raw_query
25
    else: return repr(input_query)+' % '+repr(input_params)
26 14 aaronmk
27 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
28
    '''For params, see get_cur_query()'''
29
    exc.add_msg(e, 'query: '+str(get_cur_query(*args, **kw_args)))
30 135 aaronmk
31 300 aaronmk
class DbException(exc.ExceptionWithCause):
32 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
33 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
34 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
35
36 2143 aaronmk
class ExceptionWithName(DbException):
37
    def __init__(self, name, cause=None):
38 2145 aaronmk
        DbException.__init__(self, 'for name: '+str(name), cause)
39 2143 aaronmk
        self.name = name
40 360 aaronmk
41 468 aaronmk
class ExceptionWithColumns(DbException):
42
    def __init__(self, cols, cause=None):
43 2145 aaronmk
        DbException.__init__(self, 'for columns: '+(', '.join(cols)), cause)
44 468 aaronmk
        self.cols = cols
45 11 aaronmk
46 2143 aaronmk
class NameException(DbException): pass
47
48 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
49 13 aaronmk
50 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
51 13 aaronmk
52 2143 aaronmk
class DuplicateTableException(ExceptionWithName): pass
53
54 2188 aaronmk
class DuplicateFunctionException(ExceptionWithName): pass
55
56 89 aaronmk
class EmptyRowException(DbException): pass
57
58 865 aaronmk
##### Warnings
59
60
class DbWarning(UserWarning): pass
61
62 1930 aaronmk
##### Result retrieval
63
64
def col_names(cur): return (col[0] for col in cur.description)
65
66
def rows(cur): return iter(lambda: cur.fetchone(), None)
67
68
def consume_rows(cur):
69
    '''Used to fetch all rows so result will be cached'''
70
    iters.consume_iter(rows(cur))
71
72
def next_row(cur): return rows(cur).next()
73
74
def row(cur):
75
    row_ = next_row(cur)
76
    consume_rows(cur)
77
    return row_
78
79
def next_value(cur): return next_row(cur)[0]
80
81
def value(cur): return row(cur)[0]
82
83
def values(cur): return iters.func_iter(lambda: next_value(cur))
84
85
def value_or_none(cur):
86
    try: return value(cur)
87
    except StopIteration: return None
88
89 2101 aaronmk
##### Input validation
90
91 2198 aaronmk
def clean_name(name): return re.sub(r'\W', r'', name).lower()
92 2101 aaronmk
93
def check_name(name):
94
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
95
        +'" may contain only alphanumeric characters and _')
96
97
def esc_name_by_module(module, name, ignore_case=False):
98
    if module == 'psycopg2':
99
        if ignore_case:
100
            # Don't enclose in quotes because this disables case-insensitivity
101
            check_name(name)
102
            return name
103
        else: quote = '"'
104
    elif module == 'MySQLdb': quote = '`'
105
    else: raise NotImplementedError("Can't escape name for "+module+' database')
106
    return quote + name.replace(quote, '') + quote
107
108
def esc_name_by_engine(engine, name, **kw_args):
109
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
110
111
def esc_name(db, name, **kw_args):
112
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
113
114
def qual_name(db, schema, table):
115
    def esc_name_(name): return esc_name(db, name)
116
    table = esc_name_(table)
117
    if schema != None: return esc_name_(schema)+'.'+table
118
    else: return table
119
120 1869 aaronmk
##### Database connections
121 1849 aaronmk
122 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
123 1926 aaronmk
124 1869 aaronmk
db_engines = {
125
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
126
    'PostgreSQL': ('psycopg2', {}),
127
}
128
129
DatabaseErrors_set = set([DbException])
130
DatabaseErrors = tuple(DatabaseErrors_set)
131
132
def _add_module(module):
133
    DatabaseErrors_set.add(module.DatabaseError)
134
    global DatabaseErrors
135
    DatabaseErrors = tuple(DatabaseErrors_set)
136
137
def db_config_str(db_config):
138
    return db_config['engine']+' database '+db_config['database']
139
140 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
141 1894 aaronmk
142 1901 aaronmk
log_debug_none = lambda msg: None
143
144 1849 aaronmk
class DbConn:
145 2190 aaronmk
    def __init__(self, db_config, serializable=True, autocommit=False,
146
        caching=True, log_debug=log_debug_none):
147 1869 aaronmk
        self.db_config = db_config
148
        self.serializable = serializable
149 2190 aaronmk
        self.autocommit = autocommit
150
        self.caching = caching
151 1901 aaronmk
        self.log_debug = log_debug
152 2193 aaronmk
        self.debug = log_debug != log_debug_none
153 1869 aaronmk
154
        self.__db = None
155 1889 aaronmk
        self.query_results = {}
156 2139 aaronmk
        self._savepoint = 0
157 1869 aaronmk
158
    def __getattr__(self, name):
159
        if name == '__dict__': raise Exception('getting __dict__')
160
        if name == 'db': return self._db()
161
        else: raise AttributeError()
162
163
    def __getstate__(self):
164
        state = copy.copy(self.__dict__) # shallow copy
165 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
166 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
167
        return state
168
169 2165 aaronmk
    def connected(self): return self.__db != None
170
171 1869 aaronmk
    def _db(self):
172
        if self.__db == None:
173
            # Process db_config
174
            db_config = self.db_config.copy() # don't modify input!
175 2097 aaronmk
            schemas = db_config.pop('schemas', None)
176 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
177
            module = __import__(module_name)
178
            _add_module(module)
179
            for orig, new in mappings.iteritems():
180
                try: util.rename_key(db_config, orig, new)
181
                except KeyError: pass
182
183
            # Connect
184
            self.__db = module.connect(**db_config)
185
186
            # Configure connection
187 2190 aaronmk
            if self.serializable and not self.autocommit:
188
                self.db.set_session(isolation_level='SERIALIZABLE')
189 2101 aaronmk
            if schemas != None:
190
                schemas_ = ''.join((esc_name(self, s)+', '
191
                    for s in schemas.split(',')))
192
                run_raw_query(self, "SELECT set_config('search_path', \
193
%s || current_setting('search_path'), false)", [schemas_])
194 1869 aaronmk
195
        return self.__db
196 1889 aaronmk
197 1891 aaronmk
    class DbCursor(Proxy):
198 1927 aaronmk
        def __init__(self, outer):
199 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
200 2191 aaronmk
            self.outer = outer
201 1927 aaronmk
            self.query_results = outer.query_results
202 1894 aaronmk
            self.query_lookup = None
203 1891 aaronmk
            self.result = []
204 1889 aaronmk
205 1894 aaronmk
        def execute(self, query, params=None):
206 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
207 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
208 2148 aaronmk
            try:
209 2191 aaronmk
                try:
210
                    return_value = self.inner.execute(query, params)
211
                    self.outer.do_autocommit()
212 2148 aaronmk
                finally: self.query = get_cur_query(self.inner)
213 1904 aaronmk
            except Exception, e:
214 2170 aaronmk
                _add_cursor_info(e, self, query, params)
215 1904 aaronmk
                self.result = e # cache the exception as the result
216
                self._cache_result()
217
                raise
218 1930 aaronmk
            # Fetch all rows so result will be cached
219
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
220 1894 aaronmk
            return return_value
221
222 1891 aaronmk
        def fetchone(self):
223
            row = self.inner.fetchone()
224 1899 aaronmk
            if row != None: self.result.append(row)
225
            # otherwise, fetched all rows
226 1904 aaronmk
            else: self._cache_result()
227
            return row
228
229
        def _cache_result(self):
230 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
231
            # idempotent, but an invalid insert will always be invalid
232 1930 aaronmk
            if self.query_results != None and (not self._is_insert
233 1906 aaronmk
                or isinstance(self.result, Exception)):
234
235 1894 aaronmk
                assert self.query_lookup != None
236 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
237
                    util.dict_subset(dicts.AttrsDictView(self),
238
                    ['query', 'result', 'rowcount', 'description']))
239 1906 aaronmk
240 1916 aaronmk
        class CacheCursor:
241
            def __init__(self, cached_result): self.__dict__ = cached_result
242
243 1927 aaronmk
            def execute(self, *args, **kw_args):
244 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
245
                # otherwise, result is a rows list
246
                self.iter = iter(self.result)
247
248
            def fetchone(self):
249
                try: return self.iter.next()
250
                except StopIteration: return None
251 1891 aaronmk
252 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
253 2148 aaronmk
        '''Translates known DB errors to typed exceptions:
254
        See self.DbCursor.execute().'''
255 2167 aaronmk
        assert query != None
256
257 2047 aaronmk
        if not self.caching: cacheable = False
258 1903 aaronmk
        used_cache = False
259
        try:
260 1927 aaronmk
            # Get cursor
261
            if cacheable:
262
                query_lookup = _query_lookup(query, params)
263
                try:
264
                    cur = self.query_results[query_lookup]
265
                    used_cache = True
266
                except KeyError: cur = self.DbCursor(self)
267
            else: cur = self.db.cursor()
268
269
            # Run query
270 2148 aaronmk
            cur.execute(query, params)
271 1903 aaronmk
        finally:
272 2193 aaronmk
            if self.debug: # only compute msg if needed
273 1903 aaronmk
                if used_cache: cache_status = 'Cache hit'
274
                elif cacheable: cache_status = 'Cache miss'
275
                else: cache_status = 'Non-cacheable'
276 1927 aaronmk
                self.log_debug(cache_status+': '
277 2170 aaronmk
                    +strings.one_line(str(get_cur_query(cur, query, params))))
278 1903 aaronmk
279
        return cur
280 1914 aaronmk
281
    def is_cached(self, query, params=None):
282
        return _query_lookup(query, params) in self.query_results
283 2139 aaronmk
284
    def with_savepoint(self, func):
285 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
286 2139 aaronmk
        self.run_query('SAVEPOINT '+savepoint)
287
        self._savepoint += 1
288
        try:
289
            try: return_val = func()
290
            finally:
291
                self._savepoint -= 1
292
                assert self._savepoint >= 0
293
        except:
294
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint)
295
            raise
296
        else:
297
            self.run_query('RELEASE SAVEPOINT '+savepoint)
298 2191 aaronmk
            self.do_autocommit()
299 2139 aaronmk
            return return_val
300 2191 aaronmk
301
    def do_autocommit(self):
302
        '''Autocommits if outside savepoint'''
303
        assert self._savepoint >= 0
304
        if self.autocommit and self._savepoint == 0:
305
            self.log_debug('Autocommiting')
306
            self.db.commit()
307 1849 aaronmk
308 1869 aaronmk
connect = DbConn
309
310 832 aaronmk
##### Querying
311
312 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
313 2085 aaronmk
    '''For params, see DbConn.run_query()'''
314 1894 aaronmk
    return db.run_query(*args, **kw_args)
315 11 aaronmk
316 2068 aaronmk
def mogrify(db, query, params):
317
    module = util.root_module(db.db)
318
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
319
    else: raise NotImplementedError("Can't mogrify query for "+module+
320
        ' database')
321
322 832 aaronmk
##### Recoverable querying
323 15 aaronmk
324 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
325 11 aaronmk
326 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
327 830 aaronmk
    if recover == None: recover = False
328
329 2148 aaronmk
    try:
330
        def run(): return run_raw_query(db, query, params, cacheable)
331
        if recover and not db.is_cached(query, params):
332
            return with_savepoint(db, run)
333
        else: return run() # don't need savepoint if cached
334
    except Exception, e:
335
        if not recover: raise # need savepoint to run index_cols()
336
        msg = str(e)
337
        match = re.search(r'duplicate key value violates unique constraint '
338
            r'"((_?[^\W_]+)_[^"]+)"', msg)
339
        if match:
340
            constraint, table = match.groups()
341
            try: cols = index_cols(db, table, constraint)
342
            except NotImplementedError: raise e
343
            else: raise DuplicateKeyException(cols, e)
344
        match = re.search(r'null value in column "(\w+)" violates not-null '
345
            'constraint', msg)
346
        if match: raise NullValueException([match.group(1)], e)
347
        match = re.search(r'relation "(\w+)" already exists', msg)
348
        if match: raise DuplicateTableException(match.group(1), e)
349 2188 aaronmk
        match = re.search(r'function "(\w+)" already exists', msg)
350
        if match: raise DuplicateFunctionException(match.group(1), e)
351 2148 aaronmk
        raise # no specific exception raised
352 830 aaronmk
353 832 aaronmk
##### Basic queries
354
355 2153 aaronmk
def next_version(name):
356
    '''Prepends the version # so it won't be removed if the name is truncated'''
357 2163 aaronmk
    version = 1 # first existing name was version 0
358 2153 aaronmk
    match = re.match(r'^v(\d+)_(.*)$', name)
359
    if match:
360
        version = int(match.group(1))+1
361
        name = match.group(2)
362
    return 'v'+str(version)+'_'+name
363
364 2151 aaronmk
def run_query_into(db, query, params, into_ref=None, *args, **kw_args):
365 2085 aaronmk
    '''Outputs a query to a temp table.
366
    For params, see run_query().
367
    '''
368 2151 aaronmk
    if into_ref == None: return run_query(db, query, params, *args, **kw_args)
369 2085 aaronmk
    else: # place rows in temp table
370 2151 aaronmk
        check_name(into_ref[0])
371 2153 aaronmk
        kw_args['recover'] = True
372
        while True:
373
            try:
374 2194 aaronmk
                create_query = 'CREATE'
375
                if not db.debug: create_query += ' TEMP'
376
                create_query += ' TABLE '+into_ref[0]+' AS '+query
377
378
                return run_query(db, create_query, params, *args, **kw_args)
379 2153 aaronmk
                    # CREATE TABLE AS sets rowcount to # rows in query
380
            except DuplicateTableException, e:
381
                into_ref[0] = next_version(into_ref[0])
382
                # try again with next version of name
383 2085 aaronmk
384 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
385
386 2127 aaronmk
join_using = object() # tells mk_select() to join the column with USING
387
388 2187 aaronmk
filter_out = object() # tells mk_select() to filter out rows that match the join
389 2180 aaronmk
390 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
391
392
def mk_select(db, tables, fields=None, conds=None, distinct_on=None, limit=None,
393
    start=None, order_by=order_by_pkey, table_is_esc=False):
394 1981 aaronmk
    '''
395 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
396 2187 aaronmk
        together: [table0, (table1, joins), ...]
397
398
        joins has the format: dict(right_col=left_col, ...)
399
        * if left_col is join_using, left_col is set to right_col
400
        * if left_col is filter_out, the tables are LEFT JOINed together and the
401
          query is filtered by `right_col IS NULL` (indicating no match)
402 1981 aaronmk
    @param fields Use None to select all fields in the table
403 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
404
        use all columns
405 1981 aaronmk
    @param table_is_esc Whether the table name has already been escaped
406 2054 aaronmk
    @return tuple(query, params)
407 1981 aaronmk
    '''
408 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
409 2058 aaronmk
410 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
411 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
412 2121 aaronmk
    table0 = tables.pop(0) # first table is separate
413
414 1135 aaronmk
    if conds == None: conds = {}
415 135 aaronmk
    assert limit == None or type(limit) == int
416 865 aaronmk
    assert start == None or type(start) == int
417 2120 aaronmk
    if order_by == order_by_pkey:
418 2121 aaronmk
        order_by = pkey(db, table0, recover=True, table_is_esc=table_is_esc)
419
    if not table_is_esc: table0 = esc_name_(table0)
420 865 aaronmk
421 2056 aaronmk
    params = []
422
423 2161 aaronmk
    def parse_col(field, default_table=None):
424 2056 aaronmk
        '''Parses fields'''
425 2176 aaronmk
        if field == None: field = (field,) # for None values, tuple is optional
426 2157 aaronmk
        is_tuple = isinstance(field, tuple)
427
        if is_tuple and len(field) == 1: # field is literal value
428
            value, = field
429 2056 aaronmk
            sql_ = '%s'
430
            params.append(value)
431 2157 aaronmk
        elif is_tuple and len(field) == 2: # field is col with table
432
            table, col = field
433
            if not table_is_esc: table = esc_name_(table)
434
            sql_ = table+'.'+esc_name_(col)
435 2161 aaronmk
        else:
436
            sql_ = esc_name_(field) # field is col name
437
            if default_table != None: sql_ = default_table+'.'+sql_
438 2056 aaronmk
        return sql_
439 11 aaronmk
    def cond(entry):
440 2056 aaronmk
        '''Parses conditions'''
441 13 aaronmk
        col, value = entry
442 2187 aaronmk
        cond_ = parse_col(col)+' '
443 11 aaronmk
        if value == None: cond_ += 'IS'
444
        else: cond_ += '='
445
        cond_ += ' %s'
446
        return cond_
447 2056 aaronmk
448 2200 aaronmk
    query = 'SELECT'
449
450
    # DISTINCT ON columns
451
    if distinct_on != None:
452
        query += ' DISTINCT'
453
        if distinct_on != distinct_on_all:
454
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
455
456
    # Columns
457
    query += ' '
458 1135 aaronmk
    if fields == None: query += '*'
459 2056 aaronmk
    else: query += ', '.join(map(parse_col, fields))
460 2200 aaronmk
461
    # Main table
462 2121 aaronmk
    query += ' FROM '+table0
463 865 aaronmk
464 2122 aaronmk
    # Add joins
465
    left_table = table0
466
    for table, joins in tables:
467
        if not table_is_esc: table = esc_name_(table)
468
469 2187 aaronmk
        left_join_ref = [False]
470
471 2122 aaronmk
        def join(entry):
472 2127 aaronmk
            '''Parses non-USING joins'''
473 2124 aaronmk
            right_col, left_col = entry
474 2173 aaronmk
475 2179 aaronmk
            # Parse special values
476 2176 aaronmk
            if left_col == None: left_col = (left_col,)
477
                # for None values, tuple is optional
478 2179 aaronmk
            elif left_col == join_using: left_col = right_col
479 2187 aaronmk
            elif left_col == filter_out:
480 2180 aaronmk
                left_col = right_col
481 2187 aaronmk
                left_join_ref[0] = True
482
                conds[(table, right_col)] = None # filter query by no match
483 2179 aaronmk
484
            # Create SQL
485 2185 aaronmk
            right_col = table+'.'+esc_name_(right_col)
486
            sql_ = right_col+' '
487 2173 aaronmk
            if isinstance(left_col, tuple) and len(left_col) == 1:
488
                # col is literal value
489
                value, = left_col
490
                if value == None: sql_ += 'IS'
491
                else: sql_ += '='
492
                sql_ += ' %s'
493
                params.append(value)
494
            else: # col is name
495
                left_col = parse_col(left_col, left_table)
496
                sql_ += ('= '+left_col+' OR ('+right_col+' IS NULL AND '
497
                    +left_col+' IS NULL)')
498
499
            return sql_
500 2122 aaronmk
501 2187 aaronmk
        # Create join condition and determine join type
502 2127 aaronmk
        if reduce(operator.and_, (v == join_using for v in joins.itervalues())):
503 2179 aaronmk
            # all cols w/ USING, so can use simpler USING syntax
504 2187 aaronmk
            join_cond = 'USING ('+(', '.join(joins.iterkeys()))+')'
505
        else: join_cond = 'ON '+(' AND '.join(map(join, joins.iteritems())))
506 2127 aaronmk
507 2187 aaronmk
        # Create join
508
        if left_join_ref[0]: query += ' LEFT'
509
        query += ' JOIN '+table+' '+join_cond
510
511 2122 aaronmk
        left_table = table
512
513 865 aaronmk
    missing = True
514 89 aaronmk
    if conds != {}:
515 2122 aaronmk
        query += ' WHERE '+(' AND '.join(map(cond, conds.iteritems())))
516 2056 aaronmk
        params += conds.values()
517 865 aaronmk
        missing = False
518 2186 aaronmk
    if order_by != None: query += ' ORDER BY '+parse_col(order_by, table0)
519 865 aaronmk
    if limit != None: query += ' LIMIT '+str(limit); missing = False
520
    if start != None:
521
        if start != 0: query += ' OFFSET '+str(start)
522
        missing = False
523
    if missing: warnings.warn(DbWarning(
524
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
525
526 2056 aaronmk
    return (query, params)
527 11 aaronmk
528 2054 aaronmk
def select(db, *args, **kw_args):
529
    '''For params, see mk_select() and run_query()'''
530
    recover = kw_args.pop('recover', None)
531
    cacheable = kw_args.pop('cacheable', True)
532
533
    query, params = mk_select(db, *args, **kw_args)
534
    return run_query(db, query, params, recover, cacheable)
535
536 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
537 2070 aaronmk
    returning=None, embeddable=False, table_is_esc=False):
538 1960 aaronmk
    '''
539
    @param returning str|None An inserted column (such as pkey) to return
540 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
541 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
542
        query will be fully cached, not just if it raises an exception.
543 1960 aaronmk
    @param table_is_esc Whether the table name has already been escaped
544
    '''
545 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
546
    if cols == []: cols = None # no cols (all defaults) = unknown col names
547 1960 aaronmk
    if not table_is_esc: check_name(table)
548 2063 aaronmk
549
    # Build query
550
    query = 'INSERT INTO '+table
551
    if cols != None:
552
        map(check_name, cols)
553
        query += ' ('+', '.join(cols)+')'
554
    query += ' '+select_query
555
556
    if returning != None:
557
        check_name(returning)
558
        query += ' RETURNING '+returning
559
560 2070 aaronmk
    if embeddable:
561
        # Create function
562 2189 aaronmk
        function_name = '_'.join(map(clean_name, ['insert', table] + cols))
563 2070 aaronmk
        return_type = 'SETOF '+table+'.'+returning+'%TYPE'
564 2189 aaronmk
        while True:
565
            try:
566 2194 aaronmk
                function = function_name
567
                if not db.debug: function = 'pg_temp.'+function
568
569 2189 aaronmk
                function_query = '''\
570
CREATE FUNCTION '''+function+'''() RETURNS '''+return_type+'''
571 2070 aaronmk
    LANGUAGE sql
572
    AS $$'''+mogrify(db, query, params)+''';$$;
573
'''
574 2189 aaronmk
                run_query(db, function_query, recover=True, cacheable=True)
575
                break # this version was successful
576
            except DuplicateFunctionException, e:
577
                function_name = next_version(function_name)
578
                # try again with next version of name
579 2070 aaronmk
580
        # Return query that uses function
581 2134 aaronmk
        return mk_select(db, function+'() AS f ('+returning+')', start=0,
582
            order_by=None, table_is_esc=True)# AS clause requires function alias
583 2070 aaronmk
584 2066 aaronmk
    return (query, params)
585
586
def insert_select(db, *args, **kw_args):
587 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
588 2152 aaronmk
    @param into_ref List with name of temp table to place RETURNING values in
589 2072 aaronmk
    '''
590 2151 aaronmk
    into_ref = kw_args.pop('into_ref', None)
591
    if into_ref != None: kw_args['embeddable'] = True
592 2066 aaronmk
    recover = kw_args.pop('recover', None)
593
    cacheable = kw_args.pop('cacheable', True)
594
595
    query, params = mk_insert_select(db, *args, **kw_args)
596 2153 aaronmk
    return run_query_into(db, query, params, into_ref, recover=recover,
597
        cacheable=cacheable)
598 2063 aaronmk
599 2066 aaronmk
default = object() # tells insert() to use the default value for a column
600
601 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
602 2085 aaronmk
    '''For params, see insert_select()'''
603 1960 aaronmk
    if lists.is_seq(row): cols = None
604
    else:
605
        cols = row.keys()
606
        row = row.values()
607
    row = list(row) # ensure that "!= []" works
608
609 1961 aaronmk
    # Check for special values
610
    labels = []
611
    values = []
612
    for value in row:
613
        if value == default: labels.append('DEFAULT')
614
        else:
615
            labels.append('%s')
616
            values.append(value)
617
618
    # Build query
619 2063 aaronmk
    if values != []: query = ' VALUES ('+(', '.join(labels))+')'
620
    else: query = None
621 1554 aaronmk
622 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
623 11 aaronmk
624 135 aaronmk
def last_insert_id(db):
625 1849 aaronmk
    module = util.root_module(db.db)
626 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
627
    elif module == 'MySQLdb': return db.insert_id()
628
    else: return None
629 13 aaronmk
630 1968 aaronmk
def truncate(db, table, schema='public'):
631
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
632 832 aaronmk
633
##### Database structure queries
634
635 2084 aaronmk
def pkey(db, table, recover=None, table_is_esc=False):
636 832 aaronmk
    '''Assumed to be first column in table'''
637 2120 aaronmk
    return col_names(select(db, table, limit=0, order_by=None, recover=recover,
638 2084 aaronmk
        table_is_esc=table_is_esc)).next()
639 832 aaronmk
640 853 aaronmk
def index_cols(db, table, index):
641
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
642
    automatically created. When you don't know whether something is a UNIQUE
643
    constraint or a UNIQUE index, use this function.'''
644
    check_name(table)
645
    check_name(index)
646 1909 aaronmk
    module = util.root_module(db.db)
647
    if module == 'psycopg2':
648
        return list(values(run_query(db, '''\
649 853 aaronmk
SELECT attname
650 866 aaronmk
FROM
651
(
652
        SELECT attnum, attname
653
        FROM pg_index
654
        JOIN pg_class index ON index.oid = indexrelid
655
        JOIN pg_class table_ ON table_.oid = indrelid
656
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
657
        WHERE
658
            table_.relname = %(table)s
659
            AND index.relname = %(index)s
660
    UNION
661
        SELECT attnum, attname
662
        FROM
663
        (
664
            SELECT
665
                indrelid
666
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
667
                    AS indkey
668
            FROM pg_index
669
            JOIN pg_class index ON index.oid = indexrelid
670
            JOIN pg_class table_ ON table_.oid = indrelid
671
            WHERE
672
                table_.relname = %(table)s
673
                AND index.relname = %(index)s
674
        ) s
675
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
676
) s
677 853 aaronmk
ORDER BY attnum
678
''',
679 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
680
    else: raise NotImplementedError("Can't list index columns for "+module+
681
        ' database')
682 853 aaronmk
683 464 aaronmk
def constraint_cols(db, table, constraint):
684
    check_name(table)
685
    check_name(constraint)
686 1849 aaronmk
    module = util.root_module(db.db)
687 464 aaronmk
    if module == 'psycopg2':
688
        return list(values(run_query(db, '''\
689
SELECT attname
690
FROM pg_constraint
691
JOIN pg_class ON pg_class.oid = conrelid
692
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
693
WHERE
694
    relname = %(table)s
695
    AND conname = %(constraint)s
696
ORDER BY attnum
697
''',
698
            {'table': table, 'constraint': constraint})))
699
    else: raise NotImplementedError("Can't list constraint columns for "+module+
700
        ' database')
701
702 2096 aaronmk
row_num_col = '_row_num'
703
704 2086 aaronmk
def add_row_num(db, table):
705 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
706
    be the primary key.'''
707 2086 aaronmk
    check_name(table)
708 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
709 2117 aaronmk
        +' serial NOT NULL PRIMARY KEY')
710 2086 aaronmk
711 1968 aaronmk
def tables(db, schema='public', table_like='%'):
712 1849 aaronmk
    module = util.root_module(db.db)
713 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
714 832 aaronmk
    if module == 'psycopg2':
715 1968 aaronmk
        return values(run_query(db, '''\
716
SELECT tablename
717
FROM pg_tables
718
WHERE
719
    schemaname = %(schema)s
720
    AND tablename LIKE %(table_like)s
721
ORDER BY tablename
722
''',
723
            params, cacheable=True))
724
    elif module == 'MySQLdb':
725
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
726
            cacheable=True))
727 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
728 830 aaronmk
729 833 aaronmk
##### Database management
730
731 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
732
    '''For kw_args, see tables()'''
733
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
734 833 aaronmk
735 832 aaronmk
##### Heuristic queries
736
737 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
738 1554 aaronmk
    '''Recovers from errors.
739 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
740
    '''
741 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
742
743 471 aaronmk
    try:
744 2149 aaronmk
        cur = insert(db, table, row, pkey_, recover=True)
745 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
746
            row_ct_ref[0] += cur.rowcount
747
        return value(cur)
748 471 aaronmk
    except DuplicateKeyException, e:
749 2104 aaronmk
        return value(select(db, table, [pkey_],
750 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
751 471 aaronmk
752 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
753 830 aaronmk
    '''Recovers from errors'''
754 2209 aaronmk
    try: return value(select(db, table, [pkey], row, limit=1, recover=True))
755 14 aaronmk
    except StopIteration:
756 40 aaronmk
        if not create: raise
757 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
758 2078 aaronmk
759 2134 aaronmk
def put_table(db, out_table, in_tables, mapping, limit=None, start=0,
760
    row_ct_ref=None, table_is_esc=False):
761 2078 aaronmk
    '''Recovers from errors.
762
    Only works under PostgreSQL (uses INSERT RETURNING).
763 2131 aaronmk
    @param in_tables The main input table to select from, followed by a list of
764
        tables to join with it using the main input table's pkey
765 2133 aaronmk
    @return (table, col) Where the pkeys (from INSERT RETURNING) are made
766 2078 aaronmk
        available
767
    '''
768 2162 aaronmk
    temp_suffix = clean_name(out_table)
769 2158 aaronmk
        # suffix, not prefix, so main name won't be removed if name is truncated
770
    pkeys_ref = ['pkeys_'+temp_suffix]
771 2131 aaronmk
772 2132 aaronmk
    # Join together input tables
773 2131 aaronmk
    in_tables = in_tables[:] # don't modify input!
774
    in_tables0 = in_tables.pop(0) # first table is separate
775
    in_pkey = pkey(db, in_tables0, recover=True, table_is_esc=table_is_esc)
776 2178 aaronmk
    insert_joins = [in_tables0]+[(t, {in_pkey: join_using}) for t in in_tables]
777 2131 aaronmk
778 2142 aaronmk
    out_pkey = pkey(db, out_table, recover=True, table_is_esc=table_is_esc)
779
    pkeys_cols = [in_pkey, out_pkey]
780
781 2201 aaronmk
    pkeys_table_exists_ref = [False]
782
    def run_query_into_pkeys(query, params):
783
        if pkeys_table_exists_ref[0]:
784
            insert_select(db, pkeys_ref[0], pkeys_cols, query, params)
785
        else:
786
            run_query_into(db, query, params, into_ref=pkeys_ref)
787
            pkeys_table_exists_ref[0] = True
788
789 2208 aaronmk
    conds = {}
790 2207 aaronmk
    distinct_on = None
791
    def mk_main_select(cols):
792 2208 aaronmk
        return mk_select(db, insert_joins, cols, conds, distinct_on,
793 2205 aaronmk
            order_by=None, limit=limit, start=start, table_is_esc=table_is_esc)
794 2132 aaronmk
795 2206 aaronmk
    # Do inserts and selects
796 2158 aaronmk
    out_pkeys_ref = ['out_pkeys_'+temp_suffix]
797 2206 aaronmk
    while True:
798
        try:
799
            cur = insert_select(db, out_table, mapping.keys(),
800 2207 aaronmk
                *mk_main_select(mapping.values()), returning=out_pkey,
801 2206 aaronmk
                into_ref=out_pkeys_ref, recover=True, table_is_esc=table_is_esc)
802
            if row_ct_ref != None and cur.rowcount >= 0:
803
                row_ct_ref[0] += cur.rowcount
804
                add_row_num(db, out_pkeys_ref[0]) # for joining with input pkeys
805
806
            # Get input pkeys corresponding to rows in insert
807
            in_pkeys_ref = ['in_pkeys_'+temp_suffix]
808 2207 aaronmk
            run_query_into(db, *mk_main_select([in_pkey]),
809
                into_ref=in_pkeys_ref)
810 2206 aaronmk
            add_row_num(db, in_pkeys_ref[0]) # for joining with output pkeys
811
812
            # Join together output and input pkeys
813
            run_query_into_pkeys(*mk_select(db, [in_pkeys_ref[0],
814
                (out_pkeys_ref[0], {row_num_col: join_using})], pkeys_cols,
815
                start=0))
816
817
            break # insert successful
818
        except DuplicateKeyException, e:
819
            join_cols = util.dict_subset_right_join(mapping, e.cols)
820
            select_joins = insert_joins + [(out_table, join_cols)]
821
822
            # Get pkeys of already existing rows
823
            run_query_into_pkeys(*mk_select(db, select_joins, pkeys_cols,
824
                order_by=None, start=0, table_is_esc=table_is_esc))
825
826
            # Save existing pkeys in temp table for joining on
827
            existing_pkeys_ref = ['existing_pkeys_'+temp_suffix]
828
            run_query_into(db, *mk_select(db, pkeys_ref[0], [in_pkey],
829
                order_by=None, start=0, table_is_esc=True),
830
                into_ref=existing_pkeys_ref)
831
                # need table_is_esc=True to make table name case-insensitive
832
833
            # rerun loop with additional constraints
834
            break # but until NullValueExceptions are handled, end loop here
835 2132 aaronmk
836 2154 aaronmk
    return (pkeys_ref[0], out_pkey)
837 2115 aaronmk
838
##### Data cleanup
839
840
def cleanup_table(db, table, cols, table_is_esc=False):
841
    def esc_name_(name): return esc_name(db, name)
842
843
    if not table_is_esc: check_name(table)
844
    cols = map(esc_name_, cols)
845
846
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
847
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
848
            for col in cols))),
849
        dict(null0='', null1=r'\N'))