Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1960 aaronmk
import lists
11 1889 aaronmk
from Proxy import Proxy
12 1872 aaronmk
import rand
13 2217 aaronmk
import sql_gen
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 2804 aaronmk
def get_cur_query(cur, input_query=None):
20 2168 aaronmk
    raw_query = None
21
    if hasattr(cur, 'query'): raw_query = cur.query
22
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
23 2170 aaronmk
24
    if raw_query != None: return raw_query
25 2804 aaronmk
    else: return '[input] '+strings.ustr(input_query)
26 14 aaronmk
27 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
28
    '''For params, see get_cur_query()'''
29 2771 aaronmk
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
30 135 aaronmk
31 300 aaronmk
class DbException(exc.ExceptionWithCause):
32 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
33 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
34 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
35
36 2143 aaronmk
class ExceptionWithName(DbException):
37
    def __init__(self, name, cause=None):
38 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
39 2143 aaronmk
        self.name = name
40 360 aaronmk
41 3109 aaronmk
class ExceptionWithValue(DbException):
42
    def __init__(self, value, cause=None):
43
        DbException.__init__(self, 'for value: '+strings.as_tt(repr(value)),
44
            cause)
45 2240 aaronmk
        self.value = value
46
47 2945 aaronmk
class ExceptionWithNameType(DbException):
48
    def __init__(self, type_, name, cause=None):
49
        DbException.__init__(self, 'for type: '+strings.as_tt(str(type_))
50
            +'; name: '+strings.as_tt(name), cause)
51
        self.type = type_
52
        self.name = name
53
54 2306 aaronmk
class ConstraintException(DbException):
55
    def __init__(self, name, cols, cause=None):
56 2484 aaronmk
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
57
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
58 2306 aaronmk
        self.name = name
59 468 aaronmk
        self.cols = cols
60 11 aaronmk
61 2523 aaronmk
class MissingCastException(DbException):
62
    def __init__(self, type_, col, cause=None):
63
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
64
            +' on column: '+strings.as_tt(col), cause)
65
        self.type = type_
66
        self.col = col
67
68 2143 aaronmk
class NameException(DbException): pass
69
70 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
71 13 aaronmk
72 2306 aaronmk
class NullValueException(ConstraintException): pass
73 13 aaronmk
74 3109 aaronmk
class InvalidValueException(ExceptionWithValue): pass
75 2239 aaronmk
76 2945 aaronmk
class DuplicateException(ExceptionWithNameType): pass
77 2143 aaronmk
78 89 aaronmk
class EmptyRowException(DbException): pass
79
80 865 aaronmk
##### Warnings
81
82
class DbWarning(UserWarning): pass
83
84 1930 aaronmk
##### Result retrieval
85
86
def col_names(cur): return (col[0] for col in cur.description)
87
88
def rows(cur): return iter(lambda: cur.fetchone(), None)
89
90
def consume_rows(cur):
91
    '''Used to fetch all rows so result will be cached'''
92
    iters.consume_iter(rows(cur))
93
94
def next_row(cur): return rows(cur).next()
95
96
def row(cur):
97
    row_ = next_row(cur)
98
    consume_rows(cur)
99
    return row_
100
101
def next_value(cur): return next_row(cur)[0]
102
103
def value(cur): return row(cur)[0]
104
105
def values(cur): return iters.func_iter(lambda: next_value(cur))
106
107
def value_or_none(cur):
108
    try: return value(cur)
109
    except StopIteration: return None
110
111 2762 aaronmk
##### Escaping
112 2101 aaronmk
113 2573 aaronmk
def esc_name_by_module(module, name):
114
    if module == 'psycopg2' or module == None: quote = '"'
115 2101 aaronmk
    elif module == 'MySQLdb': quote = '`'
116
    else: raise NotImplementedError("Can't escape name for "+module+' database')
117 2500 aaronmk
    return sql_gen.esc_name(name, quote)
118 2101 aaronmk
119
def esc_name_by_engine(engine, name, **kw_args):
120
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
121
122
def esc_name(db, name, **kw_args):
123
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
124
125
def qual_name(db, schema, table):
126
    def esc_name_(name): return esc_name(db, name)
127
    table = esc_name_(table)
128
    if schema != None: return esc_name_(schema)+'.'+table
129
    else: return table
130
131 1869 aaronmk
##### Database connections
132 1849 aaronmk
133 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
134 1926 aaronmk
135 1869 aaronmk
db_engines = {
136
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
137
    'PostgreSQL': ('psycopg2', {}),
138
}
139
140
DatabaseErrors_set = set([DbException])
141
DatabaseErrors = tuple(DatabaseErrors_set)
142
143
def _add_module(module):
144
    DatabaseErrors_set.add(module.DatabaseError)
145
    global DatabaseErrors
146
    DatabaseErrors = tuple(DatabaseErrors_set)
147
148
def db_config_str(db_config):
149
    return db_config['engine']+' database '+db_config['database']
150
151 2448 aaronmk
log_debug_none = lambda msg, level=2: None
152 1901 aaronmk
153 1849 aaronmk
class DbConn:
154 2923 aaronmk
    def __init__(self, db_config, autocommit=True, caching=True,
155 2915 aaronmk
        log_debug=log_debug_none, debug_temp=False):
156
        '''
157
        @param debug_temp Whether temporary objects should instead be permanent.
158
            This assists in debugging the internal objects used by the program.
159
        '''
160 1869 aaronmk
        self.db_config = db_config
161 2190 aaronmk
        self.autocommit = autocommit
162
        self.caching = caching
163 1901 aaronmk
        self.log_debug = log_debug
164 2193 aaronmk
        self.debug = log_debug != log_debug_none
165 2915 aaronmk
        self.debug_temp = debug_temp
166 3074 aaronmk
        self.autoanalyze = False
167 1869 aaronmk
168 3124 aaronmk
        self._savepoint = 0
169 3120 aaronmk
        self._reset()
170 1869 aaronmk
171
    def __getattr__(self, name):
172
        if name == '__dict__': raise Exception('getting __dict__')
173
        if name == 'db': return self._db()
174
        else: raise AttributeError()
175
176
    def __getstate__(self):
177
        state = copy.copy(self.__dict__) # shallow copy
178 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
179 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
180
        return state
181
182 3118 aaronmk
    def clear_cache(self): self.query_results = {}
183
184 3120 aaronmk
    def _reset(self):
185 3118 aaronmk
        self.clear_cache()
186 3124 aaronmk
        assert self._savepoint == 0
187 3118 aaronmk
        self._notices_seen = set()
188
        self.__db = None
189
190 2165 aaronmk
    def connected(self): return self.__db != None
191
192 3116 aaronmk
    def close(self):
193 3119 aaronmk
        if not self.connected(): return
194
195 3135 aaronmk
        # Record that the automatic transaction is now closed
196 3136 aaronmk
        self._savepoint -= 1
197 3135 aaronmk
198 3119 aaronmk
        self.db.close()
199 3120 aaronmk
        self._reset()
200 3116 aaronmk
201 3125 aaronmk
    def reconnect(self):
202
        # Do not do this in test mode as it would roll back everything
203
        if self.autocommit: self.close()
204
        # Connection will be reopened automatically on first query
205
206 1869 aaronmk
    def _db(self):
207
        if self.__db == None:
208
            # Process db_config
209
            db_config = self.db_config.copy() # don't modify input!
210 2097 aaronmk
            schemas = db_config.pop('schemas', None)
211 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
212
            module = __import__(module_name)
213
            _add_module(module)
214
            for orig, new in mappings.iteritems():
215
                try: util.rename_key(db_config, orig, new)
216
                except KeyError: pass
217
218
            # Connect
219
            self.__db = module.connect(**db_config)
220
221
            # Configure connection
222 2906 aaronmk
            if hasattr(self.db, 'set_isolation_level'):
223
                import psycopg2.extensions
224
                self.db.set_isolation_level(
225
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
226 2101 aaronmk
            if schemas != None:
227 2893 aaronmk
                search_path = [self.esc_name(s) for s in schemas.split(',')]
228
                search_path.append(value(run_query(self, 'SHOW search_path',
229
                    log_level=4)))
230
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
231
                    log_level=3)
232 3135 aaronmk
233
            # Record that a transaction is already open
234 3136 aaronmk
            self._savepoint += 1
235 1869 aaronmk
236
        return self.__db
237 1889 aaronmk
238 1891 aaronmk
    class DbCursor(Proxy):
239 1927 aaronmk
        def __init__(self, outer):
240 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
241 2191 aaronmk
            self.outer = outer
242 1927 aaronmk
            self.query_results = outer.query_results
243 1894 aaronmk
            self.query_lookup = None
244 1891 aaronmk
            self.result = []
245 1889 aaronmk
246 2802 aaronmk
        def execute(self, query):
247 2764 aaronmk
            self._is_insert = query.startswith('INSERT')
248 2797 aaronmk
            self.query_lookup = query
249 2148 aaronmk
            try:
250 2191 aaronmk
                try:
251 2802 aaronmk
                    cur = self.inner.execute(query)
252 2191 aaronmk
                    self.outer.do_autocommit()
253 2802 aaronmk
                finally: self.query = get_cur_query(self.inner, query)
254 1904 aaronmk
            except Exception, e:
255 2802 aaronmk
                _add_cursor_info(e, self, query)
256 1904 aaronmk
                self.result = e # cache the exception as the result
257
                self._cache_result()
258
                raise
259 3004 aaronmk
260
            # Always cache certain queries
261
            if query.startswith('CREATE') or query.startswith('ALTER'):
262 3007 aaronmk
                # structural changes
263 3040 aaronmk
                # Rest of query must be unique in the face of name collisions,
264
                # so don't cache ADD COLUMN unless it has distinguishing comment
265
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
266 3007 aaronmk
                    self._cache_result()
267 3004 aaronmk
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
268 2800 aaronmk
                consume_rows(self) # fetch all rows so result will be cached
269 3004 aaronmk
270 2762 aaronmk
            return cur
271 1894 aaronmk
272 1891 aaronmk
        def fetchone(self):
273
            row = self.inner.fetchone()
274 1899 aaronmk
            if row != None: self.result.append(row)
275
            # otherwise, fetched all rows
276 1904 aaronmk
            else: self._cache_result()
277
            return row
278
279
        def _cache_result(self):
280 2948 aaronmk
            # For inserts that return a result set, don't cache result set since
281
            # inserts are not idempotent. Other non-SELECT queries don't have
282
            # their result set read, so only exceptions will be cached (an
283
            # invalid query will always be invalid).
284 1930 aaronmk
            if self.query_results != None and (not self._is_insert
285 1906 aaronmk
                or isinstance(self.result, Exception)):
286
287 1894 aaronmk
                assert self.query_lookup != None
288 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
289
                    util.dict_subset(dicts.AttrsDictView(self),
290
                    ['query', 'result', 'rowcount', 'description']))
291 1906 aaronmk
292 1916 aaronmk
        class CacheCursor:
293
            def __init__(self, cached_result): self.__dict__ = cached_result
294
295 1927 aaronmk
            def execute(self, *args, **kw_args):
296 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
297
                # otherwise, result is a rows list
298
                self.iter = iter(self.result)
299
300
            def fetchone(self):
301
                try: return self.iter.next()
302
                except StopIteration: return None
303 1891 aaronmk
304 2212 aaronmk
    def esc_value(self, value):
305 2663 aaronmk
        try: str_ = self.mogrify('%s', [value])
306
        except NotImplementedError, e:
307
            module = util.root_module(self.db)
308
            if module == 'MySQLdb':
309
                import _mysql
310
                str_ = _mysql.escape_string(value)
311
            else: raise e
312 2374 aaronmk
        return strings.to_unicode(str_)
313 2212 aaronmk
314 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
315
316 2814 aaronmk
    def std_code(self, str_):
317
        '''Standardizes SQL code.
318
        * Ensures that string literals are prefixed by `E`
319
        '''
320
        if str_.startswith("'"): str_ = 'E'+str_
321
        return str_
322
323 2665 aaronmk
    def can_mogrify(self):
324 2663 aaronmk
        module = util.root_module(self.db)
325 2665 aaronmk
        return module == 'psycopg2'
326 2663 aaronmk
327 2665 aaronmk
    def mogrify(self, query, params=None):
328
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
329
        else: raise NotImplementedError("Can't mogrify query")
330
331 2671 aaronmk
    def print_notices(self):
332 2725 aaronmk
        if hasattr(self.db, 'notices'):
333
            for msg in self.db.notices:
334
                if msg not in self._notices_seen:
335
                    self._notices_seen.add(msg)
336
                    self.log_debug(msg, level=2)
337 2671 aaronmk
338 2793 aaronmk
    def run_query(self, query, cacheable=False, log_level=2,
339 2464 aaronmk
        debug_msg_ref=None):
340 2445 aaronmk
        '''
341 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
342
            throws one of these exceptions.
343 2664 aaronmk
        @param debug_msg_ref If specified, the log message will be returned in
344
            this instead of being output. This allows you to filter log messages
345
            depending on the result of the query.
346 2445 aaronmk
        '''
347 2167 aaronmk
        assert query != None
348
349 2047 aaronmk
        if not self.caching: cacheable = False
350 1903 aaronmk
        used_cache = False
351 2664 aaronmk
352
        def log_msg(query):
353
            if used_cache: cache_status = 'cache hit'
354
            elif cacheable: cache_status = 'cache miss'
355
            else: cache_status = 'non-cacheable'
356
            return 'DB query: '+cache_status+':\n'+strings.as_code(query, 'SQL')
357
358 1903 aaronmk
        try:
359 1927 aaronmk
            # Get cursor
360
            if cacheable:
361
                try:
362 2797 aaronmk
                    cur = self.query_results[query]
363 1927 aaronmk
                    used_cache = True
364
                except KeyError: cur = self.DbCursor(self)
365
            else: cur = self.db.cursor()
366
367 2664 aaronmk
            # Log query
368
            if self.debug and debug_msg_ref == None: # log before running
369
                self.log_debug(log_msg(query), log_level)
370
371 1927 aaronmk
            # Run query
372 2793 aaronmk
            cur.execute(query)
373 1903 aaronmk
        finally:
374 2671 aaronmk
            self.print_notices()
375 2664 aaronmk
            if self.debug and debug_msg_ref != None: # return after running
376 2793 aaronmk
                debug_msg_ref[0] = log_msg(str(get_cur_query(cur, query)))
377 1903 aaronmk
378
        return cur
379 1914 aaronmk
380 2797 aaronmk
    def is_cached(self, query): return query in self.query_results
381 2139 aaronmk
382 2907 aaronmk
    def with_autocommit(self, func):
383 2801 aaronmk
        import psycopg2.extensions
384
385
        prev_isolation_level = self.db.isolation_level
386 2907 aaronmk
        self.db.set_isolation_level(
387
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
388 2683 aaronmk
        try: return func()
389 2801 aaronmk
        finally: self.db.set_isolation_level(prev_isolation_level)
390 2683 aaronmk
391 2139 aaronmk
    def with_savepoint(self, func):
392 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
393 2443 aaronmk
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
394 2139 aaronmk
        self._savepoint += 1
395 2930 aaronmk
        try: return func()
396 2139 aaronmk
        except:
397 2443 aaronmk
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
398 2139 aaronmk
            raise
399 2930 aaronmk
        finally:
400
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
401
            # "The savepoint remains valid and can be rolled back to again"
402
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
403 2443 aaronmk
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
404 2930 aaronmk
405
            self._savepoint -= 1
406
            assert self._savepoint >= 0
407
408
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
409 2191 aaronmk
410
    def do_autocommit(self):
411
        '''Autocommits if outside savepoint'''
412 3135 aaronmk
        assert self._savepoint >= 1
413
        if self.autocommit and self._savepoint == 1:
414 2924 aaronmk
            self.log_debug('Autocommitting', level=4)
415 2191 aaronmk
            self.db.commit()
416 2643 aaronmk
417 2819 aaronmk
    def col_info(self, col):
418 2643 aaronmk
        table = sql_gen.Table('columns', 'information_schema')
419 3063 aaronmk
        type_ = sql_gen.Coalesce(sql_gen.Nullif(sql_gen.Col('data_type'),
420
            'USER-DEFINED'), sql_gen.Col('udt_name'))
421 3078 aaronmk
        cols = [type_, 'column_default',
422
            sql_gen.Cast('boolean', sql_gen.Col('is_nullable'))]
423 2643 aaronmk
424
        conds = [('table_name', col.table.name), ('column_name', col.name)]
425
        schema = col.table.schema
426
        if schema != None: conds.append(('table_schema', schema))
427
428 2819 aaronmk
        type_, default, nullable = row(select(self, table, cols, conds,
429 3059 aaronmk
            order_by='table_schema', limit=1, cacheable=False, log_level=4))
430 2643 aaronmk
            # TODO: order_by search_path schema order
431 2819 aaronmk
        default = sql_gen.as_Code(default, self)
432
433
        return sql_gen.TypedCol(col.name, type_, default, nullable)
434 2917 aaronmk
435
    def TempFunction(self, name):
436
        if self.debug_temp: schema = None
437
        else: schema = 'pg_temp'
438
        return sql_gen.Function(name, schema)
439 1849 aaronmk
440 1869 aaronmk
connect = DbConn
441
442 832 aaronmk
##### Recoverable querying
443 15 aaronmk
444 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
445 11 aaronmk
446 2791 aaronmk
def run_query(db, query, recover=None, cacheable=False, log_level=2,
447
    log_ignore_excs=None, **kw_args):
448 2794 aaronmk
    '''For params, see DbConn.run_query()'''
449 830 aaronmk
    if recover == None: recover = False
450 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
451
    log_ignore_excs = tuple(log_ignore_excs)
452 830 aaronmk
453 2666 aaronmk
    debug_msg_ref = None # usually, db.run_query() logs query before running it
454
    # But if filtering with log_ignore_excs, wait until after exception parsing
455 2984 aaronmk
    if log_ignore_excs != () or not db.can_mogrify(): debug_msg_ref = [None]
456 2666 aaronmk
457 2148 aaronmk
    try:
458 2464 aaronmk
        try:
459 2794 aaronmk
            def run(): return db.run_query(query, cacheable, log_level,
460 2793 aaronmk
                debug_msg_ref, **kw_args)
461 2796 aaronmk
            if recover and not db.is_cached(query):
462 2464 aaronmk
                return with_savepoint(db, run)
463
            else: return run() # don't need savepoint if cached
464
        except Exception, e:
465 3095 aaronmk
            msg = strings.ustr(e.args[0])
466 2464 aaronmk
467 3095 aaronmk
            match = re.match(r'^duplicate key value violates unique constraint '
468 3096 aaronmk
                r'"((_?[^\W_]+(?=[._]))?.+?)"', msg)
469 2464 aaronmk
            if match:
470
                constraint, table = match.groups()
471 3025 aaronmk
                cols = []
472
                if recover: # need auto-rollback to run index_cols()
473
                    try: cols = index_cols(db, table, constraint)
474
                    except NotImplementedError: pass
475
                raise DuplicateKeyException(constraint, cols, e)
476 2464 aaronmk
477 3095 aaronmk
            match = re.match(r'^null value in column "(.+?)" violates not-null'
478 2464 aaronmk
                r' constraint', msg)
479
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
480
481 3095 aaronmk
            match = re.match(r'^(?:invalid input (?:syntax|value)\b.*?'
482 3109 aaronmk
                r'|.+? field value out of range): "(.+?)"', msg)
483 2464 aaronmk
            if match:
484 3109 aaronmk
                value, = match.groups()
485
                raise InvalidValueException(strings.to_unicode(value), e)
486 2464 aaronmk
487 3095 aaronmk
            match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
488 2523 aaronmk
                r'is of type', msg)
489
            if match:
490
                col, type_ = match.groups()
491
                raise MissingCastException(type_, col, e)
492
493 3095 aaronmk
            match = re.match(r'^(\S+) "(.+?)".*? already exists', msg)
494 2945 aaronmk
            if match:
495
                type_, name = match.groups()
496
                raise DuplicateException(type_, name, e)
497 2464 aaronmk
498
            raise # no specific exception raised
499
    except log_ignore_excs:
500
        log_level += 2
501
        raise
502
    finally:
503 2666 aaronmk
        if debug_msg_ref != None and debug_msg_ref[0] != None:
504
            db.log_debug(debug_msg_ref[0], log_level)
505 830 aaronmk
506 832 aaronmk
##### Basic queries
507
508 2153 aaronmk
def next_version(name):
509 2163 aaronmk
    version = 1 # first existing name was version 0
510 2586 aaronmk
    match = re.match(r'^(.*)#(\d+)$', name)
511 2153 aaronmk
    if match:
512 2586 aaronmk
        name, version = match.groups()
513
        version = int(version)+1
514 2932 aaronmk
    return sql_gen.concat(name, '#'+str(version))
515 2153 aaronmk
516 2899 aaronmk
def lock_table(db, table, mode):
517
    table = sql_gen.as_Table(table)
518
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
519
520 2789 aaronmk
def run_query_into(db, query, into=None, add_indexes_=False, **kw_args):
521 2085 aaronmk
    '''Outputs a query to a temp table.
522
    For params, see run_query().
523
    '''
524 2789 aaronmk
    if into == None: return run_query(db, query, **kw_args)
525 2790 aaronmk
526
    assert isinstance(into, sql_gen.Table)
527
528 2992 aaronmk
    into.is_temp = True
529 3008 aaronmk
    # "temporary tables cannot specify a schema name", so remove schema
530
    into.schema = None
531 2992 aaronmk
532 2790 aaronmk
    kw_args['recover'] = True
533 2945 aaronmk
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
534 2790 aaronmk
535 2916 aaronmk
    temp = not db.debug_temp # tables are permanent in debug_temp mode
536 2790 aaronmk
537
    # Create table
538
    while True:
539
        create_query = 'CREATE'
540
        if temp: create_query += ' TEMP'
541
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
542 2385 aaronmk
543 2790 aaronmk
        try:
544
            cur = run_query(db, create_query, **kw_args)
545
                # CREATE TABLE AS sets rowcount to # rows in query
546
            break
547 2945 aaronmk
        except DuplicateException, e:
548 2790 aaronmk
            into.name = next_version(into.name)
549
            # try again with next version of name
550
551
    if add_indexes_: add_indexes(db, into)
552 3075 aaronmk
553
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
554
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
555
    # table is going to be used in complex queries, it is wise to run ANALYZE on
556
    # the temporary table after it is populated."
557
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
558
    # If into is not a temp table, ANALYZE is useful but not required.
559 3073 aaronmk
    analyze(db, into)
560 2790 aaronmk
561
    return cur
562 2085 aaronmk
563 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
564
565 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
566
567 2233 aaronmk
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
568 2293 aaronmk
    start=None, order_by=order_by_pkey, default_table=None):
569 1981 aaronmk
    '''
570 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
571 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
572 1981 aaronmk
    @param fields Use None to select all fields in the table
573 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
574 2379 aaronmk
        * container can be any iterable type
575 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
576
        * compare_right_side: sql_gen.ValueCond|literal value
577 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
578
        use all columns
579 2786 aaronmk
    @return query
580 1981 aaronmk
    '''
581 2315 aaronmk
    # Parse tables param
582 2964 aaronmk
    tables = lists.mk_seq(tables)
583 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
584 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
585 2121 aaronmk
586 2315 aaronmk
    # Parse other params
587 2376 aaronmk
    if conds == None: conds = []
588 2650 aaronmk
    elif dicts.is_dict(conds): conds = conds.items()
589 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
590 3129 aaronmk
    assert limit == None or isinstance(limit, (int, long))
591
    assert start == None or isinstance(start, (int, long))
592 2315 aaronmk
    if order_by is order_by_pkey:
593
        if distinct_on != []: order_by = None
594
        else: order_by = pkey(db, table0, recover=True)
595 865 aaronmk
596 2315 aaronmk
    query = 'SELECT'
597 2056 aaronmk
598 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
599 2056 aaronmk
600 2200 aaronmk
    # DISTINCT ON columns
601 2233 aaronmk
    if distinct_on != []:
602 2467 aaronmk
        query += '\nDISTINCT'
603 2254 aaronmk
        if distinct_on is not distinct_on_all:
604 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
605
606
    # Columns
607 3027 aaronmk
    if fields == None:
608
        if query.find('\n') >= 0: whitespace = '\n'
609
        else: whitespace = ' '
610
        query += whitespace+'*'
611 2765 aaronmk
    else:
612
        assert fields != []
613 3027 aaronmk
        query += '\n'+('\n, '.join(map(parse_col, fields)))
614 2200 aaronmk
615
    # Main table
616 2467 aaronmk
    query += '\nFROM '+table0.to_str(db)
617 865 aaronmk
618 2122 aaronmk
    # Add joins
619 2271 aaronmk
    left_table = table0
620 2263 aaronmk
    for join_ in tables:
621
        table = join_.table
622 2238 aaronmk
623 2343 aaronmk
        # Parse special values
624
        if join_.type_ is sql_gen.filter_out: # filter no match
625 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
626 2853 aaronmk
                sql_gen.CompareCond(None, '~=')))
627 2343 aaronmk
628 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
629 2122 aaronmk
630
        left_table = table
631
632 865 aaronmk
    missing = True
633 2376 aaronmk
    if conds != []:
634 2576 aaronmk
        if len(conds) == 1: whitespace = ' '
635
        else: whitespace = '\n'
636 2578 aaronmk
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
637
            .to_str(db) for l, r in conds], 'WHERE')
638 865 aaronmk
        missing = False
639 2227 aaronmk
    if order_by != None:
640 2467 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
641
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
642 865 aaronmk
    if start != None:
643 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
644 865 aaronmk
        missing = False
645
    if missing: warnings.warn(DbWarning(
646
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
647
648 2786 aaronmk
    return query
649 11 aaronmk
650 2054 aaronmk
def select(db, *args, **kw_args):
651
    '''For params, see mk_select() and run_query()'''
652
    recover = kw_args.pop('recover', None)
653
    cacheable = kw_args.pop('cacheable', True)
654 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
655 2054 aaronmk
656 2791 aaronmk
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
657
        log_level=log_level)
658 2054 aaronmk
659 2788 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
660 3009 aaronmk
    embeddable=False, ignore=False):
661 1960 aaronmk
    '''
662
    @param returning str|None An inserted column (such as pkey) to return
663 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
664 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
665
        query will be fully cached, not just if it raises an exception.
666 3009 aaronmk
    @param ignore Whether to ignore duplicate keys.
667 1960 aaronmk
    '''
668 2754 aaronmk
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
669 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
670 3010 aaronmk
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
671 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
672 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
673 2063 aaronmk
674 2497 aaronmk
    first_line = 'INSERT INTO '+table.to_str(db)
675 2063 aaronmk
676 3009 aaronmk
    def mk_insert(select_query):
677
        query = first_line
678 3014 aaronmk
        if cols != None:
679
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
680 3009 aaronmk
        query += '\n'+select_query
681
682
        if returning != None:
683
            returning_name_col = sql_gen.to_name_only_col(returning)
684
            query += '\nRETURNING '+returning_name_col.to_str(db)
685
686
        return query
687 2063 aaronmk
688 3017 aaronmk
    return_type = 'unknown'
689
    if returning != None: return_type = returning.to_str(db)+'%TYPE'
690
691 3009 aaronmk
    lang = 'sql'
692
    if ignore:
693 3017 aaronmk
        # Always return something to set the correct rowcount
694
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
695
696 3009 aaronmk
        embeddable = True # must use function
697
        lang = 'plpgsql'
698 3010 aaronmk
699 3092 aaronmk
        if cols == None:
700
            row = [sql_gen.Col(sql_gen.all_cols, 'row')]
701
            row_vars = [sql_gen.Table('row')]
702
        else:
703
            row_vars = row = [sql_gen.Col(c.name, 'row') for c in cols]
704
705 3009 aaronmk
        query = '''\
706 3010 aaronmk
DECLARE
707 3014 aaronmk
    row '''+table.to_str(db)+'''%ROWTYPE;
708 3009 aaronmk
BEGIN
709 3019 aaronmk
    /* Need an EXCEPTION block for each individual row because "When an error is
710
    caught by an EXCEPTION clause, [...] all changes to persistent database
711
    state within the block are rolled back."
712
    This is unfortunate because "A block containing an EXCEPTION clause is
713
    significantly more expensive to enter and exit than a block without one."
714 3015 aaronmk
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
715
#PLPGSQL-ERROR-TRAPPING)
716
    */
717 3092 aaronmk
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
718 3034 aaronmk
'''+select_query+'''
719
    LOOP
720 3015 aaronmk
        BEGIN
721 3019 aaronmk
            RETURN QUERY
722 3014 aaronmk
'''+mk_insert(sql_gen.Values(row).to_str(db))+'''
723 3010 aaronmk
;
724 3015 aaronmk
        EXCEPTION
725 3019 aaronmk
            WHEN unique_violation THEN NULL; -- continue to next row
726 3015 aaronmk
        END;
727 3010 aaronmk
    END LOOP;
728
END;\
729 3009 aaronmk
'''
730
    else: query = mk_insert(select_query)
731
732 2070 aaronmk
    if embeddable:
733
        # Create function
734 2513 aaronmk
        function_name = sql_gen.clean_name(first_line)
735 2189 aaronmk
        while True:
736
            try:
737 2918 aaronmk
                function = db.TempFunction(function_name)
738 2194 aaronmk
739 2189 aaronmk
                function_query = '''\
740 2698 aaronmk
CREATE FUNCTION '''+function.to_str(db)+'''()
741 3017 aaronmk
RETURNS SETOF '''+return_type+'''
742 3009 aaronmk
LANGUAGE '''+lang+'''
743 2467 aaronmk
AS $$
744 3009 aaronmk
'''+query+'''
745 2467 aaronmk
$$;
746 2070 aaronmk
'''
747 2446 aaronmk
                run_query(db, function_query, recover=True, cacheable=True,
748 2945 aaronmk
                    log_ignore_excs=(DuplicateException,))
749 2189 aaronmk
                break # this version was successful
750 2945 aaronmk
            except DuplicateException, e:
751 2189 aaronmk
                function_name = next_version(function_name)
752
                # try again with next version of name
753 2070 aaronmk
754 2337 aaronmk
        # Return query that uses function
755 3009 aaronmk
        cols = None
756
        if returning != None: cols = [returning]
757 2698 aaronmk
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
758 3009 aaronmk
            cols) # AS clause requires function alias
759 2787 aaronmk
        return mk_select(db, func_table, start=0, order_by=None)
760 2070 aaronmk
761 2787 aaronmk
    return query
762 2066 aaronmk
763 3074 aaronmk
def insert_select(db, table, *args, **kw_args):
764 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
765 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
766
        values in
767 2072 aaronmk
    '''
768 2386 aaronmk
    into = kw_args.pop('into', None)
769
    if into != None: kw_args['embeddable'] = True
770 2066 aaronmk
    recover = kw_args.pop('recover', None)
771 3011 aaronmk
    if kw_args.get('ignore', False): recover = True
772 2066 aaronmk
    cacheable = kw_args.pop('cacheable', True)
773 2673 aaronmk
    log_level = kw_args.pop('log_level', 2)
774 2066 aaronmk
775 3074 aaronmk
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
776
        into, recover=recover, cacheable=cacheable, log_level=log_level)
777
    autoanalyze(db, table)
778
    return cur
779 2063 aaronmk
780 2738 aaronmk
default = sql_gen.default # tells insert() to use the default value for a column
781 2066 aaronmk
782 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
783 2085 aaronmk
    '''For params, see insert_select()'''
784 1960 aaronmk
    if lists.is_seq(row): cols = None
785
    else:
786
        cols = row.keys()
787
        row = row.values()
788 2738 aaronmk
    row = list(row) # ensure that "== []" works
789 1960 aaronmk
790 2738 aaronmk
    if row == []: query = None
791
    else: query = sql_gen.Values(row).to_str(db)
792 1961 aaronmk
793 2788 aaronmk
    return insert_select(db, table, cols, query, *args, **kw_args)
794 11 aaronmk
795 3056 aaronmk
def mk_update(db, table, changes=None, cond=None, in_place=False):
796 2402 aaronmk
    '''
797
    @param changes [(col, new_value),...]
798
        * container can be any iterable type
799
        * col: sql_gen.Code|str (for col name)
800
        * new_value: sql_gen.Code|literal value
801
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
802 3056 aaronmk
    @param in_place If set, locks the table and updates rows in place.
803
        This avoids creating dead rows in PostgreSQL.
804
        * cond must be None
805 2402 aaronmk
    @return str query
806
    '''
807 3057 aaronmk
    table = sql_gen.as_Table(table)
808
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
809
        for c, v in changes]
810
811 3056 aaronmk
    if in_place:
812
        assert cond == None
813 3058 aaronmk
814 3065 aaronmk
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
815
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
816
            +db.col_info(sql_gen.with_default_table(c, table)).type
817
            +'\nUSING '+v.to_str(db) for c, v in changes))
818 3058 aaronmk
    else:
819
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
820
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
821
            for c, v in changes))
822
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
823 3056 aaronmk
824 2402 aaronmk
    return query
825
826 3074 aaronmk
def update(db, table, *args, **kw_args):
827 2402 aaronmk
    '''For params, see mk_update() and run_query()'''
828
    recover = kw_args.pop('recover', None)
829 3043 aaronmk
    cacheable = kw_args.pop('cacheable', False)
830 3030 aaronmk
    log_level = kw_args.pop('log_level', 2)
831 2402 aaronmk
832 3074 aaronmk
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
833
        cacheable, log_level=log_level)
834
    autoanalyze(db, table)
835
    return cur
836 2402 aaronmk
837 135 aaronmk
def last_insert_id(db):
838 1849 aaronmk
    module = util.root_module(db.db)
839 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
840
    elif module == 'MySQLdb': return db.insert_id()
841
    else: return None
842 13 aaronmk
843 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
844 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
845 2415 aaronmk
    to names that will be distinct among the columns' tables.
846 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
847 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
848
    @param into The table for the new columns.
849 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
850
        columns will be included in the mapping even if they are not in cols.
851
        The tables of the provided Col objects will be changed to into, so make
852
        copies of them if you want to keep the original tables.
853
    @param as_items Whether to return a list of dict items instead of a dict
854 2383 aaronmk
    @return dict(orig_col=new_col, ...)
855
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
856 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
857
        * All mappings use the into table so its name can easily be
858 2383 aaronmk
          changed for all columns at once
859
    '''
860 2415 aaronmk
    cols = lists.uniqify(cols)
861
862 2394 aaronmk
    items = []
863 2389 aaronmk
    for col in preserve:
864 2390 aaronmk
        orig_col = copy.copy(col)
865 2392 aaronmk
        col.table = into
866 2394 aaronmk
        items.append((orig_col, col))
867
    preserve = set(preserve)
868
    for col in cols:
869 2716 aaronmk
        if col not in preserve:
870
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
871 2394 aaronmk
872
    if not as_items: items = dict(items)
873
    return items
874 2383 aaronmk
875 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
876 2391 aaronmk
    '''For params, see mk_flatten_mapping()
877
    @return See return value of mk_flatten_mapping()
878
    '''
879 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
880
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
881 2786 aaronmk
    run_query_into(db, mk_select(db, joins, cols, limit=limit, start=start),
882 2846 aaronmk
        into=into, add_indexes_=True)
883 2394 aaronmk
    return dict(items)
884 2391 aaronmk
885 3079 aaronmk
##### Database structure introspection
886 2414 aaronmk
887 3079 aaronmk
#### Tables
888
889
def tables(db, schema_like='public', table_like='%', exact=False):
890
    if exact: compare = '='
891
    else: compare = 'LIKE'
892
893
    module = util.root_module(db.db)
894
    if module == 'psycopg2':
895
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
896
            ('tablename', sql_gen.CompareCond(table_like, compare))]
897
        return values(select(db, 'pg_tables', ['tablename'], conds,
898
            order_by='tablename', log_level=4))
899
    elif module == 'MySQLdb':
900
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
901
            , cacheable=True, log_level=4))
902
    else: raise NotImplementedError("Can't list tables for "+module+' database')
903
904
def table_exists(db, table):
905
    table = sql_gen.as_Table(table)
906
    return list(tables(db, table.schema, table.name, exact=True)) != []
907
908 2426 aaronmk
def table_row_count(db, table, recover=None):
909 2786 aaronmk
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
910 2443 aaronmk
        order_by=None, start=0), recover=recover, log_level=3))
911 2426 aaronmk
912 2414 aaronmk
def table_cols(db, table, recover=None):
913
    return list(col_names(select(db, table, limit=0, order_by=None,
914 2443 aaronmk
        recover=recover, log_level=4)))
915 2414 aaronmk
916 2291 aaronmk
def pkey(db, table, recover=None):
917 832 aaronmk
    '''Assumed to be first column in table'''
918 2339 aaronmk
    return table_cols(db, table, recover)[0]
919 832 aaronmk
920 2559 aaronmk
not_null_col = 'not_null_col'
921 2340 aaronmk
922
def table_not_null_col(db, table, recover=None):
923
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
924
    if not_null_col in table_cols(db, table, recover): return not_null_col
925
    else: return pkey(db, table, recover)
926
927 853 aaronmk
def index_cols(db, table, index):
928
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
929
    automatically created. When you don't know whether something is a UNIQUE
930
    constraint or a UNIQUE index, use this function.'''
931 1909 aaronmk
    module = util.root_module(db.db)
932
    if module == 'psycopg2':
933
        return list(values(run_query(db, '''\
934 853 aaronmk
SELECT attname
935 866 aaronmk
FROM
936
(
937
        SELECT attnum, attname
938
        FROM pg_index
939
        JOIN pg_class index ON index.oid = indexrelid
940
        JOIN pg_class table_ ON table_.oid = indrelid
941
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
942
        WHERE
943 2782 aaronmk
            table_.relname = '''+db.esc_value(table)+'''
944
            AND index.relname = '''+db.esc_value(index)+'''
945 866 aaronmk
    UNION
946
        SELECT attnum, attname
947
        FROM
948
        (
949
            SELECT
950
                indrelid
951
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
952
                    AS indkey
953
            FROM pg_index
954
            JOIN pg_class index ON index.oid = indexrelid
955
            JOIN pg_class table_ ON table_.oid = indrelid
956
            WHERE
957 2782 aaronmk
                table_.relname = '''+db.esc_value(table)+'''
958
                AND index.relname = '''+db.esc_value(index)+'''
959 866 aaronmk
        ) s
960
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
961
) s
962 853 aaronmk
ORDER BY attnum
963 2782 aaronmk
'''
964
            , cacheable=True, log_level=4)))
965 1909 aaronmk
    else: raise NotImplementedError("Can't list index columns for "+module+
966
        ' database')
967 853 aaronmk
968 464 aaronmk
def constraint_cols(db, table, constraint):
969 1849 aaronmk
    module = util.root_module(db.db)
970 464 aaronmk
    if module == 'psycopg2':
971
        return list(values(run_query(db, '''\
972
SELECT attname
973
FROM pg_constraint
974
JOIN pg_class ON pg_class.oid = conrelid
975
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
976
WHERE
977 2783 aaronmk
    relname = '''+db.esc_value(table)+'''
978
    AND conname = '''+db.esc_value(constraint)+'''
979 464 aaronmk
ORDER BY attnum
980 2783 aaronmk
'''
981
            )))
982 464 aaronmk
    else: raise NotImplementedError("Can't list constraint columns for "+module+
983
        ' database')
984
985 3079 aaronmk
#### Functions
986
987
def function_exists(db, function):
988
    function = sql_gen.as_Function(function)
989
990
    info_table = sql_gen.Table('routines', 'information_schema')
991
    conds = [('routine_name', function.name)]
992
    schema = function.schema
993
    if schema != None: conds.append(('routine_schema', schema))
994
    # Exclude trigger functions, since they cannot be called directly
995
    conds.append(('data_type', sql_gen.CompareCond('trigger', '!=')))
996
997
    return list(values(select(db, info_table, ['routine_name'], conds,
998
        order_by='routine_schema', limit=1, log_level=4))) != []
999
        # TODO: order_by search_path schema order
1000
1001
##### Structural changes
1002
1003
#### Columns
1004
1005
def add_col(db, table, col, comment=None, **kw_args):
1006
    '''
1007
    @param col TypedCol Name may be versioned, so be sure to propagate any
1008
        renaming back to any source column for the TypedCol.
1009
    @param comment None|str SQL comment used to distinguish columns of the same
1010
        name from each other when they contain different data, to allow the
1011
        ADD COLUMN query to be cached. If not set, query will not be cached.
1012
    '''
1013
    assert isinstance(col, sql_gen.TypedCol)
1014
1015
    while True:
1016
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1017
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1018
1019
        try:
1020
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1021
            break
1022
        except DuplicateException:
1023
            col.name = next_version(col.name)
1024
            # try again with next version of name
1025
1026
def add_not_null(db, col):
1027
    table = col.table
1028
    col = sql_gen.to_name_only_col(col)
1029
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1030
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1031
1032 2096 aaronmk
row_num_col = '_row_num'
1033
1034 3079 aaronmk
row_num_typed_col = sql_gen.TypedCol(row_num_col, 'serial', nullable=False,
1035
    constraints='PRIMARY KEY')
1036
1037
def add_row_num(db, table):
1038
    '''Adds a row number column to a table. Its name is in row_num_col. It will
1039
    be the primary key.'''
1040
    add_col(db, table, row_num_typed_col, log_level=3)
1041
1042
#### Indexes
1043
1044
def add_pkey(db, table, cols=None, recover=None):
1045
    '''Adds a primary key.
1046
    @param cols [sql_gen.Col,...] The columns in the primary key.
1047
        Defaults to the first column in the table.
1048
    @pre The table must not already have a primary key.
1049
    '''
1050
    table = sql_gen.as_Table(table)
1051
    if cols == None: cols = [pkey(db, table, recover)]
1052
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1053
1054
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1055
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1056
        log_ignore_excs=(DuplicateException,))
1057
1058 2998 aaronmk
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1059 2688 aaronmk
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1060 2538 aaronmk
    Currently, only function calls are supported as expressions.
1061 2998 aaronmk
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1062 2847 aaronmk
        This allows indexes to be used for comparisons where NULLs are equal.
1063 2538 aaronmk
    '''
1064 2964 aaronmk
    exprs = lists.mk_seq(exprs)
1065 2538 aaronmk
1066 2688 aaronmk
    # Parse exprs
1067
    old_exprs = exprs[:]
1068
    exprs = []
1069
    cols = []
1070
    for i, expr in enumerate(old_exprs):
1071 2823 aaronmk
        expr = sql_gen.as_Col(expr, table)
1072 2688 aaronmk
1073 2823 aaronmk
        # Handle nullable columns
1074 2998 aaronmk
        if ensure_not_null_:
1075
            try: expr = ensure_not_null(db, expr)
1076 2860 aaronmk
            except KeyError: pass # unknown type, so just create plain index
1077 2823 aaronmk
1078 2688 aaronmk
        # Extract col
1079 3002 aaronmk
        expr = copy.deepcopy(expr) # don't modify input!
1080 2688 aaronmk
        if isinstance(expr, sql_gen.FunctionCall):
1081
            col = expr.args[0]
1082
            expr = sql_gen.Expr(expr)
1083
        else: col = expr
1084 2823 aaronmk
        assert isinstance(col, sql_gen.Col)
1085 2688 aaronmk
1086
        # Extract table
1087
        if table == None:
1088
            assert sql_gen.is_table_col(col)
1089
            table = col.table
1090
1091
        col.table = None
1092
1093
        exprs.append(expr)
1094
        cols.append(col)
1095 2408 aaronmk
1096 2688 aaronmk
    table = sql_gen.as_Table(table)
1097
    index = sql_gen.Table(str(sql_gen.Col(','.join(map(str, cols)), table)))
1098
1099 3005 aaronmk
    # Add index
1100
    while True:
1101
        str_ = 'CREATE'
1102
        if unique: str_ += ' UNIQUE'
1103
        str_ += ' INDEX '+index.to_str(db)+' ON '+table.to_str(db)+' ('+(
1104
            ', '.join((v.to_str(db) for v in exprs)))+')'
1105
1106
        try:
1107
            run_query(db, str_, recover=True, cacheable=True, log_level=3,
1108
                log_ignore_excs=(DuplicateException,))
1109
            break
1110
        except DuplicateException:
1111
            index.name = next_version(index.name)
1112
            # try again with next version of name
1113 2408 aaronmk
1114 2997 aaronmk
def add_index_col(db, col, suffix, expr, nullable=True):
1115 3000 aaronmk
    if sql_gen.index_col(col) != None: return # already has index col
1116 2997 aaronmk
1117
    new_col = sql_gen.suffixed_col(col, suffix)
1118
1119 3006 aaronmk
    # Add column
1120 3038 aaronmk
    new_typed_col = sql_gen.TypedCol(new_col.name, db.col_info(col).type)
1121 3045 aaronmk
    add_col(db, col.table, new_typed_col, comment='src: '+repr(col),
1122
        log_level=3)
1123 3037 aaronmk
    new_col.name = new_typed_col.name # propagate any renaming
1124 3006 aaronmk
1125 3064 aaronmk
    update(db, col.table, [(new_col, expr)], in_place=True, cacheable=True,
1126
        log_level=3)
1127 2997 aaronmk
    if not nullable: add_not_null(db, new_col)
1128
    add_index(db, new_col)
1129
1130 3104 aaronmk
    col.table.index_cols[col.name] = new_col.name
1131 2997 aaronmk
1132 3047 aaronmk
# Controls when ensure_not_null() will use index columns
1133
not_null_index_cols_min_rows = 0 # rows; initially always use index columns
1134
1135 2997 aaronmk
def ensure_not_null(db, col):
1136
    '''For params, see sql_gen.ensure_not_null()'''
1137
    expr = sql_gen.ensure_not_null(db, col)
1138
1139 3047 aaronmk
    # If a nullable column in a temp table, add separate index column instead.
1140
    # Note that for small datasources, this adds 6-25% to the total import time.
1141
    if (sql_gen.is_temp_col(col) and isinstance(expr, sql_gen.EnsureNotNull)
1142
        and table_row_count(db, col.table) >= not_null_index_cols_min_rows):
1143 2997 aaronmk
        add_index_col(db, col, '::NOT NULL', expr, nullable=False)
1144 3000 aaronmk
        expr = sql_gen.index_col(col)
1145 2997 aaronmk
1146
    return expr
1147
1148 3083 aaronmk
already_indexed = object() # tells add_indexes() the pkey has already been added
1149
1150
def add_indexes(db, table, has_pkey=True):
1151
    '''Adds an index on all columns in a table.
1152
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1153
        index should be added on the first column.
1154
        * If already_indexed, the pkey is assumed to have already been added
1155
    '''
1156
    cols = table_cols(db, table)
1157
    if has_pkey:
1158
        if has_pkey is not already_indexed: add_pkey(db, table)
1159
        cols = cols[1:]
1160
    for col in cols: add_index(db, col, table)
1161
1162 3079 aaronmk
#### Tables
1163 2772 aaronmk
1164 3079 aaronmk
### Maintenance
1165 2772 aaronmk
1166 3079 aaronmk
def analyze(db, table):
1167
    table = sql_gen.as_Table(table)
1168
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1169 2934 aaronmk
1170 3079 aaronmk
def autoanalyze(db, table):
1171
    if db.autoanalyze: analyze(db, table)
1172 2935 aaronmk
1173 3079 aaronmk
def vacuum(db, table):
1174
    table = sql_gen.as_Table(table)
1175
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1176
        log_level=3))
1177 2086 aaronmk
1178 3079 aaronmk
### Lifecycle
1179
1180 2889 aaronmk
def drop_table(db, table):
1181
    table = sql_gen.as_Table(table)
1182
    return run_query(db, 'DROP TABLE IF EXISTS '+table.to_str(db)+' CASCADE')
1183
1184 3082 aaronmk
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1185
    like=None):
1186 2675 aaronmk
    '''Creates a table.
1187 2681 aaronmk
    @param cols [sql_gen.TypedCol,...] The column names and types
1188
    @param has_pkey If set, the first column becomes the primary key.
1189 2760 aaronmk
    @param col_indexes bool|[ref]
1190
        * If True, indexes will be added on all non-pkey columns.
1191
        * If a list reference, [0] will be set to a function to do this.
1192
          This can be used to delay index creation until the table is populated.
1193 2675 aaronmk
    '''
1194
    table = sql_gen.as_Table(table)
1195
1196 3082 aaronmk
    if like != None:
1197
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1198
            ]+cols
1199 2681 aaronmk
    if has_pkey:
1200
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1201 2872 aaronmk
        pkey.constraints = 'PRIMARY KEY'
1202 2681 aaronmk
1203 3085 aaronmk
    temp = table.is_temp and not db.debug_temp
1204
        # temp tables permanent in debug_temp mode
1205 2760 aaronmk
1206 3085 aaronmk
    # Create table
1207
    while True:
1208
        str_ = 'CREATE'
1209
        if temp: str_ += ' TEMP'
1210
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1211
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1212 3126 aaronmk
        str_ += '\n);'
1213 3085 aaronmk
1214
        try:
1215 3127 aaronmk
            run_query(db, str_, recover=True, cacheable=True, log_level=2,
1216 3085 aaronmk
                log_ignore_excs=(DuplicateException,))
1217
            break
1218
        except DuplicateException:
1219
            table.name = next_version(table.name)
1220
            # try again with next version of name
1221
1222 2760 aaronmk
    # Add indexes
1223 2773 aaronmk
    if has_pkey: has_pkey = already_indexed
1224
    def add_indexes_(): add_indexes(db, table, has_pkey)
1225
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1226
    elif col_indexes: add_indexes_() # add now
1227 2675 aaronmk
1228 3084 aaronmk
def copy_table_struct(db, src, dest):
1229
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1230 3085 aaronmk
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1231 3084 aaronmk
1232 3079 aaronmk
### Data
1233 2684 aaronmk
1234 2970 aaronmk
def truncate(db, table, schema='public', **kw_args):
1235
    '''For params, see run_query()'''
1236 2777 aaronmk
    table = sql_gen.as_Table(table, schema)
1237 2970 aaronmk
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1238 2732 aaronmk
1239 2965 aaronmk
def empty_temp(db, tables):
1240 2972 aaronmk
    if db.debug_temp: return # leave temp tables there for debugging
1241 2965 aaronmk
    tables = lists.mk_seq(tables)
1242 2971 aaronmk
    for table in tables: truncate(db, table, log_level=3)
1243 2965 aaronmk
1244 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
1245
    '''For kw_args, see tables()'''
1246
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1247 3094 aaronmk
1248
def distinct_table(db, table, distinct_on):
1249
    '''Creates a copy of a temp table which is distinct on the given columns.
1250 3099 aaronmk
    The old and new tables will both get an index on these columns, to
1251
    facilitate merge joins.
1252 3097 aaronmk
    @param distinct_on If empty, creates a table with one row. This is useful if
1253
        your distinct_on columns are all literal values.
1254 3099 aaronmk
    @return The new table.
1255 3094 aaronmk
    '''
1256 3099 aaronmk
    new_table = sql_gen.suffixed_table(table, '_distinct')
1257 3094 aaronmk
1258 3099 aaronmk
    copy_table_struct(db, table, new_table)
1259 3097 aaronmk
1260
    limit = None
1261
    if distinct_on == []: limit = 1 # one sample row
1262 3099 aaronmk
    else:
1263
        add_index(db, distinct_on, new_table, unique=True)
1264
        add_index(db, distinct_on, table) # for join optimization
1265 3097 aaronmk
1266 3099 aaronmk
    insert_select(db, new_table, None, mk_select(db, table, start=0,
1267 3097 aaronmk
        limit=limit), ignore=True)
1268 3099 aaronmk
    analyze(db, new_table)
1269 3094 aaronmk
1270 3099 aaronmk
    return new_table