Project

General

Profile

1
# Database access
2

    
3
import copy
4
import operator
5
import re
6
import warnings
7

    
8
import exc
9
import dicts
10
import iters
11
import lists
12
from Proxy import Proxy
13
import rand
14
import sql_gen
15
import strings
16
import util
17

    
18
##### Exceptions
19

    
20
def get_cur_query(cur, input_query=None, input_params=None):
21
    raw_query = None
22
    if hasattr(cur, 'query'): raw_query = cur.query
23
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
24
    
25
    if raw_query != None: return raw_query
26
    else: return '[input] '+strings.ustr(input_query)+' % '+repr(input_params)
27

    
28
def _add_cursor_info(e, *args, **kw_args):
29
    '''For params, see get_cur_query()'''
30
    exc.add_msg(e, 'query: '+str(get_cur_query(*args, **kw_args)))
31

    
32
class DbException(exc.ExceptionWithCause):
33
    def __init__(self, msg, cause=None, cur=None):
34
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
35
        if cur != None: _add_cursor_info(self, cur)
36

    
37
class ExceptionWithName(DbException):
38
    def __init__(self, name, cause=None):
39
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
40
        self.name = name
41

    
42
class ExceptionWithNameValue(DbException):
43
    def __init__(self, name, value, cause=None):
44
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name))
45
            +'; value: '+strings.as_tt(repr(value)), cause)
46
        self.name = name
47
        self.value = value
48

    
49
class ConstraintException(DbException):
50
    def __init__(self, name, cols, cause=None):
51
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
52
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
53
        self.name = name
54
        self.cols = cols
55

    
56
class MissingCastException(DbException):
57
    def __init__(self, type_, col, cause=None):
58
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
59
            +' on column: '+strings.as_tt(col), cause)
60
        self.type = type_
61
        self.col = col
62

    
63
class NameException(DbException): pass
64

    
65
class DuplicateKeyException(ConstraintException): pass
66

    
67
class NullValueException(ConstraintException): pass
68

    
69
class FunctionValueException(ExceptionWithNameValue): pass
70

    
71
class DuplicateTableException(ExceptionWithName): pass
72

    
73
class DuplicateFunctionException(ExceptionWithName): pass
74

    
75
class EmptyRowException(DbException): pass
76

    
77
##### Warnings
78

    
79
class DbWarning(UserWarning): pass
80

    
81
##### Result retrieval
82

    
83
def col_names(cur): return (col[0] for col in cur.description)
84

    
85
def rows(cur): return iter(lambda: cur.fetchone(), None)
86

    
87
def consume_rows(cur):
88
    '''Used to fetch all rows so result will be cached'''
89
    iters.consume_iter(rows(cur))
90

    
91
def next_row(cur): return rows(cur).next()
92

    
93
def row(cur):
94
    row_ = next_row(cur)
95
    consume_rows(cur)
96
    return row_
97

    
98
def next_value(cur): return next_row(cur)[0]
99

    
100
def value(cur): return row(cur)[0]
101

    
102
def values(cur): return iters.func_iter(lambda: next_value(cur))
103

    
104
def value_or_none(cur):
105
    try: return value(cur)
106
    except StopIteration: return None
107

    
108
##### Input validation
109

    
110
def check_name(name):
111
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
112
        +'" may contain only alphanumeric characters and _')
113

    
114
def esc_name_by_module(module, name, ignore_case=False):
115
    if module == 'psycopg2' or module == None:
116
        if ignore_case:
117
            # Don't enclose in quotes because this disables case-insensitivity
118
            check_name(name)
119
            return name
120
        else: quote = '"'
121
    elif module == 'MySQLdb': quote = '`'
122
    else: raise NotImplementedError("Can't escape name for "+module+' database')
123
    return sql_gen.esc_name(name, quote)
124

    
125
def esc_name_by_engine(engine, name, **kw_args):
126
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
127

    
128
def esc_name(db, name, **kw_args):
129
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
130

    
131
def qual_name(db, schema, table):
132
    def esc_name_(name): return esc_name(db, name)
133
    table = esc_name_(table)
134
    if schema != None: return esc_name_(schema)+'.'+table
135
    else: return table
136

    
137
##### Database connections
138

    
139
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
140

    
141
db_engines = {
142
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
143
    'PostgreSQL': ('psycopg2', {}),
144
}
145

    
146
DatabaseErrors_set = set([DbException])
147
DatabaseErrors = tuple(DatabaseErrors_set)
148

    
149
def _add_module(module):
150
    DatabaseErrors_set.add(module.DatabaseError)
151
    global DatabaseErrors
152
    DatabaseErrors = tuple(DatabaseErrors_set)
153

    
154
def db_config_str(db_config):
155
    return db_config['engine']+' database '+db_config['database']
156

    
157
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
158

    
159
log_debug_none = lambda msg, level=2: None
160

    
161
class DbConn:
162
    def __init__(self, db_config, serializable=True, autocommit=False,
163
        caching=True, log_debug=log_debug_none):
164
        self.db_config = db_config
165
        self.serializable = serializable
166
        self.autocommit = autocommit
167
        self.caching = caching
168
        self.log_debug = log_debug
169
        self.debug = log_debug != log_debug_none
170
        
171
        self.__db = None
172
        self.query_results = {}
173
        self._savepoint = 0
174
    
175
    def __getattr__(self, name):
176
        if name == '__dict__': raise Exception('getting __dict__')
177
        if name == 'db': return self._db()
178
        else: raise AttributeError()
179
    
180
    def __getstate__(self):
181
        state = copy.copy(self.__dict__) # shallow copy
182
        state['log_debug'] = None # don't pickle the debug callback
183
        state['_DbConn__db'] = None # don't pickle the connection
184
        return state
185
    
186
    def connected(self): return self.__db != None
187
    
188
    def _db(self):
189
        if self.__db == None:
190
            # Process db_config
191
            db_config = self.db_config.copy() # don't modify input!
192
            schemas = db_config.pop('schemas', None)
193
            module_name, mappings = db_engines[db_config.pop('engine')]
194
            module = __import__(module_name)
195
            _add_module(module)
196
            for orig, new in mappings.iteritems():
197
                try: util.rename_key(db_config, orig, new)
198
                except KeyError: pass
199
            
200
            # Connect
201
            self.__db = module.connect(**db_config)
202
            
203
            # Configure connection
204
            if self.serializable and not self.autocommit: run_raw_query(self,
205
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
206
            if schemas != None:
207
                schemas_ = ''.join((esc_name(self, s)+', '
208
                    for s in schemas.split(',')))
209
                run_raw_query(self, "SELECT set_config('search_path', \
210
%s || current_setting('search_path'), false)", [schemas_])
211
        
212
        return self.__db
213
    
214
    class DbCursor(Proxy):
215
        def __init__(self, outer):
216
            Proxy.__init__(self, outer.db.cursor())
217
            self.outer = outer
218
            self.query_results = outer.query_results
219
            self.query_lookup = None
220
            self.result = []
221
        
222
        def execute(self, query, params=None):
223
            self._is_insert = query.upper().find('INSERT') >= 0
224
            self.query_lookup = _query_lookup(query, params)
225
            try:
226
                try:
227
                    return_value = self.inner.execute(query, params)
228
                    self.outer.do_autocommit()
229
                finally: self.query = get_cur_query(self.inner)
230
            except Exception, e:
231
                _add_cursor_info(e, self, query, params)
232
                self.result = e # cache the exception as the result
233
                self._cache_result()
234
                raise
235
            # Fetch all rows so result will be cached
236
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
237
            return return_value
238
        
239
        def fetchone(self):
240
            row = self.inner.fetchone()
241
            if row != None: self.result.append(row)
242
            # otherwise, fetched all rows
243
            else: self._cache_result()
244
            return row
245
        
246
        def _cache_result(self):
247
            # For inserts, only cache exceptions since inserts are not
248
            # idempotent, but an invalid insert will always be invalid
249
            if self.query_results != None and (not self._is_insert
250
                or isinstance(self.result, Exception)):
251
                
252
                assert self.query_lookup != None
253
                self.query_results[self.query_lookup] = self.CacheCursor(
254
                    util.dict_subset(dicts.AttrsDictView(self),
255
                    ['query', 'result', 'rowcount', 'description']))
256
        
257
        class CacheCursor:
258
            def __init__(self, cached_result): self.__dict__ = cached_result
259
            
260
            def execute(self, *args, **kw_args):
261
                if isinstance(self.result, Exception): raise self.result
262
                # otherwise, result is a rows list
263
                self.iter = iter(self.result)
264
            
265
            def fetchone(self):
266
                try: return self.iter.next()
267
                except StopIteration: return None
268
    
269
    def esc_value(self, value):
270
        module = util.root_module(self.db)
271
        if module == 'psycopg2': str_ = self.db.cursor().mogrify('%s', [value])
272
        elif module == 'MySQLdb':
273
            import _mysql
274
            str_ = _mysql.escape_string(value)
275
        else: raise NotImplementedError("Can't escape value for "+module
276
            +' database')
277
        return strings.to_unicode(str_)
278
    
279
    def esc_name(self, name): return esc_name(self, name) # calls global func
280
    
281
    def run_query(self, query, params=None, cacheable=False, log_level=2,
282
        debug_msg_ref=None):
283
        '''
284
        @param log_ignore_excs The log_level will be increased by 2 if the query
285
            throws one of these exceptions.
286
        '''
287
        assert query != None
288
        
289
        if not self.caching: cacheable = False
290
        used_cache = False
291
        try:
292
            # Get cursor
293
            if cacheable:
294
                query_lookup = _query_lookup(query, params)
295
                try:
296
                    cur = self.query_results[query_lookup]
297
                    used_cache = True
298
                except KeyError: cur = self.DbCursor(self)
299
            else: cur = self.db.cursor()
300
            
301
            # Run query
302
            cur.execute(query, params)
303
        finally:
304
            if self.debug and debug_msg_ref != None:# only compute msg if needed
305
                if used_cache: cache_status = 'cache hit'
306
                elif cacheable: cache_status = 'cache miss'
307
                else: cache_status = 'non-cacheable'
308
                query_code = strings.as_code(str(get_cur_query(cur, query,
309
                    params)), 'SQL')
310
                debug_msg_ref[0] = 'DB query: '+cache_status+':\n'+query_code
311
        
312
        return cur
313
    
314
    def is_cached(self, query, params=None):
315
        return _query_lookup(query, params) in self.query_results
316
    
317
    def with_savepoint(self, func):
318
        savepoint = 'level_'+str(self._savepoint)
319
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
320
        self._savepoint += 1
321
        try: 
322
            try: return_val = func()
323
            finally:
324
                self._savepoint -= 1
325
                assert self._savepoint >= 0
326
        except:
327
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
328
            raise
329
        else:
330
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
331
            self.do_autocommit()
332
            return return_val
333
    
334
    def do_autocommit(self):
335
        '''Autocommits if outside savepoint'''
336
        assert self._savepoint >= 0
337
        if self.autocommit and self._savepoint == 0:
338
            self.log_debug('Autocommiting')
339
            self.db.commit()
340

    
341
connect = DbConn
342

    
343
##### Querying
344

    
345
def run_raw_query(db, *args, **kw_args):
346
    '''For params, see DbConn.run_query()'''
347
    return db.run_query(*args, **kw_args)
348

    
349
def mogrify(db, query, params):
350
    module = util.root_module(db.db)
351
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
352
    else: raise NotImplementedError("Can't mogrify query for "+module+
353
        ' database')
354

    
355
##### Recoverable querying
356

    
357
def with_savepoint(db, func): return db.with_savepoint(func)
358

    
359
def run_query(db, query, params=None, recover=None, cacheable=False,
360
    log_level=2, log_ignore_excs=None, **kw_args):
361
    '''For params, see run_raw_query()'''
362
    if recover == None: recover = False
363
    if log_ignore_excs == None: log_ignore_excs = ()
364
    log_ignore_excs = tuple(log_ignore_excs)
365
    
366
    debug_msg_ref = [None]
367
    try:
368
        try:
369
            def run(): return run_raw_query(db, query, params, cacheable,
370
                log_level, debug_msg_ref, **kw_args)
371
            if recover and not db.is_cached(query, params):
372
                return with_savepoint(db, run)
373
            else: return run() # don't need savepoint if cached
374
        except Exception, e:
375
            if not recover: raise # need savepoint to run index_cols()
376
            msg = exc.str_(e)
377
            
378
            match = re.search(r'duplicate key value violates unique constraint '
379
                r'"((_?[^\W_]+)_.+?)"', msg)
380
            if match:
381
                constraint, table = match.groups()
382
                try: cols = index_cols(db, table, constraint)
383
                except NotImplementedError: raise e
384
                else: raise DuplicateKeyException(constraint, cols, e)
385
            
386
            match = re.search(r'null value in column "(.+?)" violates not-null'
387
                r' constraint', msg)
388
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
389
            
390
            match = re.search(r'\b(?:invalid input (?:syntax|value)\b.*?'
391
                r'|date/time field value out of range): "(.+?)"\n'
392
                r'(?:(?s).*?)\bfunction "(.+?)"', msg)
393
            if match:
394
                value, name = match.groups()
395
                raise FunctionValueException(name, strings.to_unicode(value), e)
396
            
397
            match = re.search(r'column "(.+?)" is of type (.+?) but expression '
398
                r'is of type', msg)
399
            if match:
400
                col, type_ = match.groups()
401
                raise MissingCastException(type_, col, e)
402
            
403
            match = re.search(r'relation "(.+?)" already exists', msg)
404
            if match: raise DuplicateTableException(match.group(1), e)
405
            
406
            match = re.search(r'function "(.+?)" already exists', msg)
407
            if match: raise DuplicateFunctionException(match.group(1), e)
408
            
409
            raise # no specific exception raised
410
    except log_ignore_excs:
411
        log_level += 2
412
        raise
413
    finally:
414
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)
415

    
416
##### Basic queries
417

    
418
def next_version(name):
419
    '''Prepends the version # so it won't be removed if the name is truncated'''
420
    version = 1 # first existing name was version 0
421
    match = re.match(r'^#(\d+)-(.*)$', name)
422
    if match:
423
        version = int(match.group(1))+1
424
        name = match.group(2)
425
    return '#'+str(version)+'-'+name
426

    
427
def run_query_into(db, query, params, into=None, *args, **kw_args):
428
    '''Outputs a query to a temp table.
429
    For params, see run_query().
430
    '''
431
    if into == None: return run_query(db, query, params, *args, **kw_args)
432
    else: # place rows in temp table
433
        assert isinstance(into, sql_gen.Table)
434
        
435
        kw_args['recover'] = True
436
        kw_args.setdefault('log_ignore_excs', (DuplicateTableException,))
437
        
438
        temp = not db.autocommit # tables are permanent in autocommit mode
439
        # "temporary tables cannot specify a schema name", so remove schema
440
        if temp: into.schema = None
441
        
442
        while True:
443
            try:
444
                create_query = 'CREATE'
445
                if temp: create_query += ' TEMP'
446
                create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
447
                
448
                return run_query(db, create_query, params, *args, **kw_args)
449
                    # CREATE TABLE AS sets rowcount to # rows in query
450
            except DuplicateTableException, e:
451
                into.name = next_version(into.name)
452
                # try again with next version of name
453

    
454
order_by_pkey = object() # tells mk_select() to order by the pkey
455

    
456
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
457

    
458
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
459
    start=None, order_by=order_by_pkey, default_table=None):
460
    '''
461
    @param tables The single table to select from, or a list of tables to join
462
        together, with tables after the first being sql_gen.Join objects
463
    @param fields Use None to select all fields in the table
464
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
465
        * container can be any iterable type
466
        * compare_left_side: sql_gen.Code|str (for col name)
467
        * compare_right_side: sql_gen.ValueCond|literal value
468
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
469
        use all columns
470
    @return tuple(query, params)
471
    '''
472
    # Parse tables param
473
    if not lists.is_seq(tables): tables = [tables]
474
    tables = list(tables) # don't modify input! (list() copies input)
475
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
476
    
477
    # Parse other params
478
    if conds == None: conds = []
479
    elif isinstance(conds, dict): conds = conds.items()
480
    conds = list(conds) # don't modify input! (list() copies input)
481
    assert limit == None or type(limit) == int
482
    assert start == None or type(start) == int
483
    if order_by is order_by_pkey:
484
        if distinct_on != []: order_by = None
485
        else: order_by = pkey(db, table0, recover=True)
486
    
487
    query = 'SELECT'
488
    
489
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
490
    
491
    # DISTINCT ON columns
492
    if distinct_on != []:
493
        query += '\nDISTINCT'
494
        if distinct_on is not distinct_on_all:
495
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
496
    
497
    # Columns
498
    query += '\n'
499
    if fields == None: query += '*'
500
    else: query += '\n, '.join(map(parse_col, fields))
501
    
502
    # Main table
503
    query += '\nFROM '+table0.to_str(db)
504
    
505
    # Add joins
506
    left_table = table0
507
    for join_ in tables:
508
        table = join_.table
509
        
510
        # Parse special values
511
        if join_.type_ is sql_gen.filter_out: # filter no match
512
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
513
                None))
514
        
515
        query += '\n'+join_.to_str(db, left_table)
516
        
517
        left_table = table
518
    
519
    missing = True
520
    if conds != []:
521
        query += '\nWHERE\n'+('\nAND\n'.join(('('+sql_gen.ColValueCond(l, r)
522
            .to_str(db)+')' for l, r in conds)))
523
        missing = False
524
    if order_by != None:
525
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
526
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
527
    if start != None:
528
        if start != 0: query += '\nOFFSET '+str(start)
529
        missing = False
530
    if missing: warnings.warn(DbWarning(
531
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
532
    
533
    return (query, [])
534

    
535
def select(db, *args, **kw_args):
536
    '''For params, see mk_select() and run_query()'''
537
    recover = kw_args.pop('recover', None)
538
    cacheable = kw_args.pop('cacheable', True)
539
    log_level = kw_args.pop('log_level', 2)
540
    
541
    query, params = mk_select(db, *args, **kw_args)
542
    return run_query(db, query, params, recover, cacheable, log_level=log_level)
543

    
544
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
545
    returning=None, embeddable=False):
546
    '''
547
    @param returning str|None An inserted column (such as pkey) to return
548
    @param embeddable Whether the query should be embeddable as a nested SELECT.
549
        Warning: If you set this and cacheable=True when the query is run, the
550
        query will be fully cached, not just if it raises an exception.
551
    '''
552
    table = sql_gen.as_Table(table)
553
    if cols == []: cols = None # no cols (all defaults) = unknown col names
554
    if cols != None: cols = [sql_gen.as_Col(v).to_str(db) for v in cols]
555
    if select_query == None: select_query = 'DEFAULT VALUES'
556
    if returning != None: returning = sql_gen.as_Col(returning, table)
557
    
558
    # Build query
559
    first_line = 'INSERT INTO '+table.to_str(db)
560
    query = first_line
561
    if cols != None: query += '\n('+', '.join(cols)+')'
562
    query += '\n'+select_query
563
    
564
    if returning != None:
565
        returning_name = copy.copy(returning)
566
        returning_name.table = None
567
        returning_name = returning_name.to_str(db)
568
        query += '\nRETURNING '+returning_name
569
    
570
    if embeddable:
571
        assert returning != None
572
        
573
        # Create function
574
        function_name = sql_gen.clean_name(first_line)
575
        return_type = 'SETOF '+returning.to_str(db)+'%TYPE'
576
        while True:
577
            try:
578
                func_schema = None
579
                if not db.autocommit: func_schema = 'pg_temp'
580
                function = sql_gen.Table(function_name, func_schema).to_str(db)
581
                
582
                function_query = '''\
583
CREATE FUNCTION '''+function+'''()
584
RETURNS '''+return_type+'''
585
LANGUAGE sql
586
AS $$
587
'''+mogrify(db, query, params)+''';
588
$$;
589
'''
590
                run_query(db, function_query, recover=True, cacheable=True,
591
                    log_ignore_excs=(DuplicateFunctionException,))
592
                break # this version was successful
593
            except DuplicateFunctionException, e:
594
                function_name = next_version(function_name)
595
                # try again with next version of name
596
        
597
        # Return query that uses function
598
        func_table = sql_gen.NamedTable('f', sql_gen.CustomCode(function+'()'),
599
            [returning_name]) # AS clause requires function alias
600
        return mk_select(db, func_table, start=0, order_by=None)
601
    
602
    return (query, params)
603

    
604
def insert_select(db, *args, **kw_args):
605
    '''For params, see mk_insert_select() and run_query_into()
606
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
607
        values in
608
    '''
609
    into = kw_args.pop('into', None)
610
    if into != None: kw_args['embeddable'] = True
611
    recover = kw_args.pop('recover', None)
612
    cacheable = kw_args.pop('cacheable', True)
613
    
614
    query, params = mk_insert_select(db, *args, **kw_args)
615
    return run_query_into(db, query, params, into, recover=recover,
616
        cacheable=cacheable)
617

    
618
default = object() # tells insert() to use the default value for a column
619

    
620
def insert(db, table, row, *args, **kw_args):
621
    '''For params, see insert_select()'''
622
    if lists.is_seq(row): cols = None
623
    else:
624
        cols = row.keys()
625
        row = row.values()
626
    row = list(row) # ensure that "!= []" works
627
    
628
    # Check for special values
629
    labels = []
630
    values = []
631
    for value in row:
632
        if value is default: labels.append('DEFAULT')
633
        else:
634
            labels.append('%s')
635
            values.append(value)
636
    
637
    # Build query
638
    if values != []: query = 'VALUES ('+(', '.join(labels))+')'
639
    else: query = None
640
    
641
    return insert_select(db, table, cols, query, values, *args, **kw_args)
642

    
643
def mk_update(db, table, changes=None, cond=None):
644
    '''
645
    @param changes [(col, new_value),...]
646
        * container can be any iterable type
647
        * col: sql_gen.Code|str (for col name)
648
        * new_value: sql_gen.Code|literal value
649
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
650
    @return str query
651
    '''
652
    query = 'UPDATE '+sql_gen.as_Table(table).to_str(db)+'\nSET\n'
653
    query += ',\n'.join((sql_gen.to_name_only_col(col, table).to_str(db)+' = '
654
        +sql_gen.as_Value(new_value).to_str(db) for col, new_value in changes))
655
    if cond != None: query += '\nWHERE\n'+cond.to_str(db)
656
    
657
    return query
658

    
659
def update(db, *args, **kw_args):
660
    '''For params, see mk_update() and run_query()'''
661
    recover = kw_args.pop('recover', None)
662
    
663
    return run_query(db, mk_update(db, *args, **kw_args), [], recover)
664

    
665
def last_insert_id(db):
666
    module = util.root_module(db.db)
667
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
668
    elif module == 'MySQLdb': return db.insert_id()
669
    else: return None
670

    
671
def truncate(db, table, schema='public'):
672
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
673

    
674
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
675
    '''Creates a mapping from original column names (which may have collisions)
676
    to names that will be distinct among the columns' tables.
677
    This is meant to be used for several tables that are being joined together.
678
    @param cols The columns to combine. Duplicates will be removed.
679
    @param into The table for the new columns.
680
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
681
        columns will be included in the mapping even if they are not in cols.
682
        The tables of the provided Col objects will be changed to into, so make
683
        copies of them if you want to keep the original tables.
684
    @param as_items Whether to return a list of dict items instead of a dict
685
    @return dict(orig_col=new_col, ...)
686
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
687
        * new_col: sql_gen.Col(orig_col_name, into)
688
        * All mappings use the into table so its name can easily be
689
          changed for all columns at once
690
    '''
691
    cols = lists.uniqify(cols)
692
    
693
    items = []
694
    for col in preserve:
695
        orig_col = copy.copy(col)
696
        col.table = into
697
        items.append((orig_col, col))
698
    preserve = set(preserve)
699
    for col in cols:
700
        if col not in preserve: items.append((col, sql_gen.Col(str(col), into)))
701
    
702
    if not as_items: items = dict(items)
703
    return items
704

    
705
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
706
    '''For params, see mk_flatten_mapping()
707
    @return See return value of mk_flatten_mapping()
708
    '''
709
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
710
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
711
    run_query_into(db, *mk_select(db, joins, cols, limit=limit, start=start),
712
        into=into)
713
    return dict(items)
714

    
715
##### Database structure queries
716

    
717
def table_row_count(db, table, recover=None):
718
    return value(run_query(db, *mk_select(db, table, [sql_gen.row_count],
719
        order_by=None, start=0), recover=recover, log_level=3))
720

    
721
def table_cols(db, table, recover=None):
722
    return list(col_names(select(db, table, limit=0, order_by=None,
723
        recover=recover, log_level=4)))
724

    
725
def pkey(db, table, recover=None):
726
    '''Assumed to be first column in table'''
727
    return table_cols(db, table, recover)[0]
728

    
729
not_null_col = 'not_null'
730

    
731
def table_not_null_col(db, table, recover=None):
732
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
733
    if not_null_col in table_cols(db, table, recover): return not_null_col
734
    else: return pkey(db, table, recover)
735

    
736
def index_cols(db, table, index):
737
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
738
    automatically created. When you don't know whether something is a UNIQUE
739
    constraint or a UNIQUE index, use this function.'''
740
    module = util.root_module(db.db)
741
    if module == 'psycopg2':
742
        return list(values(run_query(db, '''\
743
SELECT attname
744
FROM
745
(
746
        SELECT attnum, attname
747
        FROM pg_index
748
        JOIN pg_class index ON index.oid = indexrelid
749
        JOIN pg_class table_ ON table_.oid = indrelid
750
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
751
        WHERE
752
            table_.relname = %(table)s
753
            AND index.relname = %(index)s
754
    UNION
755
        SELECT attnum, attname
756
        FROM
757
        (
758
            SELECT
759
                indrelid
760
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
761
                    AS indkey
762
            FROM pg_index
763
            JOIN pg_class index ON index.oid = indexrelid
764
            JOIN pg_class table_ ON table_.oid = indrelid
765
            WHERE
766
                table_.relname = %(table)s
767
                AND index.relname = %(index)s
768
        ) s
769
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
770
) s
771
ORDER BY attnum
772
''',
773
            {'table': table, 'index': index}, cacheable=True, log_level=4)))
774
    else: raise NotImplementedError("Can't list index columns for "+module+
775
        ' database')
776

    
777
def constraint_cols(db, table, constraint):
778
    module = util.root_module(db.db)
779
    if module == 'psycopg2':
780
        return list(values(run_query(db, '''\
781
SELECT attname
782
FROM pg_constraint
783
JOIN pg_class ON pg_class.oid = conrelid
784
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
785
WHERE
786
    relname = %(table)s
787
    AND conname = %(constraint)s
788
ORDER BY attnum
789
''',
790
            {'table': table, 'constraint': constraint})))
791
    else: raise NotImplementedError("Can't list constraint columns for "+module+
792
        ' database')
793

    
794
row_num_col = '_row_num'
795

    
796
def add_index(db, expr):
797
    '''Adds an index on a column or expression if it doesn't already exist.
798
    Currently, only function calls are supported as expressions.
799
    '''
800
    expr = copy.copy(expr) # don't modify input!
801
    
802
    # Extract col
803
    if isinstance(expr, sql_gen.FunctionCall):
804
        col = expr.args[0]
805
        expr = sql_gen.Expr(expr)
806
    else: col = expr
807
    assert sql_gen.is_table_col(col)
808
    
809
    index = sql_gen.as_Table(str(expr))
810
    table = col.table
811
    col.table = None
812
    try: run_query(db, 'CREATE INDEX '+index.to_str(db)+' ON '+table.to_str(db)
813
        +' ('+expr.to_str(db)+')', recover=True, cacheable=True, log_level=3)
814
    except DuplicateTableException: pass # index already existed
815

    
816
def index_pkey(db, table, recover=None):
817
    '''Makes the first column in a table the primary key.
818
    @pre The table must not already have a primary key.
819
    '''
820
    table = sql_gen.as_Table(table)
821
    
822
    index = sql_gen.as_Table(table.name+'_pkey')
823
    col = sql_gen.to_name_only_col(pkey(db, table, recover))
824
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD CONSTRAINT '
825
        +index.to_str(db)+' PRIMARY KEY('+col.to_str(db)+')', recover=recover,
826
        log_level=3)
827

    
828
def add_row_num(db, table):
829
    '''Adds a row number column to a table. Its name is in row_num_col. It will
830
    be the primary key.'''
831
    table = sql_gen.as_Table(table).to_str(db)
832
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
833
        +' serial NOT NULL PRIMARY KEY', log_level=3)
834

    
835
def tables(db, schema_like='public', table_like='%'):
836
    module = util.root_module(db.db)
837
    params = {'schema_like': schema_like, 'table_like': table_like}
838
    if module == 'psycopg2':
839
        return values(run_query(db, '''\
840
SELECT tablename
841
FROM pg_tables
842
WHERE
843
    schemaname LIKE %(schema_like)s
844
    AND tablename LIKE %(table_like)s
845
ORDER BY tablename
846
''',
847
            params, cacheable=True))
848
    elif module == 'MySQLdb':
849
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
850
            cacheable=True))
851
    else: raise NotImplementedError("Can't list tables for "+module+' database')
852

    
853
##### Database management
854

    
855
def empty_db(db, schema='public', **kw_args):
856
    '''For kw_args, see tables()'''
857
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
858

    
859
##### Heuristic queries
860

    
861
def put(db, table, row, pkey_=None, row_ct_ref=None):
862
    '''Recovers from errors.
863
    Only works under PostgreSQL (uses INSERT RETURNING).
864
    '''
865
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
866
    
867
    try:
868
        cur = insert(db, table, row, pkey_, recover=True)
869
        if row_ct_ref != None and cur.rowcount >= 0:
870
            row_ct_ref[0] += cur.rowcount
871
        return value(cur)
872
    except DuplicateKeyException, e:
873
        return value(select(db, table, [pkey_],
874
            util.dict_subset_right_join(row, e.cols), recover=True))
875

    
876
def get(db, table, row, pkey, row_ct_ref=None, create=False):
877
    '''Recovers from errors'''
878
    try: return value(select(db, table, [pkey], row, limit=1, recover=True))
879
    except StopIteration:
880
        if not create: raise
881
        return put(db, table, row, pkey, row_ct_ref) # insert new row
882

    
883
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, into=None,
884
    default=None):
885
    '''Recovers from errors.
886
    Only works under PostgreSQL (uses INSERT RETURNING).
887
    @param in_tables The main input table to select from, followed by a list of
888
        tables to join with it using the main input table's pkey
889
    @param mapping dict(out_table_col=in_table_col, ...)
890
        * out_table_col: sql_gen.Col|str
891
        * in_table_col: sql_gen.Col Wrap literal values in a sql_gen.NamedCol
892
    @param into The table to contain the output and input pkeys.
893
        Defaults to `out_table.name+'-pkeys'`.
894
    @param default The *output* column to use as the pkey for missing rows.
895
        If this output column does not exist in the mapping, uses None.
896
    @return sql_gen.Col Where the output pkeys are made available
897
    '''
898
    out_table = sql_gen.as_Table(out_table)
899
    for in_table_col in mapping.itervalues():
900
        assert isinstance(in_table_col, sql_gen.Col)
901
    if into == None: into = out_table.name+'-pkeys'
902
    into = sql_gen.as_Table(into)
903
    
904
    def log_debug(msg): db.log_debug(msg, level=1.5)
905
    def col_ustr(str_):
906
        return strings.repr_no_u(sql_gen.remove_col_rename(
907
            sql_gen.as_Col(str_)))
908
    
909
    log_debug('********** New iteration **********')
910
    log_debug('Inserting these input columns into '+strings.as_tt(
911
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
912
    
913
    # Create input joins from list of input tables
914
    in_tables_ = in_tables[:] # don't modify input!
915
    in_tables0 = in_tables_.pop(0) # first table is separate
916
    in_pkey = pkey(db, in_tables0, recover=True)
917
    in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
918
    input_joins = [in_tables0]+[sql_gen.Join(v,
919
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
920
    
921
    log_debug('Joining together input tables into temp table')
922
    # Place in new table for speed and so don't modify input if values edited
923
    in_table = sql_gen.Table(into.name.replace('-pkeys', '')+'-input')
924
    flatten_cols = filter(sql_gen.is_table_col, mapping.values())
925
    mapping = dicts.join(mapping, flatten(db, in_table, input_joins,
926
        flatten_cols, preserve=[in_pkey_col], start=0))
927
    input_joins = [in_table]
928
    db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
929
    
930
    # Resolve default value column
931
    try: default = mapping[default]
932
    except KeyError:
933
        if default != None:
934
            db.log_debug('Default value column '
935
                +strings.as_tt(strings.repr_no_u(default))
936
                +' does not exist in mapping, falling back to None', level=2.1)
937
            default = None
938
    
939
    out_pkey = pkey(db, out_table, recover=True)
940
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
941
    
942
    pkeys_names = [in_pkey, out_pkey]
943
    pkeys_cols = [in_pkey_col, out_pkey_col]
944
    
945
    pkeys_table_exists_ref = [False]
946
    def insert_into_pkeys(joins, cols):
947
        query, params = mk_select(db, joins, cols, order_by=None, start=0)
948
        if pkeys_table_exists_ref[0]:
949
            insert_select(db, into, pkeys_names, query, params)
950
        else:
951
            run_query_into(db, query, params, into=into)
952
            pkeys_table_exists_ref[0] = True
953
    
954
    limit_ref = [None]
955
    conds = set()
956
    distinct_on = []
957
    def mk_main_select(joins, cols):
958
        return mk_select(db, joins, cols, conds, distinct_on,
959
            limit=limit_ref[0], start=0)
960
    
961
    exc_strs = set()
962
    def log_exc(e):
963
        e_str = exc.str_(e, first_line_only=True)
964
        log_debug('Caught exception: '+e_str)
965
        assert e_str not in exc_strs # avoid infinite loops
966
        exc_strs.add(e_str)
967
    def remove_all_rows():
968
        log_debug('Returning NULL for all rows')
969
        limit_ref[0] = 0 # just create an empty pkeys table
970
    def ignore(in_col, value):
971
        in_col_str = strings.as_tt(repr(in_col))
972
        db.log_debug('Adding index on '+in_col_str+' to enable fast filtering',
973
            level=2.5)
974
        add_index(db, in_col)
975
        log_debug('Ignoring rows with '+in_col_str+' = '
976
            +strings.as_tt(repr(value)))
977
    def remove_rows(in_col, value):
978
        ignore(in_col, value)
979
        cond = (in_col, sql_gen.CompareCond(value, '!='))
980
        assert cond not in conds # avoid infinite loops
981
        conds.add(cond)
982
    def invalid2null(in_col, value):
983
        ignore(in_col, value)
984
        update(db, in_table, [(in_col, None)],
985
            sql_gen.ColValueCond(in_col, value))
986
    
987
    # Do inserts and selects
988
    join_cols = {}
989
    insert_out_pkeys = sql_gen.Table(into.name+'-insert_out_pkeys')
990
    insert_in_pkeys = sql_gen.Table(into.name+'-insert_in_pkeys')
991
    while True:
992
        if limit_ref[0] == 0: # special case
993
            log_debug('Creating an empty pkeys table')
994
            cur = run_query_into(db, *mk_select(db, out_table, [out_pkey],
995
                limit=limit_ref[0]), into=insert_out_pkeys)
996
            break # don't do main case
997
        
998
        has_joins = join_cols != {}
999
        
1000
        # Prepare to insert new rows
1001
        insert_joins = input_joins[:] # don't modify original!
1002
        insert_args = dict(recover=True, cacheable=False)
1003
        if has_joins:
1004
            distinct_on = [v.to_Col() for v in join_cols.values()]
1005
            insert_joins.append(sql_gen.Join(out_table, join_cols,
1006
                sql_gen.filter_out))
1007
        else:
1008
            insert_args.update(dict(returning=out_pkey, into=insert_out_pkeys))
1009
        main_select = mk_main_select(insert_joins, mapping.values())[0]
1010
        
1011
        log_debug('Trying to insert new rows')
1012
        try:
1013
            cur = insert_select(db, out_table, mapping.keys(), main_select,
1014
                **insert_args)
1015
            break # insert successful
1016
        except DuplicateKeyException, e:
1017
            log_exc(e)
1018
            
1019
            old_join_cols = join_cols.copy()
1020
            join_cols.update(util.dict_subset(mapping, e.cols))
1021
            log_debug('Ignoring existing rows, comparing on these columns:\n'
1022
                +strings.as_inline_table(join_cols, ustr=col_ustr))
1023
            assert join_cols != old_join_cols # avoid infinite loops
1024
        except NullValueException, e:
1025
            log_exc(e)
1026
            
1027
            out_col, = e.cols
1028
            try: in_col = mapping[out_col]
1029
            except KeyError:
1030
                log_debug('Missing mapping for NOT NULL column '+out_col)
1031
                remove_all_rows()
1032
            else: remove_rows(in_col, None)
1033
        except FunctionValueException, e:
1034
            log_exc(e)
1035
            
1036
            func_name = e.name
1037
            value = e.value
1038
            for out_col, in_col in mapping.iteritems():
1039
                in_col = sql_gen.remove_col_rename(in_col)
1040
                if (isinstance(in_col, sql_gen.FunctionCall)
1041
                    and in_col.function.name == func_name):
1042
                    invalid2null(in_col.args[0], value)
1043
        except MissingCastException, e:
1044
            log_exc(e)
1045
            
1046
            out_col = e.col
1047
            mapping[out_col] = sql_gen.wrap_in_func(e.type, mapping[out_col])
1048
        except DatabaseErrors, e:
1049
            log_exc(e)
1050
            
1051
            msg = 'No handler for exception: '+exc.str_(e)
1052
            warnings.warn(DbWarning(msg))
1053
            log_debug(msg)
1054
            remove_all_rows()
1055
        # after exception handled, rerun loop with additional constraints
1056
    
1057
    if row_ct_ref != None and cur.rowcount >= 0:
1058
        row_ct_ref[0] += cur.rowcount
1059
    
1060
    if has_joins:
1061
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols)]
1062
        log_debug('Getting output table pkeys of existing/inserted rows')
1063
        insert_into_pkeys(select_joins, pkeys_cols)
1064
    else:
1065
        add_row_num(db, insert_out_pkeys) # for joining with input pkeys
1066
        
1067
        log_debug('Getting input table pkeys of inserted rows')
1068
        run_query_into(db, *mk_main_select(input_joins, [in_pkey]),
1069
            into=insert_in_pkeys)
1070
        add_row_num(db, insert_in_pkeys) # for joining with output pkeys
1071
        
1072
        assert table_row_count(db, insert_out_pkeys) == table_row_count(db,
1073
            insert_in_pkeys)
1074
        
1075
        log_debug('Combining output and input pkeys in inserted order')
1076
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
1077
            {row_num_col: sql_gen.join_same_not_null})]
1078
        insert_into_pkeys(pkey_joins, pkeys_names)
1079
    
1080
    db.log_debug('Adding pkey on pkeys table to enable fast joins', level=2.5)
1081
    index_pkey(db, into)
1082
    
1083
    log_debug('Setting pkeys of missing rows to '+strings.as_tt(repr(default)))
1084
    missing_rows_joins = input_joins+[sql_gen.Join(into,
1085
        {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
1086
        # must use join_same_not_null or query will take forever
1087
    insert_into_pkeys(missing_rows_joins,
1088
        [in_pkey_col, sql_gen.NamedCol(out_pkey, default)])
1089
    
1090
    assert table_row_count(db, into) == table_row_count(db, in_table)
1091
    
1092
    return sql_gen.Col(out_pkey, into)
1093

    
1094
##### Data cleanup
1095

    
1096
def cleanup_table(db, table, cols):
1097
    def esc_name_(name): return esc_name(db, name)
1098
    
1099
    table = sql_gen.as_Table(table).to_str(db)
1100
    cols = map(esc_name_, cols)
1101
    
1102
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
1103
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
1104
            for col in cols))),
1105
        dict(null0='', null1=r'\N'))
(24-24/36)