Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1960 aaronmk
import lists
11 1889 aaronmk
from Proxy import Proxy
12 1872 aaronmk
import rand
13 2217 aaronmk
import sql_gen
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 2804 aaronmk
def get_cur_query(cur, input_query=None):
20 2168 aaronmk
    raw_query = None
21
    if hasattr(cur, 'query'): raw_query = cur.query
22
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
23 2170 aaronmk
24
    if raw_query != None: return raw_query
25 2804 aaronmk
    else: return '[input] '+strings.ustr(input_query)
26 14 aaronmk
27 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
28
    '''For params, see get_cur_query()'''
29 2771 aaronmk
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
30 135 aaronmk
31 300 aaronmk
class DbException(exc.ExceptionWithCause):
32 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
33 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
34 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
35
36 2143 aaronmk
class ExceptionWithName(DbException):
37
    def __init__(self, name, cause=None):
38 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
39 2143 aaronmk
        self.name = name
40 360 aaronmk
41 3109 aaronmk
class ExceptionWithValue(DbException):
42
    def __init__(self, value, cause=None):
43
        DbException.__init__(self, 'for value: '+strings.as_tt(repr(value)),
44
            cause)
45 2240 aaronmk
        self.value = value
46
47 2945 aaronmk
class ExceptionWithNameType(DbException):
48
    def __init__(self, type_, name, cause=None):
49
        DbException.__init__(self, 'for type: '+strings.as_tt(str(type_))
50
            +'; name: '+strings.as_tt(name), cause)
51
        self.type = type_
52
        self.name = name
53
54 2306 aaronmk
class ConstraintException(DbException):
55
    def __init__(self, name, cols, cause=None):
56 2484 aaronmk
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
57
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
58 2306 aaronmk
        self.name = name
59 468 aaronmk
        self.cols = cols
60 11 aaronmk
61 2523 aaronmk
class MissingCastException(DbException):
62
    def __init__(self, type_, col, cause=None):
63
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
64
            +' on column: '+strings.as_tt(col), cause)
65
        self.type = type_
66
        self.col = col
67
68 2143 aaronmk
class NameException(DbException): pass
69
70 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
71 13 aaronmk
72 2306 aaronmk
class NullValueException(ConstraintException): pass
73 13 aaronmk
74 3109 aaronmk
class InvalidValueException(ExceptionWithValue): pass
75 2239 aaronmk
76 2945 aaronmk
class DuplicateException(ExceptionWithNameType): pass
77 2143 aaronmk
78 89 aaronmk
class EmptyRowException(DbException): pass
79
80 865 aaronmk
##### Warnings
81
82
class DbWarning(UserWarning): pass
83
84 1930 aaronmk
##### Result retrieval
85
86
def col_names(cur): return (col[0] for col in cur.description)
87
88
def rows(cur): return iter(lambda: cur.fetchone(), None)
89
90
def consume_rows(cur):
91
    '''Used to fetch all rows so result will be cached'''
92
    iters.consume_iter(rows(cur))
93
94
def next_row(cur): return rows(cur).next()
95
96
def row(cur):
97
    row_ = next_row(cur)
98
    consume_rows(cur)
99
    return row_
100
101
def next_value(cur): return next_row(cur)[0]
102
103
def value(cur): return row(cur)[0]
104
105
def values(cur): return iters.func_iter(lambda: next_value(cur))
106
107
def value_or_none(cur):
108
    try: return value(cur)
109
    except StopIteration: return None
110
111 2762 aaronmk
##### Escaping
112 2101 aaronmk
113 2573 aaronmk
def esc_name_by_module(module, name):
114
    if module == 'psycopg2' or module == None: quote = '"'
115 2101 aaronmk
    elif module == 'MySQLdb': quote = '`'
116
    else: raise NotImplementedError("Can't escape name for "+module+' database')
117 2500 aaronmk
    return sql_gen.esc_name(name, quote)
118 2101 aaronmk
119
def esc_name_by_engine(engine, name, **kw_args):
120
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
121
122
def esc_name(db, name, **kw_args):
123
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
124
125
def qual_name(db, schema, table):
126
    def esc_name_(name): return esc_name(db, name)
127
    table = esc_name_(table)
128
    if schema != None: return esc_name_(schema)+'.'+table
129
    else: return table
130
131 1869 aaronmk
##### Database connections
132 1849 aaronmk
133 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
134 1926 aaronmk
135 1869 aaronmk
db_engines = {
136
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
137
    'PostgreSQL': ('psycopg2', {}),
138
}
139
140
DatabaseErrors_set = set([DbException])
141
DatabaseErrors = tuple(DatabaseErrors_set)
142
143
def _add_module(module):
144
    DatabaseErrors_set.add(module.DatabaseError)
145
    global DatabaseErrors
146
    DatabaseErrors = tuple(DatabaseErrors_set)
147
148
def db_config_str(db_config):
149
    return db_config['engine']+' database '+db_config['database']
150
151 2448 aaronmk
log_debug_none = lambda msg, level=2: None
152 1901 aaronmk
153 1849 aaronmk
class DbConn:
154 2923 aaronmk
    def __init__(self, db_config, autocommit=True, caching=True,
155 2915 aaronmk
        log_debug=log_debug_none, debug_temp=False):
156
        '''
157
        @param debug_temp Whether temporary objects should instead be permanent.
158
            This assists in debugging the internal objects used by the program.
159
        '''
160 1869 aaronmk
        self.db_config = db_config
161 2190 aaronmk
        self.autocommit = autocommit
162
        self.caching = caching
163 1901 aaronmk
        self.log_debug = log_debug
164 2193 aaronmk
        self.debug = log_debug != log_debug_none
165 2915 aaronmk
        self.debug_temp = debug_temp
166 3074 aaronmk
        self.autoanalyze = False
167 1869 aaronmk
168 3118 aaronmk
        self.reset()
169 1869 aaronmk
170
    def __getattr__(self, name):
171
        if name == '__dict__': raise Exception('getting __dict__')
172
        if name == 'db': return self._db()
173
        else: raise AttributeError()
174
175
    def __getstate__(self):
176
        state = copy.copy(self.__dict__) # shallow copy
177 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
178 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
179
        return state
180
181 3118 aaronmk
    def clear_cache(self): self.query_results = {}
182
183
    def reset(self):
184
        self.clear_cache()
185
        self._savepoint = 0
186
        self._notices_seen = set()
187
        self.__db = None
188
189 2165 aaronmk
    def connected(self): return self.__db != None
190
191 3116 aaronmk
    def close(self):
192
        if self.connected(): self.db.close()
193
        self.__db = None
194
195 1869 aaronmk
    def _db(self):
196
        if self.__db == None:
197
            # Process db_config
198
            db_config = self.db_config.copy() # don't modify input!
199 2097 aaronmk
            schemas = db_config.pop('schemas', None)
200 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
201
            module = __import__(module_name)
202
            _add_module(module)
203
            for orig, new in mappings.iteritems():
204
                try: util.rename_key(db_config, orig, new)
205
                except KeyError: pass
206
207
            # Connect
208
            self.__db = module.connect(**db_config)
209
210
            # Configure connection
211 2906 aaronmk
            if hasattr(self.db, 'set_isolation_level'):
212
                import psycopg2.extensions
213
                self.db.set_isolation_level(
214
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
215 2101 aaronmk
            if schemas != None:
216 2893 aaronmk
                search_path = [self.esc_name(s) for s in schemas.split(',')]
217
                search_path.append(value(run_query(self, 'SHOW search_path',
218
                    log_level=4)))
219
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
220
                    log_level=3)
221 1869 aaronmk
222
        return self.__db
223 1889 aaronmk
224 1891 aaronmk
    class DbCursor(Proxy):
225 1927 aaronmk
        def __init__(self, outer):
226 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
227 2191 aaronmk
            self.outer = outer
228 1927 aaronmk
            self.query_results = outer.query_results
229 1894 aaronmk
            self.query_lookup = None
230 1891 aaronmk
            self.result = []
231 1889 aaronmk
232 2802 aaronmk
        def execute(self, query):
233 2764 aaronmk
            self._is_insert = query.startswith('INSERT')
234 2797 aaronmk
            self.query_lookup = query
235 2148 aaronmk
            try:
236 2191 aaronmk
                try:
237 2802 aaronmk
                    cur = self.inner.execute(query)
238 2191 aaronmk
                    self.outer.do_autocommit()
239 2802 aaronmk
                finally: self.query = get_cur_query(self.inner, query)
240 1904 aaronmk
            except Exception, e:
241 2802 aaronmk
                _add_cursor_info(e, self, query)
242 1904 aaronmk
                self.result = e # cache the exception as the result
243
                self._cache_result()
244
                raise
245 3004 aaronmk
246
            # Always cache certain queries
247
            if query.startswith('CREATE') or query.startswith('ALTER'):
248 3007 aaronmk
                # structural changes
249 3040 aaronmk
                # Rest of query must be unique in the face of name collisions,
250
                # so don't cache ADD COLUMN unless it has distinguishing comment
251
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
252 3007 aaronmk
                    self._cache_result()
253 3004 aaronmk
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
254 2800 aaronmk
                consume_rows(self) # fetch all rows so result will be cached
255 3004 aaronmk
256 2762 aaronmk
            return cur
257 1894 aaronmk
258 1891 aaronmk
        def fetchone(self):
259
            row = self.inner.fetchone()
260 1899 aaronmk
            if row != None: self.result.append(row)
261
            # otherwise, fetched all rows
262 1904 aaronmk
            else: self._cache_result()
263
            return row
264
265
        def _cache_result(self):
266 2948 aaronmk
            # For inserts that return a result set, don't cache result set since
267
            # inserts are not idempotent. Other non-SELECT queries don't have
268
            # their result set read, so only exceptions will be cached (an
269
            # invalid query will always be invalid).
270 1930 aaronmk
            if self.query_results != None and (not self._is_insert
271 1906 aaronmk
                or isinstance(self.result, Exception)):
272
273 1894 aaronmk
                assert self.query_lookup != None
274 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
275
                    util.dict_subset(dicts.AttrsDictView(self),
276
                    ['query', 'result', 'rowcount', 'description']))
277 1906 aaronmk
278 1916 aaronmk
        class CacheCursor:
279
            def __init__(self, cached_result): self.__dict__ = cached_result
280
281 1927 aaronmk
            def execute(self, *args, **kw_args):
282 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
283
                # otherwise, result is a rows list
284
                self.iter = iter(self.result)
285
286
            def fetchone(self):
287
                try: return self.iter.next()
288
                except StopIteration: return None
289 1891 aaronmk
290 2212 aaronmk
    def esc_value(self, value):
291 2663 aaronmk
        try: str_ = self.mogrify('%s', [value])
292
        except NotImplementedError, e:
293
            module = util.root_module(self.db)
294
            if module == 'MySQLdb':
295
                import _mysql
296
                str_ = _mysql.escape_string(value)
297
            else: raise e
298 2374 aaronmk
        return strings.to_unicode(str_)
299 2212 aaronmk
300 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
301
302 2814 aaronmk
    def std_code(self, str_):
303
        '''Standardizes SQL code.
304
        * Ensures that string literals are prefixed by `E`
305
        '''
306
        if str_.startswith("'"): str_ = 'E'+str_
307
        return str_
308
309 2665 aaronmk
    def can_mogrify(self):
310 2663 aaronmk
        module = util.root_module(self.db)
311 2665 aaronmk
        return module == 'psycopg2'
312 2663 aaronmk
313 2665 aaronmk
    def mogrify(self, query, params=None):
314
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
315
        else: raise NotImplementedError("Can't mogrify query")
316
317 2671 aaronmk
    def print_notices(self):
318 2725 aaronmk
        if hasattr(self.db, 'notices'):
319
            for msg in self.db.notices:
320
                if msg not in self._notices_seen:
321
                    self._notices_seen.add(msg)
322
                    self.log_debug(msg, level=2)
323 2671 aaronmk
324 2793 aaronmk
    def run_query(self, query, cacheable=False, log_level=2,
325 2464 aaronmk
        debug_msg_ref=None):
326 2445 aaronmk
        '''
327 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
328
            throws one of these exceptions.
329 2664 aaronmk
        @param debug_msg_ref If specified, the log message will be returned in
330
            this instead of being output. This allows you to filter log messages
331
            depending on the result of the query.
332 2445 aaronmk
        '''
333 2167 aaronmk
        assert query != None
334
335 2047 aaronmk
        if not self.caching: cacheable = False
336 1903 aaronmk
        used_cache = False
337 2664 aaronmk
338
        def log_msg(query):
339
            if used_cache: cache_status = 'cache hit'
340
            elif cacheable: cache_status = 'cache miss'
341
            else: cache_status = 'non-cacheable'
342
            return 'DB query: '+cache_status+':\n'+strings.as_code(query, 'SQL')
343
344 1903 aaronmk
        try:
345 1927 aaronmk
            # Get cursor
346
            if cacheable:
347
                try:
348 2797 aaronmk
                    cur = self.query_results[query]
349 1927 aaronmk
                    used_cache = True
350
                except KeyError: cur = self.DbCursor(self)
351
            else: cur = self.db.cursor()
352
353 2664 aaronmk
            # Log query
354
            if self.debug and debug_msg_ref == None: # log before running
355
                self.log_debug(log_msg(query), log_level)
356
357 1927 aaronmk
            # Run query
358 2793 aaronmk
            cur.execute(query)
359 1903 aaronmk
        finally:
360 2671 aaronmk
            self.print_notices()
361 2664 aaronmk
            if self.debug and debug_msg_ref != None: # return after running
362 2793 aaronmk
                debug_msg_ref[0] = log_msg(str(get_cur_query(cur, query)))
363 1903 aaronmk
364
        return cur
365 1914 aaronmk
366 2797 aaronmk
    def is_cached(self, query): return query in self.query_results
367 2139 aaronmk
368 2907 aaronmk
    def with_autocommit(self, func):
369 2801 aaronmk
        import psycopg2.extensions
370
371
        prev_isolation_level = self.db.isolation_level
372 2907 aaronmk
        self.db.set_isolation_level(
373
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
374 2683 aaronmk
        try: return func()
375 2801 aaronmk
        finally: self.db.set_isolation_level(prev_isolation_level)
376 2683 aaronmk
377 2139 aaronmk
    def with_savepoint(self, func):
378 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
379 2443 aaronmk
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
380 2139 aaronmk
        self._savepoint += 1
381 2930 aaronmk
        try: return func()
382 2139 aaronmk
        except:
383 2443 aaronmk
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
384 2139 aaronmk
            raise
385 2930 aaronmk
        finally:
386
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
387
            # "The savepoint remains valid and can be rolled back to again"
388
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
389 2443 aaronmk
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
390 2930 aaronmk
391
            self._savepoint -= 1
392
            assert self._savepoint >= 0
393
394
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
395 2191 aaronmk
396
    def do_autocommit(self):
397
        '''Autocommits if outside savepoint'''
398
        assert self._savepoint >= 0
399
        if self.autocommit and self._savepoint == 0:
400 2924 aaronmk
            self.log_debug('Autocommitting', level=4)
401 2191 aaronmk
            self.db.commit()
402 2643 aaronmk
403 2819 aaronmk
    def col_info(self, col):
404 2643 aaronmk
        table = sql_gen.Table('columns', 'information_schema')
405 3063 aaronmk
        type_ = sql_gen.Coalesce(sql_gen.Nullif(sql_gen.Col('data_type'),
406
            'USER-DEFINED'), sql_gen.Col('udt_name'))
407 3078 aaronmk
        cols = [type_, 'column_default',
408
            sql_gen.Cast('boolean', sql_gen.Col('is_nullable'))]
409 2643 aaronmk
410
        conds = [('table_name', col.table.name), ('column_name', col.name)]
411
        schema = col.table.schema
412
        if schema != None: conds.append(('table_schema', schema))
413
414 2819 aaronmk
        type_, default, nullable = row(select(self, table, cols, conds,
415 3059 aaronmk
            order_by='table_schema', limit=1, cacheable=False, log_level=4))
416 2643 aaronmk
            # TODO: order_by search_path schema order
417 2819 aaronmk
        default = sql_gen.as_Code(default, self)
418
419
        return sql_gen.TypedCol(col.name, type_, default, nullable)
420 2917 aaronmk
421
    def TempFunction(self, name):
422
        if self.debug_temp: schema = None
423
        else: schema = 'pg_temp'
424
        return sql_gen.Function(name, schema)
425 1849 aaronmk
426 1869 aaronmk
connect = DbConn
427
428 832 aaronmk
##### Recoverable querying
429 15 aaronmk
430 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
431 11 aaronmk
432 2791 aaronmk
def run_query(db, query, recover=None, cacheable=False, log_level=2,
433
    log_ignore_excs=None, **kw_args):
434 2794 aaronmk
    '''For params, see DbConn.run_query()'''
435 830 aaronmk
    if recover == None: recover = False
436 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
437
    log_ignore_excs = tuple(log_ignore_excs)
438 830 aaronmk
439 2666 aaronmk
    debug_msg_ref = None # usually, db.run_query() logs query before running it
440
    # But if filtering with log_ignore_excs, wait until after exception parsing
441 2984 aaronmk
    if log_ignore_excs != () or not db.can_mogrify(): debug_msg_ref = [None]
442 2666 aaronmk
443 2148 aaronmk
    try:
444 2464 aaronmk
        try:
445 2794 aaronmk
            def run(): return db.run_query(query, cacheable, log_level,
446 2793 aaronmk
                debug_msg_ref, **kw_args)
447 2796 aaronmk
            if recover and not db.is_cached(query):
448 2464 aaronmk
                return with_savepoint(db, run)
449
            else: return run() # don't need savepoint if cached
450
        except Exception, e:
451 3095 aaronmk
            msg = strings.ustr(e.args[0])
452 2464 aaronmk
453 3095 aaronmk
            match = re.match(r'^duplicate key value violates unique constraint '
454 3096 aaronmk
                r'"((_?[^\W_]+(?=[._]))?.+?)"', msg)
455 2464 aaronmk
            if match:
456
                constraint, table = match.groups()
457 3025 aaronmk
                cols = []
458
                if recover: # need auto-rollback to run index_cols()
459
                    try: cols = index_cols(db, table, constraint)
460
                    except NotImplementedError: pass
461
                raise DuplicateKeyException(constraint, cols, e)
462 2464 aaronmk
463 3095 aaronmk
            match = re.match(r'^null value in column "(.+?)" violates not-null'
464 2464 aaronmk
                r' constraint', msg)
465
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
466
467 3095 aaronmk
            match = re.match(r'^(?:invalid input (?:syntax|value)\b.*?'
468 3109 aaronmk
                r'|.+? field value out of range): "(.+?)"', msg)
469 2464 aaronmk
            if match:
470 3109 aaronmk
                value, = match.groups()
471
                raise InvalidValueException(strings.to_unicode(value), e)
472 2464 aaronmk
473 3095 aaronmk
            match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
474 2523 aaronmk
                r'is of type', msg)
475
            if match:
476
                col, type_ = match.groups()
477
                raise MissingCastException(type_, col, e)
478
479 3095 aaronmk
            match = re.match(r'^(\S+) "(.+?)".*? already exists', msg)
480 2945 aaronmk
            if match:
481
                type_, name = match.groups()
482
                raise DuplicateException(type_, name, e)
483 2464 aaronmk
484
            raise # no specific exception raised
485
    except log_ignore_excs:
486
        log_level += 2
487
        raise
488
    finally:
489 2666 aaronmk
        if debug_msg_ref != None and debug_msg_ref[0] != None:
490
            db.log_debug(debug_msg_ref[0], log_level)
491 830 aaronmk
492 832 aaronmk
##### Basic queries
493
494 2153 aaronmk
def next_version(name):
495 2163 aaronmk
    version = 1 # first existing name was version 0
496 2586 aaronmk
    match = re.match(r'^(.*)#(\d+)$', name)
497 2153 aaronmk
    if match:
498 2586 aaronmk
        name, version = match.groups()
499
        version = int(version)+1
500 2932 aaronmk
    return sql_gen.concat(name, '#'+str(version))
501 2153 aaronmk
502 2899 aaronmk
def lock_table(db, table, mode):
503
    table = sql_gen.as_Table(table)
504
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
505
506 2789 aaronmk
def run_query_into(db, query, into=None, add_indexes_=False, **kw_args):
507 2085 aaronmk
    '''Outputs a query to a temp table.
508
    For params, see run_query().
509
    '''
510 2789 aaronmk
    if into == None: return run_query(db, query, **kw_args)
511 2790 aaronmk
512
    assert isinstance(into, sql_gen.Table)
513
514 2992 aaronmk
    into.is_temp = True
515 3008 aaronmk
    # "temporary tables cannot specify a schema name", so remove schema
516
    into.schema = None
517 2992 aaronmk
518 2790 aaronmk
    kw_args['recover'] = True
519 2945 aaronmk
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
520 2790 aaronmk
521 2916 aaronmk
    temp = not db.debug_temp # tables are permanent in debug_temp mode
522 2790 aaronmk
523
    # Create table
524
    while True:
525
        create_query = 'CREATE'
526
        if temp: create_query += ' TEMP'
527
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
528 2385 aaronmk
529 2790 aaronmk
        try:
530
            cur = run_query(db, create_query, **kw_args)
531
                # CREATE TABLE AS sets rowcount to # rows in query
532
            break
533 2945 aaronmk
        except DuplicateException, e:
534 2790 aaronmk
            into.name = next_version(into.name)
535
            # try again with next version of name
536
537
    if add_indexes_: add_indexes(db, into)
538 3075 aaronmk
539
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
540
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
541
    # table is going to be used in complex queries, it is wise to run ANALYZE on
542
    # the temporary table after it is populated."
543
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
544
    # If into is not a temp table, ANALYZE is useful but not required.
545 3073 aaronmk
    analyze(db, into)
546 2790 aaronmk
547
    return cur
548 2085 aaronmk
549 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
550
551 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
552
553 2233 aaronmk
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
554 2293 aaronmk
    start=None, order_by=order_by_pkey, default_table=None):
555 1981 aaronmk
    '''
556 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
557 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
558 1981 aaronmk
    @param fields Use None to select all fields in the table
559 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
560 2379 aaronmk
        * container can be any iterable type
561 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
562
        * compare_right_side: sql_gen.ValueCond|literal value
563 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
564
        use all columns
565 2786 aaronmk
    @return query
566 1981 aaronmk
    '''
567 2315 aaronmk
    # Parse tables param
568 2964 aaronmk
    tables = lists.mk_seq(tables)
569 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
570 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
571 2121 aaronmk
572 2315 aaronmk
    # Parse other params
573 2376 aaronmk
    if conds == None: conds = []
574 2650 aaronmk
    elif dicts.is_dict(conds): conds = conds.items()
575 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
576 135 aaronmk
    assert limit == None or type(limit) == int
577 865 aaronmk
    assert start == None or type(start) == int
578 2315 aaronmk
    if order_by is order_by_pkey:
579
        if distinct_on != []: order_by = None
580
        else: order_by = pkey(db, table0, recover=True)
581 865 aaronmk
582 2315 aaronmk
    query = 'SELECT'
583 2056 aaronmk
584 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
585 2056 aaronmk
586 2200 aaronmk
    # DISTINCT ON columns
587 2233 aaronmk
    if distinct_on != []:
588 2467 aaronmk
        query += '\nDISTINCT'
589 2254 aaronmk
        if distinct_on is not distinct_on_all:
590 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
591
592
    # Columns
593 3027 aaronmk
    if fields == None:
594
        if query.find('\n') >= 0: whitespace = '\n'
595
        else: whitespace = ' '
596
        query += whitespace+'*'
597 2765 aaronmk
    else:
598
        assert fields != []
599 3027 aaronmk
        query += '\n'+('\n, '.join(map(parse_col, fields)))
600 2200 aaronmk
601
    # Main table
602 2467 aaronmk
    query += '\nFROM '+table0.to_str(db)
603 865 aaronmk
604 2122 aaronmk
    # Add joins
605 2271 aaronmk
    left_table = table0
606 2263 aaronmk
    for join_ in tables:
607
        table = join_.table
608 2238 aaronmk
609 2343 aaronmk
        # Parse special values
610
        if join_.type_ is sql_gen.filter_out: # filter no match
611 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
612 2853 aaronmk
                sql_gen.CompareCond(None, '~=')))
613 2343 aaronmk
614 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
615 2122 aaronmk
616
        left_table = table
617
618 865 aaronmk
    missing = True
619 2376 aaronmk
    if conds != []:
620 2576 aaronmk
        if len(conds) == 1: whitespace = ' '
621
        else: whitespace = '\n'
622 2578 aaronmk
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
623
            .to_str(db) for l, r in conds], 'WHERE')
624 865 aaronmk
        missing = False
625 2227 aaronmk
    if order_by != None:
626 2467 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
627
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
628 865 aaronmk
    if start != None:
629 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
630 865 aaronmk
        missing = False
631
    if missing: warnings.warn(DbWarning(
632
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
633
634 2786 aaronmk
    return query
635 11 aaronmk
636 2054 aaronmk
def select(db, *args, **kw_args):
637
    '''For params, see mk_select() and run_query()'''
638
    recover = kw_args.pop('recover', None)
639
    cacheable = kw_args.pop('cacheable', True)
640 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
641 2054 aaronmk
642 2791 aaronmk
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
643
        log_level=log_level)
644 2054 aaronmk
645 2788 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
646 3009 aaronmk
    embeddable=False, ignore=False):
647 1960 aaronmk
    '''
648
    @param returning str|None An inserted column (such as pkey) to return
649 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
650 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
651
        query will be fully cached, not just if it raises an exception.
652 3009 aaronmk
    @param ignore Whether to ignore duplicate keys.
653 1960 aaronmk
    '''
654 2754 aaronmk
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
655 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
656 3010 aaronmk
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
657 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
658 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
659 2063 aaronmk
660 2497 aaronmk
    first_line = 'INSERT INTO '+table.to_str(db)
661 2063 aaronmk
662 3009 aaronmk
    def mk_insert(select_query):
663
        query = first_line
664 3014 aaronmk
        if cols != None:
665
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
666 3009 aaronmk
        query += '\n'+select_query
667
668
        if returning != None:
669
            returning_name_col = sql_gen.to_name_only_col(returning)
670
            query += '\nRETURNING '+returning_name_col.to_str(db)
671
672
        return query
673 2063 aaronmk
674 3017 aaronmk
    return_type = 'unknown'
675
    if returning != None: return_type = returning.to_str(db)+'%TYPE'
676
677 3009 aaronmk
    lang = 'sql'
678
    if ignore:
679 3017 aaronmk
        # Always return something to set the correct rowcount
680
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
681
682 3009 aaronmk
        embeddable = True # must use function
683
        lang = 'plpgsql'
684 3010 aaronmk
685 3092 aaronmk
        if cols == None:
686
            row = [sql_gen.Col(sql_gen.all_cols, 'row')]
687
            row_vars = [sql_gen.Table('row')]
688
        else:
689
            row_vars = row = [sql_gen.Col(c.name, 'row') for c in cols]
690
691 3009 aaronmk
        query = '''\
692 3010 aaronmk
DECLARE
693 3014 aaronmk
    row '''+table.to_str(db)+'''%ROWTYPE;
694 3009 aaronmk
BEGIN
695 3019 aaronmk
    /* Need an EXCEPTION block for each individual row because "When an error is
696
    caught by an EXCEPTION clause, [...] all changes to persistent database
697
    state within the block are rolled back."
698
    This is unfortunate because "A block containing an EXCEPTION clause is
699
    significantly more expensive to enter and exit than a block without one."
700 3015 aaronmk
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
701
#PLPGSQL-ERROR-TRAPPING)
702
    */
703 3092 aaronmk
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
704 3034 aaronmk
'''+select_query+'''
705
    LOOP
706 3015 aaronmk
        BEGIN
707 3019 aaronmk
            RETURN QUERY
708 3014 aaronmk
'''+mk_insert(sql_gen.Values(row).to_str(db))+'''
709 3010 aaronmk
;
710 3015 aaronmk
        EXCEPTION
711 3019 aaronmk
            WHEN unique_violation THEN NULL; -- continue to next row
712 3015 aaronmk
        END;
713 3010 aaronmk
    END LOOP;
714
END;\
715 3009 aaronmk
'''
716
    else: query = mk_insert(select_query)
717
718 2070 aaronmk
    if embeddable:
719
        # Create function
720 2513 aaronmk
        function_name = sql_gen.clean_name(first_line)
721 2189 aaronmk
        while True:
722
            try:
723 2918 aaronmk
                function = db.TempFunction(function_name)
724 2194 aaronmk
725 2189 aaronmk
                function_query = '''\
726 2698 aaronmk
CREATE FUNCTION '''+function.to_str(db)+'''()
727 3017 aaronmk
RETURNS SETOF '''+return_type+'''
728 3009 aaronmk
LANGUAGE '''+lang+'''
729 2467 aaronmk
AS $$
730 3009 aaronmk
'''+query+'''
731 2467 aaronmk
$$;
732 2070 aaronmk
'''
733 2446 aaronmk
                run_query(db, function_query, recover=True, cacheable=True,
734 2945 aaronmk
                    log_ignore_excs=(DuplicateException,))
735 2189 aaronmk
                break # this version was successful
736 2945 aaronmk
            except DuplicateException, e:
737 2189 aaronmk
                function_name = next_version(function_name)
738
                # try again with next version of name
739 2070 aaronmk
740 2337 aaronmk
        # Return query that uses function
741 3009 aaronmk
        cols = None
742
        if returning != None: cols = [returning]
743 2698 aaronmk
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
744 3009 aaronmk
            cols) # AS clause requires function alias
745 2787 aaronmk
        return mk_select(db, func_table, start=0, order_by=None)
746 2070 aaronmk
747 2787 aaronmk
    return query
748 2066 aaronmk
749 3074 aaronmk
def insert_select(db, table, *args, **kw_args):
750 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
751 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
752
        values in
753 2072 aaronmk
    '''
754 2386 aaronmk
    into = kw_args.pop('into', None)
755
    if into != None: kw_args['embeddable'] = True
756 2066 aaronmk
    recover = kw_args.pop('recover', None)
757 3011 aaronmk
    if kw_args.get('ignore', False): recover = True
758 2066 aaronmk
    cacheable = kw_args.pop('cacheable', True)
759 2673 aaronmk
    log_level = kw_args.pop('log_level', 2)
760 2066 aaronmk
761 3074 aaronmk
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
762
        into, recover=recover, cacheable=cacheable, log_level=log_level)
763
    autoanalyze(db, table)
764
    return cur
765 2063 aaronmk
766 2738 aaronmk
default = sql_gen.default # tells insert() to use the default value for a column
767 2066 aaronmk
768 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
769 2085 aaronmk
    '''For params, see insert_select()'''
770 1960 aaronmk
    if lists.is_seq(row): cols = None
771
    else:
772
        cols = row.keys()
773
        row = row.values()
774 2738 aaronmk
    row = list(row) # ensure that "== []" works
775 1960 aaronmk
776 2738 aaronmk
    if row == []: query = None
777
    else: query = sql_gen.Values(row).to_str(db)
778 1961 aaronmk
779 2788 aaronmk
    return insert_select(db, table, cols, query, *args, **kw_args)
780 11 aaronmk
781 3056 aaronmk
def mk_update(db, table, changes=None, cond=None, in_place=False):
782 2402 aaronmk
    '''
783
    @param changes [(col, new_value),...]
784
        * container can be any iterable type
785
        * col: sql_gen.Code|str (for col name)
786
        * new_value: sql_gen.Code|literal value
787
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
788 3056 aaronmk
    @param in_place If set, locks the table and updates rows in place.
789
        This avoids creating dead rows in PostgreSQL.
790
        * cond must be None
791 2402 aaronmk
    @return str query
792
    '''
793 3057 aaronmk
    table = sql_gen.as_Table(table)
794
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
795
        for c, v in changes]
796
797 3056 aaronmk
    if in_place:
798
        assert cond == None
799 3058 aaronmk
800 3065 aaronmk
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
801
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
802
            +db.col_info(sql_gen.with_default_table(c, table)).type
803
            +'\nUSING '+v.to_str(db) for c, v in changes))
804 3058 aaronmk
    else:
805
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
806
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
807
            for c, v in changes))
808
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
809 3056 aaronmk
810 2402 aaronmk
    return query
811
812 3074 aaronmk
def update(db, table, *args, **kw_args):
813 2402 aaronmk
    '''For params, see mk_update() and run_query()'''
814
    recover = kw_args.pop('recover', None)
815 3043 aaronmk
    cacheable = kw_args.pop('cacheable', False)
816 3030 aaronmk
    log_level = kw_args.pop('log_level', 2)
817 2402 aaronmk
818 3074 aaronmk
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
819
        cacheable, log_level=log_level)
820
    autoanalyze(db, table)
821
    return cur
822 2402 aaronmk
823 135 aaronmk
def last_insert_id(db):
824 1849 aaronmk
    module = util.root_module(db.db)
825 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
826
    elif module == 'MySQLdb': return db.insert_id()
827
    else: return None
828 13 aaronmk
829 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
830 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
831 2415 aaronmk
    to names that will be distinct among the columns' tables.
832 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
833 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
834
    @param into The table for the new columns.
835 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
836
        columns will be included in the mapping even if they are not in cols.
837
        The tables of the provided Col objects will be changed to into, so make
838
        copies of them if you want to keep the original tables.
839
    @param as_items Whether to return a list of dict items instead of a dict
840 2383 aaronmk
    @return dict(orig_col=new_col, ...)
841
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
842 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
843
        * All mappings use the into table so its name can easily be
844 2383 aaronmk
          changed for all columns at once
845
    '''
846 2415 aaronmk
    cols = lists.uniqify(cols)
847
848 2394 aaronmk
    items = []
849 2389 aaronmk
    for col in preserve:
850 2390 aaronmk
        orig_col = copy.copy(col)
851 2392 aaronmk
        col.table = into
852 2394 aaronmk
        items.append((orig_col, col))
853
    preserve = set(preserve)
854
    for col in cols:
855 2716 aaronmk
        if col not in preserve:
856
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
857 2394 aaronmk
858
    if not as_items: items = dict(items)
859
    return items
860 2383 aaronmk
861 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
862 2391 aaronmk
    '''For params, see mk_flatten_mapping()
863
    @return See return value of mk_flatten_mapping()
864
    '''
865 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
866
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
867 2786 aaronmk
    run_query_into(db, mk_select(db, joins, cols, limit=limit, start=start),
868 2846 aaronmk
        into=into, add_indexes_=True)
869 2394 aaronmk
    return dict(items)
870 2391 aaronmk
871 3079 aaronmk
##### Database structure introspection
872 2414 aaronmk
873 3079 aaronmk
#### Tables
874
875
def tables(db, schema_like='public', table_like='%', exact=False):
876
    if exact: compare = '='
877
    else: compare = 'LIKE'
878
879
    module = util.root_module(db.db)
880
    if module == 'psycopg2':
881
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
882
            ('tablename', sql_gen.CompareCond(table_like, compare))]
883
        return values(select(db, 'pg_tables', ['tablename'], conds,
884
            order_by='tablename', log_level=4))
885
    elif module == 'MySQLdb':
886
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
887
            , cacheable=True, log_level=4))
888
    else: raise NotImplementedError("Can't list tables for "+module+' database')
889
890
def table_exists(db, table):
891
    table = sql_gen.as_Table(table)
892
    return list(tables(db, table.schema, table.name, exact=True)) != []
893
894 2426 aaronmk
def table_row_count(db, table, recover=None):
895 2786 aaronmk
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
896 2443 aaronmk
        order_by=None, start=0), recover=recover, log_level=3))
897 2426 aaronmk
898 2414 aaronmk
def table_cols(db, table, recover=None):
899
    return list(col_names(select(db, table, limit=0, order_by=None,
900 2443 aaronmk
        recover=recover, log_level=4)))
901 2414 aaronmk
902 2291 aaronmk
def pkey(db, table, recover=None):
903 832 aaronmk
    '''Assumed to be first column in table'''
904 2339 aaronmk
    return table_cols(db, table, recover)[0]
905 832 aaronmk
906 2559 aaronmk
not_null_col = 'not_null_col'
907 2340 aaronmk
908
def table_not_null_col(db, table, recover=None):
909
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
910
    if not_null_col in table_cols(db, table, recover): return not_null_col
911
    else: return pkey(db, table, recover)
912
913 853 aaronmk
def index_cols(db, table, index):
914
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
915
    automatically created. When you don't know whether something is a UNIQUE
916
    constraint or a UNIQUE index, use this function.'''
917 1909 aaronmk
    module = util.root_module(db.db)
918
    if module == 'psycopg2':
919
        return list(values(run_query(db, '''\
920 853 aaronmk
SELECT attname
921 866 aaronmk
FROM
922
(
923
        SELECT attnum, attname
924
        FROM pg_index
925
        JOIN pg_class index ON index.oid = indexrelid
926
        JOIN pg_class table_ ON table_.oid = indrelid
927
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
928
        WHERE
929 2782 aaronmk
            table_.relname = '''+db.esc_value(table)+'''
930
            AND index.relname = '''+db.esc_value(index)+'''
931 866 aaronmk
    UNION
932
        SELECT attnum, attname
933
        FROM
934
        (
935
            SELECT
936
                indrelid
937
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
938
                    AS indkey
939
            FROM pg_index
940
            JOIN pg_class index ON index.oid = indexrelid
941
            JOIN pg_class table_ ON table_.oid = indrelid
942
            WHERE
943 2782 aaronmk
                table_.relname = '''+db.esc_value(table)+'''
944
                AND index.relname = '''+db.esc_value(index)+'''
945 866 aaronmk
        ) s
946
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
947
) s
948 853 aaronmk
ORDER BY attnum
949 2782 aaronmk
'''
950
            , cacheable=True, log_level=4)))
951 1909 aaronmk
    else: raise NotImplementedError("Can't list index columns for "+module+
952
        ' database')
953 853 aaronmk
954 464 aaronmk
def constraint_cols(db, table, constraint):
955 1849 aaronmk
    module = util.root_module(db.db)
956 464 aaronmk
    if module == 'psycopg2':
957
        return list(values(run_query(db, '''\
958
SELECT attname
959
FROM pg_constraint
960
JOIN pg_class ON pg_class.oid = conrelid
961
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
962
WHERE
963 2783 aaronmk
    relname = '''+db.esc_value(table)+'''
964
    AND conname = '''+db.esc_value(constraint)+'''
965 464 aaronmk
ORDER BY attnum
966 2783 aaronmk
'''
967
            )))
968 464 aaronmk
    else: raise NotImplementedError("Can't list constraint columns for "+module+
969
        ' database')
970
971 3079 aaronmk
#### Functions
972
973
def function_exists(db, function):
974
    function = sql_gen.as_Function(function)
975
976
    info_table = sql_gen.Table('routines', 'information_schema')
977
    conds = [('routine_name', function.name)]
978
    schema = function.schema
979
    if schema != None: conds.append(('routine_schema', schema))
980
    # Exclude trigger functions, since they cannot be called directly
981
    conds.append(('data_type', sql_gen.CompareCond('trigger', '!=')))
982
983
    return list(values(select(db, info_table, ['routine_name'], conds,
984
        order_by='routine_schema', limit=1, log_level=4))) != []
985
        # TODO: order_by search_path schema order
986
987
##### Structural changes
988
989
#### Columns
990
991
def add_col(db, table, col, comment=None, **kw_args):
992
    '''
993
    @param col TypedCol Name may be versioned, so be sure to propagate any
994
        renaming back to any source column for the TypedCol.
995
    @param comment None|str SQL comment used to distinguish columns of the same
996
        name from each other when they contain different data, to allow the
997
        ADD COLUMN query to be cached. If not set, query will not be cached.
998
    '''
999
    assert isinstance(col, sql_gen.TypedCol)
1000
1001
    while True:
1002
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1003
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1004
1005
        try:
1006
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1007
            break
1008
        except DuplicateException:
1009
            col.name = next_version(col.name)
1010
            # try again with next version of name
1011
1012
def add_not_null(db, col):
1013
    table = col.table
1014
    col = sql_gen.to_name_only_col(col)
1015
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1016
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1017
1018 2096 aaronmk
row_num_col = '_row_num'
1019
1020 3079 aaronmk
row_num_typed_col = sql_gen.TypedCol(row_num_col, 'serial', nullable=False,
1021
    constraints='PRIMARY KEY')
1022
1023
def add_row_num(db, table):
1024
    '''Adds a row number column to a table. Its name is in row_num_col. It will
1025
    be the primary key.'''
1026
    add_col(db, table, row_num_typed_col, log_level=3)
1027
1028
#### Indexes
1029
1030
def add_pkey(db, table, cols=None, recover=None):
1031
    '''Adds a primary key.
1032
    @param cols [sql_gen.Col,...] The columns in the primary key.
1033
        Defaults to the first column in the table.
1034
    @pre The table must not already have a primary key.
1035
    '''
1036
    table = sql_gen.as_Table(table)
1037
    if cols == None: cols = [pkey(db, table, recover)]
1038
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1039
1040
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1041
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1042
        log_ignore_excs=(DuplicateException,))
1043
1044 2998 aaronmk
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1045 2688 aaronmk
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1046 2538 aaronmk
    Currently, only function calls are supported as expressions.
1047 2998 aaronmk
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1048 2847 aaronmk
        This allows indexes to be used for comparisons where NULLs are equal.
1049 2538 aaronmk
    '''
1050 2964 aaronmk
    exprs = lists.mk_seq(exprs)
1051 2538 aaronmk
1052 2688 aaronmk
    # Parse exprs
1053
    old_exprs = exprs[:]
1054
    exprs = []
1055
    cols = []
1056
    for i, expr in enumerate(old_exprs):
1057 2823 aaronmk
        expr = sql_gen.as_Col(expr, table)
1058 2688 aaronmk
1059 2823 aaronmk
        # Handle nullable columns
1060 2998 aaronmk
        if ensure_not_null_:
1061
            try: expr = ensure_not_null(db, expr)
1062 2860 aaronmk
            except KeyError: pass # unknown type, so just create plain index
1063 2823 aaronmk
1064 2688 aaronmk
        # Extract col
1065 3002 aaronmk
        expr = copy.deepcopy(expr) # don't modify input!
1066 2688 aaronmk
        if isinstance(expr, sql_gen.FunctionCall):
1067
            col = expr.args[0]
1068
            expr = sql_gen.Expr(expr)
1069
        else: col = expr
1070 2823 aaronmk
        assert isinstance(col, sql_gen.Col)
1071 2688 aaronmk
1072
        # Extract table
1073
        if table == None:
1074
            assert sql_gen.is_table_col(col)
1075
            table = col.table
1076
1077
        col.table = None
1078
1079
        exprs.append(expr)
1080
        cols.append(col)
1081 2408 aaronmk
1082 2688 aaronmk
    table = sql_gen.as_Table(table)
1083
    index = sql_gen.Table(str(sql_gen.Col(','.join(map(str, cols)), table)))
1084
1085 3005 aaronmk
    # Add index
1086
    while True:
1087
        str_ = 'CREATE'
1088
        if unique: str_ += ' UNIQUE'
1089
        str_ += ' INDEX '+index.to_str(db)+' ON '+table.to_str(db)+' ('+(
1090
            ', '.join((v.to_str(db) for v in exprs)))+')'
1091
1092
        try:
1093
            run_query(db, str_, recover=True, cacheable=True, log_level=3,
1094
                log_ignore_excs=(DuplicateException,))
1095
            break
1096
        except DuplicateException:
1097
            index.name = next_version(index.name)
1098
            # try again with next version of name
1099 2408 aaronmk
1100 2997 aaronmk
def add_index_col(db, col, suffix, expr, nullable=True):
1101 3000 aaronmk
    if sql_gen.index_col(col) != None: return # already has index col
1102 2997 aaronmk
1103
    new_col = sql_gen.suffixed_col(col, suffix)
1104
1105 3006 aaronmk
    # Add column
1106 3038 aaronmk
    new_typed_col = sql_gen.TypedCol(new_col.name, db.col_info(col).type)
1107 3045 aaronmk
    add_col(db, col.table, new_typed_col, comment='src: '+repr(col),
1108
        log_level=3)
1109 3037 aaronmk
    new_col.name = new_typed_col.name # propagate any renaming
1110 3006 aaronmk
1111 3064 aaronmk
    update(db, col.table, [(new_col, expr)], in_place=True, cacheable=True,
1112
        log_level=3)
1113 2997 aaronmk
    if not nullable: add_not_null(db, new_col)
1114
    add_index(db, new_col)
1115
1116 3104 aaronmk
    col.table.index_cols[col.name] = new_col.name
1117 2997 aaronmk
1118 3047 aaronmk
# Controls when ensure_not_null() will use index columns
1119
not_null_index_cols_min_rows = 0 # rows; initially always use index columns
1120
1121 2997 aaronmk
def ensure_not_null(db, col):
1122
    '''For params, see sql_gen.ensure_not_null()'''
1123
    expr = sql_gen.ensure_not_null(db, col)
1124
1125 3047 aaronmk
    # If a nullable column in a temp table, add separate index column instead.
1126
    # Note that for small datasources, this adds 6-25% to the total import time.
1127
    if (sql_gen.is_temp_col(col) and isinstance(expr, sql_gen.EnsureNotNull)
1128
        and table_row_count(db, col.table) >= not_null_index_cols_min_rows):
1129 2997 aaronmk
        add_index_col(db, col, '::NOT NULL', expr, nullable=False)
1130 3000 aaronmk
        expr = sql_gen.index_col(col)
1131 2997 aaronmk
1132
    return expr
1133
1134 3083 aaronmk
already_indexed = object() # tells add_indexes() the pkey has already been added
1135
1136
def add_indexes(db, table, has_pkey=True):
1137
    '''Adds an index on all columns in a table.
1138
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1139
        index should be added on the first column.
1140
        * If already_indexed, the pkey is assumed to have already been added
1141
    '''
1142
    cols = table_cols(db, table)
1143
    if has_pkey:
1144
        if has_pkey is not already_indexed: add_pkey(db, table)
1145
        cols = cols[1:]
1146
    for col in cols: add_index(db, col, table)
1147
1148 3079 aaronmk
#### Tables
1149 2772 aaronmk
1150 3079 aaronmk
### Maintenance
1151 2772 aaronmk
1152 3079 aaronmk
def analyze(db, table):
1153
    table = sql_gen.as_Table(table)
1154
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1155 2934 aaronmk
1156 3079 aaronmk
def autoanalyze(db, table):
1157
    if db.autoanalyze: analyze(db, table)
1158 2935 aaronmk
1159 3079 aaronmk
def vacuum(db, table):
1160
    table = sql_gen.as_Table(table)
1161
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1162
        log_level=3))
1163 2086 aaronmk
1164 3079 aaronmk
### Lifecycle
1165
1166 2889 aaronmk
def drop_table(db, table):
1167
    table = sql_gen.as_Table(table)
1168
    return run_query(db, 'DROP TABLE IF EXISTS '+table.to_str(db)+' CASCADE')
1169
1170 3082 aaronmk
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1171
    like=None):
1172 2675 aaronmk
    '''Creates a table.
1173 2681 aaronmk
    @param cols [sql_gen.TypedCol,...] The column names and types
1174
    @param has_pkey If set, the first column becomes the primary key.
1175 2760 aaronmk
    @param col_indexes bool|[ref]
1176
        * If True, indexes will be added on all non-pkey columns.
1177
        * If a list reference, [0] will be set to a function to do this.
1178
          This can be used to delay index creation until the table is populated.
1179 2675 aaronmk
    '''
1180
    table = sql_gen.as_Table(table)
1181
1182 3082 aaronmk
    if like != None:
1183
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1184
            ]+cols
1185 2681 aaronmk
    if has_pkey:
1186
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1187 2872 aaronmk
        pkey.constraints = 'PRIMARY KEY'
1188 2681 aaronmk
1189 3085 aaronmk
    temp = table.is_temp and not db.debug_temp
1190
        # temp tables permanent in debug_temp mode
1191 2760 aaronmk
1192 3085 aaronmk
    # Create table
1193
    while True:
1194
        str_ = 'CREATE'
1195
        if temp: str_ += ' TEMP'
1196
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1197
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1198
        str_ += '\n);\n'
1199
1200
        try:
1201
            run_query(db, str_, cacheable=True, log_level=2,
1202
                log_ignore_excs=(DuplicateException,))
1203
            break
1204
        except DuplicateException:
1205
            table.name = next_version(table.name)
1206
            # try again with next version of name
1207
1208 2760 aaronmk
    # Add indexes
1209 2773 aaronmk
    if has_pkey: has_pkey = already_indexed
1210
    def add_indexes_(): add_indexes(db, table, has_pkey)
1211
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1212
    elif col_indexes: add_indexes_() # add now
1213 2675 aaronmk
1214 3084 aaronmk
def copy_table_struct(db, src, dest):
1215
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1216 3085 aaronmk
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1217 3084 aaronmk
1218 3079 aaronmk
### Data
1219 2684 aaronmk
1220 2970 aaronmk
def truncate(db, table, schema='public', **kw_args):
1221
    '''For params, see run_query()'''
1222 2777 aaronmk
    table = sql_gen.as_Table(table, schema)
1223 2970 aaronmk
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1224 2732 aaronmk
1225 2965 aaronmk
def empty_temp(db, tables):
1226 2972 aaronmk
    if db.debug_temp: return # leave temp tables there for debugging
1227 2965 aaronmk
    tables = lists.mk_seq(tables)
1228 2971 aaronmk
    for table in tables: truncate(db, table, log_level=3)
1229 2965 aaronmk
1230 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
1231
    '''For kw_args, see tables()'''
1232
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1233 3094 aaronmk
1234
def distinct_table(db, table, distinct_on):
1235
    '''Creates a copy of a temp table which is distinct on the given columns.
1236 3099 aaronmk
    The old and new tables will both get an index on these columns, to
1237
    facilitate merge joins.
1238 3097 aaronmk
    @param distinct_on If empty, creates a table with one row. This is useful if
1239
        your distinct_on columns are all literal values.
1240 3099 aaronmk
    @return The new table.
1241 3094 aaronmk
    '''
1242 3099 aaronmk
    new_table = sql_gen.suffixed_table(table, '_distinct')
1243 3094 aaronmk
1244 3099 aaronmk
    copy_table_struct(db, table, new_table)
1245 3097 aaronmk
1246
    limit = None
1247
    if distinct_on == []: limit = 1 # one sample row
1248 3099 aaronmk
    else:
1249
        add_index(db, distinct_on, new_table, unique=True)
1250
        add_index(db, distinct_on, table) # for join optimization
1251 3097 aaronmk
1252 3099 aaronmk
    insert_select(db, new_table, None, mk_select(db, table, start=0,
1253 3097 aaronmk
        limit=limit), ignore=True)
1254 3099 aaronmk
    analyze(db, new_table)
1255 3094 aaronmk
1256 3099 aaronmk
    return new_table