Project

General

Profile

1
# Database access
2

    
3
import copy
4
import re
5
import warnings
6

    
7
import exc
8
import dicts
9
import iters
10
import lists
11
from Proxy import Proxy
12
import rand
13
import sql_gen
14
import strings
15
import util
16

    
17
##### Exceptions
18

    
19
def get_cur_query(cur, input_query=None):
20
    raw_query = None
21
    if hasattr(cur, 'query'): raw_query = cur.query
22
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
23
    
24
    if raw_query != None: return raw_query
25
    else: return '[input] '+strings.ustr(input_query)
26

    
27
def _add_cursor_info(e, *args, **kw_args):
28
    '''For params, see get_cur_query()'''
29
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
30

    
31
class DbException(exc.ExceptionWithCause):
32
    def __init__(self, msg, cause=None, cur=None):
33
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
34
        if cur != None: _add_cursor_info(self, cur)
35

    
36
class ExceptionWithName(DbException):
37
    def __init__(self, name, cause=None):
38
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
39
        self.name = name
40

    
41
class ExceptionWithNameValue(DbException):
42
    def __init__(self, name, value, cause=None):
43
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name))
44
            +'; value: '+strings.as_tt(repr(value)), cause)
45
        self.name = name
46
        self.value = value
47

    
48
class ExceptionWithNameType(DbException):
49
    def __init__(self, type_, name, cause=None):
50
        DbException.__init__(self, 'for type: '+strings.as_tt(str(type_))
51
            +'; name: '+strings.as_tt(name), cause)
52
        self.type = type_
53
        self.name = name
54

    
55
class ConstraintException(DbException):
56
    def __init__(self, name, cols, cause=None):
57
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
58
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
59
        self.name = name
60
        self.cols = cols
61

    
62
class MissingCastException(DbException):
63
    def __init__(self, type_, col, cause=None):
64
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
65
            +' on column: '+strings.as_tt(col), cause)
66
        self.type = type_
67
        self.col = col
68

    
69
class NameException(DbException): pass
70

    
71
class DuplicateKeyException(ConstraintException): pass
72

    
73
class NullValueException(ConstraintException): pass
74

    
75
class FunctionValueException(ExceptionWithNameValue): pass
76

    
77
class DuplicateException(ExceptionWithNameType): pass
78

    
79
class EmptyRowException(DbException): pass
80

    
81
##### Warnings
82

    
83
class DbWarning(UserWarning): pass
84

    
85
##### Result retrieval
86

    
87
def col_names(cur): return (col[0] for col in cur.description)
88

    
89
def rows(cur): return iter(lambda: cur.fetchone(), None)
90

    
91
def consume_rows(cur):
92
    '''Used to fetch all rows so result will be cached'''
93
    iters.consume_iter(rows(cur))
94

    
95
def next_row(cur): return rows(cur).next()
96

    
97
def row(cur):
98
    row_ = next_row(cur)
99
    consume_rows(cur)
100
    return row_
101

    
102
def next_value(cur): return next_row(cur)[0]
103

    
104
def value(cur): return row(cur)[0]
105

    
106
def values(cur): return iters.func_iter(lambda: next_value(cur))
107

    
108
def value_or_none(cur):
109
    try: return value(cur)
110
    except StopIteration: return None
111

    
112
##### Escaping
113

    
114
def esc_name_by_module(module, name):
115
    if module == 'psycopg2' or module == None: quote = '"'
116
    elif module == 'MySQLdb': quote = '`'
117
    else: raise NotImplementedError("Can't escape name for "+module+' database')
118
    return sql_gen.esc_name(name, quote)
119

    
120
def esc_name_by_engine(engine, name, **kw_args):
121
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
122

    
123
def esc_name(db, name, **kw_args):
124
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
125

    
126
def qual_name(db, schema, table):
127
    def esc_name_(name): return esc_name(db, name)
128
    table = esc_name_(table)
129
    if schema != None: return esc_name_(schema)+'.'+table
130
    else: return table
131

    
132
##### Database connections
133

    
134
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
135

    
136
db_engines = {
137
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
138
    'PostgreSQL': ('psycopg2', {}),
139
}
140

    
141
DatabaseErrors_set = set([DbException])
142
DatabaseErrors = tuple(DatabaseErrors_set)
143

    
144
def _add_module(module):
145
    DatabaseErrors_set.add(module.DatabaseError)
146
    global DatabaseErrors
147
    DatabaseErrors = tuple(DatabaseErrors_set)
148

    
149
def db_config_str(db_config):
150
    return db_config['engine']+' database '+db_config['database']
151

    
152
log_debug_none = lambda msg, level=2: None
153

    
154
class DbConn:
155
    def __init__(self, db_config, autocommit=True, caching=True,
156
        log_debug=log_debug_none, debug_temp=False):
157
        '''
158
        @param debug_temp Whether temporary objects should instead be permanent.
159
            This assists in debugging the internal objects used by the program.
160
        '''
161
        self.db_config = db_config
162
        self.autocommit = autocommit
163
        self.caching = caching
164
        self.log_debug = log_debug
165
        self.debug = log_debug != log_debug_none
166
        self.debug_temp = debug_temp
167
        self.autoanalyze = False
168
        
169
        self.__db = None
170
        self.query_results = {}
171
        self._savepoint = 0
172
        self._notices_seen = set()
173
    
174
    def __getattr__(self, name):
175
        if name == '__dict__': raise Exception('getting __dict__')
176
        if name == 'db': return self._db()
177
        else: raise AttributeError()
178
    
179
    def __getstate__(self):
180
        state = copy.copy(self.__dict__) # shallow copy
181
        state['log_debug'] = None # don't pickle the debug callback
182
        state['_DbConn__db'] = None # don't pickle the connection
183
        return state
184
    
185
    def connected(self): return self.__db != None
186
    
187
    def _db(self):
188
        if self.__db == None:
189
            # Process db_config
190
            db_config = self.db_config.copy() # don't modify input!
191
            schemas = db_config.pop('schemas', None)
192
            module_name, mappings = db_engines[db_config.pop('engine')]
193
            module = __import__(module_name)
194
            _add_module(module)
195
            for orig, new in mappings.iteritems():
196
                try: util.rename_key(db_config, orig, new)
197
                except KeyError: pass
198
            
199
            # Connect
200
            self.__db = module.connect(**db_config)
201
            
202
            # Configure connection
203
            if hasattr(self.db, 'set_isolation_level'):
204
                import psycopg2.extensions
205
                self.db.set_isolation_level(
206
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
207
            if schemas != None:
208
                search_path = [self.esc_name(s) for s in schemas.split(',')]
209
                search_path.append(value(run_query(self, 'SHOW search_path',
210
                    log_level=4)))
211
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
212
                    log_level=3)
213
        
214
        return self.__db
215
    
216
    class DbCursor(Proxy):
217
        def __init__(self, outer):
218
            Proxy.__init__(self, outer.db.cursor())
219
            self.outer = outer
220
            self.query_results = outer.query_results
221
            self.query_lookup = None
222
            self.result = []
223
        
224
        def execute(self, query):
225
            self._is_insert = query.startswith('INSERT')
226
            self.query_lookup = query
227
            try:
228
                try:
229
                    cur = self.inner.execute(query)
230
                    self.outer.do_autocommit()
231
                finally: self.query = get_cur_query(self.inner, query)
232
            except Exception, e:
233
                _add_cursor_info(e, self, query)
234
                self.result = e # cache the exception as the result
235
                self._cache_result()
236
                raise
237
            
238
            # Always cache certain queries
239
            if query.startswith('CREATE') or query.startswith('ALTER'):
240
                # structural changes
241
                # Rest of query must be unique in the face of name collisions,
242
                # so don't cache ADD COLUMN unless it has distinguishing comment
243
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
244
                    self._cache_result()
245
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
246
                consume_rows(self) # fetch all rows so result will be cached
247
            
248
            return cur
249
        
250
        def fetchone(self):
251
            row = self.inner.fetchone()
252
            if row != None: self.result.append(row)
253
            # otherwise, fetched all rows
254
            else: self._cache_result()
255
            return row
256
        
257
        def _cache_result(self):
258
            # For inserts that return a result set, don't cache result set since
259
            # inserts are not idempotent. Other non-SELECT queries don't have
260
            # their result set read, so only exceptions will be cached (an
261
            # invalid query will always be invalid).
262
            if self.query_results != None and (not self._is_insert
263
                or isinstance(self.result, Exception)):
264
                
265
                assert self.query_lookup != None
266
                self.query_results[self.query_lookup] = self.CacheCursor(
267
                    util.dict_subset(dicts.AttrsDictView(self),
268
                    ['query', 'result', 'rowcount', 'description']))
269
        
270
        class CacheCursor:
271
            def __init__(self, cached_result): self.__dict__ = cached_result
272
            
273
            def execute(self, *args, **kw_args):
274
                if isinstance(self.result, Exception): raise self.result
275
                # otherwise, result is a rows list
276
                self.iter = iter(self.result)
277
            
278
            def fetchone(self):
279
                try: return self.iter.next()
280
                except StopIteration: return None
281
    
282
    def esc_value(self, value):
283
        try: str_ = self.mogrify('%s', [value])
284
        except NotImplementedError, e:
285
            module = util.root_module(self.db)
286
            if module == 'MySQLdb':
287
                import _mysql
288
                str_ = _mysql.escape_string(value)
289
            else: raise e
290
        return strings.to_unicode(str_)
291
    
292
    def esc_name(self, name): return esc_name(self, name) # calls global func
293
    
294
    def std_code(self, str_):
295
        '''Standardizes SQL code.
296
        * Ensures that string literals are prefixed by `E`
297
        '''
298
        if str_.startswith("'"): str_ = 'E'+str_
299
        return str_
300
    
301
    def can_mogrify(self):
302
        module = util.root_module(self.db)
303
        return module == 'psycopg2'
304
    
305
    def mogrify(self, query, params=None):
306
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
307
        else: raise NotImplementedError("Can't mogrify query")
308
    
309
    def print_notices(self):
310
        if hasattr(self.db, 'notices'):
311
            for msg in self.db.notices:
312
                if msg not in self._notices_seen:
313
                    self._notices_seen.add(msg)
314
                    self.log_debug(msg, level=2)
315
    
316
    def run_query(self, query, cacheable=False, log_level=2,
317
        debug_msg_ref=None):
318
        '''
319
        @param log_ignore_excs The log_level will be increased by 2 if the query
320
            throws one of these exceptions.
321
        @param debug_msg_ref If specified, the log message will be returned in
322
            this instead of being output. This allows you to filter log messages
323
            depending on the result of the query.
324
        '''
325
        assert query != None
326
        
327
        if not self.caching: cacheable = False
328
        used_cache = False
329
        
330
        def log_msg(query):
331
            if used_cache: cache_status = 'cache hit'
332
            elif cacheable: cache_status = 'cache miss'
333
            else: cache_status = 'non-cacheable'
334
            return 'DB query: '+cache_status+':\n'+strings.as_code(query, 'SQL')
335
        
336
        try:
337
            # Get cursor
338
            if cacheable:
339
                try:
340
                    cur = self.query_results[query]
341
                    used_cache = True
342
                except KeyError: cur = self.DbCursor(self)
343
            else: cur = self.db.cursor()
344
            
345
            # Log query
346
            if self.debug and debug_msg_ref == None: # log before running
347
                self.log_debug(log_msg(query), log_level)
348
            
349
            # Run query
350
            cur.execute(query)
351
        finally:
352
            self.print_notices()
353
            if self.debug and debug_msg_ref != None: # return after running
354
                debug_msg_ref[0] = log_msg(str(get_cur_query(cur, query)))
355
        
356
        return cur
357
    
358
    def is_cached(self, query): return query in self.query_results
359
    
360
    def with_autocommit(self, func):
361
        import psycopg2.extensions
362
        
363
        prev_isolation_level = self.db.isolation_level
364
        self.db.set_isolation_level(
365
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
366
        try: return func()
367
        finally: self.db.set_isolation_level(prev_isolation_level)
368
    
369
    def with_savepoint(self, func):
370
        savepoint = 'level_'+str(self._savepoint)
371
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
372
        self._savepoint += 1
373
        try: return func()
374
        except:
375
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
376
            raise
377
        finally:
378
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
379
            # "The savepoint remains valid and can be rolled back to again"
380
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
381
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
382
            
383
            self._savepoint -= 1
384
            assert self._savepoint >= 0
385
            
386
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
387
    
388
    def do_autocommit(self):
389
        '''Autocommits if outside savepoint'''
390
        assert self._savepoint >= 0
391
        if self.autocommit and self._savepoint == 0:
392
            self.log_debug('Autocommitting', level=4)
393
            self.db.commit()
394
    
395
    def col_info(self, col):
396
        table = sql_gen.Table('columns', 'information_schema')
397
        type_ = sql_gen.Coalesce(sql_gen.Nullif(sql_gen.Col('data_type'),
398
            'USER-DEFINED'), sql_gen.Col('udt_name'))
399
        cols = [type_, 'column_default',
400
            sql_gen.Cast('boolean', sql_gen.Col('is_nullable'))]
401
        
402
        conds = [('table_name', col.table.name), ('column_name', col.name)]
403
        schema = col.table.schema
404
        if schema != None: conds.append(('table_schema', schema))
405
        
406
        type_, default, nullable = row(select(self, table, cols, conds,
407
            order_by='table_schema', limit=1, cacheable=False, log_level=4))
408
            # TODO: order_by search_path schema order
409
        default = sql_gen.as_Code(default, self)
410
        
411
        return sql_gen.TypedCol(col.name, type_, default, nullable)
412
    
413
    def TempFunction(self, name):
414
        if self.debug_temp: schema = None
415
        else: schema = 'pg_temp'
416
        return sql_gen.Function(name, schema)
417

    
418
connect = DbConn
419

    
420
##### Recoverable querying
421

    
422
def with_savepoint(db, func): return db.with_savepoint(func)
423

    
424
def run_query(db, query, recover=None, cacheable=False, log_level=2,
425
    log_ignore_excs=None, **kw_args):
426
    '''For params, see DbConn.run_query()'''
427
    if recover == None: recover = False
428
    if log_ignore_excs == None: log_ignore_excs = ()
429
    log_ignore_excs = tuple(log_ignore_excs)
430
    
431
    debug_msg_ref = None # usually, db.run_query() logs query before running it
432
    # But if filtering with log_ignore_excs, wait until after exception parsing
433
    if log_ignore_excs != () or not db.can_mogrify(): debug_msg_ref = [None]
434
    
435
    try:
436
        try:
437
            def run(): return db.run_query(query, cacheable, log_level,
438
                debug_msg_ref, **kw_args)
439
            if recover and not db.is_cached(query):
440
                return with_savepoint(db, run)
441
            else: return run() # don't need savepoint if cached
442
        except Exception, e:
443
            msg = exc.str_(e)
444
            
445
            match = re.search(r'duplicate key value violates unique constraint '
446
                r'"((_?[^\W_]+)_.+?)"', msg)
447
            if match:
448
                constraint, table = match.groups()
449
                cols = []
450
                if recover: # need auto-rollback to run index_cols()
451
                    try: cols = index_cols(db, table, constraint)
452
                    except NotImplementedError: pass
453
                raise DuplicateKeyException(constraint, cols, e)
454
            
455
            match = re.search(r'null value in column "(.+?)" violates not-null'
456
                r' constraint', msg)
457
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
458
            
459
            match = re.search(r'\b(?:invalid input (?:syntax|value)\b.*?'
460
                r'|date/time field value out of range): "(.+?)"\n'
461
                r'(?:(?s).*?)\bfunction "(.+?)"', msg)
462
            if match:
463
                value, name = match.groups()
464
                raise FunctionValueException(name, strings.to_unicode(value), e)
465
            
466
            match = re.search(r'column "(.+?)" is of type (.+?) but expression '
467
                r'is of type', msg)
468
            if match:
469
                col, type_ = match.groups()
470
                raise MissingCastException(type_, col, e)
471
            
472
            match = re.search(r'\b(\S+) "(.+?)".*? already exists', msg)
473
            if match:
474
                type_, name = match.groups()
475
                raise DuplicateException(type_, name, e)
476
            
477
            raise # no specific exception raised
478
    except log_ignore_excs:
479
        log_level += 2
480
        raise
481
    finally:
482
        if debug_msg_ref != None and debug_msg_ref[0] != None:
483
            db.log_debug(debug_msg_ref[0], log_level)
484

    
485
##### Basic queries
486

    
487
def next_version(name):
488
    version = 1 # first existing name was version 0
489
    match = re.match(r'^(.*)#(\d+)$', name)
490
    if match:
491
        name, version = match.groups()
492
        version = int(version)+1
493
    return sql_gen.concat(name, '#'+str(version))
494

    
495
def lock_table(db, table, mode):
496
    table = sql_gen.as_Table(table)
497
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')
498

    
499
def run_query_into(db, query, into=None, add_indexes_=False, **kw_args):
500
    '''Outputs a query to a temp table.
501
    For params, see run_query().
502
    '''
503
    if into == None: return run_query(db, query, **kw_args)
504
    
505
    assert isinstance(into, sql_gen.Table)
506
    
507
    into.is_temp = True
508
    # "temporary tables cannot specify a schema name", so remove schema
509
    into.schema = None
510
    
511
    kw_args['recover'] = True
512
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
513
    
514
    temp = not db.debug_temp # tables are permanent in debug_temp mode
515
    
516
    # Create table
517
    while True:
518
        create_query = 'CREATE'
519
        if temp: create_query += ' TEMP'
520
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
521
        
522
        try:
523
            cur = run_query(db, create_query, **kw_args)
524
                # CREATE TABLE AS sets rowcount to # rows in query
525
            break
526
        except DuplicateException, e:
527
            into.name = next_version(into.name)
528
            # try again with next version of name
529
    
530
    if add_indexes_: add_indexes(db, into)
531
    
532
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
533
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
534
    # table is going to be used in complex queries, it is wise to run ANALYZE on
535
    # the temporary table after it is populated."
536
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
537
    # If into is not a temp table, ANALYZE is useful but not required.
538
    analyze(db, into)
539
    
540
    return cur
541

    
542
order_by_pkey = object() # tells mk_select() to order by the pkey
543

    
544
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
545

    
546
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
547
    start=None, order_by=order_by_pkey, default_table=None):
548
    '''
549
    @param tables The single table to select from, or a list of tables to join
550
        together, with tables after the first being sql_gen.Join objects
551
    @param fields Use None to select all fields in the table
552
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
553
        * container can be any iterable type
554
        * compare_left_side: sql_gen.Code|str (for col name)
555
        * compare_right_side: sql_gen.ValueCond|literal value
556
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
557
        use all columns
558
    @return query
559
    '''
560
    # Parse tables param
561
    tables = lists.mk_seq(tables)
562
    tables = list(tables) # don't modify input! (list() copies input)
563
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
564
    
565
    # Parse other params
566
    if conds == None: conds = []
567
    elif dicts.is_dict(conds): conds = conds.items()
568
    conds = list(conds) # don't modify input! (list() copies input)
569
    assert limit == None or type(limit) == int
570
    assert start == None or type(start) == int
571
    if order_by is order_by_pkey:
572
        if distinct_on != []: order_by = None
573
        else: order_by = pkey(db, table0, recover=True)
574
    
575
    query = 'SELECT'
576
    
577
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
578
    
579
    # DISTINCT ON columns
580
    if distinct_on != []:
581
        query += '\nDISTINCT'
582
        if distinct_on is not distinct_on_all:
583
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
584
    
585
    # Columns
586
    if fields == None:
587
        if query.find('\n') >= 0: whitespace = '\n'
588
        else: whitespace = ' '
589
        query += whitespace+'*'
590
    else:
591
        assert fields != []
592
        query += '\n'+('\n, '.join(map(parse_col, fields)))
593
    
594
    # Main table
595
    query += '\nFROM '+table0.to_str(db)
596
    
597
    # Add joins
598
    left_table = table0
599
    for join_ in tables:
600
        table = join_.table
601
        
602
        # Parse special values
603
        if join_.type_ is sql_gen.filter_out: # filter no match
604
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
605
                sql_gen.CompareCond(None, '~=')))
606
        
607
        query += '\n'+join_.to_str(db, left_table)
608
        
609
        left_table = table
610
    
611
    missing = True
612
    if conds != []:
613
        if len(conds) == 1: whitespace = ' '
614
        else: whitespace = '\n'
615
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
616
            .to_str(db) for l, r in conds], 'WHERE')
617
        missing = False
618
    if order_by != None:
619
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
620
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
621
    if start != None:
622
        if start != 0: query += '\nOFFSET '+str(start)
623
        missing = False
624
    if missing: warnings.warn(DbWarning(
625
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
626
    
627
    return query
628

    
629
def select(db, *args, **kw_args):
630
    '''For params, see mk_select() and run_query()'''
631
    recover = kw_args.pop('recover', None)
632
    cacheable = kw_args.pop('cacheable', True)
633
    log_level = kw_args.pop('log_level', 2)
634
    
635
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
636
        log_level=log_level)
637

    
638
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
639
    embeddable=False, ignore=False):
640
    '''
641
    @param returning str|None An inserted column (such as pkey) to return
642
    @param embeddable Whether the query should be embeddable as a nested SELECT.
643
        Warning: If you set this and cacheable=True when the query is run, the
644
        query will be fully cached, not just if it raises an exception.
645
    @param ignore Whether to ignore duplicate keys.
646
    '''
647
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
648
    if cols == []: cols = None # no cols (all defaults) = unknown col names
649
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
650
    if select_query == None: select_query = 'DEFAULT VALUES'
651
    if returning != None: returning = sql_gen.as_Col(returning, table)
652
    
653
    first_line = 'INSERT INTO '+table.to_str(db)
654
    
655
    def mk_insert(select_query):
656
        query = first_line
657
        if cols != None:
658
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
659
        query += '\n'+select_query
660
        
661
        if returning != None:
662
            returning_name_col = sql_gen.to_name_only_col(returning)
663
            query += '\nRETURNING '+returning_name_col.to_str(db)
664
        
665
        return query
666
    
667
    return_type = 'unknown'
668
    if returning != None: return_type = returning.to_str(db)+'%TYPE'
669
    
670
    lang = 'sql'
671
    if ignore:
672
        assert cols != None
673
        # Always return something to set the correct rowcount
674
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
675
        
676
        embeddable = True # must use function
677
        lang = 'plpgsql'
678
        row = [sql_gen.Col(c.name, 'row') for c in cols]
679
        
680
        query = '''\
681
DECLARE
682
    row '''+table.to_str(db)+'''%ROWTYPE;
683
BEGIN
684
    /* Need an EXCEPTION block for each individual row because "When an error is
685
    caught by an EXCEPTION clause, [...] all changes to persistent database
686
    state within the block are rolled back."
687
    This is unfortunate because "A block containing an EXCEPTION clause is
688
    significantly more expensive to enter and exit than a block without one."
689
    (http://www.postgresql.org/docs/8.3/static/plpgsql-control-structures.html\
690
#PLPGSQL-ERROR-TRAPPING)
691
    */
692
    FOR '''+(', '.join((c.to_str(db) for c in row)))+''' IN
693
'''+select_query+'''
694
    LOOP
695
        BEGIN
696
            RETURN QUERY
697
'''+mk_insert(sql_gen.Values(row).to_str(db))+'''
698
;
699
        EXCEPTION
700
            WHEN unique_violation THEN NULL; -- continue to next row
701
        END;
702
    END LOOP;
703
END;\
704
'''
705
    else: query = mk_insert(select_query)
706
    
707
    if embeddable:
708
        # Create function
709
        function_name = sql_gen.clean_name(first_line)
710
        while True:
711
            try:
712
                function = db.TempFunction(function_name)
713
                
714
                function_query = '''\
715
CREATE FUNCTION '''+function.to_str(db)+'''()
716
RETURNS SETOF '''+return_type+'''
717
LANGUAGE '''+lang+'''
718
AS $$
719
'''+query+'''
720
$$;
721
'''
722
                run_query(db, function_query, recover=True, cacheable=True,
723
                    log_ignore_excs=(DuplicateException,))
724
                break # this version was successful
725
            except DuplicateException, e:
726
                function_name = next_version(function_name)
727
                # try again with next version of name
728
        
729
        # Return query that uses function
730
        cols = None
731
        if returning != None: cols = [returning]
732
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
733
            cols) # AS clause requires function alias
734
        return mk_select(db, func_table, start=0, order_by=None)
735
    
736
    return query
737

    
738
def insert_select(db, table, *args, **kw_args):
739
    '''For params, see mk_insert_select() and run_query_into()
740
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
741
        values in
742
    '''
743
    into = kw_args.pop('into', None)
744
    if into != None: kw_args['embeddable'] = True
745
    recover = kw_args.pop('recover', None)
746
    if kw_args.get('ignore', False): recover = True
747
    cacheable = kw_args.pop('cacheable', True)
748
    log_level = kw_args.pop('log_level', 2)
749
    
750
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
751
        into, recover=recover, cacheable=cacheable, log_level=log_level)
752
    autoanalyze(db, table)
753
    return cur
754

    
755
default = sql_gen.default # tells insert() to use the default value for a column
756

    
757
def insert(db, table, row, *args, **kw_args):
758
    '''For params, see insert_select()'''
759
    if lists.is_seq(row): cols = None
760
    else:
761
        cols = row.keys()
762
        row = row.values()
763
    row = list(row) # ensure that "== []" works
764
    
765
    if row == []: query = None
766
    else: query = sql_gen.Values(row).to_str(db)
767
    
768
    return insert_select(db, table, cols, query, *args, **kw_args)
769

    
770
def mk_update(db, table, changes=None, cond=None, in_place=False):
771
    '''
772
    @param changes [(col, new_value),...]
773
        * container can be any iterable type
774
        * col: sql_gen.Code|str (for col name)
775
        * new_value: sql_gen.Code|literal value
776
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
777
    @param in_place If set, locks the table and updates rows in place.
778
        This avoids creating dead rows in PostgreSQL.
779
        * cond must be None
780
    @return str query
781
    '''
782
    table = sql_gen.as_Table(table)
783
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
784
        for c, v in changes]
785
    
786
    if in_place:
787
        assert cond == None
788
        
789
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
790
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '
791
            +db.col_info(sql_gen.with_default_table(c, table)).type
792
            +'\nUSING '+v.to_str(db) for c, v in changes))
793
    else:
794
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
795
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
796
            for c, v in changes))
797
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
798
    
799
    return query
800

    
801
def update(db, table, *args, **kw_args):
802
    '''For params, see mk_update() and run_query()'''
803
    recover = kw_args.pop('recover', None)
804
    cacheable = kw_args.pop('cacheable', False)
805
    log_level = kw_args.pop('log_level', 2)
806
    
807
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
808
        cacheable, log_level=log_level)
809
    autoanalyze(db, table)
810
    return cur
811

    
812
def last_insert_id(db):
813
    module = util.root_module(db.db)
814
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
815
    elif module == 'MySQLdb': return db.insert_id()
816
    else: return None
817

    
818
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
819
    '''Creates a mapping from original column names (which may have collisions)
820
    to names that will be distinct among the columns' tables.
821
    This is meant to be used for several tables that are being joined together.
822
    @param cols The columns to combine. Duplicates will be removed.
823
    @param into The table for the new columns.
824
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
825
        columns will be included in the mapping even if they are not in cols.
826
        The tables of the provided Col objects will be changed to into, so make
827
        copies of them if you want to keep the original tables.
828
    @param as_items Whether to return a list of dict items instead of a dict
829
    @return dict(orig_col=new_col, ...)
830
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
831
        * new_col: sql_gen.Col(orig_col_name, into)
832
        * All mappings use the into table so its name can easily be
833
          changed for all columns at once
834
    '''
835
    cols = lists.uniqify(cols)
836
    
837
    items = []
838
    for col in preserve:
839
        orig_col = copy.copy(col)
840
        col.table = into
841
        items.append((orig_col, col))
842
    preserve = set(preserve)
843
    for col in cols:
844
        if col not in preserve:
845
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
846
    
847
    if not as_items: items = dict(items)
848
    return items
849

    
850
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
851
    '''For params, see mk_flatten_mapping()
852
    @return See return value of mk_flatten_mapping()
853
    '''
854
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
855
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
856
    run_query_into(db, mk_select(db, joins, cols, limit=limit, start=start),
857
        into=into, add_indexes_=True)
858
    return dict(items)
859

    
860
##### Database structure introspection
861

    
862
#### Tables
863

    
864
def tables(db, schema_like='public', table_like='%', exact=False):
865
    if exact: compare = '='
866
    else: compare = 'LIKE'
867
    
868
    module = util.root_module(db.db)
869
    if module == 'psycopg2':
870
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
871
            ('tablename', sql_gen.CompareCond(table_like, compare))]
872
        return values(select(db, 'pg_tables', ['tablename'], conds,
873
            order_by='tablename', log_level=4))
874
    elif module == 'MySQLdb':
875
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
876
            , cacheable=True, log_level=4))
877
    else: raise NotImplementedError("Can't list tables for "+module+' database')
878

    
879
def table_exists(db, table):
880
    table = sql_gen.as_Table(table)
881
    return list(tables(db, table.schema, table.name, exact=True)) != []
882

    
883
def table_row_count(db, table, recover=None):
884
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
885
        order_by=None, start=0), recover=recover, log_level=3))
886

    
887
def table_cols(db, table, recover=None):
888
    return list(col_names(select(db, table, limit=0, order_by=None,
889
        recover=recover, log_level=4)))
890

    
891
def pkey(db, table, recover=None):
892
    '''Assumed to be first column in table'''
893
    return table_cols(db, table, recover)[0]
894

    
895
not_null_col = 'not_null_col'
896

    
897
def table_not_null_col(db, table, recover=None):
898
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
899
    if not_null_col in table_cols(db, table, recover): return not_null_col
900
    else: return pkey(db, table, recover)
901

    
902
def index_cols(db, table, index):
903
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
904
    automatically created. When you don't know whether something is a UNIQUE
905
    constraint or a UNIQUE index, use this function.'''
906
    module = util.root_module(db.db)
907
    if module == 'psycopg2':
908
        return list(values(run_query(db, '''\
909
SELECT attname
910
FROM
911
(
912
        SELECT attnum, attname
913
        FROM pg_index
914
        JOIN pg_class index ON index.oid = indexrelid
915
        JOIN pg_class table_ ON table_.oid = indrelid
916
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
917
        WHERE
918
            table_.relname = '''+db.esc_value(table)+'''
919
            AND index.relname = '''+db.esc_value(index)+'''
920
    UNION
921
        SELECT attnum, attname
922
        FROM
923
        (
924
            SELECT
925
                indrelid
926
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
927
                    AS indkey
928
            FROM pg_index
929
            JOIN pg_class index ON index.oid = indexrelid
930
            JOIN pg_class table_ ON table_.oid = indrelid
931
            WHERE
932
                table_.relname = '''+db.esc_value(table)+'''
933
                AND index.relname = '''+db.esc_value(index)+'''
934
        ) s
935
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
936
) s
937
ORDER BY attnum
938
'''
939
            , cacheable=True, log_level=4)))
940
    else: raise NotImplementedError("Can't list index columns for "+module+
941
        ' database')
942

    
943
def constraint_cols(db, table, constraint):
944
    module = util.root_module(db.db)
945
    if module == 'psycopg2':
946
        return list(values(run_query(db, '''\
947
SELECT attname
948
FROM pg_constraint
949
JOIN pg_class ON pg_class.oid = conrelid
950
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
951
WHERE
952
    relname = '''+db.esc_value(table)+'''
953
    AND conname = '''+db.esc_value(constraint)+'''
954
ORDER BY attnum
955
'''
956
            )))
957
    else: raise NotImplementedError("Can't list constraint columns for "+module+
958
        ' database')
959

    
960
#### Functions
961

    
962
def function_exists(db, function):
963
    function = sql_gen.as_Function(function)
964
    
965
    info_table = sql_gen.Table('routines', 'information_schema')
966
    conds = [('routine_name', function.name)]
967
    schema = function.schema
968
    if schema != None: conds.append(('routine_schema', schema))
969
    # Exclude trigger functions, since they cannot be called directly
970
    conds.append(('data_type', sql_gen.CompareCond('trigger', '!=')))
971
    
972
    return list(values(select(db, info_table, ['routine_name'], conds,
973
        order_by='routine_schema', limit=1, log_level=4))) != []
974
        # TODO: order_by search_path schema order
975

    
976
##### Structural changes
977

    
978
#### Columns
979

    
980
def add_col(db, table, col, comment=None, **kw_args):
981
    '''
982
    @param col TypedCol Name may be versioned, so be sure to propagate any
983
        renaming back to any source column for the TypedCol.
984
    @param comment None|str SQL comment used to distinguish columns of the same
985
        name from each other when they contain different data, to allow the
986
        ADD COLUMN query to be cached. If not set, query will not be cached.
987
    '''
988
    assert isinstance(col, sql_gen.TypedCol)
989
    
990
    while True:
991
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
992
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
993
        
994
        try:
995
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
996
            break
997
        except DuplicateException:
998
            col.name = next_version(col.name)
999
            # try again with next version of name
1000

    
1001
def add_not_null(db, col):
1002
    table = col.table
1003
    col = sql_gen.to_name_only_col(col)
1004
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
1005
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)
1006

    
1007
row_num_col = '_row_num'
1008

    
1009
row_num_typed_col = sql_gen.TypedCol(row_num_col, 'serial', nullable=False,
1010
    constraints='PRIMARY KEY')
1011

    
1012
def add_row_num(db, table):
1013
    '''Adds a row number column to a table. Its name is in row_num_col. It will
1014
    be the primary key.'''
1015
    add_col(db, table, row_num_typed_col, log_level=3)
1016

    
1017
#### Indexes
1018

    
1019
def add_pkey(db, table, cols=None, recover=None):
1020
    '''Adds a primary key.
1021
    @param cols [sql_gen.Col,...] The columns in the primary key.
1022
        Defaults to the first column in the table.
1023
    @pre The table must not already have a primary key.
1024
    '''
1025
    table = sql_gen.as_Table(table)
1026
    if cols == None: cols = [pkey(db, table, recover)]
1027
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
1028
    
1029
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
1030
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
1031
        log_ignore_excs=(DuplicateException,))
1032

    
1033
def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
1034
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
1035
    Currently, only function calls are supported as expressions.
1036
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
1037
        This allows indexes to be used for comparisons where NULLs are equal.
1038
    '''
1039
    exprs = lists.mk_seq(exprs)
1040
    
1041
    # Parse exprs
1042
    old_exprs = exprs[:]
1043
    exprs = []
1044
    cols = []
1045
    for i, expr in enumerate(old_exprs):
1046
        expr = sql_gen.as_Col(expr, table)
1047
        
1048
        # Handle nullable columns
1049
        if ensure_not_null_:
1050
            try: expr = ensure_not_null(db, expr)
1051
            except KeyError: pass # unknown type, so just create plain index
1052
        
1053
        # Extract col
1054
        expr = copy.deepcopy(expr) # don't modify input!
1055
        if isinstance(expr, sql_gen.FunctionCall):
1056
            col = expr.args[0]
1057
            expr = sql_gen.Expr(expr)
1058
        else: col = expr
1059
        assert isinstance(col, sql_gen.Col)
1060
        
1061
        # Extract table
1062
        if table == None:
1063
            assert sql_gen.is_table_col(col)
1064
            table = col.table
1065
        
1066
        col.table = None
1067
        
1068
        exprs.append(expr)
1069
        cols.append(col)
1070
    
1071
    table = sql_gen.as_Table(table)
1072
    index = sql_gen.Table(str(sql_gen.Col(','.join(map(str, cols)), table)))
1073
    
1074
    # Add index
1075
    while True:
1076
        str_ = 'CREATE'
1077
        if unique: str_ += ' UNIQUE'
1078
        str_ += ' INDEX '+index.to_str(db)+' ON '+table.to_str(db)+' ('+(
1079
            ', '.join((v.to_str(db) for v in exprs)))+')'
1080
        
1081
        try:
1082
            run_query(db, str_, recover=True, cacheable=True, log_level=3,
1083
                log_ignore_excs=(DuplicateException,))
1084
            break
1085
        except DuplicateException:
1086
            index.name = next_version(index.name)
1087
            # try again with next version of name
1088

    
1089
def add_index_col(db, col, suffix, expr, nullable=True):
1090
    if sql_gen.index_col(col) != None: return # already has index col
1091
    
1092
    new_col = sql_gen.suffixed_col(col, suffix)
1093
    
1094
    # Add column
1095
    new_typed_col = sql_gen.TypedCol(new_col.name, db.col_info(col).type)
1096
    add_col(db, col.table, new_typed_col, comment='src: '+repr(col),
1097
        log_level=3)
1098
    new_col.name = new_typed_col.name # propagate any renaming
1099
    
1100
    update(db, col.table, [(new_col, expr)], in_place=True, cacheable=True,
1101
        log_level=3)
1102
    if not nullable: add_not_null(db, new_col)
1103
    add_index(db, new_col)
1104
    
1105
    col.table.index_cols[col.name] = new_col
1106

    
1107
# Controls when ensure_not_null() will use index columns
1108
not_null_index_cols_min_rows = 0 # rows; initially always use index columns
1109

    
1110
def ensure_not_null(db, col):
1111
    '''For params, see sql_gen.ensure_not_null()'''
1112
    expr = sql_gen.ensure_not_null(db, col)
1113
    
1114
    # If a nullable column in a temp table, add separate index column instead.
1115
    # Note that for small datasources, this adds 6-25% to the total import time.
1116
    if (sql_gen.is_temp_col(col) and isinstance(expr, sql_gen.EnsureNotNull)
1117
        and table_row_count(db, col.table) >= not_null_index_cols_min_rows):
1118
        add_index_col(db, col, '::NOT NULL', expr, nullable=False)
1119
        expr = sql_gen.index_col(col)
1120
    
1121
    return expr
1122

    
1123
#### Tables
1124

    
1125
### Maintenance
1126

    
1127
def analyze(db, table):
1128
    table = sql_gen.as_Table(table)
1129
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)
1130

    
1131
def autoanalyze(db, table):
1132
    if db.autoanalyze: analyze(db, table)
1133

    
1134
def vacuum(db, table):
1135
    table = sql_gen.as_Table(table)
1136
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1137
        log_level=3))
1138

    
1139
### Lifecycle
1140

    
1141
def drop_table(db, table):
1142
    table = sql_gen.as_Table(table)
1143
    return run_query(db, 'DROP TABLE IF EXISTS '+table.to_str(db)+' CASCADE')
1144

    
1145
def create_table(db, table, cols, has_pkey=True, col_indexes=True):
1146
    '''Creates a table.
1147
    @param cols [sql_gen.TypedCol,...] The column names and types
1148
    @param has_pkey If set, the first column becomes the primary key.
1149
    @param col_indexes bool|[ref]
1150
        * If True, indexes will be added on all non-pkey columns.
1151
        * If a list reference, [0] will be set to a function to do this.
1152
          This can be used to delay index creation until the table is populated.
1153
    '''
1154
    table = sql_gen.as_Table(table)
1155
    
1156
    if has_pkey:
1157
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1158
        pkey.constraints = 'PRIMARY KEY'
1159
    
1160
    str_ = 'CREATE TABLE '+table.to_str(db)+' (\n'
1161
    str_ += '\n, '.join(v.to_str(db) for v in cols)
1162
    str_ += '\n);\n'
1163
    run_query(db, str_, cacheable=True, log_level=2)
1164
    
1165
    # Add indexes
1166
    if has_pkey: has_pkey = already_indexed
1167
    def add_indexes_(): add_indexes(db, table, has_pkey)
1168
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1169
    elif col_indexes: add_indexes_() # add now
1170

    
1171
already_indexed = object() # tells add_indexes() the pkey has already been added
1172

    
1173
def add_indexes(db, table, has_pkey=True):
1174
    '''Adds an index on all columns in a table.
1175
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
1176
        index should be added on the first column.
1177
        * If already_indexed, the pkey is assumed to have already been added
1178
    '''
1179
    cols = table_cols(db, table)
1180
    if has_pkey:
1181
        if has_pkey is not already_indexed: add_pkey(db, table)
1182
        cols = cols[1:]
1183
    for col in cols: add_index(db, col, table)
1184

    
1185
### Data
1186

    
1187
def truncate(db, table, schema='public', **kw_args):
1188
    '''For params, see run_query()'''
1189
    table = sql_gen.as_Table(table, schema)
1190
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)
1191

    
1192
def empty_temp(db, tables):
1193
    if db.debug_temp: return # leave temp tables there for debugging
1194
    tables = lists.mk_seq(tables)
1195
    for table in tables: truncate(db, table, log_level=3)
1196

    
1197
def empty_db(db, schema='public', **kw_args):
1198
    '''For kw_args, see tables()'''
1199
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1200

    
1201
##### Database management
1202

    
1203
##### Data cleanup
1204

    
1205
def cleanup_table(db, table, cols):
1206
    table = sql_gen.as_Table(table)
1207
    cols = map(sql_gen.as_Col, cols)
1208
    
1209
    expr = ('nullif(nullif(trim(both from %s), '+db.esc_value('')+'), '
1210
        +db.esc_value(r'\N')+')')
1211
    changes = [(v, sql_gen.CustomCode(expr % v.to_str(db)))
1212
        for v in cols]
1213
    
1214
    update(db, table, changes, in_place=True)
(24-24/37)