Project

General

Profile

1
# Database import/export
2

    
3
import copy
4
import csv
5
import operator
6
import warnings
7
import sys
8

    
9
import csvs
10
import exc
11
import dicts
12
import sql
13
import sql_gen
14
import streams
15
import strings
16
import util
17

    
18
##### Exceptions
19

    
20
# Can't use built-in SyntaxError because it stringifies to only the first line
21
class SyntaxError(Exception): pass
22

    
23
##### Data cleanup
24

    
25
def table_nulls_mapped__set(db, table):
26
    assert isinstance(table, sql_gen.Table)
27
    sql.run_query(db, 'SELECT util.table_nulls_mapped__set('
28
        +sql_gen.table2regclass_text(db, table)+')')
29

    
30
def table_nulls_mapped__get(db, table):
31
    assert isinstance(table, sql_gen.Table)
32
    return sql.value(sql.run_query(db, 'SELECT util.table_nulls_mapped__get('
33
        +sql_gen.table2regclass_text(db, table)+')'))
34

    
35
null_strs_str_default = r',-,\N,NULL,N/A,NA,UNKNOWN,nulo'
36
    # NA: this will remove a common abbr for North America, but we don't use the
37
    # continent, so this is OK
38

    
39
null_strs_str = os.getenv('null_strs', null_strs_str_default)
40
null_strs = null_strs_str.split(',')
41

    
42
def cleanup_table(db, table):
43
    '''idempotent'''
44
    table = sql_gen.as_Table(table)
45
    assert sql.table_exists(db, table)
46
    
47
    if table_nulls_mapped__get(db, table): return # already cleaned up
48
    
49
    cols = filter(lambda c: sql_gen.is_text_col(db, c),
50
        sql.table_cols(db, table))
51
    try: pkey_col = sql.table_pkey_col(db, table)
52
    except sql.DoesNotExistException: pass
53
    else:
54
        try: cols.remove(pkey_col)
55
        except ValueError: pass
56
    if not cols: return
57
    
58
    db.log_debug('Cleaning up table', level=1.5)
59
    
60
    expr = 'trim(both from %s)'
61
    for null in null_strs: expr = 'nullif('+expr+', '+db.esc_value(null)+')'
62
    changes = [(v, sql_gen.CustomCode(expr % v.to_str(db))) for v in cols]
63
    
64
    while True:
65
        try:
66
            sql.update(db, table, changes, in_place=True, recover=True)
67
            break # successful
68
        except sql.NullValueException, e:
69
            db.log_debug('Caught exception: '+exc.str_(e))
70
            col, = e.cols
71
            sql.drop_not_null(db, col)
72
    
73
    db.log_debug('Vacuuming and reanalyzing table', level=1.5)
74
    sql.vacuum(db, table)
75
    
76
    table_nulls_mapped__set(db, table)
77

    
78
##### Error tracking
79

    
80
def track_data_error(db, errors_table, cols, value, error_code, error):
81
    '''
82
    @param errors_table If None, does nothing.
83
    '''
84
    if errors_table == None: return
85
    
86
    col_names = [c.name for c in cols]
87
    if not col_names: col_names = [None] # need at least one entry
88
    for col_name in col_names:
89
        try:
90
            sql.insert(db, errors_table, dict(column=col_name, value=value,
91
                error_code=error_code, error=error), recover=True,
92
                cacheable=True, log_level=4)
93
        except sql.DuplicateKeyException: pass
94

    
95
class ExcToErrorsTable(sql_gen.ExcToWarning):
96
    '''Handles an exception by saving it or converting it to a warning.'''
97
    def __init__(self, return_, srcs, errors_table, value=None):
98
        '''
99
        @param return_ See sql_gen.ExcToWarning
100
        @param srcs The column names for the errors table
101
        @param errors_table None|sql_gen.Table
102
        @param value The value (or an expression for it) that caused the error
103
        @pre The invalid value must be in a local variable "value" of type text.
104
        '''
105
        sql_gen.ExcToWarning.__init__(self, return_)
106
        
107
        value = sql_gen.as_Code(value)
108
        
109
        self.srcs = srcs
110
        self.errors_table = errors_table
111
        self.value = value
112
    
113
    def to_str(self, db):
114
        if not self.srcs or self.errors_table == None:
115
            return sql_gen.ExcToWarning.to_str(self, db)
116
        
117
        errors_table_cols = map(sql_gen.Col,
118
            ['column', 'value', 'error_code', 'error'])
119
        col_names_query = sql.mk_select(db, sql_gen.NamedValues('c', None,
120
            [[c.name] for c in self.srcs]), order_by=None)
121
        insert_query = sql.mk_insert_select(db, self.errors_table,
122
            errors_table_cols,
123
            sql_gen.Values(errors_table_cols).to_str(db))+';\n'
124
        return '''\
125
-- Save error in errors table.
126
DECLARE
127
    error_code text := SQLSTATE;
128
    error text := SQLERRM;
129
    value text := '''+self.value.to_str(db)+''';
130
    "column" text;
131
BEGIN
132
    -- Insert the value and error for *each* source column.
133
'''+strings.indent(sql_gen.RowExcIgnore(None, col_names_query, insert_query,
134
    row_var=errors_table_cols[0]).to_str(db))+'''
135
END;
136

    
137
'''+self.return_.to_str(db)
138

    
139
def data_exception_handler(*args, **kw_args):
140
    '''Handles a data_exception by saving it or converting it to a warning.
141
    For params, see ExcToErrorsTable().
142
    '''
143
    return sql_gen.data_exception_handler(ExcToErrorsTable(*args, **kw_args))
144

    
145
def cast(db, type_, col, errors_table=None):
146
    '''Casts an (unrenamed) column or value.
147
    If errors_table set and col has srcs, saves errors in errors_table (using
148
    col's srcs attr as source columns). Otherwise, converts errors to warnings.
149
    @param col str|sql_gen.Col|sql_gen.Literal
150
    @param errors_table None|sql_gen.Table|str
151
    '''
152
    col = sql_gen.as_Col(col)
153
    
154
    # Don't convert exceptions to warnings for user-supplied constants
155
    if isinstance(col, sql_gen.Literal): return sql_gen.Cast(type_, col)
156
    
157
    assert not isinstance(col, sql_gen.NamedCol)
158
    
159
    function_name = strings.first_word(type_)
160
    srcs = col.srcs
161
    save_errors = errors_table != None and srcs
162
    if save_errors: # function will be unique for the given srcs
163
        function_name = strings.ustr(sql_gen.FunctionCall(function_name,
164
            *map(sql_gen.to_name_only_col, srcs)))
165
    function = db.TempFunction(function_name)
166
    
167
    # Create function definition
168
    modifiers = 'STRICT'
169
    if not save_errors: modifiers = 'IMMUTABLE '+modifiers
170
    value_param = sql_gen.FunctionParam('value', 'anyelement')
171
    handler = data_exception_handler('RETURN NULL;\n', srcs, errors_table,
172
        value_param.name)
173
    body = sql_gen.CustomCode(handler.to_str(db, '''\
174
/* The explicit cast to the return type is needed to make the cast happen
175
inside the try block. (Implicit casts to the return type happen at the end
176
of the function, outside any block.) */
177
RETURN '''+sql_gen.Cast(type_, sql_gen.CustomCode('value')).to_str(db)+''';
178
'''))
179
    body.lang='plpgsql'
180
    sql.define_func(db, sql_gen.FunctionDef(function, type_, body,
181
        [value_param], modifiers))
182
    
183
    return sql_gen.FunctionCall(function, col)
184

    
185
def func_wrapper_exception_handler(db, return_, args, errors_table):
186
    '''Handles a function call's data_exceptions.
187
    Supports PL/Python functions.
188
    @param return_ See data_exception_handler()
189
    @param args [arg...] Function call's args
190
    @param errors_table See data_exception_handler()
191
    '''
192
    args = filter(sql_gen.has_srcs, args)
193
    
194
    srcs = sql_gen.cross_join_srcs(args)
195
    value = sql_gen.merge_not_null(db, ',', args)
196
    return sql_gen.NestedExcHandler(
197
        data_exception_handler(return_, srcs, errors_table, value)
198
        , sql_gen.plpythonu_error_handler
199
        )
200

    
201
def cast_temp_col(db, type_, col, errors_table=None):
202
    '''Like cast(), but creates a new column with the cast values if the input
203
    is a column.
204
    @return The new column or cast value
205
    '''
206
    def cast_(col): return cast(db, type_, col, errors_table)
207
    
208
    try: col = sql_gen.underlying_col(col)
209
    except sql_gen.NoUnderlyingTableException: return sql_gen.wrap(cast_, col)
210
    
211
    table = col.table
212
    new_col = sql_gen.suffixed_col(col, '::'+strings.first_word(type_))
213
    expr = cast_(col)
214
    
215
    # Add column
216
    new_typed_col = sql_gen.TypedCol(new_col.name, type_)
217
    sql.add_col(db, table, new_typed_col, comment=strings.urepr(col)+'::'+type_)
218
    new_col.name = new_typed_col.name # propagate any renaming
219
    
220
    sql.update(db, table, [(new_col, expr)], in_place=True, recover=True)
221
    
222
    return new_col
223

    
224
def errors_table(db, table, if_exists=True):
225
    '''
226
    @param if_exists If set, returns None if the errors table doesn't exist
227
    @return None|sql_gen.Table
228
    '''
229
    table = sql_gen.as_Table(table)
230
    if table.srcs != (): table = table.srcs[0]
231
    
232
    errors_table = sql_gen.suffixed_table(table, '.errors')
233
    if if_exists and not sql.table_exists(db, errors_table): return None
234
    return errors_table
235

    
236
def mk_errors_table(db, table):
237
    errors_table_ = errors_table(db, table, if_exists=False)
238
    if sql.table_exists(db, errors_table_, cacheable=False): return
239
    
240
    typed_cols = [
241
        sql_gen.TypedCol('column', 'text'),
242
        sql_gen.TypedCol('value', 'text'),
243
        sql_gen.TypedCol('error_code', 'character varying(5)', nullable=False),
244
        sql_gen.TypedCol('error', 'text', nullable=False),
245
        ]
246
    sql.create_table(db, errors_table_, typed_cols, has_pkey=False)
247
    index_cols = ['column', sql_gen.CustomCode('md5(value)'), 'error_code',
248
        sql_gen.CustomCode('md5(error)')]
249
    sql.add_index(db, index_cols, errors_table_, unique=True)
250

    
251
##### Import
252

    
253
row_num_col_def = copy.copy(sql.row_num_col_def)
254
row_num_col_def.name = 'row_num'
255
row_num_col_def.type = 'integer'
256

    
257
def append_csv(db, table, reader, header):
258
    def esc_name_(name): return sql.esc_name(db, name)
259
    
260
    def log(msg, level=1): db.log_debug(msg, level)
261
    
262
    # Wrap in standardizing stream
263
    cols_ct = len(header)
264
    stream = csvs.InputRewriter(csvs.ProgressInputFilter(
265
        csvs.ColCtFilter(reader, cols_ct), sys.stderr, n=1000))
266
    #streams.copy(stream, sys.stderr) # to troubleshoot copy_expert() errors
267
    dialect = stream.dialect # use default dialect
268
    
269
    # Create COPY FROM statement
270
    if header == sql.table_col_names(db, table): cols_str = ''
271
    else: cols_str =' ('+(', '.join(map(esc_name_, header)))+')'
272
    copy_from = ('COPY '+table.to_str(db)+cols_str+' FROM STDIN DELIMITER '
273
        +db.esc_value(dialect.delimiter)+' NULL '+db.esc_value(''))
274
    assert not csvs.is_tsv(dialect)
275
    copy_from += ' CSV'
276
    if dialect.quoting != csv.QUOTE_NONE:
277
        quote_str = db.esc_value(dialect.quotechar)
278
        copy_from += ' QUOTE '+quote_str
279
        if dialect.doublequote: copy_from += ' ESCAPE '+quote_str
280
    copy_from += ';\n'
281
    
282
    log(copy_from, level=2)
283
    try: db.db.cursor().copy_expert(copy_from, stream)
284
    except Exception, e: sql.parse_exception(db, e, recover=True)
285

    
286
def import_csv(db, table, reader, header):
287
    def log(msg, level=1): db.log_debug(msg, level)
288
    
289
    # Get format info
290
    col_names = map(strings.to_unicode, header)
291
    for i, col in enumerate(col_names): # replace empty column names
292
        if col == '': col_names[i] = 'column_'+str(i)
293
    
294
    # Select schema and escape names
295
    def esc_name(name): return db.esc_name(name)
296
    
297
    typed_cols = [sql_gen.TypedCol(v, 'text') for v in col_names]
298
    typed_cols.insert(0, row_num_col_def)
299
    header.insert(0, row_num_col_def.name)
300
    reader = csvs.RowNumFilter(reader)
301
    
302
    log('Creating table')
303
    # Note that this is not rolled back if the import fails. Instead, it is
304
    # cached, and will not be re-run if the import is retried.
305
    sql.create_table(db, table, typed_cols, has_pkey=False, col_indexes=False)
306
    
307
    # Free memory used by deleted (rolled back) rows from any failed import.
308
    # This MUST be run so that the rows will be stored in inserted order, and
309
    # the row_num added after import will match up with the CSV's row order.
310
    sql.truncate(db, table)
311
    
312
    # Load the data
313
    def load(): append_csv(db, table, reader, header)
314
    sql.with_savepoint(db, load)
315
    
316
    cleanup_table(db, table)
317

    
318
def put(db, table, row, pkey_=None, row_ct_ref=None, on_error=exc.reraise):
319
    '''Recovers from errors.
320
    Only works under PostgreSQL (uses INSERT RETURNING).
321
    '''
322
    return put_table(db, table, [], row, row_ct_ref, on_error=on_error)
323

    
324
def get(db, table, row, pkey, row_ct_ref=None, create=False):
325
    '''Recovers from errors'''
326
    try:
327
        return sql.value(sql.select(db, table, [pkey], row, limit=1,
328
            recover=True))
329
    except StopIteration:
330
        if not create: raise
331
        return put(db, table, row, pkey, row_ct_ref) # insert new row
332

    
333
def is_func_result(col):
334
    return col.table.name.find('(') >= 0 and col.name == 'result'
335

    
336
def into_table_name(out_table, in_tables0, mapping, is_func):
337
    def in_col_str(in_col):
338
        in_col = sql_gen.remove_col_rename(in_col)
339
        if isinstance(in_col, sql_gen.Col):
340
            table = in_col.table
341
            if table == in_tables0:
342
                in_col = sql_gen.to_name_only_col(in_col)
343
            elif is_func_result(in_col): in_col = table # omit col name
344
        return strings.ustr(in_col)
345
    
346
    str_ = strings.ustr(out_table)
347
    if is_func:
348
        str_ += '('
349
        
350
        try: value_in_col = mapping['value']
351
        except KeyError:
352
            str_ += ', '.join((strings.ustr(k)+'='+in_col_str(v)
353
                for k, v in mapping.iteritems()))
354
        else: str_ += in_col_str(value_in_col)
355
        
356
        str_ += ')'
357
    else:
358
        out_col = 'rank'
359
        try: in_col = mapping[out_col]
360
        except KeyError: str_ += '_pkeys'
361
        else: # has a rank column, so hierarchical
362
            str_ += '['+strings.ustr(out_col)+'='+in_col_str(in_col)+']'
363
    return str_
364

    
365
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, default=None,
366
    col_defaults={}, on_error=exc.reraise):
367
    '''Recovers from errors.
368
    Only works under PostgreSQL (uses INSERT RETURNING).
369
    
370
    Warning: This function's normalizing algorithm does not support database
371
    triggers that populate fields covered by the unique constraint used to do
372
    the DISTINCT ON. Such fields must be populated by the mappings instead.
373
    (Other unique constraints and other non-unique fields are not affected by
374
    this restriction on triggers. Note that the primary key will normally not be
375
    the DISTINCT ON constraint, so trigger-populated natural keys are supported
376
    *unless* the input table contains duplicate rows for some generated keys.)
377
    
378
    Note that much of the complexity of the normalizing algorithm is due to
379
    PostgreSQL (and other DB systems) not having a native command for
380
    INSERT ON DUPLICATE SELECT (wiki.vegpath.org/INSERT_ON_DUPLICATE_SELECT).
381
    For PostgreSQL 9.1+, this can now be emulated using INSTEAD OF triggers.
382
    For earlier versions, you instead have to use this function.
383
    
384
    @param in_tables The main input table to select from, followed by a list of
385
        tables to join with it using the main input table's pkey
386
    @param mapping dict(out_table_col=in_table_col, ...)
387
        * out_table_col: str (*not* sql_gen.Col)
388
        * in_table_col: sql_gen.Col|literal-value
389
    @param default The *output* column to use as the pkey for missing rows.
390
        If this output column does not exist in the mapping, uses None.
391
        Note that this will be used for *all* missing rows, regardless of which
392
        error caused them not to be inserted.
393
    @param col_defaults Default values for required columns.
394
    @return sql_gen.Col Where the output pkeys are made available
395
    '''
396
    import psycopg2.extensions
397
    
398
    # Special handling for functions with hstore params
399
    if out_table == '_map':
400
        import psycopg2.extras
401
        psycopg2.extras.register_hstore(db.db)
402
        
403
        # Parse args
404
        try: value = mapping.pop('value')
405
        except KeyError: return None # value required
406
        
407
        mapping = dict([(k, sql_gen.get_value(v))
408
            for k, v in mapping.iteritems()]) # unwrap literal value
409
        mapping = dict(map=mapping, value=value) # non-value params -> hstore
410
    
411
    out_table = sql_gen.as_Table(out_table)
412
    
413
    def log_debug(msg): db.log_debug(msg, level=1.5)
414
    def col_ustr(str_):
415
        return strings.repr_no_u(sql_gen.remove_col_rename(str_))
416
    
417
    log_debug('********** New iteration **********')
418
    log_debug('Inserting these input columns into '+strings.as_tt(
419
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
420
    
421
    is_function = sql.function_exists(db, out_table)
422
    
423
    if is_function: row_ct_ref = None # only track inserted rows
424
    
425
    # Warn if inserting empty table rows
426
    if not mapping and not is_function: # functions with no args OK
427
        warnings.warn(UserWarning('Inserting empty table row(s)'))
428
    
429
    if is_function: out_pkey = 'result'
430
    else: out_pkey = sql.pkey_name(db, out_table, recover=True)
431
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
432
    
433
    in_tables_ = copy.copy(in_tables) # don't modify input!
434
    try: in_tables0 = in_tables_.pop(0) # first table is separate
435
    except IndexError: in_tables0 = None
436
    else:
437
        in_pkey = sql.pkey_name(db, in_tables0, recover=True)
438
        in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
439
    
440
    # Determine if can use optimization for only literal values
441
    is_literals = not reduce(operator.or_, map(sql_gen.is_table_col,
442
        mapping.values()), False)
443
    is_literals_or_function = is_literals or is_function
444
    
445
    if in_tables0 == None: errors_table_ = None
446
    else: errors_table_ = errors_table(db, in_tables0)
447
    
448
    # Create input joins from list of input tables
449
    input_joins = [in_tables0]+[sql_gen.Join(v,
450
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
451
    
452
    orig_mapping = mapping.copy()
453
    if mapping == {} and not is_function: # need >= one column for INSERT SELECT
454
        mapping = {out_pkey: None} # ColDict will replace with default value
455
    
456
    if not is_literals:
457
        into = sql_gen.as_Table(into_table_name(out_table, in_tables0, mapping,
458
            is_function))
459
        # Ensure into's out_pkey is different from in_pkey by prepending "out."
460
        if is_function: into_out_pkey = out_pkey
461
        else: into_out_pkey = 'out.'+out_pkey
462
        
463
        # Set column sources
464
        in_cols = filter(sql_gen.is_table_col, mapping.values())
465
        for col in in_cols:
466
            if col.table == in_tables0: col.set_srcs(sql_gen.src_self)
467
        
468
        log_debug('Joining together input tables into temp table')
469
        # Place in new table so don't modify input and for speed
470
        in_table = sql_gen.Table('in')
471
        mapping = dicts.join(mapping, sql.flatten(db, in_table, input_joins,
472
            in_cols, preserve=[in_pkey_col]))
473
        input_joins = [in_table]
474
        db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
475
    
476
    # Wrap mapping in a sql_gen.ColDict.
477
    # sql_gen.ColDict sanitizes both keys and values passed into it.
478
    # Do after applying dicts.join() because that returns a plain dict.
479
    mapping = sql_gen.ColDict(db, out_table, mapping)
480
    
481
    # Save all rows since in_table may have rows deleted
482
    if is_literals: pass
483
    elif is_function: full_in_table = in_table
484
    else:
485
        full_in_table = sql_gen.suffixed_table(in_table, '_full')
486
        sql.copy_table(db, in_table, full_in_table)
487
    
488
    pkeys_table_exists_ref = [False]
489
    def insert_into_pkeys(query, **kw_args):
490
        if pkeys_table_exists_ref[0]:
491
            sql.insert_select(db, into, [in_pkey, into_out_pkey], query,
492
                **kw_args)
493
        else:
494
            kw_args.setdefault('add_pkey_', True)
495
            # don't warn if can't create pkey, because this just indicates that,
496
            # at some point in the import tree, it used a set-returning function
497
            kw_args.setdefault('add_pkey_warn', False)
498
            
499
            sql.run_query_into(db, query, into=into, **kw_args)
500
            pkeys_table_exists_ref[0] = True
501
    
502
    def mk_main_select(joins, cols): return sql.mk_select(db, joins, cols)
503
    
504
    if is_literals: insert_in_table = None
505
    else:
506
        insert_in_table = in_table
507
        insert_in_tables = [insert_in_table]
508
    join_cols = sql_gen.ColDict(db, out_table)
509
    join_custom_cond = None
510
    
511
    exc_strs = set()
512
    def log_exc(e):
513
        e_str = exc.str_(e, first_line_only=True)
514
        log_debug('Caught exception: '+e_str)
515
        if e_str in exc_strs: # avoid infinite loops
516
            log_debug('Exception already seen, handler broken')
517
            on_error(e)
518
            remove_all_rows()
519
            return False
520
        else: exc_strs.add(e_str)
521
        return True
522
    
523
    ignore_all_ref = [False]
524
    def remove_all_rows():
525
        log_debug('Ignoring all rows')
526
        ignore_all_ref[0] = True # just return the default value column
527
    
528
    def handle_unknown_exc(e):
529
        log_debug('No handler for exception')
530
        on_error(e)
531
        remove_all_rows()
532
    
533
    def ensure_cond(cond, e, passed=False, failed=False):
534
        '''
535
        @param passed at least one row passed the constraint
536
        @param failed at least one row failed the constraint
537
        '''
538
        if is_literals: # we know the constraint was applied exactly once
539
            if passed: pass
540
            elif failed: remove_all_rows()
541
            else: raise NotImplementedError()
542
        else:
543
            if not is_function:
544
                out_table_cols = sql_gen.ColDict(db, out_table)
545
                out_table_cols.update(util.dict_subset_right_join({},
546
                    sql.table_col_names(db, out_table)))
547
            
548
            in_cols = []
549
            cond = strings.ustr(cond)
550
            orig_cond = cond
551
            cond = sql_gen.map_expr(db, cond, mapping, in_cols)
552
            if not is_function:
553
                cond = sql_gen.map_expr(db, cond, out_table_cols)
554
            
555
            log_debug('Ignoring rows that do not satisfy '+strings.as_tt(cond))
556
            cur = None
557
            if cond == sql_gen.false_expr:
558
                assert failed
559
                remove_all_rows()
560
            elif cond == sql_gen.true_expr: assert passed
561
            else:
562
                while True:
563
                    not_cond = sql_gen.NotCond(sql_gen.CustomCode(cond))
564
                    try:
565
                        cur = sql.delete(db, insert_in_table, not_cond)
566
                        break
567
                    except sql.DoesNotExistException, e:
568
                        if e.type != 'column': raise
569
                        
570
                        last_cond = cond
571
                        cond = sql_gen.map_expr(db, cond, {e.name: None})
572
                        if cond == last_cond: raise # not fixable
573
            
574
            # If any rows failed cond
575
            if failed or cur != None and cur.rowcount > 0:
576
                track_data_error(db, errors_table_,
577
                    sql_gen.cross_join_srcs(in_cols), None, e.cause.pgcode,
578
                    strings.ensure_newl(strings.ustr(e.cause.pgerror))
579
                    +'condition: '+orig_cond+'\ntranslated condition: '+cond)
580
    
581
    not_null_cols = set()
582
    def ignore(in_col, value, e):
583
        if sql_gen.is_table_col(in_col):
584
            in_col = sql_gen.with_table(in_col, insert_in_table)
585
            
586
            track_data_error(db, errors_table_, in_col.srcs, value,
587
                e.cause.pgcode, e.cause.pgerror)
588
            
589
            sql.add_index(db, in_col, insert_in_table) # enable fast filtering
590
            if value != None and in_col not in not_null_cols:
591
                log_debug('Replacing invalid value '
592
                    +strings.as_tt(strings.urepr(value))+' with NULL in column '
593
                    +strings.as_tt(in_col.to_str(db)))
594
                sql.update(db, insert_in_table, [(in_col, None)],
595
                    sql_gen.ColValueCond(in_col, value))
596
            else:
597
                log_debug('Ignoring rows with '+strings.as_tt(in_col.to_str(db))
598
                    +' = '+strings.as_tt(strings.urepr(value)))
599
                sql.delete(db, insert_in_table,
600
                    sql_gen.ColValueCond(in_col, value))
601
                if value == None: not_null_cols.add(in_col)
602
        else:
603
            assert isinstance(in_col, sql_gen.NamedCol)
604
            in_value = sql_gen.remove_col_rename(in_col)
605
            assert sql_gen.is_literal(in_value)
606
            if value == in_value.value:
607
                if value != None:
608
                    log_debug('Replacing invalid literal '
609
                        +strings.as_tt(in_col.to_str(db))+' with NULL')
610
                    mapping[in_col.name] = None
611
                else:
612
                    remove_all_rows()
613
            # otherwise, all columns were being ignore()d because the specific
614
            # column couldn't be identified, and this was not the invalid column
615
    
616
    if not is_literals:
617
        def insert_pkeys_table(which):
618
            return sql_gen.Table(sql_gen.concat(in_table.name,
619
                '_insert_'+which+'_pkeys'))
620
        insert_out_pkeys = insert_pkeys_table('out')
621
        insert_in_pkeys = insert_pkeys_table('in')
622
    
623
    def mk_func_call():
624
        args = dict(((k.name, v) for k, v in mapping.iteritems()))
625
        return sql_gen.FunctionCall(out_table, **args), args
626
    
627
    def handle_MissingCastException(e):
628
        if not log_exc(e): return False
629
        
630
        type_ = e.type
631
        if e.col == None: out_cols = mapping.keys()
632
        else: out_cols = [e.col]
633
        
634
        for out_col in out_cols:
635
            log_debug('Casting '+strings.as_tt(strings.repr_no_u(out_col))
636
                +' input to '+strings.as_tt(type_))
637
            in_col = mapping[out_col]
638
            while True:
639
                try:
640
                    cast_col = cast_temp_col(db, type_, in_col, errors_table_)
641
                    mapping[out_col] = cast_col
642
                    if out_col in join_cols: join_cols[out_col] = cast_col
643
                    break # cast successful
644
                except sql.InvalidValueException, e:
645
                    if not log_exc(e): return False
646
                    
647
                    ignore(in_col, e.value, e)
648
        
649
        return True
650
    
651
    missing_msg = None
652
    
653
    # Do inserts and selects
654
    while True:
655
        has_joins = join_cols != {}
656
        
657
        if ignore_all_ref[0]: break # unrecoverable error, so don't do main case
658
        
659
        # Prepare to insert new rows
660
        if is_function:
661
            if is_literals:
662
                log_debug('Calling function')
663
                func_call, args = mk_func_call()
664
        else:
665
            log_debug('Trying to insert new rows')
666
            insert_args = dict(recover=True, cacheable=False)
667
            if has_joins:
668
                insert_args.update(dict(ignore=True))
669
            else:
670
                insert_args.update(dict(returning=out_pkey))
671
                if not is_literals:
672
                    insert_args.update(dict(into=insert_out_pkeys))
673
            main_select = mk_main_select([insert_in_table], [sql_gen.with_table(
674
                c, insert_in_table) for c in mapping.values()])
675
        
676
        try:
677
            cur = None
678
            if is_function:
679
                if is_literals:
680
                    cur = sql.select(db, fields=[func_call], recover=True,
681
                        cacheable=True)
682
                else:
683
                    log_debug('Defining wrapper function')
684
                    
685
                    func_call, args = mk_func_call()
686
                    func_call = sql_gen.NamedCol(into_out_pkey, func_call)
687
                    
688
                    # Create empty pkeys table so its row type can be used
689
                    insert_into_pkeys(sql.mk_select(db, input_joins,
690
                        [in_pkey_col, func_call], limit=0), add_pkey_=False,
691
                        recover=True)
692
                    result_type = db.col_info(sql_gen.Col(into_out_pkey,
693
                        into)).type
694
                    
695
                    ## Create error handling wrapper function
696
                    
697
                    wrapper = db.TempFunction(sql_gen.concat(into.name,
698
                        '_wrap'))
699
                    
700
                    select_cols = [in_pkey_col]+args.values()
701
                    row_var = copy.copy(sql_gen.row_var)
702
                    row_var.set_srcs([in_table])
703
                    in_pkey_var = sql_gen.Col(in_pkey, row_var)
704
                    
705
                    args = dict(((k, sql_gen.with_table(v, row_var))
706
                        for k, v in args.iteritems()))
707
                    func_call = sql_gen.FunctionCall(out_table, **args)
708
                    
709
                    def mk_return(result):
710
                        return sql_gen.ReturnQuery(sql.mk_select(db,
711
                            fields=[in_pkey_var, result], explain=False))
712
                    exc_handler = func_wrapper_exception_handler(db,
713
                        mk_return(sql_gen.Cast(result_type, None)),
714
                        args.values(), errors_table_)
715
                    
716
                    sql.define_func(db, sql_gen.FunctionDef(wrapper,
717
                        sql_gen.SetOf(into),
718
                        sql_gen.RowExcIgnore(sql_gen.RowType(in_table),
719
                            sql.mk_select(db, input_joins),
720
                            mk_return(func_call), exc_handler=exc_handler)
721
                        ))
722
                    wrapper_table = sql_gen.FunctionCall(wrapper)
723
                    
724
                    log_debug('Calling function')
725
                    insert_into_pkeys(sql.mk_select(db, wrapper_table,
726
                        order_by=None), recover=True, cacheable=False)
727
                    sql.add_pkey_or_index(db, into)
728
            else:
729
                cur = sql.insert_select(db, out_table, mapping.keys(),
730
                    main_select, **insert_args)
731
            break # insert successful
732
        except sql.MissingCastException, e:
733
            if not handle_MissingCastException(e): break
734
        except sql.DuplicateKeyException, e:
735
            if not log_exc(e): break
736
            
737
            # Different rows violating different unique constraints not
738
            # supported
739
            assert not join_cols
740
            
741
            join_custom_cond = e.cond
742
            if e.cond != None: ensure_cond(e.cond, e, passed=True)
743
            
744
            join_cols.update(util.dict_subset_right_join(mapping, e.cols))
745
            log_debug('Ignoring existing rows, comparing on these columns:\n'
746
                +strings.as_inline_table(join_cols, ustr=col_ustr))
747
            
748
            if is_literals:
749
                return sql.value(sql.select(db, out_table, [out_pkey_col],
750
                    join_cols, order_by=None))
751
            
752
            # Uniquify and filter input table to avoid (most) duplicate keys
753
            # (Additional duplicates may be added concurrently and will be
754
            # filtered out separately upon insert.)
755
            while True:
756
                try:
757
                    insert_in_table = sql.distinct_table(db, insert_in_table,
758
                        join_cols.values(), [insert_in_table,
759
                        sql_gen.Join(out_table, join_cols, sql_gen.filter_out,
760
                        e.cond)])
761
                    insert_in_tables.append(insert_in_table)
762
                    break # insert successful
763
                except sql.MissingCastException, e1: # don't modify outer e
764
                    if not handle_MissingCastException(e1): break
765
        except sql.NullValueException, e:
766
            if not log_exc(e): break
767
            
768
            out_col, = e.cols
769
            try: in_col = mapping[out_col]
770
            except KeyError, e:
771
                try: in_col = mapping[out_col] = col_defaults[out_col]
772
                except KeyError:
773
                    missing_msg = 'Missing mapping for NOT NULL column '+out_col
774
                    log_debug(missing_msg)
775
                    remove_all_rows()
776
            else: ignore(in_col, None, e)
777
        except sql.CheckException, e:
778
            if not log_exc(e): break
779
            
780
            ensure_cond(e.cond, e, failed=True)
781
        except sql.InvalidValueException, e:
782
            if not log_exc(e): break
783
            
784
            for in_col in mapping.values(): ignore(in_col, e.value, e)
785
        except psycopg2.extensions.TransactionRollbackError, e:
786
            if not log_exc(e): break
787
            # retry
788
        except sql.DatabaseErrors, e:
789
            if not log_exc(e): break
790
            
791
            handle_unknown_exc(e)
792
        # after exception handled, rerun loop with additional constraints
793
    
794
    # Resolve default value column
795
    if default != None:
796
        if ignore_all_ref[0]: mapping.update(orig_mapping) # use input cols
797
        try: default = mapping[default]
798
        except KeyError:
799
            db.log_debug('Default value column '
800
                +strings.as_tt(strings.repr_no_u(default))
801
                +' does not exist in mapping, falling back to None', level=2.1)
802
            default = None
803
        else: default = sql_gen.remove_col_rename(default)
804
    
805
    if missing_msg != None and default == None:
806
        warnings.warn(UserWarning(missing_msg))
807
        # not an error because sometimes the mappings include
808
        # extra tables which aren't used by the dataset
809
    
810
    # Handle unrecoverable errors
811
    if ignore_all_ref[0]:
812
        log_debug('Returning default: '+strings.as_tt(strings.urepr(default)))
813
        return default
814
    
815
    if cur != None and row_ct_ref != None and cur.rowcount >= 0:
816
        row_ct_ref[0] += cur.rowcount
817
    
818
    if is_literals: return sql.value_or_none(cur) # support multi-row functions
819
    
820
    if is_function: pass # pkeys table already created
821
    elif has_joins:
822
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols,
823
            custom_cond=join_custom_cond)]
824
        log_debug('Getting output table pkeys of existing/inserted rows')
825
        insert_into_pkeys(sql.mk_select(db, select_joins, [in_pkey_col,
826
            sql_gen.NamedCol(into_out_pkey, out_pkey_col)], order_by=None))
827
    else:
828
        sql.add_row_num(db, insert_out_pkeys) # for joining with input pkeys
829
        
830
        log_debug('Getting input table pkeys of inserted rows')
831
        # Note that mk_main_select() does not use ORDER BY. Instead, assume that
832
        # since the SELECT query is identical to the one used in INSERT SELECT,
833
        # its rows will be retrieved in the same order.
834
        sql.run_query_into(db, mk_main_select(input_joins, [in_pkey]),
835
            into=insert_in_pkeys)
836
        sql.add_row_num(db, insert_in_pkeys) # for joining with output pkeys
837
        
838
        assert sql.table_row_count(db, insert_out_pkeys) == sql.table_row_count(
839
            db, insert_in_pkeys)
840
        
841
        log_debug('Combining output and input pkeys in inserted order')
842
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
843
            {sql.row_num_col: sql_gen.join_same_not_null})]
844
        in_col = sql_gen.Col(in_pkey, insert_in_pkeys)
845
        out_col = sql_gen.NamedCol(into_out_pkey,
846
            sql_gen.Col(out_pkey, insert_out_pkeys))
847
        insert_into_pkeys(sql.mk_select(db, pkey_joins, [in_col, out_col],
848
            order_by=None))
849
        
850
        sql.empty_temp(db, [insert_out_pkeys, insert_in_pkeys])
851
    
852
    if not is_function: # is_function doesn't leave holes
853
        log_debug('Setting pkeys of missing rows to '
854
            +strings.as_tt(strings.urepr(default)))
855
        
856
        full_in_pkey_col = sql_gen.Col(in_pkey, full_in_table)
857
        if sql_gen.is_table_col(default):
858
            default = sql_gen.with_table(default, full_in_table)
859
        missing_rows_joins = [full_in_table, sql_gen.Join(into,
860
            {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
861
            # must use join_same_not_null or query will take forever
862
        
863
        insert_args = dict(order_by=None)
864
        if not sql.table_has_pkey(db, full_in_table): # in_table has duplicates
865
            insert_args.update(dict(distinct_on=[full_in_pkey_col]))
866
        
867
        insert_into_pkeys(sql.mk_select(db, missing_rows_joins,
868
            [full_in_pkey_col, sql_gen.NamedCol(into_out_pkey, default)],
869
            **insert_args))
870
    # otherwise, there is already an entry for every row
871
    
872
    sql.empty_temp(db, insert_in_tables+[full_in_table])
873
    
874
    srcs = []
875
    if is_function: srcs = sql_gen.cols_srcs(in_cols)
876
    return sql_gen.Col(into_out_pkey, into, srcs)
(37-37/49)