Project

General

Profile

1
# Database import/export
2

    
3
import exc
4
import dicts
5
import sql
6
import sql_gen
7
import strings
8
import util
9

    
10
##### Data cleanup
11

    
12
def cleanup_table(db, table, cols):
13
    table = sql_gen.as_Table(table)
14
    cols = map(sql_gen.as_Col, cols)
15
    
16
    expr = ('nullif(nullif(trim(both from %s), '+db.esc_value('')+'), '
17
        +db.esc_value(r'\N')+')')
18
    changes = [(v, sql_gen.CustomCode(expr % v.to_str(db)))
19
        for v in cols]
20
    
21
    sql.update(db, table, changes, in_place=True)
22

    
23
##### Error tracking
24

    
25
def track_data_error(db, errors_table, cols, value, error_code, error):
26
    '''
27
    @param errors_table If None, does nothing.
28
    '''
29
    if errors_table == None or cols == (): return
30
    
31
    for col in cols:
32
        try:
33
            sql.insert(db, errors_table, dict(column=col.name, value=value,
34
                error_code=error_code, error=error), recover=True,
35
                cacheable=True, log_level=4)
36
        except sql.DuplicateKeyException: pass
37

    
38
def cast(db, type_, col, errors_table=None):
39
    '''Casts an (unrenamed) column or value.
40
    If errors_table set and col has srcs, saves errors in errors_table (using
41
    col's srcs attr as the source columns) and converts errors to warnings.
42
    @param col str|sql_gen.Col|sql_gen.Literal
43
    @param errors_table None|sql_gen.Table|str
44
    '''
45
    col = sql_gen.as_Col(col)
46
    save_errors = (errors_table != None and isinstance(col, sql_gen.Col)
47
        and col.srcs != ())
48
    if not save_errors: return sql_gen.Cast(type_, col) # can't save errors
49
    
50
    assert not isinstance(col, sql_gen.NamedCol)
51
    
52
    errors_table = sql_gen.as_Table(errors_table)
53
    srcs = map(sql_gen.to_name_only_col, col.srcs)
54
    function_name = str(sql_gen.FunctionCall(type_, *srcs))
55
    function = db.TempFunction(function_name)
56
    
57
    while True:
58
        # Create function definition
59
        errors_table_cols = map(sql_gen.Col,
60
            ['column', 'value', 'error_code', 'error'])
61
        query = '''\
62
CREATE FUNCTION '''+function.to_str(db)+'''(value text)
63
RETURNS '''+type_+'''
64
LANGUAGE plpgsql
65
STRICT
66
AS $$
67
BEGIN
68
    /* The explicit cast to the return type is needed to make the cast happen
69
    inside the try block. (Implicit casts to the return type happen at the end
70
    of the function, outside any block.) */
71
    RETURN value::'''+type_+''';
72
EXCEPTION
73
    WHEN data_exception THEN
74
        -- Save error in errors table.
75
        DECLARE
76
            error_code text := SQLSTATE;
77
            error text := SQLERRM;
78
            "column" text;
79
        BEGIN
80
            -- Insert the value and error for *each* source column.
81
            FOR "column" IN
82
'''+sql.mk_select(db, sql_gen.NamedValues('c', None, [[c.name] for c in srcs]),
83
    order_by=None, start=0)+'''
84
            LOOP
85
                BEGIN
86
'''+sql.mk_insert_select(db, errors_table, errors_table_cols,
87
    sql_gen.Values(errors_table_cols).to_str(db))+''';
88
                EXCEPTION
89
                    WHEN unique_violation THEN NULL; -- continue to next row
90
                END;
91
            END LOOP;
92
        END;
93
        
94
        RAISE WARNING '%', SQLERRM;
95
        RETURN NULL;
96
END;
97
$$;
98
'''
99
        
100
        # Create function
101
        try:
102
            sql.run_query(db, query, recover=True, cacheable=True,
103
                log_ignore_excs=(sql.DuplicateException,))
104
            break # successful
105
        except sql.DuplicateException:
106
            function.name = sql.next_version(function.name)
107
            # try again with next version of name
108
    
109
    return sql_gen.FunctionCall(function, col)
110

    
111
def cast_temp_col(db, type_, col, errors_table=None):
112
    '''Like cast(), but creates a new column with the cast values if the input
113
    is a column.
114
    @return The new column or cast value
115
    '''
116
    def cast_(col): return cast(db, type_, col, errors_table)
117
    
118
    try: col = sql_gen.underlying_col(col)
119
    except sql_gen.NoUnderlyingTableException: return sql_gen.wrap(cast_, col)
120
    
121
    table = col.table
122
    new_col = sql_gen.Col(sql_gen.concat(col.name, '::'+type_), table, col.srcs)
123
    expr = cast_(col)
124
    
125
    # Add column
126
    new_typed_col = sql_gen.TypedCol(new_col.name, type_)
127
    sql.add_col(db, table, new_typed_col, comment='src: '+repr(col))
128
    new_col.name = new_typed_col.name # propagate any renaming
129
    
130
    sql.update(db, table, [(new_col, expr)], in_place=True, cacheable=True)
131
    sql.add_index(db, new_col)
132
    
133
    return new_col
134

    
135
def errors_table(db, table, if_exists=True):
136
    '''
137
    @param if_exists If set, returns None if the errors table doesn't exist
138
    @return None|sql_gen.Table
139
    '''
140
    table = sql_gen.as_Table(table)
141
    if table.srcs != (): table = table.srcs[0]
142
    
143
    errors_table = sql_gen.suffixed_table(table, '.errors')
144
    if if_exists and not sql.table_exists(db, errors_table): return None
145
    return errors_table
146

    
147
##### Import
148

    
149
def put(db, table, row, pkey_=None, row_ct_ref=None):
150
    '''Recovers from errors.
151
    Only works under PostgreSQL (uses INSERT RETURNING).
152
    '''
153
    row = sql_gen.ColDict(db, table, row)
154
    if pkey_ == None: pkey_ = sql.pkey(db, table, recover=True)
155
    
156
    try:
157
        cur = sql.insert(db, table, row, pkey_, recover=True)
158
        if row_ct_ref != None and cur.rowcount >= 0:
159
            row_ct_ref[0] += cur.rowcount
160
        return sql.value(cur)
161
    except sql.DuplicateKeyException, e:
162
        row = sql_gen.ColDict(db, table,
163
            util.dict_subset_right_join(row, e.cols))
164
        return sql.value(sql.select(db, table, [pkey_], row, recover=True))
165

    
166
def get(db, table, row, pkey, row_ct_ref=None, create=False):
167
    '''Recovers from errors'''
168
    try:
169
        return sql.value(sql.select(db, table, [pkey], row, limit=1,
170
            recover=True))
171
    except StopIteration:
172
        if not create: raise
173
        return put(db, table, row, pkey, row_ct_ref) # insert new row
174

    
175
def is_func_result(col):
176
    return col.table.name.find('(') >= 0 and col.name == 'result'
177

    
178
def into_table_name(out_table, in_tables0, mapping, is_func):
179
    def in_col_str(in_col):
180
        in_col = sql_gen.remove_col_rename(in_col)
181
        if isinstance(in_col, sql_gen.Col):
182
            table = in_col.table
183
            if table == in_tables0:
184
                in_col = sql_gen.to_name_only_col(in_col)
185
            elif is_func_result(in_col): in_col = table # omit col name
186
        return str(in_col)
187
    
188
    str_ = str(out_table)
189
    if is_func:
190
        str_ += '('
191
        
192
        try: value_in_col = mapping['value']
193
        except KeyError:
194
            str_ += ', '.join((str(k)+'='+in_col_str(v)
195
                for k, v in mapping.iteritems()))
196
        else: str_ += in_col_str(value_in_col)
197
        
198
        str_ += ')'
199
    else:
200
        out_col = 'rank'
201
        try: in_col = mapping[out_col]
202
        except KeyError: str_ += '_pkeys'
203
        else: # has a rank column, so hierarchical
204
            str_ += '['+str(out_col)+'='+in_col_str(in_col)+']'
205
    return str_
206

    
207
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, into=None,
208
    default=None, is_func=False, on_error=exc.raise_):
209
    '''Recovers from errors.
210
    Only works under PostgreSQL (uses INSERT RETURNING).
211
    IMPORTANT: Must be run at the *beginning* of a transaction.
212
    @param in_tables The main input table to select from, followed by a list of
213
        tables to join with it using the main input table's pkey
214
    @param mapping dict(out_table_col=in_table_col, ...)
215
        * out_table_col: str (*not* sql_gen.Col)
216
        * in_table_col: sql_gen.Col|literal-value
217
    @param into The table to contain the output and input pkeys.
218
        Defaults to `out_table.name+'_pkeys'`.
219
    @param default The *output* column to use as the pkey for missing rows.
220
        If this output column does not exist in the mapping, uses None.
221
    @param is_func Whether out_table is the name of a SQL function, not a table
222
    @return sql_gen.Col Where the output pkeys are made available
223
    '''
224
    out_table = sql_gen.as_Table(out_table)
225
    
226
    def log_debug(msg): db.log_debug(msg, level=1.5)
227
    def col_ustr(str_):
228
        return strings.repr_no_u(sql_gen.remove_col_rename(str_))
229
    
230
    log_debug('********** New iteration **********')
231
    log_debug('Inserting these input columns into '+strings.as_tt(
232
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
233
    
234
    is_function = sql.function_exists(db, out_table)
235
    
236
    if is_function: out_pkey = 'result'
237
    else: out_pkey = sql.pkey(db, out_table, recover=True)
238
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
239
    
240
    if mapping == {}: # need at least one column for INSERT SELECT
241
        mapping = {out_pkey: None} # ColDict will replace with default value
242
    
243
    # Create input joins from list of input tables
244
    in_tables_ = in_tables[:] # don't modify input!
245
    in_tables0 = in_tables_.pop(0) # first table is separate
246
    errors_table_ = errors_table(db, in_tables0)
247
    in_pkey = sql.pkey(db, in_tables0, recover=True)
248
    in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
249
    input_joins = [in_tables0]+[sql_gen.Join(v,
250
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
251
    
252
    if into == None:
253
        into = into_table_name(out_table, in_tables0, mapping, is_func)
254
    into = sql_gen.as_Table(into)
255
    
256
    # Set column sources
257
    in_cols = filter(sql_gen.is_table_col, mapping.values())
258
    for col in in_cols:
259
        if col.table == in_tables0: col.set_srcs(sql_gen.src_self)
260
    
261
    log_debug('Joining together input tables into temp table')
262
    # Place in new table for speed and so don't modify input if values edited
263
    in_table = sql_gen.Table('in')
264
    mapping = dicts.join(mapping, sql.flatten(db, in_table, input_joins,
265
        in_cols, preserve=[in_pkey_col], start=0))
266
    input_joins = [in_table]
267
    db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
268
    
269
    mapping = sql_gen.ColDict(db, out_table, mapping)
270
        # after applying dicts.join() because that returns a plain dict
271
    
272
    # Resolve default value column
273
    if default != None:
274
        try: default = mapping[default]
275
        except KeyError:
276
            db.log_debug('Default value column '
277
                +strings.as_tt(strings.repr_no_u(default))
278
                +' does not exist in mapping, falling back to None', level=2.1)
279
            default = None
280
    
281
    pkeys_names = [in_pkey, out_pkey]
282
    pkeys_cols = [in_pkey_col, out_pkey_col]
283
    
284
    pkeys_table_exists_ref = [False]
285
    def insert_into_pkeys(joins, cols, distinct=False):
286
        kw_args = {}
287
        if distinct: kw_args.update(dict(distinct_on=[in_pkey_col]))
288
        query = sql.mk_select(db, joins, cols, order_by=None, start=0,
289
            **kw_args)
290
        
291
        if pkeys_table_exists_ref[0]:
292
            sql.insert_select(db, into, pkeys_names, query)
293
        else:
294
            sql.run_query_into(db, query, into=into)
295
            pkeys_table_exists_ref[0] = True
296
    
297
    limit_ref = [None]
298
    conds = set()
299
    distinct_on = sql_gen.ColDict(db, out_table)
300
    def mk_main_select(joins, cols):
301
        distinct_on_cols = [c.to_Col() for c in distinct_on.values()]
302
        return sql.mk_select(db, joins, cols, conds, distinct_on_cols,
303
            limit=limit_ref[0], start=0)
304
    
305
    exc_strs = set()
306
    def log_exc(e):
307
        e_str = exc.str_(e, first_line_only=True)
308
        log_debug('Caught exception: '+e_str)
309
        assert e_str not in exc_strs # avoid infinite loops
310
        exc_strs.add(e_str)
311
    
312
    def remove_all_rows():
313
        log_debug('Ignoring all rows')
314
        limit_ref[0] = 0 # just create an empty pkeys table
315
    
316
    def ignore(in_col, value, e):
317
        track_data_error(db, errors_table_, in_col.srcs, value,
318
            e.cause.pgcode, e.cause.pgerror)
319
        log_debug('Ignoring rows with '+strings.as_tt(repr(in_col))+' = '
320
            +strings.as_tt(repr(value)))
321
    
322
    def remove_rows(in_col, value, e):
323
        ignore(in_col, value, e)
324
        cond = (in_col, sql_gen.CompareCond(value, '!='))
325
        assert cond not in conds # avoid infinite loops
326
        conds.add(cond)
327
    
328
    def invalid2null(in_col, value, e):
329
        ignore(in_col, value, e)
330
        sql.update(db, in_table, [(in_col, None)],
331
            sql_gen.ColValueCond(in_col, value))
332
    
333
    def insert_pkeys_table(which):
334
        return sql_gen.Table(sql_gen.concat(in_table.name,
335
            '_insert_'+which+'_pkeys'))
336
    insert_out_pkeys = insert_pkeys_table('out')
337
    insert_in_pkeys = insert_pkeys_table('in')
338
    
339
    # Do inserts and selects
340
    join_cols = sql_gen.ColDict(db, out_table)
341
    while True:
342
        if limit_ref[0] == 0: # special case
343
            log_debug('Creating an empty pkeys table')
344
            cur = sql.run_query_into(db, sql.mk_select(db, out_table,
345
                [out_pkey], limit=limit_ref[0]), into=insert_out_pkeys)
346
            break # don't do main case
347
        
348
        has_joins = join_cols != {}
349
        
350
        log_debug('Trying to insert new rows')
351
        
352
        # Prepare to insert new rows
353
        insert_joins = input_joins[:] # don't modify original!
354
        insert_args = dict(recover=True, cacheable=False)
355
        if has_joins:
356
            insert_args.update(dict(ignore=True))
357
        else:
358
            insert_args.update(dict(returning=out_pkey, into=insert_out_pkeys))
359
        main_select = mk_main_select(insert_joins, mapping.values())
360
        
361
        def main_insert():
362
            if is_function:
363
                log_debug('Calling function on input rows')
364
                args = dict(((k.name, v) for k, v in mapping.iteritems()))
365
                func_call = sql_gen.NamedCol(out_pkey,
366
                    sql_gen.FunctionCall(out_table, **args))
367
                insert_into_pkeys(input_joins, [in_pkey_col, func_call])
368
                return None
369
            else:
370
                return sql.insert_select(db, out_table, mapping.keys(),
371
                    main_select, **insert_args)
372
        
373
        try:
374
            cur = sql.with_savepoint(db, main_insert)
375
            break # insert successful
376
        except sql.MissingCastException, e:
377
            log_exc(e)
378
            
379
            out_col = e.col
380
            type_ = e.type
381
            
382
            log_debug('Casting '+strings.as_tt(out_col)+' input to '
383
                +strings.as_tt(type_))
384
            mapping[out_col] = cast_temp_col(db, type_, mapping[out_col],
385
                errors_table_)
386
        except sql.DuplicateKeyException, e:
387
            log_exc(e)
388
            
389
            old_join_cols = join_cols.copy()
390
            distinct_on.update(util.dict_subset(mapping, e.cols))
391
            join_cols.update(util.dict_subset_right_join(mapping, e.cols))
392
            log_debug('Ignoring existing rows, comparing on these columns:\n'
393
                +strings.as_inline_table(join_cols, ustr=col_ustr))
394
            assert join_cols != old_join_cols # avoid infinite loops
395
        except sql.NullValueException, e:
396
            log_exc(e)
397
            
398
            out_col, = e.cols
399
            try: in_col = mapping[out_col]
400
            except KeyError:
401
                log_debug('Missing mapping for NOT NULL column '+out_col)
402
                remove_all_rows()
403
            else: remove_rows(in_col, None, e)
404
        except sql.FunctionValueException, e:
405
            log_exc(e)
406
            
407
            func_name = e.name
408
            value = e.value
409
            for out_col, in_col in mapping.iteritems():
410
                in_col = sql_gen.unwrap_func_call(in_col, func_name)
411
                invalid2null(in_col, value, e)
412
        except sql.DatabaseErrors, e:
413
            log_exc(e)
414
            
415
            log_debug('No handler for exception')
416
            on_error(e)
417
            remove_all_rows()
418
        # after exception handled, rerun loop with additional constraints
419
    
420
    if cur != None and row_ct_ref != None and cur.rowcount >= 0:
421
        row_ct_ref[0] += cur.rowcount
422
    
423
    if is_function: pass # pkeys table already created
424
    elif has_joins:
425
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols)]
426
        log_debug('Getting output table pkeys of existing/inserted rows')
427
        insert_into_pkeys(select_joins, pkeys_cols, distinct=True)
428
    else:
429
        sql.add_row_num(db, insert_out_pkeys) # for joining with input pkeys
430
        
431
        log_debug('Getting input table pkeys of inserted rows')
432
        sql.run_query_into(db, mk_main_select(input_joins, [in_pkey]),
433
            into=insert_in_pkeys)
434
        sql.add_row_num(db, insert_in_pkeys) # for joining with output pkeys
435
        
436
        assert sql.table_row_count(db, insert_out_pkeys) == sql.table_row_count(
437
            db, insert_in_pkeys)
438
        
439
        log_debug('Combining output and input pkeys in inserted order')
440
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
441
            {sql.row_num_col: sql_gen.join_same_not_null})]
442
        insert_into_pkeys(pkey_joins, pkeys_names)
443
        
444
        sql.empty_temp(db, [insert_out_pkeys, insert_in_pkeys])
445
    
446
    db.log_debug('Adding pkey on pkeys table to enable fast joins', level=2.5)
447
    sql.add_pkey(db, into)
448
    
449
    log_debug('Setting pkeys of missing rows to '+strings.as_tt(repr(default)))
450
    missing_rows_joins = input_joins+[sql_gen.Join(into,
451
        {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
452
        # must use join_same_not_null or query will take forever
453
    insert_into_pkeys(missing_rows_joins,
454
        [in_pkey_col, sql_gen.NamedCol(out_pkey, default)])
455
    
456
    assert sql.table_row_count(db, into) == sql.table_row_count(db, in_table)
457
    
458
    sql.empty_temp(db, in_table)
459
    
460
    srcs = []
461
    if is_func: srcs = sql_gen.cols_srcs(in_cols)
462
    return sql_gen.Col(out_pkey, into, srcs)
(26-26/37)