Project

General

Profile

1
# CSV I/O
2

    
3
import csv
4
import _csv
5
import StringIO
6

    
7
import exc
8
import streams
9
import strings
10
import util
11

    
12
delims = ',;\t|`'
13
tab_padded_delims = ['\t|\t']
14
tsv_delim = '\t'
15
escape = '\\'
16

    
17
ending_placeholder = r'\n'
18

    
19
def is_tsv(dialect): return dialect.delimiter.startswith(tsv_delim)
20

    
21
def sniff(line):
22
    '''Automatically detects the dialect'''
23
    line, ending = strings.extract_line_ending(line)
24
    dialect = csv.Sniffer().sniff(line, delims)
25
    
26
    if is_tsv(dialect):
27
        # Check multi-char delims using \t
28
        delim = strings.find_any(line, tab_padded_delims)
29
        if delim:
30
            dialect.delimiter = delim
31
            line_suffix = delim.rstrip('\t')
32
            if line.endswith(line_suffix): ending = line_suffix+ending
33
    else: dialect.doublequote = True # Sniffer doesn't turn this on by default
34
    dialect.lineterminator = ending
35
    
36
    return dialect
37

    
38
def stream_info(stream, parse_header=False):
39
    '''Automatically detects the dialect based on the header line.
40
    Uses the Excel dialect if the CSV file is empty.
41
    @return NamedTuple {header_line, header, dialect}'''
42
    info = util.NamedTuple()
43
    info.header_line = stream.readline()
44
    info.header = None
45
    if info.header_line != '':
46
        info.dialect = sniff(info.header_line)
47
    else: info.dialect = csv.excel # line of '' indicates EOF = empty stream
48
    
49
    if parse_header:
50
        try: info.header = reader_class(info.dialect)(
51
            StringIO.StringIO(info.header_line), info.dialect).next()
52
        except StopIteration: info.header = []
53
    
54
    return info
55

    
56
tsv_encode_map = strings.json_encode_map[:]
57
tsv_encode_map.append(('\t', r'\t'))
58
tsv_decode_map = strings.flip_map(tsv_encode_map)
59

    
60
class TsvReader:
61
    '''Unlike csv.reader, for TSVs, interprets \ as escaping a line ending but
62
    ignores it before everything else (e.g. \N for NULL).
63
    Also expands tsv_encode_map escapes.
64
    '''
65
    def __init__(self, stream, dialect):
66
        assert is_tsv(dialect)
67
        self.stream = stream
68
        self.dialect = dialect
69
    
70
    def __iter__(self): return self
71
    
72
    def next(self):
73
        record = ''
74
        ending = None
75
        while True:
76
            line = self.stream.readline()
77
            if line == '': raise StopIteration
78
            
79
            line = strings.remove_suffix(self.dialect.lineterminator, line)
80
            contents = strings.remove_suffix(escape, line)
81
            record += contents
82
            if len(contents) == len(line): break # no line continuation
83
            record += ending_placeholder
84
        
85
        # Prevent "new-line character seen in unquoted field" errors
86
        record = record.replace('\r', ending_placeholder)
87
        
88
        # Split line
89
        if len(self.dialect.delimiter) > 1: # multi-char delims
90
            row = record.split(self.dialect.delimiter)
91
        else: row = csv.reader(StringIO.StringIO(record), self.dialect).next()
92
        
93
        return [strings.replace_all(tsv_decode_map, v) for v in row]
94

    
95
def reader_class(dialect):
96
    if is_tsv(dialect): return TsvReader
97
    else: return csv.reader
98

    
99
def make_reader(stream, dialect): return reader_class(dialect)(stream, dialect)
100

    
101
def reader_and_header(stream):
102
    '''Automatically detects the dialect based on the header line
103
    @return tuple (reader, header)'''
104
    info = stream_info(stream, parse_header=True)
105
    return (make_reader(stream, info.dialect), info.header)
106

    
107
##### csv modifications
108

    
109
# Note that these methods only work on *instances* of Dialect classes
110
csv.Dialect.__eq__ = lambda self, other: self.__dict__ == other.__dict__
111
csv.Dialect.__ne__ = lambda self, other: not (self == other)
112

    
113
__Dialect__validate_orig = csv.Dialect._validate
114
def __Dialect__validate(self):
115
        try: __Dialect__validate_orig(self)
116
        except _csv.Error, e:
117
            if str(e) == '"delimiter" must be an 1-character string': pass # OK
118
            else: raise
119
csv.Dialect._validate = __Dialect__validate
120

    
121
##### Row filters
122

    
123
class Filter:
124
    '''Wraps a reader, filtering each row'''
125
    def __init__(self, filter_, reader):
126
        self.reader = reader
127
        self.filter = filter_
128
    
129
    def __iter__(self): return self
130
    
131
    def next(self): return self.filter(self.reader.next())
132
    
133
    def close(self): pass # support using as a stream
134

    
135
std_nulls = [r'\N']
136
empty_nulls = [''] + std_nulls
137

    
138
class NullFilter(Filter):
139
    '''Translates special string values to None'''
140
    def __init__(self, reader, nulls=std_nulls):
141
        map_ = dict.fromkeys(nulls, None)
142
        def filter_(row): return [map_.get(v, v) for v in row]
143
        Filter.__init__(self, filter_, reader)
144

    
145
class StripFilter(Filter):
146
    '''Strips whitespace'''
147
    def __init__(self, reader):
148
        def filter_(row): return [v.strip() for v in row]
149
        Filter.__init__(self, filter_, reader)
150

    
151
class ColCtFilter(Filter):
152
    '''Gives all rows the same # columns'''
153
    def __init__(self, reader, cols_ct):
154
        def filter_(row): return util.list_as_length(row, cols_ct)
155
        Filter.__init__(self, filter_, reader)
156

    
157
##### Translators
158

    
159
class StreamFilter(Filter):
160
    '''Wraps a reader in a way that's usable to a filter stream that does not
161
    require lines to be strings. Reports EOF as '' instead of StopIteration.'''
162
    def __init__(self, reader):
163
        Filter.__init__(self, None, reader)
164
    
165
    def readline(self):
166
        try: return self.reader.next()
167
        except StopIteration: return '' # EOF
168

    
169
class ColInsertFilter(Filter):
170
    '''Adds column(s) to each row
171
    @param mk_value(row, row_num)
172
    '''
173
    def __init__(self, reader, mk_value, index=0, n=1):
174
        def filter_(row):
175
            row = list(row) # make sure it's mutable; don't modify input!
176
            for i in xrange(n):
177
                row.insert(index+i, mk_value(row, self.reader.line_num))
178
            return row
179
        Filter.__init__(self, filter_,
180
            streams.LineCountInputStream(StreamFilter(reader)))
181

    
182
class RowNumFilter(ColInsertFilter):
183
    '''Adds a row # column at the beginning of each row'''
184
    def __init__(self, reader):
185
        def mk_value(row, row_num): return row_num
186
        ColInsertFilter.__init__(self, reader, mk_value, 0)
187

    
188
class InputRewriter(StreamFilter):
189
    '''Wraps a reader, writing each row back to CSV'''
190
    def __init__(self, reader, dialect=csv.excel):
191
        StreamFilter.__init__(self, reader)
192
        
193
        self.dialect = dialect
194
    
195
    def readline(self):
196
        try:
197
            row = self.reader.readline()
198
            if row == '': return row # EOF
199
            
200
            line_stream = StringIO.StringIO()
201
            csv.writer(line_stream, self.dialect).writerow(row)
202
            return line_stream.getvalue()
203
        except Exception, e:
204
            exc.print_ex(e)
205
            raise
206
    
207
    def read(self, n): return self.readline() # forward all reads to readline()
(10-10/47)