Project

General

Profile

1
#!/usr/bin/env python
2
# Maps one datasource to another, using a map spreadsheet if needed
3
# Exit status is the # of errors in the import, up to the maximum exit status
4
# For outputting an XML file to a PostgreSQL database, use the general format of
5
# http://vegbank.org/vegdocs/xml/vegbank_example_ver1.0.2.xml
6

    
7
import csv
8
import itertools
9
import os.path
10
import sys
11
import xml.dom.minidom as minidom
12

    
13
sys.path.append(os.path.dirname(__file__)+"/../lib")
14

    
15
import csvs
16
import db_xml
17
import exc
18
import iters
19
import maps
20
import opts
21
import parallelproc
22
import Parser
23
import profiling
24
import sql
25
import streams
26
import strings
27
import term
28
import util
29
import xpath
30
import xml_dom
31
import xml_func
32
import xml_parse
33

    
34
def get_with_prefix(map_, prefixes, key):
35
    '''Gets all entries for the given key with any of the given prefixes
36
    @return tuple(found_key, found_value)
37
    '''
38
    values = []
39
    for key_ in strings.with_prefixes(['']+prefixes, key): # also with no prefix
40
        try: value = map_[key_]
41
        except KeyError, e: continue # keep going
42
        values.append((key_, value))
43
    
44
    if values != []: return values
45
    else: raise e # re-raise last KeyError
46

    
47
def metadata_value(name): return None # this feature has been removed
48

    
49
def cleanup(val):
50
    if val == None: return val
51
    return util.none_if(strings.cleanup(strings.ustr(val)), u'', u'\\N')
52

    
53
def main_():
54
    env_names = []
55
    def usage_err():
56
        raise SystemExit('Usage: '+opts.env_usage(env_names, True)+' '
57
            +sys.argv[0]+' [map_path...] [<input] [>output]\n'
58
            'Note: Row #s start with 1')
59
    
60
    ## Get config from env vars
61
    
62
    # Modes
63
    test = opts.env_flag('test', False, env_names)
64
    commit = opts.env_flag('commit', False, env_names) and not test
65
        # never commit in test mode
66
    redo = opts.env_flag('redo', test, env_names) and not commit
67
        # never redo in commit mode (manually run `make empty_db` instead)
68
    
69
    # Ranges
70
    start = util.cast(int, opts.get_env_var('start', 1, env_names)) # 1-based
71
    # Make start interally 0-based.
72
    # It's 1-based to the user to match up with the staging table row #s.
73
    start -= 1
74
    if test: n_default = 1
75
    else: n_default = None
76
    n = util.cast(int, util.none_if(opts.get_env_var('n', n_default, env_names),
77
        u''))
78
    end = n
79
    if end != None: end += start
80
    
81
    # Debugging
82
    debug = opts.env_flag('debug', False, env_names)
83
    sql.run_raw_query.debug = debug
84
    verbose = debug or opts.env_flag('verbose', not test, env_names)
85
    opts.get_env_var('profile_to', None, env_names) # add to env_names
86
    
87
    # DB
88
    def get_db_config(prefix):
89
        return opts.get_env_vars(sql.db_config_names, prefix, env_names)
90
    in_db_config = get_db_config('in')
91
    out_db_config = get_db_config('out')
92
    in_is_db = 'engine' in in_db_config
93
    out_is_db = 'engine' in out_db_config
94
    in_schema = opts.get_env_var('in_schema', None, env_names)
95
    in_table = opts.get_env_var('in_table', None, env_names)
96
    
97
    # Optimization
98
    by_col = in_db_config == out_db_config and opts.env_flag('by_col', False,
99
        env_names) # by-column optimization only applies if mapping to same DB
100
    if test: cpus_default = 0
101
    else: cpus_default = None
102
    cpus = util.cast(int, util.none_if(opts.get_env_var('cpus', cpus_default,
103
        env_names), u''))
104
    
105
    ##
106
    
107
    # Logging
108
    def log(msg, on=verbose):
109
        if on: sys.stderr.write(msg+'\n')
110
    if debug: log_debug = lambda msg: log(msg, debug)
111
    else: log_debug = sql.log_debug_none
112
    
113
    # Parse args
114
    map_paths = sys.argv[1:]
115
    if map_paths == []:
116
        if in_is_db or not out_is_db: usage_err()
117
        else: map_paths = [None]
118
    
119
    def connect_db(db_config):
120
        log('Connecting to '+sql.db_config_str(db_config))
121
        return sql.connect(db_config, log_debug=log_debug)
122
    
123
    if end != None: end_str = str(end-1) # end is one past the last #
124
    else: end_str = 'end'
125
    log('Processing input rows '+str(start)+'-'+end_str)
126
    
127
    ex_tracker = exc.ExPercentTracker(iter_text='row')
128
    profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
129
    
130
    # Parallel processing
131
    pool = parallelproc.MultiProducerPool(cpus)
132
    log('Using '+str(pool.process_ct)+' parallel CPUs')
133
    
134
    doc = xml_dom.create_doc()
135
    root = doc.documentElement
136
    out_is_xml_ref = [False]
137
    in_label_ref = [None]
138
    def update_in_label():
139
        if in_label_ref[0] != None:
140
            xpath.get(root, '/_ignore/inLabel="'+in_label_ref[0]+'"', True)
141
    def prep_root():
142
        root.clear()
143
        update_in_label()
144
    prep_root()
145
    
146
    # Define before the out_is_db section because it's used by by_col
147
    row_ins_ct_ref = [0]
148
    
149
    def process_input(root, row_ready, map_path):
150
        '''Inputs datasource to XML tree, mapping if needed'''
151
        # Load map header
152
        in_is_xpaths = True
153
        out_is_xpaths = True
154
        out_label = None
155
        if map_path != None:
156
            metadata = []
157
            mappings = []
158
            stream = open(map_path, 'rb')
159
            reader = csv.reader(stream)
160
            in_label, out_label = reader.next()[:2]
161
            
162
            def split_col_name(name):
163
                label, sep, root = name.partition(':')
164
                label, sep2, prefixes_str = label.partition('[')
165
                prefixes_str = strings.remove_suffix(']', prefixes_str)
166
                prefixes = strings.split(',', prefixes_str)
167
                return label, sep != '', root, prefixes
168
                    # extract datasrc from "datasrc[data_format]"
169
            
170
            in_label, in_root, prefixes = maps.col_info(in_label)
171
            in_is_xpaths = in_root != None
172
            in_label_ref[0] = in_label
173
            update_in_label()
174
            out_label, out_root = maps.col_info(out_label)[:2]
175
            out_is_xpaths = out_root != None
176
            if out_is_xpaths: has_types = out_root.find('/*s/') >= 0
177
                # outer elements are types
178
            
179
            for row in reader:
180
                in_, out = row[:2]
181
                if out != '': mappings.append([in_, out_root+out])
182
            
183
            stream.close()
184
            
185
            root.ownerDocument.documentElement.tagName = out_label
186
        in_is_xml = in_is_xpaths and not in_is_db
187
        out_is_xml_ref[0] = out_is_xpaths and not out_is_db
188
        
189
        def process_rows(process_row, rows, rows_start=0):
190
            '''Processes input rows      
191
            @param process_row(in_row, i)
192
            @rows_start The (0-based) row # of the first row in rows. Set this
193
                only if the pre-start rows have already been skipped.
194
            '''
195
            rows = iter(rows)
196
            
197
            if end != None: row_nums = xrange(rows_start, end)
198
            else: row_nums = itertools.count(rows_start)
199
            i = -1
200
            for i in row_nums:
201
                try: row = rows.next()
202
                except StopIteration:
203
                    i -= 1 # last row # didn't count
204
                    break # no more rows
205
                if i < start: continue # not at start row yet
206
                
207
                process_row(row, i)
208
                row_ready(i, row)
209
            row_ct = i-start+1
210
            return row_ct
211
        
212
        def map_rows(get_value, rows, **kw_args):
213
            '''Maps input rows
214
            @param get_value(in_, row):str
215
            '''
216
            # Prevent collisions if multiple inputs mapping to same output
217
            outputs_idxs = dict()
218
            for i, mapping in enumerate(mappings):
219
                in_, out = mapping
220
                default = util.NamedTuple(count=1, first=i)
221
                idxs = outputs_idxs.setdefault(out, default)
222
                if idxs is not default: # key existed, so there was a collision
223
                    if idxs.count == 1: # first key does not yet have /_alt/#
224
                        mappings[idxs.first][1] += '/_alt/0'
225
                    mappings[i][1] += '/_alt/'+str(idxs.count)
226
                    idxs.count += 1
227
            
228
            id_node = None
229
            if out_is_db:
230
                for i, mapping in enumerate(mappings):
231
                    in_, out = mapping
232
                    # All put_obj()s should return the same id_node
233
                    nodes, id_node = xpath.put_obj(root, out, '-1', has_types,
234
                        '$'+str(in_)) # value is placeholder that documents name
235
                    mappings[i] = [in_, nodes]
236
                assert id_node != None
237
                
238
                if debug: # only str() if debug
239
                    log_debug('Put template:\n'+str(root))
240
            
241
            def process_row(row, i):
242
                row_id = str(i)
243
                if id_node != None: xml_dom.set_value(id_node, row_id)
244
                for in_, out in mappings:
245
                    log_debug('Getting '+str(in_))
246
                    value = metadata_value(in_)
247
                    if value == None: value = cleanup(get_value(in_, row))
248
                    log_debug('Putting '+repr(value)+' to '+str(out))
249
                    if out_is_db: # out is list of XML nodes
250
                        for node in out: xml_dom.set_value(node, value)
251
                    elif value != None: # out is XPath
252
                        xpath.put_obj(root, out, row_id, has_types, value)
253
                if debug: log_debug('Putting:\n'+str(root))# only str() if debug
254
            return process_rows(process_row, rows, **kw_args)
255
        
256
        def map_table(col_names, rows, **kw_args):
257
            col_names_ct = len(col_names)
258
            col_idxs = util.list_flip(col_names)
259
            
260
            # Resolve prefixes
261
            mappings_orig = mappings[:] # save a copy
262
            mappings[:] = [] # empty existing elements
263
            for in_, out in mappings_orig:
264
                if metadata_value(in_) == None:
265
                    try: cols = get_with_prefix(col_idxs, prefixes, in_)
266
                    except KeyError: pass
267
                    else: mappings[len(mappings):] = [[db_xml.ColRef(*col), out]
268
                        for col in cols] # can't use += because that uses =
269
            
270
            def get_value(in_, row): return row.list[in_.idx]
271
            def wrap_row(row):
272
                return util.ListDict(util.list_as_length(row, col_names_ct),
273
                    col_names, col_idxs) # handle CSV rows of different lengths
274
            
275
            return map_rows(get_value, util.WrapIter(wrap_row, rows), **kw_args)
276
        
277
        stdin = streams.LineCountStream(sys.stdin)
278
        def on_error(e):
279
            exc.add_msg(e, term.emph('input line #:')+' '+str(stdin.line_num))
280
            ex_tracker.track(e)
281
        
282
        if in_is_db:
283
            in_db = connect_db(in_db_config)
284
            
285
            # Get table and schema name
286
            schema = in_schema # modified, so can't have same name as outer var
287
            table = in_table # modified, so can't have same name as outer var
288
            if table == None:
289
                assert in_is_xpaths
290
                schema, sep, table = in_root.partition('.')
291
                if sep == '': # only the table name was specified
292
                    table = schema
293
                    schema = None
294
            table_is_esc = False
295
            if schema != None:
296
                table = sql.qual_name(in_db, schema, table)
297
                table_is_esc = True
298
            
299
            # Fetch rows
300
            if by_col: limit = 0 # only fetch column names
301
            else: limit = n
302
            cur = sql.select(in_db, table, limit=limit, start=start,
303
                cacheable=False, table_is_esc=table_is_esc)
304
            col_names = list(sql.col_names(cur))
305
            
306
            if by_col:
307
                map_table(col_names, []) # just create the template
308
                xml_func.strip(root)
309
                db_xml.put_table(in_db, root.firstChild, table, commit,
310
                    row_ins_ct_ref, table_is_esc)
311
            else:
312
                # Use normal by-row method
313
                row_ct = map_table(col_names, sql.rows(cur), rows_start=start)
314
                    # rows_start: pre-start rows have been skipped
315
            
316
            in_db.db.close()
317
        elif in_is_xml:
318
            def get_rows(doc2rows):
319
                return iters.flatten(itertools.imap(doc2rows,
320
                    xml_parse.docs_iter(stdin, on_error)))
321
            
322
            if map_path == None:
323
                def doc2rows(in_xml_root):
324
                    iter_ = xml_dom.NodeElemIter(in_xml_root)
325
                    util.skip(iter_, xml_dom.is_text) # skip metadata
326
                    return iter_
327
                
328
                row_ct = process_rows(lambda row, i: root.appendChild(row),
329
                    get_rows(doc2rows))
330
            else:
331
                def doc2rows(in_xml_root):
332
                    rows = xpath.get(in_xml_root, in_root, limit=end)
333
                    if rows == []: raise SystemExit('Map error: Root "'
334
                        +in_root+'" not found in input')
335
                    return rows
336
                
337
                def get_value(in_, row):
338
                    in_ = './{'+(','.join(strings.with_prefixes(
339
                        ['']+prefixes, in_)))+'}' # also with no prefix
340
                    nodes = xpath.get(row, in_, allow_rooted=False)
341
                    if nodes != []: return xml_dom.value(nodes[0])
342
                    else: return None
343
                
344
                row_ct = map_rows(get_value, get_rows(doc2rows))
345
        else: # input is CSV
346
            map_ = dict(mappings)
347
            reader, col_names = csvs.reader_and_header(sys.stdin)
348
            row_ct = map_table(col_names, reader)
349
        
350
        return row_ct
351
    
352
    def process_inputs(root, row_ready):
353
        row_ct = 0
354
        for map_path in map_paths:
355
            row_ct += process_input(root, row_ready, map_path)
356
        return row_ct
357
    
358
    pool.share_vars(locals())
359
    if out_is_db:
360
        out_db = connect_db(out_db_config)
361
        try:
362
            if redo: sql.empty_db(out_db)
363
            pool.share_vars(locals())
364
            
365
            def row_ready(row_num, input_row):
366
                def on_error(e):
367
                    exc.add_msg(e, term.emph('row #:')+' '+str(row_num+1))
368
                        # row # is interally 0-based, but 1-based to the user
369
                    exc.add_msg(e, term.emph('input row:')+'\n'+str(input_row))
370
                    exc.add_msg(e, term.emph('output row:')+'\n'+str(root))
371
                    ex_tracker.track(e, row_num)
372
                pool.share_vars(locals())
373
                
374
                row_root = root.cloneNode(True) # deep copy so don't modify root
375
                xml_func.process(row_root, on_error)
376
                if not xml_dom.is_empty(row_root):
377
                    assert xml_dom.has_one_child(row_root)
378
                    try:
379
                        sql.with_savepoint(out_db,
380
                            lambda: db_xml.put(out_db, row_root.firstChild,
381
                                row_ins_ct_ref, on_error))
382
                        if commit: out_db.db.commit()
383
                    except sql.DatabaseErrors, e: on_error(e)
384
            
385
            row_ct = process_inputs(root, row_ready)
386
            sys.stdout.write('Inserted '+str(row_ins_ct_ref[0])+
387
                ' new rows into database\n')
388
            
389
            # Consume asynchronous tasks
390
            pool.main_loop()
391
        finally:
392
            out_db.db.rollback()
393
            out_db.db.close()
394
    else:
395
        def on_error(e): ex_tracker.track(e)
396
        def row_ready(row_num, input_row): pass
397
        row_ct = process_inputs(root, row_ready)
398
        xml_func.process(root, on_error)
399
        if out_is_xml_ref[0]:
400
            doc.writexml(sys.stdout, **xml_dom.prettyxml_config)
401
        else: # output is CSV
402
            raise NotImplementedError('CSV output not supported yet')
403
    
404
    # Consume any asynchronous tasks not already consumed above
405
    pool.main_loop()
406
    
407
    profiler.stop(row_ct)
408
    ex_tracker.add_iters(row_ct)
409
    if verbose:
410
        sys.stderr.write('Processed '+str(row_ct)+' input rows\n')
411
        sys.stderr.write(profiler.msg()+'\n')
412
        sys.stderr.write(ex_tracker.msg()+'\n')
413
    ex_tracker.exit()
414

    
415
def main():
416
    try: main_()
417
    except Parser.SyntaxError, e: raise SystemExit(str(e))
418

    
419
if __name__ == '__main__':
420
    profile_to = opts.get_env_var('profile_to', None)
421
    if profile_to != None:
422
        import cProfile
423
        sys.stderr.write('Profiling to '+profile_to+'\n')
424
        cProfile.run(main.func_code, profile_to)
425
    else: main()
(25-25/47)