Project

General

Profile

1
# Database access
2

    
3
import copy
4
import re
5
import time
6
import warnings
7

    
8
import exc
9
import dicts
10
import iters
11
import lists
12
import profiling
13
from Proxy import Proxy
14
import rand
15
import sql_gen
16
import strings
17
import util
18

    
19
##### Exceptions
20

    
21
def get_cur_query(cur, input_query=None):
22
    raw_query = None
23
    if hasattr(cur, 'query'): raw_query = cur.query
24
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
25
    
26
    if raw_query != None: return raw_query
27
    else: return '[input] '+strings.ustr(input_query)
28

    
29
def _add_cursor_info(e, *args, **kw_args):
30
    '''For params, see get_cur_query()'''
31
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
32

    
33
class DbException(exc.ExceptionWithCause):
34
    def __init__(self, msg, cause=None, cur=None):
35
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
36
        if cur != None: _add_cursor_info(self, cur)
37

    
38
class ExceptionWithName(DbException):
39
    def __init__(self, name, cause=None):
40
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
41
        self.name = name
42

    
43
class ExceptionWithValue(DbException):
44
    def __init__(self, value, cause=None):
45
        DbException.__init__(self, 'for value: '+strings.as_tt(repr(value)),
46
            cause)
47
        self.value = value
48

    
49
class ExceptionWithNameType(DbException):
50
    def __init__(self, type_, name, cause=None):
51
        DbException.__init__(self, 'for type: '+strings.as_tt(str(type_))
52
            +'; name: '+strings.as_tt(name), cause)
53
        self.type = type_
54
        self.name = name
55

    
56
class ConstraintException(DbException):
57
    def __init__(self, name, cond, cols, cause=None):
58
        msg = 'Violated '+strings.as_tt(name)+' constraint'
59
        if cond != None: msg += ' with condition '+cond
60
        if cols != []: msg += ' on columns: '+strings.as_tt(', '.join(cols))
61
        DbException.__init__(self, msg, cause)
62
        self.name = name
63
        self.cond = cond
64
        self.cols = cols
65

    
66
class MissingCastException(DbException):
67
    def __init__(self, type_, col, cause=None):
68
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
69
            +' on column: '+strings.as_tt(col), cause)
70
        self.type = type_
71
        self.col = col
72

    
73
class NameException(DbException): pass
74

    
75
class DuplicateKeyException(ConstraintException): pass
76

    
77
class NullValueException(ConstraintException): pass
78

    
79
class CheckException(ConstraintException): pass
80

    
81
class InvalidValueException(ExceptionWithValue): pass
82

    
83
class DuplicateException(ExceptionWithNameType): pass
84

    
85
class EmptyRowException(DbException): pass
86

    
87
##### Warnings
88

    
89
class DbWarning(UserWarning): pass
90

    
91
##### Result retrieval
92

    
93
def col_names(cur): return (col[0] for col in cur.description)
94

    
95
def rows(cur): return iter(lambda: cur.fetchone(), None)
96

    
97
def consume_rows(cur):
98
    '''Used to fetch all rows so result will be cached'''
99
    iters.consume_iter(rows(cur))
100

    
101
def next_row(cur): return rows(cur).next()
102

    
103
def row(cur):
104
    row_ = next_row(cur)
105
    consume_rows(cur)
106
    return row_
107

    
108
def next_value(cur): return next_row(cur)[0]
109

    
110
def value(cur): return row(cur)[0]
111

    
112
def values(cur): return iters.func_iter(lambda: next_value(cur))
113

    
114
def value_or_none(cur):
115
    try: return value(cur)
116
    except StopIteration: return None
117

    
118
##### Escaping
119

    
120
def esc_name_by_module(module, name):
121
    if module == 'psycopg2' or module == None: quote = '"'
122
    elif module == 'MySQLdb': quote = '`'
123
    else: raise NotImplementedError("Can't escape name for "+module+' database')
124
    return sql_gen.esc_name(name, quote)
125

    
126
def esc_name_by_engine(engine, name, **kw_args):
127
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
128

    
129
def esc_name(db, name, **kw_args):
130
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
131

    
132
def qual_name(db, schema, table):
133
    def esc_name_(name): return esc_name(db, name)
134
    table = esc_name_(table)
135
    if schema != None: return esc_name_(schema)+'.'+table
136
    else: return table
137

    
138
##### Database connections
139

    
140
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
141

    
142
db_engines = {
143
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
144
    'PostgreSQL': ('psycopg2', {}),
145
}
146

    
147
DatabaseErrors_set = set([DbException])
148
DatabaseErrors = tuple(DatabaseErrors_set)
149

    
150
def _add_module(module):
151
    DatabaseErrors_set.add(module.DatabaseError)
152
    global DatabaseErrors
153
    DatabaseErrors = tuple(DatabaseErrors_set)
154

    
155
def db_config_str(db_config):
156
    return db_config['engine']+' database '+db_config['database']
157

    
158
log_debug_none = lambda msg, level=2: None
159

    
160
class DbConn:
161
    def __init__(self, db_config, autocommit=True, caching=True,
162
        log_debug=log_debug_none, debug_temp=False, src=None):
163
        '''
164
        @param debug_temp Whether temporary objects should instead be permanent.
165
            This assists in debugging the internal objects used by the program.
166
        @param src In autocommit mode, will be included in a comment in every
167
            query, to help identify the data source in pg_stat_activity.
168
        '''
169
        self.db_config = db_config
170
        self.autocommit = autocommit
171
        self.caching = caching
172
        self.log_debug = log_debug
173
        self.debug = log_debug != log_debug_none
174
        self.debug_temp = debug_temp
175
        self.src = src
176
        self.autoanalyze = False
177
        self.autoexplain = False
178
        self.profile_row_ct = None
179
        
180
        self._savepoint = 0
181
        self._reset()
182
    
183
    def __getattr__(self, name):
184
        if name == '__dict__': raise Exception('getting __dict__')
185
        if name == 'db': return self._db()
186
        else: raise AttributeError()
187
    
188
    def __getstate__(self):
189
        state = copy.copy(self.__dict__) # shallow copy
190
        state['log_debug'] = None # don't pickle the debug callback
191
        state['_DbConn__db'] = None # don't pickle the connection
192
        return state
193
    
194
    def clear_cache(self): self.query_results = {}
195
    
196
    def _reset(self):
197
        self.clear_cache()
198
        assert self._savepoint == 0
199
        self._notices_seen = set()
200
        self.__db = None
201
    
202
    def connected(self): return self.__db != None
203
    
204
    def close(self):
205
        if not self.connected(): return
206
        
207
        # Record that the automatic transaction is now closed
208
        self._savepoint -= 1
209
        
210
        self.db.close()
211
        self._reset()
212
    
213
    def reconnect(self):
214
        # Do not do this in test mode as it would roll back everything
215
        if self.autocommit: self.close()
216
        # Connection will be reopened automatically on first query
217
    
218
    def _db(self):
219
        if self.__db == None:
220
            # Process db_config
221
            db_config = self.db_config.copy() # don't modify input!
222
            schemas = db_config.pop('schemas', None)
223
            module_name, mappings = db_engines[db_config.pop('engine')]
224
            module = __import__(module_name)
225
            _add_module(module)
226
            for orig, new in mappings.iteritems():
227
                try: util.rename_key(db_config, orig, new)
228
                except KeyError: pass
229
            
230
            # Connect
231
            self.__db = module.connect(**db_config)
232
            
233
            # Record that a transaction is already open
234
            self._savepoint += 1
235
            
236
            # Configure connection
237
            if hasattr(self.db, 'set_isolation_level'):
238
                import psycopg2.extensions
239
                self.db.set_isolation_level(
240
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
241
            if schemas != None:
242
                search_path = [self.esc_name(s) for s in schemas.split(',')]
243
                search_path.append(value(run_query(self, 'SHOW search_path',
244
                    log_level=4)))
245
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
246
                    log_level=3)
247
        
248
        return self.__db
249
    
250
    class DbCursor(Proxy):
251
        def __init__(self, outer):
252
            Proxy.__init__(self, outer.db.cursor())
253
            self.outer = outer
254
            self.query_results = outer.query_results
255
            self.query_lookup = None
256
            self.result = []
257
        
258
        def execute(self, query):
259
            self._is_insert = query.startswith('INSERT')
260
            self.query_lookup = query
261
            try:
262
                try: cur = self.inner.execute(query)
263
                finally: self.query = get_cur_query(self.inner, query)
264
            except Exception, e:
265
                self.result = e # cache the exception as the result
266
                self._cache_result()
267
                raise
268
            
269
            # Always cache certain queries
270
            query = sql_gen.lstrip(query)
271
            if query.startswith('CREATE') or query.startswith('ALTER'):
272
                # structural changes
273
                # Rest of query must be unique in the face of name collisions,
274
                # so don't cache ADD COLUMN unless it has distinguishing comment
275
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
276
                    self._cache_result()
277
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
278
                consume_rows(self) # fetch all rows so result will be cached
279
            
280
            return cur
281
        
282
        def fetchone(self):
283
            row = self.inner.fetchone()
284
            if row != None: self.result.append(row)
285
            # otherwise, fetched all rows
286
            else: self._cache_result()
287
            return row
288
        
289
        def _cache_result(self):
290
            # For inserts that return a result set, don't cache result set since
291
            # inserts are not idempotent. Other non-SELECT queries don't have
292
            # their result set read, so only exceptions will be cached (an
293
            # invalid query will always be invalid).
294
            if self.query_results != None and (not self._is_insert
295
                or isinstance(self.result, Exception)):
296
                
297
                assert self.query_lookup != None
298
                self.query_results[self.query_lookup] = self.CacheCursor(
299
                    util.dict_subset(dicts.AttrsDictView(self),
300
                    ['query', 'result', 'rowcount', 'description']))
301
        
302
        class CacheCursor:
303
            def __init__(self, cached_result): self.__dict__ = cached_result
304
            
305
            def execute(self, *args, **kw_args):
306
                if isinstance(self.result, Exception): raise self.result
307
                # otherwise, result is a rows list
308
                self.iter = iter(self.result)
309
            
310
            def fetchone(self):
311
                try: return self.iter.next()
312
                except StopIteration: return None
313
    
314
    def esc_value(self, value):
315
        try: str_ = self.mogrify('%s', [value])
316
        except NotImplementedError, e:
317
            module = util.root_module(self.db)
318
            if module == 'MySQLdb':
319
                import _mysql
320
                str_ = _mysql.escape_string(value)
321
            else: raise e
322
        return strings.to_unicode(str_)
323
    
324
    def esc_name(self, name): return esc_name(self, name) # calls global func
325
    
326
    def std_code(self, str_):
327
        '''Standardizes SQL code.
328
        * Ensures that string literals are prefixed by `E`
329
        '''
330
        if str_.startswith("'"): str_ = 'E'+str_
331
        return str_
332
    
333
    def can_mogrify(self):
334
        module = util.root_module(self.db)
335
        return module == 'psycopg2'
336
    
337
    def mogrify(self, query, params=None):
338
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
339
        else: raise NotImplementedError("Can't mogrify query")
340
    
341
    def print_notices(self):
342
        if hasattr(self.db, 'notices'):
343
            for msg in self.db.notices:
344
                if msg not in self._notices_seen:
345
                    self._notices_seen.add(msg)
346
                    self.log_debug(msg, level=2)
347
    
348
    def run_query(self, query, cacheable=False, log_level=2,
349
        debug_msg_ref=None):
350
        '''
351
        @param log_ignore_excs The log_level will be increased by 2 if the query
352
            throws one of these exceptions.
353
        @param debug_msg_ref If specified, the log message will be returned in
354
            this instead of being output. This allows you to filter log messages
355
            depending on the result of the query.
356
        '''
357
        assert query != None
358
        
359
        if self.autocommit and self.src != None:
360
            query = sql_gen.esc_comment(self.src)+'\t'+query
361
        
362
        if not self.caching: cacheable = False
363
        used_cache = False
364
        
365
        if self.debug:
366
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
367
        try:
368
            # Get cursor
369
            if cacheable:
370
                try: cur = self.query_results[query]
371
                except KeyError: cur = self.DbCursor(self)
372
                else: used_cache = True
373
            else: cur = self.db.cursor()
374
            
375
            # Run query
376
            try: cur.execute(query)
377
            except Exception, e:
378
                _add_cursor_info(e, self, query)
379
                raise
380
            else: self.do_autocommit()
381
        finally:
382
            if self.debug:
383
                profiler.stop(self.profile_row_ct)
384
                
385
                ## Log or return query
386
                
387
                query = str(get_cur_query(cur, query))
388
                # Put the src comment on a separate line in the log file
389
                query = query.replace('\t', '\n', 1)
390
                
391
                msg = 'DB query: '
392
                
393
                if used_cache: msg += 'cache hit'
394
                elif cacheable: msg += 'cache miss'
395
                else: msg += 'non-cacheable'
396
                
397
                msg += ':\n'+profiler.msg()+'\n'+strings.as_code(query, 'SQL')
398
                
399
                if debug_msg_ref != None: debug_msg_ref[0] = msg
400
                else: self.log_debug(msg, log_level)
401
                
402
                self.print_notices()
403
        
404
        return cur
405
    
406
    def is_cached(self, query): return query in self.query_results
407
    
408
    def with_autocommit(self, func):
409
        import psycopg2.extensions
410
        
411
        prev_isolation_level = self.db.isolation_level
412
        self.db.set_isolation_level(
413
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
414
        try: return func()
415
        finally: self.db.set_isolation_level(prev_isolation_level)
416
    
417
    def with_savepoint(self, func):
418
        top = self._savepoint == 0
419
        savepoint = 'level_'+str(self._savepoint)
420
        
421
        if self.debug:
422
            self.log_debug('Begin transaction', level=4)
423
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
424
        
425
        # Must happen before running queries so they don't get autocommitted
426
        self._savepoint += 1
427
        
428
        if top: query = 'START TRANSACTION ISOLATION LEVEL READ COMMITTED'
429
        else: query = 'SAVEPOINT '+savepoint
430
        self.run_query(query, log_level=4)
431
        try:
432
            return func()
433
            if top: self.run_query('COMMIT', log_level=4)
434
        except:
435
            if top: query = 'ROLLBACK'
436
            else: query = 'ROLLBACK TO SAVEPOINT '+savepoint
437
            self.run_query(query, log_level=4)
438
            
439
            raise
440
        finally:
441
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
442
            # "The savepoint remains valid and can be rolled back to again"
443
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
444
            if not top:
445
                self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
446
            
447
            self._savepoint -= 1
448
            assert self._savepoint >= 0
449
            
450
            if self.debug:
451
                profiler.stop(self.profile_row_ct)
452
                self.log_debug('End transaction\n'+profiler.msg(), level=4)
453
            
454
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
455
    
456
    def do_autocommit(self):
457
        '''Autocommits if outside savepoint'''
458
        assert self._savepoint >= 1
459
        if self.autocommit and self._savepoint == 1:
460
            self.log_debug('Autocommitting', level=4)
461
            self.db.commit()
462
    
463
    def col_info(self, col, cacheable=True):
464
        table = sql_gen.Table('columns', 'information_schema')
465
        type_ = sql_gen.Coalesce(sql_gen.Nullif(sql_gen.Col('data_type'),
466
            'USER-DEFINED'), sql_gen.Col('udt_name'))
467
        cols = [type_, 'column_default',
468
            sql_gen.Cast('boolean', sql_gen.Col('is_nullable'))]
469
        
470
        conds = [('table_name', col.table.name), ('column_name', col.name)]
471
        schema = col.table.schema
472
        if schema != None: conds.append(('table_schema', schema))
473
        
474
        type_, default, nullable = row(select(self, table, cols, conds,
475
            order_by='table_schema', limit=1, cacheable=cacheable, log_level=4))
476
            # TODO: order_by search_path schema order
477
        default = sql_gen.as_Code(default, self)
478
        
479
        return sql_gen.TypedCol(col.name, type_, default, nullable)
480
    
481
    def TempFunction(self, name):
482
        if self.debug_temp: schema = None
483
        else: schema = 'pg_temp'
484
        return sql_gen.Function(name, schema)
485

    
486
connect = DbConn
487

    
488
##### Recoverable querying
489

    
490
def with_savepoint(db, func): return db.with_savepoint(func)
491

    
492
def run_query(db, query, recover=None, cacheable=False, log_level=2,
493
    log_ignore_excs=None, **kw_args):
494
    '''For params, see DbConn.run_query()'''
495
    if recover == None: recover = False
496
    if log_ignore_excs == None: log_ignore_excs = ()
497
    log_ignore_excs = tuple(log_ignore_excs)
498
    debug_msg_ref = [None]
499
    
500
    query = with_explain_comment(db, query)
501
    
502
    try:
503
        try:
504
            def run(): return db.run_query(query, cacheable, log_level,
505
                debug_msg_ref, **kw_args)
506
            if recover and not db.is_cached(query):
507
                return with_savepoint(db, run)
508
            else: return run() # don't need savepoint if cached
509
        except Exception, e:
510
            msg = strings.ustr(e.args[0])
511
            
512
            match = re.match(r'^duplicate key value violates unique constraint '
513
                r'"(.+?)"', msg)
514
            if match:
515
                constraint, = match.groups()
516
                cols = []
517
                if recover: # need auto-rollback to run index_cols()
518
                    try: cols = index_cols(db, constraint)
519
                    except NotImplementedError: pass
520
                raise DuplicateKeyException(constraint, None, cols, e)
521
            
522
            match = re.match(r'^null value in column "(.+?)" violates not-null'
523
                r' constraint', msg)
524
            if match:
525
                col, = match.groups()
526
                raise NullValueException('NOT NULL', None, [col], e)
527
            
528
            match = re.match(r'^new row for relation "(.+?)" violates check '
529
                r'constraint "(.+?)"', msg)
530
            if match:
531
                table, constraint = match.groups()
532
                constraint = sql_gen.Col(constraint, table)
533
                cond = None
534
                if recover: # need auto-rollback to run constraint_cond()
535
                    try: cond = constraint_cond(db, constraint)
536
                    except NotImplementedError: pass
537
                raise CheckException(constraint.to_str(db), cond, [], e)
538
            
539
            match = re.match(r'^(?:invalid input (?:syntax|value)\b.*?'
540
                r'|.+? field value out of range): "(.+?)"', msg)
541
            if match:
542
                value, = match.groups()
543
                raise InvalidValueException(strings.to_unicode(value), e)
544
            
545
            match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
546
                r'is of type', msg)
547
            if match:
548
                col, type_ = match.groups()
549
                raise MissingCastException(type_, col, e)
550
            
551
            match = re.match(r'^(\S+) "(.+?)".*? already exists', msg)
552
            if match:
553
                type_, name = match.groups()
554
                raise DuplicateException(type_, name, e)
555
            
556
            raise # no specific exception raised
557
    except log_ignore_excs:
558
        log_level += 2
559
        raise
560
    finally:
561
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)
562

    
563
##### Basic queries
564

    
565
def is_explainable(query):
566
    # See <http://www.postgresql.org/docs/8.3/static/sql-explain.html#AEN57749>
567
    return re.match(r'^(?:SELECT|INSERT|UPDATE|DELETE|VALUES|EXECUTE|DECLARE)\b'
568
        , query)
569

    
570
def explain(db, query, **kw_args):
571
    '''
572
    For params, see run_query().
573
    '''
574
    kw_args.setdefault('log_level', 4)
575
    
576
    return strings.join_lines(values(run_query(db, 'EXPLAIN '+query,
577
        recover=True, cacheable=True, **kw_args)))
578
        # not a higher log_level because it's useful to see what query is being
579
        # run before it's executed, which EXPLAIN effectively provides
580

    
581
def has_comment(query): return query.endswith('*/')
582

    
583
def with_explain_comment(db, query, **kw_args):
584
    if db.autoexplain and not has_comment(query) and is_explainable(query):
585
        query += '\n'+sql_gen.esc_comment(' EXPLAIN:\n'
586
            +explain(db, query, **kw_args))
587
    return query
588

    
589
def next_version(name):
590
    version = 1 # first existing name was version 0
591
    match = re.match(r'^(.*)#(\d+)$', name)
592
    if match:
593
        name, version = match.groups()
594
        version = int(version)+1
595
    return sql_gen.concat(name, '#'+str(version))
596

    
597
def lock_table(db, table, mode):
598
    table = sql_gen.as_Table(table)
599
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
600

    
601
def run_query_into(db, query, into=None, add_pkey_=False, **kw_args):
602
    '''Outputs a query to a temp table.
603
    For params, see run_query().
604
    '''
605
    if into == None: return run_query(db, query, **kw_args)
606
    
607
    assert isinstance(into, sql_gen.Table)
608
    
609
    into.is_temp = True
610
    # "temporary tables cannot specify a schema name", so remove schema
611
    into.schema = None
612
    
613
    kw_args['recover'] = True
614
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
615
    
616
    temp = not db.debug_temp # tables are permanent in debug_temp mode
617
    
618
    # Create table
619
    while True:
620
        create_query = 'CREATE'
621
        if temp: create_query += ' TEMP'
622
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
623
        
624
        try:
625
            cur = run_query(db, create_query, **kw_args)
626
                # CREATE TABLE AS sets rowcount to # rows in query
627
            break
628
        except DuplicateException, e:
629
            into.name = next_version(into.name)
630
            # try again with next version of name
631
    
632
    if add_pkey_: add_pkey(db, into)
633
    
634
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
635
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
636
    # table is going to be used in complex queries, it is wise to run ANALYZE on
637
    # the temporary table after it is populated."
638
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
639
    # If into is not a temp table, ANALYZE is useful but not required.
640
    analyze(db, into)
641
    
642
    return cur
643

    
644
order_by_pkey = object() # tells mk_select() to order by the pkey
645

    
646
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
647

    
648
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
649
    start=None, order_by=order_by_pkey, default_table=None):
650
    '''
651
    @param tables The single table to select from, or a list of tables to join
652
        together, with tables after the first being sql_gen.Join objects
653
    @param fields Use None to select all fields in the table
654
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
655
        * container can be any iterable type
656
        * compare_left_side: sql_gen.Code|str (for col name)
657
        * compare_right_side: sql_gen.ValueCond|literal value
658
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
659
        use all columns
660
    @return query
661
    '''
662
    # Parse tables param
663
    tables = lists.mk_seq(tables)
664
    tables = list(tables) # don't modify input! (list() copies input)
665
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
666
    
667
    # Parse other params
668
    if conds == None: conds = []
669
    elif dicts.is_dict(conds): conds = conds.items()
670
    conds = list(conds) # don't modify input! (list() copies input)
671
    assert limit == None or isinstance(limit, (int, long))
672
    assert start == None or isinstance(start, (int, long))
673
    if order_by is order_by_pkey:
674
        if distinct_on != []: order_by = None
675
        else: order_by = pkey(db, table0, recover=True)
676
    
677
    query = 'SELECT'
678
    
679
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
680
    
681
    # DISTINCT ON columns
682
    if distinct_on != []:
683
        query += '\nDISTINCT'
684
        if distinct_on is not distinct_on_all:
685
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
686
    
687
    # Columns
688
    if query.find('\n') >= 0: whitespace = '\n'
689
    else: whitespace = ' '
690
    if fields == None: query += whitespace+'*'
691
    else:
692
        assert fields != []
693
        if len(fields) > 1: whitespace = '\n'
694
        query += whitespace+('\n, '.join(map(parse_col, fields)))
695
    
696
    # Main table
697
    if query.find('\n') >= 0 or len(tables) > 0: whitespace = '\n'
698
    else: whitespace = ' '
699
    query += whitespace+'FROM '+table0.to_str(db)
700
    
701
    # Add joins
702
    left_table = table0
703
    for join_ in tables:
704
        table = join_.table
705
        
706
        # Parse special values
707
        if join_.type_ is sql_gen.filter_out: # filter no match
708
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
709
                sql_gen.CompareCond(None, '~=')))
710
        
711
        query += '\n'+join_.to_str(db, left_table)
712
        
713
        left_table = table
714
    
715
    missing = True
716
    if conds != []:
717
        if len(conds) == 1: whitespace = ' '
718
        else: whitespace = '\n'
719
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
720
            .to_str(db) for l, r in conds], 'WHERE')
721
    if order_by != None:
722
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
723
    if limit != None: query += '\nLIMIT '+str(limit)
724
    if start != None:
725
        if start != 0: query += '\nOFFSET '+str(start)
726
    
727
    query = with_explain_comment(db, query)
728
    
729
    return query
730

    
731
def select(db, *args, **kw_args):
732
    '''For params, see mk_select() and run_query()'''
733
    recover = kw_args.pop('recover', None)
734
    cacheable = kw_args.pop('cacheable', True)
735
    log_level = kw_args.pop('log_level', 2)
736
    
737
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
738
        log_level=log_level)
739

    
740
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
741
    embeddable=False, ignore=False, src=None):
742
    '''
743
    @param returning str|None An inserted column (such as pkey) to return
744
    @param embeddable Whether the query should be embeddable as a nested SELECT.
745
        Warning: If you set this and cacheable=True when the query is run, the
746
        query will be fully cached, not just if it raises an exception.
747
    @param ignore Whether to ignore duplicate keys.
748
    @param src Will be included in the name of any created function, to help
749
        identify the data source in pg_stat_activity.
750
    '''
751
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
752
    if cols == []: cols = None # no cols (all defaults) = unknown col names
753
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
754
    if select_query == None: select_query = 'DEFAULT VALUES'
755
    if returning != None: returning = sql_gen.as_Col(returning, table)
756
    
757
    first_line = 'INSERT INTO '+table.to_str(db)
758
    
759
    def mk_insert(select_query):
760
        query = first_line
761
        if cols != None:
762
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
763
        query += '\n'+select_query
764
        
765
        if returning != None:
766
            returning_name_col = sql_gen.to_name_only_col(returning)
767
            query += '\nRETURNING '+returning_name_col.to_str(db)
768
        
769
        return query
770
    
771
    return_type = 'unknown'
772
    if returning != None: return_type = returning.to_str(db)+'%TYPE'
773
    
774
    lang = 'sql'
775
    if ignore:
776
        # Always return something to set the correct rowcount
777
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
778
        
779
        embeddable = True # must use function
780
        lang = 'plpgsql'
781
        
782
        if cols == None:
783
            row = [sql_gen.Col(sql_gen.all_cols, 'row')]
784
            row_vars = [sql_gen.Table('row')]
785
        else:
786
            row_vars = row = [sql_gen.Col(c.name, 'row') for c in cols]
787
        
788
        query = '''\
789
DECLARE
790
    row '''+table.to_str(db)+'''%ROWTYPE;
791
BEGIN
792
    /* Need an EXCEPTION block for each individual row because "When an error is
793
    caught by an EXCEPTION clause, [...] all changes to persistent database
794
    state within the block are rolled back."
795
    This is unfortunate because "A block containing an EXCEPTION clause is
796
    significantly more expensive to enter and exit than a block without one."
797
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
798
#PLPGSQL-ERROR-TRAPPING)
799
    */
800
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
801
'''+select_query+'''
802
    LOOP
803
        BEGIN
804
            RETURN QUERY
805
'''+mk_insert(sql_gen.Values(row).to_str(db))+'''
806
;
807
        EXCEPTION
808
            WHEN unique_violation THEN NULL; -- continue to next row
809
        END;
810
    END LOOP;
811
END;\
812
'''
813
    else: query = mk_insert(select_query)
814
    
815
    if embeddable:
816
        # Create function
817
        function_name = sql_gen.clean_name(first_line)
818
        if src != None: function_name = src+': '+function_name
819
        while True:
820
            try:
821
                function = db.TempFunction(function_name)
822
                
823
                function_query = '''\
824
CREATE FUNCTION '''+function.to_str(db)+'''()
825
RETURNS SETOF '''+return_type+'''
826
LANGUAGE '''+lang+'''
827
AS $$
828
'''+query+'''
829
$$;
830
'''
831
                run_query(db, function_query, recover=True, cacheable=True,
832
                    log_ignore_excs=(DuplicateException,))
833
                break # this version was successful
834
            except DuplicateException, e:
835
                function_name = next_version(function_name)
836
                # try again with next version of name
837
        
838
        # Return query that uses function
839
        cols = None
840
        if returning != None: cols = [returning]
841
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
842
            cols) # AS clause requires function alias
843
        return mk_select(db, func_table, order_by=None)
844
    
845
    return query
846

    
847
def insert_select(db, table, *args, **kw_args):
848
    '''For params, see mk_insert_select() and run_query_into()
849
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
850
        values in
851
    '''
852
    returning = kw_args.get('returning', None)
853
    ignore = kw_args.get('ignore', False)
854
    
855
    into = kw_args.pop('into', None)
856
    if into != None: kw_args['embeddable'] = True
857
    recover = kw_args.pop('recover', None)
858
    if ignore: recover = True
859
    cacheable = kw_args.pop('cacheable', True)
860
    log_level = kw_args.pop('log_level', 2)
861
    
862
    rowcount_only = ignore and returning == None # keep NULL rows on server
863
    if rowcount_only: into = sql_gen.Table('rowcount')
864
    
865
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
866
        into, recover=recover, cacheable=cacheable, log_level=log_level)
867
    if rowcount_only: empty_temp(db, into)
868
    autoanalyze(db, table)
869
    return cur
870

    
871
default = sql_gen.default # tells insert() to use the default value for a column
872

    
873
def insert(db, table, row, *args, **kw_args):
874
    '''For params, see insert_select()'''
875
    if lists.is_seq(row): cols = None
876
    else:
877
        cols = row.keys()
878
        row = row.values()
879
    row = list(row) # ensure that "== []" works
880
    
881
    if row == []: query = None
882
    else: query = sql_gen.Values(row).to_str(db)
883
    
884
    return insert_select(db, table, cols, query, *args, **kw_args)
885

    
886
def mk_update(db, table, changes=None, cond=None, in_place=False,
887
    cacheable_=True):
888
    '''
889
    @param changes [(col, new_value),...]
890
        * container can be any iterable type
891
        * col: sql_gen.Code|str (for col name)
892
        * new_value: sql_gen.Code|literal value
893
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
894
    @param in_place If set, locks the table and updates rows in place.
895
        This avoids creating dead rows in PostgreSQL.
896
        * cond must be None
897
    @param cacheable_ Whether column structure information used to generate the
898
        query can be cached
899
    @return str query
900
    '''
901
    table = sql_gen.as_Table(table)
902
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
903
        for c, v in changes]
904
    
905
    if in_place:
906
        assert cond == None
907
        
908
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
909
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
910
            +db.col_info(sql_gen.with_default_table(c, table), cacheable_).type
911
            +'\nUSING '+v.to_str(db) for c, v in changes))
912
    else:
913
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
914
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
915
            for c, v in changes))
916
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
917
    
918
    query = with_explain_comment(db, query)
919
    
920
    return query
921

    
922
def update(db, table, *args, **kw_args):
923
    '''For params, see mk_update() and run_query()'''
924
    recover = kw_args.pop('recover', None)
925
    cacheable = kw_args.pop('cacheable', False)
926
    log_level = kw_args.pop('log_level', 2)
927
    
928
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
929
        cacheable, log_level=log_level)
930
    autoanalyze(db, table)
931
    return cur
932

    
933
def mk_delete(db, table, cond=None):
934
    '''
935
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
936
    @return str query
937
    '''
938
    query = 'DELETE FROM '+table.to_str(db)
939
    if cond != None: query += '\nWHERE '+cond.to_str(db)
940
    
941
    query = with_explain_comment(db, query)
942
    
943
    return query
944

    
945
def delete(db, table, *args, **kw_args):
946
    '''For params, see mk_delete() and run_query()'''
947
    recover = kw_args.pop('recover', None)
948
    cacheable = kw_args.pop('cacheable', True)
949
    log_level = kw_args.pop('log_level', 2)
950
    
951
    cur = run_query(db, mk_delete(db, table, *args, **kw_args), recover,
952
        cacheable, log_level=log_level)
953
    autoanalyze(db, table)
954
    return cur
955

    
956
def last_insert_id(db):
957
    module = util.root_module(db.db)
958
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
959
    elif module == 'MySQLdb': return db.insert_id()
960
    else: return None
961

    
962
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
963
    '''Creates a mapping from original column names (which may have collisions)
964
    to names that will be distinct among the columns' tables.
965
    This is meant to be used for several tables that are being joined together.
966
    @param cols The columns to combine. Duplicates will be removed.
967
    @param into The table for the new columns.
968
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
969
        columns will be included in the mapping even if they are not in cols.
970
        The tables of the provided Col objects will be changed to into, so make
971
        copies of them if you want to keep the original tables.
972
    @param as_items Whether to return a list of dict items instead of a dict
973
    @return dict(orig_col=new_col, ...)
974
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
975
        * new_col: sql_gen.Col(orig_col_name, into)
976
        * All mappings use the into table so its name can easily be
977
          changed for all columns at once
978
    '''
979
    cols = lists.uniqify(cols)
980
    
981
    items = []
982
    for col in preserve:
983
        orig_col = copy.copy(col)
984
        col.table = into
985
        items.append((orig_col, col))
986
    preserve = set(preserve)
987
    for col in cols:
988
        if col not in preserve:
989
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
990
    
991
    if not as_items: items = dict(items)
992
    return items
993

    
994
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
995
    '''For params, see mk_flatten_mapping()
996
    @return See return value of mk_flatten_mapping()
997
    '''
998
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
999
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
1000
    run_query_into(db, mk_select(db, joins, cols, order_by=None, limit=limit,
1001
        start=start), into=into, add_pkey_=True)
1002
    return dict(items)
1003

    
1004
##### Database structure introspection
1005

    
1006
#### Expressions
1007

    
1008
name_re = r'(?:\w+|(?:"[^"]*")+)'
1009

    
1010
def parse_expr_col(str_):
1011
    match = re.match(r'^\('+name_re+r'\(('+name_re+r').*\)\)$', str_)
1012
    if match: str_ = match.group(1)
1013
    return sql_gen.unesc_name(str_)
1014

    
1015
#### Tables
1016

    
1017
def tables(db, schema_like='public', table_like='%', exact=False):
1018
    if exact: compare = '='
1019
    else: compare = 'LIKE'
1020
    
1021
    module = util.root_module(db.db)
1022
    if module == 'psycopg2':
1023
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
1024
            ('tablename', sql_gen.CompareCond(table_like, compare))]
1025
        return values(select(db, 'pg_tables', ['tablename'], conds,
1026
            order_by='tablename', log_level=4))
1027
    elif module == 'MySQLdb':
1028
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
1029
            , cacheable=True, log_level=4))
1030
    else: raise NotImplementedError("Can't list tables for "+module+' database')
1031

    
1032
def table_exists(db, table):
1033
    table = sql_gen.as_Table(table)
1034
    return list(tables(db, table.schema, table.name, exact=True)) != []
1035

    
1036
def table_row_count(db, table, recover=None):
1037
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
1038
        order_by=None), recover=recover, log_level=3))
1039

    
1040
def table_cols(db, table, recover=None):
1041
    return list(col_names(select(db, table, limit=0, order_by=None,
1042
        recover=recover, log_level=4)))
1043

    
1044
def pkey(db, table, recover=None):
1045
    '''Assumed to be first column in table'''
1046
    return table_cols(db, table, recover)[0]
1047

    
1048
not_null_col = 'not_null_col'
1049

    
1050
def table_not_null_col(db, table, recover=None):
1051
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
1052
    if not_null_col in table_cols(db, table, recover): return not_null_col
1053
    else: return pkey(db, table, recover)
1054

    
1055
def constraint_cond(db, constraint):
1056
    module = util.root_module(db.db)
1057
    if module == 'psycopg2':
1058
        table_str = sql_gen.Literal(constraint.table.to_str(db))
1059
        name_str = sql_gen.Literal(constraint.name)
1060
        return value(run_query(db, '''\
1061
SELECT consrc
1062
FROM pg_constraint
1063
WHERE
1064
conrelid = '''+table_str.to_str(db)+'''::regclass
1065
AND conname = '''+name_str.to_str(db)+'''
1066
'''
1067
            , cacheable=True, log_level=4))
1068
    else: raise NotImplementedError("Can't list index columns for "+module+
1069
        ' database')
1070

    
1071
def index_cols(db, index):
1072
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
1073
    automatically created. When you don't know whether something is a UNIQUE
1074
    constraint or a UNIQUE index, use this function.'''
1075
    index = sql_gen.as_Table(index)
1076
    module = util.root_module(db.db)
1077
    if module == 'psycopg2':
1078
        qual_index = sql_gen.Literal(index.to_str(db))
1079
        return map(parse_expr_col, values(run_query(db, '''\
1080
SELECT pg_get_indexdef(indexrelid, generate_series(1, indnatts), true)
1081
FROM pg_index
1082
WHERE indexrelid = '''+qual_index.to_str(db)+'''::regclass
1083
'''
1084
            , cacheable=True, log_level=4)))
1085
    else: raise NotImplementedError("Can't list index columns for "+module+
1086
        ' database')
1087

    
1088
#### Functions
1089

    
1090
def function_exists(db, function):
1091
    function = sql_gen.as_Function(function)
1092
    
1093
    info_table = sql_gen.Table('routines', 'information_schema')
1094
    conds = [('routine_name', function.name)]
1095
    schema = function.schema
1096
    if schema != None: conds.append(('routine_schema', schema))
1097
    # Exclude trigger functions, since they cannot be called directly
1098
    conds.append(('data_type', sql_gen.CompareCond('trigger', '!=')))
1099
    
1100
    return list(values(select(db, info_table, ['routine_name'], conds,
1101
        order_by='routine_schema', limit=1, log_level=4))) != []
1102
        # TODO: order_by search_path schema order
1103

    
1104
##### Structural changes
1105

    
1106
#### Columns
1107

    
1108
def add_col(db, table, col, comment=None, **kw_args):
1109
    '''
1110
    @param col TypedCol Name may be versioned, so be sure to propagate any
1111
        renaming back to any source column for the TypedCol.
1112
    @param comment None|str SQL comment used to distinguish columns of the same
1113
        name from each other when they contain different data, to allow the
1114
        ADD COLUMN query to be cached. If not set, query will not be cached.
1115
    '''
1116
    assert isinstance(col, sql_gen.TypedCol)
1117
    
1118
    while True:
1119
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
1120
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
1121
        
1122
        try:
1123
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1124
            break
1125
        except DuplicateException:
1126
            col.name = next_version(col.name)
1127
            # try again with next version of name
1128

    
1129
def add_not_null(db, col):
1130
    table = col.table
1131
    col = sql_gen.to_name_only_col(col)
1132
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1133
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1134

    
1135
row_num_col = '_row_num'
1136

    
1137
row_num_typed_col = sql_gen.TypedCol(row_num_col, 'serial', nullable=False,
1138
    constraints='PRIMARY KEY')
1139

    
1140
def add_row_num(db, table):
1141
    '''Adds a row number column to a table. Its name is in row_num_col. It will
1142
    be the primary key.'''
1143
    add_col(db, table, row_num_typed_col, log_level=3)
1144

    
1145
#### Indexes
1146

    
1147
def add_pkey(db, table, cols=None, recover=None):
1148
    '''Adds a primary key.
1149
    @param cols [sql_gen.Col,...] The columns in the primary key.
1150
        Defaults to the first column in the table.
1151
    @pre The table must not already have a primary key.
1152
    '''
1153
    table = sql_gen.as_Table(table)
1154
    if cols == None: cols = [pkey(db, table, recover)]
1155
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1156
    
1157
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1158
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1159
        log_ignore_excs=(DuplicateException,))
1160

    
1161
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1162
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1163
    Currently, only function calls are supported as expressions.
1164
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1165
        This allows indexes to be used for comparisons where NULLs are equal.
1166
    '''
1167
    exprs = lists.mk_seq(exprs)
1168
    
1169
    # Parse exprs
1170
    old_exprs = exprs[:]
1171
    exprs = []
1172
    cols = []
1173
    for i, expr in enumerate(old_exprs):
1174
        expr = sql_gen.as_Col(expr, table)
1175
        
1176
        # Handle nullable columns
1177
        if ensure_not_null_:
1178
            try: expr = sql_gen.ensure_not_null(db, expr)
1179
            except KeyError: pass # unknown type, so just create plain index
1180
        
1181
        # Extract col
1182
        expr = copy.deepcopy(expr) # don't modify input!
1183
        if isinstance(expr, sql_gen.FunctionCall):
1184
            col = expr.args[0]
1185
            expr = sql_gen.Expr(expr)
1186
        else: col = expr
1187
        assert isinstance(col, sql_gen.Col)
1188
        
1189
        # Extract table
1190
        if table == None:
1191
            assert sql_gen.is_table_col(col)
1192
            table = col.table
1193
        
1194
        col.table = None
1195
        
1196
        exprs.append(expr)
1197
        cols.append(col)
1198
    
1199
    table = sql_gen.as_Table(table)
1200
    
1201
    # Add index
1202
    str_ = 'CREATE'
1203
    if unique: str_ += ' UNIQUE'
1204
    str_ += ' INDEX ON '+table.to_str(db)+' ('+(
1205
        ', '.join((v.to_str(db) for v in exprs)))+')'
1206
    run_query(db, str_, recover=True, cacheable=True, log_level=3)
1207

    
1208
already_indexed = object() # tells add_indexes() the pkey has already been added
1209

    
1210
def add_indexes(db, table, has_pkey=True):
1211
    '''Adds an index on all columns in a table.
1212
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1213
        index should be added on the first column.
1214
        * If already_indexed, the pkey is assumed to have already been added
1215
    '''
1216
    cols = table_cols(db, table)
1217
    if has_pkey:
1218
        if has_pkey is not already_indexed: add_pkey(db, table)
1219
        cols = cols[1:]
1220
    for col in cols: add_index(db, col, table)
1221

    
1222
#### Tables
1223

    
1224
### Maintenance
1225

    
1226
def analyze(db, table):
1227
    table = sql_gen.as_Table(table)
1228
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1229

    
1230
def autoanalyze(db, table):
1231
    if db.autoanalyze: analyze(db, table)
1232

    
1233
def vacuum(db, table):
1234
    table = sql_gen.as_Table(table)
1235
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1236
        log_level=3))
1237

    
1238
### Lifecycle
1239

    
1240
def drop(db, type_, name):
1241
    name = sql_gen.as_Name(name)
1242
    run_query(db, 'DROP '+type_+' IF EXISTS '+name.to_str(db)+' CASCADE')
1243

    
1244
def drop_table(db, table): drop(db, 'TABLE', table)
1245

    
1246
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1247
    like=None):
1248
    '''Creates a table.
1249
    @param cols [sql_gen.TypedCol,...] The column names and types
1250
    @param has_pkey If set, the first column becomes the primary key.
1251
    @param col_indexes bool|[ref]
1252
        * If True, indexes will be added on all non-pkey columns.
1253
        * If a list reference, [0] will be set to a function to do this.
1254
          This can be used to delay index creation until the table is populated.
1255
    '''
1256
    table = sql_gen.as_Table(table)
1257
    
1258
    if like != None:
1259
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1260
            ]+cols
1261
    if has_pkey:
1262
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1263
        pkey.constraints = 'PRIMARY KEY'
1264
    
1265
    temp = table.is_temp and not db.debug_temp
1266
        # temp tables permanent in debug_temp mode
1267
    
1268
    # Create table
1269
    while True:
1270
        str_ = 'CREATE'
1271
        if temp: str_ += ' TEMP'
1272
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1273
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1274
        str_ += '\n);'
1275
        
1276
        try:
1277
            run_query(db, str_, recover=True, cacheable=True, log_level=2,
1278
                log_ignore_excs=(DuplicateException,))
1279
            break
1280
        except DuplicateException:
1281
            table.name = next_version(table.name)
1282
            # try again with next version of name
1283
    
1284
    # Add indexes
1285
    if has_pkey: has_pkey = already_indexed
1286
    def add_indexes_(): add_indexes(db, table, has_pkey)
1287
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1288
    elif col_indexes: add_indexes_() # add now
1289

    
1290
def copy_table_struct(db, src, dest):
1291
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1292
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1293

    
1294
### Data
1295

    
1296
def truncate(db, table, schema='public', **kw_args):
1297
    '''For params, see run_query()'''
1298
    table = sql_gen.as_Table(table, schema)
1299
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1300

    
1301
def empty_temp(db, tables):
1302
    tables = lists.mk_seq(tables)
1303
    for table in tables: truncate(db, table, log_level=3)
1304

    
1305
def empty_db(db, schema='public', **kw_args):
1306
    '''For kw_args, see tables()'''
1307
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1308

    
1309
def distinct_table(db, table, distinct_on):
1310
    '''Creates a copy of a temp table which is distinct on the given columns.
1311
    The old and new tables will both get an index on these columns, to
1312
    facilitate merge joins.
1313
    @param distinct_on If empty, creates a table with one row. This is useful if
1314
        your distinct_on columns are all literal values.
1315
    @return The new table.
1316
    '''
1317
    new_table = sql_gen.suffixed_table(table, '_distinct')
1318
    
1319
    copy_table_struct(db, table, new_table)
1320
    
1321
    limit = None
1322
    if distinct_on == []: limit = 1 # one sample row
1323
    else:
1324
        add_index(db, distinct_on, new_table, unique=True)
1325
        add_index(db, distinct_on, table) # for join optimization
1326
    
1327
    insert_select(db, new_table, None, mk_select(db, table, order_by=None,
1328
        limit=limit), ignore=True)
1329
    analyze(db, new_table)
1330
    
1331
    return new_table
(24-24/37)