Project

General

Profile

1
#!/usr/bin/env python
2
# Maps one datasource to another, using a map spreadsheet if needed
3
# Exit status is the # of errors in the import, up to the maximum exit status
4
# For outputting an XML file to a PostgreSQL database, use the general format of
5
# http://vegbank.org/vegdocs/xml/vegbank_example_ver1.0.2.xml
6

    
7
import csv
8
import itertools
9
import os.path
10
import sys
11
import xml.dom.minidom as minidom
12

    
13
sys.path.append(os.path.dirname(__file__)+"/../lib")
14

    
15
import csvs
16
import exc
17
import iters
18
import maps
19
import opts
20
import Parser
21
import profiling
22
import sql
23
import streams
24
import strings
25
import term
26
import util
27
import xpath
28
import xml_dom
29
import xml_func
30
import xml_parse
31

    
32
def get_with_prefix(map_, prefixes, key):
33
    '''Gets all entries for the given key with any of the given prefixes'''
34
    values = []
35
    for key_ in strings.with_prefixes(['']+prefixes, key): # also with no prefix
36
        try: value = map_[key_]
37
        except KeyError, e: continue # keep going
38
        values.append(value)
39
    
40
    if values != []: return values
41
    else: raise e # re-raise last KeyError
42

    
43
def metadata_value(name): return None # this feature has been removed
44

    
45
def cleanup(val):
46
    if val == None: return val
47
    return util.none_if(strings.cleanup(strings.ustr(val)), u'', u'\\N')
48

    
49
def main_():
50
    env_names = []
51
    def usage_err():
52
        raise SystemExit('Usage: '+opts.env_usage(env_names, True)+' '
53
            +sys.argv[0]+' [map_path...] [<input] [>output]')
54
    
55
    ## Get config from env vars
56
    
57
    # Modes
58
    test = opts.env_flag('test', False, env_names)
59
    commit = opts.env_flag('commit', False, env_names) and not test
60
        # never commit in test mode
61
    redo = opts.env_flag('redo', test, env_names) and not commit
62
        # never redo in commit mode (manually run `make empty_db` instead)
63
    
64
    # Ranges
65
    start = util.cast(int, opts.get_env_var('start', '0', env_names))
66
    if test: end_default = 1
67
    else: end_default = None
68
    end = util.cast(int, util.none_if(
69
        opts.get_env_var('n', end_default, env_names), u''))
70
    if end != None: end += start
71
    
72
    # Optimization
73
    cpus = util.cast(int, opts.get_env_var('cpus', None, env_names))
74
    
75
    # Debugging
76
    debug = opts.env_flag('debug', False, env_names)
77
    sql.run_raw_query.debug = debug
78
    verbose = debug or opts.env_flag('verbose', not test, env_names)
79
    opts.get_env_var('profile_to', None, env_names) # add to env_names
80
    
81
    # DB
82
    db_config_names = ['engine', 'host', 'user', 'password', 'database']
83
    def get_db_config(prefix):
84
        return opts.get_env_vars(db_config_names, prefix, env_names)
85
    in_db_config = get_db_config('in')
86
    out_db_config = get_db_config('out')
87
    in_is_db = 'engine' in in_db_config
88
    out_is_db = 'engine' in out_db_config
89
    
90
    ##
91
    
92
    # Logging
93
    def log(msg, on=verbose):
94
        if on: sys.stderr.write(msg)
95
    def log_start(action, on=verbose): log(action+'...\n', on)
96
    
97
    # Parse args
98
    map_paths = sys.argv[1:]
99
    if map_paths == []:
100
        if in_is_db or not out_is_db: usage_err()
101
        else: map_paths = [None]
102
    
103
    def connect_db(db_config):
104
        log_start('Connecting to '+sql.db_config_str(db_config))
105
        return sql.connect(db_config)
106
    
107
    if end != None: end_str = str(end-1)
108
    else: end_str = 'end'
109
    log_start('Processing input rows '+str(start)+'-'+end_str)
110
    
111
    ex_tracker = exc.ExPercentTracker(iter_text='row')
112
    profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
113
    
114
    try: import pp
115
    except ImportError: cpus = 0
116
    if cpus != 0:
117
        if cpus != None: args = [cpus]
118
        else: args = []
119
        job_server = pp.Server(*args)
120
        log_start('Using '+str(job_server.get_ncpus())+' CPUs')
121
    
122
    doc = xml_dom.create_doc()
123
    root = doc.documentElement
124
    out_is_xml_ref = [False]
125
    in_label_ref = [None]
126
    def update_in_label():
127
        if in_label_ref[0] != None:
128
            xpath.get(root, '/_ignore/inLabel="'+in_label_ref[0]+'"', True)
129
    def prep_root():
130
        root.clear()
131
        update_in_label()
132
    prep_root()
133
    
134
    def process_input(root, row_ready, map_path):
135
        '''Inputs datasource to XML tree, mapping if needed'''
136
        # Load map header
137
        in_is_xpaths = True
138
        out_is_xpaths = True
139
        out_label = None
140
        if map_path != None:
141
            metadata = []
142
            mappings = []
143
            stream = open(map_path, 'rb')
144
            reader = csv.reader(stream)
145
            in_label, out_label = reader.next()[:2]
146
            
147
            def split_col_name(name):
148
                label, sep, root = name.partition(':')
149
                label, sep2, prefixes_str = label.partition('[')
150
                prefixes_str = strings.remove_suffix(']', prefixes_str)
151
                prefixes = strings.split(',', prefixes_str)
152
                return label, sep != '', root, prefixes
153
                    # extract datasrc from "datasrc[data_format]"
154
            
155
            in_label, in_root, prefixes = maps.col_info(in_label)
156
            in_is_xpaths = in_root != None
157
            in_label_ref[0] = in_label
158
            update_in_label()
159
            out_label, out_root = maps.col_info(out_label)[:2]
160
            out_is_xpaths = out_root != None
161
            if out_is_xpaths: has_types = out_root.find('/*s/') >= 0
162
                # outer elements are types
163
            
164
            for row in reader:
165
                in_, out = row[:2]
166
                if out != '':
167
                    if out_is_xpaths: out = xpath.parse(out_root+out)
168
                    mappings.append((in_, out))
169
            
170
            stream.close()
171
            
172
            root.ownerDocument.documentElement.tagName = out_label
173
        in_is_xml = in_is_xpaths and not in_is_db
174
        out_is_xml_ref[0] = out_is_xpaths and not out_is_db
175
        
176
        def process_rows(process_row, rows):
177
            '''Processes input rows
178
            @param process_row(in_row, i)
179
            '''
180
            i = 0
181
            while end == None or i < end:
182
                try: row = rows.next()
183
                except StopIteration: break # no more rows
184
                if i < start: continue # not at start row yet
185
                
186
                process_row(row, i)
187
                row_ready(i, row)
188
                i += 1
189
            row_ct = i-start
190
            return row_ct
191
        
192
        def map_rows(get_value, rows):
193
            '''Maps input rows
194
            @param get_value(in_, row):str
195
            '''
196
            def process_row(row, i):
197
                row_id = str(i)
198
                for in_, out in mappings:
199
                    value = metadata_value(in_)
200
                    if value == None:
201
                        log_start('Getting '+str(in_), debug)
202
                        value = cleanup(get_value(in_, row))
203
                    if value != None:
204
                        log_start('Putting '+str(out), debug)
205
                        xpath.put_obj(root, out, row_id, has_types, value)
206
            return process_rows(process_row, rows)
207
        
208
        def map_table(col_names, rows):
209
            col_names_ct = len(col_names)
210
            col_idxs = util.list_flip(col_names)
211
            
212
            i = 0
213
            while i < len(mappings): # mappings len changes in loop
214
                in_, out = mappings[i]
215
                if metadata_value(in_) == None:
216
                    try: mappings[i] = (
217
                        get_with_prefix(col_idxs, prefixes, in_), out)
218
                    except KeyError:
219
                        del mappings[i]
220
                        continue # keep i the same
221
                i += 1
222
            
223
            def get_value(in_, row):
224
                return util.coalesce(*util.list_subset(row.list, in_))
225
            def wrap_row(row):
226
                return util.ListDict(util.list_as_length(row, col_names_ct),
227
                    col_names, col_idxs) # handle CSV rows of different lengths
228
            
229
            return map_rows(get_value, util.WrapIter(wrap_row, rows))
230
        
231
        stdin = streams.LineCountStream(sys.stdin)
232
        def on_error(e):
233
            exc.add_msg(e, term.emph('input line #:')+' '+str(stdin.line_num))
234
            ex_tracker.track(e)
235
        
236
        if in_is_db:
237
            assert in_is_xpaths
238
            
239
            in_db = connect_db(in_db_config)
240
            in_pkeys = {}
241
            cur = sql.select(in_db, table=in_root, fields=None, conds=None,
242
                limit=end, start=0)
243
            row_ct = map_table(list(sql.col_names(cur)), sql.rows(cur))
244
            
245
            in_db.db.close()
246
        elif in_is_xml:
247
            def get_rows(doc2rows):
248
                return iters.flatten(itertools.imap(doc2rows,
249
                    xml_parse.docs_iter(stdin, on_error)))
250
            
251
            if map_path == None:
252
                def doc2rows(in_xml_root):
253
                    iter_ = xml_dom.NodeElemIter(in_xml_root)
254
                    util.skip(iter_, xml_dom.is_text) # skip metadata
255
                    return iter_
256
                
257
                row_ct = process_rows(lambda row, i: root.appendChild(row),
258
                    get_rows(doc2rows))
259
            else:
260
                def doc2rows(in_xml_root):
261
                    rows = xpath.get(in_xml_root, in_root, limit=end)
262
                    if rows == []: raise SystemExit('Map error: Root "'
263
                        +in_root+'" not found in input')
264
                    return rows
265
                
266
                def get_value(in_, row):
267
                    in_ = './{'+(','.join(strings.with_prefixes(
268
                        ['']+prefixes, in_)))+'}' # also with no prefix
269
                    nodes = xpath.get(row, in_, allow_rooted=False)
270
                    if nodes != []: return xml_dom.value(nodes[0])
271
                    else: return None
272
                
273
                row_ct = map_rows(get_value, get_rows(doc2rows))
274
        else: # input is CSV
275
            map_ = dict(mappings)
276
            reader, col_names = csvs.reader_and_header(sys.stdin)
277
            row_ct = map_table(col_names, reader)
278
        
279
        return row_ct
280
    
281
    def process_inputs(root, row_ready):
282
        row_ct = 0
283
        for map_path in map_paths:
284
            row_ct += process_input(root, row_ready, map_path)
285
        return row_ct
286
    
287
    if out_is_db:
288
        import db_xml
289
        
290
        out_db = connect_db(out_db_config)
291
        out_pkeys = {}
292
        try:
293
            if redo: sql.empty_db(out_db)
294
            row_ins_ct_ref = [0]
295
            
296
            def row_ready(row_num, input_row):
297
                def on_error(e):
298
                    exc.add_msg(e, term.emph('row #:')+' '+str(row_num))
299
                    exc.add_msg(e, term.emph('input row:')+'\n'+str(input_row))
300
                    exc.add_msg(e, term.emph('output row:')+'\n'+str(root))
301
                    ex_tracker.track(e, row_num)
302
                
303
                xml_func.process(root, on_error)
304
                if not xml_dom.is_empty(root):
305
                    assert xml_dom.has_one_child(root)
306
                    try:
307
                        sql.with_savepoint(out_db,
308
                            lambda: db_xml.put(out_db, root.firstChild,
309
                                out_pkeys, row_ins_ct_ref, on_error))
310
                        if commit: out_db.db.commit()
311
                    except sql.DatabaseErrors, e: on_error(e)
312
                prep_root()
313
            
314
            row_ct = process_inputs(root, row_ready)
315
            sys.stdout.write('Inserted '+str(row_ins_ct_ref[0])+
316
                ' new rows into database\n')
317
        finally:
318
            out_db.db.rollback()
319
            out_db.db.close()
320
    else:
321
        def on_error(e): ex_tracker.track(e)
322
        def row_ready(row_num, input_row): pass
323
        row_ct = process_inputs(root, row_ready)
324
        xml_func.process(root, on_error)
325
        if out_is_xml_ref[0]:
326
            doc.writexml(sys.stdout, **xml_dom.prettyxml_config)
327
        else: # output is CSV
328
            raise NotImplementedError('CSV output not supported yet')
329
    
330
    profiler.stop(row_ct)
331
    ex_tracker.add_iters(row_ct)
332
    if verbose:
333
        sys.stderr.write('Processed '+str(row_ct)+' input rows\n')
334
        sys.stderr.write(profiler.msg()+'\n')
335
        sys.stderr.write(ex_tracker.msg()+'\n')
336
    ex_tracker.exit()
337

    
338
def main():
339
    try: main_()
340
    except Parser.SyntaxError, e: raise SystemExit(str(e))
341

    
342
if __name__ == '__main__':
343
    profile_to = opts.get_env_var('profile_to', None)
344
    if profile_to != None:
345
        import cProfile
346
        sys.stderr.write('Profiling to '+profile_to+'\n')
347
        cProfile.run(main.func_code, profile_to)
348
    else: main()
(23-23/43)