Project

General

Profile

1
# Database import/export
2

    
3
import exc
4
import dicts
5
import sql
6
import sql_gen
7
import strings
8
import util
9

    
10
##### Data cleanup
11

    
12
def cleanup_table(db, table, cols):
13
    table = sql_gen.as_Table(table)
14
    cols = map(sql_gen.as_Col, cols)
15
    
16
    expr = ('nullif(nullif(trim(both from %s), '+db.esc_value('')+'), '
17
        +db.esc_value(r'\N')+')')
18
    changes = [(v, sql_gen.CustomCode(expr % v.to_str(db)))
19
        for v in cols]
20
    
21
    sql.update(db, table, changes, in_place=True)
22

    
23
##### Error tracking
24

    
25
def track_data_error(db, errors_table, cols, value, error_code, error):
26
    '''
27
    @param errors_table If None, does nothing.
28
    '''
29
    if errors_table == None or cols == (): return
30
    
31
    for col in cols:
32
        try:
33
            sql.insert(db, errors_table, dict(column=col.name, value=value,
34
                error_code=error_code, error=error), recover=True,
35
                cacheable=True, log_level=4)
36
        except sql.DuplicateKeyException: pass
37

    
38
def cast(db, type_, col, errors_table=None):
39
    '''Casts an (unrenamed) column or value.
40
    If errors_table set and col has srcs, saves errors in errors_table (using
41
    col's srcs attr as the source columns) and converts errors to warnings.
42
    @param col str|sql_gen.Col|sql_gen.Literal
43
    @param errors_table None|sql_gen.Table|str
44
    '''
45
    col = sql_gen.as_Col(col)
46
    save_errors = (errors_table != None and isinstance(col, sql_gen.Col)
47
        and col.srcs != ())
48
    if not save_errors: return sql_gen.Cast(type_, col) # can't save errors
49
    
50
    assert not isinstance(col, sql_gen.NamedCol)
51
    
52
    errors_table = sql_gen.as_Table(errors_table)
53
    srcs = map(sql_gen.to_name_only_col, col.srcs)
54
    function_name = str(sql_gen.FunctionCall(type_, *srcs))
55
    function = db.TempFunction(function_name)
56
    
57
    while True:
58
        # Create function definition
59
        errors_table_cols = map(sql_gen.Col,
60
            ['column', 'value', 'error_code', 'error'])
61
        query = '''\
62
CREATE FUNCTION '''+function.to_str(db)+'''(value text)
63
RETURNS '''+type_+'''
64
LANGUAGE plpgsql
65
STRICT
66
AS $$
67
BEGIN
68
    /* The explicit cast to the return type is needed to make the cast happen
69
    inside the try block. (Implicit casts to the return type happen at the end
70
    of the function, outside any block.) */
71
    RETURN value::'''+type_+''';
72
EXCEPTION
73
    WHEN data_exception THEN
74
        -- Save error in errors table.
75
        DECLARE
76
            error_code text := SQLSTATE;
77
            error text := SQLERRM;
78
            "column" text;
79
        BEGIN
80
            -- Insert the value and error for *each* source column.
81
            FOR "column" IN
82
'''+sql.mk_select(db, sql_gen.NamedValues('c', None, [[c.name] for c in srcs]),
83
    order_by=None, start=0)+'''
84
            LOOP
85
                BEGIN
86
'''+sql.mk_insert_select(db, errors_table, errors_table_cols,
87
    sql_gen.Values(errors_table_cols).to_str(db))+''';
88
                EXCEPTION
89
                    WHEN unique_violation THEN NULL; -- continue to next row
90
                END;
91
            END LOOP;
92
        END;
93
        
94
        RAISE WARNING '%', SQLERRM;
95
        RETURN NULL;
96
END;
97
$$;
98
'''
99
        
100
        # Create function
101
        try:
102
            sql.run_query(db, query, recover=True, cacheable=True,
103
                log_ignore_excs=(sql.DuplicateException,))
104
            break # successful
105
        except sql.DuplicateException:
106
            function.name = sql.next_version(function.name)
107
            # try again with next version of name
108
    
109
    return sql_gen.FunctionCall(function, col)
110

    
111
def cast_temp_col(db, type_, col, errors_table=None):
112
    '''Like cast(), but creates a new column with the cast values if the input
113
    is a column.
114
    @return The new column or cast value
115
    '''
116
    def cast_(col): return cast(db, type_, col, errors_table)
117
    
118
    try: col = sql_gen.underlying_col(col)
119
    except sql_gen.NoUnderlyingTableException: return sql_gen.wrap(cast_, col)
120
    
121
    table = col.table
122
    new_col = sql_gen.Col(sql_gen.concat(col.name, '::'+type_), table, col.srcs)
123
    expr = cast_(col)
124
    
125
    # Add column
126
    new_typed_col = sql_gen.TypedCol(new_col.name, type_)
127
    sql.add_col(db, table, new_typed_col, comment='src: '+repr(col))
128
    new_col.name = new_typed_col.name # propagate any renaming
129
    
130
    sql.update(db, table, [(new_col, expr)], in_place=True, cacheable=True)
131
    sql.add_index(db, new_col)
132
    
133
    return new_col
134

    
135
def errors_table(db, table, if_exists=True):
136
    '''
137
    @param if_exists If set, returns None if the errors table doesn't exist
138
    @return None|sql_gen.Table
139
    '''
140
    table = sql_gen.as_Table(table)
141
    if table.srcs != (): table = table.srcs[0]
142
    
143
    errors_table = sql_gen.suffixed_table(table, '.errors')
144
    if if_exists and not sql.table_exists(db, errors_table): return None
145
    return errors_table
146

    
147
##### Import
148

    
149
def put(db, table, row, pkey_=None, row_ct_ref=None):
150
    '''Recovers from errors.
151
    Only works under PostgreSQL (uses INSERT RETURNING).
152
    '''
153
    row = sql_gen.ColDict(db, table, row)
154
    if pkey_ == None: pkey_ = sql.pkey(db, table, recover=True)
155
    
156
    try:
157
        cur = sql.insert(db, table, row, pkey_, recover=True)
158
        if row_ct_ref != None and cur.rowcount >= 0:
159
            row_ct_ref[0] += cur.rowcount
160
        return sql.value(cur)
161
    except sql.DuplicateKeyException, e:
162
        row = sql_gen.ColDict(db, table,
163
            util.dict_subset_right_join(row, e.cols))
164
        return sql.value(sql.select(db, table, [pkey_], row, recover=True))
165

    
166
def get(db, table, row, pkey, row_ct_ref=None, create=False):
167
    '''Recovers from errors'''
168
    try:
169
        return sql.value(sql.select(db, table, [pkey], row, limit=1,
170
            recover=True))
171
    except StopIteration:
172
        if not create: raise
173
        return put(db, table, row, pkey, row_ct_ref) # insert new row
174

    
175
def is_func_result(col):
176
    return col.table.name.find('(') >= 0 and col.name == 'result'
177

    
178
def into_table_name(out_table, in_tables0, mapping, is_func):
179
    def in_col_str(in_col):
180
        in_col = sql_gen.remove_col_rename(in_col)
181
        if isinstance(in_col, sql_gen.Col):
182
            table = in_col.table
183
            if table == in_tables0:
184
                in_col = sql_gen.to_name_only_col(in_col)
185
            elif is_func_result(in_col): in_col = table # omit col name
186
        return str(in_col)
187
    
188
    str_ = str(out_table)
189
    if is_func:
190
        str_ += '('
191
        
192
        try: value_in_col = mapping['value']
193
        except KeyError:
194
            str_ += ', '.join((str(k)+'='+in_col_str(v)
195
                for k, v in mapping.iteritems()))
196
        else: str_ += in_col_str(value_in_col)
197
        
198
        str_ += ')'
199
    else:
200
        out_col = 'rank'
201
        try: in_col = mapping[out_col]
202
        except KeyError: str_ += '_pkeys'
203
        else: # has a rank column, so hierarchical
204
            str_ += '['+str(out_col)+'='+in_col_str(in_col)+']'
205
    return str_
206

    
207
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, into=None,
208
    default=None, is_func=False, on_error=exc.raise_):
209
    '''Recovers from errors.
210
    Only works under PostgreSQL (uses INSERT RETURNING).
211
    IMPORTANT: Must be run at the *beginning* of a transaction.
212
    @param in_tables The main input table to select from, followed by a list of
213
        tables to join with it using the main input table's pkey
214
    @param mapping dict(out_table_col=in_table_col, ...)
215
        * out_table_col: str (*not* sql_gen.Col)
216
        * in_table_col: sql_gen.Col|literal-value
217
    @param into The table to contain the output and input pkeys.
218
        Defaults to `out_table.name+'_pkeys'`.
219
    @param default The *output* column to use as the pkey for missing rows.
220
        If this output column does not exist in the mapping, uses None.
221
    @param is_func Whether out_table is the name of a SQL function, not a table
222
    @return sql_gen.Col Where the output pkeys are made available
223
    '''
224
    out_table = sql_gen.as_Table(out_table)
225
    
226
    def log_debug(msg): db.log_debug(msg, level=1.5)
227
    def col_ustr(str_):
228
        return strings.repr_no_u(sql_gen.remove_col_rename(str_))
229
    
230
    log_debug('********** New iteration **********')
231
    log_debug('Inserting these input columns into '+strings.as_tt(
232
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
233
    
234
    is_function = sql.function_exists(db, out_table)
235
    
236
    if is_function: out_pkey = 'result'
237
    else: out_pkey = sql.pkey(db, out_table, recover=True)
238
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
239
    
240
    if mapping == {}: # need at least one column for INSERT SELECT
241
        mapping = {out_pkey: None} # ColDict will replace with default value
242
    
243
    # Create input joins from list of input tables
244
    in_tables_ = in_tables[:] # don't modify input!
245
    in_tables0 = in_tables_.pop(0) # first table is separate
246
    errors_table_ = errors_table(db, in_tables0)
247
    in_pkey = sql.pkey(db, in_tables0, recover=True)
248
    in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
249
    input_joins = [in_tables0]+[sql_gen.Join(v,
250
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
251
    
252
    if into == None:
253
        into = into_table_name(out_table, in_tables0, mapping, is_func)
254
    into = sql_gen.as_Table(into)
255
    
256
    # Set column sources
257
    in_cols = filter(sql_gen.is_table_col, mapping.values())
258
    for col in in_cols:
259
        if col.table == in_tables0: col.set_srcs(sql_gen.src_self)
260
    
261
    log_debug('Joining together input tables into temp table')
262
    # Place in new table for speed and so don't modify input if values edited
263
    in_table = sql_gen.Table('in')
264
    mapping = dicts.join(mapping, sql.flatten(db, in_table, input_joins,
265
        in_cols, preserve=[in_pkey_col], start=0))
266
    input_joins = [in_table]
267
    db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
268
    
269
    mapping = sql_gen.ColDict(db, out_table, mapping)
270
        # after applying dicts.join() because that returns a plain dict
271
    
272
    # Resolve default value column
273
    if default != None:
274
        try: default = mapping[default]
275
        except KeyError:
276
            db.log_debug('Default value column '
277
                +strings.as_tt(strings.repr_no_u(default))
278
                +' does not exist in mapping, falling back to None', level=2.1)
279
            default = None
280
    
281
    pkeys_names = [in_pkey, out_pkey]
282
    pkeys_cols = [in_pkey_col, out_pkey_col]
283
    
284
    pkeys_table_exists_ref = [False]
285
    def insert_into_pkeys(joins, cols, distinct=False):
286
        kw_args = {}
287
        if distinct: kw_args.update(dict(distinct_on=[in_pkey_col]))
288
        query = sql.mk_select(db, joins, cols, order_by=None, start=0,
289
            **kw_args)
290
        
291
        if pkeys_table_exists_ref[0]:
292
            sql.insert_select(db, into, pkeys_names, query)
293
        else:
294
            sql.run_query_into(db, query, into=into)
295
            pkeys_table_exists_ref[0] = True
296
    
297
    limit_ref = [None]
298
    conds = set()
299
    distinct_on = sql_gen.ColDict(db, out_table)
300
    def mk_main_select(joins, cols):
301
        return sql.mk_select(db, joins, cols, conds, limit=limit_ref[0],
302
            start=0)
303
    
304
    exc_strs = set()
305
    def log_exc(e):
306
        e_str = exc.str_(e, first_line_only=True)
307
        log_debug('Caught exception: '+e_str)
308
        assert e_str not in exc_strs # avoid infinite loops
309
        exc_strs.add(e_str)
310
    
311
    def remove_all_rows():
312
        log_debug('Ignoring all rows')
313
        limit_ref[0] = 0 # just create an empty pkeys table
314
    
315
    def ignore(in_col, value, e):
316
        track_data_error(db, errors_table_, in_col.srcs, value,
317
            e.cause.pgcode, e.cause.pgerror)
318
        log_debug('Ignoring rows with '+strings.as_tt(repr(in_col))+' = '
319
            +strings.as_tt(repr(value)))
320
    
321
    def remove_rows(in_col, value, e):
322
        ignore(in_col, value, e)
323
        cond = (in_col, sql_gen.CompareCond(value, '!='))
324
        assert cond not in conds # avoid infinite loops
325
        conds.add(cond)
326
    
327
    def invalid2null(in_col, value, e):
328
        ignore(in_col, value, e)
329
        sql.update(db, in_table, [(in_col, None)],
330
            sql_gen.ColValueCond(in_col, value))
331
    
332
    def insert_pkeys_table(which):
333
        return sql_gen.Table(sql_gen.concat(in_table.name,
334
            '_insert_'+which+'_pkeys'))
335
    insert_out_pkeys = insert_pkeys_table('out')
336
    insert_in_pkeys = insert_pkeys_table('in')
337
    
338
    # Do inserts and selects
339
    insert_in_table = in_table
340
    join_cols = sql_gen.ColDict(db, out_table)
341
    while True:
342
        if limit_ref[0] == 0: # special case
343
            log_debug('Creating an empty pkeys table')
344
            cur = sql.run_query_into(db, sql.mk_select(db, out_table,
345
                [out_pkey], limit=limit_ref[0]), into=insert_out_pkeys)
346
            break # don't do main case
347
        
348
        has_joins = join_cols != {}
349
        
350
        log_debug('Trying to insert new rows')
351
        
352
        # Prepare to insert new rows
353
        insert_args = dict(recover=True, cacheable=False)
354
        if has_joins:
355
            insert_args.update(dict(ignore=True))
356
        else:
357
            insert_args.update(dict(returning=out_pkey, into=insert_out_pkeys))
358
        main_select = mk_main_select([insert_in_table],
359
            [sql_gen.with_table(c, insert_in_table) for c in mapping.values()])
360
        
361
        def main_insert():
362
            if is_function:
363
                log_debug('Calling function on input rows')
364
                args = dict(((k.name, v) for k, v in mapping.iteritems()))
365
                func_call = sql_gen.NamedCol(out_pkey,
366
                    sql_gen.FunctionCall(out_table, **args))
367
                insert_into_pkeys(input_joins, [in_pkey_col, func_call])
368
                return None
369
            else:
370
                return sql.insert_select(db, out_table, mapping.keys(),
371
                    main_select, **insert_args)
372
        
373
        try:
374
            cur = sql.with_savepoint(db, main_insert)
375
            break # insert successful
376
        except sql.MissingCastException, e:
377
            log_exc(e)
378
            
379
            out_col = e.col
380
            type_ = e.type
381
            
382
            log_debug('Casting '+strings.as_tt(out_col)+' input to '
383
                +strings.as_tt(type_))
384
            mapping[out_col] = cast_temp_col(db, type_, mapping[out_col],
385
                errors_table_)
386
        except sql.DuplicateKeyException, e:
387
            log_exc(e)
388
            
389
            old_join_cols = join_cols.copy()
390
            distinct_on.update(util.dict_subset(mapping, e.cols))
391
            join_cols.update(util.dict_subset_right_join(mapping, e.cols))
392
            log_debug('Ignoring existing rows, comparing on these columns:\n'
393
                +strings.as_inline_table(join_cols, ustr=col_ustr))
394
            assert join_cols != old_join_cols # avoid infinite loops
395
            
396
            # Uniquify input table to avoid internal duplicate keys
397
            insert_in_table = sql.distinct_table(db, insert_in_table,
398
                filter(sql_gen.is_table_col, distinct_on.values()))
399
        except sql.NullValueException, e:
400
            log_exc(e)
401
            
402
            out_col, = e.cols
403
            try: in_col = mapping[out_col]
404
            except KeyError:
405
                log_debug('Missing mapping for NOT NULL column '+out_col)
406
                remove_all_rows()
407
            else: remove_rows(in_col, None, e)
408
        except sql.FunctionValueException, e:
409
            log_exc(e)
410
            
411
            func_name = e.name
412
            value = e.value
413
            for out_col, in_col in mapping.iteritems():
414
                in_col = sql_gen.unwrap_func_call(in_col, func_name)
415
                invalid2null(in_col, value, e)
416
        except sql.DatabaseErrors, e:
417
            log_exc(e)
418
            
419
            log_debug('No handler for exception')
420
            on_error(e)
421
            remove_all_rows()
422
        # after exception handled, rerun loop with additional constraints
423
    
424
    if cur != None and row_ct_ref != None and cur.rowcount >= 0:
425
        row_ct_ref[0] += cur.rowcount
426
    
427
    if is_function: pass # pkeys table already created
428
    elif has_joins:
429
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols)]
430
        log_debug('Getting output table pkeys of existing/inserted rows')
431
        insert_into_pkeys(select_joins, pkeys_cols, distinct=True)
432
    else:
433
        sql.add_row_num(db, insert_out_pkeys) # for joining with input pkeys
434
        
435
        log_debug('Getting input table pkeys of inserted rows')
436
        sql.run_query_into(db, mk_main_select(input_joins, [in_pkey]),
437
            into=insert_in_pkeys)
438
        sql.add_row_num(db, insert_in_pkeys) # for joining with output pkeys
439
        
440
        assert sql.table_row_count(db, insert_out_pkeys) == sql.table_row_count(
441
            db, insert_in_pkeys)
442
        
443
        log_debug('Combining output and input pkeys in inserted order')
444
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
445
            {sql.row_num_col: sql_gen.join_same_not_null})]
446
        insert_into_pkeys(pkey_joins, pkeys_names)
447
        
448
        sql.empty_temp(db, [insert_out_pkeys, insert_in_pkeys])
449
    
450
    db.log_debug('Adding pkey on pkeys table to enable fast joins', level=2.5)
451
    sql.add_pkey(db, into)
452
    
453
    log_debug('Setting pkeys of missing rows to '+strings.as_tt(repr(default)))
454
    missing_rows_joins = input_joins+[sql_gen.Join(into,
455
        {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
456
        # must use join_same_not_null or query will take forever
457
    insert_into_pkeys(missing_rows_joins,
458
        [in_pkey_col, sql_gen.NamedCol(out_pkey, default)])
459
    
460
    assert sql.table_row_count(db, into) == sql.table_row_count(db, in_table)
461
    
462
    sql.empty_temp(db, set([in_table, insert_in_table]))
463
    
464
    srcs = []
465
    if is_func: srcs = sql_gen.cols_srcs(in_cols)
466
    return sql_gen.Col(out_pkey, into, srcs)
(26-26/37)