Project

General

Profile

1
#!/usr/bin/env python
2
# Maps one datasource to another, using a map spreadsheet if needed
3
# Exit status is the # of errors in the import, up to the maximum exit status
4
# For outputting an XML file to a PostgreSQL database, use the general format of
5
# http://vegbank.org/vegdocs/xml/vegbank_example_ver1.0.2.xml
6

    
7
import csv
8
import os.path
9
import sys
10
import xml.dom.minidom as minidom
11

    
12
sys.path.append(os.path.dirname(__file__)+"/../lib")
13

    
14
import csvs
15
import exc
16
import opts
17
import Parser
18
import profiling
19
import sql
20
import strings
21
import term
22
import util
23
import xpath
24
import xml_dom
25
import xml_func
26

    
27
def get_with_prefix(map_, prefixes, key):
28
    '''Gets all entries for the given key with any of the given prefixes'''
29
    values = []
30
    for prefix in ['']+prefixes: # also lookup with no prefix
31
        try: value = map_[prefix+key]
32
        except KeyError, e: continue # keep going
33
        values.append(value)
34
    
35
    if values != []: return values
36
    else: raise e # re-raise last KeyError
37

    
38
def metadata_value(name): return None # this feature has been removed
39

    
40
def cleanup(val):
41
    if val == None: return val
42
    return util.none_if(strings.cleanup(strings.ustr(val)), u'', u'\\N')
43

    
44
def main_():
45
    env_names = []
46
    def usage_err():
47
        raise SystemExit('Usage: '+opts.env_usage(env_names, True)+' '
48
            +sys.argv[0]+' [map_path...] [<input] [>output]')
49
    
50
    ## Get config from env vars
51
    
52
    # Modes
53
    test = opts.env_flag('test', False, env_names)
54
    commit = opts.env_flag('commit', False, env_names) and not test
55
        # never commit in test mode
56
    redo = opts.env_flag('redo', test, env_names) and not commit
57
        # never redo in commit mode (manually run `make empty_db` instead)
58
    
59
    # Ranges
60
    start = util.cast(int, opts.get_env_var('start', '0', env_names))
61
    if test: end_default = 1
62
    else: end_default = None
63
    end = util.cast(int, util.none_if(
64
        opts.get_env_var('n', end_default, env_names), u''))
65
    if end != None: end += start
66
    
67
    # Debugging
68
    debug = opts.env_flag('debug', False, env_names)
69
    sql.run_raw_query.debug = debug
70
    verbose = debug or opts.env_flag('verbose', not test, env_names)
71
    opts.get_env_var('profile_to', None, env_names) # add to env_names
72
    
73
    # DB
74
    db_config_names = ['engine', 'host', 'user', 'password', 'database']
75
    def get_db_config(prefix):
76
        return opts.get_env_vars(db_config_names, prefix, env_names)
77
    in_db_config = get_db_config('in')
78
    out_db_config = get_db_config('out')
79
    in_is_db = 'engine' in in_db_config
80
    out_is_db = 'engine' in out_db_config
81
    
82
    ##
83
    
84
    # Logging
85
    def log(msg, on=verbose):
86
        if on: sys.stderr.write(msg)
87
    def log_start(action, on=verbose): log(action+'...\n', on)
88
    
89
    # Parse args
90
    map_paths = sys.argv[1:]
91
    if map_paths == []:
92
        if in_is_db or not out_is_db: usage_err()
93
        else: map_paths = [None]
94
    
95
    def connect_db(db_config):
96
        log_start('Connecting to '+sql.db_config_str(db_config))
97
        return sql.connect(db_config)
98
    
99
    if end != None: end_str = str(end-1)
100
    else: end_str = 'end'
101
    log_start('Processing input rows '+str(start)+'-'+end_str)
102
    
103
    ex_tracker = exc.ExPercentTracker(iter_text='row')
104
    profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
105
    
106
    doc = xml_dom.create_doc()
107
    root = doc.documentElement
108
    out_is_xml_ref = [False]
109
    in_label_ref = [None]
110
    def update_in_label():
111
        if in_label_ref[0] != None:
112
            xpath.get(root, '/_ignore/inLabel="'+in_label_ref[0]+'"', True)
113
    def prep_root():
114
        root.clear()
115
        update_in_label()
116
    prep_root()
117
    
118
    def process_input(root, row_ready, map_path):
119
        '''Inputs datasource to XML tree, mapping if needed'''
120
        # Load map header
121
        in_is_xpaths = True
122
        out_is_xpaths = True
123
        out_label = None
124
        if map_path != None:
125
            metadata = []
126
            mappings = []
127
            stream = open(map_path, 'rb')
128
            reader = csv.reader(stream)
129
            in_label, out_label = reader.next()[:2]
130
            
131
            def split_col_name(name):
132
                label, sep, root = name.partition(':')
133
                label, sep2, prefixes_str = label.partition('[')
134
                prefixes_str = strings.remove_suffix(']', prefixes_str)
135
                prefixes = strings.split(',', prefixes_str)
136
                return label, sep != '', root, prefixes
137
                    # extract datasrc from "datasrc[data_format]"
138
            
139
            in_label, in_is_xpaths, in_root, prefixes = split_col_name(in_label)
140
            in_label_ref[0] = in_label
141
            update_in_label()
142
            out_label, out_is_xpaths, out_root = split_col_name(out_label)[:3]
143
            has_types = out_root.startswith('/*s/') # outer elements are types
144
            
145
            for row in reader:
146
                in_, out = row[:2]
147
                if out != '':
148
                    if out_is_xpaths: out = xpath.parse(out_root+out)
149
                    mappings.append((in_, out))
150
            
151
            stream.close()
152
            
153
            root.ownerDocument.documentElement.tagName = out_label
154
        in_is_xml = in_is_xpaths and not in_is_db
155
        out_is_xml_ref[0] = out_is_xpaths and not out_is_db
156
        
157
        if in_is_xml:
158
            doc0 = minidom.parse(sys.stdin)
159
            doc0_root = doc0.documentElement
160
            if out_label == None: out_label = doc0_root.tagName
161
        
162
        def process_rows(process_row, rows):
163
            '''Processes input rows
164
            @param process_row(in_row, i)
165
            '''
166
            i = -1 # in case for loop does not execute
167
            for i, row in enumerate(rows):
168
                if i < start: continue
169
                if end != None and i >= end: break
170
                process_row(row, i)
171
                row_ready(i, row)
172
            row_ct = i-start+1
173
            return row_ct
174
        
175
        def map_rows(get_value, rows):
176
            '''Maps input rows
177
            @param get_value(in_, row):str
178
            '''
179
            def process_row(row, i):
180
                row_id = str(i)
181
                for in_, out in mappings:
182
                    value = metadata_value(in_)
183
                    if value == None:
184
                        log_start('Getting '+str(in_), debug)
185
                        value = cleanup(get_value(in_, row))
186
                    if value != None:
187
                        log_start('Putting '+str(out), debug)
188
                        xpath.put_obj(root, out, row_id, has_types, value)
189
            return process_rows(process_row, rows)
190
        
191
        def map_table(col_names, rows):
192
            col_names_ct = len(col_names)
193
            col_idxs = util.list_flip(col_names)
194
            
195
            i = 0
196
            while i < len(mappings): # mappings len changes in loop
197
                in_, out = mappings[i]
198
                if metadata_value(in_) == None:
199
                    try: mappings[i] = (
200
                        get_with_prefix(col_idxs, prefixes, in_), out)
201
                    except KeyError:
202
                        del mappings[i]
203
                        continue # keep i the same
204
                i += 1
205
            
206
            def get_value(in_, row):
207
                return util.coalesce(*util.list_subset(row.list, in_))
208
            def wrap_row(row):
209
                return util.ListDict(util.list_as_length(row, col_names_ct),
210
                    col_names, col_idxs) # handle CSV rows of different lengths
211
            
212
            return map_rows(get_value, util.WrapIter(wrap_row, rows))
213
        
214
        if map_path == None:
215
            iter_ = xml_dom.NodeElemIter(doc0_root)
216
            util.skip(iter_, xml_dom.is_text) # skip metadata
217
            row_ct = process_rows(lambda row, i: root.appendChild(row), iter_)
218
        elif in_is_db:
219
            assert in_is_xpaths
220
            
221
            in_db = connect_db(in_db_config)
222
            in_pkeys = {}
223
            cur = sql.select(in_db, table=in_root, fields=None, conds=None,
224
                limit=end, start=0)
225
            row_ct = map_table(list(sql.col_names(cur)), sql.rows(cur))
226
            
227
            in_db.close()
228
        elif in_is_xml:
229
            def get_value(in_, row):
230
                nodes = xpath.get(row, in_, allow_rooted=False)
231
                if nodes != []: return xml_dom.value(nodes[0])
232
                else: return None
233
            rows = xpath.get(doc0_root, in_root, limit=end)
234
            if rows == []: raise SystemExit('Map error: Root "'+in_root
235
                +'" not found in input')
236
            row_ct = map_rows(get_value, rows)
237
        else: # input is CSV
238
            map_ = dict(mappings)
239
            reader, col_names = csvs.reader_and_header(sys.stdin)
240
            row_ct = map_table(col_names, reader)
241
        
242
        return row_ct
243
    
244
    def process_inputs(root, row_ready):
245
        row_ct = 0
246
        for map_path in map_paths:
247
            row_ct += process_input(root, row_ready, map_path)
248
        return row_ct
249
    
250
    if out_is_db:
251
        import db_xml
252
        
253
        out_db = connect_db(out_db_config)
254
        out_pkeys = {}
255
        try:
256
            if redo: sql.empty_db(out_db)
257
            row_ins_ct_ref = [0]
258
            
259
            def row_ready(row_num, input_row):
260
                def on_error(e):
261
                    exc.add_msg(e, term.emph('row #:')+' '+str(row_num))
262
                    exc.add_msg(e, term.emph('input row:')+'\n'+str(input_row))
263
                    exc.add_msg(e, term.emph('output row:')+'\n'+str(root))
264
                    ex_tracker.track(e, row_num)
265
                
266
                xml_func.process(root, on_error)
267
                if not xml_dom.is_empty(root):
268
                    assert xml_dom.has_one_child(root)
269
                    try:
270
                        sql.with_savepoint(out_db,
271
                            lambda: db_xml.put(out_db, root.firstChild,
272
                                out_pkeys, row_ins_ct_ref, on_error))
273
                        if commit: out_db.commit()
274
                    except sql.DatabaseErrors, e: on_error(e)
275
                prep_root()
276
            
277
            row_ct = process_inputs(root, row_ready)
278
            sys.stdout.write('Inserted '+str(row_ins_ct_ref[0])+
279
                ' new rows into database\n')
280
        finally:
281
            out_db.rollback()
282
            out_db.close()
283
    else:
284
        def on_error(e): ex_tracker.track(e)
285
        def row_ready(row_num, input_row): pass
286
        row_ct = process_inputs(root, row_ready)
287
        xml_func.process(root, on_error)
288
        if out_is_xml_ref[0]:
289
            doc.writexml(sys.stdout, **xml_dom.prettyxml_config)
290
        else: # output is CSV
291
            raise NotImplementedError('CSV output not supported yet')
292
    
293
    profiler.stop(row_ct)
294
    ex_tracker.add_iters(row_ct)
295
    if verbose:
296
        sys.stderr.write('Processed '+str(row_ct)+' input rows\n')
297
        sys.stderr.write(profiler.msg()+'\n')
298
        sys.stderr.write(ex_tracker.msg()+'\n')
299
    ex_tracker.exit()
300

    
301
def main():
302
    try: main_()
303
    except Parser.SyntaxException, e: raise SystemExit(str(e))
304

    
305
if __name__ == '__main__':
306
    profile_to = opts.get_env_var('profile_to', None)
307
    if profile_to != None:
308
        import cProfile
309
        sys.stderr.write('Profiling to '+profile_to+'\n')
310
        cProfile.run(main.func_code, profile_to)
311
    else: main()
(20-20/39)