Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 3238 aaronmk
import time
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 3241 aaronmk
import profiling
13 1889 aaronmk
from Proxy import Proxy
14 1872 aaronmk
import rand
15 2217 aaronmk
import sql_gen
16 862 aaronmk
import strings
17 131 aaronmk
import util
18 11 aaronmk
19 832 aaronmk
##### Exceptions
20
21 2804 aaronmk
def get_cur_query(cur, input_query=None):
22 2168 aaronmk
    raw_query = None
23
    if hasattr(cur, 'query'): raw_query = cur.query
24
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
25 2170 aaronmk
26
    if raw_query != None: return raw_query
27 2804 aaronmk
    else: return '[input] '+strings.ustr(input_query)
28 14 aaronmk
29 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
30
    '''For params, see get_cur_query()'''
31 2771 aaronmk
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
32 135 aaronmk
33 300 aaronmk
class DbException(exc.ExceptionWithCause):
34 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
35 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
36 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
37
38 2143 aaronmk
class ExceptionWithName(DbException):
39
    def __init__(self, name, cause=None):
40 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
41 2143 aaronmk
        self.name = name
42 360 aaronmk
43 3109 aaronmk
class ExceptionWithValue(DbException):
44
    def __init__(self, value, cause=None):
45
        DbException.__init__(self, 'for value: '+strings.as_tt(repr(value)),
46
            cause)
47 2240 aaronmk
        self.value = value
48
49 2945 aaronmk
class ExceptionWithNameType(DbException):
50
    def __init__(self, type_, name, cause=None):
51
        DbException.__init__(self, 'for type: '+strings.as_tt(str(type_))
52
            +'; name: '+strings.as_tt(name), cause)
53
        self.type = type_
54
        self.name = name
55
56 2306 aaronmk
class ConstraintException(DbException):
57 3345 aaronmk
    def __init__(self, name, cond, cols, cause=None):
58
        msg = 'Violated '+strings.as_tt(name)+' constraint'
59
        if cond != None: msg += ' with condition '+cond
60
        if cols != []: msg += ' on columns: '+strings.as_tt(', '.join(cols))
61
        DbException.__init__(self, msg, cause)
62 2306 aaronmk
        self.name = name
63 3345 aaronmk
        self.cond = cond
64 468 aaronmk
        self.cols = cols
65 11 aaronmk
66 2523 aaronmk
class MissingCastException(DbException):
67
    def __init__(self, type_, col, cause=None):
68
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
69
            +' on column: '+strings.as_tt(col), cause)
70
        self.type = type_
71
        self.col = col
72
73 2143 aaronmk
class NameException(DbException): pass
74
75 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
76 13 aaronmk
77 2306 aaronmk
class NullValueException(ConstraintException): pass
78 13 aaronmk
79 3346 aaronmk
class CheckException(ConstraintException): pass
80
81 3109 aaronmk
class InvalidValueException(ExceptionWithValue): pass
82 2239 aaronmk
83 2945 aaronmk
class DuplicateException(ExceptionWithNameType): pass
84 2143 aaronmk
85 89 aaronmk
class EmptyRowException(DbException): pass
86
87 865 aaronmk
##### Warnings
88
89
class DbWarning(UserWarning): pass
90
91 1930 aaronmk
##### Result retrieval
92
93
def col_names(cur): return (col[0] for col in cur.description)
94
95
def rows(cur): return iter(lambda: cur.fetchone(), None)
96
97
def consume_rows(cur):
98
    '''Used to fetch all rows so result will be cached'''
99
    iters.consume_iter(rows(cur))
100
101
def next_row(cur): return rows(cur).next()
102
103
def row(cur):
104
    row_ = next_row(cur)
105
    consume_rows(cur)
106
    return row_
107
108
def next_value(cur): return next_row(cur)[0]
109
110
def value(cur): return row(cur)[0]
111
112
def values(cur): return iters.func_iter(lambda: next_value(cur))
113
114
def value_or_none(cur):
115
    try: return value(cur)
116
    except StopIteration: return None
117
118 2762 aaronmk
##### Escaping
119 2101 aaronmk
120 2573 aaronmk
def esc_name_by_module(module, name):
121
    if module == 'psycopg2' or module == None: quote = '"'
122 2101 aaronmk
    elif module == 'MySQLdb': quote = '`'
123
    else: raise NotImplementedError("Can't escape name for "+module+' database')
124 2500 aaronmk
    return sql_gen.esc_name(name, quote)
125 2101 aaronmk
126
def esc_name_by_engine(engine, name, **kw_args):
127
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
128
129
def esc_name(db, name, **kw_args):
130
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
131
132
def qual_name(db, schema, table):
133
    def esc_name_(name): return esc_name(db, name)
134
    table = esc_name_(table)
135
    if schema != None: return esc_name_(schema)+'.'+table
136
    else: return table
137
138 1869 aaronmk
##### Database connections
139 1849 aaronmk
140 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
141 1926 aaronmk
142 1869 aaronmk
db_engines = {
143
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
144
    'PostgreSQL': ('psycopg2', {}),
145
}
146
147
DatabaseErrors_set = set([DbException])
148
DatabaseErrors = tuple(DatabaseErrors_set)
149
150
def _add_module(module):
151
    DatabaseErrors_set.add(module.DatabaseError)
152
    global DatabaseErrors
153
    DatabaseErrors = tuple(DatabaseErrors_set)
154
155
def db_config_str(db_config):
156
    return db_config['engine']+' database '+db_config['database']
157
158 2448 aaronmk
log_debug_none = lambda msg, level=2: None
159 1901 aaronmk
160 1849 aaronmk
class DbConn:
161 2923 aaronmk
    def __init__(self, db_config, autocommit=True, caching=True,
162 3183 aaronmk
        log_debug=log_debug_none, debug_temp=False, src=None):
163 2915 aaronmk
        '''
164
        @param debug_temp Whether temporary objects should instead be permanent.
165
            This assists in debugging the internal objects used by the program.
166 3183 aaronmk
        @param src In autocommit mode, will be included in a comment in every
167
            query, to help identify the data source in pg_stat_activity.
168 2915 aaronmk
        '''
169 1869 aaronmk
        self.db_config = db_config
170 2190 aaronmk
        self.autocommit = autocommit
171
        self.caching = caching
172 1901 aaronmk
        self.log_debug = log_debug
173 2193 aaronmk
        self.debug = log_debug != log_debug_none
174 2915 aaronmk
        self.debug_temp = debug_temp
175 3183 aaronmk
        self.src = src
176 3074 aaronmk
        self.autoanalyze = False
177 3269 aaronmk
        self.autoexplain = False
178
        self.profile_row_ct = None
179 1869 aaronmk
180 3124 aaronmk
        self._savepoint = 0
181 3120 aaronmk
        self._reset()
182 1869 aaronmk
183
    def __getattr__(self, name):
184
        if name == '__dict__': raise Exception('getting __dict__')
185
        if name == 'db': return self._db()
186
        else: raise AttributeError()
187
188
    def __getstate__(self):
189
        state = copy.copy(self.__dict__) # shallow copy
190 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
191 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
192
        return state
193
194 3118 aaronmk
    def clear_cache(self): self.query_results = {}
195
196 3120 aaronmk
    def _reset(self):
197 3118 aaronmk
        self.clear_cache()
198 3124 aaronmk
        assert self._savepoint == 0
199 3118 aaronmk
        self._notices_seen = set()
200
        self.__db = None
201
202 2165 aaronmk
    def connected(self): return self.__db != None
203
204 3116 aaronmk
    def close(self):
205 3119 aaronmk
        if not self.connected(): return
206
207 3135 aaronmk
        # Record that the automatic transaction is now closed
208 3136 aaronmk
        self._savepoint -= 1
209 3135 aaronmk
210 3119 aaronmk
        self.db.close()
211 3120 aaronmk
        self._reset()
212 3116 aaronmk
213 3125 aaronmk
    def reconnect(self):
214
        # Do not do this in test mode as it would roll back everything
215
        if self.autocommit: self.close()
216
        # Connection will be reopened automatically on first query
217
218 1869 aaronmk
    def _db(self):
219
        if self.__db == None:
220
            # Process db_config
221
            db_config = self.db_config.copy() # don't modify input!
222 2097 aaronmk
            schemas = db_config.pop('schemas', None)
223 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
224
            module = __import__(module_name)
225
            _add_module(module)
226
            for orig, new in mappings.iteritems():
227
                try: util.rename_key(db_config, orig, new)
228
                except KeyError: pass
229
230
            # Connect
231
            self.__db = module.connect(**db_config)
232
233 3161 aaronmk
            # Record that a transaction is already open
234
            self._savepoint += 1
235
236 1869 aaronmk
            # Configure connection
237 2906 aaronmk
            if hasattr(self.db, 'set_isolation_level'):
238
                import psycopg2.extensions
239
                self.db.set_isolation_level(
240
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
241 2101 aaronmk
            if schemas != None:
242 2893 aaronmk
                search_path = [self.esc_name(s) for s in schemas.split(',')]
243
                search_path.append(value(run_query(self, 'SHOW search_path',
244
                    log_level=4)))
245
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
246
                    log_level=3)
247 1869 aaronmk
248
        return self.__db
249 1889 aaronmk
250 1891 aaronmk
    class DbCursor(Proxy):
251 1927 aaronmk
        def __init__(self, outer):
252 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
253 2191 aaronmk
            self.outer = outer
254 1927 aaronmk
            self.query_results = outer.query_results
255 1894 aaronmk
            self.query_lookup = None
256 1891 aaronmk
            self.result = []
257 1889 aaronmk
258 2802 aaronmk
        def execute(self, query):
259 2764 aaronmk
            self._is_insert = query.startswith('INSERT')
260 2797 aaronmk
            self.query_lookup = query
261 2148 aaronmk
            try:
262 3162 aaronmk
                try: cur = self.inner.execute(query)
263 2802 aaronmk
                finally: self.query = get_cur_query(self.inner, query)
264 1904 aaronmk
            except Exception, e:
265
                self.result = e # cache the exception as the result
266
                self._cache_result()
267
                raise
268 3004 aaronmk
269
            # Always cache certain queries
270 3183 aaronmk
            query = sql_gen.lstrip(query)
271 3004 aaronmk
            if query.startswith('CREATE') or query.startswith('ALTER'):
272 3007 aaronmk
                # structural changes
273 3040 aaronmk
                # Rest of query must be unique in the face of name collisions,
274
                # so don't cache ADD COLUMN unless it has distinguishing comment
275
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
276 3007 aaronmk
                    self._cache_result()
277 3004 aaronmk
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
278 2800 aaronmk
                consume_rows(self) # fetch all rows so result will be cached
279 3004 aaronmk
280 2762 aaronmk
            return cur
281 1894 aaronmk
282 1891 aaronmk
        def fetchone(self):
283
            row = self.inner.fetchone()
284 1899 aaronmk
            if row != None: self.result.append(row)
285
            # otherwise, fetched all rows
286 1904 aaronmk
            else: self._cache_result()
287
            return row
288
289
        def _cache_result(self):
290 2948 aaronmk
            # For inserts that return a result set, don't cache result set since
291
            # inserts are not idempotent. Other non-SELECT queries don't have
292
            # their result set read, so only exceptions will be cached (an
293
            # invalid query will always be invalid).
294 1930 aaronmk
            if self.query_results != None and (not self._is_insert
295 1906 aaronmk
                or isinstance(self.result, Exception)):
296
297 1894 aaronmk
                assert self.query_lookup != None
298 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
299
                    util.dict_subset(dicts.AttrsDictView(self),
300
                    ['query', 'result', 'rowcount', 'description']))
301 1906 aaronmk
302 1916 aaronmk
        class CacheCursor:
303
            def __init__(self, cached_result): self.__dict__ = cached_result
304
305 1927 aaronmk
            def execute(self, *args, **kw_args):
306 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
307
                # otherwise, result is a rows list
308
                self.iter = iter(self.result)
309
310
            def fetchone(self):
311
                try: return self.iter.next()
312
                except StopIteration: return None
313 1891 aaronmk
314 2212 aaronmk
    def esc_value(self, value):
315 2663 aaronmk
        try: str_ = self.mogrify('%s', [value])
316
        except NotImplementedError, e:
317
            module = util.root_module(self.db)
318
            if module == 'MySQLdb':
319
                import _mysql
320
                str_ = _mysql.escape_string(value)
321
            else: raise e
322 2374 aaronmk
        return strings.to_unicode(str_)
323 2212 aaronmk
324 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
325
326 2814 aaronmk
    def std_code(self, str_):
327
        '''Standardizes SQL code.
328
        * Ensures that string literals are prefixed by `E`
329
        '''
330
        if str_.startswith("'"): str_ = 'E'+str_
331
        return str_
332
333 2665 aaronmk
    def can_mogrify(self):
334 2663 aaronmk
        module = util.root_module(self.db)
335 2665 aaronmk
        return module == 'psycopg2'
336 2663 aaronmk
337 2665 aaronmk
    def mogrify(self, query, params=None):
338
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
339
        else: raise NotImplementedError("Can't mogrify query")
340
341 2671 aaronmk
    def print_notices(self):
342 2725 aaronmk
        if hasattr(self.db, 'notices'):
343
            for msg in self.db.notices:
344
                if msg not in self._notices_seen:
345
                    self._notices_seen.add(msg)
346
                    self.log_debug(msg, level=2)
347 2671 aaronmk
348 2793 aaronmk
    def run_query(self, query, cacheable=False, log_level=2,
349 2464 aaronmk
        debug_msg_ref=None):
350 2445 aaronmk
        '''
351 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
352
            throws one of these exceptions.
353 2664 aaronmk
        @param debug_msg_ref If specified, the log message will be returned in
354
            this instead of being output. This allows you to filter log messages
355
            depending on the result of the query.
356 2445 aaronmk
        '''
357 2167 aaronmk
        assert query != None
358
359 3183 aaronmk
        if self.autocommit and self.src != None:
360 3206 aaronmk
            query = sql_gen.esc_comment(self.src)+'\t'+query
361 3183 aaronmk
362 2047 aaronmk
        if not self.caching: cacheable = False
363 1903 aaronmk
        used_cache = False
364 2664 aaronmk
365 3242 aaronmk
        if self.debug:
366
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
367 1903 aaronmk
        try:
368 1927 aaronmk
            # Get cursor
369
            if cacheable:
370 3238 aaronmk
                try: cur = self.query_results[query]
371 1927 aaronmk
                except KeyError: cur = self.DbCursor(self)
372 3238 aaronmk
                else: used_cache = True
373 1927 aaronmk
            else: cur = self.db.cursor()
374
375
            # Run query
376 3238 aaronmk
            try: cur.execute(query)
377 3162 aaronmk
            except Exception, e:
378
                _add_cursor_info(e, self, query)
379
                raise
380 3238 aaronmk
            else: self.do_autocommit()
381 1903 aaronmk
        finally:
382 3242 aaronmk
            if self.debug:
383 3244 aaronmk
                profiler.stop(self.profile_row_ct)
384 3242 aaronmk
385
                ## Log or return query
386
387 3239 aaronmk
                query = str(get_cur_query(cur, query))
388 3281 aaronmk
                # Put the src comment on a separate line in the log file
389
                query = query.replace('\t', '\n', 1)
390 3239 aaronmk
391 3240 aaronmk
                msg = 'DB query: '
392 3239 aaronmk
393 3240 aaronmk
                if used_cache: msg += 'cache hit'
394
                elif cacheable: msg += 'cache miss'
395
                else: msg += 'non-cacheable'
396 3239 aaronmk
397 3241 aaronmk
                msg += ':\n'+profiler.msg()+'\n'+strings.as_code(query, 'SQL')
398 3240 aaronmk
399 3237 aaronmk
                if debug_msg_ref != None: debug_msg_ref[0] = msg
400
                else: self.log_debug(msg, log_level)
401 3245 aaronmk
402
                self.print_notices()
403 1903 aaronmk
404
        return cur
405 1914 aaronmk
406 2797 aaronmk
    def is_cached(self, query): return query in self.query_results
407 2139 aaronmk
408 2907 aaronmk
    def with_autocommit(self, func):
409 2801 aaronmk
        import psycopg2.extensions
410
411
        prev_isolation_level = self.db.isolation_level
412 2907 aaronmk
        self.db.set_isolation_level(
413
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
414 2683 aaronmk
        try: return func()
415 2801 aaronmk
        finally: self.db.set_isolation_level(prev_isolation_level)
416 2683 aaronmk
417 2139 aaronmk
    def with_savepoint(self, func):
418 3137 aaronmk
        top = self._savepoint == 0
419 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
420 3137 aaronmk
421 3272 aaronmk
        if self.debug:
422 3273 aaronmk
            self.log_debug('Begin transaction', level=4)
423 3272 aaronmk
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
424
425 3160 aaronmk
        # Must happen before running queries so they don't get autocommitted
426
        self._savepoint += 1
427
428 3137 aaronmk
        if top: query = 'START TRANSACTION ISOLATION LEVEL READ COMMITTED'
429
        else: query = 'SAVEPOINT '+savepoint
430
        self.run_query(query, log_level=4)
431
        try:
432
            return func()
433
            if top: self.run_query('COMMIT', log_level=4)
434 2139 aaronmk
        except:
435 3137 aaronmk
            if top: query = 'ROLLBACK'
436
            else: query = 'ROLLBACK TO SAVEPOINT '+savepoint
437
            self.run_query(query, log_level=4)
438
439 2139 aaronmk
            raise
440 2930 aaronmk
        finally:
441
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
442
            # "The savepoint remains valid and can be rolled back to again"
443
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
444 3137 aaronmk
            if not top:
445
                self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
446 2930 aaronmk
447
            self._savepoint -= 1
448
            assert self._savepoint >= 0
449
450 3272 aaronmk
            if self.debug:
451
                profiler.stop(self.profile_row_ct)
452 3273 aaronmk
                self.log_debug('End transaction\n'+profiler.msg(), level=4)
453 3272 aaronmk
454 2930 aaronmk
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
455 2191 aaronmk
456
    def do_autocommit(self):
457
        '''Autocommits if outside savepoint'''
458 3135 aaronmk
        assert self._savepoint >= 1
459
        if self.autocommit and self._savepoint == 1:
460 2924 aaronmk
            self.log_debug('Autocommitting', level=4)
461 2191 aaronmk
            self.db.commit()
462 2643 aaronmk
463 3155 aaronmk
    def col_info(self, col, cacheable=True):
464 2643 aaronmk
        table = sql_gen.Table('columns', 'information_schema')
465 3063 aaronmk
        type_ = sql_gen.Coalesce(sql_gen.Nullif(sql_gen.Col('data_type'),
466
            'USER-DEFINED'), sql_gen.Col('udt_name'))
467 3078 aaronmk
        cols = [type_, 'column_default',
468
            sql_gen.Cast('boolean', sql_gen.Col('is_nullable'))]
469 2643 aaronmk
470
        conds = [('table_name', col.table.name), ('column_name', col.name)]
471
        schema = col.table.schema
472
        if schema != None: conds.append(('table_schema', schema))
473
474 2819 aaronmk
        type_, default, nullable = row(select(self, table, cols, conds,
475 3150 aaronmk
            order_by='table_schema', limit=1, cacheable=cacheable, log_level=4))
476 2643 aaronmk
            # TODO: order_by search_path schema order
477 2819 aaronmk
        default = sql_gen.as_Code(default, self)
478
479
        return sql_gen.TypedCol(col.name, type_, default, nullable)
480 2917 aaronmk
481
    def TempFunction(self, name):
482
        if self.debug_temp: schema = None
483
        else: schema = 'pg_temp'
484
        return sql_gen.Function(name, schema)
485 1849 aaronmk
486 1869 aaronmk
connect = DbConn
487
488 832 aaronmk
##### Recoverable querying
489 15 aaronmk
490 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
491 11 aaronmk
492 2791 aaronmk
def run_query(db, query, recover=None, cacheable=False, log_level=2,
493
    log_ignore_excs=None, **kw_args):
494 2794 aaronmk
    '''For params, see DbConn.run_query()'''
495 830 aaronmk
    if recover == None: recover = False
496 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
497
    log_ignore_excs = tuple(log_ignore_excs)
498 3236 aaronmk
    debug_msg_ref = [None]
499 830 aaronmk
500 3267 aaronmk
    query = with_explain_comment(db, query)
501 3258 aaronmk
502 2148 aaronmk
    try:
503 2464 aaronmk
        try:
504 2794 aaronmk
            def run(): return db.run_query(query, cacheable, log_level,
505 2793 aaronmk
                debug_msg_ref, **kw_args)
506 2796 aaronmk
            if recover and not db.is_cached(query):
507 2464 aaronmk
                return with_savepoint(db, run)
508
            else: return run() # don't need savepoint if cached
509
        except Exception, e:
510 3095 aaronmk
            msg = strings.ustr(e.args[0])
511 2464 aaronmk
512 3095 aaronmk
            match = re.match(r'^duplicate key value violates unique constraint '
513 3338 aaronmk
                r'"(.+?)"', msg)
514 2464 aaronmk
            if match:
515 3338 aaronmk
                constraint, = match.groups()
516 3025 aaronmk
                cols = []
517
                if recover: # need auto-rollback to run index_cols()
518 3319 aaronmk
                    try: cols = index_cols(db, constraint)
519 3025 aaronmk
                    except NotImplementedError: pass
520 3345 aaronmk
                raise DuplicateKeyException(constraint, None, cols, e)
521 2464 aaronmk
522 3095 aaronmk
            match = re.match(r'^null value in column "(.+?)" violates not-null'
523 2464 aaronmk
                r' constraint', msg)
524 3345 aaronmk
            if match:
525
                col, = match.groups()
526
                raise NullValueException('NOT NULL', None, [col], e)
527 2464 aaronmk
528 3346 aaronmk
            match = re.match(r'^new row for relation "(.+?)" violates check '
529
                r'constraint "(.+?)"', msg)
530
            if match:
531
                table, constraint = match.groups()
532 3347 aaronmk
                constraint = sql_gen.Col(constraint, table)
533 3349 aaronmk
                cond = None
534
                if recover: # need auto-rollback to run constraint_cond()
535
                    try: cond = constraint_cond(db, constraint)
536
                    except NotImplementedError: pass
537
                raise CheckException(constraint.to_str(db), cond, [], e)
538 3346 aaronmk
539 3095 aaronmk
            match = re.match(r'^(?:invalid input (?:syntax|value)\b.*?'
540 3109 aaronmk
                r'|.+? field value out of range): "(.+?)"', msg)
541 2464 aaronmk
            if match:
542 3109 aaronmk
                value, = match.groups()
543
                raise InvalidValueException(strings.to_unicode(value), e)
544 2464 aaronmk
545 3095 aaronmk
            match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
546 2523 aaronmk
                r'is of type', msg)
547
            if match:
548
                col, type_ = match.groups()
549
                raise MissingCastException(type_, col, e)
550
551 3095 aaronmk
            match = re.match(r'^(\S+) "(.+?)".*? already exists', msg)
552 2945 aaronmk
            if match:
553
                type_, name = match.groups()
554
                raise DuplicateException(type_, name, e)
555 2464 aaronmk
556
            raise # no specific exception raised
557
    except log_ignore_excs:
558
        log_level += 2
559
        raise
560
    finally:
561 3236 aaronmk
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)
562 830 aaronmk
563 832 aaronmk
##### Basic queries
564
565 3256 aaronmk
def is_explainable(query):
566
    # See <http://www.postgresql.org/docs/8.3/static/sql-explain.html#AEN57749>
567 3257 aaronmk
    return re.match(r'^(?:SELECT|INSERT|UPDATE|DELETE|VALUES|EXECUTE|DECLARE)\b'
568
        , query)
569 3256 aaronmk
570 3263 aaronmk
def explain(db, query, **kw_args):
571
    '''
572
    For params, see run_query().
573
    '''
574 3267 aaronmk
    kw_args.setdefault('log_level', 4)
575 3263 aaronmk
576 3256 aaronmk
    return strings.join_lines(values(run_query(db, 'EXPLAIN '+query,
577 3263 aaronmk
        recover=True, cacheable=True, **kw_args)))
578 3256 aaronmk
        # not a higher log_level because it's useful to see what query is being
579
        # run before it's executed, which EXPLAIN effectively provides
580
581 3265 aaronmk
def has_comment(query): return query.endswith('*/')
582
583
def with_explain_comment(db, query, **kw_args):
584 3269 aaronmk
    if db.autoexplain and not has_comment(query) and is_explainable(query):
585 3265 aaronmk
        query += '\n'+sql_gen.esc_comment(' EXPLAIN:\n'
586
            +explain(db, query, **kw_args))
587
    return query
588
589 2153 aaronmk
def next_version(name):
590 2163 aaronmk
    version = 1 # first existing name was version 0
591 2586 aaronmk
    match = re.match(r'^(.*)#(\d+)$', name)
592 2153 aaronmk
    if match:
593 2586 aaronmk
        name, version = match.groups()
594
        version = int(version)+1
595 2932 aaronmk
    return sql_gen.concat(name, '#'+str(version))
596 2153 aaronmk
597 2899 aaronmk
def lock_table(db, table, mode):
598
    table = sql_gen.as_Table(table)
599
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
600
601 3303 aaronmk
def run_query_into(db, query, into=None, add_pkey_=False, **kw_args):
602 2085 aaronmk
    '''Outputs a query to a temp table.
603
    For params, see run_query().
604
    '''
605 2789 aaronmk
    if into == None: return run_query(db, query, **kw_args)
606 2790 aaronmk
607
    assert isinstance(into, sql_gen.Table)
608
609 2992 aaronmk
    into.is_temp = True
610 3008 aaronmk
    # "temporary tables cannot specify a schema name", so remove schema
611
    into.schema = None
612 2992 aaronmk
613 2790 aaronmk
    kw_args['recover'] = True
614 2945 aaronmk
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
615 2790 aaronmk
616 2916 aaronmk
    temp = not db.debug_temp # tables are permanent in debug_temp mode
617 2790 aaronmk
618
    # Create table
619
    while True:
620
        create_query = 'CREATE'
621
        if temp: create_query += ' TEMP'
622
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
623 2385 aaronmk
624 2790 aaronmk
        try:
625
            cur = run_query(db, create_query, **kw_args)
626
                # CREATE TABLE AS sets rowcount to # rows in query
627
            break
628 2945 aaronmk
        except DuplicateException, e:
629 2790 aaronmk
            into.name = next_version(into.name)
630
            # try again with next version of name
631
632 3303 aaronmk
    if add_pkey_: add_pkey(db, into)
633 3075 aaronmk
634
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
635
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
636
    # table is going to be used in complex queries, it is wise to run ANALYZE on
637
    # the temporary table after it is populated."
638
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
639
    # If into is not a temp table, ANALYZE is useful but not required.
640 3073 aaronmk
    analyze(db, into)
641 2790 aaronmk
642
    return cur
643 2085 aaronmk
644 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
645
646 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
647
648 2233 aaronmk
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
649 2293 aaronmk
    start=None, order_by=order_by_pkey, default_table=None):
650 1981 aaronmk
    '''
651 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
652 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
653 1981 aaronmk
    @param fields Use None to select all fields in the table
654 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
655 2379 aaronmk
        * container can be any iterable type
656 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
657
        * compare_right_side: sql_gen.ValueCond|literal value
658 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
659
        use all columns
660 2786 aaronmk
    @return query
661 1981 aaronmk
    '''
662 2315 aaronmk
    # Parse tables param
663 2964 aaronmk
    tables = lists.mk_seq(tables)
664 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
665 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
666 2121 aaronmk
667 2315 aaronmk
    # Parse other params
668 2376 aaronmk
    if conds == None: conds = []
669 2650 aaronmk
    elif dicts.is_dict(conds): conds = conds.items()
670 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
671 3129 aaronmk
    assert limit == None or isinstance(limit, (int, long))
672
    assert start == None or isinstance(start, (int, long))
673 2315 aaronmk
    if order_by is order_by_pkey:
674
        if distinct_on != []: order_by = None
675
        else: order_by = pkey(db, table0, recover=True)
676 865 aaronmk
677 2315 aaronmk
    query = 'SELECT'
678 2056 aaronmk
679 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
680 2056 aaronmk
681 2200 aaronmk
    # DISTINCT ON columns
682 2233 aaronmk
    if distinct_on != []:
683 2467 aaronmk
        query += '\nDISTINCT'
684 2254 aaronmk
        if distinct_on is not distinct_on_all:
685 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
686
687
    # Columns
688 3185 aaronmk
    if query.find('\n') >= 0: whitespace = '\n'
689
    else: whitespace = ' '
690
    if fields == None: query += whitespace+'*'
691 2765 aaronmk
    else:
692
        assert fields != []
693 3185 aaronmk
        if len(fields) > 1: whitespace = '\n'
694
        query += whitespace+('\n, '.join(map(parse_col, fields)))
695 2200 aaronmk
696
    # Main table
697 3185 aaronmk
    if query.find('\n') >= 0 or len(tables) > 0: whitespace = '\n'
698
    else: whitespace = ' '
699
    query += whitespace+'FROM '+table0.to_str(db)
700 865 aaronmk
701 2122 aaronmk
    # Add joins
702 2271 aaronmk
    left_table = table0
703 2263 aaronmk
    for join_ in tables:
704
        table = join_.table
705 2238 aaronmk
706 2343 aaronmk
        # Parse special values
707
        if join_.type_ is sql_gen.filter_out: # filter no match
708 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
709 2853 aaronmk
                sql_gen.CompareCond(None, '~=')))
710 2343 aaronmk
711 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
712 2122 aaronmk
713
        left_table = table
714
715 865 aaronmk
    missing = True
716 2376 aaronmk
    if conds != []:
717 2576 aaronmk
        if len(conds) == 1: whitespace = ' '
718
        else: whitespace = '\n'
719 2578 aaronmk
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
720
            .to_str(db) for l, r in conds], 'WHERE')
721 2227 aaronmk
    if order_by != None:
722 2467 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
723 3297 aaronmk
    if limit != None: query += '\nLIMIT '+str(limit)
724 865 aaronmk
    if start != None:
725 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
726 865 aaronmk
727 3266 aaronmk
    query = with_explain_comment(db, query)
728
729 2786 aaronmk
    return query
730 11 aaronmk
731 2054 aaronmk
def select(db, *args, **kw_args):
732
    '''For params, see mk_select() and run_query()'''
733
    recover = kw_args.pop('recover', None)
734
    cacheable = kw_args.pop('cacheable', True)
735 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
736 2054 aaronmk
737 2791 aaronmk
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
738
        log_level=log_level)
739 2054 aaronmk
740 2788 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
741 3181 aaronmk
    embeddable=False, ignore=False, src=None):
742 1960 aaronmk
    '''
743
    @param returning str|None An inserted column (such as pkey) to return
744 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
745 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
746
        query will be fully cached, not just if it raises an exception.
747 3009 aaronmk
    @param ignore Whether to ignore duplicate keys.
748 3181 aaronmk
    @param src Will be included in the name of any created function, to help
749
        identify the data source in pg_stat_activity.
750 1960 aaronmk
    '''
751 2754 aaronmk
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
752 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
753 3010 aaronmk
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
754 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
755 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
756 2063 aaronmk
757 2497 aaronmk
    first_line = 'INSERT INTO '+table.to_str(db)
758 2063 aaronmk
759 3009 aaronmk
    def mk_insert(select_query):
760
        query = first_line
761 3014 aaronmk
        if cols != None:
762
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
763 3009 aaronmk
        query += '\n'+select_query
764
765
        if returning != None:
766
            returning_name_col = sql_gen.to_name_only_col(returning)
767
            query += '\nRETURNING '+returning_name_col.to_str(db)
768
769
        return query
770 2063 aaronmk
771 3017 aaronmk
    return_type = 'unknown'
772
    if returning != None: return_type = returning.to_str(db)+'%TYPE'
773
774 3009 aaronmk
    lang = 'sql'
775
    if ignore:
776 3017 aaronmk
        # Always return something to set the correct rowcount
777
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
778
779 3009 aaronmk
        embeddable = True # must use function
780
        lang = 'plpgsql'
781 3010 aaronmk
782 3092 aaronmk
        if cols == None:
783
            row = [sql_gen.Col(sql_gen.all_cols, 'row')]
784
            row_vars = [sql_gen.Table('row')]
785
        else:
786
            row_vars = row = [sql_gen.Col(c.name, 'row') for c in cols]
787
788 3009 aaronmk
        query = '''\
789 3010 aaronmk
DECLARE
790 3014 aaronmk
    row '''+table.to_str(db)+'''%ROWTYPE;
791 3009 aaronmk
BEGIN
792 3019 aaronmk
    /* Need an EXCEPTION block for each individual row because "When an error is
793
    caught by an EXCEPTION clause, [...] all changes to persistent database
794
    state within the block are rolled back."
795
    This is unfortunate because "A block containing an EXCEPTION clause is
796
    significantly more expensive to enter and exit than a block without one."
797 3015 aaronmk
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
798
#PLPGSQL-ERROR-TRAPPING)
799
    */
800 3092 aaronmk
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
801 3034 aaronmk
'''+select_query+'''
802
    LOOP
803 3015 aaronmk
        BEGIN
804 3019 aaronmk
            RETURN QUERY
805 3014 aaronmk
'''+mk_insert(sql_gen.Values(row).to_str(db))+'''
806 3010 aaronmk
;
807 3015 aaronmk
        EXCEPTION
808 3019 aaronmk
            WHEN unique_violation THEN NULL; -- continue to next row
809 3015 aaronmk
        END;
810 3010 aaronmk
    END LOOP;
811
END;\
812 3009 aaronmk
'''
813
    else: query = mk_insert(select_query)
814
815 2070 aaronmk
    if embeddable:
816
        # Create function
817 2513 aaronmk
        function_name = sql_gen.clean_name(first_line)
818 3181 aaronmk
        if src != None: function_name = src+': '+function_name
819 2189 aaronmk
        while True:
820
            try:
821 2918 aaronmk
                function = db.TempFunction(function_name)
822 2194 aaronmk
823 2189 aaronmk
                function_query = '''\
824 2698 aaronmk
CREATE FUNCTION '''+function.to_str(db)+'''()
825 3017 aaronmk
RETURNS SETOF '''+return_type+'''
826 3009 aaronmk
LANGUAGE '''+lang+'''
827 2467 aaronmk
AS $$
828 3009 aaronmk
'''+query+'''
829 2467 aaronmk
$$;
830 2070 aaronmk
'''
831 2446 aaronmk
                run_query(db, function_query, recover=True, cacheable=True,
832 2945 aaronmk
                    log_ignore_excs=(DuplicateException,))
833 2189 aaronmk
                break # this version was successful
834 2945 aaronmk
            except DuplicateException, e:
835 2189 aaronmk
                function_name = next_version(function_name)
836
                # try again with next version of name
837 2070 aaronmk
838 2337 aaronmk
        # Return query that uses function
839 3009 aaronmk
        cols = None
840
        if returning != None: cols = [returning]
841 2698 aaronmk
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
842 3009 aaronmk
            cols) # AS clause requires function alias
843 3298 aaronmk
        return mk_select(db, func_table, order_by=None)
844 2070 aaronmk
845 2787 aaronmk
    return query
846 2066 aaronmk
847 3074 aaronmk
def insert_select(db, table, *args, **kw_args):
848 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
849 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
850
        values in
851 2072 aaronmk
    '''
852 3141 aaronmk
    returning = kw_args.get('returning', None)
853
    ignore = kw_args.get('ignore', False)
854
855 2386 aaronmk
    into = kw_args.pop('into', None)
856
    if into != None: kw_args['embeddable'] = True
857 2066 aaronmk
    recover = kw_args.pop('recover', None)
858 3141 aaronmk
    if ignore: recover = True
859 2066 aaronmk
    cacheable = kw_args.pop('cacheable', True)
860 2673 aaronmk
    log_level = kw_args.pop('log_level', 2)
861 2066 aaronmk
862 3141 aaronmk
    rowcount_only = ignore and returning == None # keep NULL rows on server
863
    if rowcount_only: into = sql_gen.Table('rowcount')
864
865 3074 aaronmk
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
866
        into, recover=recover, cacheable=cacheable, log_level=log_level)
867 3141 aaronmk
    if rowcount_only: empty_temp(db, into)
868 3074 aaronmk
    autoanalyze(db, table)
869
    return cur
870 2063 aaronmk
871 2738 aaronmk
default = sql_gen.default # tells insert() to use the default value for a column
872 2066 aaronmk
873 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
874 2085 aaronmk
    '''For params, see insert_select()'''
875 1960 aaronmk
    if lists.is_seq(row): cols = None
876
    else:
877
        cols = row.keys()
878
        row = row.values()
879 2738 aaronmk
    row = list(row) # ensure that "== []" works
880 1960 aaronmk
881 2738 aaronmk
    if row == []: query = None
882
    else: query = sql_gen.Values(row).to_str(db)
883 1961 aaronmk
884 2788 aaronmk
    return insert_select(db, table, cols, query, *args, **kw_args)
885 11 aaronmk
886 3152 aaronmk
def mk_update(db, table, changes=None, cond=None, in_place=False,
887 3153 aaronmk
    cacheable_=True):
888 2402 aaronmk
    '''
889
    @param changes [(col, new_value),...]
890
        * container can be any iterable type
891
        * col: sql_gen.Code|str (for col name)
892
        * new_value: sql_gen.Code|literal value
893
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
894 3056 aaronmk
    @param in_place If set, locks the table and updates rows in place.
895
        This avoids creating dead rows in PostgreSQL.
896
        * cond must be None
897 3153 aaronmk
    @param cacheable_ Whether column structure information used to generate the
898 3152 aaronmk
        query can be cached
899 2402 aaronmk
    @return str query
900
    '''
901 3057 aaronmk
    table = sql_gen.as_Table(table)
902
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
903
        for c, v in changes]
904
905 3056 aaronmk
    if in_place:
906
        assert cond == None
907 3058 aaronmk
908 3065 aaronmk
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
909
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
910 3153 aaronmk
            +db.col_info(sql_gen.with_default_table(c, table), cacheable_).type
911 3065 aaronmk
            +'\nUSING '+v.to_str(db) for c, v in changes))
912 3058 aaronmk
    else:
913
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
914
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
915
            for c, v in changes))
916
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
917 3056 aaronmk
918 3266 aaronmk
    query = with_explain_comment(db, query)
919
920 2402 aaronmk
    return query
921
922 3074 aaronmk
def update(db, table, *args, **kw_args):
923 2402 aaronmk
    '''For params, see mk_update() and run_query()'''
924
    recover = kw_args.pop('recover', None)
925 3043 aaronmk
    cacheable = kw_args.pop('cacheable', False)
926 3030 aaronmk
    log_level = kw_args.pop('log_level', 2)
927 2402 aaronmk
928 3074 aaronmk
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
929
        cacheable, log_level=log_level)
930
    autoanalyze(db, table)
931
    return cur
932 2402 aaronmk
933 3286 aaronmk
def mk_delete(db, table, cond=None):
934
    '''
935
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
936
    @return str query
937
    '''
938
    query = 'DELETE FROM '+table.to_str(db)
939
    if cond != None: query += '\nWHERE '+cond.to_str(db)
940
941
    query = with_explain_comment(db, query)
942
943
    return query
944
945
def delete(db, table, *args, **kw_args):
946
    '''For params, see mk_delete() and run_query()'''
947
    recover = kw_args.pop('recover', None)
948 3295 aaronmk
    cacheable = kw_args.pop('cacheable', True)
949 3286 aaronmk
    log_level = kw_args.pop('log_level', 2)
950
951
    cur = run_query(db, mk_delete(db, table, *args, **kw_args), recover,
952
        cacheable, log_level=log_level)
953
    autoanalyze(db, table)
954
    return cur
955
956 135 aaronmk
def last_insert_id(db):
957 1849 aaronmk
    module = util.root_module(db.db)
958 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
959
    elif module == 'MySQLdb': return db.insert_id()
960
    else: return None
961 13 aaronmk
962 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
963 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
964 2415 aaronmk
    to names that will be distinct among the columns' tables.
965 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
966 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
967
    @param into The table for the new columns.
968 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
969
        columns will be included in the mapping even if they are not in cols.
970
        The tables of the provided Col objects will be changed to into, so make
971
        copies of them if you want to keep the original tables.
972
    @param as_items Whether to return a list of dict items instead of a dict
973 2383 aaronmk
    @return dict(orig_col=new_col, ...)
974
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
975 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
976
        * All mappings use the into table so its name can easily be
977 2383 aaronmk
          changed for all columns at once
978
    '''
979 2415 aaronmk
    cols = lists.uniqify(cols)
980
981 2394 aaronmk
    items = []
982 2389 aaronmk
    for col in preserve:
983 2390 aaronmk
        orig_col = copy.copy(col)
984 2392 aaronmk
        col.table = into
985 2394 aaronmk
        items.append((orig_col, col))
986
    preserve = set(preserve)
987
    for col in cols:
988 2716 aaronmk
        if col not in preserve:
989
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
990 2394 aaronmk
991
    if not as_items: items = dict(items)
992
    return items
993 2383 aaronmk
994 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
995 2391 aaronmk
    '''For params, see mk_flatten_mapping()
996
    @return See return value of mk_flatten_mapping()
997
    '''
998 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
999
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
1000 3296 aaronmk
    run_query_into(db, mk_select(db, joins, cols, order_by=None, limit=limit,
1001 3305 aaronmk
        start=start), into=into, add_pkey_=True)
1002 2394 aaronmk
    return dict(items)
1003 2391 aaronmk
1004 3079 aaronmk
##### Database structure introspection
1005 2414 aaronmk
1006 3321 aaronmk
#### Expressions
1007
1008 3353 aaronmk
bool_re = r'(?:true|false)'
1009
1010
def simplify_expr(expr):
1011
    expr = expr.replace('(NULL IS NULL)', 'true')
1012
    expr = expr.replace('(NULL IS NOT NULL)', 'false')
1013
    expr = re.sub(r' OR '+bool_re, r'', expr)
1014
    expr = re.sub(bool_re+r' OR ', r'', expr)
1015
    while True:
1016
        expr, n = re.subn(r'\((\([^()]*\))\)', r'\1', expr)
1017
        if n == 0: break
1018
    return expr
1019
1020 3321 aaronmk
name_re = r'(?:\w+|(?:"[^"]*")+)'
1021
1022
def parse_expr_col(str_):
1023
    match = re.match(r'^\('+name_re+r'\(('+name_re+r').*\)\)$', str_)
1024
    if match: str_ = match.group(1)
1025
    return sql_gen.unesc_name(str_)
1026
1027 3351 aaronmk
def map_expr(db, expr, mapping, in_cols_found=None):
1028
    '''Replaces output columns with input columns in an expression.
1029
    @param in_cols_found If set, will be filled in with the expr's (input) cols
1030
    '''
1031
    for out, in_ in mapping.iteritems():
1032
        orig_expr = expr
1033
        out = sql_gen.to_name_only_col(out)
1034
        in_str = sql_gen.to_name_only_col(sql_gen.remove_col_rename(in_)
1035
            ).to_str(db)
1036
1037
        # Replace out both with and without quotes
1038
        expr = expr.replace(out.to_str(db), in_str)
1039
        expr = re.sub(r'\b'+out.name+r'\b', in_str, expr)
1040
1041
        if in_cols_found != None and expr != orig_expr: # replaced something
1042
            in_cols_found.append(in_)
1043 3353 aaronmk
1044
    return simplify_expr(expr)
1045 3351 aaronmk
1046 3079 aaronmk
#### Tables
1047
1048
def tables(db, schema_like='public', table_like='%', exact=False):
1049
    if exact: compare = '='
1050
    else: compare = 'LIKE'
1051
1052
    module = util.root_module(db.db)
1053
    if module == 'psycopg2':
1054
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
1055
            ('tablename', sql_gen.CompareCond(table_like, compare))]
1056
        return values(select(db, 'pg_tables', ['tablename'], conds,
1057
            order_by='tablename', log_level=4))
1058
    elif module == 'MySQLdb':
1059
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
1060
            , cacheable=True, log_level=4))
1061
    else: raise NotImplementedError("Can't list tables for "+module+' database')
1062
1063
def table_exists(db, table):
1064
    table = sql_gen.as_Table(table)
1065
    return list(tables(db, table.schema, table.name, exact=True)) != []
1066
1067 2426 aaronmk
def table_row_count(db, table, recover=None):
1068 2786 aaronmk
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
1069 3298 aaronmk
        order_by=None), recover=recover, log_level=3))
1070 2426 aaronmk
1071 2414 aaronmk
def table_cols(db, table, recover=None):
1072
    return list(col_names(select(db, table, limit=0, order_by=None,
1073 2443 aaronmk
        recover=recover, log_level=4)))
1074 2414 aaronmk
1075 2291 aaronmk
def pkey(db, table, recover=None):
1076 832 aaronmk
    '''Assumed to be first column in table'''
1077 2339 aaronmk
    return table_cols(db, table, recover)[0]
1078 832 aaronmk
1079 2559 aaronmk
not_null_col = 'not_null_col'
1080 2340 aaronmk
1081
def table_not_null_col(db, table, recover=None):
1082
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
1083
    if not_null_col in table_cols(db, table, recover): return not_null_col
1084
    else: return pkey(db, table, recover)
1085
1086 3348 aaronmk
def constraint_cond(db, constraint):
1087
    module = util.root_module(db.db)
1088
    if module == 'psycopg2':
1089
        table_str = sql_gen.Literal(constraint.table.to_str(db))
1090
        name_str = sql_gen.Literal(constraint.name)
1091
        return value(run_query(db, '''\
1092
SELECT consrc
1093
FROM pg_constraint
1094
WHERE
1095
conrelid = '''+table_str.to_str(db)+'''::regclass
1096
AND conname = '''+name_str.to_str(db)+'''
1097
'''
1098
            , cacheable=True, log_level=4))
1099
    else: raise NotImplementedError("Can't list index columns for "+module+
1100
        ' database')
1101
1102 3319 aaronmk
def index_cols(db, index):
1103 853 aaronmk
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
1104
    automatically created. When you don't know whether something is a UNIQUE
1105
    constraint or a UNIQUE index, use this function.'''
1106 3322 aaronmk
    index = sql_gen.as_Table(index)
1107 1909 aaronmk
    module = util.root_module(db.db)
1108
    if module == 'psycopg2':
1109 3322 aaronmk
        qual_index = sql_gen.Literal(index.to_str(db))
1110
        return map(parse_expr_col, values(run_query(db, '''\
1111
SELECT pg_get_indexdef(indexrelid, generate_series(1, indnatts), true)
1112
FROM pg_index
1113
WHERE indexrelid = '''+qual_index.to_str(db)+'''::regclass
1114 2782 aaronmk
'''
1115
            , cacheable=True, log_level=4)))
1116 1909 aaronmk
    else: raise NotImplementedError("Can't list index columns for "+module+
1117
        ' database')
1118 853 aaronmk
1119 3079 aaronmk
#### Functions
1120
1121
def function_exists(db, function):
1122
    function = sql_gen.as_Function(function)
1123
1124
    info_table = sql_gen.Table('routines', 'information_schema')
1125
    conds = [('routine_name', function.name)]
1126
    schema = function.schema
1127
    if schema != None: conds.append(('routine_schema', schema))
1128
    # Exclude trigger functions, since they cannot be called directly
1129
    conds.append(('data_type', sql_gen.CompareCond('trigger', '!=')))
1130
1131
    return list(values(select(db, info_table, ['routine_name'], conds,
1132
        order_by='routine_schema', limit=1, log_level=4))) != []
1133
        # TODO: order_by search_path schema order
1134
1135
##### Structural changes
1136
1137
#### Columns
1138
1139
def add_col(db, table, col, comment=None, **kw_args):
1140
    '''
1141
    @param col TypedCol Name may be versioned, so be sure to propagate any
1142
        renaming back to any source column for the TypedCol.
1143
    @param comment None|str SQL comment used to distinguish columns of the same
1144
        name from each other when they contain different data, to allow the
1145
        ADD COLUMN query to be cached. If not set, query will not be cached.
1146
    '''
1147
    assert isinstance(col, sql_gen.TypedCol)
1148
1149
    while True:
1150
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1151
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1152
1153
        try:
1154
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1155
            break
1156
        except DuplicateException:
1157
            col.name = next_version(col.name)
1158
            # try again with next version of name
1159
1160
def add_not_null(db, col):
1161
    table = col.table
1162
    col = sql_gen.to_name_only_col(col)
1163
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1164
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1165
1166 2096 aaronmk
row_num_col = '_row_num'
1167
1168 3079 aaronmk
row_num_typed_col = sql_gen.TypedCol(row_num_col, 'serial', nullable=False,
1169
    constraints='PRIMARY KEY')
1170
1171
def add_row_num(db, table):
1172
    '''Adds a row number column to a table. Its name is in row_num_col. It will
1173
    be the primary key.'''
1174
    add_col(db, table, row_num_typed_col, log_level=3)
1175
1176
#### Indexes
1177
1178
def add_pkey(db, table, cols=None, recover=None):
1179
    '''Adds a primary key.
1180
    @param cols [sql_gen.Col,...] The columns in the primary key.
1181
        Defaults to the first column in the table.
1182
    @pre The table must not already have a primary key.
1183
    '''
1184
    table = sql_gen.as_Table(table)
1185
    if cols == None: cols = [pkey(db, table, recover)]
1186
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1187
1188
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1189
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1190
        log_ignore_excs=(DuplicateException,))
1191
1192 2998 aaronmk
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1193 2688 aaronmk
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1194 3356 aaronmk
    Currently, only function calls and literal values are supported expressions.
1195 2998 aaronmk
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1196 2847 aaronmk
        This allows indexes to be used for comparisons where NULLs are equal.
1197 2538 aaronmk
    '''
1198 2964 aaronmk
    exprs = lists.mk_seq(exprs)
1199 2538 aaronmk
1200 2688 aaronmk
    # Parse exprs
1201
    old_exprs = exprs[:]
1202
    exprs = []
1203
    cols = []
1204
    for i, expr in enumerate(old_exprs):
1205 2823 aaronmk
        expr = sql_gen.as_Col(expr, table)
1206 2688 aaronmk
1207 2823 aaronmk
        # Handle nullable columns
1208 2998 aaronmk
        if ensure_not_null_:
1209 3164 aaronmk
            try: expr = sql_gen.ensure_not_null(db, expr)
1210 2860 aaronmk
            except KeyError: pass # unknown type, so just create plain index
1211 2823 aaronmk
1212 2688 aaronmk
        # Extract col
1213 3002 aaronmk
        expr = copy.deepcopy(expr) # don't modify input!
1214 3356 aaronmk
        col = expr
1215
        if isinstance(expr, sql_gen.FunctionCall): col = expr.args[0]
1216
        expr = sql_gen.cast_literal(expr)
1217
        if not isinstance(expr, (sql_gen.Expr, sql_gen.Col)):
1218 2688 aaronmk
            expr = sql_gen.Expr(expr)
1219 3356 aaronmk
1220 2688 aaronmk
1221
        # Extract table
1222
        if table == None:
1223
            assert sql_gen.is_table_col(col)
1224
            table = col.table
1225
1226 3356 aaronmk
        if isinstance(col, sql_gen.Col): col.table = None
1227 2688 aaronmk
1228
        exprs.append(expr)
1229
        cols.append(col)
1230 2408 aaronmk
1231 2688 aaronmk
    table = sql_gen.as_Table(table)
1232
1233 3005 aaronmk
    # Add index
1234 3148 aaronmk
    str_ = 'CREATE'
1235
    if unique: str_ += ' UNIQUE'
1236
    str_ += ' INDEX ON '+table.to_str(db)+' ('+(
1237
        ', '.join((v.to_str(db) for v in exprs)))+')'
1238
    run_query(db, str_, recover=True, cacheable=True, log_level=3)
1239 2408 aaronmk
1240 3083 aaronmk
already_indexed = object() # tells add_indexes() the pkey has already been added
1241
1242
def add_indexes(db, table, has_pkey=True):
1243
    '''Adds an index on all columns in a table.
1244
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1245
        index should be added on the first column.
1246
        * If already_indexed, the pkey is assumed to have already been added
1247
    '''
1248
    cols = table_cols(db, table)
1249
    if has_pkey:
1250
        if has_pkey is not already_indexed: add_pkey(db, table)
1251
        cols = cols[1:]
1252
    for col in cols: add_index(db, col, table)
1253
1254 3079 aaronmk
#### Tables
1255 2772 aaronmk
1256 3079 aaronmk
### Maintenance
1257 2772 aaronmk
1258 3079 aaronmk
def analyze(db, table):
1259
    table = sql_gen.as_Table(table)
1260
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1261 2934 aaronmk
1262 3079 aaronmk
def autoanalyze(db, table):
1263
    if db.autoanalyze: analyze(db, table)
1264 2935 aaronmk
1265 3079 aaronmk
def vacuum(db, table):
1266
    table = sql_gen.as_Table(table)
1267
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1268
        log_level=3))
1269 2086 aaronmk
1270 3079 aaronmk
### Lifecycle
1271
1272 3247 aaronmk
def drop(db, type_, name):
1273
    name = sql_gen.as_Name(name)
1274
    run_query(db, 'DROP '+type_+' IF EXISTS '+name.to_str(db)+' CASCADE')
1275 2889 aaronmk
1276 3247 aaronmk
def drop_table(db, table): drop(db, 'TABLE', table)
1277
1278 3082 aaronmk
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1279
    like=None):
1280 2675 aaronmk
    '''Creates a table.
1281 2681 aaronmk
    @param cols [sql_gen.TypedCol,...] The column names and types
1282
    @param has_pkey If set, the first column becomes the primary key.
1283 2760 aaronmk
    @param col_indexes bool|[ref]
1284
        * If True, indexes will be added on all non-pkey columns.
1285
        * If a list reference, [0] will be set to a function to do this.
1286
          This can be used to delay index creation until the table is populated.
1287 2675 aaronmk
    '''
1288
    table = sql_gen.as_Table(table)
1289
1290 3082 aaronmk
    if like != None:
1291
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1292
            ]+cols
1293 2681 aaronmk
    if has_pkey:
1294
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1295 2872 aaronmk
        pkey.constraints = 'PRIMARY KEY'
1296 2681 aaronmk
1297 3085 aaronmk
    temp = table.is_temp and not db.debug_temp
1298
        # temp tables permanent in debug_temp mode
1299 2760 aaronmk
1300 3085 aaronmk
    # Create table
1301
    while True:
1302
        str_ = 'CREATE'
1303
        if temp: str_ += ' TEMP'
1304
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1305
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1306 3126 aaronmk
        str_ += '\n);'
1307 3085 aaronmk
1308
        try:
1309 3127 aaronmk
            run_query(db, str_, recover=True, cacheable=True, log_level=2,
1310 3085 aaronmk
                log_ignore_excs=(DuplicateException,))
1311
            break
1312
        except DuplicateException:
1313
            table.name = next_version(table.name)
1314
            # try again with next version of name
1315
1316 2760 aaronmk
    # Add indexes
1317 2773 aaronmk
    if has_pkey: has_pkey = already_indexed
1318
    def add_indexes_(): add_indexes(db, table, has_pkey)
1319
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1320
    elif col_indexes: add_indexes_() # add now
1321 2675 aaronmk
1322 3084 aaronmk
def copy_table_struct(db, src, dest):
1323
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1324 3085 aaronmk
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1325 3084 aaronmk
1326 3079 aaronmk
### Data
1327 2684 aaronmk
1328 2970 aaronmk
def truncate(db, table, schema='public', **kw_args):
1329
    '''For params, see run_query()'''
1330 2777 aaronmk
    table = sql_gen.as_Table(table, schema)
1331 2970 aaronmk
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1332 2732 aaronmk
1333 2965 aaronmk
def empty_temp(db, tables):
1334
    tables = lists.mk_seq(tables)
1335 2971 aaronmk
    for table in tables: truncate(db, table, log_level=3)
1336 2965 aaronmk
1337 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
1338
    '''For kw_args, see tables()'''
1339
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1340 3094 aaronmk
1341
def distinct_table(db, table, distinct_on):
1342
    '''Creates a copy of a temp table which is distinct on the given columns.
1343 3099 aaronmk
    The old and new tables will both get an index on these columns, to
1344
    facilitate merge joins.
1345 3097 aaronmk
    @param distinct_on If empty, creates a table with one row. This is useful if
1346
        your distinct_on columns are all literal values.
1347 3099 aaronmk
    @return The new table.
1348 3094 aaronmk
    '''
1349 3099 aaronmk
    new_table = sql_gen.suffixed_table(table, '_distinct')
1350 3357 aaronmk
    distinct_on = filter(lambda c: not sql_gen.is_null(c),
1351
        map(sql_gen.remove_col_rename, distinct_on))
1352 3094 aaronmk
1353 3099 aaronmk
    copy_table_struct(db, table, new_table)
1354 3097 aaronmk
1355
    limit = None
1356
    if distinct_on == []: limit = 1 # one sample row
1357 3099 aaronmk
    else:
1358
        add_index(db, distinct_on, new_table, unique=True)
1359
        add_index(db, distinct_on, table) # for join optimization
1360 3097 aaronmk
1361 3313 aaronmk
    insert_select(db, new_table, None, mk_select(db, table, order_by=None,
1362
        limit=limit), ignore=True)
1363 3099 aaronmk
    analyze(db, new_table)
1364 3094 aaronmk
1365 3099 aaronmk
    return new_table