Project

General

Profile

1
# Database access
2

    
3
import copy
4
import operator
5
import re
6
import warnings
7

    
8
import exc
9
import dicts
10
import iters
11
import lists
12
from Proxy import Proxy
13
import rand
14
import sql_gen
15
import strings
16
import util
17

    
18
##### Exceptions
19

    
20
def get_cur_query(cur, input_query=None, input_params=None):
21
    raw_query = None
22
    if hasattr(cur, 'query'): raw_query = cur.query
23
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
24
    
25
    if raw_query != None: return raw_query
26
    else: return '[input] '+strings.ustr(input_query)+' % '+repr(input_params)
27

    
28
def _add_cursor_info(e, *args, **kw_args):
29
    '''For params, see get_cur_query()'''
30
    exc.add_msg(e, 'query: '+str(get_cur_query(*args, **kw_args)))
31

    
32
class DbException(exc.ExceptionWithCause):
33
    def __init__(self, msg, cause=None, cur=None):
34
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
35
        if cur != None: _add_cursor_info(self, cur)
36

    
37
class ExceptionWithName(DbException):
38
    def __init__(self, name, cause=None):
39
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
40
        self.name = name
41

    
42
class ExceptionWithNameValue(DbException):
43
    def __init__(self, name, value, cause=None):
44
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name))
45
            +'; value: '+strings.as_tt(repr(value)), cause)
46
        self.name = name
47
        self.value = value
48

    
49
class ConstraintException(DbException):
50
    def __init__(self, name, cols, cause=None):
51
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
52
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
53
        self.name = name
54
        self.cols = cols
55

    
56
class MissingCastException(DbException):
57
    def __init__(self, type_, col, cause=None):
58
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
59
            +' on column: '+strings.as_tt(col), cause)
60
        self.type = type_
61
        self.col = col
62

    
63
class NameException(DbException): pass
64

    
65
class DuplicateKeyException(ConstraintException): pass
66

    
67
class NullValueException(ConstraintException): pass
68

    
69
class FunctionValueException(ExceptionWithNameValue): pass
70

    
71
class DuplicateTableException(ExceptionWithName): pass
72

    
73
class DuplicateFunctionException(ExceptionWithName): pass
74

    
75
class EmptyRowException(DbException): pass
76

    
77
##### Warnings
78

    
79
class DbWarning(UserWarning): pass
80

    
81
##### Result retrieval
82

    
83
def col_names(cur): return (col[0] for col in cur.description)
84

    
85
def rows(cur): return iter(lambda: cur.fetchone(), None)
86

    
87
def consume_rows(cur):
88
    '''Used to fetch all rows so result will be cached'''
89
    iters.consume_iter(rows(cur))
90

    
91
def next_row(cur): return rows(cur).next()
92

    
93
def row(cur):
94
    row_ = next_row(cur)
95
    consume_rows(cur)
96
    return row_
97

    
98
def next_value(cur): return next_row(cur)[0]
99

    
100
def value(cur): return row(cur)[0]
101

    
102
def values(cur): return iters.func_iter(lambda: next_value(cur))
103

    
104
def value_or_none(cur):
105
    try: return value(cur)
106
    except StopIteration: return None
107

    
108
##### Input validation
109

    
110
def check_name(name):
111
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
112
        +'" may contain only alphanumeric characters and _')
113

    
114
def esc_name_by_module(module, name, ignore_case=False):
115
    if module == 'psycopg2' or module == None:
116
        if ignore_case:
117
            # Don't enclose in quotes because this disables case-insensitivity
118
            check_name(name)
119
            return name
120
        else: quote = '"'
121
    elif module == 'MySQLdb': quote = '`'
122
    else: raise NotImplementedError("Can't escape name for "+module+' database')
123
    return sql_gen.esc_name(name, quote)
124

    
125
def esc_name_by_engine(engine, name, **kw_args):
126
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
127

    
128
def esc_name(db, name, **kw_args):
129
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
130

    
131
def qual_name(db, schema, table):
132
    def esc_name_(name): return esc_name(db, name)
133
    table = esc_name_(table)
134
    if schema != None: return esc_name_(schema)+'.'+table
135
    else: return table
136

    
137
##### Database connections
138

    
139
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
140

    
141
db_engines = {
142
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
143
    'PostgreSQL': ('psycopg2', {}),
144
}
145

    
146
DatabaseErrors_set = set([DbException])
147
DatabaseErrors = tuple(DatabaseErrors_set)
148

    
149
def _add_module(module):
150
    DatabaseErrors_set.add(module.DatabaseError)
151
    global DatabaseErrors
152
    DatabaseErrors = tuple(DatabaseErrors_set)
153

    
154
def db_config_str(db_config):
155
    return db_config['engine']+' database '+db_config['database']
156

    
157
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
158

    
159
log_debug_none = lambda msg, level=2: None
160

    
161
class DbConn:
162
    def __init__(self, db_config, serializable=True, autocommit=False,
163
        caching=True, log_debug=log_debug_none):
164
        self.db_config = db_config
165
        self.serializable = serializable
166
        self.autocommit = autocommit
167
        self.caching = caching
168
        self.log_debug = log_debug
169
        self.debug = log_debug != log_debug_none
170
        
171
        self.__db = None
172
        self.query_results = {}
173
        self._savepoint = 0
174
    
175
    def __getattr__(self, name):
176
        if name == '__dict__': raise Exception('getting __dict__')
177
        if name == 'db': return self._db()
178
        else: raise AttributeError()
179
    
180
    def __getstate__(self):
181
        state = copy.copy(self.__dict__) # shallow copy
182
        state['log_debug'] = None # don't pickle the debug callback
183
        state['_DbConn__db'] = None # don't pickle the connection
184
        return state
185
    
186
    def connected(self): return self.__db != None
187
    
188
    def _db(self):
189
        if self.__db == None:
190
            # Process db_config
191
            db_config = self.db_config.copy() # don't modify input!
192
            schemas = db_config.pop('schemas', None)
193
            module_name, mappings = db_engines[db_config.pop('engine')]
194
            module = __import__(module_name)
195
            _add_module(module)
196
            for orig, new in mappings.iteritems():
197
                try: util.rename_key(db_config, orig, new)
198
                except KeyError: pass
199
            
200
            # Connect
201
            self.__db = module.connect(**db_config)
202
            
203
            # Configure connection
204
            if self.serializable and not self.autocommit: run_raw_query(self,
205
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
206
            if schemas != None:
207
                schemas_ = ''.join((esc_name(self, s)+', '
208
                    for s in schemas.split(',')))
209
                run_raw_query(self, "SELECT set_config('search_path', \
210
%s || current_setting('search_path'), false)", [schemas_])
211
        
212
        return self.__db
213
    
214
    class DbCursor(Proxy):
215
        def __init__(self, outer):
216
            Proxy.__init__(self, outer.db.cursor())
217
            self.outer = outer
218
            self.query_results = outer.query_results
219
            self.query_lookup = None
220
            self.result = []
221
        
222
        def execute(self, query, params=None):
223
            self._is_insert = query.upper().find('INSERT') >= 0
224
            self.query_lookup = _query_lookup(query, params)
225
            try:
226
                try:
227
                    return_value = self.inner.execute(query, params)
228
                    self.outer.do_autocommit()
229
                finally: self.query = get_cur_query(self.inner)
230
            except Exception, e:
231
                _add_cursor_info(e, self, query, params)
232
                self.result = e # cache the exception as the result
233
                self._cache_result()
234
                raise
235
            # Fetch all rows so result will be cached
236
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
237
            return return_value
238
        
239
        def fetchone(self):
240
            row = self.inner.fetchone()
241
            if row != None: self.result.append(row)
242
            # otherwise, fetched all rows
243
            else: self._cache_result()
244
            return row
245
        
246
        def _cache_result(self):
247
            # For inserts, only cache exceptions since inserts are not
248
            # idempotent, but an invalid insert will always be invalid
249
            if self.query_results != None and (not self._is_insert
250
                or isinstance(self.result, Exception)):
251
                
252
                assert self.query_lookup != None
253
                self.query_results[self.query_lookup] = self.CacheCursor(
254
                    util.dict_subset(dicts.AttrsDictView(self),
255
                    ['query', 'result', 'rowcount', 'description']))
256
        
257
        class CacheCursor:
258
            def __init__(self, cached_result): self.__dict__ = cached_result
259
            
260
            def execute(self, *args, **kw_args):
261
                if isinstance(self.result, Exception): raise self.result
262
                # otherwise, result is a rows list
263
                self.iter = iter(self.result)
264
            
265
            def fetchone(self):
266
                try: return self.iter.next()
267
                except StopIteration: return None
268
    
269
    def esc_value(self, value):
270
        module = util.root_module(self.db)
271
        if module == 'psycopg2': str_ = self.db.cursor().mogrify('%s', [value])
272
        elif module == 'MySQLdb':
273
            import _mysql
274
            str_ = _mysql.escape_string(value)
275
        else: raise NotImplementedError("Can't escape value for "+module
276
            +' database')
277
        return strings.to_unicode(str_)
278
    
279
    def esc_name(self, name): return esc_name(self, name) # calls global func
280
    
281
    def run_query(self, query, params=None, cacheable=False, log_level=2,
282
        debug_msg_ref=None):
283
        '''
284
        @param log_ignore_excs The log_level will be increased by 2 if the query
285
            throws one of these exceptions.
286
        '''
287
        assert query != None
288
        
289
        if not self.caching: cacheable = False
290
        used_cache = False
291
        try:
292
            # Get cursor
293
            if cacheable:
294
                query_lookup = _query_lookup(query, params)
295
                try:
296
                    cur = self.query_results[query_lookup]
297
                    used_cache = True
298
                except KeyError: cur = self.DbCursor(self)
299
            else: cur = self.db.cursor()
300
            
301
            # Run query
302
            cur.execute(query, params)
303
        finally:
304
            if self.debug and debug_msg_ref != None:# only compute msg if needed
305
                if used_cache: cache_status = 'cache hit'
306
                elif cacheable: cache_status = 'cache miss'
307
                else: cache_status = 'non-cacheable'
308
                query_code = strings.as_code(str(get_cur_query(cur, query,
309
                    params)), 'SQL')
310
                debug_msg_ref[0] = 'DB query: '+cache_status+':\n'+query_code
311
        
312
        return cur
313
    
314
    def is_cached(self, query, params=None):
315
        return _query_lookup(query, params) in self.query_results
316
    
317
    def with_savepoint(self, func):
318
        savepoint = 'level_'+str(self._savepoint)
319
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
320
        self._savepoint += 1
321
        try: 
322
            try: return_val = func()
323
            finally:
324
                self._savepoint -= 1
325
                assert self._savepoint >= 0
326
        except:
327
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
328
            raise
329
        else:
330
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
331
            self.do_autocommit()
332
            return return_val
333
    
334
    def do_autocommit(self):
335
        '''Autocommits if outside savepoint'''
336
        assert self._savepoint >= 0
337
        if self.autocommit and self._savepoint == 0:
338
            self.log_debug('Autocommiting')
339
            self.db.commit()
340

    
341
connect = DbConn
342

    
343
##### Querying
344

    
345
def run_raw_query(db, *args, **kw_args):
346
    '''For params, see DbConn.run_query()'''
347
    return db.run_query(*args, **kw_args)
348

    
349
def mogrify(db, query, params):
350
    module = util.root_module(db.db)
351
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
352
    else: raise NotImplementedError("Can't mogrify query for "+module+
353
        ' database')
354

    
355
##### Recoverable querying
356

    
357
def with_savepoint(db, func): return db.with_savepoint(func)
358

    
359
def run_query(db, query, params=None, recover=None, cacheable=False,
360
    log_level=2, log_ignore_excs=None, **kw_args):
361
    '''For params, see run_raw_query()'''
362
    if recover == None: recover = False
363
    if log_ignore_excs == None: log_ignore_excs = ()
364
    log_ignore_excs = tuple(log_ignore_excs)
365
    
366
    debug_msg_ref = [None]
367
    try:
368
        try:
369
            def run(): return run_raw_query(db, query, params, cacheable,
370
                log_level, debug_msg_ref, **kw_args)
371
            if recover and not db.is_cached(query, params):
372
                return with_savepoint(db, run)
373
            else: return run() # don't need savepoint if cached
374
        except Exception, e:
375
            if not recover: raise # need savepoint to run index_cols()
376
            msg = exc.str_(e)
377
            
378
            match = re.search(r'duplicate key value violates unique constraint '
379
                r'"((_?[^\W_]+)_.+?)"', msg)
380
            if match:
381
                constraint, table = match.groups()
382
                try: cols = index_cols(db, table, constraint)
383
                except NotImplementedError: raise e
384
                else: raise DuplicateKeyException(constraint, cols, e)
385
            
386
            match = re.search(r'null value in column "(.+?)" violates not-null'
387
                r' constraint', msg)
388
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
389
            
390
            match = re.search(r'\b(?:invalid input (?:syntax|value)\b.*?'
391
                r'|date/time field value out of range): "(.+?)"\n'
392
                r'(?:(?s).*?)\bfunction "(.+?)".*?\bat assignment', msg)
393
            if match:
394
                value, name = match.groups()
395
                raise FunctionValueException(name, strings.to_unicode(value), e)
396
            
397
            match = re.search(r'column "(.+?)" is of type (.+?) but expression '
398
                r'is of type', msg)
399
            if match:
400
                col, type_ = match.groups()
401
                raise MissingCastException(type_, col, e)
402
            
403
            match = re.search(r'relation "(.+?)" already exists', msg)
404
            if match: raise DuplicateTableException(match.group(1), e)
405
            
406
            match = re.search(r'function "(.+?)" already exists', msg)
407
            if match: raise DuplicateFunctionException(match.group(1), e)
408
            
409
            raise # no specific exception raised
410
    except log_ignore_excs:
411
        log_level += 2
412
        raise
413
    finally:
414
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)
415

    
416
##### Basic queries
417

    
418
def next_version(name):
419
    '''Prepends the version # so it won't be removed if the name is truncated'''
420
    version = 1 # first existing name was version 0
421
    match = re.match(r'^#(\d+)-(.*)$', name)
422
    if match:
423
        version = int(match.group(1))+1
424
        name = match.group(2)
425
    return '#'+str(version)+'-'+name
426

    
427
def run_query_into(db, query, params, into=None, *args, **kw_args):
428
    '''Outputs a query to a temp table.
429
    For params, see run_query().
430
    '''
431
    if into == None: return run_query(db, query, params, *args, **kw_args)
432
    else: # place rows in temp table
433
        assert isinstance(into, sql_gen.Table)
434
        
435
        kw_args['recover'] = True
436
        kw_args.setdefault('log_ignore_excs', (DuplicateTableException,))
437
        
438
        temp = not db.autocommit # tables are permanent in autocommit mode
439
        # "temporary tables cannot specify a schema name", so remove schema
440
        if temp: into.schema = None
441
        
442
        while True:
443
            try:
444
                create_query = 'CREATE'
445
                if temp: create_query += ' TEMP'
446
                create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
447
                
448
                return run_query(db, create_query, params, *args, **kw_args)
449
                    # CREATE TABLE AS sets rowcount to # rows in query
450
            except DuplicateTableException, e:
451
                into.name = next_version(into.name)
452
                # try again with next version of name
453

    
454
order_by_pkey = object() # tells mk_select() to order by the pkey
455

    
456
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
457

    
458
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
459
    start=None, order_by=order_by_pkey, default_table=None):
460
    '''
461
    @param tables The single table to select from, or a list of tables to join
462
        together, with tables after the first being sql_gen.Join objects
463
    @param fields Use None to select all fields in the table
464
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
465
        * container can be any iterable type
466
        * compare_left_side: sql_gen.Code|str (for col name)
467
        * compare_right_side: sql_gen.ValueCond|literal value
468
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
469
        use all columns
470
    @return tuple(query, params)
471
    '''
472
    # Parse tables param
473
    if not lists.is_seq(tables): tables = [tables]
474
    tables = list(tables) # don't modify input! (list() copies input)
475
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
476
    
477
    # Parse other params
478
    if conds == None: conds = []
479
    elif isinstance(conds, dict): conds = conds.items()
480
    conds = list(conds) # don't modify input! (list() copies input)
481
    assert limit == None or type(limit) == int
482
    assert start == None or type(start) == int
483
    if order_by is order_by_pkey:
484
        if distinct_on != []: order_by = None
485
        else: order_by = pkey(db, table0, recover=True)
486
    
487
    query = 'SELECT'
488
    
489
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
490
    
491
    # DISTINCT ON columns
492
    if distinct_on != []:
493
        query += '\nDISTINCT'
494
        if distinct_on is not distinct_on_all:
495
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
496
    
497
    # Columns
498
    query += '\n'
499
    if fields == None: query += '*'
500
    else: query += '\n, '.join(map(parse_col, fields))
501
    
502
    # Main table
503
    query += '\nFROM '+table0.to_str(db)
504
    
505
    # Add joins
506
    left_table = table0
507
    for join_ in tables:
508
        table = join_.table
509
        
510
        # Parse special values
511
        if join_.type_ is sql_gen.filter_out: # filter no match
512
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
513
                None))
514
        
515
        query += '\n'+join_.to_str(db, left_table)
516
        
517
        left_table = table
518
    
519
    missing = True
520
    if conds != []:
521
        query += '\nWHERE\n'+('\nAND\n'.join(('('+sql_gen.ColValueCond(l, r)
522
            .to_str(db)+')' for l, r in conds)))
523
        missing = False
524
    if order_by != None:
525
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
526
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
527
    if start != None:
528
        if start != 0: query += '\nOFFSET '+str(start)
529
        missing = False
530
    if missing: warnings.warn(DbWarning(
531
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
532
    
533
    return (query, [])
534

    
535
def select(db, *args, **kw_args):
536
    '''For params, see mk_select() and run_query()'''
537
    recover = kw_args.pop('recover', None)
538
    cacheable = kw_args.pop('cacheable', True)
539
    log_level = kw_args.pop('log_level', 2)
540
    
541
    query, params = mk_select(db, *args, **kw_args)
542
    return run_query(db, query, params, recover, cacheable, log_level=log_level)
543

    
544
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
545
    returning=None, embeddable=False):
546
    '''
547
    @param returning str|None An inserted column (such as pkey) to return
548
    @param embeddable Whether the query should be embeddable as a nested SELECT.
549
        Warning: If you set this and cacheable=True when the query is run, the
550
        query will be fully cached, not just if it raises an exception.
551
    '''
552
    table = sql_gen.as_Table(table)
553
    if cols == []: cols = None # no cols (all defaults) = unknown col names
554
    if cols != None: cols = [sql_gen.as_Col(v).to_str(db) for v in cols]
555
    if select_query == None: select_query = 'DEFAULT VALUES'
556
    if returning != None: returning = sql_gen.as_Col(returning, table)
557
    
558
    # Build query
559
    first_line = 'INSERT INTO '+table.to_str(db)
560
    query = first_line
561
    if cols != None: query += '\n('+', '.join(cols)+')'
562
    query += '\n'+select_query
563
    
564
    if returning != None:
565
        returning_name = copy.copy(returning)
566
        returning_name.table = None
567
        returning_name = returning_name.to_str(db)
568
        query += '\nRETURNING '+returning_name
569
    
570
    if embeddable:
571
        assert returning != None
572
        
573
        # Create function
574
        function_name = sql_gen.clean_name(first_line)
575
        return_type = 'SETOF '+returning.to_str(db)+'%TYPE'
576
        while True:
577
            try:
578
                func_schema = None
579
                if not db.autocommit: func_schema = 'pg_temp'
580
                function = sql_gen.Table(function_name, func_schema).to_str(db)
581
                
582
                function_query = '''\
583
CREATE FUNCTION '''+function+'''()
584
RETURNS '''+return_type+'''
585
LANGUAGE sql
586
AS $$
587
'''+mogrify(db, query, params)+''';
588
$$;
589
'''
590
                run_query(db, function_query, recover=True, cacheable=True,
591
                    log_ignore_excs=(DuplicateFunctionException,))
592
                break # this version was successful
593
            except DuplicateFunctionException, e:
594
                function_name = next_version(function_name)
595
                # try again with next version of name
596
        
597
        # Return query that uses function
598
        func_table = sql_gen.NamedTable('f', sql_gen.CustomCode(function+'()'),
599
            [returning_name]) # AS clause requires function alias
600
        return mk_select(db, func_table, start=0, order_by=None)
601
    
602
    return (query, params)
603

    
604
def insert_select(db, *args, **kw_args):
605
    '''For params, see mk_insert_select() and run_query_into()
606
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
607
        values in
608
    '''
609
    into = kw_args.pop('into', None)
610
    if into != None: kw_args['embeddable'] = True
611
    recover = kw_args.pop('recover', None)
612
    cacheable = kw_args.pop('cacheable', True)
613
    
614
    query, params = mk_insert_select(db, *args, **kw_args)
615
    return run_query_into(db, query, params, into, recover=recover,
616
        cacheable=cacheable)
617

    
618
default = object() # tells insert() to use the default value for a column
619

    
620
def insert(db, table, row, *args, **kw_args):
621
    '''For params, see insert_select()'''
622
    if lists.is_seq(row): cols = None
623
    else:
624
        cols = row.keys()
625
        row = row.values()
626
    row = list(row) # ensure that "!= []" works
627
    
628
    # Check for special values
629
    labels = []
630
    values = []
631
    for value in row:
632
        if value is default: labels.append('DEFAULT')
633
        else:
634
            labels.append('%s')
635
            values.append(value)
636
    
637
    # Build query
638
    if values != []: query = 'VALUES ('+(', '.join(labels))+')'
639
    else: query = None
640
    
641
    return insert_select(db, table, cols, query, values, *args, **kw_args)
642

    
643
def mk_update(db, table, changes=None, cond=None):
644
    '''
645
    @param changes [(col, new_value),...]
646
        * container can be any iterable type
647
        * col: sql_gen.Code|str (for col name)
648
        * new_value: sql_gen.Code|literal value
649
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
650
    @return str query
651
    '''
652
    query = 'UPDATE '+sql_gen.as_Table(table).to_str(db)+'\nSET\n'
653
    query += ',\n'.join((sql_gen.to_name_only_col(col, table).to_str(db)+' = '
654
        +sql_gen.as_Value(new_value).to_str(db) for col, new_value in changes))
655
    if cond != None: query += '\nWHERE\n'+cond.to_str(db)
656
    
657
    return query
658

    
659
def update(db, *args, **kw_args):
660
    '''For params, see mk_update() and run_query()'''
661
    recover = kw_args.pop('recover', None)
662
    
663
    return run_query(db, mk_update(db, *args, **kw_args), [], recover)
664

    
665
def last_insert_id(db):
666
    module = util.root_module(db.db)
667
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
668
    elif module == 'MySQLdb': return db.insert_id()
669
    else: return None
670

    
671
def truncate(db, table, schema='public'):
672
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
673

    
674
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
675
    '''Creates a mapping from original column names (which may have collisions)
676
    to names that will be distinct among the columns' tables.
677
    This is meant to be used for several tables that are being joined together.
678
    @param cols The columns to combine. Duplicates will be removed.
679
    @param into The table for the new columns.
680
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
681
        columns will be included in the mapping even if they are not in cols.
682
        The tables of the provided Col objects will be changed to into, so make
683
        copies of them if you want to keep the original tables.
684
    @param as_items Whether to return a list of dict items instead of a dict
685
    @return dict(orig_col=new_col, ...)
686
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
687
        * new_col: sql_gen.Col(orig_col_name, into)
688
        * All mappings use the into table so its name can easily be
689
          changed for all columns at once
690
    '''
691
    cols = lists.uniqify(cols)
692
    
693
    items = []
694
    for col in preserve:
695
        orig_col = copy.copy(col)
696
        col.table = into
697
        items.append((orig_col, col))
698
    preserve = set(preserve)
699
    for col in cols:
700
        if col not in preserve: items.append((col, sql_gen.Col(str(col), into)))
701
    
702
    if not as_items: items = dict(items)
703
    return items
704

    
705
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
706
    '''For params, see mk_flatten_mapping()
707
    @return See return value of mk_flatten_mapping()
708
    '''
709
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
710
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
711
    run_query_into(db, *mk_select(db, joins, cols, limit=limit, start=start),
712
        into=into)
713
    return dict(items)
714

    
715
##### Database structure queries
716

    
717
def table_row_count(db, table, recover=None):
718
    return value(run_query(db, *mk_select(db, table, [sql_gen.row_count],
719
        order_by=None, start=0), recover=recover, log_level=3))
720

    
721
def table_cols(db, table, recover=None):
722
    return list(col_names(select(db, table, limit=0, order_by=None,
723
        recover=recover, log_level=4)))
724

    
725
def pkey(db, table, recover=None):
726
    '''Assumed to be first column in table'''
727
    return table_cols(db, table, recover)[0]
728

    
729
not_null_col = 'not_null'
730

    
731
def table_not_null_col(db, table, recover=None):
732
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
733
    if not_null_col in table_cols(db, table, recover): return not_null_col
734
    else: return pkey(db, table, recover)
735

    
736
def index_cols(db, table, index):
737
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
738
    automatically created. When you don't know whether something is a UNIQUE
739
    constraint or a UNIQUE index, use this function.'''
740
    module = util.root_module(db.db)
741
    if module == 'psycopg2':
742
        return list(values(run_query(db, '''\
743
SELECT attname
744
FROM
745
(
746
        SELECT attnum, attname
747
        FROM pg_index
748
        JOIN pg_class index ON index.oid = indexrelid
749
        JOIN pg_class table_ ON table_.oid = indrelid
750
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
751
        WHERE
752
            table_.relname = %(table)s
753
            AND index.relname = %(index)s
754
    UNION
755
        SELECT attnum, attname
756
        FROM
757
        (
758
            SELECT
759
                indrelid
760
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
761
                    AS indkey
762
            FROM pg_index
763
            JOIN pg_class index ON index.oid = indexrelid
764
            JOIN pg_class table_ ON table_.oid = indrelid
765
            WHERE
766
                table_.relname = %(table)s
767
                AND index.relname = %(index)s
768
        ) s
769
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
770
) s
771
ORDER BY attnum
772
''',
773
            {'table': table, 'index': index}, cacheable=True, log_level=4)))
774
    else: raise NotImplementedError("Can't list index columns for "+module+
775
        ' database')
776

    
777
def constraint_cols(db, table, constraint):
778
    module = util.root_module(db.db)
779
    if module == 'psycopg2':
780
        return list(values(run_query(db, '''\
781
SELECT attname
782
FROM pg_constraint
783
JOIN pg_class ON pg_class.oid = conrelid
784
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
785
WHERE
786
    relname = %(table)s
787
    AND conname = %(constraint)s
788
ORDER BY attnum
789
''',
790
            {'table': table, 'constraint': constraint})))
791
    else: raise NotImplementedError("Can't list constraint columns for "+module+
792
        ' database')
793

    
794
row_num_col = '_row_num'
795

    
796
def index_col(db, col):
797
    '''Adds an index on a column if it doesn't already exist.'''
798
    assert sql_gen.is_table_col(col)
799
    
800
    table = col.table
801
    index = sql_gen.as_Table(str(col))
802
    col = sql_gen.to_name_only_col(col)
803
    try: run_query(db, 'CREATE INDEX '+index.to_str(db)+' ON '+table.to_str(db)
804
        +' ('+col.to_str(db)+')', recover=True, cacheable=True, log_level=3)
805
    except DuplicateTableException: pass # index already existed
806

    
807
def index_pkey(db, table, recover=None):
808
    '''Makes the first column in a table the primary key.
809
    @pre The table must not already have a primary key.
810
    '''
811
    table = sql_gen.as_Table(table)
812
    
813
    index = sql_gen.as_Table(table.name+'_pkey')
814
    col = sql_gen.to_name_only_col(pkey(db, table, recover))
815
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD CONSTRAINT '
816
        +index.to_str(db)+' PRIMARY KEY('+col.to_str(db)+')', recover=recover,
817
        log_level=3)
818

    
819
def add_row_num(db, table):
820
    '''Adds a row number column to a table. Its name is in row_num_col. It will
821
    be the primary key.'''
822
    table = sql_gen.as_Table(table).to_str(db)
823
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
824
        +' serial NOT NULL PRIMARY KEY', log_level=3)
825

    
826
def tables(db, schema='public', table_like='%'):
827
    module = util.root_module(db.db)
828
    params = {'schema': schema, 'table_like': table_like}
829
    if module == 'psycopg2':
830
        return values(run_query(db, '''\
831
SELECT tablename
832
FROM pg_tables
833
WHERE
834
    schemaname = %(schema)s
835
    AND tablename LIKE %(table_like)s
836
ORDER BY tablename
837
''',
838
            params, cacheable=True))
839
    elif module == 'MySQLdb':
840
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
841
            cacheable=True))
842
    else: raise NotImplementedError("Can't list tables for "+module+' database')
843

    
844
##### Database management
845

    
846
def empty_db(db, schema='public', **kw_args):
847
    '''For kw_args, see tables()'''
848
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
849

    
850
##### Heuristic queries
851

    
852
def put(db, table, row, pkey_=None, row_ct_ref=None):
853
    '''Recovers from errors.
854
    Only works under PostgreSQL (uses INSERT RETURNING).
855
    '''
856
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
857
    
858
    try:
859
        cur = insert(db, table, row, pkey_, recover=True)
860
        if row_ct_ref != None and cur.rowcount >= 0:
861
            row_ct_ref[0] += cur.rowcount
862
        return value(cur)
863
    except DuplicateKeyException, e:
864
        return value(select(db, table, [pkey_],
865
            util.dict_subset_right_join(row, e.cols), recover=True))
866

    
867
def get(db, table, row, pkey, row_ct_ref=None, create=False):
868
    '''Recovers from errors'''
869
    try: return value(select(db, table, [pkey], row, limit=1, recover=True))
870
    except StopIteration:
871
        if not create: raise
872
        return put(db, table, row, pkey, row_ct_ref) # insert new row
873

    
874
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, into=None,
875
    default=None):
876
    '''Recovers from errors.
877
    Only works under PostgreSQL (uses INSERT RETURNING).
878
    @param in_tables The main input table to select from, followed by a list of
879
        tables to join with it using the main input table's pkey
880
    @param mapping dict(out_table_col=in_table_col, ...)
881
        * out_table_col: sql_gen.Col|str
882
        * in_table_col: sql_gen.Col Wrap literal values in a sql_gen.NamedCol
883
    @param into The table to contain the output and input pkeys.
884
        Defaults to `out_table.name+'-pkeys'`.
885
    @param default The *output* column to use as the pkey for missing rows.
886
        If this output column does not exist in the mapping, uses None.
887
    @return sql_gen.Col Where the output pkeys are made available
888
    '''
889
    out_table = sql_gen.as_Table(out_table)
890
    for in_table_col in mapping.itervalues():
891
        assert isinstance(in_table_col, sql_gen.Col)
892
    if into == None: into = out_table.name+'-pkeys'
893
    into = sql_gen.as_Table(into)
894
    
895
    def log_debug(msg): db.log_debug(msg, level=1.5)
896
    def col_ustr(str_):
897
        return strings.repr_no_u(sql_gen.remove_col_rename(
898
            sql_gen.as_Col(str_)))
899
    
900
    log_debug('********** New iteration **********')
901
    log_debug('Inserting these input columns into '+strings.as_tt(
902
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
903
    
904
    # Create input joins from list of input tables
905
    in_tables_ = in_tables[:] # don't modify input!
906
    in_tables0 = in_tables_.pop(0) # first table is separate
907
    in_pkey = pkey(db, in_tables0, recover=True)
908
    in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
909
    input_joins = [in_tables0]+[sql_gen.Join(v,
910
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
911
    
912
    log_debug('Joining together input tables into temp table')
913
    # Place in new table for speed and so don't modify input if values edited
914
    in_table = sql_gen.Table(into.name+'-input')
915
    flatten_cols = filter(sql_gen.is_table_col, mapping.values())
916
    mapping = dicts.join(mapping, flatten(db, in_table, input_joins,
917
        flatten_cols, preserve=[in_pkey_col], start=0))
918
    input_joins = [in_table]
919
    db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
920
    
921
    # Resolve default value column
922
    try: default = mapping[default]
923
    except KeyError:
924
        if default != None:
925
            db.log_debug('Default value column '
926
                +strings.as_tt(strings.repr_no_u(default))
927
                +' does not exist in mapping, falling back to None', level=2.1)
928
            default = None
929
    
930
    out_pkey = pkey(db, out_table, recover=True)
931
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
932
    
933
    pkeys_names = [in_pkey, out_pkey]
934
    pkeys_cols = [in_pkey_col, out_pkey_col]
935
    
936
    pkeys_table_exists_ref = [False]
937
    def insert_into_pkeys(joins, cols):
938
        query, params = mk_select(db, joins, cols, order_by=None, start=0)
939
        if pkeys_table_exists_ref[0]:
940
            insert_select(db, into, pkeys_names, query, params)
941
        else:
942
            run_query_into(db, query, params, into=into)
943
            pkeys_table_exists_ref[0] = True
944
    
945
    limit_ref = [None]
946
    conds = set()
947
    distinct_on = []
948
    def mk_main_select(joins, cols):
949
        return mk_select(db, joins, cols, conds, distinct_on,
950
            limit=limit_ref[0], start=0)
951
    
952
    exc_strs = set()
953
    def log_exc(e):
954
        e_str = exc.str_(e, first_line_only=True)
955
        log_debug('Caught exception: '+e_str)
956
        assert e_str not in exc_strs # avoid infinite loops
957
        exc_strs.add(e_str)
958
    def remove_all_rows():
959
        log_debug('Returning NULL for all rows')
960
        limit_ref[0] = 0 # just create an empty pkeys table
961
    def ignore(in_col, value):
962
        in_col_str = repr(in_col)
963
        log_debug('Adding index on '+in_col_str+' to enable fast filtering')
964
        index_col(db, in_col)
965
        log_debug('Ignoring rows with '+in_col_str+' = '+repr(value))
966
    def remove_rows(in_col, value):
967
        ignore(in_col, value)
968
        cond = (in_col, sql_gen.CompareCond(value, '!='))
969
        assert cond not in conds # avoid infinite loops
970
        conds.add(cond)
971
    def invalid2null(in_col, value):
972
        ignore(in_col, value)
973
        update(db, in_table, [(in_col, None)],
974
            sql_gen.ColValueCond(in_col, value))
975
    
976
    # Do inserts and selects
977
    join_cols = {}
978
    insert_out_pkeys = sql_gen.Table(into.name+'-insert_out_pkeys')
979
    insert_in_pkeys = sql_gen.Table(into.name+'-insert_in_pkeys')
980
    while True:
981
        if limit_ref[0] == 0: # special case
982
            log_debug('Creating an empty pkeys table')
983
            cur = run_query_into(db, *mk_select(db, out_table, [out_pkey],
984
                limit=limit_ref[0]), into=insert_out_pkeys)
985
            break # don't do main case
986
        
987
        has_joins = join_cols != {}
988
        
989
        # Prepare to insert new rows
990
        insert_joins = input_joins[:] # don't modify original!
991
        insert_args = dict(recover=True, cacheable=False)
992
        if has_joins:
993
            distinct_on = [v.to_Col() for v in join_cols.values()]
994
            insert_joins.append(sql_gen.Join(out_table, join_cols,
995
                sql_gen.filter_out))
996
        else:
997
            insert_args.update(dict(returning=out_pkey, into=insert_out_pkeys))
998
        main_select = mk_main_select(insert_joins, mapping.values())[0]
999
        
1000
        log_debug('Trying to insert new rows')
1001
        try:
1002
            cur = insert_select(db, out_table, mapping.keys(), main_select,
1003
                **insert_args)
1004
            break # insert successful
1005
        except DuplicateKeyException, e:
1006
            log_exc(e)
1007
            
1008
            old_join_cols = join_cols.copy()
1009
            join_cols.update(util.dict_subset(mapping, e.cols))
1010
            log_debug('Ignoring existing rows, comparing on these columns:\n'
1011
                +strings.as_inline_table(join_cols, ustr=col_ustr))
1012
            assert join_cols != old_join_cols # avoid infinite loops
1013
        except NullValueException, e:
1014
            log_exc(e)
1015
            
1016
            out_col, = e.cols
1017
            try: in_col = mapping[out_col]
1018
            except KeyError:
1019
                log_debug('Missing mapping for NOT NULL column '+out_col)
1020
                remove_all_rows()
1021
            else: remove_rows(in_col, None)
1022
        except FunctionValueException, e:
1023
            log_exc(e)
1024
            
1025
            assert e.name == out_table.name
1026
            out_col = 'value' # assume function param was named "value"
1027
            invalid2null(mapping[out_col], e.value)
1028
        except MissingCastException, e:
1029
            log_exc(e)
1030
            
1031
            out_col = e.col
1032
            mapping[out_col] = sql_gen.FunctionCall(e.type, mapping[out_col])
1033
        except DatabaseErrors, e:
1034
            log_exc(e)
1035
            
1036
            msg = 'No handler for exception: '+exc.str_(e, first_line_only=True)
1037
            warnings.warn(DbWarning(msg))
1038
            log_debug(msg)
1039
            remove_all_rows()
1040
        # after exception handled, rerun loop with additional constraints
1041
    
1042
    if row_ct_ref != None and cur.rowcount >= 0:
1043
        row_ct_ref[0] += cur.rowcount
1044
    
1045
    if has_joins:
1046
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols)]
1047
        log_debug('Getting output table pkeys of existing/inserted rows')
1048
        insert_into_pkeys(select_joins, pkeys_cols)
1049
    else:
1050
        add_row_num(db, insert_out_pkeys) # for joining with input pkeys
1051
        
1052
        log_debug('Getting input table pkeys of inserted rows')
1053
        run_query_into(db, *mk_main_select(input_joins, [in_pkey]),
1054
            into=insert_in_pkeys)
1055
        add_row_num(db, insert_in_pkeys) # for joining with output pkeys
1056
        
1057
        assert table_row_count(db, insert_out_pkeys) == table_row_count(db,
1058
            insert_in_pkeys)
1059
        
1060
        log_debug('Combining output and input pkeys in inserted order')
1061
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
1062
            {row_num_col: sql_gen.join_same_not_null})]
1063
        insert_into_pkeys(pkey_joins, pkeys_names)
1064
    
1065
    db.log_debug('Adding pkey on pkeys table to enable fast joins', level=2.5)
1066
    index_pkey(db, into)
1067
    
1068
    log_debug('Setting pkeys of missing rows to '+strings.as_tt(repr(default)))
1069
    missing_rows_joins = input_joins+[sql_gen.Join(into,
1070
        {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
1071
        # must use join_same_not_null or query will take forever
1072
    insert_into_pkeys(missing_rows_joins,
1073
        [in_pkey_col, sql_gen.NamedCol(out_pkey, default)])
1074
    
1075
    assert table_row_count(db, into) == table_row_count(db, in_table)
1076
    
1077
    return sql_gen.Col(out_pkey, into)
1078

    
1079
##### Data cleanup
1080

    
1081
def cleanup_table(db, table, cols):
1082
    def esc_name_(name): return esc_name(db, name)
1083
    
1084
    table = sql_gen.as_Table(table).to_str(db)
1085
    cols = map(esc_name_, cols)
1086
    
1087
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
1088
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
1089
            for col in cols))),
1090
        dict(null0='', null1=r'\N'))
(24-24/36)