Project

General

Profile

1
# Database access
2

    
3
import copy
4
import operator
5
import re
6
import warnings
7

    
8
import exc
9
import dicts
10
import iters
11
import lists
12
from Proxy import Proxy
13
import rand
14
import sql_gen
15
import strings
16
import util
17

    
18
##### Exceptions
19

    
20
def get_cur_query(cur, input_query=None, input_params=None):
21
    raw_query = None
22
    if hasattr(cur, 'query'): raw_query = cur.query
23
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
24
    
25
    if raw_query != None: return raw_query
26
    else: return '[input] '+strings.ustr(input_query)+' % '+repr(input_params)
27

    
28
def _add_cursor_info(e, *args, **kw_args):
29
    '''For params, see get_cur_query()'''
30
    exc.add_msg(e, 'query: '+str(get_cur_query(*args, **kw_args)))
31

    
32
class DbException(exc.ExceptionWithCause):
33
    def __init__(self, msg, cause=None, cur=None):
34
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
35
        if cur != None: _add_cursor_info(self, cur)
36

    
37
class ExceptionWithName(DbException):
38
    def __init__(self, name, cause=None):
39
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
40
        self.name = name
41

    
42
class ExceptionWithNameValue(DbException):
43
    def __init__(self, name, value, cause=None):
44
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name))
45
            +'; value: '+strings.as_tt(repr(value)), cause)
46
        self.name = name
47
        self.value = value
48

    
49
class ConstraintException(DbException):
50
    def __init__(self, name, cols, cause=None):
51
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
52
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
53
        self.name = name
54
        self.cols = cols
55

    
56
class MissingCastException(DbException):
57
    def __init__(self, type_, col, cause=None):
58
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
59
            +' on column: '+strings.as_tt(col), cause)
60
        self.type = type_
61
        self.col = col
62

    
63
class NameException(DbException): pass
64

    
65
class DuplicateKeyException(ConstraintException): pass
66

    
67
class NullValueException(ConstraintException): pass
68

    
69
class FunctionValueException(ExceptionWithNameValue): pass
70

    
71
class DuplicateTableException(ExceptionWithName): pass
72

    
73
class DuplicateFunctionException(ExceptionWithName): pass
74

    
75
class EmptyRowException(DbException): pass
76

    
77
##### Warnings
78

    
79
class DbWarning(UserWarning): pass
80

    
81
##### Result retrieval
82

    
83
def col_names(cur): return (col[0] for col in cur.description)
84

    
85
def rows(cur): return iter(lambda: cur.fetchone(), None)
86

    
87
def consume_rows(cur):
88
    '''Used to fetch all rows so result will be cached'''
89
    iters.consume_iter(rows(cur))
90

    
91
def next_row(cur): return rows(cur).next()
92

    
93
def row(cur):
94
    row_ = next_row(cur)
95
    consume_rows(cur)
96
    return row_
97

    
98
def next_value(cur): return next_row(cur)[0]
99

    
100
def value(cur): return row(cur)[0]
101

    
102
def values(cur): return iters.func_iter(lambda: next_value(cur))
103

    
104
def value_or_none(cur):
105
    try: return value(cur)
106
    except StopIteration: return None
107

    
108
##### Input validation
109

    
110
def esc_name_by_module(module, name):
111
    if module == 'psycopg2' or module == None: quote = '"'
112
    elif module == 'MySQLdb': quote = '`'
113
    else: raise NotImplementedError("Can't escape name for "+module+' database')
114
    return sql_gen.esc_name(name, quote)
115

    
116
def esc_name_by_engine(engine, name, **kw_args):
117
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
118

    
119
def esc_name(db, name, **kw_args):
120
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
121

    
122
def qual_name(db, schema, table):
123
    def esc_name_(name): return esc_name(db, name)
124
    table = esc_name_(table)
125
    if schema != None: return esc_name_(schema)+'.'+table
126
    else: return table
127

    
128
##### Database connections
129

    
130
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
131

    
132
db_engines = {
133
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
134
    'PostgreSQL': ('psycopg2', {}),
135
}
136

    
137
DatabaseErrors_set = set([DbException])
138
DatabaseErrors = tuple(DatabaseErrors_set)
139

    
140
def _add_module(module):
141
    DatabaseErrors_set.add(module.DatabaseError)
142
    global DatabaseErrors
143
    DatabaseErrors = tuple(DatabaseErrors_set)
144

    
145
def db_config_str(db_config):
146
    return db_config['engine']+' database '+db_config['database']
147

    
148
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
149

    
150
log_debug_none = lambda msg, level=2: None
151

    
152
class DbConn:
153
    def __init__(self, db_config, serializable=True, autocommit=False,
154
        caching=True, log_debug=log_debug_none):
155
        self.db_config = db_config
156
        self.serializable = serializable
157
        self.autocommit = autocommit
158
        self.caching = caching
159
        self.log_debug = log_debug
160
        self.debug = log_debug != log_debug_none
161
        
162
        self.__db = None
163
        self.query_results = {}
164
        self._savepoint = 0
165
    
166
    def __getattr__(self, name):
167
        if name == '__dict__': raise Exception('getting __dict__')
168
        if name == 'db': return self._db()
169
        else: raise AttributeError()
170
    
171
    def __getstate__(self):
172
        state = copy.copy(self.__dict__) # shallow copy
173
        state['log_debug'] = None # don't pickle the debug callback
174
        state['_DbConn__db'] = None # don't pickle the connection
175
        return state
176
    
177
    def connected(self): return self.__db != None
178
    
179
    def _db(self):
180
        if self.__db == None:
181
            # Process db_config
182
            db_config = self.db_config.copy() # don't modify input!
183
            schemas = db_config.pop('schemas', None)
184
            module_name, mappings = db_engines[db_config.pop('engine')]
185
            module = __import__(module_name)
186
            _add_module(module)
187
            for orig, new in mappings.iteritems():
188
                try: util.rename_key(db_config, orig, new)
189
                except KeyError: pass
190
            
191
            # Connect
192
            self.__db = module.connect(**db_config)
193
            
194
            # Configure connection
195
            if self.serializable and not self.autocommit: run_raw_query(self,
196
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
197
            if schemas != None:
198
                schemas_ = ''.join((esc_name(self, s)+', '
199
                    for s in schemas.split(',')))
200
                run_raw_query(self, "SELECT set_config('search_path', \
201
%s || current_setting('search_path'), false)", [schemas_])
202
        
203
        return self.__db
204
    
205
    class DbCursor(Proxy):
206
        def __init__(self, outer):
207
            Proxy.__init__(self, outer.db.cursor())
208
            self.outer = outer
209
            self.query_results = outer.query_results
210
            self.query_lookup = None
211
            self.result = []
212
        
213
        def execute(self, query, params=None):
214
            self._is_insert = query.upper().find('INSERT') >= 0
215
            self.query_lookup = _query_lookup(query, params)
216
            try:
217
                try:
218
                    return_value = self.inner.execute(query, params)
219
                    self.outer.do_autocommit()
220
                finally: self.query = get_cur_query(self.inner, query, params)
221
            except Exception, e:
222
                _add_cursor_info(e, self, query, params)
223
                self.result = e # cache the exception as the result
224
                self._cache_result()
225
                raise
226
            # Fetch all rows so result will be cached
227
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
228
            return return_value
229
        
230
        def fetchone(self):
231
            row = self.inner.fetchone()
232
            if row != None: self.result.append(row)
233
            # otherwise, fetched all rows
234
            else: self._cache_result()
235
            return row
236
        
237
        def _cache_result(self):
238
            # For inserts, only cache exceptions since inserts are not
239
            # idempotent, but an invalid insert will always be invalid
240
            if self.query_results != None and (not self._is_insert
241
                or isinstance(self.result, Exception)):
242
                
243
                assert self.query_lookup != None
244
                self.query_results[self.query_lookup] = self.CacheCursor(
245
                    util.dict_subset(dicts.AttrsDictView(self),
246
                    ['query', 'result', 'rowcount', 'description']))
247
        
248
        class CacheCursor:
249
            def __init__(self, cached_result): self.__dict__ = cached_result
250
            
251
            def execute(self, *args, **kw_args):
252
                if isinstance(self.result, Exception): raise self.result
253
                # otherwise, result is a rows list
254
                self.iter = iter(self.result)
255
            
256
            def fetchone(self):
257
                try: return self.iter.next()
258
                except StopIteration: return None
259
    
260
    def esc_value(self, value):
261
        try: str_ = self.mogrify('%s', [value])
262
        except NotImplementedError, e:
263
            module = util.root_module(self.db)
264
            if module == 'MySQLdb':
265
                import _mysql
266
                str_ = _mysql.escape_string(value)
267
            else: raise e
268
        return strings.to_unicode(str_)
269
    
270
    def esc_name(self, name): return esc_name(self, name) # calls global func
271
    
272
    def can_mogrify(self):
273
        module = util.root_module(self.db)
274
        return module == 'psycopg2'
275
    
276
    def mogrify(self, query, params=None):
277
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
278
        else: raise NotImplementedError("Can't mogrify query")
279
    
280
    def run_query(self, query, params=None, cacheable=False, log_level=2,
281
        debug_msg_ref=None):
282
        '''
283
        @param log_ignore_excs The log_level will be increased by 2 if the query
284
            throws one of these exceptions.
285
        @param debug_msg_ref If specified, the log message will be returned in
286
            this instead of being output. This allows you to filter log messages
287
            depending on the result of the query.
288
        '''
289
        assert query != None
290
        
291
        if not self.caching: cacheable = False
292
        used_cache = False
293
        
294
        def log_msg(query):
295
            if used_cache: cache_status = 'cache hit'
296
            elif cacheable: cache_status = 'cache miss'
297
            else: cache_status = 'non-cacheable'
298
            return 'DB query: '+cache_status+':\n'+strings.as_code(query, 'SQL')
299
        
300
        try:
301
            # Get cursor
302
            if cacheable:
303
                query_lookup = _query_lookup(query, params)
304
                try:
305
                    cur = self.query_results[query_lookup]
306
                    used_cache = True
307
                except KeyError: cur = self.DbCursor(self)
308
            else: cur = self.db.cursor()
309
            
310
            # Log query
311
            if self.debug and debug_msg_ref == None: # log before running
312
                self.log_debug(log_msg(query), log_level)
313
            
314
            # Run query
315
            cur.execute(query, params)
316
        finally:
317
            if self.debug and debug_msg_ref != None: # return after running
318
                debug_msg_ref[0] = log_msg(str(get_cur_query(cur, query,
319
                    params)))
320
        
321
        return cur
322
    
323
    def is_cached(self, query, params=None):
324
        return _query_lookup(query, params) in self.query_results
325
    
326
    def with_savepoint(self, func):
327
        savepoint = 'level_'+str(self._savepoint)
328
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
329
        self._savepoint += 1
330
        try: 
331
            try: return_val = func()
332
            finally:
333
                self._savepoint -= 1
334
                assert self._savepoint >= 0
335
        except:
336
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
337
            raise
338
        else:
339
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
340
            self.do_autocommit()
341
            return return_val
342
    
343
    def do_autocommit(self):
344
        '''Autocommits if outside savepoint'''
345
        assert self._savepoint >= 0
346
        if self.autocommit and self._savepoint == 0:
347
            self.log_debug('Autocommiting')
348
            self.db.commit()
349
    
350
    def col_default(self, col):
351
        table = sql_gen.Table('columns', 'information_schema')
352
        
353
        conds = [('table_name', col.table.name), ('column_name', col.name)]
354
        schema = col.table.schema
355
        if schema != None: conds.append(('table_schema', schema))
356
        
357
        return sql_gen.as_Code(value(select(self, table, ['column_default'],
358
            conds, order_by='table_schema', limit=1, log_level=3)))
359
            # TODO: order_by search_path schema order
360

    
361
connect = DbConn
362

    
363
##### Querying
364

    
365
def run_raw_query(db, *args, **kw_args):
366
    '''For params, see DbConn.run_query()'''
367
    return db.run_query(*args, **kw_args)
368

    
369
def mogrify(db, query, params):
370
    module = util.root_module(db.db)
371
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
372
    else: raise NotImplementedError("Can't mogrify query for "+module+
373
        ' database')
374

    
375
##### Recoverable querying
376

    
377
def with_savepoint(db, func): return db.with_savepoint(func)
378

    
379
def run_query(db, query, params=None, recover=None, cacheable=False,
380
    log_level=2, log_ignore_excs=None, **kw_args):
381
    '''For params, see run_raw_query()'''
382
    if recover == None: recover = False
383
    if log_ignore_excs == None: log_ignore_excs = ()
384
    log_ignore_excs = tuple(log_ignore_excs)
385
    
386
    debug_msg_ref = [None]
387
    try:
388
        try:
389
            def run(): return run_raw_query(db, query, params, cacheable,
390
                log_level, debug_msg_ref, **kw_args)
391
            if recover and not db.is_cached(query, params):
392
                return with_savepoint(db, run)
393
            else: return run() # don't need savepoint if cached
394
        except Exception, e:
395
            if not recover: raise # need savepoint to run index_cols()
396
            msg = exc.str_(e)
397
            
398
            match = re.search(r'duplicate key value violates unique constraint '
399
                r'"((_?[^\W_]+)_.+?)"', msg)
400
            if match:
401
                constraint, table = match.groups()
402
                try: cols = index_cols(db, table, constraint)
403
                except NotImplementedError: raise e
404
                else: raise DuplicateKeyException(constraint, cols, e)
405
            
406
            match = re.search(r'null value in column "(.+?)" violates not-null'
407
                r' constraint', msg)
408
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
409
            
410
            match = re.search(r'\b(?:invalid input (?:syntax|value)\b.*?'
411
                r'|date/time field value out of range): "(.+?)"\n'
412
                r'(?:(?s).*?)\bfunction "(.+?)"', msg)
413
            if match:
414
                value, name = match.groups()
415
                raise FunctionValueException(name, strings.to_unicode(value), e)
416
            
417
            match = re.search(r'column "(.+?)" is of type (.+?) but expression '
418
                r'is of type', msg)
419
            if match:
420
                col, type_ = match.groups()
421
                raise MissingCastException(type_, col, e)
422
            
423
            match = re.search(r'relation "(.+?)" already exists', msg)
424
            if match: raise DuplicateTableException(match.group(1), e)
425
            
426
            match = re.search(r'function "(.+?)" already exists', msg)
427
            if match: raise DuplicateFunctionException(match.group(1), e)
428
            
429
            raise # no specific exception raised
430
    except log_ignore_excs:
431
        log_level += 2
432
        raise
433
    finally:
434
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)
435

    
436
##### Basic queries
437

    
438
def next_version(name):
439
    version = 1 # first existing name was version 0
440
    match = re.match(r'^(.*)#(\d+)$', name)
441
    if match:
442
        name, version = match.groups()
443
        version = int(version)+1
444
    return sql_gen.add_suffix(name, '#'+str(version))
445

    
446
def run_query_into(db, query, params, into=None, *args, **kw_args):
447
    '''Outputs a query to a temp table.
448
    For params, see run_query().
449
    '''
450
    if into == None: return run_query(db, query, params, *args, **kw_args)
451
    else: # place rows in temp table
452
        assert isinstance(into, sql_gen.Table)
453
        
454
        kw_args['recover'] = True
455
        kw_args.setdefault('log_ignore_excs', (DuplicateTableException,))
456
        
457
        temp = not db.autocommit # tables are permanent in autocommit mode
458
        # "temporary tables cannot specify a schema name", so remove schema
459
        if temp: into.schema = None
460
        
461
        while True:
462
            try:
463
                create_query = 'CREATE'
464
                if temp: create_query += ' TEMP'
465
                create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
466
                
467
                return run_query(db, create_query, params, *args, **kw_args)
468
                    # CREATE TABLE AS sets rowcount to # rows in query
469
            except DuplicateTableException, e:
470
                into.name = next_version(into.name)
471
                # try again with next version of name
472

    
473
order_by_pkey = object() # tells mk_select() to order by the pkey
474

    
475
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
476

    
477
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
478
    start=None, order_by=order_by_pkey, default_table=None):
479
    '''
480
    @param tables The single table to select from, or a list of tables to join
481
        together, with tables after the first being sql_gen.Join objects
482
    @param fields Use None to select all fields in the table
483
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
484
        * container can be any iterable type
485
        * compare_left_side: sql_gen.Code|str (for col name)
486
        * compare_right_side: sql_gen.ValueCond|literal value
487
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
488
        use all columns
489
    @return tuple(query, params)
490
    '''
491
    # Parse tables param
492
    if not lists.is_seq(tables): tables = [tables]
493
    tables = list(tables) # don't modify input! (list() copies input)
494
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
495
    
496
    # Parse other params
497
    if conds == None: conds = []
498
    elif dicts.is_dict(conds): conds = conds.items()
499
    conds = list(conds) # don't modify input! (list() copies input)
500
    assert limit == None or type(limit) == int
501
    assert start == None or type(start) == int
502
    if order_by is order_by_pkey:
503
        if distinct_on != []: order_by = None
504
        else: order_by = pkey(db, table0, recover=True)
505
    
506
    query = 'SELECT'
507
    
508
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
509
    
510
    # DISTINCT ON columns
511
    if distinct_on != []:
512
        query += '\nDISTINCT'
513
        if distinct_on is not distinct_on_all:
514
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
515
    
516
    # Columns
517
    query += '\n'
518
    if fields == None: query += '*'
519
    else: query += '\n, '.join(map(parse_col, fields))
520
    
521
    # Main table
522
    query += '\nFROM '+table0.to_str(db)
523
    
524
    # Add joins
525
    left_table = table0
526
    for join_ in tables:
527
        table = join_.table
528
        
529
        # Parse special values
530
        if join_.type_ is sql_gen.filter_out: # filter no match
531
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
532
                None))
533
        
534
        query += '\n'+join_.to_str(db, left_table)
535
        
536
        left_table = table
537
    
538
    missing = True
539
    if conds != []:
540
        if len(conds) == 1: whitespace = ' '
541
        else: whitespace = '\n'
542
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
543
            .to_str(db) for l, r in conds], 'WHERE')
544
        missing = False
545
    if order_by != None:
546
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
547
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
548
    if start != None:
549
        if start != 0: query += '\nOFFSET '+str(start)
550
        missing = False
551
    if missing: warnings.warn(DbWarning(
552
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
553
    
554
    return (query, [])
555

    
556
def select(db, *args, **kw_args):
557
    '''For params, see mk_select() and run_query()'''
558
    recover = kw_args.pop('recover', None)
559
    cacheable = kw_args.pop('cacheable', True)
560
    log_level = kw_args.pop('log_level', 2)
561
    
562
    query, params = mk_select(db, *args, **kw_args)
563
    return run_query(db, query, params, recover, cacheable, log_level=log_level)
564

    
565
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
566
    returning=None, embeddable=False):
567
    '''
568
    @param returning str|None An inserted column (such as pkey) to return
569
    @param embeddable Whether the query should be embeddable as a nested SELECT.
570
        Warning: If you set this and cacheable=True when the query is run, the
571
        query will be fully cached, not just if it raises an exception.
572
    '''
573
    table = sql_gen.as_Table(table)
574
    if cols == []: cols = None # no cols (all defaults) = unknown col names
575
    if cols != None:
576
        cols = [sql_gen.to_name_only_col(v, table).to_str(db) for v in cols]
577
    if select_query == None: select_query = 'DEFAULT VALUES'
578
    if returning != None: returning = sql_gen.as_Col(returning, table)
579
    
580
    # Build query
581
    first_line = 'INSERT INTO '+table.to_str(db)
582
    query = first_line
583
    if cols != None: query += '\n('+', '.join(cols)+')'
584
    query += '\n'+select_query
585
    
586
    if returning != None:
587
        returning_name = copy.copy(returning)
588
        returning_name.table = None
589
        returning_name = returning_name.to_str(db)
590
        query += '\nRETURNING '+returning_name
591
    
592
    if embeddable:
593
        assert returning != None
594
        
595
        # Create function
596
        function_name = sql_gen.clean_name(first_line)
597
        return_type = 'SETOF '+returning.to_str(db)+'%TYPE'
598
        while True:
599
            try:
600
                func_schema = None
601
                if not db.autocommit: func_schema = 'pg_temp'
602
                function = sql_gen.Table(function_name, func_schema).to_str(db)
603
                
604
                function_query = '''\
605
CREATE FUNCTION '''+function+'''()
606
RETURNS '''+return_type+'''
607
LANGUAGE sql
608
AS $$
609
'''+mogrify(db, query, params)+''';
610
$$;
611
'''
612
                run_query(db, function_query, recover=True, cacheable=True,
613
                    log_ignore_excs=(DuplicateFunctionException,))
614
                break # this version was successful
615
            except DuplicateFunctionException, e:
616
                function_name = next_version(function_name)
617
                # try again with next version of name
618
        
619
        # Return query that uses function
620
        func_table = sql_gen.NamedTable('f', sql_gen.CustomCode(function+'()'),
621
            [returning_name]) # AS clause requires function alias
622
        return mk_select(db, func_table, start=0, order_by=None)
623
    
624
    return (query, params)
625

    
626
def insert_select(db, *args, **kw_args):
627
    '''For params, see mk_insert_select() and run_query_into()
628
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
629
        values in
630
    '''
631
    into = kw_args.pop('into', None)
632
    if into != None: kw_args['embeddable'] = True
633
    recover = kw_args.pop('recover', None)
634
    cacheable = kw_args.pop('cacheable', True)
635
    
636
    query, params = mk_insert_select(db, *args, **kw_args)
637
    return run_query_into(db, query, params, into, recover=recover,
638
        cacheable=cacheable)
639

    
640
default = object() # tells insert() to use the default value for a column
641

    
642
def insert(db, table, row, *args, **kw_args):
643
    '''For params, see insert_select()'''
644
    if lists.is_seq(row): cols = None
645
    else:
646
        cols = row.keys()
647
        row = row.values()
648
    row = list(row) # ensure that "!= []" works
649
    
650
    # Check for special values
651
    labels = []
652
    values = []
653
    for value in row:
654
        value = sql_gen.remove_col_rename(sql_gen.as_Value(value)).value
655
        if value is default: labels.append('DEFAULT')
656
        else:
657
            labels.append('%s')
658
            values.append(value)
659
    
660
    # Build query
661
    if values != []: query = 'VALUES ('+(', '.join(labels))+')'
662
    else: query = None
663
    
664
    return insert_select(db, table, cols, query, values, *args, **kw_args)
665

    
666
def mk_update(db, table, changes=None, cond=None):
667
    '''
668
    @param changes [(col, new_value),...]
669
        * container can be any iterable type
670
        * col: sql_gen.Code|str (for col name)
671
        * new_value: sql_gen.Code|literal value
672
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
673
    @return str query
674
    '''
675
    query = 'UPDATE '+sql_gen.as_Table(table).to_str(db)+'\nSET\n'
676
    query += ',\n'.join((sql_gen.to_name_only_col(col, table).to_str(db)+' = '
677
        +sql_gen.as_Value(new_value).to_str(db) for col, new_value in changes))
678
    if cond != None: query += '\nWHERE\n'+cond.to_str(db)
679
    
680
    return query
681

    
682
def update(db, *args, **kw_args):
683
    '''For params, see mk_update() and run_query()'''
684
    recover = kw_args.pop('recover', None)
685
    
686
    return run_query(db, mk_update(db, *args, **kw_args), [], recover)
687

    
688
def last_insert_id(db):
689
    module = util.root_module(db.db)
690
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
691
    elif module == 'MySQLdb': return db.insert_id()
692
    else: return None
693

    
694
def truncate(db, table, schema='public'):
695
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
696

    
697
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
698
    '''Creates a mapping from original column names (which may have collisions)
699
    to names that will be distinct among the columns' tables.
700
    This is meant to be used for several tables that are being joined together.
701
    @param cols The columns to combine. Duplicates will be removed.
702
    @param into The table for the new columns.
703
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
704
        columns will be included in the mapping even if they are not in cols.
705
        The tables of the provided Col objects will be changed to into, so make
706
        copies of them if you want to keep the original tables.
707
    @param as_items Whether to return a list of dict items instead of a dict
708
    @return dict(orig_col=new_col, ...)
709
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
710
        * new_col: sql_gen.Col(orig_col_name, into)
711
        * All mappings use the into table so its name can easily be
712
          changed for all columns at once
713
    '''
714
    cols = lists.uniqify(cols)
715
    
716
    items = []
717
    for col in preserve:
718
        orig_col = copy.copy(col)
719
        col.table = into
720
        items.append((orig_col, col))
721
    preserve = set(preserve)
722
    for col in cols:
723
        if col not in preserve: items.append((col, sql_gen.Col(str(col), into)))
724
    
725
    if not as_items: items = dict(items)
726
    return items
727

    
728
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
729
    '''For params, see mk_flatten_mapping()
730
    @return See return value of mk_flatten_mapping()
731
    '''
732
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
733
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
734
    run_query_into(db, *mk_select(db, joins, cols, limit=limit, start=start),
735
        into=into)
736
    return dict(items)
737

    
738
##### Database structure queries
739

    
740
def table_row_count(db, table, recover=None):
741
    return value(run_query(db, *mk_select(db, table, [sql_gen.row_count],
742
        order_by=None, start=0), recover=recover, log_level=3))
743

    
744
def table_cols(db, table, recover=None):
745
    return list(col_names(select(db, table, limit=0, order_by=None,
746
        recover=recover, log_level=4)))
747

    
748
def pkey(db, table, recover=None):
749
    '''Assumed to be first column in table'''
750
    return table_cols(db, table, recover)[0]
751

    
752
not_null_col = 'not_null_col'
753

    
754
def table_not_null_col(db, table, recover=None):
755
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
756
    if not_null_col in table_cols(db, table, recover): return not_null_col
757
    else: return pkey(db, table, recover)
758

    
759
def index_cols(db, table, index):
760
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
761
    automatically created. When you don't know whether something is a UNIQUE
762
    constraint or a UNIQUE index, use this function.'''
763
    module = util.root_module(db.db)
764
    if module == 'psycopg2':
765
        return list(values(run_query(db, '''\
766
SELECT attname
767
FROM
768
(
769
        SELECT attnum, attname
770
        FROM pg_index
771
        JOIN pg_class index ON index.oid = indexrelid
772
        JOIN pg_class table_ ON table_.oid = indrelid
773
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
774
        WHERE
775
            table_.relname = %(table)s
776
            AND index.relname = %(index)s
777
    UNION
778
        SELECT attnum, attname
779
        FROM
780
        (
781
            SELECT
782
                indrelid
783
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
784
                    AS indkey
785
            FROM pg_index
786
            JOIN pg_class index ON index.oid = indexrelid
787
            JOIN pg_class table_ ON table_.oid = indrelid
788
            WHERE
789
                table_.relname = %(table)s
790
                AND index.relname = %(index)s
791
        ) s
792
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
793
) s
794
ORDER BY attnum
795
''',
796
            {'table': table, 'index': index}, cacheable=True, log_level=4)))
797
    else: raise NotImplementedError("Can't list index columns for "+module+
798
        ' database')
799

    
800
def constraint_cols(db, table, constraint):
801
    module = util.root_module(db.db)
802
    if module == 'psycopg2':
803
        return list(values(run_query(db, '''\
804
SELECT attname
805
FROM pg_constraint
806
JOIN pg_class ON pg_class.oid = conrelid
807
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
808
WHERE
809
    relname = %(table)s
810
    AND conname = %(constraint)s
811
ORDER BY attnum
812
''',
813
            {'table': table, 'constraint': constraint})))
814
    else: raise NotImplementedError("Can't list constraint columns for "+module+
815
        ' database')
816

    
817
row_num_col = '_row_num'
818

    
819
def add_index(db, expr):
820
    '''Adds an index on a column or expression if it doesn't already exist.
821
    Currently, only function calls are supported as expressions.
822
    '''
823
    expr = copy.copy(expr) # don't modify input!
824
    
825
    # Extract col
826
    if isinstance(expr, sql_gen.FunctionCall):
827
        col = expr.args[0]
828
        expr = sql_gen.Expr(expr)
829
    else: col = expr
830
    assert sql_gen.is_table_col(col)
831
    
832
    index = sql_gen.as_Table(str(expr))
833
    table = col.table
834
    col.table = None
835
    try: run_query(db, 'CREATE INDEX '+index.to_str(db)+' ON '+table.to_str(db)
836
        +' ('+expr.to_str(db)+')', recover=True, cacheable=True, log_level=3)
837
    except DuplicateTableException: pass # index already existed
838

    
839
def add_pkey(db, table, recover=None):
840
    '''Makes the first column in a table the primary key.
841
    @pre The table must not already have a primary key.
842
    '''
843
    table = sql_gen.as_Table(table)
844
    
845
    index = sql_gen.as_Table(sql_gen.add_suffix(table.name, '_pkey'))
846
    col = sql_gen.to_name_only_col(pkey(db, table, recover))
847
    try:
848
        run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD CONSTRAINT '
849
            +index.to_str(db)+' PRIMARY KEY('+col.to_str(db)+')',
850
            recover=True, cacheable=True, log_level=3,
851
            log_ignore_excs=(DuplicateTableException,))
852
    except DuplicateTableException, e:
853
        index.name = next_version(index.name)
854
        # try again with next version of name
855

    
856
def add_row_num(db, table):
857
    '''Adds a row number column to a table. Its name is in row_num_col. It will
858
    be the primary key.'''
859
    table = sql_gen.as_Table(table).to_str(db)
860
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
861
        +' serial NOT NULL PRIMARY KEY', log_level=3)
862

    
863
def tables(db, schema_like='public', table_like='%'):
864
    module = util.root_module(db.db)
865
    params = {'schema_like': schema_like, 'table_like': table_like}
866
    if module == 'psycopg2':
867
        return values(run_query(db, '''\
868
SELECT tablename
869
FROM pg_tables
870
WHERE
871
    schemaname LIKE %(schema_like)s
872
    AND tablename LIKE %(table_like)s
873
ORDER BY tablename
874
''',
875
            params, cacheable=True))
876
    elif module == 'MySQLdb':
877
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
878
            cacheable=True))
879
    else: raise NotImplementedError("Can't list tables for "+module+' database')
880

    
881
##### Database management
882

    
883
def empty_db(db, schema='public', **kw_args):
884
    '''For kw_args, see tables()'''
885
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
886

    
887
##### Heuristic queries
888

    
889
def put(db, table, row, pkey_=None, row_ct_ref=None):
890
    '''Recovers from errors.
891
    Only works under PostgreSQL (uses INSERT RETURNING).
892
    '''
893
    row = sql_gen.ColDict(db, table, row)
894
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
895
    
896
    try:
897
        cur = insert(db, table, row, pkey_, recover=True)
898
        if row_ct_ref != None and cur.rowcount >= 0:
899
            row_ct_ref[0] += cur.rowcount
900
        return value(cur)
901
    except DuplicateKeyException, e:
902
        row = sql_gen.ColDict(db, table,
903
            util.dict_subset_right_join(row, e.cols))
904
        return value(select(db, table, [pkey_], row, recover=True))
905

    
906
def get(db, table, row, pkey, row_ct_ref=None, create=False):
907
    '''Recovers from errors'''
908
    try: return value(select(db, table, [pkey], row, limit=1, recover=True))
909
    except StopIteration:
910
        if not create: raise
911
        return put(db, table, row, pkey, row_ct_ref) # insert new row
912

    
913
def is_func_result(col):
914
    return col.table.name.find('(') >= 0 and col.name == 'result'
915

    
916
def into_table_name(out_table, in_tables0, mapping, is_func):
917
    def in_col_str(in_col):
918
        in_col = sql_gen.remove_col_rename(in_col)
919
        if isinstance(in_col, sql_gen.Col):
920
            table = in_col.table
921
            if table == in_tables0:
922
                in_col = sql_gen.to_name_only_col(in_col)
923
            elif is_func_result(in_col): in_col = table # omit col name
924
        return str(in_col)
925
    
926
    str_ = str(out_table)
927
    if is_func:
928
        str_ += '('
929
        
930
        try: value_in_col = mapping['value']
931
        except KeyError:
932
            str_ += ', '.join((str(k)+'='+in_col_str(v)
933
                for k, v in mapping.iteritems()))
934
        else: str_ += in_col_str(value_in_col)
935
        
936
        str_ += ')'
937
    else: str_ += '_pkeys'
938
    return str_
939

    
940
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, into=None,
941
    default=None, is_func=False):
942
    '''Recovers from errors.
943
    Only works under PostgreSQL (uses INSERT RETURNING).
944
    @param in_tables The main input table to select from, followed by a list of
945
        tables to join with it using the main input table's pkey
946
    @param mapping dict(out_table_col=in_table_col, ...)
947
        * out_table_col: str (*not* sql_gen.Col)
948
        * in_table_col: sql_gen.Col|literal-value
949
    @param into The table to contain the output and input pkeys.
950
        Defaults to `out_table.name+'_pkeys'`.
951
    @param default The *output* column to use as the pkey for missing rows.
952
        If this output column does not exist in the mapping, uses None.
953
    @param is_func Whether out_table is the name of a SQL function, not a table
954
    @return sql_gen.Col Where the output pkeys are made available
955
    '''
956
    out_table = sql_gen.as_Table(out_table)
957
    
958
    def log_debug(msg): db.log_debug(msg, level=1.5)
959
    def col_ustr(str_):
960
        return strings.repr_no_u(sql_gen.remove_col_rename(str_))
961
    
962
    log_debug('********** New iteration **********')
963
    log_debug('Inserting these input columns into '+strings.as_tt(
964
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
965
    
966
    # Create input joins from list of input tables
967
    in_tables_ = in_tables[:] # don't modify input!
968
    in_tables0 = in_tables_.pop(0) # first table is separate
969
    in_pkey = pkey(db, in_tables0, recover=True)
970
    in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
971
    input_joins = [in_tables0]+[sql_gen.Join(v,
972
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
973
    
974
    if into == None:
975
        into = into_table_name(out_table, in_tables0, mapping, is_func)
976
    into = sql_gen.as_Table(into)
977
    
978
    log_debug('Joining together input tables into temp table')
979
    # Place in new table for speed and so don't modify input if values edited
980
    in_table = sql_gen.Table('in')
981
    flatten_cols = filter(sql_gen.is_table_col, mapping.values())
982
    mapping = dicts.join(mapping, flatten(db, in_table, input_joins,
983
        flatten_cols, preserve=[in_pkey_col], start=0))
984
    input_joins = [in_table]
985
    db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
986
    
987
    mapping = sql_gen.ColDict(db, out_table, mapping)
988
        # after applying dicts.join() because that returns a plain dict
989
    
990
    # Resolve default value column
991
    try: default = mapping[default]
992
    except KeyError:
993
        if default != None:
994
            db.log_debug('Default value column '
995
                +strings.as_tt(strings.repr_no_u(default))
996
                +' does not exist in mapping, falling back to None', level=2.1)
997
            default = None
998
    
999
    out_pkey = pkey(db, out_table, recover=True)
1000
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
1001
    
1002
    pkeys_names = [in_pkey, out_pkey]
1003
    pkeys_cols = [in_pkey_col, out_pkey_col]
1004
    
1005
    pkeys_table_exists_ref = [False]
1006
    def insert_into_pkeys(joins, cols):
1007
        query, params = mk_select(db, joins, cols, order_by=None, start=0)
1008
        if pkeys_table_exists_ref[0]:
1009
            insert_select(db, into, pkeys_names, query, params)
1010
        else:
1011
            run_query_into(db, query, params, into=into)
1012
            pkeys_table_exists_ref[0] = True
1013
    
1014
    limit_ref = [None]
1015
    conds = set()
1016
    distinct_on = []
1017
    def mk_main_select(joins, cols):
1018
        return mk_select(db, joins, cols, conds, distinct_on,
1019
            limit=limit_ref[0], start=0)
1020
    
1021
    exc_strs = set()
1022
    def log_exc(e):
1023
        e_str = exc.str_(e, first_line_only=True)
1024
        log_debug('Caught exception: '+e_str)
1025
        assert e_str not in exc_strs # avoid infinite loops
1026
        exc_strs.add(e_str)
1027
    def remove_all_rows():
1028
        log_debug('Returning NULL for all rows')
1029
        limit_ref[0] = 0 # just create an empty pkeys table
1030
    def ignore(in_col, value):
1031
        in_col_str = strings.as_tt(repr(in_col))
1032
        db.log_debug('Adding index on '+in_col_str+' to enable fast filtering',
1033
            level=2.5)
1034
        add_index(db, in_col)
1035
        log_debug('Ignoring rows with '+in_col_str+' = '
1036
            +strings.as_tt(repr(value)))
1037
    def remove_rows(in_col, value):
1038
        ignore(in_col, value)
1039
        cond = (in_col, sql_gen.CompareCond(value, '!='))
1040
        assert cond not in conds # avoid infinite loops
1041
        conds.add(cond)
1042
    def invalid2null(in_col, value):
1043
        ignore(in_col, value)
1044
        update(db, in_table, [(in_col, None)],
1045
            sql_gen.ColValueCond(in_col, value))
1046
    
1047
    def insert_pkeys_table(which):
1048
        return sql_gen.Table(sql_gen.add_suffix(in_table.name,
1049
            '_insert_'+which+'_pkeys'))
1050
    insert_out_pkeys = insert_pkeys_table('out')
1051
    insert_in_pkeys = insert_pkeys_table('in')
1052
    
1053
    # Do inserts and selects
1054
    join_cols = sql_gen.ColDict(db, out_table)
1055
    while True:
1056
        if limit_ref[0] == 0: # special case
1057
            log_debug('Creating an empty pkeys table')
1058
            cur = run_query_into(db, *mk_select(db, out_table, [out_pkey],
1059
                limit=limit_ref[0]), into=insert_out_pkeys)
1060
            break # don't do main case
1061
        
1062
        has_joins = join_cols != {}
1063
        
1064
        # Prepare to insert new rows
1065
        insert_joins = input_joins[:] # don't modify original!
1066
        insert_args = dict(recover=True, cacheable=False)
1067
        if has_joins:
1068
            distinct_on = [v.to_Col() for v in join_cols.values()]
1069
            insert_joins.append(sql_gen.Join(out_table, join_cols,
1070
                sql_gen.filter_out))
1071
        else:
1072
            insert_args.update(dict(returning=out_pkey, into=insert_out_pkeys))
1073
        main_select = mk_main_select(insert_joins, mapping.values())[0]
1074
        
1075
        log_debug('Trying to insert new rows')
1076
        try:
1077
            cur = insert_select(db, out_table, mapping.keys(), main_select,
1078
                **insert_args)
1079
            break # insert successful
1080
        except DuplicateKeyException, e:
1081
            log_exc(e)
1082
            
1083
            old_join_cols = join_cols.copy()
1084
            join_cols.update(util.dict_subset_right_join(mapping, e.cols))
1085
            log_debug('Ignoring existing rows, comparing on these columns:\n'
1086
                +strings.as_inline_table(join_cols, ustr=col_ustr))
1087
            assert join_cols != old_join_cols # avoid infinite loops
1088
        except NullValueException, e:
1089
            log_exc(e)
1090
            
1091
            out_col, = e.cols
1092
            try: in_col = mapping[out_col]
1093
            except KeyError:
1094
                log_debug('Missing mapping for NOT NULL column '+out_col)
1095
                remove_all_rows()
1096
            else: remove_rows(in_col, None)
1097
        except FunctionValueException, e:
1098
            log_exc(e)
1099
            
1100
            func_name = e.name
1101
            value = e.value
1102
            for out_col, in_col in mapping.iteritems():
1103
                invalid2null(sql_gen.unwrap_func_call(in_col, func_name), value)
1104
        except MissingCastException, e:
1105
            log_exc(e)
1106
            
1107
            out_col = e.col
1108
            mapping[out_col] = sql_gen.wrap_in_func(e.type, mapping[out_col])
1109
        except DatabaseErrors, e:
1110
            log_exc(e)
1111
            
1112
            msg = 'No handler for exception: '+exc.str_(e)
1113
            warnings.warn(DbWarning(msg))
1114
            log_debug(msg)
1115
            remove_all_rows()
1116
        # after exception handled, rerun loop with additional constraints
1117
    
1118
    if row_ct_ref != None and cur.rowcount >= 0:
1119
        row_ct_ref[0] += cur.rowcount
1120
    
1121
    if has_joins:
1122
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols)]
1123
        log_debug('Getting output table pkeys of existing/inserted rows')
1124
        insert_into_pkeys(select_joins, pkeys_cols)
1125
    else:
1126
        add_row_num(db, insert_out_pkeys) # for joining with input pkeys
1127
        
1128
        log_debug('Getting input table pkeys of inserted rows')
1129
        run_query_into(db, *mk_main_select(input_joins, [in_pkey]),
1130
            into=insert_in_pkeys)
1131
        add_row_num(db, insert_in_pkeys) # for joining with output pkeys
1132
        
1133
        assert table_row_count(db, insert_out_pkeys) == table_row_count(db,
1134
            insert_in_pkeys)
1135
        
1136
        log_debug('Combining output and input pkeys in inserted order')
1137
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
1138
            {row_num_col: sql_gen.join_same_not_null})]
1139
        insert_into_pkeys(pkey_joins, pkeys_names)
1140
    
1141
    db.log_debug('Adding pkey on pkeys table to enable fast joins', level=2.5)
1142
    add_pkey(db, into)
1143
    
1144
    log_debug('Setting pkeys of missing rows to '+strings.as_tt(repr(default)))
1145
    missing_rows_joins = input_joins+[sql_gen.Join(into,
1146
        {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
1147
        # must use join_same_not_null or query will take forever
1148
    insert_into_pkeys(missing_rows_joins,
1149
        [in_pkey_col, sql_gen.NamedCol(out_pkey, default)])
1150
    
1151
    assert table_row_count(db, into) == table_row_count(db, in_table)
1152
    
1153
    return sql_gen.Col(out_pkey, into)
1154

    
1155
##### Data cleanup
1156

    
1157
def cleanup_table(db, table, cols):
1158
    def esc_name_(name): return esc_name(db, name)
1159
    
1160
    table = sql_gen.as_Table(table).to_str(db)
1161
    cols = map(esc_name_, cols)
1162
    
1163
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
1164
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
1165
            for col in cols))),
1166
        dict(null0='', null1=r'\N'))
(24-24/36)