Project

General

Profile

1
#!/usr/bin/env python
2
# Maps one datasource to another, using a map spreadsheet if needed
3
# Exit status is the # of errors in the import, up to the maximum exit status
4
# Multi-safe (supports an input appearing multiple times).
5
# For outputting an XML file to a PostgreSQL database, use the general format of
6
# http://vegbank.org/vegdocs/xml/vegbank_example_ver1.0.2.xml
7
# Duplicate-column safe (supports multiple columns of the same name, which will
8
# be combined)
9

    
10
import copy
11
import csv
12
import itertools
13
import os.path
14
import warnings
15
import sys
16
import xml.dom.minidom as minidom
17

    
18
sys.path.append(os.path.dirname(__file__)+"/../lib")
19

    
20
import csvs
21
import db_xml
22
import exc
23
import ints
24
import iters
25
import maps
26
import opts
27
import parallelproc
28
import Parser
29
import profiling
30
import sql
31
import sql_gen
32
import sql_io
33
import streams
34
import strings
35
import term
36
import util
37
import xpath
38
import xml_dom
39
import xml_func
40
import xml_parse
41

    
42
metadata_prefix = ':'
43
collision_suffix = '/_alt/'
44

    
45
def get_with_prefix(map_, prefixes, key):
46
    '''Gets all entries for the given key with any of the given prefixes
47
    @return tuple(found_key, found_value)
48
    '''
49
    values = []
50
    for key_ in strings.with_prefixes(['']+prefixes, key): # also with no prefix
51
        try: value = map_[key_]
52
        except KeyError, e: continue # keep going
53
        values.append((key_, value))
54
    
55
    if values != []: return values
56
    else: raise e # re-raise last KeyError
57

    
58
def is_metadata(str_): return str_.startswith(metadata_prefix)
59

    
60
def metadata_value(name):
61
    removed_ref = [False]
62
    name = strings.remove_prefix(metadata_prefix, name, removed_ref)
63
    if removed_ref[0]: return name
64
    else: return None
65

    
66
def cleanup(val):
67
    if val == None: return val
68
    return util.none_if(strings.cleanup(strings.ustr(val)), u'', u'\\N')
69

    
70
def main_():
71
    env_names = []
72
    def usage_err():
73
        raise SystemExit('Usage: '+opts.env_usage(env_names, True)+' '
74
            +sys.argv[0]+''' [map_path...] [<input] [>output]
75
note: row #s start with 1
76
verbosity > 3 in commit mode turns on debug_temp mode, which creates real tables
77
instead of temp tables
78
''')
79
    
80
    ## Get config from env vars
81
    
82
    # Modes
83
    test = opts.env_flag('test', False, env_names)
84
    commit = opts.env_flag('commit', False, env_names) and not test
85
        # never commit in test mode
86
    redo = opts.env_flag('redo', False, env_names) and not commit
87
        # never redo in commit mode (run `make schemas/reinstall` instead)
88
    
89
    # Ranges
90
    start = util.cast(int, opts.get_env_var('start', 1, env_names)) # 1-based
91
    # Make start interally 0-based.
92
    # It's 1-based to the user to match up with the staging table row #s.
93
    start -= 1
94
    if test: n_default = 1
95
    else: n_default = None
96
    n = util.cast(int, util.none_if(opts.get_env_var('n', n_default, env_names),
97
        u''))
98
    end = n
99
    if end != None: end += start
100
    
101
    # Debugging
102
    verbosity = util.cast(float, opts.get_env_var('verbosity', None, env_names))
103
    opts.get_env_var('profile_to', None, env_names) # add to env_names
104
    
105
    # DB
106
    def get_db_config(prefix):
107
        return opts.get_env_vars(sql.db_config_names, prefix, env_names)
108
    in_db_config = get_db_config('in')
109
    out_db_config = get_db_config('out')
110
    in_is_db = 'engine' in in_db_config
111
    out_is_db = 'engine' in out_db_config
112
    in_schema = opts.get_env_var('in_schema', None, env_names)
113
    in_table = opts.get_env_var('in_table', None, env_names)
114
    if in_schema != None:
115
        for config in [in_db_config, out_db_config]:
116
            config['schemas'] += ','+in_schema
117
    
118
    # Optimization
119
    cache_sql = opts.env_flag('cache_sql', True, env_names)
120
    by_col = in_db_config == out_db_config and opts.env_flag('by_col', False,
121
        env_names) # by-column optimization only applies if mapping to same DB
122
    if test: cpus_default = 0
123
    else: cpus_default = 0 # or None to use parallel processing by default
124
    cpus = util.cast(int, util.none_if(opts.get_env_var('cpus', cpus_default,
125
        env_names), u''))
126
    
127
    # Set default verbosity. Must happen after by_col is set.
128
    if verbosity == None:
129
        if test: verbosity = 0.5 # automated tests should not be verbose
130
        elif by_col: verbosity = 3 # show all queries to assist debugging
131
        else: verbosity = 1.1 # just show row progress
132
    
133
    # fix verbosity
134
    if by_col and not test: verbosity = ints.set_min(verbosity, 2)
135
        # live column-based import MUST be run with verbosity 2+ (3 preferred)
136
        # to provide debugging information for often-complex errors.
137
        # without this, debugging is effectively impossible.
138
        # automated tests are exempt from this because they output to the screen
139
    
140
    ##
141
    
142
    # Logging
143
    verbose_errors = test and verbosity > 0
144
    debug = verbosity >= 1.5
145
    def log(msg, level=1):
146
        '''Higher level -> more verbose'''
147
        if level <= verbosity:
148
            if verbosity <= 2:
149
                if level == 1.5: msg = '# '+msg # msg is Redmine list item
150
                elif msg.startswith('DB query:'): # remove extra debug info
151
                    first_line, nl, msg = msg.partition('\n')
152
            elif level > 1: msg = '['+str(level)+'] '+msg # include level in msg
153
            
154
            sys.stderr.write(strings.to_raw_str(msg.rstrip('\n')+'\n'))
155
    if debug: log_debug = lambda msg, level=2: log(msg, level)
156
    else: log_debug = sql.log_debug_none
157
    
158
    # Parse args
159
    map_paths = sys.argv[1:]
160
    if map_paths == []:
161
        if in_is_db or not out_is_db: usage_err()
162
        else: map_paths = [None]
163
    
164
    def connect_db(db_config):
165
        log('Connecting to '+sql.db_config_str(db_config))
166
        return sql.connect(db_config, caching=cache_sql, autocommit=commit,
167
            debug_temp=verbosity > 3 and commit, log_debug=log_debug)
168
    
169
    if end != None: end_str = str(end-1) # end is one past the last #
170
    else: end_str = 'end'
171
    log('Processing input rows '+str(start)+'-'+end_str)
172
    
173
    ex_tracker = exc.ExPercentTracker(iter_text='row')
174
    profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
175
    
176
    # Parallel processing
177
    pool = parallelproc.MultiProducerPool(cpus)
178
    log('Using '+str(pool.process_ct)+' parallel CPUs')
179
    
180
    # Set up DB access
181
    row_ins_ct_ref = [0]
182
    if out_is_db:
183
        out_db = connect_db(out_db_config)
184
        def is_rel_func(name):
185
            return (name in db_xml.put_special_funcs
186
                or sql.function_exists(out_db, sql_gen.Function(name)))
187
    
188
    doc = xml_dom.create_doc()
189
    root = doc.documentElement
190
    out_is_xml_ref = [False]
191
    
192
    in_label_ref = [None]
193
    col_defaults = {}
194
    def update_in_label():
195
        os.environ.setdefault('source', in_schema)
196
        os.environ.setdefault('source', in_label_ref[0])
197
    
198
    def prep_root():
199
        root.clear()
200
        update_in_label()
201
    prep_root()
202
    
203
    def process_input(root, row_ready, map_path):
204
        '''Inputs datasource to XML tree, mapping if needed'''
205
        # Load map header
206
        in_is_xpaths = True
207
        out_is_xpaths = True
208
        out_label = None
209
        if map_path != None:
210
            metadata = []
211
            mappings = []
212
            stream = open(map_path, 'rb')
213
            reader = csv.reader(stream)
214
            in_label, out_label = reader.next()[:2]
215
            
216
            def split_col_name(name):
217
                label, sep, root = name.partition(':')
218
                return label, sep != '', root, []
219
            
220
            in_label, in_root, prefixes = maps.col_info(in_label)
221
            in_is_xpaths = in_root != None
222
            in_label_ref[0] = in_label
223
            update_in_label()
224
            out_label, out_root = maps.col_info(out_label)[:2]
225
            out_is_xpaths = out_root != None
226
            if out_is_xpaths: has_types = out_root.find('/*s/') >= 0
227
                # outer elements are types
228
            
229
            for row in reader:
230
                in_, out = row[:2]
231
                if out != '': mappings.append([in_, out_root+out])
232
            
233
            stream.close()
234
            
235
            root.ownerDocument.documentElement.tagName = out_label
236
        in_is_xml = in_is_xpaths and not in_is_db
237
        out_is_xml_ref[0] = out_is_xpaths and not out_is_db
238
        
239
        def process_rows(process_row, rows, rows_start=0):
240
            '''Processes input rows
241
            @param process_row(in_row, i)
242
            @rows_start The (0-based) row # of the first row in rows. Set this
243
                only if the pre-start rows have already been skipped.
244
            '''
245
            rows = iter(rows)
246
            
247
            if end != None: row_nums = xrange(rows_start, end)
248
            else: row_nums = itertools.count(rows_start)
249
            i = -1
250
            for i in row_nums:
251
                try: row = rows.next()
252
                except StopIteration:
253
                    i -= 1 # last row # didn't count
254
                    break # no more rows
255
                if i < start: continue # not at start row yet
256
                
257
                # Row # is interally 0-based, but 1-based to the user
258
                log('Processing input row #'+str(i+1), level=1.1)
259
                process_row(row, i)
260
                row_ready(i, row)
261
            row_ct = i-start+1
262
            return row_ct
263
        
264
        def map_rows(get_value, rows, **kw_args):
265
            '''Maps input rows
266
            @param get_value(in_, row):str
267
            '''
268
            # Prevent collisions if multiple inputs mapping to same output
269
            outputs_idxs = dict()
270
            for i, mapping in enumerate(mappings):
271
                in_, out = mapping
272
                default = util.NamedTuple(count=1, first=i)
273
                idxs = outputs_idxs.setdefault(out, default)
274
                if idxs is not default: # key existed, so there was a collision
275
                    if idxs.count == 1: # first key does not yet have suffix
276
                        mappings[idxs.first][1] += collision_suffix+'0'
277
                    mappings[i][1] += collision_suffix+str(idxs.count)
278
                    idxs.count += 1
279
            
280
            id_node = None
281
            if out_is_db:
282
                mappings_orig = mappings[:] # save a copy
283
                mappings[:] = [] # empty existing elements
284
                for in_, out in mappings_orig:
285
                    in_str = strings.ustr(in_)
286
                    is_metadata_ = is_metadata(in_str)
287
                    if is_metadata_: value = metadata_value(in_str)
288
                    else: value = '$'+in_str # mark as name
289
                    
290
                    # All put_obj()s should return the same id_node
291
                    nodes, id_node = xpath.put_obj(root, out, '-1', has_types,
292
                        value) # value is placeholder that documents name
293
                    if not is_metadata_: mappings.append([in_, nodes])
294
                if id_node == None:
295
                    warnings.warn(UserWarning('Map warning: No mappings or no '
296
                        'column name matches. Are you importing the correct '
297
                        'input table?'))
298
                xml_func.simplify(root)
299
                sys.stdout.write(strings.to_raw_str('<!--put template-->\n'
300
                    +strings.ustr(root)))
301
                sys.stdout.flush()
302
            
303
            def process_row(row, i):
304
                row_id = str(i)
305
                if id_node != None: xml_dom.set_value(id_node, row_id)
306
                for in_, out in mappings:
307
                    log_debug('Getting '+strings.ustr(in_))
308
                    value = cleanup(get_value(in_, row))
309
                    log_debug('Putting '+strings.urepr(value)+' to '
310
                        +strings.ustr(out))
311
                    if out_is_db: # out is list of XML nodes
312
                        for node in out: xml_dom.set_value(node, value)
313
                    elif value != None: # out is XPath
314
                        xpath.put_obj(root, out, row_id, has_types, value)
315
            return process_rows(process_row, rows, **kw_args)
316
        
317
        def map_table(col_names, rows, **kw_args):
318
            col_names_ct = len(col_names)
319
            col_idxs = util.list_flip(col_names)
320
            col_names_map = dict(zip(col_names, col_names))
321
            
322
            # Resolve names
323
            mappings_orig = mappings[:] # save a copy
324
            mappings[:] = [] # empty existing elements
325
            for in_, out in mappings_orig:
326
                if is_metadata(in_): mappings.append([in_, out])
327
                else:
328
                    try: cols = get_with_prefix(col_names_map, [], in_)
329
                    except KeyError: pass
330
                    else:
331
                        mappings[len(mappings):] = [[db_xml.ColRef(
332
                            orig, col_idxs[orig]), out] for simp, orig in cols]
333
                            # can't use += because that uses =
334
            
335
            def get_value(in_, row): return row.list[in_.idx]
336
            def wrap_row(row):
337
                return util.ListDict(util.list_as_length(row, col_names_ct),
338
                    col_names, col_idxs) # handle CSV rows of different lengths
339
            
340
            return map_rows(get_value, util.WrapIter(wrap_row, rows), **kw_args)
341
        
342
        if in_is_db:
343
            def on_error(e): ex_tracker.track(e)
344
            
345
            if by_col: in_db = out_db
346
            else: in_db = connect_db(in_db_config)
347
            
348
            # Get table and schema name
349
            schema = in_schema # modified, so can't have same name as outer var
350
            table = in_table # modified, so can't have same name as outer var
351
            if table == None:
352
                assert in_is_xpaths
353
                schema, sep, table = in_root.partition('.')
354
                if sep == '': # only the table name was specified
355
                    table = schema
356
                    schema = None
357
            table = sql_gen.Table(table, schema)
358
            
359
            # Fetch rows
360
            if by_col: limit = 0 # only fetch column names
361
            else: limit = n
362
            try:
363
                cur = sql.select(in_db, table, limit=limit, start=start,
364
                    recover=True, cacheable=False)
365
            except sql.DoesNotExistException:
366
                table = None
367
                col_names = []
368
                rows = []
369
            else:
370
                col_names = list(sql.col_names(cur))
371
                rows = sql.rows(cur)
372
            
373
            # inline metadata value columns
374
            col_default_values = {}
375
            for col_name in col_names:
376
                col = sql_gen.Col(col_name, table)
377
                if sql.col_is_constant(in_db, col):
378
                    col_default_values[col_name] = (metadata_prefix +
379
                        sql.col_default_value(in_db, col))
380
            for i, mapping in enumerate(mappings):
381
                in_, out = mapping
382
                mappings[i] = (col_default_values.get(in_, in_), out)
383
            
384
            if by_col:
385
                map_table(col_names, []) # just create the template
386
                
387
                if table != None and start == 0 and n == None: # full re-import
388
                    log('Clearing errors table')
389
                    errors_table_ = sql_io.errors_table(in_db, table)
390
                    if errors_table_ != None:
391
                        sql.drop_table(in_db, errors_table_)
392
                
393
                # Strip XML functions not in the DB
394
                xml_func.process(root, is_rel_func=is_rel_func)
395
                if debug: log_debug('Putting stripped:\n'+strings.ustr(root))
396
                    # only calc if debug
397
                
398
                # Import rows
399
                in_row_ct_ref = [0]
400
                db_xml.put_table(in_db, root.firstChild, table, in_row_ct_ref,
401
                    row_ins_ct_ref, n, start, on_error, col_defaults)
402
                row_ct = in_row_ct_ref[0]
403
            else:
404
                # Use normal by-row method
405
                row_ct = map_table(col_names, rows, rows_start=start)
406
                    # rows_start: pre-start rows have been skipped
407
                
408
                in_db.db.close()
409
        elif in_is_xml:
410
            stdin = streams.LineCountStream(sys.stdin)
411
            def on_error(e):
412
                exc.add_msg(e, term.emph('input line #:')+' '
413
                    +str(stdin.line_num))
414
                ex_tracker.track(e)
415
            
416
            def get_rows(doc2rows):
417
                return iters.flatten(itertools.imap(doc2rows,
418
                    xml_parse.docs_iter(stdin, on_error)))
419
            
420
            if map_path == None:
421
                def doc2rows(in_xml_root):
422
                    iter_ = xml_dom.NodeElemIter(in_xml_root)
423
                    util.skip(iter_, xml_dom.is_text) # skip metadata
424
                    return iter_
425
                
426
                row_ct = process_rows(lambda row, i: root.appendChild(row),
427
                    get_rows(doc2rows))
428
            else:
429
                def doc2rows(in_xml_root):
430
                    rows = xpath.get(in_xml_root, in_root, limit=end)
431
                    if rows == []: warnings.warn(UserWarning('Map warning: '
432
                        'Root "'+in_root+'" not found in input'))
433
                    return rows
434
                
435
                def get_value(in_, row):
436
                    nodes = xpath.get(row, in_, allow_rooted=False)
437
                    if nodes != []: return xml_dom.value(nodes[0])
438
                    else: return None
439
                
440
                row_ct = map_rows(get_value, get_rows(doc2rows))
441
        else: # input is CSV
442
            reader, col_names = csvs.reader_and_header(sys.stdin)
443
            row_ct = map_table(col_names, reader)
444
        
445
        return row_ct
446
    
447
    def process_inputs(root, row_ready):
448
        row_ct = 0
449
        for map_path in map_paths:
450
            row_ct += process_input(root, row_ready, map_path)
451
        return row_ct
452
    
453
    pool.share_vars(locals())
454
    if out_is_db:
455
        try:
456
            if redo: sql.empty_db(out_db)
457
            pool.share_vars(locals())
458
            
459
            def row_ready(row_num, input_row):
460
                row_str_ = [None]
461
                def row_str():
462
                    if row_str_[0] == None:
463
                        # Row # is interally 0-based, but 1-based to the user
464
                        row_str_[0] = (term.emph('row #:')+' '+str(row_num+1)
465
                            +'\n'+term.emph('input row:')+'\n'
466
                            +strings.ustr(input_row))
467
                        if verbose_errors: row_str_[0] += ('\n'
468
                            +term.emph('output row:')+'\n'+strings.ustr(root))
469
                    return row_str_[0]
470
                
471
                if debug: log_debug(row_str()) # only calc if debug
472
                
473
                def on_error(e):
474
                    exc.add_msg(e, row_str())
475
                    ex_tracker.track(e, row_num, detail=verbose_errors)
476
                pool.share_vars(locals())
477
                
478
                row_root = root.cloneNode(True) # deep copy so don't modify root
479
                xml_func.process(row_root, on_error, is_rel_func, out_db)
480
                if debug: log_debug('Putting processed:\n'
481
                    +strings.ustr(row_root)) # only calc if debug
482
                if not xml_dom.is_empty(row_root):
483
                    assert xml_dom.has_one_child(row_root)
484
                    try:
485
                        sql.with_savepoint(out_db,
486
                            lambda: db_xml.put(out_db, row_root.firstChild,
487
                                row_ins_ct_ref, on_error, col_defaults))
488
                    except sql.DatabaseErrors, e: on_error(e)
489
            
490
            row_ct = process_inputs(root, row_ready)
491
            sys.stdout.write('Inserted '+str(row_ins_ct_ref[0])+
492
                ' new rows into database\n')
493
            sys.stdout.flush()
494
            
495
            # Consume asynchronous tasks
496
            pool.main_loop()
497
        finally: out_db.close()
498
    else:
499
        def on_error(e): ex_tracker.track(e)
500
        def row_ready(row_num, input_row): pass
501
        row_ct = process_inputs(root, row_ready)
502
        xml_func.process(root, on_error)
503
        if out_is_xml_ref[0]:
504
            doc.writexml(sys.stdout, **xml_dom.prettyxml_config)
505
        else: # output is CSV
506
            raise NotImplementedError('CSV output not supported yet')
507
    
508
    # Consume any asynchronous tasks not already consumed above
509
    pool.main_loop()
510
    
511
    profiler.stop(row_ct)
512
    if not by_col: ex_tracker.add_iters(row_ct) # only if errors are done by row
513
    log('Processed '+str(row_ct)+' input rows')
514
    log(profiler.msg())
515
    log(ex_tracker.msg())
516
    ex_tracker.exit()
517

    
518
def main():
519
    try: main_()
520
    except Parser.SyntaxError, e: raise SystemExit(strings.ustr(e))
521

    
522
if __name__ == '__main__':
523
    profile_to = opts.get_env_var('profile_to', None)
524
    if profile_to != None:
525
        import cProfile
526
        sys.stderr.write('Profiling to '+profile_to+'\n')
527
        cProfile.run(main.func_code, profile_to)
528
    else: main()
(43-43/85)