Project

General

Profile

1
#!/usr/bin/env python
2
# Maps one datasource to another, using a map spreadsheet if needed
3
# Exit status is the # of errors in the import, up to the maximum exit status
4
# For outputting an XML file to a PostgreSQL database, use the general format of
5
# http://vegbank.org/vegdocs/xml/vegbank_example_ver1.0.2.xml
6

    
7
import csv
8
import itertools
9
import os.path
10
import sys
11
import xml.dom.minidom as minidom
12

    
13
sys.path.append(os.path.dirname(__file__)+"/../lib")
14

    
15
import csvs
16
import exc
17
import iters
18
import maps
19
import opts
20
import parallel
21
import Parser
22
import profiling
23
import sql
24
import streams
25
import strings
26
import term
27
import util
28
import xpath
29
import xml_dom
30
import xml_func
31
import xml_parse
32

    
33
def get_with_prefix(map_, prefixes, key):
34
    '''Gets all entries for the given key with any of the given prefixes'''
35
    values = []
36
    for key_ in strings.with_prefixes(['']+prefixes, key): # also with no prefix
37
        try: value = map_[key_]
38
        except KeyError, e: continue # keep going
39
        values.append(value)
40
    
41
    if values != []: return values
42
    else: raise e # re-raise last KeyError
43

    
44
def metadata_value(name): return None # this feature has been removed
45

    
46
def cleanup(val):
47
    if val == None: return val
48
    return util.none_if(strings.cleanup(strings.ustr(val)), u'', u'\\N')
49

    
50
def main_():
51
    env_names = []
52
    def usage_err():
53
        raise SystemExit('Usage: '+opts.env_usage(env_names, True)+' '
54
            +sys.argv[0]+' [map_path...] [<input] [>output]\n'
55
            'Note: Row #s start with 1')
56
    
57
    ## Get config from env vars
58
    
59
    # Modes
60
    test = opts.env_flag('test', False, env_names)
61
    commit = opts.env_flag('commit', False, env_names) and not test
62
        # never commit in test mode
63
    redo = opts.env_flag('redo', test, env_names) and not commit
64
        # never redo in commit mode (manually run `make empty_db` instead)
65
    
66
    # Ranges
67
    start = util.cast(int, opts.get_env_var('start', 1, env_names)) # 1-based
68
    # Make start interally 0-based.
69
    # It's 1-based to the user to match up with the staging table row #s.
70
    start -= 1
71
    if test: n_default = 1
72
    else: n_default = None
73
    n = util.cast(int, util.none_if(opts.get_env_var('n', n_default, env_names),
74
        u''))
75
    end = n
76
    if end != None: end += start
77
    
78
    # Optimization
79
    if test: cpus_default = 0
80
    else: cpus_default = None
81
    cpus = util.cast(int, util.none_if(opts.get_env_var('cpus', cpus_default,
82
        env_names), u''))
83
    
84
    # Debugging
85
    debug = opts.env_flag('debug', False, env_names)
86
    sql.run_raw_query.debug = debug
87
    verbose = debug or opts.env_flag('verbose', not test, env_names)
88
    opts.get_env_var('profile_to', None, env_names) # add to env_names
89
    
90
    # DB
91
    def get_db_config(prefix):
92
        return opts.get_env_vars(sql.db_config_names, prefix, env_names)
93
    in_db_config = get_db_config('in')
94
    out_db_config = get_db_config('out')
95
    in_is_db = 'engine' in in_db_config
96
    out_is_db = 'engine' in out_db_config
97
    in_schema = opts.get_env_var('in_schema', None, env_names)
98
    in_table = opts.get_env_var('in_table', None, env_names)
99
    
100
    ##
101
    
102
    # Logging
103
    def log(msg, on=verbose):
104
        if on: sys.stderr.write(msg+'\n')
105
    if debug: log_debug = lambda msg: log(msg, debug)
106
    else: log_debug = sql.log_debug_none
107
    
108
    # Parse args
109
    map_paths = sys.argv[1:]
110
    if map_paths == []:
111
        if in_is_db or not out_is_db: usage_err()
112
        else: map_paths = [None]
113
    
114
    def connect_db(db_config):
115
        log('Connecting to '+sql.db_config_str(db_config))
116
        return sql.connect(db_config, log_debug=log_debug)
117
    
118
    if end != None: end_str = str(end-1) # end is one past the last #
119
    else: end_str = 'end'
120
    log('Processing input rows '+str(start)+'-'+end_str)
121
    
122
    ex_tracker = exc.ExPercentTracker(iter_text='row')
123
    profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
124
    
125
    # Parallel processing
126
    pool = parallel.MultiProducerPool(cpus)
127
    log('Using '+str(pool.process_ct)+' parallel CPUs')
128
    
129
    doc = xml_dom.create_doc()
130
    root = doc.documentElement
131
    out_is_xml_ref = [False]
132
    in_label_ref = [None]
133
    def update_in_label():
134
        if in_label_ref[0] != None:
135
            xpath.get(root, '/_ignore/inLabel="'+in_label_ref[0]+'"', True)
136
    def prep_root():
137
        root.clear()
138
        update_in_label()
139
    prep_root()
140
    
141
    def process_input(root, row_ready, map_path):
142
        '''Inputs datasource to XML tree, mapping if needed'''
143
        # Load map header
144
        in_is_xpaths = True
145
        out_is_xpaths = True
146
        out_label = None
147
        if map_path != None:
148
            metadata = []
149
            mappings = []
150
            stream = open(map_path, 'rb')
151
            reader = csv.reader(stream)
152
            in_label, out_label = reader.next()[:2]
153
            
154
            def split_col_name(name):
155
                label, sep, root = name.partition(':')
156
                label, sep2, prefixes_str = label.partition('[')
157
                prefixes_str = strings.remove_suffix(']', prefixes_str)
158
                prefixes = strings.split(',', prefixes_str)
159
                return label, sep != '', root, prefixes
160
                    # extract datasrc from "datasrc[data_format]"
161
            
162
            in_label, in_root, prefixes = maps.col_info(in_label)
163
            in_is_xpaths = in_root != None
164
            in_label_ref[0] = in_label
165
            update_in_label()
166
            out_label, out_root = maps.col_info(out_label)[:2]
167
            out_is_xpaths = out_root != None
168
            if out_is_xpaths: has_types = out_root.find('/*s/') >= 0
169
                # outer elements are types
170
            
171
            for row in reader:
172
                in_, out = row[:2]
173
                if out != '':
174
                    if out_is_xpaths: out = xpath.parse(out_root+out)
175
                    mappings.append((in_, out))
176
            
177
            stream.close()
178
            
179
            root.ownerDocument.documentElement.tagName = out_label
180
        in_is_xml = in_is_xpaths and not in_is_db
181
        out_is_xml_ref[0] = out_is_xpaths and not out_is_db
182
        
183
        def process_rows(process_row, rows, rows_start=0):
184
            '''Processes input rows      
185
            @param process_row(in_row, i)
186
            @rows_start The (0-based) row # of the first row in rows. Set this
187
                only if the pre-start rows have already been skipped.
188
            '''
189
            rows = iter(rows)
190
            
191
            if end != None: row_nums = xrange(rows_start, end)
192
            else: row_nums = itertools.count(rows_start)
193
            for i in row_nums:
194
                try: row = rows.next()
195
                except StopIteration: break # no more rows
196
                if i < start: continue # not at start row yet
197
                
198
                process_row(row, i)
199
                row_ready(i, row)
200
            row_ct = i-start
201
            return row_ct
202
        
203
        def map_rows(get_value, rows, **kw_args):
204
            '''Maps input rows
205
            @param get_value(in_, row):str
206
            '''
207
            def process_row(row, i):
208
                row_id = str(i)
209
                for in_, out in mappings:
210
                    value = metadata_value(in_)
211
                    if value == None:
212
                        log_debug('Getting '+str(in_))
213
                        value = cleanup(get_value(in_, row))
214
                    if value != None:
215
                        log_debug('Putting '+str(out))
216
                        xpath.put_obj(root, out, row_id, has_types, value)
217
            return process_rows(process_row, rows, **kw_args)
218
        
219
        def map_table(col_names, rows, **kw_args):
220
            col_names_ct = len(col_names)
221
            col_idxs = util.list_flip(col_names)
222
            
223
            i = 0
224
            while i < len(mappings): # mappings len changes in loop
225
                in_, out = mappings[i]
226
                if metadata_value(in_) == None:
227
                    try: mappings[i] = (
228
                        get_with_prefix(col_idxs, prefixes, in_), out)
229
                    except KeyError:
230
                        del mappings[i]
231
                        continue # keep i the same
232
                i += 1
233
            
234
            def get_value(in_, row):
235
                return util.coalesce(*util.list_subset(row.list, in_))
236
            def wrap_row(row):
237
                return util.ListDict(util.list_as_length(row, col_names_ct),
238
                    col_names, col_idxs) # handle CSV rows of different lengths
239
            
240
            return map_rows(get_value, util.WrapIter(wrap_row, rows), **kw_args)
241
        
242
        stdin = streams.LineCountStream(sys.stdin)
243
        def on_error(e):
244
            exc.add_msg(e, term.emph('input line #:')+' '+str(stdin.line_num))
245
            ex_tracker.track(e)
246
        
247
        if in_is_db:
248
            in_db = connect_db(in_db_config)
249
            
250
            # Get table and schema name
251
            schema = in_schema # modified, so can't have same name as outer var
252
            table = in_table # modified, so can't have same name as outer var
253
            if table == None:
254
                assert in_is_xpaths
255
                schema, sep, table = in_root.partition('.')
256
                if sep == '': # only the table name was specified
257
                    table = schema
258
                    schema = None
259
            table_is_esc = False
260
            if schema != None:
261
                table = sql.qual_name(in_db, schema, table)
262
                table_is_esc = True
263
            
264
            cur = sql.select(in_db, table, limit=n, start=start,
265
                table_is_esc=table_is_esc)
266
            row_ct = map_table(list(sql.col_names(cur)), sql.rows(cur),
267
                rows_start=start) # rows_start: pre-start rows have been skipped
268
            
269
            in_db.db.close()
270
        elif in_is_xml:
271
            def get_rows(doc2rows):
272
                return iters.flatten(itertools.imap(doc2rows,
273
                    xml_parse.docs_iter(stdin, on_error)))
274
            
275
            if map_path == None:
276
                def doc2rows(in_xml_root):
277
                    iter_ = xml_dom.NodeElemIter(in_xml_root)
278
                    util.skip(iter_, xml_dom.is_text) # skip metadata
279
                    return iter_
280
                
281
                row_ct = process_rows(lambda row, i: root.appendChild(row),
282
                    get_rows(doc2rows))
283
            else:
284
                def doc2rows(in_xml_root):
285
                    rows = xpath.get(in_xml_root, in_root, limit=end)
286
                    if rows == []: raise SystemExit('Map error: Root "'
287
                        +in_root+'" not found in input')
288
                    return rows
289
                
290
                def get_value(in_, row):
291
                    in_ = './{'+(','.join(strings.with_prefixes(
292
                        ['']+prefixes, in_)))+'}' # also with no prefix
293
                    nodes = xpath.get(row, in_, allow_rooted=False)
294
                    if nodes != []: return xml_dom.value(nodes[0])
295
                    else: return None
296
                
297
                row_ct = map_rows(get_value, get_rows(doc2rows))
298
        else: # input is CSV
299
            map_ = dict(mappings)
300
            reader, col_names = csvs.reader_and_header(sys.stdin)
301
            row_ct = map_table(col_names, reader)
302
        
303
        return row_ct
304
    
305
    def process_inputs(root, row_ready):
306
        row_ct = 0
307
        for map_path in map_paths:
308
            row_ct += process_input(root, row_ready, map_path)
309
        return row_ct
310
    
311
    pool.share_vars(locals())
312
    if out_is_db:
313
        import db_xml
314
        
315
        out_db = connect_db(out_db_config)
316
        try:
317
            if redo: sql.empty_db(out_db)
318
            row_ins_ct_ref = [0]
319
            pool.share_vars(locals())
320
            
321
            def row_ready(row_num, input_row):
322
                def on_error(e):
323
                    exc.add_msg(e, term.emph('row #:')+' '+str(row_num+1))
324
                        # row # is interally 0-based, but 1-based to the user
325
                    exc.add_msg(e, term.emph('input row:')+'\n'+str(input_row))
326
                    exc.add_msg(e, term.emph('output row:')+'\n'+str(root))
327
                    ex_tracker.track(e, row_num)
328
                pool.share_vars(locals())
329
                
330
                xml_func.process(root, on_error)
331
                if not xml_dom.is_empty(root):
332
                    assert xml_dom.has_one_child(root)
333
                    try:
334
                        sql.with_savepoint(out_db,
335
                            lambda: db_xml.put(out_db, root.firstChild,
336
                                row_ins_ct_ref, on_error))
337
                        if commit: out_db.db.commit()
338
                    except sql.DatabaseErrors, e: on_error(e)
339
                prep_root()
340
            
341
            row_ct = process_inputs(root, row_ready)
342
            sys.stdout.write('Inserted '+str(row_ins_ct_ref[0])+
343
                ' new rows into database\n')
344
            
345
            # Consume asynchronous tasks
346
            pool.main_loop()
347
        finally:
348
            out_db.db.rollback()
349
            out_db.db.close()
350
    else:
351
        def on_error(e): ex_tracker.track(e)
352
        def row_ready(row_num, input_row): pass
353
        row_ct = process_inputs(root, row_ready)
354
        xml_func.process(root, on_error)
355
        if out_is_xml_ref[0]:
356
            doc.writexml(sys.stdout, **xml_dom.prettyxml_config)
357
        else: # output is CSV
358
            raise NotImplementedError('CSV output not supported yet')
359
    
360
    # Consume any asynchronous tasks not already consumed above
361
    pool.main_loop()
362
    
363
    profiler.stop(row_ct)
364
    ex_tracker.add_iters(row_ct)
365
    if verbose:
366
        sys.stderr.write('Processed '+str(row_ct)+' input rows\n')
367
        sys.stderr.write(profiler.msg()+'\n')
368
        sys.stderr.write(ex_tracker.msg()+'\n')
369
    ex_tracker.exit()
370

    
371
def main():
372
    try: main_()
373
    except Parser.SyntaxError, e: raise SystemExit(str(e))
374

    
375
if __name__ == '__main__':
376
    profile_to = opts.get_env_var('profile_to', None)
377
    if profile_to != None:
378
        import cProfile
379
        sys.stderr.write('Profiling to '+profile_to+'\n')
380
        cProfile.run(main.func_code, profile_to)
381
    else: main()
(25-25/47)