Project

General

Profile

1
#!/usr/bin/env python
2
# Maps one datasource to another, using a map spreadsheet if needed
3
# Exit status is the # of errors in the import, up to the maximum exit status
4
# For outputting an XML file to a PostgreSQL database, use the general format of
5
# http://vegbank.org/vegdocs/xml/vegbank_example_ver1.0.2.xml
6

    
7
import csv
8
import os.path
9
import sys
10
import xml.dom.minidom as minidom
11

    
12
sys.path.append(os.path.dirname(__file__)+"/../lib")
13

    
14
import csvs
15
import exc
16
import opts
17
import Parser
18
import profiling
19
import sql
20
import strings
21
import term
22
import util
23
import xpath
24
import xml_dom
25
import xml_func
26

    
27
def get_with_prefix(map_, prefixes, key):
28
    '''Gets all entries for the given key with any of the given prefixes'''
29
    values = []
30
    for prefix in ['']+prefixes: # also lookup with no prefix
31
        try: value = map_[prefix+key]
32
        except KeyError, e: continue # keep going
33
        values.append(value)
34
    
35
    if values != []: return values
36
    else: raise e # re-raise last KeyError
37

    
38
def metadata_value(name): return None # this feature has been removed
39

    
40
def cleanup(val):
41
    if val == None: return val
42
    return util.none_if(strings.cleanup(strings.ustr(val)), u'', u'\\N')
43

    
44
def main_():
45
    env_names = []
46
    def usage_err():
47
        raise SystemExit('Usage: '+opts.env_usage(env_names, True)+' '
48
            +sys.argv[0]+' [map_path...] [<input] [>output]')
49
    
50
    ## Get config from env vars
51
    
52
    # Modes
53
    test = opts.env_flag('test', False, env_names)
54
    commit = opts.env_flag('commit', False, env_names) and not test
55
        # never commit in test mode
56
    redo = opts.env_flag('redo', test, env_names) and not commit
57
        # never redo in commit mode (manually run `make empty_db` instead)
58
    
59
    # Ranges
60
    start = util.cast(int, opts.get_env_var('start', '0', env_names))
61
    if test: end_default = 1
62
    else: end_default = None
63
    end = util.cast(int, util.none_if(
64
        opts.get_env_var('n', end_default, env_names), u''))
65
    if end != None: end += start
66
    
67
    # Debugging
68
    debug = opts.env_flag('debug', False, env_names)
69
    sql.run_raw_query.debug = debug
70
    verbose = debug or opts.env_flag('verbose', not test, env_names)
71
    opts.get_env_var('profile_to', None, env_names) # add to env_names
72
    
73
    # DB
74
    db_config_names = ['engine', 'host', 'user', 'password', 'database']
75
    def get_db_config(prefix):
76
        return opts.get_env_vars(db_config_names, prefix, env_names)
77
    in_db_config = get_db_config('in')
78
    out_db_config = get_db_config('out')
79
    in_is_db = 'engine' in in_db_config
80
    out_is_db = 'engine' in out_db_config
81
    
82
    ##
83
    
84
    # Logging
85
    def log(msg, on=verbose):
86
        if on: sys.stderr.write(msg)
87
    def log_start(action, on=verbose): log(action+'...\n', on)
88
    
89
    # Parse args
90
    map_paths = sys.argv[1:]
91
    if map_paths == []:
92
        if in_is_db or not out_is_db: usage_err()
93
        else: map_paths = [None]
94
    
95
    def connect_db(db_config):
96
        log_start('Connecting to '+sql.db_config_str(db_config))
97
        return sql.connect(db_config)
98
    
99
    if end != None: end_str = str(end-1)
100
    else: end_str = 'end'
101
    log_start('Processing input rows '+str(start)+'-'+end_str)
102
    
103
    ex_tracker = exc.ExPercentTracker(iter_text='row')
104
    profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
105
    
106
    doc = xml_dom.create_doc()
107
    root = doc.documentElement
108
    out_is_xml_ref = [False]
109
    in_label_ref = [None]
110
    def update_in_label():
111
        if in_label_ref[0] != None:
112
            xpath.get(root, '/_ignore/inLabel="'+in_label_ref[0]+'"', True)
113
    def prep_root():
114
        root.clear()
115
        update_in_label()
116
    prep_root()
117
    
118
    def process_input(root, row_ready, map_path):
119
        '''Inputs datasource to XML tree, mapping if needed'''
120
        # Load map header
121
        in_is_xpaths = True
122
        out_is_xpaths = True
123
        out_label = None
124
        if map_path != None:
125
            metadata = []
126
            mappings = []
127
            stream = open(map_path, 'rb')
128
            reader = csv.reader(stream)
129
            in_label, out_label = reader.next()[:2]
130
            
131
            def split_col_name(name):
132
                label, sep, root = name.partition(':')
133
                label, sep2, prefixes_str = label.partition('[')
134
                prefixes_str = strings.remove_suffix(']', prefixes_str)
135
                prefixes = strings.split(',', prefixes_str)
136
                return label, sep != '', root, prefixes
137
                    # extract datasrc from "datasrc[data_format]"
138
            
139
            in_label, in_is_xpaths, in_root, prefixes = split_col_name(in_label)
140
            in_label_ref[0] = in_label
141
            update_in_label()
142
            out_label, out_is_xpaths, out_root = split_col_name(out_label)[:3]
143
            has_types = out_root.startswith('/*s/') # outer elements are types
144
            
145
            for row in reader:
146
                in_, out = row[:2]
147
                if out != '':
148
                    if out_is_xpaths: out = xpath.parse(out_root+out)
149
                    mappings.append((in_, out))
150
            
151
            stream.close()
152
            
153
            root.ownerDocument.documentElement.tagName = out_label
154
        in_is_xml = in_is_xpaths and not in_is_db
155
        out_is_xml_ref[0] = out_is_xpaths and not out_is_db
156
        
157
        if in_is_xml:
158
            doc0 = minidom.parse(sys.stdin)
159
            doc0_root = doc0.documentElement
160
            if out_label == None: out_label = doc0_root.tagName
161
        
162
        def process_rows(process_row, rows):
163
            '''Processes input rows
164
            @param process_row(in_row, i)
165
            '''
166
            i = -1 # in case for loop does not execute
167
            for i, row in enumerate(rows):
168
                if i < start: continue
169
                if end != None and i >= end: break
170
                process_row(row, i)
171
                row_ready(i, row)
172
            row_ct = i-start+1
173
            return row_ct
174
        
175
        def map_rows(get_value, rows):
176
            '''Maps input rows
177
            @param get_value(in_, row):str
178
            '''
179
            def process_row(row, i):
180
                row_id = str(i)
181
                for in_, out in mappings:
182
                    value = metadata_value(in_)
183
                    if value == None:
184
                        log_start('Getting '+str(in_), debug)
185
                        value = cleanup(get_value(in_, row))
186
                    if value != None:
187
                        log_start('Putting '+str(out), debug)
188
                        xpath.put_obj(root, out, row_id, has_types, value)
189
            return process_rows(process_row, rows)
190
        
191
        def map_table(col_names, rows):
192
            col_names_ct = len(col_names)
193
            col_idxs = util.list_flip(col_names)
194
            
195
            i = 0
196
            while i < len(mappings): # mappings len changes in loop
197
                in_, out = mappings[i]
198
                if metadata_value(in_) == None:
199
                    try: mappings[i] = (
200
                        get_with_prefix(col_idxs, prefixes, in_), out)
201
                    except KeyError:
202
                        del mappings[i]
203
                        continue # keep i the same
204
                i += 1
205
            
206
            def get_value(in_, row):
207
                return util.coalesce(*util.list_subset(row.list, in_))
208
            def wrap_row(row):
209
                return util.ListDict(util.list_as_length(row, col_names_ct),
210
                    col_names, col_idxs) # handle CSV rows of different lengths
211
            
212
            return map_rows(get_value, util.WrapIter(wrap_row, rows))
213
        
214
        if map_path == None:
215
            iter_ = xml_dom.NodeElemIter(doc0_root)
216
            util.skip(iter_, xml_dom.is_text) # skip metadata
217
            row_ct = process_rows(lambda row, i: root.appendChild(row), iter_)
218
        elif in_is_db:
219
            assert in_is_xpaths
220
            
221
            in_db = connect_db(in_db_config)
222
            in_pkeys = {}
223
            cur = sql.select(in_db, table=in_root, fields=None, conds=None,
224
                limit=end, start=0)
225
            row_ct = map_table(list(sql.col_names(cur)), sql.rows(cur))
226
            
227
            in_db.close()
228
        elif in_is_xml:
229
            if prefixes != []: prefix = './{'+(','.join(['.']+prefixes))+'}/'
230
                # also lookup with no prefix
231
            else: prefix = ''
232
            
233
            rows = xpath.get(doc0_root, in_root, limit=end)
234
            if rows == []: raise SystemExit('Map error: Root "'+in_root
235
                +'" not found in input')
236
            def get_value(in_, row):
237
                nodes = xpath.get(row, prefix+in_, allow_rooted=False)
238
                if nodes != []: return xml_dom.value(nodes[0])
239
                else: return None
240
            row_ct = map_rows(get_value, rows)
241
        else: # input is CSV
242
            map_ = dict(mappings)
243
            reader, col_names = csvs.reader_and_header(sys.stdin)
244
            row_ct = map_table(col_names, reader)
245
        
246
        return row_ct
247
    
248
    def process_inputs(root, row_ready):
249
        row_ct = 0
250
        for map_path in map_paths:
251
            row_ct += process_input(root, row_ready, map_path)
252
        return row_ct
253
    
254
    if out_is_db:
255
        import db_xml
256
        
257
        out_db = connect_db(out_db_config)
258
        out_pkeys = {}
259
        try:
260
            if redo: sql.empty_db(out_db)
261
            row_ins_ct_ref = [0]
262
            
263
            def row_ready(row_num, input_row):
264
                def on_error(e):
265
                    exc.add_msg(e, term.emph('row #:')+' '+str(row_num))
266
                    exc.add_msg(e, term.emph('input row:')+'\n'+str(input_row))
267
                    exc.add_msg(e, term.emph('output row:')+'\n'+str(root))
268
                    ex_tracker.track(e, row_num)
269
                
270
                xml_func.process(root, on_error)
271
                if not xml_dom.is_empty(root):
272
                    assert xml_dom.has_one_child(root)
273
                    try:
274
                        sql.with_savepoint(out_db,
275
                            lambda: db_xml.put(out_db, root.firstChild,
276
                                out_pkeys, row_ins_ct_ref, on_error))
277
                        if commit: out_db.commit()
278
                    except sql.DatabaseErrors, e: on_error(e)
279
                prep_root()
280
            
281
            row_ct = process_inputs(root, row_ready)
282
            sys.stdout.write('Inserted '+str(row_ins_ct_ref[0])+
283
                ' new rows into database\n')
284
        finally:
285
            out_db.rollback()
286
            out_db.close()
287
    else:
288
        def on_error(e): ex_tracker.track(e)
289
        def row_ready(row_num, input_row): pass
290
        row_ct = process_inputs(root, row_ready)
291
        xml_func.process(root, on_error)
292
        if out_is_xml_ref[0]:
293
            doc.writexml(sys.stdout, **xml_dom.prettyxml_config)
294
        else: # output is CSV
295
            raise NotImplementedError('CSV output not supported yet')
296
    
297
    profiler.stop(row_ct)
298
    ex_tracker.add_iters(row_ct)
299
    if verbose:
300
        sys.stderr.write('Processed '+str(row_ct)+' input rows\n')
301
        sys.stderr.write(profiler.msg()+'\n')
302
        sys.stderr.write(ex_tracker.msg()+'\n')
303
    ex_tracker.exit()
304

    
305
def main():
306
    try: main_()
307
    except Parser.SyntaxException, e: raise SystemExit(str(e))
308

    
309
if __name__ == '__main__':
310
    profile_to = opts.get_env_var('profile_to', None)
311
    if profile_to != None:
312
        import cProfile
313
        sys.stderr.write('Profiling to '+profile_to+'\n')
314
        cProfile.run(main.func_code, profile_to)
315
    else: main()
(20-20/39)