Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1960 aaronmk
import lists
11 1889 aaronmk
from Proxy import Proxy
12 1872 aaronmk
import rand
13 862 aaronmk
import strings
14 131 aaronmk
import util
15 11 aaronmk
16 832 aaronmk
##### Exceptions
17
18 135 aaronmk
def get_cur_query(cur):
19
    if hasattr(cur, 'query'): return cur.query
20
    elif hasattr(cur, '_last_executed'): return cur._last_executed
21
    else: return None
22 14 aaronmk
23 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
24 135 aaronmk
25 300 aaronmk
class DbException(exc.ExceptionWithCause):
26 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
27 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
28 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
29
30 360 aaronmk
class NameException(DbException): pass
31
32 468 aaronmk
class ExceptionWithColumns(DbException):
33
    def __init__(self, cols, cause=None):
34
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
35
        self.cols = cols
36 11 aaronmk
37 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
38 13 aaronmk
39 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
40 13 aaronmk
41 89 aaronmk
class EmptyRowException(DbException): pass
42
43 865 aaronmk
##### Warnings
44
45
class DbWarning(UserWarning): pass
46
47 1930 aaronmk
##### Result retrieval
48
49
def col_names(cur): return (col[0] for col in cur.description)
50
51
def rows(cur): return iter(lambda: cur.fetchone(), None)
52
53
def consume_rows(cur):
54
    '''Used to fetch all rows so result will be cached'''
55
    iters.consume_iter(rows(cur))
56
57
def next_row(cur): return rows(cur).next()
58
59
def row(cur):
60
    row_ = next_row(cur)
61
    consume_rows(cur)
62
    return row_
63
64
def next_value(cur): return next_row(cur)[0]
65
66
def value(cur): return row(cur)[0]
67
68
def values(cur): return iters.func_iter(lambda: next_value(cur))
69
70
def value_or_none(cur):
71
    try: return value(cur)
72
    except StopIteration: return None
73
74 2101 aaronmk
##### Input validation
75
76
def clean_name(name): return re.sub(r'\W', r'', name)
77
78
def check_name(name):
79
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
80
        +'" may contain only alphanumeric characters and _')
81
82
def esc_name_by_module(module, name, ignore_case=False):
83
    if module == 'psycopg2':
84
        if ignore_case:
85
            # Don't enclose in quotes because this disables case-insensitivity
86
            check_name(name)
87
            return name
88
        else: quote = '"'
89
    elif module == 'MySQLdb': quote = '`'
90
    else: raise NotImplementedError("Can't escape name for "+module+' database')
91
    return quote + name.replace(quote, '') + quote
92
93
def esc_name_by_engine(engine, name, **kw_args):
94
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
95
96
def esc_name(db, name, **kw_args):
97
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
98
99
def qual_name(db, schema, table):
100
    def esc_name_(name): return esc_name(db, name)
101
    table = esc_name_(table)
102
    if schema != None: return esc_name_(schema)+'.'+table
103
    else: return table
104
105 1869 aaronmk
##### Database connections
106 1849 aaronmk
107 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
108 1926 aaronmk
109 1869 aaronmk
db_engines = {
110
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
111
    'PostgreSQL': ('psycopg2', {}),
112
}
113
114
DatabaseErrors_set = set([DbException])
115
DatabaseErrors = tuple(DatabaseErrors_set)
116
117
def _add_module(module):
118
    DatabaseErrors_set.add(module.DatabaseError)
119
    global DatabaseErrors
120
    DatabaseErrors = tuple(DatabaseErrors_set)
121
122
def db_config_str(db_config):
123
    return db_config['engine']+' database '+db_config['database']
124
125 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
126 1894 aaronmk
127 1901 aaronmk
log_debug_none = lambda msg: None
128
129 1849 aaronmk
class DbConn:
130 2047 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none,
131 2050 aaronmk
        caching=True):
132 1869 aaronmk
        self.db_config = db_config
133
        self.serializable = serializable
134 1901 aaronmk
        self.log_debug = log_debug
135 2047 aaronmk
        self.caching = caching
136 1869 aaronmk
137
        self.__db = None
138 1889 aaronmk
        self.query_results = {}
139 1869 aaronmk
140
    def __getattr__(self, name):
141
        if name == '__dict__': raise Exception('getting __dict__')
142
        if name == 'db': return self._db()
143
        else: raise AttributeError()
144
145
    def __getstate__(self):
146
        state = copy.copy(self.__dict__) # shallow copy
147 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
148 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
149
        return state
150
151
    def _db(self):
152
        if self.__db == None:
153
            # Process db_config
154
            db_config = self.db_config.copy() # don't modify input!
155 2097 aaronmk
            schemas = db_config.pop('schemas', None)
156 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
157
            module = __import__(module_name)
158
            _add_module(module)
159
            for orig, new in mappings.iteritems():
160
                try: util.rename_key(db_config, orig, new)
161
                except KeyError: pass
162
163
            # Connect
164
            self.__db = module.connect(**db_config)
165
166
            # Configure connection
167
            if self.serializable: run_raw_query(self,
168
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
169 2101 aaronmk
            if schemas != None:
170
                schemas_ = ''.join((esc_name(self, s)+', '
171
                    for s in schemas.split(',')))
172
                run_raw_query(self, "SELECT set_config('search_path', \
173
%s || current_setting('search_path'), false)", [schemas_])
174 1869 aaronmk
175
        return self.__db
176 1889 aaronmk
177 1891 aaronmk
    class DbCursor(Proxy):
178 1927 aaronmk
        def __init__(self, outer):
179 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
180 1927 aaronmk
            self.query_results = outer.query_results
181 1894 aaronmk
            self.query_lookup = None
182 1891 aaronmk
            self.result = []
183 1889 aaronmk
184 1894 aaronmk
        def execute(self, query, params=None):
185 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
186 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
187 1904 aaronmk
            try: return_value = self.inner.execute(query, params)
188
            except Exception, e:
189
                self.result = e # cache the exception as the result
190
                self._cache_result()
191
                raise
192
            finally: self.query = get_cur_query(self.inner)
193 1930 aaronmk
            # Fetch all rows so result will be cached
194
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
195 1894 aaronmk
            return return_value
196
197 1891 aaronmk
        def fetchone(self):
198
            row = self.inner.fetchone()
199 1899 aaronmk
            if row != None: self.result.append(row)
200
            # otherwise, fetched all rows
201 1904 aaronmk
            else: self._cache_result()
202
            return row
203
204
        def _cache_result(self):
205 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
206
            # idempotent, but an invalid insert will always be invalid
207 1930 aaronmk
            if self.query_results != None and (not self._is_insert
208 1906 aaronmk
                or isinstance(self.result, Exception)):
209
210 1894 aaronmk
                assert self.query_lookup != None
211 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
212
                    util.dict_subset(dicts.AttrsDictView(self),
213
                    ['query', 'result', 'rowcount', 'description']))
214 1906 aaronmk
215 1916 aaronmk
        class CacheCursor:
216
            def __init__(self, cached_result): self.__dict__ = cached_result
217
218 1927 aaronmk
            def execute(self, *args, **kw_args):
219 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
220
                # otherwise, result is a rows list
221
                self.iter = iter(self.result)
222
223
            def fetchone(self):
224
                try: return self.iter.next()
225
                except StopIteration: return None
226 1891 aaronmk
227 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
228 2047 aaronmk
        if not self.caching: cacheable = False
229 1903 aaronmk
        used_cache = False
230
        try:
231 1927 aaronmk
            # Get cursor
232
            if cacheable:
233
                query_lookup = _query_lookup(query, params)
234
                try:
235
                    cur = self.query_results[query_lookup]
236
                    used_cache = True
237
                except KeyError: cur = self.DbCursor(self)
238
            else: cur = self.db.cursor()
239
240
            # Run query
241
            try: cur.execute(query, params)
242
            except Exception, e:
243
                _add_cursor_info(e, cur)
244
                raise
245 1903 aaronmk
        finally:
246
            if self.log_debug != log_debug_none: # only compute msg if needed
247
                if used_cache: cache_status = 'Cache hit'
248
                elif cacheable: cache_status = 'Cache miss'
249
                else: cache_status = 'Non-cacheable'
250 1927 aaronmk
                self.log_debug(cache_status+': '
251
                    +strings.one_line(get_cur_query(cur)))
252 1903 aaronmk
253
        return cur
254 1914 aaronmk
255
    def is_cached(self, query, params=None):
256
        return _query_lookup(query, params) in self.query_results
257 1849 aaronmk
258 1869 aaronmk
connect = DbConn
259
260 832 aaronmk
##### Querying
261
262 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
263 2085 aaronmk
    '''For params, see DbConn.run_query()'''
264 1894 aaronmk
    return db.run_query(*args, **kw_args)
265 11 aaronmk
266 2068 aaronmk
def mogrify(db, query, params):
267
    module = util.root_module(db.db)
268
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
269
    else: raise NotImplementedError("Can't mogrify query for "+module+
270
        ' database')
271
272 832 aaronmk
##### Recoverable querying
273 15 aaronmk
274 11 aaronmk
def with_savepoint(db, func):
275 1872 aaronmk
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
276 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
277 11 aaronmk
    try: return_val = func()
278
    except:
279 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
280 11 aaronmk
        raise
281
    else:
282 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
283 11 aaronmk
        return return_val
284
285 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
286 830 aaronmk
    if recover == None: recover = False
287
288 1894 aaronmk
    def run(): return run_raw_query(db, query, params, cacheable)
289 1914 aaronmk
    if recover and not db.is_cached(query, params):
290
        return with_savepoint(db, run)
291
    else: return run() # don't need savepoint if cached
292 830 aaronmk
293 832 aaronmk
##### Basic queries
294
295 2085 aaronmk
def run_query_into(db, query, params, into=None, *args, **kw_args):
296
    '''Outputs a query to a temp table.
297
    For params, see run_query().
298
    '''
299
    if into == None: return run_query(db, query, params, *args, **kw_args)
300
    else: # place rows in temp table
301
        check_name(into)
302
303
        run_query(db, 'DROP TABLE IF EXISTS '+into+' CASCADE', *args, **kw_args)
304
        return run_query(db, 'CREATE TEMP TABLE '+into+' AS '+query, params,
305
            *args, **kw_args) # CREATE TABLE sets rowcount to # rows in query
306
307 2054 aaronmk
def mk_select(db, table, fields=None, conds=None, limit=None, start=None,
308
    table_is_esc=False):
309 1981 aaronmk
    '''
310
    @param fields Use None to select all fields in the table
311
    @param table_is_esc Whether the table name has already been escaped
312 2054 aaronmk
    @return tuple(query, params)
313 1981 aaronmk
    '''
314 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
315 2058 aaronmk
316 1135 aaronmk
    if conds == None: conds = {}
317 135 aaronmk
    assert limit == None or type(limit) == int
318 865 aaronmk
    assert start == None or type(start) == int
319 2058 aaronmk
    if not table_is_esc: table = esc_name_(table)
320 865 aaronmk
321 2056 aaronmk
    params = []
322
323
    def parse_col(field):
324
        '''Parses fields'''
325
        if isinstance(field, tuple): # field is literal values
326
            value, col = field
327
            sql_ = '%s'
328
            params.append(value)
329 2058 aaronmk
            if col != None: sql_ += ' AS '+esc_name_(col)
330
        else: sql_ = esc_name_(field) # field is col name
331 2056 aaronmk
        return sql_
332 11 aaronmk
    def cond(entry):
333 2056 aaronmk
        '''Parses conditions'''
334 13 aaronmk
        col, value = entry
335 2058 aaronmk
        cond_ = esc_name_(col)+' '
336 11 aaronmk
        if value == None: cond_ += 'IS'
337
        else: cond_ += '='
338
        cond_ += ' %s'
339
        return cond_
340 2056 aaronmk
341 1135 aaronmk
    query = 'SELECT '
342
    if fields == None: query += '*'
343 2056 aaronmk
    else: query += ', '.join(map(parse_col, fields))
344 2055 aaronmk
    query += ' FROM '+table
345 865 aaronmk
346
    missing = True
347 89 aaronmk
    if conds != {}:
348
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
349 2056 aaronmk
        params += conds.values()
350 865 aaronmk
        missing = False
351
    if limit != None: query += ' LIMIT '+str(limit); missing = False
352
    if start != None:
353
        if start != 0: query += ' OFFSET '+str(start)
354
        missing = False
355
    if missing: warnings.warn(DbWarning(
356
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
357
358 2056 aaronmk
    return (query, params)
359 11 aaronmk
360 2054 aaronmk
def select(db, *args, **kw_args):
361
    '''For params, see mk_select() and run_query()'''
362
    recover = kw_args.pop('recover', None)
363
    cacheable = kw_args.pop('cacheable', True)
364
365
    query, params = mk_select(db, *args, **kw_args)
366
    return run_query(db, query, params, recover, cacheable)
367
368 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
369 2070 aaronmk
    returning=None, embeddable=False, table_is_esc=False):
370 1960 aaronmk
    '''
371
    @param returning str|None An inserted column (such as pkey) to return
372 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
373 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
374
        query will be fully cached, not just if it raises an exception.
375 1960 aaronmk
    @param table_is_esc Whether the table name has already been escaped
376
    '''
377 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
378
    if cols == []: cols = None # no cols (all defaults) = unknown col names
379 1960 aaronmk
    if not table_is_esc: check_name(table)
380 2063 aaronmk
381
    # Build query
382
    query = 'INSERT INTO '+table
383
    if cols != None:
384
        map(check_name, cols)
385
        query += ' ('+', '.join(cols)+')'
386
    query += ' '+select_query
387
388
    if returning != None:
389
        check_name(returning)
390
        query += ' RETURNING '+returning
391
392 2070 aaronmk
    if embeddable:
393
        # Create function
394 2083 aaronmk
        function = 'pg_temp.'+('_'.join(map(clean_name,
395
            ['insert', table] + cols)))
396 2070 aaronmk
        return_type = 'SETOF '+table+'.'+returning+'%TYPE'
397
        function_query = '''\
398 2083 aaronmk
CREATE OR REPLACE FUNCTION '''+function+'''() RETURNS '''+return_type+'''
399 2070 aaronmk
    LANGUAGE sql
400
    AS $$'''+mogrify(db, query, params)+''';$$;
401
'''
402
        run_query(db, function_query, cacheable=True)
403
404
        # Return query that uses function
405 2083 aaronmk
        return mk_select(db, function+'() AS f ('+returning+')',
406 2080 aaronmk
            table_is_esc=True) # function alias is required in AS clause
407 2070 aaronmk
408 2066 aaronmk
    return (query, params)
409
410
def insert_select(db, *args, **kw_args):
411 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
412 2072 aaronmk
    @param into Name of temp table to place RETURNING values in
413
    '''
414
    into = kw_args.pop('into', None)
415
    if into != None: kw_args['embeddable'] = True
416 2066 aaronmk
    recover = kw_args.pop('recover', None)
417
    cacheable = kw_args.pop('cacheable', True)
418
419
    query, params = mk_insert_select(db, *args, **kw_args)
420 2085 aaronmk
    return run_query_into(db, query, params, into, recover, cacheable)
421 2063 aaronmk
422 2066 aaronmk
default = object() # tells insert() to use the default value for a column
423
424 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
425 2085 aaronmk
    '''For params, see insert_select()'''
426 1960 aaronmk
    if lists.is_seq(row): cols = None
427
    else:
428
        cols = row.keys()
429
        row = row.values()
430
    row = list(row) # ensure that "!= []" works
431
432 1961 aaronmk
    # Check for special values
433
    labels = []
434
    values = []
435
    for value in row:
436
        if value == default: labels.append('DEFAULT')
437
        else:
438
            labels.append('%s')
439
            values.append(value)
440
441
    # Build query
442 2063 aaronmk
    if values != []: query = ' VALUES ('+(', '.join(labels))+')'
443
    else: query = None
444 1554 aaronmk
445 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
446 11 aaronmk
447 135 aaronmk
def last_insert_id(db):
448 1849 aaronmk
    module = util.root_module(db.db)
449 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
450
    elif module == 'MySQLdb': return db.insert_id()
451
    else: return None
452 13 aaronmk
453 1968 aaronmk
def truncate(db, table, schema='public'):
454
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
455 832 aaronmk
456
##### Database structure queries
457
458 2084 aaronmk
def pkey(db, table, recover=None, table_is_esc=False):
459 832 aaronmk
    '''Assumed to be first column in table'''
460 2084 aaronmk
    return col_names(select(db, table, limit=0, recover=recover,
461
        table_is_esc=table_is_esc)).next()
462 832 aaronmk
463 853 aaronmk
def index_cols(db, table, index):
464
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
465
    automatically created. When you don't know whether something is a UNIQUE
466
    constraint or a UNIQUE index, use this function.'''
467
    check_name(table)
468
    check_name(index)
469 1909 aaronmk
    module = util.root_module(db.db)
470
    if module == 'psycopg2':
471
        return list(values(run_query(db, '''\
472 853 aaronmk
SELECT attname
473 866 aaronmk
FROM
474
(
475
        SELECT attnum, attname
476
        FROM pg_index
477
        JOIN pg_class index ON index.oid = indexrelid
478
        JOIN pg_class table_ ON table_.oid = indrelid
479
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
480
        WHERE
481
            table_.relname = %(table)s
482
            AND index.relname = %(index)s
483
    UNION
484
        SELECT attnum, attname
485
        FROM
486
        (
487
            SELECT
488
                indrelid
489
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
490
                    AS indkey
491
            FROM pg_index
492
            JOIN pg_class index ON index.oid = indexrelid
493
            JOIN pg_class table_ ON table_.oid = indrelid
494
            WHERE
495
                table_.relname = %(table)s
496
                AND index.relname = %(index)s
497
        ) s
498
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
499
) s
500 853 aaronmk
ORDER BY attnum
501
''',
502 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
503
    else: raise NotImplementedError("Can't list index columns for "+module+
504
        ' database')
505 853 aaronmk
506 464 aaronmk
def constraint_cols(db, table, constraint):
507
    check_name(table)
508
    check_name(constraint)
509 1849 aaronmk
    module = util.root_module(db.db)
510 464 aaronmk
    if module == 'psycopg2':
511
        return list(values(run_query(db, '''\
512
SELECT attname
513
FROM pg_constraint
514
JOIN pg_class ON pg_class.oid = conrelid
515
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
516
WHERE
517
    relname = %(table)s
518
    AND conname = %(constraint)s
519
ORDER BY attnum
520
''',
521
            {'table': table, 'constraint': constraint})))
522
    else: raise NotImplementedError("Can't list constraint columns for "+module+
523
        ' database')
524
525 2096 aaronmk
row_num_col = '_row_num'
526
527 2086 aaronmk
def add_row_num(db, table):
528 2096 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col.'''
529 2086 aaronmk
    check_name(table)
530 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
531
        +' serial NOT NULL')
532 2086 aaronmk
533 1968 aaronmk
def tables(db, schema='public', table_like='%'):
534 1849 aaronmk
    module = util.root_module(db.db)
535 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
536 832 aaronmk
    if module == 'psycopg2':
537 1968 aaronmk
        return values(run_query(db, '''\
538
SELECT tablename
539
FROM pg_tables
540
WHERE
541
    schemaname = %(schema)s
542
    AND tablename LIKE %(table_like)s
543
ORDER BY tablename
544
''',
545
            params, cacheable=True))
546
    elif module == 'MySQLdb':
547
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
548
            cacheable=True))
549 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
550 830 aaronmk
551 833 aaronmk
##### Database management
552
553 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
554
    '''For kw_args, see tables()'''
555
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
556 833 aaronmk
557 832 aaronmk
##### Heuristic queries
558
559 2076 aaronmk
def with_parsed_errors(db, func):
560
    '''Translates known DB errors to typed exceptions'''
561
    try: return func()
562 46 aaronmk
    except Exception, e:
563
        msg = str(e)
564 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
565
            r'"(([^\W_]+)_[^"]+)"', msg)
566
        if match:
567
            constraint, table = match.groups()
568 854 aaronmk
            try: cols = index_cols(db, table, constraint)
569 465 aaronmk
            except NotImplementedError: raise e
570 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
571 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
572
            'constraint', msg)
573 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
574 13 aaronmk
        raise # no specific exception raised
575 11 aaronmk
576 2076 aaronmk
def try_insert(db, table, row, returning=None):
577
    '''Recovers from errors'''
578
    return with_parsed_errors(db, lambda: insert(db, table, row, returning,
579
        recover=True))
580
581 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
582 1554 aaronmk
    '''Recovers from errors.
583 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
584
    '''
585 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
586
587 471 aaronmk
    try:
588 2104 aaronmk
        cur = try_insert(db, table, row, pkey_)
589 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
590
            row_ct_ref[0] += cur.rowcount
591
        return value(cur)
592 471 aaronmk
    except DuplicateKeyException, e:
593 2104 aaronmk
        return value(select(db, table, [pkey_],
594 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
595 471 aaronmk
596 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
597 830 aaronmk
    '''Recovers from errors'''
598
    try: return value(select(db, table, [pkey], row, 1, recover=True))
599 14 aaronmk
    except StopIteration:
600 40 aaronmk
        if not create: raise
601 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
602 2078 aaronmk
603 2087 aaronmk
def put_table(db, out_table, out_cols, in_tables, in_cols, pkey,
604
    row_ct_ref=None, table_is_esc=False):
605 2078 aaronmk
    '''Recovers from errors.
606
    Only works under PostgreSQL (uses INSERT RETURNING).
607 2081 aaronmk
    @return Name of the table where the pkeys (from INSERT RETURNING) are made
608 2078 aaronmk
        available
609
    '''
610
    pkeys_table = clean_name(out_table)+'_pkeys'
611
    def insert_():
612
        return insert_select(db, out_table, out_cols,
613 2087 aaronmk
            *mk_select(db, in_tables[0], in_cols, table_is_esc=table_is_esc),
614 2078 aaronmk
            returning=pkey, into=pkeys_table, recover=True,
615
            table_is_esc=table_is_esc)
616
    try:
617
        cur = with_parsed_errors(db, insert_)
618
        if row_ct_ref != None and cur.rowcount >= 0:
619
            row_ct_ref[0] += cur.rowcount
620 2086 aaronmk
621
        # Add row_num to pkeys_table, so it can be joined with in_table's pkeys
622
        add_row_num(db, pkeys_table)
623
624 2081 aaronmk
        return pkeys_table
625 2078 aaronmk
    except DuplicateKeyException, e: raise