Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1960 aaronmk
import lists
11 1889 aaronmk
from Proxy import Proxy
12 1872 aaronmk
import rand
13 2217 aaronmk
import sql_gen
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 2804 aaronmk
def get_cur_query(cur, input_query=None):
20 2168 aaronmk
    raw_query = None
21
    if hasattr(cur, 'query'): raw_query = cur.query
22
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
23 2170 aaronmk
24
    if raw_query != None: return raw_query
25 2804 aaronmk
    else: return '[input] '+strings.ustr(input_query)
26 14 aaronmk
27 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
28
    '''For params, see get_cur_query()'''
29 2771 aaronmk
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
30 135 aaronmk
31 300 aaronmk
class DbException(exc.ExceptionWithCause):
32 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
33 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
34 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
35
36 2143 aaronmk
class ExceptionWithName(DbException):
37
    def __init__(self, name, cause=None):
38 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
39 2143 aaronmk
        self.name = name
40 360 aaronmk
41 3109 aaronmk
class ExceptionWithValue(DbException):
42
    def __init__(self, value, cause=None):
43
        DbException.__init__(self, 'for value: '+strings.as_tt(repr(value)),
44
            cause)
45 2240 aaronmk
        self.value = value
46
47 2945 aaronmk
class ExceptionWithNameType(DbException):
48
    def __init__(self, type_, name, cause=None):
49
        DbException.__init__(self, 'for type: '+strings.as_tt(str(type_))
50
            +'; name: '+strings.as_tt(name), cause)
51
        self.type = type_
52
        self.name = name
53
54 2306 aaronmk
class ConstraintException(DbException):
55
    def __init__(self, name, cols, cause=None):
56 2484 aaronmk
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
57
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
58 2306 aaronmk
        self.name = name
59 468 aaronmk
        self.cols = cols
60 11 aaronmk
61 2523 aaronmk
class MissingCastException(DbException):
62
    def __init__(self, type_, col, cause=None):
63
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
64
            +' on column: '+strings.as_tt(col), cause)
65
        self.type = type_
66
        self.col = col
67
68 2143 aaronmk
class NameException(DbException): pass
69
70 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
71 13 aaronmk
72 2306 aaronmk
class NullValueException(ConstraintException): pass
73 13 aaronmk
74 3109 aaronmk
class InvalidValueException(ExceptionWithValue): pass
75 2239 aaronmk
76 2945 aaronmk
class DuplicateException(ExceptionWithNameType): pass
77 2143 aaronmk
78 89 aaronmk
class EmptyRowException(DbException): pass
79
80 865 aaronmk
##### Warnings
81
82
class DbWarning(UserWarning): pass
83
84 1930 aaronmk
##### Result retrieval
85
86
def col_names(cur): return (col[0] for col in cur.description)
87
88
def rows(cur): return iter(lambda: cur.fetchone(), None)
89
90
def consume_rows(cur):
91
    '''Used to fetch all rows so result will be cached'''
92
    iters.consume_iter(rows(cur))
93
94
def next_row(cur): return rows(cur).next()
95
96
def row(cur):
97
    row_ = next_row(cur)
98
    consume_rows(cur)
99
    return row_
100
101
def next_value(cur): return next_row(cur)[0]
102
103
def value(cur): return row(cur)[0]
104
105
def values(cur): return iters.func_iter(lambda: next_value(cur))
106
107
def value_or_none(cur):
108
    try: return value(cur)
109
    except StopIteration: return None
110
111 2762 aaronmk
##### Escaping
112 2101 aaronmk
113 2573 aaronmk
def esc_name_by_module(module, name):
114
    if module == 'psycopg2' or module == None: quote = '"'
115 2101 aaronmk
    elif module == 'MySQLdb': quote = '`'
116
    else: raise NotImplementedError("Can't escape name for "+module+' database')
117 2500 aaronmk
    return sql_gen.esc_name(name, quote)
118 2101 aaronmk
119
def esc_name_by_engine(engine, name, **kw_args):
120
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
121
122
def esc_name(db, name, **kw_args):
123
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
124
125
def qual_name(db, schema, table):
126
    def esc_name_(name): return esc_name(db, name)
127
    table = esc_name_(table)
128
    if schema != None: return esc_name_(schema)+'.'+table
129
    else: return table
130
131 1869 aaronmk
##### Database connections
132 1849 aaronmk
133 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
134 1926 aaronmk
135 1869 aaronmk
db_engines = {
136
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
137
    'PostgreSQL': ('psycopg2', {}),
138
}
139
140
DatabaseErrors_set = set([DbException])
141
DatabaseErrors = tuple(DatabaseErrors_set)
142
143
def _add_module(module):
144
    DatabaseErrors_set.add(module.DatabaseError)
145
    global DatabaseErrors
146
    DatabaseErrors = tuple(DatabaseErrors_set)
147
148
def db_config_str(db_config):
149
    return db_config['engine']+' database '+db_config['database']
150
151 2448 aaronmk
log_debug_none = lambda msg, level=2: None
152 1901 aaronmk
153 1849 aaronmk
class DbConn:
154 2923 aaronmk
    def __init__(self, db_config, autocommit=True, caching=True,
155 2915 aaronmk
        log_debug=log_debug_none, debug_temp=False):
156
        '''
157
        @param debug_temp Whether temporary objects should instead be permanent.
158
            This assists in debugging the internal objects used by the program.
159
        '''
160 1869 aaronmk
        self.db_config = db_config
161 2190 aaronmk
        self.autocommit = autocommit
162
        self.caching = caching
163 1901 aaronmk
        self.log_debug = log_debug
164 2193 aaronmk
        self.debug = log_debug != log_debug_none
165 2915 aaronmk
        self.debug_temp = debug_temp
166 3074 aaronmk
        self.autoanalyze = False
167 1869 aaronmk
168 3124 aaronmk
        self._savepoint = 0
169 3120 aaronmk
        self._reset()
170 1869 aaronmk
171
    def __getattr__(self, name):
172
        if name == '__dict__': raise Exception('getting __dict__')
173
        if name == 'db': return self._db()
174
        else: raise AttributeError()
175
176
    def __getstate__(self):
177
        state = copy.copy(self.__dict__) # shallow copy
178 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
179 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
180
        return state
181
182 3118 aaronmk
    def clear_cache(self): self.query_results = {}
183
184 3120 aaronmk
    def _reset(self):
185 3118 aaronmk
        self.clear_cache()
186 3124 aaronmk
        assert self._savepoint == 0
187 3118 aaronmk
        self._notices_seen = set()
188
        self.__db = None
189
190 2165 aaronmk
    def connected(self): return self.__db != None
191
192 3116 aaronmk
    def close(self):
193 3119 aaronmk
        if not self.connected(): return
194
195 3135 aaronmk
        # Record that the automatic transaction is now closed
196 3136 aaronmk
        self._savepoint -= 1
197 3135 aaronmk
198 3119 aaronmk
        self.db.close()
199 3120 aaronmk
        self._reset()
200 3116 aaronmk
201 3125 aaronmk
    def reconnect(self):
202
        # Do not do this in test mode as it would roll back everything
203
        if self.autocommit: self.close()
204
        # Connection will be reopened automatically on first query
205
206 1869 aaronmk
    def _db(self):
207
        if self.__db == None:
208
            # Process db_config
209
            db_config = self.db_config.copy() # don't modify input!
210 2097 aaronmk
            schemas = db_config.pop('schemas', None)
211 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
212
            module = __import__(module_name)
213
            _add_module(module)
214
            for orig, new in mappings.iteritems():
215
                try: util.rename_key(db_config, orig, new)
216
                except KeyError: pass
217
218
            # Connect
219
            self.__db = module.connect(**db_config)
220
221
            # Configure connection
222 2906 aaronmk
            if hasattr(self.db, 'set_isolation_level'):
223
                import psycopg2.extensions
224
                self.db.set_isolation_level(
225
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
226 2101 aaronmk
            if schemas != None:
227 2893 aaronmk
                search_path = [self.esc_name(s) for s in schemas.split(',')]
228
                search_path.append(value(run_query(self, 'SHOW search_path',
229
                    log_level=4)))
230
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
231
                    log_level=3)
232 3135 aaronmk
233
            # Record that a transaction is already open
234 3136 aaronmk
            self._savepoint += 1
235 1869 aaronmk
236
        return self.__db
237 1889 aaronmk
238 1891 aaronmk
    class DbCursor(Proxy):
239 1927 aaronmk
        def __init__(self, outer):
240 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
241 2191 aaronmk
            self.outer = outer
242 1927 aaronmk
            self.query_results = outer.query_results
243 1894 aaronmk
            self.query_lookup = None
244 1891 aaronmk
            self.result = []
245 1889 aaronmk
246 2802 aaronmk
        def execute(self, query):
247 2764 aaronmk
            self._is_insert = query.startswith('INSERT')
248 2797 aaronmk
            self.query_lookup = query
249 2148 aaronmk
            try:
250 2191 aaronmk
                try:
251 2802 aaronmk
                    cur = self.inner.execute(query)
252 2191 aaronmk
                    self.outer.do_autocommit()
253 2802 aaronmk
                finally: self.query = get_cur_query(self.inner, query)
254 1904 aaronmk
            except Exception, e:
255 2802 aaronmk
                _add_cursor_info(e, self, query)
256 1904 aaronmk
                self.result = e # cache the exception as the result
257
                self._cache_result()
258
                raise
259 3004 aaronmk
260
            # Always cache certain queries
261
            if query.startswith('CREATE') or query.startswith('ALTER'):
262 3007 aaronmk
                # structural changes
263 3040 aaronmk
                # Rest of query must be unique in the face of name collisions,
264
                # so don't cache ADD COLUMN unless it has distinguishing comment
265
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
266 3007 aaronmk
                    self._cache_result()
267 3004 aaronmk
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
268 2800 aaronmk
                consume_rows(self) # fetch all rows so result will be cached
269 3004 aaronmk
270 2762 aaronmk
            return cur
271 1894 aaronmk
272 1891 aaronmk
        def fetchone(self):
273
            row = self.inner.fetchone()
274 1899 aaronmk
            if row != None: self.result.append(row)
275
            # otherwise, fetched all rows
276 1904 aaronmk
            else: self._cache_result()
277
            return row
278
279
        def _cache_result(self):
280 2948 aaronmk
            # For inserts that return a result set, don't cache result set since
281
            # inserts are not idempotent. Other non-SELECT queries don't have
282
            # their result set read, so only exceptions will be cached (an
283
            # invalid query will always be invalid).
284 1930 aaronmk
            if self.query_results != None and (not self._is_insert
285 1906 aaronmk
                or isinstance(self.result, Exception)):
286
287 1894 aaronmk
                assert self.query_lookup != None
288 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
289
                    util.dict_subset(dicts.AttrsDictView(self),
290
                    ['query', 'result', 'rowcount', 'description']))
291 1906 aaronmk
292 1916 aaronmk
        class CacheCursor:
293
            def __init__(self, cached_result): self.__dict__ = cached_result
294
295 1927 aaronmk
            def execute(self, *args, **kw_args):
296 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
297
                # otherwise, result is a rows list
298
                self.iter = iter(self.result)
299
300
            def fetchone(self):
301
                try: return self.iter.next()
302
                except StopIteration: return None
303 1891 aaronmk
304 2212 aaronmk
    def esc_value(self, value):
305 2663 aaronmk
        try: str_ = self.mogrify('%s', [value])
306
        except NotImplementedError, e:
307
            module = util.root_module(self.db)
308
            if module == 'MySQLdb':
309
                import _mysql
310
                str_ = _mysql.escape_string(value)
311
            else: raise e
312 2374 aaronmk
        return strings.to_unicode(str_)
313 2212 aaronmk
314 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
315
316 2814 aaronmk
    def std_code(self, str_):
317
        '''Standardizes SQL code.
318
        * Ensures that string literals are prefixed by `E`
319
        '''
320
        if str_.startswith("'"): str_ = 'E'+str_
321
        return str_
322
323 2665 aaronmk
    def can_mogrify(self):
324 2663 aaronmk
        module = util.root_module(self.db)
325 2665 aaronmk
        return module == 'psycopg2'
326 2663 aaronmk
327 2665 aaronmk
    def mogrify(self, query, params=None):
328
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
329
        else: raise NotImplementedError("Can't mogrify query")
330
331 2671 aaronmk
    def print_notices(self):
332 2725 aaronmk
        if hasattr(self.db, 'notices'):
333
            for msg in self.db.notices:
334
                if msg not in self._notices_seen:
335
                    self._notices_seen.add(msg)
336
                    self.log_debug(msg, level=2)
337 2671 aaronmk
338 2793 aaronmk
    def run_query(self, query, cacheable=False, log_level=2,
339 2464 aaronmk
        debug_msg_ref=None):
340 2445 aaronmk
        '''
341 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
342
            throws one of these exceptions.
343 2664 aaronmk
        @param debug_msg_ref If specified, the log message will be returned in
344
            this instead of being output. This allows you to filter log messages
345
            depending on the result of the query.
346 2445 aaronmk
        '''
347 2167 aaronmk
        assert query != None
348
349 2047 aaronmk
        if not self.caching: cacheable = False
350 1903 aaronmk
        used_cache = False
351 2664 aaronmk
352
        def log_msg(query):
353
            if used_cache: cache_status = 'cache hit'
354
            elif cacheable: cache_status = 'cache miss'
355
            else: cache_status = 'non-cacheable'
356
            return 'DB query: '+cache_status+':\n'+strings.as_code(query, 'SQL')
357
358 1903 aaronmk
        try:
359 1927 aaronmk
            # Get cursor
360
            if cacheable:
361
                try:
362 2797 aaronmk
                    cur = self.query_results[query]
363 1927 aaronmk
                    used_cache = True
364
                except KeyError: cur = self.DbCursor(self)
365
            else: cur = self.db.cursor()
366
367 2664 aaronmk
            # Log query
368
            if self.debug and debug_msg_ref == None: # log before running
369
                self.log_debug(log_msg(query), log_level)
370
371 1927 aaronmk
            # Run query
372 2793 aaronmk
            cur.execute(query)
373 1903 aaronmk
        finally:
374 2671 aaronmk
            self.print_notices()
375 2664 aaronmk
            if self.debug and debug_msg_ref != None: # return after running
376 2793 aaronmk
                debug_msg_ref[0] = log_msg(str(get_cur_query(cur, query)))
377 1903 aaronmk
378
        return cur
379 1914 aaronmk
380 2797 aaronmk
    def is_cached(self, query): return query in self.query_results
381 2139 aaronmk
382 2907 aaronmk
    def with_autocommit(self, func):
383 2801 aaronmk
        import psycopg2.extensions
384
385
        prev_isolation_level = self.db.isolation_level
386 2907 aaronmk
        self.db.set_isolation_level(
387
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
388 2683 aaronmk
        try: return func()
389 2801 aaronmk
        finally: self.db.set_isolation_level(prev_isolation_level)
390 2683 aaronmk
391 2139 aaronmk
    def with_savepoint(self, func):
392 3137 aaronmk
        top = self._savepoint == 0
393 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
394 3137 aaronmk
395
        if top: query = 'START TRANSACTION ISOLATION LEVEL READ COMMITTED'
396
        else: query = 'SAVEPOINT '+savepoint
397
        self.run_query(query, log_level=4)
398
399 2139 aaronmk
        self._savepoint += 1
400 3137 aaronmk
        try:
401
            return func()
402
            if top: self.run_query('COMMIT', log_level=4)
403 2139 aaronmk
        except:
404 3137 aaronmk
            if top: query = 'ROLLBACK'
405
            else: query = 'ROLLBACK TO SAVEPOINT '+savepoint
406
            self.run_query(query, log_level=4)
407
408 2139 aaronmk
            raise
409 2930 aaronmk
        finally:
410
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
411
            # "The savepoint remains valid and can be rolled back to again"
412
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
413 3137 aaronmk
            if not top:
414
                self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
415 2930 aaronmk
416
            self._savepoint -= 1
417
            assert self._savepoint >= 0
418
419
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
420 2191 aaronmk
421
    def do_autocommit(self):
422
        '''Autocommits if outside savepoint'''
423 3135 aaronmk
        assert self._savepoint >= 1
424
        if self.autocommit and self._savepoint == 1:
425 2924 aaronmk
            self.log_debug('Autocommitting', level=4)
426 2191 aaronmk
            self.db.commit()
427 2643 aaronmk
428 3150 aaronmk
    def col_info(self, col, cacheable=False):
429 2643 aaronmk
        table = sql_gen.Table('columns', 'information_schema')
430 3063 aaronmk
        type_ = sql_gen.Coalesce(sql_gen.Nullif(sql_gen.Col('data_type'),
431
            'USER-DEFINED'), sql_gen.Col('udt_name'))
432 3078 aaronmk
        cols = [type_, 'column_default',
433
            sql_gen.Cast('boolean', sql_gen.Col('is_nullable'))]
434 2643 aaronmk
435
        conds = [('table_name', col.table.name), ('column_name', col.name)]
436
        schema = col.table.schema
437
        if schema != None: conds.append(('table_schema', schema))
438
439 2819 aaronmk
        type_, default, nullable = row(select(self, table, cols, conds,
440 3150 aaronmk
            order_by='table_schema', limit=1, cacheable=cacheable, log_level=4))
441 2643 aaronmk
            # TODO: order_by search_path schema order
442 2819 aaronmk
        default = sql_gen.as_Code(default, self)
443
444
        return sql_gen.TypedCol(col.name, type_, default, nullable)
445 2917 aaronmk
446
    def TempFunction(self, name):
447
        if self.debug_temp: schema = None
448
        else: schema = 'pg_temp'
449
        return sql_gen.Function(name, schema)
450 1849 aaronmk
451 1869 aaronmk
connect = DbConn
452
453 832 aaronmk
##### Recoverable querying
454 15 aaronmk
455 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
456 11 aaronmk
457 2791 aaronmk
def run_query(db, query, recover=None, cacheable=False, log_level=2,
458
    log_ignore_excs=None, **kw_args):
459 2794 aaronmk
    '''For params, see DbConn.run_query()'''
460 830 aaronmk
    if recover == None: recover = False
461 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
462
    log_ignore_excs = tuple(log_ignore_excs)
463 830 aaronmk
464 2666 aaronmk
    debug_msg_ref = None # usually, db.run_query() logs query before running it
465
    # But if filtering with log_ignore_excs, wait until after exception parsing
466 2984 aaronmk
    if log_ignore_excs != () or not db.can_mogrify(): debug_msg_ref = [None]
467 2666 aaronmk
468 2148 aaronmk
    try:
469 2464 aaronmk
        try:
470 2794 aaronmk
            def run(): return db.run_query(query, cacheable, log_level,
471 2793 aaronmk
                debug_msg_ref, **kw_args)
472 2796 aaronmk
            if recover and not db.is_cached(query):
473 2464 aaronmk
                return with_savepoint(db, run)
474
            else: return run() # don't need savepoint if cached
475
        except Exception, e:
476 3095 aaronmk
            msg = strings.ustr(e.args[0])
477 2464 aaronmk
478 3095 aaronmk
            match = re.match(r'^duplicate key value violates unique constraint '
479 3096 aaronmk
                r'"((_?[^\W_]+(?=[._]))?.+?)"', msg)
480 2464 aaronmk
            if match:
481
                constraint, table = match.groups()
482 3025 aaronmk
                cols = []
483
                if recover: # need auto-rollback to run index_cols()
484
                    try: cols = index_cols(db, table, constraint)
485
                    except NotImplementedError: pass
486
                raise DuplicateKeyException(constraint, cols, e)
487 2464 aaronmk
488 3095 aaronmk
            match = re.match(r'^null value in column "(.+?)" violates not-null'
489 2464 aaronmk
                r' constraint', msg)
490
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
491
492 3095 aaronmk
            match = re.match(r'^(?:invalid input (?:syntax|value)\b.*?'
493 3109 aaronmk
                r'|.+? field value out of range): "(.+?)"', msg)
494 2464 aaronmk
            if match:
495 3109 aaronmk
                value, = match.groups()
496
                raise InvalidValueException(strings.to_unicode(value), e)
497 2464 aaronmk
498 3095 aaronmk
            match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
499 2523 aaronmk
                r'is of type', msg)
500
            if match:
501
                col, type_ = match.groups()
502
                raise MissingCastException(type_, col, e)
503
504 3095 aaronmk
            match = re.match(r'^(\S+) "(.+?)".*? already exists', msg)
505 2945 aaronmk
            if match:
506
                type_, name = match.groups()
507
                raise DuplicateException(type_, name, e)
508 2464 aaronmk
509
            raise # no specific exception raised
510
    except log_ignore_excs:
511
        log_level += 2
512
        raise
513
    finally:
514 2666 aaronmk
        if debug_msg_ref != None and debug_msg_ref[0] != None:
515
            db.log_debug(debug_msg_ref[0], log_level)
516 830 aaronmk
517 832 aaronmk
##### Basic queries
518
519 2153 aaronmk
def next_version(name):
520 2163 aaronmk
    version = 1 # first existing name was version 0
521 2586 aaronmk
    match = re.match(r'^(.*)#(\d+)$', name)
522 2153 aaronmk
    if match:
523 2586 aaronmk
        name, version = match.groups()
524
        version = int(version)+1
525 2932 aaronmk
    return sql_gen.concat(name, '#'+str(version))
526 2153 aaronmk
527 2899 aaronmk
def lock_table(db, table, mode):
528
    table = sql_gen.as_Table(table)
529
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
530
531 2789 aaronmk
def run_query_into(db, query, into=None, add_indexes_=False, **kw_args):
532 2085 aaronmk
    '''Outputs a query to a temp table.
533
    For params, see run_query().
534
    '''
535 2789 aaronmk
    if into == None: return run_query(db, query, **kw_args)
536 2790 aaronmk
537
    assert isinstance(into, sql_gen.Table)
538
539 2992 aaronmk
    into.is_temp = True
540 3008 aaronmk
    # "temporary tables cannot specify a schema name", so remove schema
541
    into.schema = None
542 2992 aaronmk
543 2790 aaronmk
    kw_args['recover'] = True
544 2945 aaronmk
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
545 2790 aaronmk
546 2916 aaronmk
    temp = not db.debug_temp # tables are permanent in debug_temp mode
547 2790 aaronmk
548
    # Create table
549
    while True:
550
        create_query = 'CREATE'
551
        if temp: create_query += ' TEMP'
552
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
553 2385 aaronmk
554 2790 aaronmk
        try:
555
            cur = run_query(db, create_query, **kw_args)
556
                # CREATE TABLE AS sets rowcount to # rows in query
557
            break
558 2945 aaronmk
        except DuplicateException, e:
559 2790 aaronmk
            into.name = next_version(into.name)
560
            # try again with next version of name
561
562
    if add_indexes_: add_indexes(db, into)
563 3075 aaronmk
564
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
565
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
566
    # table is going to be used in complex queries, it is wise to run ANALYZE on
567
    # the temporary table after it is populated."
568
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
569
    # If into is not a temp table, ANALYZE is useful but not required.
570 3073 aaronmk
    analyze(db, into)
571 2790 aaronmk
572
    return cur
573 2085 aaronmk
574 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
575
576 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
577
578 2233 aaronmk
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
579 2293 aaronmk
    start=None, order_by=order_by_pkey, default_table=None):
580 1981 aaronmk
    '''
581 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
582 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
583 1981 aaronmk
    @param fields Use None to select all fields in the table
584 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
585 2379 aaronmk
        * container can be any iterable type
586 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
587
        * compare_right_side: sql_gen.ValueCond|literal value
588 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
589
        use all columns
590 2786 aaronmk
    @return query
591 1981 aaronmk
    '''
592 2315 aaronmk
    # Parse tables param
593 2964 aaronmk
    tables = lists.mk_seq(tables)
594 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
595 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
596 2121 aaronmk
597 2315 aaronmk
    # Parse other params
598 2376 aaronmk
    if conds == None: conds = []
599 2650 aaronmk
    elif dicts.is_dict(conds): conds = conds.items()
600 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
601 3129 aaronmk
    assert limit == None or isinstance(limit, (int, long))
602
    assert start == None or isinstance(start, (int, long))
603 2315 aaronmk
    if order_by is order_by_pkey:
604
        if distinct_on != []: order_by = None
605
        else: order_by = pkey(db, table0, recover=True)
606 865 aaronmk
607 2315 aaronmk
    query = 'SELECT'
608 2056 aaronmk
609 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
610 2056 aaronmk
611 2200 aaronmk
    # DISTINCT ON columns
612 2233 aaronmk
    if distinct_on != []:
613 2467 aaronmk
        query += '\nDISTINCT'
614 2254 aaronmk
        if distinct_on is not distinct_on_all:
615 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
616
617
    # Columns
618 3027 aaronmk
    if fields == None:
619
        if query.find('\n') >= 0: whitespace = '\n'
620
        else: whitespace = ' '
621
        query += whitespace+'*'
622 2765 aaronmk
    else:
623
        assert fields != []
624 3027 aaronmk
        query += '\n'+('\n, '.join(map(parse_col, fields)))
625 2200 aaronmk
626
    # Main table
627 2467 aaronmk
    query += '\nFROM '+table0.to_str(db)
628 865 aaronmk
629 2122 aaronmk
    # Add joins
630 2271 aaronmk
    left_table = table0
631 2263 aaronmk
    for join_ in tables:
632
        table = join_.table
633 2238 aaronmk
634 2343 aaronmk
        # Parse special values
635
        if join_.type_ is sql_gen.filter_out: # filter no match
636 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
637 2853 aaronmk
                sql_gen.CompareCond(None, '~=')))
638 2343 aaronmk
639 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
640 2122 aaronmk
641
        left_table = table
642
643 865 aaronmk
    missing = True
644 2376 aaronmk
    if conds != []:
645 2576 aaronmk
        if len(conds) == 1: whitespace = ' '
646
        else: whitespace = '\n'
647 2578 aaronmk
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
648
            .to_str(db) for l, r in conds], 'WHERE')
649 865 aaronmk
        missing = False
650 2227 aaronmk
    if order_by != None:
651 2467 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
652
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
653 865 aaronmk
    if start != None:
654 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
655 865 aaronmk
        missing = False
656
    if missing: warnings.warn(DbWarning(
657
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
658
659 2786 aaronmk
    return query
660 11 aaronmk
661 2054 aaronmk
def select(db, *args, **kw_args):
662
    '''For params, see mk_select() and run_query()'''
663
    recover = kw_args.pop('recover', None)
664
    cacheable = kw_args.pop('cacheable', True)
665 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
666 2054 aaronmk
667 2791 aaronmk
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
668
        log_level=log_level)
669 2054 aaronmk
670 2788 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
671 3009 aaronmk
    embeddable=False, ignore=False):
672 1960 aaronmk
    '''
673
    @param returning str|None An inserted column (such as pkey) to return
674 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
675 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
676
        query will be fully cached, not just if it raises an exception.
677 3009 aaronmk
    @param ignore Whether to ignore duplicate keys.
678 1960 aaronmk
    '''
679 2754 aaronmk
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
680 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
681 3010 aaronmk
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
682 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
683 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
684 2063 aaronmk
685 2497 aaronmk
    first_line = 'INSERT INTO '+table.to_str(db)
686 2063 aaronmk
687 3009 aaronmk
    def mk_insert(select_query):
688
        query = first_line
689 3014 aaronmk
        if cols != None:
690
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
691 3009 aaronmk
        query += '\n'+select_query
692
693
        if returning != None:
694
            returning_name_col = sql_gen.to_name_only_col(returning)
695
            query += '\nRETURNING '+returning_name_col.to_str(db)
696
697
        return query
698 2063 aaronmk
699 3017 aaronmk
    return_type = 'unknown'
700
    if returning != None: return_type = returning.to_str(db)+'%TYPE'
701
702 3009 aaronmk
    lang = 'sql'
703
    if ignore:
704 3017 aaronmk
        # Always return something to set the correct rowcount
705
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
706
707 3009 aaronmk
        embeddable = True # must use function
708
        lang = 'plpgsql'
709 3010 aaronmk
710 3092 aaronmk
        if cols == None:
711
            row = [sql_gen.Col(sql_gen.all_cols, 'row')]
712
            row_vars = [sql_gen.Table('row')]
713
        else:
714
            row_vars = row = [sql_gen.Col(c.name, 'row') for c in cols]
715
716 3009 aaronmk
        query = '''\
717 3010 aaronmk
DECLARE
718 3014 aaronmk
    row '''+table.to_str(db)+'''%ROWTYPE;
719 3009 aaronmk
BEGIN
720 3019 aaronmk
    /* Need an EXCEPTION block for each individual row because "When an error is
721
    caught by an EXCEPTION clause, [...] all changes to persistent database
722
    state within the block are rolled back."
723
    This is unfortunate because "A block containing an EXCEPTION clause is
724
    significantly more expensive to enter and exit than a block without one."
725 3015 aaronmk
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
726
#PLPGSQL-ERROR-TRAPPING)
727
    */
728 3092 aaronmk
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
729 3034 aaronmk
'''+select_query+'''
730
    LOOP
731 3015 aaronmk
        BEGIN
732 3019 aaronmk
            RETURN QUERY
733 3014 aaronmk
'''+mk_insert(sql_gen.Values(row).to_str(db))+'''
734 3010 aaronmk
;
735 3015 aaronmk
        EXCEPTION
736 3019 aaronmk
            WHEN unique_violation THEN NULL; -- continue to next row
737 3015 aaronmk
        END;
738 3010 aaronmk
    END LOOP;
739
END;\
740 3009 aaronmk
'''
741
    else: query = mk_insert(select_query)
742
743 2070 aaronmk
    if embeddable:
744
        # Create function
745 2513 aaronmk
        function_name = sql_gen.clean_name(first_line)
746 2189 aaronmk
        while True:
747
            try:
748 2918 aaronmk
                function = db.TempFunction(function_name)
749 2194 aaronmk
750 2189 aaronmk
                function_query = '''\
751 2698 aaronmk
CREATE FUNCTION '''+function.to_str(db)+'''()
752 3017 aaronmk
RETURNS SETOF '''+return_type+'''
753 3009 aaronmk
LANGUAGE '''+lang+'''
754 2467 aaronmk
AS $$
755 3009 aaronmk
'''+query+'''
756 2467 aaronmk
$$;
757 2070 aaronmk
'''
758 2446 aaronmk
                run_query(db, function_query, recover=True, cacheable=True,
759 2945 aaronmk
                    log_ignore_excs=(DuplicateException,))
760 2189 aaronmk
                break # this version was successful
761 2945 aaronmk
            except DuplicateException, e:
762 2189 aaronmk
                function_name = next_version(function_name)
763
                # try again with next version of name
764 2070 aaronmk
765 2337 aaronmk
        # Return query that uses function
766 3009 aaronmk
        cols = None
767
        if returning != None: cols = [returning]
768 2698 aaronmk
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
769 3009 aaronmk
            cols) # AS clause requires function alias
770 2787 aaronmk
        return mk_select(db, func_table, start=0, order_by=None)
771 2070 aaronmk
772 2787 aaronmk
    return query
773 2066 aaronmk
774 3074 aaronmk
def insert_select(db, table, *args, **kw_args):
775 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
776 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
777
        values in
778 2072 aaronmk
    '''
779 3141 aaronmk
    returning = kw_args.get('returning', None)
780
    ignore = kw_args.get('ignore', False)
781
782 2386 aaronmk
    into = kw_args.pop('into', None)
783
    if into != None: kw_args['embeddable'] = True
784 2066 aaronmk
    recover = kw_args.pop('recover', None)
785 3141 aaronmk
    if ignore: recover = True
786 2066 aaronmk
    cacheable = kw_args.pop('cacheable', True)
787 2673 aaronmk
    log_level = kw_args.pop('log_level', 2)
788 2066 aaronmk
789 3141 aaronmk
    rowcount_only = ignore and returning == None # keep NULL rows on server
790
    if rowcount_only: into = sql_gen.Table('rowcount')
791
792 3074 aaronmk
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
793
        into, recover=recover, cacheable=cacheable, log_level=log_level)
794 3141 aaronmk
    if rowcount_only: empty_temp(db, into)
795 3074 aaronmk
    autoanalyze(db, table)
796
    return cur
797 2063 aaronmk
798 2738 aaronmk
default = sql_gen.default # tells insert() to use the default value for a column
799 2066 aaronmk
800 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
801 2085 aaronmk
    '''For params, see insert_select()'''
802 1960 aaronmk
    if lists.is_seq(row): cols = None
803
    else:
804
        cols = row.keys()
805
        row = row.values()
806 2738 aaronmk
    row = list(row) # ensure that "== []" works
807 1960 aaronmk
808 2738 aaronmk
    if row == []: query = None
809
    else: query = sql_gen.Values(row).to_str(db)
810 1961 aaronmk
811 2788 aaronmk
    return insert_select(db, table, cols, query, *args, **kw_args)
812 11 aaronmk
813 3056 aaronmk
def mk_update(db, table, changes=None, cond=None, in_place=False):
814 2402 aaronmk
    '''
815
    @param changes [(col, new_value),...]
816
        * container can be any iterable type
817
        * col: sql_gen.Code|str (for col name)
818
        * new_value: sql_gen.Code|literal value
819
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
820 3056 aaronmk
    @param in_place If set, locks the table and updates rows in place.
821
        This avoids creating dead rows in PostgreSQL.
822
        * cond must be None
823 2402 aaronmk
    @return str query
824
    '''
825 3057 aaronmk
    table = sql_gen.as_Table(table)
826
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
827
        for c, v in changes]
828
829 3056 aaronmk
    if in_place:
830
        assert cond == None
831 3058 aaronmk
832 3065 aaronmk
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
833
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
834
            +db.col_info(sql_gen.with_default_table(c, table)).type
835
            +'\nUSING '+v.to_str(db) for c, v in changes))
836 3058 aaronmk
    else:
837
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
838
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
839
            for c, v in changes))
840
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
841 3056 aaronmk
842 2402 aaronmk
    return query
843
844 3074 aaronmk
def update(db, table, *args, **kw_args):
845 2402 aaronmk
    '''For params, see mk_update() and run_query()'''
846
    recover = kw_args.pop('recover', None)
847 3043 aaronmk
    cacheable = kw_args.pop('cacheable', False)
848 3030 aaronmk
    log_level = kw_args.pop('log_level', 2)
849 2402 aaronmk
850 3074 aaronmk
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
851
        cacheable, log_level=log_level)
852
    autoanalyze(db, table)
853
    return cur
854 2402 aaronmk
855 135 aaronmk
def last_insert_id(db):
856 1849 aaronmk
    module = util.root_module(db.db)
857 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
858
    elif module == 'MySQLdb': return db.insert_id()
859
    else: return None
860 13 aaronmk
861 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
862 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
863 2415 aaronmk
    to names that will be distinct among the columns' tables.
864 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
865 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
866
    @param into The table for the new columns.
867 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
868
        columns will be included in the mapping even if they are not in cols.
869
        The tables of the provided Col objects will be changed to into, so make
870
        copies of them if you want to keep the original tables.
871
    @param as_items Whether to return a list of dict items instead of a dict
872 2383 aaronmk
    @return dict(orig_col=new_col, ...)
873
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
874 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
875
        * All mappings use the into table so its name can easily be
876 2383 aaronmk
          changed for all columns at once
877
    '''
878 2415 aaronmk
    cols = lists.uniqify(cols)
879
880 2394 aaronmk
    items = []
881 2389 aaronmk
    for col in preserve:
882 2390 aaronmk
        orig_col = copy.copy(col)
883 2392 aaronmk
        col.table = into
884 2394 aaronmk
        items.append((orig_col, col))
885
    preserve = set(preserve)
886
    for col in cols:
887 2716 aaronmk
        if col not in preserve:
888
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
889 2394 aaronmk
890
    if not as_items: items = dict(items)
891
    return items
892 2383 aaronmk
893 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
894 2391 aaronmk
    '''For params, see mk_flatten_mapping()
895
    @return See return value of mk_flatten_mapping()
896
    '''
897 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
898
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
899 2786 aaronmk
    run_query_into(db, mk_select(db, joins, cols, limit=limit, start=start),
900 2846 aaronmk
        into=into, add_indexes_=True)
901 2394 aaronmk
    return dict(items)
902 2391 aaronmk
903 3079 aaronmk
##### Database structure introspection
904 2414 aaronmk
905 3079 aaronmk
#### Tables
906
907
def tables(db, schema_like='public', table_like='%', exact=False):
908
    if exact: compare = '='
909
    else: compare = 'LIKE'
910
911
    module = util.root_module(db.db)
912
    if module == 'psycopg2':
913
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
914
            ('tablename', sql_gen.CompareCond(table_like, compare))]
915
        return values(select(db, 'pg_tables', ['tablename'], conds,
916
            order_by='tablename', log_level=4))
917
    elif module == 'MySQLdb':
918
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
919
            , cacheable=True, log_level=4))
920
    else: raise NotImplementedError("Can't list tables for "+module+' database')
921
922
def table_exists(db, table):
923
    table = sql_gen.as_Table(table)
924
    return list(tables(db, table.schema, table.name, exact=True)) != []
925
926 2426 aaronmk
def table_row_count(db, table, recover=None):
927 2786 aaronmk
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
928 2443 aaronmk
        order_by=None, start=0), recover=recover, log_level=3))
929 2426 aaronmk
930 2414 aaronmk
def table_cols(db, table, recover=None):
931
    return list(col_names(select(db, table, limit=0, order_by=None,
932 2443 aaronmk
        recover=recover, log_level=4)))
933 2414 aaronmk
934 2291 aaronmk
def pkey(db, table, recover=None):
935 832 aaronmk
    '''Assumed to be first column in table'''
936 2339 aaronmk
    return table_cols(db, table, recover)[0]
937 832 aaronmk
938 2559 aaronmk
not_null_col = 'not_null_col'
939 2340 aaronmk
940
def table_not_null_col(db, table, recover=None):
941
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
942
    if not_null_col in table_cols(db, table, recover): return not_null_col
943
    else: return pkey(db, table, recover)
944
945 853 aaronmk
def index_cols(db, table, index):
946
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
947
    automatically created. When you don't know whether something is a UNIQUE
948
    constraint or a UNIQUE index, use this function.'''
949 1909 aaronmk
    module = util.root_module(db.db)
950
    if module == 'psycopg2':
951
        return list(values(run_query(db, '''\
952 853 aaronmk
SELECT attname
953 866 aaronmk
FROM
954
(
955
        SELECT attnum, attname
956
        FROM pg_index
957
        JOIN pg_class index ON index.oid = indexrelid
958
        JOIN pg_class table_ ON table_.oid = indrelid
959
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
960
        WHERE
961 2782 aaronmk
            table_.relname = '''+db.esc_value(table)+'''
962
            AND index.relname = '''+db.esc_value(index)+'''
963 866 aaronmk
    UNION
964
        SELECT attnum, attname
965
        FROM
966
        (
967
            SELECT
968
                indrelid
969
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
970
                    AS indkey
971
            FROM pg_index
972
            JOIN pg_class index ON index.oid = indexrelid
973
            JOIN pg_class table_ ON table_.oid = indrelid
974
            WHERE
975 2782 aaronmk
                table_.relname = '''+db.esc_value(table)+'''
976
                AND index.relname = '''+db.esc_value(index)+'''
977 866 aaronmk
        ) s
978
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
979
) s
980 853 aaronmk
ORDER BY attnum
981 2782 aaronmk
'''
982
            , cacheable=True, log_level=4)))
983 1909 aaronmk
    else: raise NotImplementedError("Can't list index columns for "+module+
984
        ' database')
985 853 aaronmk
986 464 aaronmk
def constraint_cols(db, table, constraint):
987 1849 aaronmk
    module = util.root_module(db.db)
988 464 aaronmk
    if module == 'psycopg2':
989
        return list(values(run_query(db, '''\
990
SELECT attname
991
FROM pg_constraint
992
JOIN pg_class ON pg_class.oid = conrelid
993
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
994
WHERE
995 2783 aaronmk
    relname = '''+db.esc_value(table)+'''
996
    AND conname = '''+db.esc_value(constraint)+'''
997 464 aaronmk
ORDER BY attnum
998 2783 aaronmk
'''
999
            )))
1000 464 aaronmk
    else: raise NotImplementedError("Can't list constraint columns for "+module+
1001
        ' database')
1002
1003 3079 aaronmk
#### Functions
1004
1005
def function_exists(db, function):
1006
    function = sql_gen.as_Function(function)
1007
1008
    info_table = sql_gen.Table('routines', 'information_schema')
1009
    conds = [('routine_name', function.name)]
1010
    schema = function.schema
1011
    if schema != None: conds.append(('routine_schema', schema))
1012
    # Exclude trigger functions, since they cannot be called directly
1013
    conds.append(('data_type', sql_gen.CompareCond('trigger', '!=')))
1014
1015
    return list(values(select(db, info_table, ['routine_name'], conds,
1016
        order_by='routine_schema', limit=1, log_level=4))) != []
1017
        # TODO: order_by search_path schema order
1018
1019
##### Structural changes
1020
1021
#### Columns
1022
1023
def add_col(db, table, col, comment=None, **kw_args):
1024
    '''
1025
    @param col TypedCol Name may be versioned, so be sure to propagate any
1026
        renaming back to any source column for the TypedCol.
1027
    @param comment None|str SQL comment used to distinguish columns of the same
1028
        name from each other when they contain different data, to allow the
1029
        ADD COLUMN query to be cached. If not set, query will not be cached.
1030
    '''
1031
    assert isinstance(col, sql_gen.TypedCol)
1032
1033
    while True:
1034
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1035
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1036
1037
        try:
1038
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1039
            break
1040
        except DuplicateException:
1041
            col.name = next_version(col.name)
1042
            # try again with next version of name
1043
1044
def add_not_null(db, col):
1045
    table = col.table
1046
    col = sql_gen.to_name_only_col(col)
1047
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1048
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1049
1050 2096 aaronmk
row_num_col = '_row_num'
1051
1052 3079 aaronmk
row_num_typed_col = sql_gen.TypedCol(row_num_col, 'serial', nullable=False,
1053
    constraints='PRIMARY KEY')
1054
1055
def add_row_num(db, table):
1056
    '''Adds a row number column to a table. Its name is in row_num_col. It will
1057
    be the primary key.'''
1058
    add_col(db, table, row_num_typed_col, log_level=3)
1059
1060
#### Indexes
1061
1062
def add_pkey(db, table, cols=None, recover=None):
1063
    '''Adds a primary key.
1064
    @param cols [sql_gen.Col,...] The columns in the primary key.
1065
        Defaults to the first column in the table.
1066
    @pre The table must not already have a primary key.
1067
    '''
1068
    table = sql_gen.as_Table(table)
1069
    if cols == None: cols = [pkey(db, table, recover)]
1070
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1071
1072
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1073
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1074
        log_ignore_excs=(DuplicateException,))
1075
1076 2998 aaronmk
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1077 2688 aaronmk
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1078 2538 aaronmk
    Currently, only function calls are supported as expressions.
1079 2998 aaronmk
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1080 2847 aaronmk
        This allows indexes to be used for comparisons where NULLs are equal.
1081 2538 aaronmk
    '''
1082 2964 aaronmk
    exprs = lists.mk_seq(exprs)
1083 2538 aaronmk
1084 2688 aaronmk
    # Parse exprs
1085
    old_exprs = exprs[:]
1086
    exprs = []
1087
    cols = []
1088
    for i, expr in enumerate(old_exprs):
1089 2823 aaronmk
        expr = sql_gen.as_Col(expr, table)
1090 2688 aaronmk
1091 2823 aaronmk
        # Handle nullable columns
1092 2998 aaronmk
        if ensure_not_null_:
1093
            try: expr = ensure_not_null(db, expr)
1094 2860 aaronmk
            except KeyError: pass # unknown type, so just create plain index
1095 2823 aaronmk
1096 2688 aaronmk
        # Extract col
1097 3002 aaronmk
        expr = copy.deepcopy(expr) # don't modify input!
1098 2688 aaronmk
        if isinstance(expr, sql_gen.FunctionCall):
1099
            col = expr.args[0]
1100
            expr = sql_gen.Expr(expr)
1101
        else: col = expr
1102 2823 aaronmk
        assert isinstance(col, sql_gen.Col)
1103 2688 aaronmk
1104
        # Extract table
1105
        if table == None:
1106
            assert sql_gen.is_table_col(col)
1107
            table = col.table
1108
1109
        col.table = None
1110
1111
        exprs.append(expr)
1112
        cols.append(col)
1113 2408 aaronmk
1114 2688 aaronmk
    table = sql_gen.as_Table(table)
1115
1116 3005 aaronmk
    # Add index
1117 3148 aaronmk
    str_ = 'CREATE'
1118
    if unique: str_ += ' UNIQUE'
1119
    str_ += ' INDEX ON '+table.to_str(db)+' ('+(
1120
        ', '.join((v.to_str(db) for v in exprs)))+')'
1121
    run_query(db, str_, recover=True, cacheable=True, log_level=3)
1122 2408 aaronmk
1123 2997 aaronmk
def add_index_col(db, col, suffix, expr, nullable=True):
1124 3000 aaronmk
    if sql_gen.index_col(col) != None: return # already has index col
1125 2997 aaronmk
1126
    new_col = sql_gen.suffixed_col(col, suffix)
1127
1128 3006 aaronmk
    # Add column
1129 3151 aaronmk
    new_typed_col = sql_gen.TypedCol(new_col.name,
1130
        db.col_info(col, cacheable=nullable).type)
1131
        # if not nullable, col_info will be changed later by add_not_null()
1132 3045 aaronmk
    add_col(db, col.table, new_typed_col, comment='src: '+repr(col),
1133
        log_level=3)
1134 3037 aaronmk
    new_col.name = new_typed_col.name # propagate any renaming
1135 3006 aaronmk
1136 3064 aaronmk
    update(db, col.table, [(new_col, expr)], in_place=True, cacheable=True,
1137
        log_level=3)
1138 2997 aaronmk
    if not nullable: add_not_null(db, new_col)
1139
    add_index(db, new_col)
1140
1141 3104 aaronmk
    col.table.index_cols[col.name] = new_col.name
1142 2997 aaronmk
1143 3047 aaronmk
# Controls when ensure_not_null() will use index columns
1144
not_null_index_cols_min_rows = 0 # rows; initially always use index columns
1145
1146 2997 aaronmk
def ensure_not_null(db, col):
1147
    '''For params, see sql_gen.ensure_not_null()'''
1148
    expr = sql_gen.ensure_not_null(db, col)
1149
1150 3047 aaronmk
    # If a nullable column in a temp table, add separate index column instead.
1151
    # Note that for small datasources, this adds 6-25% to the total import time.
1152
    if (sql_gen.is_temp_col(col) and isinstance(expr, sql_gen.EnsureNotNull)
1153
        and table_row_count(db, col.table) >= not_null_index_cols_min_rows):
1154 2997 aaronmk
        add_index_col(db, col, '::NOT NULL', expr, nullable=False)
1155 3000 aaronmk
        expr = sql_gen.index_col(col)
1156 2997 aaronmk
1157
    return expr
1158
1159 3083 aaronmk
already_indexed = object() # tells add_indexes() the pkey has already been added
1160
1161
def add_indexes(db, table, has_pkey=True):
1162
    '''Adds an index on all columns in a table.
1163
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1164
        index should be added on the first column.
1165
        * If already_indexed, the pkey is assumed to have already been added
1166
    '''
1167
    cols = table_cols(db, table)
1168
    if has_pkey:
1169
        if has_pkey is not already_indexed: add_pkey(db, table)
1170
        cols = cols[1:]
1171
    for col in cols: add_index(db, col, table)
1172
1173 3079 aaronmk
#### Tables
1174 2772 aaronmk
1175 3079 aaronmk
### Maintenance
1176 2772 aaronmk
1177 3079 aaronmk
def analyze(db, table):
1178
    table = sql_gen.as_Table(table)
1179
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1180 2934 aaronmk
1181 3079 aaronmk
def autoanalyze(db, table):
1182
    if db.autoanalyze: analyze(db, table)
1183 2935 aaronmk
1184 3079 aaronmk
def vacuum(db, table):
1185
    table = sql_gen.as_Table(table)
1186
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1187
        log_level=3))
1188 2086 aaronmk
1189 3079 aaronmk
### Lifecycle
1190
1191 2889 aaronmk
def drop_table(db, table):
1192
    table = sql_gen.as_Table(table)
1193
    return run_query(db, 'DROP TABLE IF EXISTS '+table.to_str(db)+' CASCADE')
1194
1195 3082 aaronmk
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1196
    like=None):
1197 2675 aaronmk
    '''Creates a table.
1198 2681 aaronmk
    @param cols [sql_gen.TypedCol,...] The column names and types
1199
    @param has_pkey If set, the first column becomes the primary key.
1200 2760 aaronmk
    @param col_indexes bool|[ref]
1201
        * If True, indexes will be added on all non-pkey columns.
1202
        * If a list reference, [0] will be set to a function to do this.
1203
          This can be used to delay index creation until the table is populated.
1204 2675 aaronmk
    '''
1205
    table = sql_gen.as_Table(table)
1206
1207 3082 aaronmk
    if like != None:
1208
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1209
            ]+cols
1210 2681 aaronmk
    if has_pkey:
1211
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1212 2872 aaronmk
        pkey.constraints = 'PRIMARY KEY'
1213 2681 aaronmk
1214 3085 aaronmk
    temp = table.is_temp and not db.debug_temp
1215
        # temp tables permanent in debug_temp mode
1216 2760 aaronmk
1217 3085 aaronmk
    # Create table
1218
    while True:
1219
        str_ = 'CREATE'
1220
        if temp: str_ += ' TEMP'
1221
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1222
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1223 3126 aaronmk
        str_ += '\n);'
1224 3085 aaronmk
1225
        try:
1226 3127 aaronmk
            run_query(db, str_, recover=True, cacheable=True, log_level=2,
1227 3085 aaronmk
                log_ignore_excs=(DuplicateException,))
1228
            break
1229
        except DuplicateException:
1230
            table.name = next_version(table.name)
1231
            # try again with next version of name
1232
1233 2760 aaronmk
    # Add indexes
1234 2773 aaronmk
    if has_pkey: has_pkey = already_indexed
1235
    def add_indexes_(): add_indexes(db, table, has_pkey)
1236
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1237
    elif col_indexes: add_indexes_() # add now
1238 2675 aaronmk
1239 3084 aaronmk
def copy_table_struct(db, src, dest):
1240
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1241 3085 aaronmk
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1242 3084 aaronmk
1243 3079 aaronmk
### Data
1244 2684 aaronmk
1245 2970 aaronmk
def truncate(db, table, schema='public', **kw_args):
1246
    '''For params, see run_query()'''
1247 2777 aaronmk
    table = sql_gen.as_Table(table, schema)
1248 2970 aaronmk
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1249 2732 aaronmk
1250 2965 aaronmk
def empty_temp(db, tables):
1251 2972 aaronmk
    if db.debug_temp: return # leave temp tables there for debugging
1252 2965 aaronmk
    tables = lists.mk_seq(tables)
1253 2971 aaronmk
    for table in tables: truncate(db, table, log_level=3)
1254 2965 aaronmk
1255 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
1256
    '''For kw_args, see tables()'''
1257
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1258 3094 aaronmk
1259
def distinct_table(db, table, distinct_on):
1260
    '''Creates a copy of a temp table which is distinct on the given columns.
1261 3099 aaronmk
    The old and new tables will both get an index on these columns, to
1262
    facilitate merge joins.
1263 3097 aaronmk
    @param distinct_on If empty, creates a table with one row. This is useful if
1264
        your distinct_on columns are all literal values.
1265 3099 aaronmk
    @return The new table.
1266 3094 aaronmk
    '''
1267 3099 aaronmk
    new_table = sql_gen.suffixed_table(table, '_distinct')
1268 3094 aaronmk
1269 3099 aaronmk
    copy_table_struct(db, table, new_table)
1270 3097 aaronmk
1271
    limit = None
1272
    if distinct_on == []: limit = 1 # one sample row
1273 3099 aaronmk
    else:
1274
        add_index(db, distinct_on, new_table, unique=True)
1275
        add_index(db, distinct_on, table) # for join optimization
1276 3097 aaronmk
1277 3099 aaronmk
    insert_select(db, new_table, None, mk_select(db, table, start=0,
1278 3097 aaronmk
        limit=limit), ignore=True)
1279 3099 aaronmk
    analyze(db, new_table)
1280 3094 aaronmk
1281 3099 aaronmk
    return new_table