Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 2127 aaronmk
import operator
5 11 aaronmk
import re
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 1889 aaronmk
from Proxy import Proxy
13 1872 aaronmk
import rand
14 2217 aaronmk
import sql_gen
15 862 aaronmk
import strings
16 131 aaronmk
import util
17 11 aaronmk
18 832 aaronmk
##### Exceptions
19
20 2170 aaronmk
def get_cur_query(cur, input_query=None, input_params=None):
21 2168 aaronmk
    raw_query = None
22
    if hasattr(cur, 'query'): raw_query = cur.query
23
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
24 2170 aaronmk
25
    if raw_query != None: return raw_query
26 2371 aaronmk
    else: return '[input] '+strings.ustr(input_query)+' % '+repr(input_params)
27 14 aaronmk
28 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
29
    '''For params, see get_cur_query()'''
30
    exc.add_msg(e, 'query: '+str(get_cur_query(*args, **kw_args)))
31 135 aaronmk
32 300 aaronmk
class DbException(exc.ExceptionWithCause):
33 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
34 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
35 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
36
37 2143 aaronmk
class ExceptionWithName(DbException):
38
    def __init__(self, name, cause=None):
39 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
40 2143 aaronmk
        self.name = name
41 360 aaronmk
42 2240 aaronmk
class ExceptionWithNameValue(DbException):
43
    def __init__(self, name, value, cause=None):
44 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name))
45
            +'; value: '+strings.as_tt(repr(value)), cause)
46 2240 aaronmk
        self.name = name
47
        self.value = value
48
49 2306 aaronmk
class ConstraintException(DbException):
50
    def __init__(self, name, cols, cause=None):
51 2484 aaronmk
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
52
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
53 2306 aaronmk
        self.name = name
54 468 aaronmk
        self.cols = cols
55 11 aaronmk
56 2523 aaronmk
class MissingCastException(DbException):
57
    def __init__(self, type_, col, cause=None):
58
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
59
            +' on column: '+strings.as_tt(col), cause)
60
        self.type = type_
61
        self.col = col
62
63 2143 aaronmk
class NameException(DbException): pass
64
65 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
66 13 aaronmk
67 2306 aaronmk
class NullValueException(ConstraintException): pass
68 13 aaronmk
69 2240 aaronmk
class FunctionValueException(ExceptionWithNameValue): pass
70 2239 aaronmk
71 2143 aaronmk
class DuplicateTableException(ExceptionWithName): pass
72
73 2188 aaronmk
class DuplicateFunctionException(ExceptionWithName): pass
74
75 89 aaronmk
class EmptyRowException(DbException): pass
76
77 865 aaronmk
##### Warnings
78
79
class DbWarning(UserWarning): pass
80
81 1930 aaronmk
##### Result retrieval
82
83
def col_names(cur): return (col[0] for col in cur.description)
84
85
def rows(cur): return iter(lambda: cur.fetchone(), None)
86
87
def consume_rows(cur):
88
    '''Used to fetch all rows so result will be cached'''
89
    iters.consume_iter(rows(cur))
90
91
def next_row(cur): return rows(cur).next()
92
93
def row(cur):
94
    row_ = next_row(cur)
95
    consume_rows(cur)
96
    return row_
97
98
def next_value(cur): return next_row(cur)[0]
99
100
def value(cur): return row(cur)[0]
101
102
def values(cur): return iters.func_iter(lambda: next_value(cur))
103
104
def value_or_none(cur):
105
    try: return value(cur)
106
    except StopIteration: return None
107
108 2101 aaronmk
##### Input validation
109
110 2573 aaronmk
def esc_name_by_module(module, name):
111
    if module == 'psycopg2' or module == None: quote = '"'
112 2101 aaronmk
    elif module == 'MySQLdb': quote = '`'
113
    else: raise NotImplementedError("Can't escape name for "+module+' database')
114 2500 aaronmk
    return sql_gen.esc_name(name, quote)
115 2101 aaronmk
116
def esc_name_by_engine(engine, name, **kw_args):
117
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
118
119
def esc_name(db, name, **kw_args):
120
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
121
122
def qual_name(db, schema, table):
123
    def esc_name_(name): return esc_name(db, name)
124
    table = esc_name_(table)
125
    if schema != None: return esc_name_(schema)+'.'+table
126
    else: return table
127
128 1869 aaronmk
##### Database connections
129 1849 aaronmk
130 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
131 1926 aaronmk
132 1869 aaronmk
db_engines = {
133
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
134
    'PostgreSQL': ('psycopg2', {}),
135
}
136
137
DatabaseErrors_set = set([DbException])
138
DatabaseErrors = tuple(DatabaseErrors_set)
139
140
def _add_module(module):
141
    DatabaseErrors_set.add(module.DatabaseError)
142
    global DatabaseErrors
143
    DatabaseErrors = tuple(DatabaseErrors_set)
144
145
def db_config_str(db_config):
146
    return db_config['engine']+' database '+db_config['database']
147
148 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
149 1894 aaronmk
150 2448 aaronmk
log_debug_none = lambda msg, level=2: None
151 1901 aaronmk
152 1849 aaronmk
class DbConn:
153 2190 aaronmk
    def __init__(self, db_config, serializable=True, autocommit=False,
154
        caching=True, log_debug=log_debug_none):
155 1869 aaronmk
        self.db_config = db_config
156
        self.serializable = serializable
157 2190 aaronmk
        self.autocommit = autocommit
158
        self.caching = caching
159 1901 aaronmk
        self.log_debug = log_debug
160 2193 aaronmk
        self.debug = log_debug != log_debug_none
161 1869 aaronmk
162
        self.__db = None
163 1889 aaronmk
        self.query_results = {}
164 2139 aaronmk
        self._savepoint = 0
165 1869 aaronmk
166
    def __getattr__(self, name):
167
        if name == '__dict__': raise Exception('getting __dict__')
168
        if name == 'db': return self._db()
169
        else: raise AttributeError()
170
171
    def __getstate__(self):
172
        state = copy.copy(self.__dict__) # shallow copy
173 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
174 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
175
        return state
176
177 2165 aaronmk
    def connected(self): return self.__db != None
178
179 1869 aaronmk
    def _db(self):
180
        if self.__db == None:
181
            # Process db_config
182
            db_config = self.db_config.copy() # don't modify input!
183 2097 aaronmk
            schemas = db_config.pop('schemas', None)
184 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
185
            module = __import__(module_name)
186
            _add_module(module)
187
            for orig, new in mappings.iteritems():
188
                try: util.rename_key(db_config, orig, new)
189
                except KeyError: pass
190
191
            # Connect
192
            self.__db = module.connect(**db_config)
193
194
            # Configure connection
195 2234 aaronmk
            if self.serializable and not self.autocommit: run_raw_query(self,
196
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
197 2101 aaronmk
            if schemas != None:
198
                schemas_ = ''.join((esc_name(self, s)+', '
199
                    for s in schemas.split(',')))
200
                run_raw_query(self, "SELECT set_config('search_path', \
201
%s || current_setting('search_path'), false)", [schemas_])
202 1869 aaronmk
203
        return self.__db
204 1889 aaronmk
205 1891 aaronmk
    class DbCursor(Proxy):
206 1927 aaronmk
        def __init__(self, outer):
207 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
208 2191 aaronmk
            self.outer = outer
209 1927 aaronmk
            self.query_results = outer.query_results
210 1894 aaronmk
            self.query_lookup = None
211 1891 aaronmk
            self.result = []
212 1889 aaronmk
213 1894 aaronmk
        def execute(self, query, params=None):
214 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
215 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
216 2148 aaronmk
            try:
217 2191 aaronmk
                try:
218
                    return_value = self.inner.execute(query, params)
219
                    self.outer.do_autocommit()
220 2148 aaronmk
                finally: self.query = get_cur_query(self.inner)
221 1904 aaronmk
            except Exception, e:
222 2170 aaronmk
                _add_cursor_info(e, self, query, params)
223 1904 aaronmk
                self.result = e # cache the exception as the result
224
                self._cache_result()
225
                raise
226 1930 aaronmk
            # Fetch all rows so result will be cached
227
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
228 1894 aaronmk
            return return_value
229
230 1891 aaronmk
        def fetchone(self):
231
            row = self.inner.fetchone()
232 1899 aaronmk
            if row != None: self.result.append(row)
233
            # otherwise, fetched all rows
234 1904 aaronmk
            else: self._cache_result()
235
            return row
236
237
        def _cache_result(self):
238 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
239
            # idempotent, but an invalid insert will always be invalid
240 1930 aaronmk
            if self.query_results != None and (not self._is_insert
241 1906 aaronmk
                or isinstance(self.result, Exception)):
242
243 1894 aaronmk
                assert self.query_lookup != None
244 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
245
                    util.dict_subset(dicts.AttrsDictView(self),
246
                    ['query', 'result', 'rowcount', 'description']))
247 1906 aaronmk
248 1916 aaronmk
        class CacheCursor:
249
            def __init__(self, cached_result): self.__dict__ = cached_result
250
251 1927 aaronmk
            def execute(self, *args, **kw_args):
252 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
253
                # otherwise, result is a rows list
254
                self.iter = iter(self.result)
255
256
            def fetchone(self):
257
                try: return self.iter.next()
258
                except StopIteration: return None
259 1891 aaronmk
260 2212 aaronmk
    def esc_value(self, value):
261 2215 aaronmk
        module = util.root_module(self.db)
262 2374 aaronmk
        if module == 'psycopg2': str_ = self.db.cursor().mogrify('%s', [value])
263 2212 aaronmk
        elif module == 'MySQLdb':
264
            import _mysql
265 2374 aaronmk
            str_ = _mysql.escape_string(value)
266 2212 aaronmk
        else: raise NotImplementedError("Can't escape value for "+module
267
            +' database')
268 2374 aaronmk
        return strings.to_unicode(str_)
269 2212 aaronmk
270 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
271
272 2445 aaronmk
    def run_query(self, query, params=None, cacheable=False, log_level=2,
273 2464 aaronmk
        debug_msg_ref=None):
274 2445 aaronmk
        '''
275 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
276
            throws one of these exceptions.
277 2445 aaronmk
        '''
278 2167 aaronmk
        assert query != None
279
280 2047 aaronmk
        if not self.caching: cacheable = False
281 1903 aaronmk
        used_cache = False
282
        try:
283 1927 aaronmk
            # Get cursor
284
            if cacheable:
285
                query_lookup = _query_lookup(query, params)
286
                try:
287
                    cur = self.query_results[query_lookup]
288
                    used_cache = True
289
                except KeyError: cur = self.DbCursor(self)
290
            else: cur = self.db.cursor()
291
292
            # Run query
293 2148 aaronmk
            cur.execute(query, params)
294 1903 aaronmk
        finally:
295 2464 aaronmk
            if self.debug and debug_msg_ref != None:# only compute msg if needed
296 2470 aaronmk
                if used_cache: cache_status = 'cache hit'
297
                elif cacheable: cache_status = 'cache miss'
298
                else: cache_status = 'non-cacheable'
299 2472 aaronmk
                query_code = strings.as_code(str(get_cur_query(cur, query,
300
                    params)), 'SQL')
301
                debug_msg_ref[0] = 'DB query: '+cache_status+':\n'+query_code
302 1903 aaronmk
303
        return cur
304 1914 aaronmk
305
    def is_cached(self, query, params=None):
306
        return _query_lookup(query, params) in self.query_results
307 2139 aaronmk
308
    def with_savepoint(self, func):
309 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
310 2443 aaronmk
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
311 2139 aaronmk
        self._savepoint += 1
312
        try:
313
            try: return_val = func()
314
            finally:
315
                self._savepoint -= 1
316
                assert self._savepoint >= 0
317
        except:
318 2443 aaronmk
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
319 2139 aaronmk
            raise
320
        else:
321 2443 aaronmk
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
322 2191 aaronmk
            self.do_autocommit()
323 2139 aaronmk
            return return_val
324 2191 aaronmk
325
    def do_autocommit(self):
326
        '''Autocommits if outside savepoint'''
327
        assert self._savepoint >= 0
328
        if self.autocommit and self._savepoint == 0:
329
            self.log_debug('Autocommiting')
330
            self.db.commit()
331 1849 aaronmk
332 1869 aaronmk
connect = DbConn
333
334 832 aaronmk
##### Querying
335
336 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
337 2085 aaronmk
    '''For params, see DbConn.run_query()'''
338 1894 aaronmk
    return db.run_query(*args, **kw_args)
339 11 aaronmk
340 2068 aaronmk
def mogrify(db, query, params):
341
    module = util.root_module(db.db)
342
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
343
    else: raise NotImplementedError("Can't mogrify query for "+module+
344
        ' database')
345
346 832 aaronmk
##### Recoverable querying
347 15 aaronmk
348 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
349 11 aaronmk
350 2464 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False,
351
    log_level=2, log_ignore_excs=None, **kw_args):
352 2441 aaronmk
    '''For params, see run_raw_query()'''
353 830 aaronmk
    if recover == None: recover = False
354 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
355
    log_ignore_excs = tuple(log_ignore_excs)
356 830 aaronmk
357 2464 aaronmk
    debug_msg_ref = [None]
358 2148 aaronmk
    try:
359 2464 aaronmk
        try:
360
            def run(): return run_raw_query(db, query, params, cacheable,
361
                log_level, debug_msg_ref, **kw_args)
362
            if recover and not db.is_cached(query, params):
363
                return with_savepoint(db, run)
364
            else: return run() # don't need savepoint if cached
365
        except Exception, e:
366
            if not recover: raise # need savepoint to run index_cols()
367
            msg = exc.str_(e)
368
369
            match = re.search(r'duplicate key value violates unique constraint '
370 2493 aaronmk
                r'"((_?[^\W_]+)_.+?)"', msg)
371 2464 aaronmk
            if match:
372
                constraint, table = match.groups()
373
                try: cols = index_cols(db, table, constraint)
374
                except NotImplementedError: raise e
375
                else: raise DuplicateKeyException(constraint, cols, e)
376
377 2493 aaronmk
            match = re.search(r'null value in column "(.+?)" violates not-null'
378 2464 aaronmk
                r' constraint', msg)
379
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
380
381
            match = re.search(r'\b(?:invalid input (?:syntax|value)\b.*?'
382
                r'|date/time field value out of range): "(.+?)"\n'
383 2535 aaronmk
                r'(?:(?s).*?)\bfunction "(.+?)"', msg)
384 2464 aaronmk
            if match:
385
                value, name = match.groups()
386
                raise FunctionValueException(name, strings.to_unicode(value), e)
387
388 2526 aaronmk
            match = re.search(r'column "(.+?)" is of type (.+?) but expression '
389 2523 aaronmk
                r'is of type', msg)
390
            if match:
391
                col, type_ = match.groups()
392
                raise MissingCastException(type_, col, e)
393
394 2493 aaronmk
            match = re.search(r'relation "(.+?)" already exists', msg)
395 2464 aaronmk
            if match: raise DuplicateTableException(match.group(1), e)
396
397 2493 aaronmk
            match = re.search(r'function "(.+?)" already exists', msg)
398 2464 aaronmk
            if match: raise DuplicateFunctionException(match.group(1), e)
399
400
            raise # no specific exception raised
401
    except log_ignore_excs:
402
        log_level += 2
403
        raise
404
    finally:
405
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)
406 830 aaronmk
407 832 aaronmk
##### Basic queries
408
409 2153 aaronmk
def next_version(name):
410 2163 aaronmk
    version = 1 # first existing name was version 0
411 2586 aaronmk
    match = re.match(r'^(.*)#(\d+)$', name)
412 2153 aaronmk
    if match:
413 2586 aaronmk
        name, version = match.groups()
414
        version = int(version)+1
415 2588 aaronmk
    return sql_gen.add_suffix(name, '#'+str(version))
416 2153 aaronmk
417 2386 aaronmk
def run_query_into(db, query, params, into=None, *args, **kw_args):
418 2085 aaronmk
    '''Outputs a query to a temp table.
419
    For params, see run_query().
420
    '''
421 2386 aaronmk
    if into == None: return run_query(db, query, params, *args, **kw_args)
422 2085 aaronmk
    else: # place rows in temp table
423 2386 aaronmk
        assert isinstance(into, sql_gen.Table)
424 2385 aaronmk
425 2153 aaronmk
        kw_args['recover'] = True
426 2464 aaronmk
        kw_args.setdefault('log_ignore_excs', (DuplicateTableException,))
427 2440 aaronmk
428 2468 aaronmk
        temp = not db.autocommit # tables are permanent in autocommit mode
429 2440 aaronmk
        # "temporary tables cannot specify a schema name", so remove schema
430
        if temp: into.schema = None
431
432 2153 aaronmk
        while True:
433
            try:
434 2194 aaronmk
                create_query = 'CREATE'
435 2440 aaronmk
                if temp: create_query += ' TEMP'
436 2467 aaronmk
                create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
437 2194 aaronmk
438
                return run_query(db, create_query, params, *args, **kw_args)
439 2153 aaronmk
                    # CREATE TABLE AS sets rowcount to # rows in query
440
            except DuplicateTableException, e:
441 2386 aaronmk
                into.name = next_version(into.name)
442 2153 aaronmk
                # try again with next version of name
443 2085 aaronmk
444 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
445
446 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
447
448 2233 aaronmk
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
449 2293 aaronmk
    start=None, order_by=order_by_pkey, default_table=None):
450 1981 aaronmk
    '''
451 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
452 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
453 1981 aaronmk
    @param fields Use None to select all fields in the table
454 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
455 2379 aaronmk
        * container can be any iterable type
456 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
457
        * compare_right_side: sql_gen.ValueCond|literal value
458 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
459
        use all columns
460 2054 aaronmk
    @return tuple(query, params)
461 1981 aaronmk
    '''
462 2315 aaronmk
    # Parse tables param
463 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
464 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
465 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
466 2121 aaronmk
467 2315 aaronmk
    # Parse other params
468 2376 aaronmk
    if conds == None: conds = []
469
    elif isinstance(conds, dict): conds = conds.items()
470 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
471 135 aaronmk
    assert limit == None or type(limit) == int
472 865 aaronmk
    assert start == None or type(start) == int
473 2315 aaronmk
    if order_by is order_by_pkey:
474
        if distinct_on != []: order_by = None
475
        else: order_by = pkey(db, table0, recover=True)
476 865 aaronmk
477 2315 aaronmk
    query = 'SELECT'
478 2056 aaronmk
479 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
480 2056 aaronmk
481 2200 aaronmk
    # DISTINCT ON columns
482 2233 aaronmk
    if distinct_on != []:
483 2467 aaronmk
        query += '\nDISTINCT'
484 2254 aaronmk
        if distinct_on is not distinct_on_all:
485 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
486
487
    # Columns
488 2467 aaronmk
    query += '\n'
489 1135 aaronmk
    if fields == None: query += '*'
490 2479 aaronmk
    else: query += '\n, '.join(map(parse_col, fields))
491 2200 aaronmk
492
    # Main table
493 2467 aaronmk
    query += '\nFROM '+table0.to_str(db)
494 865 aaronmk
495 2122 aaronmk
    # Add joins
496 2271 aaronmk
    left_table = table0
497 2263 aaronmk
    for join_ in tables:
498
        table = join_.table
499 2238 aaronmk
500 2343 aaronmk
        # Parse special values
501
        if join_.type_ is sql_gen.filter_out: # filter no match
502 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
503
                None))
504 2343 aaronmk
505 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
506 2122 aaronmk
507
        left_table = table
508
509 865 aaronmk
    missing = True
510 2376 aaronmk
    if conds != []:
511 2576 aaronmk
        if len(conds) == 1: whitespace = ' '
512
        else: whitespace = '\n'
513 2578 aaronmk
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
514
            .to_str(db) for l, r in conds], 'WHERE')
515 865 aaronmk
        missing = False
516 2227 aaronmk
    if order_by != None:
517 2467 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
518
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
519 865 aaronmk
    if start != None:
520 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
521 865 aaronmk
        missing = False
522
    if missing: warnings.warn(DbWarning(
523
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
524
525 2315 aaronmk
    return (query, [])
526 11 aaronmk
527 2054 aaronmk
def select(db, *args, **kw_args):
528
    '''For params, see mk_select() and run_query()'''
529
    recover = kw_args.pop('recover', None)
530
    cacheable = kw_args.pop('cacheable', True)
531 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
532 2054 aaronmk
533
    query, params = mk_select(db, *args, **kw_args)
534 2442 aaronmk
    return run_query(db, query, params, recover, cacheable, log_level=log_level)
535 2054 aaronmk
536 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
537 2292 aaronmk
    returning=None, embeddable=False):
538 1960 aaronmk
    '''
539
    @param returning str|None An inserted column (such as pkey) to return
540 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
541 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
542
        query will be fully cached, not just if it raises an exception.
543 1960 aaronmk
    '''
544 2328 aaronmk
    table = sql_gen.as_Table(table)
545 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
546
    if cols != None: cols = [sql_gen.as_Col(v).to_str(db) for v in cols]
547 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
548 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
549 2063 aaronmk
550
    # Build query
551 2497 aaronmk
    first_line = 'INSERT INTO '+table.to_str(db)
552
    query = first_line
553 2467 aaronmk
    if cols != None: query += '\n('+', '.join(cols)+')'
554
    query += '\n'+select_query
555 2063 aaronmk
556
    if returning != None:
557 2327 aaronmk
        returning_name = copy.copy(returning)
558
        returning_name.table = None
559
        returning_name = returning_name.to_str(db)
560 2467 aaronmk
        query += '\nRETURNING '+returning_name
561 2063 aaronmk
562 2070 aaronmk
    if embeddable:
563 2327 aaronmk
        assert returning != None
564
565 2070 aaronmk
        # Create function
566 2513 aaronmk
        function_name = sql_gen.clean_name(first_line)
567 2327 aaronmk
        return_type = 'SETOF '+returning.to_str(db)+'%TYPE'
568 2189 aaronmk
        while True:
569
            try:
570 2327 aaronmk
                func_schema = None
571 2468 aaronmk
                if not db.autocommit: func_schema = 'pg_temp'
572 2327 aaronmk
                function = sql_gen.Table(function_name, func_schema).to_str(db)
573 2194 aaronmk
574 2189 aaronmk
                function_query = '''\
575 2467 aaronmk
CREATE FUNCTION '''+function+'''()
576
RETURNS '''+return_type+'''
577
LANGUAGE sql
578
AS $$
579
'''+mogrify(db, query, params)+''';
580
$$;
581 2070 aaronmk
'''
582 2446 aaronmk
                run_query(db, function_query, recover=True, cacheable=True,
583 2464 aaronmk
                    log_ignore_excs=(DuplicateFunctionException,))
584 2189 aaronmk
                break # this version was successful
585
            except DuplicateFunctionException, e:
586
                function_name = next_version(function_name)
587
                # try again with next version of name
588 2070 aaronmk
589 2337 aaronmk
        # Return query that uses function
590
        func_table = sql_gen.NamedTable('f', sql_gen.CustomCode(function+'()'),
591
            [returning_name]) # AS clause requires function alias
592
        return mk_select(db, func_table, start=0, order_by=None)
593 2070 aaronmk
594 2066 aaronmk
    return (query, params)
595
596
def insert_select(db, *args, **kw_args):
597 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
598 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
599
        values in
600 2072 aaronmk
    '''
601 2386 aaronmk
    into = kw_args.pop('into', None)
602
    if into != None: kw_args['embeddable'] = True
603 2066 aaronmk
    recover = kw_args.pop('recover', None)
604
    cacheable = kw_args.pop('cacheable', True)
605
606
    query, params = mk_insert_select(db, *args, **kw_args)
607 2386 aaronmk
    return run_query_into(db, query, params, into, recover=recover,
608 2153 aaronmk
        cacheable=cacheable)
609 2063 aaronmk
610 2066 aaronmk
default = object() # tells insert() to use the default value for a column
611
612 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
613 2085 aaronmk
    '''For params, see insert_select()'''
614 1960 aaronmk
    if lists.is_seq(row): cols = None
615
    else:
616
        cols = row.keys()
617
        row = row.values()
618
    row = list(row) # ensure that "!= []" works
619
620 1961 aaronmk
    # Check for special values
621
    labels = []
622
    values = []
623
    for value in row:
624 2254 aaronmk
        if value is default: labels.append('DEFAULT')
625 1961 aaronmk
        else:
626
            labels.append('%s')
627
            values.append(value)
628
629
    # Build query
630 2467 aaronmk
    if values != []: query = 'VALUES ('+(', '.join(labels))+')'
631 2063 aaronmk
    else: query = None
632 1554 aaronmk
633 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
634 11 aaronmk
635 2402 aaronmk
def mk_update(db, table, changes=None, cond=None):
636
    '''
637
    @param changes [(col, new_value),...]
638
        * container can be any iterable type
639
        * col: sql_gen.Code|str (for col name)
640
        * new_value: sql_gen.Code|literal value
641
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
642
    @return str query
643
    '''
644
    query = 'UPDATE '+sql_gen.as_Table(table).to_str(db)+'\nSET\n'
645 2405 aaronmk
    query += ',\n'.join((sql_gen.to_name_only_col(col, table).to_str(db)+' = '
646 2402 aaronmk
        +sql_gen.as_Value(new_value).to_str(db) for col, new_value in changes))
647 2467 aaronmk
    if cond != None: query += '\nWHERE\n'+cond.to_str(db)
648 2402 aaronmk
649
    return query
650
651
def update(db, *args, **kw_args):
652
    '''For params, see mk_update() and run_query()'''
653
    recover = kw_args.pop('recover', None)
654
655
    return run_query(db, mk_update(db, *args, **kw_args), [], recover)
656
657 135 aaronmk
def last_insert_id(db):
658 1849 aaronmk
    module = util.root_module(db.db)
659 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
660
    elif module == 'MySQLdb': return db.insert_id()
661
    else: return None
662 13 aaronmk
663 1968 aaronmk
def truncate(db, table, schema='public'):
664
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
665 832 aaronmk
666 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
667 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
668 2415 aaronmk
    to names that will be distinct among the columns' tables.
669 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
670 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
671
    @param into The table for the new columns.
672 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
673
        columns will be included in the mapping even if they are not in cols.
674
        The tables of the provided Col objects will be changed to into, so make
675
        copies of them if you want to keep the original tables.
676
    @param as_items Whether to return a list of dict items instead of a dict
677 2383 aaronmk
    @return dict(orig_col=new_col, ...)
678
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
679 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
680
        * All mappings use the into table so its name can easily be
681 2383 aaronmk
          changed for all columns at once
682
    '''
683 2415 aaronmk
    cols = lists.uniqify(cols)
684
685 2394 aaronmk
    items = []
686 2389 aaronmk
    for col in preserve:
687 2390 aaronmk
        orig_col = copy.copy(col)
688 2392 aaronmk
        col.table = into
689 2394 aaronmk
        items.append((orig_col, col))
690
    preserve = set(preserve)
691
    for col in cols:
692 2515 aaronmk
        if col not in preserve: items.append((col, sql_gen.Col(str(col), into)))
693 2394 aaronmk
694
    if not as_items: items = dict(items)
695
    return items
696 2383 aaronmk
697 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
698 2391 aaronmk
    '''For params, see mk_flatten_mapping()
699
    @return See return value of mk_flatten_mapping()
700
    '''
701 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
702
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
703 2391 aaronmk
    run_query_into(db, *mk_select(db, joins, cols, limit=limit, start=start),
704 2392 aaronmk
        into=into)
705 2394 aaronmk
    return dict(items)
706 2391 aaronmk
707 2414 aaronmk
##### Database structure queries
708
709 2426 aaronmk
def table_row_count(db, table, recover=None):
710
    return value(run_query(db, *mk_select(db, table, [sql_gen.row_count],
711 2443 aaronmk
        order_by=None, start=0), recover=recover, log_level=3))
712 2426 aaronmk
713 2414 aaronmk
def table_cols(db, table, recover=None):
714
    return list(col_names(select(db, table, limit=0, order_by=None,
715 2443 aaronmk
        recover=recover, log_level=4)))
716 2414 aaronmk
717 2291 aaronmk
def pkey(db, table, recover=None):
718 832 aaronmk
    '''Assumed to be first column in table'''
719 2339 aaronmk
    return table_cols(db, table, recover)[0]
720 832 aaronmk
721 2559 aaronmk
not_null_col = 'not_null_col'
722 2340 aaronmk
723
def table_not_null_col(db, table, recover=None):
724
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
725
    if not_null_col in table_cols(db, table, recover): return not_null_col
726
    else: return pkey(db, table, recover)
727
728 853 aaronmk
def index_cols(db, table, index):
729
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
730
    automatically created. When you don't know whether something is a UNIQUE
731
    constraint or a UNIQUE index, use this function.'''
732 1909 aaronmk
    module = util.root_module(db.db)
733
    if module == 'psycopg2':
734
        return list(values(run_query(db, '''\
735 853 aaronmk
SELECT attname
736 866 aaronmk
FROM
737
(
738
        SELECT attnum, attname
739
        FROM pg_index
740
        JOIN pg_class index ON index.oid = indexrelid
741
        JOIN pg_class table_ ON table_.oid = indrelid
742
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
743
        WHERE
744
            table_.relname = %(table)s
745
            AND index.relname = %(index)s
746
    UNION
747
        SELECT attnum, attname
748
        FROM
749
        (
750
            SELECT
751
                indrelid
752
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
753
                    AS indkey
754
            FROM pg_index
755
            JOIN pg_class index ON index.oid = indexrelid
756
            JOIN pg_class table_ ON table_.oid = indrelid
757
            WHERE
758
                table_.relname = %(table)s
759
                AND index.relname = %(index)s
760
        ) s
761
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
762
) s
763 853 aaronmk
ORDER BY attnum
764
''',
765 2443 aaronmk
            {'table': table, 'index': index}, cacheable=True, log_level=4)))
766 1909 aaronmk
    else: raise NotImplementedError("Can't list index columns for "+module+
767
        ' database')
768 853 aaronmk
769 464 aaronmk
def constraint_cols(db, table, constraint):
770 1849 aaronmk
    module = util.root_module(db.db)
771 464 aaronmk
    if module == 'psycopg2':
772
        return list(values(run_query(db, '''\
773
SELECT attname
774
FROM pg_constraint
775
JOIN pg_class ON pg_class.oid = conrelid
776
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
777
WHERE
778
    relname = %(table)s
779
    AND conname = %(constraint)s
780
ORDER BY attnum
781
''',
782
            {'table': table, 'constraint': constraint})))
783
    else: raise NotImplementedError("Can't list constraint columns for "+module+
784
        ' database')
785
786 2096 aaronmk
row_num_col = '_row_num'
787
788 2538 aaronmk
def add_index(db, expr):
789
    '''Adds an index on a column or expression if it doesn't already exist.
790
    Currently, only function calls are supported as expressions.
791
    '''
792
    expr = copy.copy(expr) # don't modify input!
793
794
    # Extract col
795 2539 aaronmk
    if isinstance(expr, sql_gen.FunctionCall):
796
        col = expr.args[0]
797 2541 aaronmk
        expr = sql_gen.Expr(expr)
798 2538 aaronmk
    else: col = expr
799 2408 aaronmk
    assert sql_gen.is_table_col(col)
800
801 2538 aaronmk
    index = sql_gen.as_Table(str(expr))
802 2408 aaronmk
    table = col.table
803 2538 aaronmk
    col.table = None
804 2408 aaronmk
    try: run_query(db, 'CREATE INDEX '+index.to_str(db)+' ON '+table.to_str(db)
805 2538 aaronmk
        +' ('+expr.to_str(db)+')', recover=True, cacheable=True, log_level=3)
806 2408 aaronmk
    except DuplicateTableException: pass # index already existed
807
808 2594 aaronmk
def add_pkey(db, table, recover=None):
809 2406 aaronmk
    '''Makes the first column in a table the primary key.
810
    @pre The table must not already have a primary key.
811
    '''
812
    table = sql_gen.as_Table(table)
813
814 2590 aaronmk
    index = sql_gen.as_Table(sql_gen.add_suffix(table.name, '_pkey'))
815 2406 aaronmk
    col = sql_gen.to_name_only_col(pkey(db, table, recover))
816
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD CONSTRAINT '
817 2443 aaronmk
        +index.to_str(db)+' PRIMARY KEY('+col.to_str(db)+')', recover=recover,
818
        log_level=3)
819 2406 aaronmk
820 2086 aaronmk
def add_row_num(db, table):
821 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
822
    be the primary key.'''
823 2320 aaronmk
    table = sql_gen.as_Table(table).to_str(db)
824 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
825 2443 aaronmk
        +' serial NOT NULL PRIMARY KEY', log_level=3)
826 2086 aaronmk
827 2548 aaronmk
def tables(db, schema_like='public', table_like='%'):
828 1849 aaronmk
    module = util.root_module(db.db)
829 2548 aaronmk
    params = {'schema_like': schema_like, 'table_like': table_like}
830 832 aaronmk
    if module == 'psycopg2':
831 1968 aaronmk
        return values(run_query(db, '''\
832
SELECT tablename
833
FROM pg_tables
834
WHERE
835 2548 aaronmk
    schemaname LIKE %(schema_like)s
836 1968 aaronmk
    AND tablename LIKE %(table_like)s
837
ORDER BY tablename
838
''',
839
            params, cacheable=True))
840
    elif module == 'MySQLdb':
841
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
842
            cacheable=True))
843 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
844 830 aaronmk
845 833 aaronmk
##### Database management
846
847 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
848
    '''For kw_args, see tables()'''
849
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
850 833 aaronmk
851 832 aaronmk
##### Heuristic queries
852
853 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
854 1554 aaronmk
    '''Recovers from errors.
855 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
856
    '''
857 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
858
859 471 aaronmk
    try:
860 2149 aaronmk
        cur = insert(db, table, row, pkey_, recover=True)
861 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
862
            row_ct_ref[0] += cur.rowcount
863
        return value(cur)
864 471 aaronmk
    except DuplicateKeyException, e:
865 2104 aaronmk
        return value(select(db, table, [pkey_],
866 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
867 471 aaronmk
868 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
869 830 aaronmk
    '''Recovers from errors'''
870 2209 aaronmk
    try: return value(select(db, table, [pkey], row, limit=1, recover=True))
871 14 aaronmk
    except StopIteration:
872 40 aaronmk
        if not create: raise
873 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
874 2078 aaronmk
875 2593 aaronmk
def is_func_result(col):
876
    return col.table.name.find('(') >= 0 and col.name == 'result'
877
878 2592 aaronmk
def into_table_name(out_table, in_tables0, mapping, is_func):
879 2580 aaronmk
    str_ = str(out_table)
880
    if is_func:
881
        def col(out_col, in_col):
882
            # Add out_col
883
            out_col = sql_gen.to_name_only_col(out_col)
884
            str_ = ''
885
            if out_col.name != 'value': str_ += str(out_col)+'='
886
887
            # Add in_col
888 2592 aaronmk
            in_col = sql_gen.remove_col_rename(in_col)
889 2593 aaronmk
            if isinstance(in_col, sql_gen.Col):
890
                table = in_col.table
891
                if table == in_tables0:
892
                    in_col = sql_gen.to_name_only_col(in_col)
893
                elif is_func_result(in_col): in_col = table # omit col name
894 2580 aaronmk
            str_ += str(in_col)
895
896
            return str_
897
898
        str_ += '('+(', '.join((col(k, v) for k, v in mapping.iteritems())))+')'
899
    else: str_ += '_pkeys'
900
    return str_
901
902 2508 aaronmk
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, into=None,
903 2552 aaronmk
    default=None, is_func=False):
904 2078 aaronmk
    '''Recovers from errors.
905
    Only works under PostgreSQL (uses INSERT RETURNING).
906 2131 aaronmk
    @param in_tables The main input table to select from, followed by a list of
907
        tables to join with it using the main input table's pkey
908 2312 aaronmk
    @param mapping dict(out_table_col=in_table_col, ...)
909
        * out_table_col: sql_gen.Col|str
910 2323 aaronmk
        * in_table_col: sql_gen.Col Wrap literal values in a sql_gen.NamedCol
911 2489 aaronmk
    @param into The table to contain the output and input pkeys.
912 2574 aaronmk
        Defaults to `out_table.name+'_pkeys'`.
913 2509 aaronmk
    @param default The *output* column to use as the pkey for missing rows.
914
        If this output column does not exist in the mapping, uses None.
915 2552 aaronmk
    @param is_func Whether out_table is the name of a SQL function, not a table
916 2312 aaronmk
    @return sql_gen.Col Where the output pkeys are made available
917 2078 aaronmk
    '''
918 2329 aaronmk
    out_table = sql_gen.as_Table(out_table)
919 2565 aaronmk
    mapping = sql_gen.ColDict(mapping)
920 2312 aaronmk
921 2450 aaronmk
    def log_debug(msg): db.log_debug(msg, level=1.5)
922 2505 aaronmk
    def col_ustr(str_):
923 2567 aaronmk
        return strings.repr_no_u(sql_gen.remove_col_rename(str_))
924 2450 aaronmk
925 2486 aaronmk
    log_debug('********** New iteration **********')
926 2505 aaronmk
    log_debug('Inserting these input columns into '+strings.as_tt(
927
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
928 2463 aaronmk
929 2382 aaronmk
    # Create input joins from list of input tables
930
    in_tables_ = in_tables[:] # don't modify input!
931
    in_tables0 = in_tables_.pop(0) # first table is separate
932 2279 aaronmk
    in_pkey = pkey(db, in_tables0, recover=True)
933 2285 aaronmk
    in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
934 2460 aaronmk
    input_joins = [in_tables0]+[sql_gen.Join(v,
935
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
936 2131 aaronmk
937 2592 aaronmk
    if into == None:
938
        into = into_table_name(out_table, in_tables0, mapping, is_func)
939
    into = sql_gen.as_Table(into)
940
941 2486 aaronmk
    log_debug('Joining together input tables into temp table')
942 2395 aaronmk
    # Place in new table for speed and so don't modify input if values edited
943 2584 aaronmk
    in_table = sql_gen.Table('in')
944 2395 aaronmk
    flatten_cols = filter(sql_gen.is_table_col, mapping.values())
945
    mapping = dicts.join(mapping, flatten(db, in_table, input_joins,
946
        flatten_cols, preserve=[in_pkey_col], start=0))
947
    input_joins = [in_table]
948 2486 aaronmk
    db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
949 2395 aaronmk
950 2509 aaronmk
    # Resolve default value column
951
    try: default = mapping[default]
952
    except KeyError:
953
        if default != None:
954
            db.log_debug('Default value column '
955
                +strings.as_tt(strings.repr_no_u(default))
956 2511 aaronmk
                +' does not exist in mapping, falling back to None', level=2.1)
957 2509 aaronmk
            default = None
958
959 2279 aaronmk
    out_pkey = pkey(db, out_table, recover=True)
960 2285 aaronmk
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
961 2142 aaronmk
962 2387 aaronmk
    pkeys_names = [in_pkey, out_pkey]
963 2236 aaronmk
    pkeys_cols = [in_pkey_col, out_pkey_col]
964
965 2201 aaronmk
    pkeys_table_exists_ref = [False]
966 2420 aaronmk
    def insert_into_pkeys(joins, cols):
967
        query, params = mk_select(db, joins, cols, order_by=None, start=0)
968 2201 aaronmk
        if pkeys_table_exists_ref[0]:
969 2489 aaronmk
            insert_select(db, into, pkeys_names, query, params)
970 2201 aaronmk
        else:
971 2489 aaronmk
            run_query_into(db, query, params, into=into)
972 2201 aaronmk
            pkeys_table_exists_ref[0] = True
973
974 2429 aaronmk
    limit_ref = [None]
975 2380 aaronmk
    conds = set()
976 2233 aaronmk
    distinct_on = []
977 2325 aaronmk
    def mk_main_select(joins, cols):
978 2429 aaronmk
        return mk_select(db, joins, cols, conds, distinct_on,
979
            limit=limit_ref[0], start=0)
980 2132 aaronmk
981 2519 aaronmk
    exc_strs = set()
982 2309 aaronmk
    def log_exc(e):
983 2519 aaronmk
        e_str = exc.str_(e, first_line_only=True)
984
        log_debug('Caught exception: '+e_str)
985
        assert e_str not in exc_strs # avoid infinite loops
986
        exc_strs.add(e_str)
987 2451 aaronmk
    def remove_all_rows():
988 2450 aaronmk
        log_debug('Returning NULL for all rows')
989 2429 aaronmk
        limit_ref[0] = 0 # just create an empty pkeys table
990 2409 aaronmk
    def ignore(in_col, value):
991 2545 aaronmk
        in_col_str = strings.as_tt(repr(in_col))
992 2544 aaronmk
        db.log_debug('Adding index on '+in_col_str+' to enable fast filtering',
993
            level=2.5)
994 2537 aaronmk
        add_index(db, in_col)
995 2545 aaronmk
        log_debug('Ignoring rows with '+in_col_str+' = '
996
            +strings.as_tt(repr(value)))
997 2403 aaronmk
    def remove_rows(in_col, value):
998 2409 aaronmk
        ignore(in_col, value)
999 2378 aaronmk
        cond = (in_col, sql_gen.CompareCond(value, '!='))
1000
        assert cond not in conds # avoid infinite loops
1001 2380 aaronmk
        conds.add(cond)
1002 2403 aaronmk
    def invalid2null(in_col, value):
1003 2409 aaronmk
        ignore(in_col, value)
1004 2403 aaronmk
        update(db, in_table, [(in_col, None)],
1005
            sql_gen.ColValueCond(in_col, value))
1006 2245 aaronmk
1007 2589 aaronmk
    def insert_pkeys_table(which):
1008
        return sql_gen.Table(sql_gen.add_suffix(in_table.name,
1009
            '_insert_'+which+'_pkeys'))
1010
    insert_out_pkeys = insert_pkeys_table('out')
1011
    insert_in_pkeys = insert_pkeys_table('in')
1012
1013 2206 aaronmk
    # Do inserts and selects
1014 2565 aaronmk
    join_cols = sql_gen.ColDict()
1015 2206 aaronmk
    while True:
1016 2521 aaronmk
        if limit_ref[0] == 0: # special case
1017
            log_debug('Creating an empty pkeys table')
1018
            cur = run_query_into(db, *mk_select(db, out_table, [out_pkey],
1019
                limit=limit_ref[0]), into=insert_out_pkeys)
1020
            break # don't do main case
1021
1022 2303 aaronmk
        has_joins = join_cols != {}
1023
1024 2305 aaronmk
        # Prepare to insert new rows
1025 2325 aaronmk
        insert_joins = input_joins[:] # don't modify original!
1026 2403 aaronmk
        insert_args = dict(recover=True, cacheable=False)
1027 2303 aaronmk
        if has_joins:
1028 2317 aaronmk
            distinct_on = [v.to_Col() for v in join_cols.values()]
1029 2325 aaronmk
            insert_joins.append(sql_gen.Join(out_table, join_cols,
1030
                sql_gen.filter_out))
1031
        else:
1032 2404 aaronmk
            insert_args.update(dict(returning=out_pkey, into=insert_out_pkeys))
1033 2520 aaronmk
        main_select = mk_main_select(insert_joins, mapping.values())[0]
1034 2303 aaronmk
1035 2486 aaronmk
        log_debug('Trying to insert new rows')
1036 2206 aaronmk
        try:
1037 2518 aaronmk
            cur = insert_select(db, out_table, mapping.keys(), main_select,
1038
                **insert_args)
1039 2357 aaronmk
            break # insert successful
1040 2206 aaronmk
        except DuplicateKeyException, e:
1041 2309 aaronmk
            log_exc(e)
1042
1043 2258 aaronmk
            old_join_cols = join_cols.copy()
1044 2565 aaronmk
            join_cols.update(util.dict_subset_right_join(mapping, e.cols))
1045 2486 aaronmk
            log_debug('Ignoring existing rows, comparing on these columns:\n'
1046 2505 aaronmk
                +strings.as_inline_table(join_cols, ustr=col_ustr))
1047 2258 aaronmk
            assert join_cols != old_join_cols # avoid infinite loops
1048 2230 aaronmk
        except NullValueException, e:
1049 2309 aaronmk
            log_exc(e)
1050
1051 2230 aaronmk
            out_col, = e.cols
1052
            try: in_col = mapping[out_col]
1053 2356 aaronmk
            except KeyError:
1054 2486 aaronmk
                log_debug('Missing mapping for NOT NULL column '+out_col)
1055 2451 aaronmk
                remove_all_rows()
1056 2403 aaronmk
            else: remove_rows(in_col, None)
1057 2542 aaronmk
        except FunctionValueException, e:
1058
            log_exc(e)
1059
1060
            func_name = e.name
1061
            value = e.value
1062
            for out_col, in_col in mapping.iteritems():
1063 2562 aaronmk
                invalid2null(sql_gen.unwrap_func_call(in_col, func_name), value)
1064 2525 aaronmk
        except MissingCastException, e:
1065
            log_exc(e)
1066
1067
            out_col = e.col
1068 2534 aaronmk
            mapping[out_col] = sql_gen.wrap_in_func(e.type, mapping[out_col])
1069 2429 aaronmk
        except DatabaseErrors, e:
1070
            log_exc(e)
1071
1072 2531 aaronmk
            msg = 'No handler for exception: '+exc.str_(e)
1073 2451 aaronmk
            warnings.warn(DbWarning(msg))
1074
            log_debug(msg)
1075
            remove_all_rows()
1076 2358 aaronmk
        # after exception handled, rerun loop with additional constraints
1077 2132 aaronmk
1078 2357 aaronmk
    if row_ct_ref != None and cur.rowcount >= 0:
1079
        row_ct_ref[0] += cur.rowcount
1080
1081
    if has_joins:
1082
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols)]
1083 2486 aaronmk
        log_debug('Getting output table pkeys of existing/inserted rows')
1084 2420 aaronmk
        insert_into_pkeys(select_joins, pkeys_cols)
1085 2357 aaronmk
    else:
1086 2404 aaronmk
        add_row_num(db, insert_out_pkeys) # for joining with input pkeys
1087 2357 aaronmk
1088 2486 aaronmk
        log_debug('Getting input table pkeys of inserted rows')
1089 2357 aaronmk
        run_query_into(db, *mk_main_select(input_joins, [in_pkey]),
1090 2404 aaronmk
            into=insert_in_pkeys)
1091
        add_row_num(db, insert_in_pkeys) # for joining with output pkeys
1092 2357 aaronmk
1093 2428 aaronmk
        assert table_row_count(db, insert_out_pkeys) == table_row_count(db,
1094
            insert_in_pkeys)
1095
1096 2486 aaronmk
        log_debug('Combining output and input pkeys in inserted order')
1097 2404 aaronmk
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
1098 2357 aaronmk
            {row_num_col: sql_gen.join_same_not_null})]
1099 2420 aaronmk
        insert_into_pkeys(pkey_joins, pkeys_names)
1100 2357 aaronmk
1101 2486 aaronmk
    db.log_debug('Adding pkey on pkeys table to enable fast joins', level=2.5)
1102 2594 aaronmk
    add_pkey(db, into)
1103 2407 aaronmk
1104 2508 aaronmk
    log_debug('Setting pkeys of missing rows to '+strings.as_tt(repr(default)))
1105 2489 aaronmk
    missing_rows_joins = input_joins+[sql_gen.Join(into,
1106 2357 aaronmk
        {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
1107
        # must use join_same_not_null or query will take forever
1108 2420 aaronmk
    insert_into_pkeys(missing_rows_joins,
1109 2508 aaronmk
        [in_pkey_col, sql_gen.NamedCol(out_pkey, default)])
1110 2357 aaronmk
1111 2489 aaronmk
    assert table_row_count(db, into) == table_row_count(db, in_table)
1112 2428 aaronmk
1113 2489 aaronmk
    return sql_gen.Col(out_pkey, into)
1114 2115 aaronmk
1115
##### Data cleanup
1116
1117 2290 aaronmk
def cleanup_table(db, table, cols):
1118 2115 aaronmk
    def esc_name_(name): return esc_name(db, name)
1119
1120 2290 aaronmk
    table = sql_gen.as_Table(table).to_str(db)
1121 2115 aaronmk
    cols = map(esc_name_, cols)
1122
1123
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
1124
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
1125
            for col in cols))),
1126
        dict(null0='', null1=r'\N'))