Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 11 aaronmk
import re
5 865 aaronmk
import warnings
6 11 aaronmk
7 300 aaronmk
import exc
8 1909 aaronmk
import dicts
9 1893 aaronmk
import iters
10 1960 aaronmk
import lists
11 1889 aaronmk
from Proxy import Proxy
12 1872 aaronmk
import rand
13 862 aaronmk
import strings
14 131 aaronmk
import util
15 11 aaronmk
16 832 aaronmk
##### Exceptions
17
18 135 aaronmk
def get_cur_query(cur):
19
    if hasattr(cur, 'query'): return cur.query
20
    elif hasattr(cur, '_last_executed'): return cur._last_executed
21
    else: return None
22 14 aaronmk
23 300 aaronmk
def _add_cursor_info(e, cur): exc.add_msg(e, 'query: '+get_cur_query(cur))
24 135 aaronmk
25 300 aaronmk
class DbException(exc.ExceptionWithCause):
26 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
27 300 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause)
28 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
29
30 360 aaronmk
class NameException(DbException): pass
31
32 468 aaronmk
class ExceptionWithColumns(DbException):
33
    def __init__(self, cols, cause=None):
34
        DbException.__init__(self, 'columns: ' + ', '.join(cols), cause)
35
        self.cols = cols
36 11 aaronmk
37 468 aaronmk
class DuplicateKeyException(ExceptionWithColumns): pass
38 13 aaronmk
39 468 aaronmk
class NullValueException(ExceptionWithColumns): pass
40 13 aaronmk
41 89 aaronmk
class EmptyRowException(DbException): pass
42
43 865 aaronmk
##### Warnings
44
45
class DbWarning(UserWarning): pass
46
47 1930 aaronmk
##### Result retrieval
48
49
def col_names(cur): return (col[0] for col in cur.description)
50
51
def rows(cur): return iter(lambda: cur.fetchone(), None)
52
53
def consume_rows(cur):
54
    '''Used to fetch all rows so result will be cached'''
55
    iters.consume_iter(rows(cur))
56
57
def next_row(cur): return rows(cur).next()
58
59
def row(cur):
60
    row_ = next_row(cur)
61
    consume_rows(cur)
62
    return row_
63
64
def next_value(cur): return next_row(cur)[0]
65
66
def value(cur): return row(cur)[0]
67
68
def values(cur): return iters.func_iter(lambda: next_value(cur))
69
70
def value_or_none(cur):
71
    try: return value(cur)
72
    except StopIteration: return None
73
74 1869 aaronmk
##### Database connections
75 1849 aaronmk
76 1926 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database']
77
78 1869 aaronmk
db_engines = {
79
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
80
    'PostgreSQL': ('psycopg2', {}),
81
}
82
83
DatabaseErrors_set = set([DbException])
84
DatabaseErrors = tuple(DatabaseErrors_set)
85
86
def _add_module(module):
87
    DatabaseErrors_set.add(module.DatabaseError)
88
    global DatabaseErrors
89
    DatabaseErrors = tuple(DatabaseErrors_set)
90
91
def db_config_str(db_config):
92
    return db_config['engine']+' database '+db_config['database']
93
94 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
95 1894 aaronmk
96 1901 aaronmk
log_debug_none = lambda msg: None
97
98 1849 aaronmk
class DbConn:
99 2047 aaronmk
    def __init__(self, db_config, serializable=True, log_debug=log_debug_none,
100 2050 aaronmk
        caching=True):
101 1869 aaronmk
        self.db_config = db_config
102
        self.serializable = serializable
103 1901 aaronmk
        self.log_debug = log_debug
104 2047 aaronmk
        self.caching = caching
105 1869 aaronmk
106
        self.__db = None
107 1889 aaronmk
        self.query_results = {}
108 1869 aaronmk
109
    def __getattr__(self, name):
110
        if name == '__dict__': raise Exception('getting __dict__')
111
        if name == 'db': return self._db()
112
        else: raise AttributeError()
113
114
    def __getstate__(self):
115
        state = copy.copy(self.__dict__) # shallow copy
116 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
117 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
118
        return state
119
120
    def _db(self):
121
        if self.__db == None:
122
            # Process db_config
123
            db_config = self.db_config.copy() # don't modify input!
124
            module_name, mappings = db_engines[db_config.pop('engine')]
125
            module = __import__(module_name)
126
            _add_module(module)
127
            for orig, new in mappings.iteritems():
128
                try: util.rename_key(db_config, orig, new)
129
                except KeyError: pass
130
131
            # Connect
132
            self.__db = module.connect(**db_config)
133
134
            # Configure connection
135
            if self.serializable: run_raw_query(self,
136
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
137
138
        return self.__db
139 1889 aaronmk
140 1891 aaronmk
    class DbCursor(Proxy):
141 1927 aaronmk
        def __init__(self, outer):
142 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
143 1927 aaronmk
            self.query_results = outer.query_results
144 1894 aaronmk
            self.query_lookup = None
145 1891 aaronmk
            self.result = []
146 1889 aaronmk
147 1894 aaronmk
        def execute(self, query, params=None):
148 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
149 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
150 1904 aaronmk
            try: return_value = self.inner.execute(query, params)
151
            except Exception, e:
152
                self.result = e # cache the exception as the result
153
                self._cache_result()
154
                raise
155
            finally: self.query = get_cur_query(self.inner)
156 1930 aaronmk
            # Fetch all rows so result will be cached
157
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
158 1894 aaronmk
            return return_value
159
160 1891 aaronmk
        def fetchone(self):
161
            row = self.inner.fetchone()
162 1899 aaronmk
            if row != None: self.result.append(row)
163
            # otherwise, fetched all rows
164 1904 aaronmk
            else: self._cache_result()
165
            return row
166
167
        def _cache_result(self):
168 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
169
            # idempotent, but an invalid insert will always be invalid
170 1930 aaronmk
            if self.query_results != None and (not self._is_insert
171 1906 aaronmk
                or isinstance(self.result, Exception)):
172
173 1894 aaronmk
                assert self.query_lookup != None
174 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
175
                    util.dict_subset(dicts.AttrsDictView(self),
176
                    ['query', 'result', 'rowcount', 'description']))
177 1906 aaronmk
178 1916 aaronmk
        class CacheCursor:
179
            def __init__(self, cached_result): self.__dict__ = cached_result
180
181 1927 aaronmk
            def execute(self, *args, **kw_args):
182 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
183
                # otherwise, result is a rows list
184
                self.iter = iter(self.result)
185
186
            def fetchone(self):
187
                try: return self.iter.next()
188
                except StopIteration: return None
189 1891 aaronmk
190 1894 aaronmk
    def run_query(self, query, params=None, cacheable=False):
191 2047 aaronmk
        if not self.caching: cacheable = False
192 1903 aaronmk
        used_cache = False
193
        try:
194 1927 aaronmk
            # Get cursor
195
            if cacheable:
196
                query_lookup = _query_lookup(query, params)
197
                try:
198
                    cur = self.query_results[query_lookup]
199
                    used_cache = True
200
                except KeyError: cur = self.DbCursor(self)
201
            else: cur = self.db.cursor()
202
203
            # Run query
204
            try: cur.execute(query, params)
205
            except Exception, e:
206
                _add_cursor_info(e, cur)
207
                raise
208 1903 aaronmk
        finally:
209
            if self.log_debug != log_debug_none: # only compute msg if needed
210
                if used_cache: cache_status = 'Cache hit'
211
                elif cacheable: cache_status = 'Cache miss'
212
                else: cache_status = 'Non-cacheable'
213 1927 aaronmk
                self.log_debug(cache_status+': '
214
                    +strings.one_line(get_cur_query(cur)))
215 1903 aaronmk
216
        return cur
217 1914 aaronmk
218
    def is_cached(self, query, params=None):
219
        return _query_lookup(query, params) in self.query_results
220 1849 aaronmk
221 1869 aaronmk
connect = DbConn
222
223 1919 aaronmk
##### Input validation
224
225 2077 aaronmk
def clean_name(name): return re.sub(r'\W', r'', name)
226
227 1919 aaronmk
def check_name(name):
228
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
229
        +'" may contain only alphanumeric characters and _')
230
231 2061 aaronmk
def esc_name_by_module(module, name, ignore_case=False):
232 1919 aaronmk
    if module == 'psycopg2':
233 2061 aaronmk
        if ignore_case:
234
            # Don't enclose in quotes because this disables case-insensitivity
235 2057 aaronmk
            check_name(name)
236
            return name
237 2061 aaronmk
        else: quote = '"'
238 1919 aaronmk
    elif module == 'MySQLdb': quote = '`'
239
    else: raise NotImplementedError("Can't escape name for "+module+' database')
240
    return quote + name.replace(quote, '') + quote
241
242
def esc_name_by_engine(engine, name, **kw_args):
243
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
244
245
def esc_name(db, name, **kw_args):
246
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
247
248 1968 aaronmk
def qual_name(db, schema, table):
249 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
250 2051 aaronmk
    table = esc_name_(table)
251
    if schema != None: return esc_name_(schema)+'.'+table
252
    else: return table
253 1968 aaronmk
254 832 aaronmk
##### Querying
255
256 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
257 2085 aaronmk
    '''For params, see DbConn.run_query()'''
258 1894 aaronmk
    return db.run_query(*args, **kw_args)
259 11 aaronmk
260 2068 aaronmk
def mogrify(db, query, params):
261
    module = util.root_module(db.db)
262
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
263
    else: raise NotImplementedError("Can't mogrify query for "+module+
264
        ' database')
265
266 832 aaronmk
##### Recoverable querying
267 15 aaronmk
268 11 aaronmk
def with_savepoint(db, func):
269 1872 aaronmk
    savepoint = 'savepoint_'+str(rand.rand_int()) # must be unique
270 830 aaronmk
    run_raw_query(db, 'SAVEPOINT '+savepoint)
271 11 aaronmk
    try: return_val = func()
272
    except:
273 830 aaronmk
        run_raw_query(db, 'ROLLBACK TO SAVEPOINT '+savepoint)
274 11 aaronmk
        raise
275
    else:
276 830 aaronmk
        run_raw_query(db, 'RELEASE SAVEPOINT '+savepoint)
277 11 aaronmk
        return return_val
278
279 1894 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False):
280 830 aaronmk
    if recover == None: recover = False
281
282 1894 aaronmk
    def run(): return run_raw_query(db, query, params, cacheable)
283 1914 aaronmk
    if recover and not db.is_cached(query, params):
284
        return with_savepoint(db, run)
285
    else: return run() # don't need savepoint if cached
286 830 aaronmk
287 832 aaronmk
##### Basic queries
288
289 2085 aaronmk
def run_query_into(db, query, params, into=None, *args, **kw_args):
290
    '''Outputs a query to a temp table.
291
    For params, see run_query().
292
    '''
293
    if into == None: return run_query(db, query, params, *args, **kw_args)
294
    else: # place rows in temp table
295
        check_name(into)
296
297
        run_query(db, 'DROP TABLE IF EXISTS '+into+' CASCADE', *args, **kw_args)
298
        return run_query(db, 'CREATE TEMP TABLE '+into+' AS '+query, params,
299
            *args, **kw_args) # CREATE TABLE sets rowcount to # rows in query
300
301 2054 aaronmk
def mk_select(db, table, fields=None, conds=None, limit=None, start=None,
302
    table_is_esc=False):
303 1981 aaronmk
    '''
304
    @param fields Use None to select all fields in the table
305
    @param table_is_esc Whether the table name has already been escaped
306 2054 aaronmk
    @return tuple(query, params)
307 1981 aaronmk
    '''
308 2060 aaronmk
    def esc_name_(name): return esc_name(db, name)
309 2058 aaronmk
310 1135 aaronmk
    if conds == None: conds = {}
311 135 aaronmk
    assert limit == None or type(limit) == int
312 865 aaronmk
    assert start == None or type(start) == int
313 2058 aaronmk
    if not table_is_esc: table = esc_name_(table)
314 865 aaronmk
315 2056 aaronmk
    params = []
316
317
    def parse_col(field):
318
        '''Parses fields'''
319
        if isinstance(field, tuple): # field is literal values
320
            value, col = field
321
            sql_ = '%s'
322
            params.append(value)
323 2058 aaronmk
            if col != None: sql_ += ' AS '+esc_name_(col)
324
        else: sql_ = esc_name_(field) # field is col name
325 2056 aaronmk
        return sql_
326 11 aaronmk
    def cond(entry):
327 2056 aaronmk
        '''Parses conditions'''
328 13 aaronmk
        col, value = entry
329 2058 aaronmk
        cond_ = esc_name_(col)+' '
330 11 aaronmk
        if value == None: cond_ += 'IS'
331
        else: cond_ += '='
332
        cond_ += ' %s'
333
        return cond_
334 2056 aaronmk
335 1135 aaronmk
    query = 'SELECT '
336
    if fields == None: query += '*'
337 2056 aaronmk
    else: query += ', '.join(map(parse_col, fields))
338 2055 aaronmk
    query += ' FROM '+table
339 865 aaronmk
340
    missing = True
341 89 aaronmk
    if conds != {}:
342
        query += ' WHERE '+' AND '.join(map(cond, conds.iteritems()))
343 2056 aaronmk
        params += conds.values()
344 865 aaronmk
        missing = False
345
    if limit != None: query += ' LIMIT '+str(limit); missing = False
346
    if start != None:
347
        if start != 0: query += ' OFFSET '+str(start)
348
        missing = False
349
    if missing: warnings.warn(DbWarning(
350
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
351
352 2056 aaronmk
    return (query, params)
353 11 aaronmk
354 2054 aaronmk
def select(db, *args, **kw_args):
355
    '''For params, see mk_select() and run_query()'''
356
    recover = kw_args.pop('recover', None)
357
    cacheable = kw_args.pop('cacheable', True)
358
359
    query, params = mk_select(db, *args, **kw_args)
360
    return run_query(db, query, params, recover, cacheable)
361
362 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
363 2070 aaronmk
    returning=None, embeddable=False, table_is_esc=False):
364 1960 aaronmk
    '''
365
    @param returning str|None An inserted column (such as pkey) to return
366 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
367 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
368
        query will be fully cached, not just if it raises an exception.
369 1960 aaronmk
    @param table_is_esc Whether the table name has already been escaped
370
    '''
371 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
372
    if cols == []: cols = None # no cols (all defaults) = unknown col names
373 1960 aaronmk
    if not table_is_esc: check_name(table)
374 2063 aaronmk
375
    # Build query
376
    query = 'INSERT INTO '+table
377
    if cols != None:
378
        map(check_name, cols)
379
        query += ' ('+', '.join(cols)+')'
380
    query += ' '+select_query
381
382
    if returning != None:
383
        check_name(returning)
384
        query += ' RETURNING '+returning
385
386 2070 aaronmk
    if embeddable:
387
        # Create function
388 2083 aaronmk
        function = 'pg_temp.'+('_'.join(map(clean_name,
389
            ['insert', table] + cols)))
390 2070 aaronmk
        return_type = 'SETOF '+table+'.'+returning+'%TYPE'
391
        function_query = '''\
392 2083 aaronmk
CREATE OR REPLACE FUNCTION '''+function+'''() RETURNS '''+return_type+'''
393 2070 aaronmk
    LANGUAGE sql
394
    AS $$'''+mogrify(db, query, params)+''';$$;
395
'''
396
        run_query(db, function_query, cacheable=True)
397
398
        # Return query that uses function
399 2083 aaronmk
        return mk_select(db, function+'() AS f ('+returning+')',
400 2080 aaronmk
            table_is_esc=True) # function alias is required in AS clause
401 2070 aaronmk
402 2066 aaronmk
    return (query, params)
403
404
def insert_select(db, *args, **kw_args):
405 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
406 2072 aaronmk
    @param into Name of temp table to place RETURNING values in
407
    '''
408
    into = kw_args.pop('into', None)
409
    if into != None: kw_args['embeddable'] = True
410 2066 aaronmk
    recover = kw_args.pop('recover', None)
411
    cacheable = kw_args.pop('cacheable', True)
412
413
    query, params = mk_insert_select(db, *args, **kw_args)
414 2085 aaronmk
    return run_query_into(db, query, params, into, recover, cacheable)
415 2063 aaronmk
416 2066 aaronmk
default = object() # tells insert() to use the default value for a column
417
418 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
419 2085 aaronmk
    '''For params, see insert_select()'''
420 1960 aaronmk
    if lists.is_seq(row): cols = None
421
    else:
422
        cols = row.keys()
423
        row = row.values()
424
    row = list(row) # ensure that "!= []" works
425
426 1961 aaronmk
    # Check for special values
427
    labels = []
428
    values = []
429
    for value in row:
430
        if value == default: labels.append('DEFAULT')
431
        else:
432
            labels.append('%s')
433
            values.append(value)
434
435
    # Build query
436 2063 aaronmk
    if values != []: query = ' VALUES ('+(', '.join(labels))+')'
437
    else: query = None
438 1554 aaronmk
439 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
440 11 aaronmk
441 135 aaronmk
def last_insert_id(db):
442 1849 aaronmk
    module = util.root_module(db.db)
443 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
444
    elif module == 'MySQLdb': return db.insert_id()
445
    else: return None
446 13 aaronmk
447 1968 aaronmk
def truncate(db, table, schema='public'):
448
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
449 832 aaronmk
450
##### Database structure queries
451
452 2084 aaronmk
def pkey(db, table, recover=None, table_is_esc=False):
453 832 aaronmk
    '''Assumed to be first column in table'''
454 2084 aaronmk
    return col_names(select(db, table, limit=0, recover=recover,
455
        table_is_esc=table_is_esc)).next()
456 832 aaronmk
457 853 aaronmk
def index_cols(db, table, index):
458
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
459
    automatically created. When you don't know whether something is a UNIQUE
460
    constraint or a UNIQUE index, use this function.'''
461
    check_name(table)
462
    check_name(index)
463 1909 aaronmk
    module = util.root_module(db.db)
464
    if module == 'psycopg2':
465
        return list(values(run_query(db, '''\
466 853 aaronmk
SELECT attname
467 866 aaronmk
FROM
468
(
469
        SELECT attnum, attname
470
        FROM pg_index
471
        JOIN pg_class index ON index.oid = indexrelid
472
        JOIN pg_class table_ ON table_.oid = indrelid
473
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
474
        WHERE
475
            table_.relname = %(table)s
476
            AND index.relname = %(index)s
477
    UNION
478
        SELECT attnum, attname
479
        FROM
480
        (
481
            SELECT
482
                indrelid
483
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
484
                    AS indkey
485
            FROM pg_index
486
            JOIN pg_class index ON index.oid = indexrelid
487
            JOIN pg_class table_ ON table_.oid = indrelid
488
            WHERE
489
                table_.relname = %(table)s
490
                AND index.relname = %(index)s
491
        ) s
492
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
493
) s
494 853 aaronmk
ORDER BY attnum
495
''',
496 1909 aaronmk
            {'table': table, 'index': index}, cacheable=True)))
497
    else: raise NotImplementedError("Can't list index columns for "+module+
498
        ' database')
499 853 aaronmk
500 464 aaronmk
def constraint_cols(db, table, constraint):
501
    check_name(table)
502
    check_name(constraint)
503 1849 aaronmk
    module = util.root_module(db.db)
504 464 aaronmk
    if module == 'psycopg2':
505
        return list(values(run_query(db, '''\
506
SELECT attname
507
FROM pg_constraint
508
JOIN pg_class ON pg_class.oid = conrelid
509
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
510
WHERE
511
    relname = %(table)s
512
    AND conname = %(constraint)s
513
ORDER BY attnum
514
''',
515
            {'table': table, 'constraint': constraint})))
516
    else: raise NotImplementedError("Can't list constraint columns for "+module+
517
        ' database')
518
519 2086 aaronmk
def add_row_num(db, table):
520
    '''Adds a row number column named `row_num` to a table'''
521
    check_name(table)
522
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN row_num serial NOT NULL')
523
524 1968 aaronmk
def tables(db, schema='public', table_like='%'):
525 1849 aaronmk
    module = util.root_module(db.db)
526 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
527 832 aaronmk
    if module == 'psycopg2':
528 1968 aaronmk
        return values(run_query(db, '''\
529
SELECT tablename
530
FROM pg_tables
531
WHERE
532
    schemaname = %(schema)s
533
    AND tablename LIKE %(table_like)s
534
ORDER BY tablename
535
''',
536
            params, cacheable=True))
537
    elif module == 'MySQLdb':
538
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
539
            cacheable=True))
540 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
541 830 aaronmk
542 833 aaronmk
##### Database management
543
544 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
545
    '''For kw_args, see tables()'''
546
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
547 833 aaronmk
548 832 aaronmk
##### Heuristic queries
549
550 2076 aaronmk
def with_parsed_errors(db, func):
551
    '''Translates known DB errors to typed exceptions'''
552
    try: return func()
553 46 aaronmk
    except Exception, e:
554
        msg = str(e)
555 465 aaronmk
        match = re.search(r'duplicate key value violates unique constraint '
556
            r'"(([^\W_]+)_[^"]+)"', msg)
557
        if match:
558
            constraint, table = match.groups()
559 854 aaronmk
            try: cols = index_cols(db, table, constraint)
560 465 aaronmk
            except NotImplementedError: raise e
561 851 aaronmk
            else: raise DuplicateKeyException(cols, e)
562 13 aaronmk
        match = re.search(r'null value in column "(\w+)" violates not-null '
563
            'constraint', msg)
564 470 aaronmk
        if match: raise NullValueException([match.group(1)], e)
565 13 aaronmk
        raise # no specific exception raised
566 11 aaronmk
567 2076 aaronmk
def try_insert(db, table, row, returning=None):
568
    '''Recovers from errors'''
569
    return with_parsed_errors(db, lambda: insert(db, table, row, returning,
570
        recover=True))
571
572 471 aaronmk
def put(db, table, row, pkey, row_ct_ref=None):
573 1554 aaronmk
    '''Recovers from errors.
574 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
575
    '''
576 471 aaronmk
    try:
577 1554 aaronmk
        cur = try_insert(db, table, row, pkey)
578
        if row_ct_ref != None and cur.rowcount >= 0:
579
            row_ct_ref[0] += cur.rowcount
580
        return value(cur)
581 471 aaronmk
    except DuplicateKeyException, e:
582 1069 aaronmk
        return value(select(db, table, [pkey],
583
            util.dict_subset_right_join(row, e.cols), recover=True))
584 471 aaronmk
585 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
586 830 aaronmk
    '''Recovers from errors'''
587
    try: return value(select(db, table, [pkey], row, 1, recover=True))
588 14 aaronmk
    except StopIteration:
589 40 aaronmk
        if not create: raise
590 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
591 2078 aaronmk
592 2087 aaronmk
def put_table(db, out_table, out_cols, in_tables, in_cols, pkey,
593
    row_ct_ref=None, table_is_esc=False):
594 2078 aaronmk
    '''Recovers from errors.
595
    Only works under PostgreSQL (uses INSERT RETURNING).
596 2081 aaronmk
    @return Name of the table where the pkeys (from INSERT RETURNING) are made
597 2078 aaronmk
        available
598
    '''
599
    pkeys_table = clean_name(out_table)+'_pkeys'
600
    def insert_():
601
        return insert_select(db, out_table, out_cols,
602 2087 aaronmk
            *mk_select(db, in_tables[0], in_cols, table_is_esc=table_is_esc),
603 2078 aaronmk
            returning=pkey, into=pkeys_table, recover=True,
604
            table_is_esc=table_is_esc)
605
    try:
606
        cur = with_parsed_errors(db, insert_)
607
        if row_ct_ref != None and cur.rowcount >= 0:
608
            row_ct_ref[0] += cur.rowcount
609 2086 aaronmk
610
        # Add row_num to pkeys_table, so it can be joined with in_table's pkeys
611
        add_row_num(db, pkeys_table)
612
613 2081 aaronmk
        return pkeys_table
614 2078 aaronmk
    except DuplicateKeyException, e: raise