Project

General

Profile

« Previous | Next » 

Revision 1554

sql.py: insert() (and try_insert()): Added optional returning param to provide name of an inserted column (usually pkey) to return

View differences:

lib/sql.py
139 139
    
140 140
    return run_query(db, query, conds.values(), recover)
141 141

  
142
def insert(db, table, row, recover=None):
142
def insert(db, table, row, returning=None, recover=None):
143
    '''@param returning str|None An inserted column (such as pkey) to return'''
143 144
    check_name(table)
144 145
    cols = row.keys()
145 146
    map(check_name, cols)
146 147
    query = 'INSERT INTO '+table
148
    
147 149
    if row != {}: query += ' ('+', '.join(cols)+') VALUES ('\
148 150
        +', '.join(['%s']*len(cols))+')'
149 151
    else: query += ' DEFAULT VALUES'
152
    
153
    if returning != None:
154
        check_name(returning)
155
        query += ' RETURNING '+returning
156
    
150 157
    return run_query(db, query, row.values(), recover)
151 158

  
152 159
def last_insert_id(db):
......
246 253

  
247 254
##### Heuristic queries
248 255

  
249
def try_insert(db, table, row):
256
def try_insert(db, table, row, returning=None):
250 257
    '''Recovers from errors'''
251
    try: return insert(db, table, row, recover=True)
258
    try: return insert(db, table, row, returning, recover=True)
252 259
    except Exception, e:
253 260
        msg = str(e)
254 261
        match = re.search(r'duplicate key value violates unique constraint '
......
264 271
        raise # no specific exception raised
265 272

  
266 273
def put(db, table, row, pkey, row_ct_ref=None):
267
    '''Recovers from errors'''
274
    '''Recovers from errors.
275
    Only works under PostgreSQL (uses `INSERT ... RETURNING`)'''
268 276
    try:
269
        row_ct = try_insert(db, table, row).rowcount
270
        if row_ct_ref != None and row_ct >= 0: row_ct_ref[0] += row_ct
271
        return last_insert_id(db)
277
        cur = try_insert(db, table, row, pkey)
278
        if row_ct_ref != None and cur.rowcount >= 0:
279
            row_ct_ref[0] += cur.rowcount
280
        return value(cur)
272 281
    except DuplicateKeyException, e:
273 282
        return value(select(db, table, [pkey],
274 283
            util.dict_subset_right_join(row, e.cols), recover=True))

Also available in: Unified diff