Project

General

Profile

1
# Database import/export
2

    
3
import copy
4
import csv
5
import operator
6
import warnings
7
import sys
8

    
9
import csvs
10
import exc
11
import dicts
12
import sql
13
import sql_gen
14
import streams
15
import strings
16
import util
17

    
18
##### Exceptions
19

    
20
# Can't use built-in SyntaxError because it stringifies to only the first line
21
class SyntaxError(Exception): pass
22

    
23
##### Data cleanup
24

    
25
def table_nulls_mapped__set(db, table):
26
    assert isinstance(table, sql_gen.Table)
27
    sql.run_query(db, 'SELECT util.table_nulls_mapped__set('
28
        +sql_gen.table2regclass_text(db, table)+')')
29

    
30
def table_nulls_mapped__get(db, table):
31
    assert isinstance(table, sql_gen.Table)
32
    return sql.value(sql.run_query(db, 'SELECT util.table_nulls_mapped__get('
33
        +sql_gen.table2regclass_text(db, table)+')'))
34

    
35
null_strs = ['', '-', r'\N', 'NULL', 'N/A', 'NA', 'UNKNOWN', 'nulo']
36
    # NA: this will remove a common abbr for North America, but we don't use the
37
    # continent, so this is OK
38

    
39
def cleanup_table(db, table):
40
    '''idempotent'''
41
    table = sql_gen.as_Table(table)
42
    assert sql.table_exists(db, table)
43
    
44
    if table_nulls_mapped__get(db, table): return # already cleaned up
45
    
46
    cols = filter(lambda c: sql_gen.is_text_col(db, c),
47
        sql.table_cols(db, table))
48
    try: pkey_col = sql.table_pkey_col(db, table)
49
    except sql.DoesNotExistException: pass
50
    else:
51
        try: cols.remove(pkey_col)
52
        except ValueError: pass
53
    if not cols: return
54
    
55
    db.log_debug('Cleaning up table', level=1.5)
56
    
57
    expr = 'trim(both from %s)'
58
    for null in null_strs: expr = 'nullif('+expr+', '+db.esc_value(null)+')'
59
    changes = [(v, sql_gen.CustomCode(expr % v.to_str(db))) for v in cols]
60
    
61
    while True:
62
        try:
63
            sql.update(db, table, changes, in_place=True, recover=True)
64
            break # successful
65
        except sql.NullValueException, e:
66
            db.log_debug('Caught exception: '+exc.str_(e))
67
            col, = e.cols
68
            sql.drop_not_null(db, col)
69
    
70
    db.log_debug('Vacuuming and reanalyzing table', level=1.5)
71
    sql.vacuum(db, table)
72
    
73
    table_nulls_mapped__set(db, table)
74

    
75
##### Error tracking
76

    
77
def track_data_error(db, errors_table, cols, value, error_code, error):
78
    '''
79
    @param errors_table If None, does nothing.
80
    '''
81
    if errors_table == None: return
82
    
83
    col_names = [c.name for c in cols]
84
    if not col_names: col_names = [None] # need at least one entry
85
    for col_name in col_names:
86
        try:
87
            sql.insert(db, errors_table, dict(column=col_name, value=value,
88
                error_code=error_code, error=error), recover=True,
89
                cacheable=True, log_level=4)
90
        except sql.DuplicateKeyException: pass
91

    
92
class ExcToErrorsTable(sql_gen.ExcToWarning):
93
    '''Handles an exception by saving it or converting it to a warning.'''
94
    def __init__(self, return_, srcs, errors_table, value=None):
95
        '''
96
        @param return_ See sql_gen.ExcToWarning
97
        @param srcs The column names for the errors table
98
        @param errors_table None|sql_gen.Table
99
        @param value The value (or an expression for it) that caused the error
100
        @pre The invalid value must be in a local variable "value" of type text.
101
        '''
102
        sql_gen.ExcToWarning.__init__(self, return_)
103
        
104
        value = sql_gen.as_Code(value)
105
        
106
        self.srcs = srcs
107
        self.errors_table = errors_table
108
        self.value = value
109
    
110
    def to_str(self, db):
111
        if not self.srcs or self.errors_table == None:
112
            return sql_gen.ExcToWarning.to_str(self, db)
113
        
114
        errors_table_cols = map(sql_gen.Col,
115
            ['column', 'value', 'error_code', 'error'])
116
        col_names_query = sql.mk_select(db, sql_gen.NamedValues('c', None,
117
            [[c.name] for c in self.srcs]), order_by=None)
118
        insert_query = sql.mk_insert_select(db, self.errors_table,
119
            errors_table_cols,
120
            sql_gen.Values(errors_table_cols).to_str(db))+';\n'
121
        return '''\
122
-- Save error in errors table.
123
DECLARE
124
    error_code text := SQLSTATE;
125
    error text := SQLERRM;
126
    value text := '''+self.value.to_str(db)+''';
127
    "column" text;
128
BEGIN
129
    -- Insert the value and error for *each* source column.
130
'''+strings.indent(sql_gen.RowExcIgnore(None, col_names_query, insert_query,
131
    row_var=errors_table_cols[0]).to_str(db))+'''
132
END;
133

    
134
'''+self.return_.to_str(db)
135

    
136
def data_exception_handler(*args, **kw_args):
137
    '''Handles a data_exception by saving it or converting it to a warning.
138
    For params, see ExcToErrorsTable().
139
    '''
140
    return sql_gen.data_exception_handler(ExcToErrorsTable(*args, **kw_args))
141

    
142
def cast(db, type_, col, errors_table=None):
143
    '''Casts an (unrenamed) column or value.
144
    If errors_table set and col has srcs, saves errors in errors_table (using
145
    col's srcs attr as source columns). Otherwise, converts errors to warnings.
146
    @param col str|sql_gen.Col|sql_gen.Literal
147
    @param errors_table None|sql_gen.Table|str
148
    '''
149
    col = sql_gen.as_Col(col)
150
    
151
    # Don't convert exceptions to warnings for user-supplied constants
152
    if isinstance(col, sql_gen.Literal): return sql_gen.Cast(type_, col)
153
    
154
    assert not isinstance(col, sql_gen.NamedCol)
155
    
156
    function_name = strings.first_word(type_)
157
    srcs = col.srcs
158
    save_errors = errors_table != None and srcs
159
    if save_errors: # function will be unique for the given srcs
160
        function_name = strings.ustr(sql_gen.FunctionCall(function_name,
161
            *map(sql_gen.to_name_only_col, srcs)))
162
    function = db.TempFunction(function_name)
163
    
164
    # Create function definition
165
    modifiers = 'STRICT'
166
    if not save_errors: modifiers = 'IMMUTABLE '+modifiers
167
    value_param = sql_gen.FunctionParam('value', 'anyelement')
168
    handler = data_exception_handler('RETURN NULL;\n', srcs, errors_table,
169
        value_param.name)
170
    body = sql_gen.CustomCode(handler.to_str(db, '''\
171
/* The explicit cast to the return type is needed to make the cast happen
172
inside the try block. (Implicit casts to the return type happen at the end
173
of the function, outside any block.) */
174
RETURN '''+sql_gen.Cast(type_, sql_gen.CustomCode('value')).to_str(db)+''';
175
'''))
176
    body.lang='plpgsql'
177
    sql.define_func(db, sql_gen.FunctionDef(function, type_, body,
178
        [value_param], modifiers))
179
    
180
    return sql_gen.FunctionCall(function, col)
181

    
182
def func_wrapper_exception_handler(db, return_, args, errors_table):
183
    '''Handles a function call's data_exceptions.
184
    Supports PL/Python functions.
185
    @param return_ See data_exception_handler()
186
    @param args [arg...] Function call's args
187
    @param errors_table See data_exception_handler()
188
    '''
189
    args = filter(sql_gen.has_srcs, args)
190
    
191
    srcs = sql_gen.cross_join_srcs(args)
192
    value = sql_gen.merge_not_null(db, ',', args)
193
    return sql_gen.NestedExcHandler(
194
        data_exception_handler(return_, srcs, errors_table, value)
195
        , sql_gen.plpythonu_error_handler
196
        )
197

    
198
def cast_temp_col(db, type_, col, errors_table=None):
199
    '''Like cast(), but creates a new column with the cast values if the input
200
    is a column.
201
    @return The new column or cast value
202
    '''
203
    def cast_(col): return cast(db, type_, col, errors_table)
204
    
205
    try: col = sql_gen.underlying_col(col)
206
    except sql_gen.NoUnderlyingTableException: return sql_gen.wrap(cast_, col)
207
    
208
    table = col.table
209
    new_col = sql_gen.suffixed_col(col, '::'+strings.first_word(type_))
210
    expr = cast_(col)
211
    
212
    # Add column
213
    new_typed_col = sql_gen.TypedCol(new_col.name, type_)
214
    sql.add_col(db, table, new_typed_col, comment=strings.urepr(col)+'::'+type_)
215
    new_col.name = new_typed_col.name # propagate any renaming
216
    
217
    sql.update(db, table, [(new_col, expr)], in_place=True, recover=True)
218
    
219
    return new_col
220

    
221
def errors_table(db, table, if_exists=True):
222
    '''
223
    @param if_exists If set, returns None if the errors table doesn't exist
224
    @return None|sql_gen.Table
225
    '''
226
    table = sql_gen.as_Table(table)
227
    if table.srcs != (): table = table.srcs[0]
228
    
229
    errors_table = sql_gen.suffixed_table(table, '.errors')
230
    if if_exists and not sql.table_exists(db, errors_table): return None
231
    return errors_table
232

    
233
def mk_errors_table(db, table):
234
    errors_table_ = errors_table(db, table, if_exists=False)
235
    if sql.table_exists(db, errors_table_, cacheable=False): return
236
    
237
    typed_cols = [
238
        sql_gen.TypedCol('column', 'text'),
239
        sql_gen.TypedCol('value', 'text'),
240
        sql_gen.TypedCol('error_code', 'character varying(5)', nullable=False),
241
        sql_gen.TypedCol('error', 'text', nullable=False),
242
        ]
243
    sql.create_table(db, errors_table_, typed_cols, has_pkey=False)
244
    index_cols = ['column', sql_gen.CustomCode('md5(value)'), 'error_code',
245
        sql_gen.CustomCode('md5(error)')]
246
    sql.add_index(db, index_cols, errors_table_, unique=True)
247

    
248
##### Import
249

    
250
row_num_col_def = copy.copy(sql.row_num_col_def)
251
row_num_col_def.name = 'row_num'
252
row_num_col_def.type = 'integer'
253

    
254
def append_csv(db, table, reader, header):
255
    def esc_name_(name): return sql.esc_name(db, name)
256
    
257
    def log(msg, level=1): db.log_debug(msg, level)
258
    
259
    # Wrap in standardizing stream
260
    cols_ct = len(header)
261
    stream = csvs.InputRewriter(csvs.ProgressInputFilter(
262
        csvs.ColCtFilter(reader, cols_ct), sys.stderr, n=1000))
263
    #streams.copy(stream, sys.stderr) # to troubleshoot copy_expert() errors
264
    dialect = stream.dialect # use default dialect
265
    
266
    # Create COPY FROM statement
267
    if header == sql.table_col_names(db, table): cols_str = ''
268
    else: cols_str =' ('+(', '.join(map(esc_name_, header)))+')'
269
    copy_from = ('COPY '+table.to_str(db)+cols_str+' FROM STDIN DELIMITER '
270
        +db.esc_value(dialect.delimiter)+' NULL '+db.esc_value(''))
271
    assert not csvs.is_tsv(dialect)
272
    copy_from += ' CSV'
273
    if dialect.quoting != csv.QUOTE_NONE:
274
        quote_str = db.esc_value(dialect.quotechar)
275
        copy_from += ' QUOTE '+quote_str
276
        if dialect.doublequote: copy_from += ' ESCAPE '+quote_str
277
    copy_from += ';\n'
278
    
279
    log(copy_from, level=2)
280
    try: db.db.cursor().copy_expert(copy_from, stream)
281
    except Exception, e: sql.parse_exception(db, e, recover=True)
282

    
283
def import_csv(db, table, reader, header):
284
    def log(msg, level=1): db.log_debug(msg, level)
285
    
286
    # Get format info
287
    col_names = map(strings.to_unicode, header)
288
    for i, col in enumerate(col_names): # replace empty column names
289
        if col == '': col_names[i] = 'column_'+str(i)
290
    
291
    # Select schema and escape names
292
    def esc_name(name): return db.esc_name(name)
293
    
294
    typed_cols = [sql_gen.TypedCol(v, 'text') for v in col_names]
295
    typed_cols.insert(0, row_num_col_def)
296
    header.insert(0, row_num_col_def.name)
297
    reader = csvs.RowNumFilter(reader)
298
    
299
    log('Creating table')
300
    # Note that this is not rolled back if the import fails. Instead, it is
301
    # cached, and will not be re-run if the import is retried.
302
    sql.create_table(db, table, typed_cols, has_pkey=False, col_indexes=False)
303
    
304
    # Free memory used by deleted (rolled back) rows from any failed import.
305
    # This MUST be run so that the rows will be stored in inserted order, and
306
    # the row_num added after import will match up with the CSV's row order.
307
    sql.truncate(db, table)
308
    
309
    # Load the data
310
    def load(): append_csv(db, table, reader, header)
311
    sql.with_savepoint(db, load)
312
    
313
    cleanup_table(db, table)
314

    
315
def put(db, table, row, pkey_=None, row_ct_ref=None, on_error=exc.reraise):
316
    '''Recovers from errors.
317
    Only works under PostgreSQL (uses INSERT RETURNING).
318
    '''
319
    return put_table(db, table, [], row, row_ct_ref, on_error=on_error)
320

    
321
def get(db, table, row, pkey, row_ct_ref=None, create=False):
322
    '''Recovers from errors'''
323
    try:
324
        return sql.value(sql.select(db, table, [pkey], row, limit=1,
325
            recover=True))
326
    except StopIteration:
327
        if not create: raise
328
        return put(db, table, row, pkey, row_ct_ref) # insert new row
329

    
330
def is_func_result(col):
331
    return col.table.name.find('(') >= 0 and col.name == 'result'
332

    
333
def into_table_name(out_table, in_tables0, mapping, is_func):
334
    def in_col_str(in_col):
335
        in_col = sql_gen.remove_col_rename(in_col)
336
        if isinstance(in_col, sql_gen.Col):
337
            table = in_col.table
338
            if table == in_tables0:
339
                in_col = sql_gen.to_name_only_col(in_col)
340
            elif is_func_result(in_col): in_col = table # omit col name
341
        return strings.ustr(in_col)
342
    
343
    str_ = strings.ustr(out_table)
344
    if is_func:
345
        str_ += '('
346
        
347
        try: value_in_col = mapping['value']
348
        except KeyError:
349
            str_ += ', '.join((strings.ustr(k)+'='+in_col_str(v)
350
                for k, v in mapping.iteritems()))
351
        else: str_ += in_col_str(value_in_col)
352
        
353
        str_ += ')'
354
    else:
355
        out_col = 'rank'
356
        try: in_col = mapping[out_col]
357
        except KeyError: str_ += '_pkeys'
358
        else: # has a rank column, so hierarchical
359
            str_ += '['+strings.ustr(out_col)+'='+in_col_str(in_col)+']'
360
    return str_
361

    
362
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, default=None,
363
    col_defaults={}, on_error=exc.reraise):
364
    '''Recovers from errors.
365
    Only works under PostgreSQL (uses INSERT RETURNING).
366
    
367
    Warning: This function's normalizing algorithm does not support database
368
    triggers that populate fields covered by the unique constraint used to do
369
    the DISTINCT ON. Such fields must be populated by the mappings instead.
370
    (Other unique constraints and other non-unique fields are not affected by
371
    this restriction on triggers. Note that the primary key will normally not be
372
    the DISTINCT ON constraint, so trigger-populated natural keys are supported
373
    *unless* the input table contains duplicate rows for some generated keys.)
374
    
375
    Note that much of the complexity of the normalizing algorithm is due to
376
    PostgreSQL (and other DB systems) not having a native command for
377
    INSERT ON DUPLICATE SELECT (wiki.vegpath.org/INSERT_ON_DUPLICATE_SELECT).
378
    For PostgreSQL 9.1+, this can now be emulated using INSTEAD OF triggers.
379
    For earlier versions, you instead have to use this function.
380
    
381
    @param in_tables The main input table to select from, followed by a list of
382
        tables to join with it using the main input table's pkey
383
    @param mapping dict(out_table_col=in_table_col, ...)
384
        * out_table_col: str (*not* sql_gen.Col)
385
        * in_table_col: sql_gen.Col|literal-value
386
    @param default The *output* column to use as the pkey for missing rows.
387
        If this output column does not exist in the mapping, uses None.
388
        Note that this will be used for *all* missing rows, regardless of which
389
        error caused them not to be inserted.
390
    @param col_defaults Default values for required columns.
391
    @return sql_gen.Col Where the output pkeys are made available
392
    '''
393
    import psycopg2.extensions
394
    
395
    # Special handling for functions with hstore params
396
    if out_table == '_map':
397
        import psycopg2.extras
398
        psycopg2.extras.register_hstore(db.db)
399
        
400
        # Parse args
401
        try: value = mapping.pop('value')
402
        except KeyError: return None # value required
403
        
404
        mapping = dict([(k, sql_gen.get_value(v))
405
            for k, v in mapping.iteritems()]) # unwrap literal value
406
        mapping = dict(map=mapping, value=value) # non-value params -> hstore
407
    
408
    out_table = sql_gen.as_Table(out_table)
409
    
410
    def log_debug(msg): db.log_debug(msg, level=1.5)
411
    def col_ustr(str_):
412
        return strings.repr_no_u(sql_gen.remove_col_rename(str_))
413
    
414
    log_debug('********** New iteration **********')
415
    log_debug('Inserting these input columns into '+strings.as_tt(
416
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
417
    
418
    is_function = sql.function_exists(db, out_table)
419
    
420
    if is_function: row_ct_ref = None # only track inserted rows
421
    
422
    # Warn if inserting empty table rows
423
    if not mapping and not is_function: # functions with no args OK
424
        warnings.warn(UserWarning('Inserting empty table row(s)'))
425
    
426
    if is_function: out_pkey = 'result'
427
    else: out_pkey = sql.pkey_name(db, out_table, recover=True)
428
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
429
    
430
    in_tables_ = copy.copy(in_tables) # don't modify input!
431
    try: in_tables0 = in_tables_.pop(0) # first table is separate
432
    except IndexError: in_tables0 = None
433
    else:
434
        in_pkey = sql.pkey_name(db, in_tables0, recover=True)
435
        in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
436
    
437
    # Determine if can use optimization for only literal values
438
    is_literals = not reduce(operator.or_, map(sql_gen.is_table_col,
439
        mapping.values()), False)
440
    is_literals_or_function = is_literals or is_function
441
    
442
    if in_tables0 == None: errors_table_ = None
443
    else: errors_table_ = errors_table(db, in_tables0)
444
    
445
    # Create input joins from list of input tables
446
    input_joins = [in_tables0]+[sql_gen.Join(v,
447
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
448
    
449
    orig_mapping = mapping.copy()
450
    if mapping == {} and not is_function: # need >= one column for INSERT SELECT
451
        mapping = {out_pkey: None} # ColDict will replace with default value
452
    
453
    if not is_literals:
454
        into = sql_gen.as_Table(into_table_name(out_table, in_tables0, mapping,
455
            is_function))
456
        # Ensure into's out_pkey is different from in_pkey by prepending "out."
457
        if is_function: into_out_pkey = out_pkey
458
        else: into_out_pkey = 'out.'+out_pkey
459
        
460
        # Set column sources
461
        in_cols = filter(sql_gen.is_table_col, mapping.values())
462
        for col in in_cols:
463
            if col.table == in_tables0: col.set_srcs(sql_gen.src_self)
464
        
465
        log_debug('Joining together input tables into temp table')
466
        # Place in new table so don't modify input and for speed
467
        in_table = sql_gen.Table('in')
468
        mapping = dicts.join(mapping, sql.flatten(db, in_table, input_joins,
469
            in_cols, preserve=[in_pkey_col]))
470
        input_joins = [in_table]
471
        db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
472
    
473
    # Wrap mapping in a sql_gen.ColDict.
474
    # sql_gen.ColDict sanitizes both keys and values passed into it.
475
    # Do after applying dicts.join() because that returns a plain dict.
476
    mapping = sql_gen.ColDict(db, out_table, mapping)
477
    
478
    # Save all rows since in_table may have rows deleted
479
    if is_literals: pass
480
    elif is_function: full_in_table = in_table
481
    else:
482
        full_in_table = sql_gen.suffixed_table(in_table, '_full')
483
        sql.copy_table(db, in_table, full_in_table)
484
    
485
    pkeys_table_exists_ref = [False]
486
    def insert_into_pkeys(query, **kw_args):
487
        if pkeys_table_exists_ref[0]:
488
            sql.insert_select(db, into, [in_pkey, into_out_pkey], query,
489
                **kw_args)
490
        else:
491
            kw_args.setdefault('add_pkey_', True)
492
            # don't warn if can't create pkey, because this just indicates that,
493
            # at some point in the import tree, it used a set-returning function
494
            kw_args.setdefault('add_pkey_warn', False)
495
            
496
            sql.run_query_into(db, query, into=into, **kw_args)
497
            pkeys_table_exists_ref[0] = True
498
    
499
    def mk_main_select(joins, cols): return sql.mk_select(db, joins, cols)
500
    
501
    if is_literals: insert_in_table = None
502
    else:
503
        insert_in_table = in_table
504
        insert_in_tables = [insert_in_table]
505
    join_cols = sql_gen.ColDict(db, out_table)
506
    join_custom_cond = None
507
    
508
    exc_strs = set()
509
    def log_exc(e):
510
        e_str = exc.str_(e, first_line_only=True)
511
        log_debug('Caught exception: '+e_str)
512
        if e_str in exc_strs: # avoid infinite loops
513
            log_debug('Exception already seen, handler broken')
514
            on_error(e)
515
            remove_all_rows()
516
            return False
517
        else: exc_strs.add(e_str)
518
        return True
519
    
520
    ignore_all_ref = [False]
521
    def remove_all_rows():
522
        log_debug('Ignoring all rows')
523
        ignore_all_ref[0] = True # just return the default value column
524
    
525
    def handle_unknown_exc(e):
526
        log_debug('No handler for exception')
527
        on_error(e)
528
        remove_all_rows()
529
    
530
    def ensure_cond(cond, e, passed=False, failed=False):
531
        '''
532
        @param passed at least one row passed the constraint
533
        @param failed at least one row failed the constraint
534
        '''
535
        if is_literals: # we know the constraint was applied exactly once
536
            if passed: pass
537
            elif failed: remove_all_rows()
538
            else: raise NotImplementedError()
539
        else:
540
            if not is_function:
541
                out_table_cols = sql_gen.ColDict(db, out_table)
542
                out_table_cols.update(util.dict_subset_right_join({},
543
                    sql.table_col_names(db, out_table)))
544
            
545
            in_cols = []
546
            cond = strings.ustr(cond)
547
            orig_cond = cond
548
            cond = sql_gen.map_expr(db, cond, mapping, in_cols)
549
            if not is_function:
550
                cond = sql_gen.map_expr(db, cond, out_table_cols)
551
            
552
            log_debug('Ignoring rows that do not satisfy '+strings.as_tt(cond))
553
            cur = None
554
            if cond == sql_gen.false_expr:
555
                assert failed
556
                remove_all_rows()
557
            elif cond == sql_gen.true_expr: assert passed
558
            else:
559
                while True:
560
                    not_cond = sql_gen.NotCond(sql_gen.CustomCode(cond))
561
                    try:
562
                        cur = sql.delete(db, insert_in_table, not_cond)
563
                        break
564
                    except sql.DoesNotExistException, e:
565
                        if e.type != 'column': raise
566
                        
567
                        last_cond = cond
568
                        cond = sql_gen.map_expr(db, cond, {e.name: None})
569
                        if cond == last_cond: raise # not fixable
570
            
571
            # If any rows failed cond
572
            if failed or cur != None and cur.rowcount > 0:
573
                track_data_error(db, errors_table_,
574
                    sql_gen.cross_join_srcs(in_cols), None, e.cause.pgcode,
575
                    strings.ensure_newl(strings.ustr(e.cause.pgerror))
576
                    +'condition: '+orig_cond+'\ntranslated condition: '+cond)
577
    
578
    not_null_cols = set()
579
    def ignore(in_col, value, e):
580
        if sql_gen.is_table_col(in_col):
581
            in_col = sql_gen.with_table(in_col, insert_in_table)
582
            
583
            track_data_error(db, errors_table_, in_col.srcs, value,
584
                e.cause.pgcode, e.cause.pgerror)
585
            
586
            sql.add_index(db, in_col, insert_in_table) # enable fast filtering
587
            if value != None and in_col not in not_null_cols:
588
                log_debug('Replacing invalid value '
589
                    +strings.as_tt(strings.urepr(value))+' with NULL in column '
590
                    +strings.as_tt(in_col.to_str(db)))
591
                sql.update(db, insert_in_table, [(in_col, None)],
592
                    sql_gen.ColValueCond(in_col, value))
593
            else:
594
                log_debug('Ignoring rows with '+strings.as_tt(in_col.to_str(db))
595
                    +' = '+strings.as_tt(strings.urepr(value)))
596
                sql.delete(db, insert_in_table,
597
                    sql_gen.ColValueCond(in_col, value))
598
                if value == None: not_null_cols.add(in_col)
599
        else:
600
            assert isinstance(in_col, sql_gen.NamedCol)
601
            in_value = sql_gen.remove_col_rename(in_col)
602
            assert sql_gen.is_literal(in_value)
603
            if value == in_value.value:
604
                if value != None:
605
                    log_debug('Replacing invalid literal '
606
                        +strings.as_tt(in_col.to_str(db))+' with NULL')
607
                    mapping[in_col.name] = None
608
                else:
609
                    remove_all_rows()
610
            # otherwise, all columns were being ignore()d because the specific
611
            # column couldn't be identified, and this was not the invalid column
612
    
613
    if not is_literals:
614
        def insert_pkeys_table(which):
615
            return sql_gen.Table(sql_gen.concat(in_table.name,
616
                '_insert_'+which+'_pkeys'))
617
        insert_out_pkeys = insert_pkeys_table('out')
618
        insert_in_pkeys = insert_pkeys_table('in')
619
    
620
    def mk_func_call():
621
        args = dict(((k.name, v) for k, v in mapping.iteritems()))
622
        return sql_gen.FunctionCall(out_table, **args), args
623
    
624
    def handle_MissingCastException(e):
625
        if not log_exc(e): return False
626
        
627
        type_ = e.type
628
        if e.col == None: out_cols = mapping.keys()
629
        else: out_cols = [e.col]
630
        
631
        for out_col in out_cols:
632
            log_debug('Casting '+strings.as_tt(strings.repr_no_u(out_col))
633
                +' input to '+strings.as_tt(type_))
634
            in_col = mapping[out_col]
635
            while True:
636
                try:
637
                    cast_col = cast_temp_col(db, type_, in_col, errors_table_)
638
                    mapping[out_col] = cast_col
639
                    if out_col in join_cols: join_cols[out_col] = cast_col
640
                    break # cast successful
641
                except sql.InvalidValueException, e:
642
                    if not log_exc(e): return False
643
                    
644
                    ignore(in_col, e.value, e)
645
        
646
        return True
647
    
648
    missing_msg = None
649
    
650
    # Do inserts and selects
651
    while True:
652
        has_joins = join_cols != {}
653
        
654
        if ignore_all_ref[0]: break # unrecoverable error, so don't do main case
655
        
656
        # Prepare to insert new rows
657
        if is_function:
658
            if is_literals:
659
                log_debug('Calling function')
660
                func_call, args = mk_func_call()
661
        else:
662
            log_debug('Trying to insert new rows')
663
            insert_args = dict(recover=True, cacheable=False)
664
            if has_joins:
665
                insert_args.update(dict(ignore=True))
666
            else:
667
                insert_args.update(dict(returning=out_pkey))
668
                if not is_literals:
669
                    insert_args.update(dict(into=insert_out_pkeys))
670
            main_select = mk_main_select([insert_in_table], [sql_gen.with_table(
671
                c, insert_in_table) for c in mapping.values()])
672
        
673
        try:
674
            cur = None
675
            if is_function:
676
                if is_literals:
677
                    cur = sql.select(db, fields=[func_call], recover=True,
678
                        cacheable=True)
679
                else:
680
                    log_debug('Defining wrapper function')
681
                    
682
                    func_call, args = mk_func_call()
683
                    func_call = sql_gen.NamedCol(into_out_pkey, func_call)
684
                    
685
                    # Create empty pkeys table so its row type can be used
686
                    insert_into_pkeys(sql.mk_select(db, input_joins,
687
                        [in_pkey_col, func_call], limit=0), add_pkey_=False,
688
                        recover=True)
689
                    result_type = db.col_info(sql_gen.Col(into_out_pkey,
690
                        into)).type
691
                    
692
                    ## Create error handling wrapper function
693
                    
694
                    wrapper = db.TempFunction(sql_gen.concat(into.name,
695
                        '_wrap'))
696
                    
697
                    select_cols = [in_pkey_col]+args.values()
698
                    row_var = copy.copy(sql_gen.row_var)
699
                    row_var.set_srcs([in_table])
700
                    in_pkey_var = sql_gen.Col(in_pkey, row_var)
701
                    
702
                    args = dict(((k, sql_gen.with_table(v, row_var))
703
                        for k, v in args.iteritems()))
704
                    func_call = sql_gen.FunctionCall(out_table, **args)
705
                    
706
                    def mk_return(result):
707
                        return sql_gen.ReturnQuery(sql.mk_select(db,
708
                            fields=[in_pkey_var, result], explain=False))
709
                    exc_handler = func_wrapper_exception_handler(db,
710
                        mk_return(sql_gen.Cast(result_type, None)),
711
                        args.values(), errors_table_)
712
                    
713
                    sql.define_func(db, sql_gen.FunctionDef(wrapper,
714
                        sql_gen.SetOf(into),
715
                        sql_gen.RowExcIgnore(sql_gen.RowType(in_table),
716
                            sql.mk_select(db, input_joins),
717
                            mk_return(func_call), exc_handler=exc_handler)
718
                        ))
719
                    wrapper_table = sql_gen.FunctionCall(wrapper)
720
                    
721
                    log_debug('Calling function')
722
                    insert_into_pkeys(sql.mk_select(db, wrapper_table,
723
                        order_by=None), recover=True, cacheable=False)
724
                    sql.add_pkey_or_index(db, into)
725
            else:
726
                cur = sql.insert_select(db, out_table, mapping.keys(),
727
                    main_select, **insert_args)
728
            break # insert successful
729
        except sql.MissingCastException, e:
730
            if not handle_MissingCastException(e): break
731
        except sql.DuplicateKeyException, e:
732
            if not log_exc(e): break
733
            
734
            # Different rows violating different unique constraints not
735
            # supported
736
            assert not join_cols
737
            
738
            join_custom_cond = e.cond
739
            if e.cond != None: ensure_cond(e.cond, e, passed=True)
740
            
741
            join_cols.update(util.dict_subset_right_join(mapping, e.cols))
742
            log_debug('Ignoring existing rows, comparing on these columns:\n'
743
                +strings.as_inline_table(join_cols, ustr=col_ustr))
744
            
745
            if is_literals:
746
                return sql.value(sql.select(db, out_table, [out_pkey_col],
747
                    join_cols, order_by=None))
748
            
749
            # Uniquify and filter input table to avoid (most) duplicate keys
750
            # (Additional duplicates may be added concurrently and will be
751
            # filtered out separately upon insert.)
752
            while True:
753
                try:
754
                    insert_in_table = sql.distinct_table(db, insert_in_table,
755
                        join_cols.values(), [insert_in_table,
756
                        sql_gen.Join(out_table, join_cols, sql_gen.filter_out,
757
                        e.cond)])
758
                    insert_in_tables.append(insert_in_table)
759
                    break # insert successful
760
                except sql.MissingCastException, e1: # don't modify outer e
761
                    if not handle_MissingCastException(e1): break
762
        except sql.NullValueException, e:
763
            if not log_exc(e): break
764
            
765
            out_col, = e.cols
766
            try: in_col = mapping[out_col]
767
            except KeyError, e:
768
                try: in_col = mapping[out_col] = col_defaults[out_col]
769
                except KeyError:
770
                    missing_msg = 'Missing mapping for NOT NULL column '+out_col
771
                    log_debug(missing_msg)
772
                    remove_all_rows()
773
            else: ignore(in_col, None, e)
774
        except sql.CheckException, e:
775
            if not log_exc(e): break
776
            
777
            ensure_cond(e.cond, e, failed=True)
778
        except sql.InvalidValueException, e:
779
            if not log_exc(e): break
780
            
781
            for in_col in mapping.values(): ignore(in_col, e.value, e)
782
        except psycopg2.extensions.TransactionRollbackError, e:
783
            if not log_exc(e): break
784
            # retry
785
        except sql.DatabaseErrors, e:
786
            if not log_exc(e): break
787
            
788
            handle_unknown_exc(e)
789
        # after exception handled, rerun loop with additional constraints
790
    
791
    # Resolve default value column
792
    if default != None:
793
        if ignore_all_ref[0]: mapping.update(orig_mapping) # use input cols
794
        try: default = mapping[default]
795
        except KeyError:
796
            db.log_debug('Default value column '
797
                +strings.as_tt(strings.repr_no_u(default))
798
                +' does not exist in mapping, falling back to None', level=2.1)
799
            default = None
800
        else: default = sql_gen.remove_col_rename(default)
801
    
802
    if missing_msg != None and default == None:
803
        warnings.warn(UserWarning(missing_msg))
804
        # not an error because sometimes the mappings include
805
        # extra tables which aren't used by the dataset
806
    
807
    # Handle unrecoverable errors
808
    if ignore_all_ref[0]:
809
        log_debug('Returning default: '+strings.as_tt(strings.urepr(default)))
810
        return default
811
    
812
    if cur != None and row_ct_ref != None and cur.rowcount >= 0:
813
        row_ct_ref[0] += cur.rowcount
814
    
815
    if is_literals: return sql.value_or_none(cur) # support multi-row functions
816
    
817
    if is_function: pass # pkeys table already created
818
    elif has_joins:
819
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols,
820
            custom_cond=join_custom_cond)]
821
        log_debug('Getting output table pkeys of existing/inserted rows')
822
        insert_into_pkeys(sql.mk_select(db, select_joins, [in_pkey_col,
823
            sql_gen.NamedCol(into_out_pkey, out_pkey_col)], order_by=None))
824
    else:
825
        sql.add_row_num(db, insert_out_pkeys) # for joining with input pkeys
826
        
827
        log_debug('Getting input table pkeys of inserted rows')
828
        # Note that mk_main_select() does not use ORDER BY. Instead, assume that
829
        # since the SELECT query is identical to the one used in INSERT SELECT,
830
        # its rows will be retrieved in the same order.
831
        sql.run_query_into(db, mk_main_select(input_joins, [in_pkey]),
832
            into=insert_in_pkeys)
833
        sql.add_row_num(db, insert_in_pkeys) # for joining with output pkeys
834
        
835
        assert sql.table_row_count(db, insert_out_pkeys) == sql.table_row_count(
836
            db, insert_in_pkeys)
837
        
838
        log_debug('Combining output and input pkeys in inserted order')
839
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
840
            {sql.row_num_col: sql_gen.join_same_not_null})]
841
        in_col = sql_gen.Col(in_pkey, insert_in_pkeys)
842
        out_col = sql_gen.NamedCol(into_out_pkey,
843
            sql_gen.Col(out_pkey, insert_out_pkeys))
844
        insert_into_pkeys(sql.mk_select(db, pkey_joins, [in_col, out_col],
845
            order_by=None))
846
        
847
        sql.empty_temp(db, [insert_out_pkeys, insert_in_pkeys])
848
    
849
    if not is_function: # is_function doesn't leave holes
850
        log_debug('Setting pkeys of missing rows to '
851
            +strings.as_tt(strings.urepr(default)))
852
        
853
        full_in_pkey_col = sql_gen.Col(in_pkey, full_in_table)
854
        if sql_gen.is_table_col(default):
855
            default = sql_gen.with_table(default, full_in_table)
856
        missing_rows_joins = [full_in_table, sql_gen.Join(into,
857
            {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
858
            # must use join_same_not_null or query will take forever
859
        
860
        insert_args = dict(order_by=None)
861
        if not sql.table_has_pkey(db, full_in_table): # in_table has duplicates
862
            insert_args.update(dict(distinct_on=[full_in_pkey_col]))
863
        
864
        insert_into_pkeys(sql.mk_select(db, missing_rows_joins,
865
            [full_in_pkey_col, sql_gen.NamedCol(into_out_pkey, default)],
866
            **insert_args))
867
    # otherwise, there is already an entry for every row
868
    
869
    sql.empty_temp(db, insert_in_tables+[full_in_table])
870
    
871
    srcs = []
872
    if is_function: srcs = sql_gen.cols_srcs(in_cols)
873
    return sql_gen.Col(into_out_pkey, into, srcs)
(37-37/49)