Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1960 aaronmk
import lists
11 1889 aaronmk
from Proxy import Proxy
12 1872 aaronmk
import rand
13 862 aaronmk
import strings
14 131 aaronmk
import util
15 11 aaronmk
16 832 aaronmk
##### Exceptions
17
18 135 aaronmk
def get_cur_query(cur):
19
    if hasattr(cur, 'query'): return cur.query
20
    elif hasattr(cur, '_last_executed'): return cur._last_executed
21
    else: return None
22 14 aaronmk
23 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
24 135 aaronmk
25 300 aaronmk
class DbException(exc.ExceptionWithCause):
26 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
27 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
28 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
29
30 360 aaronmk
class NameException(DbException): pass
31
32 468 aaronmk
class ExceptionWithColumns(DbException):
33
    def __init__(self, cols, cause=None):
34
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
35
        self.cols = cols
36 11 aaronmk
37 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
38 13 aaronmk
39 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
40 13 aaronmk
41 89 aaronmk
class EmptyRowException(DbException): pass
42
43 865 aaronmk
##### Warnings
44
45
class DbWarning(UserWarning): pass
46
47 1930 aaronmk
##### Result retrieval
48
49
def col_names(cur): return (col[0] for col in cur.description)
50
51
def rows(cur): return iter(lambda: cur.fetchone(), None)
52
53
def consume_rows(cur):
54
    '''Used to fetch all rows so result will be cached'''
55
    iters.consume_iter(rows(cur))
56
57
def next_row(cur): return rows(cur).next()
58
59
def row(cur):
60
    row_ = next_row(cur)
61
    consume_rows(cur)
62
    return row_
63
64
def next_value(cur): return next_row(cur)[0]
65
66
def value(cur): return row(cur)[0]
67
68
def values(cur): return iters.func_iter(lambda: next_value(cur))
69
70
def value_or_none(cur):
71
    try: return value(cur)
72
    except StopIteration: return None
73
74 1869 aaronmk
##### Database connections
75 1849 aaronmk
76 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
77 1926 aaronmk
78 1869 aaronmk
db_engines = {
79
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
80
    'PostgreSQL': ('psycopg2', {}),
81
}
82
83
DatabaseErrors_set = set([DbException])
84
DatabaseErrors = tuple(DatabaseErrors_set)
85
86
def _add_module(module):
87
    DatabaseErrors_set.add(module.DatabaseError)
88
    global DatabaseErrors
89
    DatabaseErrors = tuple(DatabaseErrors_set)
90
91
def db_config_str(db_config):
92
    return db_config['engine']+' database '+db_config['database']
93
94 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
95 1894 aaronmk
96 1901 aaronmk
log_debug_none = lambda msg: None
97
98 1849 aaronmk
class DbConn:
99 2047 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none,
100 2050 aaronmk
        caching=True):
101 1869 aaronmk
        self.db_config = db_config
102
        self.serializable = serializable
103 1901 aaronmk
        self.log_debug = log_debug
104 2047 aaronmk
        self.caching = caching
105 1869 aaronmk
106
        self.__db = None
107 1889 aaronmk
        self.query_results = {}
108 1869 aaronmk
109
    def __getattr__(self, name):
110
        if name == '__dict__': raise Exception('getting __dict__')
111
        if name == 'db': return self._db()
112
        else: raise AttributeError()
113
114
    def __getstate__(self):
115
        state = copy.copy(self.__dict__) # shallow copy
116 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
117 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
118
        return state
119
120
    def _db(self):
121
        if self.__db == None:
122
            # Process db_config
123
            db_config = self.db_config.copy() # don't modify input!
124 2097 aaronmk
            schemas = db_config.pop('schemas', None)
125 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
126
            module = __import__(module_name)
127
            _add_module(module)
128
            for orig, new in mappings.iteritems():
129
                try: util.rename_key(db_config, orig, new)
130
                except KeyError: pass
131
132
            # Connect
133
            self.__db = module.connect(**db_config)
134
135
            # Configure connection
136 2097 aaronmk
            if schemas != None:
137
                schemas = schemas[:] # don't modify input!
138
                schemas.append('current_setting(search_path)')
139
                run_raw_query(self, 'SET search_path = '+(', '.join(schemas)))
140 1869 aaronmk
            if self.serializable: run_raw_query(self,
141
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
142
143
        return self.__db
144 1889 aaronmk
145 1891 aaronmk
    class DbCursor(Proxy):
146 1927 aaronmk
        def __init__(self, outer):
147 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
148 1927 aaronmk
            self.query_results = outer.query_results
149 1894 aaronmk
            self.query_lookup = None
150 1891 aaronmk
            self.result = []
151 1889 aaronmk
152 1894 aaronmk
        def execute(self, query, params=None):
153 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
154 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
155 1904 aaronmk
            try: return_value = self.inner.execute(query, params)
156
            except Exception, e:
157
                self.result = e # cache the exception as the result
158
                self._cache_result()
159
                raise
160
            finally: self.query = get_cur_query(self.inner)
161 1930 aaronmk
            # Fetch all rows so result will be cached
162
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
163 1894 aaronmk
            return return_value
164
165 1891 aaronmk
        def fetchone(self):
166
            row = self.inner.fetchone()
167 1899 aaronmk
            if row != None: self.result.append(row)
168
            # otherwise, fetched all rows
169 1904 aaronmk
            else: self._cache_result()
170
            return row
171
172
        def _cache_result(self):
173 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
174
            # idempotent, but an invalid insert will always be invalid
175 1930 aaronmk
            if self.query_results != None and (not self._is_insert
176 1906 aaronmk
                or isinstance(self.result, Exception)):
177
178 1894 aaronmk
                assert self.query_lookup != None
179 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
180
                    util.dict_subset(dicts.AttrsDictView(self),
181
                    ['query', 'result', 'rowcount', 'description']))
182 1906 aaronmk
183 1916 aaronmk
        class CacheCursor:
184
            def __init__(self, cached_result): self.__dict__ = cached_result
185
186 1927 aaronmk
            def execute(self, *args, **kw_args):
187 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
188
                # otherwise, result is a rows list
189
                self.iter = iter(self.result)
190
191
            def fetchone(self):
192
                try: return self.iter.next()
193
                except StopIteration: return None
194 1891 aaronmk
195 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
196 2047 aaronmk
        if not self.caching: cacheable = False
197 1903 aaronmk
        used_cache = False
198
        try:
199 1927 aaronmk
            # Get cursor
200
            if cacheable:
201
                query_lookup = _query_lookup(query, params)
202
                try:
203
                    cur = self.query_results[query_lookup]
204
                    used_cache = True
205
                except KeyError: cur = self.DbCursor(self)
206
            else: cur = self.db.cursor()
207
208
            # Run query
209
            try: cur.execute(query, params)
210
            except Exception, e:
211
                _add_cursor_info(e, cur)
212
                raise
213 1903 aaronmk
        finally:
214
            if self.log_debug != log_debug_none: # only compute msg if needed
215
                if used_cache: cache_status = 'Cache hit'
216
                elif cacheable: cache_status = 'Cache miss'
217
                else: cache_status = 'Non-cacheable'
218 1927 aaronmk
                self.log_debug(cache_status+': '
219
                    +strings.one_line(get_cur_query(cur)))
220 1903 aaronmk
221
        return cur
222 1914 aaronmk
223
    def is_cached(self, query, params=None):
224
        return _query_lookup(query, params) in self.query_results
225 1849 aaronmk
226 1869 aaronmk
connect = DbConn
227
228 1919 aaronmk
##### Input validation
229
230 2077 aaronmk
def clean_name(name): return re.sub(r'\W', r'', name)
231
232 1919 aaronmk
def check_name(name):
233
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
234
        +'" may contain only alphanumeric characters and _')
235
236 2061 aaronmk
def esc_name_by_module(module, name, ignore_case=False):
237 1919 aaronmk
    if module == 'psycopg2':
238 2061 aaronmk
        if ignore_case:
239
            # Don't enclose in quotes because this disables case-insensitivity
240 2057 aaronmk
            check_name(name)
241
            return name
242 2061 aaronmk
        else: quote = '"'
243 1919 aaronmk
    elif module == 'MySQLdb': quote = '`'
244
    else: raise NotImplementedError("Can't escape name for "+module+' database')
245
    return quote + name.replace(quote, '') + quote
246
247
def esc_name_by_engine(engine, name, **kw_args):
248
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
249
250
def esc_name(db, name, **kw_args):
251
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
252
253 1968 aaronmk
def qual_name(db, schema, table):
254 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
255 2051 aaronmk
    table = esc_name_(table)
256
    if schema != None: return esc_name_(schema)+'.'+table
257
    else: return table
258 1968 aaronmk
259 832 aaronmk
##### Querying
260
261 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
262 2085 aaronmk
    '''For params, see DbConn.run_query()'''
263 1894 aaronmk
    return db.run_query(*args, **kw_args)
264 11 aaronmk
265 2068 aaronmk
def mogrify(db, query, params):
266
    module = util.root_module(db.db)
267
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
268
    else: raise NotImplementedError("Can't mogrify query for "+module+
269
        ' database')
270
271 832 aaronmk
##### Recoverable querying
272 15 aaronmk
273 11 aaronmk
def with_savepoint(db, func):
274 1872 aaronmk
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
275 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
276 11 aaronmk
    try: return_val = func()
277
    except:
278 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
279 11 aaronmk
        raise
280
    else:
281 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
282 11 aaronmk
        return return_val
283
284 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
285 830 aaronmk
    if recover == None: recover = False
286
287 1894 aaronmk
    def run(): return run_raw_query(db, query, params, cacheable)
288 1914 aaronmk
    if recover and not db.is_cached(query, params):
289
        return with_savepoint(db, run)
290
    else: return run() # don't need savepoint if cached
291 830 aaronmk
292 832 aaronmk
##### Basic queries
293
294 2085 aaronmk
def run_query_into(db, query, params, into=None, *args, **kw_args):
295
    '''Outputs a query to a temp table.
296
    For params, see run_query().
297
    '''
298
    if into == None: return run_query(db, query, params, *args, **kw_args)
299
    else: # place rows in temp table
300
        check_name(into)
301
302
        run_query(db, 'DROP TABLE IF EXISTS '+into+' CASCADE', *args, **kw_args)
303
        return run_query(db, 'CREATE TEMP TABLE '+into+' AS '+query, params,
304
            *args, **kw_args) # CREATE TABLE sets rowcount to # rows in query
305
306 2054 aaronmk
def mk_select(db, table, fields=None, conds=None, limit=None, start=None,
307
    table_is_esc=False):
308 1981 aaronmk
    '''
309
    @param fields Use None to select all fields in the table
310
    @param table_is_esc Whether the table name has already been escaped
311 2054 aaronmk
    @return tuple(query, params)
312 1981 aaronmk
    '''
313 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
314 2058 aaronmk
315 1135 aaronmk
    if conds == None: conds = {}
316 135 aaronmk
    assert limit == None or type(limit) == int
317 865 aaronmk
    assert start == None or type(start) == int
318 2058 aaronmk
    if not table_is_esc: table = esc_name_(table)
319 865 aaronmk
320 2056 aaronmk
    params = []
321
322
    def parse_col(field):
323
        '''Parses fields'''
324
        if isinstance(field, tuple): # field is literal values
325
            value, col = field
326
            sql_ = '%s'
327
            params.append(value)
328 2058 aaronmk
            if col != None: sql_ += ' AS '+esc_name_(col)
329
        else: sql_ = esc_name_(field) # field is col name
330 2056 aaronmk
        return sql_
331 11 aaronmk
    def cond(entry):
332 2056 aaronmk
        '''Parses conditions'''
333 13 aaronmk
        col, value = entry
334 2058 aaronmk
        cond_ = esc_name_(col)+' '
335 11 aaronmk
        if value == None: cond_ += 'IS'
336
        else: cond_ += '='
337
        cond_ += ' %s'
338
        return cond_
339 2056 aaronmk
340 1135 aaronmk
    query = 'SELECT '
341
    if fields == None: query += '*'
342 2056 aaronmk
    else: query += ', '.join(map(parse_col, fields))
343 2055 aaronmk
    query += ' FROM '+table
344 865 aaronmk
345
    missing = True
346 89 aaronmk
    if conds != {}:
347
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
348 2056 aaronmk
        params += conds.values()
349 865 aaronmk
        missing = False
350
    if limit != None: query += ' LIMIT '+str(limit); missing = False
351
    if start != None:
352
        if start != 0: query += ' OFFSET '+str(start)
353
        missing = False
354
    if missing: warnings.warn(DbWarning(
355
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
356
357 2056 aaronmk
    return (query, params)
358 11 aaronmk
359 2054 aaronmk
def select(db, *args, **kw_args):
360
    '''For params, see mk_select() and run_query()'''
361
    recover = kw_args.pop('recover', None)
362
    cacheable = kw_args.pop('cacheable', True)
363
364
    query, params = mk_select(db, *args, **kw_args)
365
    return run_query(db, query, params, recover, cacheable)
366
367 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
368 2070 aaronmk
    returning=None, embeddable=False, table_is_esc=False):
369 1960 aaronmk
    '''
370
    @param returning str|None An inserted column (such as pkey) to return
371 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
372 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
373
        query will be fully cached, not just if it raises an exception.
374 1960 aaronmk
    @param table_is_esc Whether the table name has already been escaped
375
    '''
376 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
377
    if cols == []: cols = None # no cols (all defaults) = unknown col names
378 1960 aaronmk
    if not table_is_esc: check_name(table)
379 2063 aaronmk
380
    # Build query
381
    query = 'INSERT INTO '+table
382
    if cols != None:
383
        map(check_name, cols)
384
        query += ' ('+', '.join(cols)+')'
385
    query += ' '+select_query
386
387
    if returning != None:
388
        check_name(returning)
389
        query += ' RETURNING '+returning
390
391 2070 aaronmk
    if embeddable:
392
        # Create function
393 2083 aaronmk
        function = 'pg_temp.'+('_'.join(map(clean_name,
394
            ['insert', table] + cols)))
395 2070 aaronmk
        return_type = 'SETOF '+table+'.'+returning+'%TYPE'
396
        function_query = '''\
397 2083 aaronmk
CREATE OR REPLACE FUNCTION '''+function+'''() RETURNS '''+return_type+'''
398 2070 aaronmk
    LANGUAGE sql
399
    AS $$'''+mogrify(db, query, params)+''';$$;
400
'''
401
        run_query(db, function_query, cacheable=True)
402
403
        # Return query that uses function
404 2083 aaronmk
        return mk_select(db, function+'() AS f ('+returning+')',
405 2080 aaronmk
            table_is_esc=True) # function alias is required in AS clause
406 2070 aaronmk
407 2066 aaronmk
    return (query, params)
408
409
def insert_select(db, *args, **kw_args):
410 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
411 2072 aaronmk
    @param into Name of temp table to place RETURNING values in
412
    '''
413
    into = kw_args.pop('into', None)
414
    if into != None: kw_args['embeddable'] = True
415 2066 aaronmk
    recover = kw_args.pop('recover', None)
416
    cacheable = kw_args.pop('cacheable', True)
417
418
    query, params = mk_insert_select(db, *args, **kw_args)
419 2085 aaronmk
    return run_query_into(db, query, params, into, recover, cacheable)
420 2063 aaronmk
421 2066 aaronmk
default = object() # tells insert() to use the default value for a column
422
423 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
424 2085 aaronmk
    '''For params, see insert_select()'''
425 1960 aaronmk
    if lists.is_seq(row): cols = None
426
    else:
427
        cols = row.keys()
428
        row = row.values()
429
    row = list(row) # ensure that "!= []" works
430
431 1961 aaronmk
    # Check for special values
432
    labels = []
433
    values = []
434
    for value in row:
435
        if value == default: labels.append('DEFAULT')
436
        else:
437
            labels.append('%s')
438
            values.append(value)
439
440
    # Build query
441 2063 aaronmk
    if values != []: query = ' VALUES ('+(', '.join(labels))+')'
442
    else: query = None
443 1554 aaronmk
444 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
445 11 aaronmk
446 135 aaronmk
def last_insert_id(db):
447 1849 aaronmk
    module = util.root_module(db.db)
448 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
449
    elif module == 'MySQLdb': return db.insert_id()
450
    else: return None
451 13 aaronmk
452 1968 aaronmk
def truncate(db, table, schema='public'):
453
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
454 832 aaronmk
455
##### Database structure queries
456
457 2084 aaronmk
def pkey(db, table, recover=None, table_is_esc=False):
458 832 aaronmk
    '''Assumed to be first column in table'''
459 2084 aaronmk
    return col_names(select(db, table, limit=0, recover=recover,
460
        table_is_esc=table_is_esc)).next()
461 832 aaronmk
462 853 aaronmk
def index_cols(db, table, index):
463
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
464
    automatically created. When you don't know whether something is a UNIQUE
465
    constraint or a UNIQUE index, use this function.'''
466
    check_name(table)
467
    check_name(index)
468 1909 aaronmk
    module = util.root_module(db.db)
469
    if module == 'psycopg2':
470
        return list(values(run_query(db, '''\
471 853 aaronmk
SELECT attname
472 866 aaronmk
FROM
473
(
474
        SELECT attnum, attname
475
        FROM pg_index
476
        JOIN pg_class index ON index.oid = indexrelid
477
        JOIN pg_class table_ ON table_.oid = indrelid
478
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
479
        WHERE
480
            table_.relname = %(table)s
481
            AND index.relname = %(index)s
482
    UNION
483
        SELECT attnum, attname
484
        FROM
485
        (
486
            SELECT
487
                indrelid
488
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
489
                    AS indkey
490
            FROM pg_index
491
            JOIN pg_class index ON index.oid = indexrelid
492
            JOIN pg_class table_ ON table_.oid = indrelid
493
            WHERE
494
                table_.relname = %(table)s
495
                AND index.relname = %(index)s
496
        ) s
497
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
498
) s
499 853 aaronmk
ORDER BY attnum
500
''',
501 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
502
    else: raise NotImplementedError("Can't list index columns for "+module+
503
        ' database')
504 853 aaronmk
505 464 aaronmk
def constraint_cols(db, table, constraint):
506
    check_name(table)
507
    check_name(constraint)
508 1849 aaronmk
    module = util.root_module(db.db)
509 464 aaronmk
    if module == 'psycopg2':
510
        return list(values(run_query(db, '''\
511
SELECT attname
512
FROM pg_constraint
513
JOIN pg_class ON pg_class.oid = conrelid
514
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
515
WHERE
516
    relname = %(table)s
517
    AND conname = %(constraint)s
518
ORDER BY attnum
519
''',
520
            {'table': table, 'constraint': constraint})))
521
    else: raise NotImplementedError("Can't list constraint columns for "+module+
522
        ' database')
523
524 2096 aaronmk
row_num_col = '_row_num'
525
526 2086 aaronmk
def add_row_num(db, table):
527 2096 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col.'''
528 2086 aaronmk
    check_name(table)
529 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
530
        +' serial NOT NULL')
531 2086 aaronmk
532 1968 aaronmk
def tables(db, schema='public', table_like='%'):
533 1849 aaronmk
    module = util.root_module(db.db)
534 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
535 832 aaronmk
    if module == 'psycopg2':
536 1968 aaronmk
        return values(run_query(db, '''\
537
SELECT tablename
538
FROM pg_tables
539
WHERE
540
    schemaname = %(schema)s
541
    AND tablename LIKE %(table_like)s
542
ORDER BY tablename
543
''',
544
            params, cacheable=True))
545
    elif module == 'MySQLdb':
546
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
547
            cacheable=True))
548 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
549 830 aaronmk
550 833 aaronmk
##### Database management
551
552 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
553
    '''For kw_args, see tables()'''
554
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
555 833 aaronmk
556 832 aaronmk
##### Heuristic queries
557
558 2076 aaronmk
def with_parsed_errors(db, func):
559
    '''Translates known DB errors to typed exceptions'''
560
    try: return func()
561 46 aaronmk
    except Exception, e:
562
        msg = str(e)
563 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
564
            r'"(([^\W_]+)_[^"]+)"', msg)
565
        if match:
566
            constraint, table = match.groups()
567 854 aaronmk
            try: cols = index_cols(db, table, constraint)
568 465 aaronmk
            except NotImplementedError: raise e
569 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
570 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
571
            'constraint', msg)
572 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
573 13 aaronmk
        raise # no specific exception raised
574 11 aaronmk
575 2076 aaronmk
def try_insert(db, table, row, returning=None):
576
    '''Recovers from errors'''
577
    return with_parsed_errors(db, lambda: insert(db, table, row, returning,
578
        recover=True))
579
580 471 aaronmk
def put(db, table, row, pkey, row_ct_ref=None):
581 1554 aaronmk
    '''Recovers from errors.
582 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
583
    '''
584 471 aaronmk
    try:
585 1554 aaronmk
        cur = try_insert(db, table, row, pkey)
586
        if row_ct_ref != None and cur.rowcount >= 0:
587
            row_ct_ref[0] += cur.rowcount
588
        return value(cur)
589 471 aaronmk
    except DuplicateKeyException, e:
590 1069 aaronmk
        return value(select(db, table, [pkey],
591
            util.dict_subset_right_join(row, e.cols), recover=True))
592 471 aaronmk
593 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
594 830 aaronmk
    '''Recovers from errors'''
595
    try: return value(select(db, table, [pkey], row, 1, recover=True))
596 14 aaronmk
    except StopIteration:
597 40 aaronmk
        if not create: raise
598 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
599 2078 aaronmk
600 2087 aaronmk
def put_table(db, out_table, out_cols, in_tables, in_cols, pkey,
601
    row_ct_ref=None, table_is_esc=False):
602 2078 aaronmk
    '''Recovers from errors.
603
    Only works under PostgreSQL (uses INSERT RETURNING).
604 2081 aaronmk
    @return Name of the table where the pkeys (from INSERT RETURNING) are made
605 2078 aaronmk
        available
606
    '''
607
    pkeys_table = clean_name(out_table)+'_pkeys'
608
    def insert_():
609
        return insert_select(db, out_table, out_cols,
610 2087 aaronmk
            *mk_select(db, in_tables[0], in_cols, table_is_esc=table_is_esc),
611 2078 aaronmk
            returning=pkey, into=pkeys_table, recover=True,
612
            table_is_esc=table_is_esc)
613
    try:
614
        cur = with_parsed_errors(db, insert_)
615
        if row_ct_ref != None and cur.rowcount >= 0:
616
            row_ct_ref[0] += cur.rowcount
617 2086 aaronmk
618
        # Add row_num to pkeys_table, so it can be joined with in_table's pkeys
619
        add_row_num(db, pkeys_table)
620
621 2081 aaronmk
        return pkeys_table
622 2078 aaronmk
    except DuplicateKeyException, e: raise