Project

General

Profile

1
# Database access
2

    
3
import copy
4
import re
5
import time
6
import warnings
7

    
8
import exc
9
import dicts
10
import iters
11
import lists
12
import profiling
13
from Proxy import Proxy
14
import rand
15
import sql_gen
16
import strings
17
import util
18

    
19
##### Exceptions
20

    
21
def get_cur_query(cur, input_query=None):
22
    raw_query = None
23
    if hasattr(cur, 'query'): raw_query = cur.query
24
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
25
    
26
    if raw_query != None: return raw_query
27
    else: return '[input] '+strings.ustr(input_query)
28

    
29
def _add_cursor_info(e, *args, **kw_args):
30
    '''For params, see get_cur_query()'''
31
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
32

    
33
class DbException(exc.ExceptionWithCause):
34
    def __init__(self, msg, cause=None, cur=None):
35
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
36
        if cur != None: _add_cursor_info(self, cur)
37

    
38
class ExceptionWithName(DbException):
39
    def __init__(self, name, cause=None):
40
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
41
        self.name = name
42

    
43
class ExceptionWithValue(DbException):
44
    def __init__(self, value, cause=None):
45
        DbException.__init__(self, 'for value: '+strings.as_tt(repr(value)),
46
            cause)
47
        self.value = value
48

    
49
class ExceptionWithNameType(DbException):
50
    def __init__(self, type_, name, cause=None):
51
        DbException.__init__(self, 'for type: '+strings.as_tt(str(type_))
52
            +'; name: '+strings.as_tt(name), cause)
53
        self.type = type_
54
        self.name = name
55

    
56
class ConstraintException(DbException):
57
    def __init__(self, name, cond, cols, cause=None):
58
        msg = 'Violated '+strings.as_tt(name)+' constraint'
59
        if cond != None: msg += ' with condition '+cond
60
        if cols != []: msg += ' on columns: '+strings.as_tt(', '.join(cols))
61
        DbException.__init__(self, msg, cause)
62
        self.name = name
63
        self.cond = cond
64
        self.cols = cols
65

    
66
class MissingCastException(DbException):
67
    def __init__(self, type_, col, cause=None):
68
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
69
            +' on column: '+strings.as_tt(col), cause)
70
        self.type = type_
71
        self.col = col
72

    
73
class NameException(DbException): pass
74

    
75
class DuplicateKeyException(ConstraintException): pass
76

    
77
class NullValueException(ConstraintException): pass
78

    
79
class CheckException(ConstraintException): pass
80

    
81
class InvalidValueException(ExceptionWithValue): pass
82

    
83
class DuplicateException(ExceptionWithNameType): pass
84

    
85
class EmptyRowException(DbException): pass
86

    
87
##### Warnings
88

    
89
class DbWarning(UserWarning): pass
90

    
91
##### Result retrieval
92

    
93
def col_names(cur): return (col[0] for col in cur.description)
94

    
95
def rows(cur): return iter(lambda: cur.fetchone(), None)
96

    
97
def consume_rows(cur):
98
    '''Used to fetch all rows so result will be cached'''
99
    iters.consume_iter(rows(cur))
100

    
101
def next_row(cur): return rows(cur).next()
102

    
103
def row(cur):
104
    row_ = next_row(cur)
105
    consume_rows(cur)
106
    return row_
107

    
108
def next_value(cur): return next_row(cur)[0]
109

    
110
def value(cur): return row(cur)[0]
111

    
112
def values(cur): return iters.func_iter(lambda: next_value(cur))
113

    
114
def value_or_none(cur):
115
    try: return value(cur)
116
    except StopIteration: return None
117

    
118
##### Escaping
119

    
120
def esc_name_by_module(module, name):
121
    if module == 'psycopg2' or module == None: quote = '"'
122
    elif module == 'MySQLdb': quote = '`'
123
    else: raise NotImplementedError("Can't escape name for "+module+' database')
124
    return sql_gen.esc_name(name, quote)
125

    
126
def esc_name_by_engine(engine, name, **kw_args):
127
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
128

    
129
def esc_name(db, name, **kw_args):
130
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
131

    
132
def qual_name(db, schema, table):
133
    def esc_name_(name): return esc_name(db, name)
134
    table = esc_name_(table)
135
    if schema != None: return esc_name_(schema)+'.'+table
136
    else: return table
137

    
138
##### Database connections
139

    
140
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
141

    
142
db_engines = {
143
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
144
    'PostgreSQL': ('psycopg2', {}),
145
}
146

    
147
DatabaseErrors_set = set([DbException])
148
DatabaseErrors = tuple(DatabaseErrors_set)
149

    
150
def _add_module(module):
151
    DatabaseErrors_set.add(module.DatabaseError)
152
    global DatabaseErrors
153
    DatabaseErrors = tuple(DatabaseErrors_set)
154

    
155
def db_config_str(db_config):
156
    return db_config['engine']+' database '+db_config['database']
157

    
158
log_debug_none = lambda msg, level=2: None
159

    
160
class DbConn:
161
    def __init__(self, db_config, autocommit=True, caching=True,
162
        log_debug=log_debug_none, debug_temp=False, src=None):
163
        '''
164
        @param debug_temp Whether temporary objects should instead be permanent.
165
            This assists in debugging the internal objects used by the program.
166
        @param src In autocommit mode, will be included in a comment in every
167
            query, to help identify the data source in pg_stat_activity.
168
        '''
169
        self.db_config = db_config
170
        self.autocommit = autocommit
171
        self.caching = caching
172
        self.log_debug = log_debug
173
        self.debug = log_debug != log_debug_none
174
        self.debug_temp = debug_temp
175
        self.src = src
176
        self.autoanalyze = False
177
        self.autoexplain = False
178
        self.profile_row_ct = None
179
        
180
        self._savepoint = 0
181
        self._reset()
182
    
183
    def __getattr__(self, name):
184
        if name == '__dict__': raise Exception('getting __dict__')
185
        if name == 'db': return self._db()
186
        else: raise AttributeError()
187
    
188
    def __getstate__(self):
189
        state = copy.copy(self.__dict__) # shallow copy
190
        state['log_debug'] = None # don't pickle the debug callback
191
        state['_DbConn__db'] = None # don't pickle the connection
192
        return state
193
    
194
    def clear_cache(self): self.query_results = {}
195
    
196
    def _reset(self):
197
        self.clear_cache()
198
        assert self._savepoint == 0
199
        self._notices_seen = set()
200
        self.__db = None
201
    
202
    def connected(self): return self.__db != None
203
    
204
    def close(self):
205
        if not self.connected(): return
206
        
207
        # Record that the automatic transaction is now closed
208
        self._savepoint -= 1
209
        
210
        self.db.close()
211
        self._reset()
212
    
213
    def reconnect(self):
214
        # Do not do this in test mode as it would roll back everything
215
        if self.autocommit: self.close()
216
        # Connection will be reopened automatically on first query
217
    
218
    def _db(self):
219
        if self.__db == None:
220
            # Process db_config
221
            db_config = self.db_config.copy() # don't modify input!
222
            schemas = db_config.pop('schemas', None)
223
            module_name, mappings = db_engines[db_config.pop('engine')]
224
            module = __import__(module_name)
225
            _add_module(module)
226
            for orig, new in mappings.iteritems():
227
                try: util.rename_key(db_config, orig, new)
228
                except KeyError: pass
229
            
230
            # Connect
231
            self.__db = module.connect(**db_config)
232
            
233
            # Record that a transaction is already open
234
            self._savepoint += 1
235
            
236
            # Configure connection
237
            if hasattr(self.db, 'set_isolation_level'):
238
                import psycopg2.extensions
239
                self.db.set_isolation_level(
240
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
241
            if schemas != None:
242
                search_path = [self.esc_name(s) for s in schemas.split(',')]
243
                search_path.append(value(run_query(self, 'SHOW search_path',
244
                    log_level=4)))
245
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
246
                    log_level=3)
247
        
248
        return self.__db
249
    
250
    class DbCursor(Proxy):
251
        def __init__(self, outer):
252
            Proxy.__init__(self, outer.db.cursor())
253
            self.outer = outer
254
            self.query_results = outer.query_results
255
            self.query_lookup = None
256
            self.result = []
257
        
258
        def execute(self, query):
259
            self._is_insert = query.startswith('INSERT')
260
            self.query_lookup = query
261
            try:
262
                try: cur = self.inner.execute(query)
263
                finally: self.query = get_cur_query(self.inner, query)
264
            except Exception, e:
265
                self.result = e # cache the exception as the result
266
                self._cache_result()
267
                raise
268
            
269
            # Always cache certain queries
270
            query = sql_gen.lstrip(query)
271
            if query.startswith('CREATE') or query.startswith('ALTER'):
272
                # structural changes
273
                # Rest of query must be unique in the face of name collisions,
274
                # so don't cache ADD COLUMN unless it has distinguishing comment
275
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
276
                    self._cache_result()
277
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
278
                consume_rows(self) # fetch all rows so result will be cached
279
            
280
            return cur
281
        
282
        def fetchone(self):
283
            row = self.inner.fetchone()
284
            if row != None: self.result.append(row)
285
            # otherwise, fetched all rows
286
            else: self._cache_result()
287
            return row
288
        
289
        def _cache_result(self):
290
            # For inserts that return a result set, don't cache result set since
291
            # inserts are not idempotent. Other non-SELECT queries don't have
292
            # their result set read, so only exceptions will be cached (an
293
            # invalid query will always be invalid).
294
            if self.query_results != None and (not self._is_insert
295
                or isinstance(self.result, Exception)):
296
                
297
                assert self.query_lookup != None
298
                self.query_results[self.query_lookup] = self.CacheCursor(
299
                    util.dict_subset(dicts.AttrsDictView(self),
300
                    ['query', 'result', 'rowcount', 'description']))
301
        
302
        class CacheCursor:
303
            def __init__(self, cached_result): self.__dict__ = cached_result
304
            
305
            def execute(self, *args, **kw_args):
306
                if isinstance(self.result, Exception): raise self.result
307
                # otherwise, result is a rows list
308
                self.iter = iter(self.result)
309
            
310
            def fetchone(self):
311
                try: return self.iter.next()
312
                except StopIteration: return None
313
    
314
    def esc_value(self, value):
315
        try: str_ = self.mogrify('%s', [value])
316
        except NotImplementedError, e:
317
            module = util.root_module(self.db)
318
            if module == 'MySQLdb':
319
                import _mysql
320
                str_ = _mysql.escape_string(value)
321
            else: raise e
322
        return strings.to_unicode(str_)
323
    
324
    def esc_name(self, name): return esc_name(self, name) # calls global func
325
    
326
    def std_code(self, str_):
327
        '''Standardizes SQL code.
328
        * Ensures that string literals are prefixed by `E`
329
        '''
330
        if str_.startswith("'"): str_ = 'E'+str_
331
        return str_
332
    
333
    def can_mogrify(self):
334
        module = util.root_module(self.db)
335
        return module == 'psycopg2'
336
    
337
    def mogrify(self, query, params=None):
338
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
339
        else: raise NotImplementedError("Can't mogrify query")
340
    
341
    def print_notices(self):
342
        if hasattr(self.db, 'notices'):
343
            for msg in self.db.notices:
344
                if msg not in self._notices_seen:
345
                    self._notices_seen.add(msg)
346
                    self.log_debug(msg, level=2)
347
    
348
    def run_query(self, query, cacheable=False, log_level=2,
349
        debug_msg_ref=None):
350
        '''
351
        @param log_ignore_excs The log_level will be increased by 2 if the query
352
            throws one of these exceptions.
353
        @param debug_msg_ref If specified, the log message will be returned in
354
            this instead of being output. This allows you to filter log messages
355
            depending on the result of the query.
356
        '''
357
        assert query != None
358
        
359
        if self.autocommit and self.src != None:
360
            query = sql_gen.esc_comment(self.src)+'\t'+query
361
        
362
        if not self.caching: cacheable = False
363
        used_cache = False
364
        
365
        if self.debug:
366
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
367
        try:
368
            # Get cursor
369
            if cacheable:
370
                try: cur = self.query_results[query]
371
                except KeyError: cur = self.DbCursor(self)
372
                else: used_cache = True
373
            else: cur = self.db.cursor()
374
            
375
            # Run query
376
            try: cur.execute(query)
377
            except Exception, e:
378
                _add_cursor_info(e, self, query)
379
                raise
380
            else: self.do_autocommit()
381
        finally:
382
            if self.debug:
383
                profiler.stop(self.profile_row_ct)
384
                
385
                ## Log or return query
386
                
387
                query = str(get_cur_query(cur, query))
388
                # Put the src comment on a separate line in the log file
389
                query = query.replace('\t', '\n', 1)
390
                
391
                msg = 'DB query: '
392
                
393
                if used_cache: msg += 'cache hit'
394
                elif cacheable: msg += 'cache miss'
395
                else: msg += 'non-cacheable'
396
                
397
                msg += ':\n'+profiler.msg()+'\n'+strings.as_code(query, 'SQL')
398
                
399
                if debug_msg_ref != None: debug_msg_ref[0] = msg
400
                else: self.log_debug(msg, log_level)
401
                
402
                self.print_notices()
403
        
404
        return cur
405
    
406
    def is_cached(self, query): return query in self.query_results
407
    
408
    def with_autocommit(self, func):
409
        import psycopg2.extensions
410
        
411
        prev_isolation_level = self.db.isolation_level
412
        self.db.set_isolation_level(
413
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
414
        try: return func()
415
        finally: self.db.set_isolation_level(prev_isolation_level)
416
    
417
    def with_savepoint(self, func):
418
        top = self._savepoint == 0
419
        savepoint = 'level_'+str(self._savepoint)
420
        
421
        if self.debug:
422
            self.log_debug('Begin transaction', level=4)
423
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
424
        
425
        # Must happen before running queries so they don't get autocommitted
426
        self._savepoint += 1
427
        
428
        if top: query = 'START TRANSACTION ISOLATION LEVEL READ COMMITTED'
429
        else: query = 'SAVEPOINT '+savepoint
430
        self.run_query(query, log_level=4)
431
        try:
432
            return func()
433
            if top: self.run_query('COMMIT', log_level=4)
434
        except:
435
            if top: query = 'ROLLBACK'
436
            else: query = 'ROLLBACK TO SAVEPOINT '+savepoint
437
            self.run_query(query, log_level=4)
438
            
439
            raise
440
        finally:
441
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
442
            # "The savepoint remains valid and can be rolled back to again"
443
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
444
            if not top:
445
                self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
446
            
447
            self._savepoint -= 1
448
            assert self._savepoint >= 0
449
            
450
            if self.debug:
451
                profiler.stop(self.profile_row_ct)
452
                self.log_debug('End transaction\n'+profiler.msg(), level=4)
453
            
454
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
455
    
456
    def do_autocommit(self):
457
        '''Autocommits if outside savepoint'''
458
        assert self._savepoint >= 1
459
        if self.autocommit and self._savepoint == 1:
460
            self.log_debug('Autocommitting', level=4)
461
            self.db.commit()
462
    
463
    def col_info(self, col, cacheable=True):
464
        table = sql_gen.Table('columns', 'information_schema')
465
        type_ = sql_gen.Coalesce(sql_gen.Nullif(sql_gen.Col('data_type'),
466
            'USER-DEFINED'), sql_gen.Col('udt_name'))
467
        cols = [type_, 'column_default',
468
            sql_gen.Cast('boolean', sql_gen.Col('is_nullable'))]
469
        
470
        conds = [('table_name', col.table.name), ('column_name', col.name)]
471
        schema = col.table.schema
472
        if schema != None: conds.append(('table_schema', schema))
473
        
474
        type_, default, nullable = row(select(self, table, cols, conds,
475
            order_by='table_schema', limit=1, cacheable=cacheable, log_level=4))
476
            # TODO: order_by search_path schema order
477
        default = sql_gen.as_Code(default, self)
478
        
479
        return sql_gen.TypedCol(col.name, type_, default, nullable)
480
    
481
    def TempFunction(self, name):
482
        if self.debug_temp: schema = None
483
        else: schema = 'pg_temp'
484
        return sql_gen.Function(name, schema)
485

    
486
connect = DbConn
487

    
488
##### Recoverable querying
489

    
490
def with_savepoint(db, func): return db.with_savepoint(func)
491

    
492
def run_query(db, query, recover=None, cacheable=False, log_level=2,
493
    log_ignore_excs=None, **kw_args):
494
    '''For params, see DbConn.run_query()'''
495
    if recover == None: recover = False
496
    if log_ignore_excs == None: log_ignore_excs = ()
497
    log_ignore_excs = tuple(log_ignore_excs)
498
    debug_msg_ref = [None]
499
    
500
    query = with_explain_comment(db, query)
501
    
502
    try:
503
        try:
504
            def run(): return db.run_query(query, cacheable, log_level,
505
                debug_msg_ref, **kw_args)
506
            if recover and not db.is_cached(query):
507
                return with_savepoint(db, run)
508
            else: return run() # don't need savepoint if cached
509
        except Exception, e:
510
            msg = strings.ustr(e.args[0])
511
            msg = re.sub(r'^PL/Python: \w+: ', r'', msg)
512
            
513
            match = re.match(r'^duplicate key value violates unique constraint '
514
                r'"(.+?)"', msg)
515
            if match:
516
                constraint, = match.groups()
517
                cols = []
518
                if recover: # need auto-rollback to run index_cols()
519
                    try: cols = index_cols(db, constraint)
520
                    except NotImplementedError: pass
521
                raise DuplicateKeyException(constraint, None, cols, e)
522
            
523
            match = re.match(r'^null value in column "(.+?)" violates not-null'
524
                r' constraint', msg)
525
            if match:
526
                col, = match.groups()
527
                raise NullValueException('NOT NULL', None, [col], e)
528
            
529
            match = re.match(r'^new row for relation "(.+?)" violates check '
530
                r'constraint "(.+?)"', msg)
531
            if match:
532
                table, constraint = match.groups()
533
                constraint = sql_gen.Col(constraint, table)
534
                cond = None
535
                if recover: # need auto-rollback to run constraint_cond()
536
                    try: cond = constraint_cond(db, constraint)
537
                    except NotImplementedError: pass
538
                raise CheckException(constraint.to_str(db), cond, [], e)
539
            
540
            match = re.match(r'^(?:invalid input (?:syntax|value)\b.*?'
541
                r'|.+? field value out of range): "(.+?)"', msg)
542
            if match:
543
                value, = match.groups()
544
                raise InvalidValueException(strings.to_unicode(value), e)
545
            
546
            match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
547
                r'is of type', msg)
548
            if match:
549
                col, type_ = match.groups()
550
                raise MissingCastException(type_, col, e)
551
            
552
            match = re.match(r'^(\S+) "(.+?)".*? already exists', msg)
553
            if match:
554
                type_, name = match.groups()
555
                raise DuplicateException(type_, name, e)
556
            
557
            raise # no specific exception raised
558
    except log_ignore_excs:
559
        log_level += 2
560
        raise
561
    finally:
562
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)
563

    
564
##### Basic queries
565

    
566
def is_explainable(query):
567
    # See <http://www.postgresql.org/docs/8.3/static/sql-explain.html#AEN57749>
568
    return re.match(r'^(?:SELECT|INSERT|UPDATE|DELETE|VALUES|EXECUTE|DECLARE)\b'
569
        , query)
570

    
571
def explain(db, query, **kw_args):
572
    '''
573
    For params, see run_query().
574
    '''
575
    kw_args.setdefault('log_level', 4)
576
    
577
    return strings.join_lines(values(run_query(db, 'EXPLAIN '+query,
578
        recover=True, cacheable=True, **kw_args)))
579
        # not a higher log_level because it's useful to see what query is being
580
        # run before it's executed, which EXPLAIN effectively provides
581

    
582
def has_comment(query): return query.endswith('*/')
583

    
584
def with_explain_comment(db, query, **kw_args):
585
    if db.autoexplain and not has_comment(query) and is_explainable(query):
586
        query += '\n'+sql_gen.esc_comment(' EXPLAIN:\n'
587
            +explain(db, query, **kw_args))
588
    return query
589

    
590
def next_version(name):
591
    version = 1 # first existing name was version 0
592
    match = re.match(r'^(.*)#(\d+)$', name)
593
    if match:
594
        name, version = match.groups()
595
        version = int(version)+1
596
    return sql_gen.concat(name, '#'+str(version))
597

    
598
def lock_table(db, table, mode):
599
    table = sql_gen.as_Table(table)
600
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
601

    
602
def run_query_into(db, query, into=None, add_pkey_=False, **kw_args):
603
    '''Outputs a query to a temp table.
604
    For params, see run_query().
605
    '''
606
    if into == None: return run_query(db, query, **kw_args)
607
    
608
    assert isinstance(into, sql_gen.Table)
609
    
610
    into.is_temp = True
611
    # "temporary tables cannot specify a schema name", so remove schema
612
    into.schema = None
613
    
614
    kw_args['recover'] = True
615
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
616
    
617
    temp = not db.debug_temp # tables are permanent in debug_temp mode
618
    
619
    # Create table
620
    while True:
621
        create_query = 'CREATE'
622
        if temp: create_query += ' TEMP'
623
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
624
        
625
        try:
626
            cur = run_query(db, create_query, **kw_args)
627
                # CREATE TABLE AS sets rowcount to # rows in query
628
            break
629
        except DuplicateException, e:
630
            into.name = next_version(into.name)
631
            # try again with next version of name
632
    
633
    if add_pkey_: add_pkey(db, into)
634
    
635
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
636
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
637
    # table is going to be used in complex queries, it is wise to run ANALYZE on
638
    # the temporary table after it is populated."
639
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
640
    # If into is not a temp table, ANALYZE is useful but not required.
641
    analyze(db, into)
642
    
643
    return cur
644

    
645
order_by_pkey = object() # tells mk_select() to order by the pkey
646

    
647
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
648

    
649
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
650
    start=None, order_by=order_by_pkey, default_table=None):
651
    '''
652
    @param tables The single table to select from, or a list of tables to join
653
        together, with tables after the first being sql_gen.Join objects
654
    @param fields Use None to select all fields in the table
655
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
656
        * container can be any iterable type
657
        * compare_left_side: sql_gen.Code|str (for col name)
658
        * compare_right_side: sql_gen.ValueCond|literal value
659
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
660
        use all columns
661
    @return query
662
    '''
663
    # Parse tables param
664
    tables = lists.mk_seq(tables)
665
    tables = list(tables) # don't modify input! (list() copies input)
666
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
667
    
668
    # Parse other params
669
    if conds == None: conds = []
670
    elif dicts.is_dict(conds): conds = conds.items()
671
    conds = list(conds) # don't modify input! (list() copies input)
672
    assert limit == None or isinstance(limit, (int, long))
673
    assert start == None or isinstance(start, (int, long))
674
    if order_by is order_by_pkey:
675
        if distinct_on != []: order_by = None
676
        else: order_by = pkey(db, table0, recover=True)
677
    
678
    query = 'SELECT'
679
    
680
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
681
    
682
    # DISTINCT ON columns
683
    if distinct_on != []:
684
        query += '\nDISTINCT'
685
        if distinct_on is not distinct_on_all:
686
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
687
    
688
    # Columns
689
    if query.find('\n') >= 0: whitespace = '\n'
690
    else: whitespace = ' '
691
    if fields == None: query += whitespace+'*'
692
    else:
693
        assert fields != []
694
        if len(fields) > 1: whitespace = '\n'
695
        query += whitespace+('\n, '.join(map(parse_col, fields)))
696
    
697
    # Main table
698
    if query.find('\n') >= 0 or len(tables) > 0: whitespace = '\n'
699
    else: whitespace = ' '
700
    query += whitespace+'FROM '+table0.to_str(db)
701
    
702
    # Add joins
703
    left_table = table0
704
    for join_ in tables:
705
        table = join_.table
706
        
707
        # Parse special values
708
        if join_.type_ is sql_gen.filter_out: # filter no match
709
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
710
                sql_gen.CompareCond(None, '~=')))
711
        
712
        query += '\n'+join_.to_str(db, left_table)
713
        
714
        left_table = table
715
    
716
    missing = True
717
    if conds != []:
718
        if len(conds) == 1: whitespace = ' '
719
        else: whitespace = '\n'
720
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
721
            .to_str(db) for l, r in conds], 'WHERE')
722
    if order_by != None:
723
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
724
    if limit != None: query += '\nLIMIT '+str(limit)
725
    if start != None:
726
        if start != 0: query += '\nOFFSET '+str(start)
727
    
728
    query = with_explain_comment(db, query)
729
    
730
    return query
731

    
732
def select(db, *args, **kw_args):
733
    '''For params, see mk_select() and run_query()'''
734
    recover = kw_args.pop('recover', None)
735
    cacheable = kw_args.pop('cacheable', True)
736
    log_level = kw_args.pop('log_level', 2)
737
    
738
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
739
        log_level=log_level)
740

    
741
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
742
    embeddable=False, ignore=False, src=None):
743
    '''
744
    @param returning str|None An inserted column (such as pkey) to return
745
    @param embeddable Whether the query should be embeddable as a nested SELECT.
746
        Warning: If you set this and cacheable=True when the query is run, the
747
        query will be fully cached, not just if it raises an exception.
748
    @param ignore Whether to ignore duplicate keys.
749
    @param src Will be included in the name of any created function, to help
750
        identify the data source in pg_stat_activity.
751
    '''
752
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
753
    if cols == []: cols = None # no cols (all defaults) = unknown col names
754
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
755
    if select_query == None: select_query = 'DEFAULT VALUES'
756
    if returning != None: returning = sql_gen.as_Col(returning, table)
757
    
758
    first_line = 'INSERT INTO '+table.to_str(db)
759
    
760
    def mk_insert(select_query):
761
        query = first_line
762
        if cols != None:
763
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
764
        query += '\n'+select_query
765
        
766
        if returning != None:
767
            returning_name_col = sql_gen.to_name_only_col(returning)
768
            query += '\nRETURNING '+returning_name_col.to_str(db)
769
        
770
        return query
771
    
772
    return_type = 'unknown'
773
    if returning != None: return_type = returning.to_str(db)+'%TYPE'
774
    
775
    lang = 'sql'
776
    if ignore:
777
        # Always return something to set the correct rowcount
778
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
779
        
780
        embeddable = True # must use function
781
        lang = 'plpgsql'
782
        
783
        if cols == None:
784
            row = [sql_gen.Col(sql_gen.all_cols, 'row')]
785
            row_vars = [sql_gen.Table('row')]
786
        else:
787
            row_vars = row = [sql_gen.Col(c.name, 'row') for c in cols]
788
        
789
        query = '''\
790
DECLARE
791
    row '''+table.to_str(db)+'''%ROWTYPE;
792
BEGIN
793
    /* Need an EXCEPTION block for each individual row because "When an error is
794
    caught by an EXCEPTION clause, [...] all changes to persistent database
795
    state within the block are rolled back."
796
    This is unfortunate because "A block containing an EXCEPTION clause is
797
    significantly more expensive to enter and exit than a block without one."
798
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
799
#PLPGSQL-ERROR-TRAPPING)
800
    */
801
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
802
'''+select_query+'''
803
    LOOP
804
        BEGIN
805
            RETURN QUERY
806
'''+mk_insert(sql_gen.Values(row).to_str(db))+'''
807
;
808
        EXCEPTION
809
            WHEN unique_violation THEN NULL; -- continue to next row
810
        END;
811
    END LOOP;
812
END;\
813
'''
814
    else: query = mk_insert(select_query)
815
    
816
    if embeddable:
817
        # Create function
818
        function_name = sql_gen.clean_name(first_line)
819
        if src != None: function_name = src+': '+function_name
820
        while True:
821
            try:
822
                function = db.TempFunction(function_name)
823
                
824
                function_query = '''\
825
CREATE FUNCTION '''+function.to_str(db)+'''()
826
RETURNS SETOF '''+return_type+'''
827
LANGUAGE '''+lang+'''
828
AS $$
829
'''+query+'''
830
$$;
831
'''
832
                run_query(db, function_query, recover=True, cacheable=True,
833
                    log_ignore_excs=(DuplicateException,))
834
                break # this version was successful
835
            except DuplicateException, e:
836
                function_name = next_version(function_name)
837
                # try again with next version of name
838
        
839
        # Return query that uses function
840
        cols = None
841
        if returning != None: cols = [returning]
842
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
843
            cols) # AS clause requires function alias
844
        return mk_select(db, func_table, order_by=None)
845
    
846
    return query
847

    
848
def insert_select(db, table, *args, **kw_args):
849
    '''For params, see mk_insert_select() and run_query_into()
850
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
851
        values in
852
    '''
853
    returning = kw_args.get('returning', None)
854
    ignore = kw_args.get('ignore', False)
855
    
856
    into = kw_args.pop('into', None)
857
    if into != None: kw_args['embeddable'] = True
858
    recover = kw_args.pop('recover', None)
859
    if ignore: recover = True
860
    cacheable = kw_args.pop('cacheable', True)
861
    log_level = kw_args.pop('log_level', 2)
862
    
863
    rowcount_only = ignore and returning == None # keep NULL rows on server
864
    if rowcount_only: into = sql_gen.Table('rowcount')
865
    
866
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
867
        into, recover=recover, cacheable=cacheable, log_level=log_level)
868
    if rowcount_only: empty_temp(db, into)
869
    autoanalyze(db, table)
870
    return cur
871

    
872
default = sql_gen.default # tells insert() to use the default value for a column
873

    
874
def insert(db, table, row, *args, **kw_args):
875
    '''For params, see insert_select()'''
876
    if lists.is_seq(row): cols = None
877
    else:
878
        cols = row.keys()
879
        row = row.values()
880
    row = list(row) # ensure that "== []" works
881
    
882
    if row == []: query = None
883
    else: query = sql_gen.Values(row).to_str(db)
884
    
885
    return insert_select(db, table, cols, query, *args, **kw_args)
886

    
887
def mk_update(db, table, changes=None, cond=None, in_place=False,
888
    cacheable_=True):
889
    '''
890
    @param changes [(col, new_value),...]
891
        * container can be any iterable type
892
        * col: sql_gen.Code|str (for col name)
893
        * new_value: sql_gen.Code|literal value
894
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
895
    @param in_place If set, locks the table and updates rows in place.
896
        This avoids creating dead rows in PostgreSQL.
897
        * cond must be None
898
    @param cacheable_ Whether column structure information used to generate the
899
        query can be cached
900
    @return str query
901
    '''
902
    table = sql_gen.as_Table(table)
903
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
904
        for c, v in changes]
905
    
906
    if in_place:
907
        assert cond == None
908
        
909
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
910
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
911
            +db.col_info(sql_gen.with_default_table(c, table), cacheable_).type
912
            +'\nUSING '+v.to_str(db) for c, v in changes))
913
    else:
914
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
915
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
916
            for c, v in changes))
917
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
918
    
919
    query = with_explain_comment(db, query)
920
    
921
    return query
922

    
923
def update(db, table, *args, **kw_args):
924
    '''For params, see mk_update() and run_query()'''
925
    recover = kw_args.pop('recover', None)
926
    cacheable = kw_args.pop('cacheable', False)
927
    log_level = kw_args.pop('log_level', 2)
928
    
929
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
930
        cacheable, log_level=log_level)
931
    autoanalyze(db, table)
932
    return cur
933

    
934
def mk_delete(db, table, cond=None):
935
    '''
936
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
937
    @return str query
938
    '''
939
    query = 'DELETE FROM '+table.to_str(db)
940
    if cond != None: query += '\nWHERE '+cond.to_str(db)
941
    
942
    query = with_explain_comment(db, query)
943
    
944
    return query
945

    
946
def delete(db, table, *args, **kw_args):
947
    '''For params, see mk_delete() and run_query()'''
948
    recover = kw_args.pop('recover', None)
949
    cacheable = kw_args.pop('cacheable', True)
950
    log_level = kw_args.pop('log_level', 2)
951
    
952
    cur = run_query(db, mk_delete(db, table, *args, **kw_args), recover,
953
        cacheable, log_level=log_level)
954
    autoanalyze(db, table)
955
    return cur
956

    
957
def last_insert_id(db):
958
    module = util.root_module(db.db)
959
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
960
    elif module == 'MySQLdb': return db.insert_id()
961
    else: return None
962

    
963
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
964
    '''Creates a mapping from original column names (which may have collisions)
965
    to names that will be distinct among the columns' tables.
966
    This is meant to be used for several tables that are being joined together.
967
    @param cols The columns to combine. Duplicates will be removed.
968
    @param into The table for the new columns.
969
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
970
        columns will be included in the mapping even if they are not in cols.
971
        The tables of the provided Col objects will be changed to into, so make
972
        copies of them if you want to keep the original tables.
973
    @param as_items Whether to return a list of dict items instead of a dict
974
    @return dict(orig_col=new_col, ...)
975
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
976
        * new_col: sql_gen.Col(orig_col_name, into)
977
        * All mappings use the into table so its name can easily be
978
          changed for all columns at once
979
    '''
980
    cols = lists.uniqify(cols)
981
    
982
    items = []
983
    for col in preserve:
984
        orig_col = copy.copy(col)
985
        col.table = into
986
        items.append((orig_col, col))
987
    preserve = set(preserve)
988
    for col in cols:
989
        if col not in preserve:
990
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
991
    
992
    if not as_items: items = dict(items)
993
    return items
994

    
995
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
996
    '''For params, see mk_flatten_mapping()
997
    @return See return value of mk_flatten_mapping()
998
    '''
999
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
1000
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
1001
    run_query_into(db, mk_select(db, joins, cols, order_by=None, limit=limit,
1002
        start=start), into=into, add_pkey_=True)
1003
    return dict(items)
1004

    
1005
##### Database structure introspection
1006

    
1007
#### Expressions
1008

    
1009
bool_re = r'(?:true|false)'
1010

    
1011
def simplify_expr(expr):
1012
    expr = expr.replace('(NULL IS NULL)', 'true')
1013
    expr = expr.replace('(NULL IS NOT NULL)', 'false')
1014
    expr = re.sub(r' OR '+bool_re, r'', expr)
1015
    expr = re.sub(bool_re+r' OR ', r'', expr)
1016
    while True:
1017
        expr, n = re.subn(r'\((\([^()]*\))\)', r'\1', expr)
1018
        if n == 0: break
1019
    return expr
1020

    
1021
name_re = r'(?:\w+|(?:"[^"]*")+)'
1022

    
1023
def parse_expr_col(str_):
1024
    match = re.match(r'^\('+name_re+r'\(('+name_re+r').*\)\)$', str_)
1025
    if match: str_ = match.group(1)
1026
    return sql_gen.unesc_name(str_)
1027

    
1028
def map_expr(db, expr, mapping, in_cols_found=None):
1029
    '''Replaces output columns with input columns in an expression.
1030
    @param in_cols_found If set, will be filled in with the expr's (input) cols
1031
    '''
1032
    for out, in_ in mapping.iteritems():
1033
        orig_expr = expr
1034
        out = sql_gen.to_name_only_col(out)
1035
        in_str = sql_gen.to_name_only_col(sql_gen.remove_col_rename(in_)
1036
            ).to_str(db)
1037
        
1038
        # Replace out both with and without quotes
1039
        expr = expr.replace(out.to_str(db), in_str)
1040
        expr = re.sub(r'\b'+out.name+r'\b', in_str, expr)
1041
        
1042
        if in_cols_found != None and expr != orig_expr: # replaced something
1043
            in_cols_found.append(in_)
1044
    
1045
    return simplify_expr(expr)
1046

    
1047
#### Tables
1048

    
1049
def tables(db, schema_like='public', table_like='%', exact=False):
1050
    if exact: compare = '='
1051
    else: compare = 'LIKE'
1052
    
1053
    module = util.root_module(db.db)
1054
    if module == 'psycopg2':
1055
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
1056
            ('tablename', sql_gen.CompareCond(table_like, compare))]
1057
        return values(select(db, 'pg_tables', ['tablename'], conds,
1058
            order_by='tablename', log_level=4))
1059
    elif module == 'MySQLdb':
1060
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
1061
            , cacheable=True, log_level=4))
1062
    else: raise NotImplementedError("Can't list tables for "+module+' database')
1063

    
1064
def table_exists(db, table):
1065
    table = sql_gen.as_Table(table)
1066
    return list(tables(db, table.schema, table.name, exact=True)) != []
1067

    
1068
def table_row_count(db, table, recover=None):
1069
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
1070
        order_by=None), recover=recover, log_level=3))
1071

    
1072
def table_cols(db, table, recover=None):
1073
    return list(col_names(select(db, table, limit=0, order_by=None,
1074
        recover=recover, log_level=4)))
1075

    
1076
def pkey(db, table, recover=None):
1077
    '''Assumed to be first column in table'''
1078
    return table_cols(db, table, recover)[0]
1079

    
1080
not_null_col = 'not_null_col'
1081

    
1082
def table_not_null_col(db, table, recover=None):
1083
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
1084
    if not_null_col in table_cols(db, table, recover): return not_null_col
1085
    else: return pkey(db, table, recover)
1086

    
1087
def constraint_cond(db, constraint):
1088
    module = util.root_module(db.db)
1089
    if module == 'psycopg2':
1090
        table_str = sql_gen.Literal(constraint.table.to_str(db))
1091
        name_str = sql_gen.Literal(constraint.name)
1092
        return value(run_query(db, '''\
1093
SELECT consrc
1094
FROM pg_constraint
1095
WHERE
1096
conrelid = '''+table_str.to_str(db)+'''::regclass
1097
AND conname = '''+name_str.to_str(db)+'''
1098
'''
1099
            , cacheable=True, log_level=4))
1100
    else: raise NotImplementedError("Can't list index columns for "+module+
1101
        ' database')
1102

    
1103
def index_cols(db, index):
1104
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
1105
    automatically created. When you don't know whether something is a UNIQUE
1106
    constraint or a UNIQUE index, use this function.'''
1107
    index = sql_gen.as_Table(index)
1108
    module = util.root_module(db.db)
1109
    if module == 'psycopg2':
1110
        qual_index = sql_gen.Literal(index.to_str(db))
1111
        return map(parse_expr_col, values(run_query(db, '''\
1112
SELECT pg_get_indexdef(indexrelid, generate_series(1, indnatts), true)
1113
FROM pg_index
1114
WHERE indexrelid = '''+qual_index.to_str(db)+'''::regclass
1115
'''
1116
            , cacheable=True, log_level=4)))
1117
    else: raise NotImplementedError("Can't list index columns for "+module+
1118
        ' database')
1119

    
1120
#### Functions
1121

    
1122
def function_exists(db, function):
1123
    function = sql_gen.as_Function(function)
1124
    
1125
    info_table = sql_gen.Table('routines', 'information_schema')
1126
    conds = [('routine_name', function.name)]
1127
    schema = function.schema
1128
    if schema != None: conds.append(('routine_schema', schema))
1129
    # Exclude trigger functions, since they cannot be called directly
1130
    conds.append(('data_type', sql_gen.CompareCond('trigger', '!=')))
1131
    
1132
    return list(values(select(db, info_table, ['routine_name'], conds,
1133
        order_by='routine_schema', limit=1, log_level=4))) != []
1134
        # TODO: order_by search_path schema order
1135

    
1136
##### Structural changes
1137

    
1138
#### Columns
1139

    
1140
def add_col(db, table, col, comment=None, **kw_args):
1141
    '''
1142
    @param col TypedCol Name may be versioned, so be sure to propagate any
1143
        renaming back to any source column for the TypedCol.
1144
    @param comment None|str SQL comment used to distinguish columns of the same
1145
        name from each other when they contain different data, to allow the
1146
        ADD COLUMN query to be cached. If not set, query will not be cached.
1147
    '''
1148
    assert isinstance(col, sql_gen.TypedCol)
1149
    
1150
    while True:
1151
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1152
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1153
        
1154
        try:
1155
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1156
            break
1157
        except DuplicateException:
1158
            col.name = next_version(col.name)
1159
            # try again with next version of name
1160

    
1161
def add_not_null(db, col):
1162
    table = col.table
1163
    col = sql_gen.to_name_only_col(col)
1164
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1165
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1166

    
1167
row_num_col = '_row_num'
1168

    
1169
row_num_typed_col = sql_gen.TypedCol(row_num_col, 'serial', nullable=False,
1170
    constraints='PRIMARY KEY')
1171

    
1172
def add_row_num(db, table):
1173
    '''Adds a row number column to a table. Its name is in row_num_col. It will
1174
    be the primary key.'''
1175
    add_col(db, table, row_num_typed_col, log_level=3)
1176

    
1177
#### Indexes
1178

    
1179
def add_pkey(db, table, cols=None, recover=None):
1180
    '''Adds a primary key.
1181
    @param cols [sql_gen.Col,...] The columns in the primary key.
1182
        Defaults to the first column in the table.
1183
    @pre The table must not already have a primary key.
1184
    '''
1185
    table = sql_gen.as_Table(table)
1186
    if cols == None: cols = [pkey(db, table, recover)]
1187
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1188
    
1189
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1190
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1191
        log_ignore_excs=(DuplicateException,))
1192

    
1193
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1194
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1195
    Currently, only function calls and literal values are supported expressions.
1196
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1197
        This allows indexes to be used for comparisons where NULLs are equal.
1198
    '''
1199
    exprs = lists.mk_seq(exprs)
1200
    
1201
    # Parse exprs
1202
    old_exprs = exprs[:]
1203
    exprs = []
1204
    cols = []
1205
    for i, expr in enumerate(old_exprs):
1206
        expr = sql_gen.as_Col(expr, table)
1207
        
1208
        # Handle nullable columns
1209
        if ensure_not_null_:
1210
            try: expr = sql_gen.ensure_not_null(db, expr)
1211
            except KeyError: pass # unknown type, so just create plain index
1212
        
1213
        # Extract col
1214
        expr = copy.deepcopy(expr) # don't modify input!
1215
        col = expr
1216
        if isinstance(expr, sql_gen.FunctionCall): col = expr.args[0]
1217
        expr = sql_gen.cast_literal(expr)
1218
        if not isinstance(expr, (sql_gen.Expr, sql_gen.Col)):
1219
            expr = sql_gen.Expr(expr)
1220
            
1221
        
1222
        # Extract table
1223
        if table == None:
1224
            assert sql_gen.is_table_col(col)
1225
            table = col.table
1226
        
1227
        if isinstance(col, sql_gen.Col): col.table = None
1228
        
1229
        exprs.append(expr)
1230
        cols.append(col)
1231
    
1232
    table = sql_gen.as_Table(table)
1233
    
1234
    # Add index
1235
    str_ = 'CREATE'
1236
    if unique: str_ += ' UNIQUE'
1237
    str_ += ' INDEX ON '+table.to_str(db)+' ('+(
1238
        ', '.join((v.to_str(db) for v in exprs)))+')'
1239
    run_query(db, str_, recover=True, cacheable=True, log_level=3)
1240

    
1241
already_indexed = object() # tells add_indexes() the pkey has already been added
1242

    
1243
def add_indexes(db, table, has_pkey=True):
1244
    '''Adds an index on all columns in a table.
1245
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1246
        index should be added on the first column.
1247
        * If already_indexed, the pkey is assumed to have already been added
1248
    '''
1249
    cols = table_cols(db, table)
1250
    if has_pkey:
1251
        if has_pkey is not already_indexed: add_pkey(db, table)
1252
        cols = cols[1:]
1253
    for col in cols: add_index(db, col, table)
1254

    
1255
#### Tables
1256

    
1257
### Maintenance
1258

    
1259
def analyze(db, table):
1260
    table = sql_gen.as_Table(table)
1261
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1262

    
1263
def autoanalyze(db, table):
1264
    if db.autoanalyze: analyze(db, table)
1265

    
1266
def vacuum(db, table):
1267
    table = sql_gen.as_Table(table)
1268
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1269
        log_level=3))
1270

    
1271
### Lifecycle
1272

    
1273
def drop(db, type_, name):
1274
    name = sql_gen.as_Name(name)
1275
    run_query(db, 'DROP '+type_+' IF EXISTS '+name.to_str(db)+' CASCADE')
1276

    
1277
def drop_table(db, table): drop(db, 'TABLE', table)
1278

    
1279
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1280
    like=None):
1281
    '''Creates a table.
1282
    @param cols [sql_gen.TypedCol,...] The column names and types
1283
    @param has_pkey If set, the first column becomes the primary key.
1284
    @param col_indexes bool|[ref]
1285
        * If True, indexes will be added on all non-pkey columns.
1286
        * If a list reference, [0] will be set to a function to do this.
1287
          This can be used to delay index creation until the table is populated.
1288
    '''
1289
    table = sql_gen.as_Table(table)
1290
    
1291
    if like != None:
1292
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1293
            ]+cols
1294
    if has_pkey:
1295
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1296
        pkey.constraints = 'PRIMARY KEY'
1297
    
1298
    temp = table.is_temp and not db.debug_temp
1299
        # temp tables permanent in debug_temp mode
1300
    
1301
    # Create table
1302
    def create():
1303
        str_ = 'CREATE'
1304
        if temp: str_ += ' TEMP'
1305
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1306
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1307
        str_ += '\n);'
1308
        
1309
        run_query(db, str_, recover=True, cacheable=True, log_level=2,
1310
            log_ignore_excs=(DuplicateException,))
1311
    if table.is_temp:
1312
        while True:
1313
            try:
1314
                create()
1315
                break
1316
            except DuplicateException:
1317
                table.name = next_version(table.name)
1318
                # try again with next version of name
1319
    else: create()
1320
    
1321
    # Add indexes
1322
    if has_pkey: has_pkey = already_indexed
1323
    def add_indexes_(): add_indexes(db, table, has_pkey)
1324
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1325
    elif col_indexes: add_indexes_() # add now
1326

    
1327
def copy_table_struct(db, src, dest):
1328
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1329
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1330

    
1331
### Data
1332

    
1333
def truncate(db, table, schema='public', **kw_args):
1334
    '''For params, see run_query()'''
1335
    table = sql_gen.as_Table(table, schema)
1336
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1337

    
1338
def empty_temp(db, tables):
1339
    tables = lists.mk_seq(tables)
1340
    for table in tables: truncate(db, table, log_level=3)
1341

    
1342
def empty_db(db, schema='public', **kw_args):
1343
    '''For kw_args, see tables()'''
1344
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1345

    
1346
def distinct_table(db, table, distinct_on):
1347
    '''Creates a copy of a temp table which is distinct on the given columns.
1348
    The old and new tables will both get an index on these columns, to
1349
    facilitate merge joins.
1350
    @param distinct_on If empty, creates a table with one row. This is useful if
1351
        your distinct_on columns are all literal values.
1352
    @return The new table.
1353
    '''
1354
    new_table = sql_gen.suffixed_table(table, '_distinct')
1355
    distinct_on = filter(sql_gen.is_table_col, distinct_on)
1356
    
1357
    copy_table_struct(db, table, new_table)
1358
    
1359
    limit = None
1360
    if distinct_on == []: limit = 1 # one sample row
1361
    else:
1362
        add_index(db, distinct_on, new_table, unique=True)
1363
        add_index(db, distinct_on, table) # for join optimization
1364
    
1365
    insert_select(db, new_table, None, mk_select(db, table, order_by=None,
1366
        limit=limit), ignore=True)
1367
    analyze(db, new_table)
1368
    
1369
    return new_table
(24-24/37)