Project

General

Profile

« Previous | Next » 

Revision 1849

sql.py: Wrapped db connection inside an object that can also store the cache of the pkeys and index_cols

View differences:

lib/sql.py
47 47
        +'" may contain only alphanumeric characters and _')
48 48

  
49 49
def esc_name(db, name):
50
    module = util.root_module(db)
50
    module = util.root_module(db.db)
51 51
    if module == 'psycopg2': return name
52 52
        # Don't enclose in quotes because this disables case-insensitivity
53 53
    elif module == 'MySQLdb': quote = '`'
54 54
    else: raise NotImplementedError("Can't escape name for "+module+' database')
55 55
    return quote + name.replace(quote, '') + quote
56 56

  
57
##### Connection object
58

  
59
class DbConn:
60
    def __init__(self, db):
61
        self.db = db
62
        self.pkeys = {}
63
        self.index_cols = {}
64

  
57 65
##### Querying
58 66

  
59 67
def run_raw_query(db, query, params=None):
60
    cur = db.cursor()
68
    cur = db.db.cursor()
61 69
    try: cur.execute(query, params)
62 70
    except Exception, e:
63 71
        _add_cursor_info(e, cur)
......
157 165
    return run_query(db, query, row.values(), recover)
158 166

  
159 167
def last_insert_id(db):
160
    module = util.root_module(db)
168
    module = util.root_module(db.db)
161 169
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
162 170
    elif module == 'MySQLdb': return db.insert_id()
163 171
    else: return None
......
182 190
    constraint or a UNIQUE index, use this function.'''
183 191
    check_name(table)
184 192
    check_name(index)
185
    module = util.root_module(db)
193
    module = util.root_module(db.db)
186 194
    if module == 'psycopg2':
187 195
        return list(values(run_query(db, '''\
188 196
SELECT attname
......
222 230
def constraint_cols(db, table, constraint):
223 231
    check_name(table)
224 232
    check_name(constraint)
225
    module = util.root_module(db)
233
    module = util.root_module(db.db)
226 234
    if module == 'psycopg2':
227 235
        return list(values(run_query(db, '''\
228 236
SELECT attname
......
239 247
        ' database')
240 248

  
241 249
def tables(db):
242
    module = util.root_module(db)
250
    module = util.root_module(db.db)
243 251
    if module == 'psycopg2':
244 252
        return values(run_query(db, "SELECT tablename from pg_tables "
245 253
            "WHERE schemaname = 'public' ORDER BY tablename"))
......
312 320
    for orig, new in mappings.iteritems():
313 321
        try: util.rename_key(db_config, orig, new)
314 322
        except KeyError: pass
315
    db = module.connect(**db_config)
323
    db = DbConn(module.connect(**db_config))
316 324
    if serializable:
317 325
        run_raw_query(db, 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
318 326
    return db
bin/map
242 242
                limit=end, start=0)
243 243
            row_ct = map_table(list(sql.col_names(cur)), sql.rows(cur))
244 244
            
245
            in_db.close()
245
            in_db.db.close()
246 246
        elif in_is_xml:
247 247
            def get_rows(doc2rows):
248 248
                return iters.flatten(itertools.imap(doc2rows,
......
307 307
                        sql.with_savepoint(out_db,
308 308
                            lambda: db_xml.put(out_db, root.firstChild,
309 309
                                out_pkeys, row_ins_ct_ref, on_error))
310
                        if commit: out_db.commit()
310
                        if commit: out_db.db.commit()
311 311
                    except sql.DatabaseErrors, e: on_error(e)
312 312
                prep_root()
313 313
            
......
315 315
            sys.stdout.write('Inserted '+str(row_ins_ct_ref[0])+
316 316
                ' new rows into database\n')
317 317
        finally:
318
            out_db.rollback()
319
            out_db.close()
318
            out_db.db.rollback()
319
            out_db.db.close()
320 320
    else:
321 321
        def on_error(e): ex_tracker.track(e)
322 322
        def row_ready(row_num, input_row): pass

Also available in: Unified diff