Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 3238 aaronmk
import time
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 3241 aaronmk
import profiling
13 1889 aaronmk
from Proxy import Proxy
14 1872 aaronmk
import rand
15 5349 aaronmk
import regexp
16 2217 aaronmk
import sql_gen
17 862 aaronmk
import strings
18 131 aaronmk
import util
19 11 aaronmk
20 832 aaronmk
##### Exceptions
21
22 2804 aaronmk
def get_cur_query(cur, input_query=None):
23 2168 aaronmk
    raw_query = None
24
    if hasattr(cur, 'query'): raw_query = cur.query
25
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
26 2170 aaronmk
27
    if raw_query != None: return raw_query
28 2804 aaronmk
    else: return '[input] '+strings.ustr(input_query)
29 14 aaronmk
30 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
31
    '''For params, see get_cur_query()'''
32 2771 aaronmk
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
33 135 aaronmk
34 300 aaronmk
class DbException(exc.ExceptionWithCause):
35 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
36 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
37 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
38
39 2143 aaronmk
class ExceptionWithName(DbException):
40
    def __init__(self, name, cause=None):
41 4491 aaronmk
        DbException.__init__(self, 'for name: '
42
            +strings.as_tt(strings.ustr(name)), cause)
43 2143 aaronmk
        self.name = name
44 360 aaronmk
45 3109 aaronmk
class ExceptionWithValue(DbException):
46
    def __init__(self, value, cause=None):
47 4492 aaronmk
        DbException.__init__(self, 'for value: '
48
            +strings.as_tt(strings.urepr(value)), cause)
49 2240 aaronmk
        self.value = value
50
51 2945 aaronmk
class ExceptionWithNameType(DbException):
52
    def __init__(self, type_, name, cause=None):
53 4491 aaronmk
        DbException.__init__(self, 'for type: '+strings.as_tt(strings.ustr(
54
            type_))+'; name: '+strings.as_tt(name), cause)
55 2945 aaronmk
        self.type = type_
56
        self.name = name
57
58 2306 aaronmk
class ConstraintException(DbException):
59 3345 aaronmk
    def __init__(self, name, cond, cols, cause=None):
60
        msg = 'Violated '+strings.as_tt(name)+' constraint'
61
        if cond != None: msg += ' with condition '+cond
62
        if cols != []: msg += ' on columns: '+strings.as_tt(', '.join(cols))
63
        DbException.__init__(self, msg, cause)
64 2306 aaronmk
        self.name = name
65 3345 aaronmk
        self.cond = cond
66 468 aaronmk
        self.cols = cols
67 11 aaronmk
68 2523 aaronmk
class MissingCastException(DbException):
69 4139 aaronmk
    def __init__(self, type_, col=None, cause=None):
70
        msg = 'Missing cast to type '+strings.as_tt(type_)
71
        if col != None: msg += ' on column: '+strings.as_tt(col)
72
        DbException.__init__(self, msg, cause)
73 2523 aaronmk
        self.type = type_
74
        self.col = col
75
76 2143 aaronmk
class NameException(DbException): pass
77
78 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
79 13 aaronmk
80 2306 aaronmk
class NullValueException(ConstraintException): pass
81 13 aaronmk
82 3346 aaronmk
class CheckException(ConstraintException): pass
83
84 3109 aaronmk
class InvalidValueException(ExceptionWithValue): pass
85 2239 aaronmk
86 2945 aaronmk
class DuplicateException(ExceptionWithNameType): pass
87 2143 aaronmk
88 3419 aaronmk
class DoesNotExistException(ExceptionWithNameType): pass
89
90 89 aaronmk
class EmptyRowException(DbException): pass
91
92 865 aaronmk
##### Warnings
93
94
class DbWarning(UserWarning): pass
95
96 1930 aaronmk
##### Result retrieval
97
98
def col_names(cur): return (col[0] for col in cur.description)
99
100
def rows(cur): return iter(lambda: cur.fetchone(), None)
101
102
def consume_rows(cur):
103
    '''Used to fetch all rows so result will be cached'''
104
    iters.consume_iter(rows(cur))
105
106
def next_row(cur): return rows(cur).next()
107
108
def row(cur):
109
    row_ = next_row(cur)
110
    consume_rows(cur)
111
    return row_
112
113
def next_value(cur): return next_row(cur)[0]
114
115
def value(cur): return row(cur)[0]
116
117
def values(cur): return iters.func_iter(lambda: next_value(cur))
118
119
def value_or_none(cur):
120
    try: return value(cur)
121
    except StopIteration: return None
122
123 2762 aaronmk
##### Escaping
124 2101 aaronmk
125 2573 aaronmk
def esc_name_by_module(module, name):
126
    if module == 'psycopg2' or module == None: quote = '"'
127 2101 aaronmk
    elif module == 'MySQLdb': quote = '`'
128
    else: raise NotImplementedError("Can't escape name for "+module+' database')
129 2500 aaronmk
    return sql_gen.esc_name(name, quote)
130 2101 aaronmk
131
def esc_name_by_engine(engine, name, **kw_args):
132
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
133
134
def esc_name(db, name, **kw_args):
135
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
136
137
def qual_name(db, schema, table):
138
    def esc_name_(name): return esc_name(db, name)
139
    table = esc_name_(table)
140
    if schema != None: return esc_name_(schema)+'.'+table
141
    else: return table
142
143 1869 aaronmk
##### Database connections
144 1849 aaronmk
145 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
146 1926 aaronmk
147 1869 aaronmk
db_engines = {
148
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
149
    'PostgreSQL': ('psycopg2', {}),
150
}
151
152
DatabaseErrors_set = set([DbException])
153
DatabaseErrors = tuple(DatabaseErrors_set)
154
155
def _add_module(module):
156
    DatabaseErrors_set.add(module.DatabaseError)
157
    global DatabaseErrors
158
    DatabaseErrors = tuple(DatabaseErrors_set)
159
160
def db_config_str(db_config):
161
    return db_config['engine']+' database '+db_config['database']
162
163 2448 aaronmk
log_debug_none = lambda msg, level=2: None
164 1901 aaronmk
165 1849 aaronmk
class DbConn:
166 2923 aaronmk
    def __init__(self, db_config, autocommit=True, caching=True,
167 3183 aaronmk
        log_debug=log_debug_none, debug_temp=False, src=None):
168 2915 aaronmk
        '''
169
        @param debug_temp Whether temporary objects should instead be permanent.
170
            This assists in debugging the internal objects used by the program.
171 3183 aaronmk
        @param src In autocommit mode, will be included in a comment in every
172
            query, to help identify the data source in pg_stat_activity.
173 2915 aaronmk
        '''
174 1869 aaronmk
        self.db_config = db_config
175 2190 aaronmk
        self.autocommit = autocommit
176
        self.caching = caching
177 1901 aaronmk
        self.log_debug = log_debug
178 2193 aaronmk
        self.debug = log_debug != log_debug_none
179 2915 aaronmk
        self.debug_temp = debug_temp
180 3183 aaronmk
        self.src = src
181 3074 aaronmk
        self.autoanalyze = False
182 3269 aaronmk
        self.autoexplain = False
183
        self.profile_row_ct = None
184 1869 aaronmk
185 3124 aaronmk
        self._savepoint = 0
186 3120 aaronmk
        self._reset()
187 1869 aaronmk
188
    def __getattr__(self, name):
189
        if name == '__dict__': raise Exception('getting __dict__')
190
        if name == 'db': return self._db()
191
        else: raise AttributeError()
192
193
    def __getstate__(self):
194
        state = copy.copy(self.__dict__) # shallow copy
195 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
196 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
197
        return state
198
199 3118 aaronmk
    def clear_cache(self): self.query_results = {}
200
201 3120 aaronmk
    def _reset(self):
202 3118 aaronmk
        self.clear_cache()
203 3124 aaronmk
        assert self._savepoint == 0
204 3118 aaronmk
        self._notices_seen = set()
205
        self.__db = None
206
207 2165 aaronmk
    def connected(self): return self.__db != None
208
209 3116 aaronmk
    def close(self):
210 3119 aaronmk
        if not self.connected(): return
211
212 3135 aaronmk
        # Record that the automatic transaction is now closed
213 3136 aaronmk
        self._savepoint -= 1
214 3135 aaronmk
215 3119 aaronmk
        self.db.close()
216 3120 aaronmk
        self._reset()
217 3116 aaronmk
218 3125 aaronmk
    def reconnect(self):
219
        # Do not do this in test mode as it would roll back everything
220
        if self.autocommit: self.close()
221
        # Connection will be reopened automatically on first query
222
223 1869 aaronmk
    def _db(self):
224
        if self.__db == None:
225
            # Process db_config
226
            db_config = self.db_config.copy() # don't modify input!
227 2097 aaronmk
            schemas = db_config.pop('schemas', None)
228 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
229
            module = __import__(module_name)
230
            _add_module(module)
231
            for orig, new in mappings.iteritems():
232
                try: util.rename_key(db_config, orig, new)
233
                except KeyError: pass
234
235
            # Connect
236
            self.__db = module.connect(**db_config)
237
238 3161 aaronmk
            # Record that a transaction is already open
239
            self._savepoint += 1
240
241 1869 aaronmk
            # Configure connection
242 2906 aaronmk
            if hasattr(self.db, 'set_isolation_level'):
243
                import psycopg2.extensions
244
                self.db.set_isolation_level(
245
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
246 2101 aaronmk
            if schemas != None:
247 2893 aaronmk
                search_path = [self.esc_name(s) for s in schemas.split(',')]
248
                search_path.append(value(run_query(self, 'SHOW search_path',
249
                    log_level=4)))
250
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
251
                    log_level=3)
252 1869 aaronmk
253
        return self.__db
254 1889 aaronmk
255 1891 aaronmk
    class DbCursor(Proxy):
256 1927 aaronmk
        def __init__(self, outer):
257 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
258 2191 aaronmk
            self.outer = outer
259 1927 aaronmk
            self.query_results = outer.query_results
260 1894 aaronmk
            self.query_lookup = None
261 1891 aaronmk
            self.result = []
262 1889 aaronmk
263 2802 aaronmk
        def execute(self, query):
264 2764 aaronmk
            self._is_insert = query.startswith('INSERT')
265 2797 aaronmk
            self.query_lookup = query
266 2148 aaronmk
            try:
267 3162 aaronmk
                try: cur = self.inner.execute(query)
268 2802 aaronmk
                finally: self.query = get_cur_query(self.inner, query)
269 1904 aaronmk
            except Exception, e:
270
                self.result = e # cache the exception as the result
271
                self._cache_result()
272
                raise
273 3004 aaronmk
274
            # Always cache certain queries
275 3183 aaronmk
            query = sql_gen.lstrip(query)
276 3004 aaronmk
            if query.startswith('CREATE') or query.startswith('ALTER'):
277 3007 aaronmk
                # structural changes
278 3040 aaronmk
                # Rest of query must be unique in the face of name collisions,
279
                # so don't cache ADD COLUMN unless it has distinguishing comment
280
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
281 3007 aaronmk
                    self._cache_result()
282 3004 aaronmk
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
283 2800 aaronmk
                consume_rows(self) # fetch all rows so result will be cached
284 3004 aaronmk
285 2762 aaronmk
            return cur
286 1894 aaronmk
287 1891 aaronmk
        def fetchone(self):
288
            row = self.inner.fetchone()
289 1899 aaronmk
            if row != None: self.result.append(row)
290
            # otherwise, fetched all rows
291 1904 aaronmk
            else: self._cache_result()
292
            return row
293
294
        def _cache_result(self):
295 2948 aaronmk
            # For inserts that return a result set, don't cache result set since
296
            # inserts are not idempotent. Other non-SELECT queries don't have
297
            # their result set read, so only exceptions will be cached (an
298
            # invalid query will always be invalid).
299 1930 aaronmk
            if self.query_results != None and (not self._is_insert
300 1906 aaronmk
                or isinstance(self.result, Exception)):
301
302 1894 aaronmk
                assert self.query_lookup != None
303 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
304
                    util.dict_subset(dicts.AttrsDictView(self),
305
                    ['query', 'result', 'rowcount', 'description']))
306 1906 aaronmk
307 1916 aaronmk
        class CacheCursor:
308
            def __init__(self, cached_result): self.__dict__ = cached_result
309
310 1927 aaronmk
            def execute(self, *args, **kw_args):
311 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
312
                # otherwise, result is a rows list
313
                self.iter = iter(self.result)
314
315
            def fetchone(self):
316
                try: return self.iter.next()
317
                except StopIteration: return None
318 1891 aaronmk
319 2212 aaronmk
    def esc_value(self, value):
320 2663 aaronmk
        try: str_ = self.mogrify('%s', [value])
321
        except NotImplementedError, e:
322
            module = util.root_module(self.db)
323
            if module == 'MySQLdb':
324
                import _mysql
325
                str_ = _mysql.escape_string(value)
326
            else: raise e
327 2374 aaronmk
        return strings.to_unicode(str_)
328 2212 aaronmk
329 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
330
331 2814 aaronmk
    def std_code(self, str_):
332
        '''Standardizes SQL code.
333
        * Ensures that string literals are prefixed by `E`
334
        '''
335
        if str_.startswith("'"): str_ = 'E'+str_
336
        return str_
337
338 2665 aaronmk
    def can_mogrify(self):
339 2663 aaronmk
        module = util.root_module(self.db)
340 2665 aaronmk
        return module == 'psycopg2'
341 2663 aaronmk
342 2665 aaronmk
    def mogrify(self, query, params=None):
343
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
344
        else: raise NotImplementedError("Can't mogrify query")
345
346 2671 aaronmk
    def print_notices(self):
347 2725 aaronmk
        if hasattr(self.db, 'notices'):
348
            for msg in self.db.notices:
349
                if msg not in self._notices_seen:
350
                    self._notices_seen.add(msg)
351
                    self.log_debug(msg, level=2)
352 2671 aaronmk
353 2793 aaronmk
    def run_query(self, query, cacheable=False, log_level=2,
354 2464 aaronmk
        debug_msg_ref=None):
355 2445 aaronmk
        '''
356 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
357
            throws one of these exceptions.
358 2664 aaronmk
        @param debug_msg_ref If specified, the log message will be returned in
359
            this instead of being output. This allows you to filter log messages
360
            depending on the result of the query.
361 2445 aaronmk
        '''
362 2167 aaronmk
        assert query != None
363
364 3183 aaronmk
        if self.autocommit and self.src != None:
365 3206 aaronmk
            query = sql_gen.esc_comment(self.src)+'\t'+query
366 3183 aaronmk
367 2047 aaronmk
        if not self.caching: cacheable = False
368 1903 aaronmk
        used_cache = False
369 2664 aaronmk
370 3242 aaronmk
        if self.debug:
371
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
372 1903 aaronmk
        try:
373 1927 aaronmk
            # Get cursor
374
            if cacheable:
375 3238 aaronmk
                try: cur = self.query_results[query]
376 1927 aaronmk
                except KeyError: cur = self.DbCursor(self)
377 3238 aaronmk
                else: used_cache = True
378 1927 aaronmk
            else: cur = self.db.cursor()
379
380
            # Run query
381 3238 aaronmk
            try: cur.execute(query)
382 3162 aaronmk
            except Exception, e:
383
                _add_cursor_info(e, self, query)
384
                raise
385 3238 aaronmk
            else: self.do_autocommit()
386 1903 aaronmk
        finally:
387 3242 aaronmk
            if self.debug:
388 3244 aaronmk
                profiler.stop(self.profile_row_ct)
389 3242 aaronmk
390
                ## Log or return query
391
392 4491 aaronmk
                query = strings.ustr(get_cur_query(cur, query))
393 3281 aaronmk
                # Put the src comment on a separate line in the log file
394
                query = query.replace('\t', '\n', 1)
395 3239 aaronmk
396 3240 aaronmk
                msg = 'DB query: '
397 3239 aaronmk
398 3240 aaronmk
                if used_cache: msg += 'cache hit'
399
                elif cacheable: msg += 'cache miss'
400
                else: msg += 'non-cacheable'
401 3239 aaronmk
402 3241 aaronmk
                msg += ':\n'+profiler.msg()+'\n'+strings.as_code(query, 'SQL')
403 3240 aaronmk
404 3237 aaronmk
                if debug_msg_ref != None: debug_msg_ref[0] = msg
405
                else: self.log_debug(msg, log_level)
406 3245 aaronmk
407
                self.print_notices()
408 1903 aaronmk
409
        return cur
410 1914 aaronmk
411 2797 aaronmk
    def is_cached(self, query): return query in self.query_results
412 2139 aaronmk
413 2907 aaronmk
    def with_autocommit(self, func):
414 2801 aaronmk
        import psycopg2.extensions
415
416
        prev_isolation_level = self.db.isolation_level
417 2907 aaronmk
        self.db.set_isolation_level(
418
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
419 2683 aaronmk
        try: return func()
420 2801 aaronmk
        finally: self.db.set_isolation_level(prev_isolation_level)
421 2683 aaronmk
422 2139 aaronmk
    def with_savepoint(self, func):
423 3137 aaronmk
        top = self._savepoint == 0
424 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
425 3137 aaronmk
426 3272 aaronmk
        if self.debug:
427 3273 aaronmk
            self.log_debug('Begin transaction', level=4)
428 3272 aaronmk
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
429
430 3160 aaronmk
        # Must happen before running queries so they don't get autocommitted
431
        self._savepoint += 1
432
433 3137 aaronmk
        if top: query = 'START TRANSACTION ISOLATION LEVEL READ COMMITTED'
434
        else: query = 'SAVEPOINT '+savepoint
435
        self.run_query(query, log_level=4)
436
        try:
437
            return func()
438
            if top: self.run_query('COMMIT', log_level=4)
439 2139 aaronmk
        except:
440 3137 aaronmk
            if top: query = 'ROLLBACK'
441
            else: query = 'ROLLBACK TO SAVEPOINT '+savepoint
442
            self.run_query(query, log_level=4)
443
444 2139 aaronmk
            raise
445 2930 aaronmk
        finally:
446
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
447
            # "The savepoint remains valid and can be rolled back to again"
448
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
449 3137 aaronmk
            if not top:
450
                self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
451 2930 aaronmk
452
            self._savepoint -= 1
453
            assert self._savepoint >= 0
454
455 3272 aaronmk
            if self.debug:
456
                profiler.stop(self.profile_row_ct)
457 3273 aaronmk
                self.log_debug('End transaction\n'+profiler.msg(), level=4)
458 3272 aaronmk
459 2930 aaronmk
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
460 2191 aaronmk
461
    def do_autocommit(self):
462
        '''Autocommits if outside savepoint'''
463 3135 aaronmk
        assert self._savepoint >= 1
464
        if self.autocommit and self._savepoint == 1:
465 2924 aaronmk
            self.log_debug('Autocommitting', level=4)
466 2191 aaronmk
            self.db.commit()
467 2643 aaronmk
468 3155 aaronmk
    def col_info(self, col, cacheable=True):
469 2643 aaronmk
        table = sql_gen.Table('columns', 'information_schema')
470 4936 aaronmk
        cols = [sql_gen.Col('data_type'), sql_gen.Col('udt_name'),
471
            'column_default', sql_gen.Cast('boolean',
472
            sql_gen.Col('is_nullable'))]
473 2643 aaronmk
474 3750 aaronmk
        conds = [('table_name', col.table.name),
475
            ('column_name', strings.ustr(col.name))]
476 2643 aaronmk
        schema = col.table.schema
477
        if schema != None: conds.append(('table_schema', schema))
478
479 3638 aaronmk
        cur = select(self, table, cols, conds, order_by='table_schema', limit=1,
480
            cacheable=cacheable, log_level=4) # TODO: order by search_path order
481 4936 aaronmk
        try: type_, extra_type, default, nullable = row(cur)
482 4114 aaronmk
        except StopIteration: raise sql_gen.NoUnderlyingTableException(col)
483 2819 aaronmk
        default = sql_gen.as_Code(default, self)
484 4936 aaronmk
        if type_ == 'USER-DEFINED': type_ = extra_type
485 4939 aaronmk
        elif type_ == 'ARRAY':
486
            type_ = sql_gen.ArrayType(strings.remove_prefix('_', extra_type,
487
                require=True))
488 2819 aaronmk
489
        return sql_gen.TypedCol(col.name, type_, default, nullable)
490 2917 aaronmk
491
    def TempFunction(self, name):
492
        if self.debug_temp: schema = None
493
        else: schema = 'pg_temp'
494
        return sql_gen.Function(name, schema)
495 1849 aaronmk
496 1869 aaronmk
connect = DbConn
497
498 832 aaronmk
##### Recoverable querying
499 15 aaronmk
500 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
501 11 aaronmk
502 2791 aaronmk
def run_query(db, query, recover=None, cacheable=False, log_level=2,
503
    log_ignore_excs=None, **kw_args):
504 2794 aaronmk
    '''For params, see DbConn.run_query()'''
505 830 aaronmk
    if recover == None: recover = False
506 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
507
    log_ignore_excs = tuple(log_ignore_excs)
508 3236 aaronmk
    debug_msg_ref = [None]
509 830 aaronmk
510 3267 aaronmk
    query = with_explain_comment(db, query)
511 3258 aaronmk
512 2148 aaronmk
    try:
513 2464 aaronmk
        try:
514 2794 aaronmk
            def run(): return db.run_query(query, cacheable, log_level,
515 2793 aaronmk
                debug_msg_ref, **kw_args)
516 2796 aaronmk
            if recover and not db.is_cached(query):
517 2464 aaronmk
                return with_savepoint(db, run)
518
            else: return run() # don't need savepoint if cached
519
        except Exception, e:
520 3095 aaronmk
            msg = strings.ustr(e.args[0])
521 4103 aaronmk
            msg = re.sub(r'^(?:PL/Python: )?ValueError: ', r'', msg)
522 2464 aaronmk
523 3095 aaronmk
            match = re.match(r'^duplicate key value violates unique constraint '
524 3338 aaronmk
                r'"(.+?)"', msg)
525 2464 aaronmk
            if match:
526 3338 aaronmk
                constraint, = match.groups()
527 3025 aaronmk
                cols = []
528
                if recover: # need auto-rollback to run index_cols()
529 3319 aaronmk
                    try: cols = index_cols(db, constraint)
530 3025 aaronmk
                    except NotImplementedError: pass
531 3345 aaronmk
                raise DuplicateKeyException(constraint, None, cols, e)
532 2464 aaronmk
533 3095 aaronmk
            match = re.match(r'^null value in column "(.+?)" violates not-null'
534 2464 aaronmk
                r' constraint', msg)
535 3345 aaronmk
            if match:
536
                col, = match.groups()
537
                raise NullValueException('NOT NULL', None, [col], e)
538 2464 aaronmk
539 3346 aaronmk
            match = re.match(r'^new row for relation "(.+?)" violates check '
540
                r'constraint "(.+?)"', msg)
541
            if match:
542
                table, constraint = match.groups()
543 3347 aaronmk
                constraint = sql_gen.Col(constraint, table)
544 3349 aaronmk
                cond = None
545
                if recover: # need auto-rollback to run constraint_cond()
546
                    try: cond = constraint_cond(db, constraint)
547
                    except NotImplementedError: pass
548
                raise CheckException(constraint.to_str(db), cond, [], e)
549 3346 aaronmk
550 3095 aaronmk
            match = re.match(r'^(?:invalid input (?:syntax|value)\b.*?'
551 3635 aaronmk
                r'|.+? out of range): "(.+?)"', msg)
552 2464 aaronmk
            if match:
553 3109 aaronmk
                value, = match.groups()
554
                raise InvalidValueException(strings.to_unicode(value), e)
555 2464 aaronmk
556 3095 aaronmk
            match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
557 2523 aaronmk
                r'is of type', msg)
558
            if match:
559
                col, type_ = match.groups()
560
                raise MissingCastException(type_, col, e)
561
562 4141 aaronmk
            match = re.match(r'^could not determine polymorphic type because '
563
                r'input has type "unknown"', msg)
564
            if match: raise MissingCastException('text', None, e)
565
566 4485 aaronmk
            match = re.match(r'^.+? types .+? and .+? cannot be matched', msg)
567
            if match: raise MissingCastException('text', None, e)
568
569 4509 aaronmk
            typed_name_re = r'^(\S+) "(.+?)"(?: of relation ".+?")?'
570 3419 aaronmk
571
            match = re.match(typed_name_re+r'.*? already exists', msg)
572 2945 aaronmk
            if match:
573
                type_, name = match.groups()
574
                raise DuplicateException(type_, name, e)
575 2464 aaronmk
576 4145 aaronmk
            match = re.match(r'more than one (\S+) named ""(.+?)""', msg)
577
            if match:
578
                type_, name = match.groups()
579
                raise DuplicateException(type_, name, e)
580
581 3419 aaronmk
            match = re.match(typed_name_re+r' does not exist', msg)
582
            if match:
583
                type_, name = match.groups()
584
                raise DoesNotExistException(type_, name, e)
585
586 2464 aaronmk
            raise # no specific exception raised
587
    except log_ignore_excs:
588
        log_level += 2
589
        raise
590
    finally:
591 3236 aaronmk
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)
592 830 aaronmk
593 832 aaronmk
##### Basic queries
594
595 3256 aaronmk
def is_explainable(query):
596
    # See <http://www.postgresql.org/docs/8.3/static/sql-explain.html#AEN57749>
597 3257 aaronmk
    return re.match(r'^(?:SELECT|INSERT|UPDATE|DELETE|VALUES|EXECUTE|DECLARE)\b'
598
        , query)
599 3256 aaronmk
600 3263 aaronmk
def explain(db, query, **kw_args):
601
    '''
602
    For params, see run_query().
603
    '''
604 3267 aaronmk
    kw_args.setdefault('log_level', 4)
605 3263 aaronmk
606 3750 aaronmk
    return strings.ustr(strings.join_lines(values(run_query(db,
607
        'EXPLAIN '+query, recover=True, cacheable=True, **kw_args))))
608 3256 aaronmk
        # not a higher log_level because it's useful to see what query is being
609
        # run before it's executed, which EXPLAIN effectively provides
610
611 3265 aaronmk
def has_comment(query): return query.endswith('*/')
612
613
def with_explain_comment(db, query, **kw_args):
614 3269 aaronmk
    if db.autoexplain and not has_comment(query) and is_explainable(query):
615 3265 aaronmk
        query += '\n'+sql_gen.esc_comment(' EXPLAIN:\n'
616
            +explain(db, query, **kw_args))
617
    return query
618
619 2153 aaronmk
def next_version(name):
620 2163 aaronmk
    version = 1 # first existing name was version 0
621 2586 aaronmk
    match = re.match(r'^(.*)#(\d+)$', name)
622 2153 aaronmk
    if match:
623 2586 aaronmk
        name, version = match.groups()
624
        version = int(version)+1
625 2932 aaronmk
    return sql_gen.concat(name, '#'+str(version))
626 2153 aaronmk
627 2899 aaronmk
def lock_table(db, table, mode):
628
    table = sql_gen.as_Table(table)
629
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
630
631 3303 aaronmk
def run_query_into(db, query, into=None, add_pkey_=False, **kw_args):
632 2085 aaronmk
    '''Outputs a query to a temp table.
633
    For params, see run_query().
634
    '''
635 2789 aaronmk
    if into == None: return run_query(db, query, **kw_args)
636 2790 aaronmk
637
    assert isinstance(into, sql_gen.Table)
638
639 2992 aaronmk
    into.is_temp = True
640 3008 aaronmk
    # "temporary tables cannot specify a schema name", so remove schema
641
    into.schema = None
642 2992 aaronmk
643 2790 aaronmk
    kw_args['recover'] = True
644 2945 aaronmk
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
645 2790 aaronmk
646 2916 aaronmk
    temp = not db.debug_temp # tables are permanent in debug_temp mode
647 2790 aaronmk
648
    # Create table
649
    while True:
650
        create_query = 'CREATE'
651
        if temp: create_query += ' TEMP'
652
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
653 2385 aaronmk
654 2790 aaronmk
        try:
655
            cur = run_query(db, create_query, **kw_args)
656
                # CREATE TABLE AS sets rowcount to # rows in query
657
            break
658 2945 aaronmk
        except DuplicateException, e:
659 2790 aaronmk
            into.name = next_version(into.name)
660
            # try again with next version of name
661
662 3303 aaronmk
    if add_pkey_: add_pkey(db, into)
663 3075 aaronmk
664
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
665
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
666
    # table is going to be used in complex queries, it is wise to run ANALYZE on
667
    # the temporary table after it is populated."
668
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
669
    # If into is not a temp table, ANALYZE is useful but not required.
670 3073 aaronmk
    analyze(db, into)
671 2790 aaronmk
672
    return cur
673 2085 aaronmk
674 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
675
676 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
677
678 3420 aaronmk
def mk_select(db, tables=None, fields=None, conds=None, distinct_on=[],
679 3494 aaronmk
    limit=None, start=None, order_by=order_by_pkey, default_table=None,
680
    explain=True):
681 1981 aaronmk
    '''
682 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
683 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
684 1981 aaronmk
    @param fields Use None to select all fields in the table
685 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
686 2379 aaronmk
        * container can be any iterable type
687 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
688
        * compare_right_side: sql_gen.ValueCond|literal value
689 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
690
        use all columns
691 2786 aaronmk
    @return query
692 1981 aaronmk
    '''
693 2315 aaronmk
    # Parse tables param
694 2964 aaronmk
    tables = lists.mk_seq(tables)
695 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
696 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
697 2121 aaronmk
698 2315 aaronmk
    # Parse other params
699 2376 aaronmk
    if conds == None: conds = []
700 2650 aaronmk
    elif dicts.is_dict(conds): conds = conds.items()
701 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
702 3129 aaronmk
    assert limit == None or isinstance(limit, (int, long))
703
    assert start == None or isinstance(start, (int, long))
704 2315 aaronmk
    if order_by is order_by_pkey:
705 3421 aaronmk
        if table0 == None or distinct_on != []: order_by = None
706 2315 aaronmk
        else: order_by = pkey(db, table0, recover=True)
707 865 aaronmk
708 2315 aaronmk
    query = 'SELECT'
709 2056 aaronmk
710 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
711 2056 aaronmk
712 2200 aaronmk
    # DISTINCT ON columns
713 2233 aaronmk
    if distinct_on != []:
714 2467 aaronmk
        query += '\nDISTINCT'
715 2254 aaronmk
        if distinct_on is not distinct_on_all:
716 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
717
718
    # Columns
719 3185 aaronmk
    if query.find('\n') >= 0: whitespace = '\n'
720
    else: whitespace = ' '
721
    if fields == None: query += whitespace+'*'
722 2765 aaronmk
    else:
723
        assert fields != []
724 3185 aaronmk
        if len(fields) > 1: whitespace = '\n'
725
        query += whitespace+('\n, '.join(map(parse_col, fields)))
726 2200 aaronmk
727
    # Main table
728 3185 aaronmk
    if query.find('\n') >= 0 or len(tables) > 0: whitespace = '\n'
729
    else: whitespace = ' '
730 3420 aaronmk
    if table0 != None: query += whitespace+'FROM '+table0.to_str(db)
731 865 aaronmk
732 2122 aaronmk
    # Add joins
733 2271 aaronmk
    left_table = table0
734 2263 aaronmk
    for join_ in tables:
735
        table = join_.table
736 2238 aaronmk
737 2343 aaronmk
        # Parse special values
738
        if join_.type_ is sql_gen.filter_out: # filter no match
739 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
740 2853 aaronmk
                sql_gen.CompareCond(None, '~=')))
741 2343 aaronmk
742 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
743 2122 aaronmk
744
        left_table = table
745
746 865 aaronmk
    missing = True
747 2376 aaronmk
    if conds != []:
748 2576 aaronmk
        if len(conds) == 1: whitespace = ' '
749
        else: whitespace = '\n'
750 2578 aaronmk
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
751
            .to_str(db) for l, r in conds], 'WHERE')
752 2227 aaronmk
    if order_by != None:
753 2467 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
754 3297 aaronmk
    if limit != None: query += '\nLIMIT '+str(limit)
755 865 aaronmk
    if start != None:
756 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
757 865 aaronmk
758 3494 aaronmk
    if explain: query = with_explain_comment(db, query)
759 3266 aaronmk
760 2786 aaronmk
    return query
761 11 aaronmk
762 2054 aaronmk
def select(db, *args, **kw_args):
763
    '''For params, see mk_select() and run_query()'''
764
    recover = kw_args.pop('recover', None)
765
    cacheable = kw_args.pop('cacheable', True)
766 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
767 2054 aaronmk
768 2791 aaronmk
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
769
        log_level=log_level)
770 2054 aaronmk
771 2788 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
772 3181 aaronmk
    embeddable=False, ignore=False, src=None):
773 1960 aaronmk
    '''
774
    @param returning str|None An inserted column (such as pkey) to return
775 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
776 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
777
        query will be fully cached, not just if it raises an exception.
778 3009 aaronmk
    @param ignore Whether to ignore duplicate keys.
779 3181 aaronmk
    @param src Will be included in the name of any created function, to help
780
        identify the data source in pg_stat_activity.
781 1960 aaronmk
    '''
782 2754 aaronmk
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
783 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
784 3010 aaronmk
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
785 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
786 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
787 2063 aaronmk
788 2497 aaronmk
    first_line = 'INSERT INTO '+table.to_str(db)
789 2063 aaronmk
790 3009 aaronmk
    def mk_insert(select_query):
791
        query = first_line
792 3014 aaronmk
        if cols != None:
793
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
794 3009 aaronmk
        query += '\n'+select_query
795
796
        if returning != None:
797
            returning_name_col = sql_gen.to_name_only_col(returning)
798
            query += '\nRETURNING '+returning_name_col.to_str(db)
799
800
        return query
801 2063 aaronmk
802 3489 aaronmk
    return_type = sql_gen.CustomCode('unknown')
803
    if returning != None: return_type = sql_gen.ColType(returning)
804 3017 aaronmk
805 3009 aaronmk
    if ignore:
806 3017 aaronmk
        # Always return something to set the correct rowcount
807
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
808
809 3009 aaronmk
        embeddable = True # must use function
810 3010 aaronmk
811 3450 aaronmk
        if cols == None: row = [sql_gen.Col(sql_gen.all_cols, 'row')]
812
        else: row = [sql_gen.Col(c.name, 'row') for c in cols]
813 3092 aaronmk
814 3484 aaronmk
        query = sql_gen.RowExcIgnore(sql_gen.RowType(table), select_query,
815 3497 aaronmk
            sql_gen.ReturnQuery(mk_insert(sql_gen.Values(row).to_str(db))),
816
            cols)
817 3009 aaronmk
    else: query = mk_insert(select_query)
818
819 2070 aaronmk
    if embeddable:
820
        # Create function
821 2513 aaronmk
        function_name = sql_gen.clean_name(first_line)
822 3181 aaronmk
        if src != None: function_name = src+': '+function_name
823 2189 aaronmk
        while True:
824
            try:
825 3451 aaronmk
                func = db.TempFunction(function_name)
826 3489 aaronmk
                def_ = sql_gen.FunctionDef(func, sql_gen.SetOf(return_type),
827
                    query)
828 2194 aaronmk
829 3443 aaronmk
                run_query(db, def_.to_str(db), recover=True, cacheable=True,
830 2945 aaronmk
                    log_ignore_excs=(DuplicateException,))
831 2189 aaronmk
                break # this version was successful
832 2945 aaronmk
            except DuplicateException, e:
833 2189 aaronmk
                function_name = next_version(function_name)
834
                # try again with next version of name
835 2070 aaronmk
836 2337 aaronmk
        # Return query that uses function
837 3009 aaronmk
        cols = None
838
        if returning != None: cols = [returning]
839 3451 aaronmk
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(func), cols)
840
            # AS clause requires function alias
841 3298 aaronmk
        return mk_select(db, func_table, order_by=None)
842 2070 aaronmk
843 2787 aaronmk
    return query
844 2066 aaronmk
845 3074 aaronmk
def insert_select(db, table, *args, **kw_args):
846 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
847 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
848
        values in
849 2072 aaronmk
    '''
850 3141 aaronmk
    returning = kw_args.get('returning', None)
851
    ignore = kw_args.get('ignore', False)
852
853 2386 aaronmk
    into = kw_args.pop('into', None)
854
    if into != None: kw_args['embeddable'] = True
855 2066 aaronmk
    recover = kw_args.pop('recover', None)
856 3141 aaronmk
    if ignore: recover = True
857 2066 aaronmk
    cacheable = kw_args.pop('cacheable', True)
858 2673 aaronmk
    log_level = kw_args.pop('log_level', 2)
859 2066 aaronmk
860 3141 aaronmk
    rowcount_only = ignore and returning == None # keep NULL rows on server
861
    if rowcount_only: into = sql_gen.Table('rowcount')
862
863 3074 aaronmk
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
864
        into, recover=recover, cacheable=cacheable, log_level=log_level)
865 3141 aaronmk
    if rowcount_only: empty_temp(db, into)
866 3074 aaronmk
    autoanalyze(db, table)
867
    return cur
868 2063 aaronmk
869 2738 aaronmk
default = sql_gen.default # tells insert() to use the default value for a column
870 2066 aaronmk
871 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
872 2085 aaronmk
    '''For params, see insert_select()'''
873 5050 aaronmk
    ignore = kw_args.pop('ignore', False)
874 5094 aaronmk
    if ignore: kw_args.setdefault('recover', True)
875 5050 aaronmk
876 1960 aaronmk
    if lists.is_seq(row): cols = None
877
    else:
878
        cols = row.keys()
879
        row = row.values()
880 2738 aaronmk
    row = list(row) # ensure that "== []" works
881 1960 aaronmk
882 2738 aaronmk
    if row == []: query = None
883
    else: query = sql_gen.Values(row).to_str(db)
884 1961 aaronmk
885 5050 aaronmk
    try: return insert_select(db, table, cols, query, *args, **kw_args)
886 5057 aaronmk
    except (DuplicateKeyException, NullValueException):
887 5050 aaronmk
        if not ignore: raise
888 5163 aaronmk
        return None
889 11 aaronmk
890 3152 aaronmk
def mk_update(db, table, changes=None, cond=None, in_place=False,
891 3153 aaronmk
    cacheable_=True):
892 2402 aaronmk
    '''
893
    @param changes [(col, new_value),...]
894
        * container can be any iterable type
895
        * col: sql_gen.Code|str (for col name)
896
        * new_value: sql_gen.Code|literal value
897
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
898 3056 aaronmk
    @param in_place If set, locks the table and updates rows in place.
899
        This avoids creating dead rows in PostgreSQL.
900
        * cond must be None
901 3153 aaronmk
    @param cacheable_ Whether column structure information used to generate the
902 3152 aaronmk
        query can be cached
903 2402 aaronmk
    @return str query
904
    '''
905 3057 aaronmk
    table = sql_gen.as_Table(table)
906
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
907
        for c, v in changes]
908
909 3056 aaronmk
    if in_place:
910
        assert cond == None
911 3058 aaronmk
912 3065 aaronmk
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
913
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
914 3153 aaronmk
            +db.col_info(sql_gen.with_default_table(c, table), cacheable_).type
915 3065 aaronmk
            +'\nUSING '+v.to_str(db) for c, v in changes))
916 3058 aaronmk
    else:
917
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
918
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
919
            for c, v in changes))
920
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
921 3056 aaronmk
922 3266 aaronmk
    query = with_explain_comment(db, query)
923
924 2402 aaronmk
    return query
925
926 3074 aaronmk
def update(db, table, *args, **kw_args):
927 2402 aaronmk
    '''For params, see mk_update() and run_query()'''
928
    recover = kw_args.pop('recover', None)
929 3043 aaronmk
    cacheable = kw_args.pop('cacheable', False)
930 3030 aaronmk
    log_level = kw_args.pop('log_level', 2)
931 2402 aaronmk
932 3074 aaronmk
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
933
        cacheable, log_level=log_level)
934
    autoanalyze(db, table)
935
    return cur
936 2402 aaronmk
937 3286 aaronmk
def mk_delete(db, table, cond=None):
938
    '''
939
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
940
    @return str query
941
    '''
942
    query = 'DELETE FROM '+table.to_str(db)
943
    if cond != None: query += '\nWHERE '+cond.to_str(db)
944
945
    query = with_explain_comment(db, query)
946
947
    return query
948
949
def delete(db, table, *args, **kw_args):
950
    '''For params, see mk_delete() and run_query()'''
951
    recover = kw_args.pop('recover', None)
952 3295 aaronmk
    cacheable = kw_args.pop('cacheable', True)
953 3286 aaronmk
    log_level = kw_args.pop('log_level', 2)
954
955
    cur = run_query(db, mk_delete(db, table, *args, **kw_args), recover,
956
        cacheable, log_level=log_level)
957
    autoanalyze(db, table)
958
    return cur
959
960 135 aaronmk
def last_insert_id(db):
961 1849 aaronmk
    module = util.root_module(db.db)
962 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
963
    elif module == 'MySQLdb': return db.insert_id()
964
    else: return None
965 13 aaronmk
966 3490 aaronmk
def define_func(db, def_):
967
    func = def_.function
968
    while True:
969
        try:
970
            run_query(db, def_.to_str(db), recover=True, cacheable=True,
971
                log_ignore_excs=(DuplicateException,))
972
            break # successful
973
        except DuplicateException:
974 3495 aaronmk
            func.name = next_version(func.name)
975 3490 aaronmk
            # try again with next version of name
976
977 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
978 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
979 2415 aaronmk
    to names that will be distinct among the columns' tables.
980 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
981 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
982
    @param into The table for the new columns.
983 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
984
        columns will be included in the mapping even if they are not in cols.
985
        The tables of the provided Col objects will be changed to into, so make
986
        copies of them if you want to keep the original tables.
987
    @param as_items Whether to return a list of dict items instead of a dict
988 2383 aaronmk
    @return dict(orig_col=new_col, ...)
989
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
990 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
991
        * All mappings use the into table so its name can easily be
992 2383 aaronmk
          changed for all columns at once
993
    '''
994 2415 aaronmk
    cols = lists.uniqify(cols)
995
996 2394 aaronmk
    items = []
997 2389 aaronmk
    for col in preserve:
998 2390 aaronmk
        orig_col = copy.copy(col)
999 2392 aaronmk
        col.table = into
1000 2394 aaronmk
        items.append((orig_col, col))
1001
    preserve = set(preserve)
1002
    for col in cols:
1003 2716 aaronmk
        if col not in preserve:
1004 3750 aaronmk
            items.append((col, sql_gen.Col(strings.ustr(col), into, col.srcs)))
1005 2394 aaronmk
1006
    if not as_items: items = dict(items)
1007
    return items
1008 2383 aaronmk
1009 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
1010 2391 aaronmk
    '''For params, see mk_flatten_mapping()
1011
    @return See return value of mk_flatten_mapping()
1012
    '''
1013 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
1014
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
1015 3296 aaronmk
    run_query_into(db, mk_select(db, joins, cols, order_by=None, limit=limit,
1016 3305 aaronmk
        start=start), into=into, add_pkey_=True)
1017 3708 aaronmk
        # don't cache because the temp table will usually be truncated after use
1018 2394 aaronmk
    return dict(items)
1019 2391 aaronmk
1020 3079 aaronmk
##### Database structure introspection
1021 2414 aaronmk
1022 3321 aaronmk
#### Expressions
1023
1024 5358 aaronmk
true_re = r'true'
1025
false_re = r'false'
1026 5345 aaronmk
bool_re = r'(?:'+true_re+r'|'+false_re+r')'
1027 5363 aaronmk
atom_re = r'(?:'+bool_re+r'|\([^()]*\)'+r')'
1028 3353 aaronmk
1029 5361 aaronmk
def logic_op_re(op, value_re, expr_re=''):
1030 5346 aaronmk
    op_re = ' '+op+' '
1031 5361 aaronmk
    return '(?:'+expr_re+op_re+value_re+'|'+value_re+op_re+expr_re+')'
1032 5346 aaronmk
1033 5364 aaronmk
and_false_re = logic_op_re('AND', false_re, atom_re)
1034 5365 aaronmk
and_true_re = logic_op_re('AND', true_re)
1035 5362 aaronmk
or_re = logic_op_re('OR', bool_re)
1036 5365 aaronmk
or_and_true_re = '(?:'+and_true_re+'|'+or_re+')'
1037 5362 aaronmk
1038 5350 aaronmk
def simplify_parens(expr):
1039 5363 aaronmk
    return regexp.sub_nested(r'\(('+atom_re+')\)', r'\1', expr)
1040 5350 aaronmk
1041 5355 aaronmk
def simplify_recursive(sub_func, expr):
1042
    '''
1043
    @param sub_func See regexp.sub_recursive() sub_func param
1044
    '''
1045
    return simplify_parens(regexp.sub_recursive(
1046
        lambda s: sub_func(simplify_parens(s)), expr))
1047
1048 3353 aaronmk
def simplify_expr(expr):
1049 5364 aaronmk
    def simplify_logic_ops(expr):
1050
        total_n = 0
1051
        expr, n = re.subn(and_false_re, 'false', expr)
1052
        total_n += n
1053 5365 aaronmk
        expr, n = re.subn(or_and_true_re, r'', expr)
1054 5364 aaronmk
        total_n += n
1055
        return expr, total_n
1056
1057 3353 aaronmk
    expr = expr.replace('(NULL IS NULL)', 'true')
1058
    expr = expr.replace('(NULL IS NOT NULL)', 'false')
1059 5364 aaronmk
    expr = simplify_recursive(simplify_logic_ops, expr)
1060 3353 aaronmk
    return expr
1061
1062 3321 aaronmk
name_re = r'(?:\w+|(?:"[^"]*")+)'
1063
1064
def parse_expr_col(str_):
1065
    match = re.match(r'^\('+name_re+r'\(('+name_re+r').*\)\)$', str_)
1066
    if match: str_ = match.group(1)
1067
    return sql_gen.unesc_name(str_)
1068
1069 3351 aaronmk
def map_expr(db, expr, mapping, in_cols_found=None):
1070
    '''Replaces output columns with input columns in an expression.
1071
    @param in_cols_found If set, will be filled in with the expr's (input) cols
1072
    '''
1073
    for out, in_ in mapping.iteritems():
1074
        orig_expr = expr
1075
        out = sql_gen.to_name_only_col(out)
1076
        in_str = sql_gen.to_name_only_col(sql_gen.remove_col_rename(in_)
1077
            ).to_str(db)
1078
1079
        # Replace out both with and without quotes
1080
        expr = expr.replace(out.to_str(db), in_str)
1081 5317 aaronmk
        expr = re.sub(r'(?<!["\'\.\[])\b'+out.name+r'\b(?!["\'\.=\]])', in_str,
1082
            expr)
1083 3351 aaronmk
1084
        if in_cols_found != None and expr != orig_expr: # replaced something
1085
            in_cols_found.append(in_)
1086 3353 aaronmk
1087
    return simplify_expr(expr)
1088 3351 aaronmk
1089 3079 aaronmk
#### Tables
1090
1091 4555 aaronmk
def tables(db, schema_like='public', table_like='%', exact=False,
1092
    cacheable=True):
1093 3079 aaronmk
    if exact: compare = '='
1094
    else: compare = 'LIKE'
1095
1096
    module = util.root_module(db.db)
1097
    if module == 'psycopg2':
1098
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
1099
            ('tablename', sql_gen.CompareCond(table_like, compare))]
1100
        return values(select(db, 'pg_tables', ['tablename'], conds,
1101 4555 aaronmk
            order_by='tablename', cacheable=cacheable, log_level=4))
1102 3079 aaronmk
    elif module == 'MySQLdb':
1103
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
1104
            , cacheable=True, log_level=4))
1105
    else: raise NotImplementedError("Can't list tables for "+module+' database')
1106
1107 4556 aaronmk
def table_exists(db, table, cacheable=True):
1108 3079 aaronmk
    table = sql_gen.as_Table(table)
1109 4556 aaronmk
    return list(tables(db, table.schema, table.name, True, cacheable)) != []
1110 3079 aaronmk
1111 2426 aaronmk
def table_row_count(db, table, recover=None):
1112 2786 aaronmk
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
1113 3298 aaronmk
        order_by=None), recover=recover, log_level=3))
1114 2426 aaronmk
1115 5337 aaronmk
def table_col_names(db, table, recover=None):
1116 2414 aaronmk
    return list(col_names(select(db, table, limit=0, order_by=None,
1117 2443 aaronmk
        recover=recover, log_level=4)))
1118 2414 aaronmk
1119 4261 aaronmk
pkey_col = 'row_num'
1120
1121 2291 aaronmk
def pkey(db, table, recover=None):
1122 5061 aaronmk
    '''If no pkey, returns the first column in the table.'''
1123
    table = sql_gen.as_Table(table)
1124
1125
    join_cols = ['table_schema', 'table_name', 'constraint_schema',
1126
        'constraint_name']
1127
    tables = [sql_gen.Table('key_column_usage', 'information_schema'),
1128
        sql_gen.Join(sql_gen.Table('table_constraints', 'information_schema'),
1129
            dict(((c, sql_gen.join_same_not_null) for c in join_cols)))]
1130
    cols = [sql_gen.Col('column_name')]
1131
1132
    conds = [('constraint_type', 'PRIMARY KEY'), ('table_name', table.name)]
1133
    schema = table.schema
1134
    if schema != None: conds.append(('table_schema', schema))
1135
    order_by = 'position_in_unique_constraint'
1136
1137
    try: return value(select(db, tables, cols, conds, order_by=order_by,
1138
        limit=1, log_level=4))
1139 5337 aaronmk
    except StopIteration: return table_col_names(db, table, recover)[0]
1140 832 aaronmk
1141 5128 aaronmk
def pkey_col_(db, table, *args, **kw_args):
1142
    return sql_gen.Col(pkey(db, table, *args, **kw_args), table)
1143
1144 2559 aaronmk
not_null_col = 'not_null_col'
1145 2340 aaronmk
1146
def table_not_null_col(db, table, recover=None):
1147
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
1148 5337 aaronmk
    if not_null_col in table_col_names(db, table, recover): return not_null_col
1149 2340 aaronmk
    else: return pkey(db, table, recover)
1150
1151 3348 aaronmk
def constraint_cond(db, constraint):
1152
    module = util.root_module(db.db)
1153
    if module == 'psycopg2':
1154
        table_str = sql_gen.Literal(constraint.table.to_str(db))
1155
        name_str = sql_gen.Literal(constraint.name)
1156
        return value(run_query(db, '''\
1157
SELECT consrc
1158
FROM pg_constraint
1159
WHERE
1160
conrelid = '''+table_str.to_str(db)+'''::regclass
1161
AND conname = '''+name_str.to_str(db)+'''
1162
'''
1163
            , cacheable=True, log_level=4))
1164
    else: raise NotImplementedError("Can't list index columns for "+module+
1165
        ' database')
1166
1167 3319 aaronmk
def index_cols(db, index):
1168 853 aaronmk
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
1169
    automatically created. When you don't know whether something is a UNIQUE
1170
    constraint or a UNIQUE index, use this function.'''
1171 3322 aaronmk
    index = sql_gen.as_Table(index)
1172 1909 aaronmk
    module = util.root_module(db.db)
1173
    if module == 'psycopg2':
1174 3322 aaronmk
        qual_index = sql_gen.Literal(index.to_str(db))
1175
        return map(parse_expr_col, values(run_query(db, '''\
1176
SELECT pg_get_indexdef(indexrelid, generate_series(1, indnatts), true)
1177
FROM pg_index
1178
WHERE indexrelid = '''+qual_index.to_str(db)+'''::regclass
1179 2782 aaronmk
'''
1180
            , cacheable=True, log_level=4)))
1181 1909 aaronmk
    else: raise NotImplementedError("Can't list index columns for "+module+
1182
        ' database')
1183 853 aaronmk
1184 3079 aaronmk
#### Functions
1185
1186
def function_exists(db, function):
1187 3423 aaronmk
    qual_function = sql_gen.Literal(function.to_str(db))
1188
    try:
1189 3425 aaronmk
        select(db, fields=[sql_gen.Cast('regproc', qual_function)],
1190
            recover=True, cacheable=True, log_level=4)
1191 3423 aaronmk
    except DoesNotExistException: return False
1192 4146 aaronmk
    except DuplicateException: return True # overloaded function
1193 3423 aaronmk
    else: return True
1194 3079 aaronmk
1195
##### Structural changes
1196
1197
#### Columns
1198
1199 5020 aaronmk
def add_col(db, table, col, comment=None, if_not_exists=False, **kw_args):
1200 3079 aaronmk
    '''
1201
    @param col TypedCol Name may be versioned, so be sure to propagate any
1202
        renaming back to any source column for the TypedCol.
1203
    @param comment None|str SQL comment used to distinguish columns of the same
1204
        name from each other when they contain different data, to allow the
1205
        ADD COLUMN query to be cached. If not set, query will not be cached.
1206
    '''
1207
    assert isinstance(col, sql_gen.TypedCol)
1208
1209
    while True:
1210
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1211
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1212
1213
        try:
1214
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1215
            break
1216
        except DuplicateException:
1217 5020 aaronmk
            if if_not_exists: raise
1218 3079 aaronmk
            col.name = next_version(col.name)
1219
            # try again with next version of name
1220
1221
def add_not_null(db, col):
1222
    table = col.table
1223
    col = sql_gen.to_name_only_col(col)
1224
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1225
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1226
1227 4443 aaronmk
def drop_not_null(db, col):
1228
    table = col.table
1229
    col = sql_gen.to_name_only_col(col)
1230
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1231
        +col.to_str(db)+' DROP NOT NULL', cacheable=True, log_level=3)
1232
1233 2096 aaronmk
row_num_col = '_row_num'
1234
1235 4997 aaronmk
row_num_col_def = sql_gen.TypedCol('', 'serial', nullable=False,
1236 3079 aaronmk
    constraints='PRIMARY KEY')
1237
1238 4997 aaronmk
def add_row_num(db, table, name=row_num_col):
1239
    '''Adds a row number column to a table. Its definition is in
1240
    row_num_col_def. It will be the primary key.'''
1241
    col_def = copy.copy(row_num_col_def)
1242
    col_def.name = name
1243 5021 aaronmk
    add_col(db, table, col_def, comment='', if_not_exists=True, log_level=3)
1244 3079 aaronmk
1245
#### Indexes
1246
1247
def add_pkey(db, table, cols=None, recover=None):
1248
    '''Adds a primary key.
1249
    @param cols [sql_gen.Col,...] The columns in the primary key.
1250
        Defaults to the first column in the table.
1251
    @pre The table must not already have a primary key.
1252
    '''
1253
    table = sql_gen.as_Table(table)
1254
    if cols == None: cols = [pkey(db, table, recover)]
1255
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1256
1257
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1258
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1259
        log_ignore_excs=(DuplicateException,))
1260
1261 2998 aaronmk
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1262 2688 aaronmk
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1263 3356 aaronmk
    Currently, only function calls and literal values are supported expressions.
1264 2998 aaronmk
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1265 2847 aaronmk
        This allows indexes to be used for comparisons where NULLs are equal.
1266 2538 aaronmk
    '''
1267 2964 aaronmk
    exprs = lists.mk_seq(exprs)
1268 2538 aaronmk
1269 2688 aaronmk
    # Parse exprs
1270
    old_exprs = exprs[:]
1271
    exprs = []
1272
    cols = []
1273
    for i, expr in enumerate(old_exprs):
1274 2823 aaronmk
        expr = sql_gen.as_Col(expr, table)
1275 2688 aaronmk
1276 2823 aaronmk
        # Handle nullable columns
1277 2998 aaronmk
        if ensure_not_null_:
1278 3164 aaronmk
            try: expr = sql_gen.ensure_not_null(db, expr)
1279 2860 aaronmk
            except KeyError: pass # unknown type, so just create plain index
1280 2823 aaronmk
1281 2688 aaronmk
        # Extract col
1282 3002 aaronmk
        expr = copy.deepcopy(expr) # don't modify input!
1283 3356 aaronmk
        col = expr
1284
        if isinstance(expr, sql_gen.FunctionCall): col = expr.args[0]
1285
        expr = sql_gen.cast_literal(expr)
1286
        if not isinstance(expr, (sql_gen.Expr, sql_gen.Col)):
1287 2688 aaronmk
            expr = sql_gen.Expr(expr)
1288 3356 aaronmk
1289 2688 aaronmk
1290
        # Extract table
1291
        if table == None:
1292
            assert sql_gen.is_table_col(col)
1293
            table = col.table
1294
1295 3356 aaronmk
        if isinstance(col, sql_gen.Col): col.table = None
1296 2688 aaronmk
1297
        exprs.append(expr)
1298
        cols.append(col)
1299 2408 aaronmk
1300 2688 aaronmk
    table = sql_gen.as_Table(table)
1301
1302 3005 aaronmk
    # Add index
1303 3148 aaronmk
    str_ = 'CREATE'
1304
    if unique: str_ += ' UNIQUE'
1305
    str_ += ' INDEX ON '+table.to_str(db)+' ('+(
1306
        ', '.join((v.to_str(db) for v in exprs)))+')'
1307
    run_query(db, str_, recover=True, cacheable=True, log_level=3)
1308 2408 aaronmk
1309 3083 aaronmk
already_indexed = object() # tells add_indexes() the pkey has already been added
1310
1311
def add_indexes(db, table, has_pkey=True):
1312
    '''Adds an index on all columns in a table.
1313
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1314
        index should be added on the first column.
1315
        * If already_indexed, the pkey is assumed to have already been added
1316
    '''
1317 5337 aaronmk
    cols = table_col_names(db, table)
1318 3083 aaronmk
    if has_pkey:
1319
        if has_pkey is not already_indexed: add_pkey(db, table)
1320
        cols = cols[1:]
1321
    for col in cols: add_index(db, col, table)
1322
1323 3079 aaronmk
#### Tables
1324 2772 aaronmk
1325 3079 aaronmk
### Maintenance
1326 2772 aaronmk
1327 3079 aaronmk
def analyze(db, table):
1328
    table = sql_gen.as_Table(table)
1329
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1330 2934 aaronmk
1331 3079 aaronmk
def autoanalyze(db, table):
1332
    if db.autoanalyze: analyze(db, table)
1333 2935 aaronmk
1334 3079 aaronmk
def vacuum(db, table):
1335
    table = sql_gen.as_Table(table)
1336
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1337
        log_level=3))
1338 2086 aaronmk
1339 3079 aaronmk
### Lifecycle
1340
1341 3247 aaronmk
def drop(db, type_, name):
1342
    name = sql_gen.as_Name(name)
1343
    run_query(db, 'DROP '+type_+' IF EXISTS '+name.to_str(db)+' CASCADE')
1344 2889 aaronmk
1345 3247 aaronmk
def drop_table(db, table): drop(db, 'TABLE', table)
1346
1347 3082 aaronmk
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1348
    like=None):
1349 2675 aaronmk
    '''Creates a table.
1350 2681 aaronmk
    @param cols [sql_gen.TypedCol,...] The column names and types
1351
    @param has_pkey If set, the first column becomes the primary key.
1352 2760 aaronmk
    @param col_indexes bool|[ref]
1353
        * If True, indexes will be added on all non-pkey columns.
1354
        * If a list reference, [0] will be set to a function to do this.
1355
          This can be used to delay index creation until the table is populated.
1356 2675 aaronmk
    '''
1357
    table = sql_gen.as_Table(table)
1358
1359 3082 aaronmk
    if like != None:
1360
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1361
            ]+cols
1362 2681 aaronmk
    if has_pkey:
1363
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1364 2872 aaronmk
        pkey.constraints = 'PRIMARY KEY'
1365 2681 aaronmk
1366 3085 aaronmk
    temp = table.is_temp and not db.debug_temp
1367
        # temp tables permanent in debug_temp mode
1368 2760 aaronmk
1369 3085 aaronmk
    # Create table
1370 3383 aaronmk
    def create():
1371 3085 aaronmk
        str_ = 'CREATE'
1372
        if temp: str_ += ' TEMP'
1373
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1374
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1375 3126 aaronmk
        str_ += '\n);'
1376 3085 aaronmk
1377 3383 aaronmk
        run_query(db, str_, recover=True, cacheable=True, log_level=2,
1378
            log_ignore_excs=(DuplicateException,))
1379
    if table.is_temp:
1380
        while True:
1381
            try:
1382
                create()
1383
                break
1384
            except DuplicateException:
1385
                table.name = next_version(table.name)
1386
                # try again with next version of name
1387
    else: create()
1388 3085 aaronmk
1389 2760 aaronmk
    # Add indexes
1390 2773 aaronmk
    if has_pkey: has_pkey = already_indexed
1391
    def add_indexes_(): add_indexes(db, table, has_pkey)
1392
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1393
    elif col_indexes: add_indexes_() # add now
1394 2675 aaronmk
1395 3084 aaronmk
def copy_table_struct(db, src, dest):
1396
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1397 3085 aaronmk
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1398 3084 aaronmk
1399 3079 aaronmk
### Data
1400 2684 aaronmk
1401 2970 aaronmk
def truncate(db, table, schema='public', **kw_args):
1402
    '''For params, see run_query()'''
1403 2777 aaronmk
    table = sql_gen.as_Table(table, schema)
1404 2970 aaronmk
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1405 2732 aaronmk
1406 2965 aaronmk
def empty_temp(db, tables):
1407
    tables = lists.mk_seq(tables)
1408 2971 aaronmk
    for table in tables: truncate(db, table, log_level=3)
1409 2965 aaronmk
1410 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
1411
    '''For kw_args, see tables()'''
1412
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1413 3094 aaronmk
1414
def distinct_table(db, table, distinct_on):
1415
    '''Creates a copy of a temp table which is distinct on the given columns.
1416 3099 aaronmk
    The old and new tables will both get an index on these columns, to
1417
    facilitate merge joins.
1418 3097 aaronmk
    @param distinct_on If empty, creates a table with one row. This is useful if
1419
        your distinct_on columns are all literal values.
1420 3099 aaronmk
    @return The new table.
1421 3094 aaronmk
    '''
1422 3099 aaronmk
    new_table = sql_gen.suffixed_table(table, '_distinct')
1423 3411 aaronmk
    distinct_on = filter(sql_gen.is_table_col, distinct_on)
1424 3094 aaronmk
1425 3099 aaronmk
    copy_table_struct(db, table, new_table)
1426 3097 aaronmk
1427
    limit = None
1428
    if distinct_on == []: limit = 1 # one sample row
1429 3099 aaronmk
    else:
1430
        add_index(db, distinct_on, new_table, unique=True)
1431
        add_index(db, distinct_on, table) # for join optimization
1432 3097 aaronmk
1433 3313 aaronmk
    insert_select(db, new_table, None, mk_select(db, table, order_by=None,
1434
        limit=limit), ignore=True)
1435 3099 aaronmk
    analyze(db, new_table)
1436 3094 aaronmk
1437 3099 aaronmk
    return new_table