Project

General

Profile

1
#!/usr/bin/env python
2
# Maps one datasource to another, using a map spreadsheet if needed
3
# Exit status is the # of errors in the import, up to the maximum exit status
4
# For outputting an XML file to a PostgreSQL database, use the general format of
5
# http://vegbank.org/vegdocs/xml/vegbank_example_ver1.0.2.xml
6

    
7
import csv
8
import os.path
9
import sys
10
import xml.dom.minidom as minidom
11

    
12
sys.path.append(os.path.dirname(__file__)+"/../lib")
13

    
14
import csvs
15
import exc
16
import opts
17
import Parser
18
import profiling
19
import sql
20
import strings
21
import term
22
import util
23
import xpath
24
import xml_dom
25
import xml_func
26

    
27
def get_with_prefix(map_, prefixes, key):
28
    '''Gets all entries for the given key with any of the given prefixes'''
29
    values = []
30
    for key_ in strings.with_prefixes(['']+prefixes, key): # also with no prefix
31
        try: value = map_[key_]
32
        except KeyError, e: continue # keep going
33
        values.append(value)
34
    
35
    if values != []: return values
36
    else: raise e # re-raise last KeyError
37

    
38
def metadata_value(name): return None # this feature has been removed
39

    
40
def cleanup(val):
41
    if val == None: return val
42
    return util.none_if(strings.cleanup(strings.ustr(val)), u'', u'\\N')
43

    
44
def main_():
45
    env_names = []
46
    def usage_err():
47
        raise SystemExit('Usage: '+opts.env_usage(env_names, True)+' '
48
            +sys.argv[0]+' [map_path...] [<input] [>output]')
49
    
50
    ## Get config from env vars
51
    
52
    # Modes
53
    test = opts.env_flag('test', False, env_names)
54
    commit = opts.env_flag('commit', False, env_names) and not test
55
        # never commit in test mode
56
    redo = opts.env_flag('redo', test, env_names) and not commit
57
        # never redo in commit mode (manually run `make empty_db` instead)
58
    
59
    # Ranges
60
    start = util.cast(int, opts.get_env_var('start', '0', env_names))
61
    if test: end_default = 1
62
    else: end_default = None
63
    end = util.cast(int, util.none_if(
64
        opts.get_env_var('n', end_default, env_names), u''))
65
    if end != None: end += start
66
    
67
    # Debugging
68
    debug = opts.env_flag('debug', False, env_names)
69
    sql.run_raw_query.debug = debug
70
    verbose = debug or opts.env_flag('verbose', not test, env_names)
71
    opts.get_env_var('profile_to', None, env_names) # add to env_names
72
    
73
    # DB
74
    db_config_names = ['engine', 'host', 'user', 'password', 'database']
75
    def get_db_config(prefix):
76
        return opts.get_env_vars(db_config_names, prefix, env_names)
77
    in_db_config = get_db_config('in')
78
    out_db_config = get_db_config('out')
79
    in_is_db = 'engine' in in_db_config
80
    out_is_db = 'engine' in out_db_config
81
    
82
    ##
83
    
84
    # Logging
85
    def log(msg, on=verbose):
86
        if on: sys.stderr.write(msg)
87
    def log_start(action, on=verbose): log(action+'...\n', on)
88
    
89
    # Parse args
90
    map_paths = sys.argv[1:]
91
    if map_paths == []:
92
        if in_is_db or not out_is_db: usage_err()
93
        else: map_paths = [None]
94
    
95
    def connect_db(db_config):
96
        log_start('Connecting to '+sql.db_config_str(db_config))
97
        return sql.connect(db_config)
98
    
99
    if end != None: end_str = str(end-1)
100
    else: end_str = 'end'
101
    log_start('Processing input rows '+str(start)+'-'+end_str)
102
    
103
    ex_tracker = exc.ExPercentTracker(iter_text='row')
104
    profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
105
    
106
    doc = xml_dom.create_doc()
107
    root = doc.documentElement
108
    out_is_xml_ref = [False]
109
    in_label_ref = [None]
110
    def update_in_label():
111
        if in_label_ref[0] != None:
112
            xpath.get(root, '/_ignore/inLabel="'+in_label_ref[0]+'"', True)
113
    def prep_root():
114
        root.clear()
115
        update_in_label()
116
    prep_root()
117
    
118
    def process_input(root, row_ready, map_path):
119
        '''Inputs datasource to XML tree, mapping if needed'''
120
        # Load map header
121
        in_is_xpaths = True
122
        out_is_xpaths = True
123
        out_label = None
124
        if map_path != None:
125
            metadata = []
126
            mappings = []
127
            stream = open(map_path, 'rb')
128
            reader = csv.reader(stream)
129
            in_label, out_label = reader.next()[:2]
130
            
131
            def split_col_name(name):
132
                label, sep, root = name.partition(':')
133
                label, sep2, prefixes_str = label.partition('[')
134
                prefixes_str = strings.remove_suffix(']', prefixes_str)
135
                prefixes = strings.split(',', prefixes_str)
136
                return label, sep != '', root, prefixes
137
                    # extract datasrc from "datasrc[data_format]"
138
            
139
            in_label, in_is_xpaths, in_root, prefixes = split_col_name(in_label)
140
            in_label_ref[0] = in_label
141
            update_in_label()
142
            out_label, out_is_xpaths, out_root = split_col_name(out_label)[:3]
143
            has_types = out_root.startswith('/*s/') # outer elements are types
144
            
145
            for row in reader:
146
                in_, out = row[:2]
147
                if out != '':
148
                    if out_is_xpaths: out = xpath.parse(out_root+out)
149
                    mappings.append((in_, out))
150
            
151
            stream.close()
152
            
153
            root.ownerDocument.documentElement.tagName = out_label
154
        in_is_xml = in_is_xpaths and not in_is_db
155
        out_is_xml_ref[0] = out_is_xpaths and not out_is_db
156
        
157
        def process_rows(process_row, rows):
158
            '''Processes input rows
159
            @param process_row(in_row, i)
160
            '''
161
            i = -1 # in case for loop does not execute
162
            for i, row in enumerate(rows):
163
                if i < start: continue
164
                if end != None and i >= end: break
165
                process_row(row, i)
166
                row_ready(i, row)
167
            row_ct = i-start+1
168
            return row_ct
169
        
170
        def map_rows(get_value, rows):
171
            '''Maps input rows
172
            @param get_value(in_, row):str
173
            '''
174
            def process_row(row, i):
175
                row_id = str(i)
176
                for in_, out in mappings:
177
                    value = metadata_value(in_)
178
                    if value == None:
179
                        log_start('Getting '+str(in_), debug)
180
                        value = cleanup(get_value(in_, row))
181
                    if value != None:
182
                        log_start('Putting '+str(out), debug)
183
                        xpath.put_obj(root, out, row_id, has_types, value)
184
            return process_rows(process_row, rows)
185
        
186
        def map_table(col_names, rows):
187
            col_names_ct = len(col_names)
188
            col_idxs = util.list_flip(col_names)
189
            
190
            i = 0
191
            while i < len(mappings): # mappings len changes in loop
192
                in_, out = mappings[i]
193
                if metadata_value(in_) == None:
194
                    try: mappings[i] = (
195
                        get_with_prefix(col_idxs, prefixes, in_), out)
196
                    except KeyError:
197
                        del mappings[i]
198
                        continue # keep i the same
199
                i += 1
200
            
201
            def get_value(in_, row):
202
                return util.coalesce(*util.list_subset(row.list, in_))
203
            def wrap_row(row):
204
                return util.ListDict(util.list_as_length(row, col_names_ct),
205
                    col_names, col_idxs) # handle CSV rows of different lengths
206
            
207
            return map_rows(get_value, util.WrapIter(wrap_row, rows))
208
        
209
        if in_is_db:
210
            assert in_is_xpaths
211
            
212
            in_db = connect_db(in_db_config)
213
            in_pkeys = {}
214
            cur = sql.select(in_db, table=in_root, fields=None, conds=None,
215
                limit=end, start=0)
216
            row_ct = map_table(list(sql.col_names(cur)), sql.rows(cur))
217
            
218
            in_db.close()
219
        elif in_is_xml:
220
            doc0_root = minidom.parse(sys.stdin).documentElement
221
            if map_path == None:
222
                iter_ = xml_dom.NodeElemIter(doc0_root)
223
                util.skip(iter_, xml_dom.is_text) # skip metadata
224
                row_ct = process_rows(lambda row, i: root.appendChild(row),
225
                    iter_)
226
            else:
227
                rows = xpath.get(doc0_root, in_root, limit=end)
228
                if rows == []: raise SystemExit('Map error: Root "'+in_root
229
                    +'" not found in input')
230
                
231
                def get_value(in_, row):
232
                    in_ = './{'+(','.join(strings.with_prefixes(['']+prefixes,
233
                        in_)))+'}' # also with no prefix
234
                    nodes = xpath.get(row, in_, allow_rooted=False)
235
                    if nodes != []: return xml_dom.value(nodes[0])
236
                    else: return None
237
                
238
                row_ct = map_rows(get_value, rows)
239
        else: # input is CSV
240
            map_ = dict(mappings)
241
            reader, col_names = csvs.reader_and_header(sys.stdin)
242
            row_ct = map_table(col_names, reader)
243
        
244
        return row_ct
245
    
246
    def process_inputs(root, row_ready):
247
        row_ct = 0
248
        for map_path in map_paths:
249
            row_ct += process_input(root, row_ready, map_path)
250
        return row_ct
251
    
252
    if out_is_db:
253
        import db_xml
254
        
255
        out_db = connect_db(out_db_config)
256
        out_pkeys = {}
257
        try:
258
            if redo: sql.empty_db(out_db)
259
            row_ins_ct_ref = [0]
260
            
261
            def row_ready(row_num, input_row):
262
                def on_error(e):
263
                    exc.add_msg(e, term.emph('row #:')+' '+str(row_num))
264
                    exc.add_msg(e, term.emph('input row:')+'\n'+str(input_row))
265
                    exc.add_msg(e, term.emph('output row:')+'\n'+str(root))
266
                    ex_tracker.track(e, row_num)
267
                
268
                xml_func.process(root, on_error)
269
                if not xml_dom.is_empty(root):
270
                    assert xml_dom.has_one_child(root)
271
                    try:
272
                        sql.with_savepoint(out_db,
273
                            lambda: db_xml.put(out_db, root.firstChild,
274
                                out_pkeys, row_ins_ct_ref, on_error))
275
                        if commit: out_db.commit()
276
                    except sql.DatabaseErrors, e: on_error(e)
277
                prep_root()
278
            
279
            row_ct = process_inputs(root, row_ready)
280
            sys.stdout.write('Inserted '+str(row_ins_ct_ref[0])+
281
                ' new rows into database\n')
282
        finally:
283
            out_db.rollback()
284
            out_db.close()
285
    else:
286
        def on_error(e): ex_tracker.track(e)
287
        def row_ready(row_num, input_row): pass
288
        row_ct = process_inputs(root, row_ready)
289
        xml_func.process(root, on_error)
290
        if out_is_xml_ref[0]:
291
            doc.writexml(sys.stdout, **xml_dom.prettyxml_config)
292
        else: # output is CSV
293
            raise NotImplementedError('CSV output not supported yet')
294
    
295
    profiler.stop(row_ct)
296
    ex_tracker.add_iters(row_ct)
297
    if verbose:
298
        sys.stderr.write('Processed '+str(row_ct)+' input rows\n')
299
        sys.stderr.write(profiler.msg()+'\n')
300
        sys.stderr.write(ex_tracker.msg()+'\n')
301
    ex_tracker.exit()
302

    
303
def main():
304
    try: main_()
305
    except Parser.SyntaxException, e: raise SystemExit(str(e))
306

    
307
if __name__ == '__main__':
308
    profile_to = opts.get_env_var('profile_to', None)
309
    if profile_to != None:
310
        import cProfile
311
        sys.stderr.write('Profiling to '+profile_to+'\n')
312
        cProfile.run(main.func_code, profile_to)
313
    else: main()
(21-21/40)