Project

General

Profile

1
# Database import/export
2

    
3
import exc
4
import dicts
5
import sql
6
import sql_gen
7
import strings
8
import util
9

    
10
##### Error tracking
11

    
12
def track_data_error(db, errors_table, cols, value, error_code, error):
13
    '''
14
    @param errors_table If None, does nothing.
15
    '''
16
    if errors_table == None or cols == (): return
17
    
18
    for col in cols:
19
        try:
20
            sql.insert(db, errors_table, dict(column=col.name, value=value,
21
                error_code=error_code, error=error), recover=True,
22
                cacheable=True, log_level=4)
23
        except sql.DuplicateKeyException: pass
24

    
25
def cast(db, type_, col, errors_table=None):
26
    '''Casts an (unrenamed) column or value.
27
    If errors_table set and col has srcs, saves errors in errors_table (using
28
    col's srcs attr as the source columns) and converts errors to warnings.
29
    @param col str|sql_gen.Col|sql_gen.Literal
30
    @param errors_table None|sql_gen.Table|str
31
    '''
32
    col = sql_gen.as_Col(col)
33
    save_errors = (errors_table != None and isinstance(col, sql_gen.Col)
34
        and col.srcs != ())
35
    if not save_errors: return sql_gen.Cast(type_, col) # can't save errors
36
    
37
    assert not isinstance(col, sql_gen.NamedCol)
38
    
39
    errors_table = sql_gen.as_Table(errors_table)
40
    srcs = map(sql_gen.to_name_only_col, col.srcs)
41
    function_name = str(sql_gen.FunctionCall(type_, *srcs))
42
    function = db.TempFunction(function_name)
43
    
44
    while True:
45
        # Create function definition
46
        errors_table_cols = map(sql_gen.Col,
47
            ['column', 'value', 'error_code', 'error'])
48
        query = '''\
49
CREATE FUNCTION '''+function.to_str(db)+'''(value text)
50
RETURNS '''+type_+'''
51
LANGUAGE plpgsql
52
STRICT
53
AS $$
54
BEGIN
55
    /* The explicit cast to the return type is needed to make the cast happen
56
    inside the try block. (Implicit casts to the return type happen at the end
57
    of the function, outside any block.) */
58
    RETURN value::'''+type_+''';
59
EXCEPTION
60
    WHEN data_exception THEN
61
        -- Save error in errors table.
62
        DECLARE
63
            error_code text := SQLSTATE;
64
            error text := SQLERRM;
65
            "column" text;
66
        BEGIN
67
            -- Insert the value and error for *each* source column.
68
            FOR "column" IN
69
'''+sql.mk_select(db, sql_gen.NamedValues('c', None, [[c.name] for c in srcs]),
70
    order_by=None, start=0)+'''
71
            LOOP
72
                BEGIN
73
'''+sql.mk_insert_select(db, errors_table, errors_table_cols,
74
    sql_gen.Values(errors_table_cols).to_str(db))+''';
75
                EXCEPTION
76
                    WHEN unique_violation THEN NULL; -- continue to next row
77
                END;
78
            END LOOP;
79
        END;
80
        
81
        RAISE WARNING '%', SQLERRM;
82
        RETURN NULL;
83
END;
84
$$;
85
'''
86
        
87
        # Create function
88
        try:
89
            sql.run_query(db, query, recover=True, cacheable=True,
90
                log_ignore_excs=(sql.DuplicateException,))
91
            break # successful
92
        except sql.DuplicateException:
93
            function.name = sql.next_version(function.name)
94
            # try again with next version of name
95
    
96
    return sql_gen.FunctionCall(function, col)
97

    
98
def cast_temp_col(db, type_, col, errors_table=None):
99
    '''Like cast(), but creates a new column with the cast values if the input
100
    is a column.
101
    @return The new column or cast value
102
    '''
103
    def cast_(col): return cast(db, type_, col, errors_table)
104
    
105
    try: col = sql_gen.underlying_col(col)
106
    except sql_gen.NoUnderlyingTableException: return sql_gen.wrap(cast_, col)
107
    
108
    table = col.table
109
    new_col = sql_gen.Col(sql_gen.concat(col.name, '::'+type_), table, col.srcs)
110
    expr = cast_(col)
111
    
112
    # Add column
113
    new_typed_col = sql_gen.TypedCol(new_col.name, type_)
114
    sql.add_col(db, table, new_typed_col, comment='src: '+repr(col))
115
    new_col.name = new_typed_col.name # propagate any renaming
116
    
117
    sql.update(db, table, [(new_col, expr)], in_place=True, cacheable=True)
118
    sql.add_index(db, new_col)
119
    
120
    return new_col
121

    
122
def errors_table(db, table, if_exists=True):
123
    '''
124
    @param if_exists If set, returns None if the errors table doesn't exist
125
    @return None|sql_gen.Table
126
    '''
127
    table = sql_gen.as_Table(table)
128
    if table.srcs != (): table = table.srcs[0]
129
    
130
    errors_table = sql_gen.suffixed_table(table, '.errors')
131
    if if_exists and not sql.table_exists(db, errors_table): return None
132
    return errors_table
133

    
134
##### Import
135

    
136
def put(db, table, row, pkey_=None, row_ct_ref=None):
137
    '''Recovers from errors.
138
    Only works under PostgreSQL (uses INSERT RETURNING).
139
    '''
140
    row = sql_gen.ColDict(db, table, row)
141
    if pkey_ == None: pkey_ = sql.pkey(db, table, recover=True)
142
    
143
    try:
144
        cur = sql.insert(db, table, row, pkey_, recover=True)
145
        if row_ct_ref != None and cur.rowcount >= 0:
146
            row_ct_ref[0] += cur.rowcount
147
        return sql.value(cur)
148
    except sql.DuplicateKeyException, e:
149
        row = sql_gen.ColDict(db, table,
150
            util.dict_subset_right_join(row, e.cols))
151
        return sql.value(sql.select(db, table, [pkey_], row, recover=True))
152

    
153
def get(db, table, row, pkey, row_ct_ref=None, create=False):
154
    '''Recovers from errors'''
155
    try:
156
        return sql.value(sql.select(db, table, [pkey], row, limit=1,
157
            recover=True))
158
    except StopIteration:
159
        if not create: raise
160
        return put(db, table, row, pkey, row_ct_ref) # insert new row
161

    
162
def is_func_result(col):
163
    return col.table.name.find('(') >= 0 and col.name == 'result'
164

    
165
def into_table_name(out_table, in_tables0, mapping, is_func):
166
    def in_col_str(in_col):
167
        in_col = sql_gen.remove_col_rename(in_col)
168
        if isinstance(in_col, sql_gen.Col):
169
            table = in_col.table
170
            if table == in_tables0:
171
                in_col = sql_gen.to_name_only_col(in_col)
172
            elif is_func_result(in_col): in_col = table # omit col name
173
        return str(in_col)
174
    
175
    str_ = str(out_table)
176
    if is_func:
177
        str_ += '('
178
        
179
        try: value_in_col = mapping['value']
180
        except KeyError:
181
            str_ += ', '.join((str(k)+'='+in_col_str(v)
182
                for k, v in mapping.iteritems()))
183
        else: str_ += in_col_str(value_in_col)
184
        
185
        str_ += ')'
186
    else:
187
        out_col = 'rank'
188
        try: in_col = mapping[out_col]
189
        except KeyError: str_ += '_pkeys'
190
        else: # has a rank column, so hierarchical
191
            str_ += '['+str(out_col)+'='+in_col_str(in_col)+']'
192
    return str_
193

    
194
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, into=None,
195
    default=None, is_func=False, on_error=exc.raise_):
196
    '''Recovers from errors.
197
    Only works under PostgreSQL (uses INSERT RETURNING).
198
    IMPORTANT: Must be run at the *beginning* of a transaction.
199
    @param in_tables The main input table to select from, followed by a list of
200
        tables to join with it using the main input table's pkey
201
    @param mapping dict(out_table_col=in_table_col, ...)
202
        * out_table_col: str (*not* sql_gen.Col)
203
        * in_table_col: sql_gen.Col|literal-value
204
    @param into The table to contain the output and input pkeys.
205
        Defaults to `out_table.name+'_pkeys'`.
206
    @param default The *output* column to use as the pkey for missing rows.
207
        If this output column does not exist in the mapping, uses None.
208
    @param is_func Whether out_table is the name of a SQL function, not a table
209
    @return sql_gen.Col Where the output pkeys are made available
210
    '''
211
    out_table = sql_gen.as_Table(out_table)
212
    
213
    def log_debug(msg): db.log_debug(msg, level=1.5)
214
    def col_ustr(str_):
215
        return strings.repr_no_u(sql_gen.remove_col_rename(str_))
216
    
217
    log_debug('********** New iteration **********')
218
    log_debug('Inserting these input columns into '+strings.as_tt(
219
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
220
    
221
    is_function = sql.function_exists(db, out_table)
222
    
223
    if is_function: out_pkey = 'result'
224
    else: out_pkey = sql.pkey(db, out_table, recover=True)
225
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
226
    
227
    if mapping == {}: # need at least one column for INSERT SELECT
228
        mapping = {out_pkey: None} # ColDict will replace with default value
229
    
230
    # Create input joins from list of input tables
231
    in_tables_ = in_tables[:] # don't modify input!
232
    in_tables0 = in_tables_.pop(0) # first table is separate
233
    errors_table_ = errors_table(db, in_tables0)
234
    in_pkey = sql.pkey(db, in_tables0, recover=True)
235
    in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
236
    input_joins = [in_tables0]+[sql_gen.Join(v,
237
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
238
    
239
    if into == None:
240
        into = into_table_name(out_table, in_tables0, mapping, is_func)
241
    into = sql_gen.as_Table(into)
242
    
243
    # Set column sources
244
    in_cols = filter(sql_gen.is_table_col, mapping.values())
245
    for col in in_cols:
246
        if col.table == in_tables0: col.set_srcs(sql_gen.src_self)
247
    
248
    log_debug('Joining together input tables into temp table')
249
    # Place in new table for speed and so don't modify input if values edited
250
    in_table = sql_gen.Table('in')
251
    mapping = dicts.join(mapping, sql.flatten(db, in_table, input_joins,
252
        in_cols, preserve=[in_pkey_col], start=0))
253
    input_joins = [in_table]
254
    db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
255
    
256
    mapping = sql_gen.ColDict(db, out_table, mapping)
257
        # after applying dicts.join() because that returns a plain dict
258
    
259
    # Resolve default value column
260
    if default != None:
261
        try: default = mapping[default]
262
        except KeyError:
263
            db.log_debug('Default value column '
264
                +strings.as_tt(strings.repr_no_u(default))
265
                +' does not exist in mapping, falling back to None', level=2.1)
266
            default = None
267
    
268
    pkeys_names = [in_pkey, out_pkey]
269
    pkeys_cols = [in_pkey_col, out_pkey_col]
270
    
271
    pkeys_table_exists_ref = [False]
272
    def insert_into_pkeys(joins, cols, distinct=False):
273
        kw_args = {}
274
        if distinct: kw_args.update(dict(distinct_on=[in_pkey_col]))
275
        query = sql.mk_select(db, joins, cols, order_by=None, start=0,
276
            **kw_args)
277
        
278
        if pkeys_table_exists_ref[0]:
279
            sql.insert_select(db, into, pkeys_names, query)
280
        else:
281
            sql.run_query_into(db, query, into=into)
282
            pkeys_table_exists_ref[0] = True
283
    
284
    limit_ref = [None]
285
    conds = set()
286
    distinct_on = sql_gen.ColDict(db, out_table)
287
    def mk_main_select(joins, cols):
288
        distinct_on_cols = [c.to_Col() for c in distinct_on.values()]
289
        return sql.mk_select(db, joins, cols, conds, distinct_on_cols,
290
            limit=limit_ref[0], start=0)
291
    
292
    exc_strs = set()
293
    def log_exc(e):
294
        e_str = exc.str_(e, first_line_only=True)
295
        log_debug('Caught exception: '+e_str)
296
        assert e_str not in exc_strs # avoid infinite loops
297
        exc_strs.add(e_str)
298
    
299
    def remove_all_rows():
300
        log_debug('Ignoring all rows')
301
        limit_ref[0] = 0 # just create an empty pkeys table
302
    
303
    def ignore(in_col, value, e):
304
        track_data_error(db, errors_table_, in_col.srcs, value,
305
            e.cause.pgcode, e.cause.pgerror)
306
        log_debug('Ignoring rows with '+strings.as_tt(repr(in_col))+' = '
307
            +strings.as_tt(repr(value)))
308
    
309
    def remove_rows(in_col, value, e):
310
        ignore(in_col, value, e)
311
        cond = (in_col, sql_gen.CompareCond(value, '!='))
312
        assert cond not in conds # avoid infinite loops
313
        conds.add(cond)
314
    
315
    def invalid2null(in_col, value, e):
316
        ignore(in_col, value, e)
317
        sql.update(db, in_table, [(in_col, None)],
318
            sql_gen.ColValueCond(in_col, value))
319
    
320
    def insert_pkeys_table(which):
321
        return sql_gen.Table(sql_gen.concat(in_table.name,
322
            '_insert_'+which+'_pkeys'))
323
    insert_out_pkeys = insert_pkeys_table('out')
324
    insert_in_pkeys = insert_pkeys_table('in')
325
    
326
    # Do inserts and selects
327
    join_cols = sql_gen.ColDict(db, out_table)
328
    while True:
329
        if limit_ref[0] == 0: # special case
330
            log_debug('Creating an empty pkeys table')
331
            cur = sql.run_query_into(db, sql.mk_select(db, out_table,
332
                [out_pkey], limit=limit_ref[0]), into=insert_out_pkeys)
333
            break # don't do main case
334
        
335
        has_joins = join_cols != {}
336
        
337
        log_debug('Trying to insert new rows')
338
        
339
        # Prepare to insert new rows
340
        insert_joins = input_joins[:] # don't modify original!
341
        insert_args = dict(recover=True, cacheable=False)
342
        if has_joins:
343
            insert_args.update(dict(ignore=True))
344
        else:
345
            insert_args.update(dict(returning=out_pkey, into=insert_out_pkeys))
346
        main_select = mk_main_select(insert_joins, mapping.values())
347
        
348
        def main_insert():
349
            if is_function:
350
                log_debug('Calling function on input rows')
351
                args = dict(((k.name, v) for k, v in mapping.iteritems()))
352
                func_call = sql_gen.NamedCol(out_pkey,
353
                    sql_gen.FunctionCall(out_table, **args))
354
                insert_into_pkeys(input_joins, [in_pkey_col, func_call])
355
                return None
356
            else:
357
                return sql.insert_select(db, out_table, mapping.keys(),
358
                    main_select, **insert_args)
359
        
360
        try:
361
            cur = sql.with_savepoint(db, main_insert)
362
            break # insert successful
363
        except sql.MissingCastException, e:
364
            log_exc(e)
365
            
366
            out_col = e.col
367
            type_ = e.type
368
            
369
            log_debug('Casting '+strings.as_tt(out_col)+' input to '
370
                +strings.as_tt(type_))
371
            mapping[out_col] = cast_temp_col(db, type_, mapping[out_col],
372
                errors_table_)
373
        except sql.DuplicateKeyException, e:
374
            log_exc(e)
375
            
376
            old_join_cols = join_cols.copy()
377
            distinct_on.update(util.dict_subset(mapping, e.cols))
378
            join_cols.update(util.dict_subset_right_join(mapping, e.cols))
379
            log_debug('Ignoring existing rows, comparing on these columns:\n'
380
                +strings.as_inline_table(join_cols, ustr=col_ustr))
381
            assert join_cols != old_join_cols # avoid infinite loops
382
        except sql.NullValueException, e:
383
            log_exc(e)
384
            
385
            out_col, = e.cols
386
            try: in_col = mapping[out_col]
387
            except KeyError:
388
                log_debug('Missing mapping for NOT NULL column '+out_col)
389
                remove_all_rows()
390
            else: remove_rows(in_col, None, e)
391
        except sql.FunctionValueException, e:
392
            log_exc(e)
393
            
394
            func_name = e.name
395
            value = e.value
396
            for out_col, in_col in mapping.iteritems():
397
                in_col = sql_gen.unwrap_func_call(in_col, func_name)
398
                invalid2null(in_col, value, e)
399
        except sql.DatabaseErrors, e:
400
            log_exc(e)
401
            
402
            log_debug('No handler for exception')
403
            on_error(e)
404
            remove_all_rows()
405
        # after exception handled, rerun loop with additional constraints
406
    
407
    if cur != None and row_ct_ref != None and cur.rowcount >= 0:
408
        row_ct_ref[0] += cur.rowcount
409
    
410
    if is_function: pass # pkeys table already created
411
    elif has_joins:
412
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols)]
413
        log_debug('Getting output table pkeys of existing/inserted rows')
414
        insert_into_pkeys(select_joins, pkeys_cols, distinct=True)
415
    else:
416
        sql.add_row_num(db, insert_out_pkeys) # for joining with input pkeys
417
        
418
        log_debug('Getting input table pkeys of inserted rows')
419
        sql.run_query_into(db, mk_main_select(input_joins, [in_pkey]),
420
            into=insert_in_pkeys)
421
        sql.add_row_num(db, insert_in_pkeys) # for joining with output pkeys
422
        
423
        assert sql.table_row_count(db, insert_out_pkeys) == sql.table_row_count(
424
            db, insert_in_pkeys)
425
        
426
        log_debug('Combining output and input pkeys in inserted order')
427
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
428
            {sql.row_num_col: sql_gen.join_same_not_null})]
429
        insert_into_pkeys(pkey_joins, pkeys_names)
430
        
431
        sql.empty_temp(db, [insert_out_pkeys, insert_in_pkeys])
432
    
433
    db.log_debug('Adding pkey on pkeys table to enable fast joins', level=2.5)
434
    sql.add_pkey(db, into)
435
    
436
    log_debug('Setting pkeys of missing rows to '+strings.as_tt(repr(default)))
437
    missing_rows_joins = input_joins+[sql_gen.Join(into,
438
        {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
439
        # must use join_same_not_null or query will take forever
440
    insert_into_pkeys(missing_rows_joins,
441
        [in_pkey_col, sql_gen.NamedCol(out_pkey, default)])
442
    
443
    assert sql.table_row_count(db, into) == sql.table_row_count(db, in_table)
444
    
445
    sql.empty_temp(db, in_table)
446
    
447
    srcs = []
448
    if is_func: srcs = sql_gen.cols_srcs(in_cols)
449
    return sql_gen.Col(out_pkey, into, srcs)
(26-26/37)