Project

General

Profile

1
# Database import/export
2

    
3
import copy
4
import csv
5
import operator
6
import os
7
import warnings
8
import sys
9

    
10
import csvs
11
import exc
12
import dicts
13
import sql
14
import sql_gen
15
import streams
16
import strings
17
import util
18

    
19
##### Exceptions
20

    
21
# Can't use built-in SyntaxError because it stringifies to only the first line
22
class SyntaxError(Exception): pass
23

    
24
##### Data cleanup
25

    
26
def table_nulls_mapped__set(db, table):
27
    assert isinstance(table, sql_gen.Table)
28
    sql.run_query(db, 'SELECT util.table_nulls_mapped__set('
29
        +sql_gen.table2regclass_text(db, table)+')')
30

    
31
def table_nulls_mapped__get(db, table):
32
    assert isinstance(table, sql_gen.Table)
33
    return sql.value(sql.run_query(db, 'SELECT util.table_nulls_mapped__get('
34
        +sql_gen.table2regclass_text(db, table)+')'))
35

    
36
null_strs_str_default = r',-,\N,NULL,N/A,NA,UNKNOWN,nulo'
37
    # NA: this will remove a common abbr for North America, but we don't use the
38
    # continent, so this is OK
39

    
40
null_strs_str = os.getenv('null_strs', null_strs_str_default)
41
null_strs = null_strs_str.split(',')
42

    
43
def cleanup_table(db, table):
44
    '''idempotent'''
45
    table = sql_gen.as_Table(table)
46
    assert sql.table_exists(db, table)
47
    
48
    if table_nulls_mapped__get(db, table): return # already cleaned up
49
    
50
    cols = filter(lambda c: sql_gen.is_text_col(db, c),
51
        sql.table_cols(db, table))
52
    try: pkey_col = sql.table_pkey_col(db, table)
53
    except sql.DoesNotExistException: pass
54
    else:
55
        try: cols.remove(pkey_col)
56
        except ValueError: pass
57
    if not cols: return
58
    
59
    db.log_debug('Cleaning up table', level=1.5)
60
    
61
    expr = 'trim(both from %s)'
62
    for null in null_strs: expr = 'nullif('+expr+', '+db.esc_value(null)+')'
63
    changes = [(v, sql_gen.CustomCode(expr % v.to_str(db))) for v in cols]
64
    
65
    while True:
66
        try:
67
            sql.update(db, table, changes, in_place=True, recover=True)
68
            break # successful
69
        except sql.NullValueException, e:
70
            db.log_debug('Caught exception: '+exc.str_(e))
71
            col, = e.cols
72
            sql.drop_not_null(db, col)
73
    
74
    db.log_debug('Vacuuming and reanalyzing table', level=1.5)
75
    sql.vacuum(db, table)
76
    
77
    table_nulls_mapped__set(db, table)
78

    
79
##### Error tracking
80

    
81
def track_data_error(db, errors_table, cols, value, error_code, error):
82
    '''
83
    @param errors_table If None, does nothing.
84
    '''
85
    if errors_table == None: return
86
    
87
    col_names = [c.name for c in cols]
88
    if not col_names: col_names = [None] # need at least one entry
89
    for col_name in col_names:
90
        try:
91
            sql.insert(db, errors_table, dict(column=col_name, value=value,
92
                error_code=error_code, error=error), recover=True,
93
                cacheable=True, log_level=4)
94
        except sql.DuplicateKeyException: pass
95

    
96
class ExcToErrorsTable(sql_gen.ExcToWarning):
97
    '''Handles an exception by saving it or converting it to a warning.'''
98
    def __init__(self, return_, srcs, errors_table, value=None):
99
        '''
100
        @param return_ See sql_gen.ExcToWarning
101
        @param srcs The column names for the errors table
102
        @param errors_table None|sql_gen.Table
103
        @param value The value (or an expression for it) that caused the error
104
        @pre The invalid value must be in a local variable "value" of type text.
105
        '''
106
        sql_gen.ExcToWarning.__init__(self, return_)
107
        
108
        value = sql_gen.as_Code(value)
109
        
110
        self.srcs = srcs
111
        self.errors_table = errors_table
112
        self.value = value
113
    
114
    def to_str(self, db):
115
        if not self.srcs or self.errors_table == None:
116
            return sql_gen.ExcToWarning.to_str(self, db)
117
        
118
        errors_table_cols = map(sql_gen.Col,
119
            ['column', 'value', 'error_code', 'error'])
120
        col_names_query = sql.mk_select(db, sql_gen.NamedValues('c', None,
121
            [[c.name] for c in self.srcs]), order_by=None)
122
        insert_query = sql.mk_insert_select(db, self.errors_table,
123
            errors_table_cols,
124
            sql_gen.Values(errors_table_cols).to_str(db))+';\n'
125
        return '''\
126
-- Save error in errors table.
127
DECLARE
128
    error_code text := SQLSTATE;
129
    error text := SQLERRM;
130
    value text := '''+self.value.to_str(db)+''';
131
    "column" text;
132
BEGIN
133
    -- Insert the value and error for *each* source column.
134
'''+strings.indent(sql_gen.RowExcIgnore(None, col_names_query, insert_query,
135
    row_var=errors_table_cols[0]).to_str(db))+'''
136
END;
137

    
138
'''+self.return_.to_str(db)
139

    
140
def data_exception_handler(*args, **kw_args):
141
    '''Handles a data_exception by saving it or converting it to a warning.
142
    For params, see ExcToErrorsTable().
143
    '''
144
    return sql_gen.data_exception_handler(ExcToErrorsTable(*args, **kw_args))
145

    
146
def cast(db, type_, col, errors_table=None):
147
    '''Casts an (unrenamed) column or value.
148
    If errors_table set and col has srcs, saves errors in errors_table (using
149
    col's srcs attr as source columns). Otherwise, converts errors to warnings.
150
    @param col str|sql_gen.Col|sql_gen.Literal
151
    @param errors_table None|sql_gen.Table|str
152
    '''
153
    col = sql_gen.as_Col(col)
154
    
155
    # Don't convert exceptions to warnings for user-supplied constants
156
    if isinstance(col, sql_gen.Literal): return sql_gen.Cast(type_, col)
157
    
158
    assert not isinstance(col, sql_gen.NamedCol)
159
    
160
    function_name = strings.first_word(type_)
161
    srcs = col.srcs
162
    save_errors = errors_table != None and srcs
163
    if save_errors: # function will be unique for the given srcs
164
        function_name = strings.ustr(sql_gen.FunctionCall(function_name,
165
            *map(sql_gen.to_name_only_col, srcs)))
166
    function = db.TempFunction(function_name)
167
    
168
    # Create function definition
169
    modifiers = 'STRICT'
170
    if not save_errors: modifiers = 'IMMUTABLE '+modifiers
171
    value_param = sql_gen.FunctionParam('value', 'anyelement')
172
    handler = data_exception_handler('RETURN NULL;\n', srcs, errors_table,
173
        value_param.name)
174
    body = sql_gen.CustomCode(handler.to_str(db, '''\
175
/* The explicit cast to the return type is needed to make the cast happen
176
inside the try block. (Implicit casts to the return type happen at the end
177
of the function, outside any block.) */
178
RETURN '''+sql_gen.Cast(type_, sql_gen.CustomCode('value')).to_str(db)+''';
179
'''))
180
    body.lang='plpgsql'
181
    sql.define_func(db, sql_gen.FunctionDef(function, type_, body,
182
        [value_param], modifiers))
183
    
184
    return sql_gen.FunctionCall(function, col)
185

    
186
def func_wrapper_exception_handler(db, return_, args, errors_table):
187
    '''Handles a function call's data_exceptions.
188
    Supports PL/Python functions.
189
    @param return_ See data_exception_handler()
190
    @param args [arg...] Function call's args
191
    @param errors_table See data_exception_handler()
192
    '''
193
    args = filter(sql_gen.has_srcs, args)
194
    
195
    srcs = sql_gen.cross_join_srcs(args)
196
    value = sql_gen.merge_not_null(db, ',', args)
197
    return sql_gen.NestedExcHandler(
198
        data_exception_handler(return_, srcs, errors_table, value)
199
        , sql_gen.plpythonu_error_handler
200
        )
201

    
202
def cast_temp_col(db, type_, col, errors_table=None):
203
    '''Like cast(), but creates a new column with the cast values if the input
204
    is a column.
205
    @return The new column or cast value
206
    '''
207
    def cast_(col): return cast(db, type_, col, errors_table)
208
    
209
    try: col = sql_gen.underlying_col(col)
210
    except sql_gen.NoUnderlyingTableException: return sql_gen.wrap(cast_, col)
211
    
212
    table = col.table
213
    new_col = sql_gen.suffixed_col(col, '::'+strings.first_word(type_))
214
    expr = cast_(col)
215
    
216
    # Add column
217
    new_typed_col = sql_gen.TypedCol(new_col.name, type_)
218
    sql.add_col(db, table, new_typed_col, comment=strings.urepr(col)+'::'+type_)
219
    new_col.name = new_typed_col.name # propagate any renaming
220
    
221
    sql.update(db, table, [(new_col, expr)], in_place=True, recover=True)
222
    
223
    return new_col
224

    
225
def errors_table(db, table, if_exists=True):
226
    '''
227
    @param if_exists If set, returns None if the errors table doesn't exist
228
    @return None|sql_gen.Table
229
    '''
230
    table = sql_gen.as_Table(table)
231
    if table.srcs != (): table = table.srcs[0]
232
    
233
    errors_table = sql_gen.suffixed_table(table, '.errors')
234
    if if_exists and not sql.table_exists(db, errors_table): return None
235
    return errors_table
236

    
237
def mk_errors_table(db, table):
238
    errors_table_ = errors_table(db, table, if_exists=False)
239
    if sql.table_exists(db, errors_table_, cacheable=False): return
240
    
241
    typed_cols = [
242
        sql_gen.TypedCol('column', 'text'),
243
        sql_gen.TypedCol('value', 'text'),
244
        sql_gen.TypedCol('error_code', 'character varying(5)', nullable=False),
245
        sql_gen.TypedCol('error', 'text', nullable=False),
246
        ]
247
    sql.create_table(db, errors_table_, typed_cols, has_pkey=False)
248
    index_cols = ['column', sql_gen.CustomCode('md5(value)'), 'error_code',
249
        sql_gen.CustomCode('md5(error)')]
250
    sql.add_index(db, index_cols, errors_table_, unique=True)
251

    
252
##### Import
253

    
254
row_num_col_def = copy.copy(sql.row_num_col_def)
255
row_num_col_def.name = 'row_num'
256
row_num_col_def.type = 'integer'
257

    
258
def append_csv(db, table, reader, header):
259
    def esc_name_(name): return sql.esc_name(db, name)
260
    
261
    def log(msg, level=1): db.log_debug(msg, level)
262
    
263
    # Wrap in standardizing stream
264
    cols_ct = len(header)
265
    stream = csvs.InputRewriter(csvs.ProgressInputFilter(
266
        csvs.ColCtFilter(reader, cols_ct), sys.stderr, n=1000))
267
    #streams.copy(stream, sys.stderr) # to troubleshoot copy_expert() errors
268
    dialect = stream.dialect # use default dialect
269
    
270
    # Create COPY FROM statement
271
    if header == sql.table_col_names(db, table): cols_str = ''
272
    else: cols_str =' ('+(', '.join(map(esc_name_, header)))+')'
273
    copy_from = ('COPY '+table.to_str(db)+cols_str+' FROM STDIN DELIMITER '
274
        +db.esc_value(dialect.delimiter)+' NULL '+db.esc_value(''))
275
    assert not csvs.is_tsv(dialect)
276
    copy_from += ' CSV'
277
    if dialect.quoting != csv.QUOTE_NONE:
278
        quote_str = db.esc_value(dialect.quotechar)
279
        copy_from += ' QUOTE '+quote_str
280
        if dialect.doublequote: copy_from += ' ESCAPE '+quote_str
281
    copy_from += ';\n'
282
    
283
    log(copy_from, level=2)
284
    try: db.db.cursor().copy_expert(copy_from, stream)
285
    except Exception, e: sql.parse_exception(db, e, recover=True)
286

    
287
def import_csv(db, table, reader, header):
288
    def log(msg, level=1): db.log_debug(msg, level)
289
    
290
    # Get format info
291
    col_names = map(strings.to_unicode, header)
292
    for i, col in enumerate(col_names): # replace empty column names
293
        if col == '': col_names[i] = 'column_'+str(i)
294
    
295
    # Select schema and escape names
296
    def esc_name(name): return db.esc_name(name)
297
    
298
    typed_cols = [sql_gen.TypedCol(v, 'text') for v in col_names]
299
    typed_cols.insert(0, row_num_col_def)
300
    header.insert(0, row_num_col_def.name)
301
    reader = csvs.RowNumFilter(reader)
302
    
303
    log('Creating table')
304
    # Note that this is not rolled back if the import fails. Instead, it is
305
    # cached, and will not be re-run if the import is retried.
306
    sql.create_table(db, table, typed_cols, has_pkey=False, col_indexes=False)
307
    
308
    # Free memory used by deleted (rolled back) rows from any failed import.
309
    # This MUST be run so that the rows will be stored in inserted order, and
310
    # the row_num added after import will match up with the CSV's row order.
311
    sql.truncate(db, table)
312
    
313
    # Load the data
314
    def load(): append_csv(db, table, reader, header)
315
    sql.with_savepoint(db, load)
316
    
317
    cleanup_table(db, table)
318

    
319
def put(db, table, row, pkey_=None, row_ct_ref=None, on_error=exc.reraise):
320
    '''Recovers from errors.
321
    Only works under PostgreSQL (uses INSERT RETURNING).
322
    '''
323
    return put_table(db, table, [], row, row_ct_ref, on_error=on_error)
324

    
325
def get(db, table, row, pkey, row_ct_ref=None, create=False):
326
    '''Recovers from errors'''
327
    try:
328
        return sql.value(sql.select(db, table, [pkey], row, limit=1,
329
            recover=True))
330
    except StopIteration:
331
        if not create: raise
332
        return put(db, table, row, pkey, row_ct_ref) # insert new row
333

    
334
def is_func_result(col):
335
    return col.table.name.find('(') >= 0 and col.name == 'result'
336

    
337
def into_table_name(out_table, in_tables0, mapping, is_func):
338
    def in_col_str(in_col):
339
        in_col = sql_gen.remove_col_rename(in_col)
340
        if isinstance(in_col, sql_gen.Col):
341
            table = in_col.table
342
            if table == in_tables0:
343
                in_col = sql_gen.to_name_only_col(in_col)
344
            elif is_func_result(in_col): in_col = table # omit col name
345
        return strings.ustr(in_col)
346
    
347
    str_ = strings.ustr(out_table)
348
    if is_func:
349
        str_ += '('
350
        
351
        try: value_in_col = mapping['value']
352
        except KeyError:
353
            str_ += ', '.join((strings.ustr(k)+'='+in_col_str(v)
354
                for k, v in mapping.iteritems()))
355
        else: str_ += in_col_str(value_in_col)
356
        
357
        str_ += ')'
358
    else:
359
        out_col = 'rank'
360
        try: in_col = mapping[out_col]
361
        except KeyError: str_ += '_pkeys'
362
        else: # has a rank column, so hierarchical
363
            str_ += '['+strings.ustr(out_col)+'='+in_col_str(in_col)+']'
364
    return str_
365

    
366
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, default=None,
367
    col_defaults={}, on_error=exc.reraise):
368
    '''Recovers from errors.
369
    Only works under PostgreSQL (uses INSERT RETURNING).
370
    
371
    Warning: This function's normalizing algorithm does not support database
372
    triggers that populate fields covered by the unique constraint used to do
373
    the DISTINCT ON. Such fields must be populated by the mappings instead.
374
    (Other unique constraints and other non-unique fields are not affected by
375
    this restriction on triggers. Note that the primary key will normally not be
376
    the DISTINCT ON constraint, so trigger-populated natural keys are supported
377
    *unless* the input table contains duplicate rows for some generated keys.)
378
    
379
    Note that much of the complexity of the normalizing algorithm is due to
380
    PostgreSQL (and other DB systems) not having a native command for
381
    INSERT ON DUPLICATE SELECT (wiki.vegpath.org/INSERT_ON_DUPLICATE_SELECT).
382
    For PostgreSQL 9.1+, this can now be emulated using INSTEAD OF triggers.
383
    For earlier versions, you instead have to use this function.
384
    
385
    @param in_tables The main input table to select from, followed by a list of
386
        tables to join with it using the main input table's pkey
387
    @param mapping dict(out_table_col=in_table_col, ...)
388
        * out_table_col: str (*not* sql_gen.Col)
389
        * in_table_col: sql_gen.Col|literal-value
390
    @param default The *output* column to use as the pkey for missing rows.
391
        If this output column does not exist in the mapping, uses None.
392
        Note that this will be used for *all* missing rows, regardless of which
393
        error caused them not to be inserted.
394
    @param col_defaults Default values for required columns.
395
    @return sql_gen.Col Where the output pkeys are made available
396
    '''
397
    import psycopg2.extensions
398
    
399
    # Special handling for functions with hstore params
400
    if out_table == '_map':
401
        import psycopg2.extras
402
        psycopg2.extras.register_hstore(db.db)
403
        
404
        # Parse args
405
        try: value = mapping.pop('value')
406
        except KeyError: return None # value required
407
        
408
        mapping = dict([(k, sql_gen.get_value(v))
409
            for k, v in mapping.iteritems()]) # unwrap literal value
410
        mapping = dict(map=mapping, value=value) # non-value params -> hstore
411
    
412
    out_table = sql_gen.as_Table(out_table)
413
    
414
    def log_debug(msg): db.log_debug(msg, level=1.5)
415
    def col_ustr(str_):
416
        return strings.repr_no_u(sql_gen.remove_col_rename(str_))
417
    
418
    log_debug('********** New iteration **********')
419
    log_debug('Inserting these input columns into '+strings.as_tt(
420
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
421
    
422
    is_function = sql.function_exists(db, out_table)
423
    
424
    if is_function: row_ct_ref = None # only track inserted rows
425
    
426
    # Warn if inserting empty table rows
427
    if not mapping and not is_function: # functions with no args OK
428
        warnings.warn(UserWarning('Inserting empty table row(s)'))
429
    
430
    if is_function: out_pkey = 'result'
431
    else: out_pkey = sql.pkey_name(db, out_table, recover=True)
432
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
433
    
434
    in_tables_ = copy.copy(in_tables) # don't modify input!
435
    try: in_tables0 = in_tables_.pop(0) # first table is separate
436
    except IndexError: in_tables0 = None
437
    else:
438
        in_pkey = sql.pkey_name(db, in_tables0, recover=True)
439
        in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
440
    
441
    # Determine if can use optimization for only literal values
442
    is_literals = not reduce(operator.or_, map(sql_gen.is_table_col,
443
        mapping.values()), False)
444
    is_literals_or_function = is_literals or is_function
445
    
446
    if in_tables0 == None: errors_table_ = None
447
    else: errors_table_ = errors_table(db, in_tables0)
448
    
449
    # Create input joins from list of input tables
450
    input_joins = [in_tables0]+[sql_gen.Join(v,
451
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
452
    
453
    orig_mapping = mapping.copy()
454
    if mapping == {} and not is_function: # need >= one column for INSERT SELECT
455
        mapping = {out_pkey: None} # ColDict will replace with default value
456
    
457
    if not is_literals:
458
        into = sql_gen.as_Table(into_table_name(out_table, in_tables0, mapping,
459
            is_function))
460
        # Ensure into's out_pkey is different from in_pkey by prepending "out."
461
        if is_function: into_out_pkey = out_pkey
462
        else: into_out_pkey = 'out.'+out_pkey
463
        
464
        # Set column sources
465
        in_cols = filter(sql_gen.is_table_col, mapping.values())
466
        for col in in_cols:
467
            if col.table == in_tables0: col.set_srcs(sql_gen.src_self)
468
        
469
        log_debug('Joining together input tables into temp table')
470
        # Place in new table so don't modify input and for speed
471
        in_table = sql_gen.Table('in')
472
        mapping = dicts.join(mapping, sql.flatten(db, in_table, input_joins,
473
            in_cols, preserve=[in_pkey_col]))
474
        input_joins = [in_table]
475
        db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
476
    
477
    # Wrap mapping in a sql_gen.ColDict.
478
    # sql_gen.ColDict sanitizes both keys and values passed into it.
479
    # Do after applying dicts.join() because that returns a plain dict.
480
    mapping = sql_gen.ColDict(db, out_table, mapping)
481
    
482
    # Save all rows since in_table may have rows deleted
483
    if is_literals: pass
484
    elif is_function: full_in_table = in_table
485
    else:
486
        full_in_table = sql_gen.suffixed_table(in_table, '_full')
487
        sql.copy_table(db, in_table, full_in_table)
488
    
489
    pkeys_table_exists_ref = [False]
490
    def insert_into_pkeys(query, **kw_args):
491
        if pkeys_table_exists_ref[0]:
492
            sql.insert_select(db, into, [in_pkey, into_out_pkey], query,
493
                **kw_args)
494
        else:
495
            kw_args.setdefault('add_pkey_', True)
496
            # don't warn if can't create pkey, because this just indicates that,
497
            # at some point in the import tree, it used a set-returning function
498
            kw_args.setdefault('add_pkey_warn', False)
499
            
500
            sql.run_query_into(db, query, into=into, **kw_args)
501
            pkeys_table_exists_ref[0] = True
502
    
503
    def mk_main_select(joins, cols): return sql.mk_select(db, joins, cols)
504
    
505
    if is_literals: insert_in_table = None
506
    else:
507
        insert_in_table = in_table
508
        insert_in_tables = [insert_in_table]
509
    join_cols = sql_gen.ColDict(db, out_table)
510
    join_custom_cond = None
511
    
512
    exc_strs = set()
513
    def log_exc(e):
514
        e_str = exc.str_(e, first_line_only=True)
515
        log_debug('Caught exception: '+e_str)
516
        if e_str in exc_strs: # avoid infinite loops
517
            log_debug('Exception already seen, handler broken')
518
            on_error(e)
519
            remove_all_rows()
520
            return False
521
        else: exc_strs.add(e_str)
522
        return True
523
    
524
    ignore_all_ref = [False]
525
    def remove_all_rows():
526
        log_debug('Ignoring all rows')
527
        ignore_all_ref[0] = True # just return the default value column
528
    
529
    def handle_unknown_exc(e):
530
        log_debug('No handler for exception')
531
        on_error(e)
532
        remove_all_rows()
533
    
534
    def ensure_cond(cond, e, passed=False, failed=False):
535
        '''
536
        @param passed at least one row passed the constraint
537
        @param failed at least one row failed the constraint
538
        '''
539
        if is_literals: # we know the constraint was applied exactly once
540
            if passed: pass
541
            elif failed: remove_all_rows()
542
            else: raise NotImplementedError()
543
        else:
544
            if not is_function:
545
                out_table_cols = sql_gen.ColDict(db, out_table)
546
                out_table_cols.update(util.dict_subset_right_join({},
547
                    sql.table_col_names(db, out_table)))
548
            
549
            in_cols = []
550
            cond = strings.ustr(cond)
551
            orig_cond = cond
552
            cond = sql_gen.map_expr(db, cond, mapping, in_cols)
553
            if not is_function:
554
                cond = sql_gen.map_expr(db, cond, out_table_cols)
555
            
556
            log_debug('Ignoring rows that do not satisfy '+strings.as_tt(cond))
557
            cur = None
558
            if cond == sql_gen.false_expr:
559
                assert failed
560
                remove_all_rows()
561
            elif cond == sql_gen.true_expr: assert passed
562
            else:
563
                while True:
564
                    not_cond = sql_gen.NotCond(sql_gen.CustomCode(cond))
565
                    try:
566
                        cur = sql.delete(db, insert_in_table, not_cond)
567
                        break
568
                    except sql.DoesNotExistException, e:
569
                        if e.type != 'column': raise
570
                        
571
                        last_cond = cond
572
                        cond = sql_gen.map_expr(db, cond, {e.name: None})
573
                        if cond == last_cond: raise # not fixable
574
            
575
            # If any rows failed cond
576
            if failed or cur != None and cur.rowcount > 0:
577
                track_data_error(db, errors_table_,
578
                    sql_gen.cross_join_srcs(in_cols), None, e.cause.pgcode,
579
                    strings.ensure_newl(strings.ustr(e.cause.pgerror))
580
                    +'condition: '+orig_cond+'\ntranslated condition: '+cond)
581
    
582
    not_null_cols = set()
583
    def ignore(in_col, value, e):
584
        if sql_gen.is_table_col(in_col):
585
            in_col = sql_gen.with_table(in_col, insert_in_table)
586
            
587
            track_data_error(db, errors_table_, in_col.srcs, value,
588
                e.cause.pgcode, e.cause.pgerror)
589
            
590
            sql.add_index(db, in_col, insert_in_table) # enable fast filtering
591
            if value != None and in_col not in not_null_cols:
592
                log_debug('Replacing invalid value '
593
                    +strings.as_tt(strings.urepr(value))+' with NULL in column '
594
                    +strings.as_tt(in_col.to_str(db)))
595
                sql.update(db, insert_in_table, [(in_col, None)],
596
                    sql_gen.ColValueCond(in_col, value))
597
            else:
598
                log_debug('Ignoring rows with '+strings.as_tt(in_col.to_str(db))
599
                    +' = '+strings.as_tt(strings.urepr(value)))
600
                sql.delete(db, insert_in_table,
601
                    sql_gen.ColValueCond(in_col, value))
602
                if value == None: not_null_cols.add(in_col)
603
        else:
604
            assert isinstance(in_col, sql_gen.NamedCol)
605
            in_value = sql_gen.remove_col_rename(in_col)
606
            assert sql_gen.is_literal(in_value)
607
            if value == in_value.value:
608
                if value != None:
609
                    log_debug('Replacing invalid literal '
610
                        +strings.as_tt(in_col.to_str(db))+' with NULL')
611
                    mapping[in_col.name] = None
612
                else:
613
                    remove_all_rows()
614
            # otherwise, all columns were being ignore()d because the specific
615
            # column couldn't be identified, and this was not the invalid column
616
    
617
    if not is_literals:
618
        def insert_pkeys_table(which):
619
            return sql_gen.Table(sql_gen.concat(in_table.name,
620
                '_insert_'+which+'_pkeys'))
621
        insert_out_pkeys = insert_pkeys_table('out')
622
        insert_in_pkeys = insert_pkeys_table('in')
623
    
624
    def mk_func_call():
625
        args = dict(((k.name, v) for k, v in mapping.iteritems()))
626
        return sql_gen.FunctionCall(out_table, **args), args
627
    
628
    def handle_MissingCastException(e):
629
        if not log_exc(e): return False
630
        
631
        type_ = e.type
632
        if e.col == None: out_cols = mapping.keys()
633
        else: out_cols = [e.col]
634
        
635
        for out_col in out_cols:
636
            log_debug('Casting '+strings.as_tt(strings.repr_no_u(out_col))
637
                +' input to '+strings.as_tt(type_))
638
            in_col = mapping[out_col]
639
            while True:
640
                try:
641
                    cast_col = cast_temp_col(db, type_, in_col, errors_table_)
642
                    mapping[out_col] = cast_col
643
                    if out_col in join_cols: join_cols[out_col] = cast_col
644
                    break # cast successful
645
                except sql.InvalidValueException, e:
646
                    if not log_exc(e): return False
647
                    
648
                    ignore(in_col, e.value, e)
649
        
650
        return True
651
    
652
    missing_msg = None
653
    
654
    # Do inserts and selects
655
    while True:
656
        has_joins = join_cols != {}
657
        
658
        if ignore_all_ref[0]: break # unrecoverable error, so don't do main case
659
        
660
        # Prepare to insert new rows
661
        if is_function:
662
            if is_literals:
663
                log_debug('Calling function')
664
                func_call, args = mk_func_call()
665
        else:
666
            log_debug('Trying to insert new rows')
667
            insert_args = dict(recover=True, cacheable=False)
668
            if has_joins:
669
                insert_args.update(dict(ignore=True))
670
            else:
671
                insert_args.update(dict(returning=out_pkey))
672
                if not is_literals:
673
                    insert_args.update(dict(into=insert_out_pkeys))
674
            main_select = mk_main_select([insert_in_table], [sql_gen.with_table(
675
                c, insert_in_table) for c in mapping.values()])
676
        
677
        try:
678
            cur = None
679
            if is_function:
680
                if is_literals:
681
                    cur = sql.select(db, fields=[func_call], recover=True,
682
                        cacheable=True)
683
                else:
684
                    log_debug('Defining wrapper function')
685
                    
686
                    func_call, args = mk_func_call()
687
                    func_call = sql_gen.NamedCol(into_out_pkey, func_call)
688
                    
689
                    # Create empty pkeys table so its row type can be used
690
                    insert_into_pkeys(sql.mk_select(db, input_joins,
691
                        [in_pkey_col, func_call], limit=0), add_pkey_=False,
692
                        recover=True)
693
                    result_type = db.col_info(sql_gen.Col(into_out_pkey,
694
                        into)).type
695
                    
696
                    ## Create error handling wrapper function
697
                    
698
                    wrapper = db.TempFunction(sql_gen.concat(into.name,
699
                        '_wrap'))
700
                    
701
                    select_cols = [in_pkey_col]+args.values()
702
                    row_var = copy.copy(sql_gen.row_var)
703
                    row_var.set_srcs([in_table])
704
                    in_pkey_var = sql_gen.Col(in_pkey, row_var)
705
                    
706
                    args = dict(((k, sql_gen.with_table(v, row_var))
707
                        for k, v in args.iteritems()))
708
                    func_call = sql_gen.FunctionCall(out_table, **args)
709
                    
710
                    def mk_return(result):
711
                        return sql_gen.ReturnQuery(sql.mk_select(db,
712
                            fields=[in_pkey_var, result], explain=False))
713
                    exc_handler = func_wrapper_exception_handler(db,
714
                        mk_return(sql_gen.Cast(result_type, None)),
715
                        args.values(), errors_table_)
716
                    
717
                    sql.define_func(db, sql_gen.FunctionDef(wrapper,
718
                        sql_gen.SetOf(into),
719
                        sql_gen.RowExcIgnore(sql_gen.RowType(in_table),
720
                            sql.mk_select(db, input_joins),
721
                            mk_return(func_call), exc_handler=exc_handler)
722
                        ))
723
                    wrapper_table = sql_gen.FunctionCall(wrapper)
724
                    
725
                    log_debug('Calling function')
726
                    insert_into_pkeys(sql.mk_select(db, wrapper_table,
727
                        order_by=None), recover=True, cacheable=False)
728
                    sql.add_pkey_or_index(db, into)
729
            else:
730
                cur = sql.insert_select(db, out_table, mapping.keys(),
731
                    main_select, **insert_args)
732
            break # insert successful
733
        except sql.MissingCastException, e:
734
            if not handle_MissingCastException(e): break
735
        except sql.DuplicateKeyException, e:
736
            if not log_exc(e): break
737
            
738
            # Different rows violating different unique constraints not
739
            # supported
740
            assert not join_cols
741
            
742
            join_custom_cond = e.cond
743
            if e.cond != None: ensure_cond(e.cond, e, passed=True)
744
            
745
            join_cols.update(util.dict_subset_right_join(mapping, e.cols))
746
            log_debug('Ignoring existing rows, comparing on these columns:\n'
747
                +strings.as_inline_table(join_cols, ustr=col_ustr))
748
            
749
            if is_literals:
750
                return sql.value(sql.select(db, out_table, [out_pkey_col],
751
                    join_cols, order_by=None))
752
            
753
            # Uniquify and filter input table to avoid (most) duplicate keys
754
            # (Additional duplicates may be added concurrently and will be
755
            # filtered out separately upon insert.)
756
            while True:
757
                try:
758
                    insert_in_table = sql.distinct_table(db, insert_in_table,
759
                        join_cols.values(), [insert_in_table,
760
                        sql_gen.Join(out_table, join_cols, sql_gen.filter_out,
761
                        e.cond)])
762
                    insert_in_tables.append(insert_in_table)
763
                    break # insert successful
764
                except sql.MissingCastException, e1: # don't modify outer e
765
                    if not handle_MissingCastException(e1): break
766
        except sql.NullValueException, e:
767
            if not log_exc(e): break
768
            
769
            out_col, = e.cols
770
            try: in_col = mapping[out_col]
771
            except KeyError, e:
772
                try: in_col = mapping[out_col] = col_defaults[out_col]
773
                except KeyError:
774
                    missing_msg = 'Missing mapping for NOT NULL column '+out_col
775
                    log_debug(missing_msg)
776
                    remove_all_rows()
777
            else: ignore(in_col, None, e)
778
        except sql.CheckException, e:
779
            if not log_exc(e): break
780
            
781
            ensure_cond(e.cond, e, failed=True)
782
        except sql.InvalidValueException, e:
783
            if not log_exc(e): break
784
            
785
            for in_col in mapping.values(): ignore(in_col, e.value, e)
786
        except psycopg2.extensions.TransactionRollbackError, e:
787
            if not log_exc(e): break
788
            # retry
789
        except sql.DatabaseErrors, e:
790
            if not log_exc(e): break
791
            
792
            handle_unknown_exc(e)
793
        # after exception handled, rerun loop with additional constraints
794
    
795
    # Resolve default value column
796
    if default != None:
797
        if ignore_all_ref[0]: mapping.update(orig_mapping) # use input cols
798
        try: default = mapping[default]
799
        except KeyError:
800
            db.log_debug('Default value column '
801
                +strings.as_tt(strings.repr_no_u(default))
802
                +' does not exist in mapping, falling back to None', level=2.1)
803
            default = None
804
        else: default = sql_gen.remove_col_rename(default)
805
    
806
    if missing_msg != None and default == None:
807
        warnings.warn(UserWarning(missing_msg))
808
        # not an error because sometimes the mappings include
809
        # extra tables which aren't used by the dataset
810
    
811
    # Handle unrecoverable errors
812
    if ignore_all_ref[0]:
813
        log_debug('Returning default: '+strings.as_tt(strings.urepr(default)))
814
        return default
815
    
816
    if cur != None and row_ct_ref != None and cur.rowcount >= 0:
817
        row_ct_ref[0] += cur.rowcount
818
    
819
    if is_literals: return sql.value_or_none(cur) # support multi-row functions
820
    
821
    if is_function: pass # pkeys table already created
822
    elif has_joins:
823
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols,
824
            custom_cond=join_custom_cond)]
825
        log_debug('Getting output table pkeys of existing/inserted rows')
826
        insert_into_pkeys(sql.mk_select(db, select_joins, [in_pkey_col,
827
            sql_gen.NamedCol(into_out_pkey, out_pkey_col)], order_by=None))
828
    else:
829
        sql.add_row_num(db, insert_out_pkeys) # for joining with input pkeys
830
        
831
        log_debug('Getting input table pkeys of inserted rows')
832
        # Note that mk_main_select() does not use ORDER BY. Instead, assume that
833
        # since the SELECT query is identical to the one used in INSERT SELECT,
834
        # its rows will be retrieved in the same order.
835
        sql.run_query_into(db, mk_main_select(input_joins, [in_pkey]),
836
            into=insert_in_pkeys)
837
        sql.add_row_num(db, insert_in_pkeys) # for joining with output pkeys
838
        
839
        assert sql.table_row_count(db, insert_out_pkeys) == sql.table_row_count(
840
            db, insert_in_pkeys)
841
        
842
        log_debug('Combining output and input pkeys in inserted order')
843
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
844
            {sql.row_num_col: sql_gen.join_same_not_null})]
845
        in_col = sql_gen.Col(in_pkey, insert_in_pkeys)
846
        out_col = sql_gen.NamedCol(into_out_pkey,
847
            sql_gen.Col(out_pkey, insert_out_pkeys))
848
        insert_into_pkeys(sql.mk_select(db, pkey_joins, [in_col, out_col],
849
            order_by=None))
850
        
851
        sql.empty_temp(db, [insert_out_pkeys, insert_in_pkeys])
852
    
853
    if not is_function: # is_function doesn't leave holes
854
        log_debug('Setting pkeys of missing rows to '
855
            +strings.as_tt(strings.urepr(default)))
856
        
857
        full_in_pkey_col = sql_gen.Col(in_pkey, full_in_table)
858
        if sql_gen.is_table_col(default):
859
            default = sql_gen.with_table(default, full_in_table)
860
        missing_rows_joins = [full_in_table, sql_gen.Join(into,
861
            {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
862
            # must use join_same_not_null or query will take forever
863
        
864
        insert_args = dict(order_by=None)
865
        if not sql.table_has_pkey(db, full_in_table): # in_table has duplicates
866
            insert_args.update(dict(distinct_on=[full_in_pkey_col]))
867
        
868
        insert_into_pkeys(sql.mk_select(db, missing_rows_joins,
869
            [full_in_pkey_col, sql_gen.NamedCol(into_out_pkey, default)],
870
            **insert_args))
871
    # otherwise, there is already an entry for every row
872
    
873
    sql.empty_temp(db, insert_in_tables+[full_in_table])
874
    
875
    srcs = []
876
    if is_function: srcs = sql_gen.cols_srcs(in_cols)
877
    return sql_gen.Col(into_out_pkey, into, srcs)
(37-37/49)