Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 2127 aaronmk
import operator
5 11 aaronmk
import re
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 1889 aaronmk
from Proxy import Proxy
13 1872 aaronmk
import rand
14 2217 aaronmk
import sql_gen
15 862 aaronmk
import strings
16 131 aaronmk
import util
17 11 aaronmk
18 832 aaronmk
##### Exceptions
19
20 2170 aaronmk
def get_cur_query(cur, input_query=None, input_params=None):
21 2168 aaronmk
    raw_query = None
22
    if hasattr(cur, 'query'): raw_query = cur.query
23
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
24 2170 aaronmk
25
    if raw_query != None: return raw_query
26 2371 aaronmk
    else: return '[input] '+strings.ustr(input_query)+' % '+repr(input_params)
27 14 aaronmk
28 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
29
    '''For params, see get_cur_query()'''
30
    exc.add_msg(e, 'query: '+str(get_cur_query(*args, **kw_args)))
31 135 aaronmk
32 300 aaronmk
class DbException(exc.ExceptionWithCause):
33 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
34 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
35 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
36
37 2143 aaronmk
class ExceptionWithName(DbException):
38
    def __init__(self, name, cause=None):
39 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
40 2143 aaronmk
        self.name = name
41 360 aaronmk
42 2240 aaronmk
class ExceptionWithNameValue(DbException):
43
    def __init__(self, name, value, cause=None):
44 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name))
45
            +'; value: '+strings.as_tt(repr(value)), cause)
46 2240 aaronmk
        self.name = name
47
        self.value = value
48
49 2306 aaronmk
class ConstraintException(DbException):
50
    def __init__(self, name, cols, cause=None):
51 2484 aaronmk
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
52
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
53 2306 aaronmk
        self.name = name
54 468 aaronmk
        self.cols = cols
55 11 aaronmk
56 2523 aaronmk
class MissingCastException(DbException):
57
    def __init__(self, type_, col, cause=None):
58
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
59
            +' on column: '+strings.as_tt(col), cause)
60
        self.type = type_
61
        self.col = col
62
63 2143 aaronmk
class NameException(DbException): pass
64
65 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
66 13 aaronmk
67 2306 aaronmk
class NullValueException(ConstraintException): pass
68 13 aaronmk
69 2240 aaronmk
class FunctionValueException(ExceptionWithNameValue): pass
70 2239 aaronmk
71 2143 aaronmk
class DuplicateTableException(ExceptionWithName): pass
72
73 2188 aaronmk
class DuplicateFunctionException(ExceptionWithName): pass
74
75 89 aaronmk
class EmptyRowException(DbException): pass
76
77 865 aaronmk
##### Warnings
78
79
class DbWarning(UserWarning): pass
80
81 1930 aaronmk
##### Result retrieval
82
83
def col_names(cur): return (col[0] for col in cur.description)
84
85
def rows(cur): return iter(lambda: cur.fetchone(), None)
86
87
def consume_rows(cur):
88
    '''Used to fetch all rows so result will be cached'''
89
    iters.consume_iter(rows(cur))
90
91
def next_row(cur): return rows(cur).next()
92
93
def row(cur):
94
    row_ = next_row(cur)
95
    consume_rows(cur)
96
    return row_
97
98
def next_value(cur): return next_row(cur)[0]
99
100
def value(cur): return row(cur)[0]
101
102
def values(cur): return iters.func_iter(lambda: next_value(cur))
103
104
def value_or_none(cur):
105
    try: return value(cur)
106
    except StopIteration: return None
107
108 2101 aaronmk
##### Input validation
109
110
def esc_name_by_module(module, name, ignore_case=False):
111 2388 aaronmk
    if module == 'psycopg2' or module == None:
112 2569 aaronmk
        if ignore_case and is_safe_name(name):
113 2101 aaronmk
            # Don't enclose in quotes because this disables case-insensitivity
114
            return name
115
        else: quote = '"'
116
    elif module == 'MySQLdb': quote = '`'
117
    else: raise NotImplementedError("Can't escape name for "+module+' database')
118 2500 aaronmk
    return sql_gen.esc_name(name, quote)
119 2101 aaronmk
120
def esc_name_by_engine(engine, name, **kw_args):
121
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
122
123
def esc_name(db, name, **kw_args):
124
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
125
126
def qual_name(db, schema, table):
127
    def esc_name_(name): return esc_name(db, name)
128
    table = esc_name_(table)
129
    if schema != None: return esc_name_(schema)+'.'+table
130
    else: return table
131
132 1869 aaronmk
##### Database connections
133 1849 aaronmk
134 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
135 1926 aaronmk
136 1869 aaronmk
db_engines = {
137
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
138
    'PostgreSQL': ('psycopg2', {}),
139
}
140
141
DatabaseErrors_set = set([DbException])
142
DatabaseErrors = tuple(DatabaseErrors_set)
143
144
def _add_module(module):
145
    DatabaseErrors_set.add(module.DatabaseError)
146
    global DatabaseErrors
147
    DatabaseErrors = tuple(DatabaseErrors_set)
148
149
def db_config_str(db_config):
150
    return db_config['engine']+' database '+db_config['database']
151
152 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
153 1894 aaronmk
154 2448 aaronmk
log_debug_none = lambda msg, level=2: None
155 1901 aaronmk
156 1849 aaronmk
class DbConn:
157 2190 aaronmk
    def __init__(self, db_config, serializable=True, autocommit=False,
158
        caching=True, log_debug=log_debug_none):
159 1869 aaronmk
        self.db_config = db_config
160
        self.serializable = serializable
161 2190 aaronmk
        self.autocommit = autocommit
162
        self.caching = caching
163 1901 aaronmk
        self.log_debug = log_debug
164 2193 aaronmk
        self.debug = log_debug != log_debug_none
165 1869 aaronmk
166
        self.__db = None
167 1889 aaronmk
        self.query_results = {}
168 2139 aaronmk
        self._savepoint = 0
169 1869 aaronmk
170
    def __getattr__(self, name):
171
        if name == '__dict__': raise Exception('getting __dict__')
172
        if name == 'db': return self._db()
173
        else: raise AttributeError()
174
175
    def __getstate__(self):
176
        state = copy.copy(self.__dict__) # shallow copy
177 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
178 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
179
        return state
180
181 2165 aaronmk
    def connected(self): return self.__db != None
182
183 1869 aaronmk
    def _db(self):
184
        if self.__db == None:
185
            # Process db_config
186
            db_config = self.db_config.copy() # don't modify input!
187 2097 aaronmk
            schemas = db_config.pop('schemas', None)
188 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
189
            module = __import__(module_name)
190
            _add_module(module)
191
            for orig, new in mappings.iteritems():
192
                try: util.rename_key(db_config, orig, new)
193
                except KeyError: pass
194
195
            # Connect
196
            self.__db = module.connect(**db_config)
197
198
            # Configure connection
199 2234 aaronmk
            if self.serializable and not self.autocommit: run_raw_query(self,
200
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
201 2101 aaronmk
            if schemas != None:
202
                schemas_ = ''.join((esc_name(self, s)+', '
203
                    for s in schemas.split(',')))
204
                run_raw_query(self, "SELECT set_config('search_path', \
205
%s || current_setting('search_path'), false)", [schemas_])
206 1869 aaronmk
207
        return self.__db
208 1889 aaronmk
209 1891 aaronmk
    class DbCursor(Proxy):
210 1927 aaronmk
        def __init__(self, outer):
211 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
212 2191 aaronmk
            self.outer = outer
213 1927 aaronmk
            self.query_results = outer.query_results
214 1894 aaronmk
            self.query_lookup = None
215 1891 aaronmk
            self.result = []
216 1889 aaronmk
217 1894 aaronmk
        def execute(self, query, params=None):
218 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
219 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
220 2148 aaronmk
            try:
221 2191 aaronmk
                try:
222
                    return_value = self.inner.execute(query, params)
223
                    self.outer.do_autocommit()
224 2148 aaronmk
                finally: self.query = get_cur_query(self.inner)
225 1904 aaronmk
            except Exception, e:
226 2170 aaronmk
                _add_cursor_info(e, self, query, params)
227 1904 aaronmk
                self.result = e # cache the exception as the result
228
                self._cache_result()
229
                raise
230 1930 aaronmk
            # Fetch all rows so result will be cached
231
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
232 1894 aaronmk
            return return_value
233
234 1891 aaronmk
        def fetchone(self):
235
            row = self.inner.fetchone()
236 1899 aaronmk
            if row != None: self.result.append(row)
237
            # otherwise, fetched all rows
238 1904 aaronmk
            else: self._cache_result()
239
            return row
240
241
        def _cache_result(self):
242 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
243
            # idempotent, but an invalid insert will always be invalid
244 1930 aaronmk
            if self.query_results != None and (not self._is_insert
245 1906 aaronmk
                or isinstance(self.result, Exception)):
246
247 1894 aaronmk
                assert self.query_lookup != None
248 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
249
                    util.dict_subset(dicts.AttrsDictView(self),
250
                    ['query', 'result', 'rowcount', 'description']))
251 1906 aaronmk
252 1916 aaronmk
        class CacheCursor:
253
            def __init__(self, cached_result): self.__dict__ = cached_result
254
255 1927 aaronmk
            def execute(self, *args, **kw_args):
256 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
257
                # otherwise, result is a rows list
258
                self.iter = iter(self.result)
259
260
            def fetchone(self):
261
                try: return self.iter.next()
262
                except StopIteration: return None
263 1891 aaronmk
264 2212 aaronmk
    def esc_value(self, value):
265 2215 aaronmk
        module = util.root_module(self.db)
266 2374 aaronmk
        if module == 'psycopg2': str_ = self.db.cursor().mogrify('%s', [value])
267 2212 aaronmk
        elif module == 'MySQLdb':
268
            import _mysql
269 2374 aaronmk
            str_ = _mysql.escape_string(value)
270 2212 aaronmk
        else: raise NotImplementedError("Can't escape value for "+module
271
            +' database')
272 2374 aaronmk
        return strings.to_unicode(str_)
273 2212 aaronmk
274 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
275
276 2445 aaronmk
    def run_query(self, query, params=None, cacheable=False, log_level=2,
277 2464 aaronmk
        debug_msg_ref=None):
278 2445 aaronmk
        '''
279 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
280
            throws one of these exceptions.
281 2445 aaronmk
        '''
282 2167 aaronmk
        assert query != None
283
284 2047 aaronmk
        if not self.caching: cacheable = False
285 1903 aaronmk
        used_cache = False
286
        try:
287 1927 aaronmk
            # Get cursor
288
            if cacheable:
289
                query_lookup = _query_lookup(query, params)
290
                try:
291
                    cur = self.query_results[query_lookup]
292
                    used_cache = True
293
                except KeyError: cur = self.DbCursor(self)
294
            else: cur = self.db.cursor()
295
296
            # Run query
297 2148 aaronmk
            cur.execute(query, params)
298 1903 aaronmk
        finally:
299 2464 aaronmk
            if self.debug and debug_msg_ref != None:# only compute msg if needed
300 2470 aaronmk
                if used_cache: cache_status = 'cache hit'
301
                elif cacheable: cache_status = 'cache miss'
302
                else: cache_status = 'non-cacheable'
303 2472 aaronmk
                query_code = strings.as_code(str(get_cur_query(cur, query,
304
                    params)), 'SQL')
305
                debug_msg_ref[0] = 'DB query: '+cache_status+':\n'+query_code
306 1903 aaronmk
307
        return cur
308 1914 aaronmk
309
    def is_cached(self, query, params=None):
310
        return _query_lookup(query, params) in self.query_results
311 2139 aaronmk
312
    def with_savepoint(self, func):
313 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
314 2443 aaronmk
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
315 2139 aaronmk
        self._savepoint += 1
316
        try:
317
            try: return_val = func()
318
            finally:
319
                self._savepoint -= 1
320
                assert self._savepoint >= 0
321
        except:
322 2443 aaronmk
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
323 2139 aaronmk
            raise
324
        else:
325 2443 aaronmk
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
326 2191 aaronmk
            self.do_autocommit()
327 2139 aaronmk
            return return_val
328 2191 aaronmk
329
    def do_autocommit(self):
330
        '''Autocommits if outside savepoint'''
331
        assert self._savepoint >= 0
332
        if self.autocommit and self._savepoint == 0:
333
            self.log_debug('Autocommiting')
334
            self.db.commit()
335 1849 aaronmk
336 1869 aaronmk
connect = DbConn
337
338 832 aaronmk
##### Querying
339
340 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
341 2085 aaronmk
    '''For params, see DbConn.run_query()'''
342 1894 aaronmk
    return db.run_query(*args, **kw_args)
343 11 aaronmk
344 2068 aaronmk
def mogrify(db, query, params):
345
    module = util.root_module(db.db)
346
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
347
    else: raise NotImplementedError("Can't mogrify query for "+module+
348
        ' database')
349
350 832 aaronmk
##### Recoverable querying
351 15 aaronmk
352 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
353 11 aaronmk
354 2464 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False,
355
    log_level=2, log_ignore_excs=None, **kw_args):
356 2441 aaronmk
    '''For params, see run_raw_query()'''
357 830 aaronmk
    if recover == None: recover = False
358 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
359
    log_ignore_excs = tuple(log_ignore_excs)
360 830 aaronmk
361 2464 aaronmk
    debug_msg_ref = [None]
362 2148 aaronmk
    try:
363 2464 aaronmk
        try:
364
            def run(): return run_raw_query(db, query, params, cacheable,
365
                log_level, debug_msg_ref, **kw_args)
366
            if recover and not db.is_cached(query, params):
367
                return with_savepoint(db, run)
368
            else: return run() # don't need savepoint if cached
369
        except Exception, e:
370
            if not recover: raise # need savepoint to run index_cols()
371
            msg = exc.str_(e)
372
373
            match = re.search(r'duplicate key value violates unique constraint '
374 2493 aaronmk
                r'"((_?[^\W_]+)_.+?)"', msg)
375 2464 aaronmk
            if match:
376
                constraint, table = match.groups()
377
                try: cols = index_cols(db, table, constraint)
378
                except NotImplementedError: raise e
379
                else: raise DuplicateKeyException(constraint, cols, e)
380
381 2493 aaronmk
            match = re.search(r'null value in column "(.+?)" violates not-null'
382 2464 aaronmk
                r' constraint', msg)
383
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
384
385
            match = re.search(r'\b(?:invalid input (?:syntax|value)\b.*?'
386
                r'|date/time field value out of range): "(.+?)"\n'
387 2535 aaronmk
                r'(?:(?s).*?)\bfunction "(.+?)"', msg)
388 2464 aaronmk
            if match:
389
                value, name = match.groups()
390
                raise FunctionValueException(name, strings.to_unicode(value), e)
391
392 2526 aaronmk
            match = re.search(r'column "(.+?)" is of type (.+?) but expression '
393 2523 aaronmk
                r'is of type', msg)
394
            if match:
395
                col, type_ = match.groups()
396
                raise MissingCastException(type_, col, e)
397
398 2493 aaronmk
            match = re.search(r'relation "(.+?)" already exists', msg)
399 2464 aaronmk
            if match: raise DuplicateTableException(match.group(1), e)
400
401 2493 aaronmk
            match = re.search(r'function "(.+?)" already exists', msg)
402 2464 aaronmk
            if match: raise DuplicateFunctionException(match.group(1), e)
403
404
            raise # no specific exception raised
405
    except log_ignore_excs:
406
        log_level += 2
407
        raise
408
    finally:
409
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)
410 830 aaronmk
411 832 aaronmk
##### Basic queries
412
413 2153 aaronmk
def next_version(name):
414
    '''Prepends the version # so it won't be removed if the name is truncated'''
415 2163 aaronmk
    version = 1 # first existing name was version 0
416 2498 aaronmk
    match = re.match(r'^#(\d+)-(.*)$', name)
417 2153 aaronmk
    if match:
418
        version = int(match.group(1))+1
419
        name = match.group(2)
420 2498 aaronmk
    return '#'+str(version)+'-'+name
421 2153 aaronmk
422 2386 aaronmk
def run_query_into(db, query, params, into=None, *args, **kw_args):
423 2085 aaronmk
    '''Outputs a query to a temp table.
424
    For params, see run_query().
425
    '''
426 2386 aaronmk
    if into == None: return run_query(db, query, params, *args, **kw_args)
427 2085 aaronmk
    else: # place rows in temp table
428 2386 aaronmk
        assert isinstance(into, sql_gen.Table)
429 2385 aaronmk
430 2153 aaronmk
        kw_args['recover'] = True
431 2464 aaronmk
        kw_args.setdefault('log_ignore_excs', (DuplicateTableException,))
432 2440 aaronmk
433 2468 aaronmk
        temp = not db.autocommit # tables are permanent in autocommit mode
434 2440 aaronmk
        # "temporary tables cannot specify a schema name", so remove schema
435
        if temp: into.schema = None
436
437 2153 aaronmk
        while True:
438
            try:
439 2194 aaronmk
                create_query = 'CREATE'
440 2440 aaronmk
                if temp: create_query += ' TEMP'
441 2467 aaronmk
                create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
442 2194 aaronmk
443
                return run_query(db, create_query, params, *args, **kw_args)
444 2153 aaronmk
                    # CREATE TABLE AS sets rowcount to # rows in query
445
            except DuplicateTableException, e:
446 2386 aaronmk
                into.name = next_version(into.name)
447 2153 aaronmk
                # try again with next version of name
448 2085 aaronmk
449 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
450
451 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
452
453 2233 aaronmk
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
454 2293 aaronmk
    start=None, order_by=order_by_pkey, default_table=None):
455 1981 aaronmk
    '''
456 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
457 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
458 1981 aaronmk
    @param fields Use None to select all fields in the table
459 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
460 2379 aaronmk
        * container can be any iterable type
461 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
462
        * compare_right_side: sql_gen.ValueCond|literal value
463 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
464
        use all columns
465 2054 aaronmk
    @return tuple(query, params)
466 1981 aaronmk
    '''
467 2315 aaronmk
    # Parse tables param
468 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
469 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
470 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
471 2121 aaronmk
472 2315 aaronmk
    # Parse other params
473 2376 aaronmk
    if conds == None: conds = []
474
    elif isinstance(conds, dict): conds = conds.items()
475 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
476 135 aaronmk
    assert limit == None or type(limit) == int
477 865 aaronmk
    assert start == None or type(start) == int
478 2315 aaronmk
    if order_by is order_by_pkey:
479
        if distinct_on != []: order_by = None
480
        else: order_by = pkey(db, table0, recover=True)
481 865 aaronmk
482 2315 aaronmk
    query = 'SELECT'
483 2056 aaronmk
484 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
485 2056 aaronmk
486 2200 aaronmk
    # DISTINCT ON columns
487 2233 aaronmk
    if distinct_on != []:
488 2467 aaronmk
        query += '\nDISTINCT'
489 2254 aaronmk
        if distinct_on is not distinct_on_all:
490 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
491
492
    # Columns
493 2467 aaronmk
    query += '\n'
494 1135 aaronmk
    if fields == None: query += '*'
495 2479 aaronmk
    else: query += '\n, '.join(map(parse_col, fields))
496 2200 aaronmk
497
    # Main table
498 2467 aaronmk
    query += '\nFROM '+table0.to_str(db)
499 865 aaronmk
500 2122 aaronmk
    # Add joins
501 2271 aaronmk
    left_table = table0
502 2263 aaronmk
    for join_ in tables:
503
        table = join_.table
504 2238 aaronmk
505 2343 aaronmk
        # Parse special values
506
        if join_.type_ is sql_gen.filter_out: # filter no match
507 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
508
                None))
509 2343 aaronmk
510 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
511 2122 aaronmk
512
        left_table = table
513
514 865 aaronmk
    missing = True
515 2376 aaronmk
    if conds != []:
516 2467 aaronmk
        query += '\nWHERE\n'+('\nAND\n'.join(('('+sql_gen.ColValueCond(l, r)
517 2410 aaronmk
            .to_str(db)+')' for l, r in conds)))
518 865 aaronmk
        missing = False
519 2227 aaronmk
    if order_by != None:
520 2467 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
521
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
522 865 aaronmk
    if start != None:
523 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
524 865 aaronmk
        missing = False
525
    if missing: warnings.warn(DbWarning(
526
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
527
528 2315 aaronmk
    return (query, [])
529 11 aaronmk
530 2054 aaronmk
def select(db, *args, **kw_args):
531
    '''For params, see mk_select() and run_query()'''
532
    recover = kw_args.pop('recover', None)
533
    cacheable = kw_args.pop('cacheable', True)
534 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
535 2054 aaronmk
536
    query, params = mk_select(db, *args, **kw_args)
537 2442 aaronmk
    return run_query(db, query, params, recover, cacheable, log_level=log_level)
538 2054 aaronmk
539 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
540 2292 aaronmk
    returning=None, embeddable=False):
541 1960 aaronmk
    '''
542
    @param returning str|None An inserted column (such as pkey) to return
543 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
544 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
545
        query will be fully cached, not just if it raises an exception.
546 1960 aaronmk
    '''
547 2328 aaronmk
    table = sql_gen.as_Table(table)
548 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
549
    if cols != None: cols = [sql_gen.as_Col(v).to_str(db) for v in cols]
550 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
551 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
552 2063 aaronmk
553
    # Build query
554 2497 aaronmk
    first_line = 'INSERT INTO '+table.to_str(db)
555
    query = first_line
556 2467 aaronmk
    if cols != None: query += '\n('+', '.join(cols)+')'
557
    query += '\n'+select_query
558 2063 aaronmk
559
    if returning != None:
560 2327 aaronmk
        returning_name = copy.copy(returning)
561
        returning_name.table = None
562
        returning_name = returning_name.to_str(db)
563 2467 aaronmk
        query += '\nRETURNING '+returning_name
564 2063 aaronmk
565 2070 aaronmk
    if embeddable:
566 2327 aaronmk
        assert returning != None
567
568 2070 aaronmk
        # Create function
569 2513 aaronmk
        function_name = sql_gen.clean_name(first_line)
570 2327 aaronmk
        return_type = 'SETOF '+returning.to_str(db)+'%TYPE'
571 2189 aaronmk
        while True:
572
            try:
573 2327 aaronmk
                func_schema = None
574 2468 aaronmk
                if not db.autocommit: func_schema = 'pg_temp'
575 2327 aaronmk
                function = sql_gen.Table(function_name, func_schema).to_str(db)
576 2194 aaronmk
577 2189 aaronmk
                function_query = '''\
578 2467 aaronmk
CREATE FUNCTION '''+function+'''()
579
RETURNS '''+return_type+'''
580
LANGUAGE sql
581
AS $$
582
'''+mogrify(db, query, params)+''';
583
$$;
584 2070 aaronmk
'''
585 2446 aaronmk
                run_query(db, function_query, recover=True, cacheable=True,
586 2464 aaronmk
                    log_ignore_excs=(DuplicateFunctionException,))
587 2189 aaronmk
                break # this version was successful
588
            except DuplicateFunctionException, e:
589
                function_name = next_version(function_name)
590
                # try again with next version of name
591 2070 aaronmk
592 2337 aaronmk
        # Return query that uses function
593
        func_table = sql_gen.NamedTable('f', sql_gen.CustomCode(function+'()'),
594
            [returning_name]) # AS clause requires function alias
595
        return mk_select(db, func_table, start=0, order_by=None)
596 2070 aaronmk
597 2066 aaronmk
    return (query, params)
598
599
def insert_select(db, *args, **kw_args):
600 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
601 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
602
        values in
603 2072 aaronmk
    '''
604 2386 aaronmk
    into = kw_args.pop('into', None)
605
    if into != None: kw_args['embeddable'] = True
606 2066 aaronmk
    recover = kw_args.pop('recover', None)
607
    cacheable = kw_args.pop('cacheable', True)
608
609
    query, params = mk_insert_select(db, *args, **kw_args)
610 2386 aaronmk
    return run_query_into(db, query, params, into, recover=recover,
611 2153 aaronmk
        cacheable=cacheable)
612 2063 aaronmk
613 2066 aaronmk
default = object() # tells insert() to use the default value for a column
614
615 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
616 2085 aaronmk
    '''For params, see insert_select()'''
617 1960 aaronmk
    if lists.is_seq(row): cols = None
618
    else:
619
        cols = row.keys()
620
        row = row.values()
621
    row = list(row) # ensure that "!= []" works
622
623 1961 aaronmk
    # Check for special values
624
    labels = []
625
    values = []
626
    for value in row:
627 2254 aaronmk
        if value is default: labels.append('DEFAULT')
628 1961 aaronmk
        else:
629
            labels.append('%s')
630
            values.append(value)
631
632
    # Build query
633 2467 aaronmk
    if values != []: query = 'VALUES ('+(', '.join(labels))+')'
634 2063 aaronmk
    else: query = None
635 1554 aaronmk
636 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
637 11 aaronmk
638 2402 aaronmk
def mk_update(db, table, changes=None, cond=None):
639
    '''
640
    @param changes [(col, new_value),...]
641
        * container can be any iterable type
642
        * col: sql_gen.Code|str (for col name)
643
        * new_value: sql_gen.Code|literal value
644
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
645
    @return str query
646
    '''
647
    query = 'UPDATE '+sql_gen.as_Table(table).to_str(db)+'\nSET\n'
648 2405 aaronmk
    query += ',\n'.join((sql_gen.to_name_only_col(col, table).to_str(db)+' = '
649 2402 aaronmk
        +sql_gen.as_Value(new_value).to_str(db) for col, new_value in changes))
650 2467 aaronmk
    if cond != None: query += '\nWHERE\n'+cond.to_str(db)
651 2402 aaronmk
652
    return query
653
654
def update(db, *args, **kw_args):
655
    '''For params, see mk_update() and run_query()'''
656
    recover = kw_args.pop('recover', None)
657
658
    return run_query(db, mk_update(db, *args, **kw_args), [], recover)
659
660 135 aaronmk
def last_insert_id(db):
661 1849 aaronmk
    module = util.root_module(db.db)
662 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
663
    elif module == 'MySQLdb': return db.insert_id()
664
    else: return None
665 13 aaronmk
666 1968 aaronmk
def truncate(db, table, schema='public'):
667
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
668 832 aaronmk
669 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
670 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
671 2415 aaronmk
    to names that will be distinct among the columns' tables.
672 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
673 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
674
    @param into The table for the new columns.
675 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
676
        columns will be included in the mapping even if they are not in cols.
677
        The tables of the provided Col objects will be changed to into, so make
678
        copies of them if you want to keep the original tables.
679
    @param as_items Whether to return a list of dict items instead of a dict
680 2383 aaronmk
    @return dict(orig_col=new_col, ...)
681
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
682 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
683
        * All mappings use the into table so its name can easily be
684 2383 aaronmk
          changed for all columns at once
685
    '''
686 2415 aaronmk
    cols = lists.uniqify(cols)
687
688 2394 aaronmk
    items = []
689 2389 aaronmk
    for col in preserve:
690 2390 aaronmk
        orig_col = copy.copy(col)
691 2392 aaronmk
        col.table = into
692 2394 aaronmk
        items.append((orig_col, col))
693
    preserve = set(preserve)
694
    for col in cols:
695 2515 aaronmk
        if col not in preserve: items.append((col, sql_gen.Col(str(col), into)))
696 2394 aaronmk
697
    if not as_items: items = dict(items)
698
    return items
699 2383 aaronmk
700 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
701 2391 aaronmk
    '''For params, see mk_flatten_mapping()
702
    @return See return value of mk_flatten_mapping()
703
    '''
704 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
705
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
706 2391 aaronmk
    run_query_into(db, *mk_select(db, joins, cols, limit=limit, start=start),
707 2392 aaronmk
        into=into)
708 2394 aaronmk
    return dict(items)
709 2391 aaronmk
710 2414 aaronmk
##### Database structure queries
711
712 2426 aaronmk
def table_row_count(db, table, recover=None):
713
    return value(run_query(db, *mk_select(db, table, [sql_gen.row_count],
714 2443 aaronmk
        order_by=None, start=0), recover=recover, log_level=3))
715 2426 aaronmk
716 2414 aaronmk
def table_cols(db, table, recover=None):
717
    return list(col_names(select(db, table, limit=0, order_by=None,
718 2443 aaronmk
        recover=recover, log_level=4)))
719 2414 aaronmk
720 2291 aaronmk
def pkey(db, table, recover=None):
721 832 aaronmk
    '''Assumed to be first column in table'''
722 2339 aaronmk
    return table_cols(db, table, recover)[0]
723 832 aaronmk
724 2559 aaronmk
not_null_col = 'not_null_col'
725 2340 aaronmk
726
def table_not_null_col(db, table, recover=None):
727
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
728
    if not_null_col in table_cols(db, table, recover): return not_null_col
729
    else: return pkey(db, table, recover)
730
731 853 aaronmk
def index_cols(db, table, index):
732
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
733
    automatically created. When you don't know whether something is a UNIQUE
734
    constraint or a UNIQUE index, use this function.'''
735 1909 aaronmk
    module = util.root_module(db.db)
736
    if module == 'psycopg2':
737
        return list(values(run_query(db, '''\
738 853 aaronmk
SELECT attname
739 866 aaronmk
FROM
740
(
741
        SELECT attnum, attname
742
        FROM pg_index
743
        JOIN pg_class index ON index.oid = indexrelid
744
        JOIN pg_class table_ ON table_.oid = indrelid
745
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
746
        WHERE
747
            table_.relname = %(table)s
748
            AND index.relname = %(index)s
749
    UNION
750
        SELECT attnum, attname
751
        FROM
752
        (
753
            SELECT
754
                indrelid
755
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
756
                    AS indkey
757
            FROM pg_index
758
            JOIN pg_class index ON index.oid = indexrelid
759
            JOIN pg_class table_ ON table_.oid = indrelid
760
            WHERE
761
                table_.relname = %(table)s
762
                AND index.relname = %(index)s
763
        ) s
764
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
765
) s
766 853 aaronmk
ORDER BY attnum
767
''',
768 2443 aaronmk
            {'table': table, 'index': index}, cacheable=True, log_level=4)))
769 1909 aaronmk
    else: raise NotImplementedError("Can't list index columns for "+module+
770
        ' database')
771 853 aaronmk
772 464 aaronmk
def constraint_cols(db, table, constraint):
773 1849 aaronmk
    module = util.root_module(db.db)
774 464 aaronmk
    if module == 'psycopg2':
775
        return list(values(run_query(db, '''\
776
SELECT attname
777
FROM pg_constraint
778
JOIN pg_class ON pg_class.oid = conrelid
779
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
780
WHERE
781
    relname = %(table)s
782
    AND conname = %(constraint)s
783
ORDER BY attnum
784
''',
785
            {'table': table, 'constraint': constraint})))
786
    else: raise NotImplementedError("Can't list constraint columns for "+module+
787
        ' database')
788
789 2096 aaronmk
row_num_col = '_row_num'
790
791 2538 aaronmk
def add_index(db, expr):
792
    '''Adds an index on a column or expression if it doesn't already exist.
793
    Currently, only function calls are supported as expressions.
794
    '''
795
    expr = copy.copy(expr) # don't modify input!
796
797
    # Extract col
798 2539 aaronmk
    if isinstance(expr, sql_gen.FunctionCall):
799
        col = expr.args[0]
800 2541 aaronmk
        expr = sql_gen.Expr(expr)
801 2538 aaronmk
    else: col = expr
802 2408 aaronmk
    assert sql_gen.is_table_col(col)
803
804 2538 aaronmk
    index = sql_gen.as_Table(str(expr))
805 2408 aaronmk
    table = col.table
806 2538 aaronmk
    col.table = None
807 2408 aaronmk
    try: run_query(db, 'CREATE INDEX '+index.to_str(db)+' ON '+table.to_str(db)
808 2538 aaronmk
        +' ('+expr.to_str(db)+')', recover=True, cacheable=True, log_level=3)
809 2408 aaronmk
    except DuplicateTableException: pass # index already existed
810
811 2406 aaronmk
def index_pkey(db, table, recover=None):
812
    '''Makes the first column in a table the primary key.
813
    @pre The table must not already have a primary key.
814
    '''
815
    table = sql_gen.as_Table(table)
816
817
    index = sql_gen.as_Table(table.name+'_pkey')
818
    col = sql_gen.to_name_only_col(pkey(db, table, recover))
819
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD CONSTRAINT '
820 2443 aaronmk
        +index.to_str(db)+' PRIMARY KEY('+col.to_str(db)+')', recover=recover,
821
        log_level=3)
822 2406 aaronmk
823 2086 aaronmk
def add_row_num(db, table):
824 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
825
    be the primary key.'''
826 2320 aaronmk
    table = sql_gen.as_Table(table).to_str(db)
827 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
828 2443 aaronmk
        +' serial NOT NULL PRIMARY KEY', log_level=3)
829 2086 aaronmk
830 2548 aaronmk
def tables(db, schema_like='public', table_like='%'):
831 1849 aaronmk
    module = util.root_module(db.db)
832 2548 aaronmk
    params = {'schema_like': schema_like, 'table_like': table_like}
833 832 aaronmk
    if module == 'psycopg2':
834 1968 aaronmk
        return values(run_query(db, '''\
835
SELECT tablename
836
FROM pg_tables
837
WHERE
838 2548 aaronmk
    schemaname LIKE %(schema_like)s
839 1968 aaronmk
    AND tablename LIKE %(table_like)s
840
ORDER BY tablename
841
''',
842
            params, cacheable=True))
843
    elif module == 'MySQLdb':
844
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
845
            cacheable=True))
846 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
847 830 aaronmk
848 833 aaronmk
##### Database management
849
850 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
851
    '''For kw_args, see tables()'''
852
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
853 833 aaronmk
854 832 aaronmk
##### Heuristic queries
855
856 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
857 1554 aaronmk
    '''Recovers from errors.
858 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
859
    '''
860 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
861
862 471 aaronmk
    try:
863 2149 aaronmk
        cur = insert(db, table, row, pkey_, recover=True)
864 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
865
            row_ct_ref[0] += cur.rowcount
866
        return value(cur)
867 471 aaronmk
    except DuplicateKeyException, e:
868 2104 aaronmk
        return value(select(db, table, [pkey_],
869 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
870 471 aaronmk
871 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
872 830 aaronmk
    '''Recovers from errors'''
873 2209 aaronmk
    try: return value(select(db, table, [pkey], row, limit=1, recover=True))
874 14 aaronmk
    except StopIteration:
875 40 aaronmk
        if not create: raise
876 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
877 2078 aaronmk
878 2508 aaronmk
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, into=None,
879 2552 aaronmk
    default=None, is_func=False):
880 2078 aaronmk
    '''Recovers from errors.
881
    Only works under PostgreSQL (uses INSERT RETURNING).
882 2131 aaronmk
    @param in_tables The main input table to select from, followed by a list of
883
        tables to join with it using the main input table's pkey
884 2312 aaronmk
    @param mapping dict(out_table_col=in_table_col, ...)
885
        * out_table_col: sql_gen.Col|str
886 2323 aaronmk
        * in_table_col: sql_gen.Col Wrap literal values in a sql_gen.NamedCol
887 2489 aaronmk
    @param into The table to contain the output and input pkeys.
888 2495 aaronmk
        Defaults to `out_table.name+'-pkeys'`.
889 2509 aaronmk
    @param default The *output* column to use as the pkey for missing rows.
890
        If this output column does not exist in the mapping, uses None.
891 2552 aaronmk
    @param is_func Whether out_table is the name of a SQL function, not a table
892 2312 aaronmk
    @return sql_gen.Col Where the output pkeys are made available
893 2078 aaronmk
    '''
894 2329 aaronmk
    out_table = sql_gen.as_Table(out_table)
895 2565 aaronmk
    mapping = sql_gen.ColDict(mapping)
896 2555 aaronmk
    if into == None:
897
        into = out_table.name
898
        if is_func: into += '()'
899
        else: into += '-pkeys'
900 2489 aaronmk
    into = sql_gen.as_Table(into)
901 2312 aaronmk
902 2450 aaronmk
    def log_debug(msg): db.log_debug(msg, level=1.5)
903 2505 aaronmk
    def col_ustr(str_):
904 2567 aaronmk
        return strings.repr_no_u(sql_gen.remove_col_rename(str_))
905 2450 aaronmk
906 2486 aaronmk
    log_debug('********** New iteration **********')
907 2505 aaronmk
    log_debug('Inserting these input columns into '+strings.as_tt(
908
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
909 2463 aaronmk
910 2382 aaronmk
    # Create input joins from list of input tables
911
    in_tables_ = in_tables[:] # don't modify input!
912
    in_tables0 = in_tables_.pop(0) # first table is separate
913 2279 aaronmk
    in_pkey = pkey(db, in_tables0, recover=True)
914 2285 aaronmk
    in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
915 2460 aaronmk
    input_joins = [in_tables0]+[sql_gen.Join(v,
916
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
917 2131 aaronmk
918 2486 aaronmk
    log_debug('Joining together input tables into temp table')
919 2395 aaronmk
    # Place in new table for speed and so don't modify input if values edited
920 2546 aaronmk
    in_table = sql_gen.Table(into.name.replace('-pkeys', '')+'-input')
921 2395 aaronmk
    flatten_cols = filter(sql_gen.is_table_col, mapping.values())
922
    mapping = dicts.join(mapping, flatten(db, in_table, input_joins,
923
        flatten_cols, preserve=[in_pkey_col], start=0))
924
    input_joins = [in_table]
925 2486 aaronmk
    db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
926 2395 aaronmk
927 2509 aaronmk
    # Resolve default value column
928
    try: default = mapping[default]
929
    except KeyError:
930
        if default != None:
931
            db.log_debug('Default value column '
932
                +strings.as_tt(strings.repr_no_u(default))
933 2511 aaronmk
                +' does not exist in mapping, falling back to None', level=2.1)
934 2509 aaronmk
            default = None
935
936 2279 aaronmk
    out_pkey = pkey(db, out_table, recover=True)
937 2285 aaronmk
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
938 2142 aaronmk
939 2387 aaronmk
    pkeys_names = [in_pkey, out_pkey]
940 2236 aaronmk
    pkeys_cols = [in_pkey_col, out_pkey_col]
941
942 2201 aaronmk
    pkeys_table_exists_ref = [False]
943 2420 aaronmk
    def insert_into_pkeys(joins, cols):
944
        query, params = mk_select(db, joins, cols, order_by=None, start=0)
945 2201 aaronmk
        if pkeys_table_exists_ref[0]:
946 2489 aaronmk
            insert_select(db, into, pkeys_names, query, params)
947 2201 aaronmk
        else:
948 2489 aaronmk
            run_query_into(db, query, params, into=into)
949 2201 aaronmk
            pkeys_table_exists_ref[0] = True
950
951 2429 aaronmk
    limit_ref = [None]
952 2380 aaronmk
    conds = set()
953 2233 aaronmk
    distinct_on = []
954 2325 aaronmk
    def mk_main_select(joins, cols):
955 2429 aaronmk
        return mk_select(db, joins, cols, conds, distinct_on,
956
            limit=limit_ref[0], start=0)
957 2132 aaronmk
958 2519 aaronmk
    exc_strs = set()
959 2309 aaronmk
    def log_exc(e):
960 2519 aaronmk
        e_str = exc.str_(e, first_line_only=True)
961
        log_debug('Caught exception: '+e_str)
962
        assert e_str not in exc_strs # avoid infinite loops
963
        exc_strs.add(e_str)
964 2451 aaronmk
    def remove_all_rows():
965 2450 aaronmk
        log_debug('Returning NULL for all rows')
966 2429 aaronmk
        limit_ref[0] = 0 # just create an empty pkeys table
967 2409 aaronmk
    def ignore(in_col, value):
968 2545 aaronmk
        in_col_str = strings.as_tt(repr(in_col))
969 2544 aaronmk
        db.log_debug('Adding index on '+in_col_str+' to enable fast filtering',
970
            level=2.5)
971 2537 aaronmk
        add_index(db, in_col)
972 2545 aaronmk
        log_debug('Ignoring rows with '+in_col_str+' = '
973
            +strings.as_tt(repr(value)))
974 2403 aaronmk
    def remove_rows(in_col, value):
975 2409 aaronmk
        ignore(in_col, value)
976 2378 aaronmk
        cond = (in_col, sql_gen.CompareCond(value, '!='))
977
        assert cond not in conds # avoid infinite loops
978 2380 aaronmk
        conds.add(cond)
979 2403 aaronmk
    def invalid2null(in_col, value):
980 2409 aaronmk
        ignore(in_col, value)
981 2403 aaronmk
        update(db, in_table, [(in_col, None)],
982
            sql_gen.ColValueCond(in_col, value))
983 2245 aaronmk
984 2206 aaronmk
    # Do inserts and selects
985 2565 aaronmk
    join_cols = sql_gen.ColDict()
986 2495 aaronmk
    insert_out_pkeys = sql_gen.Table(into.name+'-insert_out_pkeys')
987
    insert_in_pkeys = sql_gen.Table(into.name+'-insert_in_pkeys')
988 2206 aaronmk
    while True:
989 2521 aaronmk
        if limit_ref[0] == 0: # special case
990
            log_debug('Creating an empty pkeys table')
991
            cur = run_query_into(db, *mk_select(db, out_table, [out_pkey],
992
                limit=limit_ref[0]), into=insert_out_pkeys)
993
            break # don't do main case
994
995 2303 aaronmk
        has_joins = join_cols != {}
996
997 2305 aaronmk
        # Prepare to insert new rows
998 2325 aaronmk
        insert_joins = input_joins[:] # don't modify original!
999 2403 aaronmk
        insert_args = dict(recover=True, cacheable=False)
1000 2303 aaronmk
        if has_joins:
1001 2317 aaronmk
            distinct_on = [v.to_Col() for v in join_cols.values()]
1002 2325 aaronmk
            insert_joins.append(sql_gen.Join(out_table, join_cols,
1003
                sql_gen.filter_out))
1004
        else:
1005 2404 aaronmk
            insert_args.update(dict(returning=out_pkey, into=insert_out_pkeys))
1006 2520 aaronmk
        main_select = mk_main_select(insert_joins, mapping.values())[0]
1007 2303 aaronmk
1008 2486 aaronmk
        log_debug('Trying to insert new rows')
1009 2206 aaronmk
        try:
1010 2518 aaronmk
            cur = insert_select(db, out_table, mapping.keys(), main_select,
1011
                **insert_args)
1012 2357 aaronmk
            break # insert successful
1013 2206 aaronmk
        except DuplicateKeyException, e:
1014 2309 aaronmk
            log_exc(e)
1015
1016 2258 aaronmk
            old_join_cols = join_cols.copy()
1017 2565 aaronmk
            join_cols.update(util.dict_subset_right_join(mapping, e.cols))
1018 2486 aaronmk
            log_debug('Ignoring existing rows, comparing on these columns:\n'
1019 2505 aaronmk
                +strings.as_inline_table(join_cols, ustr=col_ustr))
1020 2258 aaronmk
            assert join_cols != old_join_cols # avoid infinite loops
1021 2230 aaronmk
        except NullValueException, e:
1022 2309 aaronmk
            log_exc(e)
1023
1024 2230 aaronmk
            out_col, = e.cols
1025
            try: in_col = mapping[out_col]
1026 2356 aaronmk
            except KeyError:
1027 2486 aaronmk
                log_debug('Missing mapping for NOT NULL column '+out_col)
1028 2451 aaronmk
                remove_all_rows()
1029 2403 aaronmk
            else: remove_rows(in_col, None)
1030 2542 aaronmk
        except FunctionValueException, e:
1031
            log_exc(e)
1032
1033
            func_name = e.name
1034
            value = e.value
1035
            for out_col, in_col in mapping.iteritems():
1036 2562 aaronmk
                invalid2null(sql_gen.unwrap_func_call(in_col, func_name), value)
1037 2525 aaronmk
        except MissingCastException, e:
1038
            log_exc(e)
1039
1040
            out_col = e.col
1041 2534 aaronmk
            mapping[out_col] = sql_gen.wrap_in_func(e.type, mapping[out_col])
1042 2429 aaronmk
        except DatabaseErrors, e:
1043
            log_exc(e)
1044
1045 2531 aaronmk
            msg = 'No handler for exception: '+exc.str_(e)
1046 2451 aaronmk
            warnings.warn(DbWarning(msg))
1047
            log_debug(msg)
1048
            remove_all_rows()
1049 2358 aaronmk
        # after exception handled, rerun loop with additional constraints
1050 2132 aaronmk
1051 2357 aaronmk
    if row_ct_ref != None and cur.rowcount >= 0:
1052
        row_ct_ref[0] += cur.rowcount
1053
1054
    if has_joins:
1055
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols)]
1056 2486 aaronmk
        log_debug('Getting output table pkeys of existing/inserted rows')
1057 2420 aaronmk
        insert_into_pkeys(select_joins, pkeys_cols)
1058 2357 aaronmk
    else:
1059 2404 aaronmk
        add_row_num(db, insert_out_pkeys) # for joining with input pkeys
1060 2357 aaronmk
1061 2486 aaronmk
        log_debug('Getting input table pkeys of inserted rows')
1062 2357 aaronmk
        run_query_into(db, *mk_main_select(input_joins, [in_pkey]),
1063 2404 aaronmk
            into=insert_in_pkeys)
1064
        add_row_num(db, insert_in_pkeys) # for joining with output pkeys
1065 2357 aaronmk
1066 2428 aaronmk
        assert table_row_count(db, insert_out_pkeys) == table_row_count(db,
1067
            insert_in_pkeys)
1068
1069 2486 aaronmk
        log_debug('Combining output and input pkeys in inserted order')
1070 2404 aaronmk
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
1071 2357 aaronmk
            {row_num_col: sql_gen.join_same_not_null})]
1072 2420 aaronmk
        insert_into_pkeys(pkey_joins, pkeys_names)
1073 2357 aaronmk
1074 2486 aaronmk
    db.log_debug('Adding pkey on pkeys table to enable fast joins', level=2.5)
1075 2489 aaronmk
    index_pkey(db, into)
1076 2407 aaronmk
1077 2508 aaronmk
    log_debug('Setting pkeys of missing rows to '+strings.as_tt(repr(default)))
1078 2489 aaronmk
    missing_rows_joins = input_joins+[sql_gen.Join(into,
1079 2357 aaronmk
        {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
1080
        # must use join_same_not_null or query will take forever
1081 2420 aaronmk
    insert_into_pkeys(missing_rows_joins,
1082 2508 aaronmk
        [in_pkey_col, sql_gen.NamedCol(out_pkey, default)])
1083 2357 aaronmk
1084 2489 aaronmk
    assert table_row_count(db, into) == table_row_count(db, in_table)
1085 2428 aaronmk
1086 2489 aaronmk
    return sql_gen.Col(out_pkey, into)
1087 2115 aaronmk
1088
##### Data cleanup
1089
1090 2290 aaronmk
def cleanup_table(db, table, cols):
1091 2115 aaronmk
    def esc_name_(name): return esc_name(db, name)
1092
1093 2290 aaronmk
    table = sql_gen.as_Table(table).to_str(db)
1094 2115 aaronmk
    cols = map(esc_name_, cols)
1095
1096
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
1097
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
1098
            for col in cols))),
1099
        dict(null0='', null1=r'\N'))