Project

General

Profile

1
#!/usr/bin/env python
2
# Maps one datasource to another, using a map spreadsheet if needed
3
# Exit status is the # of errors in the import, up to the maximum exit status
4
# Multi-safe (supports an input appearing multiple times).
5
# For outputting an XML file to a PostgreSQL database, use the general format of
6
# http://vegbank.org/vegdocs/xml/vegbank_example_ver1.0.2.xml
7
# Duplicate-column safe (supports multiple columns of the same name, which will
8
# be combined)
9
# Case- and punctuation-insensitive.
10

    
11
import copy
12
import csv
13
import itertools
14
import os.path
15
import warnings
16
import sys
17
import xml.dom.minidom as minidom
18

    
19
sys.path.append(os.path.dirname(__file__)+"/../lib")
20

    
21
import csvs
22
import db_xml
23
import exc
24
import ints
25
import iters
26
import maps
27
import opts
28
import parallelproc
29
import Parser
30
import profiling
31
import sql
32
import sql_gen
33
import sql_io
34
import streams
35
import strings
36
import term
37
import util
38
import xpath
39
import xml_dom
40
import xml_func
41
import xml_parse
42

    
43
metadata_prefix = ':'
44
collision_suffix = '/_alt/'
45

    
46
def get_with_prefix(map_, prefixes, key):
47
    '''Gets all entries for the given key with any of the given prefixes
48
    @return tuple(found_key, found_value)
49
    '''
50
    values = []
51
    for key_ in strings.with_prefixes(['']+prefixes, key): # also with no prefix
52
        try: value = map_[key_]
53
        except KeyError, e: continue # keep going
54
        values.append((key_, value))
55
    
56
    if values != []: return values
57
    else: raise e # re-raise last KeyError
58

    
59
def is_metadata(str_): return str_.startswith(metadata_prefix)
60

    
61
def metadata_value(name):
62
    removed_ref = [False]
63
    name = strings.remove_prefix(metadata_prefix, name, removed_ref)
64
    if removed_ref[0]: return name
65
    else: return None
66

    
67
def cleanup(val):
68
    if val == None: return val
69
    return util.none_if(strings.cleanup(strings.ustr(val)), u'', u'\\N')
70

    
71
def main_():
72
    env_names = []
73
    def usage_err():
74
        raise SystemExit('Usage: '+opts.env_usage(env_names, True)+' '
75
            +sys.argv[0]+' [map_path...] [<input] [>output]\n'
76
            'Note: Row #s start with 1')
77
    
78
    ## Get config from env vars
79
    
80
    # Modes
81
    test = opts.env_flag('test', False, env_names)
82
    commit = opts.env_flag('commit', False, env_names) and not test
83
        # never commit in test mode
84
    redo = opts.env_flag('redo', False, env_names) and not commit
85
        # never redo in commit mode (run `make schemas/reinstall` instead)
86
    
87
    # Ranges
88
    start = util.cast(int, opts.get_env_var('start', 1, env_names)) # 1-based
89
    # Make start interally 0-based.
90
    # It's 1-based to the user to match up with the staging table row #s.
91
    start -= 1
92
    if test: n_default = 1
93
    else: n_default = None
94
    n = util.cast(int, util.none_if(opts.get_env_var('n', n_default, env_names),
95
        u''))
96
    end = n
97
    if end != None: end += start
98
    
99
    # Debugging
100
    verbosity = util.cast(float, opts.get_env_var('verbosity', None, env_names))
101
    opts.get_env_var('profile_to', None, env_names) # add to env_names
102
    
103
    # DB
104
    def get_db_config(prefix):
105
        return opts.get_env_vars(sql.db_config_names, prefix, env_names)
106
    in_db_config = get_db_config('in')
107
    out_db_config = get_db_config('out')
108
    in_is_db = 'engine' in in_db_config
109
    out_is_db = 'engine' in out_db_config
110
    in_schema = opts.get_env_var('in_schema', None, env_names)
111
    in_table = opts.get_env_var('in_table', None, env_names)
112
    if in_schema != None:
113
        for config in [in_db_config, out_db_config]:
114
            config['schemas'] += ','+in_schema
115
    
116
    # Optimization
117
    cache_sql = opts.env_flag('cache_sql', True, env_names)
118
    by_col = in_db_config == out_db_config and opts.env_flag('by_col', False,
119
        env_names) # by-column optimization only applies if mapping to same DB
120
    if test: cpus_default = 0
121
    else: cpus_default = 0 # or None to use parallel processing by default
122
    cpus = util.cast(int, util.none_if(opts.get_env_var('cpus', cpus_default,
123
        env_names), u''))
124
    
125
    # Set default verbosity. Must happen after by_col is set.
126
    if verbosity == None:
127
        if test: verbosity = 0.5 # automated tests should not be verbose
128
        elif by_col: verbosity = 3 # show all queries to assist debugging
129
        else: verbosity = 1.1 # just show row progress
130
    
131
    # fix verbosity
132
    if by_col and not test: verbosity = ints.set_min(verbosity, 2)
133
        # live column-based import MUST be run with verbosity 2+ (3 preferred)
134
        # to provide debugging information for often-complex errors.
135
        # without this, debugging is effectively impossible.
136
        # automated tests are exempt from this because they output to the screen
137
    
138
    ##
139
    
140
    # Logging
141
    verbose_errors = test and verbosity > 0
142
    debug = verbosity >= 1.5
143
    def log(msg, level=1):
144
        '''Higher level -> more verbose'''
145
        if level <= verbosity:
146
            if verbosity <= 2:
147
                if level == 1.5: msg = '# '+msg # msg is Redmine list item
148
                elif msg.startswith('DB query:'): # remove extra debug info
149
                    first_line, nl, msg = msg.partition('\n')
150
            elif level > 1: msg = '['+str(level)+'] '+msg # include level in msg
151
            
152
            sys.stderr.write(strings.to_raw_str(msg.rstrip('\n')+'\n'))
153
    if debug: log_debug = lambda msg, level=2: log(msg, level)
154
    else: log_debug = sql.log_debug_none
155
    
156
    # Parse args
157
    map_paths = sys.argv[1:]
158
    if map_paths == []:
159
        if in_is_db or not out_is_db: usage_err()
160
        else: map_paths = [None]
161
    
162
    def connect_db(db_config):
163
        log('Connecting to '+sql.db_config_str(db_config))
164
        return sql.connect(db_config, caching=cache_sql, autocommit=commit,
165
            debug_temp=verbosity > 3 and commit, log_debug=log_debug)
166
    
167
    if end != None: end_str = str(end-1) # end is one past the last #
168
    else: end_str = 'end'
169
    log('Processing input rows '+str(start)+'-'+end_str)
170
    
171
    ex_tracker = exc.ExPercentTracker(iter_text='row')
172
    profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
173
    
174
    # Parallel processing
175
    pool = parallelproc.MultiProducerPool(cpus)
176
    log('Using '+str(pool.process_ct)+' parallel CPUs')
177
    
178
    # Set up DB access
179
    row_ins_ct_ref = [0]
180
    if out_is_db:
181
        out_db = connect_db(out_db_config)
182
        def is_rel_func(name):
183
            return (name in db_xml.put_special_funcs
184
                or sql.function_exists(out_db, sql_gen.Function(name)))
185
    
186
    doc = xml_dom.create_doc()
187
    root = doc.documentElement
188
    out_is_xml_ref = [False]
189
    
190
    in_label_ref = [None]
191
    col_defaults = {}
192
    def update_in_label():
193
        if in_schema != None: os.environ['source'] = in_schema
194
        elif in_label_ref[0] != None: os.environ['source'] = in_label_ref[0]
195
    
196
    def prep_root():
197
        root.clear()
198
        update_in_label()
199
    prep_root()
200
    
201
    def process_input(root, row_ready, map_path):
202
        '''Inputs datasource to XML tree, mapping if needed'''
203
        # Load map header
204
        in_is_xpaths = True
205
        out_is_xpaths = True
206
        out_label = None
207
        if map_path != None:
208
            metadata = []
209
            mappings = []
210
            stream = open(map_path, 'rb')
211
            reader = csv.reader(stream)
212
            in_label, out_label = reader.next()[:2]
213
            
214
            def split_col_name(name):
215
                label, sep, root = name.partition(':')
216
                return label, sep != '', root, []
217
            
218
            in_label, in_root, prefixes = maps.col_info(in_label)
219
            in_is_xpaths = in_root != None
220
            in_label_ref[0] = in_label
221
            update_in_label()
222
            out_label, out_root = maps.col_info(out_label)[:2]
223
            out_is_xpaths = out_root != None
224
            if out_is_xpaths: has_types = out_root.find('/*s/') >= 0
225
                # outer elements are types
226
            
227
            for row in reader:
228
                in_, out = row[:2]
229
                if out != '': mappings.append([in_, out_root+out])
230
            
231
            stream.close()
232
            
233
            root.ownerDocument.documentElement.tagName = out_label
234
        in_is_xml = in_is_xpaths and not in_is_db
235
        out_is_xml_ref[0] = out_is_xpaths and not out_is_db
236
        
237
        def process_rows(process_row, rows, rows_start=0):
238
            '''Processes input rows
239
            @param process_row(in_row, i)
240
            @rows_start The (0-based) row # of the first row in rows. Set this
241
                only if the pre-start rows have already been skipped.
242
            '''
243
            rows = iter(rows)
244
            
245
            if end != None: row_nums = xrange(rows_start, end)
246
            else: row_nums = itertools.count(rows_start)
247
            i = -1
248
            for i in row_nums:
249
                try: row = rows.next()
250
                except StopIteration:
251
                    i -= 1 # last row # didn't count
252
                    break # no more rows
253
                if i < start: continue # not at start row yet
254
                
255
                # Row # is interally 0-based, but 1-based to the user
256
                log('Processing input row #'+str(i+1), level=1.1)
257
                process_row(row, i)
258
                row_ready(i, row)
259
            row_ct = i-start+1
260
            return row_ct
261
        
262
        def map_rows(get_value, rows, **kw_args):
263
            '''Maps input rows
264
            @param get_value(in_, row):str
265
            '''
266
            # Prevent collisions if multiple inputs mapping to same output
267
            outputs_idxs = dict()
268
            for i, mapping in enumerate(mappings):
269
                in_, out = mapping
270
                default = util.NamedTuple(count=1, first=i)
271
                idxs = outputs_idxs.setdefault(out, default)
272
                if idxs is not default: # key existed, so there was a collision
273
                    if idxs.count == 1: # first key does not yet have suffix
274
                        mappings[idxs.first][1] += collision_suffix+'0'
275
                    mappings[i][1] += collision_suffix+str(idxs.count)
276
                    idxs.count += 1
277
            
278
            id_node = None
279
            if out_is_db:
280
                mappings_orig = mappings[:] # save a copy
281
                mappings[:] = [] # empty existing elements
282
                for in_, out in mappings_orig:
283
                    in_str = strings.ustr(in_)
284
                    is_metadata_ = is_metadata(in_str)
285
                    if is_metadata_: value = metadata_value(in_str)
286
                    else: value = '$'+in_str # mark as name
287
                    
288
                    # All put_obj()s should return the same id_node
289
                    nodes, id_node = xpath.put_obj(root, out, '-1', has_types,
290
                        value) # value is placeholder that documents name
291
                    if not is_metadata_: mappings.append([in_, nodes])
292
                if id_node == None:
293
                    warnings.warn(UserWarning('Map warning: No mappings or no '
294
                        'column name matches. Are you importing the correct '
295
                        'input table?'))
296
                xml_func.simplify(root)
297
                sys.stdout.write(strings.to_raw_str('Put template:\n'
298
                    +strings.ustr(root)))
299
                sys.stdout.flush()
300
            
301
            def process_row(row, i):
302
                row_id = str(i)
303
                if id_node != None: xml_dom.set_value(id_node, row_id)
304
                for in_, out in mappings:
305
                    log_debug('Getting '+strings.ustr(in_))
306
                    value = cleanup(get_value(in_, row))
307
                    log_debug('Putting '+strings.urepr(value)+' to '
308
                        +strings.ustr(out))
309
                    if out_is_db: # out is list of XML nodes
310
                        for node in out: xml_dom.set_value(node, value)
311
                    elif value != None: # out is XPath
312
                        xpath.put_obj(root, out, row_id, has_types, value)
313
            return process_rows(process_row, rows, **kw_args)
314
        
315
        def map_table(col_names, rows, **kw_args):
316
            col_names_ct = len(col_names)
317
            col_idxs = util.list_flip(col_names)
318
            col_names_map = dict(zip(col_names, col_names))
319
            prefixes_simp = map(maps.simplify, prefixes)
320
            
321
            # Resolve prefixes
322
            mappings_orig = mappings[:] # save a copy
323
            mappings[:] = [] # empty existing elements
324
            for in_, out in mappings_orig:
325
                if is_metadata(in_): mappings.append([in_, out])
326
                else:
327
                    try:
328
                        cols = get_with_prefix(col_names_map, prefixes_simp,
329
                            in_)
330
                    except KeyError: pass
331
                    else:
332
                        cols = [(orig, col_idxs[orig]) for simp, orig in cols]
333
                        mappings[len(mappings):] = [[db_xml.ColRef(*col), out]
334
                            for col in cols] # can't use += because that uses =
335
            
336
            def get_value(in_, row): return row.list[in_.idx]
337
            def wrap_row(row):
338
                return util.ListDict(util.list_as_length(row, col_names_ct),
339
                    col_names, col_idxs) # handle CSV rows of different lengths
340
            
341
            return map_rows(get_value, util.WrapIter(wrap_row, rows), **kw_args)
342
        
343
        if in_is_db:
344
            def on_error(e): ex_tracker.track(e)
345
            
346
            if by_col: in_db = out_db
347
            else: in_db = connect_db(in_db_config)
348
            
349
            # Get table and schema name
350
            schema = in_schema # modified, so can't have same name as outer var
351
            table = in_table # modified, so can't have same name as outer var
352
            if table == None:
353
                assert in_is_xpaths
354
                schema, sep, table = in_root.partition('.')
355
                if sep == '': # only the table name was specified
356
                    table = schema
357
                    schema = None
358
            table = sql_gen.Table(table, schema)
359
            
360
            # Fetch rows
361
            if by_col: limit = 0 # only fetch column names
362
            else: limit = n
363
            try:
364
                cur = sql.select(in_db, table, limit=limit, start=start,
365
                    recover=True, cacheable=False)
366
            except sql.DoesNotExistException:
367
                table = None
368
                col_names = []
369
                rows = []
370
            else:
371
                col_names = list(sql.col_names(cur))
372
                rows = sql.rows(cur)
373
            
374
            if by_col:
375
                map_table(col_names, []) # just create the template
376
                
377
                if table != None and start == 0 and n == None: # full re-import
378
                    log('Clearing errors table')
379
                    errors_table_ = sql_io.errors_table(in_db, table)
380
                    if errors_table_ != None:
381
                        sql.drop_table(in_db, errors_table_)
382
                
383
                # Strip XML functions not in the DB
384
                xml_func.process(root, is_rel_func=is_rel_func)
385
                if debug: log_debug('Putting stripped:\n'+strings.ustr(root))
386
                    # only calc if debug
387
                
388
                # Import rows
389
                in_row_ct_ref = [0]
390
                db_xml.put_table(in_db, root.firstChild, table, in_row_ct_ref,
391
                    row_ins_ct_ref, n, start, on_error, col_defaults)
392
                row_ct = in_row_ct_ref[0]
393
            else:
394
                # Use normal by-row method
395
                row_ct = map_table(col_names, rows, rows_start=start)
396
                    # rows_start: pre-start rows have been skipped
397
                
398
                in_db.db.close()
399
        elif in_is_xml:
400
            stdin = streams.LineCountStream(sys.stdin)
401
            def on_error(e):
402
                exc.add_msg(e, term.emph('input line #:')+' '
403
                    +str(stdin.line_num))
404
                ex_tracker.track(e)
405
            
406
            def get_rows(doc2rows):
407
                return iters.flatten(itertools.imap(doc2rows,
408
                    xml_parse.docs_iter(stdin, on_error)))
409
            
410
            if map_path == None:
411
                def doc2rows(in_xml_root):
412
                    iter_ = xml_dom.NodeElemIter(in_xml_root)
413
                    util.skip(iter_, xml_dom.is_text) # skip metadata
414
                    return iter_
415
                
416
                row_ct = process_rows(lambda row, i: root.appendChild(row),
417
                    get_rows(doc2rows))
418
            else:
419
                def doc2rows(in_xml_root):
420
                    rows = xpath.get(in_xml_root, in_root, limit=end)
421
                    if rows == []: warnings.warn(UserWarning('Map warning: '
422
                        'Root "'+in_root+'" not found in input'))
423
                    return rows
424
                
425
                def get_value(in_, row):
426
                    in_ = './{'+(','.join(strings.with_prefixes(
427
                        ['']+prefixes, in_)))+'}' # also with no prefix
428
                    nodes = xpath.get(row, in_, allow_rooted=False)
429
                    if nodes != []: return xml_dom.value(nodes[0])
430
                    else: return None
431
                
432
                row_ct = map_rows(get_value, get_rows(doc2rows))
433
        else: # input is CSV
434
            reader, col_names = csvs.reader_and_header(sys.stdin)
435
            row_ct = map_table(col_names, reader)
436
        
437
        return row_ct
438
    
439
    def process_inputs(root, row_ready):
440
        row_ct = 0
441
        for map_path in map_paths:
442
            row_ct += process_input(root, row_ready, map_path)
443
        return row_ct
444
    
445
    pool.share_vars(locals())
446
    if out_is_db:
447
        try:
448
            if redo: sql.empty_db(out_db)
449
            pool.share_vars(locals())
450
            
451
            def row_ready(row_num, input_row):
452
                row_str_ = [None]
453
                def row_str():
454
                    if row_str_[0] == None:
455
                        # Row # is interally 0-based, but 1-based to the user
456
                        row_str_[0] = (term.emph('row #:')+' '+str(row_num+1)
457
                            +'\n'+term.emph('input row:')+'\n'
458
                            +strings.ustr(input_row))
459
                        if verbose_errors: row_str_[0] += ('\n'
460
                            +term.emph('output row:')+'\n'+strings.ustr(root))
461
                    return row_str_[0]
462
                
463
                if debug: log_debug(row_str()) # only calc if debug
464
                
465
                def on_error(e):
466
                    exc.add_msg(e, row_str())
467
                    ex_tracker.track(e, row_num, detail=verbose_errors)
468
                pool.share_vars(locals())
469
                
470
                row_root = root.cloneNode(True) # deep copy so don't modify root
471
                xml_func.process(row_root, on_error, is_rel_func, out_db)
472
                if debug: log_debug('Putting processed:\n'
473
                    +strings.ustr(row_root)) # only calc if debug
474
                if not xml_dom.is_empty(row_root):
475
                    assert xml_dom.has_one_child(row_root)
476
                    try:
477
                        sql.with_savepoint(out_db,
478
                            lambda: db_xml.put(out_db, row_root.firstChild,
479
                                row_ins_ct_ref, on_error, col_defaults))
480
                    except sql.DatabaseErrors, e: on_error(e)
481
            
482
            row_ct = process_inputs(root, row_ready)
483
            sys.stdout.write('Inserted '+str(row_ins_ct_ref[0])+
484
                ' new rows into database\n')
485
            sys.stdout.flush()
486
            
487
            # Consume asynchronous tasks
488
            pool.main_loop()
489
        finally: out_db.close()
490
    else:
491
        def on_error(e): ex_tracker.track(e)
492
        def row_ready(row_num, input_row): pass
493
        row_ct = process_inputs(root, row_ready)
494
        xml_func.process(root, on_error)
495
        if out_is_xml_ref[0]:
496
            doc.writexml(sys.stdout, **xml_dom.prettyxml_config)
497
        else: # output is CSV
498
            raise NotImplementedError('CSV output not supported yet')
499
    
500
    # Consume any asynchronous tasks not already consumed above
501
    pool.main_loop()
502
    
503
    profiler.stop(row_ct)
504
    if not by_col: ex_tracker.add_iters(row_ct) # only if errors are done by row
505
    log('Processed '+str(row_ct)+' input rows')
506
    log(profiler.msg())
507
    log(ex_tracker.msg())
508
    ex_tracker.exit()
509

    
510
def main():
511
    try: main_()
512
    except Parser.SyntaxError, e: raise SystemExit(strings.ustr(e))
513

    
514
if __name__ == '__main__':
515
    profile_to = opts.get_env_var('profile_to', None)
516
    if profile_to != None:
517
        import cProfile
518
        sys.stderr.write('Profiling to '+profile_to+'\n')
519
        cProfile.run(main.func_code, profile_to)
520
    else: main()
(42-42/84)