Project

General

Profile

1
# Database access
2

    
3
import copy
4
import operator
5
import re
6
import warnings
7

    
8
import exc
9
import dicts
10
import iters
11
import lists
12
from Proxy import Proxy
13
import rand
14
import sql_gen
15
import strings
16
import util
17

    
18
##### Exceptions
19

    
20
def get_cur_query(cur, input_query=None, input_params=None):
21
    raw_query = None
22
    if hasattr(cur, 'query'): raw_query = cur.query
23
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
24
    
25
    if raw_query != None: return raw_query
26
    else: return '[input] '+strings.ustr(input_query)+' % '+repr(input_params)
27

    
28
def _add_cursor_info(e, *args, **kw_args):
29
    '''For params, see get_cur_query()'''
30
    exc.add_msg(e, 'query: '+str(get_cur_query(*args, **kw_args)))
31

    
32
class DbException(exc.ExceptionWithCause):
33
    def __init__(self, msg, cause=None, cur=None):
34
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
35
        if cur != None: _add_cursor_info(self, cur)
36

    
37
class ExceptionWithName(DbException):
38
    def __init__(self, name, cause=None):
39
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
40
        self.name = name
41

    
42
class ExceptionWithNameValue(DbException):
43
    def __init__(self, name, value, cause=None):
44
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name))
45
            +'; value: '+strings.as_tt(repr(value)), cause)
46
        self.name = name
47
        self.value = value
48

    
49
class ConstraintException(DbException):
50
    def __init__(self, name, cols, cause=None):
51
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
52
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
53
        self.name = name
54
        self.cols = cols
55

    
56
class MissingCastException(DbException):
57
    def __init__(self, type_, col, cause=None):
58
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
59
            +' on column: '+strings.as_tt(col), cause)
60
        self.type = type_
61
        self.col = col
62

    
63
class NameException(DbException): pass
64

    
65
class DuplicateKeyException(ConstraintException): pass
66

    
67
class NullValueException(ConstraintException): pass
68

    
69
class FunctionValueException(ExceptionWithNameValue): pass
70

    
71
class DuplicateTableException(ExceptionWithName): pass
72

    
73
class DuplicateFunctionException(ExceptionWithName): pass
74

    
75
class EmptyRowException(DbException): pass
76

    
77
##### Warnings
78

    
79
class DbWarning(UserWarning): pass
80

    
81
##### Result retrieval
82

    
83
def col_names(cur): return (col[0] for col in cur.description)
84

    
85
def rows(cur): return iter(lambda: cur.fetchone(), None)
86

    
87
def consume_rows(cur):
88
    '''Used to fetch all rows so result will be cached'''
89
    iters.consume_iter(rows(cur))
90

    
91
def next_row(cur): return rows(cur).next()
92

    
93
def row(cur):
94
    row_ = next_row(cur)
95
    consume_rows(cur)
96
    return row_
97

    
98
def next_value(cur): return next_row(cur)[0]
99

    
100
def value(cur): return row(cur)[0]
101

    
102
def values(cur): return iters.func_iter(lambda: next_value(cur))
103

    
104
def value_or_none(cur):
105
    try: return value(cur)
106
    except StopIteration: return None
107

    
108
##### Escaping
109

    
110
def esc_name_by_module(module, name):
111
    if module == 'psycopg2' or module == None: quote = '"'
112
    elif module == 'MySQLdb': quote = '`'
113
    else: raise NotImplementedError("Can't escape name for "+module+' database')
114
    return sql_gen.esc_name(name, quote)
115

    
116
def esc_name_by_engine(engine, name, **kw_args):
117
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
118

    
119
def esc_name(db, name, **kw_args):
120
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
121

    
122
def qual_name(db, schema, table):
123
    def esc_name_(name): return esc_name(db, name)
124
    table = esc_name_(table)
125
    if schema != None: return esc_name_(schema)+'.'+table
126
    else: return table
127

    
128
##### Database connections
129

    
130
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
131

    
132
db_engines = {
133
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
134
    'PostgreSQL': ('psycopg2', {}),
135
}
136

    
137
DatabaseErrors_set = set([DbException])
138
DatabaseErrors = tuple(DatabaseErrors_set)
139

    
140
def _add_module(module):
141
    DatabaseErrors_set.add(module.DatabaseError)
142
    global DatabaseErrors
143
    DatabaseErrors = tuple(DatabaseErrors_set)
144

    
145
def db_config_str(db_config):
146
    return db_config['engine']+' database '+db_config['database']
147

    
148
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
149

    
150
log_debug_none = lambda msg, level=2: None
151

    
152
class DbConn:
153
    def __init__(self, db_config, serializable=True, autocommit=False,
154
        caching=True, log_debug=log_debug_none):
155
        self.db_config = db_config
156
        self.serializable = serializable
157
        self.autocommit = autocommit
158
        self.caching = caching
159
        self.log_debug = log_debug
160
        self.debug = log_debug != log_debug_none
161
        
162
        self.__db = None
163
        self.query_results = {}
164
        self._savepoint = 0
165
        self._notices_seen = set()
166
    
167
    def __getattr__(self, name):
168
        if name == '__dict__': raise Exception('getting __dict__')
169
        if name == 'db': return self._db()
170
        else: raise AttributeError()
171
    
172
    def __getstate__(self):
173
        state = copy.copy(self.__dict__) # shallow copy
174
        state['log_debug'] = None # don't pickle the debug callback
175
        state['_DbConn__db'] = None # don't pickle the connection
176
        return state
177
    
178
    def connected(self): return self.__db != None
179
    
180
    def _db(self):
181
        if self.__db == None:
182
            # Process db_config
183
            db_config = self.db_config.copy() # don't modify input!
184
            schemas = db_config.pop('schemas', None)
185
            module_name, mappings = db_engines[db_config.pop('engine')]
186
            module = __import__(module_name)
187
            _add_module(module)
188
            for orig, new in mappings.iteritems():
189
                try: util.rename_key(db_config, orig, new)
190
                except KeyError: pass
191
            
192
            # Connect
193
            self.__db = module.connect(**db_config)
194
            
195
            # Configure connection
196
            if self.serializable and not self.autocommit: self.run_query(
197
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE', log_level=4)
198
            if schemas != None:
199
                schemas_ = ''.join((esc_name(self, s)+', '
200
                    for s in schemas.split(',')))
201
                self.run_query("SELECT set_config('search_path', \
202
%s || current_setting('search_path'), false)", [schemas_], log_level=3)
203
        
204
        return self.__db
205
    
206
    class DbCursor(Proxy):
207
        def __init__(self, outer):
208
            Proxy.__init__(self, outer.db.cursor())
209
            self.outer = outer
210
            self.query_results = outer.query_results
211
            self.query_lookup = None
212
            self.result = []
213
        
214
        def execute(self, query, params=None):
215
            if params == None or params == [] or params == ():# not using params
216
                esc_query = strings.esc_for_mogrify(query)
217
            else: esc_query = query
218
            
219
            self._is_insert = query.upper().find('INSERT') >= 0
220
            self.query_lookup = _query_lookup(query, params)
221
            try:
222
                try:
223
                    cur = self.inner.execute(esc_query, params)
224
                    self.outer.do_autocommit()
225
                finally: self.query = get_cur_query(self.inner, query, params)
226
            except Exception, e:
227
                _add_cursor_info(e, self, query, params)
228
                self.result = e # cache the exception as the result
229
                self._cache_result()
230
                raise
231
            # Fetch all rows so result will be cached
232
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
233
            return cur
234
        
235
        def fetchone(self):
236
            row = self.inner.fetchone()
237
            if row != None: self.result.append(row)
238
            # otherwise, fetched all rows
239
            else: self._cache_result()
240
            return row
241
        
242
        def _cache_result(self):
243
            # For inserts, only cache exceptions since inserts are not
244
            # idempotent, but an invalid insert will always be invalid
245
            if self.query_results != None and (not self._is_insert
246
                or isinstance(self.result, Exception)):
247
                
248
                assert self.query_lookup != None
249
                self.query_results[self.query_lookup] = self.CacheCursor(
250
                    util.dict_subset(dicts.AttrsDictView(self),
251
                    ['query', 'result', 'rowcount', 'description']))
252
        
253
        class CacheCursor:
254
            def __init__(self, cached_result): self.__dict__ = cached_result
255
            
256
            def execute(self, *args, **kw_args):
257
                if isinstance(self.result, Exception): raise self.result
258
                # otherwise, result is a rows list
259
                self.iter = iter(self.result)
260
            
261
            def fetchone(self):
262
                try: return self.iter.next()
263
                except StopIteration: return None
264
    
265
    def esc_value(self, value):
266
        try: str_ = self.mogrify('%s', [value])
267
        except NotImplementedError, e:
268
            module = util.root_module(self.db)
269
            if module == 'MySQLdb':
270
                import _mysql
271
                str_ = _mysql.escape_string(value)
272
            else: raise e
273
        return strings.to_unicode(str_)
274
    
275
    def esc_name(self, name): return esc_name(self, name) # calls global func
276
    
277
    def can_mogrify(self):
278
        module = util.root_module(self.db)
279
        return module == 'psycopg2'
280
    
281
    def mogrify(self, query, params=None):
282
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
283
        else: raise NotImplementedError("Can't mogrify query")
284
    
285
    def print_notices(self):
286
        if hasattr(self.db, 'notices'):
287
            for msg in self.db.notices:
288
                if msg not in self._notices_seen:
289
                    self._notices_seen.add(msg)
290
                    self.log_debug(msg, level=2)
291
    
292
    def run_query(self, query, params=None, cacheable=False, log_level=2,
293
        debug_msg_ref=None):
294
        '''
295
        @param log_ignore_excs The log_level will be increased by 2 if the query
296
            throws one of these exceptions.
297
        @param debug_msg_ref If specified, the log message will be returned in
298
            this instead of being output. This allows you to filter log messages
299
            depending on the result of the query.
300
        '''
301
        assert query != None
302
        
303
        if not self.caching: cacheable = False
304
        used_cache = False
305
        
306
        def log_msg(query):
307
            if used_cache: cache_status = 'cache hit'
308
            elif cacheable: cache_status = 'cache miss'
309
            else: cache_status = 'non-cacheable'
310
            return 'DB query: '+cache_status+':\n'+strings.as_code(query, 'SQL')
311
        
312
        try:
313
            # Get cursor
314
            if cacheable:
315
                query_lookup = _query_lookup(query, params)
316
                try:
317
                    cur = self.query_results[query_lookup]
318
                    used_cache = True
319
                except KeyError: cur = self.DbCursor(self)
320
            else: cur = self.db.cursor()
321
            
322
            # Log query
323
            if self.debug and debug_msg_ref == None: # log before running
324
                self.log_debug(log_msg(query), log_level)
325
            
326
            # Run query
327
            cur.execute(query, params)
328
        finally:
329
            self.print_notices()
330
            if self.debug and debug_msg_ref != None: # return after running
331
                debug_msg_ref[0] = log_msg(str(get_cur_query(cur, query,
332
                    params)))
333
        
334
        return cur
335
    
336
    def is_cached(self, query, params=None):
337
        return _query_lookup(query, params) in self.query_results
338
    
339
    def with_autocommit(self, func, autocommit=True):
340
        prev_autocommit = self.db.autocommit
341
        self.db.autocommit = autocommit
342
        try: return func()
343
        finally: self.db.autocommit = prev_autocommit
344
    
345
    def with_savepoint(self, func):
346
        savepoint = 'level_'+str(self._savepoint)
347
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
348
        self._savepoint += 1
349
        try:
350
            try: return_val = func()
351
            finally:
352
                self._savepoint -= 1
353
                assert self._savepoint >= 0
354
        except:
355
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
356
            raise
357
        else:
358
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
359
            self.do_autocommit()
360
            return return_val
361
    
362
    def do_autocommit(self):
363
        '''Autocommits if outside savepoint'''
364
        assert self._savepoint >= 0
365
        if self.autocommit and self._savepoint == 0:
366
            self.log_debug('Autocommitting')
367
            self.db.commit()
368
    
369
    def col_default(self, col):
370
        table = sql_gen.Table('columns', 'information_schema')
371
        
372
        conds = [('table_name', col.table.name), ('column_name', col.name)]
373
        schema = col.table.schema
374
        if schema != None: conds.append(('table_schema', schema))
375
        
376
        return sql_gen.as_Code(value(select(self, table, ['column_default'],
377
            conds, order_by='table_schema', limit=1, log_level=3)))
378
            # TODO: order_by search_path schema order
379

    
380
connect = DbConn
381

    
382
##### Querying
383

    
384
def run_raw_query(db, *args, **kw_args):
385
    '''For params, see DbConn.run_query()'''
386
    return db.run_query(*args, **kw_args)
387

    
388
def mogrify(db, query, params):
389
    module = util.root_module(db.db)
390
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
391
    else: raise NotImplementedError("Can't mogrify query for "+module+
392
        ' database')
393

    
394
##### Recoverable querying
395

    
396
def with_savepoint(db, func): return db.with_savepoint(func)
397

    
398
def run_query(db, query, params=None, recover=None, cacheable=False,
399
    log_level=2, log_ignore_excs=None, **kw_args):
400
    '''For params, see run_raw_query()'''
401
    if recover == None: recover = False
402
    if log_ignore_excs == None: log_ignore_excs = ()
403
    log_ignore_excs = tuple(log_ignore_excs)
404
    
405
    debug_msg_ref = None # usually, db.run_query() logs query before running it
406
    # But if filtering with log_ignore_excs, wait until after exception parsing
407
    if log_ignore_excs != () or not db.can_mogrify(): debug_msg_ref = [None] 
408
    
409
    try:
410
        try:
411
            def run(): return run_raw_query(db, query, params, cacheable,
412
                log_level, debug_msg_ref, **kw_args)
413
            if recover and not db.is_cached(query, params):
414
                return with_savepoint(db, run)
415
            else: return run() # don't need savepoint if cached
416
        except Exception, e:
417
            if not recover: raise # need savepoint to run index_cols()
418
            msg = exc.str_(e)
419
            
420
            match = re.search(r'duplicate key value violates unique constraint '
421
                r'"((_?[^\W_]+)_.+?)"', msg)
422
            if match:
423
                constraint, table = match.groups()
424
                try: cols = index_cols(db, table, constraint)
425
                except NotImplementedError: raise e
426
                else: raise DuplicateKeyException(constraint, cols, e)
427
            
428
            match = re.search(r'null value in column "(.+?)" violates not-null'
429
                r' constraint', msg)
430
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
431
            
432
            match = re.search(r'\b(?:invalid input (?:syntax|value)\b.*?'
433
                r'|date/time field value out of range): "(.+?)"\n'
434
                r'(?:(?s).*?)\bfunction "(.+?)"', msg)
435
            if match:
436
                value, name = match.groups()
437
                raise FunctionValueException(name, strings.to_unicode(value), e)
438
            
439
            match = re.search(r'column "(.+?)" is of type (.+?) but expression '
440
                r'is of type', msg)
441
            if match:
442
                col, type_ = match.groups()
443
                raise MissingCastException(type_, col, e)
444
            
445
            match = re.search(r'relation "(.+?)" already exists', msg)
446
            if match: raise DuplicateTableException(match.group(1), e)
447
            
448
            match = re.search(r'function "(.+?)" already exists', msg)
449
            if match: raise DuplicateFunctionException(match.group(1), e)
450
            
451
            raise # no specific exception raised
452
    except log_ignore_excs:
453
        log_level += 2
454
        raise
455
    finally:
456
        if debug_msg_ref != None and debug_msg_ref[0] != None:
457
            db.log_debug(debug_msg_ref[0], log_level)
458

    
459
##### Basic queries
460

    
461
def next_version(name):
462
    version = 1 # first existing name was version 0
463
    match = re.match(r'^(.*)#(\d+)$', name)
464
    if match:
465
        name, version = match.groups()
466
        version = int(version)+1
467
    return sql_gen.add_suffix(name, '#'+str(version))
468

    
469
def run_query_into(db, query, params, into=None, *args, **kw_args):
470
    '''Outputs a query to a temp table.
471
    For params, see run_query().
472
    '''
473
    if into == None: return run_query(db, query, params, *args, **kw_args)
474
    else: # place rows in temp table
475
        assert isinstance(into, sql_gen.Table)
476
        
477
        kw_args['recover'] = True
478
        kw_args.setdefault('log_ignore_excs', (DuplicateTableException,))
479
        
480
        temp = not db.autocommit # tables are permanent in autocommit mode
481
        # "temporary tables cannot specify a schema name", so remove schema
482
        if temp: into.schema = None
483
        
484
        while True:
485
            try:
486
                create_query = 'CREATE'
487
                if temp: create_query += ' TEMP'
488
                create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
489
                
490
                return run_query(db, create_query, params, *args, **kw_args)
491
                    # CREATE TABLE AS sets rowcount to # rows in query
492
            except DuplicateTableException, e:
493
                into.name = next_version(into.name)
494
                # try again with next version of name
495

    
496
order_by_pkey = object() # tells mk_select() to order by the pkey
497

    
498
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
499

    
500
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
501
    start=None, order_by=order_by_pkey, default_table=None):
502
    '''
503
    @param tables The single table to select from, or a list of tables to join
504
        together, with tables after the first being sql_gen.Join objects
505
    @param fields Use None to select all fields in the table
506
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
507
        * container can be any iterable type
508
        * compare_left_side: sql_gen.Code|str (for col name)
509
        * compare_right_side: sql_gen.ValueCond|literal value
510
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
511
        use all columns
512
    @return tuple(query, params)
513
    '''
514
    # Parse tables param
515
    if not lists.is_seq(tables): tables = [tables]
516
    tables = list(tables) # don't modify input! (list() copies input)
517
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
518
    
519
    # Parse other params
520
    if conds == None: conds = []
521
    elif dicts.is_dict(conds): conds = conds.items()
522
    conds = list(conds) # don't modify input! (list() copies input)
523
    assert limit == None or type(limit) == int
524
    assert start == None or type(start) == int
525
    if order_by is order_by_pkey:
526
        if distinct_on != []: order_by = None
527
        else: order_by = pkey(db, table0, recover=True)
528
    
529
    query = 'SELECT'
530
    
531
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
532
    
533
    # DISTINCT ON columns
534
    if distinct_on != []:
535
        query += '\nDISTINCT'
536
        if distinct_on is not distinct_on_all:
537
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
538
    
539
    # Columns
540
    query += '\n'
541
    if fields == None: query += '*'
542
    else: query += '\n, '.join(map(parse_col, fields))
543
    
544
    # Main table
545
    query += '\nFROM '+table0.to_str(db)
546
    
547
    # Add joins
548
    left_table = table0
549
    for join_ in tables:
550
        table = join_.table
551
        
552
        # Parse special values
553
        if join_.type_ is sql_gen.filter_out: # filter no match
554
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
555
                None))
556
        
557
        query += '\n'+join_.to_str(db, left_table)
558
        
559
        left_table = table
560
    
561
    missing = True
562
    if conds != []:
563
        if len(conds) == 1: whitespace = ' '
564
        else: whitespace = '\n'
565
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
566
            .to_str(db) for l, r in conds], 'WHERE')
567
        missing = False
568
    if order_by != None:
569
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
570
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
571
    if start != None:
572
        if start != 0: query += '\nOFFSET '+str(start)
573
        missing = False
574
    if missing: warnings.warn(DbWarning(
575
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
576
    
577
    return (query, [])
578

    
579
def select(db, *args, **kw_args):
580
    '''For params, see mk_select() and run_query()'''
581
    recover = kw_args.pop('recover', None)
582
    cacheable = kw_args.pop('cacheable', True)
583
    log_level = kw_args.pop('log_level', 2)
584
    
585
    query, params = mk_select(db, *args, **kw_args)
586
    return run_query(db, query, params, recover, cacheable, log_level=log_level)
587

    
588
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
589
    returning=None, embeddable=False):
590
    '''
591
    @param returning str|None An inserted column (such as pkey) to return
592
    @param embeddable Whether the query should be embeddable as a nested SELECT.
593
        Warning: If you set this and cacheable=True when the query is run, the
594
        query will be fully cached, not just if it raises an exception.
595
    '''
596
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
597
    if cols == []: cols = None # no cols (all defaults) = unknown col names
598
    if cols != None:
599
        cols = [sql_gen.to_name_only_col(v, table).to_str(db) for v in cols]
600
    if select_query == None: select_query = 'DEFAULT VALUES'
601
    if returning != None: returning = sql_gen.as_Col(returning, table)
602
    
603
    # Build query
604
    first_line = 'INSERT INTO '+table.to_str(db)
605
    query = first_line
606
    if cols != None: query += '\n('+', '.join(cols)+')'
607
    query += '\n'+select_query
608
    
609
    if returning != None:
610
        query += '\nRETURNING '+sql_gen.to_name_only_col(returning).to_str(db)
611
    
612
    if embeddable:
613
        assert returning != None
614
        
615
        # Create function
616
        function_name = sql_gen.clean_name(first_line)
617
        return_type = 'SETOF '+returning.to_str(db)+'%TYPE'
618
        while True:
619
            try:
620
                function = sql_gen.TempFunction(function_name, db.autocommit)
621
                
622
                function_query = '''\
623
CREATE FUNCTION '''+function.to_str(db)+'''()
624
RETURNS '''+return_type+'''
625
LANGUAGE sql
626
AS $$
627
'''+mogrify(db, query, params)+''';
628
$$;
629
'''
630
                run_query(db, function_query, recover=True, cacheable=True,
631
                    log_ignore_excs=(DuplicateFunctionException,))
632
                break # this version was successful
633
            except DuplicateFunctionException, e:
634
                function_name = next_version(function_name)
635
                # try again with next version of name
636
        
637
        # Return query that uses function
638
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(function),
639
            [returning]) # AS clause requires function alias
640
        return mk_select(db, func_table, start=0, order_by=None)
641
    
642
    return (query, params)
643

    
644
def insert_select(db, *args, **kw_args):
645
    '''For params, see mk_insert_select() and run_query_into()
646
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
647
        values in
648
    '''
649
    into = kw_args.pop('into', None)
650
    if into != None: kw_args['embeddable'] = True
651
    recover = kw_args.pop('recover', None)
652
    cacheable = kw_args.pop('cacheable', True)
653
    log_level = kw_args.pop('log_level', 2)
654
    
655
    query, params = mk_insert_select(db, *args, **kw_args)
656
    return run_query_into(db, query, params, into, recover=recover,
657
        cacheable=cacheable, log_level=log_level)
658

    
659
default = sql_gen.default # tells insert() to use the default value for a column
660

    
661
def insert(db, table, row, *args, **kw_args):
662
    '''For params, see insert_select()'''
663
    if lists.is_seq(row): cols = None
664
    else:
665
        cols = row.keys()
666
        row = row.values()
667
    row = list(row) # ensure that "== []" works
668
    
669
    if row == []: query = None
670
    else: query = sql_gen.Values(row).to_str(db)
671
    
672
    return insert_select(db, table, cols, query, [], *args, **kw_args)
673

    
674
def mk_update(db, table, changes=None, cond=None):
675
    '''
676
    @param changes [(col, new_value),...]
677
        * container can be any iterable type
678
        * col: sql_gen.Code|str (for col name)
679
        * new_value: sql_gen.Code|literal value
680
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
681
    @return str query
682
    '''
683
    query = 'UPDATE '+sql_gen.as_Table(table).to_str(db)+'\nSET\n'
684
    query += ',\n'.join((sql_gen.to_name_only_col(col, table).to_str(db)+' = '
685
        +sql_gen.as_Value(new_value).to_str(db) for col, new_value in changes))
686
    if cond != None: query += '\nWHERE\n'+cond.to_str(db)
687
    
688
    return query
689

    
690
def update(db, *args, **kw_args):
691
    '''For params, see mk_update() and run_query()'''
692
    recover = kw_args.pop('recover', None)
693
    
694
    return run_query(db, mk_update(db, *args, **kw_args), [], recover)
695

    
696
def last_insert_id(db):
697
    module = util.root_module(db.db)
698
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
699
    elif module == 'MySQLdb': return db.insert_id()
700
    else: return None
701

    
702
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
703
    '''Creates a mapping from original column names (which may have collisions)
704
    to names that will be distinct among the columns' tables.
705
    This is meant to be used for several tables that are being joined together.
706
    @param cols The columns to combine. Duplicates will be removed.
707
    @param into The table for the new columns.
708
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
709
        columns will be included in the mapping even if they are not in cols.
710
        The tables of the provided Col objects will be changed to into, so make
711
        copies of them if you want to keep the original tables.
712
    @param as_items Whether to return a list of dict items instead of a dict
713
    @return dict(orig_col=new_col, ...)
714
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
715
        * new_col: sql_gen.Col(orig_col_name, into)
716
        * All mappings use the into table so its name can easily be
717
          changed for all columns at once
718
    '''
719
    cols = lists.uniqify(cols)
720
    
721
    items = []
722
    for col in preserve:
723
        orig_col = copy.copy(col)
724
        col.table = into
725
        items.append((orig_col, col))
726
    preserve = set(preserve)
727
    for col in cols:
728
        if col not in preserve:
729
            items.append((col, sql_gen.Col(str(col), into, col.srcs)))
730
    
731
    if not as_items: items = dict(items)
732
    return items
733

    
734
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
735
    '''For params, see mk_flatten_mapping()
736
    @return See return value of mk_flatten_mapping()
737
    '''
738
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
739
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
740
    run_query_into(db, *mk_select(db, joins, cols, limit=limit, start=start),
741
        into=into)
742
    return dict(items)
743

    
744
def mk_track_data_error(db, errors_table, cols, value, error_code, error):
745
    cols = map(sql_gen.to_name_only_col, cols)
746
    
747
    columns_cols = ['column']
748
    columns = sql_gen.NamedValues('columns', columns_cols,
749
        [[c.name] for c in cols])
750
    values_cols = ['value', 'error_code', 'error']
751
    values = sql_gen.NamedValues('values', values_cols,
752
        [value, error_code, error])
753
    
754
    select_cols = columns_cols+values_cols
755
    name_only_cols = map(sql_gen.to_name_only_col, select_cols)
756
    errors_table = sql_gen.NamedTable('errors', errors_table)
757
    joins = [columns, sql_gen.Join(values, type_='CROSS'),
758
        sql_gen.Join(errors_table, dict(zip(name_only_cols, select_cols)),
759
        sql_gen.filter_out)]
760
    
761
    return mk_insert_select(db, errors_table, name_only_cols,
762
        *mk_select(db, joins, select_cols, order_by=None))[0]
763

    
764
def track_data_error(db, errors_table, *args, **kw_args):
765
    '''
766
    @param errors_table If None, does nothing.
767
    '''
768
    if errors_table == None: return
769
    run_query(db, mk_track_data_error(db, errors_table, *args, **kw_args),
770
        cacheable=True, log_level=4)
771

    
772
def cast(db, type_, col, errors_table=None):
773
    '''Casts an (unrenamed) column or value.
774
    If a column, converts any errors to warnings.
775
    @param col sql_gen.Col|sql_gen.Literal
776
    @param errors_table None|sql_gen.Table|str
777
        If set and col is a column with srcs, saves any errors in this table,
778
        using column's srcs attr as the source columns.
779
    '''
780
    if isinstance(col, sql_gen.Literal): # literal value, so just cast
781
        return sql_gen.CustomCode(col.to_str(db)+'::'+type_)
782
    
783
    assert isinstance(col, sql_gen.Col)
784
    assert not isinstance(col, sql_gen.NamedCol)
785
    
786
    save_errors = errors_table != None and col.srcs != ()
787
    if save_errors:
788
        errors_table = sql_gen.as_Table(errors_table)
789
        srcs = map(sql_gen.to_name_only_col, col.srcs)
790
        function_name = str(sql_gen.FunctionCall(type_, *srcs))
791
    else: function_name = type_
792
    function = sql_gen.TempFunction(function_name, db.autocommit)
793
    
794
    while True:
795
        # Create function definition
796
        query = '''\
797
CREATE FUNCTION '''+function.to_str(db)+'''(value text)
798
RETURNS '''+type_+'''
799
LANGUAGE plpgsql
800
'''
801
        if not save_errors: query += 'IMMUTABLE\n'
802
        query += '''\
803
STRICT
804
AS $$
805
BEGIN
806
    /* The explicit cast to the return type is needed to make the cast happen
807
    inside the try block. (Implicit casts to the return type happen at the end
808
    of the function, outside any block.) */
809
    RETURN value::'''+type_+''';
810
EXCEPTION
811
    WHEN data_exception THEN
812
'''
813
        if save_errors:
814
            query += '''\
815
        -- Save error in errors table.
816
        -- Insert the value and error for *each* source column.
817
'''+mk_track_data_error(db, errors_table, srcs,
818
    *map(sql_gen.CustomCode, ['value', 'SQLSTATE', 'SQLERRM']))+''';
819
        
820
'''
821
        query += '''\
822
        RAISE WARNING '%', SQLERRM;
823
        RETURN NULL;
824
END;
825
$$;
826
'''
827
        
828
        # Create function
829
        try:
830
            run_query(db, query, recover=True, cacheable=True,
831
                log_ignore_excs=(DuplicateFunctionException,))
832
            break # successful
833
        except DuplicateFunctionException:
834
            if save_errors: function.name = next_version(function.name)
835
                # try again with next version of name
836
            else: break # plain cast function, so only need one version
837
    
838
    return sql_gen.FunctionCall(function, col)
839

    
840
##### Database structure queries
841

    
842
def table_row_count(db, table, recover=None):
843
    return value(run_query(db, *mk_select(db, table, [sql_gen.row_count],
844
        order_by=None, start=0), recover=recover, log_level=3))
845

    
846
def table_cols(db, table, recover=None):
847
    return list(col_names(select(db, table, limit=0, order_by=None,
848
        recover=recover, log_level=4)))
849

    
850
def pkey(db, table, recover=None):
851
    '''Assumed to be first column in table'''
852
    return table_cols(db, table, recover)[0]
853

    
854
not_null_col = 'not_null_col'
855

    
856
def table_not_null_col(db, table, recover=None):
857
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
858
    if not_null_col in table_cols(db, table, recover): return not_null_col
859
    else: return pkey(db, table, recover)
860

    
861
def index_cols(db, table, index):
862
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
863
    automatically created. When you don't know whether something is a UNIQUE
864
    constraint or a UNIQUE index, use this function.'''
865
    module = util.root_module(db.db)
866
    if module == 'psycopg2':
867
        return list(values(run_query(db, '''\
868
SELECT attname
869
FROM
870
(
871
        SELECT attnum, attname
872
        FROM pg_index
873
        JOIN pg_class index ON index.oid = indexrelid
874
        JOIN pg_class table_ ON table_.oid = indrelid
875
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
876
        WHERE
877
            table_.relname = %(table)s
878
            AND index.relname = %(index)s
879
    UNION
880
        SELECT attnum, attname
881
        FROM
882
        (
883
            SELECT
884
                indrelid
885
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
886
                    AS indkey
887
            FROM pg_index
888
            JOIN pg_class index ON index.oid = indexrelid
889
            JOIN pg_class table_ ON table_.oid = indrelid
890
            WHERE
891
                table_.relname = %(table)s
892
                AND index.relname = %(index)s
893
        ) s
894
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
895
) s
896
ORDER BY attnum
897
''',
898
            {'table': table, 'index': index}, cacheable=True, log_level=4)))
899
    else: raise NotImplementedError("Can't list index columns for "+module+
900
        ' database')
901

    
902
def constraint_cols(db, table, constraint):
903
    module = util.root_module(db.db)
904
    if module == 'psycopg2':
905
        return list(values(run_query(db, '''\
906
SELECT attname
907
FROM pg_constraint
908
JOIN pg_class ON pg_class.oid = conrelid
909
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
910
WHERE
911
    relname = %(table)s
912
    AND conname = %(constraint)s
913
ORDER BY attnum
914
''',
915
            {'table': table, 'constraint': constraint})))
916
    else: raise NotImplementedError("Can't list constraint columns for "+module+
917
        ' database')
918

    
919
row_num_col = '_row_num'
920

    
921
def add_index(db, exprs, table=None, unique=False):
922
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
923
    Currently, only function calls are supported as expressions.
924
    '''
925
    if not lists.is_seq(exprs): exprs = [exprs]
926
    
927
    # Parse exprs
928
    old_exprs = exprs[:]
929
    exprs = []
930
    cols = []
931
    for i, expr in enumerate(old_exprs):
932
        expr = copy.deepcopy(expr) # don't modify input!
933
        expr = sql_gen.as_Col(expr)
934
        
935
        # Extract col
936
        if isinstance(expr, sql_gen.FunctionCall):
937
            col = expr.args[0]
938
            expr = sql_gen.Expr(expr)
939
        else: col = expr
940
        
941
        # Extract table
942
        if table == None:
943
            assert sql_gen.is_table_col(col)
944
            table = col.table
945
        
946
        col.table = None
947
        
948
        exprs.append(expr)
949
        cols.append(col)
950
    
951
    table = sql_gen.as_Table(table)
952
    index = sql_gen.Table(str(sql_gen.Col(','.join(map(str, cols)), table)))
953
    
954
    str_ = 'CREATE'
955
    if unique: str_ += ' UNIQUE'
956
    str_ += ' INDEX '+index.to_str(db)+' ON '+table.to_str(db)+' ('+(
957
        ', '.join((v.to_str(db) for v in exprs)))+')'
958
    
959
    try: run_query(db, str_, recover=True, cacheable=True, log_level=3)
960
    except DuplicateTableException: pass # index already existed
961

    
962
def add_pkey(db, table, cols=None, recover=None):
963
    '''Adds a primary key.
964
    @param cols [sql_gen.Col,...] The columns in the primary key.
965
        Defaults to the first column in the table.
966
    @pre The table must not already have a primary key.
967
    '''
968
    table = sql_gen.as_Table(table)
969
    if cols == None: cols = [pkey(db, table, recover)]
970
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
971
    
972
    index = sql_gen.as_Table(sql_gen.add_suffix(table.name, '_pkey'))
973
    try:
974
        run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD CONSTRAINT '
975
            +index.to_str(db)+' PRIMARY KEY ('+(', '.join(col_strs))+')',
976
            recover=True, cacheable=True, log_level=3,
977
            log_ignore_excs=(DuplicateTableException,))
978
    except DuplicateTableException, e:
979
        index.name = next_version(index.name)
980
        # try again with next version of name
981

    
982
def add_row_num(db, table):
983
    '''Adds a row number column to a table. Its name is in row_num_col. It will
984
    be the primary key.'''
985
    table = sql_gen.as_Table(table).to_str(db)
986
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
987
        +' serial NOT NULL PRIMARY KEY', log_level=3)
988

    
989
def create_table(db, table, cols, has_pkey=True, col_indexes=True):
990
    '''Creates a table.
991
    @param cols [sql_gen.TypedCol,...] The column names and types
992
    @param has_pkey If set, the first column becomes the primary key.
993
    @param col_indexes bool|[ref]
994
        * If True, indexes will be added on all non-pkey columns.
995
        * If a list reference, [0] will be set to a function to do this.
996
          This can be used to delay index creation until the table is populated.
997
    '''
998
    table = sql_gen.as_Table(table)
999
    
1000
    if has_pkey:
1001
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
1002
        pkey.type += ' NOT NULL PRIMARY KEY'
1003
    
1004
    str_ = 'CREATE TABLE '+table.to_str(db)+' (\n'
1005
    str_ += '\n, '.join(v.to_str(db) for v in cols)
1006
    str_ += '\n);\n'
1007
    run_query(db, str_, cacheable=True, log_level=2)
1008
    
1009
    # Add indexes
1010
    if has_pkey: index_cols = cols[1:]
1011
    else: index_cols = cols
1012
    def add_indexes():
1013
        for col in index_cols: add_index(db, col.to_Col(), table)
1014
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes # defer
1015
    elif col_indexes: add_indexes() # add now
1016

    
1017
def vacuum(db, table):
1018
    table = sql_gen.as_Table(table)
1019
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
1020
        log_level=3))
1021

    
1022
def truncate(db, table, schema='public'):
1023
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE')
1024

    
1025
def tables(db, schema_like='public', table_like='%', exact=False):
1026
    if exact: compare = '='
1027
    else: compare = 'LIKE'
1028
    
1029
    module = util.root_module(db.db)
1030
    params = {'schema_like': schema_like, 'table_like': table_like}
1031
    if module == 'psycopg2':
1032
        return values(run_query(db, '''\
1033
SELECT tablename
1034
FROM pg_tables
1035
WHERE
1036
    schemaname '''+compare+''' %(schema_like)s
1037
    AND tablename '''+compare+''' %(table_like)s
1038
ORDER BY tablename
1039
''',
1040
            params, cacheable=True, log_level=4))
1041
    elif module == 'MySQLdb':
1042
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
1043
            cacheable=True, log_level=4))
1044
    else: raise NotImplementedError("Can't list tables for "+module+' database')
1045

    
1046
def table_exists(db, table):
1047
    table = sql_gen.as_Table(table)
1048
    return list(tables(db, table.schema, table.name, exact=True)) != []
1049

    
1050
def errors_table(db, table, if_exists=True):
1051
    '''
1052
    @param if_exists If set, returns None if the errors table doesn't exist
1053
    @return None|sql_gen.Table
1054
    '''
1055
    table = sql_gen.as_Table(table)
1056
    if table.srcs != (): table = table.srcs[0]
1057
    
1058
    errors_table = sql_gen.suffixed_table(table, '.errors')
1059
    if if_exists and not table_exists(db, errors_table): return None
1060
    return errors_table
1061

    
1062
##### Database management
1063

    
1064
def empty_db(db, schema='public', **kw_args):
1065
    '''For kw_args, see tables()'''
1066
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
1067

    
1068
##### Heuristic queries
1069

    
1070
def put(db, table, row, pkey_=None, row_ct_ref=None):
1071
    '''Recovers from errors.
1072
    Only works under PostgreSQL (uses INSERT RETURNING).
1073
    '''
1074
    row = sql_gen.ColDict(db, table, row)
1075
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
1076
    
1077
    try:
1078
        cur = insert(db, table, row, pkey_, recover=True)
1079
        if row_ct_ref != None and cur.rowcount >= 0:
1080
            row_ct_ref[0] += cur.rowcount
1081
        return value(cur)
1082
    except DuplicateKeyException, e:
1083
        row = sql_gen.ColDict(db, table,
1084
            util.dict_subset_right_join(row, e.cols))
1085
        return value(select(db, table, [pkey_], row, recover=True))
1086

    
1087
def get(db, table, row, pkey, row_ct_ref=None, create=False):
1088
    '''Recovers from errors'''
1089
    try: return value(select(db, table, [pkey], row, limit=1, recover=True))
1090
    except StopIteration:
1091
        if not create: raise
1092
        return put(db, table, row, pkey, row_ct_ref) # insert new row
1093

    
1094
def is_func_result(col):
1095
    return col.table.name.find('(') >= 0 and col.name == 'result'
1096

    
1097
def into_table_name(out_table, in_tables0, mapping, is_func):
1098
    def in_col_str(in_col):
1099
        in_col = sql_gen.remove_col_rename(in_col)
1100
        if isinstance(in_col, sql_gen.Col):
1101
            table = in_col.table
1102
            if table == in_tables0:
1103
                in_col = sql_gen.to_name_only_col(in_col)
1104
            elif is_func_result(in_col): in_col = table # omit col name
1105
        return str(in_col)
1106
    
1107
    str_ = str(out_table)
1108
    if is_func:
1109
        str_ += '('
1110
        
1111
        try: value_in_col = mapping['value']
1112
        except KeyError:
1113
            str_ += ', '.join((str(k)+'='+in_col_str(v)
1114
                for k, v in mapping.iteritems()))
1115
        else: str_ += in_col_str(value_in_col)
1116
        
1117
        str_ += ')'
1118
    else: str_ += '_pkeys'
1119
    return str_
1120

    
1121
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, into=None,
1122
    default=None, is_func=False):
1123
    '''Recovers from errors.
1124
    Only works under PostgreSQL (uses INSERT RETURNING).
1125
    @param in_tables The main input table to select from, followed by a list of
1126
        tables to join with it using the main input table's pkey
1127
    @param mapping dict(out_table_col=in_table_col, ...)
1128
        * out_table_col: str (*not* sql_gen.Col)
1129
        * in_table_col: sql_gen.Col|literal-value
1130
    @param into The table to contain the output and input pkeys.
1131
        Defaults to `out_table.name+'_pkeys'`.
1132
    @param default The *output* column to use as the pkey for missing rows.
1133
        If this output column does not exist in the mapping, uses None.
1134
    @param is_func Whether out_table is the name of a SQL function, not a table
1135
    @return sql_gen.Col Where the output pkeys are made available
1136
    '''
1137
    out_table = sql_gen.as_Table(out_table)
1138
    
1139
    def log_debug(msg): db.log_debug(msg, level=1.5)
1140
    def col_ustr(str_):
1141
        return strings.repr_no_u(sql_gen.remove_col_rename(str_))
1142
    
1143
    log_debug('********** New iteration **********')
1144
    log_debug('Inserting these input columns into '+strings.as_tt(
1145
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
1146
    
1147
    # Create input joins from list of input tables
1148
    in_tables_ = in_tables[:] # don't modify input!
1149
    in_tables0 = in_tables_.pop(0) # first table is separate
1150
    errors_table_ = errors_table(db, in_tables0)
1151
    in_pkey = pkey(db, in_tables0, recover=True)
1152
    in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
1153
    input_joins = [in_tables0]+[sql_gen.Join(v,
1154
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
1155
    
1156
    if into == None:
1157
        into = into_table_name(out_table, in_tables0, mapping, is_func)
1158
    into = sql_gen.as_Table(into)
1159
    
1160
    # Set column sources
1161
    in_cols = filter(sql_gen.is_table_col, mapping.values())
1162
    for col in in_cols:
1163
        if col.table == in_tables0: col.set_srcs(sql_gen.src_self)
1164
    
1165
    log_debug('Joining together input tables into temp table')
1166
    # Place in new table for speed and so don't modify input if values edited
1167
    in_table = sql_gen.Table('in')
1168
    mapping = dicts.join(mapping, flatten(db, in_table, input_joins, in_cols,
1169
        preserve=[in_pkey_col], start=0))
1170
    input_joins = [in_table]
1171
    db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
1172
    
1173
    mapping = sql_gen.ColDict(db, out_table, mapping)
1174
        # after applying dicts.join() because that returns a plain dict
1175
    
1176
    # Resolve default value column
1177
    try: default = mapping[default]
1178
    except KeyError:
1179
        if default != None:
1180
            db.log_debug('Default value column '
1181
                +strings.as_tt(strings.repr_no_u(default))
1182
                +' does not exist in mapping, falling back to None', level=2.1)
1183
            default = None
1184
    
1185
    out_pkey = pkey(db, out_table, recover=True)
1186
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
1187
    
1188
    pkeys_names = [in_pkey, out_pkey]
1189
    pkeys_cols = [in_pkey_col, out_pkey_col]
1190
    
1191
    pkeys_table_exists_ref = [False]
1192
    def insert_into_pkeys(joins, cols):
1193
        query, params = mk_select(db, joins, cols, order_by=None, start=0)
1194
        if pkeys_table_exists_ref[0]:
1195
            insert_select(db, into, pkeys_names, query, params)
1196
        else:
1197
            run_query_into(db, query, params, into=into)
1198
            pkeys_table_exists_ref[0] = True
1199
    
1200
    limit_ref = [None]
1201
    conds = set()
1202
    distinct_on = []
1203
    def mk_main_select(joins, cols):
1204
        return mk_select(db, joins, cols, conds, distinct_on,
1205
            limit=limit_ref[0], start=0)
1206
    
1207
    exc_strs = set()
1208
    def log_exc(e):
1209
        e_str = exc.str_(e, first_line_only=True)
1210
        log_debug('Caught exception: '+e_str)
1211
        assert e_str not in exc_strs # avoid infinite loops
1212
        exc_strs.add(e_str)
1213
    
1214
    def remove_all_rows():
1215
        log_debug('Returning NULL for all rows')
1216
        limit_ref[0] = 0 # just create an empty pkeys table
1217
    
1218
    def ignore(in_col, value, e):
1219
        track_data_error(db, errors_table_, in_col.srcs, value, e.cause.pgcode,
1220
            e.cause.pgerror)
1221
        
1222
        in_col_str = strings.as_tt(repr(in_col))
1223
        db.log_debug('Adding index on '+in_col_str+' to enable fast filtering',
1224
            level=2.5)
1225
        add_index(db, in_col)
1226
        
1227
        log_debug('Ignoring rows with '+in_col_str+' = '
1228
            +strings.as_tt(repr(value)))
1229
    def remove_rows(in_col, value, e):
1230
        ignore(in_col, value, e)
1231
        cond = (in_col, sql_gen.CompareCond(value, '!='))
1232
        assert cond not in conds # avoid infinite loops
1233
        conds.add(cond)
1234
    def invalid2null(in_col, value, e):
1235
        ignore(in_col, value, e)
1236
        update(db, in_table, [(in_col, None)],
1237
            sql_gen.ColValueCond(in_col, value))
1238
    
1239
    def insert_pkeys_table(which):
1240
        return sql_gen.Table(sql_gen.add_suffix(in_table.name,
1241
            '_insert_'+which+'_pkeys'))
1242
    insert_out_pkeys = insert_pkeys_table('out')
1243
    insert_in_pkeys = insert_pkeys_table('in')
1244
    
1245
    # Do inserts and selects
1246
    join_cols = sql_gen.ColDict(db, out_table)
1247
    while True:
1248
        if limit_ref[0] == 0: # special case
1249
            log_debug('Creating an empty pkeys table')
1250
            cur = run_query_into(db, *mk_select(db, out_table, [out_pkey],
1251
                limit=limit_ref[0]), into=insert_out_pkeys)
1252
            break # don't do main case
1253
        
1254
        has_joins = join_cols != {}
1255
        
1256
        # Prepare to insert new rows
1257
        insert_joins = input_joins[:] # don't modify original!
1258
        insert_args = dict(recover=True, cacheable=False)
1259
        if has_joins:
1260
            distinct_on = [v.to_Col() for v in join_cols.values()]
1261
            insert_joins.append(sql_gen.Join(out_table, join_cols,
1262
                sql_gen.filter_out))
1263
        else:
1264
            insert_args.update(dict(returning=out_pkey, into=insert_out_pkeys))
1265
        main_select = mk_main_select(insert_joins, mapping.values())[0]
1266
        
1267
        log_debug('Trying to insert new rows')
1268
        try:
1269
            cur = insert_select(db, out_table, mapping.keys(), main_select,
1270
                **insert_args)
1271
            break # insert successful
1272
        except DuplicateKeyException, e:
1273
            log_exc(e)
1274
            
1275
            old_join_cols = join_cols.copy()
1276
            join_cols.update(util.dict_subset_right_join(mapping, e.cols))
1277
            log_debug('Ignoring existing rows, comparing on these columns:\n'
1278
                +strings.as_inline_table(join_cols, ustr=col_ustr))
1279
            assert join_cols != old_join_cols # avoid infinite loops
1280
        except NullValueException, e:
1281
            log_exc(e)
1282
            
1283
            out_col, = e.cols
1284
            try: in_col = mapping[out_col]
1285
            except KeyError:
1286
                log_debug('Missing mapping for NOT NULL column '+out_col)
1287
                remove_all_rows()
1288
            else: remove_rows(in_col, None, e)
1289
        except FunctionValueException, e:
1290
            log_exc(e)
1291
            
1292
            func_name = e.name
1293
            value = e.value
1294
            for out_col, in_col in mapping.iteritems():
1295
                in_col = sql_gen.unwrap_func_call(in_col, func_name)
1296
                invalid2null(in_col, value, e)
1297
        except MissingCastException, e:
1298
            log_exc(e)
1299
            
1300
            out_col = e.col
1301
            type_ = e.type
1302
            
1303
            log_debug('Casting '+strings.as_tt(out_col)+' input to '
1304
                +strings.as_tt(type_))
1305
            def wrap_func(col): return cast(db, type_, col, errors_table_)
1306
            mapping[out_col] = sql_gen.wrap(wrap_func, mapping[out_col])
1307
        except DatabaseErrors, e:
1308
            log_exc(e)
1309
            
1310
            msg = 'No handler for exception: '+exc.str_(e)
1311
            warnings.warn(DbWarning(msg))
1312
            log_debug(msg)
1313
            remove_all_rows()
1314
        # after exception handled, rerun loop with additional constraints
1315
    
1316
    if row_ct_ref != None and cur.rowcount >= 0:
1317
        row_ct_ref[0] += cur.rowcount
1318
    
1319
    if has_joins:
1320
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols)]
1321
        log_debug('Getting output table pkeys of existing/inserted rows')
1322
        insert_into_pkeys(select_joins, pkeys_cols)
1323
    else:
1324
        add_row_num(db, insert_out_pkeys) # for joining with input pkeys
1325
        
1326
        log_debug('Getting input table pkeys of inserted rows')
1327
        run_query_into(db, *mk_main_select(input_joins, [in_pkey]),
1328
            into=insert_in_pkeys)
1329
        add_row_num(db, insert_in_pkeys) # for joining with output pkeys
1330
        
1331
        assert table_row_count(db, insert_out_pkeys) == table_row_count(db,
1332
            insert_in_pkeys)
1333
        
1334
        log_debug('Combining output and input pkeys in inserted order')
1335
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
1336
            {row_num_col: sql_gen.join_same_not_null})]
1337
        insert_into_pkeys(pkey_joins, pkeys_names)
1338
    
1339
    db.log_debug('Adding pkey on pkeys table to enable fast joins', level=2.5)
1340
    add_pkey(db, into)
1341
    
1342
    log_debug('Setting pkeys of missing rows to '+strings.as_tt(repr(default)))
1343
    missing_rows_joins = input_joins+[sql_gen.Join(into,
1344
        {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
1345
        # must use join_same_not_null or query will take forever
1346
    insert_into_pkeys(missing_rows_joins,
1347
        [in_pkey_col, sql_gen.NamedCol(out_pkey, default)])
1348
    
1349
    assert table_row_count(db, into) == table_row_count(db, in_table)
1350
    
1351
    srcs = []
1352
    if is_func: srcs = sql_gen.cols_srcs(in_cols)
1353
    return sql_gen.Col(out_pkey, into, srcs)
1354

    
1355
##### Data cleanup
1356

    
1357
def cleanup_table(db, table, cols):
1358
    table = sql_gen.as_Table(table)
1359
    cols = map(sql_gen.as_Col, cols)
1360
    
1361
    expr = ('nullif(nullif(trim(both from %s), '+db.esc_value('')+'), '
1362
        +db.esc_value(r'\N')+')')
1363
    changes = [(v, sql_gen.CustomCode(expr % v.to_str(db)))
1364
        for v in cols]
1365
    
1366
    update(db, table, changes)
(24-24/36)