Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1960 aaronmk
import lists
11 1889 aaronmk
from Proxy import Proxy
12 1872 aaronmk
import rand
13 2217 aaronmk
import sql_gen
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 2804 aaronmk
def get_cur_query(cur, input_query=None):
20 2168 aaronmk
    raw_query = None
21
    if hasattr(cur, 'query'): raw_query = cur.query
22
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
23 2170 aaronmk
24
    if raw_query != None: return raw_query
25 2804 aaronmk
    else: return '[input] '+strings.ustr(input_query)
26 14 aaronmk
27 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
28
    '''For params, see get_cur_query()'''
29 2771 aaronmk
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
30 135 aaronmk
31 300 aaronmk
class DbException(exc.ExceptionWithCause):
32 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
33 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
34 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
35
36 2143 aaronmk
class ExceptionWithName(DbException):
37
    def __init__(self, name, cause=None):
38 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
39 2143 aaronmk
        self.name = name
40 360 aaronmk
41 3109 aaronmk
class ExceptionWithValue(DbException):
42
    def __init__(self, value, cause=None):
43
        DbException.__init__(self, 'for value: '+strings.as_tt(repr(value)),
44
            cause)
45 2240 aaronmk
        self.value = value
46
47 2945 aaronmk
class ExceptionWithNameType(DbException):
48
    def __init__(self, type_, name, cause=None):
49
        DbException.__init__(self, 'for type: '+strings.as_tt(str(type_))
50
            +'; name: '+strings.as_tt(name), cause)
51
        self.type = type_
52
        self.name = name
53
54 2306 aaronmk
class ConstraintException(DbException):
55
    def __init__(self, name, cols, cause=None):
56 2484 aaronmk
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
57
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
58 2306 aaronmk
        self.name = name
59 468 aaronmk
        self.cols = cols
60 11 aaronmk
61 2523 aaronmk
class MissingCastException(DbException):
62
    def __init__(self, type_, col, cause=None):
63
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
64
            +' on column: '+strings.as_tt(col), cause)
65
        self.type = type_
66
        self.col = col
67
68 2143 aaronmk
class NameException(DbException): pass
69
70 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
71 13 aaronmk
72 2306 aaronmk
class NullValueException(ConstraintException): pass
73 13 aaronmk
74 3109 aaronmk
class InvalidValueException(ExceptionWithValue): pass
75 2239 aaronmk
76 2945 aaronmk
class DuplicateException(ExceptionWithNameType): pass
77 2143 aaronmk
78 89 aaronmk
class EmptyRowException(DbException): pass
79
80 865 aaronmk
##### Warnings
81
82
class DbWarning(UserWarning): pass
83
84 1930 aaronmk
##### Result retrieval
85
86
def col_names(cur): return (col[0] for col in cur.description)
87
88
def rows(cur): return iter(lambda: cur.fetchone(), None)
89
90
def consume_rows(cur):
91
    '''Used to fetch all rows so result will be cached'''
92
    iters.consume_iter(rows(cur))
93
94
def next_row(cur): return rows(cur).next()
95
96
def row(cur):
97
    row_ = next_row(cur)
98
    consume_rows(cur)
99
    return row_
100
101
def next_value(cur): return next_row(cur)[0]
102
103
def value(cur): return row(cur)[0]
104
105
def values(cur): return iters.func_iter(lambda: next_value(cur))
106
107
def value_or_none(cur):
108
    try: return value(cur)
109
    except StopIteration: return None
110
111 2762 aaronmk
##### Escaping
112 2101 aaronmk
113 2573 aaronmk
def esc_name_by_module(module, name):
114
    if module == 'psycopg2' or module == None: quote = '"'
115 2101 aaronmk
    elif module == 'MySQLdb': quote = '`'
116
    else: raise NotImplementedError("Can't escape name for "+module+' database')
117 2500 aaronmk
    return sql_gen.esc_name(name, quote)
118 2101 aaronmk
119
def esc_name_by_engine(engine, name, **kw_args):
120
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
121
122
def esc_name(db, name, **kw_args):
123
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
124
125
def qual_name(db, schema, table):
126
    def esc_name_(name): return esc_name(db, name)
127
    table = esc_name_(table)
128
    if schema != None: return esc_name_(schema)+'.'+table
129
    else: return table
130
131 1869 aaronmk
##### Database connections
132 1849 aaronmk
133 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
134 1926 aaronmk
135 1869 aaronmk
db_engines = {
136
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
137
    'PostgreSQL': ('psycopg2', {}),
138
}
139
140
DatabaseErrors_set = set([DbException])
141
DatabaseErrors = tuple(DatabaseErrors_set)
142
143
def _add_module(module):
144
    DatabaseErrors_set.add(module.DatabaseError)
145
    global DatabaseErrors
146
    DatabaseErrors = tuple(DatabaseErrors_set)
147
148
def db_config_str(db_config):
149
    return db_config['engine']+' database '+db_config['database']
150
151 2448 aaronmk
log_debug_none = lambda msg, level=2: None
152 1901 aaronmk
153 1849 aaronmk
class DbConn:
154 2923 aaronmk
    def __init__(self, db_config, autocommit=True, caching=True,
155 2915 aaronmk
        log_debug=log_debug_none, debug_temp=False):
156
        '''
157
        @param debug_temp Whether temporary objects should instead be permanent.
158
            This assists in debugging the internal objects used by the program.
159
        '''
160 1869 aaronmk
        self.db_config = db_config
161 2190 aaronmk
        self.autocommit = autocommit
162
        self.caching = caching
163 1901 aaronmk
        self.log_debug = log_debug
164 2193 aaronmk
        self.debug = log_debug != log_debug_none
165 2915 aaronmk
        self.debug_temp = debug_temp
166 3074 aaronmk
        self.autoanalyze = False
167 1869 aaronmk
168 3124 aaronmk
        self._savepoint = 0
169 3120 aaronmk
        self._reset()
170 1869 aaronmk
171
    def __getattr__(self, name):
172
        if name == '__dict__': raise Exception('getting __dict__')
173
        if name == 'db': return self._db()
174
        else: raise AttributeError()
175
176
    def __getstate__(self):
177
        state = copy.copy(self.__dict__) # shallow copy
178 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
179 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
180
        return state
181
182 3118 aaronmk
    def clear_cache(self): self.query_results = {}
183
184 3120 aaronmk
    def _reset(self):
185 3118 aaronmk
        self.clear_cache()
186 3124 aaronmk
        assert self._savepoint == 0
187 3118 aaronmk
        self._notices_seen = set()
188
        self.__db = None
189
190 2165 aaronmk
    def connected(self): return self.__db != None
191
192 3116 aaronmk
    def close(self):
193 3119 aaronmk
        if not self.connected(): return
194
195 3135 aaronmk
        # Record that the automatic transaction is now closed
196 3136 aaronmk
        self._savepoint -= 1
197 3135 aaronmk
198 3119 aaronmk
        self.db.close()
199 3120 aaronmk
        self._reset()
200 3116 aaronmk
201 3125 aaronmk
    def reconnect(self):
202
        # Do not do this in test mode as it would roll back everything
203
        if self.autocommit: self.close()
204
        # Connection will be reopened automatically on first query
205
206 1869 aaronmk
    def _db(self):
207
        if self.__db == None:
208
            # Process db_config
209
            db_config = self.db_config.copy() # don't modify input!
210 2097 aaronmk
            schemas = db_config.pop('schemas', None)
211 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
212
            module = __import__(module_name)
213
            _add_module(module)
214
            for orig, new in mappings.iteritems():
215
                try: util.rename_key(db_config, orig, new)
216
                except KeyError: pass
217
218
            # Connect
219
            self.__db = module.connect(**db_config)
220
221
            # Configure connection
222 2906 aaronmk
            if hasattr(self.db, 'set_isolation_level'):
223
                import psycopg2.extensions
224
                self.db.set_isolation_level(
225
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
226 2101 aaronmk
            if schemas != None:
227 2893 aaronmk
                search_path = [self.esc_name(s) for s in schemas.split(',')]
228
                search_path.append(value(run_query(self, 'SHOW search_path',
229
                    log_level=4)))
230
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
231
                    log_level=3)
232 3135 aaronmk
233
            # Record that a transaction is already open
234 3136 aaronmk
            self._savepoint += 1
235 1869 aaronmk
236
        return self.__db
237 1889 aaronmk
238 1891 aaronmk
    class DbCursor(Proxy):
239 1927 aaronmk
        def __init__(self, outer):
240 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
241 2191 aaronmk
            self.outer = outer
242 1927 aaronmk
            self.query_results = outer.query_results
243 1894 aaronmk
            self.query_lookup = None
244 1891 aaronmk
            self.result = []
245 1889 aaronmk
246 2802 aaronmk
        def execute(self, query):
247 2764 aaronmk
            self._is_insert = query.startswith('INSERT')
248 2797 aaronmk
            self.query_lookup = query
249 2148 aaronmk
            try:
250 2191 aaronmk
                try:
251 2802 aaronmk
                    cur = self.inner.execute(query)
252 2191 aaronmk
                    self.outer.do_autocommit()
253 2802 aaronmk
                finally: self.query = get_cur_query(self.inner, query)
254 1904 aaronmk
            except Exception, e:
255 2802 aaronmk
                _add_cursor_info(e, self, query)
256 1904 aaronmk
                self.result = e # cache the exception as the result
257
                self._cache_result()
258
                raise
259 3004 aaronmk
260
            # Always cache certain queries
261
            if query.startswith('CREATE') or query.startswith('ALTER'):
262 3007 aaronmk
                # structural changes
263 3040 aaronmk
                # Rest of query must be unique in the face of name collisions,
264
                # so don't cache ADD COLUMN unless it has distinguishing comment
265
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
266 3007 aaronmk
                    self._cache_result()
267 3004 aaronmk
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
268 2800 aaronmk
                consume_rows(self) # fetch all rows so result will be cached
269 3004 aaronmk
270 2762 aaronmk
            return cur
271 1894 aaronmk
272 1891 aaronmk
        def fetchone(self):
273
            row = self.inner.fetchone()
274 1899 aaronmk
            if row != None: self.result.append(row)
275
            # otherwise, fetched all rows
276 1904 aaronmk
            else: self._cache_result()
277
            return row
278
279
        def _cache_result(self):
280 2948 aaronmk
            # For inserts that return a result set, don't cache result set since
281
            # inserts are not idempotent. Other non-SELECT queries don't have
282
            # their result set read, so only exceptions will be cached (an
283
            # invalid query will always be invalid).
284 1930 aaronmk
            if self.query_results != None and (not self._is_insert
285 1906 aaronmk
                or isinstance(self.result, Exception)):
286
287 1894 aaronmk
                assert self.query_lookup != None
288 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
289
                    util.dict_subset(dicts.AttrsDictView(self),
290
                    ['query', 'result', 'rowcount', 'description']))
291 1906 aaronmk
292 1916 aaronmk
        class CacheCursor:
293
            def __init__(self, cached_result): self.__dict__ = cached_result
294
295 1927 aaronmk
            def execute(self, *args, **kw_args):
296 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
297
                # otherwise, result is a rows list
298
                self.iter = iter(self.result)
299
300
            def fetchone(self):
301
                try: return self.iter.next()
302
                except StopIteration: return None
303 1891 aaronmk
304 2212 aaronmk
    def esc_value(self, value):
305 2663 aaronmk
        try: str_ = self.mogrify('%s', [value])
306
        except NotImplementedError, e:
307
            module = util.root_module(self.db)
308
            if module == 'MySQLdb':
309
                import _mysql
310
                str_ = _mysql.escape_string(value)
311
            else: raise e
312 2374 aaronmk
        return strings.to_unicode(str_)
313 2212 aaronmk
314 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
315
316 2814 aaronmk
    def std_code(self, str_):
317
        '''Standardizes SQL code.
318
        * Ensures that string literals are prefixed by `E`
319
        '''
320
        if str_.startswith("'"): str_ = 'E'+str_
321
        return str_
322
323 2665 aaronmk
    def can_mogrify(self):
324 2663 aaronmk
        module = util.root_module(self.db)
325 2665 aaronmk
        return module == 'psycopg2'
326 2663 aaronmk
327 2665 aaronmk
    def mogrify(self, query, params=None):
328
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
329
        else: raise NotImplementedError("Can't mogrify query")
330
331 2671 aaronmk
    def print_notices(self):
332 2725 aaronmk
        if hasattr(self.db, 'notices'):
333
            for msg in self.db.notices:
334
                if msg not in self._notices_seen:
335
                    self._notices_seen.add(msg)
336
                    self.log_debug(msg, level=2)
337 2671 aaronmk
338 2793 aaronmk
    def run_query(self, query, cacheable=False, log_level=2,
339 2464 aaronmk
        debug_msg_ref=None):
340 2445 aaronmk
        '''
341 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
342
            throws one of these exceptions.
343 2664 aaronmk
        @param debug_msg_ref If specified, the log message will be returned in
344
            this instead of being output. This allows you to filter log messages
345
            depending on the result of the query.
346 2445 aaronmk
        '''
347 2167 aaronmk
        assert query != None
348
349 2047 aaronmk
        if not self.caching: cacheable = False
350 1903 aaronmk
        used_cache = False
351 2664 aaronmk
352
        def log_msg(query):
353
            if used_cache: cache_status = 'cache hit'
354
            elif cacheable: cache_status = 'cache miss'
355
            else: cache_status = 'non-cacheable'
356
            return 'DB query: '+cache_status+':\n'+strings.as_code(query, 'SQL')
357
358 1903 aaronmk
        try:
359 1927 aaronmk
            # Get cursor
360
            if cacheable:
361
                try:
362 2797 aaronmk
                    cur = self.query_results[query]
363 1927 aaronmk
                    used_cache = True
364
                except KeyError: cur = self.DbCursor(self)
365
            else: cur = self.db.cursor()
366
367 2664 aaronmk
            # Log query
368
            if self.debug and debug_msg_ref == None: # log before running
369
                self.log_debug(log_msg(query), log_level)
370
371 1927 aaronmk
            # Run query
372 2793 aaronmk
            cur.execute(query)
373 1903 aaronmk
        finally:
374 2671 aaronmk
            self.print_notices()
375 2664 aaronmk
            if self.debug and debug_msg_ref != None: # return after running
376 2793 aaronmk
                debug_msg_ref[0] = log_msg(str(get_cur_query(cur, query)))
377 1903 aaronmk
378
        return cur
379 1914 aaronmk
380 2797 aaronmk
    def is_cached(self, query): return query in self.query_results
381 2139 aaronmk
382 2907 aaronmk
    def with_autocommit(self, func):
383 2801 aaronmk
        import psycopg2.extensions
384
385
        prev_isolation_level = self.db.isolation_level
386 2907 aaronmk
        self.db.set_isolation_level(
387
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
388 2683 aaronmk
        try: return func()
389 2801 aaronmk
        finally: self.db.set_isolation_level(prev_isolation_level)
390 2683 aaronmk
391 2139 aaronmk
    def with_savepoint(self, func):
392 3137 aaronmk
        top = self._savepoint == 0
393 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
394 3137 aaronmk
395
        if top: query = 'START TRANSACTION ISOLATION LEVEL READ COMMITTED'
396
        else: query = 'SAVEPOINT '+savepoint
397
        self.run_query(query, log_level=4)
398
399 2139 aaronmk
        self._savepoint += 1
400 3137 aaronmk
        try:
401
            return func()
402
            if top: self.run_query('COMMIT', log_level=4)
403 2139 aaronmk
        except:
404 3137 aaronmk
            if top: query = 'ROLLBACK'
405
            else: query = 'ROLLBACK TO SAVEPOINT '+savepoint
406
            self.run_query(query, log_level=4)
407
408 2139 aaronmk
            raise
409 2930 aaronmk
        finally:
410
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
411
            # "The savepoint remains valid and can be rolled back to again"
412
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
413 3137 aaronmk
            if not top:
414
                self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
415 2930 aaronmk
416
            self._savepoint -= 1
417
            assert self._savepoint >= 0
418
419
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
420 2191 aaronmk
421
    def do_autocommit(self):
422
        '''Autocommits if outside savepoint'''
423 3135 aaronmk
        assert self._savepoint >= 1
424
        if self.autocommit and self._savepoint == 1:
425 2924 aaronmk
            self.log_debug('Autocommitting', level=4)
426 2191 aaronmk
            self.db.commit()
427 2643 aaronmk
428 3155 aaronmk
    def col_info(self, col, cacheable=True):
429 2643 aaronmk
        table = sql_gen.Table('columns', 'information_schema')
430 3063 aaronmk
        type_ = sql_gen.Coalesce(sql_gen.Nullif(sql_gen.Col('data_type'),
431
            'USER-DEFINED'), sql_gen.Col('udt_name'))
432 3078 aaronmk
        cols = [type_, 'column_default',
433
            sql_gen.Cast('boolean', sql_gen.Col('is_nullable'))]
434 2643 aaronmk
435
        conds = [('table_name', col.table.name), ('column_name', col.name)]
436
        schema = col.table.schema
437
        if schema != None: conds.append(('table_schema', schema))
438
439 2819 aaronmk
        type_, default, nullable = row(select(self, table, cols, conds,
440 3150 aaronmk
            order_by='table_schema', limit=1, cacheable=cacheable, log_level=4))
441 2643 aaronmk
            # TODO: order_by search_path schema order
442 2819 aaronmk
        default = sql_gen.as_Code(default, self)
443
444
        return sql_gen.TypedCol(col.name, type_, default, nullable)
445 2917 aaronmk
446
    def TempFunction(self, name):
447
        if self.debug_temp: schema = None
448
        else: schema = 'pg_temp'
449
        return sql_gen.Function(name, schema)
450 1849 aaronmk
451 1869 aaronmk
connect = DbConn
452
453 832 aaronmk
##### Recoverable querying
454 15 aaronmk
455 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
456 11 aaronmk
457 2791 aaronmk
def run_query(db, query, recover=None, cacheable=False, log_level=2,
458
    log_ignore_excs=None, **kw_args):
459 2794 aaronmk
    '''For params, see DbConn.run_query()'''
460 830 aaronmk
    if recover == None: recover = False
461 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
462
    log_ignore_excs = tuple(log_ignore_excs)
463 830 aaronmk
464 2666 aaronmk
    debug_msg_ref = None # usually, db.run_query() logs query before running it
465
    # But if filtering with log_ignore_excs, wait until after exception parsing
466 2984 aaronmk
    if log_ignore_excs != () or not db.can_mogrify(): debug_msg_ref = [None]
467 2666 aaronmk
468 2148 aaronmk
    try:
469 2464 aaronmk
        try:
470 2794 aaronmk
            def run(): return db.run_query(query, cacheable, log_level,
471 2793 aaronmk
                debug_msg_ref, **kw_args)
472 2796 aaronmk
            if recover and not db.is_cached(query):
473 2464 aaronmk
                return with_savepoint(db, run)
474
            else: return run() # don't need savepoint if cached
475
        except Exception, e:
476 3095 aaronmk
            msg = strings.ustr(e.args[0])
477 2464 aaronmk
478 3095 aaronmk
            match = re.match(r'^duplicate key value violates unique constraint '
479 3096 aaronmk
                r'"((_?[^\W_]+(?=[._]))?.+?)"', msg)
480 2464 aaronmk
            if match:
481
                constraint, table = match.groups()
482 3025 aaronmk
                cols = []
483
                if recover: # need auto-rollback to run index_cols()
484
                    try: cols = index_cols(db, table, constraint)
485
                    except NotImplementedError: pass
486
                raise DuplicateKeyException(constraint, cols, e)
487 2464 aaronmk
488 3095 aaronmk
            match = re.match(r'^null value in column "(.+?)" violates not-null'
489 2464 aaronmk
                r' constraint', msg)
490
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
491
492 3095 aaronmk
            match = re.match(r'^(?:invalid input (?:syntax|value)\b.*?'
493 3109 aaronmk
                r'|.+? field value out of range): "(.+?)"', msg)
494 2464 aaronmk
            if match:
495 3109 aaronmk
                value, = match.groups()
496
                raise InvalidValueException(strings.to_unicode(value), e)
497 2464 aaronmk
498 3095 aaronmk
            match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
499 2523 aaronmk
                r'is of type', msg)
500
            if match:
501
                col, type_ = match.groups()
502
                raise MissingCastException(type_, col, e)
503
504 3095 aaronmk
            match = re.match(r'^(\S+) "(.+?)".*? already exists', msg)
505 2945 aaronmk
            if match:
506
                type_, name = match.groups()
507
                raise DuplicateException(type_, name, e)
508 2464 aaronmk
509
            raise # no specific exception raised
510
    except log_ignore_excs:
511
        log_level += 2
512
        raise
513
    finally:
514 2666 aaronmk
        if debug_msg_ref != None and debug_msg_ref[0] != None:
515
            db.log_debug(debug_msg_ref[0], log_level)
516 830 aaronmk
517 832 aaronmk
##### Basic queries
518
519 2153 aaronmk
def next_version(name):
520 2163 aaronmk
    version = 1 # first existing name was version 0
521 2586 aaronmk
    match = re.match(r'^(.*)#(\d+)$', name)
522 2153 aaronmk
    if match:
523 2586 aaronmk
        name, version = match.groups()
524
        version = int(version)+1
525 2932 aaronmk
    return sql_gen.concat(name, '#'+str(version))
526 2153 aaronmk
527 2899 aaronmk
def lock_table(db, table, mode):
528
    table = sql_gen.as_Table(table)
529
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
530
531 2789 aaronmk
def run_query_into(db, query, into=None, add_indexes_=False, **kw_args):
532 2085 aaronmk
    '''Outputs a query to a temp table.
533
    For params, see run_query().
534
    '''
535 2789 aaronmk
    if into == None: return run_query(db, query, **kw_args)
536 2790 aaronmk
537
    assert isinstance(into, sql_gen.Table)
538
539 2992 aaronmk
    into.is_temp = True
540 3008 aaronmk
    # "temporary tables cannot specify a schema name", so remove schema
541
    into.schema = None
542 2992 aaronmk
543 2790 aaronmk
    kw_args['recover'] = True
544 2945 aaronmk
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
545 2790 aaronmk
546 2916 aaronmk
    temp = not db.debug_temp # tables are permanent in debug_temp mode
547 2790 aaronmk
548
    # Create table
549
    while True:
550
        create_query = 'CREATE'
551
        if temp: create_query += ' TEMP'
552
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
553 2385 aaronmk
554 2790 aaronmk
        try:
555
            cur = run_query(db, create_query, **kw_args)
556
                # CREATE TABLE AS sets rowcount to # rows in query
557
            break
558 2945 aaronmk
        except DuplicateException, e:
559 2790 aaronmk
            into.name = next_version(into.name)
560
            # try again with next version of name
561
562
    if add_indexes_: add_indexes(db, into)
563 3075 aaronmk
564
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
565
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
566
    # table is going to be used in complex queries, it is wise to run ANALYZE on
567
    # the temporary table after it is populated."
568
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
569
    # If into is not a temp table, ANALYZE is useful but not required.
570 3073 aaronmk
    analyze(db, into)
571 2790 aaronmk
572
    return cur
573 2085 aaronmk
574 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
575
576 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
577
578 2233 aaronmk
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
579 2293 aaronmk
    start=None, order_by=order_by_pkey, default_table=None):
580 1981 aaronmk
    '''
581 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
582 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
583 1981 aaronmk
    @param fields Use None to select all fields in the table
584 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
585 2379 aaronmk
        * container can be any iterable type
586 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
587
        * compare_right_side: sql_gen.ValueCond|literal value
588 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
589
        use all columns
590 2786 aaronmk
    @return query
591 1981 aaronmk
    '''
592 2315 aaronmk
    # Parse tables param
593 2964 aaronmk
    tables = lists.mk_seq(tables)
594 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
595 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
596 2121 aaronmk
597 2315 aaronmk
    # Parse other params
598 2376 aaronmk
    if conds == None: conds = []
599 2650 aaronmk
    elif dicts.is_dict(conds): conds = conds.items()
600 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
601 3129 aaronmk
    assert limit == None or isinstance(limit, (int, long))
602
    assert start == None or isinstance(start, (int, long))
603 2315 aaronmk
    if order_by is order_by_pkey:
604
        if distinct_on != []: order_by = None
605
        else: order_by = pkey(db, table0, recover=True)
606 865 aaronmk
607 2315 aaronmk
    query = 'SELECT'
608 2056 aaronmk
609 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
610 2056 aaronmk
611 2200 aaronmk
    # DISTINCT ON columns
612 2233 aaronmk
    if distinct_on != []:
613 2467 aaronmk
        query += '\nDISTINCT'
614 2254 aaronmk
        if distinct_on is not distinct_on_all:
615 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
616
617
    # Columns
618 3027 aaronmk
    if fields == None:
619
        if query.find('\n') >= 0: whitespace = '\n'
620
        else: whitespace = ' '
621
        query += whitespace+'*'
622 2765 aaronmk
    else:
623
        assert fields != []
624 3027 aaronmk
        query += '\n'+('\n, '.join(map(parse_col, fields)))
625 2200 aaronmk
626
    # Main table
627 2467 aaronmk
    query += '\nFROM '+table0.to_str(db)
628 865 aaronmk
629 2122 aaronmk
    # Add joins
630 2271 aaronmk
    left_table = table0
631 2263 aaronmk
    for join_ in tables:
632
        table = join_.table
633 2238 aaronmk
634 2343 aaronmk
        # Parse special values
635
        if join_.type_ is sql_gen.filter_out: # filter no match
636 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
637 2853 aaronmk
                sql_gen.CompareCond(None, '~=')))
638 2343 aaronmk
639 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
640 2122 aaronmk
641
        left_table = table
642
643 865 aaronmk
    missing = True
644 2376 aaronmk
    if conds != []:
645 2576 aaronmk
        if len(conds) == 1: whitespace = ' '
646
        else: whitespace = '\n'
647 2578 aaronmk
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
648
            .to_str(db) for l, r in conds], 'WHERE')
649 865 aaronmk
        missing = False
650 2227 aaronmk
    if order_by != None:
651 2467 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
652
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
653 865 aaronmk
    if start != None:
654 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
655 865 aaronmk
        missing = False
656
    if missing: warnings.warn(DbWarning(
657
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
658
659 2786 aaronmk
    return query
660 11 aaronmk
661 2054 aaronmk
def select(db, *args, **kw_args):
662
    '''For params, see mk_select() and run_query()'''
663
    recover = kw_args.pop('recover', None)
664
    cacheable = kw_args.pop('cacheable', True)
665 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
666 2054 aaronmk
667 2791 aaronmk
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
668
        log_level=log_level)
669 2054 aaronmk
670 2788 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
671 3009 aaronmk
    embeddable=False, ignore=False):
672 1960 aaronmk
    '''
673
    @param returning str|None An inserted column (such as pkey) to return
674 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
675 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
676
        query will be fully cached, not just if it raises an exception.
677 3009 aaronmk
    @param ignore Whether to ignore duplicate keys.
678 1960 aaronmk
    '''
679 2754 aaronmk
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
680 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
681 3010 aaronmk
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
682 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
683 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
684 2063 aaronmk
685 2497 aaronmk
    first_line = 'INSERT INTO '+table.to_str(db)
686 2063 aaronmk
687 3009 aaronmk
    def mk_insert(select_query):
688
        query = first_line
689 3014 aaronmk
        if cols != None:
690
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
691 3009 aaronmk
        query += '\n'+select_query
692
693
        if returning != None:
694
            returning_name_col = sql_gen.to_name_only_col(returning)
695
            query += '\nRETURNING '+returning_name_col.to_str(db)
696
697
        return query
698 2063 aaronmk
699 3017 aaronmk
    return_type = 'unknown'
700
    if returning != None: return_type = returning.to_str(db)+'%TYPE'
701
702 3009 aaronmk
    lang = 'sql'
703
    if ignore:
704 3017 aaronmk
        # Always return something to set the correct rowcount
705
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
706
707 3009 aaronmk
        embeddable = True # must use function
708
        lang = 'plpgsql'
709 3010 aaronmk
710 3092 aaronmk
        if cols == None:
711
            row = [sql_gen.Col(sql_gen.all_cols, 'row')]
712
            row_vars = [sql_gen.Table('row')]
713
        else:
714
            row_vars = row = [sql_gen.Col(c.name, 'row') for c in cols]
715
716 3009 aaronmk
        query = '''\
717 3010 aaronmk
DECLARE
718 3014 aaronmk
    row '''+table.to_str(db)+'''%ROWTYPE;
719 3009 aaronmk
BEGIN
720 3019 aaronmk
    /* Need an EXCEPTION block for each individual row because "When an error is
721
    caught by an EXCEPTION clause, [...] all changes to persistent database
722
    state within the block are rolled back."
723
    This is unfortunate because "A block containing an EXCEPTION clause is
724
    significantly more expensive to enter and exit than a block without one."
725 3015 aaronmk
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
726
#PLPGSQL-ERROR-TRAPPING)
727
    */
728 3092 aaronmk
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
729 3034 aaronmk
'''+select_query+'''
730
    LOOP
731 3015 aaronmk
        BEGIN
732 3019 aaronmk
            RETURN QUERY
733 3014 aaronmk
'''+mk_insert(sql_gen.Values(row).to_str(db))+'''
734 3010 aaronmk
;
735 3015 aaronmk
        EXCEPTION
736 3019 aaronmk
            WHEN unique_violation THEN NULL; -- continue to next row
737 3015 aaronmk
        END;
738 3010 aaronmk
    END LOOP;
739
END;\
740 3009 aaronmk
'''
741
    else: query = mk_insert(select_query)
742
743 2070 aaronmk
    if embeddable:
744
        # Create function
745 2513 aaronmk
        function_name = sql_gen.clean_name(first_line)
746 2189 aaronmk
        while True:
747
            try:
748 2918 aaronmk
                function = db.TempFunction(function_name)
749 2194 aaronmk
750 2189 aaronmk
                function_query = '''\
751 2698 aaronmk
CREATE FUNCTION '''+function.to_str(db)+'''()
752 3017 aaronmk
RETURNS SETOF '''+return_type+'''
753 3009 aaronmk
LANGUAGE '''+lang+'''
754 2467 aaronmk
AS $$
755 3009 aaronmk
'''+query+'''
756 2467 aaronmk
$$;
757 2070 aaronmk
'''
758 2446 aaronmk
                run_query(db, function_query, recover=True, cacheable=True,
759 2945 aaronmk
                    log_ignore_excs=(DuplicateException,))
760 2189 aaronmk
                break # this version was successful
761 2945 aaronmk
            except DuplicateException, e:
762 2189 aaronmk
                function_name = next_version(function_name)
763
                # try again with next version of name
764 2070 aaronmk
765 2337 aaronmk
        # Return query that uses function
766 3009 aaronmk
        cols = None
767
        if returning != None: cols = [returning]
768 2698 aaronmk
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
769 3009 aaronmk
            cols) # AS clause requires function alias
770 2787 aaronmk
        return mk_select(db, func_table, start=0, order_by=None)
771 2070 aaronmk
772 2787 aaronmk
    return query
773 2066 aaronmk
774 3074 aaronmk
def insert_select(db, table, *args, **kw_args):
775 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
776 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
777
        values in
778 2072 aaronmk
    '''
779 3141 aaronmk
    returning = kw_args.get('returning', None)
780
    ignore = kw_args.get('ignore', False)
781
782 2386 aaronmk
    into = kw_args.pop('into', None)
783
    if into != None: kw_args['embeddable'] = True
784 2066 aaronmk
    recover = kw_args.pop('recover', None)
785 3141 aaronmk
    if ignore: recover = True
786 2066 aaronmk
    cacheable = kw_args.pop('cacheable', True)
787 2673 aaronmk
    log_level = kw_args.pop('log_level', 2)
788 2066 aaronmk
789 3141 aaronmk
    rowcount_only = ignore and returning == None # keep NULL rows on server
790
    if rowcount_only: into = sql_gen.Table('rowcount')
791
792 3074 aaronmk
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
793
        into, recover=recover, cacheable=cacheable, log_level=log_level)
794 3141 aaronmk
    if rowcount_only: empty_temp(db, into)
795 3074 aaronmk
    autoanalyze(db, table)
796
    return cur
797 2063 aaronmk
798 2738 aaronmk
default = sql_gen.default # tells insert() to use the default value for a column
799 2066 aaronmk
800 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
801 2085 aaronmk
    '''For params, see insert_select()'''
802 1960 aaronmk
    if lists.is_seq(row): cols = None
803
    else:
804
        cols = row.keys()
805
        row = row.values()
806 2738 aaronmk
    row = list(row) # ensure that "== []" works
807 1960 aaronmk
808 2738 aaronmk
    if row == []: query = None
809
    else: query = sql_gen.Values(row).to_str(db)
810 1961 aaronmk
811 2788 aaronmk
    return insert_select(db, table, cols, query, *args, **kw_args)
812 11 aaronmk
813 3152 aaronmk
def mk_update(db, table, changes=None, cond=None, in_place=False,
814 3153 aaronmk
    cacheable_=True):
815 2402 aaronmk
    '''
816
    @param changes [(col, new_value),...]
817
        * container can be any iterable type
818
        * col: sql_gen.Code|str (for col name)
819
        * new_value: sql_gen.Code|literal value
820
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
821 3056 aaronmk
    @param in_place If set, locks the table and updates rows in place.
822
        This avoids creating dead rows in PostgreSQL.
823
        * cond must be None
824 3153 aaronmk
    @param cacheable_ Whether column structure information used to generate the
825 3152 aaronmk
        query can be cached
826 2402 aaronmk
    @return str query
827
    '''
828 3057 aaronmk
    table = sql_gen.as_Table(table)
829
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
830
        for c, v in changes]
831
832 3056 aaronmk
    if in_place:
833
        assert cond == None
834 3058 aaronmk
835 3065 aaronmk
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
836
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
837 3153 aaronmk
            +db.col_info(sql_gen.with_default_table(c, table), cacheable_).type
838 3065 aaronmk
            +'\nUSING '+v.to_str(db) for c, v in changes))
839 3058 aaronmk
    else:
840
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
841
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
842
            for c, v in changes))
843
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
844 3056 aaronmk
845 2402 aaronmk
    return query
846
847 3074 aaronmk
def update(db, table, *args, **kw_args):
848 2402 aaronmk
    '''For params, see mk_update() and run_query()'''
849
    recover = kw_args.pop('recover', None)
850 3043 aaronmk
    cacheable = kw_args.pop('cacheable', False)
851 3030 aaronmk
    log_level = kw_args.pop('log_level', 2)
852 2402 aaronmk
853 3074 aaronmk
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
854
        cacheable, log_level=log_level)
855
    autoanalyze(db, table)
856
    return cur
857 2402 aaronmk
858 135 aaronmk
def last_insert_id(db):
859 1849 aaronmk
    module = util.root_module(db.db)
860 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
861
    elif module == 'MySQLdb': return db.insert_id()
862
    else: return None
863 13 aaronmk
864 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
865 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
866 2415 aaronmk
    to names that will be distinct among the columns' tables.
867 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
868 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
869
    @param into The table for the new columns.
870 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
871
        columns will be included in the mapping even if they are not in cols.
872
        The tables of the provided Col objects will be changed to into, so make
873
        copies of them if you want to keep the original tables.
874
    @param as_items Whether to return a list of dict items instead of a dict
875 2383 aaronmk
    @return dict(orig_col=new_col, ...)
876
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
877 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
878
        * All mappings use the into table so its name can easily be
879 2383 aaronmk
          changed for all columns at once
880
    '''
881 2415 aaronmk
    cols = lists.uniqify(cols)
882
883 2394 aaronmk
    items = []
884 2389 aaronmk
    for col in preserve:
885 2390 aaronmk
        orig_col = copy.copy(col)
886 2392 aaronmk
        col.table = into
887 2394 aaronmk
        items.append((orig_col, col))
888
    preserve = set(preserve)
889
    for col in cols:
890 2716 aaronmk
        if col not in preserve:
891
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
892 2394 aaronmk
893
    if not as_items: items = dict(items)
894
    return items
895 2383 aaronmk
896 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
897 2391 aaronmk
    '''For params, see mk_flatten_mapping()
898
    @return See return value of mk_flatten_mapping()
899
    '''
900 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
901
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
902 2786 aaronmk
    run_query_into(db, mk_select(db, joins, cols, limit=limit, start=start),
903 2846 aaronmk
        into=into, add_indexes_=True)
904 2394 aaronmk
    return dict(items)
905 2391 aaronmk
906 3079 aaronmk
##### Database structure introspection
907 2414 aaronmk
908 3079 aaronmk
#### Tables
909
910
def tables(db, schema_like='public', table_like='%', exact=False):
911
    if exact: compare = '='
912
    else: compare = 'LIKE'
913
914
    module = util.root_module(db.db)
915
    if module == 'psycopg2':
916
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
917
            ('tablename', sql_gen.CompareCond(table_like, compare))]
918
        return values(select(db, 'pg_tables', ['tablename'], conds,
919
            order_by='tablename', log_level=4))
920
    elif module == 'MySQLdb':
921
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
922
            , cacheable=True, log_level=4))
923
    else: raise NotImplementedError("Can't list tables for "+module+' database')
924
925
def table_exists(db, table):
926
    table = sql_gen.as_Table(table)
927
    return list(tables(db, table.schema, table.name, exact=True)) != []
928
929 2426 aaronmk
def table_row_count(db, table, recover=None):
930 2786 aaronmk
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
931 2443 aaronmk
        order_by=None, start=0), recover=recover, log_level=3))
932 2426 aaronmk
933 2414 aaronmk
def table_cols(db, table, recover=None):
934
    return list(col_names(select(db, table, limit=0, order_by=None,
935 2443 aaronmk
        recover=recover, log_level=4)))
936 2414 aaronmk
937 2291 aaronmk
def pkey(db, table, recover=None):
938 832 aaronmk
    '''Assumed to be first column in table'''
939 2339 aaronmk
    return table_cols(db, table, recover)[0]
940 832 aaronmk
941 2559 aaronmk
not_null_col = 'not_null_col'
942 2340 aaronmk
943
def table_not_null_col(db, table, recover=None):
944
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
945
    if not_null_col in table_cols(db, table, recover): return not_null_col
946
    else: return pkey(db, table, recover)
947
948 853 aaronmk
def index_cols(db, table, index):
949
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
950
    automatically created. When you don't know whether something is a UNIQUE
951
    constraint or a UNIQUE index, use this function.'''
952 1909 aaronmk
    module = util.root_module(db.db)
953
    if module == 'psycopg2':
954
        return list(values(run_query(db, '''\
955 853 aaronmk
SELECT attname
956 866 aaronmk
FROM
957
(
958
        SELECT attnum, attname
959
        FROM pg_index
960
        JOIN pg_class index ON index.oid = indexrelid
961
        JOIN pg_class table_ ON table_.oid = indrelid
962
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
963
        WHERE
964 2782 aaronmk
            table_.relname = '''+db.esc_value(table)+'''
965
            AND index.relname = '''+db.esc_value(index)+'''
966 866 aaronmk
    UNION
967
        SELECT attnum, attname
968
        FROM
969
        (
970
            SELECT
971
                indrelid
972
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
973
                    AS indkey
974
            FROM pg_index
975
            JOIN pg_class index ON index.oid = indexrelid
976
            JOIN pg_class table_ ON table_.oid = indrelid
977
            WHERE
978 2782 aaronmk
                table_.relname = '''+db.esc_value(table)+'''
979
                AND index.relname = '''+db.esc_value(index)+'''
980 866 aaronmk
        ) s
981
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
982
) s
983 853 aaronmk
ORDER BY attnum
984 2782 aaronmk
'''
985
            , cacheable=True, log_level=4)))
986 1909 aaronmk
    else: raise NotImplementedError("Can't list index columns for "+module+
987
        ' database')
988 853 aaronmk
989 464 aaronmk
def constraint_cols(db, table, constraint):
990 1849 aaronmk
    module = util.root_module(db.db)
991 464 aaronmk
    if module == 'psycopg2':
992
        return list(values(run_query(db, '''\
993
SELECT attname
994
FROM pg_constraint
995
JOIN pg_class ON pg_class.oid = conrelid
996
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
997
WHERE
998 2783 aaronmk
    relname = '''+db.esc_value(table)+'''
999
    AND conname = '''+db.esc_value(constraint)+'''
1000 464 aaronmk
ORDER BY attnum
1001 2783 aaronmk
'''
1002
            )))
1003 464 aaronmk
    else: raise NotImplementedError("Can't list constraint columns for "+module+
1004
        ' database')
1005
1006 3079 aaronmk
#### Functions
1007
1008
def function_exists(db, function):
1009
    function = sql_gen.as_Function(function)
1010
1011
    info_table = sql_gen.Table('routines', 'information_schema')
1012
    conds = [('routine_name', function.name)]
1013
    schema = function.schema
1014
    if schema != None: conds.append(('routine_schema', schema))
1015
    # Exclude trigger functions, since they cannot be called directly
1016
    conds.append(('data_type', sql_gen.CompareCond('trigger', '!=')))
1017
1018
    return list(values(select(db, info_table, ['routine_name'], conds,
1019
        order_by='routine_schema', limit=1, log_level=4))) != []
1020
        # TODO: order_by search_path schema order
1021
1022
##### Structural changes
1023
1024
#### Columns
1025
1026
def add_col(db, table, col, comment=None, **kw_args):
1027
    '''
1028
    @param col TypedCol Name may be versioned, so be sure to propagate any
1029
        renaming back to any source column for the TypedCol.
1030
    @param comment None|str SQL comment used to distinguish columns of the same
1031
        name from each other when they contain different data, to allow the
1032
        ADD COLUMN query to be cached. If not set, query will not be cached.
1033
    '''
1034
    assert isinstance(col, sql_gen.TypedCol)
1035
1036
    while True:
1037
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1038
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1039
1040
        try:
1041
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1042
            break
1043
        except DuplicateException:
1044
            col.name = next_version(col.name)
1045
            # try again with next version of name
1046
1047
def add_not_null(db, col):
1048
    table = col.table
1049
    col = sql_gen.to_name_only_col(col)
1050
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1051
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1052
1053 2096 aaronmk
row_num_col = '_row_num'
1054
1055 3079 aaronmk
row_num_typed_col = sql_gen.TypedCol(row_num_col, 'serial', nullable=False,
1056
    constraints='PRIMARY KEY')
1057
1058
def add_row_num(db, table):
1059
    '''Adds a row number column to a table. Its name is in row_num_col. It will
1060
    be the primary key.'''
1061
    add_col(db, table, row_num_typed_col, log_level=3)
1062
1063
#### Indexes
1064
1065
def add_pkey(db, table, cols=None, recover=None):
1066
    '''Adds a primary key.
1067
    @param cols [sql_gen.Col,...] The columns in the primary key.
1068
        Defaults to the first column in the table.
1069
    @pre The table must not already have a primary key.
1070
    '''
1071
    table = sql_gen.as_Table(table)
1072
    if cols == None: cols = [pkey(db, table, recover)]
1073
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1074
1075
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1076
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1077
        log_ignore_excs=(DuplicateException,))
1078
1079 2998 aaronmk
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1080 2688 aaronmk
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1081 2538 aaronmk
    Currently, only function calls are supported as expressions.
1082 2998 aaronmk
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1083 2847 aaronmk
        This allows indexes to be used for comparisons where NULLs are equal.
1084 2538 aaronmk
    '''
1085 2964 aaronmk
    exprs = lists.mk_seq(exprs)
1086 2538 aaronmk
1087 2688 aaronmk
    # Parse exprs
1088
    old_exprs = exprs[:]
1089
    exprs = []
1090
    cols = []
1091
    for i, expr in enumerate(old_exprs):
1092 2823 aaronmk
        expr = sql_gen.as_Col(expr, table)
1093 2688 aaronmk
1094 2823 aaronmk
        # Handle nullable columns
1095 2998 aaronmk
        if ensure_not_null_:
1096
            try: expr = ensure_not_null(db, expr)
1097 2860 aaronmk
            except KeyError: pass # unknown type, so just create plain index
1098 2823 aaronmk
1099 2688 aaronmk
        # Extract col
1100 3002 aaronmk
        expr = copy.deepcopy(expr) # don't modify input!
1101 2688 aaronmk
        if isinstance(expr, sql_gen.FunctionCall):
1102
            col = expr.args[0]
1103
            expr = sql_gen.Expr(expr)
1104
        else: col = expr
1105 2823 aaronmk
        assert isinstance(col, sql_gen.Col)
1106 2688 aaronmk
1107
        # Extract table
1108
        if table == None:
1109
            assert sql_gen.is_table_col(col)
1110
            table = col.table
1111
1112
        col.table = None
1113
1114
        exprs.append(expr)
1115
        cols.append(col)
1116 2408 aaronmk
1117 2688 aaronmk
    table = sql_gen.as_Table(table)
1118
1119 3005 aaronmk
    # Add index
1120 3148 aaronmk
    str_ = 'CREATE'
1121
    if unique: str_ += ' UNIQUE'
1122
    str_ += ' INDEX ON '+table.to_str(db)+' ('+(
1123
        ', '.join((v.to_str(db) for v in exprs)))+')'
1124
    run_query(db, str_, recover=True, cacheable=True, log_level=3)
1125 2408 aaronmk
1126 2997 aaronmk
def add_index_col(db, col, suffix, expr, nullable=True):
1127 3000 aaronmk
    if sql_gen.index_col(col) != None: return # already has index col
1128 2997 aaronmk
1129
    new_col = sql_gen.suffixed_col(col, suffix)
1130
1131 3006 aaronmk
    # Add column
1132 3151 aaronmk
    new_typed_col = sql_gen.TypedCol(new_col.name,
1133
        db.col_info(col, cacheable=nullable).type)
1134
        # if not nullable, col_info will be changed later by add_not_null()
1135 3045 aaronmk
    add_col(db, col.table, new_typed_col, comment='src: '+repr(col),
1136
        log_level=3)
1137 3037 aaronmk
    new_col.name = new_typed_col.name # propagate any renaming
1138 3006 aaronmk
1139 3154 aaronmk
    update(db, col.table, [(new_col, expr)], in_place=True, cacheable_=nullable,
1140
        cacheable=True, log_level=3)
1141 2997 aaronmk
    if not nullable: add_not_null(db, new_col)
1142
    add_index(db, new_col)
1143
1144 3104 aaronmk
    col.table.index_cols[col.name] = new_col.name
1145 2997 aaronmk
1146 3047 aaronmk
# Controls when ensure_not_null() will use index columns
1147
not_null_index_cols_min_rows = 0 # rows; initially always use index columns
1148
1149 2997 aaronmk
def ensure_not_null(db, col):
1150
    '''For params, see sql_gen.ensure_not_null()'''
1151
    expr = sql_gen.ensure_not_null(db, col)
1152
1153 3047 aaronmk
    # If a nullable column in a temp table, add separate index column instead.
1154
    # Note that for small datasources, this adds 6-25% to the total import time.
1155
    if (sql_gen.is_temp_col(col) and isinstance(expr, sql_gen.EnsureNotNull)
1156
        and table_row_count(db, col.table) >= not_null_index_cols_min_rows):
1157 2997 aaronmk
        add_index_col(db, col, '::NOT NULL', expr, nullable=False)
1158 3000 aaronmk
        expr = sql_gen.index_col(col)
1159 2997 aaronmk
1160
    return expr
1161
1162 3083 aaronmk
already_indexed = object() # tells add_indexes() the pkey has already been added
1163
1164
def add_indexes(db, table, has_pkey=True):
1165
    '''Adds an index on all columns in a table.
1166
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1167
        index should be added on the first column.
1168
        * If already_indexed, the pkey is assumed to have already been added
1169
    '''
1170
    cols = table_cols(db, table)
1171
    if has_pkey:
1172
        if has_pkey is not already_indexed: add_pkey(db, table)
1173
        cols = cols[1:]
1174
    for col in cols: add_index(db, col, table)
1175
1176 3079 aaronmk
#### Tables
1177 2772 aaronmk
1178 3079 aaronmk
### Maintenance
1179 2772 aaronmk
1180 3079 aaronmk
def analyze(db, table):
1181
    table = sql_gen.as_Table(table)
1182
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1183 2934 aaronmk
1184 3079 aaronmk
def autoanalyze(db, table):
1185
    if db.autoanalyze: analyze(db, table)
1186 2935 aaronmk
1187 3079 aaronmk
def vacuum(db, table):
1188
    table = sql_gen.as_Table(table)
1189
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1190
        log_level=3))
1191 2086 aaronmk
1192 3079 aaronmk
### Lifecycle
1193
1194 2889 aaronmk
def drop_table(db, table):
1195
    table = sql_gen.as_Table(table)
1196
    return run_query(db, 'DROP TABLE IF EXISTS '+table.to_str(db)+' CASCADE')
1197
1198 3082 aaronmk
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1199
    like=None):
1200 2675 aaronmk
    '''Creates a table.
1201 2681 aaronmk
    @param cols [sql_gen.TypedCol,...] The column names and types
1202
    @param has_pkey If set, the first column becomes the primary key.
1203 2760 aaronmk
    @param col_indexes bool|[ref]
1204
        * If True, indexes will be added on all non-pkey columns.
1205
        * If a list reference, [0] will be set to a function to do this.
1206
          This can be used to delay index creation until the table is populated.
1207 2675 aaronmk
    '''
1208
    table = sql_gen.as_Table(table)
1209
1210 3082 aaronmk
    if like != None:
1211
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1212
            ]+cols
1213 2681 aaronmk
    if has_pkey:
1214
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1215 2872 aaronmk
        pkey.constraints = 'PRIMARY KEY'
1216 2681 aaronmk
1217 3085 aaronmk
    temp = table.is_temp and not db.debug_temp
1218
        # temp tables permanent in debug_temp mode
1219 2760 aaronmk
1220 3085 aaronmk
    # Create table
1221
    while True:
1222
        str_ = 'CREATE'
1223
        if temp: str_ += ' TEMP'
1224
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1225
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1226 3126 aaronmk
        str_ += '\n);'
1227 3085 aaronmk
1228
        try:
1229 3127 aaronmk
            run_query(db, str_, recover=True, cacheable=True, log_level=2,
1230 3085 aaronmk
                log_ignore_excs=(DuplicateException,))
1231
            break
1232
        except DuplicateException:
1233
            table.name = next_version(table.name)
1234
            # try again with next version of name
1235
1236 2760 aaronmk
    # Add indexes
1237 2773 aaronmk
    if has_pkey: has_pkey = already_indexed
1238
    def add_indexes_(): add_indexes(db, table, has_pkey)
1239
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1240
    elif col_indexes: add_indexes_() # add now
1241 2675 aaronmk
1242 3084 aaronmk
def copy_table_struct(db, src, dest):
1243
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1244 3085 aaronmk
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1245 3084 aaronmk
1246 3079 aaronmk
### Data
1247 2684 aaronmk
1248 2970 aaronmk
def truncate(db, table, schema='public', **kw_args):
1249
    '''For params, see run_query()'''
1250 2777 aaronmk
    table = sql_gen.as_Table(table, schema)
1251 2970 aaronmk
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1252 2732 aaronmk
1253 2965 aaronmk
def empty_temp(db, tables):
1254 2972 aaronmk
    if db.debug_temp: return # leave temp tables there for debugging
1255 2965 aaronmk
    tables = lists.mk_seq(tables)
1256 2971 aaronmk
    for table in tables: truncate(db, table, log_level=3)
1257 2965 aaronmk
1258 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
1259
    '''For kw_args, see tables()'''
1260
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1261 3094 aaronmk
1262
def distinct_table(db, table, distinct_on):
1263
    '''Creates a copy of a temp table which is distinct on the given columns.
1264 3099 aaronmk
    The old and new tables will both get an index on these columns, to
1265
    facilitate merge joins.
1266 3097 aaronmk
    @param distinct_on If empty, creates a table with one row. This is useful if
1267
        your distinct_on columns are all literal values.
1268 3099 aaronmk
    @return The new table.
1269 3094 aaronmk
    '''
1270 3099 aaronmk
    new_table = sql_gen.suffixed_table(table, '_distinct')
1271 3094 aaronmk
1272 3099 aaronmk
    copy_table_struct(db, table, new_table)
1273 3097 aaronmk
1274
    limit = None
1275
    if distinct_on == []: limit = 1 # one sample row
1276 3099 aaronmk
    else:
1277
        add_index(db, distinct_on, new_table, unique=True)
1278
        add_index(db, distinct_on, table) # for join optimization
1279 3097 aaronmk
1280 3099 aaronmk
    insert_select(db, new_table, None, mk_select(db, table, start=0,
1281 3097 aaronmk
        limit=limit), ignore=True)
1282 3099 aaronmk
    analyze(db, new_table)
1283 3094 aaronmk
1284 3099 aaronmk
    return new_table