Revision 1963
Added by Aaron Marcuse-Kubitza over 12 years ago
bin/csv2db | ||
---|---|---|
4 | 4 |
|
5 | 5 |
import csv |
6 | 6 |
import os.path |
7 |
import re |
|
7 | 8 |
import subprocess |
8 | 9 |
import sys |
9 | 10 |
|
... | ... | |
14 | 15 |
import opts |
15 | 16 |
import sql |
16 | 17 |
import streams |
18 |
import strings |
|
17 | 19 |
|
18 | 20 |
def main(): |
19 | 21 |
# Usage |
... | ... | |
36 | 38 |
# Connect to DB |
37 | 39 |
db = sql.connect(db_config) |
38 | 40 |
|
39 |
def try_load(): |
|
41 |
use_copy_from = [True] |
|
42 |
|
|
43 |
# Loads data into the table using the currently-selected approach. |
|
44 |
def load_(): |
|
40 | 45 |
# Open input stream |
41 | 46 |
proc = subprocess.Popen(input_cmd, stdout=subprocess.PIPE, bufsize=-1) |
42 | 47 |
in_ = proc.stdout |
... | ... | |
44 | 49 |
# Get format info |
45 | 50 |
info = csvs.stream_info(in_, parse_header=True) |
46 | 51 |
dialect = info.dialect |
52 |
if csvs.is_tsv(dialect): use_copy_from[0] = False |
|
47 | 53 |
|
48 |
# Escape names
|
|
54 |
# Select schema and escape names
|
|
49 | 55 |
def esc_name(name): return sql.esc_name(db, name, preserve_case=True) |
50 |
qual_table = esc_name(schema)+'.'+esc_name(table) |
|
56 |
sql.run_query(db, 'SET search_path TO '+esc_name(schema)) |
|
57 |
esc_table = esc_name(table) |
|
51 | 58 |
esc_cols = map(esc_name, info.header) |
52 | 59 |
|
53 | 60 |
# Create CREATE TABLE statement |
54 | 61 |
pkey = esc_name(table+'_pkey') |
55 |
create_table = 'CREATE TABLE '+qual_table+' (\n'
|
|
62 |
create_table = 'CREATE TABLE '+esc_table+' (\n'
|
|
56 | 63 |
create_table += ' row_num serial NOT NULL,\n' |
57 | 64 |
for esc_col in esc_cols: create_table += ' '+esc_col+' text,\n' |
58 | 65 |
create_table += ' CONSTRAINT '+pkey+' PRIMARY KEY (row_num)\n' |
59 | 66 |
create_table += ');\n' |
60 | 67 |
if debug: sys.stderr.write(create_table) |
61 | 68 |
|
69 |
# Create table |
|
70 |
sql.run_query(db, create_table) |
|
71 |
|
|
62 | 72 |
# Create COPY FROM statement |
63 |
cur = db.db.cursor() |
|
64 |
copy_from = ('COPY '+qual_table+' ('+(', '.join(esc_cols)) |
|
65 |
+') FROM STDIN DELIMITER %(delimiter)s NULL %(null)s') |
|
66 |
if not csvs.is_tsv(dialect): |
|
73 |
if use_copy_from[0]: |
|
74 |
cur = db.db.cursor() |
|
75 |
copy_from = ('COPY '+esc_table+' ('+(', '.join(esc_cols)) |
|
76 |
+') FROM STDIN DELIMITER %(delimiter)s NULL %(null)s') |
|
77 |
assert not csvs.is_tsv(dialect) |
|
67 | 78 |
copy_from += ' CSV' |
68 | 79 |
if dialect.quoting != csv.QUOTE_NONE: |
69 | 80 |
copy_from += ' QUOTE %(quotechar)s' |
70 | 81 |
if dialect.doublequote: copy_from += ' ESCAPE %(quotechar)s' |
71 |
copy_from += ';\n' |
|
72 |
copy_from = cur.mogrify(copy_from, dict(delimiter=dialect.delimiter, |
|
73 |
null=r'\N', quotechar=dialect.quotechar)) |
|
74 |
if debug: sys.stderr.write(copy_from) |
|
82 |
copy_from += ';\n'
|
|
83 |
copy_from = cur.mogrify(copy_from, dict(delimiter=dialect.delimiter,
|
|
84 |
null=r'\N', quotechar=dialect.quotechar))
|
|
85 |
if debug: sys.stderr.write(copy_from)
|
|
75 | 86 |
|
76 |
# Create table |
|
77 |
sql.run_query(db, create_table) |
|
78 |
|
|
79 |
# COPY FROM the data |
|
87 |
# Load the data |
|
80 | 88 |
line_in = streams.ProgressInputStream(in_, sys.stderr, |
81 | 89 |
'Processed %d row(s)', n=10000) |
82 |
try: db.db.cursor().copy_expert(copy_from, line_in) |
|
90 |
try: |
|
91 |
if use_copy_from[0]: |
|
92 |
sys.stderr.write('Using COPY FROM\n') |
|
93 |
db.db.cursor().copy_expert(copy_from, line_in) |
|
94 |
else: |
|
95 |
sys.stderr.write('Using INSERT\n') |
|
96 |
for row in csvs.make_reader(line_in, dialect): |
|
97 |
row = map(strings.to_unicode, row) |
|
98 |
row.insert(0, sql.default) # leave space for autogen row_num |
|
99 |
sql.insert(db, esc_table, row, table_is_esc=True) |
|
83 | 100 |
finally: |
84 | 101 |
line_in.close() # also closes proc.stdout |
85 | 102 |
proc.wait() |
103 |
load = lambda: sql.with_savepoint(db, load_) |
|
86 | 104 |
|
87 |
for encoding in ['UTF8', 'LATIN1']: |
|
88 |
db.db.set_client_encoding(encoding) |
|
89 |
|
|
90 |
try: sql.with_savepoint(db, try_load) |
|
91 |
except sql.DatabaseErrors, e: |
|
92 |
if str(e).find('invalid byte sequence for encoding') >= 0: |
|
93 |
exc.print_ex(e, plain=True) |
|
94 |
# now, continue to next encoding |
|
95 |
else: raise e |
|
96 |
else: |
|
97 |
db.db.commit() # commit must occur outside of with_savepoint() |
|
98 |
break # don't try further encodings |
|
105 |
try: load() |
|
106 |
except sql.DatabaseErrors, e: |
|
107 |
if use_copy_from[0]: # first try |
|
108 |
exc.print_ex(e, plain=True) |
|
109 |
use_copy_from[0] = False |
|
110 |
load() # try again with different approach |
|
111 |
else: raise e |
|
112 |
db.db.commit() |
|
99 | 113 |
|
100 | 114 |
main() |
Also available in: Unified diff
csv2db: Fall back to manually inserting each row (autodetecting the encoding for each field) if COPY FROM doesn't work