Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1960 aaronmk
import lists
11 1889 aaronmk
from Proxy import Proxy
12 1872 aaronmk
import rand
13 2217 aaronmk
import sql_gen
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 2804 aaronmk
def get_cur_query(cur, input_query=None):
20 2168 aaronmk
    raw_query = None
21
    if hasattr(cur, 'query'): raw_query = cur.query
22
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
23 2170 aaronmk
24
    if raw_query != None: return raw_query
25 2804 aaronmk
    else: return '[input] '+strings.ustr(input_query)
26 14 aaronmk
27 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
28
    '''For params, see get_cur_query()'''
29 2771 aaronmk
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
30 135 aaronmk
31 300 aaronmk
class DbException(exc.ExceptionWithCause):
32 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
33 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
34 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
35
36 2143 aaronmk
class ExceptionWithName(DbException):
37
    def __init__(self, name, cause=None):
38 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
39 2143 aaronmk
        self.name = name
40 360 aaronmk
41 3109 aaronmk
class ExceptionWithValue(DbException):
42
    def __init__(self, value, cause=None):
43
        DbException.__init__(self, 'for value: '+strings.as_tt(repr(value)),
44
            cause)
45 2240 aaronmk
        self.value = value
46
47 2945 aaronmk
class ExceptionWithNameType(DbException):
48
    def __init__(self, type_, name, cause=None):
49
        DbException.__init__(self, 'for type: '+strings.as_tt(str(type_))
50
            +'; name: '+strings.as_tt(name), cause)
51
        self.type = type_
52
        self.name = name
53
54 2306 aaronmk
class ConstraintException(DbException):
55
    def __init__(self, name, cols, cause=None):
56 2484 aaronmk
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
57
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
58 2306 aaronmk
        self.name = name
59 468 aaronmk
        self.cols = cols
60 11 aaronmk
61 2523 aaronmk
class MissingCastException(DbException):
62
    def __init__(self, type_, col, cause=None):
63
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
64
            +' on column: '+strings.as_tt(col), cause)
65
        self.type = type_
66
        self.col = col
67
68 2143 aaronmk
class NameException(DbException): pass
69
70 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
71 13 aaronmk
72 2306 aaronmk
class NullValueException(ConstraintException): pass
73 13 aaronmk
74 3109 aaronmk
class InvalidValueException(ExceptionWithValue): pass
75 2239 aaronmk
76 2945 aaronmk
class DuplicateException(ExceptionWithNameType): pass
77 2143 aaronmk
78 89 aaronmk
class EmptyRowException(DbException): pass
79
80 865 aaronmk
##### Warnings
81
82
class DbWarning(UserWarning): pass
83
84 1930 aaronmk
##### Result retrieval
85
86
def col_names(cur): return (col[0] for col in cur.description)
87
88
def rows(cur): return iter(lambda: cur.fetchone(), None)
89
90
def consume_rows(cur):
91
    '''Used to fetch all rows so result will be cached'''
92
    iters.consume_iter(rows(cur))
93
94
def next_row(cur): return rows(cur).next()
95
96
def row(cur):
97
    row_ = next_row(cur)
98
    consume_rows(cur)
99
    return row_
100
101
def next_value(cur): return next_row(cur)[0]
102
103
def value(cur): return row(cur)[0]
104
105
def values(cur): return iters.func_iter(lambda: next_value(cur))
106
107
def value_or_none(cur):
108
    try: return value(cur)
109
    except StopIteration: return None
110
111 2762 aaronmk
##### Escaping
112 2101 aaronmk
113 2573 aaronmk
def esc_name_by_module(module, name):
114
    if module == 'psycopg2' or module == None: quote = '"'
115 2101 aaronmk
    elif module == 'MySQLdb': quote = '`'
116
    else: raise NotImplementedError("Can't escape name for "+module+' database')
117 2500 aaronmk
    return sql_gen.esc_name(name, quote)
118 2101 aaronmk
119
def esc_name_by_engine(engine, name, **kw_args):
120
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
121
122
def esc_name(db, name, **kw_args):
123
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
124
125
def qual_name(db, schema, table):
126
    def esc_name_(name): return esc_name(db, name)
127
    table = esc_name_(table)
128
    if schema != None: return esc_name_(schema)+'.'+table
129
    else: return table
130
131 1869 aaronmk
##### Database connections
132 1849 aaronmk
133 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
134 1926 aaronmk
135 1869 aaronmk
db_engines = {
136
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
137
    'PostgreSQL': ('psycopg2', {}),
138
}
139
140
DatabaseErrors_set = set([DbException])
141
DatabaseErrors = tuple(DatabaseErrors_set)
142
143
def _add_module(module):
144
    DatabaseErrors_set.add(module.DatabaseError)
145
    global DatabaseErrors
146
    DatabaseErrors = tuple(DatabaseErrors_set)
147
148
def db_config_str(db_config):
149
    return db_config['engine']+' database '+db_config['database']
150
151 2448 aaronmk
log_debug_none = lambda msg, level=2: None
152 1901 aaronmk
153 1849 aaronmk
class DbConn:
154 2923 aaronmk
    def __init__(self, db_config, autocommit=True, caching=True,
155 2915 aaronmk
        log_debug=log_debug_none, debug_temp=False):
156
        '''
157
        @param debug_temp Whether temporary objects should instead be permanent.
158
            This assists in debugging the internal objects used by the program.
159
        '''
160 1869 aaronmk
        self.db_config = db_config
161 2190 aaronmk
        self.autocommit = autocommit
162
        self.caching = caching
163 1901 aaronmk
        self.log_debug = log_debug
164 2193 aaronmk
        self.debug = log_debug != log_debug_none
165 2915 aaronmk
        self.debug_temp = debug_temp
166 3074 aaronmk
        self.autoanalyze = False
167 1869 aaronmk
168
        self.__db = None
169 1889 aaronmk
        self.query_results = {}
170 2139 aaronmk
        self._savepoint = 0
171 2671 aaronmk
        self._notices_seen = set()
172 1869 aaronmk
173
    def __getattr__(self, name):
174
        if name == '__dict__': raise Exception('getting __dict__')
175
        if name == 'db': return self._db()
176
        else: raise AttributeError()
177
178
    def __getstate__(self):
179
        state = copy.copy(self.__dict__) # shallow copy
180 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
181 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
182
        return state
183
184 2165 aaronmk
    def connected(self): return self.__db != None
185
186 1869 aaronmk
    def _db(self):
187
        if self.__db == None:
188
            # Process db_config
189
            db_config = self.db_config.copy() # don't modify input!
190 2097 aaronmk
            schemas = db_config.pop('schemas', None)
191 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
192
            module = __import__(module_name)
193
            _add_module(module)
194
            for orig, new in mappings.iteritems():
195
                try: util.rename_key(db_config, orig, new)
196
                except KeyError: pass
197
198
            # Connect
199
            self.__db = module.connect(**db_config)
200
201
            # Configure connection
202 2906 aaronmk
            if hasattr(self.db, 'set_isolation_level'):
203
                import psycopg2.extensions
204
                self.db.set_isolation_level(
205
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
206 2101 aaronmk
            if schemas != None:
207 2893 aaronmk
                search_path = [self.esc_name(s) for s in schemas.split(',')]
208
                search_path.append(value(run_query(self, 'SHOW search_path',
209
                    log_level=4)))
210
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
211
                    log_level=3)
212 1869 aaronmk
213
        return self.__db
214 1889 aaronmk
215 1891 aaronmk
    class DbCursor(Proxy):
216 1927 aaronmk
        def __init__(self, outer):
217 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
218 2191 aaronmk
            self.outer = outer
219 1927 aaronmk
            self.query_results = outer.query_results
220 1894 aaronmk
            self.query_lookup = None
221 1891 aaronmk
            self.result = []
222 1889 aaronmk
223 2802 aaronmk
        def execute(self, query):
224 2764 aaronmk
            self._is_insert = query.startswith('INSERT')
225 2797 aaronmk
            self.query_lookup = query
226 2148 aaronmk
            try:
227 2191 aaronmk
                try:
228 2802 aaronmk
                    cur = self.inner.execute(query)
229 2191 aaronmk
                    self.outer.do_autocommit()
230 2802 aaronmk
                finally: self.query = get_cur_query(self.inner, query)
231 1904 aaronmk
            except Exception, e:
232 2802 aaronmk
                _add_cursor_info(e, self, query)
233 1904 aaronmk
                self.result = e # cache the exception as the result
234
                self._cache_result()
235
                raise
236 3004 aaronmk
237
            # Always cache certain queries
238
            if query.startswith('CREATE') or query.startswith('ALTER'):
239 3007 aaronmk
                # structural changes
240 3040 aaronmk
                # Rest of query must be unique in the face of name collisions,
241
                # so don't cache ADD COLUMN unless it has distinguishing comment
242
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
243 3007 aaronmk
                    self._cache_result()
244 3004 aaronmk
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
245 2800 aaronmk
                consume_rows(self) # fetch all rows so result will be cached
246 3004 aaronmk
247 2762 aaronmk
            return cur
248 1894 aaronmk
249 1891 aaronmk
        def fetchone(self):
250
            row = self.inner.fetchone()
251 1899 aaronmk
            if row != None: self.result.append(row)
252
            # otherwise, fetched all rows
253 1904 aaronmk
            else: self._cache_result()
254
            return row
255
256
        def _cache_result(self):
257 2948 aaronmk
            # For inserts that return a result set, don't cache result set since
258
            # inserts are not idempotent. Other non-SELECT queries don't have
259
            # their result set read, so only exceptions will be cached (an
260
            # invalid query will always be invalid).
261 1930 aaronmk
            if self.query_results != None and (not self._is_insert
262 1906 aaronmk
                or isinstance(self.result, Exception)):
263
264 1894 aaronmk
                assert self.query_lookup != None
265 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
266
                    util.dict_subset(dicts.AttrsDictView(self),
267
                    ['query', 'result', 'rowcount', 'description']))
268 1906 aaronmk
269 1916 aaronmk
        class CacheCursor:
270
            def __init__(self, cached_result): self.__dict__ = cached_result
271
272 1927 aaronmk
            def execute(self, *args, **kw_args):
273 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
274
                # otherwise, result is a rows list
275
                self.iter = iter(self.result)
276
277
            def fetchone(self):
278
                try: return self.iter.next()
279
                except StopIteration: return None
280 1891 aaronmk
281 2212 aaronmk
    def esc_value(self, value):
282 2663 aaronmk
        try: str_ = self.mogrify('%s', [value])
283
        except NotImplementedError, e:
284
            module = util.root_module(self.db)
285
            if module == 'MySQLdb':
286
                import _mysql
287
                str_ = _mysql.escape_string(value)
288
            else: raise e
289 2374 aaronmk
        return strings.to_unicode(str_)
290 2212 aaronmk
291 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
292
293 2814 aaronmk
    def std_code(self, str_):
294
        '''Standardizes SQL code.
295
        * Ensures that string literals are prefixed by `E`
296
        '''
297
        if str_.startswith("'"): str_ = 'E'+str_
298
        return str_
299
300 2665 aaronmk
    def can_mogrify(self):
301 2663 aaronmk
        module = util.root_module(self.db)
302 2665 aaronmk
        return module == 'psycopg2'
303 2663 aaronmk
304 2665 aaronmk
    def mogrify(self, query, params=None):
305
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
306
        else: raise NotImplementedError("Can't mogrify query")
307
308 2671 aaronmk
    def print_notices(self):
309 2725 aaronmk
        if hasattr(self.db, 'notices'):
310
            for msg in self.db.notices:
311
                if msg not in self._notices_seen:
312
                    self._notices_seen.add(msg)
313
                    self.log_debug(msg, level=2)
314 2671 aaronmk
315 2793 aaronmk
    def run_query(self, query, cacheable=False, log_level=2,
316 2464 aaronmk
        debug_msg_ref=None):
317 2445 aaronmk
        '''
318 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
319
            throws one of these exceptions.
320 2664 aaronmk
        @param debug_msg_ref If specified, the log message will be returned in
321
            this instead of being output. This allows you to filter log messages
322
            depending on the result of the query.
323 2445 aaronmk
        '''
324 2167 aaronmk
        assert query != None
325
326 2047 aaronmk
        if not self.caching: cacheable = False
327 1903 aaronmk
        used_cache = False
328 2664 aaronmk
329
        def log_msg(query):
330
            if used_cache: cache_status = 'cache hit'
331
            elif cacheable: cache_status = 'cache miss'
332
            else: cache_status = 'non-cacheable'
333
            return 'DB query: '+cache_status+':\n'+strings.as_code(query, 'SQL')
334
335 1903 aaronmk
        try:
336 1927 aaronmk
            # Get cursor
337
            if cacheable:
338
                try:
339 2797 aaronmk
                    cur = self.query_results[query]
340 1927 aaronmk
                    used_cache = True
341
                except KeyError: cur = self.DbCursor(self)
342
            else: cur = self.db.cursor()
343
344 2664 aaronmk
            # Log query
345
            if self.debug and debug_msg_ref == None: # log before running
346
                self.log_debug(log_msg(query), log_level)
347
348 1927 aaronmk
            # Run query
349 2793 aaronmk
            cur.execute(query)
350 1903 aaronmk
        finally:
351 2671 aaronmk
            self.print_notices()
352 2664 aaronmk
            if self.debug and debug_msg_ref != None: # return after running
353 2793 aaronmk
                debug_msg_ref[0] = log_msg(str(get_cur_query(cur, query)))
354 1903 aaronmk
355
        return cur
356 1914 aaronmk
357 2797 aaronmk
    def is_cached(self, query): return query in self.query_results
358 2139 aaronmk
359 2907 aaronmk
    def with_autocommit(self, func):
360 2801 aaronmk
        import psycopg2.extensions
361
362
        prev_isolation_level = self.db.isolation_level
363 2907 aaronmk
        self.db.set_isolation_level(
364
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
365 2683 aaronmk
        try: return func()
366 2801 aaronmk
        finally: self.db.set_isolation_level(prev_isolation_level)
367 2683 aaronmk
368 2139 aaronmk
    def with_savepoint(self, func):
369 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
370 2443 aaronmk
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
371 2139 aaronmk
        self._savepoint += 1
372 2930 aaronmk
        try: return func()
373 2139 aaronmk
        except:
374 2443 aaronmk
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
375 2139 aaronmk
            raise
376 2930 aaronmk
        finally:
377
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
378
            # "The savepoint remains valid and can be rolled back to again"
379
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
380 2443 aaronmk
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
381 2930 aaronmk
382
            self._savepoint -= 1
383
            assert self._savepoint >= 0
384
385
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
386 2191 aaronmk
387
    def do_autocommit(self):
388
        '''Autocommits if outside savepoint'''
389
        assert self._savepoint >= 0
390
        if self.autocommit and self._savepoint == 0:
391 2924 aaronmk
            self.log_debug('Autocommitting', level=4)
392 2191 aaronmk
            self.db.commit()
393 2643 aaronmk
394 2819 aaronmk
    def col_info(self, col):
395 2643 aaronmk
        table = sql_gen.Table('columns', 'information_schema')
396 3063 aaronmk
        type_ = sql_gen.Coalesce(sql_gen.Nullif(sql_gen.Col('data_type'),
397
            'USER-DEFINED'), sql_gen.Col('udt_name'))
398 3078 aaronmk
        cols = [type_, 'column_default',
399
            sql_gen.Cast('boolean', sql_gen.Col('is_nullable'))]
400 2643 aaronmk
401
        conds = [('table_name', col.table.name), ('column_name', col.name)]
402
        schema = col.table.schema
403
        if schema != None: conds.append(('table_schema', schema))
404
405 2819 aaronmk
        type_, default, nullable = row(select(self, table, cols, conds,
406 3059 aaronmk
            order_by='table_schema', limit=1, cacheable=False, log_level=4))
407 2643 aaronmk
            # TODO: order_by search_path schema order
408 2819 aaronmk
        default = sql_gen.as_Code(default, self)
409
410
        return sql_gen.TypedCol(col.name, type_, default, nullable)
411 2917 aaronmk
412
    def TempFunction(self, name):
413
        if self.debug_temp: schema = None
414
        else: schema = 'pg_temp'
415
        return sql_gen.Function(name, schema)
416 1849 aaronmk
417 1869 aaronmk
connect = DbConn
418
419 832 aaronmk
##### Recoverable querying
420 15 aaronmk
421 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
422 11 aaronmk
423 2791 aaronmk
def run_query(db, query, recover=None, cacheable=False, log_level=2,
424
    log_ignore_excs=None, **kw_args):
425 2794 aaronmk
    '''For params, see DbConn.run_query()'''
426 830 aaronmk
    if recover == None: recover = False
427 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
428
    log_ignore_excs = tuple(log_ignore_excs)
429 830 aaronmk
430 2666 aaronmk
    debug_msg_ref = None # usually, db.run_query() logs query before running it
431
    # But if filtering with log_ignore_excs, wait until after exception parsing
432 2984 aaronmk
    if log_ignore_excs != () or not db.can_mogrify(): debug_msg_ref = [None]
433 2666 aaronmk
434 2148 aaronmk
    try:
435 2464 aaronmk
        try:
436 2794 aaronmk
            def run(): return db.run_query(query, cacheable, log_level,
437 2793 aaronmk
                debug_msg_ref, **kw_args)
438 2796 aaronmk
            if recover and not db.is_cached(query):
439 2464 aaronmk
                return with_savepoint(db, run)
440
            else: return run() # don't need savepoint if cached
441
        except Exception, e:
442 3095 aaronmk
            msg = strings.ustr(e.args[0])
443 2464 aaronmk
444 3095 aaronmk
            match = re.match(r'^duplicate key value violates unique constraint '
445 3096 aaronmk
                r'"((_?[^\W_]+(?=[._]))?.+?)"', msg)
446 2464 aaronmk
            if match:
447
                constraint, table = match.groups()
448 3025 aaronmk
                cols = []
449
                if recover: # need auto-rollback to run index_cols()
450
                    try: cols = index_cols(db, table, constraint)
451
                    except NotImplementedError: pass
452
                raise DuplicateKeyException(constraint, cols, e)
453 2464 aaronmk
454 3095 aaronmk
            match = re.match(r'^null value in column "(.+?)" violates not-null'
455 2464 aaronmk
                r' constraint', msg)
456
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
457
458 3095 aaronmk
            match = re.match(r'^(?:invalid input (?:syntax|value)\b.*?'
459 3109 aaronmk
                r'|.+? field value out of range): "(.+?)"', msg)
460 2464 aaronmk
            if match:
461 3109 aaronmk
                value, = match.groups()
462
                raise InvalidValueException(strings.to_unicode(value), e)
463 2464 aaronmk
464 3095 aaronmk
            match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
465 2523 aaronmk
                r'is of type', msg)
466
            if match:
467
                col, type_ = match.groups()
468
                raise MissingCastException(type_, col, e)
469
470 3095 aaronmk
            match = re.match(r'^(\S+) "(.+?)".*? already exists', msg)
471 2945 aaronmk
            if match:
472
                type_, name = match.groups()
473
                raise DuplicateException(type_, name, e)
474 2464 aaronmk
475
            raise # no specific exception raised
476
    except log_ignore_excs:
477
        log_level += 2
478
        raise
479
    finally:
480 2666 aaronmk
        if debug_msg_ref != None and debug_msg_ref[0] != None:
481
            db.log_debug(debug_msg_ref[0], log_level)
482 830 aaronmk
483 832 aaronmk
##### Basic queries
484
485 2153 aaronmk
def next_version(name):
486 2163 aaronmk
    version = 1 # first existing name was version 0
487 2586 aaronmk
    match = re.match(r'^(.*)#(\d+)$', name)
488 2153 aaronmk
    if match:
489 2586 aaronmk
        name, version = match.groups()
490
        version = int(version)+1
491 2932 aaronmk
    return sql_gen.concat(name, '#'+str(version))
492 2153 aaronmk
493 2899 aaronmk
def lock_table(db, table, mode):
494
    table = sql_gen.as_Table(table)
495
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
496
497 2789 aaronmk
def run_query_into(db, query, into=None, add_indexes_=False, **kw_args):
498 2085 aaronmk
    '''Outputs a query to a temp table.
499
    For params, see run_query().
500
    '''
501 2789 aaronmk
    if into == None: return run_query(db, query, **kw_args)
502 2790 aaronmk
503
    assert isinstance(into, sql_gen.Table)
504
505 2992 aaronmk
    into.is_temp = True
506 3008 aaronmk
    # "temporary tables cannot specify a schema name", so remove schema
507
    into.schema = None
508 2992 aaronmk
509 2790 aaronmk
    kw_args['recover'] = True
510 2945 aaronmk
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
511 2790 aaronmk
512 2916 aaronmk
    temp = not db.debug_temp # tables are permanent in debug_temp mode
513 2790 aaronmk
514
    # Create table
515
    while True:
516
        create_query = 'CREATE'
517
        if temp: create_query += ' TEMP'
518
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
519 2385 aaronmk
520 2790 aaronmk
        try:
521
            cur = run_query(db, create_query, **kw_args)
522
                # CREATE TABLE AS sets rowcount to # rows in query
523
            break
524 2945 aaronmk
        except DuplicateException, e:
525 2790 aaronmk
            into.name = next_version(into.name)
526
            # try again with next version of name
527
528
    if add_indexes_: add_indexes(db, into)
529 3075 aaronmk
530
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
531
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
532
    # table is going to be used in complex queries, it is wise to run ANALYZE on
533
    # the temporary table after it is populated."
534
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
535
    # If into is not a temp table, ANALYZE is useful but not required.
536 3073 aaronmk
    analyze(db, into)
537 2790 aaronmk
538
    return cur
539 2085 aaronmk
540 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
541
542 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
543
544 2233 aaronmk
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
545 2293 aaronmk
    start=None, order_by=order_by_pkey, default_table=None):
546 1981 aaronmk
    '''
547 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
548 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
549 1981 aaronmk
    @param fields Use None to select all fields in the table
550 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
551 2379 aaronmk
        * container can be any iterable type
552 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
553
        * compare_right_side: sql_gen.ValueCond|literal value
554 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
555
        use all columns
556 2786 aaronmk
    @return query
557 1981 aaronmk
    '''
558 2315 aaronmk
    # Parse tables param
559 2964 aaronmk
    tables = lists.mk_seq(tables)
560 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
561 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
562 2121 aaronmk
563 2315 aaronmk
    # Parse other params
564 2376 aaronmk
    if conds == None: conds = []
565 2650 aaronmk
    elif dicts.is_dict(conds): conds = conds.items()
566 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
567 135 aaronmk
    assert limit == None or type(limit) == int
568 865 aaronmk
    assert start == None or type(start) == int
569 2315 aaronmk
    if order_by is order_by_pkey:
570
        if distinct_on != []: order_by = None
571
        else: order_by = pkey(db, table0, recover=True)
572 865 aaronmk
573 2315 aaronmk
    query = 'SELECT'
574 2056 aaronmk
575 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
576 2056 aaronmk
577 2200 aaronmk
    # DISTINCT ON columns
578 2233 aaronmk
    if distinct_on != []:
579 2467 aaronmk
        query += '\nDISTINCT'
580 2254 aaronmk
        if distinct_on is not distinct_on_all:
581 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
582
583
    # Columns
584 3027 aaronmk
    if fields == None:
585
        if query.find('\n') >= 0: whitespace = '\n'
586
        else: whitespace = ' '
587
        query += whitespace+'*'
588 2765 aaronmk
    else:
589
        assert fields != []
590 3027 aaronmk
        query += '\n'+('\n, '.join(map(parse_col, fields)))
591 2200 aaronmk
592
    # Main table
593 2467 aaronmk
    query += '\nFROM '+table0.to_str(db)
594 865 aaronmk
595 2122 aaronmk
    # Add joins
596 2271 aaronmk
    left_table = table0
597 2263 aaronmk
    for join_ in tables:
598
        table = join_.table
599 2238 aaronmk
600 2343 aaronmk
        # Parse special values
601
        if join_.type_ is sql_gen.filter_out: # filter no match
602 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
603 2853 aaronmk
                sql_gen.CompareCond(None, '~=')))
604 2343 aaronmk
605 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
606 2122 aaronmk
607
        left_table = table
608
609 865 aaronmk
    missing = True
610 2376 aaronmk
    if conds != []:
611 2576 aaronmk
        if len(conds) == 1: whitespace = ' '
612
        else: whitespace = '\n'
613 2578 aaronmk
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
614
            .to_str(db) for l, r in conds], 'WHERE')
615 865 aaronmk
        missing = False
616 2227 aaronmk
    if order_by != None:
617 2467 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
618
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
619 865 aaronmk
    if start != None:
620 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
621 865 aaronmk
        missing = False
622
    if missing: warnings.warn(DbWarning(
623
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
624
625 2786 aaronmk
    return query
626 11 aaronmk
627 2054 aaronmk
def select(db, *args, **kw_args):
628
    '''For params, see mk_select() and run_query()'''
629
    recover = kw_args.pop('recover', None)
630
    cacheable = kw_args.pop('cacheable', True)
631 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
632 2054 aaronmk
633 2791 aaronmk
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
634
        log_level=log_level)
635 2054 aaronmk
636 2788 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
637 3009 aaronmk
    embeddable=False, ignore=False):
638 1960 aaronmk
    '''
639
    @param returning str|None An inserted column (such as pkey) to return
640 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
641 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
642
        query will be fully cached, not just if it raises an exception.
643 3009 aaronmk
    @param ignore Whether to ignore duplicate keys.
644 1960 aaronmk
    '''
645 2754 aaronmk
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
646 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
647 3010 aaronmk
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
648 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
649 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
650 2063 aaronmk
651 2497 aaronmk
    first_line = 'INSERT INTO '+table.to_str(db)
652 2063 aaronmk
653 3009 aaronmk
    def mk_insert(select_query):
654
        query = first_line
655 3014 aaronmk
        if cols != None:
656
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
657 3009 aaronmk
        query += '\n'+select_query
658
659
        if returning != None:
660
            returning_name_col = sql_gen.to_name_only_col(returning)
661
            query += '\nRETURNING '+returning_name_col.to_str(db)
662
663
        return query
664 2063 aaronmk
665 3017 aaronmk
    return_type = 'unknown'
666
    if returning != None: return_type = returning.to_str(db)+'%TYPE'
667
668 3009 aaronmk
    lang = 'sql'
669
    if ignore:
670 3017 aaronmk
        # Always return something to set the correct rowcount
671
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
672
673 3009 aaronmk
        embeddable = True # must use function
674
        lang = 'plpgsql'
675 3010 aaronmk
676 3092 aaronmk
        if cols == None:
677
            row = [sql_gen.Col(sql_gen.all_cols, 'row')]
678
            row_vars = [sql_gen.Table('row')]
679
        else:
680
            row_vars = row = [sql_gen.Col(c.name, 'row') for c in cols]
681
682 3009 aaronmk
        query = '''\
683 3010 aaronmk
DECLARE
684 3014 aaronmk
    row '''+table.to_str(db)+'''%ROWTYPE;
685 3009 aaronmk
BEGIN
686 3019 aaronmk
    /* Need an EXCEPTION block for each individual row because "When an error is
687
    caught by an EXCEPTION clause, [...] all changes to persistent database
688
    state within the block are rolled back."
689
    This is unfortunate because "A block containing an EXCEPTION clause is
690
    significantly more expensive to enter and exit than a block without one."
691 3015 aaronmk
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
692
#PLPGSQL-ERROR-TRAPPING)
693
    */
694 3092 aaronmk
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
695 3034 aaronmk
'''+select_query+'''
696
    LOOP
697 3015 aaronmk
        BEGIN
698 3019 aaronmk
            RETURN QUERY
699 3014 aaronmk
'''+mk_insert(sql_gen.Values(row).to_str(db))+'''
700 3010 aaronmk
;
701 3015 aaronmk
        EXCEPTION
702 3019 aaronmk
            WHEN unique_violation THEN NULL; -- continue to next row
703 3015 aaronmk
        END;
704 3010 aaronmk
    END LOOP;
705
END;\
706 3009 aaronmk
'''
707
    else: query = mk_insert(select_query)
708
709 2070 aaronmk
    if embeddable:
710
        # Create function
711 2513 aaronmk
        function_name = sql_gen.clean_name(first_line)
712 2189 aaronmk
        while True:
713
            try:
714 2918 aaronmk
                function = db.TempFunction(function_name)
715 2194 aaronmk
716 2189 aaronmk
                function_query = '''\
717 2698 aaronmk
CREATE FUNCTION '''+function.to_str(db)+'''()
718 3017 aaronmk
RETURNS SETOF '''+return_type+'''
719 3009 aaronmk
LANGUAGE '''+lang+'''
720 2467 aaronmk
AS $$
721 3009 aaronmk
'''+query+'''
722 2467 aaronmk
$$;
723 2070 aaronmk
'''
724 2446 aaronmk
                run_query(db, function_query, recover=True, cacheable=True,
725 2945 aaronmk
                    log_ignore_excs=(DuplicateException,))
726 2189 aaronmk
                break # this version was successful
727 2945 aaronmk
            except DuplicateException, e:
728 2189 aaronmk
                function_name = next_version(function_name)
729
                # try again with next version of name
730 2070 aaronmk
731 2337 aaronmk
        # Return query that uses function
732 3009 aaronmk
        cols = None
733
        if returning != None: cols = [returning]
734 2698 aaronmk
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
735 3009 aaronmk
            cols) # AS clause requires function alias
736 2787 aaronmk
        return mk_select(db, func_table, start=0, order_by=None)
737 2070 aaronmk
738 2787 aaronmk
    return query
739 2066 aaronmk
740 3074 aaronmk
def insert_select(db, table, *args, **kw_args):
741 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
742 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
743
        values in
744 2072 aaronmk
    '''
745 2386 aaronmk
    into = kw_args.pop('into', None)
746
    if into != None: kw_args['embeddable'] = True
747 2066 aaronmk
    recover = kw_args.pop('recover', None)
748 3011 aaronmk
    if kw_args.get('ignore', False): recover = True
749 2066 aaronmk
    cacheable = kw_args.pop('cacheable', True)
750 2673 aaronmk
    log_level = kw_args.pop('log_level', 2)
751 2066 aaronmk
752 3074 aaronmk
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
753
        into, recover=recover, cacheable=cacheable, log_level=log_level)
754
    autoanalyze(db, table)
755
    return cur
756 2063 aaronmk
757 2738 aaronmk
default = sql_gen.default # tells insert() to use the default value for a column
758 2066 aaronmk
759 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
760 2085 aaronmk
    '''For params, see insert_select()'''
761 1960 aaronmk
    if lists.is_seq(row): cols = None
762
    else:
763
        cols = row.keys()
764
        row = row.values()
765 2738 aaronmk
    row = list(row) # ensure that "== []" works
766 1960 aaronmk
767 2738 aaronmk
    if row == []: query = None
768
    else: query = sql_gen.Values(row).to_str(db)
769 1961 aaronmk
770 2788 aaronmk
    return insert_select(db, table, cols, query, *args, **kw_args)
771 11 aaronmk
772 3056 aaronmk
def mk_update(db, table, changes=None, cond=None, in_place=False):
773 2402 aaronmk
    '''
774
    @param changes [(col, new_value),...]
775
        * container can be any iterable type
776
        * col: sql_gen.Code|str (for col name)
777
        * new_value: sql_gen.Code|literal value
778
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
779 3056 aaronmk
    @param in_place If set, locks the table and updates rows in place.
780
        This avoids creating dead rows in PostgreSQL.
781
        * cond must be None
782 2402 aaronmk
    @return str query
783
    '''
784 3057 aaronmk
    table = sql_gen.as_Table(table)
785
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
786
        for c, v in changes]
787
788 3056 aaronmk
    if in_place:
789
        assert cond == None
790 3058 aaronmk
791 3065 aaronmk
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
792
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
793
            +db.col_info(sql_gen.with_default_table(c, table)).type
794
            +'\nUSING '+v.to_str(db) for c, v in changes))
795 3058 aaronmk
    else:
796
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
797
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
798
            for c, v in changes))
799
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
800 3056 aaronmk
801 2402 aaronmk
    return query
802
803 3074 aaronmk
def update(db, table, *args, **kw_args):
804 2402 aaronmk
    '''For params, see mk_update() and run_query()'''
805
    recover = kw_args.pop('recover', None)
806 3043 aaronmk
    cacheable = kw_args.pop('cacheable', False)
807 3030 aaronmk
    log_level = kw_args.pop('log_level', 2)
808 2402 aaronmk
809 3074 aaronmk
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
810
        cacheable, log_level=log_level)
811
    autoanalyze(db, table)
812
    return cur
813 2402 aaronmk
814 135 aaronmk
def last_insert_id(db):
815 1849 aaronmk
    module = util.root_module(db.db)
816 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
817
    elif module == 'MySQLdb': return db.insert_id()
818
    else: return None
819 13 aaronmk
820 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
821 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
822 2415 aaronmk
    to names that will be distinct among the columns' tables.
823 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
824 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
825
    @param into The table for the new columns.
826 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
827
        columns will be included in the mapping even if they are not in cols.
828
        The tables of the provided Col objects will be changed to into, so make
829
        copies of them if you want to keep the original tables.
830
    @param as_items Whether to return a list of dict items instead of a dict
831 2383 aaronmk
    @return dict(orig_col=new_col, ...)
832
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
833 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
834
        * All mappings use the into table so its name can easily be
835 2383 aaronmk
          changed for all columns at once
836
    '''
837 2415 aaronmk
    cols = lists.uniqify(cols)
838
839 2394 aaronmk
    items = []
840 2389 aaronmk
    for col in preserve:
841 2390 aaronmk
        orig_col = copy.copy(col)
842 2392 aaronmk
        col.table = into
843 2394 aaronmk
        items.append((orig_col, col))
844
    preserve = set(preserve)
845
    for col in cols:
846 2716 aaronmk
        if col not in preserve:
847
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
848 2394 aaronmk
849
    if not as_items: items = dict(items)
850
    return items
851 2383 aaronmk
852 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
853 2391 aaronmk
    '''For params, see mk_flatten_mapping()
854
    @return See return value of mk_flatten_mapping()
855
    '''
856 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
857
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
858 2786 aaronmk
    run_query_into(db, mk_select(db, joins, cols, limit=limit, start=start),
859 2846 aaronmk
        into=into, add_indexes_=True)
860 2394 aaronmk
    return dict(items)
861 2391 aaronmk
862 3079 aaronmk
##### Database structure introspection
863 2414 aaronmk
864 3079 aaronmk
#### Tables
865
866
def tables(db, schema_like='public', table_like='%', exact=False):
867
    if exact: compare = '='
868
    else: compare = 'LIKE'
869
870
    module = util.root_module(db.db)
871
    if module == 'psycopg2':
872
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
873
            ('tablename', sql_gen.CompareCond(table_like, compare))]
874
        return values(select(db, 'pg_tables', ['tablename'], conds,
875
            order_by='tablename', log_level=4))
876
    elif module == 'MySQLdb':
877
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
878
            , cacheable=True, log_level=4))
879
    else: raise NotImplementedError("Can't list tables for "+module+' database')
880
881
def table_exists(db, table):
882
    table = sql_gen.as_Table(table)
883
    return list(tables(db, table.schema, table.name, exact=True)) != []
884
885 2426 aaronmk
def table_row_count(db, table, recover=None):
886 2786 aaronmk
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
887 2443 aaronmk
        order_by=None, start=0), recover=recover, log_level=3))
888 2426 aaronmk
889 2414 aaronmk
def table_cols(db, table, recover=None):
890
    return list(col_names(select(db, table, limit=0, order_by=None,
891 2443 aaronmk
        recover=recover, log_level=4)))
892 2414 aaronmk
893 2291 aaronmk
def pkey(db, table, recover=None):
894 832 aaronmk
    '''Assumed to be first column in table'''
895 2339 aaronmk
    return table_cols(db, table, recover)[0]
896 832 aaronmk
897 2559 aaronmk
not_null_col = 'not_null_col'
898 2340 aaronmk
899
def table_not_null_col(db, table, recover=None):
900
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
901
    if not_null_col in table_cols(db, table, recover): return not_null_col
902
    else: return pkey(db, table, recover)
903
904 853 aaronmk
def index_cols(db, table, index):
905
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
906
    automatically created. When you don't know whether something is a UNIQUE
907
    constraint or a UNIQUE index, use this function.'''
908 1909 aaronmk
    module = util.root_module(db.db)
909
    if module == 'psycopg2':
910
        return list(values(run_query(db, '''\
911 853 aaronmk
SELECT attname
912 866 aaronmk
FROM
913
(
914
        SELECT attnum, attname
915
        FROM pg_index
916
        JOIN pg_class index ON index.oid = indexrelid
917
        JOIN pg_class table_ ON table_.oid = indrelid
918
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
919
        WHERE
920 2782 aaronmk
            table_.relname = '''+db.esc_value(table)+'''
921
            AND index.relname = '''+db.esc_value(index)+'''
922 866 aaronmk
    UNION
923
        SELECT attnum, attname
924
        FROM
925
        (
926
            SELECT
927
                indrelid
928
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
929
                    AS indkey
930
            FROM pg_index
931
            JOIN pg_class index ON index.oid = indexrelid
932
            JOIN pg_class table_ ON table_.oid = indrelid
933
            WHERE
934 2782 aaronmk
                table_.relname = '''+db.esc_value(table)+'''
935
                AND index.relname = '''+db.esc_value(index)+'''
936 866 aaronmk
        ) s
937
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
938
) s
939 853 aaronmk
ORDER BY attnum
940 2782 aaronmk
'''
941
            , cacheable=True, log_level=4)))
942 1909 aaronmk
    else: raise NotImplementedError("Can't list index columns for "+module+
943
        ' database')
944 853 aaronmk
945 464 aaronmk
def constraint_cols(db, table, constraint):
946 1849 aaronmk
    module = util.root_module(db.db)
947 464 aaronmk
    if module == 'psycopg2':
948
        return list(values(run_query(db, '''\
949
SELECT attname
950
FROM pg_constraint
951
JOIN pg_class ON pg_class.oid = conrelid
952
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
953
WHERE
954 2783 aaronmk
    relname = '''+db.esc_value(table)+'''
955
    AND conname = '''+db.esc_value(constraint)+'''
956 464 aaronmk
ORDER BY attnum
957 2783 aaronmk
'''
958
            )))
959 464 aaronmk
    else: raise NotImplementedError("Can't list constraint columns for "+module+
960
        ' database')
961
962 3079 aaronmk
#### Functions
963
964
def function_exists(db, function):
965
    function = sql_gen.as_Function(function)
966
967
    info_table = sql_gen.Table('routines', 'information_schema')
968
    conds = [('routine_name', function.name)]
969
    schema = function.schema
970
    if schema != None: conds.append(('routine_schema', schema))
971
    # Exclude trigger functions, since they cannot be called directly
972
    conds.append(('data_type', sql_gen.CompareCond('trigger', '!=')))
973
974
    return list(values(select(db, info_table, ['routine_name'], conds,
975
        order_by='routine_schema', limit=1, log_level=4))) != []
976
        # TODO: order_by search_path schema order
977
978
##### Structural changes
979
980
#### Columns
981
982
def add_col(db, table, col, comment=None, **kw_args):
983
    '''
984
    @param col TypedCol Name may be versioned, so be sure to propagate any
985
        renaming back to any source column for the TypedCol.
986
    @param comment None|str SQL comment used to distinguish columns of the same
987
        name from each other when they contain different data, to allow the
988
        ADD COLUMN query to be cached. If not set, query will not be cached.
989
    '''
990
    assert isinstance(col, sql_gen.TypedCol)
991
992
    while True:
993
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
994
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
995
996
        try:
997
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
998
            break
999
        except DuplicateException:
1000
            col.name = next_version(col.name)
1001
            # try again with next version of name
1002
1003
def add_not_null(db, col):
1004
    table = col.table
1005
    col = sql_gen.to_name_only_col(col)
1006
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1007
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1008
1009 2096 aaronmk
row_num_col = '_row_num'
1010
1011 3079 aaronmk
row_num_typed_col = sql_gen.TypedCol(row_num_col, 'serial', nullable=False,
1012
    constraints='PRIMARY KEY')
1013
1014
def add_row_num(db, table):
1015
    '''Adds a row number column to a table. Its name is in row_num_col. It will
1016
    be the primary key.'''
1017
    add_col(db, table, row_num_typed_col, log_level=3)
1018
1019
#### Indexes
1020
1021
def add_pkey(db, table, cols=None, recover=None):
1022
    '''Adds a primary key.
1023
    @param cols [sql_gen.Col,...] The columns in the primary key.
1024
        Defaults to the first column in the table.
1025
    @pre The table must not already have a primary key.
1026
    '''
1027
    table = sql_gen.as_Table(table)
1028
    if cols == None: cols = [pkey(db, table, recover)]
1029
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1030
1031
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1032
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1033
        log_ignore_excs=(DuplicateException,))
1034
1035 2998 aaronmk
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1036 2688 aaronmk
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1037 2538 aaronmk
    Currently, only function calls are supported as expressions.
1038 2998 aaronmk
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1039 2847 aaronmk
        This allows indexes to be used for comparisons where NULLs are equal.
1040 2538 aaronmk
    '''
1041 2964 aaronmk
    exprs = lists.mk_seq(exprs)
1042 2538 aaronmk
1043 2688 aaronmk
    # Parse exprs
1044
    old_exprs = exprs[:]
1045
    exprs = []
1046
    cols = []
1047
    for i, expr in enumerate(old_exprs):
1048 2823 aaronmk
        expr = sql_gen.as_Col(expr, table)
1049 2688 aaronmk
1050 2823 aaronmk
        # Handle nullable columns
1051 2998 aaronmk
        if ensure_not_null_:
1052
            try: expr = ensure_not_null(db, expr)
1053 2860 aaronmk
            except KeyError: pass # unknown type, so just create plain index
1054 2823 aaronmk
1055 2688 aaronmk
        # Extract col
1056 3002 aaronmk
        expr = copy.deepcopy(expr) # don't modify input!
1057 2688 aaronmk
        if isinstance(expr, sql_gen.FunctionCall):
1058
            col = expr.args[0]
1059
            expr = sql_gen.Expr(expr)
1060
        else: col = expr
1061 2823 aaronmk
        assert isinstance(col, sql_gen.Col)
1062 2688 aaronmk
1063
        # Extract table
1064
        if table == None:
1065
            assert sql_gen.is_table_col(col)
1066
            table = col.table
1067
1068
        col.table = None
1069
1070
        exprs.append(expr)
1071
        cols.append(col)
1072 2408 aaronmk
1073 2688 aaronmk
    table = sql_gen.as_Table(table)
1074
    index = sql_gen.Table(str(sql_gen.Col(','.join(map(str, cols)), table)))
1075
1076 3005 aaronmk
    # Add index
1077
    while True:
1078
        str_ = 'CREATE'
1079
        if unique: str_ += ' UNIQUE'
1080
        str_ += ' INDEX '+index.to_str(db)+' ON '+table.to_str(db)+' ('+(
1081
            ', '.join((v.to_str(db) for v in exprs)))+')'
1082
1083
        try:
1084
            run_query(db, str_, recover=True, cacheable=True, log_level=3,
1085
                log_ignore_excs=(DuplicateException,))
1086
            break
1087
        except DuplicateException:
1088
            index.name = next_version(index.name)
1089
            # try again with next version of name
1090 2408 aaronmk
1091 2997 aaronmk
def add_index_col(db, col, suffix, expr, nullable=True):
1092 3000 aaronmk
    if sql_gen.index_col(col) != None: return # already has index col
1093 2997 aaronmk
1094
    new_col = sql_gen.suffixed_col(col, suffix)
1095
1096 3006 aaronmk
    # Add column
1097 3038 aaronmk
    new_typed_col = sql_gen.TypedCol(new_col.name, db.col_info(col).type)
1098 3045 aaronmk
    add_col(db, col.table, new_typed_col, comment='src: '+repr(col),
1099
        log_level=3)
1100 3037 aaronmk
    new_col.name = new_typed_col.name # propagate any renaming
1101 3006 aaronmk
1102 3064 aaronmk
    update(db, col.table, [(new_col, expr)], in_place=True, cacheable=True,
1103
        log_level=3)
1104 2997 aaronmk
    if not nullable: add_not_null(db, new_col)
1105
    add_index(db, new_col)
1106
1107 3104 aaronmk
    col.table.index_cols[col.name] = new_col.name
1108 2997 aaronmk
1109 3047 aaronmk
# Controls when ensure_not_null() will use index columns
1110
not_null_index_cols_min_rows = 0 # rows; initially always use index columns
1111
1112 2997 aaronmk
def ensure_not_null(db, col):
1113
    '''For params, see sql_gen.ensure_not_null()'''
1114
    expr = sql_gen.ensure_not_null(db, col)
1115
1116 3047 aaronmk
    # If a nullable column in a temp table, add separate index column instead.
1117
    # Note that for small datasources, this adds 6-25% to the total import time.
1118
    if (sql_gen.is_temp_col(col) and isinstance(expr, sql_gen.EnsureNotNull)
1119
        and table_row_count(db, col.table) >= not_null_index_cols_min_rows):
1120 2997 aaronmk
        add_index_col(db, col, '::NOT NULL', expr, nullable=False)
1121 3000 aaronmk
        expr = sql_gen.index_col(col)
1122 2997 aaronmk
1123
    return expr
1124
1125 3083 aaronmk
already_indexed = object() # tells add_indexes() the pkey has already been added
1126
1127
def add_indexes(db, table, has_pkey=True):
1128
    '''Adds an index on all columns in a table.
1129
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1130
        index should be added on the first column.
1131
        * If already_indexed, the pkey is assumed to have already been added
1132
    '''
1133
    cols = table_cols(db, table)
1134
    if has_pkey:
1135
        if has_pkey is not already_indexed: add_pkey(db, table)
1136
        cols = cols[1:]
1137
    for col in cols: add_index(db, col, table)
1138
1139 3079 aaronmk
#### Tables
1140 2772 aaronmk
1141 3079 aaronmk
### Maintenance
1142 2772 aaronmk
1143 3079 aaronmk
def analyze(db, table):
1144
    table = sql_gen.as_Table(table)
1145
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1146 2934 aaronmk
1147 3079 aaronmk
def autoanalyze(db, table):
1148
    if db.autoanalyze: analyze(db, table)
1149 2935 aaronmk
1150 3079 aaronmk
def vacuum(db, table):
1151
    table = sql_gen.as_Table(table)
1152
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1153
        log_level=3))
1154 2086 aaronmk
1155 3079 aaronmk
### Lifecycle
1156
1157 2889 aaronmk
def drop_table(db, table):
1158
    table = sql_gen.as_Table(table)
1159
    return run_query(db, 'DROP TABLE IF EXISTS '+table.to_str(db)+' CASCADE')
1160
1161 3082 aaronmk
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1162
    like=None):
1163 2675 aaronmk
    '''Creates a table.
1164 2681 aaronmk
    @param cols [sql_gen.TypedCol,...] The column names and types
1165
    @param has_pkey If set, the first column becomes the primary key.
1166 2760 aaronmk
    @param col_indexes bool|[ref]
1167
        * If True, indexes will be added on all non-pkey columns.
1168
        * If a list reference, [0] will be set to a function to do this.
1169
          This can be used to delay index creation until the table is populated.
1170 2675 aaronmk
    '''
1171
    table = sql_gen.as_Table(table)
1172
1173 3082 aaronmk
    if like != None:
1174
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1175
            ]+cols
1176 2681 aaronmk
    if has_pkey:
1177
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1178 2872 aaronmk
        pkey.constraints = 'PRIMARY KEY'
1179 2681 aaronmk
1180 3085 aaronmk
    temp = table.is_temp and not db.debug_temp
1181
        # temp tables permanent in debug_temp mode
1182 2760 aaronmk
1183 3085 aaronmk
    # Create table
1184
    while True:
1185
        str_ = 'CREATE'
1186
        if temp: str_ += ' TEMP'
1187
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1188
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1189
        str_ += '\n);\n'
1190
1191
        try:
1192
            run_query(db, str_, cacheable=True, log_level=2,
1193
                log_ignore_excs=(DuplicateException,))
1194
            break
1195
        except DuplicateException:
1196
            table.name = next_version(table.name)
1197
            # try again with next version of name
1198
1199 2760 aaronmk
    # Add indexes
1200 2773 aaronmk
    if has_pkey: has_pkey = already_indexed
1201
    def add_indexes_(): add_indexes(db, table, has_pkey)
1202
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1203
    elif col_indexes: add_indexes_() # add now
1204 2675 aaronmk
1205 3084 aaronmk
def copy_table_struct(db, src, dest):
1206
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1207 3085 aaronmk
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1208 3084 aaronmk
1209 3079 aaronmk
### Data
1210 2684 aaronmk
1211 2970 aaronmk
def truncate(db, table, schema='public', **kw_args):
1212
    '''For params, see run_query()'''
1213 2777 aaronmk
    table = sql_gen.as_Table(table, schema)
1214 2970 aaronmk
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1215 2732 aaronmk
1216 2965 aaronmk
def empty_temp(db, tables):
1217 2972 aaronmk
    if db.debug_temp: return # leave temp tables there for debugging
1218 2965 aaronmk
    tables = lists.mk_seq(tables)
1219 2971 aaronmk
    for table in tables: truncate(db, table, log_level=3)
1220 2965 aaronmk
1221 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
1222
    '''For kw_args, see tables()'''
1223
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1224 3094 aaronmk
1225
def distinct_table(db, table, distinct_on):
1226
    '''Creates a copy of a temp table which is distinct on the given columns.
1227 3099 aaronmk
    The old and new tables will both get an index on these columns, to
1228
    facilitate merge joins.
1229 3097 aaronmk
    @param distinct_on If empty, creates a table with one row. This is useful if
1230
        your distinct_on columns are all literal values.
1231 3099 aaronmk
    @return The new table.
1232 3094 aaronmk
    '''
1233 3099 aaronmk
    new_table = sql_gen.suffixed_table(table, '_distinct')
1234 3094 aaronmk
1235 3099 aaronmk
    copy_table_struct(db, table, new_table)
1236 3097 aaronmk
1237
    limit = None
1238
    if distinct_on == []: limit = 1 # one sample row
1239 3099 aaronmk
    else:
1240
        add_index(db, distinct_on, new_table, unique=True)
1241
        add_index(db, distinct_on, table) # for join optimization
1242 3097 aaronmk
1243 3099 aaronmk
    insert_select(db, new_table, None, mk_select(db, table, start=0,
1244 3097 aaronmk
        limit=limit), ignore=True)
1245 3099 aaronmk
    analyze(db, new_table)
1246 3094 aaronmk
1247 3099 aaronmk
    return new_table