Project

General

Profile

1
# Database import/export
2

    
3
import exc
4
import dicts
5
import sql
6
import sql_gen
7
import strings
8
import util
9

    
10
##### Data cleanup
11

    
12
def cleanup_table(db, table, cols):
13
    table = sql_gen.as_Table(table)
14
    cols = map(sql_gen.as_Col, cols)
15
    
16
    expr = ('nullif(nullif(trim(both from %s), '+db.esc_value('')+'), '
17
        +db.esc_value(r'\N')+')')
18
    changes = [(v, sql_gen.CustomCode(expr % v.to_str(db)))
19
        for v in cols]
20
    
21
    sql.update(db, table, changes, in_place=True)
22

    
23
##### Error tracking
24

    
25
def track_data_error(db, errors_table, cols, value, error_code, error):
26
    '''
27
    @param errors_table If None, does nothing.
28
    '''
29
    if errors_table == None or cols == (): return
30
    
31
    for col in cols:
32
        try:
33
            sql.insert(db, errors_table, dict(column=col.name, value=value,
34
                error_code=error_code, error=error), recover=True,
35
                cacheable=True, log_level=4)
36
        except sql.DuplicateKeyException: pass
37

    
38
def cast(db, type_, col, errors_table=None):
39
    '''Casts an (unrenamed) column or value.
40
    If errors_table set and col has srcs, saves errors in errors_table (using
41
    col's srcs attr as the source columns) and converts errors to warnings.
42
    @param col str|sql_gen.Col|sql_gen.Literal
43
    @param errors_table None|sql_gen.Table|str
44
    '''
45
    col = sql_gen.as_Col(col)
46
    save_errors = (errors_table != None and isinstance(col, sql_gen.Col)
47
        and col.srcs != ())
48
    if not save_errors: return sql_gen.Cast(type_, col) # can't save errors
49
    
50
    assert not isinstance(col, sql_gen.NamedCol)
51
    
52
    errors_table = sql_gen.as_Table(errors_table)
53
    srcs = map(sql_gen.to_name_only_col, col.srcs)
54
    function_name = str(sql_gen.FunctionCall(type_, *srcs))
55
    function = db.TempFunction(function_name)
56
    
57
    while True:
58
        # Create function definition
59
        errors_table_cols = map(sql_gen.Col,
60
            ['column', 'value', 'error_code', 'error'])
61
        query = '''\
62
CREATE FUNCTION '''+function.to_str(db)+'''(value text)
63
RETURNS '''+type_+'''
64
LANGUAGE plpgsql
65
STRICT
66
AS $$
67
BEGIN
68
    /* The explicit cast to the return type is needed to make the cast happen
69
    inside the try block. (Implicit casts to the return type happen at the end
70
    of the function, outside any block.) */
71
    RETURN value::'''+type_+''';
72
EXCEPTION
73
    WHEN data_exception THEN
74
        -- Save error in errors table.
75
        DECLARE
76
            error_code text := SQLSTATE;
77
            error text := SQLERRM;
78
            "column" text;
79
        BEGIN
80
            -- Insert the value and error for *each* source column.
81
            FOR "column" IN
82
'''+sql.mk_select(db, sql_gen.NamedValues('c', None, [[c.name] for c in srcs]),
83
    order_by=None, start=0)+'''
84
            LOOP
85
                BEGIN
86
'''+sql.mk_insert_select(db, errors_table, errors_table_cols,
87
    sql_gen.Values(errors_table_cols).to_str(db))+''';
88
                EXCEPTION
89
                    WHEN unique_violation THEN NULL; -- continue to next row
90
                END;
91
            END LOOP;
92
        END;
93
        
94
        RAISE WARNING '%', SQLERRM;
95
        RETURN NULL;
96
END;
97
$$;
98
'''
99
        
100
        # Create function
101
        try:
102
            sql.run_query(db, query, recover=True, cacheable=True,
103
                log_ignore_excs=(sql.DuplicateException,))
104
            break # successful
105
        except sql.DuplicateException:
106
            function.name = sql.next_version(function.name)
107
            # try again with next version of name
108
    
109
    return sql_gen.FunctionCall(function, col)
110

    
111
def cast_temp_col(db, type_, col, errors_table=None):
112
    '''Like cast(), but creates a new column with the cast values if the input
113
    is a column.
114
    @return The new column or cast value
115
    '''
116
    def cast_(col): return cast(db, type_, col, errors_table)
117
    
118
    try: col = sql_gen.underlying_col(col)
119
    except sql_gen.NoUnderlyingTableException: return sql_gen.wrap(cast_, col)
120
    
121
    table = col.table
122
    new_col = sql_gen.Col(sql_gen.concat(col.name, '::'+type_), table, col.srcs)
123
    expr = cast_(col)
124
    
125
    # Add column
126
    new_typed_col = sql_gen.TypedCol(new_col.name, type_)
127
    sql.add_col(db, table, new_typed_col, comment='src: '+repr(col))
128
    new_col.name = new_typed_col.name # propagate any renaming
129
    
130
    sql.update(db, table, [(new_col, expr)], in_place=True, cacheable=True)
131
    sql.add_index(db, new_col)
132
    
133
    return new_col
134

    
135
def errors_table(db, table, if_exists=True):
136
    '''
137
    @param if_exists If set, returns None if the errors table doesn't exist
138
    @return None|sql_gen.Table
139
    '''
140
    table = sql_gen.as_Table(table)
141
    if table.srcs != (): table = table.srcs[0]
142
    
143
    errors_table = sql_gen.suffixed_table(table, '.errors')
144
    if if_exists and not sql.table_exists(db, errors_table): return None
145
    return errors_table
146

    
147
##### Import
148

    
149
def put(db, table, row, pkey_=None, row_ct_ref=None):
150
    '''Recovers from errors.
151
    Only works under PostgreSQL (uses INSERT RETURNING).
152
    '''
153
    row = sql_gen.ColDict(db, table, row)
154
    if pkey_ == None: pkey_ = sql.pkey(db, table, recover=True)
155
    
156
    try:
157
        cur = sql.insert(db, table, row, pkey_, recover=True)
158
        if row_ct_ref != None and cur.rowcount >= 0:
159
            row_ct_ref[0] += cur.rowcount
160
        return sql.value(cur)
161
    except sql.DuplicateKeyException, e:
162
        row = sql_gen.ColDict(db, table,
163
            util.dict_subset_right_join(row, e.cols))
164
        return sql.value(sql.select(db, table, [pkey_], row, recover=True))
165

    
166
def get(db, table, row, pkey, row_ct_ref=None, create=False):
167
    '''Recovers from errors'''
168
    try:
169
        return sql.value(sql.select(db, table, [pkey], row, limit=1,
170
            recover=True))
171
    except StopIteration:
172
        if not create: raise
173
        return put(db, table, row, pkey, row_ct_ref) # insert new row
174

    
175
def is_func_result(col):
176
    return col.table.name.find('(') >= 0 and col.name == 'result'
177

    
178
def into_table_name(out_table, in_tables0, mapping, is_func):
179
    def in_col_str(in_col):
180
        in_col = sql_gen.remove_col_rename(in_col)
181
        if isinstance(in_col, sql_gen.Col):
182
            table = in_col.table
183
            if table == in_tables0:
184
                in_col = sql_gen.to_name_only_col(in_col)
185
            elif is_func_result(in_col): in_col = table # omit col name
186
        return str(in_col)
187
    
188
    str_ = str(out_table)
189
    if is_func:
190
        str_ += '('
191
        
192
        try: value_in_col = mapping['value']
193
        except KeyError:
194
            str_ += ', '.join((str(k)+'='+in_col_str(v)
195
                for k, v in mapping.iteritems()))
196
        else: str_ += in_col_str(value_in_col)
197
        
198
        str_ += ')'
199
    else:
200
        out_col = 'rank'
201
        try: in_col = mapping[out_col]
202
        except KeyError: str_ += '_pkeys'
203
        else: # has a rank column, so hierarchical
204
            str_ += '['+str(out_col)+'='+in_col_str(in_col)+']'
205
    return str_
206

    
207
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, into=None,
208
    default=None, is_func=False, on_error=exc.raise_):
209
    '''Recovers from errors.
210
    Only works under PostgreSQL (uses INSERT RETURNING).
211
    IMPORTANT: Must be run at the *beginning* of a transaction.
212
    @param in_tables The main input table to select from, followed by a list of
213
        tables to join with it using the main input table's pkey
214
    @param mapping dict(out_table_col=in_table_col, ...)
215
        * out_table_col: str (*not* sql_gen.Col)
216
        * in_table_col: sql_gen.Col|literal-value
217
    @param into The table to contain the output and input pkeys.
218
        Defaults to `out_table.name+'_pkeys'`.
219
    @param default The *output* column to use as the pkey for missing rows.
220
        If this output column does not exist in the mapping, uses None.
221
    @param is_func Whether out_table is the name of a SQL function, not a table
222
    @return sql_gen.Col Where the output pkeys are made available
223
    '''
224
    out_table = sql_gen.as_Table(out_table)
225
    
226
    def log_debug(msg): db.log_debug(msg, level=1.5)
227
    def col_ustr(str_):
228
        return strings.repr_no_u(sql_gen.remove_col_rename(str_))
229
    
230
    log_debug('********** New iteration **********')
231
    log_debug('Inserting these input columns into '+strings.as_tt(
232
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
233
    
234
    is_function = sql.function_exists(db, out_table)
235
    
236
    if is_function: out_pkey = 'result'
237
    else: out_pkey = sql.pkey(db, out_table, recover=True)
238
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
239
    
240
    if mapping == {}: # need at least one column for INSERT SELECT
241
        mapping = {out_pkey: None} # ColDict will replace with default value
242
    
243
    # Create input joins from list of input tables
244
    in_tables_ = in_tables[:] # don't modify input!
245
    in_tables0 = in_tables_.pop(0) # first table is separate
246
    errors_table_ = errors_table(db, in_tables0)
247
    in_pkey = sql.pkey(db, in_tables0, recover=True)
248
    in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
249
    input_joins = [in_tables0]+[sql_gen.Join(v,
250
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
251
    
252
    if into == None:
253
        into = into_table_name(out_table, in_tables0, mapping, is_func)
254
    into = sql_gen.as_Table(into)
255
    
256
    # Set column sources
257
    in_cols = filter(sql_gen.is_table_col, mapping.values())
258
    for col in in_cols:
259
        if col.table == in_tables0: col.set_srcs(sql_gen.src_self)
260
    
261
    log_debug('Joining together input tables into temp table')
262
    # Place in new table for speed and so don't modify input if values edited
263
    in_table = sql_gen.Table('in')
264
    mapping = dicts.join(mapping, sql.flatten(db, in_table, input_joins,
265
        in_cols, preserve=[in_pkey_col], start=0))
266
    input_joins = [in_table]
267
    db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
268
    
269
    mapping = sql_gen.ColDict(db, out_table, mapping)
270
        # after applying dicts.join() because that returns a plain dict
271
    
272
    # Resolve default value column
273
    if default != None:
274
        try: default = mapping[default]
275
        except KeyError:
276
            db.log_debug('Default value column '
277
                +strings.as_tt(strings.repr_no_u(default))
278
                +' does not exist in mapping, falling back to None', level=2.1)
279
            default = None
280
    
281
    pkeys_names = [in_pkey, out_pkey]
282
    pkeys_cols = [in_pkey_col, out_pkey_col]
283
    
284
    pkeys_table_exists_ref = [False]
285
    def insert_into_pkeys(joins, cols, distinct=False):
286
        kw_args = {}
287
        if distinct: kw_args.update(dict(distinct_on=[in_pkey_col]))
288
        query = sql.mk_select(db, joins, cols, order_by=None, start=0,
289
            **kw_args)
290
        
291
        if pkeys_table_exists_ref[0]:
292
            sql.insert_select(db, into, pkeys_names, query)
293
        else:
294
            sql.run_query_into(db, query, into=into)
295
            pkeys_table_exists_ref[0] = True
296
    
297
    limit_ref = [None]
298
    insert_in_table = in_table
299
    conds = set()
300
    distinct_on = sql_gen.ColDict(db, out_table)
301
    def mk_main_select(joins, cols):
302
        conds_ = [(sql_gen.with_table(k, insert_in_table), v) for k, v in conds]
303
        return sql.mk_select(db, joins, cols, conds_, limit=limit_ref[0],
304
            start=0)
305
    
306
    exc_strs = set()
307
    def log_exc(e):
308
        e_str = exc.str_(e, first_line_only=True)
309
        log_debug('Caught exception: '+e_str)
310
        assert e_str not in exc_strs # avoid infinite loops
311
        exc_strs.add(e_str)
312
    
313
    def remove_all_rows():
314
        log_debug('Ignoring all rows')
315
        limit_ref[0] = 0 # just create an empty pkeys table
316
    
317
    def ignore(in_col, value, e):
318
        track_data_error(db, errors_table_, in_col.srcs, value,
319
            e.cause.pgcode, e.cause.pgerror)
320
        log_debug('Ignoring rows with '+strings.as_tt(repr(in_col))+' = '
321
            +strings.as_tt(repr(value)))
322
    
323
    def remove_rows(in_col, value, e):
324
        ignore(in_col, value, e)
325
        cond = (in_col, sql_gen.CompareCond(value, '!='))
326
        assert cond not in conds # avoid infinite loops
327
        conds.add(cond)
328
    
329
    def invalid2null(in_col, value, e):
330
        ignore(in_col, value, e)
331
        sql.update(db, in_table, [(in_col, None)],
332
            sql_gen.ColValueCond(in_col, value))
333
    
334
    def insert_pkeys_table(which):
335
        return sql_gen.Table(sql_gen.concat(in_table.name,
336
            '_insert_'+which+'_pkeys'))
337
    insert_out_pkeys = insert_pkeys_table('out')
338
    insert_in_pkeys = insert_pkeys_table('in')
339
    
340
    # Do inserts and selects
341
    join_cols = sql_gen.ColDict(db, out_table)
342
    while True:
343
        if limit_ref[0] == 0: # special case
344
            log_debug('Creating an empty pkeys table')
345
            cur = sql.run_query_into(db, sql.mk_select(db, out_table,
346
                [out_pkey], limit=limit_ref[0]), into=insert_out_pkeys)
347
            break # don't do main case
348
        
349
        has_joins = join_cols != {}
350
        
351
        log_debug('Trying to insert new rows')
352
        
353
        # Prepare to insert new rows
354
        insert_args = dict(recover=True, cacheable=False)
355
        if has_joins:
356
            insert_args.update(dict(ignore=True))
357
        else:
358
            insert_args.update(dict(returning=out_pkey, into=insert_out_pkeys))
359
        main_select = mk_main_select([insert_in_table],
360
            [sql_gen.with_table(c, insert_in_table) for c in mapping.values()])
361
        
362
        def main_insert():
363
            if is_function:
364
                log_debug('Calling function on input rows')
365
                args = dict(((k.name, v) for k, v in mapping.iteritems()))
366
                func_call = sql_gen.NamedCol(out_pkey,
367
                    sql_gen.FunctionCall(out_table, **args))
368
                insert_into_pkeys(input_joins, [in_pkey_col, func_call])
369
                return None
370
            else:
371
                return sql.insert_select(db, out_table, mapping.keys(),
372
                    main_select, **insert_args)
373
        
374
        try:
375
            cur = sql.with_savepoint(db, main_insert)
376
            break # insert successful
377
        except sql.MissingCastException, e:
378
            log_exc(e)
379
            
380
            out_col = e.col
381
            type_ = e.type
382
            
383
            log_debug('Casting '+strings.as_tt(out_col)+' input to '
384
                +strings.as_tt(type_))
385
            mapping[out_col] = cast_temp_col(db, type_, mapping[out_col],
386
                errors_table_)
387
        except sql.DuplicateKeyException, e:
388
            log_exc(e)
389
            
390
            old_join_cols = join_cols.copy()
391
            distinct_on.update(util.dict_subset(mapping, e.cols))
392
            join_cols.update(util.dict_subset_right_join(mapping, e.cols))
393
            log_debug('Ignoring existing rows, comparing on these columns:\n'
394
                +strings.as_inline_table(join_cols, ustr=col_ustr))
395
            assert join_cols != old_join_cols # avoid infinite loops
396
            
397
            # Uniquify input table to avoid internal duplicate keys
398
            insert_in_table = sql.distinct_table(db, insert_in_table,
399
                filter(sql_gen.is_table_col, distinct_on.values()))
400
        except sql.NullValueException, e:
401
            log_exc(e)
402
            
403
            out_col, = e.cols
404
            try: in_col = mapping[out_col]
405
            except KeyError:
406
                log_debug('Missing mapping for NOT NULL column '+out_col)
407
                remove_all_rows()
408
            else: remove_rows(in_col, None, e)
409
        except sql.DatabaseErrors, e:
410
            log_exc(e)
411
            
412
            log_debug('No handler for exception')
413
            on_error(e)
414
            remove_all_rows()
415
        # after exception handled, rerun loop with additional constraints
416
    
417
    if cur != None and row_ct_ref != None and cur.rowcount >= 0:
418
        row_ct_ref[0] += cur.rowcount
419
    
420
    if is_function: pass # pkeys table already created
421
    elif has_joins:
422
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols)]
423
        log_debug('Getting output table pkeys of existing/inserted rows')
424
        insert_into_pkeys(select_joins, pkeys_cols, distinct=True)
425
    else:
426
        sql.add_row_num(db, insert_out_pkeys) # for joining with input pkeys
427
        
428
        log_debug('Getting input table pkeys of inserted rows')
429
        sql.run_query_into(db, mk_main_select(input_joins, [in_pkey]),
430
            into=insert_in_pkeys)
431
        sql.add_row_num(db, insert_in_pkeys) # for joining with output pkeys
432
        
433
        assert sql.table_row_count(db, insert_out_pkeys) == sql.table_row_count(
434
            db, insert_in_pkeys)
435
        
436
        log_debug('Combining output and input pkeys in inserted order')
437
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
438
            {sql.row_num_col: sql_gen.join_same_not_null})]
439
        insert_into_pkeys(pkey_joins, pkeys_names)
440
        
441
        sql.empty_temp(db, [insert_out_pkeys, insert_in_pkeys])
442
    
443
    db.log_debug('Adding pkey on pkeys table to enable fast joins', level=2.5)
444
    sql.add_pkey(db, into)
445
    
446
    log_debug('Setting pkeys of missing rows to '+strings.as_tt(repr(default)))
447
    missing_rows_joins = input_joins+[sql_gen.Join(into,
448
        {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
449
        # must use join_same_not_null or query will take forever
450
    insert_into_pkeys(missing_rows_joins,
451
        [in_pkey_col, sql_gen.NamedCol(out_pkey, default)])
452
    
453
    assert sql.table_row_count(db, into) == sql.table_row_count(db, in_table)
454
    
455
    sql.empty_temp(db, set([in_table, insert_in_table]))
456
    
457
    srcs = []
458
    if is_func: srcs = sql_gen.cols_srcs(in_cols)
459
    return sql_gen.Col(out_pkey, into, srcs)
(26-26/37)