Project

General

Profile

1
# Database access
2

    
3
import copy
4
import re
5
import warnings
6

    
7
import exc
8
import dicts
9
import iters
10
import lists
11
from Proxy import Proxy
12
import rand
13
import sql_gen
14
import strings
15
import util
16

    
17
##### Exceptions
18

    
19
def get_cur_query(cur, input_query=None):
20
    raw_query = None
21
    if hasattr(cur, 'query'): raw_query = cur.query
22
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
23
    
24
    if raw_query != None: return raw_query
25
    else: return '[input] '+strings.ustr(input_query)
26

    
27
def _add_cursor_info(e, *args, **kw_args):
28
    '''For params, see get_cur_query()'''
29
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
30

    
31
class DbException(exc.ExceptionWithCause):
32
    def __init__(self, msg, cause=None, cur=None):
33
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
34
        if cur != None: _add_cursor_info(self, cur)
35

    
36
class ExceptionWithName(DbException):
37
    def __init__(self, name, cause=None):
38
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
39
        self.name = name
40

    
41
class ExceptionWithValue(DbException):
42
    def __init__(self, value, cause=None):
43
        DbException.__init__(self, 'for value: '+strings.as_tt(repr(value)),
44
            cause)
45
        self.value = value
46

    
47
class ExceptionWithNameType(DbException):
48
    def __init__(self, type_, name, cause=None):
49
        DbException.__init__(self, 'for type: '+strings.as_tt(str(type_))
50
            +'; name: '+strings.as_tt(name), cause)
51
        self.type = type_
52
        self.name = name
53

    
54
class ConstraintException(DbException):
55
    def __init__(self, name, cols, cause=None):
56
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
57
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
58
        self.name = name
59
        self.cols = cols
60

    
61
class MissingCastException(DbException):
62
    def __init__(self, type_, col, cause=None):
63
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
64
            +' on column: '+strings.as_tt(col), cause)
65
        self.type = type_
66
        self.col = col
67

    
68
class NameException(DbException): pass
69

    
70
class DuplicateKeyException(ConstraintException): pass
71

    
72
class NullValueException(ConstraintException): pass
73

    
74
class InvalidValueException(ExceptionWithValue): pass
75

    
76
class DuplicateException(ExceptionWithNameType): pass
77

    
78
class EmptyRowException(DbException): pass
79

    
80
##### Warnings
81

    
82
class DbWarning(UserWarning): pass
83

    
84
##### Result retrieval
85

    
86
def col_names(cur): return (col[0] for col in cur.description)
87

    
88
def rows(cur): return iter(lambda: cur.fetchone(), None)
89

    
90
def consume_rows(cur):
91
    '''Used to fetch all rows so result will be cached'''
92
    iters.consume_iter(rows(cur))
93

    
94
def next_row(cur): return rows(cur).next()
95

    
96
def row(cur):
97
    row_ = next_row(cur)
98
    consume_rows(cur)
99
    return row_
100

    
101
def next_value(cur): return next_row(cur)[0]
102

    
103
def value(cur): return row(cur)[0]
104

    
105
def values(cur): return iters.func_iter(lambda: next_value(cur))
106

    
107
def value_or_none(cur):
108
    try: return value(cur)
109
    except StopIteration: return None
110

    
111
##### Escaping
112

    
113
def esc_name_by_module(module, name):
114
    if module == 'psycopg2' or module == None: quote = '"'
115
    elif module == 'MySQLdb': quote = '`'
116
    else: raise NotImplementedError("Can't escape name for "+module+' database')
117
    return sql_gen.esc_name(name, quote)
118

    
119
def esc_name_by_engine(engine, name, **kw_args):
120
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
121

    
122
def esc_name(db, name, **kw_args):
123
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
124

    
125
def qual_name(db, schema, table):
126
    def esc_name_(name): return esc_name(db, name)
127
    table = esc_name_(table)
128
    if schema != None: return esc_name_(schema)+'.'+table
129
    else: return table
130

    
131
##### Database connections
132

    
133
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
134

    
135
db_engines = {
136
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
137
    'PostgreSQL': ('psycopg2', {}),
138
}
139

    
140
DatabaseErrors_set = set([DbException])
141
DatabaseErrors = tuple(DatabaseErrors_set)
142

    
143
def _add_module(module):
144
    DatabaseErrors_set.add(module.DatabaseError)
145
    global DatabaseErrors
146
    DatabaseErrors = tuple(DatabaseErrors_set)
147

    
148
def db_config_str(db_config):
149
    return db_config['engine']+' database '+db_config['database']
150

    
151
log_debug_none = lambda msg, level=2: None
152

    
153
class DbConn:
154
    def __init__(self, db_config, autocommit=True, caching=True,
155
        log_debug=log_debug_none, debug_temp=False):
156
        '''
157
        @param debug_temp Whether temporary objects should instead be permanent.
158
            This assists in debugging the internal objects used by the program.
159
        '''
160
        self.db_config = db_config
161
        self.autocommit = autocommit
162
        self.caching = caching
163
        self.log_debug = log_debug
164
        self.debug = log_debug != log_debug_none
165
        self.debug_temp = debug_temp
166
        self.autoanalyze = False
167
        
168
        self._savepoint = 0
169
        self._reset()
170
    
171
    def __getattr__(self, name):
172
        if name == '__dict__': raise Exception('getting __dict__')
173
        if name == 'db': return self._db()
174
        else: raise AttributeError()
175
    
176
    def __getstate__(self):
177
        state = copy.copy(self.__dict__) # shallow copy
178
        state['log_debug'] = None # don't pickle the debug callback
179
        state['_DbConn__db'] = None # don't pickle the connection
180
        return state
181
    
182
    def clear_cache(self): self.query_results = {}
183
    
184
    def _reset(self):
185
        self.clear_cache()
186
        assert self._savepoint == 0
187
        self._notices_seen = set()
188
        self.__db = None
189
    
190
    def connected(self): return self.__db != None
191
    
192
    def close(self):
193
        if not self.connected(): return
194
        
195
        self.db.close()
196
        self._reset()
197
    
198
    def reconnect(self):
199
        # Do not do this in test mode as it would roll back everything
200
        if self.autocommit: self.close()
201
        # Connection will be reopened automatically on first query
202
    
203
    def _db(self):
204
        if self.__db == None:
205
            # Process db_config
206
            db_config = self.db_config.copy() # don't modify input!
207
            schemas = db_config.pop('schemas', None)
208
            module_name, mappings = db_engines[db_config.pop('engine')]
209
            module = __import__(module_name)
210
            _add_module(module)
211
            for orig, new in mappings.iteritems():
212
                try: util.rename_key(db_config, orig, new)
213
                except KeyError: pass
214
            
215
            # Connect
216
            self.__db = module.connect(**db_config)
217
            
218
            # Configure connection
219
            if hasattr(self.db, 'set_isolation_level'):
220
                import psycopg2.extensions
221
                self.db.set_isolation_level(
222
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
223
            if schemas != None:
224
                search_path = [self.esc_name(s) for s in schemas.split(',')]
225
                search_path.append(value(run_query(self, 'SHOW search_path',
226
                    log_level=4)))
227
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
228
                    log_level=3)
229
        
230
        return self.__db
231
    
232
    class DbCursor(Proxy):
233
        def __init__(self, outer):
234
            Proxy.__init__(self, outer.db.cursor())
235
            self.outer = outer
236
            self.query_results = outer.query_results
237
            self.query_lookup = None
238
            self.result = []
239
        
240
        def execute(self, query):
241
            self._is_insert = query.startswith('INSERT')
242
            self.query_lookup = query
243
            try:
244
                try:
245
                    cur = self.inner.execute(query)
246
                    self.outer.do_autocommit()
247
                finally: self.query = get_cur_query(self.inner, query)
248
            except Exception, e:
249
                _add_cursor_info(e, self, query)
250
                self.result = e # cache the exception as the result
251
                self._cache_result()
252
                raise
253
            
254
            # Always cache certain queries
255
            if query.startswith('CREATE') or query.startswith('ALTER'):
256
                # structural changes
257
                # Rest of query must be unique in the face of name collisions,
258
                # so don't cache ADD COLUMN unless it has distinguishing comment
259
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
260
                    self._cache_result()
261
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
262
                consume_rows(self) # fetch all rows so result will be cached
263
            
264
            return cur
265
        
266
        def fetchone(self):
267
            row = self.inner.fetchone()
268
            if row != None: self.result.append(row)
269
            # otherwise, fetched all rows
270
            else: self._cache_result()
271
            return row
272
        
273
        def _cache_result(self):
274
            # For inserts that return a result set, don't cache result set since
275
            # inserts are not idempotent. Other non-SELECT queries don't have
276
            # their result set read, so only exceptions will be cached (an
277
            # invalid query will always be invalid).
278
            if self.query_results != None and (not self._is_insert
279
                or isinstance(self.result, Exception)):
280
                
281
                assert self.query_lookup != None
282
                self.query_results[self.query_lookup] = self.CacheCursor(
283
                    util.dict_subset(dicts.AttrsDictView(self),
284
                    ['query', 'result', 'rowcount', 'description']))
285
        
286
        class CacheCursor:
287
            def __init__(self, cached_result): self.__dict__ = cached_result
288
            
289
            def execute(self, *args, **kw_args):
290
                if isinstance(self.result, Exception): raise self.result
291
                # otherwise, result is a rows list
292
                self.iter = iter(self.result)
293
            
294
            def fetchone(self):
295
                try: return self.iter.next()
296
                except StopIteration: return None
297
    
298
    def esc_value(self, value):
299
        try: str_ = self.mogrify('%s', [value])
300
        except NotImplementedError, e:
301
            module = util.root_module(self.db)
302
            if module == 'MySQLdb':
303
                import _mysql
304
                str_ = _mysql.escape_string(value)
305
            else: raise e
306
        return strings.to_unicode(str_)
307
    
308
    def esc_name(self, name): return esc_name(self, name) # calls global func
309
    
310
    def std_code(self, str_):
311
        '''Standardizes SQL code.
312
        * Ensures that string literals are prefixed by `E`
313
        '''
314
        if str_.startswith("'"): str_ = 'E'+str_
315
        return str_
316
    
317
    def can_mogrify(self):
318
        module = util.root_module(self.db)
319
        return module == 'psycopg2'
320
    
321
    def mogrify(self, query, params=None):
322
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
323
        else: raise NotImplementedError("Can't mogrify query")
324
    
325
    def print_notices(self):
326
        if hasattr(self.db, 'notices'):
327
            for msg in self.db.notices:
328
                if msg not in self._notices_seen:
329
                    self._notices_seen.add(msg)
330
                    self.log_debug(msg, level=2)
331
    
332
    def run_query(self, query, cacheable=False, log_level=2,
333
        debug_msg_ref=None):
334
        '''
335
        @param log_ignore_excs The log_level will be increased by 2 if the query
336
            throws one of these exceptions.
337
        @param debug_msg_ref If specified, the log message will be returned in
338
            this instead of being output. This allows you to filter log messages
339
            depending on the result of the query.
340
        '''
341
        assert query != None
342
        
343
        if not self.caching: cacheable = False
344
        used_cache = False
345
        
346
        def log_msg(query):
347
            if used_cache: cache_status = 'cache hit'
348
            elif cacheable: cache_status = 'cache miss'
349
            else: cache_status = 'non-cacheable'
350
            return 'DB query: '+cache_status+':\n'+strings.as_code(query, 'SQL')
351
        
352
        try:
353
            # Get cursor
354
            if cacheable:
355
                try:
356
                    cur = self.query_results[query]
357
                    used_cache = True
358
                except KeyError: cur = self.DbCursor(self)
359
            else: cur = self.db.cursor()
360
            
361
            # Log query
362
            if self.debug and debug_msg_ref == None: # log before running
363
                self.log_debug(log_msg(query), log_level)
364
            
365
            # Run query
366
            cur.execute(query)
367
        finally:
368
            self.print_notices()
369
            if self.debug and debug_msg_ref != None: # return after running
370
                debug_msg_ref[0] = log_msg(str(get_cur_query(cur, query)))
371
        
372
        return cur
373
    
374
    def is_cached(self, query): return query in self.query_results
375
    
376
    def with_autocommit(self, func):
377
        import psycopg2.extensions
378
        
379
        prev_isolation_level = self.db.isolation_level
380
        self.db.set_isolation_level(
381
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
382
        try: return func()
383
        finally: self.db.set_isolation_level(prev_isolation_level)
384
    
385
    def with_savepoint(self, func):
386
        savepoint = 'level_'+str(self._savepoint)
387
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
388
        self._savepoint += 1
389
        try: return func()
390
        except:
391
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
392
            raise
393
        finally:
394
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
395
            # "The savepoint remains valid and can be rolled back to again"
396
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
397
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
398
            
399
            self._savepoint -= 1
400
            assert self._savepoint >= 0
401
            
402
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
403
    
404
    def do_autocommit(self):
405
        '''Autocommits if outside savepoint'''
406
        assert self._savepoint >= 0
407
        if self.autocommit and self._savepoint == 0:
408
            self.log_debug('Autocommitting', level=4)
409
            self.db.commit()
410
    
411
    def col_info(self, col):
412
        table = sql_gen.Table('columns', 'information_schema')
413
        type_ = sql_gen.Coalesce(sql_gen.Nullif(sql_gen.Col('data_type'),
414
            'USER-DEFINED'), sql_gen.Col('udt_name'))
415
        cols = [type_, 'column_default',
416
            sql_gen.Cast('boolean', sql_gen.Col('is_nullable'))]
417
        
418
        conds = [('table_name', col.table.name), ('column_name', col.name)]
419
        schema = col.table.schema
420
        if schema != None: conds.append(('table_schema', schema))
421
        
422
        type_, default, nullable = row(select(self, table, cols, conds,
423
            order_by='table_schema', limit=1, cacheable=False, log_level=4))
424
            # TODO: order_by search_path schema order
425
        default = sql_gen.as_Code(default, self)
426
        
427
        return sql_gen.TypedCol(col.name, type_, default, nullable)
428
    
429
    def TempFunction(self, name):
430
        if self.debug_temp: schema = None
431
        else: schema = 'pg_temp'
432
        return sql_gen.Function(name, schema)
433

    
434
connect = DbConn
435

    
436
##### Recoverable querying
437

    
438
def with_savepoint(db, func): return db.with_savepoint(func)
439

    
440
def run_query(db, query, recover=None, cacheable=False, log_level=2,
441
    log_ignore_excs=None, **kw_args):
442
    '''For params, see DbConn.run_query()'''
443
    if recover == None: recover = False
444
    if log_ignore_excs == None: log_ignore_excs = ()
445
    log_ignore_excs = tuple(log_ignore_excs)
446
    
447
    debug_msg_ref = None # usually, db.run_query() logs query before running it
448
    # But if filtering with log_ignore_excs, wait until after exception parsing
449
    if log_ignore_excs != () or not db.can_mogrify(): debug_msg_ref = [None]
450
    
451
    try:
452
        try:
453
            def run(): return db.run_query(query, cacheable, log_level,
454
                debug_msg_ref, **kw_args)
455
            if recover and not db.is_cached(query):
456
                return with_savepoint(db, run)
457
            else: return run() # don't need savepoint if cached
458
        except Exception, e:
459
            msg = strings.ustr(e.args[0])
460
            
461
            match = re.match(r'^duplicate key value violates unique constraint '
462
                r'"((_?[^\W_]+(?=[._]))?.+?)"', msg)
463
            if match:
464
                constraint, table = match.groups()
465
                cols = []
466
                if recover: # need auto-rollback to run index_cols()
467
                    try: cols = index_cols(db, table, constraint)
468
                    except NotImplementedError: pass
469
                raise DuplicateKeyException(constraint, cols, e)
470
            
471
            match = re.match(r'^null value in column "(.+?)" violates not-null'
472
                r' constraint', msg)
473
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
474
            
475
            match = re.match(r'^(?:invalid input (?:syntax|value)\b.*?'
476
                r'|.+? field value out of range): "(.+?)"', msg)
477
            if match:
478
                value, = match.groups()
479
                raise InvalidValueException(strings.to_unicode(value), e)
480
            
481
            match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
482
                r'is of type', msg)
483
            if match:
484
                col, type_ = match.groups()
485
                raise MissingCastException(type_, col, e)
486
            
487
            match = re.match(r'^(\S+) "(.+?)".*? already exists', msg)
488
            if match:
489
                type_, name = match.groups()
490
                raise DuplicateException(type_, name, e)
491
            
492
            raise # no specific exception raised
493
    except log_ignore_excs:
494
        log_level += 2
495
        raise
496
    finally:
497
        if debug_msg_ref != None and debug_msg_ref[0] != None:
498
            db.log_debug(debug_msg_ref[0], log_level)
499

    
500
##### Basic queries
501

    
502
def next_version(name):
503
    version = 1 # first existing name was version 0
504
    match = re.match(r'^(.*)#(\d+)$', name)
505
    if match:
506
        name, version = match.groups()
507
        version = int(version)+1
508
    return sql_gen.concat(name, '#'+str(version))
509

    
510
def lock_table(db, table, mode):
511
    table = sql_gen.as_Table(table)
512
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
513

    
514
def run_query_into(db, query, into=None, add_indexes_=False, **kw_args):
515
    '''Outputs a query to a temp table.
516
    For params, see run_query().
517
    '''
518
    if into == None: return run_query(db, query, **kw_args)
519
    
520
    assert isinstance(into, sql_gen.Table)
521
    
522
    into.is_temp = True
523
    # "temporary tables cannot specify a schema name", so remove schema
524
    into.schema = None
525
    
526
    kw_args['recover'] = True
527
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
528
    
529
    temp = not db.debug_temp # tables are permanent in debug_temp mode
530
    
531
    # Create table
532
    while True:
533
        create_query = 'CREATE'
534
        if temp: create_query += ' TEMP'
535
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
536
        
537
        try:
538
            cur = run_query(db, create_query, **kw_args)
539
                # CREATE TABLE AS sets rowcount to # rows in query
540
            break
541
        except DuplicateException, e:
542
            into.name = next_version(into.name)
543
            # try again with next version of name
544
    
545
    if add_indexes_: add_indexes(db, into)
546
    
547
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
548
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
549
    # table is going to be used in complex queries, it is wise to run ANALYZE on
550
    # the temporary table after it is populated."
551
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
552
    # If into is not a temp table, ANALYZE is useful but not required.
553
    analyze(db, into)
554
    
555
    return cur
556

    
557
order_by_pkey = object() # tells mk_select() to order by the pkey
558

    
559
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
560

    
561
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
562
    start=None, order_by=order_by_pkey, default_table=None):
563
    '''
564
    @param tables The single table to select from, or a list of tables to join
565
        together, with tables after the first being sql_gen.Join objects
566
    @param fields Use None to select all fields in the table
567
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
568
        * container can be any iterable type
569
        * compare_left_side: sql_gen.Code|str (for col name)
570
        * compare_right_side: sql_gen.ValueCond|literal value
571
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
572
        use all columns
573
    @return query
574
    '''
575
    # Parse tables param
576
    tables = lists.mk_seq(tables)
577
    tables = list(tables) # don't modify input! (list() copies input)
578
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
579
    
580
    # Parse other params
581
    if conds == None: conds = []
582
    elif dicts.is_dict(conds): conds = conds.items()
583
    conds = list(conds) # don't modify input! (list() copies input)
584
    assert limit == None or isinstance(limit, (int, long))
585
    assert start == None or isinstance(start, (int, long))
586
    if order_by is order_by_pkey:
587
        if distinct_on != []: order_by = None
588
        else: order_by = pkey(db, table0, recover=True)
589
    
590
    query = 'SELECT'
591
    
592
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
593
    
594
    # DISTINCT ON columns
595
    if distinct_on != []:
596
        query += '\nDISTINCT'
597
        if distinct_on is not distinct_on_all:
598
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
599
    
600
    # Columns
601
    if fields == None:
602
        if query.find('\n') >= 0: whitespace = '\n'
603
        else: whitespace = ' '
604
        query += whitespace+'*'
605
    else:
606
        assert fields != []
607
        query += '\n'+('\n, '.join(map(parse_col, fields)))
608
    
609
    # Main table
610
    query += '\nFROM '+table0.to_str(db)
611
    
612
    # Add joins
613
    left_table = table0
614
    for join_ in tables:
615
        table = join_.table
616
        
617
        # Parse special values
618
        if join_.type_ is sql_gen.filter_out: # filter no match
619
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
620
                sql_gen.CompareCond(None, '~=')))
621
        
622
        query += '\n'+join_.to_str(db, left_table)
623
        
624
        left_table = table
625
    
626
    missing = True
627
    if conds != []:
628
        if len(conds) == 1: whitespace = ' '
629
        else: whitespace = '\n'
630
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
631
            .to_str(db) for l, r in conds], 'WHERE')
632
        missing = False
633
    if order_by != None:
634
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
635
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
636
    if start != None:
637
        if start != 0: query += '\nOFFSET '+str(start)
638
        missing = False
639
    if missing: warnings.warn(DbWarning(
640
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
641
    
642
    return query
643

    
644
def select(db, *args, **kw_args):
645
    '''For params, see mk_select() and run_query()'''
646
    recover = kw_args.pop('recover', None)
647
    cacheable = kw_args.pop('cacheable', True)
648
    log_level = kw_args.pop('log_level', 2)
649
    
650
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
651
        log_level=log_level)
652

    
653
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
654
    embeddable=False, ignore=False):
655
    '''
656
    @param returning str|None An inserted column (such as pkey) to return
657
    @param embeddable Whether the query should be embeddable as a nested SELECT.
658
        Warning: If you set this and cacheable=True when the query is run, the
659
        query will be fully cached, not just if it raises an exception.
660
    @param ignore Whether to ignore duplicate keys.
661
    '''
662
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
663
    if cols == []: cols = None # no cols (all defaults) = unknown col names
664
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
665
    if select_query == None: select_query = 'DEFAULT VALUES'
666
    if returning != None: returning = sql_gen.as_Col(returning, table)
667
    
668
    first_line = 'INSERT INTO '+table.to_str(db)
669
    
670
    def mk_insert(select_query):
671
        query = first_line
672
        if cols != None:
673
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
674
        query += '\n'+select_query
675
        
676
        if returning != None:
677
            returning_name_col = sql_gen.to_name_only_col(returning)
678
            query += '\nRETURNING '+returning_name_col.to_str(db)
679
        
680
        return query
681
    
682
    return_type = 'unknown'
683
    if returning != None: return_type = returning.to_str(db)+'%TYPE'
684
    
685
    lang = 'sql'
686
    if ignore:
687
        # Always return something to set the correct rowcount
688
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
689
        
690
        embeddable = True # must use function
691
        lang = 'plpgsql'
692
        
693
        if cols == None:
694
            row = [sql_gen.Col(sql_gen.all_cols, 'row')]
695
            row_vars = [sql_gen.Table('row')]
696
        else:
697
            row_vars = row = [sql_gen.Col(c.name, 'row') for c in cols]
698
        
699
        query = '''\
700
DECLARE
701
    row '''+table.to_str(db)+'''%ROWTYPE;
702
BEGIN
703
    /* Need an EXCEPTION block for each individual row because "When an error is
704
    caught by an EXCEPTION clause, [...] all changes to persistent database
705
    state within the block are rolled back."
706
    This is unfortunate because "A block containing an EXCEPTION clause is
707
    significantly more expensive to enter and exit than a block without one."
708
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
709
#PLPGSQL-ERROR-TRAPPING)
710
    */
711
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
712
'''+select_query+'''
713
    LOOP
714
        BEGIN
715
            RETURN QUERY
716
'''+mk_insert(sql_gen.Values(row).to_str(db))+'''
717
;
718
        EXCEPTION
719
            WHEN unique_violation THEN NULL; -- continue to next row
720
        END;
721
    END LOOP;
722
END;\
723
'''
724
    else: query = mk_insert(select_query)
725
    
726
    if embeddable:
727
        # Create function
728
        function_name = sql_gen.clean_name(first_line)
729
        while True:
730
            try:
731
                function = db.TempFunction(function_name)
732
                
733
                function_query = '''\
734
CREATE FUNCTION '''+function.to_str(db)+'''()
735
RETURNS SETOF '''+return_type+'''
736
LANGUAGE '''+lang+'''
737
AS $$
738
'''+query+'''
739
$$;
740
'''
741
                run_query(db, function_query, recover=True, cacheable=True,
742
                    log_ignore_excs=(DuplicateException,))
743
                break # this version was successful
744
            except DuplicateException, e:
745
                function_name = next_version(function_name)
746
                # try again with next version of name
747
        
748
        # Return query that uses function
749
        cols = None
750
        if returning != None: cols = [returning]
751
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
752
            cols) # AS clause requires function alias
753
        return mk_select(db, func_table, start=0, order_by=None)
754
    
755
    return query
756

    
757
def insert_select(db, table, *args, **kw_args):
758
    '''For params, see mk_insert_select() and run_query_into()
759
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
760
        values in
761
    '''
762
    into = kw_args.pop('into', None)
763
    if into != None: kw_args['embeddable'] = True
764
    recover = kw_args.pop('recover', None)
765
    if kw_args.get('ignore', False): recover = True
766
    cacheable = kw_args.pop('cacheable', True)
767
    log_level = kw_args.pop('log_level', 2)
768
    
769
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
770
        into, recover=recover, cacheable=cacheable, log_level=log_level)
771
    autoanalyze(db, table)
772
    return cur
773

    
774
default = sql_gen.default # tells insert() to use the default value for a column
775

    
776
def insert(db, table, row, *args, **kw_args):
777
    '''For params, see insert_select()'''
778
    if lists.is_seq(row): cols = None
779
    else:
780
        cols = row.keys()
781
        row = row.values()
782
    row = list(row) # ensure that "== []" works
783
    
784
    if row == []: query = None
785
    else: query = sql_gen.Values(row).to_str(db)
786
    
787
    return insert_select(db, table, cols, query, *args, **kw_args)
788

    
789
def mk_update(db, table, changes=None, cond=None, in_place=False):
790
    '''
791
    @param changes [(col, new_value),...]
792
        * container can be any iterable type
793
        * col: sql_gen.Code|str (for col name)
794
        * new_value: sql_gen.Code|literal value
795
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
796
    @param in_place If set, locks the table and updates rows in place.
797
        This avoids creating dead rows in PostgreSQL.
798
        * cond must be None
799
    @return str query
800
    '''
801
    table = sql_gen.as_Table(table)
802
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
803
        for c, v in changes]
804
    
805
    if in_place:
806
        assert cond == None
807
        
808
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
809
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
810
            +db.col_info(sql_gen.with_default_table(c, table)).type
811
            +'\nUSING '+v.to_str(db) for c, v in changes))
812
    else:
813
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
814
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
815
            for c, v in changes))
816
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
817
    
818
    return query
819

    
820
def update(db, table, *args, **kw_args):
821
    '''For params, see mk_update() and run_query()'''
822
    recover = kw_args.pop('recover', None)
823
    cacheable = kw_args.pop('cacheable', False)
824
    log_level = kw_args.pop('log_level', 2)
825
    
826
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
827
        cacheable, log_level=log_level)
828
    autoanalyze(db, table)
829
    return cur
830

    
831
def last_insert_id(db):
832
    module = util.root_module(db.db)
833
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
834
    elif module == 'MySQLdb': return db.insert_id()
835
    else: return None
836

    
837
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
838
    '''Creates a mapping from original column names (which may have collisions)
839
    to names that will be distinct among the columns' tables.
840
    This is meant to be used for several tables that are being joined together.
841
    @param cols The columns to combine. Duplicates will be removed.
842
    @param into The table for the new columns.
843
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
844
        columns will be included in the mapping even if they are not in cols.
845
        The tables of the provided Col objects will be changed to into, so make
846
        copies of them if you want to keep the original tables.
847
    @param as_items Whether to return a list of dict items instead of a dict
848
    @return dict(orig_col=new_col, ...)
849
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
850
        * new_col: sql_gen.Col(orig_col_name, into)
851
        * All mappings use the into table so its name can easily be
852
          changed for all columns at once
853
    '''
854
    cols = lists.uniqify(cols)
855
    
856
    items = []
857
    for col in preserve:
858
        orig_col = copy.copy(col)
859
        col.table = into
860
        items.append((orig_col, col))
861
    preserve = set(preserve)
862
    for col in cols:
863
        if col not in preserve:
864
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
865
    
866
    if not as_items: items = dict(items)
867
    return items
868

    
869
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
870
    '''For params, see mk_flatten_mapping()
871
    @return See return value of mk_flatten_mapping()
872
    '''
873
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
874
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
875
    run_query_into(db, mk_select(db, joins, cols, limit=limit, start=start),
876
        into=into, add_indexes_=True)
877
    return dict(items)
878

    
879
##### Database structure introspection
880

    
881
#### Tables
882

    
883
def tables(db, schema_like='public', table_like='%', exact=False):
884
    if exact: compare = '='
885
    else: compare = 'LIKE'
886
    
887
    module = util.root_module(db.db)
888
    if module == 'psycopg2':
889
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
890
            ('tablename', sql_gen.CompareCond(table_like, compare))]
891
        return values(select(db, 'pg_tables', ['tablename'], conds,
892
            order_by='tablename', log_level=4))
893
    elif module == 'MySQLdb':
894
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
895
            , cacheable=True, log_level=4))
896
    else: raise NotImplementedError("Can't list tables for "+module+' database')
897

    
898
def table_exists(db, table):
899
    table = sql_gen.as_Table(table)
900
    return list(tables(db, table.schema, table.name, exact=True)) != []
901

    
902
def table_row_count(db, table, recover=None):
903
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
904
        order_by=None, start=0), recover=recover, log_level=3))
905

    
906
def table_cols(db, table, recover=None):
907
    return list(col_names(select(db, table, limit=0, order_by=None,
908
        recover=recover, log_level=4)))
909

    
910
def pkey(db, table, recover=None):
911
    '''Assumed to be first column in table'''
912
    return table_cols(db, table, recover)[0]
913

    
914
not_null_col = 'not_null_col'
915

    
916
def table_not_null_col(db, table, recover=None):
917
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
918
    if not_null_col in table_cols(db, table, recover): return not_null_col
919
    else: return pkey(db, table, recover)
920

    
921
def index_cols(db, table, index):
922
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
923
    automatically created. When you don't know whether something is a UNIQUE
924
    constraint or a UNIQUE index, use this function.'''
925
    module = util.root_module(db.db)
926
    if module == 'psycopg2':
927
        return list(values(run_query(db, '''\
928
SELECT attname
929
FROM
930
(
931
        SELECT attnum, attname
932
        FROM pg_index
933
        JOIN pg_class index ON index.oid = indexrelid
934
        JOIN pg_class table_ ON table_.oid = indrelid
935
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
936
        WHERE
937
            table_.relname = '''+db.esc_value(table)+'''
938
            AND index.relname = '''+db.esc_value(index)+'''
939
    UNION
940
        SELECT attnum, attname
941
        FROM
942
        (
943
            SELECT
944
                indrelid
945
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
946
                    AS indkey
947
            FROM pg_index
948
            JOIN pg_class index ON index.oid = indexrelid
949
            JOIN pg_class table_ ON table_.oid = indrelid
950
            WHERE
951
                table_.relname = '''+db.esc_value(table)+'''
952
                AND index.relname = '''+db.esc_value(index)+'''
953
        ) s
954
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
955
) s
956
ORDER BY attnum
957
'''
958
            , cacheable=True, log_level=4)))
959
    else: raise NotImplementedError("Can't list index columns for "+module+
960
        ' database')
961

    
962
def constraint_cols(db, table, constraint):
963
    module = util.root_module(db.db)
964
    if module == 'psycopg2':
965
        return list(values(run_query(db, '''\
966
SELECT attname
967
FROM pg_constraint
968
JOIN pg_class ON pg_class.oid = conrelid
969
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
970
WHERE
971
    relname = '''+db.esc_value(table)+'''
972
    AND conname = '''+db.esc_value(constraint)+'''
973
ORDER BY attnum
974
'''
975
            )))
976
    else: raise NotImplementedError("Can't list constraint columns for "+module+
977
        ' database')
978

    
979
#### Functions
980

    
981
def function_exists(db, function):
982
    function = sql_gen.as_Function(function)
983
    
984
    info_table = sql_gen.Table('routines', 'information_schema')
985
    conds = [('routine_name', function.name)]
986
    schema = function.schema
987
    if schema != None: conds.append(('routine_schema', schema))
988
    # Exclude trigger functions, since they cannot be called directly
989
    conds.append(('data_type', sql_gen.CompareCond('trigger', '!=')))
990
    
991
    return list(values(select(db, info_table, ['routine_name'], conds,
992
        order_by='routine_schema', limit=1, log_level=4))) != []
993
        # TODO: order_by search_path schema order
994

    
995
##### Structural changes
996

    
997
#### Columns
998

    
999
def add_col(db, table, col, comment=None, **kw_args):
1000
    '''
1001
    @param col TypedCol Name may be versioned, so be sure to propagate any
1002
        renaming back to any source column for the TypedCol.
1003
    @param comment None|str SQL comment used to distinguish columns of the same
1004
        name from each other when they contain different data, to allow the
1005
        ADD COLUMN query to be cached. If not set, query will not be cached.
1006
    '''
1007
    assert isinstance(col, sql_gen.TypedCol)
1008
    
1009
    while True:
1010
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1011
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1012
        
1013
        try:
1014
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1015
            break
1016
        except DuplicateException:
1017
            col.name = next_version(col.name)
1018
            # try again with next version of name
1019

    
1020
def add_not_null(db, col):
1021
    table = col.table
1022
    col = sql_gen.to_name_only_col(col)
1023
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1024
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1025

    
1026
row_num_col = '_row_num'
1027

    
1028
row_num_typed_col = sql_gen.TypedCol(row_num_col, 'serial', nullable=False,
1029
    constraints='PRIMARY KEY')
1030

    
1031
def add_row_num(db, table):
1032
    '''Adds a row number column to a table. Its name is in row_num_col. It will
1033
    be the primary key.'''
1034
    add_col(db, table, row_num_typed_col, log_level=3)
1035

    
1036
#### Indexes
1037

    
1038
def add_pkey(db, table, cols=None, recover=None):
1039
    '''Adds a primary key.
1040
    @param cols [sql_gen.Col,...] The columns in the primary key.
1041
        Defaults to the first column in the table.
1042
    @pre The table must not already have a primary key.
1043
    '''
1044
    table = sql_gen.as_Table(table)
1045
    if cols == None: cols = [pkey(db, table, recover)]
1046
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1047
    
1048
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1049
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1050
        log_ignore_excs=(DuplicateException,))
1051

    
1052
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1053
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1054
    Currently, only function calls are supported as expressions.
1055
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1056
        This allows indexes to be used for comparisons where NULLs are equal.
1057
    '''
1058
    exprs = lists.mk_seq(exprs)
1059
    
1060
    # Parse exprs
1061
    old_exprs = exprs[:]
1062
    exprs = []
1063
    cols = []
1064
    for i, expr in enumerate(old_exprs):
1065
        expr = sql_gen.as_Col(expr, table)
1066
        
1067
        # Handle nullable columns
1068
        if ensure_not_null_:
1069
            try: expr = ensure_not_null(db, expr)
1070
            except KeyError: pass # unknown type, so just create plain index
1071
        
1072
        # Extract col
1073
        expr = copy.deepcopy(expr) # don't modify input!
1074
        if isinstance(expr, sql_gen.FunctionCall):
1075
            col = expr.args[0]
1076
            expr = sql_gen.Expr(expr)
1077
        else: col = expr
1078
        assert isinstance(col, sql_gen.Col)
1079
        
1080
        # Extract table
1081
        if table == None:
1082
            assert sql_gen.is_table_col(col)
1083
            table = col.table
1084
        
1085
        col.table = None
1086
        
1087
        exprs.append(expr)
1088
        cols.append(col)
1089
    
1090
    table = sql_gen.as_Table(table)
1091
    index = sql_gen.Table(str(sql_gen.Col(','.join(map(str, cols)), table)))
1092
    
1093
    # Add index
1094
    while True:
1095
        str_ = 'CREATE'
1096
        if unique: str_ += ' UNIQUE'
1097
        str_ += ' INDEX '+index.to_str(db)+' ON '+table.to_str(db)+' ('+(
1098
            ', '.join((v.to_str(db) for v in exprs)))+')'
1099
        
1100
        try:
1101
            run_query(db, str_, recover=True, cacheable=True, log_level=3,
1102
                log_ignore_excs=(DuplicateException,))
1103
            break
1104
        except DuplicateException:
1105
            index.name = next_version(index.name)
1106
            # try again with next version of name
1107

    
1108
def add_index_col(db, col, suffix, expr, nullable=True):
1109
    if sql_gen.index_col(col) != None: return # already has index col
1110
    
1111
    new_col = sql_gen.suffixed_col(col, suffix)
1112
    
1113
    # Add column
1114
    new_typed_col = sql_gen.TypedCol(new_col.name, db.col_info(col).type)
1115
    add_col(db, col.table, new_typed_col, comment='src: '+repr(col),
1116
        log_level=3)
1117
    new_col.name = new_typed_col.name # propagate any renaming
1118
    
1119
    update(db, col.table, [(new_col, expr)], in_place=True, cacheable=True,
1120
        log_level=3)
1121
    if not nullable: add_not_null(db, new_col)
1122
    add_index(db, new_col)
1123
    
1124
    col.table.index_cols[col.name] = new_col.name
1125

    
1126
# Controls when ensure_not_null() will use index columns
1127
not_null_index_cols_min_rows = 0 # rows; initially always use index columns
1128

    
1129
def ensure_not_null(db, col):
1130
    '''For params, see sql_gen.ensure_not_null()'''
1131
    expr = sql_gen.ensure_not_null(db, col)
1132
    
1133
    # If a nullable column in a temp table, add separate index column instead.
1134
    # Note that for small datasources, this adds 6-25% to the total import time.
1135
    if (sql_gen.is_temp_col(col) and isinstance(expr, sql_gen.EnsureNotNull)
1136
        and table_row_count(db, col.table) >= not_null_index_cols_min_rows):
1137
        add_index_col(db, col, '::NOT NULL', expr, nullable=False)
1138
        expr = sql_gen.index_col(col)
1139
    
1140
    return expr
1141

    
1142
already_indexed = object() # tells add_indexes() the pkey has already been added
1143

    
1144
def add_indexes(db, table, has_pkey=True):
1145
    '''Adds an index on all columns in a table.
1146
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1147
        index should be added on the first column.
1148
        * If already_indexed, the pkey is assumed to have already been added
1149
    '''
1150
    cols = table_cols(db, table)
1151
    if has_pkey:
1152
        if has_pkey is not already_indexed: add_pkey(db, table)
1153
        cols = cols[1:]
1154
    for col in cols: add_index(db, col, table)
1155

    
1156
#### Tables
1157

    
1158
### Maintenance
1159

    
1160
def analyze(db, table):
1161
    table = sql_gen.as_Table(table)
1162
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1163

    
1164
def autoanalyze(db, table):
1165
    if db.autoanalyze: analyze(db, table)
1166

    
1167
def vacuum(db, table):
1168
    table = sql_gen.as_Table(table)
1169
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1170
        log_level=3))
1171

    
1172
### Lifecycle
1173

    
1174
def drop_table(db, table):
1175
    table = sql_gen.as_Table(table)
1176
    return run_query(db, 'DROP TABLE IF EXISTS '+table.to_str(db)+' CASCADE')
1177

    
1178
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1179
    like=None):
1180
    '''Creates a table.
1181
    @param cols [sql_gen.TypedCol,...] The column names and types
1182
    @param has_pkey If set, the first column becomes the primary key.
1183
    @param col_indexes bool|[ref]
1184
        * If True, indexes will be added on all non-pkey columns.
1185
        * If a list reference, [0] will be set to a function to do this.
1186
          This can be used to delay index creation until the table is populated.
1187
    '''
1188
    table = sql_gen.as_Table(table)
1189
    
1190
    if like != None:
1191
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1192
            ]+cols
1193
    if has_pkey:
1194
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1195
        pkey.constraints = 'PRIMARY KEY'
1196
    
1197
    temp = table.is_temp and not db.debug_temp
1198
        # temp tables permanent in debug_temp mode
1199
    
1200
    # Create table
1201
    while True:
1202
        str_ = 'CREATE'
1203
        if temp: str_ += ' TEMP'
1204
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1205
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1206
        str_ += '\n);'
1207
        
1208
        try:
1209
            run_query(db, str_, recover=True, cacheable=True, log_level=2,
1210
                log_ignore_excs=(DuplicateException,))
1211
            break
1212
        except DuplicateException:
1213
            table.name = next_version(table.name)
1214
            # try again with next version of name
1215
    
1216
    # Add indexes
1217
    if has_pkey: has_pkey = already_indexed
1218
    def add_indexes_(): add_indexes(db, table, has_pkey)
1219
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1220
    elif col_indexes: add_indexes_() # add now
1221

    
1222
def copy_table_struct(db, src, dest):
1223
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1224
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1225

    
1226
### Data
1227

    
1228
def truncate(db, table, schema='public', **kw_args):
1229
    '''For params, see run_query()'''
1230
    table = sql_gen.as_Table(table, schema)
1231
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1232

    
1233
def empty_temp(db, tables):
1234
    if db.debug_temp: return # leave temp tables there for debugging
1235
    tables = lists.mk_seq(tables)
1236
    for table in tables: truncate(db, table, log_level=3)
1237

    
1238
def empty_db(db, schema='public', **kw_args):
1239
    '''For kw_args, see tables()'''
1240
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1241

    
1242
def distinct_table(db, table, distinct_on):
1243
    '''Creates a copy of a temp table which is distinct on the given columns.
1244
    The old and new tables will both get an index on these columns, to
1245
    facilitate merge joins.
1246
    @param distinct_on If empty, creates a table with one row. This is useful if
1247
        your distinct_on columns are all literal values.
1248
    @return The new table.
1249
    '''
1250
    new_table = sql_gen.suffixed_table(table, '_distinct')
1251
    
1252
    copy_table_struct(db, table, new_table)
1253
    
1254
    limit = None
1255
    if distinct_on == []: limit = 1 # one sample row
1256
    else:
1257
        add_index(db, distinct_on, new_table, unique=True)
1258
        add_index(db, distinct_on, table) # for join optimization
1259
    
1260
    insert_select(db, new_table, None, mk_select(db, table, start=0,
1261
        limit=limit), ignore=True)
1262
    analyze(db, new_table)
1263
    
1264
    return new_table
(24-24/37)