Project

General

Profile

1 11 aaronmk
# Database access
2
3 1869 aaronmk
import copy
4 2127 aaronmk
import operator
5 11 aaronmk
import re
6 865 aaronmk
import warnings
7 11 aaronmk
8 300 aaronmk
import exc
9 1909 aaronmk
import dicts
10 1893 aaronmk
import iters
11 1960 aaronmk
import lists
12 1889 aaronmk
from Proxy import Proxy
13 1872 aaronmk
import rand
14 2217 aaronmk
import sql_gen
15 862 aaronmk
import strings
16 131 aaronmk
import util
17 11 aaronmk
18 832 aaronmk
##### Exceptions
19
20 2170 aaronmk
def get_cur_query(cur, input_query=None, input_params=None):
21 2168 aaronmk
    raw_query = None
22
    if hasattr(cur, 'query'): raw_query = cur.query
23
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
24 2170 aaronmk
25
    if raw_query != None: return raw_query
26 2371 aaronmk
    else: return '[input] '+strings.ustr(input_query)+' % '+repr(input_params)
27 14 aaronmk
28 2170 aaronmk
def _add_cursor_info(e, *args, **kw_args):
29
    '''For params, see get_cur_query()'''
30
    exc.add_msg(e, 'query: '+str(get_cur_query(*args, **kw_args)))
31 135 aaronmk
32 300 aaronmk
class DbException(exc.ExceptionWithCause):
33 14 aaronmk
    def __init__(self, msg, cause=None, cur=None):
34 2145 aaronmk
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
35 14 aaronmk
        if cur != None: _add_cursor_info(self, cur)
36
37 2143 aaronmk
class ExceptionWithName(DbException):
38
    def __init__(self, name, cause=None):
39 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
40 2143 aaronmk
        self.name = name
41 360 aaronmk
42 2240 aaronmk
class ExceptionWithNameValue(DbException):
43
    def __init__(self, name, value, cause=None):
44 2484 aaronmk
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name))
45
            +'; value: '+strings.as_tt(repr(value)), cause)
46 2240 aaronmk
        self.name = name
47
        self.value = value
48
49 2306 aaronmk
class ConstraintException(DbException):
50
    def __init__(self, name, cols, cause=None):
51 2484 aaronmk
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
52
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
53 2306 aaronmk
        self.name = name
54 468 aaronmk
        self.cols = cols
55 11 aaronmk
56 2523 aaronmk
class MissingCastException(DbException):
57
    def __init__(self, type_, col, cause=None):
58
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
59
            +' on column: '+strings.as_tt(col), cause)
60
        self.type = type_
61
        self.col = col
62
63 2143 aaronmk
class NameException(DbException): pass
64
65 2306 aaronmk
class DuplicateKeyException(ConstraintException): pass
66 13 aaronmk
67 2306 aaronmk
class NullValueException(ConstraintException): pass
68 13 aaronmk
69 2240 aaronmk
class FunctionValueException(ExceptionWithNameValue): pass
70 2239 aaronmk
71 2143 aaronmk
class DuplicateTableException(ExceptionWithName): pass
72
73 2188 aaronmk
class DuplicateFunctionException(ExceptionWithName): pass
74
75 89 aaronmk
class EmptyRowException(DbException): pass
76
77 865 aaronmk
##### Warnings
78
79
class DbWarning(UserWarning): pass
80
81 1930 aaronmk
##### Result retrieval
82
83
def col_names(cur): return (col[0] for col in cur.description)
84
85
def rows(cur): return iter(lambda: cur.fetchone(), None)
86
87
def consume_rows(cur):
88
    '''Used to fetch all rows so result will be cached'''
89
    iters.consume_iter(rows(cur))
90
91
def next_row(cur): return rows(cur).next()
92
93
def row(cur):
94
    row_ = next_row(cur)
95
    consume_rows(cur)
96
    return row_
97
98
def next_value(cur): return next_row(cur)[0]
99
100
def value(cur): return row(cur)[0]
101
102
def values(cur): return iters.func_iter(lambda: next_value(cur))
103
104
def value_or_none(cur):
105
    try: return value(cur)
106
    except StopIteration: return None
107
108 2101 aaronmk
##### Input validation
109
110
def check_name(name):
111
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
112
        +'" may contain only alphanumeric characters and _')
113
114
def esc_name_by_module(module, name, ignore_case=False):
115 2388 aaronmk
    if module == 'psycopg2' or module == None:
116 2101 aaronmk
        if ignore_case:
117
            # Don't enclose in quotes because this disables case-insensitivity
118
            check_name(name)
119
            return name
120
        else: quote = '"'
121
    elif module == 'MySQLdb': quote = '`'
122
    else: raise NotImplementedError("Can't escape name for "+module+' database')
123 2500 aaronmk
    return sql_gen.esc_name(name, quote)
124 2101 aaronmk
125
def esc_name_by_engine(engine, name, **kw_args):
126
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
127
128
def esc_name(db, name, **kw_args):
129
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
130
131
def qual_name(db, schema, table):
132
    def esc_name_(name): return esc_name(db, name)
133
    table = esc_name_(table)
134
    if schema != None: return esc_name_(schema)+'.'+table
135
    else: return table
136
137 1869 aaronmk
##### Database connections
138 1849 aaronmk
139 2097 aaronmk
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
140 1926 aaronmk
141 1869 aaronmk
db_engines = {
142
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
143
    'PostgreSQL': ('psycopg2', {}),
144
}
145
146
DatabaseErrors_set = set([DbException])
147
DatabaseErrors = tuple(DatabaseErrors_set)
148
149
def _add_module(module):
150
    DatabaseErrors_set.add(module.DatabaseError)
151
    global DatabaseErrors
152
    DatabaseErrors = tuple(DatabaseErrors_set)
153
154
def db_config_str(db_config):
155
    return db_config['engine']+' database '+db_config['database']
156
157 1909 aaronmk
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
158 1894 aaronmk
159 2448 aaronmk
log_debug_none = lambda msg, level=2: None
160 1901 aaronmk
161 1849 aaronmk
class DbConn:
162 2190 aaronmk
    def __init__(self, db_config, serializable=True, autocommit=False,
163
        caching=True, log_debug=log_debug_none):
164 1869 aaronmk
        self.db_config = db_config
165
        self.serializable = serializable
166 2190 aaronmk
        self.autocommit = autocommit
167
        self.caching = caching
168 1901 aaronmk
        self.log_debug = log_debug
169 2193 aaronmk
        self.debug = log_debug != log_debug_none
170 1869 aaronmk
171
        self.__db = None
172 1889 aaronmk
        self.query_results = {}
173 2139 aaronmk
        self._savepoint = 0
174 1869 aaronmk
175
    def __getattr__(self, name):
176
        if name == '__dict__': raise Exception('getting __dict__')
177
        if name == 'db': return self._db()
178
        else: raise AttributeError()
179
180
    def __getstate__(self):
181
        state = copy.copy(self.__dict__) # shallow copy
182 1915 aaronmk
        state['log_debug'] = None # don't pickle the debug callback
183 1869 aaronmk
        state['_DbConn__db'] = None # don't pickle the connection
184
        return state
185
186 2165 aaronmk
    def connected(self): return self.__db != None
187
188 1869 aaronmk
    def _db(self):
189
        if self.__db == None:
190
            # Process db_config
191
            db_config = self.db_config.copy() # don't modify input!
192 2097 aaronmk
            schemas = db_config.pop('schemas', None)
193 1869 aaronmk
            module_name, mappings = db_engines[db_config.pop('engine')]
194
            module = __import__(module_name)
195
            _add_module(module)
196
            for orig, new in mappings.iteritems():
197
                try: util.rename_key(db_config, orig, new)
198
                except KeyError: pass
199
200
            # Connect
201
            self.__db = module.connect(**db_config)
202
203
            # Configure connection
204 2234 aaronmk
            if self.serializable and not self.autocommit: run_raw_query(self,
205
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
206 2101 aaronmk
            if schemas != None:
207
                schemas_ = ''.join((esc_name(self, s)+', '
208
                    for s in schemas.split(',')))
209
                run_raw_query(self, "SELECT set_config('search_path', \
210
%s || current_setting('search_path'), false)", [schemas_])
211 1869 aaronmk
212
        return self.__db
213 1889 aaronmk
214 1891 aaronmk
    class DbCursor(Proxy):
215 1927 aaronmk
        def __init__(self, outer):
216 1891 aaronmk
            Proxy.__init__(self, outer.db.cursor())
217 2191 aaronmk
            self.outer = outer
218 1927 aaronmk
            self.query_results = outer.query_results
219 1894 aaronmk
            self.query_lookup = None
220 1891 aaronmk
            self.result = []
221 1889 aaronmk
222 1894 aaronmk
        def execute(self, query, params=None):
223 1930 aaronmk
            self._is_insert = query.upper().find('INSERT') >= 0
224 1894 aaronmk
            self.query_lookup = _query_lookup(query, params)
225 2148 aaronmk
            try:
226 2191 aaronmk
                try:
227
                    return_value = self.inner.execute(query, params)
228
                    self.outer.do_autocommit()
229 2148 aaronmk
                finally: self.query = get_cur_query(self.inner)
230 1904 aaronmk
            except Exception, e:
231 2170 aaronmk
                _add_cursor_info(e, self, query, params)
232 1904 aaronmk
                self.result = e # cache the exception as the result
233
                self._cache_result()
234
                raise
235 1930 aaronmk
            # Fetch all rows so result will be cached
236
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
237 1894 aaronmk
            return return_value
238
239 1891 aaronmk
        def fetchone(self):
240
            row = self.inner.fetchone()
241 1899 aaronmk
            if row != None: self.result.append(row)
242
            # otherwise, fetched all rows
243 1904 aaronmk
            else: self._cache_result()
244
            return row
245
246
        def _cache_result(self):
247 1906 aaronmk
            # For inserts, only cache exceptions since inserts are not
248
            # idempotent, but an invalid insert will always be invalid
249 1930 aaronmk
            if self.query_results != None and (not self._is_insert
250 1906 aaronmk
                or isinstance(self.result, Exception)):
251
252 1894 aaronmk
                assert self.query_lookup != None
253 1916 aaronmk
                self.query_results[self.query_lookup] = self.CacheCursor(
254
                    util.dict_subset(dicts.AttrsDictView(self),
255
                    ['query', 'result', 'rowcount', 'description']))
256 1906 aaronmk
257 1916 aaronmk
        class CacheCursor:
258
            def __init__(self, cached_result): self.__dict__ = cached_result
259
260 1927 aaronmk
            def execute(self, *args, **kw_args):
261 1916 aaronmk
                if isinstance(self.result, Exception): raise self.result
262
                # otherwise, result is a rows list
263
                self.iter = iter(self.result)
264
265
            def fetchone(self):
266
                try: return self.iter.next()
267
                except StopIteration: return None
268 1891 aaronmk
269 2212 aaronmk
    def esc_value(self, value):
270 2215 aaronmk
        module = util.root_module(self.db)
271 2374 aaronmk
        if module == 'psycopg2': str_ = self.db.cursor().mogrify('%s', [value])
272 2212 aaronmk
        elif module == 'MySQLdb':
273
            import _mysql
274 2374 aaronmk
            str_ = _mysql.escape_string(value)
275 2212 aaronmk
        else: raise NotImplementedError("Can't escape value for "+module
276
            +' database')
277 2374 aaronmk
        return strings.to_unicode(str_)
278 2212 aaronmk
279 2347 aaronmk
    def esc_name(self, name): return esc_name(self, name) # calls global func
280
281 2445 aaronmk
    def run_query(self, query, params=None, cacheable=False, log_level=2,
282 2464 aaronmk
        debug_msg_ref=None):
283 2445 aaronmk
        '''
284 2464 aaronmk
        @param log_ignore_excs The log_level will be increased by 2 if the query
285
            throws one of these exceptions.
286 2445 aaronmk
        '''
287 2167 aaronmk
        assert query != None
288
289 2047 aaronmk
        if not self.caching: cacheable = False
290 1903 aaronmk
        used_cache = False
291
        try:
292 1927 aaronmk
            # Get cursor
293
            if cacheable:
294
                query_lookup = _query_lookup(query, params)
295
                try:
296
                    cur = self.query_results[query_lookup]
297
                    used_cache = True
298
                except KeyError: cur = self.DbCursor(self)
299
            else: cur = self.db.cursor()
300
301
            # Run query
302 2148 aaronmk
            cur.execute(query, params)
303 1903 aaronmk
        finally:
304 2464 aaronmk
            if self.debug and debug_msg_ref != None:# only compute msg if needed
305 2470 aaronmk
                if used_cache: cache_status = 'cache hit'
306
                elif cacheable: cache_status = 'cache miss'
307
                else: cache_status = 'non-cacheable'
308 2472 aaronmk
                query_code = strings.as_code(str(get_cur_query(cur, query,
309
                    params)), 'SQL')
310
                debug_msg_ref[0] = 'DB query: '+cache_status+':\n'+query_code
311 1903 aaronmk
312
        return cur
313 1914 aaronmk
314
    def is_cached(self, query, params=None):
315
        return _query_lookup(query, params) in self.query_results
316 2139 aaronmk
317
    def with_savepoint(self, func):
318 2171 aaronmk
        savepoint = 'level_'+str(self._savepoint)
319 2443 aaronmk
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
320 2139 aaronmk
        self._savepoint += 1
321
        try:
322
            try: return_val = func()
323
            finally:
324
                self._savepoint -= 1
325
                assert self._savepoint >= 0
326
        except:
327 2443 aaronmk
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
328 2139 aaronmk
            raise
329
        else:
330 2443 aaronmk
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
331 2191 aaronmk
            self.do_autocommit()
332 2139 aaronmk
            return return_val
333 2191 aaronmk
334
    def do_autocommit(self):
335
        '''Autocommits if outside savepoint'''
336
        assert self._savepoint >= 0
337
        if self.autocommit and self._savepoint == 0:
338
            self.log_debug('Autocommiting')
339
            self.db.commit()
340 1849 aaronmk
341 1869 aaronmk
connect = DbConn
342
343 832 aaronmk
##### Querying
344
345 1894 aaronmk
def run_raw_query(db, *args, **kw_args):
346 2085 aaronmk
    '''For params, see DbConn.run_query()'''
347 1894 aaronmk
    return db.run_query(*args, **kw_args)
348 11 aaronmk
349 2068 aaronmk
def mogrify(db, query, params):
350
    module = util.root_module(db.db)
351
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
352
    else: raise NotImplementedError("Can't mogrify query for "+module+
353
        ' database')
354
355 832 aaronmk
##### Recoverable querying
356 15 aaronmk
357 2139 aaronmk
def with_savepoint(db, func): return db.with_savepoint(func)
358 11 aaronmk
359 2464 aaronmk
def run_query(db, query, params=None, recover=None, cacheable=False,
360
    log_level=2, log_ignore_excs=None, **kw_args):
361 2441 aaronmk
    '''For params, see run_raw_query()'''
362 830 aaronmk
    if recover == None: recover = False
363 2464 aaronmk
    if log_ignore_excs == None: log_ignore_excs = ()
364
    log_ignore_excs = tuple(log_ignore_excs)
365 830 aaronmk
366 2464 aaronmk
    debug_msg_ref = [None]
367 2148 aaronmk
    try:
368 2464 aaronmk
        try:
369
            def run(): return run_raw_query(db, query, params, cacheable,
370
                log_level, debug_msg_ref, **kw_args)
371
            if recover and not db.is_cached(query, params):
372
                return with_savepoint(db, run)
373
            else: return run() # don't need savepoint if cached
374
        except Exception, e:
375
            if not recover: raise # need savepoint to run index_cols()
376
            msg = exc.str_(e)
377
378
            match = re.search(r'duplicate key value violates unique constraint '
379 2493 aaronmk
                r'"((_?[^\W_]+)_.+?)"', msg)
380 2464 aaronmk
            if match:
381
                constraint, table = match.groups()
382
                try: cols = index_cols(db, table, constraint)
383
                except NotImplementedError: raise e
384
                else: raise DuplicateKeyException(constraint, cols, e)
385
386 2493 aaronmk
            match = re.search(r'null value in column "(.+?)" violates not-null'
387 2464 aaronmk
                r' constraint', msg)
388
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
389
390
            match = re.search(r'\b(?:invalid input (?:syntax|value)\b.*?'
391
                r'|date/time field value out of range): "(.+?)"\n'
392 2535 aaronmk
                r'(?:(?s).*?)\bfunction "(.+?)"', msg)
393 2464 aaronmk
            if match:
394
                value, name = match.groups()
395
                raise FunctionValueException(name, strings.to_unicode(value), e)
396
397 2526 aaronmk
            match = re.search(r'column "(.+?)" is of type (.+?) but expression '
398 2523 aaronmk
                r'is of type', msg)
399
            if match:
400
                col, type_ = match.groups()
401
                raise MissingCastException(type_, col, e)
402
403 2493 aaronmk
            match = re.search(r'relation "(.+?)" already exists', msg)
404 2464 aaronmk
            if match: raise DuplicateTableException(match.group(1), e)
405
406 2493 aaronmk
            match = re.search(r'function "(.+?)" already exists', msg)
407 2464 aaronmk
            if match: raise DuplicateFunctionException(match.group(1), e)
408
409
            raise # no specific exception raised
410
    except log_ignore_excs:
411
        log_level += 2
412
        raise
413
    finally:
414
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)
415 830 aaronmk
416 832 aaronmk
##### Basic queries
417
418 2153 aaronmk
def next_version(name):
419
    '''Prepends the version # so it won't be removed if the name is truncated'''
420 2163 aaronmk
    version = 1 # first existing name was version 0
421 2498 aaronmk
    match = re.match(r'^#(\d+)-(.*)$', name)
422 2153 aaronmk
    if match:
423
        version = int(match.group(1))+1
424
        name = match.group(2)
425 2498 aaronmk
    return '#'+str(version)+'-'+name
426 2153 aaronmk
427 2386 aaronmk
def run_query_into(db, query, params, into=None, *args, **kw_args):
428 2085 aaronmk
    '''Outputs a query to a temp table.
429
    For params, see run_query().
430
    '''
431 2386 aaronmk
    if into == None: return run_query(db, query, params, *args, **kw_args)
432 2085 aaronmk
    else: # place rows in temp table
433 2386 aaronmk
        assert isinstance(into, sql_gen.Table)
434 2385 aaronmk
435 2153 aaronmk
        kw_args['recover'] = True
436 2464 aaronmk
        kw_args.setdefault('log_ignore_excs', (DuplicateTableException,))
437 2440 aaronmk
438 2468 aaronmk
        temp = not db.autocommit # tables are permanent in autocommit mode
439 2440 aaronmk
        # "temporary tables cannot specify a schema name", so remove schema
440
        if temp: into.schema = None
441
442 2153 aaronmk
        while True:
443
            try:
444 2194 aaronmk
                create_query = 'CREATE'
445 2440 aaronmk
                if temp: create_query += ' TEMP'
446 2467 aaronmk
                create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
447 2194 aaronmk
448
                return run_query(db, create_query, params, *args, **kw_args)
449 2153 aaronmk
                    # CREATE TABLE AS sets rowcount to # rows in query
450
            except DuplicateTableException, e:
451 2386 aaronmk
                into.name = next_version(into.name)
452 2153 aaronmk
                # try again with next version of name
453 2085 aaronmk
454 2120 aaronmk
order_by_pkey = object() # tells mk_select() to order by the pkey
455
456 2199 aaronmk
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
457
458 2233 aaronmk
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
459 2293 aaronmk
    start=None, order_by=order_by_pkey, default_table=None):
460 1981 aaronmk
    '''
461 2121 aaronmk
    @param tables The single table to select from, or a list of tables to join
462 2280 aaronmk
        together, with tables after the first being sql_gen.Join objects
463 1981 aaronmk
    @param fields Use None to select all fields in the table
464 2377 aaronmk
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
465 2379 aaronmk
        * container can be any iterable type
466 2399 aaronmk
        * compare_left_side: sql_gen.Code|str (for col name)
467
        * compare_right_side: sql_gen.ValueCond|literal value
468 2199 aaronmk
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
469
        use all columns
470 2054 aaronmk
    @return tuple(query, params)
471 1981 aaronmk
    '''
472 2315 aaronmk
    # Parse tables param
473 2121 aaronmk
    if not lists.is_seq(tables): tables = [tables]
474 2141 aaronmk
    tables = list(tables) # don't modify input! (list() copies input)
475 2315 aaronmk
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
476 2121 aaronmk
477 2315 aaronmk
    # Parse other params
478 2376 aaronmk
    if conds == None: conds = []
479
    elif isinstance(conds, dict): conds = conds.items()
480 2379 aaronmk
    conds = list(conds) # don't modify input! (list() copies input)
481 135 aaronmk
    assert limit == None or type(limit) == int
482 865 aaronmk
    assert start == None or type(start) == int
483 2315 aaronmk
    if order_by is order_by_pkey:
484
        if distinct_on != []: order_by = None
485
        else: order_by = pkey(db, table0, recover=True)
486 865 aaronmk
487 2315 aaronmk
    query = 'SELECT'
488 2056 aaronmk
489 2315 aaronmk
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
490 2056 aaronmk
491 2200 aaronmk
    # DISTINCT ON columns
492 2233 aaronmk
    if distinct_on != []:
493 2467 aaronmk
        query += '\nDISTINCT'
494 2254 aaronmk
        if distinct_on is not distinct_on_all:
495 2200 aaronmk
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
496
497
    # Columns
498 2467 aaronmk
    query += '\n'
499 1135 aaronmk
    if fields == None: query += '*'
500 2479 aaronmk
    else: query += '\n, '.join(map(parse_col, fields))
501 2200 aaronmk
502
    # Main table
503 2467 aaronmk
    query += '\nFROM '+table0.to_str(db)
504 865 aaronmk
505 2122 aaronmk
    # Add joins
506 2271 aaronmk
    left_table = table0
507 2263 aaronmk
    for join_ in tables:
508
        table = join_.table
509 2238 aaronmk
510 2343 aaronmk
        # Parse special values
511
        if join_.type_ is sql_gen.filter_out: # filter no match
512 2376 aaronmk
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
513
                None))
514 2343 aaronmk
515 2467 aaronmk
        query += '\n'+join_.to_str(db, left_table)
516 2122 aaronmk
517
        left_table = table
518
519 865 aaronmk
    missing = True
520 2376 aaronmk
    if conds != []:
521 2467 aaronmk
        query += '\nWHERE\n'+('\nAND\n'.join(('('+sql_gen.ColValueCond(l, r)
522 2410 aaronmk
            .to_str(db)+')' for l, r in conds)))
523 865 aaronmk
        missing = False
524 2227 aaronmk
    if order_by != None:
525 2467 aaronmk
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
526
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
527 865 aaronmk
    if start != None:
528 2467 aaronmk
        if start != 0: query += '\nOFFSET '+str(start)
529 865 aaronmk
        missing = False
530
    if missing: warnings.warn(DbWarning(
531
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
532
533 2315 aaronmk
    return (query, [])
534 11 aaronmk
535 2054 aaronmk
def select(db, *args, **kw_args):
536
    '''For params, see mk_select() and run_query()'''
537
    recover = kw_args.pop('recover', None)
538
    cacheable = kw_args.pop('cacheable', True)
539 2442 aaronmk
    log_level = kw_args.pop('log_level', 2)
540 2054 aaronmk
541
    query, params = mk_select(db, *args, **kw_args)
542 2442 aaronmk
    return run_query(db, query, params, recover, cacheable, log_level=log_level)
543 2054 aaronmk
544 2066 aaronmk
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
545 2292 aaronmk
    returning=None, embeddable=False):
546 1960 aaronmk
    '''
547
    @param returning str|None An inserted column (such as pkey) to return
548 2070 aaronmk
    @param embeddable Whether the query should be embeddable as a nested SELECT.
549 2073 aaronmk
        Warning: If you set this and cacheable=True when the query is run, the
550
        query will be fully cached, not just if it raises an exception.
551 1960 aaronmk
    '''
552 2328 aaronmk
    table = sql_gen.as_Table(table)
553 2318 aaronmk
    if cols == []: cols = None # no cols (all defaults) = unknown col names
554
    if cols != None: cols = [sql_gen.as_Col(v).to_str(db) for v in cols]
555 2063 aaronmk
    if select_query == None: select_query = 'DEFAULT VALUES'
556 2327 aaronmk
    if returning != None: returning = sql_gen.as_Col(returning, table)
557 2063 aaronmk
558
    # Build query
559 2497 aaronmk
    first_line = 'INSERT INTO '+table.to_str(db)
560
    query = first_line
561 2467 aaronmk
    if cols != None: query += '\n('+', '.join(cols)+')'
562
    query += '\n'+select_query
563 2063 aaronmk
564
    if returning != None:
565 2327 aaronmk
        returning_name = copy.copy(returning)
566
        returning_name.table = None
567
        returning_name = returning_name.to_str(db)
568 2467 aaronmk
        query += '\nRETURNING '+returning_name
569 2063 aaronmk
570 2070 aaronmk
    if embeddable:
571 2327 aaronmk
        assert returning != None
572
573 2070 aaronmk
        # Create function
574 2513 aaronmk
        function_name = sql_gen.clean_name(first_line)
575 2327 aaronmk
        return_type = 'SETOF '+returning.to_str(db)+'%TYPE'
576 2189 aaronmk
        while True:
577
            try:
578 2327 aaronmk
                func_schema = None
579 2468 aaronmk
                if not db.autocommit: func_schema = 'pg_temp'
580 2327 aaronmk
                function = sql_gen.Table(function_name, func_schema).to_str(db)
581 2194 aaronmk
582 2189 aaronmk
                function_query = '''\
583 2467 aaronmk
CREATE FUNCTION '''+function+'''()
584
RETURNS '''+return_type+'''
585
LANGUAGE sql
586
AS $$
587
'''+mogrify(db, query, params)+''';
588
$$;
589 2070 aaronmk
'''
590 2446 aaronmk
                run_query(db, function_query, recover=True, cacheable=True,
591 2464 aaronmk
                    log_ignore_excs=(DuplicateFunctionException,))
592 2189 aaronmk
                break # this version was successful
593
            except DuplicateFunctionException, e:
594
                function_name = next_version(function_name)
595
                # try again with next version of name
596 2070 aaronmk
597 2337 aaronmk
        # Return query that uses function
598
        func_table = sql_gen.NamedTable('f', sql_gen.CustomCode(function+'()'),
599
            [returning_name]) # AS clause requires function alias
600
        return mk_select(db, func_table, start=0, order_by=None)
601 2070 aaronmk
602 2066 aaronmk
    return (query, params)
603
604
def insert_select(db, *args, **kw_args):
605 2085 aaronmk
    '''For params, see mk_insert_select() and run_query_into()
606 2386 aaronmk
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
607
        values in
608 2072 aaronmk
    '''
609 2386 aaronmk
    into = kw_args.pop('into', None)
610
    if into != None: kw_args['embeddable'] = True
611 2066 aaronmk
    recover = kw_args.pop('recover', None)
612
    cacheable = kw_args.pop('cacheable', True)
613
614
    query, params = mk_insert_select(db, *args, **kw_args)
615 2386 aaronmk
    return run_query_into(db, query, params, into, recover=recover,
616 2153 aaronmk
        cacheable=cacheable)
617 2063 aaronmk
618 2066 aaronmk
default = object() # tells insert() to use the default value for a column
619
620 2063 aaronmk
def insert(db, table, row, *args, **kw_args):
621 2085 aaronmk
    '''For params, see insert_select()'''
622 1960 aaronmk
    if lists.is_seq(row): cols = None
623
    else:
624
        cols = row.keys()
625
        row = row.values()
626
    row = list(row) # ensure that "!= []" works
627
628 1961 aaronmk
    # Check for special values
629
    labels = []
630
    values = []
631
    for value in row:
632 2254 aaronmk
        if value is default: labels.append('DEFAULT')
633 1961 aaronmk
        else:
634
            labels.append('%s')
635
            values.append(value)
636
637
    # Build query
638 2467 aaronmk
    if values != []: query = 'VALUES ('+(', '.join(labels))+')'
639 2063 aaronmk
    else: query = None
640 1554 aaronmk
641 2064 aaronmk
    return insert_select(db, table, cols, query, values, *args, **kw_args)
642 11 aaronmk
643 2402 aaronmk
def mk_update(db, table, changes=None, cond=None):
644
    '''
645
    @param changes [(col, new_value),...]
646
        * container can be any iterable type
647
        * col: sql_gen.Code|str (for col name)
648
        * new_value: sql_gen.Code|literal value
649
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
650
    @return str query
651
    '''
652
    query = 'UPDATE '+sql_gen.as_Table(table).to_str(db)+'\nSET\n'
653 2405 aaronmk
    query += ',\n'.join((sql_gen.to_name_only_col(col, table).to_str(db)+' = '
654 2402 aaronmk
        +sql_gen.as_Value(new_value).to_str(db) for col, new_value in changes))
655 2467 aaronmk
    if cond != None: query += '\nWHERE\n'+cond.to_str(db)
656 2402 aaronmk
657
    return query
658
659
def update(db, *args, **kw_args):
660
    '''For params, see mk_update() and run_query()'''
661
    recover = kw_args.pop('recover', None)
662
663
    return run_query(db, mk_update(db, *args, **kw_args), [], recover)
664
665 135 aaronmk
def last_insert_id(db):
666 1849 aaronmk
    module = util.root_module(db.db)
667 135 aaronmk
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
668
    elif module == 'MySQLdb': return db.insert_id()
669
    else: return None
670 13 aaronmk
671 1968 aaronmk
def truncate(db, table, schema='public'):
672
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
673 832 aaronmk
674 2394 aaronmk
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
675 2383 aaronmk
    '''Creates a mapping from original column names (which may have collisions)
676 2415 aaronmk
    to names that will be distinct among the columns' tables.
677 2383 aaronmk
    This is meant to be used for several tables that are being joined together.
678 2415 aaronmk
    @param cols The columns to combine. Duplicates will be removed.
679
    @param into The table for the new columns.
680 2394 aaronmk
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
681
        columns will be included in the mapping even if they are not in cols.
682
        The tables of the provided Col objects will be changed to into, so make
683
        copies of them if you want to keep the original tables.
684
    @param as_items Whether to return a list of dict items instead of a dict
685 2383 aaronmk
    @return dict(orig_col=new_col, ...)
686
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
687 2392 aaronmk
        * new_col: sql_gen.Col(orig_col_name, into)
688
        * All mappings use the into table so its name can easily be
689 2383 aaronmk
          changed for all columns at once
690
    '''
691 2415 aaronmk
    cols = lists.uniqify(cols)
692
693 2394 aaronmk
    items = []
694 2389 aaronmk
    for col in preserve:
695 2390 aaronmk
        orig_col = copy.copy(col)
696 2392 aaronmk
        col.table = into
697 2394 aaronmk
        items.append((orig_col, col))
698
    preserve = set(preserve)
699
    for col in cols:
700 2515 aaronmk
        if col not in preserve: items.append((col, sql_gen.Col(str(col), into)))
701 2394 aaronmk
702
    if not as_items: items = dict(items)
703
    return items
704 2383 aaronmk
705 2393 aaronmk
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
706 2391 aaronmk
    '''For params, see mk_flatten_mapping()
707
    @return See return value of mk_flatten_mapping()
708
    '''
709 2394 aaronmk
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
710
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
711 2391 aaronmk
    run_query_into(db, *mk_select(db, joins, cols, limit=limit, start=start),
712 2392 aaronmk
        into=into)
713 2394 aaronmk
    return dict(items)
714 2391 aaronmk
715 2414 aaronmk
##### Database structure queries
716
717 2426 aaronmk
def table_row_count(db, table, recover=None):
718
    return value(run_query(db, *mk_select(db, table, [sql_gen.row_count],
719 2443 aaronmk
        order_by=None, start=0), recover=recover, log_level=3))
720 2426 aaronmk
721 2414 aaronmk
def table_cols(db, table, recover=None):
722
    return list(col_names(select(db, table, limit=0, order_by=None,
723 2443 aaronmk
        recover=recover, log_level=4)))
724 2414 aaronmk
725 2291 aaronmk
def pkey(db, table, recover=None):
726 832 aaronmk
    '''Assumed to be first column in table'''
727 2339 aaronmk
    return table_cols(db, table, recover)[0]
728 832 aaronmk
729 2340 aaronmk
not_null_col = 'not_null'
730
731
def table_not_null_col(db, table, recover=None):
732
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
733
    if not_null_col in table_cols(db, table, recover): return not_null_col
734
    else: return pkey(db, table, recover)
735
736 853 aaronmk
def index_cols(db, table, index):
737
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
738
    automatically created. When you don't know whether something is a UNIQUE
739
    constraint or a UNIQUE index, use this function.'''
740 1909 aaronmk
    module = util.root_module(db.db)
741
    if module == 'psycopg2':
742
        return list(values(run_query(db, '''\
743 853 aaronmk
SELECT attname
744 866 aaronmk
FROM
745
(
746
        SELECT attnum, attname
747
        FROM pg_index
748
        JOIN pg_class index ON index.oid = indexrelid
749
        JOIN pg_class table_ ON table_.oid = indrelid
750
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
751
        WHERE
752
            table_.relname = %(table)s
753
            AND index.relname = %(index)s
754
    UNION
755
        SELECT attnum, attname
756
        FROM
757
        (
758
            SELECT
759
                indrelid
760
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
761
                    AS indkey
762
            FROM pg_index
763
            JOIN pg_class index ON index.oid = indexrelid
764
            JOIN pg_class table_ ON table_.oid = indrelid
765
            WHERE
766
                table_.relname = %(table)s
767
                AND index.relname = %(index)s
768
        ) s
769
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
770
) s
771 853 aaronmk
ORDER BY attnum
772
''',
773 2443 aaronmk
            {'table': table, 'index': index}, cacheable=True, log_level=4)))
774 1909 aaronmk
    else: raise NotImplementedError("Can't list index columns for "+module+
775
        ' database')
776 853 aaronmk
777 464 aaronmk
def constraint_cols(db, table, constraint):
778 1849 aaronmk
    module = util.root_module(db.db)
779 464 aaronmk
    if module == 'psycopg2':
780
        return list(values(run_query(db, '''\
781
SELECT attname
782
FROM pg_constraint
783
JOIN pg_class ON pg_class.oid = conrelid
784
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
785
WHERE
786
    relname = %(table)s
787
    AND conname = %(constraint)s
788
ORDER BY attnum
789
''',
790
            {'table': table, 'constraint': constraint})))
791
    else: raise NotImplementedError("Can't list constraint columns for "+module+
792
        ' database')
793
794 2096 aaronmk
row_num_col = '_row_num'
795
796 2538 aaronmk
def add_index(db, expr):
797
    '''Adds an index on a column or expression if it doesn't already exist.
798
    Currently, only function calls are supported as expressions.
799
    '''
800
    expr = copy.copy(expr) # don't modify input!
801
802
    # Extract col
803 2539 aaronmk
    if isinstance(expr, sql_gen.FunctionCall):
804
        col = expr.args[0]
805 2541 aaronmk
        expr = sql_gen.Expr(expr)
806 2538 aaronmk
    else: col = expr
807 2408 aaronmk
    assert sql_gen.is_table_col(col)
808
809 2538 aaronmk
    index = sql_gen.as_Table(str(expr))
810 2408 aaronmk
    table = col.table
811 2538 aaronmk
    col.table = None
812 2408 aaronmk
    try: run_query(db, 'CREATE INDEX '+index.to_str(db)+' ON '+table.to_str(db)
813 2538 aaronmk
        +' ('+expr.to_str(db)+')', recover=True, cacheable=True, log_level=3)
814 2408 aaronmk
    except DuplicateTableException: pass # index already existed
815
816 2406 aaronmk
def index_pkey(db, table, recover=None):
817
    '''Makes the first column in a table the primary key.
818
    @pre The table must not already have a primary key.
819
    '''
820
    table = sql_gen.as_Table(table)
821
822
    index = sql_gen.as_Table(table.name+'_pkey')
823
    col = sql_gen.to_name_only_col(pkey(db, table, recover))
824
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD CONSTRAINT '
825 2443 aaronmk
        +index.to_str(db)+' PRIMARY KEY('+col.to_str(db)+')', recover=recover,
826
        log_level=3)
827 2406 aaronmk
828 2086 aaronmk
def add_row_num(db, table):
829 2117 aaronmk
    '''Adds a row number column to a table. Its name is in row_num_col. It will
830
    be the primary key.'''
831 2320 aaronmk
    table = sql_gen.as_Table(table).to_str(db)
832 2096 aaronmk
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
833 2443 aaronmk
        +' serial NOT NULL PRIMARY KEY', log_level=3)
834 2086 aaronmk
835 2548 aaronmk
def tables(db, schema_like='public', table_like='%'):
836 1849 aaronmk
    module = util.root_module(db.db)
837 2548 aaronmk
    params = {'schema_like': schema_like, 'table_like': table_like}
838 832 aaronmk
    if module == 'psycopg2':
839 1968 aaronmk
        return values(run_query(db, '''\
840
SELECT tablename
841
FROM pg_tables
842
WHERE
843 2548 aaronmk
    schemaname LIKE %(schema_like)s
844 1968 aaronmk
    AND tablename LIKE %(table_like)s
845
ORDER BY tablename
846
''',
847
            params, cacheable=True))
848
    elif module == 'MySQLdb':
849
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
850
            cacheable=True))
851 832 aaronmk
    else: raise NotImplementedError("Can't list tables for "+module+' database')
852 830 aaronmk
853 833 aaronmk
##### Database management
854
855 1968 aaronmk
def empty_db(db, schema='public', **kw_args):
856
    '''For kw_args, see tables()'''
857
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
858 833 aaronmk
859 832 aaronmk
##### Heuristic queries
860
861 2104 aaronmk
def put(db, table, row, pkey_=None, row_ct_ref=None):
862 1554 aaronmk
    '''Recovers from errors.
863 2077 aaronmk
    Only works under PostgreSQL (uses INSERT RETURNING).
864
    '''
865 2104 aaronmk
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
866
867 471 aaronmk
    try:
868 2149 aaronmk
        cur = insert(db, table, row, pkey_, recover=True)
869 1554 aaronmk
        if row_ct_ref != None and cur.rowcount >= 0:
870
            row_ct_ref[0] += cur.rowcount
871
        return value(cur)
872 471 aaronmk
    except DuplicateKeyException, e:
873 2104 aaronmk
        return value(select(db, table, [pkey_],
874 1069 aaronmk
            util.dict_subset_right_join(row, e.cols), recover=True))
875 471 aaronmk
876 473 aaronmk
def get(db, table, row, pkey, row_ct_ref=None, create=False):
877 830 aaronmk
    '''Recovers from errors'''
878 2209 aaronmk
    try: return value(select(db, table, [pkey], row, limit=1, recover=True))
879 14 aaronmk
    except StopIteration:
880 40 aaronmk
        if not create: raise
881 471 aaronmk
        return put(db, table, row, pkey, row_ct_ref) # insert new row
882 2078 aaronmk
883 2508 aaronmk
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, into=None,
884 2552 aaronmk
    default=None, is_func=False):
885 2078 aaronmk
    '''Recovers from errors.
886
    Only works under PostgreSQL (uses INSERT RETURNING).
887 2131 aaronmk
    @param in_tables The main input table to select from, followed by a list of
888
        tables to join with it using the main input table's pkey
889 2312 aaronmk
    @param mapping dict(out_table_col=in_table_col, ...)
890
        * out_table_col: sql_gen.Col|str
891 2323 aaronmk
        * in_table_col: sql_gen.Col Wrap literal values in a sql_gen.NamedCol
892 2489 aaronmk
    @param into The table to contain the output and input pkeys.
893 2495 aaronmk
        Defaults to `out_table.name+'-pkeys'`.
894 2509 aaronmk
    @param default The *output* column to use as the pkey for missing rows.
895
        If this output column does not exist in the mapping, uses None.
896 2552 aaronmk
    @param is_func Whether out_table is the name of a SQL function, not a table
897 2312 aaronmk
    @return sql_gen.Col Where the output pkeys are made available
898 2078 aaronmk
    '''
899 2329 aaronmk
    out_table = sql_gen.as_Table(out_table)
900 2312 aaronmk
    for in_table_col in mapping.itervalues():
901
        assert isinstance(in_table_col, sql_gen.Col)
902 2495 aaronmk
    if into == None: into = out_table.name+'-pkeys'
903 2489 aaronmk
    into = sql_gen.as_Table(into)
904 2312 aaronmk
905 2450 aaronmk
    def log_debug(msg): db.log_debug(msg, level=1.5)
906 2505 aaronmk
    def col_ustr(str_):
907 2514 aaronmk
        return strings.repr_no_u(sql_gen.remove_col_rename(
908
            sql_gen.as_Col(str_)))
909 2450 aaronmk
910 2486 aaronmk
    log_debug('********** New iteration **********')
911 2505 aaronmk
    log_debug('Inserting these input columns into '+strings.as_tt(
912
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
913 2463 aaronmk
914 2382 aaronmk
    # Create input joins from list of input tables
915
    in_tables_ = in_tables[:] # don't modify input!
916
    in_tables0 = in_tables_.pop(0) # first table is separate
917 2279 aaronmk
    in_pkey = pkey(db, in_tables0, recover=True)
918 2285 aaronmk
    in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
919 2460 aaronmk
    input_joins = [in_tables0]+[sql_gen.Join(v,
920
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
921 2131 aaronmk
922 2486 aaronmk
    log_debug('Joining together input tables into temp table')
923 2395 aaronmk
    # Place in new table for speed and so don't modify input if values edited
924 2546 aaronmk
    in_table = sql_gen.Table(into.name.replace('-pkeys', '')+'-input')
925 2395 aaronmk
    flatten_cols = filter(sql_gen.is_table_col, mapping.values())
926
    mapping = dicts.join(mapping, flatten(db, in_table, input_joins,
927
        flatten_cols, preserve=[in_pkey_col], start=0))
928
    input_joins = [in_table]
929 2486 aaronmk
    db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
930 2395 aaronmk
931 2509 aaronmk
    # Resolve default value column
932
    try: default = mapping[default]
933
    except KeyError:
934
        if default != None:
935
            db.log_debug('Default value column '
936
                +strings.as_tt(strings.repr_no_u(default))
937 2511 aaronmk
                +' does not exist in mapping, falling back to None', level=2.1)
938 2509 aaronmk
            default = None
939
940 2279 aaronmk
    out_pkey = pkey(db, out_table, recover=True)
941 2285 aaronmk
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
942 2142 aaronmk
943 2387 aaronmk
    pkeys_names = [in_pkey, out_pkey]
944 2236 aaronmk
    pkeys_cols = [in_pkey_col, out_pkey_col]
945
946 2201 aaronmk
    pkeys_table_exists_ref = [False]
947 2420 aaronmk
    def insert_into_pkeys(joins, cols):
948
        query, params = mk_select(db, joins, cols, order_by=None, start=0)
949 2201 aaronmk
        if pkeys_table_exists_ref[0]:
950 2489 aaronmk
            insert_select(db, into, pkeys_names, query, params)
951 2201 aaronmk
        else:
952 2489 aaronmk
            run_query_into(db, query, params, into=into)
953 2201 aaronmk
            pkeys_table_exists_ref[0] = True
954
955 2429 aaronmk
    limit_ref = [None]
956 2380 aaronmk
    conds = set()
957 2233 aaronmk
    distinct_on = []
958 2325 aaronmk
    def mk_main_select(joins, cols):
959 2429 aaronmk
        return mk_select(db, joins, cols, conds, distinct_on,
960
            limit=limit_ref[0], start=0)
961 2132 aaronmk
962 2519 aaronmk
    exc_strs = set()
963 2309 aaronmk
    def log_exc(e):
964 2519 aaronmk
        e_str = exc.str_(e, first_line_only=True)
965
        log_debug('Caught exception: '+e_str)
966
        assert e_str not in exc_strs # avoid infinite loops
967
        exc_strs.add(e_str)
968 2451 aaronmk
    def remove_all_rows():
969 2450 aaronmk
        log_debug('Returning NULL for all rows')
970 2429 aaronmk
        limit_ref[0] = 0 # just create an empty pkeys table
971 2409 aaronmk
    def ignore(in_col, value):
972 2545 aaronmk
        in_col_str = strings.as_tt(repr(in_col))
973 2544 aaronmk
        db.log_debug('Adding index on '+in_col_str+' to enable fast filtering',
974
            level=2.5)
975 2537 aaronmk
        add_index(db, in_col)
976 2545 aaronmk
        log_debug('Ignoring rows with '+in_col_str+' = '
977
            +strings.as_tt(repr(value)))
978 2403 aaronmk
    def remove_rows(in_col, value):
979 2409 aaronmk
        ignore(in_col, value)
980 2378 aaronmk
        cond = (in_col, sql_gen.CompareCond(value, '!='))
981
        assert cond not in conds # avoid infinite loops
982 2380 aaronmk
        conds.add(cond)
983 2403 aaronmk
    def invalid2null(in_col, value):
984 2409 aaronmk
        ignore(in_col, value)
985 2403 aaronmk
        update(db, in_table, [(in_col, None)],
986
            sql_gen.ColValueCond(in_col, value))
987 2245 aaronmk
988 2206 aaronmk
    # Do inserts and selects
989 2257 aaronmk
    join_cols = {}
990 2495 aaronmk
    insert_out_pkeys = sql_gen.Table(into.name+'-insert_out_pkeys')
991
    insert_in_pkeys = sql_gen.Table(into.name+'-insert_in_pkeys')
992 2206 aaronmk
    while True:
993 2521 aaronmk
        if limit_ref[0] == 0: # special case
994
            log_debug('Creating an empty pkeys table')
995
            cur = run_query_into(db, *mk_select(db, out_table, [out_pkey],
996
                limit=limit_ref[0]), into=insert_out_pkeys)
997
            break # don't do main case
998
999 2303 aaronmk
        has_joins = join_cols != {}
1000
1001 2305 aaronmk
        # Prepare to insert new rows
1002 2325 aaronmk
        insert_joins = input_joins[:] # don't modify original!
1003 2403 aaronmk
        insert_args = dict(recover=True, cacheable=False)
1004 2303 aaronmk
        if has_joins:
1005 2317 aaronmk
            distinct_on = [v.to_Col() for v in join_cols.values()]
1006 2325 aaronmk
            insert_joins.append(sql_gen.Join(out_table, join_cols,
1007
                sql_gen.filter_out))
1008
        else:
1009 2404 aaronmk
            insert_args.update(dict(returning=out_pkey, into=insert_out_pkeys))
1010 2520 aaronmk
        main_select = mk_main_select(insert_joins, mapping.values())[0]
1011 2303 aaronmk
1012 2486 aaronmk
        log_debug('Trying to insert new rows')
1013 2206 aaronmk
        try:
1014 2518 aaronmk
            cur = insert_select(db, out_table, mapping.keys(), main_select,
1015
                **insert_args)
1016 2357 aaronmk
            break # insert successful
1017 2206 aaronmk
        except DuplicateKeyException, e:
1018 2309 aaronmk
            log_exc(e)
1019
1020 2258 aaronmk
            old_join_cols = join_cols.copy()
1021 2334 aaronmk
            join_cols.update(util.dict_subset(mapping, e.cols))
1022 2486 aaronmk
            log_debug('Ignoring existing rows, comparing on these columns:\n'
1023 2505 aaronmk
                +strings.as_inline_table(join_cols, ustr=col_ustr))
1024 2258 aaronmk
            assert join_cols != old_join_cols # avoid infinite loops
1025 2230 aaronmk
        except NullValueException, e:
1026 2309 aaronmk
            log_exc(e)
1027
1028 2230 aaronmk
            out_col, = e.cols
1029
            try: in_col = mapping[out_col]
1030 2356 aaronmk
            except KeyError:
1031 2486 aaronmk
                log_debug('Missing mapping for NOT NULL column '+out_col)
1032 2451 aaronmk
                remove_all_rows()
1033 2403 aaronmk
            else: remove_rows(in_col, None)
1034 2542 aaronmk
        except FunctionValueException, e:
1035
            log_exc(e)
1036
1037
            func_name = e.name
1038
            value = e.value
1039
            for out_col, in_col in mapping.iteritems():
1040
                in_col = sql_gen.remove_col_rename(in_col)
1041
                if (isinstance(in_col, sql_gen.FunctionCall)
1042
                    and in_col.function.name == func_name):
1043
                    invalid2null(in_col.args[0], value)
1044 2525 aaronmk
        except MissingCastException, e:
1045
            log_exc(e)
1046
1047
            out_col = e.col
1048 2534 aaronmk
            mapping[out_col] = sql_gen.wrap_in_func(e.type, mapping[out_col])
1049 2429 aaronmk
        except DatabaseErrors, e:
1050
            log_exc(e)
1051
1052 2531 aaronmk
            msg = 'No handler for exception: '+exc.str_(e)
1053 2451 aaronmk
            warnings.warn(DbWarning(msg))
1054
            log_debug(msg)
1055
            remove_all_rows()
1056 2358 aaronmk
        # after exception handled, rerun loop with additional constraints
1057 2132 aaronmk
1058 2357 aaronmk
    if row_ct_ref != None and cur.rowcount >= 0:
1059
        row_ct_ref[0] += cur.rowcount
1060
1061
    if has_joins:
1062
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols)]
1063 2486 aaronmk
        log_debug('Getting output table pkeys of existing/inserted rows')
1064 2420 aaronmk
        insert_into_pkeys(select_joins, pkeys_cols)
1065 2357 aaronmk
    else:
1066 2404 aaronmk
        add_row_num(db, insert_out_pkeys) # for joining with input pkeys
1067 2357 aaronmk
1068 2486 aaronmk
        log_debug('Getting input table pkeys of inserted rows')
1069 2357 aaronmk
        run_query_into(db, *mk_main_select(input_joins, [in_pkey]),
1070 2404 aaronmk
            into=insert_in_pkeys)
1071
        add_row_num(db, insert_in_pkeys) # for joining with output pkeys
1072 2357 aaronmk
1073 2428 aaronmk
        assert table_row_count(db, insert_out_pkeys) == table_row_count(db,
1074
            insert_in_pkeys)
1075
1076 2486 aaronmk
        log_debug('Combining output and input pkeys in inserted order')
1077 2404 aaronmk
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
1078 2357 aaronmk
            {row_num_col: sql_gen.join_same_not_null})]
1079 2420 aaronmk
        insert_into_pkeys(pkey_joins, pkeys_names)
1080 2357 aaronmk
1081 2486 aaronmk
    db.log_debug('Adding pkey on pkeys table to enable fast joins', level=2.5)
1082 2489 aaronmk
    index_pkey(db, into)
1083 2407 aaronmk
1084 2508 aaronmk
    log_debug('Setting pkeys of missing rows to '+strings.as_tt(repr(default)))
1085 2489 aaronmk
    missing_rows_joins = input_joins+[sql_gen.Join(into,
1086 2357 aaronmk
        {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
1087
        # must use join_same_not_null or query will take forever
1088 2420 aaronmk
    insert_into_pkeys(missing_rows_joins,
1089 2508 aaronmk
        [in_pkey_col, sql_gen.NamedCol(out_pkey, default)])
1090 2357 aaronmk
1091 2489 aaronmk
    assert table_row_count(db, into) == table_row_count(db, in_table)
1092 2428 aaronmk
1093 2489 aaronmk
    return sql_gen.Col(out_pkey, into)
1094 2115 aaronmk
1095
##### Data cleanup
1096
1097 2290 aaronmk
def cleanup_table(db, table, cols):
1098 2115 aaronmk
    def esc_name_(name): return esc_name(db, name)
1099
1100 2290 aaronmk
    table = sql_gen.as_Table(table).to_str(db)
1101 2115 aaronmk
    cols = map(esc_name_, cols)
1102
1103
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
1104
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
1105
            for col in cols))),
1106
        dict(null0='', null1=r'\N'))