Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1960 aaronmk
import lists
11 1889 aaronmk
from Proxy import Proxy
12 1872 aaronmk
import rand
13 2217 aaronmk
import sql_gen
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 2804 aaronmk
def get_cur_query(cur, input_query=None):
20 2168 aaronmk
    raw_query = None
21
    if hasattr(cur, 'query'): raw_query = cur.query
22
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
23 2170 aaronmk
24
    if raw_query != None: return raw_query
25 2804 aaronmk
    else: return '[input] '+strings.ustr(input_query)
26 14 aaronmk
27 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
28
    '''For params, see get_cur_query()'''
29 2771 aaronmk
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
30 135 aaronmk
31 300 aaronmk
class DbException(exc.ExceptionWithCause):
32 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
33 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
34 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
35
36 2143 aaronmk
class ExceptionWithName(DbException):
37
    def __init__(self, name, cause=None):
38 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
39 2143 aaronmk
        self.name = name
40 360 aaronmk
41 3109 aaronmk
class ExceptionWithValue(DbException):
42
    def __init__(self, value, cause=None):
43
        DbException.__init__(self, 'for value: '+strings.as_tt(repr(value)),
44
            cause)
45 2240 aaronmk
        self.value = value
46
47 2945 aaronmk
class ExceptionWithNameType(DbException):
48
    def __init__(self, type_, name, cause=None):
49
        DbException.__init__(self, 'for type: '+strings.as_tt(str(type_))
50
            +'; name: '+strings.as_tt(name), cause)
51
        self.type = type_
52
        self.name = name
53
54 2306 aaronmk
class ConstraintException(DbException):
55
    def __init__(self, name, cols, cause=None):
56 2484 aaronmk
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
57
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
58 2306 aaronmk
        self.name = name
59 468 aaronmk
        self.cols = cols
60 11 aaronmk
61 2523 aaronmk
class MissingCastException(DbException):
62
    def __init__(self, type_, col, cause=None):
63
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
64
            +' on column: '+strings.as_tt(col), cause)
65
        self.type = type_
66
        self.col = col
67
68 2143 aaronmk
class NameException(DbException): pass
69
70 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
71 13 aaronmk
72 2306 aaronmk
class NullValueException(ConstraintException): pass
73 13 aaronmk
74 3109 aaronmk
class InvalidValueException(ExceptionWithValue): pass
75 2239 aaronmk
76 2945 aaronmk
class DuplicateException(ExceptionWithNameType): pass
77 2143 aaronmk
78 89 aaronmk
class EmptyRowException(DbException): pass
79
80 865 aaronmk
##### Warnings
81
82
class DbWarning(UserWarning): pass
83
84 1930 aaronmk
##### Result retrieval
85
86
def col_names(cur): return (col[0] for col in cur.description)
87
88
def rows(cur): return iter(lambda: cur.fetchone(), None)
89
90
def consume_rows(cur):
91
    '''Used to fetch all rows so result will be cached'''
92
    iters.consume_iter(rows(cur))
93
94
def next_row(cur): return rows(cur).next()
95
96
def row(cur):
97
    row_ = next_row(cur)
98
    consume_rows(cur)
99
    return row_
100
101
def next_value(cur): return next_row(cur)[0]
102
103
def value(cur): return row(cur)[0]
104
105
def values(cur): return iters.func_iter(lambda: next_value(cur))
106
107
def value_or_none(cur):
108
    try: return value(cur)
109
    except StopIteration: return None
110
111 2762 aaronmk
##### Escaping
112 2101 aaronmk
113 2573 aaronmk
def esc_name_by_module(module, name):
114
    if module == 'psycopg2' or module == None: quote = '"'
115 2101 aaronmk
    elif module == 'MySQLdb': quote = '`'
116
    else: raise NotImplementedError("Can't escape name for "+module+' database')
117 2500 aaronmk
    return sql_gen.esc_name(name, quote)
118 2101 aaronmk
119
def esc_name_by_engine(engine, name, **kw_args):
120
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
121
122
def esc_name(db, name, **kw_args):
123
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
124
125
def qual_name(db, schema, table):
126
    def esc_name_(name): return esc_name(db, name)
127
    table = esc_name_(table)
128
    if schema != None: return esc_name_(schema)+'.'+table
129
    else: return table
130
131 1869 aaronmk
##### Database connections
132 1849 aaronmk
133 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
134 1926 aaronmk
135 1869 aaronmk
db_engines = {
136
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
137
    'PostgreSQL': ('psycopg2', {}),
138
}
139
140
DatabaseErrors_set = set([DbException])
141
DatabaseErrors = tuple(DatabaseErrors_set)
142
143
def _add_module(module):
144
    DatabaseErrors_set.add(module.DatabaseError)
145
    global DatabaseErrors
146
    DatabaseErrors = tuple(DatabaseErrors_set)
147
148
def db_config_str(db_config):
149
    return db_config['engine']+' database '+db_config['database']
150
151 2448 aaronmk
log_debug_none = lambda msg, level=2: None
152 1901 aaronmk
153 1849 aaronmk
class DbConn:
154 2923 aaronmk
    def __init__(self, db_config, autocommit=True, caching=True,
155 2915 aaronmk
        log_debug=log_debug_none, debug_temp=False):
156
        '''
157
        @param debug_temp Whether temporary objects should instead be permanent.
158
            This assists in debugging the internal objects used by the program.
159
        '''
160 1869 aaronmk
        self.db_config = db_config
161 2190 aaronmk
        self.autocommit = autocommit
162
        self.caching = caching
163 1901 aaronmk
        self.log_debug = log_debug
164 2193 aaronmk
        self.debug = log_debug != log_debug_none
165 2915 aaronmk
        self.debug_temp = debug_temp
166 3074 aaronmk
        self.autoanalyze = False
167 1869 aaronmk
168 3124 aaronmk
        self._savepoint = 0
169 3120 aaronmk
        self._reset()
170 1869 aaronmk
171
    def __getattr__(self, name):
172
        if name == '__dict__': raise Exception('getting __dict__')
173
        if name == 'db': return self._db()
174
        else: raise AttributeError()
175
176
    def __getstate__(self):
177
        state = copy.copy(self.__dict__) # shallow copy
178 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
179 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
180
        return state
181
182 3118 aaronmk
    def clear_cache(self): self.query_results = {}
183
184 3120 aaronmk
    def _reset(self):
185 3118 aaronmk
        self.clear_cache()
186 3124 aaronmk
        assert self._savepoint == 0
187 3118 aaronmk
        self._notices_seen = set()
188
        self.__db = None
189
190 2165 aaronmk
    def connected(self): return self.__db != None
191
192 3116 aaronmk
    def close(self):
193 3119 aaronmk
        if not self.connected(): return
194
195
        self.db.close()
196 3120 aaronmk
        self._reset()
197 3116 aaronmk
198 3125 aaronmk
    def reconnect(self):
199
        # Do not do this in test mode as it would roll back everything
200
        if self.autocommit: self.close()
201
        # Connection will be reopened automatically on first query
202
203 1869 aaronmk
    def _db(self):
204
        if self.__db == None:
205
            # Process db_config
206
            db_config = self.db_config.copy() # don't modify input!
207 2097 aaronmk
            schemas = db_config.pop('schemas', None)
208 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
209
            module = __import__(module_name)
210
            _add_module(module)
211
            for orig, new in mappings.iteritems():
212
                try: util.rename_key(db_config, orig, new)
213
                except KeyError: pass
214
215
            # Connect
216
            self.__db = module.connect(**db_config)
217
218
            # Configure connection
219 2906 aaronmk
            if hasattr(self.db, 'set_isolation_level'):
220
                import psycopg2.extensions
221
                self.db.set_isolation_level(
222
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
223 2101 aaronmk
            if schemas != None:
224 2893 aaronmk
                search_path = [self.esc_name(s) for s in schemas.split(',')]
225
                search_path.append(value(run_query(self, 'SHOW search_path',
226
                    log_level=4)))
227
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
228
                    log_level=3)
229 1869 aaronmk
230
        return self.__db
231 1889 aaronmk
232 1891 aaronmk
    class DbCursor(Proxy):
233 1927 aaronmk
        def __init__(self, outer):
234 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
235 2191 aaronmk
            self.outer = outer
236 1927 aaronmk
            self.query_results = outer.query_results
237 1894 aaronmk
            self.query_lookup = None
238 1891 aaronmk
            self.result = []
239 1889 aaronmk
240 2802 aaronmk
        def execute(self, query):
241 2764 aaronmk
            self._is_insert = query.startswith('INSERT')
242 2797 aaronmk
            self.query_lookup = query
243 2148 aaronmk
            try:
244 2191 aaronmk
                try:
245 2802 aaronmk
                    cur = self.inner.execute(query)
246 2191 aaronmk
                    self.outer.do_autocommit()
247 2802 aaronmk
                finally: self.query = get_cur_query(self.inner, query)
248 1904 aaronmk
            except Exception, e:
249 2802 aaronmk
                _add_cursor_info(e, self, query)
250 1904 aaronmk
                self.result = e # cache the exception as the result
251
                self._cache_result()
252
                raise
253 3004 aaronmk
254
            # Always cache certain queries
255
            if query.startswith('CREATE') or query.startswith('ALTER'):
256 3007 aaronmk
                # structural changes
257 3040 aaronmk
                # Rest of query must be unique in the face of name collisions,
258
                # so don't cache ADD COLUMN unless it has distinguishing comment
259
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
260 3007 aaronmk
                    self._cache_result()
261 3004 aaronmk
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
262 2800 aaronmk
                consume_rows(self) # fetch all rows so result will be cached
263 3004 aaronmk
264 2762 aaronmk
            return cur
265 1894 aaronmk
266 1891 aaronmk
        def fetchone(self):
267
            row = self.inner.fetchone()
268 1899 aaronmk
            if row != None: self.result.append(row)
269
            # otherwise, fetched all rows
270 1904 aaronmk
            else: self._cache_result()
271
            return row
272
273
        def _cache_result(self):
274 2948 aaronmk
            # For inserts that return a result set, don't cache result set since
275
            # inserts are not idempotent. Other non-SELECT queries don't have
276
            # their result set read, so only exceptions will be cached (an
277
            # invalid query will always be invalid).
278 1930 aaronmk
            if self.query_results != None and (not self._is_insert
279 1906 aaronmk
                or isinstance(self.result, Exception)):
280
281 1894 aaronmk
                assert self.query_lookup != None
282 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
283
                    util.dict_subset(dicts.AttrsDictView(self),
284
                    ['query', 'result', 'rowcount', 'description']))
285 1906 aaronmk
286 1916 aaronmk
        class CacheCursor:
287
            def __init__(self, cached_result): self.__dict__ = cached_result
288
289 1927 aaronmk
            def execute(self, *args, **kw_args):
290 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
291
                # otherwise, result is a rows list
292
                self.iter = iter(self.result)
293
294
            def fetchone(self):
295
                try: return self.iter.next()
296
                except StopIteration: return None
297 1891 aaronmk
298 2212 aaronmk
    def esc_value(self, value):
299 2663 aaronmk
        try: str_ = self.mogrify('%s', [value])
300
        except NotImplementedError, e:
301
            module = util.root_module(self.db)
302
            if module == 'MySQLdb':
303
                import _mysql
304
                str_ = _mysql.escape_string(value)
305
            else: raise e
306 2374 aaronmk
        return strings.to_unicode(str_)
307 2212 aaronmk
308 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
309
310 2814 aaronmk
    def std_code(self, str_):
311
        '''Standardizes SQL code.
312
        * Ensures that string literals are prefixed by `E`
313
        '''
314
        if str_.startswith("'"): str_ = 'E'+str_
315
        return str_
316
317 2665 aaronmk
    def can_mogrify(self):
318 2663 aaronmk
        module = util.root_module(self.db)
319 2665 aaronmk
        return module == 'psycopg2'
320 2663 aaronmk
321 2665 aaronmk
    def mogrify(self, query, params=None):
322
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
323
        else: raise NotImplementedError("Can't mogrify query")
324
325 2671 aaronmk
    def print_notices(self):
326 2725 aaronmk
        if hasattr(self.db, 'notices'):
327
            for msg in self.db.notices:
328
                if msg not in self._notices_seen:
329
                    self._notices_seen.add(msg)
330
                    self.log_debug(msg, level=2)
331 2671 aaronmk
332 2793 aaronmk
    def run_query(self, query, cacheable=False, log_level=2,
333 2464 aaronmk
        debug_msg_ref=None):
334 2445 aaronmk
        '''
335 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
336
            throws one of these exceptions.
337 2664 aaronmk
        @param debug_msg_ref If specified, the log message will be returned in
338
            this instead of being output. This allows you to filter log messages
339
            depending on the result of the query.
340 2445 aaronmk
        '''
341 2167 aaronmk
        assert query != None
342
343 2047 aaronmk
        if not self.caching: cacheable = False
344 1903 aaronmk
        used_cache = False
345 2664 aaronmk
346
        def log_msg(query):
347
            if used_cache: cache_status = 'cache hit'
348
            elif cacheable: cache_status = 'cache miss'
349
            else: cache_status = 'non-cacheable'
350
            return 'DB query: '+cache_status+':\n'+strings.as_code(query, 'SQL')
351
352 1903 aaronmk
        try:
353 1927 aaronmk
            # Get cursor
354
            if cacheable:
355
                try:
356 2797 aaronmk
                    cur = self.query_results[query]
357 1927 aaronmk
                    used_cache = True
358
                except KeyError: cur = self.DbCursor(self)
359
            else: cur = self.db.cursor()
360
361 2664 aaronmk
            # Log query
362
            if self.debug and debug_msg_ref == None: # log before running
363
                self.log_debug(log_msg(query), log_level)
364
365 1927 aaronmk
            # Run query
366 2793 aaronmk
            cur.execute(query)
367 1903 aaronmk
        finally:
368 2671 aaronmk
            self.print_notices()
369 2664 aaronmk
            if self.debug and debug_msg_ref != None: # return after running
370 2793 aaronmk
                debug_msg_ref[0] = log_msg(str(get_cur_query(cur, query)))
371 1903 aaronmk
372
        return cur
373 1914 aaronmk
374 2797 aaronmk
    def is_cached(self, query): return query in self.query_results
375 2139 aaronmk
376 2907 aaronmk
    def with_autocommit(self, func):
377 2801 aaronmk
        import psycopg2.extensions
378
379
        prev_isolation_level = self.db.isolation_level
380 2907 aaronmk
        self.db.set_isolation_level(
381
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
382 2683 aaronmk
        try: return func()
383 2801 aaronmk
        finally: self.db.set_isolation_level(prev_isolation_level)
384 2683 aaronmk
385 2139 aaronmk
    def with_savepoint(self, func):
386 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
387 2443 aaronmk
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
388 2139 aaronmk
        self._savepoint += 1
389 2930 aaronmk
        try: return func()
390 2139 aaronmk
        except:
391 2443 aaronmk
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
392 2139 aaronmk
            raise
393 2930 aaronmk
        finally:
394
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
395
            # "The savepoint remains valid and can be rolled back to again"
396
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
397 2443 aaronmk
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
398 2930 aaronmk
399
            self._savepoint -= 1
400
            assert self._savepoint >= 0
401
402
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
403 2191 aaronmk
404
    def do_autocommit(self):
405
        '''Autocommits if outside savepoint'''
406
        assert self._savepoint >= 0
407
        if self.autocommit and self._savepoint == 0:
408 2924 aaronmk
            self.log_debug('Autocommitting', level=4)
409 2191 aaronmk
            self.db.commit()
410 2643 aaronmk
411 2819 aaronmk
    def col_info(self, col):
412 2643 aaronmk
        table = sql_gen.Table('columns', 'information_schema')
413 3063 aaronmk
        type_ = sql_gen.Coalesce(sql_gen.Nullif(sql_gen.Col('data_type'),
414
            'USER-DEFINED'), sql_gen.Col('udt_name'))
415 3078 aaronmk
        cols = [type_, 'column_default',
416
            sql_gen.Cast('boolean', sql_gen.Col('is_nullable'))]
417 2643 aaronmk
418
        conds = [('table_name', col.table.name), ('column_name', col.name)]
419
        schema = col.table.schema
420
        if schema != None: conds.append(('table_schema', schema))
421
422 2819 aaronmk
        type_, default, nullable = row(select(self, table, cols, conds,
423 3059 aaronmk
            order_by='table_schema', limit=1, cacheable=False, log_level=4))
424 2643 aaronmk
            # TODO: order_by search_path schema order
425 2819 aaronmk
        default = sql_gen.as_Code(default, self)
426
427
        return sql_gen.TypedCol(col.name, type_, default, nullable)
428 2917 aaronmk
429
    def TempFunction(self, name):
430
        if self.debug_temp: schema = None
431
        else: schema = 'pg_temp'
432
        return sql_gen.Function(name, schema)
433 1849 aaronmk
434 1869 aaronmk
connect = DbConn
435
436 832 aaronmk
##### Recoverable querying
437 15 aaronmk
438 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
439 11 aaronmk
440 2791 aaronmk
def run_query(db, query, recover=None, cacheable=False, log_level=2,
441
    log_ignore_excs=None, **kw_args):
442 2794 aaronmk
    '''For params, see DbConn.run_query()'''
443 830 aaronmk
    if recover == None: recover = False
444 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
445
    log_ignore_excs = tuple(log_ignore_excs)
446 830 aaronmk
447 2666 aaronmk
    debug_msg_ref = None # usually, db.run_query() logs query before running it
448
    # But if filtering with log_ignore_excs, wait until after exception parsing
449 2984 aaronmk
    if log_ignore_excs != () or not db.can_mogrify(): debug_msg_ref = [None]
450 2666 aaronmk
451 2148 aaronmk
    try:
452 2464 aaronmk
        try:
453 2794 aaronmk
            def run(): return db.run_query(query, cacheable, log_level,
454 2793 aaronmk
                debug_msg_ref, **kw_args)
455 2796 aaronmk
            if recover and not db.is_cached(query):
456 2464 aaronmk
                return with_savepoint(db, run)
457
            else: return run() # don't need savepoint if cached
458
        except Exception, e:
459 3095 aaronmk
            msg = strings.ustr(e.args[0])
460 2464 aaronmk
461 3095 aaronmk
            match = re.match(r'^duplicate key value violates unique constraint '
462 3096 aaronmk
                r'"((_?[^\W_]+(?=[._]))?.+?)"', msg)
463 2464 aaronmk
            if match:
464
                constraint, table = match.groups()
465 3025 aaronmk
                cols = []
466
                if recover: # need auto-rollback to run index_cols()
467
                    try: cols = index_cols(db, table, constraint)
468
                    except NotImplementedError: pass
469
                raise DuplicateKeyException(constraint, cols, e)
470 2464 aaronmk
471 3095 aaronmk
            match = re.match(r'^null value in column "(.+?)" violates not-null'
472 2464 aaronmk
                r' constraint', msg)
473
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
474
475 3095 aaronmk
            match = re.match(r'^(?:invalid input (?:syntax|value)\b.*?'
476 3109 aaronmk
                r'|.+? field value out of range): "(.+?)"', msg)
477 2464 aaronmk
            if match:
478 3109 aaronmk
                value, = match.groups()
479
                raise InvalidValueException(strings.to_unicode(value), e)
480 2464 aaronmk
481 3095 aaronmk
            match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
482 2523 aaronmk
                r'is of type', msg)
483
            if match:
484
                col, type_ = match.groups()
485
                raise MissingCastException(type_, col, e)
486
487 3095 aaronmk
            match = re.match(r'^(\S+) "(.+?)".*? already exists', msg)
488 2945 aaronmk
            if match:
489
                type_, name = match.groups()
490
                raise DuplicateException(type_, name, e)
491 2464 aaronmk
492
            raise # no specific exception raised
493
    except log_ignore_excs:
494
        log_level += 2
495
        raise
496
    finally:
497 2666 aaronmk
        if debug_msg_ref != None and debug_msg_ref[0] != None:
498
            db.log_debug(debug_msg_ref[0], log_level)
499 830 aaronmk
500 832 aaronmk
##### Basic queries
501
502 2153 aaronmk
def next_version(name):
503 2163 aaronmk
    version = 1 # first existing name was version 0
504 2586 aaronmk
    match = re.match(r'^(.*)#(\d+)$', name)
505 2153 aaronmk
    if match:
506 2586 aaronmk
        name, version = match.groups()
507
        version = int(version)+1
508 2932 aaronmk
    return sql_gen.concat(name, '#'+str(version))
509 2153 aaronmk
510 2899 aaronmk
def lock_table(db, table, mode):
511
    table = sql_gen.as_Table(table)
512
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
513
514 2789 aaronmk
def run_query_into(db, query, into=None, add_indexes_=False, **kw_args):
515 2085 aaronmk
    '''Outputs a query to a temp table.
516
    For params, see run_query().
517
    '''
518 2789 aaronmk
    if into == None: return run_query(db, query, **kw_args)
519 2790 aaronmk
520
    assert isinstance(into, sql_gen.Table)
521
522 2992 aaronmk
    into.is_temp = True
523 3008 aaronmk
    # "temporary tables cannot specify a schema name", so remove schema
524
    into.schema = None
525 2992 aaronmk
526 2790 aaronmk
    kw_args['recover'] = True
527 2945 aaronmk
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
528 2790 aaronmk
529 2916 aaronmk
    temp = not db.debug_temp # tables are permanent in debug_temp mode
530 2790 aaronmk
531
    # Create table
532
    while True:
533
        create_query = 'CREATE'
534
        if temp: create_query += ' TEMP'
535
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
536 2385 aaronmk
537 2790 aaronmk
        try:
538
            cur = run_query(db, create_query, **kw_args)
539
                # CREATE TABLE AS sets rowcount to # rows in query
540
            break
541 2945 aaronmk
        except DuplicateException, e:
542 2790 aaronmk
            into.name = next_version(into.name)
543
            # try again with next version of name
544
545
    if add_indexes_: add_indexes(db, into)
546 3075 aaronmk
547
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
548
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
549
    # table is going to be used in complex queries, it is wise to run ANALYZE on
550
    # the temporary table after it is populated."
551
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
552
    # If into is not a temp table, ANALYZE is useful but not required.
553 3073 aaronmk
    analyze(db, into)
554 2790 aaronmk
555
    return cur
556 2085 aaronmk
557 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
558
559 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
560
561 2233 aaronmk
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
562 2293 aaronmk
    start=None, order_by=order_by_pkey, default_table=None):
563 1981 aaronmk
    '''
564 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
565 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
566 1981 aaronmk
    @param fields Use None to select all fields in the table
567 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
568 2379 aaronmk
        * container can be any iterable type
569 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
570
        * compare_right_side: sql_gen.ValueCond|literal value
571 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
572
        use all columns
573 2786 aaronmk
    @return query
574 1981 aaronmk
    '''
575 2315 aaronmk
    # Parse tables param
576 2964 aaronmk
    tables = lists.mk_seq(tables)
577 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
578 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
579 2121 aaronmk
580 2315 aaronmk
    # Parse other params
581 2376 aaronmk
    if conds == None: conds = []
582 2650 aaronmk
    elif dicts.is_dict(conds): conds = conds.items()
583 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
584 3129 aaronmk
    assert limit == None or isinstance(limit, (int, long))
585
    assert start == None or isinstance(start, (int, long))
586 2315 aaronmk
    if order_by is order_by_pkey:
587
        if distinct_on != []: order_by = None
588
        else: order_by = pkey(db, table0, recover=True)
589 865 aaronmk
590 2315 aaronmk
    query = 'SELECT'
591 2056 aaronmk
592 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
593 2056 aaronmk
594 2200 aaronmk
    # DISTINCT ON columns
595 2233 aaronmk
    if distinct_on != []:
596 2467 aaronmk
        query += '\nDISTINCT'
597 2254 aaronmk
        if distinct_on is not distinct_on_all:
598 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
599
600
    # Columns
601 3027 aaronmk
    if fields == None:
602
        if query.find('\n') >= 0: whitespace = '\n'
603
        else: whitespace = ' '
604
        query += whitespace+'*'
605 2765 aaronmk
    else:
606
        assert fields != []
607 3027 aaronmk
        query += '\n'+('\n, '.join(map(parse_col, fields)))
608 2200 aaronmk
609
    # Main table
610 2467 aaronmk
    query += '\nFROM '+table0.to_str(db)
611 865 aaronmk
612 2122 aaronmk
    # Add joins
613 2271 aaronmk
    left_table = table0
614 2263 aaronmk
    for join_ in tables:
615
        table = join_.table
616 2238 aaronmk
617 2343 aaronmk
        # Parse special values
618
        if join_.type_ is sql_gen.filter_out: # filter no match
619 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
620 2853 aaronmk
                sql_gen.CompareCond(None, '~=')))
621 2343 aaronmk
622 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
623 2122 aaronmk
624
        left_table = table
625
626 865 aaronmk
    missing = True
627 2376 aaronmk
    if conds != []:
628 2576 aaronmk
        if len(conds) == 1: whitespace = ' '
629
        else: whitespace = '\n'
630 2578 aaronmk
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
631
            .to_str(db) for l, r in conds], 'WHERE')
632 865 aaronmk
        missing = False
633 2227 aaronmk
    if order_by != None:
634 2467 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
635
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
636 865 aaronmk
    if start != None:
637 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
638 865 aaronmk
        missing = False
639
    if missing: warnings.warn(DbWarning(
640
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
641
642 2786 aaronmk
    return query
643 11 aaronmk
644 2054 aaronmk
def select(db, *args, **kw_args):
645
    '''For params, see mk_select() and run_query()'''
646
    recover = kw_args.pop('recover', None)
647
    cacheable = kw_args.pop('cacheable', True)
648 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
649 2054 aaronmk
650 2791 aaronmk
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
651
        log_level=log_level)
652 2054 aaronmk
653 2788 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
654 3009 aaronmk
    embeddable=False, ignore=False):
655 1960 aaronmk
    '''
656
    @param returning str|None An inserted column (such as pkey) to return
657 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
658 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
659
        query will be fully cached, not just if it raises an exception.
660 3009 aaronmk
    @param ignore Whether to ignore duplicate keys.
661 1960 aaronmk
    '''
662 2754 aaronmk
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
663 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
664 3010 aaronmk
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
665 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
666 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
667 2063 aaronmk
668 2497 aaronmk
    first_line = 'INSERT INTO '+table.to_str(db)
669 2063 aaronmk
670 3009 aaronmk
    def mk_insert(select_query):
671
        query = first_line
672 3014 aaronmk
        if cols != None:
673
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
674 3009 aaronmk
        query += '\n'+select_query
675
676
        if returning != None:
677
            returning_name_col = sql_gen.to_name_only_col(returning)
678
            query += '\nRETURNING '+returning_name_col.to_str(db)
679
680
        return query
681 2063 aaronmk
682 3017 aaronmk
    return_type = 'unknown'
683
    if returning != None: return_type = returning.to_str(db)+'%TYPE'
684
685 3009 aaronmk
    lang = 'sql'
686
    if ignore:
687 3017 aaronmk
        # Always return something to set the correct rowcount
688
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
689
690 3009 aaronmk
        embeddable = True # must use function
691
        lang = 'plpgsql'
692 3010 aaronmk
693 3092 aaronmk
        if cols == None:
694
            row = [sql_gen.Col(sql_gen.all_cols, 'row')]
695
            row_vars = [sql_gen.Table('row')]
696
        else:
697
            row_vars = row = [sql_gen.Col(c.name, 'row') for c in cols]
698
699 3009 aaronmk
        query = '''\
700 3010 aaronmk
DECLARE
701 3014 aaronmk
    row '''+table.to_str(db)+'''%ROWTYPE;
702 3009 aaronmk
BEGIN
703 3019 aaronmk
    /* Need an EXCEPTION block for each individual row because "When an error is
704
    caught by an EXCEPTION clause, [...] all changes to persistent database
705
    state within the block are rolled back."
706
    This is unfortunate because "A block containing an EXCEPTION clause is
707
    significantly more expensive to enter and exit than a block without one."
708 3015 aaronmk
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
709
#PLPGSQL-ERROR-TRAPPING)
710
    */
711 3092 aaronmk
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
712 3034 aaronmk
'''+select_query+'''
713
    LOOP
714 3015 aaronmk
        BEGIN
715 3019 aaronmk
            RETURN QUERY
716 3014 aaronmk
'''+mk_insert(sql_gen.Values(row).to_str(db))+'''
717 3010 aaronmk
;
718 3015 aaronmk
        EXCEPTION
719 3019 aaronmk
            WHEN unique_violation THEN NULL; -- continue to next row
720 3015 aaronmk
        END;
721 3010 aaronmk
    END LOOP;
722
END;\
723 3009 aaronmk
'''
724
    else: query = mk_insert(select_query)
725
726 2070 aaronmk
    if embeddable:
727
        # Create function
728 2513 aaronmk
        function_name = sql_gen.clean_name(first_line)
729 2189 aaronmk
        while True:
730
            try:
731 2918 aaronmk
                function = db.TempFunction(function_name)
732 2194 aaronmk
733 2189 aaronmk
                function_query = '''\
734 2698 aaronmk
CREATE FUNCTION '''+function.to_str(db)+'''()
735 3017 aaronmk
RETURNS SETOF '''+return_type+'''
736 3009 aaronmk
LANGUAGE '''+lang+'''
737 2467 aaronmk
AS $$
738 3009 aaronmk
'''+query+'''
739 2467 aaronmk
$$;
740 2070 aaronmk
'''
741 2446 aaronmk
                run_query(db, function_query, recover=True, cacheable=True,
742 2945 aaronmk
                    log_ignore_excs=(DuplicateException,))
743 2189 aaronmk
                break # this version was successful
744 2945 aaronmk
            except DuplicateException, e:
745 2189 aaronmk
                function_name = next_version(function_name)
746
                # try again with next version of name
747 2070 aaronmk
748 2337 aaronmk
        # Return query that uses function
749 3009 aaronmk
        cols = None
750
        if returning != None: cols = [returning]
751 2698 aaronmk
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
752 3009 aaronmk
            cols) # AS clause requires function alias
753 2787 aaronmk
        return mk_select(db, func_table, start=0, order_by=None)
754 2070 aaronmk
755 2787 aaronmk
    return query
756 2066 aaronmk
757 3074 aaronmk
def insert_select(db, table, *args, **kw_args):
758 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
759 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
760
        values in
761 2072 aaronmk
    '''
762 2386 aaronmk
    into = kw_args.pop('into', None)
763
    if into != None: kw_args['embeddable'] = True
764 2066 aaronmk
    recover = kw_args.pop('recover', None)
765 3011 aaronmk
    if kw_args.get('ignore', False): recover = True
766 2066 aaronmk
    cacheable = kw_args.pop('cacheable', True)
767 2673 aaronmk
    log_level = kw_args.pop('log_level', 2)
768 2066 aaronmk
769 3074 aaronmk
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
770
        into, recover=recover, cacheable=cacheable, log_level=log_level)
771
    autoanalyze(db, table)
772
    return cur
773 2063 aaronmk
774 2738 aaronmk
default = sql_gen.default # tells insert() to use the default value for a column
775 2066 aaronmk
776 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
777 2085 aaronmk
    '''For params, see insert_select()'''
778 1960 aaronmk
    if lists.is_seq(row): cols = None
779
    else:
780
        cols = row.keys()
781
        row = row.values()
782 2738 aaronmk
    row = list(row) # ensure that "== []" works
783 1960 aaronmk
784 2738 aaronmk
    if row == []: query = None
785
    else: query = sql_gen.Values(row).to_str(db)
786 1961 aaronmk
787 2788 aaronmk
    return insert_select(db, table, cols, query, *args, **kw_args)
788 11 aaronmk
789 3056 aaronmk
def mk_update(db, table, changes=None, cond=None, in_place=False):
790 2402 aaronmk
    '''
791
    @param changes [(col, new_value),...]
792
        * container can be any iterable type
793
        * col: sql_gen.Code|str (for col name)
794
        * new_value: sql_gen.Code|literal value
795
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
796 3056 aaronmk
    @param in_place If set, locks the table and updates rows in place.
797
        This avoids creating dead rows in PostgreSQL.
798
        * cond must be None
799 2402 aaronmk
    @return str query
800
    '''
801 3057 aaronmk
    table = sql_gen.as_Table(table)
802
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
803
        for c, v in changes]
804
805 3056 aaronmk
    if in_place:
806
        assert cond == None
807 3058 aaronmk
808 3065 aaronmk
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
809
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
810
            +db.col_info(sql_gen.with_default_table(c, table)).type
811
            +'\nUSING '+v.to_str(db) for c, v in changes))
812 3058 aaronmk
    else:
813
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
814
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
815
            for c, v in changes))
816
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
817 3056 aaronmk
818 2402 aaronmk
    return query
819
820 3074 aaronmk
def update(db, table, *args, **kw_args):
821 2402 aaronmk
    '''For params, see mk_update() and run_query()'''
822
    recover = kw_args.pop('recover', None)
823 3043 aaronmk
    cacheable = kw_args.pop('cacheable', False)
824 3030 aaronmk
    log_level = kw_args.pop('log_level', 2)
825 2402 aaronmk
826 3074 aaronmk
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
827
        cacheable, log_level=log_level)
828
    autoanalyze(db, table)
829
    return cur
830 2402 aaronmk
831 135 aaronmk
def last_insert_id(db):
832 1849 aaronmk
    module = util.root_module(db.db)
833 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
834
    elif module == 'MySQLdb': return db.insert_id()
835
    else: return None
836 13 aaronmk
837 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
838 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
839 2415 aaronmk
    to names that will be distinct among the columns' tables.
840 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
841 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
842
    @param into The table for the new columns.
843 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
844
        columns will be included in the mapping even if they are not in cols.
845
        The tables of the provided Col objects will be changed to into, so make
846
        copies of them if you want to keep the original tables.
847
    @param as_items Whether to return a list of dict items instead of a dict
848 2383 aaronmk
    @return dict(orig_col=new_col, ...)
849
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
850 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
851
        * All mappings use the into table so its name can easily be
852 2383 aaronmk
          changed for all columns at once
853
    '''
854 2415 aaronmk
    cols = lists.uniqify(cols)
855
856 2394 aaronmk
    items = []
857 2389 aaronmk
    for col in preserve:
858 2390 aaronmk
        orig_col = copy.copy(col)
859 2392 aaronmk
        col.table = into
860 2394 aaronmk
        items.append((orig_col, col))
861
    preserve = set(preserve)
862
    for col in cols:
863 2716 aaronmk
        if col not in preserve:
864
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
865 2394 aaronmk
866
    if not as_items: items = dict(items)
867
    return items
868 2383 aaronmk
869 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
870 2391 aaronmk
    '''For params, see mk_flatten_mapping()
871
    @return See return value of mk_flatten_mapping()
872
    '''
873 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
874
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
875 2786 aaronmk
    run_query_into(db, mk_select(db, joins, cols, limit=limit, start=start),
876 2846 aaronmk
        into=into, add_indexes_=True)
877 2394 aaronmk
    return dict(items)
878 2391 aaronmk
879 3079 aaronmk
##### Database structure introspection
880 2414 aaronmk
881 3079 aaronmk
#### Tables
882
883
def tables(db, schema_like='public', table_like='%', exact=False):
884
    if exact: compare = '='
885
    else: compare = 'LIKE'
886
887
    module = util.root_module(db.db)
888
    if module == 'psycopg2':
889
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
890
            ('tablename', sql_gen.CompareCond(table_like, compare))]
891
        return values(select(db, 'pg_tables', ['tablename'], conds,
892
            order_by='tablename', log_level=4))
893
    elif module == 'MySQLdb':
894
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
895
            , cacheable=True, log_level=4))
896
    else: raise NotImplementedError("Can't list tables for "+module+' database')
897
898
def table_exists(db, table):
899
    table = sql_gen.as_Table(table)
900
    return list(tables(db, table.schema, table.name, exact=True)) != []
901
902 2426 aaronmk
def table_row_count(db, table, recover=None):
903 2786 aaronmk
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
904 2443 aaronmk
        order_by=None, start=0), recover=recover, log_level=3))
905 2426 aaronmk
906 2414 aaronmk
def table_cols(db, table, recover=None):
907
    return list(col_names(select(db, table, limit=0, order_by=None,
908 2443 aaronmk
        recover=recover, log_level=4)))
909 2414 aaronmk
910 2291 aaronmk
def pkey(db, table, recover=None):
911 832 aaronmk
    '''Assumed to be first column in table'''
912 2339 aaronmk
    return table_cols(db, table, recover)[0]
913 832 aaronmk
914 2559 aaronmk
not_null_col = 'not_null_col'
915 2340 aaronmk
916
def table_not_null_col(db, table, recover=None):
917
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
918
    if not_null_col in table_cols(db, table, recover): return not_null_col
919
    else: return pkey(db, table, recover)
920
921 853 aaronmk
def index_cols(db, table, index):
922
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
923
    automatically created. When you don't know whether something is a UNIQUE
924
    constraint or a UNIQUE index, use this function.'''
925 1909 aaronmk
    module = util.root_module(db.db)
926
    if module == 'psycopg2':
927
        return list(values(run_query(db, '''\
928 853 aaronmk
SELECT attname
929 866 aaronmk
FROM
930
(
931
        SELECT attnum, attname
932
        FROM pg_index
933
        JOIN pg_class index ON index.oid = indexrelid
934
        JOIN pg_class table_ ON table_.oid = indrelid
935
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
936
        WHERE
937 2782 aaronmk
            table_.relname = '''+db.esc_value(table)+'''
938
            AND index.relname = '''+db.esc_value(index)+'''
939 866 aaronmk
    UNION
940
        SELECT attnum, attname
941
        FROM
942
        (
943
            SELECT
944
                indrelid
945
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
946
                    AS indkey
947
            FROM pg_index
948
            JOIN pg_class index ON index.oid = indexrelid
949
            JOIN pg_class table_ ON table_.oid = indrelid
950
            WHERE
951 2782 aaronmk
                table_.relname = '''+db.esc_value(table)+'''
952
                AND index.relname = '''+db.esc_value(index)+'''
953 866 aaronmk
        ) s
954
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
955
) s
956 853 aaronmk
ORDER BY attnum
957 2782 aaronmk
'''
958
            , cacheable=True, log_level=4)))
959 1909 aaronmk
    else: raise NotImplementedError("Can't list index columns for "+module+
960
        ' database')
961 853 aaronmk
962 464 aaronmk
def constraint_cols(db, table, constraint):
963 1849 aaronmk
    module = util.root_module(db.db)
964 464 aaronmk
    if module == 'psycopg2':
965
        return list(values(run_query(db, '''\
966
SELECT attname
967
FROM pg_constraint
968
JOIN pg_class ON pg_class.oid = conrelid
969
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
970
WHERE
971 2783 aaronmk
    relname = '''+db.esc_value(table)+'''
972
    AND conname = '''+db.esc_value(constraint)+'''
973 464 aaronmk
ORDER BY attnum
974 2783 aaronmk
'''
975
            )))
976 464 aaronmk
    else: raise NotImplementedError("Can't list constraint columns for "+module+
977
        ' database')
978
979 3079 aaronmk
#### Functions
980
981
def function_exists(db, function):
982
    function = sql_gen.as_Function(function)
983
984
    info_table = sql_gen.Table('routines', 'information_schema')
985
    conds = [('routine_name', function.name)]
986
    schema = function.schema
987
    if schema != None: conds.append(('routine_schema', schema))
988
    # Exclude trigger functions, since they cannot be called directly
989
    conds.append(('data_type', sql_gen.CompareCond('trigger', '!=')))
990
991
    return list(values(select(db, info_table, ['routine_name'], conds,
992
        order_by='routine_schema', limit=1, log_level=4))) != []
993
        # TODO: order_by search_path schema order
994
995
##### Structural changes
996
997
#### Columns
998
999
def add_col(db, table, col, comment=None, **kw_args):
1000
    '''
1001
    @param col TypedCol Name may be versioned, so be sure to propagate any
1002
        renaming back to any source column for the TypedCol.
1003
    @param comment None|str SQL comment used to distinguish columns of the same
1004
        name from each other when they contain different data, to allow the
1005
        ADD COLUMN query to be cached. If not set, query will not be cached.
1006
    '''
1007
    assert isinstance(col, sql_gen.TypedCol)
1008
1009
    while True:
1010
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1011
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1012
1013
        try:
1014
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1015
            break
1016
        except DuplicateException:
1017
            col.name = next_version(col.name)
1018
            # try again with next version of name
1019
1020
def add_not_null(db, col):
1021
    table = col.table
1022
    col = sql_gen.to_name_only_col(col)
1023
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1024
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1025
1026 2096 aaronmk
row_num_col = '_row_num'
1027
1028 3079 aaronmk
row_num_typed_col = sql_gen.TypedCol(row_num_col, 'serial', nullable=False,
1029
    constraints='PRIMARY KEY')
1030
1031
def add_row_num(db, table):
1032
    '''Adds a row number column to a table. Its name is in row_num_col. It will
1033
    be the primary key.'''
1034
    add_col(db, table, row_num_typed_col, log_level=3)
1035
1036
#### Indexes
1037
1038
def add_pkey(db, table, cols=None, recover=None):
1039
    '''Adds a primary key.
1040
    @param cols [sql_gen.Col,...] The columns in the primary key.
1041
        Defaults to the first column in the table.
1042
    @pre The table must not already have a primary key.
1043
    '''
1044
    table = sql_gen.as_Table(table)
1045
    if cols == None: cols = [pkey(db, table, recover)]
1046
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1047
1048
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1049
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1050
        log_ignore_excs=(DuplicateException,))
1051
1052 2998 aaronmk
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1053 2688 aaronmk
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1054 2538 aaronmk
    Currently, only function calls are supported as expressions.
1055 2998 aaronmk
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1056 2847 aaronmk
        This allows indexes to be used for comparisons where NULLs are equal.
1057 2538 aaronmk
    '''
1058 2964 aaronmk
    exprs = lists.mk_seq(exprs)
1059 2538 aaronmk
1060 2688 aaronmk
    # Parse exprs
1061
    old_exprs = exprs[:]
1062
    exprs = []
1063
    cols = []
1064
    for i, expr in enumerate(old_exprs):
1065 2823 aaronmk
        expr = sql_gen.as_Col(expr, table)
1066 2688 aaronmk
1067 2823 aaronmk
        # Handle nullable columns
1068 2998 aaronmk
        if ensure_not_null_:
1069
            try: expr = ensure_not_null(db, expr)
1070 2860 aaronmk
            except KeyError: pass # unknown type, so just create plain index
1071 2823 aaronmk
1072 2688 aaronmk
        # Extract col
1073 3002 aaronmk
        expr = copy.deepcopy(expr) # don't modify input!
1074 2688 aaronmk
        if isinstance(expr, sql_gen.FunctionCall):
1075
            col = expr.args[0]
1076
            expr = sql_gen.Expr(expr)
1077
        else: col = expr
1078 2823 aaronmk
        assert isinstance(col, sql_gen.Col)
1079 2688 aaronmk
1080
        # Extract table
1081
        if table == None:
1082
            assert sql_gen.is_table_col(col)
1083
            table = col.table
1084
1085
        col.table = None
1086
1087
        exprs.append(expr)
1088
        cols.append(col)
1089 2408 aaronmk
1090 2688 aaronmk
    table = sql_gen.as_Table(table)
1091
    index = sql_gen.Table(str(sql_gen.Col(','.join(map(str, cols)), table)))
1092
1093 3005 aaronmk
    # Add index
1094
    while True:
1095
        str_ = 'CREATE'
1096
        if unique: str_ += ' UNIQUE'
1097
        str_ += ' INDEX '+index.to_str(db)+' ON '+table.to_str(db)+' ('+(
1098
            ', '.join((v.to_str(db) for v in exprs)))+')'
1099
1100
        try:
1101
            run_query(db, str_, recover=True, cacheable=True, log_level=3,
1102
                log_ignore_excs=(DuplicateException,))
1103
            break
1104
        except DuplicateException:
1105
            index.name = next_version(index.name)
1106
            # try again with next version of name
1107 2408 aaronmk
1108 2997 aaronmk
def add_index_col(db, col, suffix, expr, nullable=True):
1109 3000 aaronmk
    if sql_gen.index_col(col) != None: return # already has index col
1110 2997 aaronmk
1111
    new_col = sql_gen.suffixed_col(col, suffix)
1112
1113 3006 aaronmk
    # Add column
1114 3038 aaronmk
    new_typed_col = sql_gen.TypedCol(new_col.name, db.col_info(col).type)
1115 3045 aaronmk
    add_col(db, col.table, new_typed_col, comment='src: '+repr(col),
1116
        log_level=3)
1117 3037 aaronmk
    new_col.name = new_typed_col.name # propagate any renaming
1118 3006 aaronmk
1119 3064 aaronmk
    update(db, col.table, [(new_col, expr)], in_place=True, cacheable=True,
1120
        log_level=3)
1121 2997 aaronmk
    if not nullable: add_not_null(db, new_col)
1122
    add_index(db, new_col)
1123
1124 3104 aaronmk
    col.table.index_cols[col.name] = new_col.name
1125 2997 aaronmk
1126 3047 aaronmk
# Controls when ensure_not_null() will use index columns
1127
not_null_index_cols_min_rows = 0 # rows; initially always use index columns
1128
1129 2997 aaronmk
def ensure_not_null(db, col):
1130
    '''For params, see sql_gen.ensure_not_null()'''
1131
    expr = sql_gen.ensure_not_null(db, col)
1132
1133 3047 aaronmk
    # If a nullable column in a temp table, add separate index column instead.
1134
    # Note that for small datasources, this adds 6-25% to the total import time.
1135
    if (sql_gen.is_temp_col(col) and isinstance(expr, sql_gen.EnsureNotNull)
1136
        and table_row_count(db, col.table) >= not_null_index_cols_min_rows):
1137 2997 aaronmk
        add_index_col(db, col, '::NOT NULL', expr, nullable=False)
1138 3000 aaronmk
        expr = sql_gen.index_col(col)
1139 2997 aaronmk
1140
    return expr
1141
1142 3083 aaronmk
already_indexed = object() # tells add_indexes() the pkey has already been added
1143
1144
def add_indexes(db, table, has_pkey=True):
1145
    '''Adds an index on all columns in a table.
1146
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1147
        index should be added on the first column.
1148
        * If already_indexed, the pkey is assumed to have already been added
1149
    '''
1150
    cols = table_cols(db, table)
1151
    if has_pkey:
1152
        if has_pkey is not already_indexed: add_pkey(db, table)
1153
        cols = cols[1:]
1154
    for col in cols: add_index(db, col, table)
1155
1156 3079 aaronmk
#### Tables
1157 2772 aaronmk
1158 3079 aaronmk
### Maintenance
1159 2772 aaronmk
1160 3079 aaronmk
def analyze(db, table):
1161
    table = sql_gen.as_Table(table)
1162
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1163 2934 aaronmk
1164 3079 aaronmk
def autoanalyze(db, table):
1165
    if db.autoanalyze: analyze(db, table)
1166 2935 aaronmk
1167 3079 aaronmk
def vacuum(db, table):
1168
    table = sql_gen.as_Table(table)
1169
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1170
        log_level=3))
1171 2086 aaronmk
1172 3079 aaronmk
### Lifecycle
1173
1174 2889 aaronmk
def drop_table(db, table):
1175
    table = sql_gen.as_Table(table)
1176
    return run_query(db, 'DROP TABLE IF EXISTS '+table.to_str(db)+' CASCADE')
1177
1178 3082 aaronmk
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1179
    like=None):
1180 2675 aaronmk
    '''Creates a table.
1181 2681 aaronmk
    @param cols [sql_gen.TypedCol,...] The column names and types
1182
    @param has_pkey If set, the first column becomes the primary key.
1183 2760 aaronmk
    @param col_indexes bool|[ref]
1184
        * If True, indexes will be added on all non-pkey columns.
1185
        * If a list reference, [0] will be set to a function to do this.
1186
          This can be used to delay index creation until the table is populated.
1187 2675 aaronmk
    '''
1188
    table = sql_gen.as_Table(table)
1189
1190 3082 aaronmk
    if like != None:
1191
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1192
            ]+cols
1193 2681 aaronmk
    if has_pkey:
1194
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1195 2872 aaronmk
        pkey.constraints = 'PRIMARY KEY'
1196 2681 aaronmk
1197 3085 aaronmk
    temp = table.is_temp and not db.debug_temp
1198
        # temp tables permanent in debug_temp mode
1199 2760 aaronmk
1200 3085 aaronmk
    # Create table
1201
    while True:
1202
        str_ = 'CREATE'
1203
        if temp: str_ += ' TEMP'
1204
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1205
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1206 3126 aaronmk
        str_ += '\n);'
1207 3085 aaronmk
1208
        try:
1209 3127 aaronmk
            run_query(db, str_, recover=True, cacheable=True, log_level=2,
1210 3085 aaronmk
                log_ignore_excs=(DuplicateException,))
1211
            break
1212
        except DuplicateException:
1213
            table.name = next_version(table.name)
1214
            # try again with next version of name
1215
1216 2760 aaronmk
    # Add indexes
1217 2773 aaronmk
    if has_pkey: has_pkey = already_indexed
1218
    def add_indexes_(): add_indexes(db, table, has_pkey)
1219
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1220
    elif col_indexes: add_indexes_() # add now
1221 2675 aaronmk
1222 3084 aaronmk
def copy_table_struct(db, src, dest):
1223
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1224 3085 aaronmk
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1225 3084 aaronmk
1226 3079 aaronmk
### Data
1227 2684 aaronmk
1228 2970 aaronmk
def truncate(db, table, schema='public', **kw_args):
1229
    '''For params, see run_query()'''
1230 2777 aaronmk
    table = sql_gen.as_Table(table, schema)
1231 2970 aaronmk
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1232 2732 aaronmk
1233 2965 aaronmk
def empty_temp(db, tables):
1234 2972 aaronmk
    if db.debug_temp: return # leave temp tables there for debugging
1235 2965 aaronmk
    tables = lists.mk_seq(tables)
1236 2971 aaronmk
    for table in tables: truncate(db, table, log_level=3)
1237 2965 aaronmk
1238 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
1239
    '''For kw_args, see tables()'''
1240
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1241 3094 aaronmk
1242
def distinct_table(db, table, distinct_on):
1243
    '''Creates a copy of a temp table which is distinct on the given columns.
1244 3099 aaronmk
    The old and new tables will both get an index on these columns, to
1245
    facilitate merge joins.
1246 3097 aaronmk
    @param distinct_on If empty, creates a table with one row. This is useful if
1247
        your distinct_on columns are all literal values.
1248 3099 aaronmk
    @return The new table.
1249 3094 aaronmk
    '''
1250 3099 aaronmk
    new_table = sql_gen.suffixed_table(table, '_distinct')
1251 3094 aaronmk
1252 3099 aaronmk
    copy_table_struct(db, table, new_table)
1253 3097 aaronmk
1254
    limit = None
1255
    if distinct_on == []: limit = 1 # one sample row
1256 3099 aaronmk
    else:
1257
        add_index(db, distinct_on, new_table, unique=True)
1258
        add_index(db, distinct_on, table) # for join optimization
1259 3097 aaronmk
1260 3099 aaronmk
    insert_select(db, new_table, None, mk_select(db, table, start=0,
1261 3097 aaronmk
        limit=limit), ignore=True)
1262 3099 aaronmk
    analyze(db, new_table)
1263 3094 aaronmk
1264 3099 aaronmk
    return new_table