Project

General

Profile

1
# XML "function" nodes that transform their contents
2

    
3
import datetime
4
import re
5
import sre_constants
6
import warnings
7

    
8
import angles
9
import dates
10
import exc
11
import format
12
import maps
13
import sql_io
14
import strings
15
import term
16
import units
17
import util
18
import xml_dom
19
import xpath
20

    
21
##### Exceptions
22

    
23
class SyntaxError(exc.ExceptionWithCause):
24
    def __init__(self, cause):
25
        exc.ExceptionWithCause.__init__(self, 'Invalid XML function syntax',
26
            cause)
27

    
28
class FormatException(exc.ExceptionWithCause):
29
    def __init__(self, cause):
30
        exc.ExceptionWithCause.__init__(self, 'Invalid input value', cause)
31

    
32
##### Helper functions
33

    
34
def map_items(func, items):
35
    return [(name, func(value)) for name, value in items]
36

    
37
def cast(type_, val):
38
    '''Throws FormatException if can't cast'''
39
    try: return type_(val)
40
    except ValueError, e: raise FormatException(e)
41

    
42
def conv_items(type_, items):
43
    return map_items(lambda val: cast(type_, val),
44
        xml_dom.TextEntryOnlyIter(items))
45

    
46
def pop_value(items, name='value'):
47
    '''@param name Name of value param, or None to accept any name'''
48
    try: last = items.pop() # last entry contains value
49
    except IndexError: return None # input is empty and no actions
50
    if name != None and last[0] != name: return None # input is empty
51
    return last[1]
52

    
53
def merge_tagged(root):
54
    '''Merges siblings in root that are marked as mergeable.
55
    Used to recombine pieces of nodes that were split apart in the mappings.
56
    '''
57
    for name in set((c.tagName for c in xpath.get(root, '*[@merge=1]'))):
58
        xml_dom.merge_by_name(root, name)
59
    
60
    # Recurse
61
    for child in xml_dom.NodeElemIter(root): merge_tagged(child)
62

    
63
funcs = {}
64

    
65
structural_funcs = set()
66

    
67
##### Public functions
68

    
69
def is_func_name(name):
70
    return name.startswith('_') and name != '_' # '_' is default root node name
71

    
72
def is_func(node): return is_func_name(node.tagName)
73

    
74
def is_xml_func_name(name): return is_func_name(name) and name in funcs
75

    
76
def is_xml_func(node): return is_xml_func_name(node.tagName)
77

    
78
def process(node, on_error=exc.raise_, is_rel_func=None, db=None):
79
    '''Evaluates the XML functions in an XML tree.
80
    @param is_rel_func None|f(str) Tests if a name is a relational function.
81
        * If != None: Non-relational functions are removed, or relational
82
          functions are treated specially, depending on the db param (below).
83
    @param db
84
        * If None: Non-relational functions other than structural functions are
85
          replaced with their last parameter (usually the value), not evaluated.
86
          This is used in column-based mode to remove XML-only functions.
87
        * If != None: Relational functions are evaluated directly. This is used
88
          in row-based mode to combine relational and XML functions.
89
    '''
90
    has_rel_funcs = is_rel_func != None
91
    assert db == None or has_rel_funcs # rel_funcs required if db set
92
    
93
    for child in xml_dom.NodeElemIter(node):
94
        process(child, on_error, is_rel_func, db)
95
    merge_tagged(node)
96
    
97
    name = node.tagName
98
    if not is_func_name(name): return node # not any kind of function
99
    
100
    row_mode = has_rel_funcs and db != None
101
    column_mode = has_rel_funcs and db == None
102
    items = list(xml_dom.NodeTextEntryIter(node))
103
    
104
    # Parse function
105
    if len(items) == 1 and items[0][0].isdigit(): # has single numeric param
106
        # pass-through optimization for aggregating functions with one arg
107
        value = items[0][1] # pass through first arg
108
    elif row_mode and is_rel_func(name): # row-based mode: evaluate using DB
109
        value = sql_io.put(db, name, dict(items))
110
    elif column_mode and not name in structural_funcs: # column-based mode
111
        if is_rel_func(name): return # preserve relational functions
112
        # otherwise XML-only, so just replace with last param
113
        value = pop_value(items, None)
114
    else: # local XML function
115
        try: value = funcs[name](items, node)
116
        except Exception, e: # also catch non-wrapped exceptions (XML func bugs)
117
            # Save in case another exception raised, overwriting sys.exc_info()
118
            exc.add_traceback(e)
119
            str_ = strings.ustr(node)
120
            exc.add_msg(e, 'function:\n'+str_)
121
            xml_dom.replace(node, xml_dom.mk_comment(node.ownerDocument,
122
                '\n'+term.emph_multiline(str_)))
123
                
124
            on_error(e)
125
            return # in case on_error() returns
126
    
127
    xml_dom.replace_with_text(node, value)
128

    
129
##### XML functions
130

    
131
# Function names must start with _ to avoid collisions with real tags
132
# Functions take arguments (items)
133

    
134
#### Structural
135

    
136
def _ignore(items, node):
137
    '''Used to "comment out" an XML subtree'''
138
    return None
139
funcs['_ignore'] = _ignore
140
structural_funcs.add('_ignore')
141

    
142
def _ref(items, node):
143
    '''Used to retrieve a value from another XML node
144
    @param items
145
        addr=<path> XPath to value, relative to the XML func's parent node
146
    '''
147
    items = dict(items)
148
    try: addr = items['addr']
149
    except KeyError, e: raise SyntaxError(e)
150
    
151
    value = xpath.get_value(node.parentNode, addr)
152
    if value == None:
153
        warnings.warn(UserWarning('_ref: XPath reference target missing: '
154
            +str(addr)))
155
    return value
156
funcs['_ref'] = _ref
157
structural_funcs.add('_ref')
158

    
159
#### Conditionals
160

    
161
def _eq(items, node):
162
    items = dict(items)
163
    try:
164
        left = items['left']
165
        right = items['right']
166
    except KeyError: return '' # a value was None
167
    return util.bool2str(left == right)
168
funcs['_eq'] = _eq
169

    
170
def _if(items, node):
171
    items = dict(items)
172
    try:
173
        cond = items['cond']
174
        then = items['then']
175
    except KeyError, e: raise SyntaxError(e)
176
    else_ = items.get('else', None)
177
    cond = bool(cast(strings.ustr, cond))
178
    if cond: return then
179
    else: return else_
180
funcs['_if'] = _if
181

    
182
#### Combining values
183

    
184
def _alt(items, node):
185
    items = list(items)
186
    items.sort()
187
    try: return items[0][1] # value of lowest-numbered item
188
    except IndexError: return None # input got removed by e.g. FormatException
189
funcs['_alt'] = _alt
190

    
191
def _merge(items, node):
192
    items = list(conv_items(strings.ustr, items))
193
        # get *once* from iter, check types
194
    items.sort()
195
    return maps.merge_values(*[v for k, v in items])
196
funcs['_merge'] = _merge
197

    
198
def _label(items, node):
199
    items = dict(conv_items(strings.ustr, items))
200
        # get *once* from iter, check types
201
    value = items.get('value', None)
202
    if value == None: return None # input is empty
203
    try: label = items['label']
204
    except KeyError, e: raise SyntaxError(e)
205
    return label+': '+value
206
funcs['_label'] = _label
207

    
208
#### Transforming values
209

    
210
def _collapse(items, node):
211
    '''Collapses a subtree if the "value" element in it is NULL'''
212
    items = dict(items)
213
    try: require = cast(strings.ustr, items['require'])
214
    except KeyError, e: raise SyntaxError(e)
215
    value = items.get('value', None)
216
    
217
    if xpath.get_value(value, require, allow_rooted=False) == None: return None
218
    else: return value
219
funcs['_collapse'] = _collapse
220

    
221
types_by_name = {None: strings.ustr, 'str': strings.ustr, 'float': float}
222

    
223
def _nullIf(items, node):
224
    items = dict(conv_items(strings.ustr, items))
225
    try: null = items['null']
226
    except KeyError, e: raise SyntaxError(e)
227
    value = items.get('value', None)
228
    type_str = items.get('type', None)
229
    
230
    try: type_ = types_by_name[type_str]
231
    except KeyError, e: raise SyntaxError(e)
232
    null = type_(null)
233
    
234
    try: return util.none_if(value, null)
235
    except ValueError: return value # value not convertible, so can't equal null
236
funcs['_nullIf'] = _nullIf
237

    
238
def repl(repls, value):
239
    '''Raises error if value not in map and no special '*' entry
240
    @param repls dict repl:with
241
        repl "*" means all other input values
242
        with "*" means keep input value the same
243
        with "" means ignore input value
244
    '''
245
    try: new_value = repls[value]
246
    except KeyError, e:
247
        # Save traceback right away in case another exception raised
248
        fe = FormatException(e)
249
        try: new_value = repls['*']
250
        except KeyError: raise fe
251
    if new_value == '*': new_value = value # '*' means keep input value the same
252
    return new_value
253

    
254
def _map(items, node):
255
    '''See repl()
256
    @param items
257
        <last_entry> Value
258
        <other_entries> name=value Mappings. Special values: See repl() repls.
259
    '''
260
    items = conv_items(strings.ustr, items) # get *once* from iter, check types
261
    value = pop_value(items)
262
    if value == None: return None # input is empty
263
    return util.none_if(repl(dict(items), value), u'') # empty value means None
264
funcs['_map'] = _map
265

    
266
def _replace(items, node):
267
    items = conv_items(strings.ustr, items) # get *once* from iter, check types
268
    value = pop_value(items)
269
    if value == None: return None # input is empty
270
    try:
271
        for repl, with_ in items:
272
            if re.match(r'^\w+$', repl):
273
                repl = r'(?<![^\W_])'+repl+r'(?![^\W_])' # match whole word
274
            value = re.sub(repl, with_, value)
275
    except sre_constants.error, e: raise SyntaxError(e)
276
    return util.none_if(value.strip(), u'') # empty strings always mean None
277
funcs['_replace'] = _replace
278

    
279
#### Quantities
280

    
281
def _units(items, node):
282
    items = conv_items(strings.ustr, items) # get *once* from iter, check types
283
    value = pop_value(items)
284
    if value == None: return None # input is empty
285
    
286
    quantity = units.str2quantity(value)
287
    try:
288
        for action, units_ in items:
289
            units_ = util.none_if(units_, u'')
290
            if action == 'default': units.set_default_units(quantity, units_)
291
            elif action == 'to':
292
                try: quantity = units.convert(quantity, units_)
293
                except ValueError, e: raise FormatException(e)
294
            else: raise SyntaxError(ValueError('Invalid action: '+action))
295
    except units.MissingUnitsException, e: raise FormatException(e)
296
    return units.quantity2str(quantity)
297
funcs['_units'] = _units
298

    
299
def parse_range(str_, range_sep='-'):
300
    default = (str_, None)
301
    start, sep, end = str_.partition(range_sep)
302
    if sep == '': return default # not a range
303
    if start == '' and range_sep == '-': return default # negative number
304
    return tuple(d.strip() for d in (start, end))
305

    
306
def _rangeStart(items, node):
307
    items = dict(conv_items(strings.ustr, items))
308
    try: value = items['value']
309
    except KeyError: return None # input is empty
310
    return parse_range(value)[0]
311
funcs['_rangeStart'] = _rangeStart
312

    
313
def _rangeEnd(items, node):
314
    items = dict(conv_items(strings.ustr, items))
315
    try: value = items['value']
316
    except KeyError: return None # input is empty
317
    return parse_range(value)[1]
318
funcs['_rangeEnd'] = _rangeEnd
319

    
320
def _range(items, node):
321
    items = dict(conv_items(float, items))
322
    from_ = items.get('from', None)
323
    to = items.get('to', None)
324
    if from_ == None or to == None: return None
325
    return str(to - from_)
326
funcs['_range'] = _range
327

    
328
def _avg(items, node):
329
    count = 0
330
    sum_ = 0.
331
    for name, value in conv_items(float, items):
332
        count += 1
333
        sum_ += value
334
    if count == 0: return None # input is empty
335
    else: return str(sum_/count)
336
funcs['_avg'] = _avg
337

    
338
class CvException(Exception):
339
    def __init__(self):
340
        Exception.__init__(self, 'CV (coefficient of variation) values are only'
341
            ' allowed for ratio scale data '
342
            '(see <http://en.wikipedia.org/wiki/Coefficient_of_variation>)')
343

    
344
def _noCV(items, node):
345
    try: name, value = items.pop() # last entry contains value
346
    except IndexError: return None # input is empty
347
    if re.match('^(?i)CV *\d+$', value): raise FormatException(CvException())
348
    return value
349
funcs['_noCV'] = _noCV
350

    
351
#### Dates
352

    
353
def _date(items, node):
354
    items = dict(conv_items(strings.ustr, items))
355
        # get *once* from iter, check types
356
    try: str_ = items['date']
357
    except KeyError:
358
        # Year is required
359
        try: items['year']
360
        except KeyError, e:
361
            if items == {}: return None # entire date is empty
362
            else: raise FormatException(e)
363
        
364
        # Convert month name to number
365
        try: month = items['month']
366
        except KeyError: pass
367
        else:
368
            if not month.isdigit(): # month is name
369
                try: items['month'] = str(dates.strtotime(month).month)
370
                except ValueError, e: raise FormatException(e)
371
        
372
        items = dict(conv_items(format.str2int, items.iteritems()))
373
        items.setdefault('month', 1)
374
        items.setdefault('day', 1)
375
        
376
        for try_num in xrange(2):
377
            try:
378
                date = datetime.date(**items)
379
                break
380
            except ValueError, e:
381
                if try_num > 0: raise FormatException(e)
382
                    # exception still raised after retry
383
                msg = strings.ustr(e)
384
                if msg == 'month must be in 1..12': # try swapping month and day
385
                    items['month'], items['day'] = items['day'], items['month']
386
                else: raise FormatException(e)
387
    else:
388
        try: year = float(str_)
389
        except ValueError:
390
            try: date = dates.strtotime(str_)
391
            except ImportError: return str_
392
            except ValueError, e: raise FormatException(e)
393
        else: date = (datetime.date(int(year), 1, 1) +
394
            datetime.timedelta(round((year % 1.)*365)))
395
    try: return dates.strftime('%Y-%m-%d', date)
396
    except ValueError, e: raise FormatException(e)
397
funcs['_date'] = _date
398

    
399
def _dateRangeStart(items, node):
400
    items = dict(conv_items(strings.ustr, items))
401
    try: value = items['value']
402
    except KeyError: return None # input is empty
403
    return dates.parse_date_range(value)[0]
404
funcs['_dateRangeStart'] = _dateRangeStart
405

    
406
def _dateRangeEnd(items, node):
407
    items = dict(conv_items(strings.ustr, items))
408
    try: value = items['value']
409
    except KeyError: return None # input is empty
410
    return dates.parse_date_range(value)[1]
411
funcs['_dateRangeEnd'] = _dateRangeEnd
412

    
413
#### Names
414

    
415
_name_parts_slices_items = [
416
    ('first', slice(None, 1)),
417
    ('middle', slice(1, -1)),
418
    ('last', slice(-1, None)),
419
]
420
name_parts_slices = dict(_name_parts_slices_items)
421
name_parts = [name for name, slice_ in _name_parts_slices_items]
422

    
423
def _name(items, node):
424
    items = dict(items)
425
    parts = []
426
    for part in name_parts:
427
        if part in items: parts.append(items[part])
428
    return ' '.join(parts)
429
funcs['_name'] = _name
430

    
431
def _namePart(items, node):
432
    out_items = []
433
    for part, value in items:
434
        try: slice_ = name_parts_slices[part]
435
        except KeyError, e: raise SyntaxError(e)
436
        out_items.append((part, ' '.join(value.split(' ')[slice_])))
437
    return _name(out_items, node)
438
funcs['_namePart'] = _namePart
439

    
440
#### Angles
441

    
442
def _compass(items, node):
443
    '''Converts a compass direction (N, NE, NNE, etc.) into a degree heading'''
444
    items = dict(conv_items(strings.ustr, items))
445
    try: value = items['value']
446
    except KeyError: return None # input is empty
447
    
448
    if not value.isupper(): return value # pass through other coordinate formats
449
    try: return util.cast(str, angles.compass2heading(value)) # ignore None
450
    except KeyError, e: raise FormatException(e)
451
funcs['_compass'] = _compass
452

    
453
#### Paths
454

    
455
def _simplifyPath(items, node):
456
    items = dict(items)
457
    try:
458
        next = cast(strings.ustr, items['next'])
459
        require = cast(strings.ustr, items['require'])
460
        root = items['path']
461
    except KeyError, e: raise SyntaxError(e)
462
    
463
    node = root
464
    while node != None:
465
        new_node = xpath.get_1(node, next, allow_rooted=False)
466
        if xpath.get_value(node, require, allow_rooted=False) == None: # empty
467
            xml_dom.replace(node, new_node) # remove current elem
468
            if node is root: root = new_node # also update root
469
        node = new_node
470
    return root
471
funcs['_simplifyPath'] = _simplifyPath
(34-34/37)