Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 3238 aaronmk
import time
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 3241 aaronmk
import profiling
13 1889 aaronmk
from Proxy import Proxy
14 1872 aaronmk
import rand
15 5349 aaronmk
import regexp
16 2217 aaronmk
import sql_gen
17 862 aaronmk
import strings
18 131 aaronmk
import util
19 11 aaronmk
20 832 aaronmk
##### Exceptions
21
22 2804 aaronmk
def get_cur_query(cur, input_query=None):
23 2168 aaronmk
    raw_query = None
24
    if hasattr(cur, 'query'): raw_query = cur.query
25
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
26 2170 aaronmk
27
    if raw_query != None: return raw_query
28 2804 aaronmk
    else: return '[input] '+strings.ustr(input_query)
29 14 aaronmk
30 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
31
    '''For params, see get_cur_query()'''
32 2771 aaronmk
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
33 135 aaronmk
34 300 aaronmk
class DbException(exc.ExceptionWithCause):
35 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
36 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
37 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
38
39 2143 aaronmk
class ExceptionWithName(DbException):
40
    def __init__(self, name, cause=None):
41 4491 aaronmk
        DbException.__init__(self, 'for name: '
42
            +strings.as_tt(strings.ustr(name)), cause)
43 2143 aaronmk
        self.name = name
44 360 aaronmk
45 3109 aaronmk
class ExceptionWithValue(DbException):
46
    def __init__(self, value, cause=None):
47 4492 aaronmk
        DbException.__init__(self, 'for value: '
48
            +strings.as_tt(strings.urepr(value)), cause)
49 2240 aaronmk
        self.value = value
50
51 2945 aaronmk
class ExceptionWithNameType(DbException):
52
    def __init__(self, type_, name, cause=None):
53 4491 aaronmk
        DbException.__init__(self, 'for type: '+strings.as_tt(strings.ustr(
54
            type_))+'; name: '+strings.as_tt(name), cause)
55 2945 aaronmk
        self.type = type_
56
        self.name = name
57
58 2306 aaronmk
class ConstraintException(DbException):
59 3345 aaronmk
    def __init__(self, name, cond, cols, cause=None):
60
        msg = 'Violated '+strings.as_tt(name)+' constraint'
61
        if cond != None: msg += ' with condition '+cond
62
        if cols != []: msg += ' on columns: '+strings.as_tt(', '.join(cols))
63
        DbException.__init__(self, msg, cause)
64 2306 aaronmk
        self.name = name
65 3345 aaronmk
        self.cond = cond
66 468 aaronmk
        self.cols = cols
67 11 aaronmk
68 2523 aaronmk
class MissingCastException(DbException):
69 4139 aaronmk
    def __init__(self, type_, col=None, cause=None):
70
        msg = 'Missing cast to type '+strings.as_tt(type_)
71
        if col != None: msg += ' on column: '+strings.as_tt(col)
72
        DbException.__init__(self, msg, cause)
73 2523 aaronmk
        self.type = type_
74
        self.col = col
75
76 2143 aaronmk
class NameException(DbException): pass
77
78 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
79 13 aaronmk
80 2306 aaronmk
class NullValueException(ConstraintException): pass
81 13 aaronmk
82 3346 aaronmk
class CheckException(ConstraintException): pass
83
84 3109 aaronmk
class InvalidValueException(ExceptionWithValue): pass
85 2239 aaronmk
86 2945 aaronmk
class DuplicateException(ExceptionWithNameType): pass
87 2143 aaronmk
88 3419 aaronmk
class DoesNotExistException(ExceptionWithNameType): pass
89
90 89 aaronmk
class EmptyRowException(DbException): pass
91
92 865 aaronmk
##### Warnings
93
94
class DbWarning(UserWarning): pass
95
96 1930 aaronmk
##### Result retrieval
97
98
def col_names(cur): return (col[0] for col in cur.description)
99
100
def rows(cur): return iter(lambda: cur.fetchone(), None)
101
102
def consume_rows(cur):
103
    '''Used to fetch all rows so result will be cached'''
104
    iters.consume_iter(rows(cur))
105
106
def next_row(cur): return rows(cur).next()
107
108
def row(cur):
109
    row_ = next_row(cur)
110
    consume_rows(cur)
111
    return row_
112
113
def next_value(cur): return next_row(cur)[0]
114
115
def value(cur): return row(cur)[0]
116
117
def values(cur): return iters.func_iter(lambda: next_value(cur))
118
119
def value_or_none(cur):
120
    try: return value(cur)
121
    except StopIteration: return None
122
123 2762 aaronmk
##### Escaping
124 2101 aaronmk
125 2573 aaronmk
def esc_name_by_module(module, name):
126
    if module == 'psycopg2' or module == None: quote = '"'
127 2101 aaronmk
    elif module == 'MySQLdb': quote = '`'
128
    else: raise NotImplementedError("Can't escape name for "+module+' database')
129 2500 aaronmk
    return sql_gen.esc_name(name, quote)
130 2101 aaronmk
131
def esc_name_by_engine(engine, name, **kw_args):
132
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
133
134
def esc_name(db, name, **kw_args):
135
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
136
137
def qual_name(db, schema, table):
138
    def esc_name_(name): return esc_name(db, name)
139
    table = esc_name_(table)
140
    if schema != None: return esc_name_(schema)+'.'+table
141
    else: return table
142
143 1869 aaronmk
##### Database connections
144 1849 aaronmk
145 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
146 1926 aaronmk
147 1869 aaronmk
db_engines = {
148
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
149
    'PostgreSQL': ('psycopg2', {}),
150
}
151
152
DatabaseErrors_set = set([DbException])
153
DatabaseErrors = tuple(DatabaseErrors_set)
154
155
def _add_module(module):
156
    DatabaseErrors_set.add(module.DatabaseError)
157
    global DatabaseErrors
158
    DatabaseErrors = tuple(DatabaseErrors_set)
159
160
def db_config_str(db_config):
161
    return db_config['engine']+' database '+db_config['database']
162
163 2448 aaronmk
log_debug_none = lambda msg, level=2: None
164 1901 aaronmk
165 1849 aaronmk
class DbConn:
166 2923 aaronmk
    def __init__(self, db_config, autocommit=True, caching=True,
167 3183 aaronmk
        log_debug=log_debug_none, debug_temp=False, src=None):
168 2915 aaronmk
        '''
169
        @param debug_temp Whether temporary objects should instead be permanent.
170
            This assists in debugging the internal objects used by the program.
171 3183 aaronmk
        @param src In autocommit mode, will be included in a comment in every
172
            query, to help identify the data source in pg_stat_activity.
173 2915 aaronmk
        '''
174 1869 aaronmk
        self.db_config = db_config
175 2190 aaronmk
        self.autocommit = autocommit
176
        self.caching = caching
177 1901 aaronmk
        self.log_debug = log_debug
178 2193 aaronmk
        self.debug = log_debug != log_debug_none
179 2915 aaronmk
        self.debug_temp = debug_temp
180 3183 aaronmk
        self.src = src
181 3074 aaronmk
        self.autoanalyze = False
182 3269 aaronmk
        self.autoexplain = False
183
        self.profile_row_ct = None
184 1869 aaronmk
185 3124 aaronmk
        self._savepoint = 0
186 3120 aaronmk
        self._reset()
187 1869 aaronmk
188
    def __getattr__(self, name):
189
        if name == '__dict__': raise Exception('getting __dict__')
190
        if name == 'db': return self._db()
191
        else: raise AttributeError()
192
193
    def __getstate__(self):
194
        state = copy.copy(self.__dict__) # shallow copy
195 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
196 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
197
        return state
198
199 3118 aaronmk
    def clear_cache(self): self.query_results = {}
200
201 3120 aaronmk
    def _reset(self):
202 3118 aaronmk
        self.clear_cache()
203 3124 aaronmk
        assert self._savepoint == 0
204 3118 aaronmk
        self._notices_seen = set()
205
        self.__db = None
206
207 2165 aaronmk
    def connected(self): return self.__db != None
208
209 3116 aaronmk
    def close(self):
210 3119 aaronmk
        if not self.connected(): return
211
212 3135 aaronmk
        # Record that the automatic transaction is now closed
213 3136 aaronmk
        self._savepoint -= 1
214 3135 aaronmk
215 3119 aaronmk
        self.db.close()
216 3120 aaronmk
        self._reset()
217 3116 aaronmk
218 3125 aaronmk
    def reconnect(self):
219
        # Do not do this in test mode as it would roll back everything
220
        if self.autocommit: self.close()
221
        # Connection will be reopened automatically on first query
222
223 1869 aaronmk
    def _db(self):
224
        if self.__db == None:
225
            # Process db_config
226
            db_config = self.db_config.copy() # don't modify input!
227 2097 aaronmk
            schemas = db_config.pop('schemas', None)
228 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
229
            module = __import__(module_name)
230
            _add_module(module)
231
            for orig, new in mappings.iteritems():
232
                try: util.rename_key(db_config, orig, new)
233
                except KeyError: pass
234
235
            # Connect
236
            self.__db = module.connect(**db_config)
237
238 3161 aaronmk
            # Record that a transaction is already open
239
            self._savepoint += 1
240
241 1869 aaronmk
            # Configure connection
242 2906 aaronmk
            if hasattr(self.db, 'set_isolation_level'):
243
                import psycopg2.extensions
244
                self.db.set_isolation_level(
245
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
246 2101 aaronmk
            if schemas != None:
247 2893 aaronmk
                search_path = [self.esc_name(s) for s in schemas.split(',')]
248
                search_path.append(value(run_query(self, 'SHOW search_path',
249
                    log_level=4)))
250
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
251
                    log_level=3)
252 1869 aaronmk
253
        return self.__db
254 1889 aaronmk
255 1891 aaronmk
    class DbCursor(Proxy):
256 1927 aaronmk
        def __init__(self, outer):
257 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
258 2191 aaronmk
            self.outer = outer
259 1927 aaronmk
            self.query_results = outer.query_results
260 1894 aaronmk
            self.query_lookup = None
261 1891 aaronmk
            self.result = []
262 1889 aaronmk
263 2802 aaronmk
        def execute(self, query):
264 2764 aaronmk
            self._is_insert = query.startswith('INSERT')
265 2797 aaronmk
            self.query_lookup = query
266 2148 aaronmk
            try:
267 3162 aaronmk
                try: cur = self.inner.execute(query)
268 2802 aaronmk
                finally: self.query = get_cur_query(self.inner, query)
269 1904 aaronmk
            except Exception, e:
270
                self.result = e # cache the exception as the result
271
                self._cache_result()
272
                raise
273 3004 aaronmk
274
            # Always cache certain queries
275 3183 aaronmk
            query = sql_gen.lstrip(query)
276 3004 aaronmk
            if query.startswith('CREATE') or query.startswith('ALTER'):
277 3007 aaronmk
                # structural changes
278 3040 aaronmk
                # Rest of query must be unique in the face of name collisions,
279
                # so don't cache ADD COLUMN unless it has distinguishing comment
280
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
281 3007 aaronmk
                    self._cache_result()
282 3004 aaronmk
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
283 2800 aaronmk
                consume_rows(self) # fetch all rows so result will be cached
284 3004 aaronmk
285 2762 aaronmk
            return cur
286 1894 aaronmk
287 1891 aaronmk
        def fetchone(self):
288
            row = self.inner.fetchone()
289 1899 aaronmk
            if row != None: self.result.append(row)
290
            # otherwise, fetched all rows
291 1904 aaronmk
            else: self._cache_result()
292
            return row
293
294
        def _cache_result(self):
295 2948 aaronmk
            # For inserts that return a result set, don't cache result set since
296
            # inserts are not idempotent. Other non-SELECT queries don't have
297
            # their result set read, so only exceptions will be cached (an
298
            # invalid query will always be invalid).
299 1930 aaronmk
            if self.query_results != None and (not self._is_insert
300 1906 aaronmk
                or isinstance(self.result, Exception)):
301
302 1894 aaronmk
                assert self.query_lookup != None
303 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
304
                    util.dict_subset(dicts.AttrsDictView(self),
305
                    ['query', 'result', 'rowcount', 'description']))
306 1906 aaronmk
307 1916 aaronmk
        class CacheCursor:
308
            def __init__(self, cached_result): self.__dict__ = cached_result
309
310 1927 aaronmk
            def execute(self, *args, **kw_args):
311 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
312
                # otherwise, result is a rows list
313
                self.iter = iter(self.result)
314
315
            def fetchone(self):
316
                try: return self.iter.next()
317
                except StopIteration: return None
318 1891 aaronmk
319 2212 aaronmk
    def esc_value(self, value):
320 2663 aaronmk
        try: str_ = self.mogrify('%s', [value])
321
        except NotImplementedError, e:
322
            module = util.root_module(self.db)
323
            if module == 'MySQLdb':
324
                import _mysql
325
                str_ = _mysql.escape_string(value)
326
            else: raise e
327 2374 aaronmk
        return strings.to_unicode(str_)
328 2212 aaronmk
329 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
330
331 2814 aaronmk
    def std_code(self, str_):
332
        '''Standardizes SQL code.
333
        * Ensures that string literals are prefixed by `E`
334
        '''
335
        if str_.startswith("'"): str_ = 'E'+str_
336
        return str_
337
338 2665 aaronmk
    def can_mogrify(self):
339 2663 aaronmk
        module = util.root_module(self.db)
340 2665 aaronmk
        return module == 'psycopg2'
341 2663 aaronmk
342 2665 aaronmk
    def mogrify(self, query, params=None):
343
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
344
        else: raise NotImplementedError("Can't mogrify query")
345
346 2671 aaronmk
    def print_notices(self):
347 2725 aaronmk
        if hasattr(self.db, 'notices'):
348
            for msg in self.db.notices:
349
                if msg not in self._notices_seen:
350
                    self._notices_seen.add(msg)
351
                    self.log_debug(msg, level=2)
352 2671 aaronmk
353 2793 aaronmk
    def run_query(self, query, cacheable=False, log_level=2,
354 2464 aaronmk
        debug_msg_ref=None):
355 2445 aaronmk
        '''
356 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
357
            throws one of these exceptions.
358 2664 aaronmk
        @param debug_msg_ref If specified, the log message will be returned in
359
            this instead of being output. This allows you to filter log messages
360
            depending on the result of the query.
361 2445 aaronmk
        '''
362 2167 aaronmk
        assert query != None
363
364 3183 aaronmk
        if self.autocommit and self.src != None:
365 3206 aaronmk
            query = sql_gen.esc_comment(self.src)+'\t'+query
366 3183 aaronmk
367 2047 aaronmk
        if not self.caching: cacheable = False
368 1903 aaronmk
        used_cache = False
369 2664 aaronmk
370 3242 aaronmk
        if self.debug:
371
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
372 1903 aaronmk
        try:
373 1927 aaronmk
            # Get cursor
374
            if cacheable:
375 3238 aaronmk
                try: cur = self.query_results[query]
376 1927 aaronmk
                except KeyError: cur = self.DbCursor(self)
377 3238 aaronmk
                else: used_cache = True
378 1927 aaronmk
            else: cur = self.db.cursor()
379
380
            # Run query
381 3238 aaronmk
            try: cur.execute(query)
382 3162 aaronmk
            except Exception, e:
383
                _add_cursor_info(e, self, query)
384
                raise
385 3238 aaronmk
            else: self.do_autocommit()
386 1903 aaronmk
        finally:
387 3242 aaronmk
            if self.debug:
388 3244 aaronmk
                profiler.stop(self.profile_row_ct)
389 3242 aaronmk
390
                ## Log or return query
391
392 4491 aaronmk
                query = strings.ustr(get_cur_query(cur, query))
393 3281 aaronmk
                # Put the src comment on a separate line in the log file
394
                query = query.replace('\t', '\n', 1)
395 3239 aaronmk
396 3240 aaronmk
                msg = 'DB query: '
397 3239 aaronmk
398 3240 aaronmk
                if used_cache: msg += 'cache hit'
399
                elif cacheable: msg += 'cache miss'
400
                else: msg += 'non-cacheable'
401 3239 aaronmk
402 3241 aaronmk
                msg += ':\n'+profiler.msg()+'\n'+strings.as_code(query, 'SQL')
403 3240 aaronmk
404 3237 aaronmk
                if debug_msg_ref != None: debug_msg_ref[0] = msg
405
                else: self.log_debug(msg, log_level)
406 3245 aaronmk
407
                self.print_notices()
408 1903 aaronmk
409
        return cur
410 1914 aaronmk
411 2797 aaronmk
    def is_cached(self, query): return query in self.query_results
412 2139 aaronmk
413 2907 aaronmk
    def with_autocommit(self, func):
414 2801 aaronmk
        import psycopg2.extensions
415
416
        prev_isolation_level = self.db.isolation_level
417 2907 aaronmk
        self.db.set_isolation_level(
418
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
419 2683 aaronmk
        try: return func()
420 2801 aaronmk
        finally: self.db.set_isolation_level(prev_isolation_level)
421 2683 aaronmk
422 2139 aaronmk
    def with_savepoint(self, func):
423 3137 aaronmk
        top = self._savepoint == 0
424 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
425 3137 aaronmk
426 3272 aaronmk
        if self.debug:
427 3273 aaronmk
            self.log_debug('Begin transaction', level=4)
428 3272 aaronmk
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
429
430 3160 aaronmk
        # Must happen before running queries so they don't get autocommitted
431
        self._savepoint += 1
432
433 3137 aaronmk
        if top: query = 'START TRANSACTION ISOLATION LEVEL READ COMMITTED'
434
        else: query = 'SAVEPOINT '+savepoint
435
        self.run_query(query, log_level=4)
436
        try:
437
            return func()
438
            if top: self.run_query('COMMIT', log_level=4)
439 2139 aaronmk
        except:
440 3137 aaronmk
            if top: query = 'ROLLBACK'
441
            else: query = 'ROLLBACK TO SAVEPOINT '+savepoint
442
            self.run_query(query, log_level=4)
443
444 2139 aaronmk
            raise
445 2930 aaronmk
        finally:
446
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
447
            # "The savepoint remains valid and can be rolled back to again"
448
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
449 3137 aaronmk
            if not top:
450
                self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
451 2930 aaronmk
452
            self._savepoint -= 1
453
            assert self._savepoint >= 0
454
455 3272 aaronmk
            if self.debug:
456
                profiler.stop(self.profile_row_ct)
457 3273 aaronmk
                self.log_debug('End transaction\n'+profiler.msg(), level=4)
458 3272 aaronmk
459 2930 aaronmk
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
460 2191 aaronmk
461
    def do_autocommit(self):
462
        '''Autocommits if outside savepoint'''
463 3135 aaronmk
        assert self._savepoint >= 1
464
        if self.autocommit and self._savepoint == 1:
465 2924 aaronmk
            self.log_debug('Autocommitting', level=4)
466 2191 aaronmk
            self.db.commit()
467 2643 aaronmk
468 3155 aaronmk
    def col_info(self, col, cacheable=True):
469 2643 aaronmk
        table = sql_gen.Table('columns', 'information_schema')
470 4936 aaronmk
        cols = [sql_gen.Col('data_type'), sql_gen.Col('udt_name'),
471
            'column_default', sql_gen.Cast('boolean',
472
            sql_gen.Col('is_nullable'))]
473 2643 aaronmk
474 3750 aaronmk
        conds = [('table_name', col.table.name),
475
            ('column_name', strings.ustr(col.name))]
476 2643 aaronmk
        schema = col.table.schema
477
        if schema != None: conds.append(('table_schema', schema))
478
479 3638 aaronmk
        cur = select(self, table, cols, conds, order_by='table_schema', limit=1,
480
            cacheable=cacheable, log_level=4) # TODO: order by search_path order
481 4936 aaronmk
        try: type_, extra_type, default, nullable = row(cur)
482 4114 aaronmk
        except StopIteration: raise sql_gen.NoUnderlyingTableException(col)
483 2819 aaronmk
        default = sql_gen.as_Code(default, self)
484 4936 aaronmk
        if type_ == 'USER-DEFINED': type_ = extra_type
485 4939 aaronmk
        elif type_ == 'ARRAY':
486
            type_ = sql_gen.ArrayType(strings.remove_prefix('_', extra_type,
487
                require=True))
488 2819 aaronmk
489
        return sql_gen.TypedCol(col.name, type_, default, nullable)
490 2917 aaronmk
491
    def TempFunction(self, name):
492
        if self.debug_temp: schema = None
493
        else: schema = 'pg_temp'
494
        return sql_gen.Function(name, schema)
495 1849 aaronmk
496 1869 aaronmk
connect = DbConn
497
498 832 aaronmk
##### Recoverable querying
499 15 aaronmk
500 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
501 11 aaronmk
502 2791 aaronmk
def run_query(db, query, recover=None, cacheable=False, log_level=2,
503
    log_ignore_excs=None, **kw_args):
504 2794 aaronmk
    '''For params, see DbConn.run_query()'''
505 830 aaronmk
    if recover == None: recover = False
506 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
507
    log_ignore_excs = tuple(log_ignore_excs)
508 3236 aaronmk
    debug_msg_ref = [None]
509 830 aaronmk
510 3267 aaronmk
    query = with_explain_comment(db, query)
511 3258 aaronmk
512 2148 aaronmk
    try:
513 2464 aaronmk
        try:
514 2794 aaronmk
            def run(): return db.run_query(query, cacheable, log_level,
515 2793 aaronmk
                debug_msg_ref, **kw_args)
516 2796 aaronmk
            if recover and not db.is_cached(query):
517 2464 aaronmk
                return with_savepoint(db, run)
518
            else: return run() # don't need savepoint if cached
519
        except Exception, e:
520 3095 aaronmk
            msg = strings.ustr(e.args[0])
521 4103 aaronmk
            msg = re.sub(r'^(?:PL/Python: )?ValueError: ', r'', msg)
522 2464 aaronmk
523 3095 aaronmk
            match = re.match(r'^duplicate key value violates unique constraint '
524 3338 aaronmk
                r'"(.+?)"', msg)
525 2464 aaronmk
            if match:
526 3338 aaronmk
                constraint, = match.groups()
527 3025 aaronmk
                cols = []
528
                if recover: # need auto-rollback to run index_cols()
529 3319 aaronmk
                    try: cols = index_cols(db, constraint)
530 3025 aaronmk
                    except NotImplementedError: pass
531 3345 aaronmk
                raise DuplicateKeyException(constraint, None, cols, e)
532 2464 aaronmk
533 3095 aaronmk
            match = re.match(r'^null value in column "(.+?)" violates not-null'
534 2464 aaronmk
                r' constraint', msg)
535 3345 aaronmk
            if match:
536
                col, = match.groups()
537
                raise NullValueException('NOT NULL', None, [col], e)
538 2464 aaronmk
539 3346 aaronmk
            match = re.match(r'^new row for relation "(.+?)" violates check '
540
                r'constraint "(.+?)"', msg)
541
            if match:
542
                table, constraint = match.groups()
543 3347 aaronmk
                constraint = sql_gen.Col(constraint, table)
544 3349 aaronmk
                cond = None
545
                if recover: # need auto-rollback to run constraint_cond()
546
                    try: cond = constraint_cond(db, constraint)
547
                    except NotImplementedError: pass
548
                raise CheckException(constraint.to_str(db), cond, [], e)
549 3346 aaronmk
550 3095 aaronmk
            match = re.match(r'^(?:invalid input (?:syntax|value)\b.*?'
551 3635 aaronmk
                r'|.+? out of range): "(.+?)"', msg)
552 2464 aaronmk
            if match:
553 3109 aaronmk
                value, = match.groups()
554
                raise InvalidValueException(strings.to_unicode(value), e)
555 2464 aaronmk
556 3095 aaronmk
            match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
557 2523 aaronmk
                r'is of type', msg)
558
            if match:
559
                col, type_ = match.groups()
560
                raise MissingCastException(type_, col, e)
561
562 4141 aaronmk
            match = re.match(r'^could not determine polymorphic type because '
563
                r'input has type "unknown"', msg)
564
            if match: raise MissingCastException('text', None, e)
565
566 4485 aaronmk
            match = re.match(r'^.+? types .+? and .+? cannot be matched', msg)
567
            if match: raise MissingCastException('text', None, e)
568
569 4509 aaronmk
            typed_name_re = r'^(\S+) "(.+?)"(?: of relation ".+?")?'
570 3419 aaronmk
571
            match = re.match(typed_name_re+r'.*? already exists', msg)
572 2945 aaronmk
            if match:
573
                type_, name = match.groups()
574
                raise DuplicateException(type_, name, e)
575 2464 aaronmk
576 4145 aaronmk
            match = re.match(r'more than one (\S+) named ""(.+?)""', msg)
577
            if match:
578
                type_, name = match.groups()
579
                raise DuplicateException(type_, name, e)
580
581 3419 aaronmk
            match = re.match(typed_name_re+r' does not exist', msg)
582
            if match:
583
                type_, name = match.groups()
584
                raise DoesNotExistException(type_, name, e)
585
586 2464 aaronmk
            raise # no specific exception raised
587
    except log_ignore_excs:
588
        log_level += 2
589
        raise
590
    finally:
591 3236 aaronmk
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)
592 830 aaronmk
593 832 aaronmk
##### Basic queries
594
595 3256 aaronmk
def is_explainable(query):
596
    # See <http://www.postgresql.org/docs/8.3/static/sql-explain.html#AEN57749>
597 3257 aaronmk
    return re.match(r'^(?:SELECT|INSERT|UPDATE|DELETE|VALUES|EXECUTE|DECLARE)\b'
598
        , query)
599 3256 aaronmk
600 3263 aaronmk
def explain(db, query, **kw_args):
601
    '''
602
    For params, see run_query().
603
    '''
604 3267 aaronmk
    kw_args.setdefault('log_level', 4)
605 3263 aaronmk
606 3750 aaronmk
    return strings.ustr(strings.join_lines(values(run_query(db,
607
        'EXPLAIN '+query, recover=True, cacheable=True, **kw_args))))
608 3256 aaronmk
        # not a higher log_level because it's useful to see what query is being
609
        # run before it's executed, which EXPLAIN effectively provides
610
611 3265 aaronmk
def has_comment(query): return query.endswith('*/')
612
613
def with_explain_comment(db, query, **kw_args):
614 3269 aaronmk
    if db.autoexplain and not has_comment(query) and is_explainable(query):
615 3265 aaronmk
        query += '\n'+sql_gen.esc_comment(' EXPLAIN:\n'
616
            +explain(db, query, **kw_args))
617
    return query
618
619 2153 aaronmk
def next_version(name):
620 2163 aaronmk
    version = 1 # first existing name was version 0
621 2586 aaronmk
    match = re.match(r'^(.*)#(\d+)$', name)
622 2153 aaronmk
    if match:
623 2586 aaronmk
        name, version = match.groups()
624
        version = int(version)+1
625 2932 aaronmk
    return sql_gen.concat(name, '#'+str(version))
626 2153 aaronmk
627 2899 aaronmk
def lock_table(db, table, mode):
628
    table = sql_gen.as_Table(table)
629
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
630
631 3303 aaronmk
def run_query_into(db, query, into=None, add_pkey_=False, **kw_args):
632 2085 aaronmk
    '''Outputs a query to a temp table.
633
    For params, see run_query().
634
    '''
635 2789 aaronmk
    if into == None: return run_query(db, query, **kw_args)
636 2790 aaronmk
637
    assert isinstance(into, sql_gen.Table)
638
639 2992 aaronmk
    into.is_temp = True
640 3008 aaronmk
    # "temporary tables cannot specify a schema name", so remove schema
641
    into.schema = None
642 2992 aaronmk
643 2790 aaronmk
    kw_args['recover'] = True
644 2945 aaronmk
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
645 2790 aaronmk
646 2916 aaronmk
    temp = not db.debug_temp # tables are permanent in debug_temp mode
647 2790 aaronmk
648
    # Create table
649
    while True:
650
        create_query = 'CREATE'
651
        if temp: create_query += ' TEMP'
652
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
653 2385 aaronmk
654 2790 aaronmk
        try:
655
            cur = run_query(db, create_query, **kw_args)
656
                # CREATE TABLE AS sets rowcount to # rows in query
657
            break
658 2945 aaronmk
        except DuplicateException, e:
659 2790 aaronmk
            into.name = next_version(into.name)
660
            # try again with next version of name
661
662 3303 aaronmk
    if add_pkey_: add_pkey(db, into)
663 3075 aaronmk
664
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
665
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
666
    # table is going to be used in complex queries, it is wise to run ANALYZE on
667
    # the temporary table after it is populated."
668
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
669
    # If into is not a temp table, ANALYZE is useful but not required.
670 3073 aaronmk
    analyze(db, into)
671 2790 aaronmk
672
    return cur
673 2085 aaronmk
674 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
675
676 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
677
678 3420 aaronmk
def mk_select(db, tables=None, fields=None, conds=None, distinct_on=[],
679 3494 aaronmk
    limit=None, start=None, order_by=order_by_pkey, default_table=None,
680
    explain=True):
681 1981 aaronmk
    '''
682 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
683 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
684 1981 aaronmk
    @param fields Use None to select all fields in the table
685 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
686 2379 aaronmk
        * container can be any iterable type
687 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
688
        * compare_right_side: sql_gen.ValueCond|literal value
689 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
690
        use all columns
691 2786 aaronmk
    @return query
692 1981 aaronmk
    '''
693 2315 aaronmk
    # Parse tables param
694 2964 aaronmk
    tables = lists.mk_seq(tables)
695 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
696 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
697 2121 aaronmk
698 2315 aaronmk
    # Parse other params
699 2376 aaronmk
    if conds == None: conds = []
700 2650 aaronmk
    elif dicts.is_dict(conds): conds = conds.items()
701 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
702 3129 aaronmk
    assert limit == None or isinstance(limit, (int, long))
703
    assert start == None or isinstance(start, (int, long))
704 2315 aaronmk
    if order_by is order_by_pkey:
705 3421 aaronmk
        if table0 == None or distinct_on != []: order_by = None
706 2315 aaronmk
        else: order_by = pkey(db, table0, recover=True)
707 865 aaronmk
708 2315 aaronmk
    query = 'SELECT'
709 2056 aaronmk
710 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
711 2056 aaronmk
712 2200 aaronmk
    # DISTINCT ON columns
713 2233 aaronmk
    if distinct_on != []:
714 2467 aaronmk
        query += '\nDISTINCT'
715 2254 aaronmk
        if distinct_on is not distinct_on_all:
716 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
717
718
    # Columns
719 3185 aaronmk
    if query.find('\n') >= 0: whitespace = '\n'
720
    else: whitespace = ' '
721
    if fields == None: query += whitespace+'*'
722 2765 aaronmk
    else:
723
        assert fields != []
724 3185 aaronmk
        if len(fields) > 1: whitespace = '\n'
725
        query += whitespace+('\n, '.join(map(parse_col, fields)))
726 2200 aaronmk
727
    # Main table
728 3185 aaronmk
    if query.find('\n') >= 0 or len(tables) > 0: whitespace = '\n'
729
    else: whitespace = ' '
730 3420 aaronmk
    if table0 != None: query += whitespace+'FROM '+table0.to_str(db)
731 865 aaronmk
732 2122 aaronmk
    # Add joins
733 2271 aaronmk
    left_table = table0
734 2263 aaronmk
    for join_ in tables:
735
        table = join_.table
736 2238 aaronmk
737 2343 aaronmk
        # Parse special values
738
        if join_.type_ is sql_gen.filter_out: # filter no match
739 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
740 2853 aaronmk
                sql_gen.CompareCond(None, '~=')))
741 2343 aaronmk
742 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
743 2122 aaronmk
744
        left_table = table
745
746 865 aaronmk
    missing = True
747 2376 aaronmk
    if conds != []:
748 2576 aaronmk
        if len(conds) == 1: whitespace = ' '
749
        else: whitespace = '\n'
750 2578 aaronmk
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
751
            .to_str(db) for l, r in conds], 'WHERE')
752 2227 aaronmk
    if order_by != None:
753 2467 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
754 3297 aaronmk
    if limit != None: query += '\nLIMIT '+str(limit)
755 865 aaronmk
    if start != None:
756 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
757 865 aaronmk
758 3494 aaronmk
    if explain: query = with_explain_comment(db, query)
759 3266 aaronmk
760 2786 aaronmk
    return query
761 11 aaronmk
762 2054 aaronmk
def select(db, *args, **kw_args):
763
    '''For params, see mk_select() and run_query()'''
764
    recover = kw_args.pop('recover', None)
765
    cacheable = kw_args.pop('cacheable', True)
766 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
767 2054 aaronmk
768 2791 aaronmk
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
769
        log_level=log_level)
770 2054 aaronmk
771 2788 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
772 3181 aaronmk
    embeddable=False, ignore=False, src=None):
773 1960 aaronmk
    '''
774
    @param returning str|None An inserted column (such as pkey) to return
775 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
776 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
777
        query will be fully cached, not just if it raises an exception.
778 3009 aaronmk
    @param ignore Whether to ignore duplicate keys.
779 3181 aaronmk
    @param src Will be included in the name of any created function, to help
780
        identify the data source in pg_stat_activity.
781 1960 aaronmk
    '''
782 2754 aaronmk
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
783 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
784 3010 aaronmk
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
785 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
786 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
787 2063 aaronmk
788 2497 aaronmk
    first_line = 'INSERT INTO '+table.to_str(db)
789 2063 aaronmk
790 3009 aaronmk
    def mk_insert(select_query):
791
        query = first_line
792 3014 aaronmk
        if cols != None:
793
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
794 3009 aaronmk
        query += '\n'+select_query
795
796
        if returning != None:
797
            returning_name_col = sql_gen.to_name_only_col(returning)
798
            query += '\nRETURNING '+returning_name_col.to_str(db)
799
800
        return query
801 2063 aaronmk
802 3489 aaronmk
    return_type = sql_gen.CustomCode('unknown')
803
    if returning != None: return_type = sql_gen.ColType(returning)
804 3017 aaronmk
805 3009 aaronmk
    if ignore:
806 3017 aaronmk
        # Always return something to set the correct rowcount
807
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
808
809 3009 aaronmk
        embeddable = True # must use function
810 3010 aaronmk
811 3450 aaronmk
        if cols == None: row = [sql_gen.Col(sql_gen.all_cols, 'row')]
812
        else: row = [sql_gen.Col(c.name, 'row') for c in cols]
813 3092 aaronmk
814 3484 aaronmk
        query = sql_gen.RowExcIgnore(sql_gen.RowType(table), select_query,
815 3497 aaronmk
            sql_gen.ReturnQuery(mk_insert(sql_gen.Values(row).to_str(db))),
816
            cols)
817 3009 aaronmk
    else: query = mk_insert(select_query)
818
819 2070 aaronmk
    if embeddable:
820
        # Create function
821 2513 aaronmk
        function_name = sql_gen.clean_name(first_line)
822 3181 aaronmk
        if src != None: function_name = src+': '+function_name
823 2189 aaronmk
        while True:
824
            try:
825 3451 aaronmk
                func = db.TempFunction(function_name)
826 3489 aaronmk
                def_ = sql_gen.FunctionDef(func, sql_gen.SetOf(return_type),
827
                    query)
828 2194 aaronmk
829 3443 aaronmk
                run_query(db, def_.to_str(db), recover=True, cacheable=True,
830 2945 aaronmk
                    log_ignore_excs=(DuplicateException,))
831 2189 aaronmk
                break # this version was successful
832 2945 aaronmk
            except DuplicateException, e:
833 2189 aaronmk
                function_name = next_version(function_name)
834
                # try again with next version of name
835 2070 aaronmk
836 2337 aaronmk
        # Return query that uses function
837 3009 aaronmk
        cols = None
838
        if returning != None: cols = [returning]
839 3451 aaronmk
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(func), cols)
840
            # AS clause requires function alias
841 3298 aaronmk
        return mk_select(db, func_table, order_by=None)
842 2070 aaronmk
843 2787 aaronmk
    return query
844 2066 aaronmk
845 3074 aaronmk
def insert_select(db, table, *args, **kw_args):
846 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
847 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
848
        values in
849 2072 aaronmk
    '''
850 3141 aaronmk
    returning = kw_args.get('returning', None)
851
    ignore = kw_args.get('ignore', False)
852
853 2386 aaronmk
    into = kw_args.pop('into', None)
854
    if into != None: kw_args['embeddable'] = True
855 2066 aaronmk
    recover = kw_args.pop('recover', None)
856 3141 aaronmk
    if ignore: recover = True
857 2066 aaronmk
    cacheable = kw_args.pop('cacheable', True)
858 2673 aaronmk
    log_level = kw_args.pop('log_level', 2)
859 2066 aaronmk
860 3141 aaronmk
    rowcount_only = ignore and returning == None # keep NULL rows on server
861
    if rowcount_only: into = sql_gen.Table('rowcount')
862
863 3074 aaronmk
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
864
        into, recover=recover, cacheable=cacheable, log_level=log_level)
865 3141 aaronmk
    if rowcount_only: empty_temp(db, into)
866 3074 aaronmk
    autoanalyze(db, table)
867
    return cur
868 2063 aaronmk
869 2738 aaronmk
default = sql_gen.default # tells insert() to use the default value for a column
870 2066 aaronmk
871 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
872 2085 aaronmk
    '''For params, see insert_select()'''
873 5050 aaronmk
    ignore = kw_args.pop('ignore', False)
874 5094 aaronmk
    if ignore: kw_args.setdefault('recover', True)
875 5050 aaronmk
876 1960 aaronmk
    if lists.is_seq(row): cols = None
877
    else:
878
        cols = row.keys()
879
        row = row.values()
880 2738 aaronmk
    row = list(row) # ensure that "== []" works
881 1960 aaronmk
882 2738 aaronmk
    if row == []: query = None
883
    else: query = sql_gen.Values(row).to_str(db)
884 1961 aaronmk
885 5050 aaronmk
    try: return insert_select(db, table, cols, query, *args, **kw_args)
886 5057 aaronmk
    except (DuplicateKeyException, NullValueException):
887 5050 aaronmk
        if not ignore: raise
888 5163 aaronmk
        return None
889 11 aaronmk
890 3152 aaronmk
def mk_update(db, table, changes=None, cond=None, in_place=False,
891 3153 aaronmk
    cacheable_=True):
892 2402 aaronmk
    '''
893
    @param changes [(col, new_value),...]
894
        * container can be any iterable type
895
        * col: sql_gen.Code|str (for col name)
896
        * new_value: sql_gen.Code|literal value
897
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
898 3056 aaronmk
    @param in_place If set, locks the table and updates rows in place.
899
        This avoids creating dead rows in PostgreSQL.
900
        * cond must be None
901 3153 aaronmk
    @param cacheable_ Whether column structure information used to generate the
902 3152 aaronmk
        query can be cached
903 2402 aaronmk
    @return str query
904
    '''
905 3057 aaronmk
    table = sql_gen.as_Table(table)
906
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
907
        for c, v in changes]
908
909 3056 aaronmk
    if in_place:
910
        assert cond == None
911 3058 aaronmk
912 3065 aaronmk
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
913
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
914 3153 aaronmk
            +db.col_info(sql_gen.with_default_table(c, table), cacheable_).type
915 3065 aaronmk
            +'\nUSING '+v.to_str(db) for c, v in changes))
916 3058 aaronmk
    else:
917
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
918
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
919
            for c, v in changes))
920
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
921 3056 aaronmk
922 3266 aaronmk
    query = with_explain_comment(db, query)
923
924 2402 aaronmk
    return query
925
926 3074 aaronmk
def update(db, table, *args, **kw_args):
927 2402 aaronmk
    '''For params, see mk_update() and run_query()'''
928
    recover = kw_args.pop('recover', None)
929 3043 aaronmk
    cacheable = kw_args.pop('cacheable', False)
930 3030 aaronmk
    log_level = kw_args.pop('log_level', 2)
931 2402 aaronmk
932 3074 aaronmk
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
933
        cacheable, log_level=log_level)
934
    autoanalyze(db, table)
935
    return cur
936 2402 aaronmk
937 3286 aaronmk
def mk_delete(db, table, cond=None):
938
    '''
939
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
940
    @return str query
941
    '''
942
    query = 'DELETE FROM '+table.to_str(db)
943
    if cond != None: query += '\nWHERE '+cond.to_str(db)
944
945
    query = with_explain_comment(db, query)
946
947
    return query
948
949
def delete(db, table, *args, **kw_args):
950
    '''For params, see mk_delete() and run_query()'''
951
    recover = kw_args.pop('recover', None)
952 3295 aaronmk
    cacheable = kw_args.pop('cacheable', True)
953 3286 aaronmk
    log_level = kw_args.pop('log_level', 2)
954
955
    cur = run_query(db, mk_delete(db, table, *args, **kw_args), recover,
956
        cacheable, log_level=log_level)
957
    autoanalyze(db, table)
958
    return cur
959
960 135 aaronmk
def last_insert_id(db):
961 1849 aaronmk
    module = util.root_module(db.db)
962 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
963
    elif module == 'MySQLdb': return db.insert_id()
964
    else: return None
965 13 aaronmk
966 3490 aaronmk
def define_func(db, def_):
967
    func = def_.function
968
    while True:
969
        try:
970
            run_query(db, def_.to_str(db), recover=True, cacheable=True,
971
                log_ignore_excs=(DuplicateException,))
972
            break # successful
973
        except DuplicateException:
974 3495 aaronmk
            func.name = next_version(func.name)
975 3490 aaronmk
            # try again with next version of name
976
977 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
978 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
979 2415 aaronmk
    to names that will be distinct among the columns' tables.
980 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
981 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
982
    @param into The table for the new columns.
983 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
984
        columns will be included in the mapping even if they are not in cols.
985
        The tables of the provided Col objects will be changed to into, so make
986
        copies of them if you want to keep the original tables.
987
    @param as_items Whether to return a list of dict items instead of a dict
988 2383 aaronmk
    @return dict(orig_col=new_col, ...)
989
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
990 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
991
        * All mappings use the into table so its name can easily be
992 2383 aaronmk
          changed for all columns at once
993
    '''
994 2415 aaronmk
    cols = lists.uniqify(cols)
995
996 2394 aaronmk
    items = []
997 2389 aaronmk
    for col in preserve:
998 2390 aaronmk
        orig_col = copy.copy(col)
999 2392 aaronmk
        col.table = into
1000 2394 aaronmk
        items.append((orig_col, col))
1001
    preserve = set(preserve)
1002
    for col in cols:
1003 2716 aaronmk
        if col not in preserve:
1004 3750 aaronmk
            items.append((col, sql_gen.Col(strings.ustr(col), into, col.srcs)))
1005 2394 aaronmk
1006
    if not as_items: items = dict(items)
1007
    return items
1008 2383 aaronmk
1009 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
1010 2391 aaronmk
    '''For params, see mk_flatten_mapping()
1011
    @return See return value of mk_flatten_mapping()
1012
    '''
1013 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
1014
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
1015 3296 aaronmk
    run_query_into(db, mk_select(db, joins, cols, order_by=None, limit=limit,
1016 3305 aaronmk
        start=start), into=into, add_pkey_=True)
1017 3708 aaronmk
        # don't cache because the temp table will usually be truncated after use
1018 2394 aaronmk
    return dict(items)
1019 2391 aaronmk
1020 3079 aaronmk
##### Database structure introspection
1021 2414 aaronmk
1022 3321 aaronmk
#### Expressions
1023
1024 5341 aaronmk
def paren_re(re_): return r'(?:'+re_+r'|\('+re_+r'\))'
1025
1026 5344 aaronmk
true_re = paren_re(r'true')
1027
false_re = paren_re(r'false')
1028 5345 aaronmk
bool_re = r'(?:'+true_re+r'|'+false_re+r')'
1029 3353 aaronmk
1030 5346 aaronmk
def logic_op_re(op, value_re):
1031
    op_re = ' '+op+' '
1032
    return '(?:'+op_re+value_re+'|'+value_re+op_re+')'
1033
1034 5350 aaronmk
def simplify_parens(expr):
1035
    return regexp.sub_nested(r'\((\([^()]*\))\)', r'\1', expr)
1036
1037 3353 aaronmk
def simplify_expr(expr):
1038
    expr = expr.replace('(NULL IS NULL)', 'true')
1039
    expr = expr.replace('(NULL IS NOT NULL)', 'false')
1040 5347 aaronmk
    expr = re.sub(logic_op_re('OR', bool_re), r'', expr)
1041 5351 aaronmk
    expr = simplify_parens(expr)
1042 3353 aaronmk
    return expr
1043
1044 3321 aaronmk
name_re = r'(?:\w+|(?:"[^"]*")+)'
1045
1046
def parse_expr_col(str_):
1047
    match = re.match(r'^\('+name_re+r'\(('+name_re+r').*\)\)$', str_)
1048
    if match: str_ = match.group(1)
1049
    return sql_gen.unesc_name(str_)
1050
1051 3351 aaronmk
def map_expr(db, expr, mapping, in_cols_found=None):
1052
    '''Replaces output columns with input columns in an expression.
1053
    @param in_cols_found If set, will be filled in with the expr's (input) cols
1054
    '''
1055
    for out, in_ in mapping.iteritems():
1056
        orig_expr = expr
1057
        out = sql_gen.to_name_only_col(out)
1058
        in_str = sql_gen.to_name_only_col(sql_gen.remove_col_rename(in_)
1059
            ).to_str(db)
1060
1061
        # Replace out both with and without quotes
1062
        expr = expr.replace(out.to_str(db), in_str)
1063 5317 aaronmk
        expr = re.sub(r'(?<!["\'\.\[])\b'+out.name+r'\b(?!["\'\.=\]])', in_str,
1064
            expr)
1065 3351 aaronmk
1066
        if in_cols_found != None and expr != orig_expr: # replaced something
1067
            in_cols_found.append(in_)
1068 3353 aaronmk
1069
    return simplify_expr(expr)
1070 3351 aaronmk
1071 3079 aaronmk
#### Tables
1072
1073 4555 aaronmk
def tables(db, schema_like='public', table_like='%', exact=False,
1074
    cacheable=True):
1075 3079 aaronmk
    if exact: compare = '='
1076
    else: compare = 'LIKE'
1077
1078
    module = util.root_module(db.db)
1079
    if module == 'psycopg2':
1080
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
1081
            ('tablename', sql_gen.CompareCond(table_like, compare))]
1082
        return values(select(db, 'pg_tables', ['tablename'], conds,
1083 4555 aaronmk
            order_by='tablename', cacheable=cacheable, log_level=4))
1084 3079 aaronmk
    elif module == 'MySQLdb':
1085
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
1086
            , cacheable=True, log_level=4))
1087
    else: raise NotImplementedError("Can't list tables for "+module+' database')
1088
1089 4556 aaronmk
def table_exists(db, table, cacheable=True):
1090 3079 aaronmk
    table = sql_gen.as_Table(table)
1091 4556 aaronmk
    return list(tables(db, table.schema, table.name, True, cacheable)) != []
1092 3079 aaronmk
1093 2426 aaronmk
def table_row_count(db, table, recover=None):
1094 2786 aaronmk
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
1095 3298 aaronmk
        order_by=None), recover=recover, log_level=3))
1096 2426 aaronmk
1097 5337 aaronmk
def table_col_names(db, table, recover=None):
1098 2414 aaronmk
    return list(col_names(select(db, table, limit=0, order_by=None,
1099 2443 aaronmk
        recover=recover, log_level=4)))
1100 2414 aaronmk
1101 4261 aaronmk
pkey_col = 'row_num'
1102
1103 2291 aaronmk
def pkey(db, table, recover=None):
1104 5061 aaronmk
    '''If no pkey, returns the first column in the table.'''
1105
    table = sql_gen.as_Table(table)
1106
1107
    join_cols = ['table_schema', 'table_name', 'constraint_schema',
1108
        'constraint_name']
1109
    tables = [sql_gen.Table('key_column_usage', 'information_schema'),
1110
        sql_gen.Join(sql_gen.Table('table_constraints', 'information_schema'),
1111
            dict(((c, sql_gen.join_same_not_null) for c in join_cols)))]
1112
    cols = [sql_gen.Col('column_name')]
1113
1114
    conds = [('constraint_type', 'PRIMARY KEY'), ('table_name', table.name)]
1115
    schema = table.schema
1116
    if schema != None: conds.append(('table_schema', schema))
1117
    order_by = 'position_in_unique_constraint'
1118
1119
    try: return value(select(db, tables, cols, conds, order_by=order_by,
1120
        limit=1, log_level=4))
1121 5337 aaronmk
    except StopIteration: return table_col_names(db, table, recover)[0]
1122 832 aaronmk
1123 5128 aaronmk
def pkey_col_(db, table, *args, **kw_args):
1124
    return sql_gen.Col(pkey(db, table, *args, **kw_args), table)
1125
1126 2559 aaronmk
not_null_col = 'not_null_col'
1127 2340 aaronmk
1128
def table_not_null_col(db, table, recover=None):
1129
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
1130 5337 aaronmk
    if not_null_col in table_col_names(db, table, recover): return not_null_col
1131 2340 aaronmk
    else: return pkey(db, table, recover)
1132
1133 3348 aaronmk
def constraint_cond(db, constraint):
1134
    module = util.root_module(db.db)
1135
    if module == 'psycopg2':
1136
        table_str = sql_gen.Literal(constraint.table.to_str(db))
1137
        name_str = sql_gen.Literal(constraint.name)
1138
        return value(run_query(db, '''\
1139
SELECT consrc
1140
FROM pg_constraint
1141
WHERE
1142
conrelid = '''+table_str.to_str(db)+'''::regclass
1143
AND conname = '''+name_str.to_str(db)+'''
1144
'''
1145
            , cacheable=True, log_level=4))
1146
    else: raise NotImplementedError("Can't list index columns for "+module+
1147
        ' database')
1148
1149 3319 aaronmk
def index_cols(db, index):
1150 853 aaronmk
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
1151
    automatically created. When you don't know whether something is a UNIQUE
1152
    constraint or a UNIQUE index, use this function.'''
1153 3322 aaronmk
    index = sql_gen.as_Table(index)
1154 1909 aaronmk
    module = util.root_module(db.db)
1155
    if module == 'psycopg2':
1156 3322 aaronmk
        qual_index = sql_gen.Literal(index.to_str(db))
1157
        return map(parse_expr_col, values(run_query(db, '''\
1158
SELECT pg_get_indexdef(indexrelid, generate_series(1, indnatts), true)
1159
FROM pg_index
1160
WHERE indexrelid = '''+qual_index.to_str(db)+'''::regclass
1161 2782 aaronmk
'''
1162
            , cacheable=True, log_level=4)))
1163 1909 aaronmk
    else: raise NotImplementedError("Can't list index columns for "+module+
1164
        ' database')
1165 853 aaronmk
1166 3079 aaronmk
#### Functions
1167
1168
def function_exists(db, function):
1169 3423 aaronmk
    qual_function = sql_gen.Literal(function.to_str(db))
1170
    try:
1171 3425 aaronmk
        select(db, fields=[sql_gen.Cast('regproc', qual_function)],
1172
            recover=True, cacheable=True, log_level=4)
1173 3423 aaronmk
    except DoesNotExistException: return False
1174 4146 aaronmk
    except DuplicateException: return True # overloaded function
1175 3423 aaronmk
    else: return True
1176 3079 aaronmk
1177
##### Structural changes
1178
1179
#### Columns
1180
1181 5020 aaronmk
def add_col(db, table, col, comment=None, if_not_exists=False, **kw_args):
1182 3079 aaronmk
    '''
1183
    @param col TypedCol Name may be versioned, so be sure to propagate any
1184
        renaming back to any source column for the TypedCol.
1185
    @param comment None|str SQL comment used to distinguish columns of the same
1186
        name from each other when they contain different data, to allow the
1187
        ADD COLUMN query to be cached. If not set, query will not be cached.
1188
    '''
1189
    assert isinstance(col, sql_gen.TypedCol)
1190
1191
    while True:
1192
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1193
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1194
1195
        try:
1196
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1197
            break
1198
        except DuplicateException:
1199 5020 aaronmk
            if if_not_exists: raise
1200 3079 aaronmk
            col.name = next_version(col.name)
1201
            # try again with next version of name
1202
1203
def add_not_null(db, col):
1204
    table = col.table
1205
    col = sql_gen.to_name_only_col(col)
1206
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1207
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1208
1209 4443 aaronmk
def drop_not_null(db, col):
1210
    table = col.table
1211
    col = sql_gen.to_name_only_col(col)
1212
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1213
        +col.to_str(db)+' DROP NOT NULL', cacheable=True, log_level=3)
1214
1215 2096 aaronmk
row_num_col = '_row_num'
1216
1217 4997 aaronmk
row_num_col_def = sql_gen.TypedCol('', 'serial', nullable=False,
1218 3079 aaronmk
    constraints='PRIMARY KEY')
1219
1220 4997 aaronmk
def add_row_num(db, table, name=row_num_col):
1221
    '''Adds a row number column to a table. Its definition is in
1222
    row_num_col_def. It will be the primary key.'''
1223
    col_def = copy.copy(row_num_col_def)
1224
    col_def.name = name
1225 5021 aaronmk
    add_col(db, table, col_def, comment='', if_not_exists=True, log_level=3)
1226 3079 aaronmk
1227
#### Indexes
1228
1229
def add_pkey(db, table, cols=None, recover=None):
1230
    '''Adds a primary key.
1231
    @param cols [sql_gen.Col,...] The columns in the primary key.
1232
        Defaults to the first column in the table.
1233
    @pre The table must not already have a primary key.
1234
    '''
1235
    table = sql_gen.as_Table(table)
1236
    if cols == None: cols = [pkey(db, table, recover)]
1237
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1238
1239
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1240
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1241
        log_ignore_excs=(DuplicateException,))
1242
1243 2998 aaronmk
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1244 2688 aaronmk
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1245 3356 aaronmk
    Currently, only function calls and literal values are supported expressions.
1246 2998 aaronmk
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1247 2847 aaronmk
        This allows indexes to be used for comparisons where NULLs are equal.
1248 2538 aaronmk
    '''
1249 2964 aaronmk
    exprs = lists.mk_seq(exprs)
1250 2538 aaronmk
1251 2688 aaronmk
    # Parse exprs
1252
    old_exprs = exprs[:]
1253
    exprs = []
1254
    cols = []
1255
    for i, expr in enumerate(old_exprs):
1256 2823 aaronmk
        expr = sql_gen.as_Col(expr, table)
1257 2688 aaronmk
1258 2823 aaronmk
        # Handle nullable columns
1259 2998 aaronmk
        if ensure_not_null_:
1260 3164 aaronmk
            try: expr = sql_gen.ensure_not_null(db, expr)
1261 2860 aaronmk
            except KeyError: pass # unknown type, so just create plain index
1262 2823 aaronmk
1263 2688 aaronmk
        # Extract col
1264 3002 aaronmk
        expr = copy.deepcopy(expr) # don't modify input!
1265 3356 aaronmk
        col = expr
1266
        if isinstance(expr, sql_gen.FunctionCall): col = expr.args[0]
1267
        expr = sql_gen.cast_literal(expr)
1268
        if not isinstance(expr, (sql_gen.Expr, sql_gen.Col)):
1269 2688 aaronmk
            expr = sql_gen.Expr(expr)
1270 3356 aaronmk
1271 2688 aaronmk
1272
        # Extract table
1273
        if table == None:
1274
            assert sql_gen.is_table_col(col)
1275
            table = col.table
1276
1277 3356 aaronmk
        if isinstance(col, sql_gen.Col): col.table = None
1278 2688 aaronmk
1279
        exprs.append(expr)
1280
        cols.append(col)
1281 2408 aaronmk
1282 2688 aaronmk
    table = sql_gen.as_Table(table)
1283
1284 3005 aaronmk
    # Add index
1285 3148 aaronmk
    str_ = 'CREATE'
1286
    if unique: str_ += ' UNIQUE'
1287
    str_ += ' INDEX ON '+table.to_str(db)+' ('+(
1288
        ', '.join((v.to_str(db) for v in exprs)))+')'
1289
    run_query(db, str_, recover=True, cacheable=True, log_level=3)
1290 2408 aaronmk
1291 3083 aaronmk
already_indexed = object() # tells add_indexes() the pkey has already been added
1292
1293
def add_indexes(db, table, has_pkey=True):
1294
    '''Adds an index on all columns in a table.
1295
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1296
        index should be added on the first column.
1297
        * If already_indexed, the pkey is assumed to have already been added
1298
    '''
1299 5337 aaronmk
    cols = table_col_names(db, table)
1300 3083 aaronmk
    if has_pkey:
1301
        if has_pkey is not already_indexed: add_pkey(db, table)
1302
        cols = cols[1:]
1303
    for col in cols: add_index(db, col, table)
1304
1305 3079 aaronmk
#### Tables
1306 2772 aaronmk
1307 3079 aaronmk
### Maintenance
1308 2772 aaronmk
1309 3079 aaronmk
def analyze(db, table):
1310
    table = sql_gen.as_Table(table)
1311
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1312 2934 aaronmk
1313 3079 aaronmk
def autoanalyze(db, table):
1314
    if db.autoanalyze: analyze(db, table)
1315 2935 aaronmk
1316 3079 aaronmk
def vacuum(db, table):
1317
    table = sql_gen.as_Table(table)
1318
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1319
        log_level=3))
1320 2086 aaronmk
1321 3079 aaronmk
### Lifecycle
1322
1323 3247 aaronmk
def drop(db, type_, name):
1324
    name = sql_gen.as_Name(name)
1325
    run_query(db, 'DROP '+type_+' IF EXISTS '+name.to_str(db)+' CASCADE')
1326 2889 aaronmk
1327 3247 aaronmk
def drop_table(db, table): drop(db, 'TABLE', table)
1328
1329 3082 aaronmk
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1330
    like=None):
1331 2675 aaronmk
    '''Creates a table.
1332 2681 aaronmk
    @param cols [sql_gen.TypedCol,...] The column names and types
1333
    @param has_pkey If set, the first column becomes the primary key.
1334 2760 aaronmk
    @param col_indexes bool|[ref]
1335
        * If True, indexes will be added on all non-pkey columns.
1336
        * If a list reference, [0] will be set to a function to do this.
1337
          This can be used to delay index creation until the table is populated.
1338 2675 aaronmk
    '''
1339
    table = sql_gen.as_Table(table)
1340
1341 3082 aaronmk
    if like != None:
1342
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1343
            ]+cols
1344 2681 aaronmk
    if has_pkey:
1345
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1346 2872 aaronmk
        pkey.constraints = 'PRIMARY KEY'
1347 2681 aaronmk
1348 3085 aaronmk
    temp = table.is_temp and not db.debug_temp
1349
        # temp tables permanent in debug_temp mode
1350 2760 aaronmk
1351 3085 aaronmk
    # Create table
1352 3383 aaronmk
    def create():
1353 3085 aaronmk
        str_ = 'CREATE'
1354
        if temp: str_ += ' TEMP'
1355
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1356
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1357 3126 aaronmk
        str_ += '\n);'
1358 3085 aaronmk
1359 3383 aaronmk
        run_query(db, str_, recover=True, cacheable=True, log_level=2,
1360
            log_ignore_excs=(DuplicateException,))
1361
    if table.is_temp:
1362
        while True:
1363
            try:
1364
                create()
1365
                break
1366
            except DuplicateException:
1367
                table.name = next_version(table.name)
1368
                # try again with next version of name
1369
    else: create()
1370 3085 aaronmk
1371 2760 aaronmk
    # Add indexes
1372 2773 aaronmk
    if has_pkey: has_pkey = already_indexed
1373
    def add_indexes_(): add_indexes(db, table, has_pkey)
1374
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1375
    elif col_indexes: add_indexes_() # add now
1376 2675 aaronmk
1377 3084 aaronmk
def copy_table_struct(db, src, dest):
1378
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1379 3085 aaronmk
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1380 3084 aaronmk
1381 3079 aaronmk
### Data
1382 2684 aaronmk
1383 2970 aaronmk
def truncate(db, table, schema='public', **kw_args):
1384
    '''For params, see run_query()'''
1385 2777 aaronmk
    table = sql_gen.as_Table(table, schema)
1386 2970 aaronmk
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1387 2732 aaronmk
1388 2965 aaronmk
def empty_temp(db, tables):
1389
    tables = lists.mk_seq(tables)
1390 2971 aaronmk
    for table in tables: truncate(db, table, log_level=3)
1391 2965 aaronmk
1392 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
1393
    '''For kw_args, see tables()'''
1394
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1395 3094 aaronmk
1396
def distinct_table(db, table, distinct_on):
1397
    '''Creates a copy of a temp table which is distinct on the given columns.
1398 3099 aaronmk
    The old and new tables will both get an index on these columns, to
1399
    facilitate merge joins.
1400 3097 aaronmk
    @param distinct_on If empty, creates a table with one row. This is useful if
1401
        your distinct_on columns are all literal values.
1402 3099 aaronmk
    @return The new table.
1403 3094 aaronmk
    '''
1404 3099 aaronmk
    new_table = sql_gen.suffixed_table(table, '_distinct')
1405 3411 aaronmk
    distinct_on = filter(sql_gen.is_table_col, distinct_on)
1406 3094 aaronmk
1407 3099 aaronmk
    copy_table_struct(db, table, new_table)
1408 3097 aaronmk
1409
    limit = None
1410
    if distinct_on == []: limit = 1 # one sample row
1411 3099 aaronmk
    else:
1412
        add_index(db, distinct_on, new_table, unique=True)
1413
        add_index(db, distinct_on, table) # for join optimization
1414 3097 aaronmk
1415 3313 aaronmk
    insert_select(db, new_table, None, mk_select(db, table, order_by=None,
1416
        limit=limit), ignore=True)
1417 3099 aaronmk
    analyze(db, new_table)
1418 3094 aaronmk
1419 3099 aaronmk
    return new_table