Project

General

Profile

1
#!/usr/bin/env python
2
# Maps one datasource to another, using a map spreadsheet if needed
3
# Exit status is the # of errors in the import, up to the maximum exit status
4
# For outputting an XML file to a PostgreSQL database, use the general format of
5
# http://vegbank.org/vegdocs/xml/vegbank_example_ver1.0.2.xml
6

    
7
import csv
8
import itertools
9
import os.path
10
import sys
11
import xml.dom.minidom as minidom
12

    
13
sys.path.append(os.path.dirname(__file__)+"/../lib")
14

    
15
import csvs
16
import exc
17
import iters
18
import maps
19
import opts
20
import Parser
21
import profiling
22
import sql
23
import streams
24
import strings
25
import term
26
import util
27
import xpath
28
import xml_dom
29
import xml_func
30
import xml_parse
31

    
32
def get_with_prefix(map_, prefixes, key):
33
    '''Gets all entries for the given key with any of the given prefixes'''
34
    values = []
35
    for key_ in strings.with_prefixes(['']+prefixes, key): # also with no prefix
36
        try: value = map_[key_]
37
        except KeyError, e: continue # keep going
38
        values.append(value)
39
    
40
    if values != []: return values
41
    else: raise e # re-raise last KeyError
42

    
43
def metadata_value(name): return None # this feature has been removed
44

    
45
def cleanup(val):
46
    if val == None: return val
47
    return util.none_if(strings.cleanup(strings.ustr(val)), u'', u'\\N')
48

    
49
def main_():
50
    env_names = []
51
    def usage_err():
52
        raise SystemExit('Usage: '+opts.env_usage(env_names, True)+' '
53
            +sys.argv[0]+' [map_path...] [<input] [>output]')
54
    
55
    ## Get config from env vars
56
    
57
    # Modes
58
    test = opts.env_flag('test', False, env_names)
59
    commit = opts.env_flag('commit', False, env_names) and not test
60
        # never commit in test mode
61
    redo = opts.env_flag('redo', test, env_names) and not commit
62
        # never redo in commit mode (manually run `make empty_db` instead)
63
    
64
    # Ranges
65
    start = util.cast(int, opts.get_env_var('start', '0', env_names))
66
    if test: end_default = 1
67
    else: end_default = None
68
    end = util.cast(int, util.none_if(
69
        opts.get_env_var('n', end_default, env_names), u''))
70
    if end != None: end += start
71
    
72
    # Optimization
73
    if test: cpus_default = 0
74
    else: cpus_default = None
75
    cpus = util.cast(int, opts.get_env_var('cpus', cpus_default, env_names))
76
    
77
    # Debugging
78
    debug = opts.env_flag('debug', False, env_names)
79
    sql.run_raw_query.debug = debug
80
    verbose = debug or opts.env_flag('verbose', not test, env_names)
81
    opts.get_env_var('profile_to', None, env_names) # add to env_names
82
    
83
    # DB
84
    db_config_names = ['engine', 'host', 'user', 'password', 'database']
85
    def get_db_config(prefix):
86
        return opts.get_env_vars(db_config_names, prefix, env_names)
87
    in_db_config = get_db_config('in')
88
    out_db_config = get_db_config('out')
89
    in_is_db = 'engine' in in_db_config
90
    out_is_db = 'engine' in out_db_config
91
    
92
    ##
93
    
94
    # Logging
95
    def log(msg, on=verbose):
96
        if on: sys.stderr.write(msg)
97
    def log_start(action, on=verbose): log(action+'...\n', on)
98
    
99
    # Parse args
100
    map_paths = sys.argv[1:]
101
    if map_paths == []:
102
        if in_is_db or not out_is_db: usage_err()
103
        else: map_paths = [None]
104
    
105
    def connect_db(db_config):
106
        log_start('Connecting to '+sql.db_config_str(db_config))
107
        return sql.connect(db_config)
108
    
109
    if end != None: end_str = str(end-1)
110
    else: end_str = 'end'
111
    log_start('Processing input rows '+str(start)+'-'+end_str)
112
    
113
    ex_tracker = exc.ExPercentTracker(iter_text='row')
114
    profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
115
    
116
    if cpus != 0:
117
        try: import pp
118
        except ImportError: pass
119
        else:
120
            if cpus != None: args = [cpus]
121
            else: args = []
122
            job_server = pp.Server(*args)
123
            log_start('Using '+str(job_server.get_ncpus())+' CPUs')
124
    
125
    doc = xml_dom.create_doc()
126
    root = doc.documentElement
127
    out_is_xml_ref = [False]
128
    in_label_ref = [None]
129
    def update_in_label():
130
        if in_label_ref[0] != None:
131
            xpath.get(root, '/_ignore/inLabel="'+in_label_ref[0]+'"', True)
132
    def prep_root():
133
        root.clear()
134
        update_in_label()
135
    prep_root()
136
    
137
    def process_input(root, row_ready, map_path):
138
        '''Inputs datasource to XML tree, mapping if needed'''
139
        # Load map header
140
        in_is_xpaths = True
141
        out_is_xpaths = True
142
        out_label = None
143
        if map_path != None:
144
            metadata = []
145
            mappings = []
146
            stream = open(map_path, 'rb')
147
            reader = csv.reader(stream)
148
            in_label, out_label = reader.next()[:2]
149
            
150
            def split_col_name(name):
151
                label, sep, root = name.partition(':')
152
                label, sep2, prefixes_str = label.partition('[')
153
                prefixes_str = strings.remove_suffix(']', prefixes_str)
154
                prefixes = strings.split(',', prefixes_str)
155
                return label, sep != '', root, prefixes
156
                    # extract datasrc from "datasrc[data_format]"
157
            
158
            in_label, in_root, prefixes = maps.col_info(in_label)
159
            in_is_xpaths = in_root != None
160
            in_label_ref[0] = in_label
161
            update_in_label()
162
            out_label, out_root = maps.col_info(out_label)[:2]
163
            out_is_xpaths = out_root != None
164
            if out_is_xpaths: has_types = out_root.find('/*s/') >= 0
165
                # outer elements are types
166
            
167
            for row in reader:
168
                in_, out = row[:2]
169
                if out != '':
170
                    if out_is_xpaths: out = xpath.parse(out_root+out)
171
                    mappings.append((in_, out))
172
            
173
            stream.close()
174
            
175
            root.ownerDocument.documentElement.tagName = out_label
176
        in_is_xml = in_is_xpaths and not in_is_db
177
        out_is_xml_ref[0] = out_is_xpaths and not out_is_db
178
        
179
        def process_rows(process_row, rows):
180
            '''Processes input rows
181
            @param process_row(in_row, i)
182
            '''
183
            i = 0
184
            while end == None or i < end:
185
                try: row = rows.next()
186
                except StopIteration: break # no more rows
187
                if i < start: continue # not at start row yet
188
                
189
                process_row(row, i)
190
                row_ready(i, row)
191
                i += 1
192
            row_ct = i-start
193
            return row_ct
194
        
195
        def map_rows(get_value, rows):
196
            '''Maps input rows
197
            @param get_value(in_, row):str
198
            '''
199
            def process_row(row, i):
200
                row_id = str(i)
201
                for in_, out in mappings:
202
                    value = metadata_value(in_)
203
                    if value == None:
204
                        log_start('Getting '+str(in_), debug)
205
                        value = cleanup(get_value(in_, row))
206
                    if value != None:
207
                        log_start('Putting '+str(out), debug)
208
                        xpath.put_obj(root, out, row_id, has_types, value)
209
            return process_rows(process_row, rows)
210
        
211
        def map_table(col_names, rows):
212
            col_names_ct = len(col_names)
213
            col_idxs = util.list_flip(col_names)
214
            
215
            i = 0
216
            while i < len(mappings): # mappings len changes in loop
217
                in_, out = mappings[i]
218
                if metadata_value(in_) == None:
219
                    try: mappings[i] = (
220
                        get_with_prefix(col_idxs, prefixes, in_), out)
221
                    except KeyError:
222
                        del mappings[i]
223
                        continue # keep i the same
224
                i += 1
225
            
226
            def get_value(in_, row):
227
                return util.coalesce(*util.list_subset(row.list, in_))
228
            def wrap_row(row):
229
                return util.ListDict(util.list_as_length(row, col_names_ct),
230
                    col_names, col_idxs) # handle CSV rows of different lengths
231
            
232
            return map_rows(get_value, util.WrapIter(wrap_row, rows))
233
        
234
        stdin = streams.LineCountStream(sys.stdin)
235
        def on_error(e):
236
            exc.add_msg(e, term.emph('input line #:')+' '+str(stdin.line_num))
237
            ex_tracker.track(e)
238
        
239
        if in_is_db:
240
            assert in_is_xpaths
241
            
242
            in_db = connect_db(in_db_config)
243
            cur = sql.select(in_db, table=in_root, fields=None, conds=None,
244
                limit=end, start=0)
245
            row_ct = map_table(list(sql.col_names(cur)), sql.rows(cur))
246
            
247
            in_db.db.close()
248
        elif in_is_xml:
249
            def get_rows(doc2rows):
250
                return iters.flatten(itertools.imap(doc2rows,
251
                    xml_parse.docs_iter(stdin, on_error)))
252
            
253
            if map_path == None:
254
                def doc2rows(in_xml_root):
255
                    iter_ = xml_dom.NodeElemIter(in_xml_root)
256
                    util.skip(iter_, xml_dom.is_text) # skip metadata
257
                    return iter_
258
                
259
                row_ct = process_rows(lambda row, i: root.appendChild(row),
260
                    get_rows(doc2rows))
261
            else:
262
                def doc2rows(in_xml_root):
263
                    rows = xpath.get(in_xml_root, in_root, limit=end)
264
                    if rows == []: raise SystemExit('Map error: Root "'
265
                        +in_root+'" not found in input')
266
                    return rows
267
                
268
                def get_value(in_, row):
269
                    in_ = './{'+(','.join(strings.with_prefixes(
270
                        ['']+prefixes, in_)))+'}' # also with no prefix
271
                    nodes = xpath.get(row, in_, allow_rooted=False)
272
                    if nodes != []: return xml_dom.value(nodes[0])
273
                    else: return None
274
                
275
                row_ct = map_rows(get_value, get_rows(doc2rows))
276
        else: # input is CSV
277
            map_ = dict(mappings)
278
            reader, col_names = csvs.reader_and_header(sys.stdin)
279
            row_ct = map_table(col_names, reader)
280
        
281
        return row_ct
282
    
283
    def process_inputs(root, row_ready):
284
        row_ct = 0
285
        for map_path in map_paths:
286
            row_ct += process_input(root, row_ready, map_path)
287
        return row_ct
288
    
289
    if out_is_db:
290
        import db_xml
291
        
292
        out_db = connect_db(out_db_config)
293
        try:
294
            if redo: sql.empty_db(out_db)
295
            row_ins_ct_ref = [0]
296
            
297
            def row_ready(row_num, input_row):
298
                def on_error(e):
299
                    exc.add_msg(e, term.emph('row #:')+' '+str(row_num))
300
                    exc.add_msg(e, term.emph('input row:')+'\n'+str(input_row))
301
                    exc.add_msg(e, term.emph('output row:')+'\n'+str(root))
302
                    ex_tracker.track(e, row_num)
303
                
304
                xml_func.process(root, on_error)
305
                if not xml_dom.is_empty(root):
306
                    assert xml_dom.has_one_child(root)
307
                    try:
308
                        sql.with_savepoint(out_db,
309
                            lambda: db_xml.put(out_db, root.firstChild,
310
                                row_ins_ct_ref, on_error))
311
                        if commit: out_db.db.commit()
312
                    except sql.DatabaseErrors, e: on_error(e)
313
                prep_root()
314
            
315
            row_ct = process_inputs(root, row_ready)
316
            sys.stdout.write('Inserted '+str(row_ins_ct_ref[0])+
317
                ' new rows into database\n')
318
        finally:
319
            out_db.db.rollback()
320
            out_db.db.close()
321
    else:
322
        def on_error(e): ex_tracker.track(e)
323
        def row_ready(row_num, input_row): pass
324
        row_ct = process_inputs(root, row_ready)
325
        xml_func.process(root, on_error)
326
        if out_is_xml_ref[0]:
327
            doc.writexml(sys.stdout, **xml_dom.prettyxml_config)
328
        else: # output is CSV
329
            raise NotImplementedError('CSV output not supported yet')
330
    
331
    profiler.stop(row_ct)
332
    ex_tracker.add_iters(row_ct)
333
    if verbose:
334
        sys.stderr.write('Processed '+str(row_ct)+' input rows\n')
335
        sys.stderr.write(profiler.msg()+'\n')
336
        sys.stderr.write(ex_tracker.msg()+'\n')
337
    ex_tracker.exit()
338

    
339
def main():
340
    try: main_()
341
    except Parser.SyntaxError, e: raise SystemExit(str(e))
342

    
343
if __name__ == '__main__':
344
    profile_to = opts.get_env_var('profile_to', None)
345
    if profile_to != None:
346
        import cProfile
347
        sys.stderr.write('Profiling to '+profile_to+'\n')
348
        cProfile.run(main.func_code, profile_to)
349
    else: main()
(23-23/43)