Project

General

Profile

1 1942 aaronmk
#!/usr/bin/env python
2
# Loads a command's CSV output stream into a PostgreSQL table.
3
# The command may be run more than once.
4
5
import csv
6
import os.path
7
import subprocess
8
import sys
9
10
sys.path.append(os.path.dirname(__file__)+"/../lib")
11
12
import csvs
13
import exc
14
import opts
15
import sql
16
import streams
17
18
def main():
19
    # Usage
20
    env_names = []
21
    def usage_err():
22
        raise SystemExit('Usage: '+opts.env_usage(env_names)+' '+sys.argv[0]
23
            +' input_cmd [args...]')
24
25
    # Parse args
26
    input_cmd = sys.argv[1:]
27
    if input_cmd == []: usage_err()
28
29
    # Get config from env vars
30
    table = opts.get_env_var('table', None, env_names)
31
    schema = opts.get_env_var('schema', 'public', env_names)
32
    db_config = opts.get_env_vars(sql.db_config_names, None, env_names)
33
    debug = opts.env_flag('debug', False, env_names)
34
    if not (table != None and 'engine' in db_config): usage_err()
35
36
    # Connect to DB
37
    db = sql.connect(db_config)
38
39
    def try_load():
40
        # Open input stream
41
        proc = subprocess.Popen(input_cmd, stdout=subprocess.PIPE, bufsize=-1)
42
        in_ = proc.stdout
43
44
        # Get format info
45
        info = csvs.stream_info(in_, parse_header=True)
46
        dialect = info.dialect
47
48
        # Escape names
49
        def esc_name(name): return sql.esc_name(db, name, preserve_case=True)
50
        qual_table = esc_name(schema)+'.'+esc_name(table)
51
        esc_cols = map(esc_name, info.header)
52
53
        # Create CREATE TABLE statement
54
        pkey = esc_name(table+'_pkey')
55
        create_table = 'CREATE TABLE '+qual_table+' (\n'
56
        create_table += '    row_num serial NOT NULL,\n'
57
        for esc_col in esc_cols: create_table += '    '+esc_col+' text,\n'
58
        create_table += '    CONSTRAINT '+pkey+' PRIMARY KEY (row_num)\n'
59
        create_table += ');\n'
60
        if debug: sys.stderr.write(create_table)
61
62
        # Create COPY FROM statement
63
        cur = db.db.cursor()
64
        copy_from = ('COPY '+qual_table+' ('+(', '.join(esc_cols))
65
            +') FROM STDIN DELIMITER %(delimiter)s NULL %(null)s')
66
        if not csvs.is_tsv(dialect):
67
            copy_from += ' CSV'
68
            if dialect.quoting != csv.QUOTE_NONE:
69
                copy_from += ' QUOTE %(quotechar)s'
70
                if dialect.doublequote: copy_from += ' ESCAPE %(quotechar)s'
71
        copy_from += ';\n'
72
        copy_from = cur.mogrify(copy_from, dict(delimiter=dialect.delimiter,
73
            null=r'\N', quotechar=dialect.quotechar))
74
        if debug: sys.stderr.write(copy_from)
75
76
        # Create table
77
        sql.run_query(db, create_table)
78
79
        # COPY FROM the data
80
        line_in = streams.ProgressInputStream(in_, sys.stderr,
81
            'Processed %d row(s)', n=10000)
82
        try: db.db.cursor().copy_expert(copy_from, line_in)
83
        finally:
84
            line_in.close() # also closes proc.stdout
85
            proc.wait()
86
87
    for encoding in ['UTF8', 'LATIN1']:
88
        db.db.set_client_encoding(encoding)
89
90
        try: sql.with_savepoint(db, try_load)
91
        except sql.DatabaseErrors, e:
92
            if str(e).find('invalid byte sequence for encoding') >= 0:
93
                exc.print_ex(e, plain=True)
94
                # now, continue to next encoding
95
            else: raise e
96
        else:
97
            db.db.commit() # commit must occur outside of with_savepoint()
98
            break # don't try further encodings
99
100
main()