Project

General

Profile

« Previous | Next » 

Revision 5017

sql_io.py: import_csv(): Factored insertion code out into new append_csv()

View differences:

sql_io.py
219 219

  
220 220
##### Import
221 221

  
222
def append_csv(db, table, stream_info, stream, use_copy_from=True):
223
    def log(msg, level=1): db.log_debug(msg, level)
224
    
225
    dialect = stream_info.dialect
226
    if csvs.is_tsv(dialect): use_copy_from = False
227
    if use_copy_from:
228
        log('Using COPY FROM')
229
        
230
        # Create COPY FROM statement
231
        copy_from = ('COPY '+table.to_str(db)+' FROM STDIN DELIMITER '
232
            +db.esc_value(dialect.delimiter)+' NULL '+db.esc_value(''))
233
        assert not csvs.is_tsv(dialect)
234
        copy_from += ' CSV'
235
        if dialect.quoting != csv.QUOTE_NONE:
236
            quote_str = db.esc_value(dialect.quotechar)
237
            copy_from += ' QUOTE '+quote_str
238
            if dialect.doublequote: copy_from += ' ESCAPE '+quote_str
239
        copy_from += ';\n'
240
        
241
        log(copy_from, level=2)
242
        db.db.cursor().copy_expert(copy_from, stream)
243
    else:
244
        log('Using INSERT')
245
        cols_ct = len(stream_info.header)
246
        for row in csvs.make_reader(stream, dialect):
247
            row = map(strings.to_unicode, row)
248
            util.list_set_length(row, cols_ct) # truncate extra cols
249
            sql.insert(db, table, row, cacheable=False, log_level=5)
250

  
222 251
def import_csv(db, table, stream, use_copy_from=True, has_row_num=True):
223 252
    def log(msg, level=1): db.log_debug(msg, level)
224 253
    
225 254
    # Get format info
226 255
    info = csvs.stream_info(stream, parse_header=True)
227
    dialect = info.dialect
228
    if csvs.is_tsv(dialect): use_copy_from = False
229 256
    col_names = map(strings.to_unicode, info.header)
230 257
    for i, col in enumerate(col_names): # replace empty column names
231 258
        if col == '': col_names[i] = 'column_'+str(i)
......
246 273
    sql.truncate(db, table)
247 274
    
248 275
    # Load the data
249
    def load_():
250
        if use_copy_from:
251
            log('Using COPY FROM')
252
            
253
            # Create COPY FROM statement
254
            copy_from = ('COPY '+table.to_str(db)+' FROM STDIN DELIMITER '
255
                +db.esc_value(dialect.delimiter)+' NULL '+db.esc_value(''))
256
            assert not csvs.is_tsv(dialect)
257
            copy_from += ' CSV'
258
            if dialect.quoting != csv.QUOTE_NONE:
259
                quote_str = db.esc_value(dialect.quotechar)
260
                copy_from += ' QUOTE '+quote_str
261
                if dialect.doublequote: copy_from += ' ESCAPE '+quote_str
262
            copy_from += ';\n'
263
            
264
            log(copy_from, level=2)
265
            db.db.cursor().copy_expert(copy_from, stream)
266
        else:
267
            log('Using INSERT')
268
            cols_ct = len(col_names)
269
            for row in csvs.make_reader(stream, dialect):
270
                row = map(strings.to_unicode, row)
271
                util.list_set_length(row, cols_ct) # truncate extra cols
272
                sql.insert(db, table, row, cacheable=False, log_level=5)
273
    sql.with_savepoint(db, load_)
276
    sql.with_savepoint(db, lambda: append_csv(db, table, info, stream,
277
        use_copy_from))
274 278
    
275 279
    if has_row_num: sql.add_row_num(db, table, sql.pkey_col)
276 280
    cleanup_table(db, table)

Also available in: Unified diff