Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1960 aaronmk
import lists
11 1889 aaronmk
from Proxy import Proxy
12 1872 aaronmk
import rand
13 2217 aaronmk
import sql_gen
14 862 aaronmk
import strings
15 131 aaronmk
import util
16 11 aaronmk
17 832 aaronmk
##### Exceptions
18
19 2804 aaronmk
def get_cur_query(cur, input_query=None):
20 2168 aaronmk
    raw_query = None
21
    if hasattr(cur, 'query'): raw_query = cur.query
22
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
23 2170 aaronmk
24
    if raw_query != None: return raw_query
25 2804 aaronmk
    else: return '[input] '+strings.ustr(input_query)
26 14 aaronmk
27 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
28
    '''For params, see get_cur_query()'''
29 2771 aaronmk
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
30 135 aaronmk
31 300 aaronmk
class DbException(exc.ExceptionWithCause):
32 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
33 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
34 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
35
36 2143 aaronmk
class ExceptionWithName(DbException):
37
    def __init__(self, name, cause=None):
38 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
39 2143 aaronmk
        self.name = name
40 360 aaronmk
41 3109 aaronmk
class ExceptionWithValue(DbException):
42
    def __init__(self, value, cause=None):
43
        DbException.__init__(self, 'for value: '+strings.as_tt(repr(value)),
44
            cause)
45 2240 aaronmk
        self.value = value
46
47 2945 aaronmk
class ExceptionWithNameType(DbException):
48
    def __init__(self, type_, name, cause=None):
49
        DbException.__init__(self, 'for type: '+strings.as_tt(str(type_))
50
            +'; name: '+strings.as_tt(name), cause)
51
        self.type = type_
52
        self.name = name
53
54 2306 aaronmk
class ConstraintException(DbException):
55
    def __init__(self, name, cols, cause=None):
56 2484 aaronmk
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
57
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
58 2306 aaronmk
        self.name = name
59 468 aaronmk
        self.cols = cols
60 11 aaronmk
61 2523 aaronmk
class MissingCastException(DbException):
62
    def __init__(self, type_, col, cause=None):
63
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
64
            +' on column: '+strings.as_tt(col), cause)
65
        self.type = type_
66
        self.col = col
67
68 2143 aaronmk
class NameException(DbException): pass
69
70 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
71 13 aaronmk
72 2306 aaronmk
class NullValueException(ConstraintException): pass
73 13 aaronmk
74 3109 aaronmk
class InvalidValueException(ExceptionWithValue): pass
75 2239 aaronmk
76 2945 aaronmk
class DuplicateException(ExceptionWithNameType): pass
77 2143 aaronmk
78 89 aaronmk
class EmptyRowException(DbException): pass
79
80 865 aaronmk
##### Warnings
81
82
class DbWarning(UserWarning): pass
83
84 1930 aaronmk
##### Result retrieval
85
86
def col_names(cur): return (col[0] for col in cur.description)
87
88
def rows(cur): return iter(lambda: cur.fetchone(), None)
89
90
def consume_rows(cur):
91
    '''Used to fetch all rows so result will be cached'''
92
    iters.consume_iter(rows(cur))
93
94
def next_row(cur): return rows(cur).next()
95
96
def row(cur):
97
    row_ = next_row(cur)
98
    consume_rows(cur)
99
    return row_
100
101
def next_value(cur): return next_row(cur)[0]
102
103
def value(cur): return row(cur)[0]
104
105
def values(cur): return iters.func_iter(lambda: next_value(cur))
106
107
def value_or_none(cur):
108
    try: return value(cur)
109
    except StopIteration: return None
110
111 2762 aaronmk
##### Escaping
112 2101 aaronmk
113 2573 aaronmk
def esc_name_by_module(module, name):
114
    if module == 'psycopg2' or module == None: quote = '"'
115 2101 aaronmk
    elif module == 'MySQLdb': quote = '`'
116
    else: raise NotImplementedError("Can't escape name for "+module+' database')
117 2500 aaronmk
    return sql_gen.esc_name(name, quote)
118 2101 aaronmk
119
def esc_name_by_engine(engine, name, **kw_args):
120
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
121
122
def esc_name(db, name, **kw_args):
123
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
124
125
def qual_name(db, schema, table):
126
    def esc_name_(name): return esc_name(db, name)
127
    table = esc_name_(table)
128
    if schema != None: return esc_name_(schema)+'.'+table
129
    else: return table
130
131 1869 aaronmk
##### Database connections
132 1849 aaronmk
133 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
134 1926 aaronmk
135 1869 aaronmk
db_engines = {
136
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
137
    'PostgreSQL': ('psycopg2', {}),
138
}
139
140
DatabaseErrors_set = set([DbException])
141
DatabaseErrors = tuple(DatabaseErrors_set)
142
143
def _add_module(module):
144
    DatabaseErrors_set.add(module.DatabaseError)
145
    global DatabaseErrors
146
    DatabaseErrors = tuple(DatabaseErrors_set)
147
148
def db_config_str(db_config):
149
    return db_config['engine']+' database '+db_config['database']
150
151 2448 aaronmk
log_debug_none = lambda msg, level=2: None
152 1901 aaronmk
153 1849 aaronmk
class DbConn:
154 2923 aaronmk
    def __init__(self, db_config, autocommit=True, caching=True,
155 2915 aaronmk
        log_debug=log_debug_none, debug_temp=False):
156
        '''
157
        @param debug_temp Whether temporary objects should instead be permanent.
158
            This assists in debugging the internal objects used by the program.
159
        '''
160 1869 aaronmk
        self.db_config = db_config
161 2190 aaronmk
        self.autocommit = autocommit
162
        self.caching = caching
163 1901 aaronmk
        self.log_debug = log_debug
164 2193 aaronmk
        self.debug = log_debug != log_debug_none
165 2915 aaronmk
        self.debug_temp = debug_temp
166 3074 aaronmk
        self.autoanalyze = False
167 1869 aaronmk
168 3124 aaronmk
        self._savepoint = 0
169 3120 aaronmk
        self._reset()
170 1869 aaronmk
171
    def __getattr__(self, name):
172
        if name == '__dict__': raise Exception('getting __dict__')
173
        if name == 'db': return self._db()
174
        else: raise AttributeError()
175
176
    def __getstate__(self):
177
        state = copy.copy(self.__dict__) # shallow copy
178 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
179 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
180
        return state
181
182 3118 aaronmk
    def clear_cache(self): self.query_results = {}
183
184 3120 aaronmk
    def _reset(self):
185 3118 aaronmk
        self.clear_cache()
186 3124 aaronmk
        assert self._savepoint == 0
187 3118 aaronmk
        self._notices_seen = set()
188
        self.__db = None
189
190 2165 aaronmk
    def connected(self): return self.__db != None
191
192 3116 aaronmk
    def close(self):
193 3119 aaronmk
        if not self.connected(): return
194
195 3135 aaronmk
        # Record that the automatic transaction is now closed
196 3136 aaronmk
        self._savepoint -= 1
197 3135 aaronmk
198 3119 aaronmk
        self.db.close()
199 3120 aaronmk
        self._reset()
200 3116 aaronmk
201 3125 aaronmk
    def reconnect(self):
202
        # Do not do this in test mode as it would roll back everything
203
        if self.autocommit: self.close()
204
        # Connection will be reopened automatically on first query
205
206 1869 aaronmk
    def _db(self):
207
        if self.__db == None:
208
            # Process db_config
209
            db_config = self.db_config.copy() # don't modify input!
210 2097 aaronmk
            schemas = db_config.pop('schemas', None)
211 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
212
            module = __import__(module_name)
213
            _add_module(module)
214
            for orig, new in mappings.iteritems():
215
                try: util.rename_key(db_config, orig, new)
216
                except KeyError: pass
217
218
            # Connect
219
            self.__db = module.connect(**db_config)
220
221 3161 aaronmk
            # Record that a transaction is already open
222
            self._savepoint += 1
223
224 1869 aaronmk
            # Configure connection
225 2906 aaronmk
            if hasattr(self.db, 'set_isolation_level'):
226
                import psycopg2.extensions
227
                self.db.set_isolation_level(
228
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
229 2101 aaronmk
            if schemas != None:
230 2893 aaronmk
                search_path = [self.esc_name(s) for s in schemas.split(',')]
231
                search_path.append(value(run_query(self, 'SHOW search_path',
232
                    log_level=4)))
233
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
234
                    log_level=3)
235 1869 aaronmk
236
        return self.__db
237 1889 aaronmk
238 1891 aaronmk
    class DbCursor(Proxy):
239 1927 aaronmk
        def __init__(self, outer):
240 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
241 2191 aaronmk
            self.outer = outer
242 1927 aaronmk
            self.query_results = outer.query_results
243 1894 aaronmk
            self.query_lookup = None
244 1891 aaronmk
            self.result = []
245 1889 aaronmk
246 2802 aaronmk
        def execute(self, query):
247 2764 aaronmk
            self._is_insert = query.startswith('INSERT')
248 2797 aaronmk
            self.query_lookup = query
249 2148 aaronmk
            try:
250 3162 aaronmk
                try: cur = self.inner.execute(query)
251 2802 aaronmk
                finally: self.query = get_cur_query(self.inner, query)
252 1904 aaronmk
            except Exception, e:
253
                self.result = e # cache the exception as the result
254
                self._cache_result()
255
                raise
256 3004 aaronmk
257
            # Always cache certain queries
258
            if query.startswith('CREATE') or query.startswith('ALTER'):
259 3007 aaronmk
                # structural changes
260 3040 aaronmk
                # Rest of query must be unique in the face of name collisions,
261
                # so don't cache ADD COLUMN unless it has distinguishing comment
262
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
263 3007 aaronmk
                    self._cache_result()
264 3004 aaronmk
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
265 2800 aaronmk
                consume_rows(self) # fetch all rows so result will be cached
266 3004 aaronmk
267 2762 aaronmk
            return cur
268 1894 aaronmk
269 1891 aaronmk
        def fetchone(self):
270
            row = self.inner.fetchone()
271 1899 aaronmk
            if row != None: self.result.append(row)
272
            # otherwise, fetched all rows
273 1904 aaronmk
            else: self._cache_result()
274
            return row
275
276
        def _cache_result(self):
277 2948 aaronmk
            # For inserts that return a result set, don't cache result set since
278
            # inserts are not idempotent. Other non-SELECT queries don't have
279
            # their result set read, so only exceptions will be cached (an
280
            # invalid query will always be invalid).
281 1930 aaronmk
            if self.query_results != None and (not self._is_insert
282 1906 aaronmk
                or isinstance(self.result, Exception)):
283
284 1894 aaronmk
                assert self.query_lookup != None
285 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
286
                    util.dict_subset(dicts.AttrsDictView(self),
287
                    ['query', 'result', 'rowcount', 'description']))
288 1906 aaronmk
289 1916 aaronmk
        class CacheCursor:
290
            def __init__(self, cached_result): self.__dict__ = cached_result
291
292 1927 aaronmk
            def execute(self, *args, **kw_args):
293 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
294
                # otherwise, result is a rows list
295
                self.iter = iter(self.result)
296
297
            def fetchone(self):
298
                try: return self.iter.next()
299
                except StopIteration: return None
300 1891 aaronmk
301 2212 aaronmk
    def esc_value(self, value):
302 2663 aaronmk
        try: str_ = self.mogrify('%s', [value])
303
        except NotImplementedError, e:
304
            module = util.root_module(self.db)
305
            if module == 'MySQLdb':
306
                import _mysql
307
                str_ = _mysql.escape_string(value)
308
            else: raise e
309 2374 aaronmk
        return strings.to_unicode(str_)
310 2212 aaronmk
311 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
312
313 2814 aaronmk
    def std_code(self, str_):
314
        '''Standardizes SQL code.
315
        * Ensures that string literals are prefixed by `E`
316
        '''
317
        if str_.startswith("'"): str_ = 'E'+str_
318
        return str_
319
320 2665 aaronmk
    def can_mogrify(self):
321 2663 aaronmk
        module = util.root_module(self.db)
322 2665 aaronmk
        return module == 'psycopg2'
323 2663 aaronmk
324 2665 aaronmk
    def mogrify(self, query, params=None):
325
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
326
        else: raise NotImplementedError("Can't mogrify query")
327
328 2671 aaronmk
    def print_notices(self):
329 2725 aaronmk
        if hasattr(self.db, 'notices'):
330
            for msg in self.db.notices:
331
                if msg not in self._notices_seen:
332
                    self._notices_seen.add(msg)
333
                    self.log_debug(msg, level=2)
334 2671 aaronmk
335 2793 aaronmk
    def run_query(self, query, cacheable=False, log_level=2,
336 2464 aaronmk
        debug_msg_ref=None):
337 2445 aaronmk
        '''
338 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
339
            throws one of these exceptions.
340 2664 aaronmk
        @param debug_msg_ref If specified, the log message will be returned in
341
            this instead of being output. This allows you to filter log messages
342
            depending on the result of the query.
343 2445 aaronmk
        '''
344 2167 aaronmk
        assert query != None
345
346 2047 aaronmk
        if not self.caching: cacheable = False
347 1903 aaronmk
        used_cache = False
348 2664 aaronmk
349
        def log_msg(query):
350
            if used_cache: cache_status = 'cache hit'
351
            elif cacheable: cache_status = 'cache miss'
352
            else: cache_status = 'non-cacheable'
353
            return 'DB query: '+cache_status+':\n'+strings.as_code(query, 'SQL')
354
355 1903 aaronmk
        try:
356 1927 aaronmk
            # Get cursor
357
            if cacheable:
358
                try:
359 2797 aaronmk
                    cur = self.query_results[query]
360 1927 aaronmk
                    used_cache = True
361
                except KeyError: cur = self.DbCursor(self)
362
            else: cur = self.db.cursor()
363
364 2664 aaronmk
            # Log query
365
            if self.debug and debug_msg_ref == None: # log before running
366
                self.log_debug(log_msg(query), log_level)
367
368 1927 aaronmk
            # Run query
369 3162 aaronmk
            try:
370
                cur.execute(query)
371
                self.do_autocommit()
372
            except Exception, e:
373
                _add_cursor_info(e, self, query)
374
                raise
375 1903 aaronmk
        finally:
376 2671 aaronmk
            self.print_notices()
377 2664 aaronmk
            if self.debug and debug_msg_ref != None: # return after running
378 2793 aaronmk
                debug_msg_ref[0] = log_msg(str(get_cur_query(cur, query)))
379 1903 aaronmk
380
        return cur
381 1914 aaronmk
382 2797 aaronmk
    def is_cached(self, query): return query in self.query_results
383 2139 aaronmk
384 2907 aaronmk
    def with_autocommit(self, func):
385 2801 aaronmk
        import psycopg2.extensions
386
387
        prev_isolation_level = self.db.isolation_level
388 2907 aaronmk
        self.db.set_isolation_level(
389
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
390 2683 aaronmk
        try: return func()
391 2801 aaronmk
        finally: self.db.set_isolation_level(prev_isolation_level)
392 2683 aaronmk
393 2139 aaronmk
    def with_savepoint(self, func):
394 3137 aaronmk
        top = self._savepoint == 0
395 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
396 3137 aaronmk
397 3160 aaronmk
        # Must happen before running queries so they don't get autocommitted
398
        self._savepoint += 1
399
400 3137 aaronmk
        if top: query = 'START TRANSACTION ISOLATION LEVEL READ COMMITTED'
401
        else: query = 'SAVEPOINT '+savepoint
402
        self.run_query(query, log_level=4)
403
        try:
404
            return func()
405
            if top: self.run_query('COMMIT', log_level=4)
406 2139 aaronmk
        except:
407 3137 aaronmk
            if top: query = 'ROLLBACK'
408
            else: query = 'ROLLBACK TO SAVEPOINT '+savepoint
409
            self.run_query(query, log_level=4)
410
411 2139 aaronmk
            raise
412 2930 aaronmk
        finally:
413
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
414
            # "The savepoint remains valid and can be rolled back to again"
415
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
416 3137 aaronmk
            if not top:
417
                self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
418 2930 aaronmk
419
            self._savepoint -= 1
420
            assert self._savepoint >= 0
421
422
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
423 2191 aaronmk
424
    def do_autocommit(self):
425
        '''Autocommits if outside savepoint'''
426 3135 aaronmk
        assert self._savepoint >= 1
427
        if self.autocommit and self._savepoint == 1:
428 2924 aaronmk
            self.log_debug('Autocommitting', level=4)
429 2191 aaronmk
            self.db.commit()
430 2643 aaronmk
431 3155 aaronmk
    def col_info(self, col, cacheable=True):
432 2643 aaronmk
        table = sql_gen.Table('columns', 'information_schema')
433 3063 aaronmk
        type_ = sql_gen.Coalesce(sql_gen.Nullif(sql_gen.Col('data_type'),
434
            'USER-DEFINED'), sql_gen.Col('udt_name'))
435 3078 aaronmk
        cols = [type_, 'column_default',
436
            sql_gen.Cast('boolean', sql_gen.Col('is_nullable'))]
437 2643 aaronmk
438
        conds = [('table_name', col.table.name), ('column_name', col.name)]
439
        schema = col.table.schema
440
        if schema != None: conds.append(('table_schema', schema))
441
442 2819 aaronmk
        type_, default, nullable = row(select(self, table, cols, conds,
443 3150 aaronmk
            order_by='table_schema', limit=1, cacheable=cacheable, log_level=4))
444 2643 aaronmk
            # TODO: order_by search_path schema order
445 2819 aaronmk
        default = sql_gen.as_Code(default, self)
446
447
        return sql_gen.TypedCol(col.name, type_, default, nullable)
448 2917 aaronmk
449
    def TempFunction(self, name):
450
        if self.debug_temp: schema = None
451
        else: schema = 'pg_temp'
452
        return sql_gen.Function(name, schema)
453 1849 aaronmk
454 1869 aaronmk
connect = DbConn
455
456 832 aaronmk
##### Recoverable querying
457 15 aaronmk
458 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
459 11 aaronmk
460 2791 aaronmk
def run_query(db, query, recover=None, cacheable=False, log_level=2,
461
    log_ignore_excs=None, **kw_args):
462 2794 aaronmk
    '''For params, see DbConn.run_query()'''
463 830 aaronmk
    if recover == None: recover = False
464 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
465
    log_ignore_excs = tuple(log_ignore_excs)
466 830 aaronmk
467 2666 aaronmk
    debug_msg_ref = None # usually, db.run_query() logs query before running it
468
    # But if filtering with log_ignore_excs, wait until after exception parsing
469 2984 aaronmk
    if log_ignore_excs != () or not db.can_mogrify(): debug_msg_ref = [None]
470 2666 aaronmk
471 2148 aaronmk
    try:
472 2464 aaronmk
        try:
473 2794 aaronmk
            def run(): return db.run_query(query, cacheable, log_level,
474 2793 aaronmk
                debug_msg_ref, **kw_args)
475 2796 aaronmk
            if recover and not db.is_cached(query):
476 2464 aaronmk
                return with_savepoint(db, run)
477
            else: return run() # don't need savepoint if cached
478
        except Exception, e:
479 3095 aaronmk
            msg = strings.ustr(e.args[0])
480 2464 aaronmk
481 3095 aaronmk
            match = re.match(r'^duplicate key value violates unique constraint '
482 3096 aaronmk
                r'"((_?[^\W_]+(?=[._]))?.+?)"', msg)
483 2464 aaronmk
            if match:
484
                constraint, table = match.groups()
485 3025 aaronmk
                cols = []
486
                if recover: # need auto-rollback to run index_cols()
487
                    try: cols = index_cols(db, table, constraint)
488
                    except NotImplementedError: pass
489
                raise DuplicateKeyException(constraint, cols, e)
490 2464 aaronmk
491 3095 aaronmk
            match = re.match(r'^null value in column "(.+?)" violates not-null'
492 2464 aaronmk
                r' constraint', msg)
493
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
494
495 3095 aaronmk
            match = re.match(r'^(?:invalid input (?:syntax|value)\b.*?'
496 3109 aaronmk
                r'|.+? field value out of range): "(.+?)"', msg)
497 2464 aaronmk
            if match:
498 3109 aaronmk
                value, = match.groups()
499
                raise InvalidValueException(strings.to_unicode(value), e)
500 2464 aaronmk
501 3095 aaronmk
            match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
502 2523 aaronmk
                r'is of type', msg)
503
            if match:
504
                col, type_ = match.groups()
505
                raise MissingCastException(type_, col, e)
506
507 3095 aaronmk
            match = re.match(r'^(\S+) "(.+?)".*? already exists', msg)
508 2945 aaronmk
            if match:
509
                type_, name = match.groups()
510
                raise DuplicateException(type_, name, e)
511 2464 aaronmk
512
            raise # no specific exception raised
513
    except log_ignore_excs:
514
        log_level += 2
515
        raise
516
    finally:
517 2666 aaronmk
        if debug_msg_ref != None and debug_msg_ref[0] != None:
518
            db.log_debug(debug_msg_ref[0], log_level)
519 830 aaronmk
520 832 aaronmk
##### Basic queries
521
522 2153 aaronmk
def next_version(name):
523 2163 aaronmk
    version = 1 # first existing name was version 0
524 2586 aaronmk
    match = re.match(r'^(.*)#(\d+)$', name)
525 2153 aaronmk
    if match:
526 2586 aaronmk
        name, version = match.groups()
527
        version = int(version)+1
528 2932 aaronmk
    return sql_gen.concat(name, '#'+str(version))
529 2153 aaronmk
530 2899 aaronmk
def lock_table(db, table, mode):
531
    table = sql_gen.as_Table(table)
532
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
533
534 2789 aaronmk
def run_query_into(db, query, into=None, add_indexes_=False, **kw_args):
535 2085 aaronmk
    '''Outputs a query to a temp table.
536
    For params, see run_query().
537
    '''
538 2789 aaronmk
    if into == None: return run_query(db, query, **kw_args)
539 2790 aaronmk
540
    assert isinstance(into, sql_gen.Table)
541
542 2992 aaronmk
    into.is_temp = True
543 3008 aaronmk
    # "temporary tables cannot specify a schema name", so remove schema
544
    into.schema = None
545 2992 aaronmk
546 2790 aaronmk
    kw_args['recover'] = True
547 2945 aaronmk
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
548 2790 aaronmk
549 2916 aaronmk
    temp = not db.debug_temp # tables are permanent in debug_temp mode
550 2790 aaronmk
551
    # Create table
552
    while True:
553
        create_query = 'CREATE'
554
        if temp: create_query += ' TEMP'
555
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
556 2385 aaronmk
557 2790 aaronmk
        try:
558
            cur = run_query(db, create_query, **kw_args)
559
                # CREATE TABLE AS sets rowcount to # rows in query
560
            break
561 2945 aaronmk
        except DuplicateException, e:
562 2790 aaronmk
            into.name = next_version(into.name)
563
            # try again with next version of name
564
565
    if add_indexes_: add_indexes(db, into)
566 3075 aaronmk
567
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
568
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
569
    # table is going to be used in complex queries, it is wise to run ANALYZE on
570
    # the temporary table after it is populated."
571
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
572
    # If into is not a temp table, ANALYZE is useful but not required.
573 3073 aaronmk
    analyze(db, into)
574 2790 aaronmk
575
    return cur
576 2085 aaronmk
577 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
578
579 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
580
581 2233 aaronmk
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
582 2293 aaronmk
    start=None, order_by=order_by_pkey, default_table=None):
583 1981 aaronmk
    '''
584 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
585 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
586 1981 aaronmk
    @param fields Use None to select all fields in the table
587 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
588 2379 aaronmk
        * container can be any iterable type
589 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
590
        * compare_right_side: sql_gen.ValueCond|literal value
591 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
592
        use all columns
593 2786 aaronmk
    @return query
594 1981 aaronmk
    '''
595 2315 aaronmk
    # Parse tables param
596 2964 aaronmk
    tables = lists.mk_seq(tables)
597 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
598 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
599 2121 aaronmk
600 2315 aaronmk
    # Parse other params
601 2376 aaronmk
    if conds == None: conds = []
602 2650 aaronmk
    elif dicts.is_dict(conds): conds = conds.items()
603 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
604 3129 aaronmk
    assert limit == None or isinstance(limit, (int, long))
605
    assert start == None or isinstance(start, (int, long))
606 2315 aaronmk
    if order_by is order_by_pkey:
607
        if distinct_on != []: order_by = None
608
        else: order_by = pkey(db, table0, recover=True)
609 865 aaronmk
610 2315 aaronmk
    query = 'SELECT'
611 2056 aaronmk
612 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
613 2056 aaronmk
614 2200 aaronmk
    # DISTINCT ON columns
615 2233 aaronmk
    if distinct_on != []:
616 2467 aaronmk
        query += '\nDISTINCT'
617 2254 aaronmk
        if distinct_on is not distinct_on_all:
618 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
619
620
    # Columns
621 3027 aaronmk
    if fields == None:
622
        if query.find('\n') >= 0: whitespace = '\n'
623
        else: whitespace = ' '
624
        query += whitespace+'*'
625 2765 aaronmk
    else:
626
        assert fields != []
627 3027 aaronmk
        query += '\n'+('\n, '.join(map(parse_col, fields)))
628 2200 aaronmk
629
    # Main table
630 2467 aaronmk
    query += '\nFROM '+table0.to_str(db)
631 865 aaronmk
632 2122 aaronmk
    # Add joins
633 2271 aaronmk
    left_table = table0
634 2263 aaronmk
    for join_ in tables:
635
        table = join_.table
636 2238 aaronmk
637 2343 aaronmk
        # Parse special values
638
        if join_.type_ is sql_gen.filter_out: # filter no match
639 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
640 2853 aaronmk
                sql_gen.CompareCond(None, '~=')))
641 2343 aaronmk
642 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
643 2122 aaronmk
644
        left_table = table
645
646 865 aaronmk
    missing = True
647 2376 aaronmk
    if conds != []:
648 2576 aaronmk
        if len(conds) == 1: whitespace = ' '
649
        else: whitespace = '\n'
650 2578 aaronmk
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
651
            .to_str(db) for l, r in conds], 'WHERE')
652 865 aaronmk
        missing = False
653 2227 aaronmk
    if order_by != None:
654 2467 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
655
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
656 865 aaronmk
    if start != None:
657 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
658 865 aaronmk
        missing = False
659
    if missing: warnings.warn(DbWarning(
660
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
661
662 2786 aaronmk
    return query
663 11 aaronmk
664 2054 aaronmk
def select(db, *args, **kw_args):
665
    '''For params, see mk_select() and run_query()'''
666
    recover = kw_args.pop('recover', None)
667
    cacheable = kw_args.pop('cacheable', True)
668 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
669 2054 aaronmk
670 2791 aaronmk
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
671
        log_level=log_level)
672 2054 aaronmk
673 2788 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
674 3009 aaronmk
    embeddable=False, ignore=False):
675 1960 aaronmk
    '''
676
    @param returning str|None An inserted column (such as pkey) to return
677 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
678 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
679
        query will be fully cached, not just if it raises an exception.
680 3009 aaronmk
    @param ignore Whether to ignore duplicate keys.
681 1960 aaronmk
    '''
682 2754 aaronmk
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
683 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
684 3010 aaronmk
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
685 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
686 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
687 2063 aaronmk
688 2497 aaronmk
    first_line = 'INSERT INTO '+table.to_str(db)
689 2063 aaronmk
690 3009 aaronmk
    def mk_insert(select_query):
691
        query = first_line
692 3014 aaronmk
        if cols != None:
693
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
694 3009 aaronmk
        query += '\n'+select_query
695
696
        if returning != None:
697
            returning_name_col = sql_gen.to_name_only_col(returning)
698
            query += '\nRETURNING '+returning_name_col.to_str(db)
699
700
        return query
701 2063 aaronmk
702 3017 aaronmk
    return_type = 'unknown'
703
    if returning != None: return_type = returning.to_str(db)+'%TYPE'
704
705 3009 aaronmk
    lang = 'sql'
706
    if ignore:
707 3017 aaronmk
        # Always return something to set the correct rowcount
708
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
709
710 3009 aaronmk
        embeddable = True # must use function
711
        lang = 'plpgsql'
712 3010 aaronmk
713 3092 aaronmk
        if cols == None:
714
            row = [sql_gen.Col(sql_gen.all_cols, 'row')]
715
            row_vars = [sql_gen.Table('row')]
716
        else:
717
            row_vars = row = [sql_gen.Col(c.name, 'row') for c in cols]
718
719 3009 aaronmk
        query = '''\
720 3010 aaronmk
DECLARE
721 3014 aaronmk
    row '''+table.to_str(db)+'''%ROWTYPE;
722 3009 aaronmk
BEGIN
723 3019 aaronmk
    /* Need an EXCEPTION block for each individual row because "When an error is
724
    caught by an EXCEPTION clause, [...] all changes to persistent database
725
    state within the block are rolled back."
726
    This is unfortunate because "A block containing an EXCEPTION clause is
727
    significantly more expensive to enter and exit than a block without one."
728 3015 aaronmk
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
729
#PLPGSQL-ERROR-TRAPPING)
730
    */
731 3092 aaronmk
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
732 3034 aaronmk
'''+select_query+'''
733
    LOOP
734 3015 aaronmk
        BEGIN
735 3019 aaronmk
            RETURN QUERY
736 3014 aaronmk
'''+mk_insert(sql_gen.Values(row).to_str(db))+'''
737 3010 aaronmk
;
738 3015 aaronmk
        EXCEPTION
739 3019 aaronmk
            WHEN unique_violation THEN NULL; -- continue to next row
740 3015 aaronmk
        END;
741 3010 aaronmk
    END LOOP;
742
END;\
743 3009 aaronmk
'''
744
    else: query = mk_insert(select_query)
745
746 2070 aaronmk
    if embeddable:
747
        # Create function
748 2513 aaronmk
        function_name = sql_gen.clean_name(first_line)
749 2189 aaronmk
        while True:
750
            try:
751 2918 aaronmk
                function = db.TempFunction(function_name)
752 2194 aaronmk
753 2189 aaronmk
                function_query = '''\
754 2698 aaronmk
CREATE FUNCTION '''+function.to_str(db)+'''()
755 3017 aaronmk
RETURNS SETOF '''+return_type+'''
756 3009 aaronmk
LANGUAGE '''+lang+'''
757 2467 aaronmk
AS $$
758 3009 aaronmk
'''+query+'''
759 2467 aaronmk
$$;
760 2070 aaronmk
'''
761 2446 aaronmk
                run_query(db, function_query, recover=True, cacheable=True,
762 2945 aaronmk
                    log_ignore_excs=(DuplicateException,))
763 2189 aaronmk
                break # this version was successful
764 2945 aaronmk
            except DuplicateException, e:
765 2189 aaronmk
                function_name = next_version(function_name)
766
                # try again with next version of name
767 2070 aaronmk
768 2337 aaronmk
        # Return query that uses function
769 3009 aaronmk
        cols = None
770
        if returning != None: cols = [returning]
771 2698 aaronmk
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
772 3009 aaronmk
            cols) # AS clause requires function alias
773 2787 aaronmk
        return mk_select(db, func_table, start=0, order_by=None)
774 2070 aaronmk
775 2787 aaronmk
    return query
776 2066 aaronmk
777 3074 aaronmk
def insert_select(db, table, *args, **kw_args):
778 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
779 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
780
        values in
781 2072 aaronmk
    '''
782 3141 aaronmk
    returning = kw_args.get('returning', None)
783
    ignore = kw_args.get('ignore', False)
784
785 2386 aaronmk
    into = kw_args.pop('into', None)
786
    if into != None: kw_args['embeddable'] = True
787 2066 aaronmk
    recover = kw_args.pop('recover', None)
788 3141 aaronmk
    if ignore: recover = True
789 2066 aaronmk
    cacheable = kw_args.pop('cacheable', True)
790 2673 aaronmk
    log_level = kw_args.pop('log_level', 2)
791 2066 aaronmk
792 3141 aaronmk
    rowcount_only = ignore and returning == None # keep NULL rows on server
793
    if rowcount_only: into = sql_gen.Table('rowcount')
794
795 3074 aaronmk
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
796
        into, recover=recover, cacheable=cacheable, log_level=log_level)
797 3141 aaronmk
    if rowcount_only: empty_temp(db, into)
798 3074 aaronmk
    autoanalyze(db, table)
799
    return cur
800 2063 aaronmk
801 2738 aaronmk
default = sql_gen.default # tells insert() to use the default value for a column
802 2066 aaronmk
803 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
804 2085 aaronmk
    '''For params, see insert_select()'''
805 1960 aaronmk
    if lists.is_seq(row): cols = None
806
    else:
807
        cols = row.keys()
808
        row = row.values()
809 2738 aaronmk
    row = list(row) # ensure that "== []" works
810 1960 aaronmk
811 2738 aaronmk
    if row == []: query = None
812
    else: query = sql_gen.Values(row).to_str(db)
813 1961 aaronmk
814 2788 aaronmk
    return insert_select(db, table, cols, query, *args, **kw_args)
815 11 aaronmk
816 3152 aaronmk
def mk_update(db, table, changes=None, cond=None, in_place=False,
817 3153 aaronmk
    cacheable_=True):
818 2402 aaronmk
    '''
819
    @param changes [(col, new_value),...]
820
        * container can be any iterable type
821
        * col: sql_gen.Code|str (for col name)
822
        * new_value: sql_gen.Code|literal value
823
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
824 3056 aaronmk
    @param in_place If set, locks the table and updates rows in place.
825
        This avoids creating dead rows in PostgreSQL.
826
        * cond must be None
827 3153 aaronmk
    @param cacheable_ Whether column structure information used to generate the
828 3152 aaronmk
        query can be cached
829 2402 aaronmk
    @return str query
830
    '''
831 3057 aaronmk
    table = sql_gen.as_Table(table)
832
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
833
        for c, v in changes]
834
835 3056 aaronmk
    if in_place:
836
        assert cond == None
837 3058 aaronmk
838 3065 aaronmk
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
839
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
840 3153 aaronmk
            +db.col_info(sql_gen.with_default_table(c, table), cacheable_).type
841 3065 aaronmk
            +'\nUSING '+v.to_str(db) for c, v in changes))
842 3058 aaronmk
    else:
843
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
844
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
845
            for c, v in changes))
846
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
847 3056 aaronmk
848 2402 aaronmk
    return query
849
850 3074 aaronmk
def update(db, table, *args, **kw_args):
851 2402 aaronmk
    '''For params, see mk_update() and run_query()'''
852
    recover = kw_args.pop('recover', None)
853 3043 aaronmk
    cacheable = kw_args.pop('cacheable', False)
854 3030 aaronmk
    log_level = kw_args.pop('log_level', 2)
855 2402 aaronmk
856 3074 aaronmk
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
857
        cacheable, log_level=log_level)
858
    autoanalyze(db, table)
859
    return cur
860 2402 aaronmk
861 135 aaronmk
def last_insert_id(db):
862 1849 aaronmk
    module = util.root_module(db.db)
863 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
864
    elif module == 'MySQLdb': return db.insert_id()
865
    else: return None
866 13 aaronmk
867 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
868 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
869 2415 aaronmk
    to names that will be distinct among the columns' tables.
870 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
871 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
872
    @param into The table for the new columns.
873 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
874
        columns will be included in the mapping even if they are not in cols.
875
        The tables of the provided Col objects will be changed to into, so make
876
        copies of them if you want to keep the original tables.
877
    @param as_items Whether to return a list of dict items instead of a dict
878 2383 aaronmk
    @return dict(orig_col=new_col, ...)
879
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
880 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
881
        * All mappings use the into table so its name can easily be
882 2383 aaronmk
          changed for all columns at once
883
    '''
884 2415 aaronmk
    cols = lists.uniqify(cols)
885
886 2394 aaronmk
    items = []
887 2389 aaronmk
    for col in preserve:
888 2390 aaronmk
        orig_col = copy.copy(col)
889 2392 aaronmk
        col.table = into
890 2394 aaronmk
        items.append((orig_col, col))
891
    preserve = set(preserve)
892
    for col in cols:
893 2716 aaronmk
        if col not in preserve:
894
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
895 2394 aaronmk
896
    if not as_items: items = dict(items)
897
    return items
898 2383 aaronmk
899 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
900 2391 aaronmk
    '''For params, see mk_flatten_mapping()
901
    @return See return value of mk_flatten_mapping()
902
    '''
903 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
904
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
905 2786 aaronmk
    run_query_into(db, mk_select(db, joins, cols, limit=limit, start=start),
906 3166 aaronmk
        into=into)
907 2394 aaronmk
    return dict(items)
908 2391 aaronmk
909 3079 aaronmk
##### Database structure introspection
910 2414 aaronmk
911 3079 aaronmk
#### Tables
912
913
def tables(db, schema_like='public', table_like='%', exact=False):
914
    if exact: compare = '='
915
    else: compare = 'LIKE'
916
917
    module = util.root_module(db.db)
918
    if module == 'psycopg2':
919
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
920
            ('tablename', sql_gen.CompareCond(table_like, compare))]
921
        return values(select(db, 'pg_tables', ['tablename'], conds,
922
            order_by='tablename', log_level=4))
923
    elif module == 'MySQLdb':
924
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
925
            , cacheable=True, log_level=4))
926
    else: raise NotImplementedError("Can't list tables for "+module+' database')
927
928
def table_exists(db, table):
929
    table = sql_gen.as_Table(table)
930
    return list(tables(db, table.schema, table.name, exact=True)) != []
931
932 2426 aaronmk
def table_row_count(db, table, recover=None):
933 2786 aaronmk
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
934 2443 aaronmk
        order_by=None, start=0), recover=recover, log_level=3))
935 2426 aaronmk
936 2414 aaronmk
def table_cols(db, table, recover=None):
937
    return list(col_names(select(db, table, limit=0, order_by=None,
938 2443 aaronmk
        recover=recover, log_level=4)))
939 2414 aaronmk
940 2291 aaronmk
def pkey(db, table, recover=None):
941 832 aaronmk
    '''Assumed to be first column in table'''
942 2339 aaronmk
    return table_cols(db, table, recover)[0]
943 832 aaronmk
944 2559 aaronmk
not_null_col = 'not_null_col'
945 2340 aaronmk
946
def table_not_null_col(db, table, recover=None):
947
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
948
    if not_null_col in table_cols(db, table, recover): return not_null_col
949
    else: return pkey(db, table, recover)
950
951 853 aaronmk
def index_cols(db, table, index):
952
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
953
    automatically created. When you don't know whether something is a UNIQUE
954
    constraint or a UNIQUE index, use this function.'''
955 1909 aaronmk
    module = util.root_module(db.db)
956
    if module == 'psycopg2':
957
        return list(values(run_query(db, '''\
958 853 aaronmk
SELECT attname
959 866 aaronmk
FROM
960
(
961
        SELECT attnum, attname
962
        FROM pg_index
963
        JOIN pg_class index ON index.oid = indexrelid
964
        JOIN pg_class table_ ON table_.oid = indrelid
965
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
966
        WHERE
967 2782 aaronmk
            table_.relname = '''+db.esc_value(table)+'''
968
            AND index.relname = '''+db.esc_value(index)+'''
969 866 aaronmk
    UNION
970
        SELECT attnum, attname
971
        FROM
972
        (
973
            SELECT
974
                indrelid
975
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
976
                    AS indkey
977
            FROM pg_index
978
            JOIN pg_class index ON index.oid = indexrelid
979
            JOIN pg_class table_ ON table_.oid = indrelid
980
            WHERE
981 2782 aaronmk
                table_.relname = '''+db.esc_value(table)+'''
982
                AND index.relname = '''+db.esc_value(index)+'''
983 866 aaronmk
        ) s
984
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
985
) s
986 853 aaronmk
ORDER BY attnum
987 2782 aaronmk
'''
988
            , cacheable=True, log_level=4)))
989 1909 aaronmk
    else: raise NotImplementedError("Can't list index columns for "+module+
990
        ' database')
991 853 aaronmk
992 464 aaronmk
def constraint_cols(db, table, constraint):
993 1849 aaronmk
    module = util.root_module(db.db)
994 464 aaronmk
    if module == 'psycopg2':
995
        return list(values(run_query(db, '''\
996
SELECT attname
997
FROM pg_constraint
998
JOIN pg_class ON pg_class.oid = conrelid
999
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
1000
WHERE
1001 2783 aaronmk
    relname = '''+db.esc_value(table)+'''
1002
    AND conname = '''+db.esc_value(constraint)+'''
1003 464 aaronmk
ORDER BY attnum
1004 2783 aaronmk
'''
1005
            )))
1006 464 aaronmk
    else: raise NotImplementedError("Can't list constraint columns for "+module+
1007
        ' database')
1008
1009 3079 aaronmk
#### Functions
1010
1011
def function_exists(db, function):
1012
    function = sql_gen.as_Function(function)
1013
1014
    info_table = sql_gen.Table('routines', 'information_schema')
1015
    conds = [('routine_name', function.name)]
1016
    schema = function.schema
1017
    if schema != None: conds.append(('routine_schema', schema))
1018
    # Exclude trigger functions, since they cannot be called directly
1019
    conds.append(('data_type', sql_gen.CompareCond('trigger', '!=')))
1020
1021
    return list(values(select(db, info_table, ['routine_name'], conds,
1022
        order_by='routine_schema', limit=1, log_level=4))) != []
1023
        # TODO: order_by search_path schema order
1024
1025
##### Structural changes
1026
1027
#### Columns
1028
1029
def add_col(db, table, col, comment=None, **kw_args):
1030
    '''
1031
    @param col TypedCol Name may be versioned, so be sure to propagate any
1032
        renaming back to any source column for the TypedCol.
1033
    @param comment None|str SQL comment used to distinguish columns of the same
1034
        name from each other when they contain different data, to allow the
1035
        ADD COLUMN query to be cached. If not set, query will not be cached.
1036
    '''
1037
    assert isinstance(col, sql_gen.TypedCol)
1038
1039
    while True:
1040
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1041
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1042
1043
        try:
1044
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1045
            break
1046
        except DuplicateException:
1047
            col.name = next_version(col.name)
1048
            # try again with next version of name
1049
1050
def add_not_null(db, col):
1051
    table = col.table
1052
    col = sql_gen.to_name_only_col(col)
1053
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1054
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1055
1056 2096 aaronmk
row_num_col = '_row_num'
1057
1058 3079 aaronmk
row_num_typed_col = sql_gen.TypedCol(row_num_col, 'serial', nullable=False,
1059
    constraints='PRIMARY KEY')
1060
1061
def add_row_num(db, table):
1062
    '''Adds a row number column to a table. Its name is in row_num_col. It will
1063
    be the primary key.'''
1064
    add_col(db, table, row_num_typed_col, log_level=3)
1065
1066
#### Indexes
1067
1068
def add_pkey(db, table, cols=None, recover=None):
1069
    '''Adds a primary key.
1070
    @param cols [sql_gen.Col,...] The columns in the primary key.
1071
        Defaults to the first column in the table.
1072
    @pre The table must not already have a primary key.
1073
    '''
1074
    table = sql_gen.as_Table(table)
1075
    if cols == None: cols = [pkey(db, table, recover)]
1076
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1077
1078
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1079
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1080
        log_ignore_excs=(DuplicateException,))
1081
1082 2998 aaronmk
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1083 2688 aaronmk
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1084 2538 aaronmk
    Currently, only function calls are supported as expressions.
1085 2998 aaronmk
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1086 2847 aaronmk
        This allows indexes to be used for comparisons where NULLs are equal.
1087 2538 aaronmk
    '''
1088 2964 aaronmk
    exprs = lists.mk_seq(exprs)
1089 2538 aaronmk
1090 2688 aaronmk
    # Parse exprs
1091
    old_exprs = exprs[:]
1092
    exprs = []
1093
    cols = []
1094
    for i, expr in enumerate(old_exprs):
1095 2823 aaronmk
        expr = sql_gen.as_Col(expr, table)
1096 2688 aaronmk
1097 2823 aaronmk
        # Handle nullable columns
1098 2998 aaronmk
        if ensure_not_null_:
1099 3164 aaronmk
            try: expr = sql_gen.ensure_not_null(db, expr)
1100 2860 aaronmk
            except KeyError: pass # unknown type, so just create plain index
1101 2823 aaronmk
1102 2688 aaronmk
        # Extract col
1103 3002 aaronmk
        expr = copy.deepcopy(expr) # don't modify input!
1104 2688 aaronmk
        if isinstance(expr, sql_gen.FunctionCall):
1105
            col = expr.args[0]
1106
            expr = sql_gen.Expr(expr)
1107
        else: col = expr
1108 2823 aaronmk
        assert isinstance(col, sql_gen.Col)
1109 2688 aaronmk
1110
        # Extract table
1111
        if table == None:
1112
            assert sql_gen.is_table_col(col)
1113
            table = col.table
1114
1115
        col.table = None
1116
1117
        exprs.append(expr)
1118
        cols.append(col)
1119 2408 aaronmk
1120 2688 aaronmk
    table = sql_gen.as_Table(table)
1121
1122 3005 aaronmk
    # Add index
1123 3148 aaronmk
    str_ = 'CREATE'
1124
    if unique: str_ += ' UNIQUE'
1125
    str_ += ' INDEX ON '+table.to_str(db)+' ('+(
1126
        ', '.join((v.to_str(db) for v in exprs)))+')'
1127
    run_query(db, str_, recover=True, cacheable=True, log_level=3)
1128 2408 aaronmk
1129 3083 aaronmk
already_indexed = object() # tells add_indexes() the pkey has already been added
1130
1131
def add_indexes(db, table, has_pkey=True):
1132
    '''Adds an index on all columns in a table.
1133
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1134
        index should be added on the first column.
1135
        * If already_indexed, the pkey is assumed to have already been added
1136
    '''
1137
    cols = table_cols(db, table)
1138
    if has_pkey:
1139
        if has_pkey is not already_indexed: add_pkey(db, table)
1140
        cols = cols[1:]
1141
    for col in cols: add_index(db, col, table)
1142
1143 3079 aaronmk
#### Tables
1144 2772 aaronmk
1145 3079 aaronmk
### Maintenance
1146 2772 aaronmk
1147 3079 aaronmk
def analyze(db, table):
1148
    table = sql_gen.as_Table(table)
1149
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1150 2934 aaronmk
1151 3079 aaronmk
def autoanalyze(db, table):
1152
    if db.autoanalyze: analyze(db, table)
1153 2935 aaronmk
1154 3079 aaronmk
def vacuum(db, table):
1155
    table = sql_gen.as_Table(table)
1156
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1157
        log_level=3))
1158 2086 aaronmk
1159 3079 aaronmk
### Lifecycle
1160
1161 2889 aaronmk
def drop_table(db, table):
1162
    table = sql_gen.as_Table(table)
1163
    return run_query(db, 'DROP TABLE IF EXISTS '+table.to_str(db)+' CASCADE')
1164
1165 3082 aaronmk
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1166
    like=None):
1167 2675 aaronmk
    '''Creates a table.
1168 2681 aaronmk
    @param cols [sql_gen.TypedCol,...] The column names and types
1169
    @param has_pkey If set, the first column becomes the primary key.
1170 2760 aaronmk
    @param col_indexes bool|[ref]
1171
        * If True, indexes will be added on all non-pkey columns.
1172
        * If a list reference, [0] will be set to a function to do this.
1173
          This can be used to delay index creation until the table is populated.
1174 2675 aaronmk
    '''
1175
    table = sql_gen.as_Table(table)
1176
1177 3082 aaronmk
    if like != None:
1178
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1179
            ]+cols
1180 2681 aaronmk
    if has_pkey:
1181
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1182 2872 aaronmk
        pkey.constraints = 'PRIMARY KEY'
1183 2681 aaronmk
1184 3085 aaronmk
    temp = table.is_temp and not db.debug_temp
1185
        # temp tables permanent in debug_temp mode
1186 2760 aaronmk
1187 3085 aaronmk
    # Create table
1188
    while True:
1189
        str_ = 'CREATE'
1190
        if temp: str_ += ' TEMP'
1191
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1192
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1193 3126 aaronmk
        str_ += '\n);'
1194 3085 aaronmk
1195
        try:
1196 3127 aaronmk
            run_query(db, str_, recover=True, cacheable=True, log_level=2,
1197 3085 aaronmk
                log_ignore_excs=(DuplicateException,))
1198
            break
1199
        except DuplicateException:
1200
            table.name = next_version(table.name)
1201
            # try again with next version of name
1202
1203 2760 aaronmk
    # Add indexes
1204 2773 aaronmk
    if has_pkey: has_pkey = already_indexed
1205
    def add_indexes_(): add_indexes(db, table, has_pkey)
1206
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1207
    elif col_indexes: add_indexes_() # add now
1208 2675 aaronmk
1209 3084 aaronmk
def copy_table_struct(db, src, dest):
1210
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1211 3085 aaronmk
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1212 3084 aaronmk
1213 3079 aaronmk
### Data
1214 2684 aaronmk
1215 2970 aaronmk
def truncate(db, table, schema='public', **kw_args):
1216
    '''For params, see run_query()'''
1217 2777 aaronmk
    table = sql_gen.as_Table(table, schema)
1218 2970 aaronmk
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1219 2732 aaronmk
1220 2965 aaronmk
def empty_temp(db, tables):
1221
    tables = lists.mk_seq(tables)
1222 2971 aaronmk
    for table in tables: truncate(db, table, log_level=3)
1223 2965 aaronmk
1224 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
1225
    '''For kw_args, see tables()'''
1226
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1227 3094 aaronmk
1228
def distinct_table(db, table, distinct_on):
1229
    '''Creates a copy of a temp table which is distinct on the given columns.
1230 3099 aaronmk
    The old and new tables will both get an index on these columns, to
1231
    facilitate merge joins.
1232 3097 aaronmk
    @param distinct_on If empty, creates a table with one row. This is useful if
1233
        your distinct_on columns are all literal values.
1234 3099 aaronmk
    @return The new table.
1235 3094 aaronmk
    '''
1236 3099 aaronmk
    new_table = sql_gen.suffixed_table(table, '_distinct')
1237 3094 aaronmk
1238 3099 aaronmk
    copy_table_struct(db, table, new_table)
1239 3097 aaronmk
1240
    limit = None
1241
    if distinct_on == []: limit = 1 # one sample row
1242 3099 aaronmk
    else:
1243
        add_index(db, distinct_on, new_table, unique=True)
1244
        add_index(db, distinct_on, table) # for join optimization
1245 3097 aaronmk
1246 3099 aaronmk
    insert_select(db, new_table, None, mk_select(db, table, start=0,
1247 3097 aaronmk
        limit=limit), ignore=True)
1248 3099 aaronmk
    analyze(db, new_table)
1249 3094 aaronmk
1250 3099 aaronmk
    return new_table