Project

General

Profile

1
#!/usr/bin/env python
2
# Maps one datasource to another, using a map spreadsheet if needed
3
# Exit status is the # of errors in the import, up to the maximum exit status
4
# For outputting an XML file to a PostgreSQL database, use the general format of
5
# http://vegbank.org/vegdocs/xml/vegbank_example_ver1.0.2.xml
6

    
7
import csv
8
import itertools
9
import os.path
10
import sys
11
import xml.dom.minidom as minidom
12

    
13
sys.path.append(os.path.dirname(__file__)+"/../lib")
14

    
15
import csvs
16
import exc
17
import iters
18
import maps
19
import opts
20
import Parser
21
import profiling
22
import sql
23
import streams
24
import strings
25
import term
26
import util
27
import xpath
28
import xml_dom
29
import xml_func
30
import xml_parse
31

    
32
class Pool:
33
    def apply_async(func, args=None, kw_args=None, callback=None):
34
        if args == None: args = ()
35
        if kwds == None: kwds = {}
36
        if callback == None: callback = lambda v: None
37
        
38
        callback(func(*args, **kw_args))
39

    
40
def get_with_prefix(map_, prefixes, key):
41
    '''Gets all entries for the given key with any of the given prefixes'''
42
    values = []
43
    for key_ in strings.with_prefixes(['']+prefixes, key): # also with no prefix
44
        try: value = map_[key_]
45
        except KeyError, e: continue # keep going
46
        values.append(value)
47
    
48
    if values != []: return values
49
    else: raise e # re-raise last KeyError
50

    
51
def metadata_value(name): return None # this feature has been removed
52

    
53
def cleanup(val):
54
    if val == None: return val
55
    return util.none_if(strings.cleanup(strings.ustr(val)), u'', u'\\N')
56

    
57
def main_():
58
    env_names = []
59
    def usage_err():
60
        raise SystemExit('Usage: '+opts.env_usage(env_names, True)+' '
61
            +sys.argv[0]+' [map_path...] [<input] [>output]')
62
    
63
    ## Get config from env vars
64
    
65
    # Modes
66
    test = opts.env_flag('test', False, env_names)
67
    commit = opts.env_flag('commit', False, env_names) and not test
68
        # never commit in test mode
69
    redo = opts.env_flag('redo', test, env_names) and not commit
70
        # never redo in commit mode (manually run `make empty_db` instead)
71
    
72
    # Ranges
73
    start = util.cast(int, opts.get_env_var('start', '0', env_names))
74
    if test: end_default = 1
75
    else: end_default = None
76
    end = util.cast(int, util.none_if(
77
        opts.get_env_var('n', end_default, env_names), u''))
78
    if end != None: end += start
79
    
80
    # Optimization
81
    if test: cpus_default = 0
82
    else: cpus_default = None
83
    cpus = util.cast(int, util.none_if(opts.get_env_var('cpus', cpus_default,
84
        env_names), u''))
85
    
86
    # Debugging
87
    debug = opts.env_flag('debug', False, env_names)
88
    sql.run_raw_query.debug = debug
89
    verbose = debug or opts.env_flag('verbose', not test, env_names)
90
    opts.get_env_var('profile_to', None, env_names) # add to env_names
91
    
92
    # DB
93
    db_config_names = ['engine', 'host', 'user', 'password', 'database']
94
    def get_db_config(prefix):
95
        return opts.get_env_vars(db_config_names, prefix, env_names)
96
    in_db_config = get_db_config('in')
97
    out_db_config = get_db_config('out')
98
    in_is_db = 'engine' in in_db_config
99
    out_is_db = 'engine' in out_db_config
100
    
101
    ##
102
    
103
    # Logging
104
    def log(msg, on=verbose):
105
        if on: sys.stderr.write(msg)
106
    def log_start(action, on=verbose): log(action+'...\n', on)
107
    
108
    # Parse args
109
    map_paths = sys.argv[1:]
110
    if map_paths == []:
111
        if in_is_db or not out_is_db: usage_err()
112
        else: map_paths = [None]
113
    
114
    def connect_db(db_config):
115
        log_start('Connecting to '+sql.db_config_str(db_config))
116
        return sql.connect(db_config)
117
    
118
    if end != None: end_str = str(end-1)
119
    else: end_str = 'end'
120
    log_start('Processing input rows '+str(start)+'-'+end_str)
121
    
122
    ex_tracker = exc.ExPercentTracker(iter_text='row')
123
    profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
124
    
125
    # Parallel processing
126
    try:
127
        if cpus == 0: raise ImportError('Parallel processing turned off')
128
        import multiprocessing
129
        import multiprocessing.pool
130
    except ImportError, e:
131
        log_start('Not using parallel processing: '+str(e))
132
        job_server = Pool()
133
    else:
134
        if cpus == None: cpus = multiprocessing.cpu_count()
135
        log_start('Using '+str(cpus)+' CPUs')
136
        job_server = multiprocessing.pool.Pool(processes=cpus)
137
    
138
    doc = xml_dom.create_doc()
139
    root = doc.documentElement
140
    out_is_xml_ref = [False]
141
    in_label_ref = [None]
142
    def update_in_label():
143
        if in_label_ref[0] != None:
144
            xpath.get(root, '/_ignore/inLabel="'+in_label_ref[0]+'"', True)
145
    def prep_root():
146
        root.clear()
147
        update_in_label()
148
    prep_root()
149
    
150
    def process_input(root, row_ready, map_path):
151
        '''Inputs datasource to XML tree, mapping if needed'''
152
        # Load map header
153
        in_is_xpaths = True
154
        out_is_xpaths = True
155
        out_label = None
156
        if map_path != None:
157
            metadata = []
158
            mappings = []
159
            stream = open(map_path, 'rb')
160
            reader = csv.reader(stream)
161
            in_label, out_label = reader.next()[:2]
162
            
163
            def split_col_name(name):
164
                label, sep, root = name.partition(':')
165
                label, sep2, prefixes_str = label.partition('[')
166
                prefixes_str = strings.remove_suffix(']', prefixes_str)
167
                prefixes = strings.split(',', prefixes_str)
168
                return label, sep != '', root, prefixes
169
                    # extract datasrc from "datasrc[data_format]"
170
            
171
            in_label, in_root, prefixes = maps.col_info(in_label)
172
            in_is_xpaths = in_root != None
173
            in_label_ref[0] = in_label
174
            update_in_label()
175
            out_label, out_root = maps.col_info(out_label)[:2]
176
            out_is_xpaths = out_root != None
177
            if out_is_xpaths: has_types = out_root.find('/*s/') >= 0
178
                # outer elements are types
179
            
180
            for row in reader:
181
                in_, out = row[:2]
182
                if out != '':
183
                    if out_is_xpaths: out = xpath.parse(out_root+out)
184
                    mappings.append((in_, out))
185
            
186
            stream.close()
187
            
188
            root.ownerDocument.documentElement.tagName = out_label
189
        in_is_xml = in_is_xpaths and not in_is_db
190
        out_is_xml_ref[0] = out_is_xpaths and not out_is_db
191
        
192
        def process_rows(process_row, rows):
193
            '''Processes input rows
194
            @param process_row(in_row, i)
195
            '''
196
            i = 0
197
            while end == None or i < end:
198
                try: row = rows.next()
199
                except StopIteration: break # no more rows
200
                if i < start: continue # not at start row yet
201
                
202
                process_row(row, i)
203
                row_ready(i, row)
204
                i += 1
205
            row_ct = i-start
206
            return row_ct
207
        
208
        def map_rows(get_value, rows):
209
            '''Maps input rows
210
            @param get_value(in_, row):str
211
            '''
212
            def process_row(row, i):
213
                row_id = str(i)
214
                for in_, out in mappings:
215
                    value = metadata_value(in_)
216
                    if value == None:
217
                        log_start('Getting '+str(in_), debug)
218
                        value = cleanup(get_value(in_, row))
219
                    if value != None:
220
                        log_start('Putting '+str(out), debug)
221
                        xpath.put_obj(root, out, row_id, has_types, value)
222
            return process_rows(process_row, rows)
223
        
224
        def map_table(col_names, rows):
225
            col_names_ct = len(col_names)
226
            col_idxs = util.list_flip(col_names)
227
            
228
            i = 0
229
            while i < len(mappings): # mappings len changes in loop
230
                in_, out = mappings[i]
231
                if metadata_value(in_) == None:
232
                    try: mappings[i] = (
233
                        get_with_prefix(col_idxs, prefixes, in_), out)
234
                    except KeyError:
235
                        del mappings[i]
236
                        continue # keep i the same
237
                i += 1
238
            
239
            def get_value(in_, row):
240
                return util.coalesce(*util.list_subset(row.list, in_))
241
            def wrap_row(row):
242
                return util.ListDict(util.list_as_length(row, col_names_ct),
243
                    col_names, col_idxs) # handle CSV rows of different lengths
244
            
245
            return map_rows(get_value, util.WrapIter(wrap_row, rows))
246
        
247
        stdin = streams.LineCountStream(sys.stdin)
248
        def on_error(e):
249
            exc.add_msg(e, term.emph('input line #:')+' '+str(stdin.line_num))
250
            ex_tracker.track(e)
251
        
252
        if in_is_db:
253
            assert in_is_xpaths
254
            
255
            in_db = connect_db(in_db_config)
256
            cur = sql.select(in_db, table=in_root, fields=None, conds=None,
257
                limit=end, start=0)
258
            row_ct = map_table(list(sql.col_names(cur)), sql.rows(cur))
259
            
260
            in_db.db.close()
261
        elif in_is_xml:
262
            def get_rows(doc2rows):
263
                return iters.flatten(itertools.imap(doc2rows,
264
                    xml_parse.docs_iter(stdin, on_error)))
265
            
266
            if map_path == None:
267
                def doc2rows(in_xml_root):
268
                    iter_ = xml_dom.NodeElemIter(in_xml_root)
269
                    util.skip(iter_, xml_dom.is_text) # skip metadata
270
                    return iter_
271
                
272
                row_ct = process_rows(lambda row, i: root.appendChild(row),
273
                    get_rows(doc2rows))
274
            else:
275
                def doc2rows(in_xml_root):
276
                    rows = xpath.get(in_xml_root, in_root, limit=end)
277
                    if rows == []: raise SystemExit('Map error: Root "'
278
                        +in_root+'" not found in input')
279
                    return rows
280
                
281
                def get_value(in_, row):
282
                    in_ = './{'+(','.join(strings.with_prefixes(
283
                        ['']+prefixes, in_)))+'}' # also with no prefix
284
                    nodes = xpath.get(row, in_, allow_rooted=False)
285
                    if nodes != []: return xml_dom.value(nodes[0])
286
                    else: return None
287
                
288
                row_ct = map_rows(get_value, get_rows(doc2rows))
289
        else: # input is CSV
290
            map_ = dict(mappings)
291
            reader, col_names = csvs.reader_and_header(sys.stdin)
292
            row_ct = map_table(col_names, reader)
293
        
294
        return row_ct
295
    
296
    def process_inputs(root, row_ready):
297
        row_ct = 0
298
        for map_path in map_paths:
299
            row_ct += process_input(root, row_ready, map_path)
300
        return row_ct
301
    
302
    if out_is_db:
303
        import db_xml
304
        
305
        out_db = connect_db(out_db_config)
306
        try:
307
            if redo: sql.empty_db(out_db)
308
            row_ins_ct_ref = [0]
309
            
310
            def row_ready(row_num, input_row):
311
                def on_error(e):
312
                    exc.add_msg(e, term.emph('row #:')+' '+str(row_num))
313
                    exc.add_msg(e, term.emph('input row:')+'\n'+str(input_row))
314
                    exc.add_msg(e, term.emph('output row:')+'\n'+str(root))
315
                    ex_tracker.track(e, row_num)
316
                
317
                xml_func.process(root, on_error)
318
                if not xml_dom.is_empty(root):
319
                    assert xml_dom.has_one_child(root)
320
                    try:
321
                        sql.with_savepoint(out_db,
322
                            lambda: db_xml.put(out_db, root.firstChild,
323
                                row_ins_ct_ref, on_error))
324
                        if commit: out_db.db.commit()
325
                    except sql.DatabaseErrors, e: on_error(e)
326
                prep_root()
327
            
328
            row_ct = process_inputs(root, row_ready)
329
            sys.stdout.write('Inserted '+str(row_ins_ct_ref[0])+
330
                ' new rows into database\n')
331
        finally:
332
            out_db.db.rollback()
333
            out_db.db.close()
334
    else:
335
        def on_error(e): ex_tracker.track(e)
336
        def row_ready(row_num, input_row): pass
337
        row_ct = process_inputs(root, row_ready)
338
        xml_func.process(root, on_error)
339
        if out_is_xml_ref[0]:
340
            doc.writexml(sys.stdout, **xml_dom.prettyxml_config)
341
        else: # output is CSV
342
            raise NotImplementedError('CSV output not supported yet')
343
    
344
    profiler.stop(row_ct)
345
    ex_tracker.add_iters(row_ct)
346
    if verbose:
347
        sys.stderr.write('Processed '+str(row_ct)+' input rows\n')
348
        sys.stderr.write(profiler.msg()+'\n')
349
        sys.stderr.write(ex_tracker.msg()+'\n')
350
    ex_tracker.exit()
351

    
352
def main():
353
    try: main_()
354
    except Parser.SyntaxError, e: raise SystemExit(str(e))
355

    
356
if __name__ == '__main__':
357
    profile_to = opts.get_env_var('profile_to', None)
358
    if profile_to != None:
359
        import cProfile
360
        sys.stderr.write('Profiling to '+profile_to+'\n')
361
        cProfile.run(main.func_code, profile_to)
362
    else: main()
(23-23/43)