Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 2127 aaronmk
import operator
5 11 aaronmk
import re
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 1889 aaronmk
from Proxy import Proxy
13 1872 aaronmk
import rand
14 2217 aaronmk
import sql_gen
15 862 aaronmk
import strings
16 131 aaronmk
import util
17 11 aaronmk
18 832 aaronmk
##### Exceptions
19
20 2170 aaronmk
def get_cur_query(cur, input_query=None, input_params=None):
21 2168 aaronmk
    raw_query = None
22
    if hasattr(cur, 'query'): raw_query = cur.query
23
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
24 2170 aaronmk
25
    if raw_query != None: return raw_query
26 2371 aaronmk
    else: return '[input] '+strings.ustr(input_query)+' % '+repr(input_params)
27 14 aaronmk
28 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
29
    '''For params, see get_cur_query()'''
30
    exc.add_msg(e, 'query: '+str(get_cur_query(*args, **kw_args)))
31 135 aaronmk
32 300 aaronmk
class DbException(exc.ExceptionWithCause):
33 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
34 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
35 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
36
37 2143 aaronmk
class ExceptionWithName(DbException):
38
    def __init__(self, name, cause=None):
39 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
40 2143 aaronmk
        self.name = name
41 360 aaronmk
42 2240 aaronmk
class ExceptionWithNameValue(DbException):
43
    def __init__(self, name, value, cause=None):
44 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name))
45
            +'; value: '+strings.as_tt(repr(value)), cause)
46 2240 aaronmk
        self.name = name
47
        self.value = value
48
49 2306 aaronmk
class ConstraintException(DbException):
50
    def __init__(self, name, cols, cause=None):
51 2484 aaronmk
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
52
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
53 2306 aaronmk
        self.name = name
54 468 aaronmk
        self.cols = cols
55 11 aaronmk
56 2523 aaronmk
class MissingCastException(DbException):
57
    def __init__(self, type_, col, cause=None):
58
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
59
            +' on column: '+strings.as_tt(col), cause)
60
        self.type = type_
61
        self.col = col
62
63 2143 aaronmk
class NameException(DbException): pass
64
65 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
66 13 aaronmk
67 2306 aaronmk
class NullValueException(ConstraintException): pass
68 13 aaronmk
69 2240 aaronmk
class FunctionValueException(ExceptionWithNameValue): pass
70 2239 aaronmk
71 2143 aaronmk
class DuplicateTableException(ExceptionWithName): pass
72
73 2188 aaronmk
class DuplicateFunctionException(ExceptionWithName): pass
74
75 89 aaronmk
class EmptyRowException(DbException): pass
76
77 865 aaronmk
##### Warnings
78
79
class DbWarning(UserWarning): pass
80
81 1930 aaronmk
##### Result retrieval
82
83
def col_names(cur): return (col[0] for col in cur.description)
84
85
def rows(cur): return iter(lambda: cur.fetchone(), None)
86
87
def consume_rows(cur):
88
    '''Used to fetch all rows so result will be cached'''
89
    iters.consume_iter(rows(cur))
90
91
def next_row(cur): return rows(cur).next()
92
93
def row(cur):
94
    row_ = next_row(cur)
95
    consume_rows(cur)
96
    return row_
97
98
def next_value(cur): return next_row(cur)[0]
99
100
def value(cur): return row(cur)[0]
101
102
def values(cur): return iters.func_iter(lambda: next_value(cur))
103
104
def value_or_none(cur):
105
    try: return value(cur)
106
    except StopIteration: return None
107
108 2101 aaronmk
##### Input validation
109
110 2573 aaronmk
def esc_name_by_module(module, name):
111
    if module == 'psycopg2' or module == None: quote = '"'
112 2101 aaronmk
    elif module == 'MySQLdb': quote = '`'
113
    else: raise NotImplementedError("Can't escape name for "+module+' database')
114 2500 aaronmk
    return sql_gen.esc_name(name, quote)
115 2101 aaronmk
116
def esc_name_by_engine(engine, name, **kw_args):
117
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
118
119
def esc_name(db, name, **kw_args):
120
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
121
122
def qual_name(db, schema, table):
123
    def esc_name_(name): return esc_name(db, name)
124
    table = esc_name_(table)
125
    if schema != None: return esc_name_(schema)+'.'+table
126
    else: return table
127
128 1869 aaronmk
##### Database connections
129 1849 aaronmk
130 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
131 1926 aaronmk
132 1869 aaronmk
db_engines = {
133
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
134
    'PostgreSQL': ('psycopg2', {}),
135
}
136
137
DatabaseErrors_set = set([DbException])
138
DatabaseErrors = tuple(DatabaseErrors_set)
139
140
def _add_module(module):
141
    DatabaseErrors_set.add(module.DatabaseError)
142
    global DatabaseErrors
143
    DatabaseErrors = tuple(DatabaseErrors_set)
144
145
def db_config_str(db_config):
146
    return db_config['engine']+' database '+db_config['database']
147
148 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
149 1894 aaronmk
150 2448 aaronmk
log_debug_none = lambda msg, level=2: None
151 1901 aaronmk
152 1849 aaronmk
class DbConn:
153 2190 aaronmk
    def __init__(self, db_config, serializable=True, autocommit=False,
154
        caching=True, log_debug=log_debug_none):
155 1869 aaronmk
        self.db_config = db_config
156
        self.serializable = serializable
157 2190 aaronmk
        self.autocommit = autocommit
158
        self.caching = caching
159 1901 aaronmk
        self.log_debug = log_debug
160 2193 aaronmk
        self.debug = log_debug != log_debug_none
161 1869 aaronmk
162
        self.__db = None
163 1889 aaronmk
        self.query_results = {}
164 2139 aaronmk
        self._savepoint = 0
165 1869 aaronmk
166
    def __getattr__(self, name):
167
        if name == '__dict__': raise Exception('getting __dict__')
168
        if name == 'db': return self._db()
169
        else: raise AttributeError()
170
171
    def __getstate__(self):
172
        state = copy.copy(self.__dict__) # shallow copy
173 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
174 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
175
        return state
176
177 2165 aaronmk
    def connected(self): return self.__db != None
178
179 1869 aaronmk
    def _db(self):
180
        if self.__db == None:
181
            # Process db_config
182
            db_config = self.db_config.copy() # don't modify input!
183 2097 aaronmk
            schemas = db_config.pop('schemas', None)
184 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
185
            module = __import__(module_name)
186
            _add_module(module)
187
            for orig, new in mappings.iteritems():
188
                try: util.rename_key(db_config, orig, new)
189
                except KeyError: pass
190
191
            # Connect
192
            self.__db = module.connect(**db_config)
193
194
            # Configure connection
195 2234 aaronmk
            if self.serializable and not self.autocommit: run_raw_query(self,
196
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
197 2101 aaronmk
            if schemas != None:
198
                schemas_ = ''.join((esc_name(self, s)+', '
199
                    for s in schemas.split(',')))
200
                run_raw_query(self, "SELECT set_config('search_path', \
201
%s || current_setting('search_path'), false)", [schemas_])
202 1869 aaronmk
203
        return self.__db
204 1889 aaronmk
205 1891 aaronmk
    class DbCursor(Proxy):
206 1927 aaronmk
        def __init__(self, outer):
207 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
208 2191 aaronmk
            self.outer = outer
209 1927 aaronmk
            self.query_results = outer.query_results
210 1894 aaronmk
            self.query_lookup = None
211 1891 aaronmk
            self.result = []
212 1889 aaronmk
213 1894 aaronmk
        def execute(self, query, params=None):
214 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
215 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
216 2148 aaronmk
            try:
217 2191 aaronmk
                try:
218
                    return_value = self.inner.execute(query, params)
219
                    self.outer.do_autocommit()
220 2148 aaronmk
                finally: self.query = get_cur_query(self.inner)
221 1904 aaronmk
            except Exception, e:
222 2170 aaronmk
                _add_cursor_info(e, self, query, params)
223 1904 aaronmk
                self.result = e # cache the exception as the result
224
                self._cache_result()
225
                raise
226 1930 aaronmk
            # Fetch all rows so result will be cached
227
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
228 1894 aaronmk
            return return_value
229
230 1891 aaronmk
        def fetchone(self):
231
            row = self.inner.fetchone()
232 1899 aaronmk
            if row != None: self.result.append(row)
233
            # otherwise, fetched all rows
234 1904 aaronmk
            else: self._cache_result()
235
            return row
236
237
        def _cache_result(self):
238 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
239
            # idempotent, but an invalid insert will always be invalid
240 1930 aaronmk
            if self.query_results != None and (not self._is_insert
241 1906 aaronmk
                or isinstance(self.result, Exception)):
242
243 1894 aaronmk
                assert self.query_lookup != None
244 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
245
                    util.dict_subset(dicts.AttrsDictView(self),
246
                    ['query', 'result', 'rowcount', 'description']))
247 1906 aaronmk
248 1916 aaronmk
        class CacheCursor:
249
            def __init__(self, cached_result): self.__dict__ = cached_result
250
251 1927 aaronmk
            def execute(self, *args, **kw_args):
252 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
253
                # otherwise, result is a rows list
254
                self.iter = iter(self.result)
255
256
            def fetchone(self):
257
                try: return self.iter.next()
258
                except StopIteration: return None
259 1891 aaronmk
260 2212 aaronmk
    def esc_value(self, value):
261 2215 aaronmk
        module = util.root_module(self.db)
262 2374 aaronmk
        if module == 'psycopg2': str_ = self.db.cursor().mogrify('%s', [value])
263 2212 aaronmk
        elif module == 'MySQLdb':
264
            import _mysql
265 2374 aaronmk
            str_ = _mysql.escape_string(value)
266 2212 aaronmk
        else: raise NotImplementedError("Can't escape value for "+module
267
            +' database')
268 2374 aaronmk
        return strings.to_unicode(str_)
269 2212 aaronmk
270 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
271
272 2445 aaronmk
    def run_query(self, query, params=None, cacheable=False, log_level=2,
273 2464 aaronmk
        debug_msg_ref=None):
274 2445 aaronmk
        '''
275 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
276
            throws one of these exceptions.
277 2445 aaronmk
        '''
278 2167 aaronmk
        assert query != None
279
280 2047 aaronmk
        if not self.caching: cacheable = False
281 1903 aaronmk
        used_cache = False
282
        try:
283 1927 aaronmk
            # Get cursor
284
            if cacheable:
285
                query_lookup = _query_lookup(query, params)
286
                try:
287
                    cur = self.query_results[query_lookup]
288
                    used_cache = True
289
                except KeyError: cur = self.DbCursor(self)
290
            else: cur = self.db.cursor()
291
292
            # Run query
293 2148 aaronmk
            cur.execute(query, params)
294 1903 aaronmk
        finally:
295 2464 aaronmk
            if self.debug and debug_msg_ref != None:# only compute msg if needed
296 2470 aaronmk
                if used_cache: cache_status = 'cache hit'
297
                elif cacheable: cache_status = 'cache miss'
298
                else: cache_status = 'non-cacheable'
299 2472 aaronmk
                query_code = strings.as_code(str(get_cur_query(cur, query,
300
                    params)), 'SQL')
301
                debug_msg_ref[0] = 'DB query: '+cache_status+':\n'+query_code
302 1903 aaronmk
303
        return cur
304 1914 aaronmk
305
    def is_cached(self, query, params=None):
306
        return _query_lookup(query, params) in self.query_results
307 2139 aaronmk
308
    def with_savepoint(self, func):
309 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
310 2443 aaronmk
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
311 2139 aaronmk
        self._savepoint += 1
312
        try:
313
            try: return_val = func()
314
            finally:
315
                self._savepoint -= 1
316
                assert self._savepoint >= 0
317
        except:
318 2443 aaronmk
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
319 2139 aaronmk
            raise
320
        else:
321 2443 aaronmk
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
322 2191 aaronmk
            self.do_autocommit()
323 2139 aaronmk
            return return_val
324 2191 aaronmk
325
    def do_autocommit(self):
326
        '''Autocommits if outside savepoint'''
327
        assert self._savepoint >= 0
328
        if self.autocommit and self._savepoint == 0:
329
            self.log_debug('Autocommiting')
330
            self.db.commit()
331 1849 aaronmk
332 1869 aaronmk
connect = DbConn
333
334 832 aaronmk
##### Querying
335
336 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
337 2085 aaronmk
    '''For params, see DbConn.run_query()'''
338 1894 aaronmk
    return db.run_query(*args, **kw_args)
339 11 aaronmk
340 2068 aaronmk
def mogrify(db, query, params):
341
    module = util.root_module(db.db)
342
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
343
    else: raise NotImplementedError("Can't mogrify query for "+module+
344
        ' database')
345
346 832 aaronmk
##### Recoverable querying
347 15 aaronmk
348 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
349 11 aaronmk
350 2464 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False,
351
    log_level=2, log_ignore_excs=None, **kw_args):
352 2441 aaronmk
    '''For params, see run_raw_query()'''
353 830 aaronmk
    if recover == None: recover = False
354 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
355
    log_ignore_excs = tuple(log_ignore_excs)
356 830 aaronmk
357 2464 aaronmk
    debug_msg_ref = [None]
358 2148 aaronmk
    try:
359 2464 aaronmk
        try:
360
            def run(): return run_raw_query(db, query, params, cacheable,
361
                log_level, debug_msg_ref, **kw_args)
362
            if recover and not db.is_cached(query, params):
363
                return with_savepoint(db, run)
364
            else: return run() # don't need savepoint if cached
365
        except Exception, e:
366
            if not recover: raise # need savepoint to run index_cols()
367
            msg = exc.str_(e)
368
369
            match = re.search(r'duplicate key value violates unique constraint '
370 2493 aaronmk
                r'"((_?[^\W_]+)_.+?)"', msg)
371 2464 aaronmk
            if match:
372
                constraint, table = match.groups()
373
                try: cols = index_cols(db, table, constraint)
374
                except NotImplementedError: raise e
375
                else: raise DuplicateKeyException(constraint, cols, e)
376
377 2493 aaronmk
            match = re.search(r'null value in column "(.+?)" violates not-null'
378 2464 aaronmk
                r' constraint', msg)
379
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
380
381
            match = re.search(r'\b(?:invalid input (?:syntax|value)\b.*?'
382
                r'|date/time field value out of range): "(.+?)"\n'
383 2535 aaronmk
                r'(?:(?s).*?)\bfunction "(.+?)"', msg)
384 2464 aaronmk
            if match:
385
                value, name = match.groups()
386
                raise FunctionValueException(name, strings.to_unicode(value), e)
387
388 2526 aaronmk
            match = re.search(r'column "(.+?)" is of type (.+?) but expression '
389 2523 aaronmk
                r'is of type', msg)
390
            if match:
391
                col, type_ = match.groups()
392
                raise MissingCastException(type_, col, e)
393
394 2493 aaronmk
            match = re.search(r'relation "(.+?)" already exists', msg)
395 2464 aaronmk
            if match: raise DuplicateTableException(match.group(1), e)
396
397 2493 aaronmk
            match = re.search(r'function "(.+?)" already exists', msg)
398 2464 aaronmk
            if match: raise DuplicateFunctionException(match.group(1), e)
399
400
            raise # no specific exception raised
401
    except log_ignore_excs:
402
        log_level += 2
403
        raise
404
    finally:
405
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)
406 830 aaronmk
407 832 aaronmk
##### Basic queries
408
409 2153 aaronmk
def next_version(name):
410
    '''Prepends the version # so it won't be removed if the name is truncated'''
411 2163 aaronmk
    version = 1 # first existing name was version 0
412 2498 aaronmk
    match = re.match(r'^#(\d+)-(.*)$', name)
413 2153 aaronmk
    if match:
414
        version = int(match.group(1))+1
415
        name = match.group(2)
416 2498 aaronmk
    return '#'+str(version)+'-'+name
417 2153 aaronmk
418 2386 aaronmk
def run_query_into(db, query, params, into=None, *args, **kw_args):
419 2085 aaronmk
    '''Outputs a query to a temp table.
420
    For params, see run_query().
421
    '''
422 2386 aaronmk
    if into == None: return run_query(db, query, params, *args, **kw_args)
423 2085 aaronmk
    else: # place rows in temp table
424 2386 aaronmk
        assert isinstance(into, sql_gen.Table)
425 2385 aaronmk
426 2153 aaronmk
        kw_args['recover'] = True
427 2464 aaronmk
        kw_args.setdefault('log_ignore_excs', (DuplicateTableException,))
428 2440 aaronmk
429 2468 aaronmk
        temp = not db.autocommit # tables are permanent in autocommit mode
430 2440 aaronmk
        # "temporary tables cannot specify a schema name", so remove schema
431
        if temp: into.schema = None
432
433 2153 aaronmk
        while True:
434
            try:
435 2194 aaronmk
                create_query = 'CREATE'
436 2440 aaronmk
                if temp: create_query += ' TEMP'
437 2467 aaronmk
                create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
438 2194 aaronmk
439
                return run_query(db, create_query, params, *args, **kw_args)
440 2153 aaronmk
                    # CREATE TABLE AS sets rowcount to # rows in query
441
            except DuplicateTableException, e:
442 2386 aaronmk
                into.name = next_version(into.name)
443 2153 aaronmk
                # try again with next version of name
444 2085 aaronmk
445 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
446
447 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
448
449 2233 aaronmk
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
450 2293 aaronmk
    start=None, order_by=order_by_pkey, default_table=None):
451 1981 aaronmk
    '''
452 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
453 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
454 1981 aaronmk
    @param fields Use None to select all fields in the table
455 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
456 2379 aaronmk
        * container can be any iterable type
457 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
458
        * compare_right_side: sql_gen.ValueCond|literal value
459 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
460
        use all columns
461 2054 aaronmk
    @return tuple(query, params)
462 1981 aaronmk
    '''
463 2315 aaronmk
    # Parse tables param
464 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
465 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
466 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
467 2121 aaronmk
468 2315 aaronmk
    # Parse other params
469 2376 aaronmk
    if conds == None: conds = []
470
    elif isinstance(conds, dict): conds = conds.items()
471 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
472 135 aaronmk
    assert limit == None or type(limit) == int
473 865 aaronmk
    assert start == None or type(start) == int
474 2315 aaronmk
    if order_by is order_by_pkey:
475
        if distinct_on != []: order_by = None
476
        else: order_by = pkey(db, table0, recover=True)
477 865 aaronmk
478 2315 aaronmk
    query = 'SELECT'
479 2056 aaronmk
480 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
481 2056 aaronmk
482 2200 aaronmk
    # DISTINCT ON columns
483 2233 aaronmk
    if distinct_on != []:
484 2467 aaronmk
        query += '\nDISTINCT'
485 2254 aaronmk
        if distinct_on is not distinct_on_all:
486 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
487
488
    # Columns
489 2467 aaronmk
    query += '\n'
490 1135 aaronmk
    if fields == None: query += '*'
491 2479 aaronmk
    else: query += '\n, '.join(map(parse_col, fields))
492 2200 aaronmk
493
    # Main table
494 2467 aaronmk
    query += '\nFROM '+table0.to_str(db)
495 865 aaronmk
496 2122 aaronmk
    # Add joins
497 2271 aaronmk
    left_table = table0
498 2263 aaronmk
    for join_ in tables:
499
        table = join_.table
500 2238 aaronmk
501 2343 aaronmk
        # Parse special values
502
        if join_.type_ is sql_gen.filter_out: # filter no match
503 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
504
                None))
505 2343 aaronmk
506 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
507 2122 aaronmk
508
        left_table = table
509
510 865 aaronmk
    missing = True
511 2376 aaronmk
    if conds != []:
512 2576 aaronmk
        if len(conds) == 1: whitespace = ' '
513
        else: whitespace = '\n'
514 2577 aaronmk
        query += '\n'+sql_gen.combine_conds(['('+sql_gen.ColValueCond(l, r)
515
            .to_str(db)+')' for l, r in conds], 'WHERE')
516 865 aaronmk
        missing = False
517 2227 aaronmk
    if order_by != None:
518 2467 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
519
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
520 865 aaronmk
    if start != None:
521 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
522 865 aaronmk
        missing = False
523
    if missing: warnings.warn(DbWarning(
524
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
525
526 2315 aaronmk
    return (query, [])
527 11 aaronmk
528 2054 aaronmk
def select(db, *args, **kw_args):
529
    '''For params, see mk_select() and run_query()'''
530
    recover = kw_args.pop('recover', None)
531
    cacheable = kw_args.pop('cacheable', True)
532 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
533 2054 aaronmk
534
    query, params = mk_select(db, *args, **kw_args)
535 2442 aaronmk
    return run_query(db, query, params, recover, cacheable, log_level=log_level)
536 2054 aaronmk
537 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
538 2292 aaronmk
    returning=None, embeddable=False):
539 1960 aaronmk
    '''
540
    @param returning str|None An inserted column (such as pkey) to return
541 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
542 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
543
        query will be fully cached, not just if it raises an exception.
544 1960 aaronmk
    '''
545 2328 aaronmk
    table = sql_gen.as_Table(table)
546 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
547
    if cols != None: cols = [sql_gen.as_Col(v).to_str(db) for v in cols]
548 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
549 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
550 2063 aaronmk
551
    # Build query
552 2497 aaronmk
    first_line = 'INSERT INTO '+table.to_str(db)
553
    query = first_line
554 2467 aaronmk
    if cols != None: query += '\n('+', '.join(cols)+')'
555
    query += '\n'+select_query
556 2063 aaronmk
557
    if returning != None:
558 2327 aaronmk
        returning_name = copy.copy(returning)
559
        returning_name.table = None
560
        returning_name = returning_name.to_str(db)
561 2467 aaronmk
        query += '\nRETURNING '+returning_name
562 2063 aaronmk
563 2070 aaronmk
    if embeddable:
564 2327 aaronmk
        assert returning != None
565
566 2070 aaronmk
        # Create function
567 2513 aaronmk
        function_name = sql_gen.clean_name(first_line)
568 2327 aaronmk
        return_type = 'SETOF '+returning.to_str(db)+'%TYPE'
569 2189 aaronmk
        while True:
570
            try:
571 2327 aaronmk
                func_schema = None
572 2468 aaronmk
                if not db.autocommit: func_schema = 'pg_temp'
573 2327 aaronmk
                function = sql_gen.Table(function_name, func_schema).to_str(db)
574 2194 aaronmk
575 2189 aaronmk
                function_query = '''\
576 2467 aaronmk
CREATE FUNCTION '''+function+'''()
577
RETURNS '''+return_type+'''
578
LANGUAGE sql
579
AS $$
580
'''+mogrify(db, query, params)+''';
581
$$;
582 2070 aaronmk
'''
583 2446 aaronmk
                run_query(db, function_query, recover=True, cacheable=True,
584 2464 aaronmk
                    log_ignore_excs=(DuplicateFunctionException,))
585 2189 aaronmk
                break # this version was successful
586
            except DuplicateFunctionException, e:
587
                function_name = next_version(function_name)
588
                # try again with next version of name
589 2070 aaronmk
590 2337 aaronmk
        # Return query that uses function
591
        func_table = sql_gen.NamedTable('f', sql_gen.CustomCode(function+'()'),
592
            [returning_name]) # AS clause requires function alias
593
        return mk_select(db, func_table, start=0, order_by=None)
594 2070 aaronmk
595 2066 aaronmk
    return (query, params)
596
597
def insert_select(db, *args, **kw_args):
598 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
599 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
600
        values in
601 2072 aaronmk
    '''
602 2386 aaronmk
    into = kw_args.pop('into', None)
603
    if into != None: kw_args['embeddable'] = True
604 2066 aaronmk
    recover = kw_args.pop('recover', None)
605
    cacheable = kw_args.pop('cacheable', True)
606
607
    query, params = mk_insert_select(db, *args, **kw_args)
608 2386 aaronmk
    return run_query_into(db, query, params, into, recover=recover,
609 2153 aaronmk
        cacheable=cacheable)
610 2063 aaronmk
611 2066 aaronmk
default = object() # tells insert() to use the default value for a column
612
613 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
614 2085 aaronmk
    '''For params, see insert_select()'''
615 1960 aaronmk
    if lists.is_seq(row): cols = None
616
    else:
617
        cols = row.keys()
618
        row = row.values()
619
    row = list(row) # ensure that "!= []" works
620
621 1961 aaronmk
    # Check for special values
622
    labels = []
623
    values = []
624
    for value in row:
625 2254 aaronmk
        if value is default: labels.append('DEFAULT')
626 1961 aaronmk
        else:
627
            labels.append('%s')
628
            values.append(value)
629
630
    # Build query
631 2467 aaronmk
    if values != []: query = 'VALUES ('+(', '.join(labels))+')'
632 2063 aaronmk
    else: query = None
633 1554 aaronmk
634 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
635 11 aaronmk
636 2402 aaronmk
def mk_update(db, table, changes=None, cond=None):
637
    '''
638
    @param changes [(col, new_value),...]
639
        * container can be any iterable type
640
        * col: sql_gen.Code|str (for col name)
641
        * new_value: sql_gen.Code|literal value
642
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
643
    @return str query
644
    '''
645
    query = 'UPDATE '+sql_gen.as_Table(table).to_str(db)+'\nSET\n'
646 2405 aaronmk
    query += ',\n'.join((sql_gen.to_name_only_col(col, table).to_str(db)+' = '
647 2402 aaronmk
        +sql_gen.as_Value(new_value).to_str(db) for col, new_value in changes))
648 2467 aaronmk
    if cond != None: query += '\nWHERE\n'+cond.to_str(db)
649 2402 aaronmk
650
    return query
651
652
def update(db, *args, **kw_args):
653
    '''For params, see mk_update() and run_query()'''
654
    recover = kw_args.pop('recover', None)
655
656
    return run_query(db, mk_update(db, *args, **kw_args), [], recover)
657
658 135 aaronmk
def last_insert_id(db):
659 1849 aaronmk
    module = util.root_module(db.db)
660 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
661
    elif module == 'MySQLdb': return db.insert_id()
662
    else: return None
663 13 aaronmk
664 1968 aaronmk
def truncate(db, table, schema='public'):
665
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
666 832 aaronmk
667 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
668 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
669 2415 aaronmk
    to names that will be distinct among the columns' tables.
670 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
671 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
672
    @param into The table for the new columns.
673 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
674
        columns will be included in the mapping even if they are not in cols.
675
        The tables of the provided Col objects will be changed to into, so make
676
        copies of them if you want to keep the original tables.
677
    @param as_items Whether to return a list of dict items instead of a dict
678 2383 aaronmk
    @return dict(orig_col=new_col, ...)
679
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
680 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
681
        * All mappings use the into table so its name can easily be
682 2383 aaronmk
          changed for all columns at once
683
    '''
684 2415 aaronmk
    cols = lists.uniqify(cols)
685
686 2394 aaronmk
    items = []
687 2389 aaronmk
    for col in preserve:
688 2390 aaronmk
        orig_col = copy.copy(col)
689 2392 aaronmk
        col.table = into
690 2394 aaronmk
        items.append((orig_col, col))
691
    preserve = set(preserve)
692
    for col in cols:
693 2515 aaronmk
        if col not in preserve: items.append((col, sql_gen.Col(str(col), into)))
694 2394 aaronmk
695
    if not as_items: items = dict(items)
696
    return items
697 2383 aaronmk
698 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
699 2391 aaronmk
    '''For params, see mk_flatten_mapping()
700
    @return See return value of mk_flatten_mapping()
701
    '''
702 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
703
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
704 2391 aaronmk
    run_query_into(db, *mk_select(db, joins, cols, limit=limit, start=start),
705 2392 aaronmk
        into=into)
706 2394 aaronmk
    return dict(items)
707 2391 aaronmk
708 2414 aaronmk
##### Database structure queries
709
710 2426 aaronmk
def table_row_count(db, table, recover=None):
711
    return value(run_query(db, *mk_select(db, table, [sql_gen.row_count],
712 2443 aaronmk
        order_by=None, start=0), recover=recover, log_level=3))
713 2426 aaronmk
714 2414 aaronmk
def table_cols(db, table, recover=None):
715
    return list(col_names(select(db, table, limit=0, order_by=None,
716 2443 aaronmk
        recover=recover, log_level=4)))
717 2414 aaronmk
718 2291 aaronmk
def pkey(db, table, recover=None):
719 832 aaronmk
    '''Assumed to be first column in table'''
720 2339 aaronmk
    return table_cols(db, table, recover)[0]
721 832 aaronmk
722 2559 aaronmk
not_null_col = 'not_null_col'
723 2340 aaronmk
724
def table_not_null_col(db, table, recover=None):
725
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
726
    if not_null_col in table_cols(db, table, recover): return not_null_col
727
    else: return pkey(db, table, recover)
728
729 853 aaronmk
def index_cols(db, table, index):
730
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
731
    automatically created. When you don't know whether something is a UNIQUE
732
    constraint or a UNIQUE index, use this function.'''
733 1909 aaronmk
    module = util.root_module(db.db)
734
    if module == 'psycopg2':
735
        return list(values(run_query(db, '''\
736 853 aaronmk
SELECT attname
737 866 aaronmk
FROM
738
(
739
        SELECT attnum, attname
740
        FROM pg_index
741
        JOIN pg_class index ON index.oid = indexrelid
742
        JOIN pg_class table_ ON table_.oid = indrelid
743
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
744
        WHERE
745
            table_.relname = %(table)s
746
            AND index.relname = %(index)s
747
    UNION
748
        SELECT attnum, attname
749
        FROM
750
        (
751
            SELECT
752
                indrelid
753
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
754
                    AS indkey
755
            FROM pg_index
756
            JOIN pg_class index ON index.oid = indexrelid
757
            JOIN pg_class table_ ON table_.oid = indrelid
758
            WHERE
759
                table_.relname = %(table)s
760
                AND index.relname = %(index)s
761
        ) s
762
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
763
) s
764 853 aaronmk
ORDER BY attnum
765
''',
766 2443 aaronmk
            {'table': table, 'index': index}, cacheable=True, log_level=4)))
767 1909 aaronmk
    else: raise NotImplementedError("Can't list index columns for "+module+
768
        ' database')
769 853 aaronmk
770 464 aaronmk
def constraint_cols(db, table, constraint):
771 1849 aaronmk
    module = util.root_module(db.db)
772 464 aaronmk
    if module == 'psycopg2':
773
        return list(values(run_query(db, '''\
774
SELECT attname
775
FROM pg_constraint
776
JOIN pg_class ON pg_class.oid = conrelid
777
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
778
WHERE
779
    relname = %(table)s
780
    AND conname = %(constraint)s
781
ORDER BY attnum
782
''',
783
            {'table': table, 'constraint': constraint})))
784
    else: raise NotImplementedError("Can't list constraint columns for "+module+
785
        ' database')
786
787 2096 aaronmk
row_num_col = '_row_num'
788
789 2538 aaronmk
def add_index(db, expr):
790
    '''Adds an index on a column or expression if it doesn't already exist.
791
    Currently, only function calls are supported as expressions.
792
    '''
793
    expr = copy.copy(expr) # don't modify input!
794
795
    # Extract col
796 2539 aaronmk
    if isinstance(expr, sql_gen.FunctionCall):
797
        col = expr.args[0]
798 2541 aaronmk
        expr = sql_gen.Expr(expr)
799 2538 aaronmk
    else: col = expr
800 2408 aaronmk
    assert sql_gen.is_table_col(col)
801
802 2538 aaronmk
    index = sql_gen.as_Table(str(expr))
803 2408 aaronmk
    table = col.table
804 2538 aaronmk
    col.table = None
805 2408 aaronmk
    try: run_query(db, 'CREATE INDEX '+index.to_str(db)+' ON '+table.to_str(db)
806 2538 aaronmk
        +' ('+expr.to_str(db)+')', recover=True, cacheable=True, log_level=3)
807 2408 aaronmk
    except DuplicateTableException: pass # index already existed
808
809 2406 aaronmk
def index_pkey(db, table, recover=None):
810
    '''Makes the first column in a table the primary key.
811
    @pre The table must not already have a primary key.
812
    '''
813
    table = sql_gen.as_Table(table)
814
815
    index = sql_gen.as_Table(table.name+'_pkey')
816
    col = sql_gen.to_name_only_col(pkey(db, table, recover))
817
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD CONSTRAINT '
818 2443 aaronmk
        +index.to_str(db)+' PRIMARY KEY('+col.to_str(db)+')', recover=recover,
819
        log_level=3)
820 2406 aaronmk
821 2086 aaronmk
def add_row_num(db, table):
822 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
823
    be the primary key.'''
824 2320 aaronmk
    table = sql_gen.as_Table(table).to_str(db)
825 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
826 2443 aaronmk
        +' serial NOT NULL PRIMARY KEY', log_level=3)
827 2086 aaronmk
828 2548 aaronmk
def tables(db, schema_like='public', table_like='%'):
829 1849 aaronmk
    module = util.root_module(db.db)
830 2548 aaronmk
    params = {'schema_like': schema_like, 'table_like': table_like}
831 832 aaronmk
    if module == 'psycopg2':
832 1968 aaronmk
        return values(run_query(db, '''\
833
SELECT tablename
834
FROM pg_tables
835
WHERE
836 2548 aaronmk
    schemaname LIKE %(schema_like)s
837 1968 aaronmk
    AND tablename LIKE %(table_like)s
838
ORDER BY tablename
839
''',
840
            params, cacheable=True))
841
    elif module == 'MySQLdb':
842
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
843
            cacheable=True))
844 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
845 830 aaronmk
846 833 aaronmk
##### Database management
847
848 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
849
    '''For kw_args, see tables()'''
850
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
851 833 aaronmk
852 832 aaronmk
##### Heuristic queries
853
854 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
855 1554 aaronmk
    '''Recovers from errors.
856 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
857
    '''
858 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
859
860 471 aaronmk
    try:
861 2149 aaronmk
        cur = insert(db, table, row, pkey_, recover=True)
862 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
863
            row_ct_ref[0] += cur.rowcount
864
        return value(cur)
865 471 aaronmk
    except DuplicateKeyException, e:
866 2104 aaronmk
        return value(select(db, table, [pkey_],
867 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
868 471 aaronmk
869 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
870 830 aaronmk
    '''Recovers from errors'''
871 2209 aaronmk
    try: return value(select(db, table, [pkey], row, limit=1, recover=True))
872 14 aaronmk
    except StopIteration:
873 40 aaronmk
        if not create: raise
874 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
875 2078 aaronmk
876 2508 aaronmk
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, into=None,
877 2552 aaronmk
    default=None, is_func=False):
878 2078 aaronmk
    '''Recovers from errors.
879
    Only works under PostgreSQL (uses INSERT RETURNING).
880 2131 aaronmk
    @param in_tables The main input table to select from, followed by a list of
881
        tables to join with it using the main input table's pkey
882 2312 aaronmk
    @param mapping dict(out_table_col=in_table_col, ...)
883
        * out_table_col: sql_gen.Col|str
884 2323 aaronmk
        * in_table_col: sql_gen.Col Wrap literal values in a sql_gen.NamedCol
885 2489 aaronmk
    @param into The table to contain the output and input pkeys.
886 2574 aaronmk
        Defaults to `out_table.name+'_pkeys'`.
887 2509 aaronmk
    @param default The *output* column to use as the pkey for missing rows.
888
        If this output column does not exist in the mapping, uses None.
889 2552 aaronmk
    @param is_func Whether out_table is the name of a SQL function, not a table
890 2312 aaronmk
    @return sql_gen.Col Where the output pkeys are made available
891 2078 aaronmk
    '''
892 2329 aaronmk
    out_table = sql_gen.as_Table(out_table)
893 2565 aaronmk
    mapping = sql_gen.ColDict(mapping)
894 2555 aaronmk
    if into == None:
895
        into = out_table.name
896
        if is_func: into += '()'
897 2574 aaronmk
        else: into += '_pkeys'
898 2489 aaronmk
    into = sql_gen.as_Table(into)
899 2312 aaronmk
900 2450 aaronmk
    def log_debug(msg): db.log_debug(msg, level=1.5)
901 2505 aaronmk
    def col_ustr(str_):
902 2567 aaronmk
        return strings.repr_no_u(sql_gen.remove_col_rename(str_))
903 2450 aaronmk
904 2486 aaronmk
    log_debug('********** New iteration **********')
905 2505 aaronmk
    log_debug('Inserting these input columns into '+strings.as_tt(
906
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
907 2463 aaronmk
908 2382 aaronmk
    # Create input joins from list of input tables
909
    in_tables_ = in_tables[:] # don't modify input!
910
    in_tables0 = in_tables_.pop(0) # first table is separate
911 2279 aaronmk
    in_pkey = pkey(db, in_tables0, recover=True)
912 2285 aaronmk
    in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
913 2460 aaronmk
    input_joins = [in_tables0]+[sql_gen.Join(v,
914
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
915 2131 aaronmk
916 2486 aaronmk
    log_debug('Joining together input tables into temp table')
917 2395 aaronmk
    # Place in new table for speed and so don't modify input if values edited
918 2574 aaronmk
    in_table = sql_gen.Table(into.name.replace('_pkeys', '')+'_input')
919 2395 aaronmk
    flatten_cols = filter(sql_gen.is_table_col, mapping.values())
920
    mapping = dicts.join(mapping, flatten(db, in_table, input_joins,
921
        flatten_cols, preserve=[in_pkey_col], start=0))
922
    input_joins = [in_table]
923 2486 aaronmk
    db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
924 2395 aaronmk
925 2509 aaronmk
    # Resolve default value column
926
    try: default = mapping[default]
927
    except KeyError:
928
        if default != None:
929
            db.log_debug('Default value column '
930
                +strings.as_tt(strings.repr_no_u(default))
931 2511 aaronmk
                +' does not exist in mapping, falling back to None', level=2.1)
932 2509 aaronmk
            default = None
933
934 2279 aaronmk
    out_pkey = pkey(db, out_table, recover=True)
935 2285 aaronmk
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
936 2142 aaronmk
937 2387 aaronmk
    pkeys_names = [in_pkey, out_pkey]
938 2236 aaronmk
    pkeys_cols = [in_pkey_col, out_pkey_col]
939
940 2201 aaronmk
    pkeys_table_exists_ref = [False]
941 2420 aaronmk
    def insert_into_pkeys(joins, cols):
942
        query, params = mk_select(db, joins, cols, order_by=None, start=0)
943 2201 aaronmk
        if pkeys_table_exists_ref[0]:
944 2489 aaronmk
            insert_select(db, into, pkeys_names, query, params)
945 2201 aaronmk
        else:
946 2489 aaronmk
            run_query_into(db, query, params, into=into)
947 2201 aaronmk
            pkeys_table_exists_ref[0] = True
948
949 2429 aaronmk
    limit_ref = [None]
950 2380 aaronmk
    conds = set()
951 2233 aaronmk
    distinct_on = []
952 2325 aaronmk
    def mk_main_select(joins, cols):
953 2429 aaronmk
        return mk_select(db, joins, cols, conds, distinct_on,
954
            limit=limit_ref[0], start=0)
955 2132 aaronmk
956 2519 aaronmk
    exc_strs = set()
957 2309 aaronmk
    def log_exc(e):
958 2519 aaronmk
        e_str = exc.str_(e, first_line_only=True)
959
        log_debug('Caught exception: '+e_str)
960
        assert e_str not in exc_strs # avoid infinite loops
961
        exc_strs.add(e_str)
962 2451 aaronmk
    def remove_all_rows():
963 2450 aaronmk
        log_debug('Returning NULL for all rows')
964 2429 aaronmk
        limit_ref[0] = 0 # just create an empty pkeys table
965 2409 aaronmk
    def ignore(in_col, value):
966 2545 aaronmk
        in_col_str = strings.as_tt(repr(in_col))
967 2544 aaronmk
        db.log_debug('Adding index on '+in_col_str+' to enable fast filtering',
968
            level=2.5)
969 2537 aaronmk
        add_index(db, in_col)
970 2545 aaronmk
        log_debug('Ignoring rows with '+in_col_str+' = '
971
            +strings.as_tt(repr(value)))
972 2403 aaronmk
    def remove_rows(in_col, value):
973 2409 aaronmk
        ignore(in_col, value)
974 2378 aaronmk
        cond = (in_col, sql_gen.CompareCond(value, '!='))
975
        assert cond not in conds # avoid infinite loops
976 2380 aaronmk
        conds.add(cond)
977 2403 aaronmk
    def invalid2null(in_col, value):
978 2409 aaronmk
        ignore(in_col, value)
979 2403 aaronmk
        update(db, in_table, [(in_col, None)],
980
            sql_gen.ColValueCond(in_col, value))
981 2245 aaronmk
982 2206 aaronmk
    # Do inserts and selects
983 2565 aaronmk
    join_cols = sql_gen.ColDict()
984 2574 aaronmk
    insert_out_pkeys = sql_gen.Table(into.name+'_insert_out_pkeys')
985
    insert_in_pkeys = sql_gen.Table(into.name+'_insert_in_pkeys')
986 2206 aaronmk
    while True:
987 2521 aaronmk
        if limit_ref[0] == 0: # special case
988
            log_debug('Creating an empty pkeys table')
989
            cur = run_query_into(db, *mk_select(db, out_table, [out_pkey],
990
                limit=limit_ref[0]), into=insert_out_pkeys)
991
            break # don't do main case
992
993 2303 aaronmk
        has_joins = join_cols != {}
994
995 2305 aaronmk
        # Prepare to insert new rows
996 2325 aaronmk
        insert_joins = input_joins[:] # don't modify original!
997 2403 aaronmk
        insert_args = dict(recover=True, cacheable=False)
998 2303 aaronmk
        if has_joins:
999 2317 aaronmk
            distinct_on = [v.to_Col() for v in join_cols.values()]
1000 2325 aaronmk
            insert_joins.append(sql_gen.Join(out_table, join_cols,
1001
                sql_gen.filter_out))
1002
        else:
1003 2404 aaronmk
            insert_args.update(dict(returning=out_pkey, into=insert_out_pkeys))
1004 2520 aaronmk
        main_select = mk_main_select(insert_joins, mapping.values())[0]
1005 2303 aaronmk
1006 2486 aaronmk
        log_debug('Trying to insert new rows')
1007 2206 aaronmk
        try:
1008 2518 aaronmk
            cur = insert_select(db, out_table, mapping.keys(), main_select,
1009
                **insert_args)
1010 2357 aaronmk
            break # insert successful
1011 2206 aaronmk
        except DuplicateKeyException, e:
1012 2309 aaronmk
            log_exc(e)
1013
1014 2258 aaronmk
            old_join_cols = join_cols.copy()
1015 2565 aaronmk
            join_cols.update(util.dict_subset_right_join(mapping, e.cols))
1016 2486 aaronmk
            log_debug('Ignoring existing rows, comparing on these columns:\n'
1017 2505 aaronmk
                +strings.as_inline_table(join_cols, ustr=col_ustr))
1018 2258 aaronmk
            assert join_cols != old_join_cols # avoid infinite loops
1019 2230 aaronmk
        except NullValueException, e:
1020 2309 aaronmk
            log_exc(e)
1021
1022 2230 aaronmk
            out_col, = e.cols
1023
            try: in_col = mapping[out_col]
1024 2356 aaronmk
            except KeyError:
1025 2486 aaronmk
                log_debug('Missing mapping for NOT NULL column '+out_col)
1026 2451 aaronmk
                remove_all_rows()
1027 2403 aaronmk
            else: remove_rows(in_col, None)
1028 2542 aaronmk
        except FunctionValueException, e:
1029
            log_exc(e)
1030
1031
            func_name = e.name
1032
            value = e.value
1033
            for out_col, in_col in mapping.iteritems():
1034 2562 aaronmk
                invalid2null(sql_gen.unwrap_func_call(in_col, func_name), value)
1035 2525 aaronmk
        except MissingCastException, e:
1036
            log_exc(e)
1037
1038
            out_col = e.col
1039 2534 aaronmk
            mapping[out_col] = sql_gen.wrap_in_func(e.type, mapping[out_col])
1040 2429 aaronmk
        except DatabaseErrors, e:
1041
            log_exc(e)
1042
1043 2531 aaronmk
            msg = 'No handler for exception: '+exc.str_(e)
1044 2451 aaronmk
            warnings.warn(DbWarning(msg))
1045
            log_debug(msg)
1046
            remove_all_rows()
1047 2358 aaronmk
        # after exception handled, rerun loop with additional constraints
1048 2132 aaronmk
1049 2357 aaronmk
    if row_ct_ref != None and cur.rowcount >= 0:
1050
        row_ct_ref[0] += cur.rowcount
1051
1052
    if has_joins:
1053
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols)]
1054 2486 aaronmk
        log_debug('Getting output table pkeys of existing/inserted rows')
1055 2420 aaronmk
        insert_into_pkeys(select_joins, pkeys_cols)
1056 2357 aaronmk
    else:
1057 2404 aaronmk
        add_row_num(db, insert_out_pkeys) # for joining with input pkeys
1058 2357 aaronmk
1059 2486 aaronmk
        log_debug('Getting input table pkeys of inserted rows')
1060 2357 aaronmk
        run_query_into(db, *mk_main_select(input_joins, [in_pkey]),
1061 2404 aaronmk
            into=insert_in_pkeys)
1062
        add_row_num(db, insert_in_pkeys) # for joining with output pkeys
1063 2357 aaronmk
1064 2428 aaronmk
        assert table_row_count(db, insert_out_pkeys) == table_row_count(db,
1065
            insert_in_pkeys)
1066
1067 2486 aaronmk
        log_debug('Combining output and input pkeys in inserted order')
1068 2404 aaronmk
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
1069 2357 aaronmk
            {row_num_col: sql_gen.join_same_not_null})]
1070 2420 aaronmk
        insert_into_pkeys(pkey_joins, pkeys_names)
1071 2357 aaronmk
1072 2486 aaronmk
    db.log_debug('Adding pkey on pkeys table to enable fast joins', level=2.5)
1073 2489 aaronmk
    index_pkey(db, into)
1074 2407 aaronmk
1075 2508 aaronmk
    log_debug('Setting pkeys of missing rows to '+strings.as_tt(repr(default)))
1076 2489 aaronmk
    missing_rows_joins = input_joins+[sql_gen.Join(into,
1077 2357 aaronmk
        {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
1078
        # must use join_same_not_null or query will take forever
1079 2420 aaronmk
    insert_into_pkeys(missing_rows_joins,
1080 2508 aaronmk
        [in_pkey_col, sql_gen.NamedCol(out_pkey, default)])
1081 2357 aaronmk
1082 2489 aaronmk
    assert table_row_count(db, into) == table_row_count(db, in_table)
1083 2428 aaronmk
1084 2489 aaronmk
    return sql_gen.Col(out_pkey, into)
1085 2115 aaronmk
1086
##### Data cleanup
1087
1088 2290 aaronmk
def cleanup_table(db, table, cols):
1089 2115 aaronmk
    def esc_name_(name): return esc_name(db, name)
1090
1091 2290 aaronmk
    table = sql_gen.as_Table(table).to_str(db)
1092 2115 aaronmk
    cols = map(esc_name_, cols)
1093
1094
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
1095
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
1096
            for col in cols))),
1097
        dict(null0='', null1=r'\N'))