Project

General

Profile

« Previous | Next » 

Revision 4995

sql_io.py: Added import_csv()

View differences:

lib/sql_io.py
1 1
# Database import/export
2 2

  
3 3
import copy
4
import csv
4 5
import operator
5 6
import warnings
6 7

  
8
import csvs
7 9
import exc
8 10
import dicts
9 11
import sql
......
217 219

  
218 220
##### Import
219 221

  
222
def import_csv(db, table, stream, use_copy_from=True, has_row_num=True):
223
    def log(msg, level=1): db.log_debug(msg, level)
224
    
225
    # Get format info
226
    info = csvs.stream_info(stream, parse_header=True)
227
    dialect = info.dialect
228
    if csvs.is_tsv(dialect): use_copy_from = False
229
    col_names = map(strings.to_unicode, info.header)
230
    for i, col in enumerate(col_names): # replace empty column names
231
        if col == '': col_names[i] = 'column_'+str(i)
232
    
233
    # Select schema and escape names
234
    def esc_name(name): return db.esc_name(name)
235
    
236
    typed_cols = [sql_gen.TypedCol(v, 'text') for v in col_names]
237
    
238
    log('Creating table')
239
    sql.create_table(db, table, typed_cols, has_pkey=False, col_indexes=False)
240
    
241
    # Load the data
242
    def load_():
243
        if use_copy_from:
244
            log('Using COPY FROM')
245
            
246
            # Create COPY FROM statement
247
            copy_from = ('COPY '+table.to_str(db)+' FROM STDIN DELIMITER '
248
                +db.esc_value(dialect.delimiter)+' NULL '+db.esc_value(''))
249
            assert not csvs.is_tsv(dialect)
250
            copy_from += ' CSV'
251
            if dialect.quoting != csv.QUOTE_NONE:
252
                quote_str = db.esc_value(dialect.quotechar)
253
                copy_from += ' QUOTE '+quote_str
254
                if dialect.doublequote: copy_from += ' ESCAPE '+quote_str
255
            copy_from += ';\n'
256
            
257
            log(copy_from, level=2)
258
            db.db.cursor().copy_expert(copy_from, stream)
259
        else:
260
            log('Using INSERT')
261
            cols_ct = len(col_names)
262
            for row in csvs.make_reader(stream, dialect):
263
                row = map(strings.to_unicode, row)
264
                util.list_set_length(row, cols_ct) # truncate extra cols
265
                sql.insert(db, table, row, cacheable=False, log_level=5)
266
    sql.with_savepoint(db, load_)
267
    
268
    if has_row_num: sql.add_row_num(db, table)
269
    cleanup_table(db, table)
270

  
220 271
def put(db, table, row, pkey_=None, row_ct_ref=None):
221 272
    '''Recovers from errors.
222 273
    Only works under PostgreSQL (uses INSERT RETURNING).

Also available in: Unified diff