Project

General

Profile

1
# Database access
2

    
3
import copy
4
import operator
5
import re
6
import warnings
7

    
8
import exc
9
import dicts
10
import iters
11
import lists
12
from Proxy import Proxy
13
import rand
14
import sql_gen
15
import strings
16
import util
17

    
18
##### Exceptions
19

    
20
def get_cur_query(cur, input_query=None):
21
    raw_query = None
22
    if hasattr(cur, 'query'): raw_query = cur.query
23
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
24
    
25
    if raw_query != None: return raw_query
26
    else: return '[input] '+strings.ustr(input_query)
27

    
28
def _add_cursor_info(e, *args, **kw_args):
29
    '''For params, see get_cur_query()'''
30
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))
31

    
32
class DbException(exc.ExceptionWithCause):
33
    def __init__(self, msg, cause=None, cur=None):
34
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
35
        if cur != None: _add_cursor_info(self, cur)
36

    
37
class ExceptionWithName(DbException):
38
    def __init__(self, name, cause=None):
39
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
40
        self.name = name
41

    
42
class ExceptionWithNameValue(DbException):
43
    def __init__(self, name, value, cause=None):
44
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name))
45
            +'; value: '+strings.as_tt(repr(value)), cause)
46
        self.name = name
47
        self.value = value
48

    
49
class ConstraintException(DbException):
50
    def __init__(self, name, cols, cause=None):
51
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
52
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
53
        self.name = name
54
        self.cols = cols
55

    
56
class MissingCastException(DbException):
57
    def __init__(self, type_, col, cause=None):
58
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
59
            +' on column: '+strings.as_tt(col), cause)
60
        self.type = type_
61
        self.col = col
62

    
63
class NameException(DbException): pass
64

    
65
class DuplicateKeyException(ConstraintException): pass
66

    
67
class NullValueException(ConstraintException): pass
68

    
69
class FunctionValueException(ExceptionWithNameValue): pass
70

    
71
class DuplicateTableException(ExceptionWithName): pass
72

    
73
class DuplicateFunctionException(ExceptionWithName): pass
74

    
75
class EmptyRowException(DbException): pass
76

    
77
##### Warnings
78

    
79
class DbWarning(UserWarning): pass
80

    
81
##### Result retrieval
82

    
83
def col_names(cur): return (col[0] for col in cur.description)
84

    
85
def rows(cur): return iter(lambda: cur.fetchone(), None)
86

    
87
def consume_rows(cur):
88
    '''Used to fetch all rows so result will be cached'''
89
    iters.consume_iter(rows(cur))
90

    
91
def next_row(cur): return rows(cur).next()
92

    
93
def row(cur):
94
    row_ = next_row(cur)
95
    consume_rows(cur)
96
    return row_
97

    
98
def next_value(cur): return next_row(cur)[0]
99

    
100
def value(cur): return row(cur)[0]
101

    
102
def values(cur): return iters.func_iter(lambda: next_value(cur))
103

    
104
def value_or_none(cur):
105
    try: return value(cur)
106
    except StopIteration: return None
107

    
108
##### Escaping
109

    
110
def esc_name_by_module(module, name):
111
    if module == 'psycopg2' or module == None: quote = '"'
112
    elif module == 'MySQLdb': quote = '`'
113
    else: raise NotImplementedError("Can't escape name for "+module+' database')
114
    return sql_gen.esc_name(name, quote)
115

    
116
def esc_name_by_engine(engine, name, **kw_args):
117
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
118

    
119
def esc_name(db, name, **kw_args):
120
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
121

    
122
def qual_name(db, schema, table):
123
    def esc_name_(name): return esc_name(db, name)
124
    table = esc_name_(table)
125
    if schema != None: return esc_name_(schema)+'.'+table
126
    else: return table
127

    
128
##### Database connections
129

    
130
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
131

    
132
db_engines = {
133
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
134
    'PostgreSQL': ('psycopg2', {}),
135
}
136

    
137
DatabaseErrors_set = set([DbException])
138
DatabaseErrors = tuple(DatabaseErrors_set)
139

    
140
def _add_module(module):
141
    DatabaseErrors_set.add(module.DatabaseError)
142
    global DatabaseErrors
143
    DatabaseErrors = tuple(DatabaseErrors_set)
144

    
145
def db_config_str(db_config):
146
    return db_config['engine']+' database '+db_config['database']
147

    
148
log_debug_none = lambda msg, level=2: None
149

    
150
class DbConn:
151
    def __init__(self, db_config, serializable=True, autocommit=False,
152
        caching=True, log_debug=log_debug_none):
153
        self.db_config = db_config
154
        self.serializable = serializable
155
        self.autocommit = autocommit
156
        self.caching = caching
157
        self.log_debug = log_debug
158
        self.debug = log_debug != log_debug_none
159
        
160
        self.__db = None
161
        self.query_results = {}
162
        self._savepoint = 0
163
        self._notices_seen = set()
164
    
165
    def __getattr__(self, name):
166
        if name == '__dict__': raise Exception('getting __dict__')
167
        if name == 'db': return self._db()
168
        else: raise AttributeError()
169
    
170
    def __getstate__(self):
171
        state = copy.copy(self.__dict__) # shallow copy
172
        state['log_debug'] = None # don't pickle the debug callback
173
        state['_DbConn__db'] = None # don't pickle the connection
174
        return state
175
    
176
    def connected(self): return self.__db != None
177
    
178
    def _db(self):
179
        if self.__db == None:
180
            # Process db_config
181
            db_config = self.db_config.copy() # don't modify input!
182
            schemas = db_config.pop('schemas', None)
183
            module_name, mappings = db_engines[db_config.pop('engine')]
184
            module = __import__(module_name)
185
            _add_module(module)
186
            for orig, new in mappings.iteritems():
187
                try: util.rename_key(db_config, orig, new)
188
                except KeyError: pass
189
            
190
            # Connect
191
            self.__db = module.connect(**db_config)
192
            
193
            # Configure connection
194
            if self.serializable and not self.autocommit: self.run_query(
195
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE', log_level=4)
196
            if schemas != None:
197
                schemas_ = ''.join((esc_name(self, s)+', '
198
                    for s in schemas.split(',')))
199
                self.run_query("SELECT set_config('search_path', "
200
                    +self.esc_value(schemas_)
201
                    +" || current_setting('search_path'), false)", log_level=3)
202
        
203
        return self.__db
204
    
205
    class DbCursor(Proxy):
206
        def __init__(self, outer):
207
            Proxy.__init__(self, outer.db.cursor())
208
            self.outer = outer
209
            self.query_results = outer.query_results
210
            self.query_lookup = None
211
            self.result = []
212
        
213
        def execute(self, query):
214
            self._is_insert = query.startswith('INSERT')
215
            self.query_lookup = query
216
            try:
217
                try:
218
                    cur = self.inner.execute(query)
219
                    self.outer.do_autocommit()
220
                finally: self.query = get_cur_query(self.inner, query)
221
            except Exception, e:
222
                _add_cursor_info(e, self, query)
223
                self.result = e # cache the exception as the result
224
                self._cache_result()
225
                raise
226
            if self.rowcount == 0 and query.startswith('SELECT'): # empty SELECT
227
                consume_rows(self) # fetch all rows so result will be cached
228
            return cur
229
        
230
        def fetchone(self):
231
            row = self.inner.fetchone()
232
            if row != None: self.result.append(row)
233
            # otherwise, fetched all rows
234
            else: self._cache_result()
235
            return row
236
        
237
        def _cache_result(self):
238
            # For inserts, only cache exceptions since inserts are not
239
            # idempotent, but an invalid insert will always be invalid
240
            if self.query_results != None and (not self._is_insert
241
                or isinstance(self.result, Exception)):
242
                
243
                assert self.query_lookup != None
244
                self.query_results[self.query_lookup] = self.CacheCursor(
245
                    util.dict_subset(dicts.AttrsDictView(self),
246
                    ['query', 'result', 'rowcount', 'description']))
247
        
248
        class CacheCursor:
249
            def __init__(self, cached_result): self.__dict__ = cached_result
250
            
251
            def execute(self, *args, **kw_args):
252
                if isinstance(self.result, Exception): raise self.result
253
                # otherwise, result is a rows list
254
                self.iter = iter(self.result)
255
            
256
            def fetchone(self):
257
                try: return self.iter.next()
258
                except StopIteration: return None
259
    
260
    def esc_value(self, value):
261
        try: str_ = self.mogrify('%s', [value])
262
        except NotImplementedError, e:
263
            module = util.root_module(self.db)
264
            if module == 'MySQLdb':
265
                import _mysql
266
                str_ = _mysql.escape_string(value)
267
            else: raise e
268
        return strings.to_unicode(str_)
269
    
270
    def esc_name(self, name): return esc_name(self, name) # calls global func
271
    
272
    def std_code(self, str_):
273
        '''Standardizes SQL code.
274
        * Ensures that string literals are prefixed by `E`
275
        '''
276
        if str_.startswith("'"): str_ = 'E'+str_
277
        return str_
278
    
279
    def can_mogrify(self):
280
        module = util.root_module(self.db)
281
        return module == 'psycopg2'
282
    
283
    def mogrify(self, query, params=None):
284
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
285
        else: raise NotImplementedError("Can't mogrify query")
286
    
287
    def print_notices(self):
288
        if hasattr(self.db, 'notices'):
289
            for msg in self.db.notices:
290
                if msg not in self._notices_seen:
291
                    self._notices_seen.add(msg)
292
                    self.log_debug(msg, level=2)
293
    
294
    def run_query(self, query, cacheable=False, log_level=2,
295
        debug_msg_ref=None):
296
        '''
297
        @param log_ignore_excs The log_level will be increased by 2 if the query
298
            throws one of these exceptions.
299
        @param debug_msg_ref If specified, the log message will be returned in
300
            this instead of being output. This allows you to filter log messages
301
            depending on the result of the query.
302
        '''
303
        assert query != None
304
        
305
        if not self.caching: cacheable = False
306
        used_cache = False
307
        
308
        def log_msg(query):
309
            if used_cache: cache_status = 'cache hit'
310
            elif cacheable: cache_status = 'cache miss'
311
            else: cache_status = 'non-cacheable'
312
            return 'DB query: '+cache_status+':\n'+strings.as_code(query, 'SQL')
313
        
314
        try:
315
            # Get cursor
316
            if cacheable:
317
                try:
318
                    cur = self.query_results[query]
319
                    used_cache = True
320
                except KeyError: cur = self.DbCursor(self)
321
            else: cur = self.db.cursor()
322
            
323
            # Log query
324
            if self.debug and debug_msg_ref == None: # log before running
325
                self.log_debug(log_msg(query), log_level)
326
            
327
            # Run query
328
            cur.execute(query)
329
        finally:
330
            self.print_notices()
331
            if self.debug and debug_msg_ref != None: # return after running
332
                debug_msg_ref[0] = log_msg(str(get_cur_query(cur, query)))
333
        
334
        return cur
335
    
336
    def is_cached(self, query): return query in self.query_results
337
    
338
    def with_autocommit(self, func, autocommit=True):
339
        import psycopg2.extensions
340
        if autocommit:
341
            isolation_level = psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT
342
        else: isolation_level = psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE
343
        
344
        prev_isolation_level = self.db.isolation_level
345
        self.db.set_isolation_level(isolation_level)
346
        try: return func()
347
        finally: self.db.set_isolation_level(prev_isolation_level)
348
    
349
    def with_savepoint(self, func):
350
        savepoint = 'level_'+str(self._savepoint)
351
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
352
        self._savepoint += 1
353
        try:
354
            try: return_val = func()
355
            finally:
356
                self._savepoint -= 1
357
                assert self._savepoint >= 0
358
        except:
359
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
360
            raise
361
        else:
362
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
363
            self.do_autocommit()
364
            return return_val
365
    
366
    def do_autocommit(self):
367
        '''Autocommits if outside savepoint'''
368
        assert self._savepoint >= 0
369
        if self.autocommit and self._savepoint == 0:
370
            self.log_debug('Autocommitting')
371
            self.db.commit()
372
    
373
    def col_info(self, col):
374
        table = sql_gen.Table('columns', 'information_schema')
375
        cols = ['data_type', 'column_default', 'is_nullable']
376
        
377
        conds = [('table_name', col.table.name), ('column_name', col.name)]
378
        schema = col.table.schema
379
        if schema != None: conds.append(('table_schema', schema))
380
        
381
        type_, default, nullable = row(select(self, table, cols, conds,
382
            order_by='table_schema', limit=1, log_level=3))
383
            # TODO: order_by search_path schema order
384
        default = sql_gen.as_Code(default, self)
385
        
386
        return sql_gen.TypedCol(col.name, type_, default, nullable)
387

    
388
connect = DbConn
389

    
390
##### Recoverable querying
391

    
392
def with_savepoint(db, func): return db.with_savepoint(func)
393

    
394
def run_query(db, query, recover=None, cacheable=False, log_level=2,
395
    log_ignore_excs=None, **kw_args):
396
    '''For params, see DbConn.run_query()'''
397
    if recover == None: recover = False
398
    if log_ignore_excs == None: log_ignore_excs = ()
399
    log_ignore_excs = tuple(log_ignore_excs)
400
    
401
    debug_msg_ref = None # usually, db.run_query() logs query before running it
402
    # But if filtering with log_ignore_excs, wait until after exception parsing
403
    if log_ignore_excs != () or not db.can_mogrify(): debug_msg_ref = [None] 
404
    
405
    try:
406
        try:
407
            def run(): return db.run_query(query, cacheable, log_level,
408
                debug_msg_ref, **kw_args)
409
            if recover and not db.is_cached(query):
410
                return with_savepoint(db, run)
411
            else: return run() # don't need savepoint if cached
412
        except Exception, e:
413
            if not recover: raise # need savepoint to run index_cols()
414
            msg = exc.str_(e)
415
            
416
            match = re.search(r'duplicate key value violates unique constraint '
417
                r'"((_?[^\W_]+)_.+?)"', msg)
418
            if match:
419
                constraint, table = match.groups()
420
                try: cols = index_cols(db, table, constraint)
421
                except NotImplementedError: raise e
422
                else: raise DuplicateKeyException(constraint, cols, e)
423
            
424
            match = re.search(r'null value in column "(.+?)" violates not-null'
425
                r' constraint', msg)
426
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
427
            
428
            match = re.search(r'\b(?:invalid input (?:syntax|value)\b.*?'
429
                r'|date/time field value out of range): "(.+?)"\n'
430
                r'(?:(?s).*?)\bfunction "(.+?)"', msg)
431
            if match:
432
                value, name = match.groups()
433
                raise FunctionValueException(name, strings.to_unicode(value), e)
434
            
435
            match = re.search(r'column "(.+?)" is of type (.+?) but expression '
436
                r'is of type', msg)
437
            if match:
438
                col, type_ = match.groups()
439
                raise MissingCastException(type_, col, e)
440
            
441
            match = re.search(r'relation "(.+?)" already exists', msg)
442
            if match: raise DuplicateTableException(match.group(1), e)
443
            
444
            match = re.search(r'function "(.+?)" already exists', msg)
445
            if match: raise DuplicateFunctionException(match.group(1), e)
446
            
447
            raise # no specific exception raised
448
    except log_ignore_excs:
449
        log_level += 2
450
        raise
451
    finally:
452
        if debug_msg_ref != None and debug_msg_ref[0] != None:
453
            db.log_debug(debug_msg_ref[0], log_level)
454

    
455
##### Basic queries
456

    
457
def next_version(name):
458
    version = 1 # first existing name was version 0
459
    match = re.match(r'^(.*)#(\d+)$', name)
460
    if match:
461
        name, version = match.groups()
462
        version = int(version)+1
463
    return sql_gen.add_suffix(name, '#'+str(version))
464

    
465
def run_query_into(db, query, into=None, add_indexes_=False, **kw_args):
466
    '''Outputs a query to a temp table.
467
    For params, see run_query().
468
    '''
469
    if into == None: return run_query(db, query, **kw_args)
470
    
471
    assert isinstance(into, sql_gen.Table)
472
    
473
    kw_args['recover'] = True
474
    kw_args.setdefault('log_ignore_excs', (DuplicateTableException,))
475
    
476
    temp = not db.autocommit # tables are permanent in autocommit mode
477
    # "temporary tables cannot specify a schema name", so remove schema
478
    if temp: into.schema = None
479
    
480
    # Create table
481
    while True:
482
        create_query = 'CREATE'
483
        if temp: create_query += ' TEMP'
484
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
485
        
486
        try:
487
            cur = run_query(db, create_query, **kw_args)
488
                # CREATE TABLE AS sets rowcount to # rows in query
489
            break
490
        except DuplicateTableException, e:
491
            into.name = next_version(into.name)
492
            # try again with next version of name
493
    
494
    if add_indexes_: add_indexes(db, into)
495
    
496
    return cur
497

    
498
order_by_pkey = object() # tells mk_select() to order by the pkey
499

    
500
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
501

    
502
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
503
    start=None, order_by=order_by_pkey, default_table=None):
504
    '''
505
    @param tables The single table to select from, or a list of tables to join
506
        together, with tables after the first being sql_gen.Join objects
507
    @param fields Use None to select all fields in the table
508
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
509
        * container can be any iterable type
510
        * compare_left_side: sql_gen.Code|str (for col name)
511
        * compare_right_side: sql_gen.ValueCond|literal value
512
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
513
        use all columns
514
    @return query
515
    '''
516
    # Parse tables param
517
    if not lists.is_seq(tables): tables = [tables]
518
    tables = list(tables) # don't modify input! (list() copies input)
519
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
520
    
521
    # Parse other params
522
    if conds == None: conds = []
523
    elif dicts.is_dict(conds): conds = conds.items()
524
    conds = list(conds) # don't modify input! (list() copies input)
525
    assert limit == None or type(limit) == int
526
    assert start == None or type(start) == int
527
    if order_by is order_by_pkey:
528
        if distinct_on != []: order_by = None
529
        else: order_by = pkey(db, table0, recover=True)
530
    
531
    query = 'SELECT'
532
    
533
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
534
    
535
    # DISTINCT ON columns
536
    if distinct_on != []:
537
        query += '\nDISTINCT'
538
        if distinct_on is not distinct_on_all:
539
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
540
    
541
    # Columns
542
    query += '\n'
543
    if fields == None: query += '*'
544
    else:
545
        assert fields != []
546
        query += '\n, '.join(map(parse_col, fields))
547
    
548
    # Main table
549
    query += '\nFROM '+table0.to_str(db)
550
    
551
    # Add joins
552
    left_table = table0
553
    for join_ in tables:
554
        table = join_.table
555
        
556
        # Parse special values
557
        if join_.type_ is sql_gen.filter_out: # filter no match
558
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
559
                None))
560
        
561
        query += '\n'+join_.to_str(db, left_table)
562
        
563
        left_table = table
564
    
565
    missing = True
566
    if conds != []:
567
        if len(conds) == 1: whitespace = ' '
568
        else: whitespace = '\n'
569
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
570
            .to_str(db) for l, r in conds], 'WHERE')
571
        missing = False
572
    if order_by != None:
573
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
574
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
575
    if start != None:
576
        if start != 0: query += '\nOFFSET '+str(start)
577
        missing = False
578
    if missing: warnings.warn(DbWarning(
579
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
580
    
581
    return query
582

    
583
def select(db, *args, **kw_args):
584
    '''For params, see mk_select() and run_query()'''
585
    recover = kw_args.pop('recover', None)
586
    cacheable = kw_args.pop('cacheable', True)
587
    log_level = kw_args.pop('log_level', 2)
588
    
589
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
590
        log_level=log_level)
591

    
592
def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
593
    embeddable=False):
594
    '''
595
    @param returning str|None An inserted column (such as pkey) to return
596
    @param embeddable Whether the query should be embeddable as a nested SELECT.
597
        Warning: If you set this and cacheable=True when the query is run, the
598
        query will be fully cached, not just if it raises an exception.
599
    '''
600
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
601
    if cols == []: cols = None # no cols (all defaults) = unknown col names
602
    if cols != None:
603
        cols = [sql_gen.to_name_only_col(v, table).to_str(db) for v in cols]
604
    if select_query == None: select_query = 'DEFAULT VALUES'
605
    if returning != None: returning = sql_gen.as_Col(returning, table)
606
    
607
    # Build query
608
    first_line = 'INSERT INTO '+table.to_str(db)
609
    query = first_line
610
    if cols != None: query += '\n('+', '.join(cols)+')'
611
    query += '\n'+select_query
612
    
613
    if returning != None:
614
        query += '\nRETURNING '+sql_gen.to_name_only_col(returning).to_str(db)
615
    
616
    if embeddable:
617
        assert returning != None
618
        
619
        # Create function
620
        function_name = sql_gen.clean_name(first_line)
621
        return_type = 'SETOF '+returning.to_str(db)+'%TYPE'
622
        while True:
623
            try:
624
                function = sql_gen.TempFunction(function_name, db.autocommit)
625
                
626
                function_query = '''\
627
CREATE FUNCTION '''+function.to_str(db)+'''()
628
RETURNS '''+return_type+'''
629
LANGUAGE sql
630
AS $$
631
'''+query+''';
632
$$;
633
'''
634
                run_query(db, function_query, recover=True, cacheable=True,
635
                    log_ignore_excs=(DuplicateFunctionException,))
636
                break # this version was successful
637
            except DuplicateFunctionException, e:
638
                function_name = next_version(function_name)
639
                # try again with next version of name
640
        
641
        # Return query that uses function
642
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
643
            [returning]) # AS clause requires function alias
644
        return mk_select(db, func_table, start=0, order_by=None)
645
    
646
    return query
647

    
648
def insert_select(db, *args, **kw_args):
649
    '''For params, see mk_insert_select() and run_query_into()
650
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
651
        values in
652
    '''
653
    into = kw_args.pop('into', None)
654
    if into != None: kw_args['embeddable'] = True
655
    recover = kw_args.pop('recover', None)
656
    cacheable = kw_args.pop('cacheable', True)
657
    log_level = kw_args.pop('log_level', 2)
658
    
659
    return run_query_into(db, mk_insert_select(db, *args, **kw_args), into,
660
        recover=recover, cacheable=cacheable, log_level=log_level)
661

    
662
default = sql_gen.default # tells insert() to use the default value for a column
663

    
664
def insert(db, table, row, *args, **kw_args):
665
    '''For params, see insert_select()'''
666
    if lists.is_seq(row): cols = None
667
    else:
668
        cols = row.keys()
669
        row = row.values()
670
    row = list(row) # ensure that "== []" works
671
    
672
    if row == []: query = None
673
    else: query = sql_gen.Values(row).to_str(db)
674
    
675
    return insert_select(db, table, cols, query, *args, **kw_args)
676

    
677
def mk_update(db, table, changes=None, cond=None):
678
    '''
679
    @param changes [(col, new_value),...]
680
        * container can be any iterable type
681
        * col: sql_gen.Code|str (for col name)
682
        * new_value: sql_gen.Code|literal value
683
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
684
    @return str query
685
    '''
686
    query = 'UPDATE '+sql_gen.as_Table(table).to_str(db)+'\nSET\n'
687
    query += ',\n'.join((sql_gen.to_name_only_col(col, table).to_str(db)+' = '
688
        +sql_gen.as_Value(new_value).to_str(db) for col, new_value in changes))
689
    if cond != None: query += '\nWHERE\n'+cond.to_str(db)
690
    
691
    return query
692

    
693
def update(db, *args, **kw_args):
694
    '''For params, see mk_update() and run_query()'''
695
    recover = kw_args.pop('recover', None)
696
    
697
    return run_query(db, mk_update(db, *args, **kw_args), recover)
698

    
699
def last_insert_id(db):
700
    module = util.root_module(db.db)
701
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
702
    elif module == 'MySQLdb': return db.insert_id()
703
    else: return None
704

    
705
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
706
    '''Creates a mapping from original column names (which may have collisions)
707
    to names that will be distinct among the columns' tables.
708
    This is meant to be used for several tables that are being joined together.
709
    @param cols The columns to combine. Duplicates will be removed.
710
    @param into The table for the new columns.
711
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
712
        columns will be included in the mapping even if they are not in cols.
713
        The tables of the provided Col objects will be changed to into, so make
714
        copies of them if you want to keep the original tables.
715
    @param as_items Whether to return a list of dict items instead of a dict
716
    @return dict(orig_col=new_col, ...)
717
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
718
        * new_col: sql_gen.Col(orig_col_name, into)
719
        * All mappings use the into table so its name can easily be
720
          changed for all columns at once
721
    '''
722
    cols = lists.uniqify(cols)
723
    
724
    items = []
725
    for col in preserve:
726
        orig_col = copy.copy(col)
727
        col.table = into
728
        items.append((orig_col, col))
729
    preserve = set(preserve)
730
    for col in cols:
731
        if col not in preserve:
732
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
733
    
734
    if not as_items: items = dict(items)
735
    return items
736

    
737
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
738
    '''For params, see mk_flatten_mapping()
739
    @return See return value of mk_flatten_mapping()
740
    '''
741
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
742
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
743
    run_query_into(db, mk_select(db, joins, cols, limit=limit, start=start),
744
        into=into)
745
    return dict(items)
746

    
747
def mk_track_data_error(db, errors_table, cols, value, error_code, error):
748
    assert cols != ()
749
    
750
    cols = map(sql_gen.to_name_only_col, cols)
751
    
752
    columns_cols = ['column']
753
    columns = sql_gen.NamedValues('columns', columns_cols,
754
        [[c.name] for c in cols])
755
    values_cols = ['value', 'error_code', 'error']
756
    values = sql_gen.NamedValues('values', values_cols,
757
        [value, error_code, error])
758
    
759
    select_cols = columns_cols+values_cols
760
    name_only_cols = map(sql_gen.to_name_only_col, select_cols)
761
    errors_table = sql_gen.NamedTable('errors', errors_table)
762
    joins = [columns, sql_gen.Join(values, type_='CROSS'),
763
        sql_gen.Join(errors_table, dict(zip(name_only_cols, select_cols)),
764
        sql_gen.filter_out)]
765
    
766
    return mk_insert_select(db, errors_table, name_only_cols,
767
        mk_select(db, joins, select_cols, order_by=None))
768

    
769
def track_data_error(db, errors_table, cols, *args, **kw_args):
770
    '''
771
    @param errors_table If None, does nothing.
772
    '''
773
    if errors_table == None or cols == (): return
774
    run_query(db, mk_track_data_error(db, errors_table, cols, *args, **kw_args),
775
        cacheable=True, log_level=4)
776

    
777
def cast(db, type_, col, errors_table=None):
778
    '''Casts an (unrenamed) column or value.
779
    If errors_table set and col has srcs, saves errors in errors_table (using
780
    col's srcs attr as the source columns) and converts errors to warnings.
781
    @param col str|sql_gen.Col|sql_gen.Literal
782
    @param errors_table None|sql_gen.Table|str
783
    '''
784
    col = sql_gen.as_Col(col)
785
    save_errors = (errors_table != None and isinstance(col, sql_gen.Col)
786
        and col.srcs != ())
787
    if not save_errors: # can't save errors
788
        return sql_gen.CustomCode(col.to_str(db)+'::'+type_) # just cast
789
    
790
    assert not isinstance(col, sql_gen.NamedCol)
791
    
792
    errors_table = sql_gen.as_Table(errors_table)
793
    srcs = map(sql_gen.to_name_only_col, col.srcs)
794
    function_name = str(sql_gen.FunctionCall(type_, *srcs))
795
    function = sql_gen.TempFunction(function_name, db.autocommit)
796
    
797
    while True:
798
        # Create function definition
799
        query = '''\
800
CREATE FUNCTION '''+function.to_str(db)+'''(value text)
801
RETURNS '''+type_+'''
802
LANGUAGE plpgsql
803
STRICT
804
AS $$
805
BEGIN
806
    /* The explicit cast to the return type is needed to make the cast happen
807
    inside the try block. (Implicit casts to the return type happen at the end
808
    of the function, outside any block.) */
809
    RETURN value::'''+type_+''';
810
EXCEPTION
811
    WHEN data_exception THEN
812
        -- Save error in errors table.
813
        -- Insert the value and error for *each* source column.
814
'''+mk_track_data_error(db, errors_table, srcs,
815
    *map(sql_gen.CustomCode, ['value', 'SQLSTATE', 'SQLERRM']))+''';
816
        
817
        RAISE WARNING '%', SQLERRM;
818
        RETURN NULL;
819
END;
820
$$;
821
'''
822
        
823
        # Create function
824
        try:
825
            run_query(db, query, recover=True, cacheable=True,
826
                log_ignore_excs=(DuplicateFunctionException,))
827
            break # successful
828
        except DuplicateFunctionException:
829
            function.name = next_version(function.name)
830
            # try again with next version of name
831
    
832
    return sql_gen.FunctionCall(function, col)
833

    
834
##### Database structure queries
835

    
836
def table_row_count(db, table, recover=None):
837
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
838
        order_by=None, start=0), recover=recover, log_level=3))
839

    
840
def table_cols(db, table, recover=None):
841
    return list(col_names(select(db, table, limit=0, order_by=None,
842
        recover=recover, log_level=4)))
843

    
844
def pkey(db, table, recover=None):
845
    '''Assumed to be first column in table'''
846
    return table_cols(db, table, recover)[0]
847

    
848
not_null_col = 'not_null_col'
849

    
850
def table_not_null_col(db, table, recover=None):
851
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
852
    if not_null_col in table_cols(db, table, recover): return not_null_col
853
    else: return pkey(db, table, recover)
854

    
855
def index_cols(db, table, index):
856
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
857
    automatically created. When you don't know whether something is a UNIQUE
858
    constraint or a UNIQUE index, use this function.'''
859
    module = util.root_module(db.db)
860
    if module == 'psycopg2':
861
        return list(values(run_query(db, '''\
862
SELECT attname
863
FROM
864
(
865
        SELECT attnum, attname
866
        FROM pg_index
867
        JOIN pg_class index ON index.oid = indexrelid
868
        JOIN pg_class table_ ON table_.oid = indrelid
869
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
870
        WHERE
871
            table_.relname = '''+db.esc_value(table)+'''
872
            AND index.relname = '''+db.esc_value(index)+'''
873
    UNION
874
        SELECT attnum, attname
875
        FROM
876
        (
877
            SELECT
878
                indrelid
879
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
880
                    AS indkey
881
            FROM pg_index
882
            JOIN pg_class index ON index.oid = indexrelid
883
            JOIN pg_class table_ ON table_.oid = indrelid
884
            WHERE
885
                table_.relname = '''+db.esc_value(table)+'''
886
                AND index.relname = '''+db.esc_value(index)+'''
887
        ) s
888
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
889
) s
890
ORDER BY attnum
891
'''
892
            , cacheable=True, log_level=4)))
893
    else: raise NotImplementedError("Can't list index columns for "+module+
894
        ' database')
895

    
896
def constraint_cols(db, table, constraint):
897
    module = util.root_module(db.db)
898
    if module == 'psycopg2':
899
        return list(values(run_query(db, '''\
900
SELECT attname
901
FROM pg_constraint
902
JOIN pg_class ON pg_class.oid = conrelid
903
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
904
WHERE
905
    relname = '''+db.esc_value(table)+'''
906
    AND conname = '''+db.esc_value(constraint)+'''
907
ORDER BY attnum
908
'''
909
            )))
910
    else: raise NotImplementedError("Can't list constraint columns for "+module+
911
        ' database')
912

    
913
row_num_col = '_row_num'
914

    
915
def add_index(db, exprs, table=None, unique=False):
916
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
917
    Currently, only function calls are supported as expressions.
918
    '''
919
    if not lists.is_seq(exprs): exprs = [exprs]
920
    
921
    # Parse exprs
922
    old_exprs = exprs[:]
923
    exprs = []
924
    cols = []
925
    for i, expr in enumerate(old_exprs):
926
        expr = copy.deepcopy(expr) # don't modify input!
927
        expr = sql_gen.as_Col(expr)
928
        
929
        # Extract col
930
        if isinstance(expr, sql_gen.FunctionCall):
931
            col = expr.args[0]
932
            expr = sql_gen.Expr(expr)
933
        else: col = expr
934
        
935
        # Extract table
936
        if table == None:
937
            assert sql_gen.is_table_col(col)
938
            table = col.table
939
        
940
        col.table = None
941
        
942
        exprs.append(expr)
943
        cols.append(col)
944
    
945
    table = sql_gen.as_Table(table)
946
    index = sql_gen.Table(str(sql_gen.Col(','.join(map(str, cols)), table)))
947
    
948
    str_ = 'CREATE'
949
    if unique: str_ += ' UNIQUE'
950
    str_ += ' INDEX '+index.to_str(db)+' ON '+table.to_str(db)+' ('+(
951
        ', '.join((v.to_str(db) for v in exprs)))+')'
952
    
953
    try: run_query(db, str_, recover=True, cacheable=True, log_level=3)
954
    except DuplicateTableException: pass # index already existed
955

    
956
def add_pkey(db, table, cols=None, recover=None):
957
    '''Adds a primary key.
958
    @param cols [sql_gen.Col,...] The columns in the primary key.
959
        Defaults to the first column in the table.
960
    @pre The table must not already have a primary key.
961
    '''
962
    table = sql_gen.as_Table(table)
963
    if cols == None: cols = [pkey(db, table, recover)]
964
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
965
    
966
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
967
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
968
        log_ignore_excs=(DuplicateTableException,))
969

    
970
already_indexed = object() # tells add_indexes() the pkey has already been added
971

    
972
def add_indexes(db, table, has_pkey=True):
973
    '''Adds an index on all columns in a table.
974
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
975
        index should be added on the first column.
976
        * If already_indexed, the pkey is assumed to have already been added
977
    '''
978
    cols = table_cols(db, table)
979
    if has_pkey:
980
        if has_pkey is not already_indexed: add_pkey(db, table)
981
        cols = cols[1:]
982
    for col in cols: add_index(db, col, table)
983

    
984
def add_row_num(db, table):
985
    '''Adds a row number column to a table. Its name is in row_num_col. It will
986
    be the primary key.'''
987
    table = sql_gen.as_Table(table).to_str(db)
988
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
989
        +' serial NOT NULL PRIMARY KEY', log_level=3)
990

    
991
def create_table(db, table, cols, has_pkey=True, col_indexes=True):
992
    '''Creates a table.
993
    @param cols [sql_gen.TypedCol,...] The column names and types
994
    @param has_pkey If set, the first column becomes the primary key.
995
    @param col_indexes bool|[ref]
996
        * If True, indexes will be added on all non-pkey columns.
997
        * If a list reference, [0] will be set to a function to do this.
998
          This can be used to delay index creation until the table is populated.
999
    '''
1000
    table = sql_gen.as_Table(table)
1001
    
1002
    if has_pkey:
1003
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1004
        pkey.type += ' NOT NULL PRIMARY KEY'
1005
    
1006
    str_ = 'CREATE TABLE '+table.to_str(db)+' (\n'
1007
    str_ += '\n, '.join(v.to_str(db) for v in cols)
1008
    str_ += '\n);\n'
1009
    run_query(db, str_, cacheable=True, log_level=2)
1010
    
1011
    # Add indexes
1012
    if has_pkey: has_pkey = already_indexed
1013
    def add_indexes_(): add_indexes(db, table, has_pkey)
1014
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
1015
    elif col_indexes: add_indexes_() # add now
1016

    
1017
def vacuum(db, table):
1018
    table = sql_gen.as_Table(table)
1019
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1020
        log_level=3))
1021

    
1022
def truncate(db, table, schema='public'):
1023
    table = sql_gen.as_Table(table, schema)
1024
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE')
1025

    
1026
def tables(db, schema_like='public', table_like='%', exact=False):
1027
    if exact: compare = '='
1028
    else: compare = 'LIKE'
1029
    
1030
    module = util.root_module(db.db)
1031
    if module == 'psycopg2':
1032
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
1033
            ('tablename', sql_gen.CompareCond(table_like, compare))]
1034
        return values(select(db, 'pg_tables', ['tablename'], conds,
1035
            order_by='tablename', log_level=4))
1036
    elif module == 'MySQLdb':
1037
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
1038
            , cacheable=True, log_level=4))
1039
    else: raise NotImplementedError("Can't list tables for "+module+' database')
1040

    
1041
def table_exists(db, table):
1042
    table = sql_gen.as_Table(table)
1043
    return list(tables(db, table.schema, table.name, exact=True)) != []
1044

    
1045
def errors_table(db, table, if_exists=True):
1046
    '''
1047
    @param if_exists If set, returns None if the errors table doesn't exist
1048
    @return None|sql_gen.Table
1049
    '''
1050
    table = sql_gen.as_Table(table)
1051
    if table.srcs != (): table = table.srcs[0]
1052
    
1053
    errors_table = sql_gen.suffixed_table(table, '.errors')
1054
    if if_exists and not table_exists(db, errors_table): return None
1055
    return errors_table
1056

    
1057
##### Database management
1058

    
1059
def empty_db(db, schema='public', **kw_args):
1060
    '''For kw_args, see tables()'''
1061
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1062

    
1063
##### Heuristic queries
1064

    
1065
def put(db, table, row, pkey_=None, row_ct_ref=None):
1066
    '''Recovers from errors.
1067
    Only works under PostgreSQL (uses INSERT RETURNING).
1068
    '''
1069
    row = sql_gen.ColDict(db, table, row)
1070
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
1071
    
1072
    try:
1073
        cur = insert(db, table, row, pkey_, recover=True)
1074
        if row_ct_ref != None and cur.rowcount >= 0:
1075
            row_ct_ref[0] += cur.rowcount
1076
        return value(cur)
1077
    except DuplicateKeyException, e:
1078
        row = sql_gen.ColDict(db, table,
1079
            util.dict_subset_right_join(row, e.cols))
1080
        return value(select(db, table, [pkey_], row, recover=True))
1081

    
1082
def get(db, table, row, pkey, row_ct_ref=None, create=False):
1083
    '''Recovers from errors'''
1084
    try: return value(select(db, table, [pkey], row, limit=1, recover=True))
1085
    except StopIteration:
1086
        if not create: raise
1087
        return put(db, table, row, pkey, row_ct_ref) # insert new row
1088

    
1089
def is_func_result(col):
1090
    return col.table.name.find('(') >= 0 and col.name == 'result'
1091

    
1092
def into_table_name(out_table, in_tables0, mapping, is_func):
1093
    def in_col_str(in_col):
1094
        in_col = sql_gen.remove_col_rename(in_col)
1095
        if isinstance(in_col, sql_gen.Col):
1096
            table = in_col.table
1097
            if table == in_tables0:
1098
                in_col = sql_gen.to_name_only_col(in_col)
1099
            elif is_func_result(in_col): in_col = table # omit col name
1100
        return str(in_col)
1101
    
1102
    str_ = str(out_table)
1103
    if is_func:
1104
        str_ += '('
1105
        
1106
        try: value_in_col = mapping['value']
1107
        except KeyError:
1108
            str_ += ', '.join((str(k)+'='+in_col_str(v)
1109
                for k, v in mapping.iteritems()))
1110
        else: str_ += in_col_str(value_in_col)
1111
        
1112
        str_ += ')'
1113
    else:
1114
        out_col = 'rank'
1115
        try: in_col = mapping[out_col]
1116
        except KeyError: str_ += '_pkeys'
1117
        else: # has a rank column, so hierarchical
1118
            str_ += '['+str(out_col)+'='+in_col_str(in_col)+']'
1119
    return str_
1120

    
1121
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, into=None,
1122
    default=None, is_func=False, on_error=exc.raise_):
1123
    '''Recovers from errors.
1124
    Only works under PostgreSQL (uses INSERT RETURNING).
1125
    @param in_tables The main input table to select from, followed by a list of
1126
        tables to join with it using the main input table's pkey
1127
    @param mapping dict(out_table_col=in_table_col, ...)
1128
        * out_table_col: str (*not* sql_gen.Col)
1129
        * in_table_col: sql_gen.Col|literal-value
1130
    @param into The table to contain the output and input pkeys.
1131
        Defaults to `out_table.name+'_pkeys'`.
1132
    @param default The *output* column to use as the pkey for missing rows.
1133
        If this output column does not exist in the mapping, uses None.
1134
    @param is_func Whether out_table is the name of a SQL function, not a table
1135
    @return sql_gen.Col Where the output pkeys are made available
1136
    '''
1137
    out_table = sql_gen.as_Table(out_table)
1138
    
1139
    def log_debug(msg): db.log_debug(msg, level=1.5)
1140
    def col_ustr(str_):
1141
        return strings.repr_no_u(sql_gen.remove_col_rename(str_))
1142
    
1143
    out_pkey = pkey(db, out_table, recover=True)
1144
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
1145
    
1146
    if mapping == {}: # need at least one column for INSERT SELECT
1147
        mapping = {out_pkey: None} # ColDict will replace with default value
1148
    
1149
    log_debug('********** New iteration **********')
1150
    log_debug('Inserting these input columns into '+strings.as_tt(
1151
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
1152
    
1153
    # Create input joins from list of input tables
1154
    in_tables_ = in_tables[:] # don't modify input!
1155
    in_tables0 = in_tables_.pop(0) # first table is separate
1156
    errors_table_ = errors_table(db, in_tables0)
1157
    in_pkey = pkey(db, in_tables0, recover=True)
1158
    in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
1159
    input_joins = [in_tables0]+[sql_gen.Join(v,
1160
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
1161
    
1162
    if into == None:
1163
        into = into_table_name(out_table, in_tables0, mapping, is_func)
1164
    into = sql_gen.as_Table(into)
1165
    
1166
    # Set column sources
1167
    in_cols = filter(sql_gen.is_table_col, mapping.values())
1168
    for col in in_cols:
1169
        if col.table == in_tables0: col.set_srcs(sql_gen.src_self)
1170
    
1171
    log_debug('Joining together input tables into temp table')
1172
    # Place in new table for speed and so don't modify input if values edited
1173
    in_table = sql_gen.Table('in')
1174
    mapping = dicts.join(mapping, flatten(db, in_table, input_joins, in_cols,
1175
        preserve=[in_pkey_col], start=0))
1176
    input_joins = [in_table]
1177
    db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
1178
    
1179
    mapping = sql_gen.ColDict(db, out_table, mapping)
1180
        # after applying dicts.join() because that returns a plain dict
1181
    
1182
    # Resolve default value column
1183
    try: default = mapping[default]
1184
    except KeyError:
1185
        if default != None:
1186
            db.log_debug('Default value column '
1187
                +strings.as_tt(strings.repr_no_u(default))
1188
                +' does not exist in mapping, falling back to None', level=2.1)
1189
            default = None
1190
    
1191
    pkeys_names = [in_pkey, out_pkey]
1192
    pkeys_cols = [in_pkey_col, out_pkey_col]
1193
    
1194
    pkeys_table_exists_ref = [False]
1195
    def insert_into_pkeys(joins, cols):
1196
        query = mk_select(db, joins, cols, order_by=None, start=0)
1197
        if pkeys_table_exists_ref[0]:
1198
            insert_select(db, into, pkeys_names, query)
1199
        else:
1200
            run_query_into(db, query, into=into)
1201
            pkeys_table_exists_ref[0] = True
1202
    
1203
    limit_ref = [None]
1204
    conds = set()
1205
    distinct_on = sql_gen.ColDict(db, out_table)
1206
    def mk_main_select(joins, cols):
1207
        distinct_on_cols = [c.to_Col() for c in distinct_on.values()]
1208
        return mk_select(db, joins, cols, conds, distinct_on_cols,
1209
            limit=limit_ref[0], start=0)
1210
    
1211
    exc_strs = set()
1212
    def log_exc(e):
1213
        e_str = exc.str_(e, first_line_only=True)
1214
        log_debug('Caught exception: '+e_str)
1215
        assert e_str not in exc_strs # avoid infinite loops
1216
        exc_strs.add(e_str)
1217
    
1218
    def remove_all_rows():
1219
        log_debug('Returning NULL for all rows')
1220
        limit_ref[0] = 0 # just create an empty pkeys table
1221
    
1222
    def ignore(in_col, value, e):
1223
        track_data_error(db, errors_table_, in_col.srcs, value, e.cause.pgcode,
1224
            e.cause.pgerror)
1225
        
1226
        in_col_str = strings.as_tt(repr(in_col))
1227
        db.log_debug('Adding index on '+in_col_str+' to enable fast filtering',
1228
            level=2.5)
1229
        add_index(db, in_col)
1230
        
1231
        log_debug('Ignoring rows with '+in_col_str+' = '
1232
            +strings.as_tt(repr(value)))
1233
    def remove_rows(in_col, value, e):
1234
        ignore(in_col, value, e)
1235
        cond = (in_col, sql_gen.CompareCond(value, '!='))
1236
        assert cond not in conds # avoid infinite loops
1237
        conds.add(cond)
1238
    def invalid2null(in_col, value, e):
1239
        ignore(in_col, value, e)
1240
        update(db, in_table, [(in_col, None)],
1241
            sql_gen.ColValueCond(in_col, value))
1242
    
1243
    def insert_pkeys_table(which):
1244
        return sql_gen.Table(sql_gen.add_suffix(in_table.name,
1245
            '_insert_'+which+'_pkeys'))
1246
    insert_out_pkeys = insert_pkeys_table('out')
1247
    insert_in_pkeys = insert_pkeys_table('in')
1248
    
1249
    # Do inserts and selects
1250
    join_cols = sql_gen.ColDict(db, out_table)
1251
    while True:
1252
        if limit_ref[0] == 0: # special case
1253
            log_debug('Creating an empty pkeys table')
1254
            cur = run_query_into(db, mk_select(db, out_table, [out_pkey],
1255
                limit=limit_ref[0]), into=insert_out_pkeys)
1256
            break # don't do main case
1257
        
1258
        has_joins = join_cols != {}
1259
        
1260
        # Prepare to insert new rows
1261
        insert_joins = input_joins[:] # don't modify original!
1262
        insert_args = dict(recover=True, cacheable=False)
1263
        if has_joins:
1264
            insert_joins.append(sql_gen.Join(out_table, join_cols,
1265
                sql_gen.filter_out))
1266
        else:
1267
            insert_args.update(dict(returning=out_pkey, into=insert_out_pkeys))
1268
        main_select = mk_main_select(insert_joins, mapping.values())
1269
        
1270
        log_debug('Trying to insert new rows')
1271
        try:
1272
            cur = insert_select(db, out_table, mapping.keys(), main_select,
1273
                **insert_args)
1274
            break # insert successful
1275
        except DuplicateKeyException, e:
1276
            log_exc(e)
1277
            
1278
            old_join_cols = join_cols.copy()
1279
            distinct_on.update(util.dict_subset(mapping, e.cols))
1280
            join_cols.update(util.dict_subset_right_join(mapping, e.cols))
1281
            log_debug('Ignoring existing rows, comparing on these columns:\n'
1282
                +strings.as_inline_table(join_cols, ustr=col_ustr))
1283
            assert join_cols != old_join_cols # avoid infinite loops
1284
        except NullValueException, e:
1285
            log_exc(e)
1286
            
1287
            out_col, = e.cols
1288
            try: in_col = mapping[out_col]
1289
            except KeyError:
1290
                log_debug('Missing mapping for NOT NULL column '+out_col)
1291
                remove_all_rows()
1292
            else: remove_rows(in_col, None, e)
1293
        except FunctionValueException, e:
1294
            log_exc(e)
1295
            
1296
            func_name = e.name
1297
            value = e.value
1298
            for out_col, in_col in mapping.iteritems():
1299
                in_col = sql_gen.unwrap_func_call(in_col, func_name)
1300
                invalid2null(in_col, value, e)
1301
        except MissingCastException, e:
1302
            log_exc(e)
1303
            
1304
            out_col = e.col
1305
            type_ = e.type
1306
            
1307
            log_debug('Casting '+strings.as_tt(out_col)+' input to '
1308
                +strings.as_tt(type_))
1309
            def wrap_func(col): return cast(db, type_, col, errors_table_)
1310
            mapping[out_col] = sql_gen.wrap(wrap_func, mapping[out_col])
1311
        except DatabaseErrors, e:
1312
            log_exc(e)
1313
            
1314
            log_debug('No handler for exception')
1315
            on_error(e)
1316
            remove_all_rows()
1317
        # after exception handled, rerun loop with additional constraints
1318
    
1319
    if row_ct_ref != None and cur.rowcount >= 0:
1320
        row_ct_ref[0] += cur.rowcount
1321
    
1322
    if has_joins:
1323
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols)]
1324
        log_debug('Getting output table pkeys of existing/inserted rows')
1325
        insert_into_pkeys(select_joins, pkeys_cols)
1326
    else:
1327
        add_row_num(db, insert_out_pkeys) # for joining with input pkeys
1328
        
1329
        log_debug('Getting input table pkeys of inserted rows')
1330
        run_query_into(db, mk_main_select(input_joins, [in_pkey]),
1331
            into=insert_in_pkeys)
1332
        add_row_num(db, insert_in_pkeys) # for joining with output pkeys
1333
        
1334
        assert table_row_count(db, insert_out_pkeys) == table_row_count(db,
1335
            insert_in_pkeys)
1336
        
1337
        log_debug('Combining output and input pkeys in inserted order')
1338
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
1339
            {row_num_col: sql_gen.join_same_not_null})]
1340
        insert_into_pkeys(pkey_joins, pkeys_names)
1341
    
1342
    db.log_debug('Adding pkey on pkeys table to enable fast joins', level=2.5)
1343
    add_pkey(db, into)
1344
    
1345
    log_debug('Setting pkeys of missing rows to '+strings.as_tt(repr(default)))
1346
    missing_rows_joins = input_joins+[sql_gen.Join(into,
1347
        {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
1348
        # must use join_same_not_null or query will take forever
1349
    insert_into_pkeys(missing_rows_joins,
1350
        [in_pkey_col, sql_gen.NamedCol(out_pkey, default)])
1351
    
1352
    assert table_row_count(db, into) == table_row_count(db, in_table)
1353
    
1354
    srcs = []
1355
    if is_func: srcs = sql_gen.cols_srcs(in_cols)
1356
    return sql_gen.Col(out_pkey, into, srcs)
1357

    
1358
##### Data cleanup
1359

    
1360
def cleanup_table(db, table, cols):
1361
    table = sql_gen.as_Table(table)
1362
    cols = map(sql_gen.as_Col, cols)
1363
    
1364
    expr = ('nullif(nullif(trim(both from %s), '+db.esc_value('')+'), '
1365
        +db.esc_value(r'\N')+')')
1366
    changes = [(v, sql_gen.CustomCode(expr % v.to_str(db)))
1367
        for v in cols]
1368
    
1369
    update(db, table, changes)
(24-24/36)