Project

General

Profile

1
#!/usr/bin/env python
2
# Maps one datasource to another, using a map spreadsheet if needed
3
# Exit status is the # of errors in the import, up to the maximum exit status
4
# For outputting an XML file to a PostgreSQL database, use the general format of
5
# http://vegbank.org/vegdocs/xml/vegbank_example_ver1.0.2.xml
6

    
7
import csv
8
import itertools
9
import os.path
10
import sys
11
import xml.dom.minidom as minidom
12

    
13
sys.path.append(os.path.dirname(__file__)+"/../lib")
14

    
15
import csvs
16
import exc
17
import iters
18
import maps
19
import opts
20
import Parser
21
import profiling
22
import sql
23
import streams
24
import strings
25
import term
26
import util
27
import xpath
28
import xml_dom
29
import xml_func
30
import xml_parse
31

    
32
def get_with_prefix(map_, prefixes, key):
33
    '''Gets all entries for the given key with any of the given prefixes'''
34
    values = []
35
    for key_ in strings.with_prefixes(['']+prefixes, key): # also with no prefix
36
        try: value = map_[key_]
37
        except KeyError, e: continue # keep going
38
        values.append(value)
39
    
40
    if values != []: return values
41
    else: raise e # re-raise last KeyError
42

    
43
def metadata_value(name): return None # this feature has been removed
44

    
45
def cleanup(val):
46
    if val == None: return val
47
    return util.none_if(strings.cleanup(strings.ustr(val)), u'', u'\\N')
48

    
49
def main_():
50
    env_names = []
51
    def usage_err():
52
        raise SystemExit('Usage: '+opts.env_usage(env_names, True)+' '
53
            +sys.argv[0]+' [map_path...] [<input] [>output]')
54
    
55
    ## Get config from env vars
56
    
57
    # Modes
58
    test = opts.env_flag('test', False, env_names)
59
    commit = opts.env_flag('commit', False, env_names) and not test
60
        # never commit in test mode
61
    redo = opts.env_flag('redo', test, env_names) and not commit
62
        # never redo in commit mode (manually run `make empty_db` instead)
63
    
64
    # Ranges
65
    start = util.cast(int, opts.get_env_var('start', '0', env_names))
66
    if test: end_default = 1
67
    else: end_default = None
68
    end = util.cast(int, util.none_if(
69
        opts.get_env_var('n', end_default, env_names), u''))
70
    if end != None: end += start
71
    
72
    # Optimization
73
    if test: cpus_default = 0
74
    else: cpus_default = None
75
    cpus = util.cast(int, util.none_if(opts.get_env_var('cpus', cpus_default,
76
        env_names), u''))
77
    
78
    # Debugging
79
    debug = opts.env_flag('debug', False, env_names)
80
    sql.run_raw_query.debug = debug
81
    verbose = debug or opts.env_flag('verbose', not test, env_names)
82
    opts.get_env_var('profile_to', None, env_names) # add to env_names
83
    
84
    # DB
85
    db_config_names = ['engine', 'host', 'user', 'password', 'database']
86
    def get_db_config(prefix):
87
        return opts.get_env_vars(db_config_names, prefix, env_names)
88
    in_db_config = get_db_config('in')
89
    out_db_config = get_db_config('out')
90
    in_is_db = 'engine' in in_db_config
91
    out_is_db = 'engine' in out_db_config
92
    
93
    ##
94
    
95
    # Logging
96
    def log(msg, on=verbose):
97
        if on: sys.stderr.write(msg)
98
    def log_start(action, on=verbose): log(action+'...\n', on)
99
    
100
    # Parse args
101
    map_paths = sys.argv[1:]
102
    if map_paths == []:
103
        if in_is_db or not out_is_db: usage_err()
104
        else: map_paths = [None]
105
    
106
    def connect_db(db_config):
107
        log_start('Connecting to '+sql.db_config_str(db_config))
108
        return sql.connect(db_config)
109
    
110
    if end != None: end_str = str(end-1)
111
    else: end_str = 'end'
112
    log_start('Processing input rows '+str(start)+'-'+end_str)
113
    
114
    ex_tracker = exc.ExPercentTracker(iter_text='row')
115
    profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
116
    
117
    # Parallel processing
118
    if cpus != 0:
119
        try:
120
            import multiprocessing
121
            import multiprocessing.pool
122
        except ImportError: pass
123
        else:
124
            if cpus == None: cpus = multiprocessing.cpu_count()
125
            job_server = multiprocessing.pool.Pool(processes=cpus)
126
            log_start('Using '+str(cpus)+' CPUs')
127
    
128
    doc = xml_dom.create_doc()
129
    root = doc.documentElement
130
    out_is_xml_ref = [False]
131
    in_label_ref = [None]
132
    def update_in_label():
133
        if in_label_ref[0] != None:
134
            xpath.get(root, '/_ignore/inLabel="'+in_label_ref[0]+'"', True)
135
    def prep_root():
136
        root.clear()
137
        update_in_label()
138
    prep_root()
139
    
140
    def process_input(root, row_ready, map_path):
141
        '''Inputs datasource to XML tree, mapping if needed'''
142
        # Load map header
143
        in_is_xpaths = True
144
        out_is_xpaths = True
145
        out_label = None
146
        if map_path != None:
147
            metadata = []
148
            mappings = []
149
            stream = open(map_path, 'rb')
150
            reader = csv.reader(stream)
151
            in_label, out_label = reader.next()[:2]
152
            
153
            def split_col_name(name):
154
                label, sep, root = name.partition(':')
155
                label, sep2, prefixes_str = label.partition('[')
156
                prefixes_str = strings.remove_suffix(']', prefixes_str)
157
                prefixes = strings.split(',', prefixes_str)
158
                return label, sep != '', root, prefixes
159
                    # extract datasrc from "datasrc[data_format]"
160
            
161
            in_label, in_root, prefixes = maps.col_info(in_label)
162
            in_is_xpaths = in_root != None
163
            in_label_ref[0] = in_label
164
            update_in_label()
165
            out_label, out_root = maps.col_info(out_label)[:2]
166
            out_is_xpaths = out_root != None
167
            if out_is_xpaths: has_types = out_root.find('/*s/') >= 0
168
                # outer elements are types
169
            
170
            for row in reader:
171
                in_, out = row[:2]
172
                if out != '':
173
                    if out_is_xpaths: out = xpath.parse(out_root+out)
174
                    mappings.append((in_, out))
175
            
176
            stream.close()
177
            
178
            root.ownerDocument.documentElement.tagName = out_label
179
        in_is_xml = in_is_xpaths and not in_is_db
180
        out_is_xml_ref[0] = out_is_xpaths and not out_is_db
181
        
182
        def process_rows(process_row, rows):
183
            '''Processes input rows
184
            @param process_row(in_row, i)
185
            '''
186
            i = 0
187
            while end == None or i < end:
188
                try: row = rows.next()
189
                except StopIteration: break # no more rows
190
                if i < start: continue # not at start row yet
191
                
192
                process_row(row, i)
193
                row_ready(i, row)
194
                i += 1
195
            row_ct = i-start
196
            return row_ct
197
        
198
        def map_rows(get_value, rows):
199
            '''Maps input rows
200
            @param get_value(in_, row):str
201
            '''
202
            def process_row(row, i):
203
                row_id = str(i)
204
                for in_, out in mappings:
205
                    value = metadata_value(in_)
206
                    if value == None:
207
                        log_start('Getting '+str(in_), debug)
208
                        value = cleanup(get_value(in_, row))
209
                    if value != None:
210
                        log_start('Putting '+str(out), debug)
211
                        xpath.put_obj(root, out, row_id, has_types, value)
212
            return process_rows(process_row, rows)
213
        
214
        def map_table(col_names, rows):
215
            col_names_ct = len(col_names)
216
            col_idxs = util.list_flip(col_names)
217
            
218
            i = 0
219
            while i < len(mappings): # mappings len changes in loop
220
                in_, out = mappings[i]
221
                if metadata_value(in_) == None:
222
                    try: mappings[i] = (
223
                        get_with_prefix(col_idxs, prefixes, in_), out)
224
                    except KeyError:
225
                        del mappings[i]
226
                        continue # keep i the same
227
                i += 1
228
            
229
            def get_value(in_, row):
230
                return util.coalesce(*util.list_subset(row.list, in_))
231
            def wrap_row(row):
232
                return util.ListDict(util.list_as_length(row, col_names_ct),
233
                    col_names, col_idxs) # handle CSV rows of different lengths
234
            
235
            return map_rows(get_value, util.WrapIter(wrap_row, rows))
236
        
237
        stdin = streams.LineCountStream(sys.stdin)
238
        def on_error(e):
239
            exc.add_msg(e, term.emph('input line #:')+' '+str(stdin.line_num))
240
            ex_tracker.track(e)
241
        
242
        if in_is_db:
243
            assert in_is_xpaths
244
            
245
            in_db = connect_db(in_db_config)
246
            cur = sql.select(in_db, table=in_root, fields=None, conds=None,
247
                limit=end, start=0)
248
            row_ct = map_table(list(sql.col_names(cur)), sql.rows(cur))
249
            
250
            in_db.db.close()
251
        elif in_is_xml:
252
            def get_rows(doc2rows):
253
                return iters.flatten(itertools.imap(doc2rows,
254
                    xml_parse.docs_iter(stdin, on_error)))
255
            
256
            if map_path == None:
257
                def doc2rows(in_xml_root):
258
                    iter_ = xml_dom.NodeElemIter(in_xml_root)
259
                    util.skip(iter_, xml_dom.is_text) # skip metadata
260
                    return iter_
261
                
262
                row_ct = process_rows(lambda row, i: root.appendChild(row),
263
                    get_rows(doc2rows))
264
            else:
265
                def doc2rows(in_xml_root):
266
                    rows = xpath.get(in_xml_root, in_root, limit=end)
267
                    if rows == []: raise SystemExit('Map error: Root "'
268
                        +in_root+'" not found in input')
269
                    return rows
270
                
271
                def get_value(in_, row):
272
                    in_ = './{'+(','.join(strings.with_prefixes(
273
                        ['']+prefixes, in_)))+'}' # also with no prefix
274
                    nodes = xpath.get(row, in_, allow_rooted=False)
275
                    if nodes != []: return xml_dom.value(nodes[0])
276
                    else: return None
277
                
278
                row_ct = map_rows(get_value, get_rows(doc2rows))
279
        else: # input is CSV
280
            map_ = dict(mappings)
281
            reader, col_names = csvs.reader_and_header(sys.stdin)
282
            row_ct = map_table(col_names, reader)
283
        
284
        return row_ct
285
    
286
    def process_inputs(root, row_ready):
287
        row_ct = 0
288
        for map_path in map_paths:
289
            row_ct += process_input(root, row_ready, map_path)
290
        return row_ct
291
    
292
    if out_is_db:
293
        import db_xml
294
        
295
        out_db = connect_db(out_db_config)
296
        try:
297
            if redo: sql.empty_db(out_db)
298
            row_ins_ct_ref = [0]
299
            
300
            def row_ready(row_num, input_row):
301
                def on_error(e):
302
                    exc.add_msg(e, term.emph('row #:')+' '+str(row_num))
303
                    exc.add_msg(e, term.emph('input row:')+'\n'+str(input_row))
304
                    exc.add_msg(e, term.emph('output row:')+'\n'+str(root))
305
                    ex_tracker.track(e, row_num)
306
                
307
                xml_func.process(root, on_error)
308
                if not xml_dom.is_empty(root):
309
                    assert xml_dom.has_one_child(root)
310
                    try:
311
                        sql.with_savepoint(out_db,
312
                            lambda: db_xml.put(out_db, root.firstChild,
313
                                row_ins_ct_ref, on_error))
314
                        if commit: out_db.db.commit()
315
                    except sql.DatabaseErrors, e: on_error(e)
316
                prep_root()
317
            
318
            row_ct = process_inputs(root, row_ready)
319
            sys.stdout.write('Inserted '+str(row_ins_ct_ref[0])+
320
                ' new rows into database\n')
321
        finally:
322
            out_db.db.rollback()
323
            out_db.db.close()
324
    else:
325
        def on_error(e): ex_tracker.track(e)
326
        def row_ready(row_num, input_row): pass
327
        row_ct = process_inputs(root, row_ready)
328
        xml_func.process(root, on_error)
329
        if out_is_xml_ref[0]:
330
            doc.writexml(sys.stdout, **xml_dom.prettyxml_config)
331
        else: # output is CSV
332
            raise NotImplementedError('CSV output not supported yet')
333
    
334
    profiler.stop(row_ct)
335
    ex_tracker.add_iters(row_ct)
336
    if verbose:
337
        sys.stderr.write('Processed '+str(row_ct)+' input rows\n')
338
        sys.stderr.write(profiler.msg()+'\n')
339
        sys.stderr.write(ex_tracker.msg()+'\n')
340
    ex_tracker.exit()
341

    
342
def main():
343
    try: main_()
344
    except Parser.SyntaxError, e: raise SystemExit(str(e))
345

    
346
if __name__ == '__main__':
347
    profile_to = opts.get_env_var('profile_to', None)
348
    if profile_to != None:
349
        import cProfile
350
        sys.stderr.write('Profiling to '+profile_to+'\n')
351
        cProfile.run(main.func_code, profile_to)
352
    else: main()
(23-23/43)