Project

General

Profile

1
# Database access
2

    
3
import copy
4
import re
5
import warnings
6

    
7
import exc
8
import dicts
9
import iters
10
import lists
11
from Proxy import Proxy
12
import rand
13
import sql_gen
14
import strings
15
import util
16

    
17
##### Exceptions
18

    
19
def get_cur_query(cur, input_query=None):
20
    raw_query = None
21
    if hasattr(cur, 'query'): raw_query = cur.query
22
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
23
    
24
    if raw_query != None: return raw_query
25
    else: return '[input] '+strings.ustr(input_query)
26

    
27
def _add_cursor_info(e, *args, **kw_args):
28
    '''For params, see get_cur_query()'''
29
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
30

    
31
class DbException(exc.ExceptionWithCause):
32
    def __init__(self, msg, cause=None, cur=None):
33
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
34
        if cur != None: _add_cursor_info(self, cur)
35

    
36
class ExceptionWithName(DbException):
37
    def __init__(self, name, cause=None):
38
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
39
        self.name = name
40

    
41
class ExceptionWithValue(DbException):
42
    def __init__(self, value, cause=None):
43
        DbException.__init__(self, 'for value: '+strings.as_tt(repr(value)),
44
            cause)
45
        self.value = value
46

    
47
class ExceptionWithNameType(DbException):
48
    def __init__(self, type_, name, cause=None):
49
        DbException.__init__(self, 'for type: '+strings.as_tt(str(type_))
50
            +'; name: '+strings.as_tt(name), cause)
51
        self.type = type_
52
        self.name = name
53

    
54
class ConstraintException(DbException):
55
    def __init__(self, name, cols, cause=None):
56
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
57
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
58
        self.name = name
59
        self.cols = cols
60

    
61
class MissingCastException(DbException):
62
    def __init__(self, type_, col, cause=None):
63
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
64
            +' on column: '+strings.as_tt(col), cause)
65
        self.type = type_
66
        self.col = col
67

    
68
class NameException(DbException): pass
69

    
70
class DuplicateKeyException(ConstraintException): pass
71

    
72
class NullValueException(ConstraintException): pass
73

    
74
class InvalidValueException(ExceptionWithValue): pass
75

    
76
class DuplicateException(ExceptionWithNameType): pass
77

    
78
class EmptyRowException(DbException): pass
79

    
80
##### Warnings
81

    
82
class DbWarning(UserWarning): pass
83

    
84
##### Result retrieval
85

    
86
def col_names(cur): return (col[0] for col in cur.description)
87

    
88
def rows(cur): return iter(lambda: cur.fetchone(), None)
89

    
90
def consume_rows(cur):
91
    '''Used to fetch all rows so result will be cached'''
92
    iters.consume_iter(rows(cur))
93

    
94
def next_row(cur): return rows(cur).next()
95

    
96
def row(cur):
97
    row_ = next_row(cur)
98
    consume_rows(cur)
99
    return row_
100

    
101
def next_value(cur): return next_row(cur)[0]
102

    
103
def value(cur): return row(cur)[0]
104

    
105
def values(cur): return iters.func_iter(lambda: next_value(cur))
106

    
107
def value_or_none(cur):
108
    try: return value(cur)
109
    except StopIteration: return None
110

    
111
##### Escaping
112

    
113
def esc_name_by_module(module, name):
114
    if module == 'psycopg2' or module == None: quote = '"'
115
    elif module == 'MySQLdb': quote = '`'
116
    else: raise NotImplementedError("Can't escape name for "+module+' database')
117
    return sql_gen.esc_name(name, quote)
118

    
119
def esc_name_by_engine(engine, name, **kw_args):
120
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
121

    
122
def esc_name(db, name, **kw_args):
123
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
124

    
125
def qual_name(db, schema, table):
126
    def esc_name_(name): return esc_name(db, name)
127
    table = esc_name_(table)
128
    if schema != None: return esc_name_(schema)+'.'+table
129
    else: return table
130

    
131
##### Database connections
132

    
133
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
134

    
135
db_engines = {
136
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
137
    'PostgreSQL': ('psycopg2', {}),
138
}
139

    
140
DatabaseErrors_set = set([DbException])
141
DatabaseErrors = tuple(DatabaseErrors_set)
142

    
143
def _add_module(module):
144
    DatabaseErrors_set.add(module.DatabaseError)
145
    global DatabaseErrors
146
    DatabaseErrors = tuple(DatabaseErrors_set)
147

    
148
def db_config_str(db_config):
149
    return db_config['engine']+' database '+db_config['database']
150

    
151
log_debug_none = lambda msg, level=2: None
152

    
153
class DbConn:
154
    def __init__(self, db_config, autocommit=True, caching=True,
155
        log_debug=log_debug_none, debug_temp=False):
156
        '''
157
        @param debug_temp Whether temporary objects should instead be permanent.
158
            This assists in debugging the internal objects used by the program.
159
        '''
160
        self.db_config = db_config
161
        self.autocommit = autocommit
162
        self.caching = caching
163
        self.log_debug = log_debug
164
        self.debug = log_debug != log_debug_none
165
        self.debug_temp = debug_temp
166
        self.autoanalyze = False
167
        
168
        self._savepoint = 0
169
        self._reset()
170
    
171
    def __getattr__(self, name):
172
        if name == '__dict__': raise Exception('getting __dict__')
173
        if name == 'db': return self._db()
174
        else: raise AttributeError()
175
    
176
    def __getstate__(self):
177
        state = copy.copy(self.__dict__) # shallow copy
178
        state['log_debug'] = None # don't pickle the debug callback
179
        state['_DbConn__db'] = None # don't pickle the connection
180
        return state
181
    
182
    def clear_cache(self): self.query_results = {}
183
    
184
    def _reset(self):
185
        self.clear_cache()
186
        assert self._savepoint == 0
187
        self._notices_seen = set()
188
        self.__db = None
189
    
190
    def connected(self): return self.__db != None
191
    
192
    def close(self):
193
        if not self.connected(): return
194
        
195
        # Record that the automatic transaction is now closed
196
        if not self.autocommit: self._savepoint -= 1
197
        
198
        self.db.close()
199
        self._reset()
200
    
201
    def reconnect(self):
202
        # Do not do this in test mode as it would roll back everything
203
        if self.autocommit: self.close()
204
        # Connection will be reopened automatically on first query
205
    
206
    def _db(self):
207
        if self.__db == None:
208
            # Process db_config
209
            db_config = self.db_config.copy() # don't modify input!
210
            schemas = db_config.pop('schemas', None)
211
            module_name, mappings = db_engines[db_config.pop('engine')]
212
            module = __import__(module_name)
213
            _add_module(module)
214
            for orig, new in mappings.iteritems():
215
                try: util.rename_key(db_config, orig, new)
216
                except KeyError: pass
217
            
218
            # Connect
219
            self.__db = module.connect(**db_config)
220
            
221
            # Configure connection
222
            if hasattr(self.db, 'set_isolation_level'):
223
                import psycopg2.extensions
224
                self.db.set_isolation_level(
225
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
226
            if schemas != None:
227
                search_path = [self.esc_name(s) for s in schemas.split(',')]
228
                search_path.append(value(run_query(self, 'SHOW search_path',
229
                    log_level=4)))
230
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
231
                    log_level=3)
232
            
233
            # Record that a transaction is already open
234
            if not self.autocommit: self._savepoint += 1
235
        
236
        return self.__db
237
    
238
    class DbCursor(Proxy):
239
        def __init__(self, outer):
240
            Proxy.__init__(self, outer.db.cursor())
241
            self.outer = outer
242
            self.query_results = outer.query_results
243
            self.query_lookup = None
244
            self.result = []
245
        
246
        def execute(self, query):
247
            self._is_insert = query.startswith('INSERT')
248
            self.query_lookup = query
249
            try:
250
                try:
251
                    cur = self.inner.execute(query)
252
                    self.outer.do_autocommit()
253
                finally: self.query = get_cur_query(self.inner, query)
254
            except Exception, e:
255
                _add_cursor_info(e, self, query)
256
                self.result = e # cache the exception as the result
257
                self._cache_result()
258
                raise
259
            
260
            # Always cache certain queries
261
            if query.startswith('CREATE') or query.startswith('ALTER'):
262
                # structural changes
263
                # Rest of query must be unique in the face of name collisions,
264
                # so don't cache ADD COLUMN unless it has distinguishing comment
265
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
266
                    self._cache_result()
267
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
268
                consume_rows(self) # fetch all rows so result will be cached
269
            
270
            return cur
271
        
272
        def fetchone(self):
273
            row = self.inner.fetchone()
274
            if row != None: self.result.append(row)
275
            # otherwise, fetched all rows
276
            else: self._cache_result()
277
            return row
278
        
279
        def _cache_result(self):
280
            # For inserts that return a result set, don't cache result set since
281
            # inserts are not idempotent. Other non-SELECT queries don't have
282
            # their result set read, so only exceptions will be cached (an
283
            # invalid query will always be invalid).
284
            if self.query_results != None and (not self._is_insert
285
                or isinstance(self.result, Exception)):
286
                
287
                assert self.query_lookup != None
288
                self.query_results[self.query_lookup] = self.CacheCursor(
289
                    util.dict_subset(dicts.AttrsDictView(self),
290
                    ['query', 'result', 'rowcount', 'description']))
291
        
292
        class CacheCursor:
293
            def __init__(self, cached_result): self.__dict__ = cached_result
294
            
295
            def execute(self, *args, **kw_args):
296
                if isinstance(self.result, Exception): raise self.result
297
                # otherwise, result is a rows list
298
                self.iter = iter(self.result)
299
            
300
            def fetchone(self):
301
                try: return self.iter.next()
302
                except StopIteration: return None
303
    
304
    def esc_value(self, value):
305
        try: str_ = self.mogrify('%s', [value])
306
        except NotImplementedError, e:
307
            module = util.root_module(self.db)
308
            if module == 'MySQLdb':
309
                import _mysql
310
                str_ = _mysql.escape_string(value)
311
            else: raise e
312
        return strings.to_unicode(str_)
313
    
314
    def esc_name(self, name): return esc_name(self, name) # calls global func
315
    
316
    def std_code(self, str_):
317
        '''Standardizes SQL code.
318
        * Ensures that string literals are prefixed by `E`
319
        '''
320
        if str_.startswith("'"): str_ = 'E'+str_
321
        return str_
322
    
323
    def can_mogrify(self):
324
        module = util.root_module(self.db)
325
        return module == 'psycopg2'
326
    
327
    def mogrify(self, query, params=None):
328
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
329
        else: raise NotImplementedError("Can't mogrify query")
330
    
331
    def print_notices(self):
332
        if hasattr(self.db, 'notices'):
333
            for msg in self.db.notices:
334
                if msg not in self._notices_seen:
335
                    self._notices_seen.add(msg)
336
                    self.log_debug(msg, level=2)
337
    
338
    def run_query(self, query, cacheable=False, log_level=2,
339
        debug_msg_ref=None):
340
        '''
341
        @param log_ignore_excs The log_level will be increased by 2 if the query
342
            throws one of these exceptions.
343
        @param debug_msg_ref If specified, the log message will be returned in
344
            this instead of being output. This allows you to filter log messages
345
            depending on the result of the query.
346
        '''
347
        assert query != None
348
        
349
        if not self.caching: cacheable = False
350
        used_cache = False
351
        
352
        def log_msg(query):
353
            if used_cache: cache_status = 'cache hit'
354
            elif cacheable: cache_status = 'cache miss'
355
            else: cache_status = 'non-cacheable'
356
            return 'DB query: '+cache_status+':\n'+strings.as_code(query, 'SQL')
357
        
358
        try:
359
            # Get cursor
360
            if cacheable:
361
                try:
362
                    cur = self.query_results[query]
363
                    used_cache = True
364
                except KeyError: cur = self.DbCursor(self)
365
            else: cur = self.db.cursor()
366
            
367
            # Log query
368
            if self.debug and debug_msg_ref == None: # log before running
369
                self.log_debug(log_msg(query), log_level)
370
            
371
            # Run query
372
            cur.execute(query)
373
        finally:
374
            self.print_notices()
375
            if self.debug and debug_msg_ref != None: # return after running
376
                debug_msg_ref[0] = log_msg(str(get_cur_query(cur, query)))
377
        
378
        return cur
379
    
380
    def is_cached(self, query): return query in self.query_results
381
    
382
    def with_autocommit(self, func):
383
        import psycopg2.extensions
384
        
385
        prev_isolation_level = self.db.isolation_level
386
        self.db.set_isolation_level(
387
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
388
        try: return func()
389
        finally: self.db.set_isolation_level(prev_isolation_level)
390
    
391
    def with_savepoint(self, func):
392
        savepoint = 'level_'+str(self._savepoint)
393
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
394
        self._savepoint += 1
395
        try: return func()
396
        except:
397
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
398
            raise
399
        finally:
400
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
401
            # "The savepoint remains valid and can be rolled back to again"
402
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
403
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
404
            
405
            self._savepoint -= 1
406
            assert self._savepoint >= 0
407
            
408
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
409
    
410
    def do_autocommit(self):
411
        '''Autocommits if outside savepoint'''
412
        assert self._savepoint >= 1
413
        if self.autocommit and self._savepoint == 1:
414
            self.log_debug('Autocommitting', level=4)
415
            self.db.commit()
416
    
417
    def col_info(self, col):
418
        table = sql_gen.Table('columns', 'information_schema')
419
        type_ = sql_gen.Coalesce(sql_gen.Nullif(sql_gen.Col('data_type'),
420
            'USER-DEFINED'), sql_gen.Col('udt_name'))
421
        cols = [type_, 'column_default',
422
            sql_gen.Cast('boolean', sql_gen.Col('is_nullable'))]
423
        
424
        conds = [('table_name', col.table.name), ('column_name', col.name)]
425
        schema = col.table.schema
426
        if schema != None: conds.append(('table_schema', schema))
427
        
428
        type_, default, nullable = row(select(self, table, cols, conds,
429
            order_by='table_schema', limit=1, cacheable=False, log_level=4))
430
            # TODO: order_by search_path schema order
431
        default = sql_gen.as_Code(default, self)
432
        
433
        return sql_gen.TypedCol(col.name, type_, default, nullable)
434
    
435
    def TempFunction(self, name):
436
        if self.debug_temp: schema = None
437
        else: schema = 'pg_temp'
438
        return sql_gen.Function(name, schema)
439

    
440
connect = DbConn
441

    
442
##### Recoverable querying
443

    
444
def with_savepoint(db, func): return db.with_savepoint(func)
445

    
446
def run_query(db, query, recover=None, cacheable=False, log_level=2,
447
    log_ignore_excs=None, **kw_args):
448
    '''For params, see DbConn.run_query()'''
449
    if recover == None: recover = False
450
    if log_ignore_excs == None: log_ignore_excs = ()
451
    log_ignore_excs = tuple(log_ignore_excs)
452
    
453
    debug_msg_ref = None # usually, db.run_query() logs query before running it
454
    # But if filtering with log_ignore_excs, wait until after exception parsing
455
    if log_ignore_excs != () or not db.can_mogrify(): debug_msg_ref = [None]
456
    
457
    try:
458
        try:
459
            def run(): return db.run_query(query, cacheable, log_level,
460
                debug_msg_ref, **kw_args)
461
            if recover and not db.is_cached(query):
462
                return with_savepoint(db, run)
463
            else: return run() # don't need savepoint if cached
464
        except Exception, e:
465
            msg = strings.ustr(e.args[0])
466
            
467
            match = re.match(r'^duplicate key value violates unique constraint '
468
                r'"((_?[^\W_]+(?=[._]))?.+?)"', msg)
469
            if match:
470
                constraint, table = match.groups()
471
                cols = []
472
                if recover: # need auto-rollback to run index_cols()
473
                    try: cols = index_cols(db, table, constraint)
474
                    except NotImplementedError: pass
475
                raise DuplicateKeyException(constraint, cols, e)
476
            
477
            match = re.match(r'^null value in column "(.+?)" violates not-null'
478
                r' constraint', msg)
479
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
480
            
481
            match = re.match(r'^(?:invalid input (?:syntax|value)\b.*?'
482
                r'|.+? field value out of range): "(.+?)"', msg)
483
            if match:
484
                value, = match.groups()
485
                raise InvalidValueException(strings.to_unicode(value), e)
486
            
487
            match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
488
                r'is of type', msg)
489
            if match:
490
                col, type_ = match.groups()
491
                raise MissingCastException(type_, col, e)
492
            
493
            match = re.match(r'^(\S+) "(.+?)".*? already exists', msg)
494
            if match:
495
                type_, name = match.groups()
496
                raise DuplicateException(type_, name, e)
497
            
498
            raise # no specific exception raised
499
    except log_ignore_excs:
500
        log_level += 2
501
        raise
502
    finally:
503
        if debug_msg_ref != None and debug_msg_ref[0] != None:
504
            db.log_debug(debug_msg_ref[0], log_level)
505

    
506
##### Basic queries
507

    
508
def next_version(name):
509
    version = 1 # first existing name was version 0
510
    match = re.match(r'^(.*)#(\d+)$', name)
511
    if match:
512
        name, version = match.groups()
513
        version = int(version)+1
514
    return sql_gen.concat(name, '#'+str(version))
515

    
516
def lock_table(db, table, mode):
517
    table = sql_gen.as_Table(table)
518
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
519

    
520
def run_query_into(db, query, into=None, add_indexes_=False, **kw_args):
521
    '''Outputs a query to a temp table.
522
    For params, see run_query().
523
    '''
524
    if into == None: return run_query(db, query, **kw_args)
525
    
526
    assert isinstance(into, sql_gen.Table)
527
    
528
    into.is_temp = True
529
    # "temporary tables cannot specify a schema name", so remove schema
530
    into.schema = None
531
    
532
    kw_args['recover'] = True
533
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
534
    
535
    temp = not db.debug_temp # tables are permanent in debug_temp mode
536
    
537
    # Create table
538
    while True:
539
        create_query = 'CREATE'
540
        if temp: create_query += ' TEMP'
541
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
542
        
543
        try:
544
            cur = run_query(db, create_query, **kw_args)
545
                # CREATE TABLE AS sets rowcount to # rows in query
546
            break
547
        except DuplicateException, e:
548
            into.name = next_version(into.name)
549
            # try again with next version of name
550
    
551
    if add_indexes_: add_indexes(db, into)
552
    
553
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
554
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
555
    # table is going to be used in complex queries, it is wise to run ANALYZE on
556
    # the temporary table after it is populated."
557
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
558
    # If into is not a temp table, ANALYZE is useful but not required.
559
    analyze(db, into)
560
    
561
    return cur
562

    
563
order_by_pkey = object() # tells mk_select() to order by the pkey
564

    
565
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
566

    
567
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
568
    start=None, order_by=order_by_pkey, default_table=None):
569
    '''
570
    @param tables The single table to select from, or a list of tables to join
571
        together, with tables after the first being sql_gen.Join objects
572
    @param fields Use None to select all fields in the table
573
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
574
        * container can be any iterable type
575
        * compare_left_side: sql_gen.Code|str (for col name)
576
        * compare_right_side: sql_gen.ValueCond|literal value
577
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
578
        use all columns
579
    @return query
580
    '''
581
    # Parse tables param
582
    tables = lists.mk_seq(tables)
583
    tables = list(tables) # don't modify input! (list() copies input)
584
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
585
    
586
    # Parse other params
587
    if conds == None: conds = []
588
    elif dicts.is_dict(conds): conds = conds.items()
589
    conds = list(conds) # don't modify input! (list() copies input)
590
    assert limit == None or isinstance(limit, (int, long))
591
    assert start == None or isinstance(start, (int, long))
592
    if order_by is order_by_pkey:
593
        if distinct_on != []: order_by = None
594
        else: order_by = pkey(db, table0, recover=True)
595
    
596
    query = 'SELECT'
597
    
598
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
599
    
600
    # DISTINCT ON columns
601
    if distinct_on != []:
602
        query += '\nDISTINCT'
603
        if distinct_on is not distinct_on_all:
604
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
605
    
606
    # Columns
607
    if fields == None:
608
        if query.find('\n') >= 0: whitespace = '\n'
609
        else: whitespace = ' '
610
        query += whitespace+'*'
611
    else:
612
        assert fields != []
613
        query += '\n'+('\n, '.join(map(parse_col, fields)))
614
    
615
    # Main table
616
    query += '\nFROM '+table0.to_str(db)
617
    
618
    # Add joins
619
    left_table = table0
620
    for join_ in tables:
621
        table = join_.table
622
        
623
        # Parse special values
624
        if join_.type_ is sql_gen.filter_out: # filter no match
625
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
626
                sql_gen.CompareCond(None, '~=')))
627
        
628
        query += '\n'+join_.to_str(db, left_table)
629
        
630
        left_table = table
631
    
632
    missing = True
633
    if conds != []:
634
        if len(conds) == 1: whitespace = ' '
635
        else: whitespace = '\n'
636
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
637
            .to_str(db) for l, r in conds], 'WHERE')
638
        missing = False
639
    if order_by != None:
640
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
641
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
642
    if start != None:
643
        if start != 0: query += '\nOFFSET '+str(start)
644
        missing = False
645
    if missing: warnings.warn(DbWarning(
646
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
647
    
648
    return query
649

    
650
def select(db, *args, **kw_args):
651
    '''For params, see mk_select() and run_query()'''
652
    recover = kw_args.pop('recover', None)
653
    cacheable = kw_args.pop('cacheable', True)
654
    log_level = kw_args.pop('log_level', 2)
655
    
656
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
657
        log_level=log_level)
658

    
659
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
660
    embeddable=False, ignore=False):
661
    '''
662
    @param returning str|None An inserted column (such as pkey) to return
663
    @param embeddable Whether the query should be embeddable as a nested SELECT.
664
        Warning: If you set this and cacheable=True when the query is run, the
665
        query will be fully cached, not just if it raises an exception.
666
    @param ignore Whether to ignore duplicate keys.
667
    '''
668
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
669
    if cols == []: cols = None # no cols (all defaults) = unknown col names
670
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
671
    if select_query == None: select_query = 'DEFAULT VALUES'
672
    if returning != None: returning = sql_gen.as_Col(returning, table)
673
    
674
    first_line = 'INSERT INTO '+table.to_str(db)
675
    
676
    def mk_insert(select_query):
677
        query = first_line
678
        if cols != None:
679
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
680
        query += '\n'+select_query
681
        
682
        if returning != None:
683
            returning_name_col = sql_gen.to_name_only_col(returning)
684
            query += '\nRETURNING '+returning_name_col.to_str(db)
685
        
686
        return query
687
    
688
    return_type = 'unknown'
689
    if returning != None: return_type = returning.to_str(db)+'%TYPE'
690
    
691
    lang = 'sql'
692
    if ignore:
693
        # Always return something to set the correct rowcount
694
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
695
        
696
        embeddable = True # must use function
697
        lang = 'plpgsql'
698
        
699
        if cols == None:
700
            row = [sql_gen.Col(sql_gen.all_cols, 'row')]
701
            row_vars = [sql_gen.Table('row')]
702
        else:
703
            row_vars = row = [sql_gen.Col(c.name, 'row') for c in cols]
704
        
705
        query = '''\
706
DECLARE
707
    row '''+table.to_str(db)+'''%ROWTYPE;
708
BEGIN
709
    /* Need an EXCEPTION block for each individual row because "When an error is
710
    caught by an EXCEPTION clause, [...] all changes to persistent database
711
    state within the block are rolled back."
712
    This is unfortunate because "A block containing an EXCEPTION clause is
713
    significantly more expensive to enter and exit than a block without one."
714
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
715
#PLPGSQL-ERROR-TRAPPING)
716
    */
717
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
718
'''+select_query+'''
719
    LOOP
720
        BEGIN
721
            RETURN QUERY
722
'''+mk_insert(sql_gen.Values(row).to_str(db))+'''
723
;
724
        EXCEPTION
725
            WHEN unique_violation THEN NULL; -- continue to next row
726
        END;
727
    END LOOP;
728
END;\
729
'''
730
    else: query = mk_insert(select_query)
731
    
732
    if embeddable:
733
        # Create function
734
        function_name = sql_gen.clean_name(first_line)
735
        while True:
736
            try:
737
                function = db.TempFunction(function_name)
738
                
739
                function_query = '''\
740
CREATE FUNCTION '''+function.to_str(db)+'''()
741
RETURNS SETOF '''+return_type+'''
742
LANGUAGE '''+lang+'''
743
AS $$
744
'''+query+'''
745
$$;
746
'''
747
                run_query(db, function_query, recover=True, cacheable=True,
748
                    log_ignore_excs=(DuplicateException,))
749
                break # this version was successful
750
            except DuplicateException, e:
751
                function_name = next_version(function_name)
752
                # try again with next version of name
753
        
754
        # Return query that uses function
755
        cols = None
756
        if returning != None: cols = [returning]
757
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
758
            cols) # AS clause requires function alias
759
        return mk_select(db, func_table, start=0, order_by=None)
760
    
761
    return query
762

    
763
def insert_select(db, table, *args, **kw_args):
764
    '''For params, see mk_insert_select() and run_query_into()
765
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
766
        values in
767
    '''
768
    into = kw_args.pop('into', None)
769
    if into != None: kw_args['embeddable'] = True
770
    recover = kw_args.pop('recover', None)
771
    if kw_args.get('ignore', False): recover = True
772
    cacheable = kw_args.pop('cacheable', True)
773
    log_level = kw_args.pop('log_level', 2)
774
    
775
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
776
        into, recover=recover, cacheable=cacheable, log_level=log_level)
777
    autoanalyze(db, table)
778
    return cur
779

    
780
default = sql_gen.default # tells insert() to use the default value for a column
781

    
782
def insert(db, table, row, *args, **kw_args):
783
    '''For params, see insert_select()'''
784
    if lists.is_seq(row): cols = None
785
    else:
786
        cols = row.keys()
787
        row = row.values()
788
    row = list(row) # ensure that "== []" works
789
    
790
    if row == []: query = None
791
    else: query = sql_gen.Values(row).to_str(db)
792
    
793
    return insert_select(db, table, cols, query, *args, **kw_args)
794

    
795
def mk_update(db, table, changes=None, cond=None, in_place=False):
796
    '''
797
    @param changes [(col, new_value),...]
798
        * container can be any iterable type
799
        * col: sql_gen.Code|str (for col name)
800
        * new_value: sql_gen.Code|literal value
801
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
802
    @param in_place If set, locks the table and updates rows in place.
803
        This avoids creating dead rows in PostgreSQL.
804
        * cond must be None
805
    @return str query
806
    '''
807
    table = sql_gen.as_Table(table)
808
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
809
        for c, v in changes]
810
    
811
    if in_place:
812
        assert cond == None
813
        
814
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
815
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
816
            +db.col_info(sql_gen.with_default_table(c, table)).type
817
            +'\nUSING '+v.to_str(db) for c, v in changes))
818
    else:
819
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
820
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
821
            for c, v in changes))
822
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
823
    
824
    return query
825

    
826
def update(db, table, *args, **kw_args):
827
    '''For params, see mk_update() and run_query()'''
828
    recover = kw_args.pop('recover', None)
829
    cacheable = kw_args.pop('cacheable', False)
830
    log_level = kw_args.pop('log_level', 2)
831
    
832
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
833
        cacheable, log_level=log_level)
834
    autoanalyze(db, table)
835
    return cur
836

    
837
def last_insert_id(db):
838
    module = util.root_module(db.db)
839
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
840
    elif module == 'MySQLdb': return db.insert_id()
841
    else: return None
842

    
843
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
844
    '''Creates a mapping from original column names (which may have collisions)
845
    to names that will be distinct among the columns' tables.
846
    This is meant to be used for several tables that are being joined together.
847
    @param cols The columns to combine. Duplicates will be removed.
848
    @param into The table for the new columns.
849
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
850
        columns will be included in the mapping even if they are not in cols.
851
        The tables of the provided Col objects will be changed to into, so make
852
        copies of them if you want to keep the original tables.
853
    @param as_items Whether to return a list of dict items instead of a dict
854
    @return dict(orig_col=new_col, ...)
855
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
856
        * new_col: sql_gen.Col(orig_col_name, into)
857
        * All mappings use the into table so its name can easily be
858
          changed for all columns at once
859
    '''
860
    cols = lists.uniqify(cols)
861
    
862
    items = []
863
    for col in preserve:
864
        orig_col = copy.copy(col)
865
        col.table = into
866
        items.append((orig_col, col))
867
    preserve = set(preserve)
868
    for col in cols:
869
        if col not in preserve:
870
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
871
    
872
    if not as_items: items = dict(items)
873
    return items
874

    
875
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
876
    '''For params, see mk_flatten_mapping()
877
    @return See return value of mk_flatten_mapping()
878
    '''
879
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
880
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
881
    run_query_into(db, mk_select(db, joins, cols, limit=limit, start=start),
882
        into=into, add_indexes_=True)
883
    return dict(items)
884

    
885
##### Database structure introspection
886

    
887
#### Tables
888

    
889
def tables(db, schema_like='public', table_like='%', exact=False):
890
    if exact: compare = '='
891
    else: compare = 'LIKE'
892
    
893
    module = util.root_module(db.db)
894
    if module == 'psycopg2':
895
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
896
            ('tablename', sql_gen.CompareCond(table_like, compare))]
897
        return values(select(db, 'pg_tables', ['tablename'], conds,
898
            order_by='tablename', log_level=4))
899
    elif module == 'MySQLdb':
900
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
901
            , cacheable=True, log_level=4))
902
    else: raise NotImplementedError("Can't list tables for "+module+' database')
903

    
904
def table_exists(db, table):
905
    table = sql_gen.as_Table(table)
906
    return list(tables(db, table.schema, table.name, exact=True)) != []
907

    
908
def table_row_count(db, table, recover=None):
909
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
910
        order_by=None, start=0), recover=recover, log_level=3))
911

    
912
def table_cols(db, table, recover=None):
913
    return list(col_names(select(db, table, limit=0, order_by=None,
914
        recover=recover, log_level=4)))
915

    
916
def pkey(db, table, recover=None):
917
    '''Assumed to be first column in table'''
918
    return table_cols(db, table, recover)[0]
919

    
920
not_null_col = 'not_null_col'
921

    
922
def table_not_null_col(db, table, recover=None):
923
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
924
    if not_null_col in table_cols(db, table, recover): return not_null_col
925
    else: return pkey(db, table, recover)
926

    
927
def index_cols(db, table, index):
928
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
929
    automatically created. When you don't know whether something is a UNIQUE
930
    constraint or a UNIQUE index, use this function.'''
931
    module = util.root_module(db.db)
932
    if module == 'psycopg2':
933
        return list(values(run_query(db, '''\
934
SELECT attname
935
FROM
936
(
937
        SELECT attnum, attname
938
        FROM pg_index
939
        JOIN pg_class index ON index.oid = indexrelid
940
        JOIN pg_class table_ ON table_.oid = indrelid
941
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
942
        WHERE
943
            table_.relname = '''+db.esc_value(table)+'''
944
            AND index.relname = '''+db.esc_value(index)+'''
945
    UNION
946
        SELECT attnum, attname
947
        FROM
948
        (
949
            SELECT
950
                indrelid
951
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
952
                    AS indkey
953
            FROM pg_index
954
            JOIN pg_class index ON index.oid = indexrelid
955
            JOIN pg_class table_ ON table_.oid = indrelid
956
            WHERE
957
                table_.relname = '''+db.esc_value(table)+'''
958
                AND index.relname = '''+db.esc_value(index)+'''
959
        ) s
960
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
961
) s
962
ORDER BY attnum
963
'''
964
            , cacheable=True, log_level=4)))
965
    else: raise NotImplementedError("Can't list index columns for "+module+
966
        ' database')
967

    
968
def constraint_cols(db, table, constraint):
969
    module = util.root_module(db.db)
970
    if module == 'psycopg2':
971
        return list(values(run_query(db, '''\
972
SELECT attname
973
FROM pg_constraint
974
JOIN pg_class ON pg_class.oid = conrelid
975
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
976
WHERE
977
    relname = '''+db.esc_value(table)+'''
978
    AND conname = '''+db.esc_value(constraint)+'''
979
ORDER BY attnum
980
'''
981
            )))
982
    else: raise NotImplementedError("Can't list constraint columns for "+module+
983
        ' database')
984

    
985
#### Functions
986

    
987
def function_exists(db, function):
988
    function = sql_gen.as_Function(function)
989
    
990
    info_table = sql_gen.Table('routines', 'information_schema')
991
    conds = [('routine_name', function.name)]
992
    schema = function.schema
993
    if schema != None: conds.append(('routine_schema', schema))
994
    # Exclude trigger functions, since they cannot be called directly
995
    conds.append(('data_type', sql_gen.CompareCond('trigger', '!=')))
996
    
997
    return list(values(select(db, info_table, ['routine_name'], conds,
998
        order_by='routine_schema', limit=1, log_level=4))) != []
999
        # TODO: order_by search_path schema order
1000

    
1001
##### Structural changes
1002

    
1003
#### Columns
1004

    
1005
def add_col(db, table, col, comment=None, **kw_args):
1006
    '''
1007
    @param col TypedCol Name may be versioned, so be sure to propagate any
1008
        renaming back to any source column for the TypedCol.
1009
    @param comment None|str SQL comment used to distinguish columns of the same
1010
        name from each other when they contain different data, to allow the
1011
        ADD COLUMN query to be cached. If not set, query will not be cached.
1012
    '''
1013
    assert isinstance(col, sql_gen.TypedCol)
1014
    
1015
    while True:
1016
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1017
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1018
        
1019
        try:
1020
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1021
            break
1022
        except DuplicateException:
1023
            col.name = next_version(col.name)
1024
            # try again with next version of name
1025

    
1026
def add_not_null(db, col):
1027
    table = col.table
1028
    col = sql_gen.to_name_only_col(col)
1029
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1030
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1031

    
1032
row_num_col = '_row_num'
1033

    
1034
row_num_typed_col = sql_gen.TypedCol(row_num_col, 'serial', nullable=False,
1035
    constraints='PRIMARY KEY')
1036

    
1037
def add_row_num(db, table):
1038
    '''Adds a row number column to a table. Its name is in row_num_col. It will
1039
    be the primary key.'''
1040
    add_col(db, table, row_num_typed_col, log_level=3)
1041

    
1042
#### Indexes
1043

    
1044
def add_pkey(db, table, cols=None, recover=None):
1045
    '''Adds a primary key.
1046
    @param cols [sql_gen.Col,...] The columns in the primary key.
1047
        Defaults to the first column in the table.
1048
    @pre The table must not already have a primary key.
1049
    '''
1050
    table = sql_gen.as_Table(table)
1051
    if cols == None: cols = [pkey(db, table, recover)]
1052
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1053
    
1054
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1055
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1056
        log_ignore_excs=(DuplicateException,))
1057

    
1058
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1059
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1060
    Currently, only function calls are supported as expressions.
1061
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1062
        This allows indexes to be used for comparisons where NULLs are equal.
1063
    '''
1064
    exprs = lists.mk_seq(exprs)
1065
    
1066
    # Parse exprs
1067
    old_exprs = exprs[:]
1068
    exprs = []
1069
    cols = []
1070
    for i, expr in enumerate(old_exprs):
1071
        expr = sql_gen.as_Col(expr, table)
1072
        
1073
        # Handle nullable columns
1074
        if ensure_not_null_:
1075
            try: expr = ensure_not_null(db, expr)
1076
            except KeyError: pass # unknown type, so just create plain index
1077
        
1078
        # Extract col
1079
        expr = copy.deepcopy(expr) # don't modify input!
1080
        if isinstance(expr, sql_gen.FunctionCall):
1081
            col = expr.args[0]
1082
            expr = sql_gen.Expr(expr)
1083
        else: col = expr
1084
        assert isinstance(col, sql_gen.Col)
1085
        
1086
        # Extract table
1087
        if table == None:
1088
            assert sql_gen.is_table_col(col)
1089
            table = col.table
1090
        
1091
        col.table = None
1092
        
1093
        exprs.append(expr)
1094
        cols.append(col)
1095
    
1096
    table = sql_gen.as_Table(table)
1097
    index = sql_gen.Table(str(sql_gen.Col(','.join(map(str, cols)), table)))
1098
    
1099
    # Add index
1100
    while True:
1101
        str_ = 'CREATE'
1102
        if unique: str_ += ' UNIQUE'
1103
        str_ += ' INDEX '+index.to_str(db)+' ON '+table.to_str(db)+' ('+(
1104
            ', '.join((v.to_str(db) for v in exprs)))+')'
1105
        
1106
        try:
1107
            run_query(db, str_, recover=True, cacheable=True, log_level=3,
1108
                log_ignore_excs=(DuplicateException,))
1109
            break
1110
        except DuplicateException:
1111
            index.name = next_version(index.name)
1112
            # try again with next version of name
1113

    
1114
def add_index_col(db, col, suffix, expr, nullable=True):
1115
    if sql_gen.index_col(col) != None: return # already has index col
1116
    
1117
    new_col = sql_gen.suffixed_col(col, suffix)
1118
    
1119
    # Add column
1120
    new_typed_col = sql_gen.TypedCol(new_col.name, db.col_info(col).type)
1121
    add_col(db, col.table, new_typed_col, comment='src: '+repr(col),
1122
        log_level=3)
1123
    new_col.name = new_typed_col.name # propagate any renaming
1124
    
1125
    update(db, col.table, [(new_col, expr)], in_place=True, cacheable=True,
1126
        log_level=3)
1127
    if not nullable: add_not_null(db, new_col)
1128
    add_index(db, new_col)
1129
    
1130
    col.table.index_cols[col.name] = new_col.name
1131

    
1132
# Controls when ensure_not_null() will use index columns
1133
not_null_index_cols_min_rows = 0 # rows; initially always use index columns
1134

    
1135
def ensure_not_null(db, col):
1136
    '''For params, see sql_gen.ensure_not_null()'''
1137
    expr = sql_gen.ensure_not_null(db, col)
1138
    
1139
    # If a nullable column in a temp table, add separate index column instead.
1140
    # Note that for small datasources, this adds 6-25% to the total import time.
1141
    if (sql_gen.is_temp_col(col) and isinstance(expr, sql_gen.EnsureNotNull)
1142
        and table_row_count(db, col.table) >= not_null_index_cols_min_rows):
1143
        add_index_col(db, col, '::NOT NULL', expr, nullable=False)
1144
        expr = sql_gen.index_col(col)
1145
    
1146
    return expr
1147

    
1148
already_indexed = object() # tells add_indexes() the pkey has already been added
1149

    
1150
def add_indexes(db, table, has_pkey=True):
1151
    '''Adds an index on all columns in a table.
1152
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1153
        index should be added on the first column.
1154
        * If already_indexed, the pkey is assumed to have already been added
1155
    '''
1156
    cols = table_cols(db, table)
1157
    if has_pkey:
1158
        if has_pkey is not already_indexed: add_pkey(db, table)
1159
        cols = cols[1:]
1160
    for col in cols: add_index(db, col, table)
1161

    
1162
#### Tables
1163

    
1164
### Maintenance
1165

    
1166
def analyze(db, table):
1167
    table = sql_gen.as_Table(table)
1168
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1169

    
1170
def autoanalyze(db, table):
1171
    if db.autoanalyze: analyze(db, table)
1172

    
1173
def vacuum(db, table):
1174
    table = sql_gen.as_Table(table)
1175
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1176
        log_level=3))
1177

    
1178
### Lifecycle
1179

    
1180
def drop_table(db, table):
1181
    table = sql_gen.as_Table(table)
1182
    return run_query(db, 'DROP TABLE IF EXISTS '+table.to_str(db)+' CASCADE')
1183

    
1184
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1185
    like=None):
1186
    '''Creates a table.
1187
    @param cols [sql_gen.TypedCol,...] The column names and types
1188
    @param has_pkey If set, the first column becomes the primary key.
1189
    @param col_indexes bool|[ref]
1190
        * If True, indexes will be added on all non-pkey columns.
1191
        * If a list reference, [0] will be set to a function to do this.
1192
          This can be used to delay index creation until the table is populated.
1193
    '''
1194
    table = sql_gen.as_Table(table)
1195
    
1196
    if like != None:
1197
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1198
            ]+cols
1199
    if has_pkey:
1200
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1201
        pkey.constraints = 'PRIMARY KEY'
1202
    
1203
    temp = table.is_temp and not db.debug_temp
1204
        # temp tables permanent in debug_temp mode
1205
    
1206
    # Create table
1207
    while True:
1208
        str_ = 'CREATE'
1209
        if temp: str_ += ' TEMP'
1210
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1211
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1212
        str_ += '\n);'
1213
        
1214
        try:
1215
            run_query(db, str_, recover=True, cacheable=True, log_level=2,
1216
                log_ignore_excs=(DuplicateException,))
1217
            break
1218
        except DuplicateException:
1219
            table.name = next_version(table.name)
1220
            # try again with next version of name
1221
    
1222
    # Add indexes
1223
    if has_pkey: has_pkey = already_indexed
1224
    def add_indexes_(): add_indexes(db, table, has_pkey)
1225
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1226
    elif col_indexes: add_indexes_() # add now
1227

    
1228
def copy_table_struct(db, src, dest):
1229
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1230
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1231

    
1232
### Data
1233

    
1234
def truncate(db, table, schema='public', **kw_args):
1235
    '''For params, see run_query()'''
1236
    table = sql_gen.as_Table(table, schema)
1237
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1238

    
1239
def empty_temp(db, tables):
1240
    if db.debug_temp: return # leave temp tables there for debugging
1241
    tables = lists.mk_seq(tables)
1242
    for table in tables: truncate(db, table, log_level=3)
1243

    
1244
def empty_db(db, schema='public', **kw_args):
1245
    '''For kw_args, see tables()'''
1246
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1247

    
1248
def distinct_table(db, table, distinct_on):
1249
    '''Creates a copy of a temp table which is distinct on the given columns.
1250
    The old and new tables will both get an index on these columns, to
1251
    facilitate merge joins.
1252
    @param distinct_on If empty, creates a table with one row. This is useful if
1253
        your distinct_on columns are all literal values.
1254
    @return The new table.
1255
    '''
1256
    new_table = sql_gen.suffixed_table(table, '_distinct')
1257
    
1258
    copy_table_struct(db, table, new_table)
1259
    
1260
    limit = None
1261
    if distinct_on == []: limit = 1 # one sample row
1262
    else:
1263
        add_index(db, distinct_on, new_table, unique=True)
1264
        add_index(db, distinct_on, table) # for join optimization
1265
    
1266
    insert_select(db, new_table, None, mk_select(db, table, start=0,
1267
        limit=limit), ignore=True)
1268
    analyze(db, new_table)
1269
    
1270
    return new_table
(24-24/37)