Project

General

Profile

1
# Database access
2

    
3
import copy
4
import re
5
import warnings
6

    
7
import exc
8
import dicts
9
import iters
10
import lists
11
from Proxy import Proxy
12
import rand
13
import sql_gen
14
import strings
15
import util
16

    
17
##### Exceptions
18

    
19
def get_cur_query(cur, input_query=None):
20
    raw_query = None
21
    if hasattr(cur, 'query'): raw_query = cur.query
22
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
23
    
24
    if raw_query != None: return raw_query
25
    else: return '[input] '+strings.ustr(input_query)
26

    
27
def _add_cursor_info(e, *args, **kw_args):
28
    '''For params, see get_cur_query()'''
29
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
30

    
31
class DbException(exc.ExceptionWithCause):
32
    def __init__(self, msg, cause=None, cur=None):
33
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
34
        if cur != None: _add_cursor_info(self, cur)
35

    
36
class ExceptionWithName(DbException):
37
    def __init__(self, name, cause=None):
38
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
39
        self.name = name
40

    
41
class ExceptionWithNameValue(DbException):
42
    def __init__(self, name, value, cause=None):
43
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name))
44
            +'; value: '+strings.as_tt(repr(value)), cause)
45
        self.name = name
46
        self.value = value
47

    
48
class ExceptionWithNameType(DbException):
49
    def __init__(self, type_, name, cause=None):
50
        DbException.__init__(self, 'for type: '+strings.as_tt(str(type_))
51
            +'; name: '+strings.as_tt(name), cause)
52
        self.type = type_
53
        self.name = name
54

    
55
class ConstraintException(DbException):
56
    def __init__(self, name, cols, cause=None):
57
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
58
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
59
        self.name = name
60
        self.cols = cols
61

    
62
class MissingCastException(DbException):
63
    def __init__(self, type_, col, cause=None):
64
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
65
            +' on column: '+strings.as_tt(col), cause)
66
        self.type = type_
67
        self.col = col
68

    
69
class NameException(DbException): pass
70

    
71
class DuplicateKeyException(ConstraintException): pass
72

    
73
class NullValueException(ConstraintException): pass
74

    
75
class FunctionValueException(ExceptionWithNameValue): pass
76

    
77
class DuplicateException(ExceptionWithNameType): pass
78

    
79
class EmptyRowException(DbException): pass
80

    
81
##### Warnings
82

    
83
class DbWarning(UserWarning): pass
84

    
85
##### Result retrieval
86

    
87
def col_names(cur): return (col[0] for col in cur.description)
88

    
89
def rows(cur): return iter(lambda: cur.fetchone(), None)
90

    
91
def consume_rows(cur):
92
    '''Used to fetch all rows so result will be cached'''
93
    iters.consume_iter(rows(cur))
94

    
95
def next_row(cur): return rows(cur).next()
96

    
97
def row(cur):
98
    row_ = next_row(cur)
99
    consume_rows(cur)
100
    return row_
101

    
102
def next_value(cur): return next_row(cur)[0]
103

    
104
def value(cur): return row(cur)[0]
105

    
106
def values(cur): return iters.func_iter(lambda: next_value(cur))
107

    
108
def value_or_none(cur):
109
    try: return value(cur)
110
    except StopIteration: return None
111

    
112
##### Escaping
113

    
114
def esc_name_by_module(module, name):
115
    if module == 'psycopg2' or module == None: quote = '"'
116
    elif module == 'MySQLdb': quote = '`'
117
    else: raise NotImplementedError("Can't escape name for "+module+' database')
118
    return sql_gen.esc_name(name, quote)
119

    
120
def esc_name_by_engine(engine, name, **kw_args):
121
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
122

    
123
def esc_name(db, name, **kw_args):
124
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
125

    
126
def qual_name(db, schema, table):
127
    def esc_name_(name): return esc_name(db, name)
128
    table = esc_name_(table)
129
    if schema != None: return esc_name_(schema)+'.'+table
130
    else: return table
131

    
132
##### Database connections
133

    
134
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
135

    
136
db_engines = {
137
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
138
    'PostgreSQL': ('psycopg2', {}),
139
}
140

    
141
DatabaseErrors_set = set([DbException])
142
DatabaseErrors = tuple(DatabaseErrors_set)
143

    
144
def _add_module(module):
145
    DatabaseErrors_set.add(module.DatabaseError)
146
    global DatabaseErrors
147
    DatabaseErrors = tuple(DatabaseErrors_set)
148

    
149
def db_config_str(db_config):
150
    return db_config['engine']+' database '+db_config['database']
151

    
152
log_debug_none = lambda msg, level=2: None
153

    
154
class DbConn:
155
    def __init__(self, db_config, autocommit=True, caching=True,
156
        log_debug=log_debug_none, debug_temp=False):
157
        '''
158
        @param debug_temp Whether temporary objects should instead be permanent.
159
            This assists in debugging the internal objects used by the program.
160
        '''
161
        self.db_config = db_config
162
        self.autocommit = autocommit
163
        self.caching = caching
164
        self.log_debug = log_debug
165
        self.debug = log_debug != log_debug_none
166
        self.debug_temp = debug_temp
167
        self.autoanalyze = False
168
        
169
        self.__db = None
170
        self.query_results = {}
171
        self._savepoint = 0
172
        self._notices_seen = set()
173
    
174
    def __getattr__(self, name):
175
        if name == '__dict__': raise Exception('getting __dict__')
176
        if name == 'db': return self._db()
177
        else: raise AttributeError()
178
    
179
    def __getstate__(self):
180
        state = copy.copy(self.__dict__) # shallow copy
181
        state['log_debug'] = None # don't pickle the debug callback
182
        state['_DbConn__db'] = None # don't pickle the connection
183
        return state
184
    
185
    def connected(self): return self.__db != None
186
    
187
    def _db(self):
188
        if self.__db == None:
189
            # Process db_config
190
            db_config = self.db_config.copy() # don't modify input!
191
            schemas = db_config.pop('schemas', None)
192
            module_name, mappings = db_engines[db_config.pop('engine')]
193
            module = __import__(module_name)
194
            _add_module(module)
195
            for orig, new in mappings.iteritems():
196
                try: util.rename_key(db_config, orig, new)
197
                except KeyError: pass
198
            
199
            # Connect
200
            self.__db = module.connect(**db_config)
201
            
202
            # Configure connection
203
            if hasattr(self.db, 'set_isolation_level'):
204
                import psycopg2.extensions
205
                self.db.set_isolation_level(
206
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
207
            if schemas != None:
208
                search_path = [self.esc_name(s) for s in schemas.split(',')]
209
                search_path.append(value(run_query(self, 'SHOW search_path',
210
                    log_level=4)))
211
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
212
                    log_level=3)
213
        
214
        return self.__db
215
    
216
    class DbCursor(Proxy):
217
        def __init__(self, outer):
218
            Proxy.__init__(self, outer.db.cursor())
219
            self.outer = outer
220
            self.query_results = outer.query_results
221
            self.query_lookup = None
222
            self.result = []
223
        
224
        def execute(self, query):
225
            self._is_insert = query.startswith('INSERT')
226
            self.query_lookup = query
227
            try:
228
                try:
229
                    cur = self.inner.execute(query)
230
                    self.outer.do_autocommit()
231
                finally: self.query = get_cur_query(self.inner, query)
232
            except Exception, e:
233
                _add_cursor_info(e, self, query)
234
                self.result = e # cache the exception as the result
235
                self._cache_result()
236
                raise
237
            
238
            # Always cache certain queries
239
            if query.startswith('CREATE') or query.startswith('ALTER'):
240
                # structural changes
241
                # Rest of query must be unique in the face of name collisions,
242
                # so don't cache ADD COLUMN unless it has distinguishing comment
243
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
244
                    self._cache_result()
245
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
246
                consume_rows(self) # fetch all rows so result will be cached
247
            
248
            return cur
249
        
250
        def fetchone(self):
251
            row = self.inner.fetchone()
252
            if row != None: self.result.append(row)
253
            # otherwise, fetched all rows
254
            else: self._cache_result()
255
            return row
256
        
257
        def _cache_result(self):
258
            # For inserts that return a result set, don't cache result set since
259
            # inserts are not idempotent. Other non-SELECT queries don't have
260
            # their result set read, so only exceptions will be cached (an
261
            # invalid query will always be invalid).
262
            if self.query_results != None and (not self._is_insert
263
                or isinstance(self.result, Exception)):
264
                
265
                assert self.query_lookup != None
266
                self.query_results[self.query_lookup] = self.CacheCursor(
267
                    util.dict_subset(dicts.AttrsDictView(self),
268
                    ['query', 'result', 'rowcount', 'description']))
269
        
270
        class CacheCursor:
271
            def __init__(self, cached_result): self.__dict__ = cached_result
272
            
273
            def execute(self, *args, **kw_args):
274
                if isinstance(self.result, Exception): raise self.result
275
                # otherwise, result is a rows list
276
                self.iter = iter(self.result)
277
            
278
            def fetchone(self):
279
                try: return self.iter.next()
280
                except StopIteration: return None
281
    
282
    def esc_value(self, value):
283
        try: str_ = self.mogrify('%s', [value])
284
        except NotImplementedError, e:
285
            module = util.root_module(self.db)
286
            if module == 'MySQLdb':
287
                import _mysql
288
                str_ = _mysql.escape_string(value)
289
            else: raise e
290
        return strings.to_unicode(str_)
291
    
292
    def esc_name(self, name): return esc_name(self, name) # calls global func
293
    
294
    def std_code(self, str_):
295
        '''Standardizes SQL code.
296
        * Ensures that string literals are prefixed by `E`
297
        '''
298
        if str_.startswith("'"): str_ = 'E'+str_
299
        return str_
300
    
301
    def can_mogrify(self):
302
        module = util.root_module(self.db)
303
        return module == 'psycopg2'
304
    
305
    def mogrify(self, query, params=None):
306
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
307
        else: raise NotImplementedError("Can't mogrify query")
308
    
309
    def print_notices(self):
310
        if hasattr(self.db, 'notices'):
311
            for msg in self.db.notices:
312
                if msg not in self._notices_seen:
313
                    self._notices_seen.add(msg)
314
                    self.log_debug(msg, level=2)
315
    
316
    def run_query(self, query, cacheable=False, log_level=2,
317
        debug_msg_ref=None):
318
        '''
319
        @param log_ignore_excs The log_level will be increased by 2 if the query
320
            throws one of these exceptions.
321
        @param debug_msg_ref If specified, the log message will be returned in
322
            this instead of being output. This allows you to filter log messages
323
            depending on the result of the query.
324
        '''
325
        assert query != None
326
        
327
        if not self.caching: cacheable = False
328
        used_cache = False
329
        
330
        def log_msg(query):
331
            if used_cache: cache_status = 'cache hit'
332
            elif cacheable: cache_status = 'cache miss'
333
            else: cache_status = 'non-cacheable'
334
            return 'DB query: '+cache_status+':\n'+strings.as_code(query, 'SQL')
335
        
336
        try:
337
            # Get cursor
338
            if cacheable:
339
                try:
340
                    cur = self.query_results[query]
341
                    used_cache = True
342
                except KeyError: cur = self.DbCursor(self)
343
            else: cur = self.db.cursor()
344
            
345
            # Log query
346
            if self.debug and debug_msg_ref == None: # log before running
347
                self.log_debug(log_msg(query), log_level)
348
            
349
            # Run query
350
            cur.execute(query)
351
        finally:
352
            self.print_notices()
353
            if self.debug and debug_msg_ref != None: # return after running
354
                debug_msg_ref[0] = log_msg(str(get_cur_query(cur, query)))
355
        
356
        return cur
357
    
358
    def is_cached(self, query): return query in self.query_results
359
    
360
    def with_autocommit(self, func):
361
        import psycopg2.extensions
362
        
363
        prev_isolation_level = self.db.isolation_level
364
        self.db.set_isolation_level(
365
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
366
        try: return func()
367
        finally: self.db.set_isolation_level(prev_isolation_level)
368
    
369
    def with_savepoint(self, func):
370
        savepoint = 'level_'+str(self._savepoint)
371
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
372
        self._savepoint += 1
373
        try: return func()
374
        except:
375
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
376
            raise
377
        finally:
378
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
379
            # "The savepoint remains valid and can be rolled back to again"
380
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
381
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
382
            
383
            self._savepoint -= 1
384
            assert self._savepoint >= 0
385
            
386
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
387
    
388
    def do_autocommit(self):
389
        '''Autocommits if outside savepoint'''
390
        assert self._savepoint >= 0
391
        if self.autocommit and self._savepoint == 0:
392
            self.log_debug('Autocommitting', level=4)
393
            self.db.commit()
394
    
395
    def col_info(self, col):
396
        table = sql_gen.Table('columns', 'information_schema')
397
        type_ = sql_gen.Coalesce(sql_gen.Nullif(sql_gen.Col('data_type'),
398
            'USER-DEFINED'), sql_gen.Col('udt_name'))
399
        cols = [type_, 'column_default',
400
            sql_gen.Cast('boolean', sql_gen.Col('is_nullable'))]
401
        
402
        conds = [('table_name', col.table.name), ('column_name', col.name)]
403
        schema = col.table.schema
404
        if schema != None: conds.append(('table_schema', schema))
405
        
406
        type_, default, nullable = row(select(self, table, cols, conds,
407
            order_by='table_schema', limit=1, cacheable=False, log_level=4))
408
            # TODO: order_by search_path schema order
409
        default = sql_gen.as_Code(default, self)
410
        
411
        return sql_gen.TypedCol(col.name, type_, default, nullable)
412
    
413
    def TempFunction(self, name):
414
        if self.debug_temp: schema = None
415
        else: schema = 'pg_temp'
416
        return sql_gen.Function(name, schema)
417

    
418
connect = DbConn
419

    
420
##### Recoverable querying
421

    
422
def with_savepoint(db, func): return db.with_savepoint(func)
423

    
424
def run_query(db, query, recover=None, cacheable=False, log_level=2,
425
    log_ignore_excs=None, **kw_args):
426
    '''For params, see DbConn.run_query()'''
427
    if recover == None: recover = False
428
    if log_ignore_excs == None: log_ignore_excs = ()
429
    log_ignore_excs = tuple(log_ignore_excs)
430
    
431
    debug_msg_ref = None # usually, db.run_query() logs query before running it
432
    # But if filtering with log_ignore_excs, wait until after exception parsing
433
    if log_ignore_excs != () or not db.can_mogrify(): debug_msg_ref = [None]
434
    
435
    try:
436
        try:
437
            def run(): return db.run_query(query, cacheable, log_level,
438
                debug_msg_ref, **kw_args)
439
            if recover and not db.is_cached(query):
440
                return with_savepoint(db, run)
441
            else: return run() # don't need savepoint if cached
442
        except Exception, e:
443
            msg = exc.str_(e)
444
            
445
            match = re.search(r'duplicate key value violates unique constraint '
446
                r'"((_?[^\W_]+)_.+?)"', msg)
447
            if match:
448
                constraint, table = match.groups()
449
                cols = []
450
                if recover: # need auto-rollback to run index_cols()
451
                    try: cols = index_cols(db, table, constraint)
452
                    except NotImplementedError: pass
453
                raise DuplicateKeyException(constraint, cols, e)
454
            
455
            match = re.search(r'null value in column "(.+?)" violates not-null'
456
                r' constraint', msg)
457
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
458
            
459
            match = re.search(r'\b(?:invalid input (?:syntax|value)\b.*?'
460
                r'|date/time field value out of range): "(.+?)"\n'
461
                r'(?:(?s).*?)\bfunction "(.+?)"', msg)
462
            if match:
463
                value, name = match.groups()
464
                raise FunctionValueException(name, strings.to_unicode(value), e)
465
            
466
            match = re.search(r'column "(.+?)" is of type (.+?) but expression '
467
                r'is of type', msg)
468
            if match:
469
                col, type_ = match.groups()
470
                raise MissingCastException(type_, col, e)
471
            
472
            match = re.search(r'\b(\S+) "(.+?)".*? already exists', msg)
473
            if match:
474
                type_, name = match.groups()
475
                raise DuplicateException(type_, name, e)
476
            
477
            raise # no specific exception raised
478
    except log_ignore_excs:
479
        log_level += 2
480
        raise
481
    finally:
482
        if debug_msg_ref != None and debug_msg_ref[0] != None:
483
            db.log_debug(debug_msg_ref[0], log_level)
484

    
485
##### Basic queries
486

    
487
def next_version(name):
488
    version = 1 # first existing name was version 0
489
    match = re.match(r'^(.*)#(\d+)$', name)
490
    if match:
491
        name, version = match.groups()
492
        version = int(version)+1
493
    return sql_gen.concat(name, '#'+str(version))
494

    
495
def lock_table(db, table, mode):
496
    table = sql_gen.as_Table(table)
497
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
498

    
499
def run_query_into(db, query, into=None, add_indexes_=False, **kw_args):
500
    '''Outputs a query to a temp table.
501
    For params, see run_query().
502
    '''
503
    if into == None: return run_query(db, query, **kw_args)
504
    
505
    assert isinstance(into, sql_gen.Table)
506
    
507
    into.is_temp = True
508
    # "temporary tables cannot specify a schema name", so remove schema
509
    into.schema = None
510
    
511
    kw_args['recover'] = True
512
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
513
    
514
    temp = not db.debug_temp # tables are permanent in debug_temp mode
515
    
516
    # Create table
517
    while True:
518
        create_query = 'CREATE'
519
        if temp: create_query += ' TEMP'
520
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
521
        
522
        try:
523
            cur = run_query(db, create_query, **kw_args)
524
                # CREATE TABLE AS sets rowcount to # rows in query
525
            break
526
        except DuplicateException, e:
527
            into.name = next_version(into.name)
528
            # try again with next version of name
529
    
530
    if add_indexes_: add_indexes(db, into)
531
    
532
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
533
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
534
    # table is going to be used in complex queries, it is wise to run ANALYZE on
535
    # the temporary table after it is populated."
536
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
537
    # If into is not a temp table, ANALYZE is useful but not required.
538
    analyze(db, into)
539
    
540
    return cur
541

    
542
order_by_pkey = object() # tells mk_select() to order by the pkey
543

    
544
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
545

    
546
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
547
    start=None, order_by=order_by_pkey, default_table=None):
548
    '''
549
    @param tables The single table to select from, or a list of tables to join
550
        together, with tables after the first being sql_gen.Join objects
551
    @param fields Use None to select all fields in the table
552
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
553
        * container can be any iterable type
554
        * compare_left_side: sql_gen.Code|str (for col name)
555
        * compare_right_side: sql_gen.ValueCond|literal value
556
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
557
        use all columns
558
    @return query
559
    '''
560
    # Parse tables param
561
    tables = lists.mk_seq(tables)
562
    tables = list(tables) # don't modify input! (list() copies input)
563
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
564
    
565
    # Parse other params
566
    if conds == None: conds = []
567
    elif dicts.is_dict(conds): conds = conds.items()
568
    conds = list(conds) # don't modify input! (list() copies input)
569
    assert limit == None or type(limit) == int
570
    assert start == None or type(start) == int
571
    if order_by is order_by_pkey:
572
        if distinct_on != []: order_by = None
573
        else: order_by = pkey(db, table0, recover=True)
574
    
575
    query = 'SELECT'
576
    
577
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
578
    
579
    # DISTINCT ON columns
580
    if distinct_on != []:
581
        query += '\nDISTINCT'
582
        if distinct_on is not distinct_on_all:
583
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
584
    
585
    # Columns
586
    if fields == None:
587
        if query.find('\n') >= 0: whitespace = '\n'
588
        else: whitespace = ' '
589
        query += whitespace+'*'
590
    else:
591
        assert fields != []
592
        query += '\n'+('\n, '.join(map(parse_col, fields)))
593
    
594
    # Main table
595
    query += '\nFROM '+table0.to_str(db)
596
    
597
    # Add joins
598
    left_table = table0
599
    for join_ in tables:
600
        table = join_.table
601
        
602
        # Parse special values
603
        if join_.type_ is sql_gen.filter_out: # filter no match
604
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
605
                sql_gen.CompareCond(None, '~=')))
606
        
607
        query += '\n'+join_.to_str(db, left_table)
608
        
609
        left_table = table
610
    
611
    missing = True
612
    if conds != []:
613
        if len(conds) == 1: whitespace = ' '
614
        else: whitespace = '\n'
615
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
616
            .to_str(db) for l, r in conds], 'WHERE')
617
        missing = False
618
    if order_by != None:
619
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
620
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
621
    if start != None:
622
        if start != 0: query += '\nOFFSET '+str(start)
623
        missing = False
624
    if missing: warnings.warn(DbWarning(
625
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
626
    
627
    return query
628

    
629
def select(db, *args, **kw_args):
630
    '''For params, see mk_select() and run_query()'''
631
    recover = kw_args.pop('recover', None)
632
    cacheable = kw_args.pop('cacheable', True)
633
    log_level = kw_args.pop('log_level', 2)
634
    
635
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
636
        log_level=log_level)
637

    
638
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
639
    embeddable=False, ignore=False):
640
    '''
641
    @param returning str|None An inserted column (such as pkey) to return
642
    @param embeddable Whether the query should be embeddable as a nested SELECT.
643
        Warning: If you set this and cacheable=True when the query is run, the
644
        query will be fully cached, not just if it raises an exception.
645
    @param ignore Whether to ignore duplicate keys.
646
    '''
647
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
648
    if cols == []: cols = None # no cols (all defaults) = unknown col names
649
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
650
    if select_query == None: select_query = 'DEFAULT VALUES'
651
    if returning != None: returning = sql_gen.as_Col(returning, table)
652
    
653
    first_line = 'INSERT INTO '+table.to_str(db)
654
    
655
    def mk_insert(select_query):
656
        query = first_line
657
        if cols != None:
658
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
659
        query += '\n'+select_query
660
        
661
        if returning != None:
662
            returning_name_col = sql_gen.to_name_only_col(returning)
663
            query += '\nRETURNING '+returning_name_col.to_str(db)
664
        
665
        return query
666
    
667
    return_type = 'unknown'
668
    if returning != None: return_type = returning.to_str(db)+'%TYPE'
669
    
670
    lang = 'sql'
671
    if ignore:
672
        # Always return something to set the correct rowcount
673
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
674
        
675
        embeddable = True # must use function
676
        lang = 'plpgsql'
677
        
678
        if cols == None:
679
            row = [sql_gen.Col(sql_gen.all_cols, 'row')]
680
            row_vars = [sql_gen.Table('row')]
681
        else:
682
            row_vars = row = [sql_gen.Col(c.name, 'row') for c in cols]
683
        
684
        query = '''\
685
DECLARE
686
    row '''+table.to_str(db)+'''%ROWTYPE;
687
BEGIN
688
    /* Need an EXCEPTION block for each individual row because "When an error is
689
    caught by an EXCEPTION clause, [...] all changes to persistent database
690
    state within the block are rolled back."
691
    This is unfortunate because "A block containing an EXCEPTION clause is
692
    significantly more expensive to enter and exit than a block without one."
693
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
694
#PLPGSQL-ERROR-TRAPPING)
695
    */
696
    FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
697
'''+select_query+'''
698
    LOOP
699
        BEGIN
700
            RETURN QUERY
701
'''+mk_insert(sql_gen.Values(row).to_str(db))+'''
702
;
703
        EXCEPTION
704
            WHEN unique_violation THEN NULL; -- continue to next row
705
        END;
706
    END LOOP;
707
END;\
708
'''
709
    else: query = mk_insert(select_query)
710
    
711
    if embeddable:
712
        # Create function
713
        function_name = sql_gen.clean_name(first_line)
714
        while True:
715
            try:
716
                function = db.TempFunction(function_name)
717
                
718
                function_query = '''\
719
CREATE FUNCTION '''+function.to_str(db)+'''()
720
RETURNS SETOF '''+return_type+'''
721
LANGUAGE '''+lang+'''
722
AS $$
723
'''+query+'''
724
$$;
725
'''
726
                run_query(db, function_query, recover=True, cacheable=True,
727
                    log_ignore_excs=(DuplicateException,))
728
                break # this version was successful
729
            except DuplicateException, e:
730
                function_name = next_version(function_name)
731
                # try again with next version of name
732
        
733
        # Return query that uses function
734
        cols = None
735
        if returning != None: cols = [returning]
736
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
737
            cols) # AS clause requires function alias
738
        return mk_select(db, func_table, start=0, order_by=None)
739
    
740
    return query
741

    
742
def insert_select(db, table, *args, **kw_args):
743
    '''For params, see mk_insert_select() and run_query_into()
744
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
745
        values in
746
    '''
747
    into = kw_args.pop('into', None)
748
    if into != None: kw_args['embeddable'] = True
749
    recover = kw_args.pop('recover', None)
750
    if kw_args.get('ignore', False): recover = True
751
    cacheable = kw_args.pop('cacheable', True)
752
    log_level = kw_args.pop('log_level', 2)
753
    
754
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
755
        into, recover=recover, cacheable=cacheable, log_level=log_level)
756
    autoanalyze(db, table)
757
    return cur
758

    
759
default = sql_gen.default # tells insert() to use the default value for a column
760

    
761
def insert(db, table, row, *args, **kw_args):
762
    '''For params, see insert_select()'''
763
    if lists.is_seq(row): cols = None
764
    else:
765
        cols = row.keys()
766
        row = row.values()
767
    row = list(row) # ensure that "== []" works
768
    
769
    if row == []: query = None
770
    else: query = sql_gen.Values(row).to_str(db)
771
    
772
    return insert_select(db, table, cols, query, *args, **kw_args)
773

    
774
def mk_update(db, table, changes=None, cond=None, in_place=False):
775
    '''
776
    @param changes [(col, new_value),...]
777
        * container can be any iterable type
778
        * col: sql_gen.Code|str (for col name)
779
        * new_value: sql_gen.Code|literal value
780
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
781
    @param in_place If set, locks the table and updates rows in place.
782
        This avoids creating dead rows in PostgreSQL.
783
        * cond must be None
784
    @return str query
785
    '''
786
    table = sql_gen.as_Table(table)
787
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
788
        for c, v in changes]
789
    
790
    if in_place:
791
        assert cond == None
792
        
793
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
794
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
795
            +db.col_info(sql_gen.with_default_table(c, table)).type
796
            +'\nUSING '+v.to_str(db) for c, v in changes))
797
    else:
798
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
799
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
800
            for c, v in changes))
801
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
802
    
803
    return query
804

    
805
def update(db, table, *args, **kw_args):
806
    '''For params, see mk_update() and run_query()'''
807
    recover = kw_args.pop('recover', None)
808
    cacheable = kw_args.pop('cacheable', False)
809
    log_level = kw_args.pop('log_level', 2)
810
    
811
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
812
        cacheable, log_level=log_level)
813
    autoanalyze(db, table)
814
    return cur
815

    
816
def last_insert_id(db):
817
    module = util.root_module(db.db)
818
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
819
    elif module == 'MySQLdb': return db.insert_id()
820
    else: return None
821

    
822
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
823
    '''Creates a mapping from original column names (which may have collisions)
824
    to names that will be distinct among the columns' tables.
825
    This is meant to be used for several tables that are being joined together.
826
    @param cols The columns to combine. Duplicates will be removed.
827
    @param into The table for the new columns.
828
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
829
        columns will be included in the mapping even if they are not in cols.
830
        The tables of the provided Col objects will be changed to into, so make
831
        copies of them if you want to keep the original tables.
832
    @param as_items Whether to return a list of dict items instead of a dict
833
    @return dict(orig_col=new_col, ...)
834
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
835
        * new_col: sql_gen.Col(orig_col_name, into)
836
        * All mappings use the into table so its name can easily be
837
          changed for all columns at once
838
    '''
839
    cols = lists.uniqify(cols)
840
    
841
    items = []
842
    for col in preserve:
843
        orig_col = copy.copy(col)
844
        col.table = into
845
        items.append((orig_col, col))
846
    preserve = set(preserve)
847
    for col in cols:
848
        if col not in preserve:
849
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
850
    
851
    if not as_items: items = dict(items)
852
    return items
853

    
854
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
855
    '''For params, see mk_flatten_mapping()
856
    @return See return value of mk_flatten_mapping()
857
    '''
858
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
859
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
860
    run_query_into(db, mk_select(db, joins, cols, limit=limit, start=start),
861
        into=into, add_indexes_=True)
862
    return dict(items)
863

    
864
##### Database structure introspection
865

    
866
#### Tables
867

    
868
def tables(db, schema_like='public', table_like='%', exact=False):
869
    if exact: compare = '='
870
    else: compare = 'LIKE'
871
    
872
    module = util.root_module(db.db)
873
    if module == 'psycopg2':
874
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
875
            ('tablename', sql_gen.CompareCond(table_like, compare))]
876
        return values(select(db, 'pg_tables', ['tablename'], conds,
877
            order_by='tablename', log_level=4))
878
    elif module == 'MySQLdb':
879
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
880
            , cacheable=True, log_level=4))
881
    else: raise NotImplementedError("Can't list tables for "+module+' database')
882

    
883
def table_exists(db, table):
884
    table = sql_gen.as_Table(table)
885
    return list(tables(db, table.schema, table.name, exact=True)) != []
886

    
887
def table_row_count(db, table, recover=None):
888
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
889
        order_by=None, start=0), recover=recover, log_level=3))
890

    
891
def table_cols(db, table, recover=None):
892
    return list(col_names(select(db, table, limit=0, order_by=None,
893
        recover=recover, log_level=4)))
894

    
895
def pkey(db, table, recover=None):
896
    '''Assumed to be first column in table'''
897
    return table_cols(db, table, recover)[0]
898

    
899
not_null_col = 'not_null_col'
900

    
901
def table_not_null_col(db, table, recover=None):
902
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
903
    if not_null_col in table_cols(db, table, recover): return not_null_col
904
    else: return pkey(db, table, recover)
905

    
906
def index_cols(db, table, index):
907
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
908
    automatically created. When you don't know whether something is a UNIQUE
909
    constraint or a UNIQUE index, use this function.'''
910
    module = util.root_module(db.db)
911
    if module == 'psycopg2':
912
        return list(values(run_query(db, '''\
913
SELECT attname
914
FROM
915
(
916
        SELECT attnum, attname
917
        FROM pg_index
918
        JOIN pg_class index ON index.oid = indexrelid
919
        JOIN pg_class table_ ON table_.oid = indrelid
920
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
921
        WHERE
922
            table_.relname = '''+db.esc_value(table)+'''
923
            AND index.relname = '''+db.esc_value(index)+'''
924
    UNION
925
        SELECT attnum, attname
926
        FROM
927
        (
928
            SELECT
929
                indrelid
930
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
931
                    AS indkey
932
            FROM pg_index
933
            JOIN pg_class index ON index.oid = indexrelid
934
            JOIN pg_class table_ ON table_.oid = indrelid
935
            WHERE
936
                table_.relname = '''+db.esc_value(table)+'''
937
                AND index.relname = '''+db.esc_value(index)+'''
938
        ) s
939
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
940
) s
941
ORDER BY attnum
942
'''
943
            , cacheable=True, log_level=4)))
944
    else: raise NotImplementedError("Can't list index columns for "+module+
945
        ' database')
946

    
947
def constraint_cols(db, table, constraint):
948
    module = util.root_module(db.db)
949
    if module == 'psycopg2':
950
        return list(values(run_query(db, '''\
951
SELECT attname
952
FROM pg_constraint
953
JOIN pg_class ON pg_class.oid = conrelid
954
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
955
WHERE
956
    relname = '''+db.esc_value(table)+'''
957
    AND conname = '''+db.esc_value(constraint)+'''
958
ORDER BY attnum
959
'''
960
            )))
961
    else: raise NotImplementedError("Can't list constraint columns for "+module+
962
        ' database')
963

    
964
#### Functions
965

    
966
def function_exists(db, function):
967
    function = sql_gen.as_Function(function)
968
    
969
    info_table = sql_gen.Table('routines', 'information_schema')
970
    conds = [('routine_name', function.name)]
971
    schema = function.schema
972
    if schema != None: conds.append(('routine_schema', schema))
973
    # Exclude trigger functions, since they cannot be called directly
974
    conds.append(('data_type', sql_gen.CompareCond('trigger', '!=')))
975
    
976
    return list(values(select(db, info_table, ['routine_name'], conds,
977
        order_by='routine_schema', limit=1, log_level=4))) != []
978
        # TODO: order_by search_path schema order
979

    
980
##### Structural changes
981

    
982
#### Columns
983

    
984
def add_col(db, table, col, comment=None, **kw_args):
985
    '''
986
    @param col TypedCol Name may be versioned, so be sure to propagate any
987
        renaming back to any source column for the TypedCol.
988
    @param comment None|str SQL comment used to distinguish columns of the same
989
        name from each other when they contain different data, to allow the
990
        ADD COLUMN query to be cached. If not set, query will not be cached.
991
    '''
992
    assert isinstance(col, sql_gen.TypedCol)
993
    
994
    while True:
995
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
996
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
997
        
998
        try:
999
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
1000
            break
1001
        except DuplicateException:
1002
            col.name = next_version(col.name)
1003
            # try again with next version of name
1004

    
1005
def add_not_null(db, col):
1006
    table = col.table
1007
    col = sql_gen.to_name_only_col(col)
1008
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1009
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1010

    
1011
row_num_col = '_row_num'
1012

    
1013
row_num_typed_col = sql_gen.TypedCol(row_num_col, 'serial', nullable=False,
1014
    constraints='PRIMARY KEY')
1015

    
1016
def add_row_num(db, table):
1017
    '''Adds a row number column to a table. Its name is in row_num_col. It will
1018
    be the primary key.'''
1019
    add_col(db, table, row_num_typed_col, log_level=3)
1020

    
1021
#### Indexes
1022

    
1023
def add_pkey(db, table, cols=None, recover=None):
1024
    '''Adds a primary key.
1025
    @param cols [sql_gen.Col,...] The columns in the primary key.
1026
        Defaults to the first column in the table.
1027
    @pre The table must not already have a primary key.
1028
    '''
1029
    table = sql_gen.as_Table(table)
1030
    if cols == None: cols = [pkey(db, table, recover)]
1031
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1032
    
1033
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1034
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1035
        log_ignore_excs=(DuplicateException,))
1036

    
1037
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1038
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1039
    Currently, only function calls are supported as expressions.
1040
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1041
        This allows indexes to be used for comparisons where NULLs are equal.
1042
    '''
1043
    exprs = lists.mk_seq(exprs)
1044
    
1045
    # Parse exprs
1046
    old_exprs = exprs[:]
1047
    exprs = []
1048
    cols = []
1049
    for i, expr in enumerate(old_exprs):
1050
        expr = sql_gen.as_Col(expr, table)
1051
        
1052
        # Handle nullable columns
1053
        if ensure_not_null_:
1054
            try: expr = ensure_not_null(db, expr)
1055
            except KeyError: pass # unknown type, so just create plain index
1056
        
1057
        # Extract col
1058
        expr = copy.deepcopy(expr) # don't modify input!
1059
        if isinstance(expr, sql_gen.FunctionCall):
1060
            col = expr.args[0]
1061
            expr = sql_gen.Expr(expr)
1062
        else: col = expr
1063
        assert isinstance(col, sql_gen.Col)
1064
        
1065
        # Extract table
1066
        if table == None:
1067
            assert sql_gen.is_table_col(col)
1068
            table = col.table
1069
        
1070
        col.table = None
1071
        
1072
        exprs.append(expr)
1073
        cols.append(col)
1074
    
1075
    table = sql_gen.as_Table(table)
1076
    index = sql_gen.Table(str(sql_gen.Col(','.join(map(str, cols)), table)))
1077
    
1078
    # Add index
1079
    while True:
1080
        str_ = 'CREATE'
1081
        if unique: str_ += ' UNIQUE'
1082
        str_ += ' INDEX '+index.to_str(db)+' ON '+table.to_str(db)+' ('+(
1083
            ', '.join((v.to_str(db) for v in exprs)))+')'
1084
        
1085
        try:
1086
            run_query(db, str_, recover=True, cacheable=True, log_level=3,
1087
                log_ignore_excs=(DuplicateException,))
1088
            break
1089
        except DuplicateException:
1090
            index.name = next_version(index.name)
1091
            # try again with next version of name
1092

    
1093
def add_index_col(db, col, suffix, expr, nullable=True):
1094
    if sql_gen.index_col(col) != None: return # already has index col
1095
    
1096
    new_col = sql_gen.suffixed_col(col, suffix)
1097
    
1098
    # Add column
1099
    new_typed_col = sql_gen.TypedCol(new_col.name, db.col_info(col).type)
1100
    add_col(db, col.table, new_typed_col, comment='src: '+repr(col),
1101
        log_level=3)
1102
    new_col.name = new_typed_col.name # propagate any renaming
1103
    
1104
    update(db, col.table, [(new_col, expr)], in_place=True, cacheable=True,
1105
        log_level=3)
1106
    if not nullable: add_not_null(db, new_col)
1107
    add_index(db, new_col)
1108
    
1109
    col.table.index_cols[col.name] = new_col
1110

    
1111
# Controls when ensure_not_null() will use index columns
1112
not_null_index_cols_min_rows = 0 # rows; initially always use index columns
1113

    
1114
def ensure_not_null(db, col):
1115
    '''For params, see sql_gen.ensure_not_null()'''
1116
    expr = sql_gen.ensure_not_null(db, col)
1117
    
1118
    # If a nullable column in a temp table, add separate index column instead.
1119
    # Note that for small datasources, this adds 6-25% to the total import time.
1120
    if (sql_gen.is_temp_col(col) and isinstance(expr, sql_gen.EnsureNotNull)
1121
        and table_row_count(db, col.table) >= not_null_index_cols_min_rows):
1122
        add_index_col(db, col, '::NOT NULL', expr, nullable=False)
1123
        expr = sql_gen.index_col(col)
1124
    
1125
    return expr
1126

    
1127
already_indexed = object() # tells add_indexes() the pkey has already been added
1128

    
1129
def add_indexes(db, table, has_pkey=True):
1130
    '''Adds an index on all columns in a table.
1131
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1132
        index should be added on the first column.
1133
        * If already_indexed, the pkey is assumed to have already been added
1134
    '''
1135
    cols = table_cols(db, table)
1136
    if has_pkey:
1137
        if has_pkey is not already_indexed: add_pkey(db, table)
1138
        cols = cols[1:]
1139
    for col in cols: add_index(db, col, table)
1140

    
1141
#### Tables
1142

    
1143
### Maintenance
1144

    
1145
def analyze(db, table):
1146
    table = sql_gen.as_Table(table)
1147
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1148

    
1149
def autoanalyze(db, table):
1150
    if db.autoanalyze: analyze(db, table)
1151

    
1152
def vacuum(db, table):
1153
    table = sql_gen.as_Table(table)
1154
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1155
        log_level=3))
1156

    
1157
### Lifecycle
1158

    
1159
def drop_table(db, table):
1160
    table = sql_gen.as_Table(table)
1161
    return run_query(db, 'DROP TABLE IF EXISTS '+table.to_str(db)+' CASCADE')
1162

    
1163
def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
1164
    like=None):
1165
    '''Creates a table.
1166
    @param cols [sql_gen.TypedCol,...] The column names and types
1167
    @param has_pkey If set, the first column becomes the primary key.
1168
    @param col_indexes bool|[ref]
1169
        * If True, indexes will be added on all non-pkey columns.
1170
        * If a list reference, [0] will be set to a function to do this.
1171
          This can be used to delay index creation until the table is populated.
1172
    '''
1173
    table = sql_gen.as_Table(table)
1174
    
1175
    if like != None:
1176
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
1177
            ]+cols
1178
    if has_pkey:
1179
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1180
        pkey.constraints = 'PRIMARY KEY'
1181
    
1182
    temp = table.is_temp and not db.debug_temp
1183
        # temp tables permanent in debug_temp mode
1184
    
1185
    # Create table
1186
    while True:
1187
        str_ = 'CREATE'
1188
        if temp: str_ += ' TEMP'
1189
        str_ += ' TABLE '+table.to_str(db)+' (\n'
1190
        str_ += '\n, '.join(c.to_str(db) for c in cols)
1191
        str_ += '\n);\n'
1192
        
1193
        try:
1194
            run_query(db, str_, cacheable=True, log_level=2,
1195
                log_ignore_excs=(DuplicateException,))
1196
            break
1197
        except DuplicateException:
1198
            table.name = next_version(table.name)
1199
            # try again with next version of name
1200
    
1201
    # Add indexes
1202
    if has_pkey: has_pkey = already_indexed
1203
    def add_indexes_(): add_indexes(db, table, has_pkey)
1204
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1205
    elif col_indexes: add_indexes_() # add now
1206

    
1207
def copy_table_struct(db, src, dest):
1208
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
1209
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)
1210

    
1211
### Data
1212

    
1213
def truncate(db, table, schema='public', **kw_args):
1214
    '''For params, see run_query()'''
1215
    table = sql_gen.as_Table(table, schema)
1216
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1217

    
1218
def empty_temp(db, tables):
1219
    if db.debug_temp: return # leave temp tables there for debugging
1220
    tables = lists.mk_seq(tables)
1221
    for table in tables: truncate(db, table, log_level=3)
1222

    
1223
def empty_db(db, schema='public', **kw_args):
1224
    '''For kw_args, see tables()'''
1225
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
(24-24/37)