Project

General

Profile

1
#!/usr/bin/env python
2
# Maps one datasource to another, using a map spreadsheet if needed
3
# Exit status is the # of errors in the import, up to the maximum exit status
4
# For outputting an XML file to a PostgreSQL database, use the general format of
5
# http://vegbank.org/vegdocs/xml/vegbank_example_ver1.0.2.xml
6

    
7
import csv
8
import os.path
9
import sys
10
import xml.dom.minidom as minidom
11

    
12
sys.path.append(os.path.dirname(__file__)+"/../lib")
13

    
14
import csvs
15
import exc
16
import opts
17
import Parser
18
import profiling
19
import sql
20
import strings
21
import term
22
import util
23
import xpath
24
import xml_dom
25
import xml_func
26

    
27
def get_with_prefix(map_, prefixes, key):
28
    '''Gets all entries for the given key with any of the given prefixes'''
29
    values = []
30
    for prefix in ['']+prefixes: # also lookup with no prefix
31
        try: value = map_[prefix+key]
32
        except KeyError, e: continue # keep going
33
        values.append(value)
34
    
35
    if values != []: return values
36
    else: raise e # re-raise last KeyError
37

    
38
def metadata_value(name): return None # this feature has been removed
39

    
40
def cleanup(val):
41
    if val == None: return val
42
    return util.none_if(strings.cleanup(strings.ustr(val)), u'', u'\\N')
43

    
44
def main_():
45
    env_names = []
46
    def usage_err():
47
        raise SystemExit('Usage: '+opts.env_usage(env_names, True)+' '
48
            +sys.argv[0]+' [map_path...] [<input] [>output]')
49
    
50
    ## Get config from env vars
51
    
52
    # Modes
53
    test = opts.env_flag('test', False, env_names)
54
    commit = opts.env_flag('commit', False, env_names) and not test
55
        # never commit in test mode
56
    redo = opts.env_flag('redo', test, env_names) and not commit
57
        # never redo in commit mode (manually run `make empty_db` instead)
58
    
59
    # Ranges
60
    start = util.cast(int, opts.get_env_var('start', '0', env_names))
61
    if test: end_default = 1
62
    else: end_default = None
63
    end = util.cast(int, util.none_if(
64
        opts.get_env_var('n', end_default, env_names), u''))
65
    if end != None: end += start
66
    
67
    # Debugging
68
    debug = opts.env_flag('debug', False, env_names)
69
    sql.run_raw_query.debug = debug
70
    verbose = debug or opts.env_flag('verbose', False, env_names)
71
    opts.get_env_var('profile_to', None, env_names) # add to env_names
72
    
73
    # DB
74
    db_config_names = ['engine', 'host', 'user', 'password', 'database']
75
    def get_db_config(prefix):
76
        return opts.get_env_vars(db_config_names, prefix, env_names)
77
    in_db_config = get_db_config('in')
78
    out_db_config = get_db_config('out')
79
    in_is_db = 'engine' in in_db_config
80
    out_is_db = 'engine' in out_db_config
81
    
82
    ##
83
    
84
    # Logging
85
    def log(msg, on=verbose):
86
        if on: sys.stderr.write(msg)
87
    def log_start(action, on=verbose): log(action+'...\n', on)
88
    
89
    # Parse args
90
    map_paths = sys.argv[1:]
91
    if map_paths == []:
92
        if in_is_db or not out_is_db: usage_err()
93
        else: map_paths = [None]
94
    
95
    def connect_db(db_config):
96
        log_start('Connecting to '+sql.db_config_str(db_config))
97
        return sql.connect(db_config)
98
    
99
    ex_tracker = exc.ExPercentTracker(iter_text='row')
100
    profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
101
    
102
    doc = xml_dom.create_doc()
103
    root = doc.documentElement
104
    out_is_xml_ref = [False]
105
    in_label_ref = [None]
106
    def update_in_label():
107
        if in_label_ref[0] != None:
108
            xpath.get(root, '/_ignore/inLabel="'+in_label_ref[0]+'"', True)
109
    def prep_root():
110
        root.clear()
111
        update_in_label()
112
    prep_root()
113
    
114
    def process_input(root, row_ready, map_path):
115
        '''Inputs datasource to XML tree, mapping if needed'''
116
        # Load map header
117
        in_is_xpaths = True
118
        out_is_xpaths = True
119
        out_label = None
120
        if map_path != None:
121
            metadata = []
122
            mappings = []
123
            stream = open(map_path, 'rb')
124
            reader = csv.reader(stream)
125
            in_label, out_label = reader.next()[:2]
126
            
127
            def split_col_name(name):
128
                label, sep, root = name.partition(':')
129
                label, sep2, prefixes_str = label.partition('[')
130
                prefixes_str = strings.remove_suffix(']', prefixes_str)
131
                prefixes = strings.split(',', prefixes_str)
132
                return label, sep != '', root, prefixes
133
                    # extract datasrc from "datasrc[data_format]"
134
            
135
            in_label, in_is_xpaths, in_root, prefixes = split_col_name(in_label)
136
            in_label_ref[0] = in_label
137
            update_in_label()
138
            out_label, out_is_xpaths, out_root = split_col_name(out_label)[:3]
139
            has_types = out_root.startswith('/*s/') # outer elements are types
140
            
141
            for row in reader:
142
                in_, out = row[:2]
143
                if out != '':
144
                    if out_is_xpaths: out = xpath.parse(out_root+out)
145
                    mappings.append((in_, out))
146
            
147
            stream.close()
148
            
149
            root.ownerDocument.documentElement.tagName = out_label
150
        in_is_xml = in_is_xpaths and not in_is_db
151
        out_is_xml_ref[0] = out_is_xpaths and not out_is_db
152
        
153
        if in_is_xml:
154
            doc0 = minidom.parse(sys.stdin)
155
            doc0_root = doc0.documentElement
156
            if out_label == None: out_label = doc0_root.tagName
157
        
158
        def process_rows(process_row, rows):
159
            '''Processes input rows
160
            @param process_row(in_row, i)
161
            '''
162
            i = -1 # in case for loop does not execute
163
            for i, row in enumerate(rows):
164
                if i < start: continue
165
                if end != None and i >= end: break
166
                process_row(row, i)
167
                row_ready(i, row)
168
            row_ct = i-start+1
169
            return row_ct
170
        
171
        def map_rows(get_value, rows):
172
            '''Maps input rows
173
            @param get_value(in_, row):str
174
            '''
175
            def process_row(row, i):
176
                row_id = str(i)
177
                for in_, out in mappings:
178
                    value = metadata_value(in_)
179
                    if value == None:
180
                        log_start('Getting '+str(in_), debug)
181
                        value = cleanup(get_value(in_, row))
182
                    if value != None:
183
                        log_start('Putting '+str(out), debug)
184
                        xpath.put_obj(root, out, row_id, has_types, value)
185
            return process_rows(process_row, rows)
186
        
187
        def map_table(col_names, rows):
188
            col_names_ct = len(col_names)
189
            col_idxs = util.list_flip(col_names)
190
            
191
            i = 0
192
            while i < len(mappings): # mappings len changes in loop
193
                in_, out = mappings[i]
194
                if metadata_value(in_) == None:
195
                    try: mappings[i] = (
196
                        get_with_prefix(col_idxs, prefixes, in_), out)
197
                    except KeyError:
198
                        del mappings[i]
199
                        continue # keep i the same
200
                i += 1
201
            
202
            def get_value(in_, row):
203
                return util.coalesce(*util.list_subset(row.list, in_))
204
            def wrap_row(row):
205
                return util.ListDict(util.list_as_length(row, col_names_ct),
206
                    col_names, col_idxs) # handle CSV rows of different lengths
207
            
208
            return map_rows(get_value, util.WrapIter(wrap_row, rows))
209
        
210
        if map_path == None:
211
            iter_ = xml_dom.NodeElemIter(doc0_root)
212
            util.skip(iter_, xml_dom.is_text) # skip metadata
213
            row_ct = process_rows(lambda row, i: root.appendChild(row), iter_)
214
        elif in_is_db:
215
            assert in_is_xpaths
216
            
217
            in_db = connect_db(in_db_config)
218
            in_pkeys = {}
219
            cur = sql.select(in_db, table=in_root, fields=None, conds=None,
220
                limit=end, start=0)
221
            row_ct = map_table(list(sql.col_names(cur)), sql.rows(cur))
222
            
223
            in_db.close()
224
        elif in_is_xml:
225
            def get_value(in_, row):
226
                nodes = xpath.get(row, in_, allow_rooted=False)
227
                if nodes != []: return xml_dom.value(nodes[0])
228
                else: return None
229
            rows = xpath.get(doc0_root, in_root, limit=end)
230
            if rows == []: raise SystemExit('Map error: Root "'+in_root
231
                +'" not found in input')
232
            row_ct = map_rows(get_value, rows)
233
        else: # input is CSV
234
            map_ = dict(mappings)
235
            reader, col_names = csvs.reader_and_header(sys.stdin)
236
            row_ct = map_table(col_names, reader)
237
        
238
        return row_ct
239
    
240
    def process_inputs(root, row_ready):
241
        row_ct = 0
242
        for map_path in map_paths:
243
            row_ct += process_input(root, row_ready, map_path)
244
        return row_ct
245
    
246
    if out_is_db:
247
        import db_xml
248
        
249
        out_db = connect_db(out_db_config)
250
        out_pkeys = {}
251
        try:
252
            if redo: sql.empty_db(out_db)
253
            row_ins_ct_ref = [0]
254
            
255
            def row_ready(row_num, input_row):
256
                def on_error(e):
257
                    exc.add_msg(e, term.emph('row #:')+' '+str(row_num))
258
                    exc.add_msg(e, term.emph('input row:')+'\n'+str(input_row))
259
                    exc.add_msg(e, term.emph('output row:')+'\n'+str(root))
260
                    ex_tracker.track(e)
261
                
262
                xml_func.process(root, on_error)
263
                if not xml_dom.is_empty(root):
264
                    assert xml_dom.has_one_child(root)
265
                    try:
266
                        sql.with_savepoint(out_db,
267
                            lambda: db_xml.put(out_db, root.firstChild,
268
                                out_pkeys, row_ins_ct_ref, on_error))
269
                        if commit: out_db.commit()
270
                    except sql.DatabaseErrors, e: on_error(e)
271
                prep_root()
272
            
273
            row_ct = process_inputs(root, row_ready)
274
            sys.stdout.write('Inserted '+str(row_ins_ct_ref[0])+
275
                ' new rows into database\n')
276
        finally:
277
            out_db.rollback()
278
            out_db.close()
279
    else:
280
        def on_error(e): ex_tracker.track(e)
281
        def row_ready(row_num, input_row): pass
282
        row_ct = process_inputs(root, row_ready)
283
        xml_func.process(root, on_error)
284
        if out_is_xml_ref[0]:
285
            doc.writexml(sys.stdout, **xml_dom.prettyxml_config)
286
        else: # output is CSV
287
            raise NotImplementedError('CSV output not supported yet')
288
    
289
    profiler.stop(row_ct)
290
    ex_tracker.add_iters(row_ct)
291
    if verbose:
292
        sys.stderr.write('Processed '+str(row_ct)+' input rows\n')
293
        sys.stderr.write(profiler.msg()+'\n')
294
        sys.stderr.write(ex_tracker.msg()+'\n')
295
    ex_tracker.exit()
296

    
297
def main():
298
    try: main_()
299
    except Parser.SyntaxException, e: raise SystemExit(str(e))
300

    
301
if __name__ == '__main__':
302
    profile_to = opts.get_env_var('profile_to', None)
303
    if profile_to != None:
304
        import cProfile
305
        cProfile.run(main.func_code, profile_to)
306
    else: main()
(18-18/36)