Project

General

Profile

1
# Database access
2

    
3
import copy
4
import re
5
import time
6
import warnings
7

    
8
import exc
9
import dicts
10
import iters
11
import lists
12
import profiling
13
from Proxy import Proxy
14
import rand
15
import sql_gen
16
import strings
17
import util
18

    
19
##### Exceptions
20

    
21
def get_cur_query(cur, input_query=None):
22
    raw_query = None
23
    if hasattr(cur, 'query'): raw_query = cur.query
24
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
25
    
26
    if raw_query != None: return raw_query
27
    else: return '[input] '+strings.ustr(input_query)
28

    
29
def _add_cursor_info(e, *args, **kw_args):
30
    '''For params, see get_cur_query()'''
31
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
32

    
33
class DbException(exc.ExceptionWithCause):
34
    def __init__(self, msg, cause=None, cur=None):
35
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
36
        if cur != None: _add_cursor_info(self, cur)
37

    
38
class ExceptionWithName(DbException):
39
    def __init__(self, name, cause=None):
40
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
41
        self.name = name
42

    
43
class ExceptionWithValue(DbException):
44
    def __init__(self, value, cause=None):
45
        DbException.__init__(self, 'for value: '+strings.as_tt(repr(value)),
46
            cause)
47
        self.value = value
48

    
49
class ExceptionWithNameType(DbException):
50
    def __init__(self, type_, name, cause=None):
51
        DbException.__init__(self, 'for type: '+strings.as_tt(str(type_))
52
            +'; name: '+strings.as_tt(name), cause)
53
        self.type = type_
54
        self.name = name
55

    
56
class ConstraintException(DbException):
57
    def __init__(self, name, cond, cols, cause=None):
58
        msg = 'Violated '+strings.as_tt(name)+' constraint'
59
        if cond != None: msg += ' with condition '+cond
60
        if cols != []: msg += ' on columns: '+strings.as_tt(', '.join(cols))
61
        DbException.__init__(self, msg, cause)
62
        self.name = name
63
        self.cond = cond
64
        self.cols = cols
65

    
66
class MissingCastException(DbException):
67
    def __init__(self, type_, col, cause=None):
68
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
69
            +' on column: '+strings.as_tt(col), cause)
70
        self.type = type_
71
        self.col = col
72

    
73
class NameException(DbException): pass
74

    
75
class DuplicateKeyException(ConstraintException): pass
76

    
77
class NullValueException(ConstraintException): pass
78

    
79
class CheckException(ConstraintException): pass
80

    
81
class InvalidValueException(ExceptionWithValue): pass
82

    
83
class DuplicateException(ExceptionWithNameType): pass
84

    
85
class EmptyRowException(DbException): pass
86

    
87
##### Warnings
88

    
89
class DbWarning(UserWarning): pass
90

    
91
##### Result retrieval
92

    
93
def col_names(cur): return (col[0] for col in cur.description)
94

    
95
def rows(cur): return iter(lambda: cur.fetchone(), None)
96

    
97
def consume_rows(cur):
98
    '''Used to fetch all rows so result will be cached'''
99
    iters.consume_iter(rows(cur))
100

    
101
def next_row(cur): return rows(cur).next()
102

    
103
def row(cur):
104
    row_ = next_row(cur)
105
    consume_rows(cur)
106
    return row_
107

    
108
def next_value(cur): return next_row(cur)[0]
109

    
110
def value(cur): return row(cur)[0]
111

    
112
def values(cur): return iters.func_iter(lambda: next_value(cur))
113

    
114
def value_or_none(cur):
115
    try: return value(cur)
116
    except StopIteration: return None
117

    
118
##### Escaping
119

    
120
def esc_name_by_module(module, name):
121
    if module == 'psycopg2' or module == None: quote = '"'
122
    elif module == 'MySQLdb': quote = '`'
123
    else: raise NotImplementedError("Can't escape name for "+module+' database')
124
    return sql_gen.esc_name(name, quote)
125

    
126
def esc_name_by_engine(engine, name, **kw_args):
127
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
128

    
129
def esc_name(db, name, **kw_args):
130
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
131

    
132
def qual_name(db, schema, table):
133
    def esc_name_(name): return esc_name(db, name)
134
    table = esc_name_(table)
135
    if schema != None: return esc_name_(schema)+'.'+table
136
    else: return table
137

    
138
##### Database connections
139

    
140
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
141

    
142
db_engines = {
143
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
144
    'PostgreSQL': ('psycopg2', {}),
145
}
146

    
147
DatabaseErrors_set = set([DbException])
148
DatabaseErrors = tuple(DatabaseErrors_set)
149

    
150
def _add_module(module):
151
    DatabaseErrors_set.add(module.DatabaseError)
152
    global DatabaseErrors
153
    DatabaseErrors = tuple(DatabaseErrors_set)
154

    
155
def db_config_str(db_config):
156
    return db_config['engine']+' database '+db_config['database']
157

    
158
log_debug_none = lambda msg, level=2: None
159

    
160
class DbConn:
161
    def __init__(self, db_config, autocommit=True, caching=True,
162
        log_debug=log_debug_none, debug_temp=False, src=None):
163
        '''
164
        @param debug_temp Whether temporary objects should instead be permanent.
165
            This assists in debugging the internal objects used by the program.
166
        @param src In autocommit mode, will be included in a comment in every
167
            query, to help identify the data source in pg_stat_activity.
168
        '''
169
        self.db_config = db_config
170
        self.autocommit = autocommit
171
        self.caching = caching
172
        self.log_debug = log_debug
173
        self.debug = log_debug != log_debug_none
174
        self.debug_temp = debug_temp
175
        self.src = src
176
        self.autoanalyze = False
177
        self.autoexplain = False
178
        self.profile_row_ct = None
179
        
180
        self._savepoint = 0
181
        self._reset()
182
    
183
    def __getattr__(self, name):
184
        if name == '__dict__': raise Exception('getting __dict__')
185
        if name == 'db': return self._db()
186
        else: raise AttributeError()
187
    
188
    def __getstate__(self):
189
        state = copy.copy(self.__dict__) # shallow copy
190
        state['log_debug'] = None # don't pickle the debug callback
191
        state['_DbConn__db'] = None # don't pickle the connection
192
        return state
193
    
194
    def clear_cache(self): self.query_results = {}
195
    
196
    def _reset(self):
197
        self.clear_cache()
198
        assert self._savepoint == 0
199
        self._notices_seen = set()
200
        self.__db = None
201
    
202
    def connected(self): return self.__db != None
203
    
204
    def close(self):
205
        if not self.connected(): return
206
        
207
        # Record that the automatic transaction is now closed
208
        self._savepoint -= 1
209
        
210
        self.db.close()
211
        self._reset()
212
    
213
    def reconnect(self):
214
        # Do not do this in test mode as it would roll back everything
215
        if self.autocommit: self.close()
216
        # Connection will be reopened automatically on first query
217
    
218
    def _db(self):
219
        if self.__db == None:
220
            # Process db_config
221
            db_config = self.db_config.copy() # don't modify input!
222
            schemas = db_config.pop('schemas', None)
223
            module_name, mappings = db_engines[db_config.pop('engine')]
224
            module = __import__(module_name)
225
            _add_module(module)
226
            for orig, new in mappings.iteritems():
227
                try: util.rename_key(db_config, orig, new)
228
                except KeyError: pass
229
            
230
            # Connect
231
            self.__db = module.connect(**db_config)
232
            
233
            # Record that a transaction is already open
234
            self._savepoint += 1
235
            
236
            # Configure connection
237
            if hasattr(self.db, 'set_isolation_level'):
238
                import psycopg2.extensions
239
                self.db.set_isolation_level(
240
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
241
            if schemas != None:
242
                search_path = [self.esc_name(s) for s in schemas.split(',')]
243
                search_path.append(value(run_query(self, 'SHOW search_path',
244
                    log_level=4)))
245
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
246
                    log_level=3)
247
        
248
        return self.__db
249
    
250
    class DbCursor(Proxy):
251
        def __init__(self, outer):
252
            Proxy.__init__(self, outer.db.cursor())
253
            self.outer = outer
254
            self.query_results = outer.query_results
255
            self.query_lookup = None
256
            self.result = []
257
        
258
        def execute(self, query):
259
            self._is_insert = query.startswith('INSERT')
260
            self.query_lookup = query
261
            try:
262
                try: cur = self.inner.execute(query)
263
                finally: self.query = get_cur_query(self.inner, query)
264
            except Exception, e:
265
                self.result = e # cache the exception as the result
266
                self._cache_result()
267
                raise
268
            
269
            # Always cache certain queries
270
            query = sql_gen.lstrip(query)
271
            if query.startswith('CREATE') or query.startswith('ALTER'):
272
                # structural changes
273
                # Rest of query must be unique in the face of name collisions,
274
                # so don't cache ADD COLUMN unless it has distinguishing comment
275
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
276
                    self._cache_result()
277
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
278
                consume_rows(self) # fetch all rows so result will be cached
279
            
280
            return cur
281
        
282
        def fetchone(self):
283
            row = self.inner.fetchone()
284
            if row != None: self.result.append(row)
285
            # otherwise, fetched all rows
286
            else: self._cache_result()
287
            return row
288
        
289
        def _cache_result(self):
290
            # For inserts that return a result set, don't cache result set since
291
            # inserts are not idempotent. Other non-SELECT queries don't have
292
            # their result set read, so only exceptions will be cached (an
293
            # invalid query will always be invalid).
294
            if self.query_results != None and (not self._is_insert
295
                or isinstance(self.result, Exception)):
296
                
297
                assert self.query_lookup != None
298
                self.query_results[self.query_lookup] = self.CacheCursor(
299
                    util.dict_subset(dicts.AttrsDictView(self),
300
                    ['query', 'result', 'rowcount', 'description']))
301
        
302
        class CacheCursor:
303
            def __init__(self, cached_result): self.__dict__ = cached_result
304
            
305
            def execute(self, *args, **kw_args):
306
                if isinstance(self.result, Exception): raise self.result
307
                # otherwise, result is a rows list
308
                self.iter = iter(self.result)
309
            
310
            def fetchone(self):
311
                try: return self.iter.next()
312
                except StopIteration: return None
313
    
314
    def esc_value(self, value):
315
        try: str_ = self.mogrify('%s', [value])
316
        except NotImplementedError, e:
317
            module = util.root_module(self.db)
318
            if module == 'MySQLdb':
319
                import _mysql
320
                str_ = _mysql.escape_string(value)
321
            else: raise e
322
        return strings.to_unicode(str_)
323
    
324
    def esc_name(self, name): return esc_name(self, name) # calls global func
325
    
326
    def std_code(self, str_):
327
        '''Standardizes SQL code.
328
        * Ensures that string literals are prefixed by `E`
329
        '''
330
        if str_.startswith("'"): str_ = 'E'+str_
331
        return str_
332
    
333
    def can_mogrify(self):
334
        module = util.root_module(self.db)
335
        return module == 'psycopg2'
336
    
337
    def mogrify(self, query, params=None):
338
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
339
        else: raise NotImplementedError("Can't mogrify query")
340
    
341
    def print_notices(self):
342
        if hasattr(self.db, 'notices'):
343
            for msg in self.db.notices:
344
                if msg not in self._notices_seen:
345
                    self._notices_seen.add(msg)
346
                    self.log_debug(msg, level=2)
347
    
348
    def run_query(self, query, cacheable=False, log_level=2,
349
        debug_msg_ref=None):
350
        '''
351
        @param log_ignore_excs The log_level will be increased by 2 if the query
352
            throws one of these exceptions.
353
        @param debug_msg_ref If specified, the log message will be returned in
354
            this instead of being output. This allows you to filter log messages
355
            depending on the result of the query.
356
        '''
357
        assert query != None
358
        
359
        if self.autocommit and self.src != None:
360
            query = sql_gen.esc_comment(self.src)+'\t'+query
361
        
362
        if not self.caching: cacheable = False
363
        used_cache = False
364
        
365
        if self.debug:
366
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
367
        try:
368
            # Get cursor
369
            if cacheable:
370
                try: cur = self.query_results[query]
371
                except KeyError: cur = self.DbCursor(self)
372
                else: used_cache = True
373
            else: cur = self.db.cursor()
374
            
375
            # Run query
376
            try: cur.execute(query)
377
            except Exception, e:
378
                _add_cursor_info(e, self, query)
379
                raise
380
            else: self.do_autocommit()
381
        finally:
382
            if self.debug:
383
                profiler.stop(self.profile_row_ct)
384
                
385
                ## Log or return query
386
                
387
                query = str(get_cur_query(cur, query))
388
                # Put the src comment on a separate line in the log file
389
                query = query.replace('\t', '\n', 1)
390
                
391
                msg = 'DB query: '
392
                
393
                if used_cache: msg += 'cache hit'
394
                elif cacheable: msg += 'cache miss'
395
                else: msg += 'non-cacheable'
396
                
397
                msg += ':\n'+profiler.msg()+'\n'+strings.as_code(query, 'SQL')
398
                
399
                if debug_msg_ref != None: debug_msg_ref[0] = msg
400
                else: self.log_debug(msg, log_level)
401
                
402
                self.print_notices()
403
        
404
        return cur
405
    
406
    def is_cached(self, query): return query in self.query_results
407
    
408
    def with_autocommit(self, func):
409
        import psycopg2.extensions
410
        
411
        prev_isolation_level = self.db.isolation_level
412
        self.db.set_isolation_level(
413
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
414
        try: return func()
415
        finally: self.db.set_isolation_level(prev_isolation_level)
416
    
417
    def with_savepoint(self, func):
418
        top = self._savepoint == 0
419
        savepoint = 'level_'+str(self._savepoint)
420
        
421
        if self.debug:
422
            self.log_debug('Begin transaction', level=4)
423
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
424
        
425
        # Must happen before running queries so they don't get autocommitted
426
        self._savepoint += 1
427
        
428
        if top: query = 'START TRANSACTION ISOLATION LEVEL READ COMMITTED'
429
        else: query = 'SAVEPOINT '+savepoint
430
        self.run_query(query, log_level=4)
431
        try:
432
            return func()
433
            if top: self.run_query('COMMIT', log_level=4)
434
        except:
435
            if top: query = 'ROLLBACK'
436
            else: query = 'ROLLBACK TO SAVEPOINT '+savepoint
437
            self.run_query(query, log_level=4)
438
            
439
            raise
440
        finally:
441
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
442
            # "The savepoint remains valid and can be rolled back to again"
443
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
444
            if not top:
445
                self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
446
            
447
            self._savepoint -= 1
448
            assert self._savepoint >= 0
449
            
450
            if self.debug:
451
                profiler.stop(self.profile_row_ct)
452
                self.log_debug('End transaction\n'+profiler.msg(), level=4)
453
            
454
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
455
    
456
    def do_autocommit(self):
457
        '''Autocommits if outside savepoint'''
458
        assert self._savepoint >= 1
459
        if self.autocommit and self._savepoint == 1:
460
            self.log_debug('Autocommitting', level=4)
461
            self.db.commit()
462
    
463
    def col_info(self, col, cacheable=True):
464
        table = sql_gen.Table('columns', 'information_schema')
465
        type_ = sql_gen.Coalesce(sql_gen.Nullif(sql_gen.Col('data_type'),
466
            'USER-DEFINED'), sql_gen.Col('udt_name'))
467
        cols = [type_, 'column_default',
468
            sql_gen.Cast('boolean', sql_gen.Col('is_nullable'))]
469
        
470
        conds = [('table_name', col.table.name), ('column_name', col.name)]
471
        schema = col.table.schema
472
        if schema != None: conds.append(('table_schema', schema))
473
        
474
        type_, default, nullable = row(select(self, table, cols, conds,
475
            order_by='table_schema', limit=1, cacheable=cacheable, log_level=4))
476
            # TODO: order_by search_path schema order
477
        default = sql_gen.as_Code(default, self)
478
        
479
        return sql_gen.TypedCol(col.name, type_, default, nullable)
480
    
481
    def TempFunction(self, name):
482
        if self.debug_temp: schema = None
483
        else: schema = 'pg_temp'
484
        return sql_gen.Function(name, schema)
485

    
486
connect = DbConn
487

    
488
##### Recoverable querying
489

    
490
def with_savepoint(db, func): return db.with_savepoint(func)
491

    
492
def run_query(db, query, recover=None, cacheable=False, log_level=2,
493
    log_ignore_excs=None, **kw_args):
494
    '''For params, see DbConn.run_query()'''
495
    if recover == None: recover = False
496
    if log_ignore_excs == None: log_ignore_excs = ()
497
    log_ignore_excs = tuple(log_ignore_excs)
498
    debug_msg_ref = [None]
499
    
500
    query = with_explain_comment(db, query)
501
    
502
    try:
503
        try:
504
            def run(): return db.run_query(query, cacheable, log_level,
505
                debug_msg_ref, **kw_args)
506
            if recover and not db.is_cached(query):
507
                return with_savepoint(db, run)
508
            else: return run() # don't need savepoint if cached
509
        except Exception, e:
510
            msg = strings.ustr(e.args[0])
511
            
512
            match = re.match(r'^duplicate key value violates unique constraint '
513
                r'"(.+?)"', msg)
514
            if match:
515
                constraint, = match.groups()
516
                cols = []
517
                if recover: # need auto-rollback to run index_cols()
518
                    try: cols = index_cols(db, constraint)
519
                    except NotImplementedError: pass
520
                raise DuplicateKeyException(constraint, None, cols, e)
521
            
522
            match = re.match(r'^null value in column "(.+?)" violates not-null'
523
                r' constraint', msg)
524
            if match:
525
                col, = match.groups()
526
                raise NullValueException('NOT NULL', None, [col], e)
527
            
528
            match = re.match(r'^new row for relation "(.+?)" violates check '
529
                r'constraint "(.+?)"', msg)
530
            if match:
531
                table, constraint = match.groups()
532
                constraint = sql_gen.Col(constraint, table)
533
                cond = None
534
                if recover: # need auto-rollback to run constraint_cond()
535
                    try: cond = constraint_cond(db, constraint)
536
                    except NotImplementedError: pass
537
                raise CheckException(constraint.to_str(db), cond, [], e)
538
            
539
            match = re.match(r'^(?:invalid input (?:syntax|value)\b.*?'
540
                r'|.+? field value out of range): "(.+?)"', msg)
541
            if match:
542
                value, = match.groups()
543
                raise InvalidValueException(strings.to_unicode(value), e)
544
            
545
            match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
546
                r'is of type', msg)
547
            if match:
548
                col, type_ = match.groups()
549
                raise MissingCastException(type_, col, e)
550
            
551
            match = re.match(r'^(\S+) "(.+?)".*? already exists', msg)
552
            if match:
553
                type_, name = match.groups()
554
                raise DuplicateException(type_, name, e)
555
            
556
            raise # no specific exception raised
557
    except log_ignore_excs:
558
        log_level += 2
559
        raise
560
    finally:
561
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)
562

    
563
##### Basic queries
564

    
565
def is_explainable(query):
566
    # See <http://www.postgresql.org/docs/8.3/static/sql-explain.html#AEN57749>
567
    return re.match(r'^(?:SELECT|INSERT|UPDATE|DELETE|VALUES|EXECUTE|DECLARE)\b'
568
        , query)
569

    
570
def explain(db, query, **kw_args):
571
    '''
572
    For params, see run_query().
573
    '''
574
    kw_args.setdefault('log_level', 4)
575
    
576
    return strings.join_lines(values(run_query(db, 'EXPLAIN '+query,
577
        recover=True, cacheable=True, **kw_args)))
578
        # not a higher log_level because it's useful to see what query is being
579
        # run before it's executed, which EXPLAIN effectively provides
580

    
581
def has_comment(query): return query.endswith('*/')
582

    
583
def with_explain_comment(db, query, **kw_args):
584
    if db.autoexplain and not has_comment(query) and is_explainable(query):
585
        query += '\n'+sql_gen.esc_comment(' EXPLAIN:\n'
586
            +explain(db, query, **kw_args))
587
    return query
588

    
589
def next_version(name):
590
    version = 1 # first existing name was version 0
591
    match = re.match(r'^(.*)#(\d+)$', name)
592
    if match:
593
        name, version = match.groups()
594
        version = int(version)+1
595
    return sql_gen.concat(name, '#'+str(version))
596

    
597
def lock_table(db, table, mode):
598
    table = sql_gen.as_Table(table)
599
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
600

    
601
def run_query_into(db, query, into=None, add_pkey_=False, **kw_args):
602
    '''Outputs a query to a temp table.
603
    For params, see run_query().
604
    '''
605
    if into == None: return run_query(db, query, **kw_args)
606
    
607
    assert isinstance(into, sql_gen.Table)
608
    
609
    into.is_temp = True
610
    # "temporary tables cannot specify a schema name", so remove schema
611
    into.schema = None
612
    
613
    kw_args['recover'] = True
614
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
615
    
616
    temp = not db.debug_temp # tables are permanent in debug_temp mode
617
    
618
    # Create table
619
    while True:
620
        create_query = 'CREATE'
621
        if temp: create_query += ' TEMP'
622
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
623
        
624
        try:
625
            cur = run_query(db, create_query, **kw_args)
626
                # CREATE TABLE AS sets rowcount to # rows in query
627
            break
628
        except DuplicateException, e:
629
            into.name = next_version(into.name)
630
            # try again with next version of name
631
    
632
    if add_pkey_: add_pkey(db, into)
633
    
634
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
635
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
636
    # table is going to be used in complex queries, it is wise to run ANALYZE on
637
    # the temporary table after it is populated."
638
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
639
    # If into is not a temp table, ANALYZE is useful but not required.
640
    analyze(db, into)
641
    
642
    return cur
643

    
644
order_by_pkey = object() # tells mk_select() to order by the pkey
645

    
646
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
647

    
648
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
649
    start=None, order_by=order_by_pkey, default_table=None):
650
    '''
651
    @param tables The single table to select from, or a list of tables to join
652
        together, with tables after the first being sql_gen.Join objects
653
    @param fields Use None to select all fields in the table
654
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
655
        * container can be any iterable type
656
        * compare_left_side: sql_gen.Code|str (for col name)
657
        * compare_right_side: sql_gen.ValueCond|literal value
658
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
659
        use all columns
660
    @return query
661
    '''
662
    # Parse tables param
663
    tables = lists.mk_seq(tables)
664
    tables = list(tables) # don't modify input! (list() copies input)
665
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
666
    
667
    # Parse other params
668
    if conds == None: conds = []
669
    elif dicts.is_dict(conds): conds = conds.items()
670
    conds = list(conds) # don't modify input! (list() copies input)
671
    assert limit == None or isinstance(limit, (int, long))
672
    assert start == None or isinstance(start, (int, long))
673
    if order_by is order_by_pkey:
674
        if distinct_on != []: order_by = None
675
        else: order_by = pkey(db, table0, recover=True)
676
    
677
    query = 'SELECT'
678
    
679
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
680
    
681
    # DISTINCT ON columns
682
    if distinct_on != []:
683
        query += '\nDISTINCT'
684
        if distinct_on is not distinct_on_all:
685
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
686
    
687
    # Columns
688
    if query.find('\n') >= 0: whitespace = '\n'
689
    else: whitespace = ' '
690
    if fields == None: query += whitespace+'*'
691
    else:
692
        assert fields != []
693
        if len(fields) > 1: whitespace = '\n'
694
        query += whitespace+('\n, '.join(map(parse_col, fields)))
695
    
696
    # Main table
697
    if query.find('\n') >= 0 or len(tables) > 0: whitespace = '\n'
698
    else: whitespace = ' '
699
    query += whitespace+'FROM '+table0.to_str(db)
700
    
701
    # Add joins
702
    left_table = table0
703
    for join_ in tables:
704
        table = join_.table
705
        
706
        # Parse special values
707
        if join_.type_ is sql_gen.filter_out: # filter no match
708
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
709
                sql_gen.CompareCond(None, '~=')))
710
        
711
        query += '\n'+join_.to_str(db, left_table)
712
        
713
        left_table = table
714
    
715
    missing = True
716
    if conds != []:
717
        if len(conds) == 1: whitespace = ' '
718
        else: whitespace = '\n'
719
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
720
            .to_str(db) for l, r in conds], 'WHERE')
721
    if order_by != None:
722
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
723
    if limit != None: query += '\nLIMIT '+str(limit)
724
    if start != None:
725
        if start != 0: query += '\nOFFSET '+str(start)
726
    
727
    query = with_explain_comment(db, query)
728
    
729
    return query
730

    
731
def select(db, *args, **kw_args):
732
    '''For params, see mk_select() and run_query()'''
733
    recover = kw_args.pop('recover', None)
734
    cacheable = kw_args.pop('cacheable', True)
735
    log_level = kw_args.pop('log_level', 2)
736
    
737
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
738
        log_level=log_level)
739

    
740
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
741
    embeddable=False, ignore=False, src=None):
742
    '''
743
    @param returning str|None An inserted column (such as pkey) to return
744
    @param embeddable Whether the query should be embeddable as a nested SELECT.
745
        Warning: If you set this and cacheable=True when the query is run, the
746
        query will be fully cached, not just if it raises an exception.
747
    @param ignore Whether to ignore duplicate keys.
748
    @param src Will be included in the name of any created function, to help
749
        identify the data source in pg_stat_activity.
750
    '''
751
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
752
    if cols == []: cols = None # no cols (all defaults) = unknown col names
753
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
754
    if select_query == None: select_query = 'DEFAULT VALUES'
755
    if returning != None: returning = sql_gen.as_Col(returning, table)
756
    
757
    first_line = 'INSERT INTO '+table.to_str(db)
758
    
759
    def mk_insert(select_query):
760
        query = first_line
761
        if cols != None:
762
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
763
        query += '\n'+select_query
764
        
765
        if returning != None:
766
            returning_name_col = sql_gen.to_name_only_col(returning)
767
            query += '\nRETURNING '+returning_name_col.to_str(db)
768
        
769
        return query
770
    
771
    return_type = 'unknown'
772
    if returning != None: return_type = returning.to_str(db)+'%TYPE'
773
    
774
    lang = 'sql'
775
    if ignore:
776
        # Always return something to set the correct rowcount
777
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
778
        
779
        embeddable = True # must use function
780
        lang = 'plpgsql'
781
        
782
        if cols == None:
783
            row = [sql_gen.Col(sql_gen.all_cols, 'row')]
784
            row_vars = [sql_gen.Table('row')]
785
        else:
786
            row_vars = row = [sql_gen.Col(c.name, 'row') for c in cols]
787
        
788
        query = '''\
789
DECLARE
790
    row '''+table.to_str(db)+'''%ROWTYPE;
791
BEGIN
792
    /* Need an EXCEPTION block for each individual row because "When an error is
793
    caught by an EXCEPTION clause, [...] all changes to persistent database
794
    state within the block are rolled back."
795
    This is unfortunate because "A block containing an EXCEPTION clause is
796
    significantly more expensive to enter and exit than a block without one."
797
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
798
#PLPGSQL-ERROR-TRAPPING)
799
    */
800
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
801
'''+select_query+'''
802
    LOOP
803
        BEGIN
804
            RETURN QUERY
805
'''+mk_insert(sql_gen.Values(row).to_str(db))+'''
806
;
807
        EXCEPTION
808
            WHEN unique_violation THEN NULL; -- continue to next row
809
        END;
810
    END LOOP;
811
END;\
812
'''
813
    else: query = mk_insert(select_query)
814
    
815
    if embeddable:
816
        # Create function
817
        function_name = sql_gen.clean_name(first_line)
818
        if src != None: function_name = src+': '+function_name
819
        while True:
820
            try:
821
                function = db.TempFunction(function_name)
822
                
823
                function_query = '''\
824
CREATE FUNCTION '''+function.to_str(db)+'''()
825
RETURNS SETOF '''+return_type+'''
826
LANGUAGE '''+lang+'''
827
AS $$
828
'''+query+'''
829
$$;
830
'''
831
                run_query(db, function_query, recover=True, cacheable=True,
832
                    log_ignore_excs=(DuplicateException,))
833
                break # this version was successful
834
            except DuplicateException, e:
835
                function_name = next_version(function_name)
836
                # try again with next version of name
837
        
838
        # Return query that uses function
839
        cols = None
840
        if returning != None: cols = [returning]
841
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
842
            cols) # AS clause requires function alias
843
        return mk_select(db, func_table, order_by=None)
844
    
845
    return query
846

    
847
def insert_select(db, table, *args, **kw_args):
848
    '''For params, see mk_insert_select() and run_query_into()
849
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
850
        values in
851
    '''
852
    returning = kw_args.get('returning', None)
853
    ignore = kw_args.get('ignore', False)
854
    
855
    into = kw_args.pop('into', None)
856
    if into != None: kw_args['embeddable'] = True
857
    recover = kw_args.pop('recover', None)
858
    if ignore: recover = True
859
    cacheable = kw_args.pop('cacheable', True)
860
    log_level = kw_args.pop('log_level', 2)
861
    
862
    rowcount_only = ignore and returning == None # keep NULL rows on server
863
    if rowcount_only: into = sql_gen.Table('rowcount')
864
    
865
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
866
        into, recover=recover, cacheable=cacheable, log_level=log_level)
867
    if rowcount_only: empty_temp(db, into)
868
    autoanalyze(db, table)
869
    return cur
870

    
871
default = sql_gen.default # tells insert() to use the default value for a column
872

    
873
def insert(db, table, row, *args, **kw_args):
874
    '''For params, see insert_select()'''
875
    if lists.is_seq(row): cols = None
876
    else:
877
        cols = row.keys()
878
        row = row.values()
879
    row = list(row) # ensure that "== []" works
880
    
881
    if row == []: query = None
882
    else: query = sql_gen.Values(row).to_str(db)
883
    
884
    return insert_select(db, table, cols, query, *args, **kw_args)
885

    
886
def mk_update(db, table, changes=None, cond=None, in_place=False,
887
    cacheable_=True):
888
    '''
889
    @param changes [(col, new_value),...]
890
        * container can be any iterable type
891
        * col: sql_gen.Code|str (for col name)
892
        * new_value: sql_gen.Code|literal value
893
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
894
    @param in_place If set, locks the table and updates rows in place.
895
        This avoids creating dead rows in PostgreSQL.
896
        * cond must be None
897
    @param cacheable_ Whether column structure information used to generate the
898
        query can be cached
899
    @return str query
900
    '''
901
    table = sql_gen.as_Table(table)
902
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
903
        for c, v in changes]
904
    
905
    if in_place:
906
        assert cond == None
907
        
908
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
909
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
910
            +db.col_info(sql_gen.with_default_table(c, table), cacheable_).type
911
            +'\nUSING '+v.to_str(db) for c, v in changes))
912
    else:
913
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
914
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
915
            for c, v in changes))
916
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
917
    
918
    query = with_explain_comment(db, query)
919
    
920
    return query
921

    
922
def update(db, table, *args, **kw_args):
923
    '''For params, see mk_update() and run_query()'''
924
    recover = kw_args.pop('recover', None)
925
    cacheable = kw_args.pop('cacheable', False)
926
    log_level = kw_args.pop('log_level', 2)
927
    
928
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
929
        cacheable, log_level=log_level)
930
    autoanalyze(db, table)
931
    return cur
932

    
933
def mk_delete(db, table, cond=None):
934
    '''
935
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
936
    @return str query
937
    '''
938
    query = 'DELETE FROM '+table.to_str(db)
939
    if cond != None: query += '\nWHERE '+cond.to_str(db)
940
    
941
    query = with_explain_comment(db, query)
942
    
943
    return query
944

    
945
def delete(db, table, *args, **kw_args):
946
    '''For params, see mk_delete() and run_query()'''
947
    recover = kw_args.pop('recover', None)
948
    cacheable = kw_args.pop('cacheable', True)
949
    log_level = kw_args.pop('log_level', 2)
950
    
951
    cur = run_query(db, mk_delete(db, table, *args, **kw_args), recover,
952
        cacheable, log_level=log_level)
953
    autoanalyze(db, table)
954
    return cur
955

    
956
def last_insert_id(db):
957
    module = util.root_module(db.db)
958
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
959
    elif module == 'MySQLdb': return db.insert_id()
960
    else: return None
961

    
962
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
963
    '''Creates a mapping from original column names (which may have collisions)
964
    to names that will be distinct among the columns' tables.
965
    This is meant to be used for several tables that are being joined together.
966
    @param cols The columns to combine. Duplicates will be removed.
967
    @param into The table for the new columns.
968
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
969
        columns will be included in the mapping even if they are not in cols.
970
        The tables of the provided Col objects will be changed to into, so make
971
        copies of them if you want to keep the original tables.
972
    @param as_items Whether to return a list of dict items instead of a dict
973
    @return dict(orig_col=new_col, ...)
974
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
975
        * new_col: sql_gen.Col(orig_col_name, into)
976
        * All mappings use the into table so its name can easily be
977
          changed for all columns at once
978
    '''
979
    cols = lists.uniqify(cols)
980
    
981
    items = []
982
    for col in preserve:
983
        orig_col = copy.copy(col)
984
        col.table = into
985
        items.append((orig_col, col))
986
    preserve = set(preserve)
987
    for col in cols:
988
        if col not in preserve:
989
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
990
    
991
    if not as_items: items = dict(items)
992
    return items
993

    
994
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
995
    '''For params, see mk_flatten_mapping()
996
    @return See return value of mk_flatten_mapping()
997
    '''
998
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
999
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
1000
    run_query_into(db, mk_select(db, joins, cols, order_by=None, limit=limit,
1001
        start=start), into=into, add_pkey_=True)
1002
    return dict(items)
1003

    
1004
##### Database structure introspection
1005

    
1006
#### Expressions
1007

    
1008
bool_re = r'(?:true|false)'
1009

    
1010
def simplify_expr(expr):
1011
    expr = expr.replace('(NULL IS NULL)', 'true')
1012
    expr = expr.replace('(NULL IS NOT NULL)', 'false')
1013
    expr = re.sub(r' OR '+bool_re, r'', expr)
1014
    expr = re.sub(bool_re+r' OR ', r'', expr)
1015
    while True:
1016
        expr, n = re.subn(r'\((\([^()]*\))\)', r'\1', expr)
1017
        if n == 0: break
1018
    return expr
1019

    
1020
name_re = r'(?:\w+|(?:"[^"]*")+)'
1021

    
1022
def parse_expr_col(str_):
1023
    match = re.match(r'^\('+name_re+r'\(('+name_re+r').*\)\)$', str_)
1024
    if match: str_ = match.group(1)
1025
    return sql_gen.unesc_name(str_)
1026

    
1027
def map_expr(db, expr, mapping, in_cols_found=None):
1028
    '''Replaces output columns with input columns in an expression.
1029
    @param in_cols_found If set, will be filled in with the expr's (input) cols
1030
    '''
1031
    for out, in_ in mapping.iteritems():
1032
        orig_expr = expr
1033
        out = sql_gen.to_name_only_col(out)
1034
        in_str = sql_gen.to_name_only_col(sql_gen.remove_col_rename(in_)
1035
            ).to_str(db)
1036
        
1037
        # Replace out both with and without quotes
1038
        expr = expr.replace(out.to_str(db), in_str)
1039
        expr = re.sub(r'\b'+out.name+r'\b', in_str, expr)
1040
        
1041
        if in_cols_found != None and expr != orig_expr: # replaced something
1042
            in_cols_found.append(in_)
1043
    
1044
    return simplify_expr(expr)
1045

    
1046
#### Tables
1047

    
1048
def tables(db, schema_like='public', table_like='%', exact=False):
1049
    if exact: compare = '='
1050
    else: compare = 'LIKE'
1051
    
1052
    module = util.root_module(db.db)
1053
    if module == 'psycopg2':
1054
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
1055
            ('tablename', sql_gen.CompareCond(table_like, compare))]
1056
        return values(select(db, 'pg_tables', ['tablename'], conds,
1057
            order_by='tablename', log_level=4))
1058
    elif module == 'MySQLdb':
1059
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
1060
            , cacheable=True, log_level=4))
1061
    else: raise NotImplementedError("Can't list tables for "+module+' database')
1062

    
1063
def table_exists(db, table):
1064
    table = sql_gen.as_Table(table)
1065
    return list(tables(db, table.schema, table.name, exact=True)) != []
1066

    
1067
def table_row_count(db, table, recover=None):
1068
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
1069
        order_by=None), recover=recover, log_level=3))
1070

    
1071
def table_cols(db, table, recover=None):
1072
    return list(col_names(select(db, table, limit=0, order_by=None,
1073
        recover=recover, log_level=4)))
1074

    
1075
def pkey(db, table, recover=None):
1076
    '''Assumed to be first column in table'''
1077
    return table_cols(db, table, recover)[0]
1078

    
1079
not_null_col = 'not_null_col'
1080

    
1081
def table_not_null_col(db, table, recover=None):
1082
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
1083
    if not_null_col in table_cols(db, table, recover): return not_null_col
1084
    else: return pkey(db, table, recover)
1085

    
1086
def constraint_cond(db, constraint):
1087
    module = util.root_module(db.db)
1088
    if module == 'psycopg2':
1089
        table_str = sql_gen.Literal(constraint.table.to_str(db))
1090
        name_str = sql_gen.Literal(constraint.name)
1091
        return value(run_query(db, '''\
1092
SELECT consrc
1093
FROM pg_constraint
1094
WHERE
1095
conrelid = '''+table_str.to_str(db)+'''::regclass
1096
AND conname = '''+name_str.to_str(db)+'''
1097
'''
1098
            , cacheable=True, log_level=4))
1099
    else: raise NotImplementedError("Can't list index columns for "+module+
1100
        ' database')
1101

    
1102
def index_cols(db, index):
1103
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
1104
    automatically created. When you don't know whether something is a UNIQUE
1105
    constraint or a UNIQUE index, use this function.'''
1106
    index = sql_gen.as_Table(index)
1107
    module = util.root_module(db.db)
1108
    if module == 'psycopg2':
1109
        qual_index = sql_gen.Literal(index.to_str(db))
1110
        return map(parse_expr_col, values(run_query(db, '''\
1111
SELECT pg_get_indexdef(indexrelid, generate_series(1, indnatts), true)
1112
FROM pg_index
1113
WHERE indexrelid = '''+qual_index.to_str(db)+'''::regclass
1114
'''
1115
            , cacheable=True, log_level=4)))
1116
    else: raise NotImplementedError("Can't list index columns for "+module+
1117
        ' database')
1118

    
1119
#### Functions
1120

    
1121
def function_exists(db, function):
1122
    function = sql_gen.as_Function(function)
1123
    
1124
    info_table = sql_gen.Table('routines', 'information_schema')
1125
    conds = [('routine_name', function.name)]
1126
    schema = function.schema
1127
    if schema != None: conds.append(('routine_schema', schema))
1128
    # Exclude trigger functions, since they cannot be called directly
1129
    conds.append(('data_type', sql_gen.CompareCond('trigger', '!=')))
1130
    
1131
    return list(values(select(db, info_table, ['routine_name'], conds,
1132
        order_by='routine_schema', limit=1, log_level=4))) != []
1133
        # TODO: order_by search_path schema order
1134

    
1135
##### Structural changes
1136

    
1137
#### Columns
1138

    
1139
def add_col(db, table, col, comment=None, **kw_args):
1140
    '''
1141
    @param col TypedCol Name may be versioned, so be sure to propagate any
1142
        renaming back to any source column for the TypedCol.
1143
    @param comment None|str SQL comment used to distinguish columns of the same
1144
        name from each other when they contain different data, to allow the
1145
        ADD COLUMN query to be cached. If not set, query will not be cached.
1146
    '''
1147
    assert isinstance(col, sql_gen.TypedCol)
1148
    
1149
    while True:
1150
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1151
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1152
        
1153
        try:
1154
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1155
            break
1156
        except DuplicateException:
1157
            col.name = next_version(col.name)
1158
            # try again with next version of name
1159

    
1160
def add_not_null(db, col):
1161
    table = col.table
1162
    col = sql_gen.to_name_only_col(col)
1163
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1164
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1165

    
1166
row_num_col = '_row_num'
1167

    
1168
row_num_typed_col = sql_gen.TypedCol(row_num_col, 'serial', nullable=False,
1169
    constraints='PRIMARY KEY')
1170

    
1171
def add_row_num(db, table):
1172
    '''Adds a row number column to a table. Its name is in row_num_col. It will
1173
    be the primary key.'''
1174
    add_col(db, table, row_num_typed_col, log_level=3)
1175

    
1176
#### Indexes
1177

    
1178
def add_pkey(db, table, cols=None, recover=None):
1179
    '''Adds a primary key.
1180
    @param cols [sql_gen.Col,...] The columns in the primary key.
1181
        Defaults to the first column in the table.
1182
    @pre The table must not already have a primary key.
1183
    '''
1184
    table = sql_gen.as_Table(table)
1185
    if cols == None: cols = [pkey(db, table, recover)]
1186
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1187
    
1188
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1189
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1190
        log_ignore_excs=(DuplicateException,))
1191

    
1192
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1193
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1194
    Currently, only function calls and literal values are supported expressions.
1195
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1196
        This allows indexes to be used for comparisons where NULLs are equal.
1197
    '''
1198
    exprs = lists.mk_seq(exprs)
1199
    
1200
    # Parse exprs
1201
    old_exprs = exprs[:]
1202
    exprs = []
1203
    cols = []
1204
    for i, expr in enumerate(old_exprs):
1205
        expr = sql_gen.as_Col(expr, table)
1206
        
1207
        # Handle nullable columns
1208
        if ensure_not_null_:
1209
            try: expr = sql_gen.ensure_not_null(db, expr)
1210
            except KeyError: pass # unknown type, so just create plain index
1211
        
1212
        # Extract col
1213
        expr = copy.deepcopy(expr) # don't modify input!
1214
        col = expr
1215
        if isinstance(expr, sql_gen.FunctionCall): col = expr.args[0]
1216
        expr = sql_gen.cast_literal(expr)
1217
        if not isinstance(expr, (sql_gen.Expr, sql_gen.Col)):
1218
            expr = sql_gen.Expr(expr)
1219
            
1220
        
1221
        # Extract table
1222
        if table == None:
1223
            assert sql_gen.is_table_col(col)
1224
            table = col.table
1225
        
1226
        if isinstance(col, sql_gen.Col): col.table = None
1227
        
1228
        exprs.append(expr)
1229
        cols.append(col)
1230
    
1231
    table = sql_gen.as_Table(table)
1232
    
1233
    # Add index
1234
    str_ = 'CREATE'
1235
    if unique: str_ += ' UNIQUE'
1236
    str_ += ' INDEX ON '+table.to_str(db)+' ('+(
1237
        ', '.join((v.to_str(db) for v in exprs)))+')'
1238
    run_query(db, str_, recover=True, cacheable=True, log_level=3)
1239

    
1240
already_indexed = object() # tells add_indexes() the pkey has already been added
1241

    
1242
def add_indexes(db, table, has_pkey=True):
1243
    '''Adds an index on all columns in a table.
1244
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1245
        index should be added on the first column.
1246
        * If already_indexed, the pkey is assumed to have already been added
1247
    '''
1248
    cols = table_cols(db, table)
1249
    if has_pkey:
1250
        if has_pkey is not already_indexed: add_pkey(db, table)
1251
        cols = cols[1:]
1252
    for col in cols: add_index(db, col, table)
1253

    
1254
#### Tables
1255

    
1256
### Maintenance
1257

    
1258
def analyze(db, table):
1259
    table = sql_gen.as_Table(table)
1260
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1261

    
1262
def autoanalyze(db, table):
1263
    if db.autoanalyze: analyze(db, table)
1264

    
1265
def vacuum(db, table):
1266
    table = sql_gen.as_Table(table)
1267
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1268
        log_level=3))
1269

    
1270
### Lifecycle
1271

    
1272
def drop(db, type_, name):
1273
    name = sql_gen.as_Name(name)
1274
    run_query(db, 'DROP '+type_+' IF EXISTS '+name.to_str(db)+' CASCADE')
1275

    
1276
def drop_table(db, table): drop(db, 'TABLE', table)
1277

    
1278
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1279
    like=None):
1280
    '''Creates a table.
1281
    @param cols [sql_gen.TypedCol,...] The column names and types
1282
    @param has_pkey If set, the first column becomes the primary key.
1283
    @param col_indexes bool|[ref]
1284
        * If True, indexes will be added on all non-pkey columns.
1285
        * If a list reference, [0] will be set to a function to do this.
1286
          This can be used to delay index creation until the table is populated.
1287
    '''
1288
    table = sql_gen.as_Table(table)
1289
    
1290
    if like != None:
1291
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1292
            ]+cols
1293
    if has_pkey:
1294
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1295
        pkey.constraints = 'PRIMARY KEY'
1296
    
1297
    temp = table.is_temp and not db.debug_temp
1298
        # temp tables permanent in debug_temp mode
1299
    
1300
    # Create table
1301
    def create():
1302
        str_ = 'CREATE'
1303
        if temp: str_ += ' TEMP'
1304
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1305
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1306
        str_ += '\n);'
1307
        
1308
        run_query(db, str_, recover=True, cacheable=True, log_level=2,
1309
            log_ignore_excs=(DuplicateException,))
1310
    if table.is_temp:
1311
        while True:
1312
            try:
1313
                create()
1314
                break
1315
            except DuplicateException:
1316
                table.name = next_version(table.name)
1317
                # try again with next version of name
1318
    else: create()
1319
    
1320
    # Add indexes
1321
    if has_pkey: has_pkey = already_indexed
1322
    def add_indexes_(): add_indexes(db, table, has_pkey)
1323
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1324
    elif col_indexes: add_indexes_() # add now
1325

    
1326
def copy_table_struct(db, src, dest):
1327
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1328
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1329

    
1330
### Data
1331

    
1332
def truncate(db, table, schema='public', **kw_args):
1333
    '''For params, see run_query()'''
1334
    table = sql_gen.as_Table(table, schema)
1335
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1336

    
1337
def empty_temp(db, tables):
1338
    tables = lists.mk_seq(tables)
1339
    for table in tables: truncate(db, table, log_level=3)
1340

    
1341
def empty_db(db, schema='public', **kw_args):
1342
    '''For kw_args, see tables()'''
1343
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1344

    
1345
def distinct_table(db, table, distinct_on):
1346
    '''Creates a copy of a temp table which is distinct on the given columns.
1347
    The old and new tables will both get an index on these columns, to
1348
    facilitate merge joins.
1349
    @param distinct_on If empty, creates a table with one row. This is useful if
1350
        your distinct_on columns are all literal values.
1351
    @return The new table.
1352
    '''
1353
    new_table = sql_gen.suffixed_table(table, '_distinct')
1354
    distinct_on = filter(sql_gen.is_table_col, distinct_on)
1355
    
1356
    copy_table_struct(db, table, new_table)
1357
    
1358
    limit = None
1359
    if distinct_on == []: limit = 1 # one sample row
1360
    else:
1361
        add_index(db, distinct_on, new_table, unique=True)
1362
        add_index(db, distinct_on, table) # for join optimization
1363
    
1364
    insert_select(db, new_table, None, mk_select(db, table, order_by=None,
1365
        limit=limit), ignore=True)
1366
    analyze(db, new_table)
1367
    
1368
    return new_table
(24-24/37)