Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 3238 aaronmk
import time
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 3241 aaronmk
import profiling
13 1889 aaronmk
from Proxy import Proxy
14 1872 aaronmk
import rand
15 2217 aaronmk
import sql_gen
16 862 aaronmk
import strings
17 131 aaronmk
import util
18 11 aaronmk
19 832 aaronmk
##### Exceptions
20
21 2804 aaronmk
def get_cur_query(cur, input_query=None):
22 2168 aaronmk
    raw_query = None
23
    if hasattr(cur, 'query'): raw_query = cur.query
24
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
25 2170 aaronmk
26
    if raw_query != None: return raw_query
27 2804 aaronmk
    else: return '[input] '+strings.ustr(input_query)
28 14 aaronmk
29 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
30
    '''For params, see get_cur_query()'''
31 2771 aaronmk
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
32 135 aaronmk
33 300 aaronmk
class DbException(exc.ExceptionWithCause):
34 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
35 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
36 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
37
38 2143 aaronmk
class ExceptionWithName(DbException):
39
    def __init__(self, name, cause=None):
40 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
41 2143 aaronmk
        self.name = name
42 360 aaronmk
43 3109 aaronmk
class ExceptionWithValue(DbException):
44
    def __init__(self, value, cause=None):
45
        DbException.__init__(self, 'for value: '+strings.as_tt(repr(value)),
46
            cause)
47 2240 aaronmk
        self.value = value
48
49 2945 aaronmk
class ExceptionWithNameType(DbException):
50
    def __init__(self, type_, name, cause=None):
51
        DbException.__init__(self, 'for type: '+strings.as_tt(str(type_))
52
            +'; name: '+strings.as_tt(name), cause)
53
        self.type = type_
54
        self.name = name
55
56 2306 aaronmk
class ConstraintException(DbException):
57 3345 aaronmk
    def __init__(self, name, cond, cols, cause=None):
58
        msg = 'Violated '+strings.as_tt(name)+' constraint'
59
        if cond != None: msg += ' with condition '+cond
60
        if cols != []: msg += ' on columns: '+strings.as_tt(', '.join(cols))
61
        DbException.__init__(self, msg, cause)
62 2306 aaronmk
        self.name = name
63 3345 aaronmk
        self.cond = cond
64 468 aaronmk
        self.cols = cols
65 11 aaronmk
66 2523 aaronmk
class MissingCastException(DbException):
67
    def __init__(self, type_, col, cause=None):
68
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
69
            +' on column: '+strings.as_tt(col), cause)
70
        self.type = type_
71
        self.col = col
72
73 2143 aaronmk
class NameException(DbException): pass
74
75 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
76 13 aaronmk
77 2306 aaronmk
class NullValueException(ConstraintException): pass
78 13 aaronmk
79 3346 aaronmk
class CheckException(ConstraintException): pass
80
81 3109 aaronmk
class InvalidValueException(ExceptionWithValue): pass
82 2239 aaronmk
83 2945 aaronmk
class DuplicateException(ExceptionWithNameType): pass
84 2143 aaronmk
85 89 aaronmk
class EmptyRowException(DbException): pass
86
87 865 aaronmk
##### Warnings
88
89
class DbWarning(UserWarning): pass
90
91 1930 aaronmk
##### Result retrieval
92
93
def col_names(cur): return (col[0] for col in cur.description)
94
95
def rows(cur): return iter(lambda: cur.fetchone(), None)
96
97
def consume_rows(cur):
98
    '''Used to fetch all rows so result will be cached'''
99
    iters.consume_iter(rows(cur))
100
101
def next_row(cur): return rows(cur).next()
102
103
def row(cur):
104
    row_ = next_row(cur)
105
    consume_rows(cur)
106
    return row_
107
108
def next_value(cur): return next_row(cur)[0]
109
110
def value(cur): return row(cur)[0]
111
112
def values(cur): return iters.func_iter(lambda: next_value(cur))
113
114
def value_or_none(cur):
115
    try: return value(cur)
116
    except StopIteration: return None
117
118 2762 aaronmk
##### Escaping
119 2101 aaronmk
120 2573 aaronmk
def esc_name_by_module(module, name):
121
    if module == 'psycopg2' or module == None: quote = '"'
122 2101 aaronmk
    elif module == 'MySQLdb': quote = '`'
123
    else: raise NotImplementedError("Can't escape name for "+module+' database')
124 2500 aaronmk
    return sql_gen.esc_name(name, quote)
125 2101 aaronmk
126
def esc_name_by_engine(engine, name, **kw_args):
127
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
128
129
def esc_name(db, name, **kw_args):
130
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
131
132
def qual_name(db, schema, table):
133
    def esc_name_(name): return esc_name(db, name)
134
    table = esc_name_(table)
135
    if schema != None: return esc_name_(schema)+'.'+table
136
    else: return table
137
138 1869 aaronmk
##### Database connections
139 1849 aaronmk
140 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
141 1926 aaronmk
142 1869 aaronmk
db_engines = {
143
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
144
    'PostgreSQL': ('psycopg2', {}),
145
}
146
147
DatabaseErrors_set = set([DbException])
148
DatabaseErrors = tuple(DatabaseErrors_set)
149
150
def _add_module(module):
151
    DatabaseErrors_set.add(module.DatabaseError)
152
    global DatabaseErrors
153
    DatabaseErrors = tuple(DatabaseErrors_set)
154
155
def db_config_str(db_config):
156
    return db_config['engine']+' database '+db_config['database']
157
158 2448 aaronmk
log_debug_none = lambda msg, level=2: None
159 1901 aaronmk
160 1849 aaronmk
class DbConn:
161 2923 aaronmk
    def __init__(self, db_config, autocommit=True, caching=True,
162 3183 aaronmk
        log_debug=log_debug_none, debug_temp=False, src=None):
163 2915 aaronmk
        '''
164
        @param debug_temp Whether temporary objects should instead be permanent.
165
            This assists in debugging the internal objects used by the program.
166 3183 aaronmk
        @param src In autocommit mode, will be included in a comment in every
167
            query, to help identify the data source in pg_stat_activity.
168 2915 aaronmk
        '''
169 1869 aaronmk
        self.db_config = db_config
170 2190 aaronmk
        self.autocommit = autocommit
171
        self.caching = caching
172 1901 aaronmk
        self.log_debug = log_debug
173 2193 aaronmk
        self.debug = log_debug != log_debug_none
174 2915 aaronmk
        self.debug_temp = debug_temp
175 3183 aaronmk
        self.src = src
176 3074 aaronmk
        self.autoanalyze = False
177 3269 aaronmk
        self.autoexplain = False
178
        self.profile_row_ct = None
179 1869 aaronmk
180 3124 aaronmk
        self._savepoint = 0
181 3120 aaronmk
        self._reset()
182 1869 aaronmk
183
    def __getattr__(self, name):
184
        if name == '__dict__': raise Exception('getting __dict__')
185
        if name == 'db': return self._db()
186
        else: raise AttributeError()
187
188
    def __getstate__(self):
189
        state = copy.copy(self.__dict__) # shallow copy
190 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
191 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
192
        return state
193
194 3118 aaronmk
    def clear_cache(self): self.query_results = {}
195
196 3120 aaronmk
    def _reset(self):
197 3118 aaronmk
        self.clear_cache()
198 3124 aaronmk
        assert self._savepoint == 0
199 3118 aaronmk
        self._notices_seen = set()
200
        self.__db = None
201
202 2165 aaronmk
    def connected(self): return self.__db != None
203
204 3116 aaronmk
    def close(self):
205 3119 aaronmk
        if not self.connected(): return
206
207 3135 aaronmk
        # Record that the automatic transaction is now closed
208 3136 aaronmk
        self._savepoint -= 1
209 3135 aaronmk
210 3119 aaronmk
        self.db.close()
211 3120 aaronmk
        self._reset()
212 3116 aaronmk
213 3125 aaronmk
    def reconnect(self):
214
        # Do not do this in test mode as it would roll back everything
215
        if self.autocommit: self.close()
216
        # Connection will be reopened automatically on first query
217
218 1869 aaronmk
    def _db(self):
219
        if self.__db == None:
220
            # Process db_config
221
            db_config = self.db_config.copy() # don't modify input!
222 2097 aaronmk
            schemas = db_config.pop('schemas', None)
223 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
224
            module = __import__(module_name)
225
            _add_module(module)
226
            for orig, new in mappings.iteritems():
227
                try: util.rename_key(db_config, orig, new)
228
                except KeyError: pass
229
230
            # Connect
231
            self.__db = module.connect(**db_config)
232
233 3161 aaronmk
            # Record that a transaction is already open
234
            self._savepoint += 1
235
236 1869 aaronmk
            # Configure connection
237 2906 aaronmk
            if hasattr(self.db, 'set_isolation_level'):
238
                import psycopg2.extensions
239
                self.db.set_isolation_level(
240
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
241 2101 aaronmk
            if schemas != None:
242 2893 aaronmk
                search_path = [self.esc_name(s) for s in schemas.split(',')]
243
                search_path.append(value(run_query(self, 'SHOW search_path',
244
                    log_level=4)))
245
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
246
                    log_level=3)
247 1869 aaronmk
248
        return self.__db
249 1889 aaronmk
250 1891 aaronmk
    class DbCursor(Proxy):
251 1927 aaronmk
        def __init__(self, outer):
252 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
253 2191 aaronmk
            self.outer = outer
254 1927 aaronmk
            self.query_results = outer.query_results
255 1894 aaronmk
            self.query_lookup = None
256 1891 aaronmk
            self.result = []
257 1889 aaronmk
258 2802 aaronmk
        def execute(self, query):
259 2764 aaronmk
            self._is_insert = query.startswith('INSERT')
260 2797 aaronmk
            self.query_lookup = query
261 2148 aaronmk
            try:
262 3162 aaronmk
                try: cur = self.inner.execute(query)
263 2802 aaronmk
                finally: self.query = get_cur_query(self.inner, query)
264 1904 aaronmk
            except Exception, e:
265
                self.result = e # cache the exception as the result
266
                self._cache_result()
267
                raise
268 3004 aaronmk
269
            # Always cache certain queries
270 3183 aaronmk
            query = sql_gen.lstrip(query)
271 3004 aaronmk
            if query.startswith('CREATE') or query.startswith('ALTER'):
272 3007 aaronmk
                # structural changes
273 3040 aaronmk
                # Rest of query must be unique in the face of name collisions,
274
                # so don't cache ADD COLUMN unless it has distinguishing comment
275
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
276 3007 aaronmk
                    self._cache_result()
277 3004 aaronmk
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
278 2800 aaronmk
                consume_rows(self) # fetch all rows so result will be cached
279 3004 aaronmk
280 2762 aaronmk
            return cur
281 1894 aaronmk
282 1891 aaronmk
        def fetchone(self):
283
            row = self.inner.fetchone()
284 1899 aaronmk
            if row != None: self.result.append(row)
285
            # otherwise, fetched all rows
286 1904 aaronmk
            else: self._cache_result()
287
            return row
288
289
        def _cache_result(self):
290 2948 aaronmk
            # For inserts that return a result set, don't cache result set since
291
            # inserts are not idempotent. Other non-SELECT queries don't have
292
            # their result set read, so only exceptions will be cached (an
293
            # invalid query will always be invalid).
294 1930 aaronmk
            if self.query_results != None and (not self._is_insert
295 1906 aaronmk
                or isinstance(self.result, Exception)):
296
297 1894 aaronmk
                assert self.query_lookup != None
298 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
299
                    util.dict_subset(dicts.AttrsDictView(self),
300
                    ['query', 'result', 'rowcount', 'description']))
301 1906 aaronmk
302 1916 aaronmk
        class CacheCursor:
303
            def __init__(self, cached_result): self.__dict__ = cached_result
304
305 1927 aaronmk
            def execute(self, *args, **kw_args):
306 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
307
                # otherwise, result is a rows list
308
                self.iter = iter(self.result)
309
310
            def fetchone(self):
311
                try: return self.iter.next()
312
                except StopIteration: return None
313 1891 aaronmk
314 2212 aaronmk
    def esc_value(self, value):
315 2663 aaronmk
        try: str_ = self.mogrify('%s', [value])
316
        except NotImplementedError, e:
317
            module = util.root_module(self.db)
318
            if module == 'MySQLdb':
319
                import _mysql
320
                str_ = _mysql.escape_string(value)
321
            else: raise e
322 2374 aaronmk
        return strings.to_unicode(str_)
323 2212 aaronmk
324 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
325
326 2814 aaronmk
    def std_code(self, str_):
327
        '''Standardizes SQL code.
328
        * Ensures that string literals are prefixed by `E`
329
        '''
330
        if str_.startswith("'"): str_ = 'E'+str_
331
        return str_
332
333 2665 aaronmk
    def can_mogrify(self):
334 2663 aaronmk
        module = util.root_module(self.db)
335 2665 aaronmk
        return module == 'psycopg2'
336 2663 aaronmk
337 2665 aaronmk
    def mogrify(self, query, params=None):
338
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
339
        else: raise NotImplementedError("Can't mogrify query")
340
341 2671 aaronmk
    def print_notices(self):
342 2725 aaronmk
        if hasattr(self.db, 'notices'):
343
            for msg in self.db.notices:
344
                if msg not in self._notices_seen:
345
                    self._notices_seen.add(msg)
346
                    self.log_debug(msg, level=2)
347 2671 aaronmk
348 2793 aaronmk
    def run_query(self, query, cacheable=False, log_level=2,
349 2464 aaronmk
        debug_msg_ref=None):
350 2445 aaronmk
        '''
351 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
352
            throws one of these exceptions.
353 2664 aaronmk
        @param debug_msg_ref If specified, the log message will be returned in
354
            this instead of being output. This allows you to filter log messages
355
            depending on the result of the query.
356 2445 aaronmk
        '''
357 2167 aaronmk
        assert query != None
358
359 3183 aaronmk
        if self.autocommit and self.src != None:
360 3206 aaronmk
            query = sql_gen.esc_comment(self.src)+'\t'+query
361 3183 aaronmk
362 2047 aaronmk
        if not self.caching: cacheable = False
363 1903 aaronmk
        used_cache = False
364 2664 aaronmk
365 3242 aaronmk
        if self.debug:
366
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
367 1903 aaronmk
        try:
368 1927 aaronmk
            # Get cursor
369
            if cacheable:
370 3238 aaronmk
                try: cur = self.query_results[query]
371 1927 aaronmk
                except KeyError: cur = self.DbCursor(self)
372 3238 aaronmk
                else: used_cache = True
373 1927 aaronmk
            else: cur = self.db.cursor()
374
375
            # Run query
376 3238 aaronmk
            try: cur.execute(query)
377 3162 aaronmk
            except Exception, e:
378
                _add_cursor_info(e, self, query)
379
                raise
380 3238 aaronmk
            else: self.do_autocommit()
381 1903 aaronmk
        finally:
382 3242 aaronmk
            if self.debug:
383 3244 aaronmk
                profiler.stop(self.profile_row_ct)
384 3242 aaronmk
385
                ## Log or return query
386
387 3239 aaronmk
                query = str(get_cur_query(cur, query))
388 3281 aaronmk
                # Put the src comment on a separate line in the log file
389
                query = query.replace('\t', '\n', 1)
390 3239 aaronmk
391 3240 aaronmk
                msg = 'DB query: '
392 3239 aaronmk
393 3240 aaronmk
                if used_cache: msg += 'cache hit'
394
                elif cacheable: msg += 'cache miss'
395
                else: msg += 'non-cacheable'
396 3239 aaronmk
397 3241 aaronmk
                msg += ':\n'+profiler.msg()+'\n'+strings.as_code(query, 'SQL')
398 3240 aaronmk
399 3237 aaronmk
                if debug_msg_ref != None: debug_msg_ref[0] = msg
400
                else: self.log_debug(msg, log_level)
401 3245 aaronmk
402
                self.print_notices()
403 1903 aaronmk
404
        return cur
405 1914 aaronmk
406 2797 aaronmk
    def is_cached(self, query): return query in self.query_results
407 2139 aaronmk
408 2907 aaronmk
    def with_autocommit(self, func):
409 2801 aaronmk
        import psycopg2.extensions
410
411
        prev_isolation_level = self.db.isolation_level
412 2907 aaronmk
        self.db.set_isolation_level(
413
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
414 2683 aaronmk
        try: return func()
415 2801 aaronmk
        finally: self.db.set_isolation_level(prev_isolation_level)
416 2683 aaronmk
417 2139 aaronmk
    def with_savepoint(self, func):
418 3137 aaronmk
        top = self._savepoint == 0
419 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
420 3137 aaronmk
421 3272 aaronmk
        if self.debug:
422 3273 aaronmk
            self.log_debug('Begin transaction', level=4)
423 3272 aaronmk
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
424
425 3160 aaronmk
        # Must happen before running queries so they don't get autocommitted
426
        self._savepoint += 1
427
428 3137 aaronmk
        if top: query = 'START TRANSACTION ISOLATION LEVEL READ COMMITTED'
429
        else: query = 'SAVEPOINT '+savepoint
430
        self.run_query(query, log_level=4)
431
        try:
432
            return func()
433
            if top: self.run_query('COMMIT', log_level=4)
434 2139 aaronmk
        except:
435 3137 aaronmk
            if top: query = 'ROLLBACK'
436
            else: query = 'ROLLBACK TO SAVEPOINT '+savepoint
437
            self.run_query(query, log_level=4)
438
439 2139 aaronmk
            raise
440 2930 aaronmk
        finally:
441
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
442
            # "The savepoint remains valid and can be rolled back to again"
443
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
444 3137 aaronmk
            if not top:
445
                self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
446 2930 aaronmk
447
            self._savepoint -= 1
448
            assert self._savepoint >= 0
449
450 3272 aaronmk
            if self.debug:
451
                profiler.stop(self.profile_row_ct)
452 3273 aaronmk
                self.log_debug('End transaction\n'+profiler.msg(), level=4)
453 3272 aaronmk
454 2930 aaronmk
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
455 2191 aaronmk
456
    def do_autocommit(self):
457
        '''Autocommits if outside savepoint'''
458 3135 aaronmk
        assert self._savepoint >= 1
459
        if self.autocommit and self._savepoint == 1:
460 2924 aaronmk
            self.log_debug('Autocommitting', level=4)
461 2191 aaronmk
            self.db.commit()
462 2643 aaronmk
463 3155 aaronmk
    def col_info(self, col, cacheable=True):
464 2643 aaronmk
        table = sql_gen.Table('columns', 'information_schema')
465 3063 aaronmk
        type_ = sql_gen.Coalesce(sql_gen.Nullif(sql_gen.Col('data_type'),
466
            'USER-DEFINED'), sql_gen.Col('udt_name'))
467 3078 aaronmk
        cols = [type_, 'column_default',
468
            sql_gen.Cast('boolean', sql_gen.Col('is_nullable'))]
469 2643 aaronmk
470
        conds = [('table_name', col.table.name), ('column_name', col.name)]
471
        schema = col.table.schema
472
        if schema != None: conds.append(('table_schema', schema))
473
474 2819 aaronmk
        type_, default, nullable = row(select(self, table, cols, conds,
475 3150 aaronmk
            order_by='table_schema', limit=1, cacheable=cacheable, log_level=4))
476 2643 aaronmk
            # TODO: order_by search_path schema order
477 2819 aaronmk
        default = sql_gen.as_Code(default, self)
478
479
        return sql_gen.TypedCol(col.name, type_, default, nullable)
480 2917 aaronmk
481
    def TempFunction(self, name):
482
        if self.debug_temp: schema = None
483
        else: schema = 'pg_temp'
484
        return sql_gen.Function(name, schema)
485 1849 aaronmk
486 1869 aaronmk
connect = DbConn
487
488 832 aaronmk
##### Recoverable querying
489 15 aaronmk
490 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
491 11 aaronmk
492 2791 aaronmk
def run_query(db, query, recover=None, cacheable=False, log_level=2,
493
    log_ignore_excs=None, **kw_args):
494 2794 aaronmk
    '''For params, see DbConn.run_query()'''
495 830 aaronmk
    if recover == None: recover = False
496 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
497
    log_ignore_excs = tuple(log_ignore_excs)
498 3236 aaronmk
    debug_msg_ref = [None]
499 830 aaronmk
500 3267 aaronmk
    query = with_explain_comment(db, query)
501 3258 aaronmk
502 2148 aaronmk
    try:
503 2464 aaronmk
        try:
504 2794 aaronmk
            def run(): return db.run_query(query, cacheable, log_level,
505 2793 aaronmk
                debug_msg_ref, **kw_args)
506 2796 aaronmk
            if recover and not db.is_cached(query):
507 2464 aaronmk
                return with_savepoint(db, run)
508
            else: return run() # don't need savepoint if cached
509
        except Exception, e:
510 3095 aaronmk
            msg = strings.ustr(e.args[0])
511 3414 aaronmk
            msg = re.sub(r'^PL/Python: \w+: ', r'', msg)
512 2464 aaronmk
513 3095 aaronmk
            match = re.match(r'^duplicate key value violates unique constraint '
514 3338 aaronmk
                r'"(.+?)"', msg)
515 2464 aaronmk
            if match:
516 3338 aaronmk
                constraint, = match.groups()
517 3025 aaronmk
                cols = []
518
                if recover: # need auto-rollback to run index_cols()
519 3319 aaronmk
                    try: cols = index_cols(db, constraint)
520 3025 aaronmk
                    except NotImplementedError: pass
521 3345 aaronmk
                raise DuplicateKeyException(constraint, None, cols, e)
522 2464 aaronmk
523 3095 aaronmk
            match = re.match(r'^null value in column "(.+?)" violates not-null'
524 2464 aaronmk
                r' constraint', msg)
525 3345 aaronmk
            if match:
526
                col, = match.groups()
527
                raise NullValueException('NOT NULL', None, [col], e)
528 2464 aaronmk
529 3346 aaronmk
            match = re.match(r'^new row for relation "(.+?)" violates check '
530
                r'constraint "(.+?)"', msg)
531
            if match:
532
                table, constraint = match.groups()
533 3347 aaronmk
                constraint = sql_gen.Col(constraint, table)
534 3349 aaronmk
                cond = None
535
                if recover: # need auto-rollback to run constraint_cond()
536
                    try: cond = constraint_cond(db, constraint)
537
                    except NotImplementedError: pass
538
                raise CheckException(constraint.to_str(db), cond, [], e)
539 3346 aaronmk
540 3095 aaronmk
            match = re.match(r'^(?:invalid input (?:syntax|value)\b.*?'
541 3109 aaronmk
                r'|.+? field value out of range): "(.+?)"', msg)
542 2464 aaronmk
            if match:
543 3109 aaronmk
                value, = match.groups()
544
                raise InvalidValueException(strings.to_unicode(value), e)
545 2464 aaronmk
546 3095 aaronmk
            match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
547 2523 aaronmk
                r'is of type', msg)
548
            if match:
549
                col, type_ = match.groups()
550
                raise MissingCastException(type_, col, e)
551
552 3095 aaronmk
            match = re.match(r'^(\S+) "(.+?)".*? already exists', msg)
553 2945 aaronmk
            if match:
554
                type_, name = match.groups()
555
                raise DuplicateException(type_, name, e)
556 2464 aaronmk
557
            raise # no specific exception raised
558
    except log_ignore_excs:
559
        log_level += 2
560
        raise
561
    finally:
562 3236 aaronmk
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)
563 830 aaronmk
564 832 aaronmk
##### Basic queries
565
566 3256 aaronmk
def is_explainable(query):
567
    # See <http://www.postgresql.org/docs/8.3/static/sql-explain.html#AEN57749>
568 3257 aaronmk
    return re.match(r'^(?:SELECT|INSERT|UPDATE|DELETE|VALUES|EXECUTE|DECLARE)\b'
569
        , query)
570 3256 aaronmk
571 3263 aaronmk
def explain(db, query, **kw_args):
572
    '''
573
    For params, see run_query().
574
    '''
575 3267 aaronmk
    kw_args.setdefault('log_level', 4)
576 3263 aaronmk
577 3256 aaronmk
    return strings.join_lines(values(run_query(db, 'EXPLAIN '+query,
578 3263 aaronmk
        recover=True, cacheable=True, **kw_args)))
579 3256 aaronmk
        # not a higher log_level because it's useful to see what query is being
580
        # run before it's executed, which EXPLAIN effectively provides
581
582 3265 aaronmk
def has_comment(query): return query.endswith('*/')
583
584
def with_explain_comment(db, query, **kw_args):
585 3269 aaronmk
    if db.autoexplain and not has_comment(query) and is_explainable(query):
586 3265 aaronmk
        query += '\n'+sql_gen.esc_comment(' EXPLAIN:\n'
587
            +explain(db, query, **kw_args))
588
    return query
589
590 2153 aaronmk
def next_version(name):
591 2163 aaronmk
    version = 1 # first existing name was version 0
592 2586 aaronmk
    match = re.match(r'^(.*)#(\d+)$', name)
593 2153 aaronmk
    if match:
594 2586 aaronmk
        name, version = match.groups()
595
        version = int(version)+1
596 2932 aaronmk
    return sql_gen.concat(name, '#'+str(version))
597 2153 aaronmk
598 2899 aaronmk
def lock_table(db, table, mode):
599
    table = sql_gen.as_Table(table)
600
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
601
602 3303 aaronmk
def run_query_into(db, query, into=None, add_pkey_=False, **kw_args):
603 2085 aaronmk
    '''Outputs a query to a temp table.
604
    For params, see run_query().
605
    '''
606 2789 aaronmk
    if into == None: return run_query(db, query, **kw_args)
607 2790 aaronmk
608
    assert isinstance(into, sql_gen.Table)
609
610 2992 aaronmk
    into.is_temp = True
611 3008 aaronmk
    # "temporary tables cannot specify a schema name", so remove schema
612
    into.schema = None
613 2992 aaronmk
614 2790 aaronmk
    kw_args['recover'] = True
615 2945 aaronmk
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
616 2790 aaronmk
617 2916 aaronmk
    temp = not db.debug_temp # tables are permanent in debug_temp mode
618 2790 aaronmk
619
    # Create table
620
    while True:
621
        create_query = 'CREATE'
622
        if temp: create_query += ' TEMP'
623
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
624 2385 aaronmk
625 2790 aaronmk
        try:
626
            cur = run_query(db, create_query, **kw_args)
627
                # CREATE TABLE AS sets rowcount to # rows in query
628
            break
629 2945 aaronmk
        except DuplicateException, e:
630 2790 aaronmk
            into.name = next_version(into.name)
631
            # try again with next version of name
632
633 3303 aaronmk
    if add_pkey_: add_pkey(db, into)
634 3075 aaronmk
635
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
636
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
637
    # table is going to be used in complex queries, it is wise to run ANALYZE on
638
    # the temporary table after it is populated."
639
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
640
    # If into is not a temp table, ANALYZE is useful but not required.
641 3073 aaronmk
    analyze(db, into)
642 2790 aaronmk
643
    return cur
644 2085 aaronmk
645 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
646
647 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
648
649 2233 aaronmk
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
650 2293 aaronmk
    start=None, order_by=order_by_pkey, default_table=None):
651 1981 aaronmk
    '''
652 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
653 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
654 1981 aaronmk
    @param fields Use None to select all fields in the table
655 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
656 2379 aaronmk
        * container can be any iterable type
657 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
658
        * compare_right_side: sql_gen.ValueCond|literal value
659 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
660
        use all columns
661 2786 aaronmk
    @return query
662 1981 aaronmk
    '''
663 2315 aaronmk
    # Parse tables param
664 2964 aaronmk
    tables = lists.mk_seq(tables)
665 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
666 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
667 2121 aaronmk
668 2315 aaronmk
    # Parse other params
669 2376 aaronmk
    if conds == None: conds = []
670 2650 aaronmk
    elif dicts.is_dict(conds): conds = conds.items()
671 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
672 3129 aaronmk
    assert limit == None or isinstance(limit, (int, long))
673
    assert start == None or isinstance(start, (int, long))
674 2315 aaronmk
    if order_by is order_by_pkey:
675
        if distinct_on != []: order_by = None
676
        else: order_by = pkey(db, table0, recover=True)
677 865 aaronmk
678 2315 aaronmk
    query = 'SELECT'
679 2056 aaronmk
680 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
681 2056 aaronmk
682 2200 aaronmk
    # DISTINCT ON columns
683 2233 aaronmk
    if distinct_on != []:
684 2467 aaronmk
        query += '\nDISTINCT'
685 2254 aaronmk
        if distinct_on is not distinct_on_all:
686 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
687
688
    # Columns
689 3185 aaronmk
    if query.find('\n') >= 0: whitespace = '\n'
690
    else: whitespace = ' '
691
    if fields == None: query += whitespace+'*'
692 2765 aaronmk
    else:
693
        assert fields != []
694 3185 aaronmk
        if len(fields) > 1: whitespace = '\n'
695
        query += whitespace+('\n, '.join(map(parse_col, fields)))
696 2200 aaronmk
697
    # Main table
698 3185 aaronmk
    if query.find('\n') >= 0 or len(tables) > 0: whitespace = '\n'
699
    else: whitespace = ' '
700
    query += whitespace+'FROM '+table0.to_str(db)
701 865 aaronmk
702 2122 aaronmk
    # Add joins
703 2271 aaronmk
    left_table = table0
704 2263 aaronmk
    for join_ in tables:
705
        table = join_.table
706 2238 aaronmk
707 2343 aaronmk
        # Parse special values
708
        if join_.type_ is sql_gen.filter_out: # filter no match
709 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
710 2853 aaronmk
                sql_gen.CompareCond(None, '~=')))
711 2343 aaronmk
712 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
713 2122 aaronmk
714
        left_table = table
715
716 865 aaronmk
    missing = True
717 2376 aaronmk
    if conds != []:
718 2576 aaronmk
        if len(conds) == 1: whitespace = ' '
719
        else: whitespace = '\n'
720 2578 aaronmk
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
721
            .to_str(db) for l, r in conds], 'WHERE')
722 2227 aaronmk
    if order_by != None:
723 2467 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
724 3297 aaronmk
    if limit != None: query += '\nLIMIT '+str(limit)
725 865 aaronmk
    if start != None:
726 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
727 865 aaronmk
728 3266 aaronmk
    query = with_explain_comment(db, query)
729
730 2786 aaronmk
    return query
731 11 aaronmk
732 2054 aaronmk
def select(db, *args, **kw_args):
733
    '''For params, see mk_select() and run_query()'''
734
    recover = kw_args.pop('recover', None)
735
    cacheable = kw_args.pop('cacheable', True)
736 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
737 2054 aaronmk
738 2791 aaronmk
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
739
        log_level=log_level)
740 2054 aaronmk
741 2788 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
742 3181 aaronmk
    embeddable=False, ignore=False, src=None):
743 1960 aaronmk
    '''
744
    @param returning str|None An inserted column (such as pkey) to return
745 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
746 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
747
        query will be fully cached, not just if it raises an exception.
748 3009 aaronmk
    @param ignore Whether to ignore duplicate keys.
749 3181 aaronmk
    @param src Will be included in the name of any created function, to help
750
        identify the data source in pg_stat_activity.
751 1960 aaronmk
    '''
752 2754 aaronmk
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
753 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
754 3010 aaronmk
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
755 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
756 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
757 2063 aaronmk
758 2497 aaronmk
    first_line = 'INSERT INTO '+table.to_str(db)
759 2063 aaronmk
760 3009 aaronmk
    def mk_insert(select_query):
761
        query = first_line
762 3014 aaronmk
        if cols != None:
763
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
764 3009 aaronmk
        query += '\n'+select_query
765
766
        if returning != None:
767
            returning_name_col = sql_gen.to_name_only_col(returning)
768
            query += '\nRETURNING '+returning_name_col.to_str(db)
769
770
        return query
771 2063 aaronmk
772 3017 aaronmk
    return_type = 'unknown'
773
    if returning != None: return_type = returning.to_str(db)+'%TYPE'
774
775 3009 aaronmk
    lang = 'sql'
776
    if ignore:
777 3017 aaronmk
        # Always return something to set the correct rowcount
778
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
779
780 3009 aaronmk
        embeddable = True # must use function
781
        lang = 'plpgsql'
782 3010 aaronmk
783 3092 aaronmk
        if cols == None:
784
            row = [sql_gen.Col(sql_gen.all_cols, 'row')]
785
            row_vars = [sql_gen.Table('row')]
786
        else:
787
            row_vars = row = [sql_gen.Col(c.name, 'row') for c in cols]
788
789 3009 aaronmk
        query = '''\
790 3010 aaronmk
DECLARE
791 3014 aaronmk
    row '''+table.to_str(db)+'''%ROWTYPE;
792 3009 aaronmk
BEGIN
793 3019 aaronmk
    /* Need an EXCEPTION block for each individual row because "When an error is
794
    caught by an EXCEPTION clause, [...] all changes to persistent database
795
    state within the block are rolled back."
796
    This is unfortunate because "A block containing an EXCEPTION clause is
797
    significantly more expensive to enter and exit than a block without one."
798 3015 aaronmk
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
799
#PLPGSQL-ERROR-TRAPPING)
800
    */
801 3092 aaronmk
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
802 3034 aaronmk
'''+select_query+'''
803
    LOOP
804 3015 aaronmk
        BEGIN
805 3019 aaronmk
            RETURN QUERY
806 3014 aaronmk
'''+mk_insert(sql_gen.Values(row).to_str(db))+'''
807 3010 aaronmk
;
808 3015 aaronmk
        EXCEPTION
809 3019 aaronmk
            WHEN unique_violation THEN NULL; -- continue to next row
810 3015 aaronmk
        END;
811 3010 aaronmk
    END LOOP;
812
END;\
813 3009 aaronmk
'''
814
    else: query = mk_insert(select_query)
815
816 2070 aaronmk
    if embeddable:
817
        # Create function
818 2513 aaronmk
        function_name = sql_gen.clean_name(first_line)
819 3181 aaronmk
        if src != None: function_name = src+': '+function_name
820 2189 aaronmk
        while True:
821
            try:
822 2918 aaronmk
                function = db.TempFunction(function_name)
823 2194 aaronmk
824 2189 aaronmk
                function_query = '''\
825 2698 aaronmk
CREATE FUNCTION '''+function.to_str(db)+'''()
826 3017 aaronmk
RETURNS SETOF '''+return_type+'''
827 3009 aaronmk
LANGUAGE '''+lang+'''
828 2467 aaronmk
AS $$
829 3009 aaronmk
'''+query+'''
830 2467 aaronmk
$$;
831 2070 aaronmk
'''
832 2446 aaronmk
                run_query(db, function_query, recover=True, cacheable=True,
833 2945 aaronmk
                    log_ignore_excs=(DuplicateException,))
834 2189 aaronmk
                break # this version was successful
835 2945 aaronmk
            except DuplicateException, e:
836 2189 aaronmk
                function_name = next_version(function_name)
837
                # try again with next version of name
838 2070 aaronmk
839 2337 aaronmk
        # Return query that uses function
840 3009 aaronmk
        cols = None
841
        if returning != None: cols = [returning]
842 2698 aaronmk
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
843 3009 aaronmk
            cols) # AS clause requires function alias
844 3298 aaronmk
        return mk_select(db, func_table, order_by=None)
845 2070 aaronmk
846 2787 aaronmk
    return query
847 2066 aaronmk
848 3074 aaronmk
def insert_select(db, table, *args, **kw_args):
849 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
850 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
851
        values in
852 2072 aaronmk
    '''
853 3141 aaronmk
    returning = kw_args.get('returning', None)
854
    ignore = kw_args.get('ignore', False)
855
856 2386 aaronmk
    into = kw_args.pop('into', None)
857
    if into != None: kw_args['embeddable'] = True
858 2066 aaronmk
    recover = kw_args.pop('recover', None)
859 3141 aaronmk
    if ignore: recover = True
860 2066 aaronmk
    cacheable = kw_args.pop('cacheable', True)
861 2673 aaronmk
    log_level = kw_args.pop('log_level', 2)
862 2066 aaronmk
863 3141 aaronmk
    rowcount_only = ignore and returning == None # keep NULL rows on server
864
    if rowcount_only: into = sql_gen.Table('rowcount')
865
866 3074 aaronmk
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
867
        into, recover=recover, cacheable=cacheable, log_level=log_level)
868 3141 aaronmk
    if rowcount_only: empty_temp(db, into)
869 3074 aaronmk
    autoanalyze(db, table)
870
    return cur
871 2063 aaronmk
872 2738 aaronmk
default = sql_gen.default # tells insert() to use the default value for a column
873 2066 aaronmk
874 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
875 2085 aaronmk
    '''For params, see insert_select()'''
876 1960 aaronmk
    if lists.is_seq(row): cols = None
877
    else:
878
        cols = row.keys()
879
        row = row.values()
880 2738 aaronmk
    row = list(row) # ensure that "== []" works
881 1960 aaronmk
882 2738 aaronmk
    if row == []: query = None
883
    else: query = sql_gen.Values(row).to_str(db)
884 1961 aaronmk
885 2788 aaronmk
    return insert_select(db, table, cols, query, *args, **kw_args)
886 11 aaronmk
887 3152 aaronmk
def mk_update(db, table, changes=None, cond=None, in_place=False,
888 3153 aaronmk
    cacheable_=True):
889 2402 aaronmk
    '''
890
    @param changes [(col, new_value),...]
891
        * container can be any iterable type
892
        * col: sql_gen.Code|str (for col name)
893
        * new_value: sql_gen.Code|literal value
894
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
895 3056 aaronmk
    @param in_place If set, locks the table and updates rows in place.
896
        This avoids creating dead rows in PostgreSQL.
897
        * cond must be None
898 3153 aaronmk
    @param cacheable_ Whether column structure information used to generate the
899 3152 aaronmk
        query can be cached
900 2402 aaronmk
    @return str query
901
    '''
902 3057 aaronmk
    table = sql_gen.as_Table(table)
903
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
904
        for c, v in changes]
905
906 3056 aaronmk
    if in_place:
907
        assert cond == None
908 3058 aaronmk
909 3065 aaronmk
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
910
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
911 3153 aaronmk
            +db.col_info(sql_gen.with_default_table(c, table), cacheable_).type
912 3065 aaronmk
            +'\nUSING '+v.to_str(db) for c, v in changes))
913 3058 aaronmk
    else:
914
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
915
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
916
            for c, v in changes))
917
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
918 3056 aaronmk
919 3266 aaronmk
    query = with_explain_comment(db, query)
920
921 2402 aaronmk
    return query
922
923 3074 aaronmk
def update(db, table, *args, **kw_args):
924 2402 aaronmk
    '''For params, see mk_update() and run_query()'''
925
    recover = kw_args.pop('recover', None)
926 3043 aaronmk
    cacheable = kw_args.pop('cacheable', False)
927 3030 aaronmk
    log_level = kw_args.pop('log_level', 2)
928 2402 aaronmk
929 3074 aaronmk
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
930
        cacheable, log_level=log_level)
931
    autoanalyze(db, table)
932
    return cur
933 2402 aaronmk
934 3286 aaronmk
def mk_delete(db, table, cond=None):
935
    '''
936
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
937
    @return str query
938
    '''
939
    query = 'DELETE FROM '+table.to_str(db)
940
    if cond != None: query += '\nWHERE '+cond.to_str(db)
941
942
    query = with_explain_comment(db, query)
943
944
    return query
945
946
def delete(db, table, *args, **kw_args):
947
    '''For params, see mk_delete() and run_query()'''
948
    recover = kw_args.pop('recover', None)
949 3295 aaronmk
    cacheable = kw_args.pop('cacheable', True)
950 3286 aaronmk
    log_level = kw_args.pop('log_level', 2)
951
952
    cur = run_query(db, mk_delete(db, table, *args, **kw_args), recover,
953
        cacheable, log_level=log_level)
954
    autoanalyze(db, table)
955
    return cur
956
957 135 aaronmk
def last_insert_id(db):
958 1849 aaronmk
    module = util.root_module(db.db)
959 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
960
    elif module == 'MySQLdb': return db.insert_id()
961
    else: return None
962 13 aaronmk
963 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
964 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
965 2415 aaronmk
    to names that will be distinct among the columns' tables.
966 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
967 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
968
    @param into The table for the new columns.
969 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
970
        columns will be included in the mapping even if they are not in cols.
971
        The tables of the provided Col objects will be changed to into, so make
972
        copies of them if you want to keep the original tables.
973
    @param as_items Whether to return a list of dict items instead of a dict
974 2383 aaronmk
    @return dict(orig_col=new_col, ...)
975
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
976 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
977
        * All mappings use the into table so its name can easily be
978 2383 aaronmk
          changed for all columns at once
979
    '''
980 2415 aaronmk
    cols = lists.uniqify(cols)
981
982 2394 aaronmk
    items = []
983 2389 aaronmk
    for col in preserve:
984 2390 aaronmk
        orig_col = copy.copy(col)
985 2392 aaronmk
        col.table = into
986 2394 aaronmk
        items.append((orig_col, col))
987
    preserve = set(preserve)
988
    for col in cols:
989 2716 aaronmk
        if col not in preserve:
990
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
991 2394 aaronmk
992
    if not as_items: items = dict(items)
993
    return items
994 2383 aaronmk
995 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
996 2391 aaronmk
    '''For params, see mk_flatten_mapping()
997
    @return See return value of mk_flatten_mapping()
998
    '''
999 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
1000
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
1001 3296 aaronmk
    run_query_into(db, mk_select(db, joins, cols, order_by=None, limit=limit,
1002 3305 aaronmk
        start=start), into=into, add_pkey_=True)
1003 2394 aaronmk
    return dict(items)
1004 2391 aaronmk
1005 3079 aaronmk
##### Database structure introspection
1006 2414 aaronmk
1007 3321 aaronmk
#### Expressions
1008
1009 3353 aaronmk
bool_re = r'(?:true|false)'
1010
1011
def simplify_expr(expr):
1012
    expr = expr.replace('(NULL IS NULL)', 'true')
1013
    expr = expr.replace('(NULL IS NOT NULL)', 'false')
1014
    expr = re.sub(r' OR '+bool_re, r'', expr)
1015
    expr = re.sub(bool_re+r' OR ', r'', expr)
1016
    while True:
1017
        expr, n = re.subn(r'\((\([^()]*\))\)', r'\1', expr)
1018
        if n == 0: break
1019
    return expr
1020
1021 3321 aaronmk
name_re = r'(?:\w+|(?:"[^"]*")+)'
1022
1023
def parse_expr_col(str_):
1024
    match = re.match(r'^\('+name_re+r'\(('+name_re+r').*\)\)$', str_)
1025
    if match: str_ = match.group(1)
1026
    return sql_gen.unesc_name(str_)
1027
1028 3351 aaronmk
def map_expr(db, expr, mapping, in_cols_found=None):
1029
    '''Replaces output columns with input columns in an expression.
1030
    @param in_cols_found If set, will be filled in with the expr's (input) cols
1031
    '''
1032
    for out, in_ in mapping.iteritems():
1033
        orig_expr = expr
1034
        out = sql_gen.to_name_only_col(out)
1035
        in_str = sql_gen.to_name_only_col(sql_gen.remove_col_rename(in_)
1036
            ).to_str(db)
1037
1038
        # Replace out both with and without quotes
1039
        expr = expr.replace(out.to_str(db), in_str)
1040
        expr = re.sub(r'\b'+out.name+r'\b', in_str, expr)
1041
1042
        if in_cols_found != None and expr != orig_expr: # replaced something
1043
            in_cols_found.append(in_)
1044 3353 aaronmk
1045
    return simplify_expr(expr)
1046 3351 aaronmk
1047 3079 aaronmk
#### Tables
1048
1049
def tables(db, schema_like='public', table_like='%', exact=False):
1050
    if exact: compare = '='
1051
    else: compare = 'LIKE'
1052
1053
    module = util.root_module(db.db)
1054
    if module == 'psycopg2':
1055
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
1056
            ('tablename', sql_gen.CompareCond(table_like, compare))]
1057
        return values(select(db, 'pg_tables', ['tablename'], conds,
1058
            order_by='tablename', log_level=4))
1059
    elif module == 'MySQLdb':
1060
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
1061
            , cacheable=True, log_level=4))
1062
    else: raise NotImplementedError("Can't list tables for "+module+' database')
1063
1064
def table_exists(db, table):
1065
    table = sql_gen.as_Table(table)
1066
    return list(tables(db, table.schema, table.name, exact=True)) != []
1067
1068 2426 aaronmk
def table_row_count(db, table, recover=None):
1069 2786 aaronmk
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
1070 3298 aaronmk
        order_by=None), recover=recover, log_level=3))
1071 2426 aaronmk
1072 2414 aaronmk
def table_cols(db, table, recover=None):
1073
    return list(col_names(select(db, table, limit=0, order_by=None,
1074 2443 aaronmk
        recover=recover, log_level=4)))
1075 2414 aaronmk
1076 2291 aaronmk
def pkey(db, table, recover=None):
1077 832 aaronmk
    '''Assumed to be first column in table'''
1078 2339 aaronmk
    return table_cols(db, table, recover)[0]
1079 832 aaronmk
1080 2559 aaronmk
not_null_col = 'not_null_col'
1081 2340 aaronmk
1082
def table_not_null_col(db, table, recover=None):
1083
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
1084
    if not_null_col in table_cols(db, table, recover): return not_null_col
1085
    else: return pkey(db, table, recover)
1086
1087 3348 aaronmk
def constraint_cond(db, constraint):
1088
    module = util.root_module(db.db)
1089
    if module == 'psycopg2':
1090
        table_str = sql_gen.Literal(constraint.table.to_str(db))
1091
        name_str = sql_gen.Literal(constraint.name)
1092
        return value(run_query(db, '''\
1093
SELECT consrc
1094
FROM pg_constraint
1095
WHERE
1096
conrelid = '''+table_str.to_str(db)+'''::regclass
1097
AND conname = '''+name_str.to_str(db)+'''
1098
'''
1099
            , cacheable=True, log_level=4))
1100
    else: raise NotImplementedError("Can't list index columns for "+module+
1101
        ' database')
1102
1103 3319 aaronmk
def index_cols(db, index):
1104 853 aaronmk
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
1105
    automatically created. When you don't know whether something is a UNIQUE
1106
    constraint or a UNIQUE index, use this function.'''
1107 3322 aaronmk
    index = sql_gen.as_Table(index)
1108 1909 aaronmk
    module = util.root_module(db.db)
1109
    if module == 'psycopg2':
1110 3322 aaronmk
        qual_index = sql_gen.Literal(index.to_str(db))
1111
        return map(parse_expr_col, values(run_query(db, '''\
1112
SELECT pg_get_indexdef(indexrelid, generate_series(1, indnatts), true)
1113
FROM pg_index
1114
WHERE indexrelid = '''+qual_index.to_str(db)+'''::regclass
1115 2782 aaronmk
'''
1116
            , cacheable=True, log_level=4)))
1117 1909 aaronmk
    else: raise NotImplementedError("Can't list index columns for "+module+
1118
        ' database')
1119 853 aaronmk
1120 3079 aaronmk
#### Functions
1121
1122
def function_exists(db, function):
1123
    function = sql_gen.as_Function(function)
1124
1125
    info_table = sql_gen.Table('routines', 'information_schema')
1126
    conds = [('routine_name', function.name)]
1127
    schema = function.schema
1128
    if schema != None: conds.append(('routine_schema', schema))
1129
    # Exclude trigger functions, since they cannot be called directly
1130
    conds.append(('data_type', sql_gen.CompareCond('trigger', '!=')))
1131
1132
    return list(values(select(db, info_table, ['routine_name'], conds,
1133
        order_by='routine_schema', limit=1, log_level=4))) != []
1134
        # TODO: order_by search_path schema order
1135
1136
##### Structural changes
1137
1138
#### Columns
1139
1140
def add_col(db, table, col, comment=None, **kw_args):
1141
    '''
1142
    @param col TypedCol Name may be versioned, so be sure to propagate any
1143
        renaming back to any source column for the TypedCol.
1144
    @param comment None|str SQL comment used to distinguish columns of the same
1145
        name from each other when they contain different data, to allow the
1146
        ADD COLUMN query to be cached. If not set, query will not be cached.
1147
    '''
1148
    assert isinstance(col, sql_gen.TypedCol)
1149
1150
    while True:
1151
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1152
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1153
1154
        try:
1155
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1156
            break
1157
        except DuplicateException:
1158
            col.name = next_version(col.name)
1159
            # try again with next version of name
1160
1161
def add_not_null(db, col):
1162
    table = col.table
1163
    col = sql_gen.to_name_only_col(col)
1164
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1165
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1166
1167 2096 aaronmk
row_num_col = '_row_num'
1168
1169 3079 aaronmk
row_num_typed_col = sql_gen.TypedCol(row_num_col, 'serial', nullable=False,
1170
    constraints='PRIMARY KEY')
1171
1172
def add_row_num(db, table):
1173
    '''Adds a row number column to a table. Its name is in row_num_col. It will
1174
    be the primary key.'''
1175
    add_col(db, table, row_num_typed_col, log_level=3)
1176
1177
#### Indexes
1178
1179
def add_pkey(db, table, cols=None, recover=None):
1180
    '''Adds a primary key.
1181
    @param cols [sql_gen.Col,...] The columns in the primary key.
1182
        Defaults to the first column in the table.
1183
    @pre The table must not already have a primary key.
1184
    '''
1185
    table = sql_gen.as_Table(table)
1186
    if cols == None: cols = [pkey(db, table, recover)]
1187
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1188
1189
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1190
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1191
        log_ignore_excs=(DuplicateException,))
1192
1193 2998 aaronmk
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1194 2688 aaronmk
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1195 3356 aaronmk
    Currently, only function calls and literal values are supported expressions.
1196 2998 aaronmk
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1197 2847 aaronmk
        This allows indexes to be used for comparisons where NULLs are equal.
1198 2538 aaronmk
    '''
1199 2964 aaronmk
    exprs = lists.mk_seq(exprs)
1200 2538 aaronmk
1201 2688 aaronmk
    # Parse exprs
1202
    old_exprs = exprs[:]
1203
    exprs = []
1204
    cols = []
1205
    for i, expr in enumerate(old_exprs):
1206 2823 aaronmk
        expr = sql_gen.as_Col(expr, table)
1207 2688 aaronmk
1208 2823 aaronmk
        # Handle nullable columns
1209 2998 aaronmk
        if ensure_not_null_:
1210 3164 aaronmk
            try: expr = sql_gen.ensure_not_null(db, expr)
1211 2860 aaronmk
            except KeyError: pass # unknown type, so just create plain index
1212 2823 aaronmk
1213 2688 aaronmk
        # Extract col
1214 3002 aaronmk
        expr = copy.deepcopy(expr) # don't modify input!
1215 3356 aaronmk
        col = expr
1216
        if isinstance(expr, sql_gen.FunctionCall): col = expr.args[0]
1217
        expr = sql_gen.cast_literal(expr)
1218
        if not isinstance(expr, (sql_gen.Expr, sql_gen.Col)):
1219 2688 aaronmk
            expr = sql_gen.Expr(expr)
1220 3356 aaronmk
1221 2688 aaronmk
1222
        # Extract table
1223
        if table == None:
1224
            assert sql_gen.is_table_col(col)
1225
            table = col.table
1226
1227 3356 aaronmk
        if isinstance(col, sql_gen.Col): col.table = None
1228 2688 aaronmk
1229
        exprs.append(expr)
1230
        cols.append(col)
1231 2408 aaronmk
1232 2688 aaronmk
    table = sql_gen.as_Table(table)
1233
1234 3005 aaronmk
    # Add index
1235 3148 aaronmk
    str_ = 'CREATE'
1236
    if unique: str_ += ' UNIQUE'
1237
    str_ += ' INDEX ON '+table.to_str(db)+' ('+(
1238
        ', '.join((v.to_str(db) for v in exprs)))+')'
1239
    run_query(db, str_, recover=True, cacheable=True, log_level=3)
1240 2408 aaronmk
1241 3083 aaronmk
already_indexed = object() # tells add_indexes() the pkey has already been added
1242
1243
def add_indexes(db, table, has_pkey=True):
1244
    '''Adds an index on all columns in a table.
1245
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1246
        index should be added on the first column.
1247
        * If already_indexed, the pkey is assumed to have already been added
1248
    '''
1249
    cols = table_cols(db, table)
1250
    if has_pkey:
1251
        if has_pkey is not already_indexed: add_pkey(db, table)
1252
        cols = cols[1:]
1253
    for col in cols: add_index(db, col, table)
1254
1255 3079 aaronmk
#### Tables
1256 2772 aaronmk
1257 3079 aaronmk
### Maintenance
1258 2772 aaronmk
1259 3079 aaronmk
def analyze(db, table):
1260
    table = sql_gen.as_Table(table)
1261
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1262 2934 aaronmk
1263 3079 aaronmk
def autoanalyze(db, table):
1264
    if db.autoanalyze: analyze(db, table)
1265 2935 aaronmk
1266 3079 aaronmk
def vacuum(db, table):
1267
    table = sql_gen.as_Table(table)
1268
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1269
        log_level=3))
1270 2086 aaronmk
1271 3079 aaronmk
### Lifecycle
1272
1273 3247 aaronmk
def drop(db, type_, name):
1274
    name = sql_gen.as_Name(name)
1275
    run_query(db, 'DROP '+type_+' IF EXISTS '+name.to_str(db)+' CASCADE')
1276 2889 aaronmk
1277 3247 aaronmk
def drop_table(db, table): drop(db, 'TABLE', table)
1278
1279 3082 aaronmk
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1280
    like=None):
1281 2675 aaronmk
    '''Creates a table.
1282 2681 aaronmk
    @param cols [sql_gen.TypedCol,...] The column names and types
1283
    @param has_pkey If set, the first column becomes the primary key.
1284 2760 aaronmk
    @param col_indexes bool|[ref]
1285
        * If True, indexes will be added on all non-pkey columns.
1286
        * If a list reference, [0] will be set to a function to do this.
1287
          This can be used to delay index creation until the table is populated.
1288 2675 aaronmk
    '''
1289
    table = sql_gen.as_Table(table)
1290
1291 3082 aaronmk
    if like != None:
1292
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1293
            ]+cols
1294 2681 aaronmk
    if has_pkey:
1295
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1296 2872 aaronmk
        pkey.constraints = 'PRIMARY KEY'
1297 2681 aaronmk
1298 3085 aaronmk
    temp = table.is_temp and not db.debug_temp
1299
        # temp tables permanent in debug_temp mode
1300 2760 aaronmk
1301 3085 aaronmk
    # Create table
1302 3383 aaronmk
    def create():
1303 3085 aaronmk
        str_ = 'CREATE'
1304
        if temp: str_ += ' TEMP'
1305
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1306
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1307 3126 aaronmk
        str_ += '\n);'
1308 3085 aaronmk
1309 3383 aaronmk
        run_query(db, str_, recover=True, cacheable=True, log_level=2,
1310
            log_ignore_excs=(DuplicateException,))
1311
    if table.is_temp:
1312
        while True:
1313
            try:
1314
                create()
1315
                break
1316
            except DuplicateException:
1317
                table.name = next_version(table.name)
1318
                # try again with next version of name
1319
    else: create()
1320 3085 aaronmk
1321 2760 aaronmk
    # Add indexes
1322 2773 aaronmk
    if has_pkey: has_pkey = already_indexed
1323
    def add_indexes_(): add_indexes(db, table, has_pkey)
1324
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1325
    elif col_indexes: add_indexes_() # add now
1326 2675 aaronmk
1327 3084 aaronmk
def copy_table_struct(db, src, dest):
1328
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1329 3085 aaronmk
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1330 3084 aaronmk
1331 3079 aaronmk
### Data
1332 2684 aaronmk
1333 2970 aaronmk
def truncate(db, table, schema='public', **kw_args):
1334
    '''For params, see run_query()'''
1335 2777 aaronmk
    table = sql_gen.as_Table(table, schema)
1336 2970 aaronmk
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1337 2732 aaronmk
1338 2965 aaronmk
def empty_temp(db, tables):
1339
    tables = lists.mk_seq(tables)
1340 2971 aaronmk
    for table in tables: truncate(db, table, log_level=3)
1341 2965 aaronmk
1342 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
1343
    '''For kw_args, see tables()'''
1344
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1345 3094 aaronmk
1346
def distinct_table(db, table, distinct_on):
1347
    '''Creates a copy of a temp table which is distinct on the given columns.
1348 3099 aaronmk
    The old and new tables will both get an index on these columns, to
1349
    facilitate merge joins.
1350 3097 aaronmk
    @param distinct_on If empty, creates a table with one row. This is useful if
1351
        your distinct_on columns are all literal values.
1352 3099 aaronmk
    @return The new table.
1353 3094 aaronmk
    '''
1354 3099 aaronmk
    new_table = sql_gen.suffixed_table(table, '_distinct')
1355 3411 aaronmk
    distinct_on = filter(sql_gen.is_table_col, distinct_on)
1356 3094 aaronmk
1357 3099 aaronmk
    copy_table_struct(db, table, new_table)
1358 3097 aaronmk
1359
    limit = None
1360
    if distinct_on == []: limit = 1 # one sample row
1361 3099 aaronmk
    else:
1362
        add_index(db, distinct_on, new_table, unique=True)
1363
        add_index(db, distinct_on, table) # for join optimization
1364 3097 aaronmk
1365 3313 aaronmk
    insert_select(db, new_table, None, mk_select(db, table, order_by=None,
1366
        limit=limit), ignore=True)
1367 3099 aaronmk
    analyze(db, new_table)
1368 3094 aaronmk
1369 3099 aaronmk
    return new_table