Project

General

Profile

1
#!/usr/bin/env python
2
# Maps one datasource to another, using a map spreadsheet if needed
3
# Exit status is the # of errors in the import, up to the maximum exit status
4
# For outputting an XML file to a PostgreSQL database, use the general format of
5
# http://vegbank.org/vegdocs/xml/vegbank_example_ver1.0.2.xml
6

    
7
import csv
8
import itertools
9
import os.path
10
import sys
11
import xml.dom.minidom as minidom
12

    
13
sys.path.append(os.path.dirname(__file__)+"/../lib")
14

    
15
import csvs
16
import db_xml
17
import exc
18
import iters
19
import maps
20
import opts
21
import parallelproc
22
import Parser
23
import profiling
24
import sql
25
import streams
26
import strings
27
import term
28
import util
29
import xpath
30
import xml_dom
31
import xml_func
32
import xml_parse
33

    
34
def get_with_prefix(map_, prefixes, key):
35
    '''Gets all entries for the given key with any of the given prefixes
36
    @return tuple(found_key, found_value)
37
    '''
38
    values = []
39
    for key_ in strings.with_prefixes(['']+prefixes, key): # also with no prefix
40
        try: value = map_[key_]
41
        except KeyError, e: continue # keep going
42
        values.append((key_, value))
43
    
44
    if values != []: return values
45
    else: raise e # re-raise last KeyError
46

    
47
def metadata_value(name): return None # this feature has been removed
48

    
49
def cleanup(val):
50
    if val == None: return val
51
    return util.none_if(strings.cleanup(strings.ustr(val)), u'', u'\\N')
52

    
53
def main_():
54
    env_names = []
55
    def usage_err():
56
        raise SystemExit('Usage: '+opts.env_usage(env_names, True)+' '
57
            +sys.argv[0]+' [map_path...] [<input] [>output]\n'
58
            'Note: Row #s start with 1')
59
    
60
    ## Get config from env vars
61
    
62
    # Modes
63
    test = opts.env_flag('test', False, env_names)
64
    commit = opts.env_flag('commit', False, env_names) and not test
65
        # never commit in test mode
66
    redo = opts.env_flag('redo', test, env_names) and not commit
67
        # never redo in commit mode (manually run `make empty_db` instead)
68
    
69
    # Ranges
70
    start = util.cast(int, opts.get_env_var('start', 1, env_names)) # 1-based
71
    # Make start interally 0-based.
72
    # It's 1-based to the user to match up with the staging table row #s.
73
    start -= 1
74
    if test: n_default = 1
75
    else: n_default = None
76
    n = util.cast(int, util.none_if(opts.get_env_var('n', n_default, env_names),
77
        u''))
78
    end = n
79
    if end != None: end += start
80
    
81
    # Debugging
82
    debug = opts.env_flag('debug', False, env_names)
83
    sql.run_raw_query.debug = debug
84
    verbose = debug or opts.env_flag('verbose', not test, env_names)
85
    verbose_errors = opts.env_flag('verbose_errors', test, env_names)
86
    opts.get_env_var('profile_to', None, env_names) # add to env_names
87
    
88
    # DB
89
    def get_db_config(prefix):
90
        return opts.get_env_vars(sql.db_config_names, prefix, env_names)
91
    in_db_config = get_db_config('in')
92
    out_db_config = get_db_config('out')
93
    in_is_db = 'engine' in in_db_config
94
    out_is_db = 'engine' in out_db_config
95
    in_schema = opts.get_env_var('in_schema', None, env_names)
96
    in_table = opts.get_env_var('in_table', None, env_names)
97
    
98
    # Optimization
99
    by_col = in_db_config == out_db_config and opts.env_flag('by_col', False,
100
        env_names) # by-column optimization only applies if mapping to same DB
101
    if test: cpus_default = 0
102
    else: cpus_default = 0 # or None to use parallel processing by default
103
    cpus = util.cast(int, util.none_if(opts.get_env_var('cpus', cpus_default,
104
        env_names), u''))
105
    
106
    ##
107
    
108
    # Logging
109
    def log(msg, on=verbose):
110
        if on: sys.stderr.write(msg+'\n')
111
    if debug: log_debug = lambda msg: log(msg, debug)
112
    else: log_debug = sql.log_debug_none
113
    
114
    # Parse args
115
    map_paths = sys.argv[1:]
116
    if map_paths == []:
117
        if in_is_db or not out_is_db: usage_err()
118
        else: map_paths = [None]
119
    
120
    def connect_db(db_config):
121
        log('Connecting to '+sql.db_config_str(db_config))
122
        return sql.connect(db_config, log_debug=log_debug)
123
    
124
    if end != None: end_str = str(end-1) # end is one past the last #
125
    else: end_str = 'end'
126
    log('Processing input rows '+str(start)+'-'+end_str)
127
    
128
    ex_tracker = exc.ExPercentTracker(iter_text='row')
129
    profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
130
    
131
    # Parallel processing
132
    pool = parallelproc.MultiProducerPool(cpus)
133
    log('Using '+str(pool.process_ct)+' parallel CPUs')
134
    
135
    doc = xml_dom.create_doc()
136
    root = doc.documentElement
137
    out_is_xml_ref = [False]
138
    in_label_ref = [None]
139
    def update_in_label():
140
        if in_label_ref[0] != None:
141
            xpath.get(root, '/_ignore/inLabel="'+in_label_ref[0]+'"', True)
142
    def prep_root():
143
        root.clear()
144
        update_in_label()
145
    prep_root()
146
    
147
    # Define before the out_is_db section because it's used by by_col
148
    row_ins_ct_ref = [0]
149
    
150
    def process_input(root, row_ready, map_path):
151
        '''Inputs datasource to XML tree, mapping if needed'''
152
        # Load map header
153
        in_is_xpaths = True
154
        out_is_xpaths = True
155
        out_label = None
156
        if map_path != None:
157
            metadata = []
158
            mappings = []
159
            stream = open(map_path, 'rb')
160
            reader = csv.reader(stream)
161
            in_label, out_label = reader.next()[:2]
162
            
163
            def split_col_name(name):
164
                label, sep, root = name.partition(':')
165
                label, sep2, prefixes_str = label.partition('[')
166
                prefixes_str = strings.remove_suffix(']', prefixes_str)
167
                prefixes = strings.split(',', prefixes_str)
168
                return label, sep != '', root, prefixes
169
                    # extract datasrc from "datasrc[data_format]"
170
            
171
            in_label, in_root, prefixes = maps.col_info(in_label)
172
            in_is_xpaths = in_root != None
173
            in_label_ref[0] = in_label
174
            update_in_label()
175
            out_label, out_root = maps.col_info(out_label)[:2]
176
            out_is_xpaths = out_root != None
177
            if out_is_xpaths: has_types = out_root.find('/*s/') >= 0
178
                # outer elements are types
179
            
180
            for row in reader:
181
                in_, out = row[:2]
182
                if out != '': mappings.append([in_, out_root+out])
183
            
184
            stream.close()
185
            
186
            root.ownerDocument.documentElement.tagName = out_label
187
        in_is_xml = in_is_xpaths and not in_is_db
188
        out_is_xml_ref[0] = out_is_xpaths and not out_is_db
189
        
190
        def process_rows(process_row, rows, rows_start=0):
191
            '''Processes input rows      
192
            @param process_row(in_row, i)
193
            @rows_start The (0-based) row # of the first row in rows. Set this
194
                only if the pre-start rows have already been skipped.
195
            '''
196
            rows = iter(rows)
197
            
198
            if end != None: row_nums = xrange(rows_start, end)
199
            else: row_nums = itertools.count(rows_start)
200
            i = -1
201
            for i in row_nums:
202
                try: row = rows.next()
203
                except StopIteration:
204
                    i -= 1 # last row # didn't count
205
                    break # no more rows
206
                if i < start: continue # not at start row yet
207
                
208
                process_row(row, i)
209
                row_ready(i, row)
210
            row_ct = i-start+1
211
            return row_ct
212
        
213
        def map_rows(get_value, rows, **kw_args):
214
            '''Maps input rows
215
            @param get_value(in_, row):str
216
            '''
217
            # Prevent collisions if multiple inputs mapping to same output
218
            outputs_idxs = dict()
219
            for i, mapping in enumerate(mappings):
220
                in_, out = mapping
221
                default = util.NamedTuple(count=1, first=i)
222
                idxs = outputs_idxs.setdefault(out, default)
223
                if idxs is not default: # key existed, so there was a collision
224
                    if idxs.count == 1: # first key does not yet have /_alt/#
225
                        mappings[idxs.first][1] += '/_alt/0'
226
                    mappings[i][1] += '/_alt/'+str(idxs.count)
227
                    idxs.count += 1
228
            
229
            id_node = None
230
            if out_is_db:
231
                for i, mapping in enumerate(mappings):
232
                    in_, out = mapping
233
                    # All put_obj()s should return the same id_node
234
                    nodes, id_node = xpath.put_obj(root, out, '-1', has_types,
235
                        '$'+str(in_)) # value is placeholder that documents name
236
                    mappings[i] = [in_, nodes]
237
                assert id_node != None
238
                
239
                if debug: # only str() if debug
240
                    log_debug('Put template:\n'+str(root))
241
            
242
            def process_row(row, i):
243
                row_id = str(i)
244
                if id_node != None: xml_dom.set_value(id_node, row_id)
245
                for in_, out in mappings:
246
                    log_debug('Getting '+str(in_))
247
                    value = metadata_value(in_)
248
                    if value == None: value = cleanup(get_value(in_, row))
249
                    log_debug('Putting '+repr(value)+' to '+str(out))
250
                    if out_is_db: # out is list of XML nodes
251
                        for node in out: xml_dom.set_value(node, value)
252
                    elif value != None: # out is XPath
253
                        xpath.put_obj(root, out, row_id, has_types, value)
254
                if debug: log_debug('Putting:\n'+str(root))# only str() if debug
255
            return process_rows(process_row, rows, **kw_args)
256
        
257
        def map_table(col_names, rows, **kw_args):
258
            col_names_ct = len(col_names)
259
            col_idxs = util.list_flip(col_names)
260
            
261
            # Resolve prefixes
262
            mappings_orig = mappings[:] # save a copy
263
            mappings[:] = [] # empty existing elements
264
            for in_, out in mappings_orig:
265
                if metadata_value(in_) == None:
266
                    try: cols = get_with_prefix(col_idxs, prefixes, in_)
267
                    except KeyError: pass
268
                    else: mappings[len(mappings):] = [[db_xml.ColRef(*col), out]
269
                        for col in cols] # can't use += because that uses =
270
            
271
            def get_value(in_, row): return row.list[in_.idx]
272
            def wrap_row(row):
273
                return util.ListDict(util.list_as_length(row, col_names_ct),
274
                    col_names, col_idxs) # handle CSV rows of different lengths
275
            
276
            return map_rows(get_value, util.WrapIter(wrap_row, rows), **kw_args)
277
        
278
        stdin = streams.LineCountStream(sys.stdin)
279
        def on_error(e):
280
            exc.add_msg(e, term.emph('input line #:')+' '+str(stdin.line_num))
281
            ex_tracker.track(e)
282
        
283
        if in_is_db:
284
            in_db = connect_db(in_db_config)
285
            
286
            # Get table and schema name
287
            schema = in_schema # modified, so can't have same name as outer var
288
            table = in_table # modified, so can't have same name as outer var
289
            if table == None:
290
                assert in_is_xpaths
291
                schema, sep, table = in_root.partition('.')
292
                if sep == '': # only the table name was specified
293
                    table = schema
294
                    schema = None
295
            table_is_esc = False
296
            if schema != None:
297
                table = sql.qual_name(in_db, schema, table)
298
                table_is_esc = True
299
            
300
            # Fetch rows
301
            if by_col: limit = 0 # only fetch column names
302
            else: limit = n
303
            cur = sql.select(in_db, table, limit=limit, start=start,
304
                cacheable=False, table_is_esc=table_is_esc)
305
            col_names = list(sql.col_names(cur))
306
            
307
            if by_col:
308
                map_table(col_names, []) # just create the template
309
                xml_func.strip(root)
310
                db_xml.put_table(in_db, root.firstChild, table, commit,
311
                    row_ins_ct_ref, table_is_esc)
312
            else:
313
                # Use normal by-row method
314
                row_ct = map_table(col_names, sql.rows(cur), rows_start=start)
315
                    # rows_start: pre-start rows have been skipped
316
            
317
            in_db.db.close()
318
        elif in_is_xml:
319
            def get_rows(doc2rows):
320
                return iters.flatten(itertools.imap(doc2rows,
321
                    xml_parse.docs_iter(stdin, on_error)))
322
            
323
            if map_path == None:
324
                def doc2rows(in_xml_root):
325
                    iter_ = xml_dom.NodeElemIter(in_xml_root)
326
                    util.skip(iter_, xml_dom.is_text) # skip metadata
327
                    return iter_
328
                
329
                row_ct = process_rows(lambda row, i: root.appendChild(row),
330
                    get_rows(doc2rows))
331
            else:
332
                def doc2rows(in_xml_root):
333
                    rows = xpath.get(in_xml_root, in_root, limit=end)
334
                    if rows == []: raise SystemExit('Map error: Root "'
335
                        +in_root+'" not found in input')
336
                    return rows
337
                
338
                def get_value(in_, row):
339
                    in_ = './{'+(','.join(strings.with_prefixes(
340
                        ['']+prefixes, in_)))+'}' # also with no prefix
341
                    nodes = xpath.get(row, in_, allow_rooted=False)
342
                    if nodes != []: return xml_dom.value(nodes[0])
343
                    else: return None
344
                
345
                row_ct = map_rows(get_value, get_rows(doc2rows))
346
        else: # input is CSV
347
            map_ = dict(mappings)
348
            reader, col_names = csvs.reader_and_header(sys.stdin)
349
            row_ct = map_table(col_names, reader)
350
        
351
        return row_ct
352
    
353
    def process_inputs(root, row_ready):
354
        row_ct = 0
355
        for map_path in map_paths:
356
            row_ct += process_input(root, row_ready, map_path)
357
        return row_ct
358
    
359
    pool.share_vars(locals())
360
    if out_is_db:
361
        out_db = connect_db(out_db_config)
362
        try:
363
            if redo: sql.empty_db(out_db)
364
            pool.share_vars(locals())
365
            
366
            def row_ready(row_num, input_row):
367
                def on_error(e):
368
                    exc.add_msg(e, term.emph('row #:')+' '+str(row_num+1))
369
                        # row # is interally 0-based, but 1-based to the user
370
                    exc.add_msg(e, term.emph('input row:')+'\n'+str(input_row))
371
                    if verbose_errors:
372
                        exc.add_msg(e, term.emph('output row:')+'\n'+str(root))
373
                    ex_tracker.track(e, row_num, detail=verbose_errors)
374
                pool.share_vars(locals())
375
                
376
                row_root = root.cloneNode(True) # deep copy so don't modify root
377
                xml_func.process(row_root, on_error)
378
                if not xml_dom.is_empty(row_root):
379
                    assert xml_dom.has_one_child(row_root)
380
                    try:
381
                        sql.with_savepoint(out_db,
382
                            lambda: db_xml.put(out_db, row_root.firstChild,
383
                                row_ins_ct_ref, on_error))
384
                        if commit: out_db.db.commit()
385
                    except sql.DatabaseErrors, e: on_error(e)
386
            
387
            row_ct = process_inputs(root, row_ready)
388
            sys.stdout.write('Inserted '+str(row_ins_ct_ref[0])+
389
                ' new rows into database\n')
390
            
391
            # Consume asynchronous tasks
392
            pool.main_loop()
393
        finally:
394
            out_db.db.rollback()
395
            out_db.db.close()
396
    else:
397
        def on_error(e): ex_tracker.track(e)
398
        def row_ready(row_num, input_row): pass
399
        row_ct = process_inputs(root, row_ready)
400
        xml_func.process(root, on_error)
401
        if out_is_xml_ref[0]:
402
            doc.writexml(sys.stdout, **xml_dom.prettyxml_config)
403
        else: # output is CSV
404
            raise NotImplementedError('CSV output not supported yet')
405
    
406
    # Consume any asynchronous tasks not already consumed above
407
    pool.main_loop()
408
    
409
    profiler.stop(row_ct)
410
    ex_tracker.add_iters(row_ct)
411
    if verbose:
412
        sys.stderr.write('Processed '+str(row_ct)+' input rows\n')
413
        sys.stderr.write(profiler.msg()+'\n')
414
        sys.stderr.write(ex_tracker.msg()+'\n')
415
    ex_tracker.exit()
416

    
417
def main():
418
    try: main_()
419
    except Parser.SyntaxError, e: raise SystemExit(str(e))
420

    
421
if __name__ == '__main__':
422
    profile_to = opts.get_env_var('profile_to', None)
423
    if profile_to != None:
424
        import cProfile
425
        sys.stderr.write('Profiling to '+profile_to+'\n')
426
        cProfile.run(main.func_code, profile_to)
427
    else: main()
(25-25/47)