Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 2127 aaronmk
import operator
5 11 aaronmk
import re
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 1889 aaronmk
from Proxy import Proxy
13 1872 aaronmk
import rand
14 2217 aaronmk
import sql_gen
15 862 aaronmk
import strings
16 131 aaronmk
import util
17 11 aaronmk
18 832 aaronmk
##### Exceptions
19
20 2170 aaronmk
def get_cur_query(cur, input_query=None, input_params=None):
21 2168 aaronmk
    raw_query = None
22
    if hasattr(cur, 'query'): raw_query = cur.query
23
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
24 2170 aaronmk
25
    if raw_query != None: return raw_query
26 2371 aaronmk
    else: return '[input] '+strings.ustr(input_query)+' % '+repr(input_params)
27 14 aaronmk
28 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
29
    '''For params, see get_cur_query()'''
30
    exc.add_msg(e, 'query: '+str(get_cur_query(*args, **kw_args)))
31 135 aaronmk
32 300 aaronmk
class DbException(exc.ExceptionWithCause):
33 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
34 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
35 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
36
37 2143 aaronmk
class ExceptionWithName(DbException):
38
    def __init__(self, name, cause=None):
39 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
40 2143 aaronmk
        self.name = name
41 360 aaronmk
42 2240 aaronmk
class ExceptionWithNameValue(DbException):
43
    def __init__(self, name, value, cause=None):
44 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name))
45
            +'; value: '+strings.as_tt(repr(value)), cause)
46 2240 aaronmk
        self.name = name
47
        self.value = value
48
49 2306 aaronmk
class ConstraintException(DbException):
50
    def __init__(self, name, cols, cause=None):
51 2484 aaronmk
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
52
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
53 2306 aaronmk
        self.name = name
54 468 aaronmk
        self.cols = cols
55 11 aaronmk
56 2143 aaronmk
class NameException(DbException): pass
57
58 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
59 13 aaronmk
60 2306 aaronmk
class NullValueException(ConstraintException): pass
61 13 aaronmk
62 2240 aaronmk
class FunctionValueException(ExceptionWithNameValue): pass
63 2239 aaronmk
64 2143 aaronmk
class DuplicateTableException(ExceptionWithName): pass
65
66 2188 aaronmk
class DuplicateFunctionException(ExceptionWithName): pass
67
68 89 aaronmk
class EmptyRowException(DbException): pass
69
70 865 aaronmk
##### Warnings
71
72
class DbWarning(UserWarning): pass
73
74 1930 aaronmk
##### Result retrieval
75
76
def col_names(cur): return (col[0] for col in cur.description)
77
78
def rows(cur): return iter(lambda: cur.fetchone(), None)
79
80
def consume_rows(cur):
81
    '''Used to fetch all rows so result will be cached'''
82
    iters.consume_iter(rows(cur))
83
84
def next_row(cur): return rows(cur).next()
85
86
def row(cur):
87
    row_ = next_row(cur)
88
    consume_rows(cur)
89
    return row_
90
91
def next_value(cur): return next_row(cur)[0]
92
93
def value(cur): return row(cur)[0]
94
95
def values(cur): return iters.func_iter(lambda: next_value(cur))
96
97
def value_or_none(cur):
98
    try: return value(cur)
99
    except StopIteration: return None
100
101 2101 aaronmk
##### Input validation
102
103 2492 aaronmk
def clean_name(name): return name.replace('"', '')
104 2396 aaronmk
105 2101 aaronmk
def check_name(name):
106
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
107
        +'" may contain only alphanumeric characters and _')
108
109
def esc_name_by_module(module, name, ignore_case=False):
110 2388 aaronmk
    if module == 'psycopg2' or module == None:
111 2101 aaronmk
        if ignore_case:
112
            # Don't enclose in quotes because this disables case-insensitivity
113
            check_name(name)
114
            return name
115
        else: quote = '"'
116
    elif module == 'MySQLdb': quote = '`'
117
    else: raise NotImplementedError("Can't escape name for "+module+' database')
118
    return quote + name.replace(quote, '') + quote
119
120
def esc_name_by_engine(engine, name, **kw_args):
121
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
122
123
def esc_name(db, name, **kw_args):
124
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
125
126
def qual_name(db, schema, table):
127
    def esc_name_(name): return esc_name(db, name)
128
    table = esc_name_(table)
129
    if schema != None: return esc_name_(schema)+'.'+table
130
    else: return table
131
132 1869 aaronmk
##### Database connections
133 1849 aaronmk
134 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
135 1926 aaronmk
136 1869 aaronmk
db_engines = {
137
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
138
    'PostgreSQL': ('psycopg2', {}),
139
}
140
141
DatabaseErrors_set = set([DbException])
142
DatabaseErrors = tuple(DatabaseErrors_set)
143
144
def _add_module(module):
145
    DatabaseErrors_set.add(module.DatabaseError)
146
    global DatabaseErrors
147
    DatabaseErrors = tuple(DatabaseErrors_set)
148
149
def db_config_str(db_config):
150
    return db_config['engine']+' database '+db_config['database']
151
152 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
153 1894 aaronmk
154 2448 aaronmk
log_debug_none = lambda msg, level=2: None
155 1901 aaronmk
156 1849 aaronmk
class DbConn:
157 2190 aaronmk
    def __init__(self, db_config, serializable=True, autocommit=False,
158
        caching=True, log_debug=log_debug_none):
159 1869 aaronmk
        self.db_config = db_config
160
        self.serializable = serializable
161 2190 aaronmk
        self.autocommit = autocommit
162
        self.caching = caching
163 1901 aaronmk
        self.log_debug = log_debug
164 2193 aaronmk
        self.debug = log_debug != log_debug_none
165 1869 aaronmk
166
        self.__db = None
167 1889 aaronmk
        self.query_results = {}
168 2139 aaronmk
        self._savepoint = 0
169 1869 aaronmk
170
    def __getattr__(self, name):
171
        if name == '__dict__': raise Exception('getting __dict__')
172
        if name == 'db': return self._db()
173
        else: raise AttributeError()
174
175
    def __getstate__(self):
176
        state = copy.copy(self.__dict__) # shallow copy
177 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
178 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
179
        return state
180
181 2165 aaronmk
    def connected(self): return self.__db != None
182
183 1869 aaronmk
    def _db(self):
184
        if self.__db == None:
185
            # Process db_config
186
            db_config = self.db_config.copy() # don't modify input!
187 2097 aaronmk
            schemas = db_config.pop('schemas', None)
188 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
189
            module = __import__(module_name)
190
            _add_module(module)
191
            for orig, new in mappings.iteritems():
192
                try: util.rename_key(db_config, orig, new)
193
                except KeyError: pass
194
195
            # Connect
196
            self.__db = module.connect(**db_config)
197
198
            # Configure connection
199 2234 aaronmk
            if self.serializable and not self.autocommit: run_raw_query(self,
200
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
201 2101 aaronmk
            if schemas != None:
202
                schemas_ = ''.join((esc_name(self, s)+', '
203
                    for s in schemas.split(',')))
204
                run_raw_query(self, "SELECT set_config('search_path', \
205
%s || current_setting('search_path'), false)", [schemas_])
206 1869 aaronmk
207
        return self.__db
208 1889 aaronmk
209 1891 aaronmk
    class DbCursor(Proxy):
210 1927 aaronmk
        def __init__(self, outer):
211 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
212 2191 aaronmk
            self.outer = outer
213 1927 aaronmk
            self.query_results = outer.query_results
214 1894 aaronmk
            self.query_lookup = None
215 1891 aaronmk
            self.result = []
216 1889 aaronmk
217 1894 aaronmk
        def execute(self, query, params=None):
218 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
219 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
220 2148 aaronmk
            try:
221 2191 aaronmk
                try:
222
                    return_value = self.inner.execute(query, params)
223
                    self.outer.do_autocommit()
224 2148 aaronmk
                finally: self.query = get_cur_query(self.inner)
225 1904 aaronmk
            except Exception, e:
226 2170 aaronmk
                _add_cursor_info(e, self, query, params)
227 1904 aaronmk
                self.result = e # cache the exception as the result
228
                self._cache_result()
229
                raise
230 1930 aaronmk
            # Fetch all rows so result will be cached
231
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
232 1894 aaronmk
            return return_value
233
234 1891 aaronmk
        def fetchone(self):
235
            row = self.inner.fetchone()
236 1899 aaronmk
            if row != None: self.result.append(row)
237
            # otherwise, fetched all rows
238 1904 aaronmk
            else: self._cache_result()
239
            return row
240
241
        def _cache_result(self):
242 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
243
            # idempotent, but an invalid insert will always be invalid
244 1930 aaronmk
            if self.query_results != None and (not self._is_insert
245 1906 aaronmk
                or isinstance(self.result, Exception)):
246
247 1894 aaronmk
                assert self.query_lookup != None
248 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
249
                    util.dict_subset(dicts.AttrsDictView(self),
250
                    ['query', 'result', 'rowcount', 'description']))
251 1906 aaronmk
252 1916 aaronmk
        class CacheCursor:
253
            def __init__(self, cached_result): self.__dict__ = cached_result
254
255 1927 aaronmk
            def execute(self, *args, **kw_args):
256 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
257
                # otherwise, result is a rows list
258
                self.iter = iter(self.result)
259
260
            def fetchone(self):
261
                try: return self.iter.next()
262
                except StopIteration: return None
263 1891 aaronmk
264 2212 aaronmk
    def esc_value(self, value):
265 2215 aaronmk
        module = util.root_module(self.db)
266 2374 aaronmk
        if module == 'psycopg2': str_ = self.db.cursor().mogrify('%s', [value])
267 2212 aaronmk
        elif module == 'MySQLdb':
268
            import _mysql
269 2374 aaronmk
            str_ = _mysql.escape_string(value)
270 2212 aaronmk
        else: raise NotImplementedError("Can't escape value for "+module
271
            +' database')
272 2374 aaronmk
        return strings.to_unicode(str_)
273 2212 aaronmk
274 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
275
276 2445 aaronmk
    def run_query(self, query, params=None, cacheable=False, log_level=2,
277 2464 aaronmk
        debug_msg_ref=None):
278 2445 aaronmk
        '''
279 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
280
            throws one of these exceptions.
281 2445 aaronmk
        '''
282 2167 aaronmk
        assert query != None
283
284 2047 aaronmk
        if not self.caching: cacheable = False
285 1903 aaronmk
        used_cache = False
286
        try:
287 1927 aaronmk
            # Get cursor
288
            if cacheable:
289
                query_lookup = _query_lookup(query, params)
290
                try:
291
                    cur = self.query_results[query_lookup]
292
                    used_cache = True
293
                except KeyError: cur = self.DbCursor(self)
294
            else: cur = self.db.cursor()
295
296
            # Run query
297 2148 aaronmk
            cur.execute(query, params)
298 1903 aaronmk
        finally:
299 2464 aaronmk
            if self.debug and debug_msg_ref != None:# only compute msg if needed
300 2470 aaronmk
                if used_cache: cache_status = 'cache hit'
301
                elif cacheable: cache_status = 'cache miss'
302
                else: cache_status = 'non-cacheable'
303 2472 aaronmk
                query_code = strings.as_code(str(get_cur_query(cur, query,
304
                    params)), 'SQL')
305
                debug_msg_ref[0] = 'DB query: '+cache_status+':\n'+query_code
306 1903 aaronmk
307
        return cur
308 1914 aaronmk
309
    def is_cached(self, query, params=None):
310
        return _query_lookup(query, params) in self.query_results
311 2139 aaronmk
312
    def with_savepoint(self, func):
313 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
314 2443 aaronmk
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
315 2139 aaronmk
        self._savepoint += 1
316
        try:
317
            try: return_val = func()
318
            finally:
319
                self._savepoint -= 1
320
                assert self._savepoint >= 0
321
        except:
322 2443 aaronmk
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
323 2139 aaronmk
            raise
324
        else:
325 2443 aaronmk
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
326 2191 aaronmk
            self.do_autocommit()
327 2139 aaronmk
            return return_val
328 2191 aaronmk
329
    def do_autocommit(self):
330
        '''Autocommits if outside savepoint'''
331
        assert self._savepoint >= 0
332
        if self.autocommit and self._savepoint == 0:
333
            self.log_debug('Autocommiting')
334
            self.db.commit()
335 1849 aaronmk
336 1869 aaronmk
connect = DbConn
337
338 832 aaronmk
##### Querying
339
340 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
341 2085 aaronmk
    '''For params, see DbConn.run_query()'''
342 1894 aaronmk
    return db.run_query(*args, **kw_args)
343 11 aaronmk
344 2068 aaronmk
def mogrify(db, query, params):
345
    module = util.root_module(db.db)
346
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
347
    else: raise NotImplementedError("Can't mogrify query for "+module+
348
        ' database')
349
350 832 aaronmk
##### Recoverable querying
351 15 aaronmk
352 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
353 11 aaronmk
354 2464 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False,
355
    log_level=2, log_ignore_excs=None, **kw_args):
356 2441 aaronmk
    '''For params, see run_raw_query()'''
357 830 aaronmk
    if recover == None: recover = False
358 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
359
    log_ignore_excs = tuple(log_ignore_excs)
360 830 aaronmk
361 2464 aaronmk
    debug_msg_ref = [None]
362 2148 aaronmk
    try:
363 2464 aaronmk
        try:
364
            def run(): return run_raw_query(db, query, params, cacheable,
365
                log_level, debug_msg_ref, **kw_args)
366
            if recover and not db.is_cached(query, params):
367
                return with_savepoint(db, run)
368
            else: return run() # don't need savepoint if cached
369
        except Exception, e:
370
            if not recover: raise # need savepoint to run index_cols()
371
            msg = exc.str_(e)
372
373
            match = re.search(r'duplicate key value violates unique constraint '
374 2493 aaronmk
                r'"((_?[^\W_]+)_.+?)"', msg)
375 2464 aaronmk
            if match:
376
                constraint, table = match.groups()
377
                try: cols = index_cols(db, table, constraint)
378
                except NotImplementedError: raise e
379
                else: raise DuplicateKeyException(constraint, cols, e)
380
381 2493 aaronmk
            match = re.search(r'null value in column "(.+?)" violates not-null'
382 2464 aaronmk
                r' constraint', msg)
383
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
384
385
            match = re.search(r'\b(?:invalid input (?:syntax|value)\b.*?'
386
                r'|date/time field value out of range): "(.+?)"\n'
387 2493 aaronmk
                r'(?:(?s).*?)\bfunction "(.+?)".*?\bat assignment', msg)
388 2464 aaronmk
            if match:
389
                value, name = match.groups()
390
                raise FunctionValueException(name, strings.to_unicode(value), e)
391
392 2493 aaronmk
            match = re.search(r'relation "(.+?)" already exists', msg)
393 2464 aaronmk
            if match: raise DuplicateTableException(match.group(1), e)
394
395 2493 aaronmk
            match = re.search(r'function "(.+?)" already exists', msg)
396 2464 aaronmk
            if match: raise DuplicateFunctionException(match.group(1), e)
397
398
            raise # no specific exception raised
399
    except log_ignore_excs:
400
        log_level += 2
401
        raise
402
    finally:
403
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)
404 830 aaronmk
405 832 aaronmk
##### Basic queries
406
407 2153 aaronmk
def next_version(name):
408
    '''Prepends the version # so it won't be removed if the name is truncated'''
409 2163 aaronmk
    version = 1 # first existing name was version 0
410 2153 aaronmk
    match = re.match(r'^v(\d+)_(.*)$', name)
411
    if match:
412
        version = int(match.group(1))+1
413
        name = match.group(2)
414
    return 'v'+str(version)+'_'+name
415
416 2386 aaronmk
def run_query_into(db, query, params, into=None, *args, **kw_args):
417 2085 aaronmk
    '''Outputs a query to a temp table.
418
    For params, see run_query().
419
    '''
420 2386 aaronmk
    if into == None: return run_query(db, query, params, *args, **kw_args)
421 2085 aaronmk
    else: # place rows in temp table
422 2386 aaronmk
        assert isinstance(into, sql_gen.Table)
423 2385 aaronmk
424 2153 aaronmk
        kw_args['recover'] = True
425 2464 aaronmk
        kw_args.setdefault('log_ignore_excs', (DuplicateTableException,))
426 2440 aaronmk
427 2468 aaronmk
        temp = not db.autocommit # tables are permanent in autocommit mode
428 2440 aaronmk
        # "temporary tables cannot specify a schema name", so remove schema
429
        if temp: into.schema = None
430
431 2153 aaronmk
        while True:
432
            try:
433 2194 aaronmk
                create_query = 'CREATE'
434 2440 aaronmk
                if temp: create_query += ' TEMP'
435 2467 aaronmk
                create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
436 2194 aaronmk
437
                return run_query(db, create_query, params, *args, **kw_args)
438 2153 aaronmk
                    # CREATE TABLE AS sets rowcount to # rows in query
439
            except DuplicateTableException, e:
440 2386 aaronmk
                into.name = next_version(into.name)
441 2153 aaronmk
                # try again with next version of name
442 2085 aaronmk
443 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
444
445 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
446
447 2233 aaronmk
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
448 2293 aaronmk
    start=None, order_by=order_by_pkey, default_table=None):
449 1981 aaronmk
    '''
450 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
451 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
452 1981 aaronmk
    @param fields Use None to select all fields in the table
453 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
454 2379 aaronmk
        * container can be any iterable type
455 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
456
        * compare_right_side: sql_gen.ValueCond|literal value
457 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
458
        use all columns
459 2054 aaronmk
    @return tuple(query, params)
460 1981 aaronmk
    '''
461 2315 aaronmk
    # Parse tables param
462 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
463 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
464 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
465 2121 aaronmk
466 2315 aaronmk
    # Parse other params
467 2376 aaronmk
    if conds == None: conds = []
468
    elif isinstance(conds, dict): conds = conds.items()
469 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
470 135 aaronmk
    assert limit == None or type(limit) == int
471 865 aaronmk
    assert start == None or type(start) == int
472 2315 aaronmk
    if order_by is order_by_pkey:
473
        if distinct_on != []: order_by = None
474
        else: order_by = pkey(db, table0, recover=True)
475 865 aaronmk
476 2315 aaronmk
    query = 'SELECT'
477 2056 aaronmk
478 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
479 2056 aaronmk
480 2200 aaronmk
    # DISTINCT ON columns
481 2233 aaronmk
    if distinct_on != []:
482 2467 aaronmk
        query += '\nDISTINCT'
483 2254 aaronmk
        if distinct_on is not distinct_on_all:
484 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
485
486
    # Columns
487 2467 aaronmk
    query += '\n'
488 1135 aaronmk
    if fields == None: query += '*'
489 2479 aaronmk
    else: query += '\n, '.join(map(parse_col, fields))
490 2200 aaronmk
491
    # Main table
492 2467 aaronmk
    query += '\nFROM '+table0.to_str(db)
493 865 aaronmk
494 2122 aaronmk
    # Add joins
495 2271 aaronmk
    left_table = table0
496 2263 aaronmk
    for join_ in tables:
497
        table = join_.table
498 2238 aaronmk
499 2343 aaronmk
        # Parse special values
500
        if join_.type_ is sql_gen.filter_out: # filter no match
501 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
502
                None))
503 2343 aaronmk
504 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
505 2122 aaronmk
506
        left_table = table
507
508 865 aaronmk
    missing = True
509 2376 aaronmk
    if conds != []:
510 2467 aaronmk
        query += '\nWHERE\n'+('\nAND\n'.join(('('+sql_gen.ColValueCond(l, r)
511 2410 aaronmk
            .to_str(db)+')' for l, r in conds)))
512 865 aaronmk
        missing = False
513 2227 aaronmk
    if order_by != None:
514 2467 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
515
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
516 865 aaronmk
    if start != None:
517 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
518 865 aaronmk
        missing = False
519
    if missing: warnings.warn(DbWarning(
520
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
521
522 2315 aaronmk
    return (query, [])
523 11 aaronmk
524 2054 aaronmk
def select(db, *args, **kw_args):
525
    '''For params, see mk_select() and run_query()'''
526
    recover = kw_args.pop('recover', None)
527
    cacheable = kw_args.pop('cacheable', True)
528 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
529 2054 aaronmk
530
    query, params = mk_select(db, *args, **kw_args)
531 2442 aaronmk
    return run_query(db, query, params, recover, cacheable, log_level=log_level)
532 2054 aaronmk
533 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
534 2292 aaronmk
    returning=None, embeddable=False):
535 1960 aaronmk
    '''
536
    @param returning str|None An inserted column (such as pkey) to return
537 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
538 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
539
        query will be fully cached, not just if it raises an exception.
540 1960 aaronmk
    '''
541 2328 aaronmk
    table = sql_gen.as_Table(table)
542 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
543
    if cols != None: cols = [sql_gen.as_Col(v).to_str(db) for v in cols]
544 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
545 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
546 2063 aaronmk
547
    # Build query
548 2328 aaronmk
    query = 'INSERT INTO '+table.to_str(db)
549 2467 aaronmk
    if cols != None: query += '\n('+', '.join(cols)+')'
550
    query += '\n'+select_query
551 2063 aaronmk
552
    if returning != None:
553 2327 aaronmk
        returning_name = copy.copy(returning)
554
        returning_name.table = None
555
        returning_name = returning_name.to_str(db)
556 2467 aaronmk
        query += '\nRETURNING '+returning_name
557 2063 aaronmk
558 2070 aaronmk
    if embeddable:
559 2327 aaronmk
        assert returning != None
560
561 2070 aaronmk
        # Create function
562 2330 aaronmk
        function_name = '_'.join(['insert', table.name] + cols)
563 2327 aaronmk
        return_type = 'SETOF '+returning.to_str(db)+'%TYPE'
564 2189 aaronmk
        while True:
565
            try:
566 2327 aaronmk
                func_schema = None
567 2468 aaronmk
                if not db.autocommit: func_schema = 'pg_temp'
568 2327 aaronmk
                function = sql_gen.Table(function_name, func_schema).to_str(db)
569 2194 aaronmk
570 2189 aaronmk
                function_query = '''\
571 2467 aaronmk
CREATE FUNCTION '''+function+'''()
572
RETURNS '''+return_type+'''
573
LANGUAGE sql
574
AS $$
575
'''+mogrify(db, query, params)+''';
576
$$;
577 2070 aaronmk
'''
578 2446 aaronmk
                run_query(db, function_query, recover=True, cacheable=True,
579 2464 aaronmk
                    log_ignore_excs=(DuplicateFunctionException,))
580 2189 aaronmk
                break # this version was successful
581
            except DuplicateFunctionException, e:
582
                function_name = next_version(function_name)
583
                # try again with next version of name
584 2070 aaronmk
585 2337 aaronmk
        # Return query that uses function
586
        func_table = sql_gen.NamedTable('f', sql_gen.CustomCode(function+'()'),
587
            [returning_name]) # AS clause requires function alias
588
        return mk_select(db, func_table, start=0, order_by=None)
589 2070 aaronmk
590 2066 aaronmk
    return (query, params)
591
592
def insert_select(db, *args, **kw_args):
593 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
594 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
595
        values in
596 2072 aaronmk
    '''
597 2386 aaronmk
    into = kw_args.pop('into', None)
598
    if into != None: kw_args['embeddable'] = True
599 2066 aaronmk
    recover = kw_args.pop('recover', None)
600
    cacheable = kw_args.pop('cacheable', True)
601
602
    query, params = mk_insert_select(db, *args, **kw_args)
603 2386 aaronmk
    return run_query_into(db, query, params, into, recover=recover,
604 2153 aaronmk
        cacheable=cacheable)
605 2063 aaronmk
606 2066 aaronmk
default = object() # tells insert() to use the default value for a column
607
608 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
609 2085 aaronmk
    '''For params, see insert_select()'''
610 1960 aaronmk
    if lists.is_seq(row): cols = None
611
    else:
612
        cols = row.keys()
613
        row = row.values()
614
    row = list(row) # ensure that "!= []" works
615
616 1961 aaronmk
    # Check for special values
617
    labels = []
618
    values = []
619
    for value in row:
620 2254 aaronmk
        if value is default: labels.append('DEFAULT')
621 1961 aaronmk
        else:
622
            labels.append('%s')
623
            values.append(value)
624
625
    # Build query
626 2467 aaronmk
    if values != []: query = 'VALUES ('+(', '.join(labels))+')'
627 2063 aaronmk
    else: query = None
628 1554 aaronmk
629 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
630 11 aaronmk
631 2402 aaronmk
def mk_update(db, table, changes=None, cond=None):
632
    '''
633
    @param changes [(col, new_value),...]
634
        * container can be any iterable type
635
        * col: sql_gen.Code|str (for col name)
636
        * new_value: sql_gen.Code|literal value
637
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
638
    @return str query
639
    '''
640
    query = 'UPDATE '+sql_gen.as_Table(table).to_str(db)+'\nSET\n'
641 2405 aaronmk
    query += ',\n'.join((sql_gen.to_name_only_col(col, table).to_str(db)+' = '
642 2402 aaronmk
        +sql_gen.as_Value(new_value).to_str(db) for col, new_value in changes))
643 2467 aaronmk
    if cond != None: query += '\nWHERE\n'+cond.to_str(db)
644 2402 aaronmk
645
    return query
646
647
def update(db, *args, **kw_args):
648
    '''For params, see mk_update() and run_query()'''
649
    recover = kw_args.pop('recover', None)
650
651
    return run_query(db, mk_update(db, *args, **kw_args), [], recover)
652
653 135 aaronmk
def last_insert_id(db):
654 1849 aaronmk
    module = util.root_module(db.db)
655 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
656
    elif module == 'MySQLdb': return db.insert_id()
657
    else: return None
658 13 aaronmk
659 1968 aaronmk
def truncate(db, table, schema='public'):
660
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
661 832 aaronmk
662 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
663 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
664 2415 aaronmk
    to names that will be distinct among the columns' tables.
665 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
666 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
667
    @param into The table for the new columns.
668 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
669
        columns will be included in the mapping even if they are not in cols.
670
        The tables of the provided Col objects will be changed to into, so make
671
        copies of them if you want to keep the original tables.
672
    @param as_items Whether to return a list of dict items instead of a dict
673 2383 aaronmk
    @return dict(orig_col=new_col, ...)
674
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
675 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
676
        * All mappings use the into table so its name can easily be
677 2383 aaronmk
          changed for all columns at once
678
    '''
679 2415 aaronmk
    cols = lists.uniqify(cols)
680
681 2394 aaronmk
    items = []
682 2389 aaronmk
    for col in preserve:
683 2390 aaronmk
        orig_col = copy.copy(col)
684 2392 aaronmk
        col.table = into
685 2394 aaronmk
        items.append((orig_col, col))
686
    preserve = set(preserve)
687
    for col in cols:
688 2397 aaronmk
        if col not in preserve:
689
            items.append((col, sql_gen.Col(clean_name(str(col)), into)))
690 2394 aaronmk
691
    if not as_items: items = dict(items)
692
    return items
693 2383 aaronmk
694 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
695 2391 aaronmk
    '''For params, see mk_flatten_mapping()
696
    @return See return value of mk_flatten_mapping()
697
    '''
698 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
699
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
700 2391 aaronmk
    run_query_into(db, *mk_select(db, joins, cols, limit=limit, start=start),
701 2392 aaronmk
        into=into)
702 2394 aaronmk
    return dict(items)
703 2391 aaronmk
704 2414 aaronmk
##### Database structure queries
705
706 2426 aaronmk
def table_row_count(db, table, recover=None):
707
    return value(run_query(db, *mk_select(db, table, [sql_gen.row_count],
708 2443 aaronmk
        order_by=None, start=0), recover=recover, log_level=3))
709 2426 aaronmk
710 2414 aaronmk
def table_cols(db, table, recover=None):
711
    return list(col_names(select(db, table, limit=0, order_by=None,
712 2443 aaronmk
        recover=recover, log_level=4)))
713 2414 aaronmk
714 2291 aaronmk
def pkey(db, table, recover=None):
715 832 aaronmk
    '''Assumed to be first column in table'''
716 2339 aaronmk
    return table_cols(db, table, recover)[0]
717 832 aaronmk
718 2340 aaronmk
not_null_col = 'not_null'
719
720
def table_not_null_col(db, table, recover=None):
721
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
722
    if not_null_col in table_cols(db, table, recover): return not_null_col
723
    else: return pkey(db, table, recover)
724
725 853 aaronmk
def index_cols(db, table, index):
726
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
727
    automatically created. When you don't know whether something is a UNIQUE
728
    constraint or a UNIQUE index, use this function.'''
729 1909 aaronmk
    module = util.root_module(db.db)
730
    if module == 'psycopg2':
731
        return list(values(run_query(db, '''\
732 853 aaronmk
SELECT attname
733 866 aaronmk
FROM
734
(
735
        SELECT attnum, attname
736
        FROM pg_index
737
        JOIN pg_class index ON index.oid = indexrelid
738
        JOIN pg_class table_ ON table_.oid = indrelid
739
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
740
        WHERE
741
            table_.relname = %(table)s
742
            AND index.relname = %(index)s
743
    UNION
744
        SELECT attnum, attname
745
        FROM
746
        (
747
            SELECT
748
                indrelid
749
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
750
                    AS indkey
751
            FROM pg_index
752
            JOIN pg_class index ON index.oid = indexrelid
753
            JOIN pg_class table_ ON table_.oid = indrelid
754
            WHERE
755
                table_.relname = %(table)s
756
                AND index.relname = %(index)s
757
        ) s
758
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
759
) s
760 853 aaronmk
ORDER BY attnum
761
''',
762 2443 aaronmk
            {'table': table, 'index': index}, cacheable=True, log_level=4)))
763 1909 aaronmk
    else: raise NotImplementedError("Can't list index columns for "+module+
764
        ' database')
765 853 aaronmk
766 464 aaronmk
def constraint_cols(db, table, constraint):
767 1849 aaronmk
    module = util.root_module(db.db)
768 464 aaronmk
    if module == 'psycopg2':
769
        return list(values(run_query(db, '''\
770
SELECT attname
771
FROM pg_constraint
772
JOIN pg_class ON pg_class.oid = conrelid
773
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
774
WHERE
775
    relname = %(table)s
776
    AND conname = %(constraint)s
777
ORDER BY attnum
778
''',
779
            {'table': table, 'constraint': constraint})))
780
    else: raise NotImplementedError("Can't list constraint columns for "+module+
781
        ' database')
782
783 2096 aaronmk
row_num_col = '_row_num'
784
785 2408 aaronmk
def index_col(db, col):
786
    '''Adds an index on a column if it doesn't already exist.'''
787
    assert sql_gen.is_table_col(col)
788
789
    table = col.table
790
    index = sql_gen.as_Table(clean_name(str(col)))
791
    col = sql_gen.to_name_only_col(col)
792
    try: run_query(db, 'CREATE INDEX '+index.to_str(db)+' ON '+table.to_str(db)
793 2443 aaronmk
        +' ('+col.to_str(db)+')', recover=True, cacheable=True, log_level=3)
794 2408 aaronmk
    except DuplicateTableException: pass # index already existed
795
796 2406 aaronmk
def index_pkey(db, table, recover=None):
797
    '''Makes the first column in a table the primary key.
798
    @pre The table must not already have a primary key.
799
    '''
800
    table = sql_gen.as_Table(table)
801
802
    index = sql_gen.as_Table(table.name+'_pkey')
803
    col = sql_gen.to_name_only_col(pkey(db, table, recover))
804
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD CONSTRAINT '
805 2443 aaronmk
        +index.to_str(db)+' PRIMARY KEY('+col.to_str(db)+')', recover=recover,
806
        log_level=3)
807 2406 aaronmk
808 2086 aaronmk
def add_row_num(db, table):
809 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
810
    be the primary key.'''
811 2320 aaronmk
    table = sql_gen.as_Table(table).to_str(db)
812 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
813 2443 aaronmk
        +' serial NOT NULL PRIMARY KEY', log_level=3)
814 2086 aaronmk
815 1968 aaronmk
def tables(db, schema='public', table_like='%'):
816 1849 aaronmk
    module = util.root_module(db.db)
817 1968 aaronmk
    params = {'schema': schema, 'table_like': table_like}
818 832 aaronmk
    if module == 'psycopg2':
819 1968 aaronmk
        return values(run_query(db, '''\
820
SELECT tablename
821
FROM pg_tables
822
WHERE
823
    schemaname = %(schema)s
824
    AND tablename LIKE %(table_like)s
825
ORDER BY tablename
826
''',
827
            params, cacheable=True))
828
    elif module == 'MySQLdb':
829
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
830
            cacheable=True))
831 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
832 830 aaronmk
833 833 aaronmk
##### Database management
834
835 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
836
    '''For kw_args, see tables()'''
837
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
838 833 aaronmk
839 832 aaronmk
##### Heuristic queries
840
841 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
842 1554 aaronmk
    '''Recovers from errors.
843 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
844
    '''
845 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
846
847 471 aaronmk
    try:
848 2149 aaronmk
        cur = insert(db, table, row, pkey_, recover=True)
849 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
850
            row_ct_ref[0] += cur.rowcount
851
        return value(cur)
852 471 aaronmk
    except DuplicateKeyException, e:
853 2104 aaronmk
        return value(select(db, table, [pkey_],
854 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
855 471 aaronmk
856 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
857 830 aaronmk
    '''Recovers from errors'''
858 2209 aaronmk
    try: return value(select(db, table, [pkey], row, limit=1, recover=True))
859 14 aaronmk
    except StopIteration:
860 40 aaronmk
        if not create: raise
861 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
862 2078 aaronmk
863 2489 aaronmk
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, into=None):
864 2078 aaronmk
    '''Recovers from errors.
865
    Only works under PostgreSQL (uses INSERT RETURNING).
866 2131 aaronmk
    @param in_tables The main input table to select from, followed by a list of
867
        tables to join with it using the main input table's pkey
868 2312 aaronmk
    @param mapping dict(out_table_col=in_table_col, ...)
869
        * out_table_col: sql_gen.Col|str
870 2323 aaronmk
        * in_table_col: sql_gen.Col Wrap literal values in a sql_gen.NamedCol
871 2489 aaronmk
    @param into The table to contain the output and input pkeys.
872
        Defaults to `out_table.name+'_pkeys'`.
873 2312 aaronmk
    @return sql_gen.Col Where the output pkeys are made available
874 2078 aaronmk
    '''
875 2329 aaronmk
    out_table = sql_gen.as_Table(out_table)
876 2312 aaronmk
    for in_table_col in mapping.itervalues():
877
        assert isinstance(in_table_col, sql_gen.Col)
878 2489 aaronmk
    if into == None: into = out_table.name+'_pkeys'
879
    into = sql_gen.as_Table(into)
880 2312 aaronmk
881 2450 aaronmk
    def log_debug(msg): db.log_debug(msg, level=1.5)
882
883 2486 aaronmk
    log_debug('********** New iteration **********')
884
    log_debug('Inserting these input columns into '
885
        +strings.as_tt(out_table.to_str(db))+':\n'+strings.as_table(mapping))
886 2463 aaronmk
887 2382 aaronmk
    # Create input joins from list of input tables
888
    in_tables_ = in_tables[:] # don't modify input!
889
    in_tables0 = in_tables_.pop(0) # first table is separate
890 2279 aaronmk
    in_pkey = pkey(db, in_tables0, recover=True)
891 2285 aaronmk
    in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
892 2460 aaronmk
    input_joins = [in_tables0]+[sql_gen.Join(v,
893
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
894 2131 aaronmk
895 2486 aaronmk
    log_debug('Joining together input tables into temp table')
896 2395 aaronmk
    # Place in new table for speed and so don't modify input if values edited
897 2489 aaronmk
    in_table = sql_gen.Table(into.name+'_in')
898 2395 aaronmk
    flatten_cols = filter(sql_gen.is_table_col, mapping.values())
899
    mapping = dicts.join(mapping, flatten(db, in_table, input_joins,
900
        flatten_cols, preserve=[in_pkey_col], start=0))
901
    input_joins = [in_table]
902 2486 aaronmk
    db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
903 2395 aaronmk
904 2279 aaronmk
    out_pkey = pkey(db, out_table, recover=True)
905 2285 aaronmk
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
906 2142 aaronmk
907 2387 aaronmk
    pkeys_names = [in_pkey, out_pkey]
908 2236 aaronmk
    pkeys_cols = [in_pkey_col, out_pkey_col]
909
910 2201 aaronmk
    pkeys_table_exists_ref = [False]
911 2420 aaronmk
    def insert_into_pkeys(joins, cols):
912
        query, params = mk_select(db, joins, cols, order_by=None, start=0)
913 2201 aaronmk
        if pkeys_table_exists_ref[0]:
914 2489 aaronmk
            insert_select(db, into, pkeys_names, query, params)
915 2201 aaronmk
        else:
916 2489 aaronmk
            run_query_into(db, query, params, into=into)
917 2201 aaronmk
            pkeys_table_exists_ref[0] = True
918
919 2429 aaronmk
    limit_ref = [None]
920 2380 aaronmk
    conds = set()
921 2233 aaronmk
    distinct_on = []
922 2325 aaronmk
    def mk_main_select(joins, cols):
923 2429 aaronmk
        return mk_select(db, joins, cols, conds, distinct_on,
924
            limit=limit_ref[0], start=0)
925 2132 aaronmk
926 2309 aaronmk
    def log_exc(e):
927 2450 aaronmk
        log_debug('Caught exception: '+exc.str_(e, first_line_only=True))
928 2451 aaronmk
    def remove_all_rows():
929 2450 aaronmk
        log_debug('Returning NULL for all rows')
930 2429 aaronmk
        limit_ref[0] = 0 # just create an empty pkeys table
931 2409 aaronmk
    def ignore(in_col, value):
932
        in_col_str = str(in_col)
933 2450 aaronmk
        log_debug('Adding index on '+in_col_str+' to enable fast filtering')
934 2409 aaronmk
        index_col(db, in_col)
935 2450 aaronmk
        log_debug('Ignoring rows with '+in_col_str+' = '+repr(value))
936 2403 aaronmk
    def remove_rows(in_col, value):
937 2409 aaronmk
        ignore(in_col, value)
938 2378 aaronmk
        cond = (in_col, sql_gen.CompareCond(value, '!='))
939
        assert cond not in conds # avoid infinite loops
940 2380 aaronmk
        conds.add(cond)
941 2403 aaronmk
    def invalid2null(in_col, value):
942 2409 aaronmk
        ignore(in_col, value)
943 2403 aaronmk
        update(db, in_table, [(in_col, None)],
944
            sql_gen.ColValueCond(in_col, value))
945 2245 aaronmk
946 2206 aaronmk
    # Do inserts and selects
947 2257 aaronmk
    join_cols = {}
948 2489 aaronmk
    insert_out_pkeys = sql_gen.Table(into.name+'_insert_out_pkeys')
949
    insert_in_pkeys = sql_gen.Table(into.name+'_insert_in_pkeys')
950 2206 aaronmk
    while True:
951 2303 aaronmk
        has_joins = join_cols != {}
952
953 2305 aaronmk
        # Prepare to insert new rows
954 2325 aaronmk
        insert_joins = input_joins[:] # don't modify original!
955 2403 aaronmk
        insert_args = dict(recover=True, cacheable=False)
956 2303 aaronmk
        if has_joins:
957 2317 aaronmk
            distinct_on = [v.to_Col() for v in join_cols.values()]
958 2325 aaronmk
            insert_joins.append(sql_gen.Join(out_table, join_cols,
959
                sql_gen.filter_out))
960
        else:
961 2404 aaronmk
            insert_args.update(dict(returning=out_pkey, into=insert_out_pkeys))
962 2303 aaronmk
963 2486 aaronmk
        log_debug('Trying to insert new rows')
964 2206 aaronmk
        try:
965
            cur = insert_select(db, out_table, mapping.keys(),
966 2325 aaronmk
                *mk_main_select(insert_joins, mapping.values()), **insert_args)
967 2357 aaronmk
            break # insert successful
968 2206 aaronmk
        except DuplicateKeyException, e:
969 2309 aaronmk
            log_exc(e)
970
971 2258 aaronmk
            old_join_cols = join_cols.copy()
972 2334 aaronmk
            join_cols.update(util.dict_subset(mapping, e.cols))
973 2486 aaronmk
            log_debug('Ignoring existing rows, comparing on these columns:\n'
974 2483 aaronmk
                +strings.as_inline_table(join_cols))
975 2258 aaronmk
            assert join_cols != old_join_cols # avoid infinite loops
976 2230 aaronmk
        except NullValueException, e:
977 2309 aaronmk
            log_exc(e)
978
979 2230 aaronmk
            out_col, = e.cols
980
            try: in_col = mapping[out_col]
981 2356 aaronmk
            except KeyError:
982 2486 aaronmk
                log_debug('Missing mapping for NOT NULL column '+out_col)
983 2451 aaronmk
                remove_all_rows()
984 2403 aaronmk
            else: remove_rows(in_col, None)
985 2243 aaronmk
        except FunctionValueException, e:
986 2309 aaronmk
            log_exc(e)
987
988 2344 aaronmk
            assert e.name == out_table.name
989 2243 aaronmk
            out_col = 'value' # assume function param was named "value"
990 2403 aaronmk
            invalid2null(mapping[out_col], e.value)
991 2429 aaronmk
        except DatabaseErrors, e:
992
            log_exc(e)
993
994 2451 aaronmk
            msg = 'No handler for exception: '+exc.str_(e, first_line_only=True)
995
            warnings.warn(DbWarning(msg))
996
            log_debug(msg)
997
            remove_all_rows()
998 2358 aaronmk
        # after exception handled, rerun loop with additional constraints
999 2132 aaronmk
1000 2357 aaronmk
    if row_ct_ref != None and cur.rowcount >= 0:
1001
        row_ct_ref[0] += cur.rowcount
1002
1003
    if has_joins:
1004
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols)]
1005 2486 aaronmk
        log_debug('Getting output table pkeys of existing/inserted rows')
1006 2420 aaronmk
        insert_into_pkeys(select_joins, pkeys_cols)
1007 2357 aaronmk
    else:
1008 2404 aaronmk
        add_row_num(db, insert_out_pkeys) # for joining with input pkeys
1009 2357 aaronmk
1010 2486 aaronmk
        log_debug('Getting input table pkeys of inserted rows')
1011 2357 aaronmk
        run_query_into(db, *mk_main_select(input_joins, [in_pkey]),
1012 2404 aaronmk
            into=insert_in_pkeys)
1013
        add_row_num(db, insert_in_pkeys) # for joining with output pkeys
1014 2357 aaronmk
1015 2428 aaronmk
        assert table_row_count(db, insert_out_pkeys) == table_row_count(db,
1016
            insert_in_pkeys)
1017
1018 2486 aaronmk
        log_debug('Combining output and input pkeys in inserted order')
1019 2404 aaronmk
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
1020 2357 aaronmk
            {row_num_col: sql_gen.join_same_not_null})]
1021 2420 aaronmk
        insert_into_pkeys(pkey_joins, pkeys_names)
1022 2357 aaronmk
1023 2486 aaronmk
    db.log_debug('Adding pkey on pkeys table to enable fast joins', level=2.5)
1024 2489 aaronmk
    index_pkey(db, into)
1025 2407 aaronmk
1026 2485 aaronmk
    log_debug("Setting pkeys of missing rows to NULL")
1027 2489 aaronmk
    missing_rows_joins = input_joins+[sql_gen.Join(into,
1028 2357 aaronmk
        {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
1029
        # must use join_same_not_null or query will take forever
1030 2420 aaronmk
    insert_into_pkeys(missing_rows_joins,
1031
        [in_pkey_col, sql_gen.NamedCol(out_pkey, None)])
1032 2357 aaronmk
1033 2489 aaronmk
    assert table_row_count(db, into) == table_row_count(db, in_table)
1034 2428 aaronmk
1035 2489 aaronmk
    return sql_gen.Col(out_pkey, into)
1036 2115 aaronmk
1037
##### Data cleanup
1038
1039 2290 aaronmk
def cleanup_table(db, table, cols):
1040 2115 aaronmk
    def esc_name_(name): return esc_name(db, name)
1041
1042 2290 aaronmk
    table = sql_gen.as_Table(table).to_str(db)
1043 2115 aaronmk
    cols = map(esc_name_, cols)
1044
1045
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
1046
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
1047
            for col in cols))),
1048
        dict(null0='', null1=r'\N'))