Project

General

Profile

1
# CSV I/O
2

    
3
import csv
4
import _csv
5
import StringIO
6

    
7
import streams
8
import strings
9
import util
10

    
11
delims = ',;\t|`'
12
tab_padded_delims = ['\t|\t']
13
tsv_delim = '\t'
14
escape = '\\'
15

    
16
ending_placeholder = r'\n'
17

    
18
def is_tsv(dialect): return dialect.delimiter.startswith(tsv_delim)
19

    
20
def sniff(line):
21
    '''Automatically detects the dialect'''
22
    line, ending = strings.extract_line_ending(line)
23
    dialect = csv.Sniffer().sniff(line, delims)
24
    
25
    if is_tsv(dialect):
26
        # Check multi-char delims using \t
27
        delim = strings.find_any(line, tab_padded_delims)
28
        if delim:
29
            dialect.delimiter = delim
30
            line_suffix = delim.rstrip('\t')
31
            if line.endswith(line_suffix): ending = line_suffix+ending
32
    else: dialect.doublequote = True # Sniffer doesn't turn this on by default
33
    dialect.lineterminator = ending
34
    
35
    return dialect
36

    
37
def stream_info(stream, parse_header=False):
38
    '''Automatically detects the dialect based on the header line.
39
    Uses the Excel dialect if the CSV file is empty.
40
    @return NamedTuple {header_line, header, dialect}'''
41
    info = util.NamedTuple()
42
    info.header_line = stream.readline()
43
    info.header = None
44
    if info.header_line != '':
45
        info.dialect = sniff(info.header_line)
46
    else: info.dialect = csv.excel # line of '' indicates EOF = empty stream
47
    
48
    if parse_header:
49
        try: info.header = reader_class(info.dialect)(
50
            StringIO.StringIO(info.header_line), info.dialect).next()
51
        except StopIteration: info.header = []
52
    
53
    return info
54

    
55
tsv_encode_map = strings.json_encode_map[:]
56
tsv_encode_map.append(('\t', r'\t'))
57
tsv_decode_map = strings.flip_map(tsv_encode_map)
58

    
59
class TsvReader:
60
    '''Unlike csv.reader, for TSVs, interprets \ as escaping a line ending but
61
    ignores it before everything else (e.g. \N for NULL).
62
    Also expands tsv_encode_map escapes.
63
    '''
64
    def __init__(self, stream, dialect):
65
        assert is_tsv(dialect)
66
        self.stream = stream
67
        self.dialect = dialect
68
    
69
    def __iter__(self): return self
70
    
71
    def next(self):
72
        record = ''
73
        ending = None
74
        while True:
75
            line = self.stream.readline()
76
            if line == '': raise StopIteration
77
            
78
            line = strings.remove_suffix(self.dialect.lineterminator, line)
79
            contents = strings.remove_suffix(escape, line)
80
            record += contents
81
            if len(contents) == len(line): break # no line continuation
82
            record += ending_placeholder
83
        
84
        # Prevent "new-line character seen in unquoted field" errors
85
        record = record.replace('\r', ending_placeholder)
86
        
87
        # Split line
88
        if len(self.dialect.delimiter) > 1: # multi-char delims
89
            row = record.split(self.dialect.delimiter)
90
        else: row = csv.reader(StringIO.StringIO(record), self.dialect).next()
91
        
92
        return [strings.replace_all(tsv_decode_map, v) for v in row]
93

    
94
def reader_class(dialect):
95
    if is_tsv(dialect): return TsvReader
96
    else: return csv.reader
97

    
98
def make_reader(stream, dialect): return reader_class(dialect)(stream, dialect)
99

    
100
def reader_and_header(stream):
101
    '''Automatically detects the dialect based on the header line
102
    @return tuple (reader, header)'''
103
    info = stream_info(stream, parse_header=True)
104
    return (make_reader(stream, info.dialect), info.header)
105

    
106
##### csv modifications
107

    
108
# Note that these methods only work on *instances* of Dialect classes
109
csv.Dialect.__eq__ = lambda self, other: self.__dict__ == other.__dict__
110
csv.Dialect.__ne__ = lambda self, other: not (self == other)
111

    
112
__Dialect__validate_orig = csv.Dialect._validate
113
def __Dialect__validate(self):
114
        try: __Dialect__validate_orig(self)
115
        except _csv.Error, e:
116
            if str(e) == '"delimiter" must be an 1-character string': pass # OK
117
            else: raise
118
csv.Dialect._validate = __Dialect__validate
119

    
120
##### Row filters
121

    
122
class Filter:
123
    '''Wraps a reader, filtering each row'''
124
    def __init__(self, filter_, reader):
125
        self.reader = reader
126
        self.filter = filter_
127
    
128
    def __iter__(self): return self
129
    
130
    def next(self): return self.filter(self.reader.next())
131
    
132
    def close(self): pass # support using as a stream
133

    
134
std_nulls = [r'\N']
135
empty_nulls = [''] + std_nulls
136

    
137
class NullFilter(Filter):
138
    '''Translates special string values to None'''
139
    def __init__(self, reader, nulls=std_nulls):
140
        map_ = dict.fromkeys(nulls, None)
141
        def filter_(row): return [map_.get(v, v) for v in row]
142
        Filter.__init__(self, filter_, reader)
143

    
144
class StripFilter(Filter):
145
    '''Strips whitespace'''
146
    def __init__(self, reader):
147
        def filter_(row): return [v.strip() for v in row]
148
        Filter.__init__(self, filter_, reader)
149

    
150
class ColCtFilter(Filter):
151
    '''Gives all rows the same # columns'''
152
    def __init__(self, reader, cols_ct):
153
        def filter_(row): return util.list_as_length(row, cols_ct)
154
        Filter.__init__(self, filter_, reader)
155

    
156
##### Translators
157

    
158
class StreamFilter(Filter):
159
    '''Wraps a reader in a way that's usable to a filter stream that does not
160
    require lines to be strings. Reports EOF as '' instead of StopIteration.'''
161
    def __init__(self, reader):
162
        Filter.__init__(self, None, reader)
163
    
164
    def readline(self):
165
        try: return self.reader.next()
166
        except StopIteration: return '' # EOF
167

    
168
class ColInsertFilter(Filter):
169
    '''Adds column(s) to each row
170
    @param mk_value(row, row_num)
171
    '''
172
    def __init__(self, reader, mk_value, index=0, n=1):
173
        def filter_(row):
174
            row = list(row) # make sure it's mutable; don't modify input!
175
            for i in xrange(n):
176
                row.insert(index+i, mk_value(row, self.reader.line_num))
177
            return row
178
        Filter.__init__(self, filter_,
179
            streams.LineCountInputStream(StreamFilter(reader)))
180

    
181
class RowNumFilter(ColInsertFilter):
182
    '''Adds a row # column at the beginning of each row'''
183
    def __init__(self, reader):
184
        def mk_value(row, row_num): return row_num
185
        ColInsertFilter.__init__(self, reader, mk_value, 0)
186

    
187
class InputRewriter(StreamFilter):
188
    '''Wraps a reader, writing each row back to CSV'''
189
    def __init__(self, reader, dialect=csv.excel):
190
        StreamFilter.__init__(self, reader)
191
        
192
        self.dialect = dialect
193
    
194
    def readline(self):
195
        row = self.reader.readline()
196
        if row == '': return row # EOF
197
        
198
        line_stream = StringIO.StringIO()
199
        csv.writer(line_stream, self.dialect).writerow(row)
200
        return line_stream.getvalue()
201
    
202
    def read(self, n): return self.readline() # forward all reads to readline()
(10-10/47)