Project

General

Profile

1
#!/usr/bin/env python
2
# Maps one datasource to another, using a map spreadsheet if needed
3
# Exit status is the # of errors in the import, up to the maximum exit status
4
# For outputting an XML file to a PostgreSQL database, use the general format of
5
# http://vegbank.org/vegdocs/xml/vegbank_example_ver1.0.2.xml
6

    
7
import csv
8
import os.path
9
import sys
10
import xml.dom.minidom as minidom
11

    
12
sys.path.append(os.path.dirname(__file__)+"/../lib")
13

    
14
import csvs
15
import exc
16
import opts
17
import Parser
18
import profiling
19
import sql
20
import strings
21
import term
22
import util
23
import xpath
24
import xml_dom
25
import xml_func
26

    
27
def metadata_value(name): return None # this feature has been removed
28

    
29
def cleanup(val):
30
    if val == None: return val
31
    return util.none_if(strings.cleanup(strings.ustr(val)), u'', u'\\N')
32

    
33
def main_():
34
    env_names = []
35
    def usage_err():
36
        raise SystemExit('Usage: '+opts.env_usage(env_names, True)+' '
37
            +sys.argv[0]+' [map_path...] [<input] [>output]')
38
    
39
    # Get db config from env vars
40
    db_config_names = ['engine', 'host', 'user', 'password', 'database']
41
    def get_db_config(prefix):
42
        return opts.get_env_vars(db_config_names, prefix, env_names)
43
    in_db_config = get_db_config('in')
44
    out_db_config = get_db_config('out')
45
    in_is_db = 'engine' in in_db_config
46
    out_is_db = 'engine' in out_db_config
47
    
48
    # Get other config from env vars
49
    end = util.cast(int, opts.get_env_var('n', None, env_names))
50
    start = util.cast(int, opts.get_env_var('start', '0', env_names))
51
    if end != None: end += start
52
    test = opts.env_flag('test', False, env_names)
53
    commit = opts.env_flag('commit', False, env_names) and not test
54
        # never commit in test mode
55
    redo = opts.env_flag('redo', test, env_names) and not commit
56
        # never redo in commit mode (manually run `make empty_db` instead)
57
    debug = opts.env_flag('debug', False, env_names)
58
    sql.run_raw_query.debug = debug
59
    verbose = debug or opts.env_flag('verbose', False, env_names)
60
    opts.get_env_var('profile_to', None, env_names) # add to env_names
61
    
62
    # Logging
63
    def log(msg, on=verbose):
64
        if on: sys.stderr.write(msg)
65
    def log_start(action, on=verbose): log(action+'...\n', on)
66
    
67
    # Parse args
68
    map_paths = sys.argv[1:]
69
    if map_paths == []:
70
        if in_is_db or not out_is_db: usage_err()
71
        else: map_paths = [None]
72
    
73
    def connect_db(db_config):
74
        log_start('Connecting to '+sql.db_config_str(db_config))
75
        return sql.connect(db_config)
76
    
77
    ex_tracker = exc.ExPercentTracker(iter_text='row')
78
    profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
79
    
80
    doc = xml_dom.create_doc()
81
    root = doc.documentElement
82
    out_is_xml_ref = [False]
83
    in_label_ref = [None]
84
    def update_in_label():
85
        if in_label_ref[0] != None:
86
            xpath.get(root, '/_ignore/inLabel="'+in_label_ref[0]+'"', True)
87
    def prep_root():
88
        root.clear()
89
        update_in_label()
90
    prep_root()
91
    
92
    def process_input(root, row_ready, map_path):
93
        '''Inputs datasource to XML tree, mapping if needed'''
94
        # Load map header
95
        in_is_xpaths = True
96
        out_is_xpaths = True
97
        out_label = None
98
        if map_path != None:
99
            metadata = []
100
            mappings = []
101
            stream = open(map_path, 'rb')
102
            reader = csv.reader(stream)
103
            in_label, out_label = reader.next()[:2]
104
            
105
            def split_col_name(name):
106
                label, sep, root = name.partition(':')
107
                label, sep2, prefixes_str = label.partition('[')
108
                prefixes_str = strings.remove_suffix(']', prefixes_str)
109
                prefixes = strings.split(',', prefixes_str)
110
                return label, sep != '', root, prefixes
111
                    # extract datasrc from "datasrc[data_format]"
112
            
113
            in_label, in_is_xpaths, in_root, prefixes = split_col_name(in_label)
114
            in_label_ref[0] = in_label
115
            update_in_label()
116
            out_label, out_is_xpaths, out_root = split_col_name(out_label)[:3]
117
            has_types = out_root.startswith('/*s/') # outer elements are types
118
            
119
            for row in reader:
120
                in_, out = row[:2]
121
                if out != '':
122
                    if out_is_xpaths: out = xpath.parse(out_root+out)
123
                    mappings.append((in_, out))
124
            
125
            stream.close()
126
            
127
            root.ownerDocument.documentElement.tagName = out_label
128
        in_is_xml = in_is_xpaths and not in_is_db
129
        out_is_xml_ref[0] = out_is_xpaths and not out_is_db
130
        
131
        if in_is_xml:
132
            doc0 = minidom.parse(sys.stdin)
133
            doc0_root = doc0.documentElement
134
            if out_label == None: out_label = doc0_root.tagName
135
        
136
        def process_rows(process_row, rows):
137
            '''Processes input rows
138
            @param process_row(in_row, i)
139
            '''
140
            i = -1 # in case for loop does not execute
141
            for i, row in enumerate(rows):
142
                if i < start: continue
143
                if end != None and i >= end: break
144
                process_row(row, i)
145
                row_ready(i, row)
146
            row_ct = i-start+1
147
            return row_ct
148
        
149
        def map_rows(get_value, rows):
150
            '''Maps input rows
151
            @param get_value(in_, row):str
152
            '''
153
            def process_row(row, i):
154
                row_id = str(i)
155
                for in_, out in mappings:
156
                    value = metadata_value(in_)
157
                    if value == None:
158
                        log_start('Getting '+str(in_), debug)
159
                        value = cleanup(get_value(in_, row))
160
                    if value != None:
161
                        log_start('Putting '+str(out), debug)
162
                        xpath.put_obj(root, out, row_id, has_types, value)
163
            return process_rows(process_row, rows)
164
        
165
        if map_path == None:
166
            iter_ = xml_dom.NodeElemIter(doc0_root)
167
            util.skip(iter_, xml_dom.is_text) # skip metadata
168
            row_ct = process_rows(lambda row, i: root.appendChild(row), iter_)
169
        elif in_is_db:
170
            assert in_is_xpaths
171
            
172
            in_db = connect_db(in_db_config)
173
            in_pkeys = {}
174
            cur = sql.select(in_db, table=in_root, fields=None, conds=None,
175
                limit=end, start=0)
176
            col_names = list(sql.col_names(cur))
177
            col_idxs = util.list_flip(col_names)
178
            
179
            mappings_new = []
180
            for i, mapping in enumerate(mappings):
181
                in_, out = mapping
182
                if metadata_value(in_) == None:
183
                    try: mapping = (col_idxs[in_], out)
184
                    except KeyError: continue
185
                mappings_new.append(mapping)
186
            mappings = mappings_new
187
            
188
            def get_value(in_, row):
189
                try: return row.list[in_]
190
                except IndexError: return None
191
            def wrap_row(row): return util.ListDict(row, col_names, col_idxs)
192
            row_ct = map_rows(get_value, util.WrapIter(wrap_row, sql.rows(cur)))
193
            
194
            in_db.close()
195
        elif in_is_xml:
196
            def get_value(in_, row):
197
                nodes = xpath.get(row, in_, allow_rooted=False)
198
                if nodes != []: return xml_dom.value(nodes[0])
199
                else: return None
200
            rows = xpath.get(doc0_root, in_root, limit=end)
201
            if rows == []: raise SystemExit('Map error: Root "'+in_root
202
                +'" not found in input')
203
            row_ct = map_rows(get_value, rows)
204
        else: # input is CSV
205
            map_ = dict(mappings)
206
            reader, col_names = csvs.reader_and_header(sys.stdin)
207
            col_idxs = util.list_flip(col_names)
208
            
209
            mappings_new = []
210
            for i, mapping in enumerate(mappings):
211
                in_, out = mapping
212
                if metadata_value(in_) == None:
213
                    try: mapping = (col_idxs[in_], out)
214
                    except KeyError: continue
215
                mappings_new.append(mapping)
216
            mappings = mappings_new
217
            
218
            def get_value(in_, row):
219
                try: return row.list[in_]
220
                except IndexError: return None
221
            def wrap_row(row): return util.ListDict(row, col_names, col_idxs)
222
            row_ct = map_rows(get_value, util.WrapIter(wrap_row, reader))
223
        
224
        return row_ct
225
    
226
    def process_inputs(root, row_ready):
227
        row_ct = 0
228
        for map_path in map_paths:
229
            row_ct += process_input(root, row_ready, map_path)
230
        return row_ct
231
    
232
    if out_is_db:
233
        import db_xml
234
        
235
        out_db = connect_db(out_db_config)
236
        out_pkeys = {}
237
        try:
238
            if redo: sql.empty_db(out_db)
239
            row_ins_ct_ref = [0]
240
            
241
            def row_ready(row_num, input_row):
242
                def on_error(e):
243
                    exc.add_msg(e, term.emph('row #:')+' '+str(row_num))
244
                    exc.add_msg(e, term.emph('input row:')+'\n'+str(input_row))
245
                    exc.add_msg(e, term.emph('output row:')+'\n'+str(root))
246
                    ex_tracker.track(e)
247
                
248
                xml_func.process(root, on_error)
249
                if not xml_dom.is_empty(root):
250
                    assert xml_dom.has_one_child(root)
251
                    try:
252
                        sql.with_savepoint(out_db,
253
                            lambda: db_xml.put(out_db, root.firstChild,
254
                                out_pkeys, row_ins_ct_ref, on_error))
255
                        if commit: out_db.commit()
256
                    except sql.DatabaseErrors, e: on_error(e)
257
                prep_root()
258
            
259
            row_ct = process_inputs(root, row_ready)
260
            sys.stdout.write('Inserted '+str(row_ins_ct_ref[0])+
261
                ' new rows into database\n')
262
        finally:
263
            out_db.rollback()
264
            out_db.close()
265
    else:
266
        def on_error(e): ex_tracker.track(e)
267
        def row_ready(row_num, input_row): pass
268
        row_ct = process_inputs(root, row_ready)
269
        xml_func.process(root, on_error)
270
        if out_is_xml_ref[0]:
271
            doc.writexml(sys.stdout, **xml_dom.prettyxml_config)
272
        else: # output is CSV
273
            raise NotImplementedError('CSV output not supported yet')
274
    
275
    profiler.stop(row_ct)
276
    ex_tracker.add_iters(row_ct)
277
    if verbose:
278
        sys.stderr.write('Processed '+str(row_ct)+' input rows\n')
279
        sys.stderr.write(profiler.msg()+'\n')
280
        sys.stderr.write(ex_tracker.msg()+'\n')
281
    ex_tracker.exit()
282

    
283
def main():
284
    try: main_()
285
    except Parser.SyntaxException, e: raise SystemExit(str(e))
286

    
287
if __name__ == '__main__':
288
    profile_to = opts.get_env_var('profile_to', None)
289
    if profile_to != None:
290
        import cProfile
291
        cProfile.run(main.func_code, profile_to)
292
    else: main()
(14-14/28)