Project

General

Profile

1
# Database import/export
2

    
3
import copy
4
import csv
5
import operator
6
import warnings
7
import sys
8

    
9
import csvs
10
import exc
11
import dicts
12
import sql
13
import sql_gen
14
import streams
15
import strings
16
import util
17

    
18
##### Exceptions
19

    
20
# Can't use built-in SyntaxError because it stringifies to only the first line
21
class SyntaxError(Exception): pass
22

    
23
##### Data cleanup
24

    
25
def table_nulls_mapped__set(db, table):
26
    assert isinstance(table, sql_gen.Table)
27
    sql.run_query(db, 'SELECT util.table_nulls_mapped__set('
28
        +sql_gen.table2regclass_text(db, table)+')')
29

    
30
def table_nulls_mapped__get(db, table):
31
    assert isinstance(table, sql_gen.Table)
32
    return sql.value(sql.run_query(db, 'SELECT util.table_nulls_mapped__get('
33
        +sql_gen.table2regclass_text(db, table)+')'))
34

    
35
null_strs = ['', '-', r'\N', 'NULL', 'UNKNOWN', 'nulo']
36

    
37
def cleanup_table(db, table):
38
    table = sql_gen.as_Table(table)
39
    cols = filter(lambda c: sql_gen.is_text_col(db, c),
40
        sql.table_cols(db, table))
41
    try: pkey_col = sql.table_pkey_col(db, table)
42
    except sql.DoesNotExistException: pass
43
    else:
44
        try: cols.remove(pkey_col)
45
        except ValueError: pass
46
    if not cols: return
47
    
48
    db.log_debug('Cleaning up table', level=1.5)
49
    
50
    expr = 'trim(both from %s)'
51
    for null in null_strs: expr = 'nullif('+expr+', '+db.esc_value(null)+')'
52
    changes = [(v, sql_gen.CustomCode(expr % v.to_str(db))) for v in cols]
53
    
54
    while True:
55
        try:
56
            sql.update(db, table, changes, in_place=True, recover=True)
57
            break # successful
58
        except sql.NullValueException, e:
59
            db.log_debug('Caught exception: '+exc.str_(e))
60
            col, = e.cols
61
            sql.drop_not_null(db, col)
62
    
63
    db.log_debug('Vacuuming and reanalyzing table', level=1.5)
64
    sql.vacuum(db, table)
65

    
66
##### Error tracking
67

    
68
def track_data_error(db, errors_table, cols, value, error_code, error):
69
    '''
70
    @param errors_table If None, does nothing.
71
    '''
72
    if errors_table == None: return
73
    
74
    col_names = [c.name for c in cols]
75
    if not col_names: col_names = [None] # need at least one entry
76
    for col_name in col_names:
77
        try:
78
            sql.insert(db, errors_table, dict(column=col_name, value=value,
79
                error_code=error_code, error=error), recover=True,
80
                cacheable=True, log_level=4)
81
        except sql.DuplicateKeyException: pass
82

    
83
class ExcToErrorsTable(sql_gen.ExcToWarning):
84
    '''Handles an exception by saving it or converting it to a warning.'''
85
    def __init__(self, return_, srcs, errors_table, value=None):
86
        '''
87
        @param return_ See sql_gen.ExcToWarning
88
        @param srcs The column names for the errors table
89
        @param errors_table None|sql_gen.Table
90
        @param value The value (or an expression for it) that caused the error
91
        @pre The invalid value must be in a local variable "value" of type text.
92
        '''
93
        sql_gen.ExcToWarning.__init__(self, return_)
94
        
95
        value = sql_gen.as_Code(value)
96
        
97
        self.srcs = srcs
98
        self.errors_table = errors_table
99
        self.value = value
100
    
101
    def to_str(self, db):
102
        if not self.srcs or self.errors_table == None:
103
            return sql_gen.ExcToWarning.to_str(self, db)
104
        
105
        errors_table_cols = map(sql_gen.Col,
106
            ['column', 'value', 'error_code', 'error'])
107
        col_names_query = sql.mk_select(db, sql_gen.NamedValues('c', None,
108
            [[c.name] for c in self.srcs]), order_by=None)
109
        insert_query = sql.mk_insert_select(db, self.errors_table,
110
            errors_table_cols,
111
            sql_gen.Values(errors_table_cols).to_str(db))+';\n'
112
        return '''\
113
-- Save error in errors table.
114
DECLARE
115
    error_code text := SQLSTATE;
116
    error text := SQLERRM;
117
    value text := '''+self.value.to_str(db)+''';
118
    "column" text;
119
BEGIN
120
    -- Insert the value and error for *each* source column.
121
'''+strings.indent(sql_gen.RowExcIgnore(None, col_names_query, insert_query,
122
    row_var=errors_table_cols[0]).to_str(db))+'''
123
END;
124

    
125
'''+self.return_.to_str(db)
126

    
127
def data_exception_handler(*args, **kw_args):
128
    '''Handles a data_exception by saving it or converting it to a warning.
129
    For params, see ExcToErrorsTable().
130
    '''
131
    return sql_gen.data_exception_handler(ExcToErrorsTable(*args, **kw_args))
132

    
133
def cast(db, type_, col, errors_table=None):
134
    '''Casts an (unrenamed) column or value.
135
    If errors_table set and col has srcs, saves errors in errors_table (using
136
    col's srcs attr as source columns). Otherwise, converts errors to warnings.
137
    @param col str|sql_gen.Col|sql_gen.Literal
138
    @param errors_table None|sql_gen.Table|str
139
    '''
140
    col = sql_gen.as_Col(col)
141
    
142
    # Don't convert exceptions to warnings for user-supplied constants
143
    if isinstance(col, sql_gen.Literal): return sql_gen.Cast(type_, col)
144
    
145
    assert not isinstance(col, sql_gen.NamedCol)
146
    
147
    function_name = strings.first_word(type_)
148
    srcs = col.srcs
149
    save_errors = errors_table != None and srcs
150
    if save_errors: # function will be unique for the given srcs
151
        function_name = strings.ustr(sql_gen.FunctionCall(function_name,
152
            *map(sql_gen.to_name_only_col, srcs)))
153
    function = db.TempFunction(function_name)
154
    
155
    # Create function definition
156
    modifiers = 'STRICT'
157
    if not save_errors: modifiers = 'IMMUTABLE '+modifiers
158
    value_param = sql_gen.FunctionParam('value', 'anyelement')
159
    handler = data_exception_handler('RETURN NULL;\n', srcs, errors_table,
160
        value_param.name)
161
    body = sql_gen.CustomCode(handler.to_str(db, '''\
162
/* The explicit cast to the return type is needed to make the cast happen
163
inside the try block. (Implicit casts to the return type happen at the end
164
of the function, outside any block.) */
165
RETURN '''+sql_gen.Cast(type_, sql_gen.CustomCode('value')).to_str(db)+''';
166
'''))
167
    body.lang='plpgsql'
168
    sql.define_func(db, sql_gen.FunctionDef(function, type_, body,
169
        [value_param], modifiers))
170
    
171
    return sql_gen.FunctionCall(function, col)
172

    
173
def func_wrapper_exception_handler(db, return_, args, errors_table):
174
    '''Handles a function call's data_exceptions.
175
    Supports PL/Python functions.
176
    @param return_ See data_exception_handler()
177
    @param args [arg...] Function call's args
178
    @param errors_table See data_exception_handler()
179
    '''
180
    args = filter(sql_gen.has_srcs, args)
181
    
182
    srcs = sql_gen.cross_join_srcs(args)
183
    value = sql_gen.merge_not_null(db, ',', args)
184
    return sql_gen.NestedExcHandler(
185
        data_exception_handler(return_, srcs, errors_table, value)
186
        , sql_gen.plpythonu_error_handler
187
        )
188

    
189
def cast_temp_col(db, type_, col, errors_table=None):
190
    '''Like cast(), but creates a new column with the cast values if the input
191
    is a column.
192
    @return The new column or cast value
193
    '''
194
    def cast_(col): return cast(db, type_, col, errors_table)
195
    
196
    try: col = sql_gen.underlying_col(col)
197
    except sql_gen.NoUnderlyingTableException: return sql_gen.wrap(cast_, col)
198
    
199
    table = col.table
200
    new_col = sql_gen.suffixed_col(col, '::'+strings.first_word(type_))
201
    expr = cast_(col)
202
    
203
    # Add column
204
    new_typed_col = sql_gen.TypedCol(new_col.name, type_)
205
    sql.add_col(db, table, new_typed_col, comment=strings.urepr(col)+'::'+type_)
206
    new_col.name = new_typed_col.name # propagate any renaming
207
    
208
    sql.update(db, table, [(new_col, expr)], in_place=True, recover=True)
209
    
210
    return new_col
211

    
212
def errors_table(db, table, if_exists=True):
213
    '''
214
    @param if_exists If set, returns None if the errors table doesn't exist
215
    @return None|sql_gen.Table
216
    '''
217
    table = sql_gen.as_Table(table)
218
    if table.srcs != (): table = table.srcs[0]
219
    
220
    errors_table = sql_gen.suffixed_table(table, '.errors')
221
    if if_exists and not sql.table_exists(db, errors_table): return None
222
    return errors_table
223

    
224
def mk_errors_table(db, table):
225
    errors_table_ = errors_table(db, table, if_exists=False)
226
    if sql.table_exists(db, errors_table_, cacheable=False): return
227
    
228
    typed_cols = [
229
        sql_gen.TypedCol('column', 'text'),
230
        sql_gen.TypedCol('value', 'text'),
231
        sql_gen.TypedCol('error_code', 'character varying(5)', nullable=False),
232
        sql_gen.TypedCol('error', 'text', nullable=False),
233
        ]
234
    sql.create_table(db, errors_table_, typed_cols, has_pkey=False)
235
    index_cols = ['column', sql_gen.CustomCode('md5(value)'), 'error_code',
236
        sql_gen.CustomCode('md5(error)')]
237
    sql.add_index(db, index_cols, errors_table_, unique=True)
238

    
239
##### Import
240

    
241
row_num_col_def = copy.copy(sql.row_num_col_def)
242
row_num_col_def.name = 'row_num'
243
row_num_col_def.type = 'integer'
244

    
245
def append_csv(db, table, reader, header):
246
    def esc_name_(name): return sql.esc_name(db, name)
247
    
248
    def log(msg, level=1): db.log_debug(msg, level)
249
    
250
    # Wrap in standardizing stream
251
    cols_ct = len(header)
252
    stream = csvs.InputRewriter(streams.ProgressInputStream(csvs.StreamFilter(
253
        csvs.ColCtFilter(reader, cols_ct)), sys.stderr, msg='Read %d row(s)',
254
        n=1000))
255
    dialect = stream.dialect # use default dialect
256
    
257
    # Create COPY FROM statement
258
    if header == sql.table_col_names(db, table): cols_str = ''
259
    else: cols_str =' ('+(', '.join(map(esc_name_, header)))+')'
260
    copy_from = ('COPY '+table.to_str(db)+cols_str+' FROM STDIN DELIMITER '
261
        +db.esc_value(dialect.delimiter)+' NULL '+db.esc_value(''))
262
    assert not csvs.is_tsv(dialect)
263
    copy_from += ' CSV'
264
    if dialect.quoting != csv.QUOTE_NONE:
265
        quote_str = db.esc_value(dialect.quotechar)
266
        copy_from += ' QUOTE '+quote_str
267
        if dialect.doublequote: copy_from += ' ESCAPE '+quote_str
268
    copy_from += ';\n'
269
    
270
    log(copy_from, level=2)
271
    try: db.db.cursor().copy_expert(copy_from, stream)
272
    except Exception, e: sql.parse_exception(db, e, recover=True)
273

    
274
def import_csv(db, table, reader, header):
275
    def log(msg, level=1): db.log_debug(msg, level)
276
    
277
    # Get format info
278
    col_names = map(strings.to_unicode, header)
279
    for i, col in enumerate(col_names): # replace empty column names
280
        if col == '': col_names[i] = 'column_'+str(i)
281
    
282
    # Select schema and escape names
283
    def esc_name(name): return db.esc_name(name)
284
    
285
    typed_cols = [sql_gen.TypedCol(v, 'text') for v in col_names]
286
    typed_cols.insert(0, row_num_col_def)
287
    header.insert(0, row_num_col_def.name)
288
    reader = csvs.RowNumFilter(reader)
289
    
290
    log('Creating table')
291
    # Note that this is not rolled back if the import fails. Instead, it is
292
    # cached, and will not be re-run if the import is retried.
293
    sql.create_table(db, table, typed_cols, has_pkey=False, col_indexes=False)
294
    
295
    # Free memory used by deleted (rolled back) rows from any failed import.
296
    # This MUST be run so that the rows will be stored in inserted order, and
297
    # the row_num added after import will match up with the CSV's row order.
298
    sql.truncate(db, table)
299
    
300
    # Load the data
301
    def load(): append_csv(db, table, reader, header)
302
    sql.with_savepoint(db, load)
303
    
304
    cleanup_table(db, table)
305

    
306
def put(db, table, row, pkey_=None, row_ct_ref=None, on_error=exc.reraise):
307
    '''Recovers from errors.
308
    Only works under PostgreSQL (uses INSERT RETURNING).
309
    '''
310
    return put_table(db, table, [], row, row_ct_ref, on_error=on_error)
311

    
312
def get(db, table, row, pkey, row_ct_ref=None, create=False):
313
    '''Recovers from errors'''
314
    try:
315
        return sql.value(sql.select(db, table, [pkey], row, limit=1,
316
            recover=True))
317
    except StopIteration:
318
        if not create: raise
319
        return put(db, table, row, pkey, row_ct_ref) # insert new row
320

    
321
def is_func_result(col):
322
    return col.table.name.find('(') >= 0 and col.name == 'result'
323

    
324
def into_table_name(out_table, in_tables0, mapping, is_func):
325
    def in_col_str(in_col):
326
        in_col = sql_gen.remove_col_rename(in_col)
327
        if isinstance(in_col, sql_gen.Col):
328
            table = in_col.table
329
            if table == in_tables0:
330
                in_col = sql_gen.to_name_only_col(in_col)
331
            elif is_func_result(in_col): in_col = table # omit col name
332
        return strings.ustr(in_col)
333
    
334
    str_ = strings.ustr(out_table)
335
    if is_func:
336
        str_ += '('
337
        
338
        try: value_in_col = mapping['value']
339
        except KeyError:
340
            str_ += ', '.join((strings.ustr(k)+'='+in_col_str(v)
341
                for k, v in mapping.iteritems()))
342
        else: str_ += in_col_str(value_in_col)
343
        
344
        str_ += ')'
345
    else:
346
        out_col = 'rank'
347
        try: in_col = mapping[out_col]
348
        except KeyError: str_ += '_pkeys'
349
        else: # has a rank column, so hierarchical
350
            str_ += '['+strings.ustr(out_col)+'='+in_col_str(in_col)+']'
351
    return str_
352

    
353
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, default=None,
354
    col_defaults={}, on_error=exc.reraise):
355
    '''Recovers from errors.
356
    Only works under PostgreSQL (uses INSERT RETURNING).
357
    
358
    Warning: This function's normalizing algorithm does not support database
359
    triggers that populate fields covered by the unique constraint used to do
360
    the DISTINCT ON. Such fields must be populated by the mappings instead.
361
    (Other unique constraints and other non-unique fields are not affected by
362
    this restriction on triggers. Note that the primary key will normally not be
363
    the DISTINCT ON constraint, so trigger-populated natural keys are supported
364
    *unless* the input table contains duplicate rows for some generated keys.)
365
    
366
    Note that much of the complexity of the normalizing algorithm is due to
367
    PostgreSQL (and other DB systems) not having a native command for
368
    insert/on duplicate select. This operation is a cross between MySQL's
369
    INSERT ON DUPLICATE KEY UPDATE (which does not provide SELECT
370
    functionality), and PostgreSQL's INSERT RETURNING (which throws an error
371
    on duplicate instead of returning the existing row).
372
    
373
    @param in_tables The main input table to select from, followed by a list of
374
        tables to join with it using the main input table's pkey
375
    @param mapping dict(out_table_col=in_table_col, ...)
376
        * out_table_col: str (*not* sql_gen.Col)
377
        * in_table_col: sql_gen.Col|literal-value
378
    @param default The *output* column to use as the pkey for missing rows.
379
        If this output column does not exist in the mapping, uses None.
380
    @param col_defaults Default values for required columns.
381
    @return sql_gen.Col Where the output pkeys are made available
382
    '''
383
    import psycopg2.extensions
384
    
385
    # Special handling for functions with hstore params
386
    if out_table == '_map':
387
        import psycopg2.extras
388
        psycopg2.extras.register_hstore(db.db)
389
        
390
        # Parse args
391
        try: value = mapping.pop('value')
392
        except KeyError: return None # value required
393
        
394
        mapping = dict([(k, sql_gen.get_value(v))
395
            for k, v in mapping.iteritems()]) # unwrap literal value
396
        mapping = dict(map=mapping, value=value) # non-value params -> hstore
397
    
398
    out_table = sql_gen.as_Table(out_table)
399
    
400
    def log_debug(msg): db.log_debug(msg, level=1.5)
401
    def col_ustr(str_):
402
        return strings.repr_no_u(sql_gen.remove_col_rename(str_))
403
    
404
    log_debug('********** New iteration **********')
405
    log_debug('Inserting these input columns into '+strings.as_tt(
406
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
407
    
408
    is_function = sql.function_exists(db, out_table)
409
    
410
    if is_function: row_ct_ref = None # only track inserted rows
411
    
412
    # Warn if inserting empty table rows
413
    if not mapping and not is_function: # functions with no args OK
414
        warnings.warn(UserWarning('Inserting empty table row(s)'))
415
    
416
    if is_function: out_pkey = 'result'
417
    else: out_pkey = sql.pkey_name(db, out_table, recover=True)
418
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
419
    
420
    in_tables_ = copy.copy(in_tables) # don't modify input!
421
    try: in_tables0 = in_tables_.pop(0) # first table is separate
422
    except IndexError: in_tables0 = None
423
    else:
424
        in_pkey = sql.pkey_name(db, in_tables0, recover=True)
425
        in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
426
    
427
    # Determine if can use optimization for only literal values
428
    is_literals = not reduce(operator.or_, map(sql_gen.is_table_col,
429
        mapping.values()), False)
430
    is_literals_or_function = is_literals or is_function
431
    
432
    if in_tables0 == None: errors_table_ = None
433
    else: errors_table_ = errors_table(db, in_tables0)
434
    
435
    # Create input joins from list of input tables
436
    input_joins = [in_tables0]+[sql_gen.Join(v,
437
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
438
    
439
    orig_mapping = mapping.copy()
440
    if mapping == {} and not is_function: # need >= one column for INSERT SELECT
441
        mapping = {out_pkey: None} # ColDict will replace with default value
442
    
443
    if not is_literals:
444
        into = sql_gen.as_Table(into_table_name(out_table, in_tables0, mapping,
445
            is_function))
446
        # Ensure into's out_pkey is different from in_pkey by prepending "out."
447
        if is_function: into_out_pkey = out_pkey
448
        else: into_out_pkey = 'out.'+out_pkey
449
        
450
        # Set column sources
451
        in_cols = filter(sql_gen.is_table_col, mapping.values())
452
        for col in in_cols:
453
            if col.table == in_tables0: col.set_srcs(sql_gen.src_self)
454
        
455
        log_debug('Joining together input tables into temp table')
456
        # Place in new table so don't modify input and for speed
457
        in_table = sql_gen.Table('in')
458
        mapping = dicts.join(mapping, sql.flatten(db, in_table, input_joins,
459
            in_cols, preserve=[in_pkey_col]))
460
        input_joins = [in_table]
461
        db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
462
    
463
    # Wrap mapping in a sql_gen.ColDict.
464
    # sql_gen.ColDict sanitizes both keys and values passed into it.
465
    # Do after applying dicts.join() because that returns a plain dict.
466
    mapping = sql_gen.ColDict(db, out_table, mapping)
467
    
468
    # Save all rows since in_table may have rows deleted
469
    if is_literals: pass
470
    elif is_function: full_in_table = in_table
471
    else:
472
        full_in_table = sql_gen.suffixed_table(in_table, '_full')
473
        sql.copy_table(db, in_table, full_in_table)
474
    
475
    pkeys_table_exists_ref = [False]
476
    def insert_into_pkeys(query, **kw_args):
477
        if pkeys_table_exists_ref[0]:
478
            sql.insert_select(db, into, [in_pkey, into_out_pkey], query,
479
                **kw_args)
480
        else:
481
            kw_args.setdefault('add_pkey_', True)
482
            
483
            sql.run_query_into(db, query, into=into, **kw_args)
484
            pkeys_table_exists_ref[0] = True
485
    
486
    def mk_main_select(joins, cols): return sql.mk_select(db, joins, cols)
487
    
488
    if is_literals: insert_in_table = None
489
    else:
490
        insert_in_table = in_table
491
        insert_in_tables = [insert_in_table]
492
    join_cols = sql_gen.ColDict(db, out_table)
493
    
494
    exc_strs = set()
495
    def log_exc(e):
496
        e_str = exc.str_(e, first_line_only=True)
497
        log_debug('Caught exception: '+e_str)
498
        if e_str in exc_strs: # avoid infinite loops
499
            log_debug('Exception already seen, handler broken')
500
            on_error(e)
501
            remove_all_rows()
502
            return False
503
        else: exc_strs.add(e_str)
504
        return True
505
    
506
    ignore_all_ref = [False]
507
    def remove_all_rows():
508
        log_debug('Ignoring all rows')
509
        ignore_all_ref[0] = True # just return the default value column
510
    
511
    def handle_unknown_exc(e):
512
        log_debug('No handler for exception')
513
        on_error(e)
514
        remove_all_rows()
515
    
516
    def ensure_cond(cond, e, passed=False, failed=False):
517
        if is_literals: # we know the constraint was applied exactly once
518
            if passed: pass
519
            elif failed: remove_all_rows()
520
            else: raise NotImplementedError()
521
        else:
522
            if not is_function:
523
                out_table_cols = sql_gen.ColDict(db, out_table)
524
                out_table_cols.update(util.dict_subset_right_join({},
525
                    sql.table_col_names(db, out_table)))
526
            
527
            in_cols = []
528
            cond = strings.ustr(cond)
529
            orig_cond = cond
530
            cond = sql_gen.map_expr(db, cond, mapping, in_cols)
531
            if not is_function:
532
                cond = sql_gen.map_expr(db, cond, out_table_cols)
533
            
534
            log_debug('Ignoring rows that do not satisfy '+strings.as_tt(cond))
535
            cur = None
536
            if cond == sql_gen.false_expr:
537
                assert failed
538
                remove_all_rows()
539
            elif cond == sql_gen.true_expr: assert passed
540
            else:
541
                while True:
542
                    not_cond = sql_gen.NotCond(sql_gen.CustomCode(cond))
543
                    try:
544
                        cur = sql.delete(db, insert_in_table, not_cond)
545
                        break
546
                    except sql.DoesNotExistException, e:
547
                        if e.type != 'column': raise
548
                        
549
                        last_cond = cond
550
                        cond = sql_gen.map_expr(db, cond, {e.name: None})
551
                        if cond == last_cond: raise # not fixable
552
            
553
            # If any rows failed cond
554
            if failed or cur != None and cur.rowcount > 0:
555
                track_data_error(db, errors_table_,
556
                    sql_gen.cross_join_srcs(in_cols), None, e.cause.pgcode,
557
                    strings.ensure_newl(strings.ustr(e.cause.pgerror))
558
                    +'condition: '+orig_cond+'\ntranslated condition: '+cond)
559
    
560
    not_null_cols = set()
561
    def ignore(in_col, value, e):
562
        if sql_gen.is_table_col(in_col):
563
            in_col = sql_gen.with_table(in_col, insert_in_table)
564
            
565
            track_data_error(db, errors_table_, in_col.srcs, value,
566
                e.cause.pgcode, e.cause.pgerror)
567
            
568
            sql.add_index(db, in_col, insert_in_table) # enable fast filtering
569
            if value != None and in_col not in not_null_cols:
570
                log_debug('Replacing invalid value '
571
                    +strings.as_tt(strings.urepr(value))+' with NULL in column '
572
                    +strings.as_tt(in_col.to_str(db)))
573
                sql.update(db, insert_in_table, [(in_col, None)],
574
                    sql_gen.ColValueCond(in_col, value))
575
            else:
576
                log_debug('Ignoring rows with '+strings.as_tt(in_col.to_str(db))
577
                    +' = '+strings.as_tt(strings.urepr(value)))
578
                sql.delete(db, insert_in_table,
579
                    sql_gen.ColValueCond(in_col, value))
580
                if value == None: not_null_cols.add(in_col)
581
        else:
582
            assert isinstance(in_col, sql_gen.NamedCol)
583
            in_value = sql_gen.remove_col_rename(in_col)
584
            assert sql_gen.is_literal(in_value)
585
            if value == in_value.value:
586
                if value != None:
587
                    log_debug('Replacing invalid literal '
588
                        +strings.as_tt(in_col.to_str(db))+' with NULL')
589
                    mapping[in_col.name] = None
590
                else:
591
                    remove_all_rows()
592
            # otherwise, all columns were being ignore()d because the specific
593
            # column couldn't be identified, and this was not the invalid column
594
    
595
    if not is_literals:
596
        def insert_pkeys_table(which):
597
            return sql_gen.Table(sql_gen.concat(in_table.name,
598
                '_insert_'+which+'_pkeys'))
599
        insert_out_pkeys = insert_pkeys_table('out')
600
        insert_in_pkeys = insert_pkeys_table('in')
601
    
602
    def mk_func_call():
603
        args = dict(((k.name, v) for k, v in mapping.iteritems()))
604
        return sql_gen.FunctionCall(out_table, **args), args
605
    
606
    missing_msg = None
607
    
608
    # Do inserts and selects
609
    while True:
610
        has_joins = join_cols != {}
611
        
612
        if ignore_all_ref[0]: break # unrecoverable error, so don't do main case
613
        
614
        # Prepare to insert new rows
615
        if is_function:
616
            if is_literals:
617
                log_debug('Calling function')
618
                func_call, args = mk_func_call()
619
        else:
620
            log_debug('Trying to insert new rows')
621
            insert_args = dict(recover=True, cacheable=False)
622
            if has_joins:
623
                insert_args.update(dict(ignore=True))
624
            else:
625
                insert_args.update(dict(returning=out_pkey))
626
                if not is_literals:
627
                    insert_args.update(dict(into=insert_out_pkeys))
628
            main_select = mk_main_select([insert_in_table], [sql_gen.with_table(
629
                c, insert_in_table) for c in mapping.values()])
630
        
631
        try:
632
            cur = None
633
            if is_function:
634
                if is_literals:
635
                    cur = sql.select(db, fields=[func_call], recover=True,
636
                        cacheable=True)
637
                else:
638
                    log_debug('Defining wrapper function')
639
                    
640
                    func_call, args = mk_func_call()
641
                    func_call = sql_gen.NamedCol(into_out_pkey, func_call)
642
                    
643
                    # Create empty pkeys table so its row type can be used
644
                    insert_into_pkeys(sql.mk_select(db, input_joins,
645
                        [in_pkey_col, func_call], limit=0), add_pkey_=False,
646
                        recover=True)
647
                    result_type = db.col_info(sql_gen.Col(into_out_pkey,
648
                        into)).type
649
                    
650
                    ## Create error handling wrapper function
651
                    
652
                    wrapper = db.TempFunction(sql_gen.concat(into.name,
653
                        '_wrap'))
654
                    
655
                    select_cols = [in_pkey_col]+args.values()
656
                    row_var = copy.copy(sql_gen.row_var)
657
                    row_var.set_srcs([in_table])
658
                    in_pkey_var = sql_gen.Col(in_pkey, row_var)
659
                    
660
                    args = dict(((k, sql_gen.with_table(v, row_var))
661
                        for k, v in args.iteritems()))
662
                    func_call = sql_gen.FunctionCall(out_table, **args)
663
                    
664
                    def mk_return(result):
665
                        return sql_gen.ReturnQuery(sql.mk_select(db,
666
                            fields=[in_pkey_var, result], explain=False))
667
                    exc_handler = func_wrapper_exception_handler(db,
668
                        mk_return(sql_gen.Cast(result_type, None)),
669
                        args.values(), errors_table_)
670
                    
671
                    sql.define_func(db, sql_gen.FunctionDef(wrapper,
672
                        sql_gen.SetOf(into),
673
                        sql_gen.RowExcIgnore(sql_gen.RowType(in_table),
674
                            sql.mk_select(db, input_joins),
675
                            mk_return(func_call), exc_handler=exc_handler)
676
                        ))
677
                    wrapper_table = sql_gen.FunctionCall(wrapper)
678
                    
679
                    log_debug('Calling function')
680
                    insert_into_pkeys(sql.mk_select(db, wrapper_table,
681
                        order_by=None), recover=True, cacheable=False)
682
                    sql.add_pkey_or_index(db, into)
683
            else:
684
                cur = sql.insert_select(db, out_table, mapping.keys(),
685
                    main_select, **insert_args)
686
            break # insert successful
687
        except sql.MissingCastException, e:
688
            if not log_exc(e): break
689
            
690
            type_ = e.type
691
            if e.col == None: out_cols = mapping.keys()
692
            else: out_cols = [e.col]
693
            
694
            for out_col in out_cols:
695
                log_debug('Casting '+strings.as_tt(strings.repr_no_u(out_col))
696
                    +' input to '+strings.as_tt(type_))
697
                in_col = mapping[out_col]
698
                while True:
699
                    try:
700
                        mapping[out_col] = cast_temp_col(db, type_, in_col,
701
                            errors_table_)
702
                        break # cast successful
703
                    except sql.InvalidValueException, e:
704
                        if not log_exc(e): break
705
                        
706
                        ignore(in_col, e.value, e)
707
        except sql.DuplicateKeyException, e:
708
            if not log_exc(e): break
709
            
710
            # Different rows violating different unique constraints not
711
            # supported
712
            assert not join_cols
713
            
714
            if e.cond != None: ensure_cond(e.cond, e, passed=True)
715
            
716
            join_cols.update(util.dict_subset_right_join(mapping, e.cols))
717
            log_debug('Ignoring existing rows, comparing on these columns:\n'
718
                +strings.as_inline_table(join_cols, ustr=col_ustr))
719
            
720
            if is_literals:
721
                return sql.value(sql.select(db, out_table, [out_pkey_col],
722
                    join_cols, order_by=None))
723
            
724
            # Uniquify and filter input table to avoid (most) duplicate keys
725
            # (Additional duplicates may be added concurrently and will be
726
            # filtered out separately upon insert.)
727
            insert_in_table = sql.distinct_table(db, insert_in_table,
728
                join_cols.values(), [insert_in_table,
729
                sql_gen.Join(out_table, join_cols, sql_gen.filter_out)])
730
            insert_in_tables.append(insert_in_table)
731
        except sql.NullValueException, e:
732
            if not log_exc(e): break
733
            
734
            out_col, = e.cols
735
            try: in_col = mapping[out_col]
736
            except KeyError, e:
737
                try: in_col = mapping[out_col] = col_defaults[out_col]
738
                except KeyError:
739
                    missing_msg = 'Missing mapping for NOT NULL column '+out_col
740
                    log_debug(missing_msg)
741
                    remove_all_rows()
742
            else: ignore(in_col, None, e)
743
        except sql.CheckException, e:
744
            if not log_exc(e): break
745
            
746
            ensure_cond(e.cond, e, failed=True)
747
        except sql.InvalidValueException, e:
748
            if not log_exc(e): break
749
            
750
            for in_col in mapping.values(): ignore(in_col, e.value, e)
751
        except psycopg2.extensions.TransactionRollbackError, e:
752
            if not log_exc(e): break
753
            # retry
754
        except sql.DatabaseErrors, e:
755
            if not log_exc(e): break
756
            
757
            handle_unknown_exc(e)
758
        # after exception handled, rerun loop with additional constraints
759
    
760
    # Resolve default value column
761
    if default != None:
762
        if ignore_all_ref[0]: mapping.update(orig_mapping) # use input cols
763
        try: default = mapping[default]
764
        except KeyError:
765
            db.log_debug('Default value column '
766
                +strings.as_tt(strings.repr_no_u(default))
767
                +' does not exist in mapping, falling back to None', level=2.1)
768
            default = None
769
        else: default = sql_gen.remove_col_rename(default)
770
    
771
    if missing_msg != None and default == None:
772
        warnings.warn(UserWarning(missing_msg))
773
        # not an error because sometimes the mappings include
774
        # extra tables which aren't used by the dataset
775
    
776
    # Handle unrecoverable errors
777
    if ignore_all_ref[0]:
778
        log_debug('Returning default: '+strings.as_tt(strings.urepr(default)))
779
        return default
780
    
781
    if cur != None and row_ct_ref != None and cur.rowcount >= 0:
782
        row_ct_ref[0] += cur.rowcount
783
    
784
    if is_literals: return sql.value(cur)
785
    
786
    if is_function: pass # pkeys table already created
787
    elif has_joins:
788
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols)]
789
        log_debug('Getting output table pkeys of existing/inserted rows')
790
        insert_into_pkeys(sql.mk_select(db, select_joins, [in_pkey_col,
791
            sql_gen.NamedCol(into_out_pkey, out_pkey_col)], order_by=None))
792
    else:
793
        sql.add_row_num(db, insert_out_pkeys) # for joining with input pkeys
794
        
795
        log_debug('Getting input table pkeys of inserted rows')
796
        # Note that mk_main_select() does not use ORDER BY. Instead, assume that
797
        # since the SELECT query is identical to the one used in INSERT SELECT,
798
        # its rows will be retrieved in the same order.
799
        sql.run_query_into(db, mk_main_select(input_joins, [in_pkey]),
800
            into=insert_in_pkeys)
801
        sql.add_row_num(db, insert_in_pkeys) # for joining with output pkeys
802
        
803
        assert sql.table_row_count(db, insert_out_pkeys) == sql.table_row_count(
804
            db, insert_in_pkeys)
805
        
806
        log_debug('Combining output and input pkeys in inserted order')
807
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
808
            {sql.row_num_col: sql_gen.join_same_not_null})]
809
        in_col = sql_gen.Col(in_pkey, insert_in_pkeys)
810
        out_col = sql_gen.NamedCol(into_out_pkey,
811
            sql_gen.Col(out_pkey, insert_out_pkeys))
812
        insert_into_pkeys(sql.mk_select(db, pkey_joins, [in_col, out_col],
813
            order_by=None))
814
        
815
        sql.empty_temp(db, [insert_out_pkeys, insert_in_pkeys])
816
    
817
    if not is_function: # is_function doesn't leave holes
818
        log_debug('Setting pkeys of missing rows to '
819
            +strings.as_tt(strings.urepr(default)))
820
        
821
        full_in_pkey_col = sql_gen.Col(in_pkey, full_in_table)
822
        if sql_gen.is_table_col(default):
823
            default = sql_gen.with_table(default, full_in_table)
824
        missing_rows_joins = [full_in_table, sql_gen.Join(into,
825
            {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
826
            # must use join_same_not_null or query will take forever
827
        
828
        insert_args = dict(order_by=None)
829
        if not sql.table_has_pkey(db, full_in_table): # in_table has duplicates
830
            insert_args.update(dict(distinct_on=[full_in_pkey_col]))
831
        
832
        insert_into_pkeys(sql.mk_select(db, missing_rows_joins,
833
            [full_in_pkey_col, sql_gen.NamedCol(into_out_pkey, default)],
834
            **insert_args))
835
    # otherwise, there is already an entry for every row
836
    
837
    sql.empty_temp(db, insert_in_tables+[full_in_table])
838
    
839
    srcs = []
840
    if is_function: srcs = sql_gen.cols_srcs(in_cols)
841
    return sql_gen.Col(into_out_pkey, into, srcs)
(37-37/49)