Project

General

Profile

1
#!/usr/bin/env python
2
# Maps one datasource to another, using a map spreadsheet if needed
3
# Exit status is the # of errors in the import, up to the maximum exit status
4
# For outputting an XML file to a PostgreSQL database, use the general format of
5
# http://vegbank.org/vegdocs/xml/vegbank_example_ver1.0.2.xml
6

    
7
import csv
8
import itertools
9
import os.path
10
import sys
11
import xml.dom.minidom as minidom
12

    
13
sys.path.append(os.path.dirname(__file__)+"/../lib")
14

    
15
import csvs
16
import db_xml
17
import exc
18
import iters
19
import maps
20
import opts
21
import parallelproc
22
import Parser
23
import profiling
24
import sql
25
import streams
26
import strings
27
import term
28
import util
29
import xpath
30
import xml_dom
31
import xml_func
32
import xml_parse
33

    
34
def get_with_prefix(map_, prefixes, key):
35
    '''Gets all entries for the given key with any of the given prefixes
36
    @return tuple(found_key, found_value)
37
    '''
38
    values = []
39
    for key_ in strings.with_prefixes(['']+prefixes, key): # also with no prefix
40
        try: value = map_[key_]
41
        except KeyError, e: continue # keep going
42
        values.append((key_, value))
43
    
44
    if values != []: return values
45
    else: raise e # re-raise last KeyError
46

    
47
def metadata_value(name): return None # this feature has been removed
48

    
49
def cleanup(val):
50
    if val == None: return val
51
    return util.none_if(strings.cleanup(strings.ustr(val)), u'', u'\\N')
52

    
53
def main_():
54
    env_names = []
55
    def usage_err():
56
        raise SystemExit('Usage: '+opts.env_usage(env_names, True)+' '
57
            +sys.argv[0]+' [map_path...] [<input] [>output]\n'
58
            'Note: Row #s start with 1')
59
    
60
    ## Get config from env vars
61
    
62
    # Modes
63
    test = opts.env_flag('test', False, env_names)
64
    commit = opts.env_flag('commit', False, env_names) and not test
65
        # never commit in test mode
66
    redo = opts.env_flag('redo', test, env_names) and not commit
67
        # never redo in commit mode (manually run `make empty_db` instead)
68
    
69
    # Ranges
70
    start = util.cast(int, opts.get_env_var('start', 1, env_names)) # 1-based
71
    # Make start interally 0-based.
72
    # It's 1-based to the user to match up with the staging table row #s.
73
    start -= 1
74
    if test: n_default = 1
75
    else: n_default = None
76
    n = util.cast(int, util.none_if(opts.get_env_var('n', n_default, env_names),
77
        u''))
78
    end = n
79
    if end != None: end += start
80
    
81
    # Debugging
82
    debug = opts.env_flag('debug', False, env_names)
83
    sql.run_raw_query.debug = debug
84
    verbose = debug or opts.env_flag('verbose', not test, env_names)
85
    opts.get_env_var('profile_to', None, env_names) # add to env_names
86
    
87
    # DB
88
    def get_db_config(prefix):
89
        return opts.get_env_vars(sql.db_config_names, prefix, env_names)
90
    in_db_config = get_db_config('in')
91
    out_db_config = get_db_config('out')
92
    in_is_db = 'engine' in in_db_config
93
    out_is_db = 'engine' in out_db_config
94
    in_schema = opts.get_env_var('in_schema', None, env_names)
95
    in_table = opts.get_env_var('in_table', None, env_names)
96
    
97
    # Optimization
98
    by_col = in_db_config == out_db_config and opts.env_flag('by_col', False,
99
        env_names) # by-column optimization only applies if mapping to same DB
100
    if test: cpus_default = 0
101
    else: cpus_default = None
102
    cpus = util.cast(int, util.none_if(opts.get_env_var('cpus', cpus_default,
103
        env_names), u''))
104
    
105
    ##
106
    
107
    # Logging
108
    def log(msg, on=verbose):
109
        if on: sys.stderr.write(msg+'\n')
110
    if debug: log_debug = lambda msg: log(msg, debug)
111
    else: log_debug = sql.log_debug_none
112
    
113
    # Parse args
114
    map_paths = sys.argv[1:]
115
    if map_paths == []:
116
        if in_is_db or not out_is_db: usage_err()
117
        else: map_paths = [None]
118
    
119
    def connect_db(db_config):
120
        log('Connecting to '+sql.db_config_str(db_config))
121
        return sql.connect(db_config, log_debug=log_debug)
122
    
123
    if end != None: end_str = str(end-1) # end is one past the last #
124
    else: end_str = 'end'
125
    log('Processing input rows '+str(start)+'-'+end_str)
126
    
127
    ex_tracker = exc.ExPercentTracker(iter_text='row')
128
    profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
129
    
130
    # Parallel processing
131
    pool = parallelproc.MultiProducerPool(cpus)
132
    log('Using '+str(pool.process_ct)+' parallel CPUs')
133
    
134
    doc = xml_dom.create_doc()
135
    root = doc.documentElement
136
    out_is_xml_ref = [False]
137
    in_label_ref = [None]
138
    def update_in_label():
139
        if in_label_ref[0] != None:
140
            xpath.get(root, '/_ignore/inLabel="'+in_label_ref[0]+'"', True)
141
    def prep_root():
142
        root.clear()
143
        update_in_label()
144
    prep_root()
145
    
146
    # Define before the out_is_db section because it's used by by_col
147
    row_ins_ct_ref = [0]
148
    
149
    def process_input(root, row_ready, map_path):
150
        '''Inputs datasource to XML tree, mapping if needed'''
151
        # Load map header
152
        in_is_xpaths = True
153
        out_is_xpaths = True
154
        out_label = None
155
        if map_path != None:
156
            metadata = []
157
            mappings = []
158
            stream = open(map_path, 'rb')
159
            reader = csv.reader(stream)
160
            in_label, out_label = reader.next()[:2]
161
            
162
            def split_col_name(name):
163
                label, sep, root = name.partition(':')
164
                label, sep2, prefixes_str = label.partition('[')
165
                prefixes_str = strings.remove_suffix(']', prefixes_str)
166
                prefixes = strings.split(',', prefixes_str)
167
                return label, sep != '', root, prefixes
168
                    # extract datasrc from "datasrc[data_format]"
169
            
170
            in_label, in_root, prefixes = maps.col_info(in_label)
171
            in_is_xpaths = in_root != None
172
            in_label_ref[0] = in_label
173
            update_in_label()
174
            out_label, out_root = maps.col_info(out_label)[:2]
175
            out_is_xpaths = out_root != None
176
            if out_is_xpaths: has_types = out_root.find('/*s/') >= 0
177
                # outer elements are types
178
            
179
            for row in reader:
180
                in_, out = row[:2]
181
                if out != '': mappings.append([in_, out_root+out])
182
            
183
            stream.close()
184
            
185
            root.ownerDocument.documentElement.tagName = out_label
186
        in_is_xml = in_is_xpaths and not in_is_db
187
        out_is_xml_ref[0] = out_is_xpaths and not out_is_db
188
        
189
        def process_rows(process_row, rows, rows_start=0):
190
            '''Processes input rows      
191
            @param process_row(in_row, i)
192
            @rows_start The (0-based) row # of the first row in rows. Set this
193
                only if the pre-start rows have already been skipped.
194
            '''
195
            rows = iter(rows)
196
            
197
            if end != None: row_nums = xrange(rows_start, end)
198
            else: row_nums = itertools.count(rows_start)
199
            i = -1
200
            for i in row_nums:
201
                try: row = rows.next()
202
                except StopIteration:
203
                    i -= 1 # last row # didn't count
204
                    break # no more rows
205
                if i < start: continue # not at start row yet
206
                
207
                process_row(row, i)
208
                row_ready(i, row)
209
            row_ct = i-start+1
210
            return row_ct
211
        
212
        def map_rows(get_value, rows, **kw_args):
213
            '''Maps input rows
214
            @param get_value(in_, row):str
215
            '''
216
            # Prevent collisions if multiple inputs mapping to same output
217
            outputs_idxs = dict()
218
            for i, mapping in enumerate(mappings):
219
                in_, out = mapping
220
                default = util.NamedTuple(count=1, first=i)
221
                idxs = outputs_idxs.setdefault(out, default)
222
                if idxs is not default: # key existed, so there was a collision
223
                    if idxs.count == 1: # first key does not yet have /_alt/#
224
                        mappings[idxs.first][1] += '/_alt/0'
225
                    mappings[i][1] += '/_alt/'+str(idxs.count)
226
                    idxs.count += 1
227
            
228
            id_node = None
229
            if out_is_db:
230
                for i, mapping in enumerate(mappings):
231
                    in_, out = mapping
232
                    # All put_obj()s should return the same id_node
233
                    nodes, id_node = xpath.put_obj(root, out, '-1', has_types,
234
                        '$'+str(in_)) # value is placeholder that documents name
235
                    mappings[i] = [in_, nodes]
236
                assert id_node != None
237
                
238
                if debug: # only str() if debug
239
                    log_debug('Put template:\n'+str(root))
240
            
241
            def process_row(row, i):
242
                row_id = str(i)
243
                if id_node != None: xml_dom.set_value(id_node, row_id)
244
                for in_, out in mappings:
245
                    log_debug('Getting '+str(in_))
246
                    value = metadata_value(in_)
247
                    if value == None: value = cleanup(get_value(in_, row))
248
                    log_debug('Putting '+repr(value)+' to '+str(out))
249
                    if out_is_db: # out is list of XML nodes
250
                        for node in out: xml_dom.set_value(node, value)
251
                    elif value != None: # out is XPath
252
                        xpath.put_obj(root, out, row_id, has_types, value)
253
                if debug: log_debug('Putting:\n'+str(root))# only str() if debug
254
            return process_rows(process_row, rows, **kw_args)
255
        
256
        def map_table(col_names, rows, **kw_args):
257
            col_names_ct = len(col_names)
258
            col_idxs = util.list_flip(col_names)
259
            
260
            # Resolve prefixes
261
            mappings_orig = mappings[:] # save a copy
262
            mappings[:] = [] # empty existing elements
263
            for in_, out in mappings_orig:
264
                if metadata_value(in_) == None:
265
                    try: cols = get_with_prefix(col_idxs, prefixes, in_)
266
                    except KeyError: pass
267
                    else: mappings[len(mappings):] = [[db_xml.ColRef(*col), out]
268
                        for col in cols] # can't use += because that uses =
269
            
270
            def get_value(in_, row): return row.list[in_.idx]
271
            def wrap_row(row):
272
                return util.ListDict(util.list_as_length(row, col_names_ct),
273
                    col_names, col_idxs) # handle CSV rows of different lengths
274
            
275
            return map_rows(get_value, util.WrapIter(wrap_row, rows), **kw_args)
276
        
277
        stdin = streams.LineCountStream(sys.stdin)
278
        def on_error(e):
279
            exc.add_msg(e, term.emph('input line #:')+' '+str(stdin.line_num))
280
            ex_tracker.track(e)
281
        
282
        if in_is_db:
283
            in_db = connect_db(in_db_config)
284
            
285
            # Get table and schema name
286
            schema = in_schema # modified, so can't have same name as outer var
287
            table = in_table # modified, so can't have same name as outer var
288
            if table == None:
289
                assert in_is_xpaths
290
                schema, sep, table = in_root.partition('.')
291
                if sep == '': # only the table name was specified
292
                    table = schema
293
                    schema = None
294
            table_is_esc = False
295
            if schema != None:
296
                table = sql.qual_name(in_db, schema, table)
297
                table_is_esc = True
298
            
299
            # Fetch rows
300
            if by_col: limit = 0 # only fetch column names
301
            else: limit = n
302
            cur = sql.select(in_db, table, limit=limit, start=start,
303
                table_is_esc=table_is_esc)
304
            col_names = list(sql.col_names(cur))
305
            
306
            if by_col:
307
                row_ready = lambda row_num, input_row: None# disable row_ready()
308
                row = ['$'+v for v in col_names] # values are the column names
309
                map_table(col_names, [row]) # map just the sample row
310
                xml_func.strip(root)
311
                db_xml.put_table(in_db, root.firstChild, table, commit,
312
                    row_ins_ct_ref, table_is_esc)
313
            else:
314
                # Use normal by-row method
315
                row_ct = map_table(col_names, sql.rows(cur), rows_start=start)
316
                    # rows_start: pre-start rows have been skipped
317
            
318
            in_db.db.close()
319
        elif in_is_xml:
320
            def get_rows(doc2rows):
321
                return iters.flatten(itertools.imap(doc2rows,
322
                    xml_parse.docs_iter(stdin, on_error)))
323
            
324
            if map_path == None:
325
                def doc2rows(in_xml_root):
326
                    iter_ = xml_dom.NodeElemIter(in_xml_root)
327
                    util.skip(iter_, xml_dom.is_text) # skip metadata
328
                    return iter_
329
                
330
                row_ct = process_rows(lambda row, i: root.appendChild(row),
331
                    get_rows(doc2rows))
332
            else:
333
                def doc2rows(in_xml_root):
334
                    rows = xpath.get(in_xml_root, in_root, limit=end)
335
                    if rows == []: raise SystemExit('Map error: Root "'
336
                        +in_root+'" not found in input')
337
                    return rows
338
                
339
                def get_value(in_, row):
340
                    in_ = './{'+(','.join(strings.with_prefixes(
341
                        ['']+prefixes, in_)))+'}' # also with no prefix
342
                    nodes = xpath.get(row, in_, allow_rooted=False)
343
                    if nodes != []: return xml_dom.value(nodes[0])
344
                    else: return None
345
                
346
                row_ct = map_rows(get_value, get_rows(doc2rows))
347
        else: # input is CSV
348
            map_ = dict(mappings)
349
            reader, col_names = csvs.reader_and_header(sys.stdin)
350
            row_ct = map_table(col_names, reader)
351
        
352
        return row_ct
353
    
354
    def process_inputs(root, row_ready):
355
        row_ct = 0
356
        for map_path in map_paths:
357
            row_ct += process_input(root, row_ready, map_path)
358
        return row_ct
359
    
360
    pool.share_vars(locals())
361
    if out_is_db:
362
        import db_xml
363
        
364
        out_db = connect_db(out_db_config)
365
        try:
366
            if redo: sql.empty_db(out_db)
367
            pool.share_vars(locals())
368
            
369
            def row_ready(row_num, input_row):
370
                def on_error(e):
371
                    exc.add_msg(e, term.emph('row #:')+' '+str(row_num+1))
372
                        # row # is interally 0-based, but 1-based to the user
373
                    exc.add_msg(e, term.emph('input row:')+'\n'+str(input_row))
374
                    exc.add_msg(e, term.emph('output row:')+'\n'+str(root))
375
                    ex_tracker.track(e, row_num)
376
                pool.share_vars(locals())
377
                
378
                row_root = root.cloneNode(True) # deep copy so don't modify root
379
                xml_func.process(row_root, on_error)
380
                if not xml_dom.is_empty(row_root):
381
                    assert xml_dom.has_one_child(row_root)
382
                    try:
383
                        sql.with_savepoint(out_db,
384
                            lambda: db_xml.put(out_db, row_root.firstChild,
385
                                row_ins_ct_ref, on_error))
386
                        if commit: out_db.db.commit()
387
                    except sql.DatabaseErrors, e: on_error(e)
388
            
389
            row_ct = process_inputs(root, row_ready)
390
            sys.stdout.write('Inserted '+str(row_ins_ct_ref[0])+
391
                ' new rows into database\n')
392
            
393
            # Consume asynchronous tasks
394
            pool.main_loop()
395
        finally:
396
            out_db.db.rollback()
397
            out_db.db.close()
398
    else:
399
        def on_error(e): ex_tracker.track(e)
400
        def row_ready(row_num, input_row): pass
401
        row_ct = process_inputs(root, row_ready)
402
        xml_func.process(root, on_error)
403
        if out_is_xml_ref[0]:
404
            doc.writexml(sys.stdout, **xml_dom.prettyxml_config)
405
        else: # output is CSV
406
            raise NotImplementedError('CSV output not supported yet')
407
    
408
    # Consume any asynchronous tasks not already consumed above
409
    pool.main_loop()
410
    
411
    profiler.stop(row_ct)
412
    ex_tracker.add_iters(row_ct)
413
    if verbose:
414
        sys.stderr.write('Processed '+str(row_ct)+' input rows\n')
415
        sys.stderr.write(profiler.msg()+'\n')
416
        sys.stderr.write(ex_tracker.msg()+'\n')
417
    ex_tracker.exit()
418

    
419
def main():
420
    try: main_()
421
    except Parser.SyntaxError, e: raise SystemExit(str(e))
422

    
423
if __name__ == '__main__':
424
    profile_to = opts.get_env_var('profile_to', None)
425
    if profile_to != None:
426
        import cProfile
427
        sys.stderr.write('Profiling to '+profile_to+'\n')
428
        cProfile.run(main.func_code, profile_to)
429
    else: main()
(25-25/47)