Project

General

Profile

1 3077 aaronmk
# Database import/export
2
3 3534 aaronmk
import copy
4 4995 aaronmk
import csv
5 3431 aaronmk
import operator
6 14820 aaronmk
import os
7 3714 aaronmk
import warnings
8 5588 aaronmk
import sys
9 3431 aaronmk
10 4995 aaronmk
import csvs
11 3077 aaronmk
import exc
12
import dicts
13
import sql
14
import sql_gen
15 5588 aaronmk
import streams
16 3077 aaronmk
import strings
17
import util
18
19 3645 aaronmk
##### Exceptions
20
21
# Can't use built-in SyntaxError because it stringifies to only the first line
22
class SyntaxError(Exception): pass
23
24 3081 aaronmk
##### Data cleanup
25
26 10187 aaronmk
def table_nulls_mapped__set(db, table):
27
    assert isinstance(table, sql_gen.Table)
28
    sql.run_query(db, 'SELECT util.table_nulls_mapped__set('
29
        +sql_gen.table2regclass_text(db, table)+')')
30
31
def table_nulls_mapped__get(db, table):
32
    assert isinstance(table, sql_gen.Table)
33
    return sql.value(sql.run_query(db, 'SELECT util.table_nulls_mapped__get('
34
        +sql_gen.table2regclass_text(db, table)+')'))
35
36 14822 aaronmk
null_strs_str_default = r',-,\N,NULL,N/A,UNKNOWN,nulo'
37 4209 aaronmk
38 14816 aaronmk
null_strs_str = os.getenv('null_strs', null_strs_str_default)
39
null_strs = null_strs_str.split(',')
40
41 4447 aaronmk
def cleanup_table(db, table):
42 10188 aaronmk
    '''idempotent'''
43 3081 aaronmk
    table = sql_gen.as_Table(table)
44 10195 aaronmk
    assert sql.table_exists(db, table)
45 10188 aaronmk
46
    if table_nulls_mapped__get(db, table): return # already cleaned up
47
48 5384 aaronmk
    cols = filter(lambda c: sql_gen.is_text_col(db, c),
49
        sql.table_cols(db, table))
50 5394 aaronmk
    try: pkey_col = sql.table_pkey_col(db, table)
51
    except sql.DoesNotExistException: pass
52
    else:
53
        try: cols.remove(pkey_col)
54
        except ValueError: pass
55 4928 aaronmk
    if not cols: return
56 3081 aaronmk
57 4993 aaronmk
    db.log_debug('Cleaning up table', level=1.5)
58
59 14821 aaronmk
    db.log_debug('null_strs = '+repr(null_strs), level=1.5)
60 14826 aaronmk
    expr = 'trim(both from %s)' # also converts character varying fields to text
61 4209 aaronmk
    for null in null_strs: expr = 'nullif('+expr+', '+db.esc_value(null)+')'
62
    changes = [(v, sql_gen.CustomCode(expr % v.to_str(db))) for v in cols]
63 3081 aaronmk
64 4444 aaronmk
    while True:
65
        try:
66
            sql.update(db, table, changes, in_place=True, recover=True)
67
            break # successful
68
        except sql.NullValueException, e:
69 4457 aaronmk
            db.log_debug('Caught exception: '+exc.str_(e))
70 4444 aaronmk
            col, = e.cols
71
            sql.drop_not_null(db, col)
72 4992 aaronmk
73
    db.log_debug('Vacuuming and reanalyzing table', level=1.5)
74
    sql.vacuum(db, table)
75 10188 aaronmk
76
    table_nulls_mapped__set(db, table)
77 3081 aaronmk
78 3078 aaronmk
##### Error tracking
79
80
def track_data_error(db, errors_table, cols, value, error_code, error):
81
    '''
82
    @param errors_table If None, does nothing.
83
    '''
84 5817 aaronmk
    if errors_table == None: return
85 3078 aaronmk
86 5815 aaronmk
    col_names = [c.name for c in cols]
87
    if not col_names: col_names = [None] # need at least one entry
88
    for col_name in col_names:
89 3078 aaronmk
        try:
90 5815 aaronmk
            sql.insert(db, errors_table, dict(column=col_name, value=value,
91 3078 aaronmk
                error_code=error_code, error=error), recover=True,
92
                cacheable=True, log_level=4)
93
        except sql.DuplicateKeyException: pass
94
95 3506 aaronmk
class ExcToErrorsTable(sql_gen.ExcToWarning):
96
    '''Handles an exception by saving it or converting it to a warning.'''
97 3511 aaronmk
    def __init__(self, return_, srcs, errors_table, value=None):
98 3506 aaronmk
        '''
99
        @param return_ See sql_gen.ExcToWarning
100
        @param srcs The column names for the errors table
101
        @param errors_table None|sql_gen.Table
102 3511 aaronmk
        @param value The value (or an expression for it) that caused the error
103 3506 aaronmk
        @pre The invalid value must be in a local variable "value" of type text.
104
        '''
105
        sql_gen.ExcToWarning.__init__(self, return_)
106
107 3511 aaronmk
        value = sql_gen.as_Code(value)
108
109 3506 aaronmk
        self.srcs = srcs
110
        self.errors_table = errors_table
111 3511 aaronmk
        self.value = value
112 3501 aaronmk
113 3506 aaronmk
    def to_str(self, db):
114
        if not self.srcs or self.errors_table == None:
115
            return sql_gen.ExcToWarning.to_str(self, db)
116
117 3459 aaronmk
        errors_table_cols = map(sql_gen.Col,
118
            ['column', 'value', 'error_code', 'error'])
119 3465 aaronmk
        col_names_query = sql.mk_select(db, sql_gen.NamedValues('c', None,
120 3506 aaronmk
            [[c.name] for c in self.srcs]), order_by=None)
121
        insert_query = sql.mk_insert_select(db, self.errors_table,
122
            errors_table_cols,
123 3465 aaronmk
            sql_gen.Values(errors_table_cols).to_str(db))+';\n'
124 3506 aaronmk
        return '''\
125 3459 aaronmk
-- Save error in errors table.
126
DECLARE
127
    error_code text := SQLSTATE;
128
    error text := SQLERRM;
129 3511 aaronmk
    value text := '''+self.value.to_str(db)+''';
130 3529 aaronmk
    "column" text;
131 3459 aaronmk
BEGIN
132
    -- Insert the value and error for *each* source column.
133 3529 aaronmk
'''+strings.indent(sql_gen.RowExcIgnore(None, col_names_query, insert_query,
134 3467 aaronmk
    row_var=errors_table_cols[0]).to_str(db))+'''
135 3459 aaronmk
END;
136 3501 aaronmk
137 3506 aaronmk
'''+self.return_.to_str(db)
138 3459 aaronmk
139 3507 aaronmk
def data_exception_handler(*args, **kw_args):
140 3506 aaronmk
    '''Handles a data_exception by saving it or converting it to a warning.
141
    For params, see ExcToErrorsTable().
142
    '''
143
    return sql_gen.data_exception_handler(ExcToErrorsTable(*args, **kw_args))
144
145 3078 aaronmk
def cast(db, type_, col, errors_table=None):
146
    '''Casts an (unrenamed) column or value.
147
    If errors_table set and col has srcs, saves errors in errors_table (using
148 3360 aaronmk
    col's srcs attr as source columns). Otherwise, converts errors to warnings.
149 3078 aaronmk
    @param col str|sql_gen.Col|sql_gen.Literal
150
    @param errors_table None|sql_gen.Table|str
151
    '''
152
    col = sql_gen.as_Col(col)
153
154 3112 aaronmk
    # Don't convert exceptions to warnings for user-supplied constants
155
    if isinstance(col, sql_gen.Literal): return sql_gen.Cast(type_, col)
156
157 3078 aaronmk
    assert not isinstance(col, sql_gen.NamedCol)
158
159 3460 aaronmk
    function_name = strings.first_word(type_)
160 3459 aaronmk
    srcs = col.srcs
161 3508 aaronmk
    save_errors = errors_table != None and srcs
162
    if save_errors: # function will be unique for the given srcs
163 3750 aaronmk
        function_name = strings.ustr(sql_gen.FunctionCall(function_name,
164 3508 aaronmk
            *map(sql_gen.to_name_only_col, srcs)))
165 3078 aaronmk
    function = db.TempFunction(function_name)
166
167 3464 aaronmk
    # Create function definition
168
    modifiers = 'STRICT'
169
    if not save_errors: modifiers = 'IMMUTABLE '+modifiers
170 5791 aaronmk
    value_param = sql_gen.FunctionParam('value', 'anyelement')
171 3511 aaronmk
    handler = data_exception_handler('RETURN NULL;\n', srcs, errors_table,
172
        value_param.name)
173 3464 aaronmk
    body = sql_gen.CustomCode(handler.to_str(db, '''\
174 3467 aaronmk
/* The explicit cast to the return type is needed to make the cast happen
175
inside the try block. (Implicit casts to the return type happen at the end
176
of the function, outside any block.) */
177 6300 aaronmk
RETURN '''+sql_gen.Cast(type_, sql_gen.CustomCode('value')).to_str(db)+''';
178 3464 aaronmk
'''))
179
    body.lang='plpgsql'
180 3500 aaronmk
    sql.define_func(db, sql_gen.FunctionDef(function, type_, body,
181 3511 aaronmk
        [value_param], modifiers))
182 3464 aaronmk
183 3078 aaronmk
    return sql_gen.FunctionCall(function, col)
184
185 3538 aaronmk
def func_wrapper_exception_handler(db, return_, args, errors_table):
186 3524 aaronmk
    '''Handles a function call's data_exceptions.
187
    Supports PL/Python functions.
188
    @param return_ See data_exception_handler()
189
    @param args [arg...] Function call's args
190
    @param errors_table See data_exception_handler()
191
    '''
192
    args = filter(sql_gen.has_srcs, args)
193
194
    srcs = sql_gen.cross_join_srcs(args)
195 3538 aaronmk
    value = sql_gen.merge_not_null(db, ',', args)
196 3524 aaronmk
    return sql_gen.NestedExcHandler(
197
        data_exception_handler(return_, srcs, errors_table, value)
198
        , sql_gen.plpythonu_error_handler
199
        )
200
201 3078 aaronmk
def cast_temp_col(db, type_, col, errors_table=None):
202
    '''Like cast(), but creates a new column with the cast values if the input
203
    is a column.
204
    @return The new column or cast value
205
    '''
206
    def cast_(col): return cast(db, type_, col, errors_table)
207
208
    try: col = sql_gen.underlying_col(col)
209
    except sql_gen.NoUnderlyingTableException: return sql_gen.wrap(cast_, col)
210
211
    table = col.table
212 3173 aaronmk
    new_col = sql_gen.suffixed_col(col, '::'+strings.first_word(type_))
213 3078 aaronmk
    expr = cast_(col)
214
215
    # Add column
216
    new_typed_col = sql_gen.TypedCol(new_col.name, type_)
217 3750 aaronmk
    sql.add_col(db, table, new_typed_col, comment=strings.urepr(col)+'::'+type_)
218 3078 aaronmk
    new_col.name = new_typed_col.name # propagate any renaming
219
220 3110 aaronmk
    sql.update(db, table, [(new_col, expr)], in_place=True, recover=True)
221 3078 aaronmk
222
    return new_col
223
224
def errors_table(db, table, if_exists=True):
225
    '''
226
    @param if_exists If set, returns None if the errors table doesn't exist
227
    @return None|sql_gen.Table
228
    '''
229
    table = sql_gen.as_Table(table)
230
    if table.srcs != (): table = table.srcs[0]
231
232
    errors_table = sql_gen.suffixed_table(table, '.errors')
233
    if if_exists and not sql.table_exists(db, errors_table): return None
234
    return errors_table
235
236 4436 aaronmk
def mk_errors_table(db, table):
237
    errors_table_ = errors_table(db, table, if_exists=False)
238 4557 aaronmk
    if sql.table_exists(db, errors_table_, cacheable=False): return
239 4436 aaronmk
240
    typed_cols = [
241 5813 aaronmk
        sql_gen.TypedCol('column', 'text'),
242 4436 aaronmk
        sql_gen.TypedCol('value', 'text'),
243
        sql_gen.TypedCol('error_code', 'character varying(5)', nullable=False),
244
        sql_gen.TypedCol('error', 'text', nullable=False),
245
        ]
246
    sql.create_table(db, errors_table_, typed_cols, has_pkey=False)
247 8077 aaronmk
    index_cols = ['column', sql_gen.CustomCode('md5(value)'), 'error_code',
248
        sql_gen.CustomCode('md5(error)')]
249 4436 aaronmk
    sql.add_index(db, index_cols, errors_table_, unique=True)
250
251 3078 aaronmk
##### Import
252
253 5568 aaronmk
row_num_col_def = copy.copy(sql.row_num_col_def)
254
row_num_col_def.name = 'row_num'
255 5569 aaronmk
row_num_col_def.type = 'integer'
256 5568 aaronmk
257 5590 aaronmk
def append_csv(db, table, reader, header):
258 9508 aaronmk
    def esc_name_(name): return sql.esc_name(db, name)
259 5037 aaronmk
260 5017 aaronmk
    def log(msg, level=1): db.log_debug(msg, level)
261
262 5584 aaronmk
    # Wrap in standardizing stream
263 5590 aaronmk
    cols_ct = len(header)
264 14589 aaronmk
    stream = csvs.InputRewriter(csvs.ProgressInputFilter(
265
        csvs.ColCtFilter(reader, cols_ct), sys.stderr, n=1000))
266 14585 aaronmk
    #streams.copy(stream, sys.stderr) # to troubleshoot copy_expert() errors
267 5584 aaronmk
    dialect = stream.dialect # use default dialect
268
269
    # Create COPY FROM statement
270 9508 aaronmk
    if header == sql.table_col_names(db, table): cols_str = ''
271
    else: cols_str =' ('+(', '.join(map(esc_name_, header)))+')'
272
    copy_from = ('COPY '+table.to_str(db)+cols_str+' FROM STDIN DELIMITER '
273 5584 aaronmk
        +db.esc_value(dialect.delimiter)+' NULL '+db.esc_value(''))
274
    assert not csvs.is_tsv(dialect)
275
    copy_from += ' CSV'
276
    if dialect.quoting != csv.QUOTE_NONE:
277
        quote_str = db.esc_value(dialect.quotechar)
278
        copy_from += ' QUOTE '+quote_str
279
        if dialect.doublequote: copy_from += ' ESCAPE '+quote_str
280
    copy_from += ';\n'
281
282
    log(copy_from, level=2)
283
    try: db.db.cursor().copy_expert(copy_from, stream)
284
    except Exception, e: sql.parse_exception(db, e, recover=True)
285 5017 aaronmk
286 5591 aaronmk
def import_csv(db, table, reader, header):
287 4995 aaronmk
    def log(msg, level=1): db.log_debug(msg, level)
288
289
    # Get format info
290 5590 aaronmk
    col_names = map(strings.to_unicode, header)
291 4995 aaronmk
    for i, col in enumerate(col_names): # replace empty column names
292
        if col == '': col_names[i] = 'column_'+str(i)
293
294
    # Select schema and escape names
295
    def esc_name(name): return db.esc_name(name)
296
297
    typed_cols = [sql_gen.TypedCol(v, 'text') for v in col_names]
298 5594 aaronmk
    typed_cols.insert(0, row_num_col_def)
299
    header.insert(0, row_num_col_def.name)
300
    reader = csvs.RowNumFilter(reader)
301 4995 aaronmk
302
    log('Creating table')
303 4999 aaronmk
    # Note that this is not rolled back if the import fails. Instead, it is
304
    # cached, and will not be re-run if the import is retried.
305 4995 aaronmk
    sql.create_table(db, table, typed_cols, has_pkey=False, col_indexes=False)
306
307 5001 aaronmk
    # Free memory used by deleted (rolled back) rows from any failed import.
308
    # This MUST be run so that the rows will be stored in inserted order, and
309
    # the row_num added after import will match up with the CSV's row order.
310 4999 aaronmk
    sql.truncate(db, table)
311
312 4995 aaronmk
    # Load the data
313 5590 aaronmk
    def load(): append_csv(db, table, reader, header)
314 5583 aaronmk
    sql.with_savepoint(db, load)
315 4995 aaronmk
316
    cleanup_table(db, table)
317
318 5719 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None, on_error=exc.reraise):
319 3077 aaronmk
    '''Recovers from errors.
320
    Only works under PostgreSQL (uses INSERT RETURNING).
321
    '''
322 5719 aaronmk
    return put_table(db, table, [], row, row_ct_ref, on_error=on_error)
323 3077 aaronmk
324
def get(db, table, row, pkey, row_ct_ref=None, create=False):
325
    '''Recovers from errors'''
326
    try:
327
        return sql.value(sql.select(db, table, [pkey], row, limit=1,
328
            recover=True))
329
    except StopIteration:
330
        if not create: raise
331
        return put(db, table, row, pkey, row_ct_ref) # insert new row
332
333
def is_func_result(col):
334
    return col.table.name.find('(') >= 0 and col.name == 'result'
335
336
def into_table_name(out_table, in_tables0, mapping, is_func):
337
    def in_col_str(in_col):
338
        in_col = sql_gen.remove_col_rename(in_col)
339
        if isinstance(in_col, sql_gen.Col):
340
            table = in_col.table
341
            if table == in_tables0:
342
                in_col = sql_gen.to_name_only_col(in_col)
343
            elif is_func_result(in_col): in_col = table # omit col name
344 3750 aaronmk
        return strings.ustr(in_col)
345 3077 aaronmk
346 4491 aaronmk
    str_ = strings.ustr(out_table)
347 3077 aaronmk
    if is_func:
348
        str_ += '('
349
350
        try: value_in_col = mapping['value']
351
        except KeyError:
352 4491 aaronmk
            str_ += ', '.join((strings.ustr(k)+'='+in_col_str(v)
353 3077 aaronmk
                for k, v in mapping.iteritems()))
354
        else: str_ += in_col_str(value_in_col)
355
356
        str_ += ')'
357
    else:
358
        out_col = 'rank'
359
        try: in_col = mapping[out_col]
360
        except KeyError: str_ += '_pkeys'
361
        else: # has a rank column, so hierarchical
362 4491 aaronmk
            str_ += '['+strings.ustr(out_col)+'='+in_col_str(in_col)+']'
363 3077 aaronmk
    return str_
364
365 3628 aaronmk
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, default=None,
366 3660 aaronmk
    col_defaults={}, on_error=exc.reraise):
367 3077 aaronmk
    '''Recovers from errors.
368
    Only works under PostgreSQL (uses INSERT RETURNING).
369 7392 aaronmk
370 7395 aaronmk
    Warning: This function's normalizing algorithm does not support database
371 10164 aaronmk
    triggers that populate fields covered by the unique constraint used to do
372
    the DISTINCT ON. Such fields must be populated by the mappings instead.
373
    (Other unique constraints and other non-unique fields are not affected by
374
    this restriction on triggers. Note that the primary key will normally not be
375
    the DISTINCT ON constraint, so trigger-populated natural keys are supported
376
    *unless* the input table contains duplicate rows for some generated keys.)
377 7395 aaronmk
378
    Note that much of the complexity of the normalizing algorithm is due to
379
    PostgreSQL (and other DB systems) not having a native command for
380 11033 aaronmk
    INSERT ON DUPLICATE SELECT (wiki.vegpath.org/INSERT_ON_DUPLICATE_SELECT).
381
    For PostgreSQL 9.1+, this can now be emulated using INSTEAD OF triggers.
382
    For earlier versions, you instead have to use this function.
383 10165 aaronmk
384 3077 aaronmk
    @param in_tables The main input table to select from, followed by a list of
385
        tables to join with it using the main input table's pkey
386
    @param mapping dict(out_table_col=in_table_col, ...)
387
        * out_table_col: str (*not* sql_gen.Col)
388
        * in_table_col: sql_gen.Col|literal-value
389
    @param default The *output* column to use as the pkey for missing rows.
390
        If this output column does not exist in the mapping, uses None.
391 11151 aaronmk
        Note that this will be used for *all* missing rows, regardless of which
392
        error caused them not to be inserted.
393 3618 aaronmk
    @param col_defaults Default values for required columns.
394 3077 aaronmk
    @return sql_gen.Col Where the output pkeys are made available
395
    '''
396 3474 aaronmk
    import psycopg2.extensions
397
398 6220 aaronmk
    # Special handling for functions with hstore params
399
    if out_table == '_map':
400
        import psycopg2.extras
401
        psycopg2.extras.register_hstore(db.db)
402
403
        # Parse args
404
        try: value = mapping.pop('value')
405
        except KeyError: return None # value required
406
407 6226 aaronmk
        mapping = dict([(k, sql_gen.get_value(v))
408
            for k, v in mapping.iteritems()]) # unwrap literal value
409 6220 aaronmk
        mapping = dict(map=mapping, value=value) # non-value params -> hstore
410
411 3077 aaronmk
    out_table = sql_gen.as_Table(out_table)
412
413
    def log_debug(msg): db.log_debug(msg, level=1.5)
414
    def col_ustr(str_):
415
        return strings.repr_no_u(sql_gen.remove_col_rename(str_))
416
417
    log_debug('********** New iteration **********')
418
    log_debug('Inserting these input columns into '+strings.as_tt(
419
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
420
421
    is_function = sql.function_exists(db, out_table)
422
423 5192 aaronmk
    if is_function: row_ct_ref = None # only track inserted rows
424
425 4984 aaronmk
    # Warn if inserting empty table rows
426
    if not mapping and not is_function: # functions with no args OK
427
        warnings.warn(UserWarning('Inserting empty table row(s)'))
428
429 3077 aaronmk
    if is_function: out_pkey = 'result'
430 5388 aaronmk
    else: out_pkey = sql.pkey_name(db, out_table, recover=True)
431 3077 aaronmk
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
432
433 5069 aaronmk
    in_tables_ = copy.copy(in_tables) # don't modify input!
434 3432 aaronmk
    try: in_tables0 = in_tables_.pop(0) # first table is separate
435
    except IndexError: in_tables0 = None
436
    else:
437 5388 aaronmk
        in_pkey = sql.pkey_name(db, in_tables0, recover=True)
438 3432 aaronmk
        in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
439 3431 aaronmk
440
    # Determine if can use optimization for only literal values
441
    is_literals = not reduce(operator.or_, map(sql_gen.is_table_col,
442 3434 aaronmk
        mapping.values()), False)
443 3431 aaronmk
    is_literals_or_function = is_literals or is_function
444
445 3432 aaronmk
    if in_tables0 == None: errors_table_ = None
446
    else: errors_table_ = errors_table(db, in_tables0)
447 3431 aaronmk
448
    # Create input joins from list of input tables
449 3077 aaronmk
    input_joins = [in_tables0]+[sql_gen.Join(v,
450
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
451
452 5381 aaronmk
    orig_mapping = mapping.copy()
453 3433 aaronmk
    if mapping == {} and not is_function: # need >= one column for INSERT SELECT
454
        mapping = {out_pkey: None} # ColDict will replace with default value
455
456 3431 aaronmk
    if not is_literals:
457 3628 aaronmk
        into = sql_gen.as_Table(into_table_name(out_table, in_tables0, mapping,
458
            is_function))
459 5553 aaronmk
        # Ensure into's out_pkey is different from in_pkey by prepending "out."
460 4495 aaronmk
        if is_function: into_out_pkey = out_pkey
461 5553 aaronmk
        else: into_out_pkey = 'out.'+out_pkey
462 3431 aaronmk
463
        # Set column sources
464
        in_cols = filter(sql_gen.is_table_col, mapping.values())
465
        for col in in_cols:
466
            if col.table == in_tables0: col.set_srcs(sql_gen.src_self)
467
468
        log_debug('Joining together input tables into temp table')
469
        # Place in new table so don't modify input and for speed
470
        in_table = sql_gen.Table('in')
471
        mapping = dicts.join(mapping, sql.flatten(db, in_table, input_joins,
472
            in_cols, preserve=[in_pkey_col]))
473
        input_joins = [in_table]
474
        db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
475 3077 aaronmk
476 3692 aaronmk
    # Wrap mapping in a sql_gen.ColDict.
477
    # sql_gen.ColDict sanitizes both keys and values passed into it.
478
    # Do after applying dicts.join() because that returns a plain dict.
479 3077 aaronmk
    mapping = sql_gen.ColDict(db, out_table, mapping)
480
481 5239 aaronmk
    # Save all rows since in_table may have rows deleted
482 3431 aaronmk
    if is_literals: pass
483
    elif is_function: full_in_table = in_table
484 3386 aaronmk
    else:
485
        full_in_table = sql_gen.suffixed_table(in_table, '_full')
486 5530 aaronmk
        sql.copy_table(db, in_table, full_in_table)
487 3287 aaronmk
488 3077 aaronmk
    pkeys_table_exists_ref = [False]
489 5444 aaronmk
    def insert_into_pkeys(query, **kw_args):
490 3077 aaronmk
        if pkeys_table_exists_ref[0]:
491 4484 aaronmk
            sql.insert_select(db, into, [in_pkey, into_out_pkey], query,
492
                **kw_args)
493 3077 aaronmk
        else:
494 6800 aaronmk
            kw_args.setdefault('add_pkey_', True)
495 13005 aaronmk
            # don't warn if can't create pkey, because this just indicates that,
496
            # at some point in the import tree, it used a set-returning function
497
            kw_args.setdefault('add_pkey_warn', False)
498 6800 aaronmk
499
            sql.run_query_into(db, query, into=into, **kw_args)
500 3077 aaronmk
            pkeys_table_exists_ref[0] = True
501
502 5523 aaronmk
    def mk_main_select(joins, cols): return sql.mk_select(db, joins, cols)
503 3418 aaronmk
504 3431 aaronmk
    if is_literals: insert_in_table = None
505
    else:
506
        insert_in_table = in_table
507
        insert_in_tables = [insert_in_table]
508 3352 aaronmk
    join_cols = sql_gen.ColDict(db, out_table)
509 10845 aaronmk
    join_custom_cond = None
510 3077 aaronmk
511
    exc_strs = set()
512
    def log_exc(e):
513
        e_str = exc.str_(e, first_line_only=True)
514
        log_debug('Caught exception: '+e_str)
515 3552 aaronmk
        if e_str in exc_strs: # avoid infinite loops
516
            log_debug('Exception already seen, handler broken')
517
            on_error(e)
518
            remove_all_rows()
519 5718 aaronmk
            return False
520 3552 aaronmk
        else: exc_strs.add(e_str)
521 5718 aaronmk
        return True
522 3077 aaronmk
523 5377 aaronmk
    ignore_all_ref = [False]
524 3077 aaronmk
    def remove_all_rows():
525
        log_debug('Ignoring all rows')
526 5377 aaronmk
        ignore_all_ref[0] = True # just return the default value column
527 3077 aaronmk
528 5889 aaronmk
    def handle_unknown_exc(e):
529
        log_debug('No handler for exception')
530
        on_error(e)
531
        remove_all_rows()
532
533 5504 aaronmk
    def ensure_cond(cond, e, passed=False, failed=False):
534 10838 aaronmk
        '''
535
        @param passed at least one row passed the constraint
536
        @param failed at least one row failed the constraint
537
        '''
538 5449 aaronmk
        if is_literals: # we know the constraint was applied exactly once
539
            if passed: pass
540
            elif failed: remove_all_rows()
541
            else: raise NotImplementedError()
542 3704 aaronmk
        else:
543 5888 aaronmk
            if not is_function:
544
                out_table_cols = sql_gen.ColDict(db, out_table)
545
                out_table_cols.update(util.dict_subset_right_join({},
546
                    sql.table_col_names(db, out_table)))
547 3704 aaronmk
548
            in_cols = []
549 7117 aaronmk
            cond = strings.ustr(cond)
550 5818 aaronmk
            orig_cond = cond
551 5367 aaronmk
            cond = sql_gen.map_expr(db, cond, mapping, in_cols)
552 5888 aaronmk
            if not is_function:
553
                cond = sql_gen.map_expr(db, cond, out_table_cols)
554 3704 aaronmk
555 5442 aaronmk
            log_debug('Ignoring rows that do not satisfy '+strings.as_tt(cond))
556 5449 aaronmk
            cur = None
557
            if cond == sql_gen.false_expr:
558
                assert failed
559
                remove_all_rows()
560
            elif cond == sql_gen.true_expr: assert passed
561 5895 aaronmk
            else:
562
                while True:
563
                    not_cond = sql_gen.NotCond(sql_gen.CustomCode(cond))
564
                    try:
565
                        cur = sql.delete(db, insert_in_table, not_cond)
566
                        break
567
                    except sql.DoesNotExistException, e:
568
                        if e.type != 'column': raise
569
570
                        last_cond = cond
571
                        cond = sql_gen.map_expr(db, cond, {e.name: None})
572
                        if cond == last_cond: raise # not fixable
573 5449 aaronmk
574 5827 aaronmk
            # If any rows failed cond
575
            if failed or cur != None and cur.rowcount > 0:
576 5505 aaronmk
                track_data_error(db, errors_table_,
577
                    sql_gen.cross_join_srcs(in_cols), None, e.cause.pgcode,
578 7117 aaronmk
                    strings.ensure_newl(strings.ustr(e.cause.pgerror))
579
                    +'condition: '+orig_cond+'\ntranslated condition: '+cond)
580 3352 aaronmk
581 3294 aaronmk
    not_null_cols = set()
582 3077 aaronmk
    def ignore(in_col, value, e):
583 3630 aaronmk
        if sql_gen.is_table_col(in_col):
584
            in_col = sql_gen.with_table(in_col, insert_in_table)
585
586
            track_data_error(db, errors_table_, in_col.srcs, value,
587
                e.cause.pgcode, e.cause.pgerror)
588
589
            sql.add_index(db, in_col, insert_in_table) # enable fast filtering
590
            if value != None and in_col not in not_null_cols:
591 4492 aaronmk
                log_debug('Replacing invalid value '
592
                    +strings.as_tt(strings.urepr(value))+' with NULL in column '
593
                    +strings.as_tt(in_col.to_str(db)))
594 3630 aaronmk
                sql.update(db, insert_in_table, [(in_col, None)],
595
                    sql_gen.ColValueCond(in_col, value))
596
            else:
597 3637 aaronmk
                log_debug('Ignoring rows with '+strings.as_tt(in_col.to_str(db))
598 4492 aaronmk
                    +' = '+strings.as_tt(strings.urepr(value)))
599 3630 aaronmk
                sql.delete(db, insert_in_table,
600
                    sql_gen.ColValueCond(in_col, value))
601
                if value == None: not_null_cols.add(in_col)
602 3293 aaronmk
        else:
603 3630 aaronmk
            assert isinstance(in_col, sql_gen.NamedCol)
604 3684 aaronmk
            in_value = sql_gen.remove_col_rename(in_col)
605
            assert sql_gen.is_literal(in_value)
606
            if value == in_value.value:
607
                if value != None:
608
                    log_debug('Replacing invalid literal '
609
                        +strings.as_tt(in_col.to_str(db))+' with NULL')
610
                    mapping[in_col.name] = None
611
                else:
612
                    remove_all_rows()
613
            # otherwise, all columns were being ignore()d because the specific
614
            # column couldn't be identified, and this was not the invalid column
615 3077 aaronmk
616 3431 aaronmk
    if not is_literals:
617
        def insert_pkeys_table(which):
618
            return sql_gen.Table(sql_gen.concat(in_table.name,
619
                '_insert_'+which+'_pkeys'))
620
        insert_out_pkeys = insert_pkeys_table('out')
621
        insert_in_pkeys = insert_pkeys_table('in')
622 3077 aaronmk
623 3918 aaronmk
    def mk_func_call():
624 3550 aaronmk
        args = dict(((k.name, v) for k, v in mapping.iteritems()))
625 4484 aaronmk
        return sql_gen.FunctionCall(out_table, **args), args
626 3550 aaronmk
627 12150 aaronmk
    def handle_MissingCastException(e):
628
        if not log_exc(e): return False
629
630
        type_ = e.type
631
        if e.col == None: out_cols = mapping.keys()
632
        else: out_cols = [e.col]
633
634
        for out_col in out_cols:
635
            log_debug('Casting '+strings.as_tt(strings.repr_no_u(out_col))
636
                +' input to '+strings.as_tt(type_))
637
            in_col = mapping[out_col]
638
            while True:
639
                try:
640 14074 aaronmk
                    cast_col = cast_temp_col(db, type_, in_col, errors_table_)
641
                    mapping[out_col] = cast_col
642
                    if out_col in join_cols: join_cols[out_col] = cast_col
643 12150 aaronmk
                    break # cast successful
644
                except sql.InvalidValueException, e:
645
                    if not log_exc(e): return False
646
647
                    ignore(in_col, e.value, e)
648
649
        return True
650
651 5239 aaronmk
    missing_msg = None
652
653 3077 aaronmk
    # Do inserts and selects
654
    while True:
655 3473 aaronmk
        has_joins = join_cols != {}
656
657 5377 aaronmk
        if ignore_all_ref[0]: break # unrecoverable error, so don't do main case
658 3077 aaronmk
659
        # Prepare to insert new rows
660 3918 aaronmk
        if is_function:
661 5726 aaronmk
            if is_literals:
662
                log_debug('Calling function')
663
                func_call, args = mk_func_call()
664 3077 aaronmk
        else:
665 3550 aaronmk
            log_debug('Trying to insert new rows')
666 3291 aaronmk
            insert_args = dict(recover=True, cacheable=False)
667
            if has_joins:
668
                insert_args.update(dict(ignore=True))
669
            else:
670 3431 aaronmk
                insert_args.update(dict(returning=out_pkey))
671
                if not is_literals:
672
                    insert_args.update(dict(into=insert_out_pkeys))
673 3291 aaronmk
            main_select = mk_main_select([insert_in_table], [sql_gen.with_table(
674
                c, insert_in_table) for c in mapping.values()])
675 3077 aaronmk
676 3292 aaronmk
        try:
677
            cur = None
678 3077 aaronmk
            if is_function:
679 3917 aaronmk
                if is_literals:
680
                    cur = sql.select(db, fields=[func_call], recover=True,
681
                        cacheable=True)
682 5444 aaronmk
                else:
683 5726 aaronmk
                    log_debug('Defining wrapper function')
684
685
                    func_call, args = mk_func_call()
686
                    func_call = sql_gen.NamedCol(into_out_pkey, func_call)
687
688
                    # Create empty pkeys table so its row type can be used
689
                    insert_into_pkeys(sql.mk_select(db, input_joins,
690 6801 aaronmk
                        [in_pkey_col, func_call], limit=0), add_pkey_=False,
691
                        recover=True)
692 5726 aaronmk
                    result_type = db.col_info(sql_gen.Col(into_out_pkey,
693
                        into)).type
694
695
                    ## Create error handling wrapper function
696
697
                    wrapper = db.TempFunction(sql_gen.concat(into.name,
698
                        '_wrap'))
699
700
                    select_cols = [in_pkey_col]+args.values()
701
                    row_var = copy.copy(sql_gen.row_var)
702
                    row_var.set_srcs([in_table])
703
                    in_pkey_var = sql_gen.Col(in_pkey, row_var)
704
705
                    args = dict(((k, sql_gen.with_table(v, row_var))
706
                        for k, v in args.iteritems()))
707
                    func_call = sql_gen.FunctionCall(out_table, **args)
708
709
                    def mk_return(result):
710
                        return sql_gen.ReturnQuery(sql.mk_select(db,
711
                            fields=[in_pkey_var, result], explain=False))
712
                    exc_handler = func_wrapper_exception_handler(db,
713
                        mk_return(sql_gen.Cast(result_type, None)),
714
                        args.values(), errors_table_)
715
716
                    sql.define_func(db, sql_gen.FunctionDef(wrapper,
717
                        sql_gen.SetOf(into),
718
                        sql_gen.RowExcIgnore(sql_gen.RowType(in_table),
719
                            sql.mk_select(db, input_joins),
720
                            mk_return(func_call), exc_handler=exc_handler)
721
                        ))
722
                    wrapper_table = sql_gen.FunctionCall(wrapper)
723
724
                    log_debug('Calling function')
725 5444 aaronmk
                    insert_into_pkeys(sql.mk_select(db, wrapper_table,
726 5894 aaronmk
                        order_by=None), recover=True, cacheable=False)
727 8820 aaronmk
                    sql.add_pkey_or_index(db, into)
728 3077 aaronmk
            else:
729 3292 aaronmk
                cur = sql.insert_select(db, out_table, mapping.keys(),
730 3077 aaronmk
                    main_select, **insert_args)
731
            break # insert successful
732
        except sql.MissingCastException, e:
733 12150 aaronmk
            if not handle_MissingCastException(e): break
734 3077 aaronmk
        except sql.DuplicateKeyException, e:
735 5718 aaronmk
            if not log_exc(e): break
736 3077 aaronmk
737 3274 aaronmk
            # Different rows violating different unique constraints not
738
            # supported
739
            assert not join_cols
740
741 10845 aaronmk
            join_custom_cond = e.cond
742 5504 aaronmk
            if e.cond != None: ensure_cond(e.cond, e, passed=True)
743 5450 aaronmk
744 3077 aaronmk
            join_cols.update(util.dict_subset_right_join(mapping, e.cols))
745
            log_debug('Ignoring existing rows, comparing on these columns:\n'
746
                +strings.as_inline_table(join_cols, ustr=col_ustr))
747 3102 aaronmk
748 3431 aaronmk
            if is_literals:
749
                return sql.value(sql.select(db, out_table, [out_pkey_col],
750 4025 aaronmk
                    join_cols, order_by=None))
751 3431 aaronmk
752 7180 aaronmk
            # Uniquify and filter input table to avoid (most) duplicate keys
753
            # (Additional duplicates may be added concurrently and will be
754
            # filtered out separately upon insert.)
755 12153 aaronmk
            while True:
756
                try:
757
                    insert_in_table = sql.distinct_table(db, insert_in_table,
758
                        join_cols.values(), [insert_in_table,
759
                        sql_gen.Join(out_table, join_cols, sql_gen.filter_out,
760
                        e.cond)])
761
                    insert_in_tables.append(insert_in_table)
762
                    break # insert successful
763
                except sql.MissingCastException, e1: # don't modify outer e
764
                    if not handle_MissingCastException(e1): break
765 3077 aaronmk
        except sql.NullValueException, e:
766 5718 aaronmk
            if not log_exc(e): break
767 3077 aaronmk
768
            out_col, = e.cols
769
            try: in_col = mapping[out_col]
770 3618 aaronmk
            except KeyError, e:
771
                try: in_col = mapping[out_col] = col_defaults[out_col]
772
                except KeyError:
773 5239 aaronmk
                    missing_msg = 'Missing mapping for NOT NULL column '+out_col
774
                    log_debug(missing_msg)
775 3618 aaronmk
                    remove_all_rows()
776 3294 aaronmk
            else: ignore(in_col, None, e)
777 3352 aaronmk
        except sql.CheckException, e:
778 5718 aaronmk
            if not log_exc(e): break
779 3352 aaronmk
780 5504 aaronmk
            ensure_cond(e.cond, e, failed=True)
781 3413 aaronmk
        except sql.InvalidValueException, e:
782 5718 aaronmk
            if not log_exc(e): break
783 3413 aaronmk
784
            for in_col in mapping.values(): ignore(in_col, e.value, e)
785 3474 aaronmk
        except psycopg2.extensions.TransactionRollbackError, e:
786 5718 aaronmk
            if not log_exc(e): break
787 3474 aaronmk
            # retry
788 3077 aaronmk
        except sql.DatabaseErrors, e:
789 5718 aaronmk
            if not log_exc(e): break
790 3077 aaronmk
791 5889 aaronmk
            handle_unknown_exc(e)
792 3077 aaronmk
        # after exception handled, rerun loop with additional constraints
793
794 5239 aaronmk
    # Resolve default value column
795
    if default != None:
796 5381 aaronmk
        if ignore_all_ref[0]: mapping.update(orig_mapping) # use input cols
797 5239 aaronmk
        try: default = mapping[default]
798
        except KeyError:
799
            db.log_debug('Default value column '
800
                +strings.as_tt(strings.repr_no_u(default))
801
                +' does not exist in mapping, falling back to None', level=2.1)
802
            default = None
803 5380 aaronmk
        else: default = sql_gen.remove_col_rename(default)
804 5239 aaronmk
805
    if missing_msg != None and default == None:
806
        warnings.warn(UserWarning(missing_msg))
807
        # not an error because sometimes the mappings include
808
        # extra tables which aren't used by the dataset
809
810
    # Handle unrecoverable errors
811 5377 aaronmk
    if ignore_all_ref[0]:
812 5373 aaronmk
        log_debug('Returning default: '+strings.as_tt(strings.urepr(default)))
813
        return default
814 5239 aaronmk
815 3077 aaronmk
    if cur != None and row_ct_ref != None and cur.rowcount >= 0:
816
        row_ct_ref[0] += cur.rowcount
817
818 12868 aaronmk
    if is_literals: return sql.value_or_none(cur) # support multi-row functions
819 3530 aaronmk
820
    if is_function: pass # pkeys table already created
821 3077 aaronmk
    elif has_joins:
822 10845 aaronmk
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols,
823
            custom_cond=join_custom_cond)]
824 3077 aaronmk
        log_debug('Getting output table pkeys of existing/inserted rows')
825 5444 aaronmk
        insert_into_pkeys(sql.mk_select(db, select_joins, [in_pkey_col,
826
            sql_gen.NamedCol(into_out_pkey, out_pkey_col)], order_by=None))
827 3077 aaronmk
    else:
828
        sql.add_row_num(db, insert_out_pkeys) # for joining with input pkeys
829
830
        log_debug('Getting input table pkeys of inserted rows')
831 3285 aaronmk
        # Note that mk_main_select() does not use ORDER BY. Instead, assume that
832
        # since the SELECT query is identical to the one used in INSERT SELECT,
833
        # its rows will be retrieved in the same order.
834 3077 aaronmk
        sql.run_query_into(db, mk_main_select(input_joins, [in_pkey]),
835
            into=insert_in_pkeys)
836
        sql.add_row_num(db, insert_in_pkeys) # for joining with output pkeys
837
838
        assert sql.table_row_count(db, insert_out_pkeys) == sql.table_row_count(
839
            db, insert_in_pkeys)
840
841
        log_debug('Combining output and input pkeys in inserted order')
842
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
843
            {sql.row_num_col: sql_gen.join_same_not_null})]
844 4484 aaronmk
        in_col = sql_gen.Col(in_pkey, insert_in_pkeys)
845
        out_col = sql_gen.NamedCol(into_out_pkey,
846
            sql_gen.Col(out_pkey, insert_out_pkeys))
847 5444 aaronmk
        insert_into_pkeys(sql.mk_select(db, pkey_joins, [in_col, out_col],
848
            order_by=None))
849 3077 aaronmk
850
        sql.empty_temp(db, [insert_out_pkeys, insert_in_pkeys])
851
852 5374 aaronmk
    if not is_function: # is_function doesn't leave holes
853 3187 aaronmk
        log_debug('Setting pkeys of missing rows to '
854 4492 aaronmk
            +strings.as_tt(strings.urepr(default)))
855 5993 aaronmk
856
        full_in_pkey_col = sql_gen.Col(in_pkey, full_in_table)
857 5380 aaronmk
        if sql_gen.is_table_col(default):
858
            default = sql_gen.with_table(default, full_in_table)
859 3287 aaronmk
        missing_rows_joins = [full_in_table, sql_gen.Join(into,
860 3187 aaronmk
            {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
861
            # must use join_same_not_null or query will take forever
862 5993 aaronmk
863
        insert_args = dict(order_by=None)
864
        if not sql.table_has_pkey(db, full_in_table): # in_table has duplicates
865
            insert_args.update(dict(distinct_on=[full_in_pkey_col]))
866
867 5444 aaronmk
        insert_into_pkeys(sql.mk_select(db, missing_rows_joins,
868 5993 aaronmk
            [full_in_pkey_col, sql_gen.NamedCol(into_out_pkey, default)],
869
            **insert_args))
870 3187 aaronmk
    # otherwise, there is already an entry for every row
871 3077 aaronmk
872 3530 aaronmk
    sql.empty_temp(db, insert_in_tables+[full_in_table])
873
874
    srcs = []
875 3619 aaronmk
    if is_function: srcs = sql_gen.cols_srcs(in_cols)
876 4484 aaronmk
    return sql_gen.Col(into_out_pkey, into, srcs)