Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1960 aaronmk
import lists
11 1889 aaronmk
from Proxy import Proxy
12 1872 aaronmk
import rand
13 2217 aaronmk
import sql_gen
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 2804 aaronmk
def get_cur_query(cur, input_query=None):
20 2168 aaronmk
    raw_query = None
21
    if hasattr(cur, 'query'): raw_query = cur.query
22
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
23 2170 aaronmk
24
    if raw_query != None: return raw_query
25 2804 aaronmk
    else: return '[input] '+strings.ustr(input_query)
26 14 aaronmk
27 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
28
    '''For params, see get_cur_query()'''
29 2771 aaronmk
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
30 135 aaronmk
31 300 aaronmk
class DbException(exc.ExceptionWithCause):
32 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
33 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
34 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
35
36 2143 aaronmk
class ExceptionWithName(DbException):
37
    def __init__(self, name, cause=None):
38 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
39 2143 aaronmk
        self.name = name
40 360 aaronmk
41 3109 aaronmk
class ExceptionWithValue(DbException):
42
    def __init__(self, value, cause=None):
43
        DbException.__init__(self, 'for value: '+strings.as_tt(repr(value)),
44
            cause)
45 2240 aaronmk
        self.value = value
46
47 2945 aaronmk
class ExceptionWithNameType(DbException):
48
    def __init__(self, type_, name, cause=None):
49
        DbException.__init__(self, 'for type: '+strings.as_tt(str(type_))
50
            +'; name: '+strings.as_tt(name), cause)
51
        self.type = type_
52
        self.name = name
53
54 2306 aaronmk
class ConstraintException(DbException):
55
    def __init__(self, name, cols, cause=None):
56 2484 aaronmk
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
57
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
58 2306 aaronmk
        self.name = name
59 468 aaronmk
        self.cols = cols
60 11 aaronmk
61 2523 aaronmk
class MissingCastException(DbException):
62
    def __init__(self, type_, col, cause=None):
63
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
64
            +' on column: '+strings.as_tt(col), cause)
65
        self.type = type_
66
        self.col = col
67
68 2143 aaronmk
class NameException(DbException): pass
69
70 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
71 13 aaronmk
72 2306 aaronmk
class NullValueException(ConstraintException): pass
73 13 aaronmk
74 3109 aaronmk
class InvalidValueException(ExceptionWithValue): pass
75 2239 aaronmk
76 2945 aaronmk
class DuplicateException(ExceptionWithNameType): pass
77 2143 aaronmk
78 89 aaronmk
class EmptyRowException(DbException): pass
79
80 865 aaronmk
##### Warnings
81
82
class DbWarning(UserWarning): pass
83
84 1930 aaronmk
##### Result retrieval
85
86
def col_names(cur): return (col[0] for col in cur.description)
87
88
def rows(cur): return iter(lambda: cur.fetchone(), None)
89
90
def consume_rows(cur):
91
    '''Used to fetch all rows so result will be cached'''
92
    iters.consume_iter(rows(cur))
93
94
def next_row(cur): return rows(cur).next()
95
96
def row(cur):
97
    row_ = next_row(cur)
98
    consume_rows(cur)
99
    return row_
100
101
def next_value(cur): return next_row(cur)[0]
102
103
def value(cur): return row(cur)[0]
104
105
def values(cur): return iters.func_iter(lambda: next_value(cur))
106
107
def value_or_none(cur):
108
    try: return value(cur)
109
    except StopIteration: return None
110
111 2762 aaronmk
##### Escaping
112 2101 aaronmk
113 2573 aaronmk
def esc_name_by_module(module, name):
114
    if module == 'psycopg2' or module == None: quote = '"'
115 2101 aaronmk
    elif module == 'MySQLdb': quote = '`'
116
    else: raise NotImplementedError("Can't escape name for "+module+' database')
117 2500 aaronmk
    return sql_gen.esc_name(name, quote)
118 2101 aaronmk
119
def esc_name_by_engine(engine, name, **kw_args):
120
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
121
122
def esc_name(db, name, **kw_args):
123
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
124
125
def qual_name(db, schema, table):
126
    def esc_name_(name): return esc_name(db, name)
127
    table = esc_name_(table)
128
    if schema != None: return esc_name_(schema)+'.'+table
129
    else: return table
130
131 1869 aaronmk
##### Database connections
132 1849 aaronmk
133 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
134 1926 aaronmk
135 1869 aaronmk
db_engines = {
136
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
137
    'PostgreSQL': ('psycopg2', {}),
138
}
139
140
DatabaseErrors_set = set([DbException])
141
DatabaseErrors = tuple(DatabaseErrors_set)
142
143
def _add_module(module):
144
    DatabaseErrors_set.add(module.DatabaseError)
145
    global DatabaseErrors
146
    DatabaseErrors = tuple(DatabaseErrors_set)
147
148
def db_config_str(db_config):
149
    return db_config['engine']+' database '+db_config['database']
150
151 2448 aaronmk
log_debug_none = lambda msg, level=2: None
152 1901 aaronmk
153 1849 aaronmk
class DbConn:
154 2923 aaronmk
    def __init__(self, db_config, autocommit=True, caching=True,
155 2915 aaronmk
        log_debug=log_debug_none, debug_temp=False):
156
        '''
157
        @param debug_temp Whether temporary objects should instead be permanent.
158
            This assists in debugging the internal objects used by the program.
159
        '''
160 1869 aaronmk
        self.db_config = db_config
161 2190 aaronmk
        self.autocommit = autocommit
162
        self.caching = caching
163 1901 aaronmk
        self.log_debug = log_debug
164 2193 aaronmk
        self.debug = log_debug != log_debug_none
165 2915 aaronmk
        self.debug_temp = debug_temp
166 3074 aaronmk
        self.autoanalyze = False
167 1869 aaronmk
168 3120 aaronmk
        self._reset()
169 1869 aaronmk
170
    def __getattr__(self, name):
171
        if name == '__dict__': raise Exception('getting __dict__')
172
        if name == 'db': return self._db()
173
        else: raise AttributeError()
174
175
    def __getstate__(self):
176
        state = copy.copy(self.__dict__) # shallow copy
177 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
178 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
179
        return state
180
181 3118 aaronmk
    def clear_cache(self): self.query_results = {}
182
183 3120 aaronmk
    def _reset(self):
184 3118 aaronmk
        self.clear_cache()
185
        self._savepoint = 0
186
        self._notices_seen = set()
187
        self.__db = None
188
189 2165 aaronmk
    def connected(self): return self.__db != None
190
191 3116 aaronmk
    def close(self):
192 3119 aaronmk
        if not self.connected(): return
193
194
        self.db.close()
195 3120 aaronmk
        self._reset()
196 3116 aaronmk
197 1869 aaronmk
    def _db(self):
198
        if self.__db == None:
199
            # Process db_config
200
            db_config = self.db_config.copy() # don't modify input!
201 2097 aaronmk
            schemas = db_config.pop('schemas', None)
202 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
203
            module = __import__(module_name)
204
            _add_module(module)
205
            for orig, new in mappings.iteritems():
206
                try: util.rename_key(db_config, orig, new)
207
                except KeyError: pass
208
209
            # Connect
210
            self.__db = module.connect(**db_config)
211
212
            # Configure connection
213 2906 aaronmk
            if hasattr(self.db, 'set_isolation_level'):
214
                import psycopg2.extensions
215
                self.db.set_isolation_level(
216
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
217 2101 aaronmk
            if schemas != None:
218 2893 aaronmk
                search_path = [self.esc_name(s) for s in schemas.split(',')]
219
                search_path.append(value(run_query(self, 'SHOW search_path',
220
                    log_level=4)))
221
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
222
                    log_level=3)
223 1869 aaronmk
224
        return self.__db
225 1889 aaronmk
226 1891 aaronmk
    class DbCursor(Proxy):
227 1927 aaronmk
        def __init__(self, outer):
228 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
229 2191 aaronmk
            self.outer = outer
230 1927 aaronmk
            self.query_results = outer.query_results
231 1894 aaronmk
            self.query_lookup = None
232 1891 aaronmk
            self.result = []
233 1889 aaronmk
234 2802 aaronmk
        def execute(self, query):
235 2764 aaronmk
            self._is_insert = query.startswith('INSERT')
236 2797 aaronmk
            self.query_lookup = query
237 2148 aaronmk
            try:
238 2191 aaronmk
                try:
239 2802 aaronmk
                    cur = self.inner.execute(query)
240 2191 aaronmk
                    self.outer.do_autocommit()
241 2802 aaronmk
                finally: self.query = get_cur_query(self.inner, query)
242 1904 aaronmk
            except Exception, e:
243 2802 aaronmk
                _add_cursor_info(e, self, query)
244 1904 aaronmk
                self.result = e # cache the exception as the result
245
                self._cache_result()
246
                raise
247 3004 aaronmk
248
            # Always cache certain queries
249
            if query.startswith('CREATE') or query.startswith('ALTER'):
250 3007 aaronmk
                # structural changes
251 3040 aaronmk
                # Rest of query must be unique in the face of name collisions,
252
                # so don't cache ADD COLUMN unless it has distinguishing comment
253
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
254 3007 aaronmk
                    self._cache_result()
255 3004 aaronmk
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
256 2800 aaronmk
                consume_rows(self) # fetch all rows so result will be cached
257 3004 aaronmk
258 2762 aaronmk
            return cur
259 1894 aaronmk
260 1891 aaronmk
        def fetchone(self):
261
            row = self.inner.fetchone()
262 1899 aaronmk
            if row != None: self.result.append(row)
263
            # otherwise, fetched all rows
264 1904 aaronmk
            else: self._cache_result()
265
            return row
266
267
        def _cache_result(self):
268 2948 aaronmk
            # For inserts that return a result set, don't cache result set since
269
            # inserts are not idempotent. Other non-SELECT queries don't have
270
            # their result set read, so only exceptions will be cached (an
271
            # invalid query will always be invalid).
272 1930 aaronmk
            if self.query_results != None and (not self._is_insert
273 1906 aaronmk
                or isinstance(self.result, Exception)):
274
275 1894 aaronmk
                assert self.query_lookup != None
276 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
277
                    util.dict_subset(dicts.AttrsDictView(self),
278
                    ['query', 'result', 'rowcount', 'description']))
279 1906 aaronmk
280 1916 aaronmk
        class CacheCursor:
281
            def __init__(self, cached_result): self.__dict__ = cached_result
282
283 1927 aaronmk
            def execute(self, *args, **kw_args):
284 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
285
                # otherwise, result is a rows list
286
                self.iter = iter(self.result)
287
288
            def fetchone(self):
289
                try: return self.iter.next()
290
                except StopIteration: return None
291 1891 aaronmk
292 2212 aaronmk
    def esc_value(self, value):
293 2663 aaronmk
        try: str_ = self.mogrify('%s', [value])
294
        except NotImplementedError, e:
295
            module = util.root_module(self.db)
296
            if module == 'MySQLdb':
297
                import _mysql
298
                str_ = _mysql.escape_string(value)
299
            else: raise e
300 2374 aaronmk
        return strings.to_unicode(str_)
301 2212 aaronmk
302 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
303
304 2814 aaronmk
    def std_code(self, str_):
305
        '''Standardizes SQL code.
306
        * Ensures that string literals are prefixed by `E`
307
        '''
308
        if str_.startswith("'"): str_ = 'E'+str_
309
        return str_
310
311 2665 aaronmk
    def can_mogrify(self):
312 2663 aaronmk
        module = util.root_module(self.db)
313 2665 aaronmk
        return module == 'psycopg2'
314 2663 aaronmk
315 2665 aaronmk
    def mogrify(self, query, params=None):
316
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
317
        else: raise NotImplementedError("Can't mogrify query")
318
319 2671 aaronmk
    def print_notices(self):
320 2725 aaronmk
        if hasattr(self.db, 'notices'):
321
            for msg in self.db.notices:
322
                if msg not in self._notices_seen:
323
                    self._notices_seen.add(msg)
324
                    self.log_debug(msg, level=2)
325 2671 aaronmk
326 2793 aaronmk
    def run_query(self, query, cacheable=False, log_level=2,
327 2464 aaronmk
        debug_msg_ref=None):
328 2445 aaronmk
        '''
329 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
330
            throws one of these exceptions.
331 2664 aaronmk
        @param debug_msg_ref If specified, the log message will be returned in
332
            this instead of being output. This allows you to filter log messages
333
            depending on the result of the query.
334 2445 aaronmk
        '''
335 2167 aaronmk
        assert query != None
336
337 2047 aaronmk
        if not self.caching: cacheable = False
338 1903 aaronmk
        used_cache = False
339 2664 aaronmk
340
        def log_msg(query):
341
            if used_cache: cache_status = 'cache hit'
342
            elif cacheable: cache_status = 'cache miss'
343
            else: cache_status = 'non-cacheable'
344
            return 'DB query: '+cache_status+':\n'+strings.as_code(query, 'SQL')
345
346 1903 aaronmk
        try:
347 1927 aaronmk
            # Get cursor
348
            if cacheable:
349
                try:
350 2797 aaronmk
                    cur = self.query_results[query]
351 1927 aaronmk
                    used_cache = True
352
                except KeyError: cur = self.DbCursor(self)
353
            else: cur = self.db.cursor()
354
355 2664 aaronmk
            # Log query
356
            if self.debug and debug_msg_ref == None: # log before running
357
                self.log_debug(log_msg(query), log_level)
358
359 1927 aaronmk
            # Run query
360 2793 aaronmk
            cur.execute(query)
361 1903 aaronmk
        finally:
362 2671 aaronmk
            self.print_notices()
363 2664 aaronmk
            if self.debug and debug_msg_ref != None: # return after running
364 2793 aaronmk
                debug_msg_ref[0] = log_msg(str(get_cur_query(cur, query)))
365 1903 aaronmk
366
        return cur
367 1914 aaronmk
368 2797 aaronmk
    def is_cached(self, query): return query in self.query_results
369 2139 aaronmk
370 2907 aaronmk
    def with_autocommit(self, func):
371 2801 aaronmk
        import psycopg2.extensions
372
373
        prev_isolation_level = self.db.isolation_level
374 2907 aaronmk
        self.db.set_isolation_level(
375
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
376 2683 aaronmk
        try: return func()
377 2801 aaronmk
        finally: self.db.set_isolation_level(prev_isolation_level)
378 2683 aaronmk
379 2139 aaronmk
    def with_savepoint(self, func):
380 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
381 2443 aaronmk
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
382 2139 aaronmk
        self._savepoint += 1
383 2930 aaronmk
        try: return func()
384 2139 aaronmk
        except:
385 2443 aaronmk
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
386 2139 aaronmk
            raise
387 2930 aaronmk
        finally:
388
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
389
            # "The savepoint remains valid and can be rolled back to again"
390
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
391 2443 aaronmk
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
392 2930 aaronmk
393
            self._savepoint -= 1
394
            assert self._savepoint >= 0
395
396
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
397 2191 aaronmk
398
    def do_autocommit(self):
399
        '''Autocommits if outside savepoint'''
400
        assert self._savepoint >= 0
401
        if self.autocommit and self._savepoint == 0:
402 2924 aaronmk
            self.log_debug('Autocommitting', level=4)
403 2191 aaronmk
            self.db.commit()
404 2643 aaronmk
405 2819 aaronmk
    def col_info(self, col):
406 2643 aaronmk
        table = sql_gen.Table('columns', 'information_schema')
407 3063 aaronmk
        type_ = sql_gen.Coalesce(sql_gen.Nullif(sql_gen.Col('data_type'),
408
            'USER-DEFINED'), sql_gen.Col('udt_name'))
409 3078 aaronmk
        cols = [type_, 'column_default',
410
            sql_gen.Cast('boolean', sql_gen.Col('is_nullable'))]
411 2643 aaronmk
412
        conds = [('table_name', col.table.name), ('column_name', col.name)]
413
        schema = col.table.schema
414
        if schema != None: conds.append(('table_schema', schema))
415
416 2819 aaronmk
        type_, default, nullable = row(select(self, table, cols, conds,
417 3059 aaronmk
            order_by='table_schema', limit=1, cacheable=False, log_level=4))
418 2643 aaronmk
            # TODO: order_by search_path schema order
419 2819 aaronmk
        default = sql_gen.as_Code(default, self)
420
421
        return sql_gen.TypedCol(col.name, type_, default, nullable)
422 2917 aaronmk
423
    def TempFunction(self, name):
424
        if self.debug_temp: schema = None
425
        else: schema = 'pg_temp'
426
        return sql_gen.Function(name, schema)
427 1849 aaronmk
428 1869 aaronmk
connect = DbConn
429
430 832 aaronmk
##### Recoverable querying
431 15 aaronmk
432 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
433 11 aaronmk
434 2791 aaronmk
def run_query(db, query, recover=None, cacheable=False, log_level=2,
435
    log_ignore_excs=None, **kw_args):
436 2794 aaronmk
    '''For params, see DbConn.run_query()'''
437 830 aaronmk
    if recover == None: recover = False
438 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
439
    log_ignore_excs = tuple(log_ignore_excs)
440 830 aaronmk
441 2666 aaronmk
    debug_msg_ref = None # usually, db.run_query() logs query before running it
442
    # But if filtering with log_ignore_excs, wait until after exception parsing
443 2984 aaronmk
    if log_ignore_excs != () or not db.can_mogrify(): debug_msg_ref = [None]
444 2666 aaronmk
445 2148 aaronmk
    try:
446 2464 aaronmk
        try:
447 2794 aaronmk
            def run(): return db.run_query(query, cacheable, log_level,
448 2793 aaronmk
                debug_msg_ref, **kw_args)
449 2796 aaronmk
            if recover and not db.is_cached(query):
450 2464 aaronmk
                return with_savepoint(db, run)
451
            else: return run() # don't need savepoint if cached
452
        except Exception, e:
453 3095 aaronmk
            msg = strings.ustr(e.args[0])
454 2464 aaronmk
455 3095 aaronmk
            match = re.match(r'^duplicate key value violates unique constraint '
456 3096 aaronmk
                r'"((_?[^\W_]+(?=[._]))?.+?)"', msg)
457 2464 aaronmk
            if match:
458
                constraint, table = match.groups()
459 3025 aaronmk
                cols = []
460
                if recover: # need auto-rollback to run index_cols()
461
                    try: cols = index_cols(db, table, constraint)
462
                    except NotImplementedError: pass
463
                raise DuplicateKeyException(constraint, cols, e)
464 2464 aaronmk
465 3095 aaronmk
            match = re.match(r'^null value in column "(.+?)" violates not-null'
466 2464 aaronmk
                r' constraint', msg)
467
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
468
469 3095 aaronmk
            match = re.match(r'^(?:invalid input (?:syntax|value)\b.*?'
470 3109 aaronmk
                r'|.+? field value out of range): "(.+?)"', msg)
471 2464 aaronmk
            if match:
472 3109 aaronmk
                value, = match.groups()
473
                raise InvalidValueException(strings.to_unicode(value), e)
474 2464 aaronmk
475 3095 aaronmk
            match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
476 2523 aaronmk
                r'is of type', msg)
477
            if match:
478
                col, type_ = match.groups()
479
                raise MissingCastException(type_, col, e)
480
481 3095 aaronmk
            match = re.match(r'^(\S+) "(.+?)".*? already exists', msg)
482 2945 aaronmk
            if match:
483
                type_, name = match.groups()
484
                raise DuplicateException(type_, name, e)
485 2464 aaronmk
486
            raise # no specific exception raised
487
    except log_ignore_excs:
488
        log_level += 2
489
        raise
490
    finally:
491 2666 aaronmk
        if debug_msg_ref != None and debug_msg_ref[0] != None:
492
            db.log_debug(debug_msg_ref[0], log_level)
493 830 aaronmk
494 832 aaronmk
##### Basic queries
495
496 2153 aaronmk
def next_version(name):
497 2163 aaronmk
    version = 1 # first existing name was version 0
498 2586 aaronmk
    match = re.match(r'^(.*)#(\d+)$', name)
499 2153 aaronmk
    if match:
500 2586 aaronmk
        name, version = match.groups()
501
        version = int(version)+1
502 2932 aaronmk
    return sql_gen.concat(name, '#'+str(version))
503 2153 aaronmk
504 2899 aaronmk
def lock_table(db, table, mode):
505
    table = sql_gen.as_Table(table)
506
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
507
508 2789 aaronmk
def run_query_into(db, query, into=None, add_indexes_=False, **kw_args):
509 2085 aaronmk
    '''Outputs a query to a temp table.
510
    For params, see run_query().
511
    '''
512 2789 aaronmk
    if into == None: return run_query(db, query, **kw_args)
513 2790 aaronmk
514
    assert isinstance(into, sql_gen.Table)
515
516 2992 aaronmk
    into.is_temp = True
517 3008 aaronmk
    # "temporary tables cannot specify a schema name", so remove schema
518
    into.schema = None
519 2992 aaronmk
520 2790 aaronmk
    kw_args['recover'] = True
521 2945 aaronmk
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
522 2790 aaronmk
523 2916 aaronmk
    temp = not db.debug_temp # tables are permanent in debug_temp mode
524 2790 aaronmk
525
    # Create table
526
    while True:
527
        create_query = 'CREATE'
528
        if temp: create_query += ' TEMP'
529
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
530 2385 aaronmk
531 2790 aaronmk
        try:
532
            cur = run_query(db, create_query, **kw_args)
533
                # CREATE TABLE AS sets rowcount to # rows in query
534
            break
535 2945 aaronmk
        except DuplicateException, e:
536 2790 aaronmk
            into.name = next_version(into.name)
537
            # try again with next version of name
538
539
    if add_indexes_: add_indexes(db, into)
540 3075 aaronmk
541
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
542
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
543
    # table is going to be used in complex queries, it is wise to run ANALYZE on
544
    # the temporary table after it is populated."
545
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
546
    # If into is not a temp table, ANALYZE is useful but not required.
547 3073 aaronmk
    analyze(db, into)
548 2790 aaronmk
549
    return cur
550 2085 aaronmk
551 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
552
553 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
554
555 2233 aaronmk
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
556 2293 aaronmk
    start=None, order_by=order_by_pkey, default_table=None):
557 1981 aaronmk
    '''
558 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
559 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
560 1981 aaronmk
    @param fields Use None to select all fields in the table
561 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
562 2379 aaronmk
        * container can be any iterable type
563 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
564
        * compare_right_side: sql_gen.ValueCond|literal value
565 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
566
        use all columns
567 2786 aaronmk
    @return query
568 1981 aaronmk
    '''
569 2315 aaronmk
    # Parse tables param
570 2964 aaronmk
    tables = lists.mk_seq(tables)
571 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
572 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
573 2121 aaronmk
574 2315 aaronmk
    # Parse other params
575 2376 aaronmk
    if conds == None: conds = []
576 2650 aaronmk
    elif dicts.is_dict(conds): conds = conds.items()
577 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
578 135 aaronmk
    assert limit == None or type(limit) == int
579 865 aaronmk
    assert start == None or type(start) == int
580 2315 aaronmk
    if order_by is order_by_pkey:
581
        if distinct_on != []: order_by = None
582
        else: order_by = pkey(db, table0, recover=True)
583 865 aaronmk
584 2315 aaronmk
    query = 'SELECT'
585 2056 aaronmk
586 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
587 2056 aaronmk
588 2200 aaronmk
    # DISTINCT ON columns
589 2233 aaronmk
    if distinct_on != []:
590 2467 aaronmk
        query += '\nDISTINCT'
591 2254 aaronmk
        if distinct_on is not distinct_on_all:
592 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
593
594
    # Columns
595 3027 aaronmk
    if fields == None:
596
        if query.find('\n') >= 0: whitespace = '\n'
597
        else: whitespace = ' '
598
        query += whitespace+'*'
599 2765 aaronmk
    else:
600
        assert fields != []
601 3027 aaronmk
        query += '\n'+('\n, '.join(map(parse_col, fields)))
602 2200 aaronmk
603
    # Main table
604 2467 aaronmk
    query += '\nFROM '+table0.to_str(db)
605 865 aaronmk
606 2122 aaronmk
    # Add joins
607 2271 aaronmk
    left_table = table0
608 2263 aaronmk
    for join_ in tables:
609
        table = join_.table
610 2238 aaronmk
611 2343 aaronmk
        # Parse special values
612
        if join_.type_ is sql_gen.filter_out: # filter no match
613 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
614 2853 aaronmk
                sql_gen.CompareCond(None, '~=')))
615 2343 aaronmk
616 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
617 2122 aaronmk
618
        left_table = table
619
620 865 aaronmk
    missing = True
621 2376 aaronmk
    if conds != []:
622 2576 aaronmk
        if len(conds) == 1: whitespace = ' '
623
        else: whitespace = '\n'
624 2578 aaronmk
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
625
            .to_str(db) for l, r in conds], 'WHERE')
626 865 aaronmk
        missing = False
627 2227 aaronmk
    if order_by != None:
628 2467 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
629
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
630 865 aaronmk
    if start != None:
631 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
632 865 aaronmk
        missing = False
633
    if missing: warnings.warn(DbWarning(
634
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
635
636 2786 aaronmk
    return query
637 11 aaronmk
638 2054 aaronmk
def select(db, *args, **kw_args):
639
    '''For params, see mk_select() and run_query()'''
640
    recover = kw_args.pop('recover', None)
641
    cacheable = kw_args.pop('cacheable', True)
642 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
643 2054 aaronmk
644 2791 aaronmk
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
645
        log_level=log_level)
646 2054 aaronmk
647 2788 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
648 3009 aaronmk
    embeddable=False, ignore=False):
649 1960 aaronmk
    '''
650
    @param returning str|None An inserted column (such as pkey) to return
651 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
652 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
653
        query will be fully cached, not just if it raises an exception.
654 3009 aaronmk
    @param ignore Whether to ignore duplicate keys.
655 1960 aaronmk
    '''
656 2754 aaronmk
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
657 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
658 3010 aaronmk
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
659 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
660 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
661 2063 aaronmk
662 2497 aaronmk
    first_line = 'INSERT INTO '+table.to_str(db)
663 2063 aaronmk
664 3009 aaronmk
    def mk_insert(select_query):
665
        query = first_line
666 3014 aaronmk
        if cols != None:
667
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
668 3009 aaronmk
        query += '\n'+select_query
669
670
        if returning != None:
671
            returning_name_col = sql_gen.to_name_only_col(returning)
672
            query += '\nRETURNING '+returning_name_col.to_str(db)
673
674
        return query
675 2063 aaronmk
676 3017 aaronmk
    return_type = 'unknown'
677
    if returning != None: return_type = returning.to_str(db)+'%TYPE'
678
679 3009 aaronmk
    lang = 'sql'
680
    if ignore:
681 3017 aaronmk
        # Always return something to set the correct rowcount
682
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
683
684 3009 aaronmk
        embeddable = True # must use function
685
        lang = 'plpgsql'
686 3010 aaronmk
687 3092 aaronmk
        if cols == None:
688
            row = [sql_gen.Col(sql_gen.all_cols, 'row')]
689
            row_vars = [sql_gen.Table('row')]
690
        else:
691
            row_vars = row = [sql_gen.Col(c.name, 'row') for c in cols]
692
693 3009 aaronmk
        query = '''\
694 3010 aaronmk
DECLARE
695 3014 aaronmk
    row '''+table.to_str(db)+'''%ROWTYPE;
696 3009 aaronmk
BEGIN
697 3019 aaronmk
    /* Need an EXCEPTION block for each individual row because "When an error is
698
    caught by an EXCEPTION clause, [...] all changes to persistent database
699
    state within the block are rolled back."
700
    This is unfortunate because "A block containing an EXCEPTION clause is
701
    significantly more expensive to enter and exit than a block without one."
702 3015 aaronmk
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
703
#PLPGSQL-ERROR-TRAPPING)
704
    */
705 3092 aaronmk
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
706 3034 aaronmk
'''+select_query+'''
707
    LOOP
708 3015 aaronmk
        BEGIN
709 3019 aaronmk
            RETURN QUERY
710 3014 aaronmk
'''+mk_insert(sql_gen.Values(row).to_str(db))+'''
711 3010 aaronmk
;
712 3015 aaronmk
        EXCEPTION
713 3019 aaronmk
            WHEN unique_violation THEN NULL; -- continue to next row
714 3015 aaronmk
        END;
715 3010 aaronmk
    END LOOP;
716
END;\
717 3009 aaronmk
'''
718
    else: query = mk_insert(select_query)
719
720 2070 aaronmk
    if embeddable:
721
        # Create function
722 2513 aaronmk
        function_name = sql_gen.clean_name(first_line)
723 2189 aaronmk
        while True:
724
            try:
725 2918 aaronmk
                function = db.TempFunction(function_name)
726 2194 aaronmk
727 2189 aaronmk
                function_query = '''\
728 2698 aaronmk
CREATE FUNCTION '''+function.to_str(db)+'''()
729 3017 aaronmk
RETURNS SETOF '''+return_type+'''
730 3009 aaronmk
LANGUAGE '''+lang+'''
731 2467 aaronmk
AS $$
732 3009 aaronmk
'''+query+'''
733 2467 aaronmk
$$;
734 2070 aaronmk
'''
735 2446 aaronmk
                run_query(db, function_query, recover=True, cacheable=True,
736 2945 aaronmk
                    log_ignore_excs=(DuplicateException,))
737 2189 aaronmk
                break # this version was successful
738 2945 aaronmk
            except DuplicateException, e:
739 2189 aaronmk
                function_name = next_version(function_name)
740
                # try again with next version of name
741 2070 aaronmk
742 2337 aaronmk
        # Return query that uses function
743 3009 aaronmk
        cols = None
744
        if returning != None: cols = [returning]
745 2698 aaronmk
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
746 3009 aaronmk
            cols) # AS clause requires function alias
747 2787 aaronmk
        return mk_select(db, func_table, start=0, order_by=None)
748 2070 aaronmk
749 2787 aaronmk
    return query
750 2066 aaronmk
751 3074 aaronmk
def insert_select(db, table, *args, **kw_args):
752 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
753 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
754
        values in
755 2072 aaronmk
    '''
756 2386 aaronmk
    into = kw_args.pop('into', None)
757
    if into != None: kw_args['embeddable'] = True
758 2066 aaronmk
    recover = kw_args.pop('recover', None)
759 3011 aaronmk
    if kw_args.get('ignore', False): recover = True
760 2066 aaronmk
    cacheable = kw_args.pop('cacheable', True)
761 2673 aaronmk
    log_level = kw_args.pop('log_level', 2)
762 2066 aaronmk
763 3074 aaronmk
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
764
        into, recover=recover, cacheable=cacheable, log_level=log_level)
765
    autoanalyze(db, table)
766
    return cur
767 2063 aaronmk
768 2738 aaronmk
default = sql_gen.default # tells insert() to use the default value for a column
769 2066 aaronmk
770 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
771 2085 aaronmk
    '''For params, see insert_select()'''
772 1960 aaronmk
    if lists.is_seq(row): cols = None
773
    else:
774
        cols = row.keys()
775
        row = row.values()
776 2738 aaronmk
    row = list(row) # ensure that "== []" works
777 1960 aaronmk
778 2738 aaronmk
    if row == []: query = None
779
    else: query = sql_gen.Values(row).to_str(db)
780 1961 aaronmk
781 2788 aaronmk
    return insert_select(db, table, cols, query, *args, **kw_args)
782 11 aaronmk
783 3056 aaronmk
def mk_update(db, table, changes=None, cond=None, in_place=False):
784 2402 aaronmk
    '''
785
    @param changes [(col, new_value),...]
786
        * container can be any iterable type
787
        * col: sql_gen.Code|str (for col name)
788
        * new_value: sql_gen.Code|literal value
789
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
790 3056 aaronmk
    @param in_place If set, locks the table and updates rows in place.
791
        This avoids creating dead rows in PostgreSQL.
792
        * cond must be None
793 2402 aaronmk
    @return str query
794
    '''
795 3057 aaronmk
    table = sql_gen.as_Table(table)
796
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
797
        for c, v in changes]
798
799 3056 aaronmk
    if in_place:
800
        assert cond == None
801 3058 aaronmk
802 3065 aaronmk
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
803
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
804
            +db.col_info(sql_gen.with_default_table(c, table)).type
805
            +'\nUSING '+v.to_str(db) for c, v in changes))
806 3058 aaronmk
    else:
807
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
808
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
809
            for c, v in changes))
810
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
811 3056 aaronmk
812 2402 aaronmk
    return query
813
814 3074 aaronmk
def update(db, table, *args, **kw_args):
815 2402 aaronmk
    '''For params, see mk_update() and run_query()'''
816
    recover = kw_args.pop('recover', None)
817 3043 aaronmk
    cacheable = kw_args.pop('cacheable', False)
818 3030 aaronmk
    log_level = kw_args.pop('log_level', 2)
819 2402 aaronmk
820 3074 aaronmk
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
821
        cacheable, log_level=log_level)
822
    autoanalyze(db, table)
823
    return cur
824 2402 aaronmk
825 135 aaronmk
def last_insert_id(db):
826 1849 aaronmk
    module = util.root_module(db.db)
827 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
828
    elif module == 'MySQLdb': return db.insert_id()
829
    else: return None
830 13 aaronmk
831 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
832 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
833 2415 aaronmk
    to names that will be distinct among the columns' tables.
834 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
835 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
836
    @param into The table for the new columns.
837 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
838
        columns will be included in the mapping even if they are not in cols.
839
        The tables of the provided Col objects will be changed to into, so make
840
        copies of them if you want to keep the original tables.
841
    @param as_items Whether to return a list of dict items instead of a dict
842 2383 aaronmk
    @return dict(orig_col=new_col, ...)
843
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
844 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
845
        * All mappings use the into table so its name can easily be
846 2383 aaronmk
          changed for all columns at once
847
    '''
848 2415 aaronmk
    cols = lists.uniqify(cols)
849
850 2394 aaronmk
    items = []
851 2389 aaronmk
    for col in preserve:
852 2390 aaronmk
        orig_col = copy.copy(col)
853 2392 aaronmk
        col.table = into
854 2394 aaronmk
        items.append((orig_col, col))
855
    preserve = set(preserve)
856
    for col in cols:
857 2716 aaronmk
        if col not in preserve:
858
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
859 2394 aaronmk
860
    if not as_items: items = dict(items)
861
    return items
862 2383 aaronmk
863 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
864 2391 aaronmk
    '''For params, see mk_flatten_mapping()
865
    @return See return value of mk_flatten_mapping()
866
    '''
867 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
868
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
869 2786 aaronmk
    run_query_into(db, mk_select(db, joins, cols, limit=limit, start=start),
870 2846 aaronmk
        into=into, add_indexes_=True)
871 2394 aaronmk
    return dict(items)
872 2391 aaronmk
873 3079 aaronmk
##### Database structure introspection
874 2414 aaronmk
875 3079 aaronmk
#### Tables
876
877
def tables(db, schema_like='public', table_like='%', exact=False):
878
    if exact: compare = '='
879
    else: compare = 'LIKE'
880
881
    module = util.root_module(db.db)
882
    if module == 'psycopg2':
883
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
884
            ('tablename', sql_gen.CompareCond(table_like, compare))]
885
        return values(select(db, 'pg_tables', ['tablename'], conds,
886
            order_by='tablename', log_level=4))
887
    elif module == 'MySQLdb':
888
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
889
            , cacheable=True, log_level=4))
890
    else: raise NotImplementedError("Can't list tables for "+module+' database')
891
892
def table_exists(db, table):
893
    table = sql_gen.as_Table(table)
894
    return list(tables(db, table.schema, table.name, exact=True)) != []
895
896 2426 aaronmk
def table_row_count(db, table, recover=None):
897 2786 aaronmk
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
898 2443 aaronmk
        order_by=None, start=0), recover=recover, log_level=3))
899 2426 aaronmk
900 2414 aaronmk
def table_cols(db, table, recover=None):
901
    return list(col_names(select(db, table, limit=0, order_by=None,
902 2443 aaronmk
        recover=recover, log_level=4)))
903 2414 aaronmk
904 2291 aaronmk
def pkey(db, table, recover=None):
905 832 aaronmk
    '''Assumed to be first column in table'''
906 2339 aaronmk
    return table_cols(db, table, recover)[0]
907 832 aaronmk
908 2559 aaronmk
not_null_col = 'not_null_col'
909 2340 aaronmk
910
def table_not_null_col(db, table, recover=None):
911
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
912
    if not_null_col in table_cols(db, table, recover): return not_null_col
913
    else: return pkey(db, table, recover)
914
915 853 aaronmk
def index_cols(db, table, index):
916
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
917
    automatically created. When you don't know whether something is a UNIQUE
918
    constraint or a UNIQUE index, use this function.'''
919 1909 aaronmk
    module = util.root_module(db.db)
920
    if module == 'psycopg2':
921
        return list(values(run_query(db, '''\
922 853 aaronmk
SELECT attname
923 866 aaronmk
FROM
924
(
925
        SELECT attnum, attname
926
        FROM pg_index
927
        JOIN pg_class index ON index.oid = indexrelid
928
        JOIN pg_class table_ ON table_.oid = indrelid
929
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
930
        WHERE
931 2782 aaronmk
            table_.relname = '''+db.esc_value(table)+'''
932
            AND index.relname = '''+db.esc_value(index)+'''
933 866 aaronmk
    UNION
934
        SELECT attnum, attname
935
        FROM
936
        (
937
            SELECT
938
                indrelid
939
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
940
                    AS indkey
941
            FROM pg_index
942
            JOIN pg_class index ON index.oid = indexrelid
943
            JOIN pg_class table_ ON table_.oid = indrelid
944
            WHERE
945 2782 aaronmk
                table_.relname = '''+db.esc_value(table)+'''
946
                AND index.relname = '''+db.esc_value(index)+'''
947 866 aaronmk
        ) s
948
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
949
) s
950 853 aaronmk
ORDER BY attnum
951 2782 aaronmk
'''
952
            , cacheable=True, log_level=4)))
953 1909 aaronmk
    else: raise NotImplementedError("Can't list index columns for "+module+
954
        ' database')
955 853 aaronmk
956 464 aaronmk
def constraint_cols(db, table, constraint):
957 1849 aaronmk
    module = util.root_module(db.db)
958 464 aaronmk
    if module == 'psycopg2':
959
        return list(values(run_query(db, '''\
960
SELECT attname
961
FROM pg_constraint
962
JOIN pg_class ON pg_class.oid = conrelid
963
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
964
WHERE
965 2783 aaronmk
    relname = '''+db.esc_value(table)+'''
966
    AND conname = '''+db.esc_value(constraint)+'''
967 464 aaronmk
ORDER BY attnum
968 2783 aaronmk
'''
969
            )))
970 464 aaronmk
    else: raise NotImplementedError("Can't list constraint columns for "+module+
971
        ' database')
972
973 3079 aaronmk
#### Functions
974
975
def function_exists(db, function):
976
    function = sql_gen.as_Function(function)
977
978
    info_table = sql_gen.Table('routines', 'information_schema')
979
    conds = [('routine_name', function.name)]
980
    schema = function.schema
981
    if schema != None: conds.append(('routine_schema', schema))
982
    # Exclude trigger functions, since they cannot be called directly
983
    conds.append(('data_type', sql_gen.CompareCond('trigger', '!=')))
984
985
    return list(values(select(db, info_table, ['routine_name'], conds,
986
        order_by='routine_schema', limit=1, log_level=4))) != []
987
        # TODO: order_by search_path schema order
988
989
##### Structural changes
990
991
#### Columns
992
993
def add_col(db, table, col, comment=None, **kw_args):
994
    '''
995
    @param col TypedCol Name may be versioned, so be sure to propagate any
996
        renaming back to any source column for the TypedCol.
997
    @param comment None|str SQL comment used to distinguish columns of the same
998
        name from each other when they contain different data, to allow the
999
        ADD COLUMN query to be cached. If not set, query will not be cached.
1000
    '''
1001
    assert isinstance(col, sql_gen.TypedCol)
1002
1003
    while True:
1004
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1005
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1006
1007
        try:
1008
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1009
            break
1010
        except DuplicateException:
1011
            col.name = next_version(col.name)
1012
            # try again with next version of name
1013
1014
def add_not_null(db, col):
1015
    table = col.table
1016
    col = sql_gen.to_name_only_col(col)
1017
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1018
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1019
1020 2096 aaronmk
row_num_col = '_row_num'
1021
1022 3079 aaronmk
row_num_typed_col = sql_gen.TypedCol(row_num_col, 'serial', nullable=False,
1023
    constraints='PRIMARY KEY')
1024
1025
def add_row_num(db, table):
1026
    '''Adds a row number column to a table. Its name is in row_num_col. It will
1027
    be the primary key.'''
1028
    add_col(db, table, row_num_typed_col, log_level=3)
1029
1030
#### Indexes
1031
1032
def add_pkey(db, table, cols=None, recover=None):
1033
    '''Adds a primary key.
1034
    @param cols [sql_gen.Col,...] The columns in the primary key.
1035
        Defaults to the first column in the table.
1036
    @pre The table must not already have a primary key.
1037
    '''
1038
    table = sql_gen.as_Table(table)
1039
    if cols == None: cols = [pkey(db, table, recover)]
1040
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1041
1042
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1043
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1044
        log_ignore_excs=(DuplicateException,))
1045
1046 2998 aaronmk
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1047 2688 aaronmk
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1048 2538 aaronmk
    Currently, only function calls are supported as expressions.
1049 2998 aaronmk
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1050 2847 aaronmk
        This allows indexes to be used for comparisons where NULLs are equal.
1051 2538 aaronmk
    '''
1052 2964 aaronmk
    exprs = lists.mk_seq(exprs)
1053 2538 aaronmk
1054 2688 aaronmk
    # Parse exprs
1055
    old_exprs = exprs[:]
1056
    exprs = []
1057
    cols = []
1058
    for i, expr in enumerate(old_exprs):
1059 2823 aaronmk
        expr = sql_gen.as_Col(expr, table)
1060 2688 aaronmk
1061 2823 aaronmk
        # Handle nullable columns
1062 2998 aaronmk
        if ensure_not_null_:
1063
            try: expr = ensure_not_null(db, expr)
1064 2860 aaronmk
            except KeyError: pass # unknown type, so just create plain index
1065 2823 aaronmk
1066 2688 aaronmk
        # Extract col
1067 3002 aaronmk
        expr = copy.deepcopy(expr) # don't modify input!
1068 2688 aaronmk
        if isinstance(expr, sql_gen.FunctionCall):
1069
            col = expr.args[0]
1070
            expr = sql_gen.Expr(expr)
1071
        else: col = expr
1072 2823 aaronmk
        assert isinstance(col, sql_gen.Col)
1073 2688 aaronmk
1074
        # Extract table
1075
        if table == None:
1076
            assert sql_gen.is_table_col(col)
1077
            table = col.table
1078
1079
        col.table = None
1080
1081
        exprs.append(expr)
1082
        cols.append(col)
1083 2408 aaronmk
1084 2688 aaronmk
    table = sql_gen.as_Table(table)
1085
    index = sql_gen.Table(str(sql_gen.Col(','.join(map(str, cols)), table)))
1086
1087 3005 aaronmk
    # Add index
1088
    while True:
1089
        str_ = 'CREATE'
1090
        if unique: str_ += ' UNIQUE'
1091
        str_ += ' INDEX '+index.to_str(db)+' ON '+table.to_str(db)+' ('+(
1092
            ', '.join((v.to_str(db) for v in exprs)))+')'
1093
1094
        try:
1095
            run_query(db, str_, recover=True, cacheable=True, log_level=3,
1096
                log_ignore_excs=(DuplicateException,))
1097
            break
1098
        except DuplicateException:
1099
            index.name = next_version(index.name)
1100
            # try again with next version of name
1101 2408 aaronmk
1102 2997 aaronmk
def add_index_col(db, col, suffix, expr, nullable=True):
1103 3000 aaronmk
    if sql_gen.index_col(col) != None: return # already has index col
1104 2997 aaronmk
1105
    new_col = sql_gen.suffixed_col(col, suffix)
1106
1107 3006 aaronmk
    # Add column
1108 3038 aaronmk
    new_typed_col = sql_gen.TypedCol(new_col.name, db.col_info(col).type)
1109 3045 aaronmk
    add_col(db, col.table, new_typed_col, comment='src: '+repr(col),
1110
        log_level=3)
1111 3037 aaronmk
    new_col.name = new_typed_col.name # propagate any renaming
1112 3006 aaronmk
1113 3064 aaronmk
    update(db, col.table, [(new_col, expr)], in_place=True, cacheable=True,
1114
        log_level=3)
1115 2997 aaronmk
    if not nullable: add_not_null(db, new_col)
1116
    add_index(db, new_col)
1117
1118 3104 aaronmk
    col.table.index_cols[col.name] = new_col.name
1119 2997 aaronmk
1120 3047 aaronmk
# Controls when ensure_not_null() will use index columns
1121
not_null_index_cols_min_rows = 0 # rows; initially always use index columns
1122
1123 2997 aaronmk
def ensure_not_null(db, col):
1124
    '''For params, see sql_gen.ensure_not_null()'''
1125
    expr = sql_gen.ensure_not_null(db, col)
1126
1127 3047 aaronmk
    # If a nullable column in a temp table, add separate index column instead.
1128
    # Note that for small datasources, this adds 6-25% to the total import time.
1129
    if (sql_gen.is_temp_col(col) and isinstance(expr, sql_gen.EnsureNotNull)
1130
        and table_row_count(db, col.table) >= not_null_index_cols_min_rows):
1131 2997 aaronmk
        add_index_col(db, col, '::NOT NULL', expr, nullable=False)
1132 3000 aaronmk
        expr = sql_gen.index_col(col)
1133 2997 aaronmk
1134
    return expr
1135
1136 3083 aaronmk
already_indexed = object() # tells add_indexes() the pkey has already been added
1137
1138
def add_indexes(db, table, has_pkey=True):
1139
    '''Adds an index on all columns in a table.
1140
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1141
        index should be added on the first column.
1142
        * If already_indexed, the pkey is assumed to have already been added
1143
    '''
1144
    cols = table_cols(db, table)
1145
    if has_pkey:
1146
        if has_pkey is not already_indexed: add_pkey(db, table)
1147
        cols = cols[1:]
1148
    for col in cols: add_index(db, col, table)
1149
1150 3079 aaronmk
#### Tables
1151 2772 aaronmk
1152 3079 aaronmk
### Maintenance
1153 2772 aaronmk
1154 3079 aaronmk
def analyze(db, table):
1155
    table = sql_gen.as_Table(table)
1156
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1157 2934 aaronmk
1158 3079 aaronmk
def autoanalyze(db, table):
1159
    if db.autoanalyze: analyze(db, table)
1160 2935 aaronmk
1161 3079 aaronmk
def vacuum(db, table):
1162
    table = sql_gen.as_Table(table)
1163
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1164
        log_level=3))
1165 2086 aaronmk
1166 3079 aaronmk
### Lifecycle
1167
1168 2889 aaronmk
def drop_table(db, table):
1169
    table = sql_gen.as_Table(table)
1170
    return run_query(db, 'DROP TABLE IF EXISTS '+table.to_str(db)+' CASCADE')
1171
1172 3082 aaronmk
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1173
    like=None):
1174 2675 aaronmk
    '''Creates a table.
1175 2681 aaronmk
    @param cols [sql_gen.TypedCol,...] The column names and types
1176
    @param has_pkey If set, the first column becomes the primary key.
1177 2760 aaronmk
    @param col_indexes bool|[ref]
1178
        * If True, indexes will be added on all non-pkey columns.
1179
        * If a list reference, [0] will be set to a function to do this.
1180
          This can be used to delay index creation until the table is populated.
1181 2675 aaronmk
    '''
1182
    table = sql_gen.as_Table(table)
1183
1184 3082 aaronmk
    if like != None:
1185
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1186
            ]+cols
1187 2681 aaronmk
    if has_pkey:
1188
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1189 2872 aaronmk
        pkey.constraints = 'PRIMARY KEY'
1190 2681 aaronmk
1191 3085 aaronmk
    temp = table.is_temp and not db.debug_temp
1192
        # temp tables permanent in debug_temp mode
1193 2760 aaronmk
1194 3085 aaronmk
    # Create table
1195
    while True:
1196
        str_ = 'CREATE'
1197
        if temp: str_ += ' TEMP'
1198
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1199
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1200
        str_ += '\n);\n'
1201
1202
        try:
1203
            run_query(db, str_, cacheable=True, log_level=2,
1204
                log_ignore_excs=(DuplicateException,))
1205
            break
1206
        except DuplicateException:
1207
            table.name = next_version(table.name)
1208
            # try again with next version of name
1209
1210 2760 aaronmk
    # Add indexes
1211 2773 aaronmk
    if has_pkey: has_pkey = already_indexed
1212
    def add_indexes_(): add_indexes(db, table, has_pkey)
1213
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1214
    elif col_indexes: add_indexes_() # add now
1215 2675 aaronmk
1216 3084 aaronmk
def copy_table_struct(db, src, dest):
1217
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1218 3085 aaronmk
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1219 3084 aaronmk
1220 3079 aaronmk
### Data
1221 2684 aaronmk
1222 2970 aaronmk
def truncate(db, table, schema='public', **kw_args):
1223
    '''For params, see run_query()'''
1224 2777 aaronmk
    table = sql_gen.as_Table(table, schema)
1225 2970 aaronmk
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1226 2732 aaronmk
1227 2965 aaronmk
def empty_temp(db, tables):
1228 2972 aaronmk
    if db.debug_temp: return # leave temp tables there for debugging
1229 2965 aaronmk
    tables = lists.mk_seq(tables)
1230 2971 aaronmk
    for table in tables: truncate(db, table, log_level=3)
1231 2965 aaronmk
1232 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
1233
    '''For kw_args, see tables()'''
1234
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1235 3094 aaronmk
1236
def distinct_table(db, table, distinct_on):
1237
    '''Creates a copy of a temp table which is distinct on the given columns.
1238 3099 aaronmk
    The old and new tables will both get an index on these columns, to
1239
    facilitate merge joins.
1240 3097 aaronmk
    @param distinct_on If empty, creates a table with one row. This is useful if
1241
        your distinct_on columns are all literal values.
1242 3099 aaronmk
    @return The new table.
1243 3094 aaronmk
    '''
1244 3099 aaronmk
    new_table = sql_gen.suffixed_table(table, '_distinct')
1245 3094 aaronmk
1246 3099 aaronmk
    copy_table_struct(db, table, new_table)
1247 3097 aaronmk
1248
    limit = None
1249
    if distinct_on == []: limit = 1 # one sample row
1250 3099 aaronmk
    else:
1251
        add_index(db, distinct_on, new_table, unique=True)
1252
        add_index(db, distinct_on, table) # for join optimization
1253 3097 aaronmk
1254 3099 aaronmk
    insert_select(db, new_table, None, mk_select(db, table, start=0,
1255 3097 aaronmk
        limit=limit), ignore=True)
1256 3099 aaronmk
    analyze(db, new_table)
1257 3094 aaronmk
1258 3099 aaronmk
    return new_table