Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 3238 aaronmk
import time
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 3241 aaronmk
import profiling
13 1889 aaronmk
from Proxy import Proxy
14 1872 aaronmk
import rand
15 2217 aaronmk
import sql_gen
16 862 aaronmk
import strings
17 131 aaronmk
import util
18 11 aaronmk
19 832 aaronmk
##### Exceptions
20
21 2804 aaronmk
def get_cur_query(cur, input_query=None):
22 2168 aaronmk
    raw_query = None
23
    if hasattr(cur, 'query'): raw_query = cur.query
24
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
25 2170 aaronmk
26
    if raw_query != None: return raw_query
27 2804 aaronmk
    else: return '[input] '+strings.ustr(input_query)
28 14 aaronmk
29 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
30
    '''For params, see get_cur_query()'''
31 2771 aaronmk
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
32 135 aaronmk
33 300 aaronmk
class DbException(exc.ExceptionWithCause):
34 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
35 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
36 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
37
38 2143 aaronmk
class ExceptionWithName(DbException):
39
    def __init__(self, name, cause=None):
40 4491 aaronmk
        DbException.__init__(self, 'for name: '
41
            +strings.as_tt(strings.ustr(name)), cause)
42 2143 aaronmk
        self.name = name
43 360 aaronmk
44 3109 aaronmk
class ExceptionWithValue(DbException):
45
    def __init__(self, value, cause=None):
46 4492 aaronmk
        DbException.__init__(self, 'for value: '
47
            +strings.as_tt(strings.urepr(value)), cause)
48 2240 aaronmk
        self.value = value
49
50 2945 aaronmk
class ExceptionWithNameType(DbException):
51
    def __init__(self, type_, name, cause=None):
52 4491 aaronmk
        DbException.__init__(self, 'for type: '+strings.as_tt(strings.ustr(
53
            type_))+'; name: '+strings.as_tt(name), cause)
54 2945 aaronmk
        self.type = type_
55
        self.name = name
56
57 2306 aaronmk
class ConstraintException(DbException):
58 3345 aaronmk
    def __init__(self, name, cond, cols, cause=None):
59
        msg = 'Violated '+strings.as_tt(name)+' constraint'
60 5447 aaronmk
        if cond != None: msg += ' with condition '+strings.as_tt(cond)
61 3345 aaronmk
        if cols != []: msg += ' on columns: '+strings.as_tt(', '.join(cols))
62
        DbException.__init__(self, msg, cause)
63 2306 aaronmk
        self.name = name
64 3345 aaronmk
        self.cond = cond
65 468 aaronmk
        self.cols = cols
66 11 aaronmk
67 2523 aaronmk
class MissingCastException(DbException):
68 4139 aaronmk
    def __init__(self, type_, col=None, cause=None):
69
        msg = 'Missing cast to type '+strings.as_tt(type_)
70
        if col != None: msg += ' on column: '+strings.as_tt(col)
71
        DbException.__init__(self, msg, cause)
72 2523 aaronmk
        self.type = type_
73
        self.col = col
74
75 5576 aaronmk
class EncodingException(ExceptionWithName): pass
76
77 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
78 13 aaronmk
79 2306 aaronmk
class NullValueException(ConstraintException): pass
80 13 aaronmk
81 3346 aaronmk
class CheckException(ConstraintException): pass
82
83 3109 aaronmk
class InvalidValueException(ExceptionWithValue): pass
84 2239 aaronmk
85 2945 aaronmk
class DuplicateException(ExceptionWithNameType): pass
86 2143 aaronmk
87 3419 aaronmk
class DoesNotExistException(ExceptionWithNameType): pass
88
89 89 aaronmk
class EmptyRowException(DbException): pass
90
91 865 aaronmk
##### Warnings
92
93
class DbWarning(UserWarning): pass
94
95 1930 aaronmk
##### Result retrieval
96
97
def col_names(cur): return (col[0] for col in cur.description)
98
99
def rows(cur): return iter(lambda: cur.fetchone(), None)
100
101
def consume_rows(cur):
102
    '''Used to fetch all rows so result will be cached'''
103
    iters.consume_iter(rows(cur))
104
105
def next_row(cur): return rows(cur).next()
106
107
def row(cur):
108
    row_ = next_row(cur)
109
    consume_rows(cur)
110
    return row_
111
112
def next_value(cur): return next_row(cur)[0]
113
114
def value(cur): return row(cur)[0]
115
116
def values(cur): return iters.func_iter(lambda: next_value(cur))
117
118
def value_or_none(cur):
119
    try: return value(cur)
120
    except StopIteration: return None
121
122 2762 aaronmk
##### Escaping
123 2101 aaronmk
124 2573 aaronmk
def esc_name_by_module(module, name):
125
    if module == 'psycopg2' or module == None: quote = '"'
126 2101 aaronmk
    elif module == 'MySQLdb': quote = '`'
127
    else: raise NotImplementedError("Can't escape name for "+module+' database')
128 2500 aaronmk
    return sql_gen.esc_name(name, quote)
129 2101 aaronmk
130
def esc_name_by_engine(engine, name, **kw_args):
131
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
132
133
def esc_name(db, name, **kw_args):
134
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
135
136
def qual_name(db, schema, table):
137
    def esc_name_(name): return esc_name(db, name)
138
    table = esc_name_(table)
139
    if schema != None: return esc_name_(schema)+'.'+table
140
    else: return table
141
142 1869 aaronmk
##### Database connections
143 1849 aaronmk
144 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
145 1926 aaronmk
146 1869 aaronmk
db_engines = {
147
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
148
    'PostgreSQL': ('psycopg2', {}),
149
}
150
151
DatabaseErrors_set = set([DbException])
152
DatabaseErrors = tuple(DatabaseErrors_set)
153
154
def _add_module(module):
155
    DatabaseErrors_set.add(module.DatabaseError)
156
    global DatabaseErrors
157
    DatabaseErrors = tuple(DatabaseErrors_set)
158
159
def db_config_str(db_config):
160
    return db_config['engine']+' database '+db_config['database']
161
162 2448 aaronmk
log_debug_none = lambda msg, level=2: None
163 1901 aaronmk
164 1849 aaronmk
class DbConn:
165 2923 aaronmk
    def __init__(self, db_config, autocommit=True, caching=True,
166 3183 aaronmk
        log_debug=log_debug_none, debug_temp=False, src=None):
167 2915 aaronmk
        '''
168
        @param debug_temp Whether temporary objects should instead be permanent.
169
            This assists in debugging the internal objects used by the program.
170 3183 aaronmk
        @param src In autocommit mode, will be included in a comment in every
171
            query, to help identify the data source in pg_stat_activity.
172 2915 aaronmk
        '''
173 1869 aaronmk
        self.db_config = db_config
174 2190 aaronmk
        self.autocommit = autocommit
175
        self.caching = caching
176 1901 aaronmk
        self.log_debug = log_debug
177 2193 aaronmk
        self.debug = log_debug != log_debug_none
178 2915 aaronmk
        self.debug_temp = debug_temp
179 3183 aaronmk
        self.src = src
180 3074 aaronmk
        self.autoanalyze = False
181 3269 aaronmk
        self.autoexplain = False
182
        self.profile_row_ct = None
183 1869 aaronmk
184 3124 aaronmk
        self._savepoint = 0
185 3120 aaronmk
        self._reset()
186 1869 aaronmk
187
    def __getattr__(self, name):
188
        if name == '__dict__': raise Exception('getting __dict__')
189
        if name == 'db': return self._db()
190
        else: raise AttributeError()
191
192
    def __getstate__(self):
193
        state = copy.copy(self.__dict__) # shallow copy
194 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
195 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
196
        return state
197
198 3118 aaronmk
    def clear_cache(self): self.query_results = {}
199
200 3120 aaronmk
    def _reset(self):
201 3118 aaronmk
        self.clear_cache()
202 3124 aaronmk
        assert self._savepoint == 0
203 3118 aaronmk
        self._notices_seen = set()
204
        self.__db = None
205
206 2165 aaronmk
    def connected(self): return self.__db != None
207
208 3116 aaronmk
    def close(self):
209 3119 aaronmk
        if not self.connected(): return
210
211 3135 aaronmk
        # Record that the automatic transaction is now closed
212 3136 aaronmk
        self._savepoint -= 1
213 3135 aaronmk
214 3119 aaronmk
        self.db.close()
215 3120 aaronmk
        self._reset()
216 3116 aaronmk
217 3125 aaronmk
    def reconnect(self):
218
        # Do not do this in test mode as it would roll back everything
219
        if self.autocommit: self.close()
220
        # Connection will be reopened automatically on first query
221
222 1869 aaronmk
    def _db(self):
223
        if self.__db == None:
224
            # Process db_config
225
            db_config = self.db_config.copy() # don't modify input!
226 2097 aaronmk
            schemas = db_config.pop('schemas', None)
227 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
228
            module = __import__(module_name)
229
            _add_module(module)
230
            for orig, new in mappings.iteritems():
231
                try: util.rename_key(db_config, orig, new)
232
                except KeyError: pass
233
234
            # Connect
235
            self.__db = module.connect(**db_config)
236
237 3161 aaronmk
            # Record that a transaction is already open
238
            self._savepoint += 1
239
240 1869 aaronmk
            # Configure connection
241 2906 aaronmk
            if hasattr(self.db, 'set_isolation_level'):
242
                import psycopg2.extensions
243
                self.db.set_isolation_level(
244
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
245 2101 aaronmk
            if schemas != None:
246 2893 aaronmk
                search_path = [self.esc_name(s) for s in schemas.split(',')]
247
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
248
                    log_level=3)
249 1869 aaronmk
250
        return self.__db
251 1889 aaronmk
252 1891 aaronmk
    class DbCursor(Proxy):
253 1927 aaronmk
        def __init__(self, outer):
254 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
255 2191 aaronmk
            self.outer = outer
256 1927 aaronmk
            self.query_results = outer.query_results
257 1894 aaronmk
            self.query_lookup = None
258 1891 aaronmk
            self.result = []
259 1889 aaronmk
260 2802 aaronmk
        def execute(self, query):
261 2764 aaronmk
            self._is_insert = query.startswith('INSERT')
262 2797 aaronmk
            self.query_lookup = query
263 2148 aaronmk
            try:
264 3162 aaronmk
                try: cur = self.inner.execute(query)
265 2802 aaronmk
                finally: self.query = get_cur_query(self.inner, query)
266 1904 aaronmk
            except Exception, e:
267
                self.result = e # cache the exception as the result
268
                self._cache_result()
269
                raise
270 3004 aaronmk
271
            # Always cache certain queries
272 3183 aaronmk
            query = sql_gen.lstrip(query)
273 3004 aaronmk
            if query.startswith('CREATE') or query.startswith('ALTER'):
274 3007 aaronmk
                # structural changes
275 3040 aaronmk
                # Rest of query must be unique in the face of name collisions,
276
                # so don't cache ADD COLUMN unless it has distinguishing comment
277
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
278 3007 aaronmk
                    self._cache_result()
279 3004 aaronmk
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
280 2800 aaronmk
                consume_rows(self) # fetch all rows so result will be cached
281 3004 aaronmk
282 2762 aaronmk
            return cur
283 1894 aaronmk
284 1891 aaronmk
        def fetchone(self):
285
            row = self.inner.fetchone()
286 1899 aaronmk
            if row != None: self.result.append(row)
287
            # otherwise, fetched all rows
288 1904 aaronmk
            else: self._cache_result()
289
            return row
290
291
        def _cache_result(self):
292 2948 aaronmk
            # For inserts that return a result set, don't cache result set since
293
            # inserts are not idempotent. Other non-SELECT queries don't have
294
            # their result set read, so only exceptions will be cached (an
295
            # invalid query will always be invalid).
296 1930 aaronmk
            if self.query_results != None and (not self._is_insert
297 1906 aaronmk
                or isinstance(self.result, Exception)):
298
299 1894 aaronmk
                assert self.query_lookup != None
300 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
301
                    util.dict_subset(dicts.AttrsDictView(self),
302
                    ['query', 'result', 'rowcount', 'description']))
303 1906 aaronmk
304 1916 aaronmk
        class CacheCursor:
305
            def __init__(self, cached_result): self.__dict__ = cached_result
306
307 1927 aaronmk
            def execute(self, *args, **kw_args):
308 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
309
                # otherwise, result is a rows list
310
                self.iter = iter(self.result)
311
312
            def fetchone(self):
313
                try: return self.iter.next()
314
                except StopIteration: return None
315 1891 aaronmk
316 2212 aaronmk
    def esc_value(self, value):
317 2663 aaronmk
        try: str_ = self.mogrify('%s', [value])
318
        except NotImplementedError, e:
319
            module = util.root_module(self.db)
320
            if module == 'MySQLdb':
321
                import _mysql
322
                str_ = _mysql.escape_string(value)
323
            else: raise e
324 2374 aaronmk
        return strings.to_unicode(str_)
325 2212 aaronmk
326 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
327
328 2814 aaronmk
    def std_code(self, str_):
329
        '''Standardizes SQL code.
330
        * Ensures that string literals are prefixed by `E`
331
        '''
332
        if str_.startswith("'"): str_ = 'E'+str_
333
        return str_
334
335 2665 aaronmk
    def can_mogrify(self):
336 2663 aaronmk
        module = util.root_module(self.db)
337 2665 aaronmk
        return module == 'psycopg2'
338 2663 aaronmk
339 2665 aaronmk
    def mogrify(self, query, params=None):
340
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
341
        else: raise NotImplementedError("Can't mogrify query")
342
343 5579 aaronmk
    def set_encoding(self, encoding):
344
        encoding_str = sql_gen.Literal(encoding)
345
        run_query(self, 'SET NAMES '+encoding_str.to_str(self))
346
347 2671 aaronmk
    def print_notices(self):
348 2725 aaronmk
        if hasattr(self.db, 'notices'):
349
            for msg in self.db.notices:
350
                if msg not in self._notices_seen:
351
                    self._notices_seen.add(msg)
352
                    self.log_debug(msg, level=2)
353 2671 aaronmk
354 2793 aaronmk
    def run_query(self, query, cacheable=False, log_level=2,
355 2464 aaronmk
        debug_msg_ref=None):
356 2445 aaronmk
        '''
357 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
358
            throws one of these exceptions.
359 2664 aaronmk
        @param debug_msg_ref If specified, the log message will be returned in
360
            this instead of being output. This allows you to filter log messages
361
            depending on the result of the query.
362 2445 aaronmk
        '''
363 2167 aaronmk
        assert query != None
364
365 3183 aaronmk
        if self.autocommit and self.src != None:
366 3206 aaronmk
            query = sql_gen.esc_comment(self.src)+'\t'+query
367 3183 aaronmk
368 2047 aaronmk
        if not self.caching: cacheable = False
369 1903 aaronmk
        used_cache = False
370 2664 aaronmk
371 3242 aaronmk
        if self.debug:
372
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
373 1903 aaronmk
        try:
374 1927 aaronmk
            # Get cursor
375
            if cacheable:
376 3238 aaronmk
                try: cur = self.query_results[query]
377 1927 aaronmk
                except KeyError: cur = self.DbCursor(self)
378 3238 aaronmk
                else: used_cache = True
379 1927 aaronmk
            else: cur = self.db.cursor()
380
381
            # Run query
382 3238 aaronmk
            try: cur.execute(query)
383 3162 aaronmk
            except Exception, e:
384
                _add_cursor_info(e, self, query)
385
                raise
386 3238 aaronmk
            else: self.do_autocommit()
387 1903 aaronmk
        finally:
388 3242 aaronmk
            if self.debug:
389 3244 aaronmk
                profiler.stop(self.profile_row_ct)
390 3242 aaronmk
391
                ## Log or return query
392
393 4491 aaronmk
                query = strings.ustr(get_cur_query(cur, query))
394 3281 aaronmk
                # Put the src comment on a separate line in the log file
395
                query = query.replace('\t', '\n', 1)
396 3239 aaronmk
397 3240 aaronmk
                msg = 'DB query: '
398 3239 aaronmk
399 3240 aaronmk
                if used_cache: msg += 'cache hit'
400
                elif cacheable: msg += 'cache miss'
401
                else: msg += 'non-cacheable'
402 3239 aaronmk
403 3241 aaronmk
                msg += ':\n'+profiler.msg()+'\n'+strings.as_code(query, 'SQL')
404 3240 aaronmk
405 3237 aaronmk
                if debug_msg_ref != None: debug_msg_ref[0] = msg
406
                else: self.log_debug(msg, log_level)
407 3245 aaronmk
408
                self.print_notices()
409 1903 aaronmk
410
        return cur
411 1914 aaronmk
412 2797 aaronmk
    def is_cached(self, query): return query in self.query_results
413 2139 aaronmk
414 2907 aaronmk
    def with_autocommit(self, func):
415 2801 aaronmk
        import psycopg2.extensions
416
417
        prev_isolation_level = self.db.isolation_level
418 2907 aaronmk
        self.db.set_isolation_level(
419
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
420 2683 aaronmk
        try: return func()
421 2801 aaronmk
        finally: self.db.set_isolation_level(prev_isolation_level)
422 2683 aaronmk
423 2139 aaronmk
    def with_savepoint(self, func):
424 3137 aaronmk
        top = self._savepoint == 0
425 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
426 3137 aaronmk
427 3272 aaronmk
        if self.debug:
428 3273 aaronmk
            self.log_debug('Begin transaction', level=4)
429 3272 aaronmk
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
430
431 3160 aaronmk
        # Must happen before running queries so they don't get autocommitted
432
        self._savepoint += 1
433
434 3137 aaronmk
        if top: query = 'START TRANSACTION ISOLATION LEVEL READ COMMITTED'
435
        else: query = 'SAVEPOINT '+savepoint
436
        self.run_query(query, log_level=4)
437
        try:
438
            return func()
439
            if top: self.run_query('COMMIT', log_level=4)
440 2139 aaronmk
        except:
441 3137 aaronmk
            if top: query = 'ROLLBACK'
442
            else: query = 'ROLLBACK TO SAVEPOINT '+savepoint
443
            self.run_query(query, log_level=4)
444
445 2139 aaronmk
            raise
446 2930 aaronmk
        finally:
447
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
448
            # "The savepoint remains valid and can be rolled back to again"
449
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
450 3137 aaronmk
            if not top:
451
                self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
452 2930 aaronmk
453
            self._savepoint -= 1
454
            assert self._savepoint >= 0
455
456 3272 aaronmk
            if self.debug:
457
                profiler.stop(self.profile_row_ct)
458 3273 aaronmk
                self.log_debug('End transaction\n'+profiler.msg(), level=4)
459 3272 aaronmk
460 2930 aaronmk
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
461 2191 aaronmk
462
    def do_autocommit(self):
463
        '''Autocommits if outside savepoint'''
464 3135 aaronmk
        assert self._savepoint >= 1
465
        if self.autocommit and self._savepoint == 1:
466 2924 aaronmk
            self.log_debug('Autocommitting', level=4)
467 2191 aaronmk
            self.db.commit()
468 2643 aaronmk
469 3155 aaronmk
    def col_info(self, col, cacheable=True):
470 6900 aaronmk
        module = util.root_module(self.db)
471 6899 aaronmk
        if module == 'psycopg2':
472 6900 aaronmk
            qual_table = sql_gen.Literal(col.table.to_str(self))
473 6899 aaronmk
            col_name_str = sql_gen.Literal(col.name)
474
            try:
475 6900 aaronmk
                type_, is_array, default, nullable = row(run_query(self, '''\
476 6899 aaronmk
SELECT
477
format_type(COALESCE(NULLIF(typelem, 0), pg_type.oid), -1) AS type
478
, typcategory = 'A' AS type_is_array
479 6903 aaronmk
, pg_get_expr(pg_attrdef.adbin, attrelid, true) AS default
480 6899 aaronmk
, NOT pg_attribute.attnotnull AS nullable
481
FROM pg_attribute
482
LEFT JOIN pg_type ON pg_type.oid = atttypid
483
LEFT JOIN pg_attrdef ON adrelid = attrelid AND adnum = attnum
484
WHERE
485 6901 aaronmk
attrelid = '''+qual_table.to_str(self)+'''::regclass
486
AND attname = '''+col_name_str.to_str(self)+'''
487 6899 aaronmk
'''
488 6904 aaronmk
                    , recover=True, cacheable=cacheable, log_level=4))
489 6905 aaronmk
            except (DoesNotExistException, StopIteration):
490
                raise sql_gen.NoUnderlyingTableException(col)
491 6900 aaronmk
            if is_array: type_ = sql_gen.ArrayType(type_)
492 6899 aaronmk
        else:
493
            table = sql_gen.Table('columns', 'information_schema')
494
            cols = [sql_gen.Col('data_type'), sql_gen.Col('udt_name'),
495
                'column_default', sql_gen.Cast('boolean',
496
                sql_gen.Col('is_nullable'))]
497
498
            conds = [('table_name', col.table.name),
499
                ('column_name', strings.ustr(col.name))]
500
            schema = col.table.schema
501
            if schema != None: conds.append(('table_schema', schema))
502
503
            cur = select(self, table, cols, conds, order_by='table_schema',
504
                limit=1, cacheable=cacheable, log_level=4)
505
            try: type_, extra_type, default, nullable = row(cur)
506
            except StopIteration: raise sql_gen.NoUnderlyingTableException(col)
507
            if type_ == 'USER-DEFINED': type_ = extra_type
508
            elif type_ == 'ARRAY':
509
                type_ = sql_gen.ArrayType(strings.remove_prefix('_', extra_type,
510
                    require=True))
511 2643 aaronmk
512 6899 aaronmk
        if default != None: default = sql_gen.as_Code(default, self)
513 2819 aaronmk
        return sql_gen.TypedCol(col.name, type_, default, nullable)
514 2917 aaronmk
515
    def TempFunction(self, name):
516
        if self.debug_temp: schema = None
517
        else: schema = 'pg_temp'
518
        return sql_gen.Function(name, schema)
519 1849 aaronmk
520 1869 aaronmk
connect = DbConn
521
522 832 aaronmk
##### Recoverable querying
523 15 aaronmk
524 5577 aaronmk
def parse_exception(db, e, recover=False):
525
    msg = strings.ustr(e.args[0])
526
    msg = re.sub(r'^(?:PL/Python: )?ValueError: ', r'', msg)
527
528
    match = re.match(r'^invalid byte sequence for encoding "(.+?)":', msg)
529
    if match:
530
        encoding, = match.groups()
531
        raise EncodingException(encoding, e)
532
533 5763 aaronmk
    def make_DuplicateKeyException(constraint, e):
534 5577 aaronmk
        cols = []
535
        cond = None
536
        if recover: # need auto-rollback to run index_cols()
537
            try:
538
                cols = index_cols(db, constraint)
539
                cond = index_cond(db, constraint)
540
            except NotImplementedError: pass
541 5763 aaronmk
        return DuplicateKeyException(constraint, cond, cols, e)
542 5577 aaronmk
543 5763 aaronmk
    match = re.match(r'^duplicate key value violates unique constraint "(.+?)"',
544
        msg)
545
    if match:
546
        constraint, = match.groups()
547
        raise make_DuplicateKeyException(constraint, e)
548
549 5764 aaronmk
    match = re.match(r'^could not create unique index "(.+?)"\n'
550
        r'DETAIL:  Key .+? is duplicated', msg)
551
    if match:
552
        constraint, = match.groups()
553 5823 aaronmk
        raise DuplicateKeyException(constraint, None, [], e)
554 5764 aaronmk
555 5577 aaronmk
    match = re.match(r'^null value in column "(.+?)" violates not-null'
556
        r' constraint', msg)
557
    if match:
558
        col, = match.groups()
559
        raise NullValueException('NOT NULL', None, [col], e)
560
561
    match = re.match(r'^new row for relation "(.+?)" violates check '
562
        r'constraint "(.+?)"', msg)
563
    if match:
564
        table, constraint = match.groups()
565
        constraint = sql_gen.Col(constraint, table)
566
        cond = None
567
        if recover: # need auto-rollback to run constraint_cond()
568
            try: cond = constraint_cond(db, constraint)
569
            except NotImplementedError: pass
570
        raise CheckException(constraint.to_str(db), cond, [], e)
571
572 5723 aaronmk
    match = re.match(r'^(?:invalid input (?:syntax|value)\b[^:]*'
573 5717 aaronmk
        r'|.+? out of range)(?:: "(.+?)")?', msg)
574 5577 aaronmk
    if match:
575
        value, = match.groups()
576 5717 aaronmk
        value = util.do_ignore_none(strings.to_unicode, value)
577
        raise InvalidValueException(value, e)
578 5577 aaronmk
579
    match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
580
        r'is of type', msg)
581
    if match:
582
        col, type_ = match.groups()
583
        raise MissingCastException(type_, col, e)
584
585
    match = re.match(r'^could not determine polymorphic type because '
586
        r'input has type "unknown"', msg)
587
    if match: raise MissingCastException('text', None, e)
588
589 5724 aaronmk
    match = re.match(r'^.+? types (.+?) and (.+?) cannot be matched', msg)
590
    if match:
591
        type0, type1 = match.groups()
592
        raise MissingCastException(type0, None, e)
593 5577 aaronmk
594 5707 aaronmk
    typed_name_re = r'^(\S+) "?(.+?)"?(?: of relation ".+?")?'
595 5577 aaronmk
596
    match = re.match(typed_name_re+r'.*? already exists', msg)
597
    if match:
598
        type_, name = match.groups()
599
        raise DuplicateException(type_, name, e)
600
601
    match = re.match(r'more than one (\S+) named ""(.+?)""', msg)
602
    if match:
603
        type_, name = match.groups()
604
        raise DuplicateException(type_, name, e)
605
606
    match = re.match(typed_name_re+r' does not exist', msg)
607
    if match:
608
        type_, name = match.groups()
609 5683 aaronmk
        if type_ == 'function':
610 5709 aaronmk
            match = re.match(r'^(.+?)\(.*\)$', name)
611 5710 aaronmk
            if match: # includes params, so is call rather than cast to regproc
612 5708 aaronmk
                function_name, = match.groups()
613 5711 aaronmk
                func = sql_gen.Function(function_name)
614 5712 aaronmk
                if function_exists(db, func) and msg.find('CAST') < 0:
615 5683 aaronmk
                    # not found only because of a missing cast
616 5715 aaronmk
                    type_ = function_param0_type(db, func)
617 6301 aaronmk
                    col = None
618 5792 aaronmk
                    if type_ == 'anyelement': type_ = 'text'
619 6301 aaronmk
                    elif type_ == 'hstore': # cast just the value param
620
                        type_ = 'text'
621
                        col = 'value'
622
                    raise MissingCastException(type_, col, e)
623 5577 aaronmk
        raise DoesNotExistException(type_, name, e)
624
625
    raise # no specific exception raised
626
627 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
628 11 aaronmk
629 2791 aaronmk
def run_query(db, query, recover=None, cacheable=False, log_level=2,
630
    log_ignore_excs=None, **kw_args):
631 2794 aaronmk
    '''For params, see DbConn.run_query()'''
632 830 aaronmk
    if recover == None: recover = False
633 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
634
    log_ignore_excs = tuple(log_ignore_excs)
635 3236 aaronmk
    debug_msg_ref = [None]
636 830 aaronmk
637 3267 aaronmk
    query = with_explain_comment(db, query)
638 3258 aaronmk
639 2148 aaronmk
    try:
640 2464 aaronmk
        try:
641 2794 aaronmk
            def run(): return db.run_query(query, cacheable, log_level,
642 2793 aaronmk
                debug_msg_ref, **kw_args)
643 2796 aaronmk
            if recover and not db.is_cached(query):
644 2464 aaronmk
                return with_savepoint(db, run)
645
            else: return run() # don't need savepoint if cached
646 5890 aaronmk
        except Exception, e:
647
            # Give failed EXPLAIN approximately the log_level of its query
648
            if query.startswith('EXPLAIN'): log_level -= 1
649
650
            parse_exception(db, e, recover)
651 2464 aaronmk
    except log_ignore_excs:
652
        log_level += 2
653
        raise
654
    finally:
655 3236 aaronmk
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)
656 830 aaronmk
657 832 aaronmk
##### Basic queries
658
659 3256 aaronmk
def is_explainable(query):
660
    # See <http://www.postgresql.org/docs/8.3/static/sql-explain.html#AEN57749>
661 3257 aaronmk
    return re.match(r'^(?:SELECT|INSERT|UPDATE|DELETE|VALUES|EXECUTE|DECLARE)\b'
662
        , query)
663 3256 aaronmk
664 3263 aaronmk
def explain(db, query, **kw_args):
665
    '''
666
    For params, see run_query().
667
    '''
668 3267 aaronmk
    kw_args.setdefault('log_level', 4)
669 3263 aaronmk
670 3750 aaronmk
    return strings.ustr(strings.join_lines(values(run_query(db,
671
        'EXPLAIN '+query, recover=True, cacheable=True, **kw_args))))
672 3256 aaronmk
        # not a higher log_level because it's useful to see what query is being
673
        # run before it's executed, which EXPLAIN effectively provides
674
675 3265 aaronmk
def has_comment(query): return query.endswith('*/')
676
677
def with_explain_comment(db, query, **kw_args):
678 3269 aaronmk
    if db.autoexplain and not has_comment(query) and is_explainable(query):
679 3265 aaronmk
        query += '\n'+sql_gen.esc_comment(' EXPLAIN:\n'
680
            +explain(db, query, **kw_args))
681
    return query
682
683 2153 aaronmk
def next_version(name):
684 2163 aaronmk
    version = 1 # first existing name was version 0
685 2586 aaronmk
    match = re.match(r'^(.*)#(\d+)$', name)
686 2153 aaronmk
    if match:
687 2586 aaronmk
        name, version = match.groups()
688
        version = int(version)+1
689 2932 aaronmk
    return sql_gen.concat(name, '#'+str(version))
690 2153 aaronmk
691 2899 aaronmk
def lock_table(db, table, mode):
692
    table = sql_gen.as_Table(table)
693
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
694
695 3303 aaronmk
def run_query_into(db, query, into=None, add_pkey_=False, **kw_args):
696 2085 aaronmk
    '''Outputs a query to a temp table.
697
    For params, see run_query().
698
    '''
699 2789 aaronmk
    if into == None: return run_query(db, query, **kw_args)
700 2790 aaronmk
701
    assert isinstance(into, sql_gen.Table)
702
703 2992 aaronmk
    into.is_temp = True
704 3008 aaronmk
    # "temporary tables cannot specify a schema name", so remove schema
705
    into.schema = None
706 2992 aaronmk
707 2790 aaronmk
    kw_args['recover'] = True
708 2945 aaronmk
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
709 2790 aaronmk
710 2916 aaronmk
    temp = not db.debug_temp # tables are permanent in debug_temp mode
711 2790 aaronmk
712
    # Create table
713
    while True:
714
        create_query = 'CREATE'
715
        if temp: create_query += ' TEMP'
716
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
717 2385 aaronmk
718 2790 aaronmk
        try:
719
            cur = run_query(db, create_query, **kw_args)
720
                # CREATE TABLE AS sets rowcount to # rows in query
721
            break
722 2945 aaronmk
        except DuplicateException, e:
723 2790 aaronmk
            into.name = next_version(into.name)
724
            # try again with next version of name
725
726 5967 aaronmk
    if add_pkey_: add_pkey_or_index(db, into, warn=True)
727 3075 aaronmk
728
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
729
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
730
    # table is going to be used in complex queries, it is wise to run ANALYZE on
731
    # the temporary table after it is populated."
732
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
733
    # If into is not a temp table, ANALYZE is useful but not required.
734 3073 aaronmk
    analyze(db, into)
735 2790 aaronmk
736
    return cur
737 2085 aaronmk
738 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
739
740 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
741
742 3420 aaronmk
def mk_select(db, tables=None, fields=None, conds=None, distinct_on=[],
743 3494 aaronmk
    limit=None, start=None, order_by=order_by_pkey, default_table=None,
744
    explain=True):
745 1981 aaronmk
    '''
746 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
747 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
748 1981 aaronmk
    @param fields Use None to select all fields in the table
749 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
750 2379 aaronmk
        * container can be any iterable type
751 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
752
        * compare_right_side: sql_gen.ValueCond|literal value
753 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
754
        use all columns
755 2786 aaronmk
    @return query
756 1981 aaronmk
    '''
757 2315 aaronmk
    # Parse tables param
758 2964 aaronmk
    tables = lists.mk_seq(tables)
759 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
760 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
761 2121 aaronmk
762 2315 aaronmk
    # Parse other params
763 2376 aaronmk
    if conds == None: conds = []
764 2650 aaronmk
    elif dicts.is_dict(conds): conds = conds.items()
765 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
766 3129 aaronmk
    assert limit == None or isinstance(limit, (int, long))
767
    assert start == None or isinstance(start, (int, long))
768 5527 aaronmk
    if limit == 0: order_by = None
769 2315 aaronmk
    if order_by is order_by_pkey:
770 5641 aaronmk
        if lists.is_seq(distinct_on) and distinct_on: order_by = distinct_on[0]
771
        elif table0 != None: order_by = table_order_by(db, table0, recover=True)
772
        else: order_by = None
773 865 aaronmk
774 2315 aaronmk
    query = 'SELECT'
775 2056 aaronmk
776 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
777 2056 aaronmk
778 2200 aaronmk
    # DISTINCT ON columns
779 2233 aaronmk
    if distinct_on != []:
780 2467 aaronmk
        query += '\nDISTINCT'
781 2254 aaronmk
        if distinct_on is not distinct_on_all:
782 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
783
784
    # Columns
785 3185 aaronmk
    if query.find('\n') >= 0: whitespace = '\n'
786
    else: whitespace = ' '
787
    if fields == None: query += whitespace+'*'
788 2765 aaronmk
    else:
789
        assert fields != []
790 3185 aaronmk
        if len(fields) > 1: whitespace = '\n'
791
        query += whitespace+('\n, '.join(map(parse_col, fields)))
792 2200 aaronmk
793
    # Main table
794 3185 aaronmk
    if query.find('\n') >= 0 or len(tables) > 0: whitespace = '\n'
795
    else: whitespace = ' '
796 3420 aaronmk
    if table0 != None: query += whitespace+'FROM '+table0.to_str(db)
797 865 aaronmk
798 2122 aaronmk
    # Add joins
799 2271 aaronmk
    left_table = table0
800 2263 aaronmk
    for join_ in tables:
801
        table = join_.table
802 2238 aaronmk
803 2343 aaronmk
        # Parse special values
804
        if join_.type_ is sql_gen.filter_out: # filter no match
805 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
806 2853 aaronmk
                sql_gen.CompareCond(None, '~=')))
807 2343 aaronmk
808 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
809 2122 aaronmk
810
        left_table = table
811
812 865 aaronmk
    missing = True
813 2376 aaronmk
    if conds != []:
814 2576 aaronmk
        if len(conds) == 1: whitespace = ' '
815
        else: whitespace = '\n'
816 2578 aaronmk
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
817
            .to_str(db) for l, r in conds], 'WHERE')
818 2227 aaronmk
    if order_by != None:
819 5642 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by).to_str(db)
820 3297 aaronmk
    if limit != None: query += '\nLIMIT '+str(limit)
821 865 aaronmk
    if start != None:
822 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
823 865 aaronmk
824 3494 aaronmk
    if explain: query = with_explain_comment(db, query)
825 3266 aaronmk
826 2786 aaronmk
    return query
827 11 aaronmk
828 2054 aaronmk
def select(db, *args, **kw_args):
829
    '''For params, see mk_select() and run_query()'''
830
    recover = kw_args.pop('recover', None)
831
    cacheable = kw_args.pop('cacheable', True)
832 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
833 2054 aaronmk
834 2791 aaronmk
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
835
        log_level=log_level)
836 2054 aaronmk
837 2788 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
838 3181 aaronmk
    embeddable=False, ignore=False, src=None):
839 1960 aaronmk
    '''
840
    @param returning str|None An inserted column (such as pkey) to return
841 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
842 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
843
        query will be fully cached, not just if it raises an exception.
844 3009 aaronmk
    @param ignore Whether to ignore duplicate keys.
845 3181 aaronmk
    @param src Will be included in the name of any created function, to help
846
        identify the data source in pg_stat_activity.
847 1960 aaronmk
    '''
848 2754 aaronmk
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
849 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
850 3010 aaronmk
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
851 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
852 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
853 2063 aaronmk
854 2497 aaronmk
    first_line = 'INSERT INTO '+table.to_str(db)
855 2063 aaronmk
856 3009 aaronmk
    def mk_insert(select_query):
857
        query = first_line
858 3014 aaronmk
        if cols != None:
859
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
860 3009 aaronmk
        query += '\n'+select_query
861
862
        if returning != None:
863
            returning_name_col = sql_gen.to_name_only_col(returning)
864
            query += '\nRETURNING '+returning_name_col.to_str(db)
865
866
        return query
867 2063 aaronmk
868 3489 aaronmk
    return_type = sql_gen.CustomCode('unknown')
869
    if returning != None: return_type = sql_gen.ColType(returning)
870 3017 aaronmk
871 3009 aaronmk
    if ignore:
872 3017 aaronmk
        # Always return something to set the correct rowcount
873
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
874
875 3009 aaronmk
        embeddable = True # must use function
876 3010 aaronmk
877 3450 aaronmk
        if cols == None: row = [sql_gen.Col(sql_gen.all_cols, 'row')]
878
        else: row = [sql_gen.Col(c.name, 'row') for c in cols]
879 3092 aaronmk
880 3484 aaronmk
        query = sql_gen.RowExcIgnore(sql_gen.RowType(table), select_query,
881 3497 aaronmk
            sql_gen.ReturnQuery(mk_insert(sql_gen.Values(row).to_str(db))),
882
            cols)
883 3009 aaronmk
    else: query = mk_insert(select_query)
884
885 2070 aaronmk
    if embeddable:
886
        # Create function
887 2513 aaronmk
        function_name = sql_gen.clean_name(first_line)
888 3181 aaronmk
        if src != None: function_name = src+': '+function_name
889 2189 aaronmk
        while True:
890
            try:
891 3451 aaronmk
                func = db.TempFunction(function_name)
892 3489 aaronmk
                def_ = sql_gen.FunctionDef(func, sql_gen.SetOf(return_type),
893
                    query)
894 2194 aaronmk
895 3443 aaronmk
                run_query(db, def_.to_str(db), recover=True, cacheable=True,
896 2945 aaronmk
                    log_ignore_excs=(DuplicateException,))
897 2189 aaronmk
                break # this version was successful
898 2945 aaronmk
            except DuplicateException, e:
899 2189 aaronmk
                function_name = next_version(function_name)
900
                # try again with next version of name
901 2070 aaronmk
902 2337 aaronmk
        # Return query that uses function
903 3009 aaronmk
        cols = None
904
        if returning != None: cols = [returning]
905 3451 aaronmk
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(func), cols)
906
            # AS clause requires function alias
907 3298 aaronmk
        return mk_select(db, func_table, order_by=None)
908 2070 aaronmk
909 2787 aaronmk
    return query
910 2066 aaronmk
911 3074 aaronmk
def insert_select(db, table, *args, **kw_args):
912 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
913 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
914
        values in
915 2072 aaronmk
    '''
916 3141 aaronmk
    returning = kw_args.get('returning', None)
917
    ignore = kw_args.get('ignore', False)
918
919 2386 aaronmk
    into = kw_args.pop('into', None)
920
    if into != None: kw_args['embeddable'] = True
921 2066 aaronmk
    recover = kw_args.pop('recover', None)
922 3141 aaronmk
    if ignore: recover = True
923 2066 aaronmk
    cacheable = kw_args.pop('cacheable', True)
924 2673 aaronmk
    log_level = kw_args.pop('log_level', 2)
925 2066 aaronmk
926 3141 aaronmk
    rowcount_only = ignore and returning == None # keep NULL rows on server
927
    if rowcount_only: into = sql_gen.Table('rowcount')
928
929 3074 aaronmk
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
930
        into, recover=recover, cacheable=cacheable, log_level=log_level)
931 3141 aaronmk
    if rowcount_only: empty_temp(db, into)
932 3074 aaronmk
    autoanalyze(db, table)
933
    return cur
934 2063 aaronmk
935 2738 aaronmk
default = sql_gen.default # tells insert() to use the default value for a column
936 2066 aaronmk
937 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
938 2085 aaronmk
    '''For params, see insert_select()'''
939 5050 aaronmk
    ignore = kw_args.pop('ignore', False)
940 5094 aaronmk
    if ignore: kw_args.setdefault('recover', True)
941 5050 aaronmk
942 1960 aaronmk
    if lists.is_seq(row): cols = None
943
    else:
944
        cols = row.keys()
945
        row = row.values()
946 2738 aaronmk
    row = list(row) # ensure that "== []" works
947 1960 aaronmk
948 2738 aaronmk
    if row == []: query = None
949
    else: query = sql_gen.Values(row).to_str(db)
950 1961 aaronmk
951 5050 aaronmk
    try: return insert_select(db, table, cols, query, *args, **kw_args)
952 5057 aaronmk
    except (DuplicateKeyException, NullValueException):
953 5050 aaronmk
        if not ignore: raise
954 5163 aaronmk
        return None
955 11 aaronmk
956 3152 aaronmk
def mk_update(db, table, changes=None, cond=None, in_place=False,
957 3153 aaronmk
    cacheable_=True):
958 2402 aaronmk
    '''
959
    @param changes [(col, new_value),...]
960
        * container can be any iterable type
961
        * col: sql_gen.Code|str (for col name)
962
        * new_value: sql_gen.Code|literal value
963
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
964 3056 aaronmk
    @param in_place If set, locks the table and updates rows in place.
965
        This avoids creating dead rows in PostgreSQL.
966
        * cond must be None
967 3153 aaronmk
    @param cacheable_ Whether column structure information used to generate the
968 3152 aaronmk
        query can be cached
969 2402 aaronmk
    @return str query
970
    '''
971 3057 aaronmk
    table = sql_gen.as_Table(table)
972
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
973
        for c, v in changes]
974
975 3056 aaronmk
    if in_place:
976
        assert cond == None
977 3058 aaronmk
978 5398 aaronmk
        def col_type(col):
979
            return sql_gen.canon_type(db.col_info(
980
                sql_gen.with_default_table(c, table), cacheable_).type)
981
        changes = [(c, v, col_type(c)) for c, v in changes]
982 3065 aaronmk
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
983 5396 aaronmk
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '+t+'\nUSING '
984
            +v.to_str(db) for c, v, t in changes))
985 3058 aaronmk
    else:
986
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
987
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
988
            for c, v in changes))
989
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
990 3056 aaronmk
991 3266 aaronmk
    query = with_explain_comment(db, query)
992
993 2402 aaronmk
    return query
994
995 3074 aaronmk
def update(db, table, *args, **kw_args):
996 2402 aaronmk
    '''For params, see mk_update() and run_query()'''
997
    recover = kw_args.pop('recover', None)
998 3043 aaronmk
    cacheable = kw_args.pop('cacheable', False)
999 3030 aaronmk
    log_level = kw_args.pop('log_level', 2)
1000 2402 aaronmk
1001 3074 aaronmk
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
1002
        cacheable, log_level=log_level)
1003
    autoanalyze(db, table)
1004
    return cur
1005 2402 aaronmk
1006 3286 aaronmk
def mk_delete(db, table, cond=None):
1007
    '''
1008
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
1009
    @return str query
1010
    '''
1011
    query = 'DELETE FROM '+table.to_str(db)
1012
    if cond != None: query += '\nWHERE '+cond.to_str(db)
1013
1014
    query = with_explain_comment(db, query)
1015
1016
    return query
1017
1018
def delete(db, table, *args, **kw_args):
1019
    '''For params, see mk_delete() and run_query()'''
1020
    recover = kw_args.pop('recover', None)
1021 3295 aaronmk
    cacheable = kw_args.pop('cacheable', True)
1022 3286 aaronmk
    log_level = kw_args.pop('log_level', 2)
1023
1024
    cur = run_query(db, mk_delete(db, table, *args, **kw_args), recover,
1025
        cacheable, log_level=log_level)
1026
    autoanalyze(db, table)
1027
    return cur
1028
1029 135 aaronmk
def last_insert_id(db):
1030 1849 aaronmk
    module = util.root_module(db.db)
1031 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
1032
    elif module == 'MySQLdb': return db.insert_id()
1033
    else: return None
1034 13 aaronmk
1035 3490 aaronmk
def define_func(db, def_):
1036
    func = def_.function
1037
    while True:
1038
        try:
1039
            run_query(db, def_.to_str(db), recover=True, cacheable=True,
1040
                log_ignore_excs=(DuplicateException,))
1041
            break # successful
1042
        except DuplicateException:
1043 3495 aaronmk
            func.name = next_version(func.name)
1044 3490 aaronmk
            # try again with next version of name
1045
1046 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
1047 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
1048 2415 aaronmk
    to names that will be distinct among the columns' tables.
1049 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
1050 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
1051
    @param into The table for the new columns.
1052 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
1053
        columns will be included in the mapping even if they are not in cols.
1054
        The tables of the provided Col objects will be changed to into, so make
1055
        copies of them if you want to keep the original tables.
1056
    @param as_items Whether to return a list of dict items instead of a dict
1057 2383 aaronmk
    @return dict(orig_col=new_col, ...)
1058
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
1059 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
1060
        * All mappings use the into table so its name can easily be
1061 2383 aaronmk
          changed for all columns at once
1062
    '''
1063 2415 aaronmk
    cols = lists.uniqify(cols)
1064
1065 2394 aaronmk
    items = []
1066 2389 aaronmk
    for col in preserve:
1067 2390 aaronmk
        orig_col = copy.copy(col)
1068 2392 aaronmk
        col.table = into
1069 2394 aaronmk
        items.append((orig_col, col))
1070
    preserve = set(preserve)
1071
    for col in cols:
1072 2716 aaronmk
        if col not in preserve:
1073 3750 aaronmk
            items.append((col, sql_gen.Col(strings.ustr(col), into, col.srcs)))
1074 2394 aaronmk
1075
    if not as_items: items = dict(items)
1076
    return items
1077 2383 aaronmk
1078 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
1079 2391 aaronmk
    '''For params, see mk_flatten_mapping()
1080
    @return See return value of mk_flatten_mapping()
1081
    '''
1082 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
1083
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
1084 5523 aaronmk
    run_query_into(db, mk_select(db, joins, cols, limit=limit, start=start),
1085
        into=into, add_pkey_=True)
1086 3708 aaronmk
        # don't cache because the temp table will usually be truncated after use
1087 2394 aaronmk
    return dict(items)
1088 2391 aaronmk
1089 3079 aaronmk
##### Database structure introspection
1090 2414 aaronmk
1091 3079 aaronmk
#### Tables
1092
1093 4555 aaronmk
def tables(db, schema_like='public', table_like='%', exact=False,
1094
    cacheable=True):
1095 3079 aaronmk
    if exact: compare = '='
1096
    else: compare = 'LIKE'
1097
1098
    module = util.root_module(db.db)
1099
    if module == 'psycopg2':
1100
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
1101
            ('tablename', sql_gen.CompareCond(table_like, compare))]
1102
        return values(select(db, 'pg_tables', ['tablename'], conds,
1103 4555 aaronmk
            order_by='tablename', cacheable=cacheable, log_level=4))
1104 3079 aaronmk
    elif module == 'MySQLdb':
1105
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
1106
            , cacheable=True, log_level=4))
1107
    else: raise NotImplementedError("Can't list tables for "+module+' database')
1108
1109 4556 aaronmk
def table_exists(db, table, cacheable=True):
1110 3079 aaronmk
    table = sql_gen.as_Table(table)
1111 4556 aaronmk
    return list(tables(db, table.schema, table.name, True, cacheable)) != []
1112 3079 aaronmk
1113 2426 aaronmk
def table_row_count(db, table, recover=None):
1114 2786 aaronmk
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
1115 3298 aaronmk
        order_by=None), recover=recover, log_level=3))
1116 2426 aaronmk
1117 5337 aaronmk
def table_col_names(db, table, recover=None):
1118 5528 aaronmk
    return list(col_names(select(db, table, limit=0, recover=recover,
1119
        log_level=4)))
1120 2414 aaronmk
1121 5383 aaronmk
def table_cols(db, table, *args, **kw_args):
1122
    return [sql_gen.as_Col(strings.ustr(c), table)
1123
        for c in table_col_names(db, table, *args, **kw_args)]
1124
1125 5521 aaronmk
def table_pkey_index(db, table, recover=None):
1126
    table_str = sql_gen.Literal(table.to_str(db))
1127
    try:
1128
        return sql_gen.Table(value(run_query(db, '''\
1129
SELECT relname
1130
FROM pg_index
1131
JOIN pg_class index ON index.oid = indexrelid
1132
WHERE
1133
indrelid = '''+table_str.to_str(db)+'''::regclass
1134
AND indisprimary
1135
'''
1136
            , recover, cacheable=True, log_level=4)), table.schema)
1137
    except StopIteration: raise DoesNotExistException('primary key', '')
1138
1139 5389 aaronmk
def table_pkey_col(db, table, recover=None):
1140 5061 aaronmk
    table = sql_gen.as_Table(table)
1141
1142 5989 aaronmk
    module = util.root_module(db.db)
1143
    if module == 'psycopg2':
1144
        return sql_gen.Col(index_cols(db, table_pkey_index(db, table,
1145
            recover))[0], table)
1146
    else:
1147
        join_cols = ['table_schema', 'table_name', 'constraint_schema',
1148
            'constraint_name']
1149
        tables = [sql_gen.Table('key_column_usage', 'information_schema'),
1150
            sql_gen.Join(
1151
                sql_gen.Table('table_constraints', 'information_schema'),
1152
                dict(((c, sql_gen.join_same_not_null) for c in join_cols)))]
1153
        cols = [sql_gen.Col('column_name')]
1154
1155
        conds = [('constraint_type', 'PRIMARY KEY'), ('table_name', table.name)]
1156
        schema = table.schema
1157
        if schema != None: conds.append(('table_schema', schema))
1158
        order_by = 'position_in_unique_constraint'
1159
1160
        try: return sql_gen.Col(value(select(db, tables, cols, conds,
1161
            order_by=order_by, limit=1, log_level=4)), table)
1162
        except StopIteration: raise DoesNotExistException('primary key', '')
1163 5389 aaronmk
1164 5990 aaronmk
def table_has_pkey(db, table, recover=None):
1165
    try: table_pkey_col(db, table, recover)
1166
    except DoesNotExistException: return False
1167
    else: return True
1168
1169 5389 aaronmk
def pkey_name(db, table, recover=None):
1170
    '''If no pkey, returns the first column in the table.'''
1171 5392 aaronmk
    return pkey_col(db, table, recover).name
1172 832 aaronmk
1173 5390 aaronmk
def pkey_col(db, table, recover=None):
1174 5391 aaronmk
    '''If no pkey, returns the first column in the table.'''
1175 5392 aaronmk
    try: return table_pkey_col(db, table, recover)
1176 5393 aaronmk
    except DoesNotExistException: return table_cols(db, table, recover)[0]
1177 5128 aaronmk
1178 2559 aaronmk
not_null_col = 'not_null_col'
1179 2340 aaronmk
1180
def table_not_null_col(db, table, recover=None):
1181
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
1182 5337 aaronmk
    if not_null_col in table_col_names(db, table, recover): return not_null_col
1183 5388 aaronmk
    else: return pkey_name(db, table, recover)
1184 2340 aaronmk
1185 3348 aaronmk
def constraint_cond(db, constraint):
1186
    module = util.root_module(db.db)
1187
    if module == 'psycopg2':
1188
        table_str = sql_gen.Literal(constraint.table.to_str(db))
1189
        name_str = sql_gen.Literal(constraint.name)
1190
        return value(run_query(db, '''\
1191
SELECT consrc
1192
FROM pg_constraint
1193
WHERE
1194
conrelid = '''+table_str.to_str(db)+'''::regclass
1195
AND conname = '''+name_str.to_str(db)+'''
1196
'''
1197
            , cacheable=True, log_level=4))
1198 5443 aaronmk
    else: raise NotImplementedError("Can't get constraint condition for "
1199
        +module+' database')
1200 3348 aaronmk
1201 5520 aaronmk
def index_exprs(db, index):
1202 3322 aaronmk
    index = sql_gen.as_Table(index)
1203 1909 aaronmk
    module = util.root_module(db.db)
1204
    if module == 'psycopg2':
1205 3322 aaronmk
        qual_index = sql_gen.Literal(index.to_str(db))
1206 5520 aaronmk
        return list(values(run_query(db, '''\
1207 3322 aaronmk
SELECT pg_get_indexdef(indexrelid, generate_series(1, indnatts), true)
1208
FROM pg_index
1209
WHERE indexrelid = '''+qual_index.to_str(db)+'''::regclass
1210 2782 aaronmk
'''
1211
            , cacheable=True, log_level=4)))
1212 5520 aaronmk
    else: raise NotImplementedError()
1213 853 aaronmk
1214 5520 aaronmk
def index_cols(db, index):
1215
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
1216
    automatically created. When you don't know whether something is a UNIQUE
1217
    constraint or a UNIQUE index, use this function.'''
1218
    return map(sql_gen.parse_expr_col, index_exprs(db, index))
1219
1220 5445 aaronmk
def index_cond(db, index):
1221
    index = sql_gen.as_Table(index)
1222
    module = util.root_module(db.db)
1223
    if module == 'psycopg2':
1224
        qual_index = sql_gen.Literal(index.to_str(db))
1225
        return value(run_query(db, '''\
1226
SELECT pg_get_expr(indpred, indrelid, true)
1227
FROM pg_index
1228
WHERE indexrelid = '''+qual_index.to_str(db)+'''::regclass
1229
'''
1230
            , cacheable=True, log_level=4))
1231
    else: raise NotImplementedError()
1232
1233 5521 aaronmk
def index_order_by(db, index):
1234
    return sql_gen.CustomCode(', '.join(index_exprs(db, index)))
1235
1236
def table_cluster_on(db, table, recover=None):
1237
    '''
1238
    @return The table's cluster index, or its pkey if none is set
1239
    '''
1240
    table_str = sql_gen.Literal(table.to_str(db))
1241
    try:
1242
        return sql_gen.Table(value(run_query(db, '''\
1243
SELECT relname
1244
FROM pg_index
1245
JOIN pg_class index ON index.oid = indexrelid
1246
WHERE
1247
indrelid = '''+table_str.to_str(db)+'''::regclass
1248
AND indisclustered
1249
'''
1250
            , recover, cacheable=True, log_level=4)), table.schema)
1251
    except StopIteration: return table_pkey_index(db, table, recover)
1252
1253
def table_order_by(db, table, recover=None):
1254 5525 aaronmk
    if table.order_by == None:
1255
        try: table.order_by = index_order_by(db, table_cluster_on(db, table,
1256
            recover))
1257
        except DoesNotExistException: pass
1258
    return table.order_by
1259 5521 aaronmk
1260 3079 aaronmk
#### Functions
1261
1262
def function_exists(db, function):
1263 3423 aaronmk
    qual_function = sql_gen.Literal(function.to_str(db))
1264
    try:
1265 3425 aaronmk
        select(db, fields=[sql_gen.Cast('regproc', qual_function)],
1266
            recover=True, cacheable=True, log_level=4)
1267 3423 aaronmk
    except DoesNotExistException: return False
1268 4146 aaronmk
    except DuplicateException: return True # overloaded function
1269 3423 aaronmk
    else: return True
1270 3079 aaronmk
1271 5713 aaronmk
def function_param0_type(db, function):
1272
    qual_function = sql_gen.Literal(function.to_str(db))
1273
    return value(run_query(db, '''\
1274
SELECT proargtypes[0]::regtype
1275
FROM pg_proc
1276
WHERE oid = '''+qual_function.to_str(db)+'''::regproc
1277
'''
1278
        , cacheable=True, log_level=4))
1279
1280 3079 aaronmk
##### Structural changes
1281
1282
#### Columns
1283
1284 5020 aaronmk
def add_col(db, table, col, comment=None, if_not_exists=False, **kw_args):
1285 3079 aaronmk
    '''
1286
    @param col TypedCol Name may be versioned, so be sure to propagate any
1287
        renaming back to any source column for the TypedCol.
1288
    @param comment None|str SQL comment used to distinguish columns of the same
1289
        name from each other when they contain different data, to allow the
1290
        ADD COLUMN query to be cached. If not set, query will not be cached.
1291
    '''
1292
    assert isinstance(col, sql_gen.TypedCol)
1293
1294
    while True:
1295
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1296
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1297
1298
        try:
1299
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1300
            break
1301
        except DuplicateException:
1302 5020 aaronmk
            if if_not_exists: raise
1303 3079 aaronmk
            col.name = next_version(col.name)
1304
            # try again with next version of name
1305
1306
def add_not_null(db, col):
1307
    table = col.table
1308
    col = sql_gen.to_name_only_col(col)
1309
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1310
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1311
1312 4443 aaronmk
def drop_not_null(db, col):
1313
    table = col.table
1314
    col = sql_gen.to_name_only_col(col)
1315
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1316
        +col.to_str(db)+' DROP NOT NULL', cacheable=True, log_level=3)
1317
1318 2096 aaronmk
row_num_col = '_row_num'
1319
1320 4997 aaronmk
row_num_col_def = sql_gen.TypedCol('', 'serial', nullable=False,
1321 3079 aaronmk
    constraints='PRIMARY KEY')
1322
1323 4997 aaronmk
def add_row_num(db, table, name=row_num_col):
1324
    '''Adds a row number column to a table. Its definition is in
1325
    row_num_col_def. It will be the primary key.'''
1326
    col_def = copy.copy(row_num_col_def)
1327
    col_def.name = name
1328 5021 aaronmk
    add_col(db, table, col_def, comment='', if_not_exists=True, log_level=3)
1329 3079 aaronmk
1330
#### Indexes
1331
1332
def add_pkey(db, table, cols=None, recover=None):
1333
    '''Adds a primary key.
1334
    @param cols [sql_gen.Col,...] The columns in the primary key.
1335
        Defaults to the first column in the table.
1336
    @pre The table must not already have a primary key.
1337
    '''
1338
    table = sql_gen.as_Table(table)
1339 5388 aaronmk
    if cols == None: cols = [pkey_name(db, table, recover)]
1340 3079 aaronmk
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1341
1342
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1343
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1344
        log_ignore_excs=(DuplicateException,))
1345
1346 2998 aaronmk
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1347 2688 aaronmk
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1348 3356 aaronmk
    Currently, only function calls and literal values are supported expressions.
1349 2998 aaronmk
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1350 2847 aaronmk
        This allows indexes to be used for comparisons where NULLs are equal.
1351 2538 aaronmk
    '''
1352 2964 aaronmk
    exprs = lists.mk_seq(exprs)
1353 2538 aaronmk
1354 2688 aaronmk
    # Parse exprs
1355
    old_exprs = exprs[:]
1356
    exprs = []
1357
    cols = []
1358
    for i, expr in enumerate(old_exprs):
1359 2823 aaronmk
        expr = sql_gen.as_Col(expr, table)
1360 2688 aaronmk
1361 2823 aaronmk
        # Handle nullable columns
1362 2998 aaronmk
        if ensure_not_null_:
1363 3164 aaronmk
            try: expr = sql_gen.ensure_not_null(db, expr)
1364 2860 aaronmk
            except KeyError: pass # unknown type, so just create plain index
1365 2823 aaronmk
1366 2688 aaronmk
        # Extract col
1367 3002 aaronmk
        expr = copy.deepcopy(expr) # don't modify input!
1368 3356 aaronmk
        col = expr
1369
        if isinstance(expr, sql_gen.FunctionCall): col = expr.args[0]
1370
        expr = sql_gen.cast_literal(expr)
1371
        if not isinstance(expr, (sql_gen.Expr, sql_gen.Col)):
1372 2688 aaronmk
            expr = sql_gen.Expr(expr)
1373 3356 aaronmk
1374 2688 aaronmk
1375
        # Extract table
1376
        if table == None:
1377
            assert sql_gen.is_table_col(col)
1378
            table = col.table
1379
1380 3356 aaronmk
        if isinstance(col, sql_gen.Col): col.table = None
1381 2688 aaronmk
1382
        exprs.append(expr)
1383
        cols.append(col)
1384 2408 aaronmk
1385 2688 aaronmk
    table = sql_gen.as_Table(table)
1386
1387 3005 aaronmk
    # Add index
1388 3148 aaronmk
    str_ = 'CREATE'
1389
    if unique: str_ += ' UNIQUE'
1390
    str_ += ' INDEX ON '+table.to_str(db)+' ('+(
1391
        ', '.join((v.to_str(db) for v in exprs)))+')'
1392
    run_query(db, str_, recover=True, cacheable=True, log_level=3)
1393 2408 aaronmk
1394 5987 aaronmk
def add_pkey_index(db, table): add_index(db, pkey_col(db, table), table)
1395
1396 5765 aaronmk
def add_pkey_or_index(db, table, cols=None, recover=None, warn=False):
1397
    try: add_pkey(db, table, cols, recover)
1398
    except DuplicateKeyException, e:
1399
        if warn: warnings.warn(UserWarning(exc.str_(e)))
1400 5988 aaronmk
        add_pkey_index(db, table)
1401 5765 aaronmk
1402 3083 aaronmk
already_indexed = object() # tells add_indexes() the pkey has already been added
1403
1404
def add_indexes(db, table, has_pkey=True):
1405
    '''Adds an index on all columns in a table.
1406
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1407
        index should be added on the first column.
1408
        * If already_indexed, the pkey is assumed to have already been added
1409
    '''
1410 5337 aaronmk
    cols = table_col_names(db, table)
1411 3083 aaronmk
    if has_pkey:
1412
        if has_pkey is not already_indexed: add_pkey(db, table)
1413
        cols = cols[1:]
1414
    for col in cols: add_index(db, col, table)
1415
1416 3079 aaronmk
#### Tables
1417 2772 aaronmk
1418 3079 aaronmk
### Maintenance
1419 2772 aaronmk
1420 3079 aaronmk
def analyze(db, table):
1421
    table = sql_gen.as_Table(table)
1422
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1423 2934 aaronmk
1424 3079 aaronmk
def autoanalyze(db, table):
1425
    if db.autoanalyze: analyze(db, table)
1426 2935 aaronmk
1427 3079 aaronmk
def vacuum(db, table):
1428
    table = sql_gen.as_Table(table)
1429
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1430
        log_level=3))
1431 2086 aaronmk
1432 3079 aaronmk
### Lifecycle
1433
1434 3247 aaronmk
def drop(db, type_, name):
1435
    name = sql_gen.as_Name(name)
1436
    run_query(db, 'DROP '+type_+' IF EXISTS '+name.to_str(db)+' CASCADE')
1437 2889 aaronmk
1438 3247 aaronmk
def drop_table(db, table): drop(db, 'TABLE', table)
1439
1440 3082 aaronmk
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1441
    like=None):
1442 2675 aaronmk
    '''Creates a table.
1443 2681 aaronmk
    @param cols [sql_gen.TypedCol,...] The column names and types
1444
    @param has_pkey If set, the first column becomes the primary key.
1445 2760 aaronmk
    @param col_indexes bool|[ref]
1446
        * If True, indexes will be added on all non-pkey columns.
1447
        * If a list reference, [0] will be set to a function to do this.
1448
          This can be used to delay index creation until the table is populated.
1449 2675 aaronmk
    '''
1450
    table = sql_gen.as_Table(table)
1451
1452 3082 aaronmk
    if like != None:
1453
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1454
            ]+cols
1455 5525 aaronmk
        table.order_by = like.order_by
1456 2681 aaronmk
    if has_pkey:
1457
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1458 2872 aaronmk
        pkey.constraints = 'PRIMARY KEY'
1459 2681 aaronmk
1460 3085 aaronmk
    temp = table.is_temp and not db.debug_temp
1461
        # temp tables permanent in debug_temp mode
1462 2760 aaronmk
1463 3085 aaronmk
    # Create table
1464 3383 aaronmk
    def create():
1465 3085 aaronmk
        str_ = 'CREATE'
1466
        if temp: str_ += ' TEMP'
1467
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1468
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1469 3126 aaronmk
        str_ += '\n);'
1470 3085 aaronmk
1471 3383 aaronmk
        run_query(db, str_, recover=True, cacheable=True, log_level=2,
1472
            log_ignore_excs=(DuplicateException,))
1473
    if table.is_temp:
1474
        while True:
1475
            try:
1476
                create()
1477
                break
1478
            except DuplicateException:
1479
                table.name = next_version(table.name)
1480
                # try again with next version of name
1481
    else: create()
1482 3085 aaronmk
1483 2760 aaronmk
    # Add indexes
1484 2773 aaronmk
    if has_pkey: has_pkey = already_indexed
1485
    def add_indexes_(): add_indexes(db, table, has_pkey)
1486
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1487
    elif col_indexes: add_indexes_() # add now
1488 2675 aaronmk
1489 3084 aaronmk
def copy_table_struct(db, src, dest):
1490
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1491 3085 aaronmk
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1492 3084 aaronmk
1493 5529 aaronmk
def copy_table(db, src, dest):
1494
    '''Creates a copy of a table, including data'''
1495
    copy_table_struct(db, src, dest)
1496
    insert_select(db, dest, None, mk_select(db, src))
1497
1498 3079 aaronmk
### Data
1499 2684 aaronmk
1500 2970 aaronmk
def truncate(db, table, schema='public', **kw_args):
1501
    '''For params, see run_query()'''
1502 2777 aaronmk
    table = sql_gen.as_Table(table, schema)
1503 2970 aaronmk
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1504 2732 aaronmk
1505 2965 aaronmk
def empty_temp(db, tables):
1506
    tables = lists.mk_seq(tables)
1507 2971 aaronmk
    for table in tables: truncate(db, table, log_level=3)
1508 2965 aaronmk
1509 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
1510
    '''For kw_args, see tables()'''
1511
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1512 3094 aaronmk
1513
def distinct_table(db, table, distinct_on):
1514
    '''Creates a copy of a temp table which is distinct on the given columns.
1515 5902 aaronmk
    Adds an index on table's distinct_on columns, to facilitate merge joins.
1516 3097 aaronmk
    @param distinct_on If empty, creates a table with one row. This is useful if
1517
        your distinct_on columns are all literal values.
1518 3099 aaronmk
    @return The new table.
1519 3094 aaronmk
    '''
1520 3099 aaronmk
    new_table = sql_gen.suffixed_table(table, '_distinct')
1521 3411 aaronmk
    distinct_on = filter(sql_gen.is_table_col, distinct_on)
1522 3094 aaronmk
1523 3099 aaronmk
    copy_table_struct(db, table, new_table)
1524 3097 aaronmk
1525
    limit = None
1526
    if distinct_on == []: limit = 1 # one sample row
1527 5903 aaronmk
    else: add_index(db, distinct_on, table) # for join optimization
1528 3097 aaronmk
1529 5903 aaronmk
    insert_select(db, new_table, None, mk_select(db, table,
1530
        distinct_on=distinct_on, order_by=None, limit=limit))
1531 3099 aaronmk
    analyze(db, new_table)
1532 3094 aaronmk
1533 3099 aaronmk
    return new_table