Project

General

Profile

« Previous | Next » 

Revision 830

sql.py: Added ability to recover from database errors so you don't get the error "InternalError: current transaction is aborted, commands ignored until end of transaction block"

View differences:

lib/sql.py
44 44
    else: raise NotImplementedError("Can't escape name for "+module+' database')
45 45
    return quote + name.replace(quote, '') + quote
46 46

  
47
def run_query(db, query, params=None):
47
def run_raw_query(db, query, params=None):
48 48
    cur = db.cursor()
49 49
    try: cur.execute(query, params)
50 50
    except Exception, e:
......
68 68

  
69 69
def with_savepoint(db, func):
70 70
    savepoint = 'savepoint_'+str(random.randint(0, sys.maxint)) # must be unique
71
    run_query(db, 'SAVEPOINT '+savepoint)
71
    run_raw_query(db, 'SAVEPOINT '+savepoint)
72 72
    try: return_val = func()
73 73
    except:
74
        run_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
74
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
75 75
        raise
76 76
    else:
77
        run_query(db, 'RELEASE SAVEPOINT '+savepoint)
77
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
78 78
        return return_val
79 79

  
80
def select(db, table, fields, conds, limit=None):
80
def run_query(db, query, params=None, recover=None):
81
    if recover == None: recover = False
82
    
83
    def run(): return run_raw_query(db, query, params)
84
    if recover: return with_savepoint(db, run)
85
    else: return run()
86

  
87
def select(db, table, fields, conds, limit=None, recover=None):
81 88
    assert limit == None or type(limit) == int
82 89
    check_name(table)
83 90
    map(check_name, fields)
......
94 101
    if conds != {}:
95 102
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
96 103
    if limit != None: query += ' LIMIT '+str(limit)
97
    return run_query(db, query, conds.values())
104
    return run_query(db, query, conds.values(), recover)
98 105

  
99
def insert(db, table, row):
106
def insert(db, table, row, recover=None):
100 107
    check_name(table)
101 108
    cols = row.keys()
102 109
    map(check_name, cols)
......
104 111
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
105 112
        +', '.join(['%s']*len(cols))+')'
106 113
    else: query += ' DEFAULT VALUES'
107
    return run_query(db, query, row.values())
114
    return run_query(db, query, row.values(), recover)
108 115

  
109 116
def last_insert_id(db):
110 117
    module = util.root_module(db)
......
131 138
    else: raise NotImplementedError("Can't list constraint columns for "+module+
132 139
        ' database')
133 140

  
141
def pkey(db, cache, table, recover=None):
142
    '''Assumed to be first column in table'''
143
    check_name(table)
144
    if table not in cache:
145
        cache[table] = col(run_query(db, 'SELECT * FROM '+table+' LIMIT 0',
146
            recover), 0)
147
    return cache[table]
148

  
134 149
def try_insert(db, table, row):
135
    try: return with_savepoint(db, lambda: insert(db, table, row))
150
    '''Recovers from errors'''
151
    try: return insert(db, table, row, recover=True)
136 152
    except Exception, e:
137 153
        msg = str(e)
138 154
        match = re.search(r'duplicate key value violates unique constraint '
......
147 163
        if match: raise NullValueException([match.group(1)], e)
148 164
        raise # no specific exception raised
149 165

  
150
def pkey(db, cache, table): # Assumed to be first column in table
151
    check_name(table)
152
    if table not in cache:
153
        cache[table] = col(run_query(db, 'SELECT * FROM '+table+' LIMIT 0'), 0)
154
    return cache[table]
155

  
156 166
def put(db, table, row, pkey, row_ct_ref=None):
167
    '''Recovers from errors'''
157 168
    try:
158 169
        row_ct = try_insert(db, table, row).rowcount
159 170
        if row_ct_ref != None and row_ct >= 0: row_ct_ref[0] += row_ct
160 171
        return last_insert_id(db)
161 172
    except DuplicateKeyException, e:
162
        return value(select(db, table, [pkey], util.dict_subset(row, e.cols)))
173
        return value(select(db, table, [pkey], util.dict_subset(row, e.cols),
174
            recover=True))
163 175

  
164 176
def get(db, table, row, pkey, row_ct_ref=None, create=False):
165
    try: return value(select(db, table, [pkey], row, 1))
177
    '''Recovers from errors'''
178
    try: return value(select(db, table, [pkey], row, 1, recover=True))
166 179
    except StopIteration:
167 180
        if not create: raise
168 181
        return put(db, table, row, pkey, row_ct_ref) # insert new row
......
206 219
        except KeyError: pass
207 220
    db = module.connect(**db_config)
208 221
    if serializable:
209
        run_query(db, 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
222
        run_raw_query(db, 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
210 223
    return db
211 224

  
212 225
def db_config_str(db_config):

Also available in: Unified diff