Project

General

Profile

1
# Database access
2

    
3
import copy
4
import re
5
import time
6
import warnings
7

    
8
import exc
9
import dicts
10
import iters
11
import lists
12
import profiling
13
from Proxy import Proxy
14
import rand
15
import sql_gen
16
import strings
17
import util
18

    
19
##### Exceptions
20

    
21
def get_cur_query(cur, input_query=None):
22
    raw_query = None
23
    if hasattr(cur, 'query'): raw_query = cur.query
24
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
25
    
26
    if raw_query != None: return raw_query
27
    else: return '[input] '+strings.ustr(input_query)
28

    
29
def _add_cursor_info(e, *args, **kw_args):
30
    '''For params, see get_cur_query()'''
31
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
32

    
33
class DbException(exc.ExceptionWithCause):
34
    def __init__(self, msg, cause=None, cur=None):
35
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
36
        if cur != None: _add_cursor_info(self, cur)
37

    
38
class ExceptionWithName(DbException):
39
    def __init__(self, name, cause=None):
40
        DbException.__init__(self, 'for name: '
41
            +strings.as_tt(strings.ustr(name)), cause)
42
        self.name = name
43

    
44
class ExceptionWithValue(DbException):
45
    def __init__(self, value, cause=None):
46
        DbException.__init__(self, 'for value: '
47
            +strings.as_tt(strings.urepr(value)), cause)
48
        self.value = value
49

    
50
class ExceptionWithNameType(DbException):
51
    def __init__(self, type_, name, cause=None):
52
        DbException.__init__(self, 'for type: '+strings.as_tt(strings.ustr(
53
            type_))+'; name: '+strings.as_tt(name), cause)
54
        self.type = type_
55
        self.name = name
56

    
57
class ConstraintException(DbException):
58
    def __init__(self, name, cond, cols, cause=None):
59
        msg = 'Violated '+strings.as_tt(name)+' constraint'
60
        if cond != None: msg += ' with condition '+strings.as_tt(cond)
61
        if cols != []: msg += ' on columns: '+strings.as_tt(', '.join(cols))
62
        DbException.__init__(self, msg, cause)
63
        self.name = name
64
        self.cond = cond
65
        self.cols = cols
66

    
67
class MissingCastException(DbException):
68
    def __init__(self, type_, col=None, cause=None):
69
        msg = 'Missing cast to type '+strings.as_tt(type_)
70
        if col != None: msg += ' on column: '+strings.as_tt(col)
71
        DbException.__init__(self, msg, cause)
72
        self.type = type_
73
        self.col = col
74

    
75
class EncodingException(ExceptionWithName): pass
76

    
77
class DuplicateKeyException(ConstraintException): pass
78

    
79
class NullValueException(ConstraintException): pass
80

    
81
class CheckException(ConstraintException): pass
82

    
83
class InvalidValueException(ExceptionWithValue): pass
84

    
85
class DuplicateException(ExceptionWithNameType): pass
86

    
87
class DoesNotExistException(ExceptionWithNameType): pass
88

    
89
class EmptyRowException(DbException): pass
90

    
91
##### Warnings
92

    
93
class DbWarning(UserWarning): pass
94

    
95
##### Result retrieval
96

    
97
def col_names(cur): return (col[0] for col in cur.description)
98

    
99
def rows(cur): return iter(lambda: cur.fetchone(), None)
100

    
101
def consume_rows(cur):
102
    '''Used to fetch all rows so result will be cached'''
103
    iters.consume_iter(rows(cur))
104

    
105
def next_row(cur): return rows(cur).next()
106

    
107
def row(cur):
108
    row_ = next_row(cur)
109
    consume_rows(cur)
110
    return row_
111

    
112
def next_value(cur): return next_row(cur)[0]
113

    
114
def value(cur): return row(cur)[0]
115

    
116
def values(cur): return iters.func_iter(lambda: next_value(cur))
117

    
118
def value_or_none(cur):
119
    try: return value(cur)
120
    except StopIteration: return None
121

    
122
##### Escaping
123

    
124
def esc_name_by_module(module, name):
125
    if module == 'psycopg2' or module == None: quote = '"'
126
    elif module == 'MySQLdb': quote = '`'
127
    else: raise NotImplementedError("Can't escape name for "+module+' database')
128
    return sql_gen.esc_name(name, quote)
129

    
130
def esc_name_by_engine(engine, name, **kw_args):
131
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
132

    
133
def esc_name(db, name, **kw_args):
134
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
135

    
136
def qual_name(db, schema, table):
137
    def esc_name_(name): return esc_name(db, name)
138
    table = esc_name_(table)
139
    if schema != None: return esc_name_(schema)+'.'+table
140
    else: return table
141

    
142
##### Database connections
143

    
144
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
145

    
146
db_engines = {
147
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
148
    'PostgreSQL': ('psycopg2', {}),
149
}
150

    
151
DatabaseErrors_set = set([DbException])
152
DatabaseErrors = tuple(DatabaseErrors_set)
153

    
154
def _add_module(module):
155
    DatabaseErrors_set.add(module.DatabaseError)
156
    global DatabaseErrors
157
    DatabaseErrors = tuple(DatabaseErrors_set)
158

    
159
def db_config_str(db_config):
160
    return db_config['engine']+' database '+db_config['database']
161

    
162
log_debug_none = lambda msg, level=2: None
163

    
164
class DbConn:
165
    def __init__(self, db_config, autocommit=True, caching=True,
166
        log_debug=log_debug_none, debug_temp=False, src=None):
167
        '''
168
        @param debug_temp Whether temporary objects should instead be permanent.
169
            This assists in debugging the internal objects used by the program.
170
        @param src In autocommit mode, will be included in a comment in every
171
            query, to help identify the data source in pg_stat_activity.
172
        '''
173
        self.db_config = db_config
174
        self.autocommit = autocommit
175
        self.caching = caching
176
        self.log_debug = log_debug
177
        self.debug = log_debug != log_debug_none
178
        self.debug_temp = debug_temp
179
        self.src = src
180
        self.autoanalyze = False
181
        self.autoexplain = False
182
        self.profile_row_ct = None
183
        
184
        self._savepoint = 0
185
        self._reset()
186
    
187
    def __getattr__(self, name):
188
        if name == '__dict__': raise Exception('getting __dict__')
189
        if name == 'db': return self._db()
190
        else: raise AttributeError()
191
    
192
    def __getstate__(self):
193
        state = copy.copy(self.__dict__) # shallow copy
194
        state['log_debug'] = None # don't pickle the debug callback
195
        state['_DbConn__db'] = None # don't pickle the connection
196
        return state
197
    
198
    def clear_cache(self): self.query_results = {}
199
    
200
    def _reset(self):
201
        self.clear_cache()
202
        assert self._savepoint == 0
203
        self._notices_seen = set()
204
        self.__db = None
205
    
206
    def connected(self): return self.__db != None
207
    
208
    def close(self):
209
        if not self.connected(): return
210
        
211
        # Record that the automatic transaction is now closed
212
        self._savepoint -= 1
213
        
214
        self.db.close()
215
        self._reset()
216
    
217
    def reconnect(self):
218
        # Do not do this in test mode as it would roll back everything
219
        if self.autocommit: self.close()
220
        # Connection will be reopened automatically on first query
221
    
222
    def _db(self):
223
        if self.__db == None:
224
            # Process db_config
225
            db_config = self.db_config.copy() # don't modify input!
226
            schemas = db_config.pop('schemas', None)
227
            module_name, mappings = db_engines[db_config.pop('engine')]
228
            module = __import__(module_name)
229
            _add_module(module)
230
            for orig, new in mappings.iteritems():
231
                try: util.rename_key(db_config, orig, new)
232
                except KeyError: pass
233
            
234
            # Connect
235
            self.__db = module.connect(**db_config)
236
            
237
            # Record that a transaction is already open
238
            self._savepoint += 1
239
            
240
            # Configure connection
241
            if hasattr(self.db, 'set_isolation_level'):
242
                import psycopg2.extensions
243
                self.db.set_isolation_level(
244
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
245
            if schemas != None:
246
                search_path = [self.esc_name(s) for s in schemas.split(',')]
247
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
248
                    log_level=3)
249
        
250
        return self.__db
251
    
252
    class DbCursor(Proxy):
253
        def __init__(self, outer):
254
            Proxy.__init__(self, outer.db.cursor())
255
            self.outer = outer
256
            self.query_results = outer.query_results
257
            self.query_lookup = None
258
            self.result = []
259
        
260
        def execute(self, query):
261
            self._is_insert = query.startswith('INSERT')
262
            self.query_lookup = query
263
            try:
264
                try: cur = self.inner.execute(query)
265
                finally: self.query = get_cur_query(self.inner, query)
266
            except Exception, e:
267
                self.result = e # cache the exception as the result
268
                self._cache_result()
269
                raise
270
            
271
            # Always cache certain queries
272
            query = sql_gen.lstrip(query)
273
            if query.startswith('CREATE') or query.startswith('ALTER'):
274
                # structural changes
275
                # Rest of query must be unique in the face of name collisions,
276
                # so don't cache ADD COLUMN unless it has distinguishing comment
277
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
278
                    self._cache_result()
279
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
280
                consume_rows(self) # fetch all rows so result will be cached
281
            
282
            return cur
283
        
284
        def fetchone(self):
285
            row = self.inner.fetchone()
286
            if row != None: self.result.append(row)
287
            # otherwise, fetched all rows
288
            else: self._cache_result()
289
            return row
290
        
291
        def _cache_result(self):
292
            # For inserts that return a result set, don't cache result set since
293
            # inserts are not idempotent. Other non-SELECT queries don't have
294
            # their result set read, so only exceptions will be cached (an
295
            # invalid query will always be invalid).
296
            if self.query_results != None and (not self._is_insert
297
                or isinstance(self.result, Exception)):
298
                
299
                assert self.query_lookup != None
300
                self.query_results[self.query_lookup] = self.CacheCursor(
301
                    util.dict_subset(dicts.AttrsDictView(self),
302
                    ['query', 'result', 'rowcount', 'description']))
303
        
304
        class CacheCursor:
305
            def __init__(self, cached_result): self.__dict__ = cached_result
306
            
307
            def execute(self, *args, **kw_args):
308
                if isinstance(self.result, Exception): raise self.result
309
                # otherwise, result is a rows list
310
                self.iter = iter(self.result)
311
            
312
            def fetchone(self):
313
                try: return self.iter.next()
314
                except StopIteration: return None
315
    
316
    def esc_value(self, value):
317
        try: str_ = self.mogrify('%s', [value])
318
        except NotImplementedError, e:
319
            module = util.root_module(self.db)
320
            if module == 'MySQLdb':
321
                import _mysql
322
                str_ = _mysql.escape_string(value)
323
            else: raise e
324
        return strings.to_unicode(str_)
325
    
326
    def esc_name(self, name): return esc_name(self, name) # calls global func
327
    
328
    def std_code(self, str_):
329
        '''Standardizes SQL code.
330
        * Ensures that string literals are prefixed by `E`
331
        '''
332
        if str_.startswith("'"): str_ = 'E'+str_
333
        return str_
334
    
335
    def can_mogrify(self):
336
        module = util.root_module(self.db)
337
        return module == 'psycopg2'
338
    
339
    def mogrify(self, query, params=None):
340
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
341
        else: raise NotImplementedError("Can't mogrify query")
342
    
343
    def set_encoding(self, encoding):
344
        encoding_str = sql_gen.Literal(encoding)
345
        run_query(self, 'SET NAMES '+encoding_str.to_str(self))
346
    
347
    def print_notices(self):
348
        if hasattr(self.db, 'notices'):
349
            for msg in self.db.notices:
350
                if msg not in self._notices_seen:
351
                    self._notices_seen.add(msg)
352
                    self.log_debug(msg, level=2)
353
    
354
    def run_query(self, query, cacheable=False, log_level=2,
355
        debug_msg_ref=None):
356
        '''
357
        @param log_ignore_excs The log_level will be increased by 2 if the query
358
            throws one of these exceptions.
359
        @param debug_msg_ref If specified, the log message will be returned in
360
            this instead of being output. This allows you to filter log messages
361
            depending on the result of the query.
362
        '''
363
        assert query != None
364
        
365
        if self.autocommit and self.src != None:
366
            query = sql_gen.esc_comment(self.src)+'\t'+query
367
        
368
        if not self.caching: cacheable = False
369
        used_cache = False
370
        
371
        if self.debug:
372
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
373
        try:
374
            # Get cursor
375
            if cacheable:
376
                try: cur = self.query_results[query]
377
                except KeyError: cur = self.DbCursor(self)
378
                else: used_cache = True
379
            else: cur = self.db.cursor()
380
            
381
            # Run query
382
            try: cur.execute(query)
383
            except Exception, e:
384
                _add_cursor_info(e, self, query)
385
                raise
386
            else: self.do_autocommit()
387
        finally:
388
            if self.debug:
389
                profiler.stop(self.profile_row_ct)
390
                
391
                ## Log or return query
392
                
393
                query = strings.ustr(get_cur_query(cur, query))
394
                # Put the src comment on a separate line in the log file
395
                query = query.replace('\t', '\n', 1)
396
                
397
                msg = 'DB query: '
398
                
399
                if used_cache: msg += 'cache hit'
400
                elif cacheable: msg += 'cache miss'
401
                else: msg += 'non-cacheable'
402
                
403
                msg += ':\n'+profiler.msg()+'\n'+strings.as_code(query, 'SQL')
404
                
405
                if debug_msg_ref != None: debug_msg_ref[0] = msg
406
                else: self.log_debug(msg, log_level)
407
                
408
                self.print_notices()
409
        
410
        return cur
411
    
412
    def is_cached(self, query): return query in self.query_results
413
    
414
    def with_autocommit(self, func):
415
        import psycopg2.extensions
416
        
417
        prev_isolation_level = self.db.isolation_level
418
        self.db.set_isolation_level(
419
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
420
        try: return func()
421
        finally: self.db.set_isolation_level(prev_isolation_level)
422
    
423
    def with_savepoint(self, func):
424
        top = self._savepoint == 0
425
        savepoint = 'level_'+str(self._savepoint)
426
        
427
        if self.debug:
428
            self.log_debug('Begin transaction', level=4)
429
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
430
        
431
        # Must happen before running queries so they don't get autocommitted
432
        self._savepoint += 1
433
        
434
        if top: query = 'START TRANSACTION ISOLATION LEVEL READ COMMITTED'
435
        else: query = 'SAVEPOINT '+savepoint
436
        self.run_query(query, log_level=4)
437
        try:
438
            return func()
439
            if top: self.run_query('COMMIT', log_level=4)
440
        except:
441
            if top: query = 'ROLLBACK'
442
            else: query = 'ROLLBACK TO SAVEPOINT '+savepoint
443
            self.run_query(query, log_level=4)
444
            
445
            raise
446
        finally:
447
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
448
            # "The savepoint remains valid and can be rolled back to again"
449
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
450
            if not top:
451
                self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
452
            
453
            self._savepoint -= 1
454
            assert self._savepoint >= 0
455
            
456
            if self.debug:
457
                profiler.stop(self.profile_row_ct)
458
                self.log_debug('End transaction\n'+profiler.msg(), level=4)
459
            
460
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
461
    
462
    def do_autocommit(self):
463
        '''Autocommits if outside savepoint'''
464
        assert self._savepoint >= 1
465
        if self.autocommit and self._savepoint == 1:
466
            self.log_debug('Autocommitting', level=4)
467
            self.db.commit()
468
    
469
    def col_info(self, col, cacheable=True):
470
        module = util.root_module(self.db)
471
        if module == 'psycopg2':
472
            qual_table = sql_gen.Literal(col.table.to_str(self))
473
            col_name_str = sql_gen.Literal(col.name)
474
            try:
475
                type_, is_array, default, nullable = row(run_query(self, '''\
476
SELECT
477
format_type(COALESCE(NULLIF(typelem, 0), pg_type.oid), -1) AS type
478
, typcategory = 'A' AS type_is_array
479
, pg_get_expr(pg_attrdef.adbin, attrelid, true) AS default
480
, NOT pg_attribute.attnotnull AS nullable
481
FROM pg_attribute
482
LEFT JOIN pg_type ON pg_type.oid = atttypid
483
LEFT JOIN pg_attrdef ON adrelid = attrelid AND adnum = attnum
484
WHERE
485
attrelid = '''+qual_table.to_str(self)+'''::regclass
486
AND attname = '''+col_name_str.to_str(self)+'''
487
'''
488
                    , recover=True, cacheable=cacheable, log_level=4))
489
            except (DoesNotExistException, StopIteration):
490
                raise sql_gen.NoUnderlyingTableException(col)
491
            if is_array: type_ = sql_gen.ArrayType(type_)
492
        else:
493
            table = sql_gen.Table('columns', 'information_schema')
494
            cols = [sql_gen.Col('data_type'), sql_gen.Col('udt_name'),
495
                'column_default', sql_gen.Cast('boolean',
496
                sql_gen.Col('is_nullable'))]
497
            
498
            conds = [('table_name', col.table.name),
499
                ('column_name', strings.ustr(col.name))]
500
            schema = col.table.schema
501
            if schema != None: conds.append(('table_schema', schema))
502
            
503
            cur = select(self, table, cols, conds, order_by='table_schema',
504
                limit=1, cacheable=cacheable, log_level=4)
505
            try: type_, extra_type, default, nullable = row(cur)
506
            except StopIteration: raise sql_gen.NoUnderlyingTableException(col)
507
            if type_ == 'USER-DEFINED': type_ = extra_type
508
            elif type_ == 'ARRAY':
509
                type_ = sql_gen.ArrayType(strings.remove_prefix('_', extra_type,
510
                    require=True))
511
        
512
        if default != None: default = sql_gen.as_Code(default, self)
513
        return sql_gen.TypedCol(col.name, type_, default, nullable)
514
    
515
    def TempFunction(self, name):
516
        if self.debug_temp: schema = None
517
        else: schema = 'pg_temp'
518
        return sql_gen.Function(name, schema)
519

    
520
connect = DbConn
521

    
522
##### Recoverable querying
523

    
524
def parse_exception(db, e, recover=False):
525
    msg = strings.ustr(e.args[0])
526
    msg = re.sub(r'^(?:PL/Python: )?ValueError: ', r'', msg)
527
    
528
    match = re.match(r'^invalid byte sequence for encoding "(.+?)":', msg)
529
    if match:
530
        encoding, = match.groups()
531
        raise EncodingException(encoding, e)
532
    
533
    def make_DuplicateKeyException(constraint, e):
534
        cols = []
535
        cond = None
536
        if recover: # need auto-rollback to run index_cols()
537
            try:
538
                cols = index_cols(db, constraint)
539
                cond = index_cond(db, constraint)
540
            except NotImplementedError: pass
541
        return DuplicateKeyException(constraint, cond, cols, e)
542
    
543
    match = re.match(r'^duplicate key value violates unique constraint "(.+?)"',
544
        msg)
545
    if match:
546
        constraint, = match.groups()
547
        raise make_DuplicateKeyException(constraint, e)
548
    
549
    match = re.match(r'^could not create unique index "(.+?)"\n'
550
        r'DETAIL:  Key .+? is duplicated', msg)
551
    if match:
552
        constraint, = match.groups()
553
        raise DuplicateKeyException(constraint, None, [], e)
554
    
555
    match = re.match(r'^null value in column "(.+?)" violates not-null'
556
        r' constraint', msg)
557
    if match:
558
        col, = match.groups()
559
        raise NullValueException('NOT NULL', None, [col], e)
560
    
561
    match = re.match(r'^new row for relation "(.+?)" violates check '
562
        r'constraint "(.+?)"', msg)
563
    if match:
564
        table, constraint = match.groups()
565
        constraint = sql_gen.Col(constraint, table)
566
        cond = None
567
        if recover: # need auto-rollback to run constraint_cond()
568
            try: cond = constraint_cond(db, constraint)
569
            except NotImplementedError: pass
570
        raise CheckException(constraint.to_str(db), cond, [], e)
571
    
572
    match = re.match(r'^(?:invalid input (?:syntax|value)\b[^:]*'
573
        r'|.+? out of range)(?:: "(.+?)")?', msg)
574
    if match:
575
        value, = match.groups()
576
        value = util.do_ignore_none(strings.to_unicode, value)
577
        raise InvalidValueException(value, e)
578
    
579
    match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
580
        r'is of type', msg)
581
    if match:
582
        col, type_ = match.groups()
583
        raise MissingCastException(type_, col, e)
584
    
585
    match = re.match(r'^could not determine polymorphic type because '
586
        r'input has type "unknown"', msg)
587
    if match: raise MissingCastException('text', None, e)
588
    
589
    match = re.match(r'^.+? types (.+?) and (.+?) cannot be matched', msg)
590
    if match:
591
        type0, type1 = match.groups()
592
        raise MissingCastException(type0, None, e)
593
    
594
    typed_name_re = r'^(\S+) "?(.+?)"?(?: of relation ".+?")?'
595
    
596
    match = re.match(typed_name_re+r'.*? already exists', msg)
597
    if match:
598
        type_, name = match.groups()
599
        raise DuplicateException(type_, name, e)
600
    
601
    match = re.match(r'more than one (\S+) named ""(.+?)""', msg)
602
    if match:
603
        type_, name = match.groups()
604
        raise DuplicateException(type_, name, e)
605
    
606
    match = re.match(typed_name_re+r' does not exist', msg)
607
    if match:
608
        type_, name = match.groups()
609
        if type_ == 'function':
610
            match = re.match(r'^(.+?)\(.*\)$', name)
611
            if match: # includes params, so is call rather than cast to regproc
612
                function_name, = match.groups()
613
                func = sql_gen.Function(function_name)
614
                if function_exists(db, func) and msg.find('CAST') < 0:
615
                    # not found only because of a missing cast
616
                    type_ = function_param0_type(db, func)
617
                    col = None
618
                    if type_ == 'anyelement': type_ = 'text'
619
                    elif type_ == 'hstore': # cast just the value param
620
                        type_ = 'text'
621
                        col = 'value'
622
                    raise MissingCastException(type_, col, e)
623
        raise DoesNotExistException(type_, name, e)
624
    
625
    raise # no specific exception raised
626

    
627
def with_savepoint(db, func): return db.with_savepoint(func)
628

    
629
def run_query(db, query, recover=None, cacheable=False, log_level=2,
630
    log_ignore_excs=None, **kw_args):
631
    '''For params, see DbConn.run_query()'''
632
    if recover == None: recover = False
633
    if log_ignore_excs == None: log_ignore_excs = ()
634
    log_ignore_excs = tuple(log_ignore_excs)
635
    debug_msg_ref = [None]
636
    
637
    query = with_explain_comment(db, query)
638
    
639
    try:
640
        try:
641
            def run(): return db.run_query(query, cacheable, log_level,
642
                debug_msg_ref, **kw_args)
643
            if recover and not db.is_cached(query):
644
                return with_savepoint(db, run)
645
            else: return run() # don't need savepoint if cached
646
        except Exception, e:
647
            # Give failed EXPLAIN approximately the log_level of its query
648
            if query.startswith('EXPLAIN'): log_level -= 1
649
            
650
            parse_exception(db, e, recover)
651
    except log_ignore_excs:
652
        log_level += 2
653
        raise
654
    finally:
655
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)
656

    
657
##### Basic queries
658

    
659
def is_explainable(query):
660
    # See <http://www.postgresql.org/docs/8.3/static/sql-explain.html#AEN57749>
661
    return re.match(r'^(?:SELECT|INSERT|UPDATE|DELETE|VALUES|EXECUTE|DECLARE)\b'
662
        , query)
663

    
664
def explain(db, query, **kw_args):
665
    '''
666
    For params, see run_query().
667
    '''
668
    kw_args.setdefault('log_level', 4)
669
    
670
    return strings.ustr(strings.join_lines(values(run_query(db,
671
        'EXPLAIN '+query, recover=True, cacheable=True, **kw_args))))
672
        # not a higher log_level because it's useful to see what query is being
673
        # run before it's executed, which EXPLAIN effectively provides
674

    
675
def has_comment(query): return query.endswith('*/')
676

    
677
def with_explain_comment(db, query, **kw_args):
678
    if db.autoexplain and not has_comment(query) and is_explainable(query):
679
        query += '\n'+sql_gen.esc_comment(' EXPLAIN:\n'
680
            +explain(db, query, **kw_args))
681
    return query
682

    
683
def next_version(name):
684
    version = 1 # first existing name was version 0
685
    match = re.match(r'^(.*)#(\d+)$', name)
686
    if match:
687
        name, version = match.groups()
688
        version = int(version)+1
689
    return sql_gen.concat(name, '#'+str(version))
690

    
691
def lock_table(db, table, mode):
692
    table = sql_gen.as_Table(table)
693
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
694

    
695
def run_query_into(db, query, into=None, add_pkey_=False, **kw_args):
696
    '''Outputs a query to a temp table.
697
    For params, see run_query().
698
    '''
699
    if into == None: return run_query(db, query, **kw_args)
700
    
701
    assert isinstance(into, sql_gen.Table)
702
    
703
    into.is_temp = True
704
    # "temporary tables cannot specify a schema name", so remove schema
705
    into.schema = None
706
    
707
    kw_args['recover'] = True
708
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
709
    
710
    temp = not db.debug_temp # tables are permanent in debug_temp mode
711
    
712
    # Create table
713
    while True:
714
        create_query = 'CREATE'
715
        if temp: create_query += ' TEMP'
716
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
717
        
718
        try:
719
            cur = run_query(db, create_query, **kw_args)
720
                # CREATE TABLE AS sets rowcount to # rows in query
721
            break
722
        except DuplicateException, e:
723
            into.name = next_version(into.name)
724
            # try again with next version of name
725
    
726
    if add_pkey_: add_pkey_or_index(db, into, warn=True)
727
    
728
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
729
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
730
    # table is going to be used in complex queries, it is wise to run ANALYZE on
731
    # the temporary table after it is populated."
732
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
733
    # If into is not a temp table, ANALYZE is useful but not required.
734
    analyze(db, into)
735
    
736
    return cur
737

    
738
order_by_pkey = object() # tells mk_select() to order by the pkey
739

    
740
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
741

    
742
def mk_select(db, tables=None, fields=None, conds=None, distinct_on=[],
743
    limit=None, start=None, order_by=order_by_pkey, default_table=None,
744
    explain=True):
745
    '''
746
    @param tables The single table to select from, or a list of tables to join
747
        together, with tables after the first being sql_gen.Join objects
748
    @param fields Use None to select all fields in the table
749
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
750
        * container can be any iterable type
751
        * compare_left_side: sql_gen.Code|str (for col name)
752
        * compare_right_side: sql_gen.ValueCond|literal value
753
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
754
        use all columns
755
    @return query
756
    '''
757
    # Parse tables param
758
    tables = lists.mk_seq(tables)
759
    tables = list(tables) # don't modify input! (list() copies input)
760
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
761
    
762
    # Parse other params
763
    if conds == None: conds = []
764
    elif dicts.is_dict(conds): conds = conds.items()
765
    conds = list(conds) # don't modify input! (list() copies input)
766
    assert limit == None or isinstance(limit, (int, long))
767
    assert start == None or isinstance(start, (int, long))
768
    if limit == 0: order_by = None
769
    if order_by is order_by_pkey:
770
        if lists.is_seq(distinct_on) and distinct_on: order_by = distinct_on[0]
771
        elif table0 != None: order_by = table_order_by(db, table0, recover=True)
772
        else: order_by = None
773
    
774
    query = 'SELECT'
775
    
776
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
777
    
778
    # DISTINCT ON columns
779
    if distinct_on != []:
780
        query += '\nDISTINCT'
781
        if distinct_on is not distinct_on_all:
782
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
783
    
784
    # Columns
785
    if query.find('\n') >= 0: whitespace = '\n'
786
    else: whitespace = ' '
787
    if fields == None: query += whitespace+'*'
788
    else:
789
        assert fields != []
790
        if len(fields) > 1: whitespace = '\n'
791
        query += whitespace+('\n, '.join(map(parse_col, fields)))
792
    
793
    # Main table
794
    if query.find('\n') >= 0 or len(tables) > 0: whitespace = '\n'
795
    else: whitespace = ' '
796
    if table0 != None: query += whitespace+'FROM '+table0.to_str(db)
797
    
798
    # Add joins
799
    left_table = table0
800
    for join_ in tables:
801
        table = join_.table
802
        
803
        # Parse special values
804
        if join_.type_ is sql_gen.filter_out: # filter no match
805
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
806
                sql_gen.CompareCond(None, '~=')))
807
        
808
        query += '\n'+join_.to_str(db, left_table)
809
        
810
        left_table = table
811
    
812
    missing = True
813
    if conds != []:
814
        if len(conds) == 1: whitespace = ' '
815
        else: whitespace = '\n'
816
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
817
            .to_str(db) for l, r in conds], 'WHERE')
818
    if order_by != None:
819
        query += '\nORDER BY '+sql_gen.as_Col(order_by).to_str(db)
820
    if limit != None: query += '\nLIMIT '+str(limit)
821
    if start != None:
822
        if start != 0: query += '\nOFFSET '+str(start)
823
    
824
    if explain: query = with_explain_comment(db, query)
825
    
826
    return query
827

    
828
def select(db, *args, **kw_args):
829
    '''For params, see mk_select() and run_query()'''
830
    recover = kw_args.pop('recover', None)
831
    cacheable = kw_args.pop('cacheable', True)
832
    log_level = kw_args.pop('log_level', 2)
833
    
834
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
835
        log_level=log_level)
836

    
837
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
838
    embeddable=False, ignore=False, src=None):
839
    '''
840
    @param returning str|None An inserted column (such as pkey) to return
841
    @param embeddable Whether the query should be embeddable as a nested SELECT.
842
        Warning: If you set this and cacheable=True when the query is run, the
843
        query will be fully cached, not just if it raises an exception.
844
    @param ignore Whether to ignore duplicate keys.
845
    @param src Will be included in the name of any created function, to help
846
        identify the data source in pg_stat_activity.
847
    '''
848
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
849
    if cols == []: cols = None # no cols (all defaults) = unknown col names
850
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
851
    if select_query == None: select_query = 'DEFAULT VALUES'
852
    if returning != None: returning = sql_gen.as_Col(returning, table)
853
    
854
    first_line = 'INSERT INTO '+table.to_str(db)
855
    
856
    def mk_insert(select_query):
857
        query = first_line
858
        if cols != None:
859
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
860
        query += '\n'+select_query
861
        
862
        if returning != None:
863
            returning_name_col = sql_gen.to_name_only_col(returning)
864
            query += '\nRETURNING '+returning_name_col.to_str(db)
865
        
866
        return query
867
    
868
    return_type = sql_gen.CustomCode('unknown')
869
    if returning != None: return_type = sql_gen.ColType(returning)
870
    
871
    if ignore:
872
        # Always return something to set the correct rowcount
873
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
874
        
875
        embeddable = True # must use function
876
        
877
        if cols == None: row = [sql_gen.Col(sql_gen.all_cols, 'row')]
878
        else: row = [sql_gen.Col(c.name, 'row') for c in cols]
879
        
880
        query = sql_gen.RowExcIgnore(sql_gen.RowType(table), select_query,
881
            sql_gen.ReturnQuery(mk_insert(sql_gen.Values(row).to_str(db))),
882
            cols)
883
    else: query = mk_insert(select_query)
884
    
885
    if embeddable:
886
        # Create function
887
        function_name = sql_gen.clean_name(first_line)
888
        if src != None: function_name = src+': '+function_name
889
        while True:
890
            try:
891
                func = db.TempFunction(function_name)
892
                def_ = sql_gen.FunctionDef(func, sql_gen.SetOf(return_type),
893
                    query)
894
                
895
                run_query(db, def_.to_str(db), recover=True, cacheable=True,
896
                    log_ignore_excs=(DuplicateException,))
897
                break # this version was successful
898
            except DuplicateException, e:
899
                function_name = next_version(function_name)
900
                # try again with next version of name
901
        
902
        # Return query that uses function
903
        cols = None
904
        if returning != None: cols = [returning]
905
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(func), cols)
906
            # AS clause requires function alias
907
        return mk_select(db, func_table, order_by=None)
908
    
909
    return query
910

    
911
def insert_select(db, table, *args, **kw_args):
912
    '''For params, see mk_insert_select() and run_query_into()
913
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
914
        values in
915
    '''
916
    returning = kw_args.get('returning', None)
917
    ignore = kw_args.get('ignore', False)
918
    
919
    into = kw_args.pop('into', None)
920
    if into != None: kw_args['embeddable'] = True
921
    recover = kw_args.pop('recover', None)
922
    if ignore: recover = True
923
    cacheable = kw_args.pop('cacheable', True)
924
    log_level = kw_args.pop('log_level', 2)
925
    
926
    rowcount_only = ignore and returning == None # keep NULL rows on server
927
    if rowcount_only: into = sql_gen.Table('rowcount')
928
    
929
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
930
        into, recover=recover, cacheable=cacheable, log_level=log_level)
931
    if rowcount_only: empty_temp(db, into)
932
    autoanalyze(db, table)
933
    return cur
934

    
935
default = sql_gen.default # tells insert() to use the default value for a column
936

    
937
def insert(db, table, row, *args, **kw_args):
938
    '''For params, see insert_select()'''
939
    ignore = kw_args.pop('ignore', False)
940
    if ignore: kw_args.setdefault('recover', True)
941
    
942
    if lists.is_seq(row): cols = None
943
    else:
944
        cols = row.keys()
945
        row = row.values()
946
    row = list(row) # ensure that "== []" works
947
    
948
    if row == []: query = None
949
    else: query = sql_gen.Values(row).to_str(db)
950
    
951
    try: return insert_select(db, table, cols, query, *args, **kw_args)
952
    except (DuplicateKeyException, NullValueException):
953
        if not ignore: raise
954
        return None
955

    
956
def mk_update(db, table, changes=None, cond=None, in_place=False,
957
    cacheable_=True):
958
    '''
959
    @param changes [(col, new_value),...]
960
        * container can be any iterable type
961
        * col: sql_gen.Code|str (for col name)
962
        * new_value: sql_gen.Code|literal value
963
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
964
    @param in_place If set, locks the table and updates rows in place.
965
        This avoids creating dead rows in PostgreSQL.
966
        * cond must be None
967
    @param cacheable_ Whether column structure information used to generate the
968
        query can be cached
969
    @return str query
970
    '''
971
    table = sql_gen.as_Table(table)
972
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
973
        for c, v in changes]
974
    
975
    if in_place:
976
        assert cond == None
977
        
978
        def col_type(col):
979
            return sql_gen.canon_type(db.col_info(
980
                sql_gen.with_default_table(c, table), cacheable_).type)
981
        changes = [(c, v, col_type(c)) for c, v in changes]
982
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
983
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '+t+'\nUSING '
984
            +v.to_str(db) for c, v, t in changes))
985
    else:
986
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
987
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
988
            for c, v in changes))
989
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
990
    
991
    query = with_explain_comment(db, query)
992
    
993
    return query
994

    
995
def update(db, table, *args, **kw_args):
996
    '''For params, see mk_update() and run_query()'''
997
    recover = kw_args.pop('recover', None)
998
    cacheable = kw_args.pop('cacheable', False)
999
    log_level = kw_args.pop('log_level', 2)
1000
    
1001
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
1002
        cacheable, log_level=log_level)
1003
    autoanalyze(db, table)
1004
    return cur
1005

    
1006
def mk_delete(db, table, cond=None):
1007
    '''
1008
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
1009
    @return str query
1010
    '''
1011
    query = 'DELETE FROM '+table.to_str(db)
1012
    if cond != None: query += '\nWHERE '+cond.to_str(db)
1013
    
1014
    query = with_explain_comment(db, query)
1015
    
1016
    return query
1017

    
1018
def delete(db, table, *args, **kw_args):
1019
    '''For params, see mk_delete() and run_query()'''
1020
    recover = kw_args.pop('recover', None)
1021
    cacheable = kw_args.pop('cacheable', True)
1022
    log_level = kw_args.pop('log_level', 2)
1023
    
1024
    cur = run_query(db, mk_delete(db, table, *args, **kw_args), recover,
1025
        cacheable, log_level=log_level)
1026
    autoanalyze(db, table)
1027
    return cur
1028

    
1029
def last_insert_id(db):
1030
    module = util.root_module(db.db)
1031
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
1032
    elif module == 'MySQLdb': return db.insert_id()
1033
    else: return None
1034

    
1035
def define_func(db, def_):
1036
    func = def_.function
1037
    while True:
1038
        try:
1039
            run_query(db, def_.to_str(db), recover=True, cacheable=True,
1040
                log_ignore_excs=(DuplicateException,))
1041
            break # successful
1042
        except DuplicateException:
1043
            func.name = next_version(func.name)
1044
            # try again with next version of name
1045

    
1046
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
1047
    '''Creates a mapping from original column names (which may have collisions)
1048
    to names that will be distinct among the columns' tables.
1049
    This is meant to be used for several tables that are being joined together.
1050
    @param cols The columns to combine. Duplicates will be removed.
1051
    @param into The table for the new columns.
1052
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
1053
        columns will be included in the mapping even if they are not in cols.
1054
        The tables of the provided Col objects will be changed to into, so make
1055
        copies of them if you want to keep the original tables.
1056
    @param as_items Whether to return a list of dict items instead of a dict
1057
    @return dict(orig_col=new_col, ...)
1058
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
1059
        * new_col: sql_gen.Col(orig_col_name, into)
1060
        * All mappings use the into table so its name can easily be
1061
          changed for all columns at once
1062
    '''
1063
    cols = lists.uniqify(cols)
1064
    
1065
    items = []
1066
    for col in preserve:
1067
        orig_col = copy.copy(col)
1068
        col.table = into
1069
        items.append((orig_col, col))
1070
    preserve = set(preserve)
1071
    for col in cols:
1072
        if col not in preserve:
1073
            items.append((col, sql_gen.Col(strings.ustr(col), into, col.srcs)))
1074
    
1075
    if not as_items: items = dict(items)
1076
    return items
1077

    
1078
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
1079
    '''For params, see mk_flatten_mapping()
1080
    @return See return value of mk_flatten_mapping()
1081
    '''
1082
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
1083
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
1084
    run_query_into(db, mk_select(db, joins, cols, limit=limit, start=start),
1085
        into=into, add_pkey_=True)
1086
        # don't cache because the temp table will usually be truncated after use
1087
    return dict(items)
1088

    
1089
##### Database structure introspection
1090

    
1091
#### Tables
1092

    
1093
def tables(db, schema_like='public', table_like='%', exact=False,
1094
    cacheable=True):
1095
    if exact: compare = '='
1096
    else: compare = 'LIKE'
1097
    
1098
    module = util.root_module(db.db)
1099
    if module == 'psycopg2':
1100
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
1101
            ('tablename', sql_gen.CompareCond(table_like, compare))]
1102
        return values(select(db, 'pg_tables', ['tablename'], conds,
1103
            order_by='tablename', cacheable=cacheable, log_level=4))
1104
    elif module == 'MySQLdb':
1105
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
1106
            , cacheable=True, log_level=4))
1107
    else: raise NotImplementedError("Can't list tables for "+module+' database')
1108

    
1109
def table_exists(db, table, cacheable=True):
1110
    table = sql_gen.as_Table(table)
1111
    return list(tables(db, table.schema, table.name, True, cacheable)) != []
1112

    
1113
def table_row_count(db, table, recover=None):
1114
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
1115
        order_by=None), recover=recover, log_level=3))
1116

    
1117
def table_col_names(db, table, recover=None):
1118
    return list(col_names(select(db, table, limit=0, recover=recover,
1119
        log_level=4)))
1120

    
1121
def table_cols(db, table, *args, **kw_args):
1122
    return [sql_gen.as_Col(strings.ustr(c), table)
1123
        for c in table_col_names(db, table, *args, **kw_args)]
1124

    
1125
def table_pkey_index(db, table, recover=None):
1126
    table_str = sql_gen.Literal(table.to_str(db))
1127
    try:
1128
        return sql_gen.Table(value(run_query(db, '''\
1129
SELECT relname
1130
FROM pg_index
1131
JOIN pg_class index ON index.oid = indexrelid
1132
WHERE
1133
indrelid = '''+table_str.to_str(db)+'''::regclass
1134
AND indisprimary
1135
'''
1136
            , recover, cacheable=True, log_level=4)), table.schema)
1137
    except StopIteration: raise DoesNotExistException('primary key', '')
1138

    
1139
def table_pkey_col(db, table, recover=None):
1140
    table = sql_gen.as_Table(table)
1141
    
1142
    module = util.root_module(db.db)
1143
    if module == 'psycopg2':
1144
        return sql_gen.Col(index_cols(db, table_pkey_index(db, table,
1145
            recover))[0], table)
1146
    else:
1147
        join_cols = ['table_schema', 'table_name', 'constraint_schema',
1148
            'constraint_name']
1149
        tables = [sql_gen.Table('key_column_usage', 'information_schema'),
1150
            sql_gen.Join(
1151
                sql_gen.Table('table_constraints', 'information_schema'),
1152
                dict(((c, sql_gen.join_same_not_null) for c in join_cols)))]
1153
        cols = [sql_gen.Col('column_name')]
1154
        
1155
        conds = [('constraint_type', 'PRIMARY KEY'), ('table_name', table.name)]
1156
        schema = table.schema
1157
        if schema != None: conds.append(('table_schema', schema))
1158
        order_by = 'position_in_unique_constraint'
1159
        
1160
        try: return sql_gen.Col(value(select(db, tables, cols, conds,
1161
            order_by=order_by, limit=1, log_level=4)), table)
1162
        except StopIteration: raise DoesNotExistException('primary key', '')
1163

    
1164
def table_has_pkey(db, table, recover=None):
1165
    try: table_pkey_col(db, table, recover)
1166
    except DoesNotExistException: return False
1167
    else: return True
1168

    
1169
def pkey_name(db, table, recover=None):
1170
    '''If no pkey, returns the first column in the table.'''
1171
    return pkey_col(db, table, recover).name
1172

    
1173
def pkey_col(db, table, recover=None):
1174
    '''If no pkey, returns the first column in the table.'''
1175
    try: return table_pkey_col(db, table, recover)
1176
    except DoesNotExistException: return table_cols(db, table, recover)[0]
1177

    
1178
not_null_col = 'not_null_col'
1179

    
1180
def table_not_null_col(db, table, recover=None):
1181
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
1182
    if not_null_col in table_col_names(db, table, recover): return not_null_col
1183
    else: return pkey_name(db, table, recover)
1184

    
1185
def constraint_cond(db, constraint):
1186
    module = util.root_module(db.db)
1187
    if module == 'psycopg2':
1188
        table_str = sql_gen.Literal(constraint.table.to_str(db))
1189
        name_str = sql_gen.Literal(constraint.name)
1190
        return value(run_query(db, '''\
1191
SELECT consrc
1192
FROM pg_constraint
1193
WHERE
1194
conrelid = '''+table_str.to_str(db)+'''::regclass
1195
AND conname = '''+name_str.to_str(db)+'''
1196
'''
1197
            , cacheable=True, log_level=4))
1198
    else: raise NotImplementedError("Can't get constraint condition for "
1199
        +module+' database')
1200

    
1201
def index_exprs(db, index):
1202
    index = sql_gen.as_Table(index)
1203
    module = util.root_module(db.db)
1204
    if module == 'psycopg2':
1205
        qual_index = sql_gen.Literal(index.to_str(db))
1206
        return list(values(run_query(db, '''\
1207
SELECT pg_get_indexdef(indexrelid, generate_series(1, indnatts), true)
1208
FROM pg_index
1209
WHERE indexrelid = '''+qual_index.to_str(db)+'''::regclass
1210
'''
1211
            , cacheable=True, log_level=4)))
1212
    else: raise NotImplementedError()
1213

    
1214
def index_cols(db, index):
1215
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
1216
    automatically created. When you don't know whether something is a UNIQUE
1217
    constraint or a UNIQUE index, use this function.'''
1218
    return map(sql_gen.parse_expr_col, index_exprs(db, index))
1219

    
1220
def index_cond(db, index):
1221
    index = sql_gen.as_Table(index)
1222
    module = util.root_module(db.db)
1223
    if module == 'psycopg2':
1224
        qual_index = sql_gen.Literal(index.to_str(db))
1225
        return value(run_query(db, '''\
1226
SELECT pg_get_expr(indpred, indrelid, true)
1227
FROM pg_index
1228
WHERE indexrelid = '''+qual_index.to_str(db)+'''::regclass
1229
'''
1230
            , cacheable=True, log_level=4))
1231
    else: raise NotImplementedError()
1232

    
1233
def index_order_by(db, index):
1234
    return sql_gen.CustomCode(', '.join(index_exprs(db, index)))
1235

    
1236
def table_cluster_on(db, table, recover=None):
1237
    '''
1238
    @return The table's cluster index, or its pkey if none is set
1239
    '''
1240
    table_str = sql_gen.Literal(table.to_str(db))
1241
    try:
1242
        return sql_gen.Table(value(run_query(db, '''\
1243
SELECT relname
1244
FROM pg_index
1245
JOIN pg_class index ON index.oid = indexrelid
1246
WHERE
1247
indrelid = '''+table_str.to_str(db)+'''::regclass
1248
AND indisclustered
1249
'''
1250
            , recover, cacheable=True, log_level=4)), table.schema)
1251
    except StopIteration: return table_pkey_index(db, table, recover)
1252

    
1253
def table_order_by(db, table, recover=None):
1254
    if table.order_by == None:
1255
        try: table.order_by = index_order_by(db, table_cluster_on(db, table,
1256
            recover))
1257
        except DoesNotExistException: pass
1258
    return table.order_by
1259

    
1260
#### Functions
1261

    
1262
def function_exists(db, function):
1263
    qual_function = sql_gen.Literal(function.to_str(db))
1264
    try:
1265
        select(db, fields=[sql_gen.Cast('regproc', qual_function)],
1266
            recover=True, cacheable=True, log_level=4)
1267
    except DoesNotExistException: return False
1268
    except DuplicateException: return True # overloaded function
1269
    else: return True
1270

    
1271
def function_param0_type(db, function):
1272
    qual_function = sql_gen.Literal(function.to_str(db))
1273
    return value(run_query(db, '''\
1274
SELECT proargtypes[0]::regtype
1275
FROM pg_proc
1276
WHERE oid = '''+qual_function.to_str(db)+'''::regproc
1277
'''
1278
        , cacheable=True, log_level=4))
1279

    
1280
##### Structural changes
1281

    
1282
#### Columns
1283

    
1284
def add_col(db, table, col, comment=None, if_not_exists=False, **kw_args):
1285
    '''
1286
    @param col TypedCol Name may be versioned, so be sure to propagate any
1287
        renaming back to any source column for the TypedCol.
1288
    @param comment None|str SQL comment used to distinguish columns of the same
1289
        name from each other when they contain different data, to allow the
1290
        ADD COLUMN query to be cached. If not set, query will not be cached.
1291
    '''
1292
    assert isinstance(col, sql_gen.TypedCol)
1293
    
1294
    while True:
1295
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1296
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1297
        
1298
        try:
1299
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1300
            break
1301
        except DuplicateException:
1302
            if if_not_exists: raise
1303
            col.name = next_version(col.name)
1304
            # try again with next version of name
1305

    
1306
def add_not_null(db, col):
1307
    table = col.table
1308
    col = sql_gen.to_name_only_col(col)
1309
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1310
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1311

    
1312
def drop_not_null(db, col):
1313
    table = col.table
1314
    col = sql_gen.to_name_only_col(col)
1315
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1316
        +col.to_str(db)+' DROP NOT NULL', cacheable=True, log_level=3)
1317

    
1318
row_num_col = '_row_num'
1319

    
1320
row_num_col_def = sql_gen.TypedCol('', 'serial', nullable=False,
1321
    constraints='PRIMARY KEY')
1322

    
1323
def add_row_num(db, table, name=row_num_col):
1324
    '''Adds a row number column to a table. Its definition is in
1325
    row_num_col_def. It will be the primary key.'''
1326
    col_def = copy.copy(row_num_col_def)
1327
    col_def.name = name
1328
    add_col(db, table, col_def, comment='', if_not_exists=True, log_level=3)
1329

    
1330
#### Indexes
1331

    
1332
def add_pkey(db, table, cols=None, recover=None):
1333
    '''Adds a primary key.
1334
    @param cols [sql_gen.Col,...] The columns in the primary key.
1335
        Defaults to the first column in the table.
1336
    @pre The table must not already have a primary key.
1337
    '''
1338
    table = sql_gen.as_Table(table)
1339
    if cols == None: cols = [pkey_name(db, table, recover)]
1340
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1341
    
1342
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1343
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1344
        log_ignore_excs=(DuplicateException,))
1345

    
1346
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1347
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1348
    Currently, only function calls and literal values are supported expressions.
1349
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1350
        This allows indexes to be used for comparisons where NULLs are equal.
1351
    '''
1352
    exprs = lists.mk_seq(exprs)
1353
    
1354
    # Parse exprs
1355
    old_exprs = exprs[:]
1356
    exprs = []
1357
    cols = []
1358
    for i, expr in enumerate(old_exprs):
1359
        expr = sql_gen.as_Col(expr, table)
1360
        
1361
        # Handle nullable columns
1362
        if ensure_not_null_:
1363
            try: expr = sql_gen.ensure_not_null(db, expr)
1364
            except KeyError: pass # unknown type, so just create plain index
1365
        
1366
        # Extract col
1367
        expr = copy.deepcopy(expr) # don't modify input!
1368
        col = expr
1369
        if isinstance(expr, sql_gen.FunctionCall): col = expr.args[0]
1370
        expr = sql_gen.cast_literal(expr)
1371
        if not isinstance(expr, (sql_gen.Expr, sql_gen.Col)):
1372
            expr = sql_gen.Expr(expr)
1373
            
1374
        
1375
        # Extract table
1376
        if table == None:
1377
            assert sql_gen.is_table_col(col)
1378
            table = col.table
1379
        
1380
        if isinstance(col, sql_gen.Col): col.table = None
1381
        
1382
        exprs.append(expr)
1383
        cols.append(col)
1384
    
1385
    table = sql_gen.as_Table(table)
1386
    
1387
    # Add index
1388
    str_ = 'CREATE'
1389
    if unique: str_ += ' UNIQUE'
1390
    str_ += ' INDEX ON '+table.to_str(db)+' ('+(
1391
        ', '.join((v.to_str(db) for v in exprs)))+')'
1392
    run_query(db, str_, recover=True, cacheable=True, log_level=3)
1393

    
1394
def add_pkey_index(db, table): add_index(db, pkey_col(db, table), table)
1395

    
1396
def add_pkey_or_index(db, table, cols=None, recover=None, warn=False):
1397
    try: add_pkey(db, table, cols, recover)
1398
    except DuplicateKeyException, e:
1399
        if warn: warnings.warn(UserWarning(exc.str_(e)))
1400
        add_pkey_index(db, table)
1401

    
1402
already_indexed = object() # tells add_indexes() the pkey has already been added
1403

    
1404
def add_indexes(db, table, has_pkey=True):
1405
    '''Adds an index on all columns in a table.
1406
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1407
        index should be added on the first column.
1408
        * If already_indexed, the pkey is assumed to have already been added
1409
    '''
1410
    cols = table_col_names(db, table)
1411
    if has_pkey:
1412
        if has_pkey is not already_indexed: add_pkey(db, table)
1413
        cols = cols[1:]
1414
    for col in cols: add_index(db, col, table)
1415

    
1416
#### Tables
1417

    
1418
### Maintenance
1419

    
1420
def analyze(db, table):
1421
    table = sql_gen.as_Table(table)
1422
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1423

    
1424
def autoanalyze(db, table):
1425
    if db.autoanalyze: analyze(db, table)
1426

    
1427
def vacuum(db, table):
1428
    table = sql_gen.as_Table(table)
1429
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1430
        log_level=3))
1431

    
1432
### Lifecycle
1433

    
1434
def drop(db, type_, name):
1435
    name = sql_gen.as_Name(name)
1436
    run_query(db, 'DROP '+type_+' IF EXISTS '+name.to_str(db)+' CASCADE')
1437

    
1438
def drop_table(db, table): drop(db, 'TABLE', table)
1439

    
1440
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1441
    like=None):
1442
    '''Creates a table.
1443
    @param cols [sql_gen.TypedCol,...] The column names and types
1444
    @param has_pkey If set, the first column becomes the primary key.
1445
    @param col_indexes bool|[ref]
1446
        * If True, indexes will be added on all non-pkey columns.
1447
        * If a list reference, [0] will be set to a function to do this.
1448
          This can be used to delay index creation until the table is populated.
1449
    '''
1450
    table = sql_gen.as_Table(table)
1451
    
1452
    if like != None:
1453
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1454
            ]+cols
1455
        table.order_by = like.order_by
1456
    if has_pkey:
1457
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1458
        pkey.constraints = 'PRIMARY KEY'
1459
    
1460
    temp = table.is_temp and not db.debug_temp
1461
        # temp tables permanent in debug_temp mode
1462
    
1463
    # Create table
1464
    def create():
1465
        str_ = 'CREATE'
1466
        if temp: str_ += ' TEMP'
1467
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1468
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1469
        str_ += '\n);'
1470
        
1471
        run_query(db, str_, recover=True, cacheable=True, log_level=2,
1472
            log_ignore_excs=(DuplicateException,))
1473
    if table.is_temp:
1474
        while True:
1475
            try:
1476
                create()
1477
                break
1478
            except DuplicateException:
1479
                table.name = next_version(table.name)
1480
                # try again with next version of name
1481
    else: create()
1482
    
1483
    # Add indexes
1484
    if has_pkey: has_pkey = already_indexed
1485
    def add_indexes_(): add_indexes(db, table, has_pkey)
1486
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1487
    elif col_indexes: add_indexes_() # add now
1488

    
1489
def copy_table_struct(db, src, dest):
1490
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1491
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1492

    
1493
def copy_table(db, src, dest):
1494
    '''Creates a copy of a table, including data'''
1495
    copy_table_struct(db, src, dest)
1496
    insert_select(db, dest, None, mk_select(db, src))
1497

    
1498
### Data
1499

    
1500
def truncate(db, table, schema='public', **kw_args):
1501
    '''For params, see run_query()'''
1502
    table = sql_gen.as_Table(table, schema)
1503
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1504

    
1505
def empty_temp(db, tables):
1506
    tables = lists.mk_seq(tables)
1507
    for table in tables: truncate(db, table, log_level=3)
1508

    
1509
def empty_db(db, schema='public', **kw_args):
1510
    '''For kw_args, see tables()'''
1511
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1512

    
1513
def distinct_table(db, table, distinct_on):
1514
    '''Creates a copy of a temp table which is distinct on the given columns.
1515
    Adds an index on table's distinct_on columns, to facilitate merge joins.
1516
    @param distinct_on If empty, creates a table with one row. This is useful if
1517
        your distinct_on columns are all literal values.
1518
    @return The new table.
1519
    '''
1520
    new_table = sql_gen.suffixed_table(table, '_distinct')
1521
    distinct_on = filter(sql_gen.is_table_col, distinct_on)
1522
    
1523
    copy_table_struct(db, table, new_table)
1524
    
1525
    limit = None
1526
    if distinct_on == []: limit = 1 # one sample row
1527
    else: add_index(db, distinct_on, table) # for join optimization
1528
    
1529
    insert_select(db, new_table, None, mk_select(db, table,
1530
        distinct_on=distinct_on, order_by=None, limit=limit))
1531
    analyze(db, new_table)
1532
    
1533
    return new_table
(33-33/47)