Project

General

Profile

1
# Database access
2

    
3
import copy
4
import operator
5
import re
6
import warnings
7

    
8
import exc
9
import dicts
10
import iters
11
import lists
12
from Proxy import Proxy
13
import rand
14
import sql_gen
15
import strings
16
import util
17

    
18
##### Exceptions
19

    
20
def get_cur_query(cur, input_query=None, input_params=None):
21
    raw_query = None
22
    if hasattr(cur, 'query'): raw_query = cur.query
23
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
24
    
25
    if raw_query != None: return raw_query
26
    else: return '[input] '+strings.ustr(input_query)+' % '+repr(input_params)
27

    
28
def _add_cursor_info(e, *args, **kw_args):
29
    '''For params, see get_cur_query()'''
30
    exc.add_msg(e, 'query: '+str(get_cur_query(*args, **kw_args)))
31

    
32
class DbException(exc.ExceptionWithCause):
33
    def __init__(self, msg, cause=None, cur=None):
34
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
35
        if cur != None: _add_cursor_info(self, cur)
36

    
37
class ExceptionWithName(DbException):
38
    def __init__(self, name, cause=None):
39
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name)), cause)
40
        self.name = name
41

    
42
class ExceptionWithNameValue(DbException):
43
    def __init__(self, name, value, cause=None):
44
        DbException.__init__(self, 'for name: '+strings.as_tt(str(name))
45
            +'; value: '+strings.as_tt(repr(value)), cause)
46
        self.name = name
47
        self.value = value
48

    
49
class ConstraintException(DbException):
50
    def __init__(self, name, cols, cause=None):
51
        DbException.__init__(self, 'Violated '+strings.as_tt(name)
52
            +' constraint on columns: '+strings.as_tt(', '.join(cols)), cause)
53
        self.name = name
54
        self.cols = cols
55

    
56
class MissingCastException(DbException):
57
    def __init__(self, type_, col, cause=None):
58
        DbException.__init__(self, 'Missing cast to type '+strings.as_tt(type_)
59
            +' on column: '+strings.as_tt(col), cause)
60
        self.type = type_
61
        self.col = col
62

    
63
class NameException(DbException): pass
64

    
65
class DuplicateKeyException(ConstraintException): pass
66

    
67
class NullValueException(ConstraintException): pass
68

    
69
class FunctionValueException(ExceptionWithNameValue): pass
70

    
71
class DuplicateTableException(ExceptionWithName): pass
72

    
73
class DuplicateFunctionException(ExceptionWithName): pass
74

    
75
class EmptyRowException(DbException): pass
76

    
77
##### Warnings
78

    
79
class DbWarning(UserWarning): pass
80

    
81
##### Result retrieval
82

    
83
def col_names(cur): return (col[0] for col in cur.description)
84

    
85
def rows(cur): return iter(lambda: cur.fetchone(), None)
86

    
87
def consume_rows(cur):
88
    '''Used to fetch all rows so result will be cached'''
89
    iters.consume_iter(rows(cur))
90

    
91
def next_row(cur): return rows(cur).next()
92

    
93
def row(cur):
94
    row_ = next_row(cur)
95
    consume_rows(cur)
96
    return row_
97

    
98
def next_value(cur): return next_row(cur)[0]
99

    
100
def value(cur): return row(cur)[0]
101

    
102
def values(cur): return iters.func_iter(lambda: next_value(cur))
103

    
104
def value_or_none(cur):
105
    try: return value(cur)
106
    except StopIteration: return None
107

    
108
##### Input validation
109

    
110
def esc_name_by_module(module, name):
111
    if module == 'psycopg2' or module == None: quote = '"'
112
    elif module == 'MySQLdb': quote = '`'
113
    else: raise NotImplementedError("Can't escape name for "+module+' database')
114
    return sql_gen.esc_name(name, quote)
115

    
116
def esc_name_by_engine(engine, name, **kw_args):
117
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
118

    
119
def esc_name(db, name, **kw_args):
120
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
121

    
122
def qual_name(db, schema, table):
123
    def esc_name_(name): return esc_name(db, name)
124
    table = esc_name_(table)
125
    if schema != None: return esc_name_(schema)+'.'+table
126
    else: return table
127

    
128
##### Database connections
129

    
130
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
131

    
132
db_engines = {
133
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
134
    'PostgreSQL': ('psycopg2', {}),
135
}
136

    
137
DatabaseErrors_set = set([DbException])
138
DatabaseErrors = tuple(DatabaseErrors_set)
139

    
140
def _add_module(module):
141
    DatabaseErrors_set.add(module.DatabaseError)
142
    global DatabaseErrors
143
    DatabaseErrors = tuple(DatabaseErrors_set)
144

    
145
def db_config_str(db_config):
146
    return db_config['engine']+' database '+db_config['database']
147

    
148
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
149

    
150
log_debug_none = lambda msg, level=2: None
151

    
152
class DbConn:
153
    def __init__(self, db_config, serializable=True, autocommit=False,
154
        caching=True, log_debug=log_debug_none):
155
        self.db_config = db_config
156
        self.serializable = serializable
157
        self.autocommit = autocommit
158
        self.caching = caching
159
        self.log_debug = log_debug
160
        self.debug = log_debug != log_debug_none
161
        
162
        self.__db = None
163
        self.query_results = {}
164
        self._savepoint = 0
165
    
166
    def __getattr__(self, name):
167
        if name == '__dict__': raise Exception('getting __dict__')
168
        if name == 'db': return self._db()
169
        else: raise AttributeError()
170
    
171
    def __getstate__(self):
172
        state = copy.copy(self.__dict__) # shallow copy
173
        state['log_debug'] = None # don't pickle the debug callback
174
        state['_DbConn__db'] = None # don't pickle the connection
175
        return state
176
    
177
    def connected(self): return self.__db != None
178
    
179
    def _db(self):
180
        if self.__db == None:
181
            # Process db_config
182
            db_config = self.db_config.copy() # don't modify input!
183
            schemas = db_config.pop('schemas', None)
184
            module_name, mappings = db_engines[db_config.pop('engine')]
185
            module = __import__(module_name)
186
            _add_module(module)
187
            for orig, new in mappings.iteritems():
188
                try: util.rename_key(db_config, orig, new)
189
                except KeyError: pass
190
            
191
            # Connect
192
            self.__db = module.connect(**db_config)
193
            
194
            # Configure connection
195
            if self.serializable and not self.autocommit: run_raw_query(self,
196
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
197
            if schemas != None:
198
                schemas_ = ''.join((esc_name(self, s)+', '
199
                    for s in schemas.split(',')))
200
                run_raw_query(self, "SELECT set_config('search_path', \
201
%s || current_setting('search_path'), false)", [schemas_])
202
        
203
        return self.__db
204
    
205
    class DbCursor(Proxy):
206
        def __init__(self, outer):
207
            Proxy.__init__(self, outer.db.cursor())
208
            self.outer = outer
209
            self.query_results = outer.query_results
210
            self.query_lookup = None
211
            self.result = []
212
        
213
        def execute(self, query, params=None):
214
            self._is_insert = query.upper().find('INSERT') >= 0
215
            self.query_lookup = _query_lookup(query, params)
216
            try:
217
                try:
218
                    return_value = self.inner.execute(query, params)
219
                    self.outer.do_autocommit()
220
                finally: self.query = get_cur_query(self.inner, query, params)
221
            except Exception, e:
222
                _add_cursor_info(e, self, query, params)
223
                self.result = e # cache the exception as the result
224
                self._cache_result()
225
                raise
226
            # Fetch all rows so result will be cached
227
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
228
            return return_value
229
        
230
        def fetchone(self):
231
            row = self.inner.fetchone()
232
            if row != None: self.result.append(row)
233
            # otherwise, fetched all rows
234
            else: self._cache_result()
235
            return row
236
        
237
        def _cache_result(self):
238
            # For inserts, only cache exceptions since inserts are not
239
            # idempotent, but an invalid insert will always be invalid
240
            if self.query_results != None and (not self._is_insert
241
                or isinstance(self.result, Exception)):
242
                
243
                assert self.query_lookup != None
244
                self.query_results[self.query_lookup] = self.CacheCursor(
245
                    util.dict_subset(dicts.AttrsDictView(self),
246
                    ['query', 'result', 'rowcount', 'description']))
247
        
248
        class CacheCursor:
249
            def __init__(self, cached_result): self.__dict__ = cached_result
250
            
251
            def execute(self, *args, **kw_args):
252
                if isinstance(self.result, Exception): raise self.result
253
                # otherwise, result is a rows list
254
                self.iter = iter(self.result)
255
            
256
            def fetchone(self):
257
                try: return self.iter.next()
258
                except StopIteration: return None
259
    
260
    def esc_value(self, value):
261
        try: str_ = self.mogrify('%s', [value])
262
        except NotImplementedError, e:
263
            module = util.root_module(self.db)
264
            if module == 'MySQLdb':
265
                import _mysql
266
                str_ = _mysql.escape_string(value)
267
            else: raise e
268
        return strings.to_unicode(str_)
269
    
270
    def esc_name(self, name): return esc_name(self, name) # calls global func
271
    
272
    def can_mogrify(self):
273
        module = util.root_module(self.db)
274
        return module == 'psycopg2'
275
    
276
    def mogrify(self, query, params=None):
277
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
278
        else: raise NotImplementedError("Can't mogrify query")
279
    
280
    def run_query(self, query, params=None, cacheable=False, log_level=2,
281
        debug_msg_ref=None):
282
        '''
283
        @param log_ignore_excs The log_level will be increased by 2 if the query
284
            throws one of these exceptions.
285
        @param debug_msg_ref If specified, the log message will be returned in
286
            this instead of being output. This allows you to filter log messages
287
            depending on the result of the query.
288
        '''
289
        assert query != None
290
        
291
        if not self.caching: cacheable = False
292
        used_cache = False
293
        
294
        def log_msg(query):
295
            if used_cache: cache_status = 'cache hit'
296
            elif cacheable: cache_status = 'cache miss'
297
            else: cache_status = 'non-cacheable'
298
            return 'DB query: '+cache_status+':\n'+strings.as_code(query, 'SQL')
299
        
300
        try:
301
            # Get cursor
302
            if cacheable:
303
                query_lookup = _query_lookup(query, params)
304
                try:
305
                    cur = self.query_results[query_lookup]
306
                    used_cache = True
307
                except KeyError: cur = self.DbCursor(self)
308
            else: cur = self.db.cursor()
309
            
310
            # Log query
311
            if self.debug and debug_msg_ref == None: # log before running
312
                self.log_debug(log_msg(query), log_level)
313
            
314
            # Run query
315
            cur.execute(query, params)
316
        finally:
317
            if self.debug and debug_msg_ref != None: # return after running
318
                debug_msg_ref[0] = log_msg(str(get_cur_query(cur, query,
319
                    params)))
320
        
321
        return cur
322
    
323
    def is_cached(self, query, params=None):
324
        return _query_lookup(query, params) in self.query_results
325
    
326
    def with_savepoint(self, func):
327
        savepoint = 'level_'+str(self._savepoint)
328
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
329
        self._savepoint += 1
330
        try: 
331
            try: return_val = func()
332
            finally:
333
                self._savepoint -= 1
334
                assert self._savepoint >= 0
335
        except:
336
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
337
            raise
338
        else:
339
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
340
            self.do_autocommit()
341
            return return_val
342
    
343
    def do_autocommit(self):
344
        '''Autocommits if outside savepoint'''
345
        assert self._savepoint >= 0
346
        if self.autocommit and self._savepoint == 0:
347
            self.log_debug('Autocommiting')
348
            self.db.commit()
349
    
350
    def col_default(self, col):
351
        table = sql_gen.Table('columns', 'information_schema')
352
        
353
        conds = [('table_name', col.table.name), ('column_name', col.name)]
354
        schema = col.table.schema
355
        if schema != None: conds.append(('table_schema', schema))
356
        
357
        return sql_gen.as_Code(value(select(self, table, ['column_default'],
358
            conds, order_by='table_schema', limit=1, log_level=3)))
359
            # TODO: order_by search_path schema order
360

    
361
connect = DbConn
362

    
363
##### Querying
364

    
365
def run_raw_query(db, *args, **kw_args):
366
    '''For params, see DbConn.run_query()'''
367
    return db.run_query(*args, **kw_args)
368

    
369
def mogrify(db, query, params):
370
    module = util.root_module(db.db)
371
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
372
    else: raise NotImplementedError("Can't mogrify query for "+module+
373
        ' database')
374

    
375
##### Recoverable querying
376

    
377
def with_savepoint(db, func): return db.with_savepoint(func)
378

    
379
def run_query(db, query, params=None, recover=None, cacheable=False,
380
    log_level=2, log_ignore_excs=None, **kw_args):
381
    '''For params, see run_raw_query()'''
382
    if recover == None: recover = False
383
    if log_ignore_excs == None: log_ignore_excs = ()
384
    log_ignore_excs = tuple(log_ignore_excs)
385
    
386
    debug_msg_ref = None # usually, db.run_query() logs query before running it
387
    # But if filtering with log_ignore_excs, wait until after exception parsing
388
    if log_ignore_excs != () or not db.can_mogrify(): debug_msg_ref = [None] 
389
    
390
    try:
391
        try:
392
            def run(): return run_raw_query(db, query, params, cacheable,
393
                log_level, debug_msg_ref, **kw_args)
394
            if recover and not db.is_cached(query, params):
395
                return with_savepoint(db, run)
396
            else: return run() # don't need savepoint if cached
397
        except Exception, e:
398
            if not recover: raise # need savepoint to run index_cols()
399
            msg = exc.str_(e)
400
            
401
            match = re.search(r'duplicate key value violates unique constraint '
402
                r'"((_?[^\W_]+)_.+?)"', msg)
403
            if match:
404
                constraint, table = match.groups()
405
                try: cols = index_cols(db, table, constraint)
406
                except NotImplementedError: raise e
407
                else: raise DuplicateKeyException(constraint, cols, e)
408
            
409
            match = re.search(r'null value in column "(.+?)" violates not-null'
410
                r' constraint', msg)
411
            if match: raise NullValueException('NOT NULL', [match.group(1)], e)
412
            
413
            match = re.search(r'\b(?:invalid input (?:syntax|value)\b.*?'
414
                r'|date/time field value out of range): "(.+?)"\n'
415
                r'(?:(?s).*?)\bfunction "(.+?)"', msg)
416
            if match:
417
                value, name = match.groups()
418
                raise FunctionValueException(name, strings.to_unicode(value), e)
419
            
420
            match = re.search(r'column "(.+?)" is of type (.+?) but expression '
421
                r'is of type', msg)
422
            if match:
423
                col, type_ = match.groups()
424
                raise MissingCastException(type_, col, e)
425
            
426
            match = re.search(r'relation "(.+?)" already exists', msg)
427
            if match: raise DuplicateTableException(match.group(1), e)
428
            
429
            match = re.search(r'function "(.+?)" already exists', msg)
430
            if match: raise DuplicateFunctionException(match.group(1), e)
431
            
432
            raise # no specific exception raised
433
    except log_ignore_excs:
434
        log_level += 2
435
        raise
436
    finally:
437
        if debug_msg_ref != None and debug_msg_ref[0] != None:
438
            db.log_debug(debug_msg_ref[0], log_level)
439

    
440
##### Basic queries
441

    
442
def next_version(name):
443
    version = 1 # first existing name was version 0
444
    match = re.match(r'^(.*)#(\d+)$', name)
445
    if match:
446
        name, version = match.groups()
447
        version = int(version)+1
448
    return sql_gen.add_suffix(name, '#'+str(version))
449

    
450
def run_query_into(db, query, params, into=None, *args, **kw_args):
451
    '''Outputs a query to a temp table.
452
    For params, see run_query().
453
    '''
454
    if into == None: return run_query(db, query, params, *args, **kw_args)
455
    else: # place rows in temp table
456
        assert isinstance(into, sql_gen.Table)
457
        
458
        kw_args['recover'] = True
459
        kw_args.setdefault('log_ignore_excs', (DuplicateTableException,))
460
        
461
        temp = not db.autocommit # tables are permanent in autocommit mode
462
        # "temporary tables cannot specify a schema name", so remove schema
463
        if temp: into.schema = None
464
        
465
        while True:
466
            try:
467
                create_query = 'CREATE'
468
                if temp: create_query += ' TEMP'
469
                create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
470
                
471
                return run_query(db, create_query, params, *args, **kw_args)
472
                    # CREATE TABLE AS sets rowcount to # rows in query
473
            except DuplicateTableException, e:
474
                into.name = next_version(into.name)
475
                # try again with next version of name
476

    
477
order_by_pkey = object() # tells mk_select() to order by the pkey
478

    
479
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
480

    
481
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
482
    start=None, order_by=order_by_pkey, default_table=None):
483
    '''
484
    @param tables The single table to select from, or a list of tables to join
485
        together, with tables after the first being sql_gen.Join objects
486
    @param fields Use None to select all fields in the table
487
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
488
        * container can be any iterable type
489
        * compare_left_side: sql_gen.Code|str (for col name)
490
        * compare_right_side: sql_gen.ValueCond|literal value
491
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
492
        use all columns
493
    @return tuple(query, params)
494
    '''
495
    # Parse tables param
496
    if not lists.is_seq(tables): tables = [tables]
497
    tables = list(tables) # don't modify input! (list() copies input)
498
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
499
    
500
    # Parse other params
501
    if conds == None: conds = []
502
    elif dicts.is_dict(conds): conds = conds.items()
503
    conds = list(conds) # don't modify input! (list() copies input)
504
    assert limit == None or type(limit) == int
505
    assert start == None or type(start) == int
506
    if order_by is order_by_pkey:
507
        if distinct_on != []: order_by = None
508
        else: order_by = pkey(db, table0, recover=True)
509
    
510
    query = 'SELECT'
511
    
512
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
513
    
514
    # DISTINCT ON columns
515
    if distinct_on != []:
516
        query += '\nDISTINCT'
517
        if distinct_on is not distinct_on_all:
518
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
519
    
520
    # Columns
521
    query += '\n'
522
    if fields == None: query += '*'
523
    else: query += '\n, '.join(map(parse_col, fields))
524
    
525
    # Main table
526
    query += '\nFROM '+table0.to_str(db)
527
    
528
    # Add joins
529
    left_table = table0
530
    for join_ in tables:
531
        table = join_.table
532
        
533
        # Parse special values
534
        if join_.type_ is sql_gen.filter_out: # filter no match
535
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
536
                None))
537
        
538
        query += '\n'+join_.to_str(db, left_table)
539
        
540
        left_table = table
541
    
542
    missing = True
543
    if conds != []:
544
        if len(conds) == 1: whitespace = ' '
545
        else: whitespace = '\n'
546
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
547
            .to_str(db) for l, r in conds], 'WHERE')
548
        missing = False
549
    if order_by != None:
550
        query += '\nORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
551
    if limit != None: query += '\nLIMIT '+str(limit); missing = False
552
    if start != None:
553
        if start != 0: query += '\nOFFSET '+str(start)
554
        missing = False
555
    if missing: warnings.warn(DbWarning(
556
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
557
    
558
    return (query, [])
559

    
560
def select(db, *args, **kw_args):
561
    '''For params, see mk_select() and run_query()'''
562
    recover = kw_args.pop('recover', None)
563
    cacheable = kw_args.pop('cacheable', True)
564
    log_level = kw_args.pop('log_level', 2)
565
    
566
    query, params = mk_select(db, *args, **kw_args)
567
    return run_query(db, query, params, recover, cacheable, log_level=log_level)
568

    
569
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
570
    returning=None, embeddable=False):
571
    '''
572
    @param returning str|None An inserted column (such as pkey) to return
573
    @param embeddable Whether the query should be embeddable as a nested SELECT.
574
        Warning: If you set this and cacheable=True when the query is run, the
575
        query will be fully cached, not just if it raises an exception.
576
    '''
577
    table = sql_gen.as_Table(table)
578
    if cols == []: cols = None # no cols (all defaults) = unknown col names
579
    if cols != None:
580
        cols = [sql_gen.to_name_only_col(v, table).to_str(db) for v in cols]
581
    if select_query == None: select_query = 'DEFAULT VALUES'
582
    if returning != None: returning = sql_gen.as_Col(returning, table)
583
    
584
    # Build query
585
    first_line = 'INSERT INTO '+table.to_str(db)
586
    query = first_line
587
    if cols != None: query += '\n('+', '.join(cols)+')'
588
    query += '\n'+select_query
589
    
590
    if returning != None:
591
        returning_name = copy.copy(returning)
592
        returning_name.table = None
593
        returning_name = returning_name.to_str(db)
594
        query += '\nRETURNING '+returning_name
595
    
596
    if embeddable:
597
        assert returning != None
598
        
599
        # Create function
600
        function_name = sql_gen.clean_name(first_line)
601
        return_type = 'SETOF '+returning.to_str(db)+'%TYPE'
602
        while True:
603
            try:
604
                func_schema = None
605
                if not db.autocommit: func_schema = 'pg_temp'
606
                function = sql_gen.Table(function_name, func_schema).to_str(db)
607
                
608
                function_query = '''\
609
CREATE FUNCTION '''+function+'''()
610
RETURNS '''+return_type+'''
611
LANGUAGE sql
612
AS $$
613
'''+mogrify(db, query, params)+''';
614
$$;
615
'''
616
                run_query(db, function_query, recover=True, cacheable=True,
617
                    log_ignore_excs=(DuplicateFunctionException,))
618
                break # this version was successful
619
            except DuplicateFunctionException, e:
620
                function_name = next_version(function_name)
621
                # try again with next version of name
622
        
623
        # Return query that uses function
624
        func_table = sql_gen.NamedTable('f', sql_gen.CustomCode(function+'()'),
625
            [returning_name]) # AS clause requires function alias
626
        return mk_select(db, func_table, start=0, order_by=None)
627
    
628
    return (query, params)
629

    
630
def insert_select(db, *args, **kw_args):
631
    '''For params, see mk_insert_select() and run_query_into()
632
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
633
        values in
634
    '''
635
    into = kw_args.pop('into', None)
636
    if into != None: kw_args['embeddable'] = True
637
    recover = kw_args.pop('recover', None)
638
    cacheable = kw_args.pop('cacheable', True)
639
    
640
    query, params = mk_insert_select(db, *args, **kw_args)
641
    return run_query_into(db, query, params, into, recover=recover,
642
        cacheable=cacheable)
643

    
644
default = object() # tells insert() to use the default value for a column
645

    
646
def insert(db, table, row, *args, **kw_args):
647
    '''For params, see insert_select()'''
648
    if lists.is_seq(row): cols = None
649
    else:
650
        cols = row.keys()
651
        row = row.values()
652
    row = list(row) # ensure that "!= []" works
653
    
654
    # Check for special values
655
    labels = []
656
    values = []
657
    for value in row:
658
        value = sql_gen.remove_col_rename(sql_gen.as_Value(value)).value
659
        if value is default: labels.append('DEFAULT')
660
        else:
661
            labels.append('%s')
662
            values.append(value)
663
    
664
    # Build query
665
    if values != []: query = 'VALUES ('+(', '.join(labels))+')'
666
    else: query = None
667
    
668
    return insert_select(db, table, cols, query, values, *args, **kw_args)
669

    
670
def mk_update(db, table, changes=None, cond=None):
671
    '''
672
    @param changes [(col, new_value),...]
673
        * container can be any iterable type
674
        * col: sql_gen.Code|str (for col name)
675
        * new_value: sql_gen.Code|literal value
676
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
677
    @return str query
678
    '''
679
    query = 'UPDATE '+sql_gen.as_Table(table).to_str(db)+'\nSET\n'
680
    query += ',\n'.join((sql_gen.to_name_only_col(col, table).to_str(db)+' = '
681
        +sql_gen.as_Value(new_value).to_str(db) for col, new_value in changes))
682
    if cond != None: query += '\nWHERE\n'+cond.to_str(db)
683
    
684
    return query
685

    
686
def update(db, *args, **kw_args):
687
    '''For params, see mk_update() and run_query()'''
688
    recover = kw_args.pop('recover', None)
689
    
690
    return run_query(db, mk_update(db, *args, **kw_args), [], recover)
691

    
692
def last_insert_id(db):
693
    module = util.root_module(db.db)
694
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
695
    elif module == 'MySQLdb': return db.insert_id()
696
    else: return None
697

    
698
def truncate(db, table, schema='public'):
699
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
700

    
701
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
702
    '''Creates a mapping from original column names (which may have collisions)
703
    to names that will be distinct among the columns' tables.
704
    This is meant to be used for several tables that are being joined together.
705
    @param cols The columns to combine. Duplicates will be removed.
706
    @param into The table for the new columns.
707
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
708
        columns will be included in the mapping even if they are not in cols.
709
        The tables of the provided Col objects will be changed to into, so make
710
        copies of them if you want to keep the original tables.
711
    @param as_items Whether to return a list of dict items instead of a dict
712
    @return dict(orig_col=new_col, ...)
713
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
714
        * new_col: sql_gen.Col(orig_col_name, into)
715
        * All mappings use the into table so its name can easily be
716
          changed for all columns at once
717
    '''
718
    cols = lists.uniqify(cols)
719
    
720
    items = []
721
    for col in preserve:
722
        orig_col = copy.copy(col)
723
        col.table = into
724
        items.append((orig_col, col))
725
    preserve = set(preserve)
726
    for col in cols:
727
        if col not in preserve: items.append((col, sql_gen.Col(str(col), into)))
728
    
729
    if not as_items: items = dict(items)
730
    return items
731

    
732
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
733
    '''For params, see mk_flatten_mapping()
734
    @return See return value of mk_flatten_mapping()
735
    '''
736
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
737
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
738
    run_query_into(db, *mk_select(db, joins, cols, limit=limit, start=start),
739
        into=into)
740
    return dict(items)
741

    
742
##### Database structure queries
743

    
744
def table_row_count(db, table, recover=None):
745
    return value(run_query(db, *mk_select(db, table, [sql_gen.row_count],
746
        order_by=None, start=0), recover=recover, log_level=3))
747

    
748
def table_cols(db, table, recover=None):
749
    return list(col_names(select(db, table, limit=0, order_by=None,
750
        recover=recover, log_level=4)))
751

    
752
def pkey(db, table, recover=None):
753
    '''Assumed to be first column in table'''
754
    return table_cols(db, table, recover)[0]
755

    
756
not_null_col = 'not_null_col'
757

    
758
def table_not_null_col(db, table, recover=None):
759
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
760
    if not_null_col in table_cols(db, table, recover): return not_null_col
761
    else: return pkey(db, table, recover)
762

    
763
def index_cols(db, table, index):
764
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
765
    automatically created. When you don't know whether something is a UNIQUE
766
    constraint or a UNIQUE index, use this function.'''
767
    module = util.root_module(db.db)
768
    if module == 'psycopg2':
769
        return list(values(run_query(db, '''\
770
SELECT attname
771
FROM
772
(
773
        SELECT attnum, attname
774
        FROM pg_index
775
        JOIN pg_class index ON index.oid = indexrelid
776
        JOIN pg_class table_ ON table_.oid = indrelid
777
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
778
        WHERE
779
            table_.relname = %(table)s
780
            AND index.relname = %(index)s
781
    UNION
782
        SELECT attnum, attname
783
        FROM
784
        (
785
            SELECT
786
                indrelid
787
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
788
                    AS indkey
789
            FROM pg_index
790
            JOIN pg_class index ON index.oid = indexrelid
791
            JOIN pg_class table_ ON table_.oid = indrelid
792
            WHERE
793
                table_.relname = %(table)s
794
                AND index.relname = %(index)s
795
        ) s
796
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
797
) s
798
ORDER BY attnum
799
''',
800
            {'table': table, 'index': index}, cacheable=True, log_level=4)))
801
    else: raise NotImplementedError("Can't list index columns for "+module+
802
        ' database')
803

    
804
def constraint_cols(db, table, constraint):
805
    module = util.root_module(db.db)
806
    if module == 'psycopg2':
807
        return list(values(run_query(db, '''\
808
SELECT attname
809
FROM pg_constraint
810
JOIN pg_class ON pg_class.oid = conrelid
811
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
812
WHERE
813
    relname = %(table)s
814
    AND conname = %(constraint)s
815
ORDER BY attnum
816
''',
817
            {'table': table, 'constraint': constraint})))
818
    else: raise NotImplementedError("Can't list constraint columns for "+module+
819
        ' database')
820

    
821
row_num_col = '_row_num'
822

    
823
def add_index(db, expr):
824
    '''Adds an index on a column or expression if it doesn't already exist.
825
    Currently, only function calls are supported as expressions.
826
    '''
827
    expr = copy.copy(expr) # don't modify input!
828
    
829
    # Extract col
830
    if isinstance(expr, sql_gen.FunctionCall):
831
        col = expr.args[0]
832
        expr = sql_gen.Expr(expr)
833
    else: col = expr
834
    assert sql_gen.is_table_col(col)
835
    
836
    index = sql_gen.as_Table(str(expr))
837
    table = col.table
838
    col.table = None
839
    try: run_query(db, 'CREATE INDEX '+index.to_str(db)+' ON '+table.to_str(db)
840
        +' ('+expr.to_str(db)+')', recover=True, cacheable=True, log_level=3)
841
    except DuplicateTableException: pass # index already existed
842

    
843
def add_pkey(db, table, recover=None):
844
    '''Makes the first column in a table the primary key.
845
    @pre The table must not already have a primary key.
846
    '''
847
    table = sql_gen.as_Table(table)
848
    
849
    index = sql_gen.as_Table(sql_gen.add_suffix(table.name, '_pkey'))
850
    col = sql_gen.to_name_only_col(pkey(db, table, recover))
851
    try:
852
        run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD CONSTRAINT '
853
            +index.to_str(db)+' PRIMARY KEY('+col.to_str(db)+')',
854
            recover=True, cacheable=True, log_level=3,
855
            log_ignore_excs=(DuplicateTableException,))
856
    except DuplicateTableException, e:
857
        index.name = next_version(index.name)
858
        # try again with next version of name
859

    
860
def add_row_num(db, table):
861
    '''Adds a row number column to a table. Its name is in row_num_col. It will
862
    be the primary key.'''
863
    table = sql_gen.as_Table(table).to_str(db)
864
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
865
        +' serial NOT NULL PRIMARY KEY', log_level=3)
866

    
867
def tables(db, schema_like='public', table_like='%'):
868
    module = util.root_module(db.db)
869
    params = {'schema_like': schema_like, 'table_like': table_like}
870
    if module == 'psycopg2':
871
        return values(run_query(db, '''\
872
SELECT tablename
873
FROM pg_tables
874
WHERE
875
    schemaname LIKE %(schema_like)s
876
    AND tablename LIKE %(table_like)s
877
ORDER BY tablename
878
''',
879
            params, cacheable=True))
880
    elif module == 'MySQLdb':
881
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
882
            cacheable=True))
883
    else: raise NotImplementedError("Can't list tables for "+module+' database')
884

    
885
##### Database management
886

    
887
def empty_db(db, schema='public', **kw_args):
888
    '''For kw_args, see tables()'''
889
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
890

    
891
##### Heuristic queries
892

    
893
def put(db, table, row, pkey_=None, row_ct_ref=None):
894
    '''Recovers from errors.
895
    Only works under PostgreSQL (uses INSERT RETURNING).
896
    '''
897
    row = sql_gen.ColDict(db, table, row)
898
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
899
    
900
    try:
901
        cur = insert(db, table, row, pkey_, recover=True)
902
        if row_ct_ref != None and cur.rowcount >= 0:
903
            row_ct_ref[0] += cur.rowcount
904
        return value(cur)
905
    except DuplicateKeyException, e:
906
        row = sql_gen.ColDict(db, table,
907
            util.dict_subset_right_join(row, e.cols))
908
        return value(select(db, table, [pkey_], row, recover=True))
909

    
910
def get(db, table, row, pkey, row_ct_ref=None, create=False):
911
    '''Recovers from errors'''
912
    try: return value(select(db, table, [pkey], row, limit=1, recover=True))
913
    except StopIteration:
914
        if not create: raise
915
        return put(db, table, row, pkey, row_ct_ref) # insert new row
916

    
917
def is_func_result(col):
918
    return col.table.name.find('(') >= 0 and col.name == 'result'
919

    
920
def into_table_name(out_table, in_tables0, mapping, is_func):
921
    def in_col_str(in_col):
922
        in_col = sql_gen.remove_col_rename(in_col)
923
        if isinstance(in_col, sql_gen.Col):
924
            table = in_col.table
925
            if table == in_tables0:
926
                in_col = sql_gen.to_name_only_col(in_col)
927
            elif is_func_result(in_col): in_col = table # omit col name
928
        return str(in_col)
929
    
930
    str_ = str(out_table)
931
    if is_func:
932
        str_ += '('
933
        
934
        try: value_in_col = mapping['value']
935
        except KeyError:
936
            str_ += ', '.join((str(k)+'='+in_col_str(v)
937
                for k, v in mapping.iteritems()))
938
        else: str_ += in_col_str(value_in_col)
939
        
940
        str_ += ')'
941
    else: str_ += '_pkeys'
942
    return str_
943

    
944
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None, into=None,
945
    default=None, is_func=False):
946
    '''Recovers from errors.
947
    Only works under PostgreSQL (uses INSERT RETURNING).
948
    @param in_tables The main input table to select from, followed by a list of
949
        tables to join with it using the main input table's pkey
950
    @param mapping dict(out_table_col=in_table_col, ...)
951
        * out_table_col: str (*not* sql_gen.Col)
952
        * in_table_col: sql_gen.Col|literal-value
953
    @param into The table to contain the output and input pkeys.
954
        Defaults to `out_table.name+'_pkeys'`.
955
    @param default The *output* column to use as the pkey for missing rows.
956
        If this output column does not exist in the mapping, uses None.
957
    @param is_func Whether out_table is the name of a SQL function, not a table
958
    @return sql_gen.Col Where the output pkeys are made available
959
    '''
960
    out_table = sql_gen.as_Table(out_table)
961
    
962
    def log_debug(msg): db.log_debug(msg, level=1.5)
963
    def col_ustr(str_):
964
        return strings.repr_no_u(sql_gen.remove_col_rename(str_))
965
    
966
    log_debug('********** New iteration **********')
967
    log_debug('Inserting these input columns into '+strings.as_tt(
968
        out_table.to_str(db))+':\n'+strings.as_table(mapping, ustr=col_ustr))
969
    
970
    # Create input joins from list of input tables
971
    in_tables_ = in_tables[:] # don't modify input!
972
    in_tables0 = in_tables_.pop(0) # first table is separate
973
    in_pkey = pkey(db, in_tables0, recover=True)
974
    in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
975
    input_joins = [in_tables0]+[sql_gen.Join(v,
976
        {in_pkey: sql_gen.join_same_not_null}) for v in in_tables_]
977
    
978
    if into == None:
979
        into = into_table_name(out_table, in_tables0, mapping, is_func)
980
    into = sql_gen.as_Table(into)
981
    
982
    log_debug('Joining together input tables into temp table')
983
    # Place in new table for speed and so don't modify input if values edited
984
    in_table = sql_gen.Table('in')
985
    flatten_cols = filter(sql_gen.is_table_col, mapping.values())
986
    mapping = dicts.join(mapping, flatten(db, in_table, input_joins,
987
        flatten_cols, preserve=[in_pkey_col], start=0))
988
    input_joins = [in_table]
989
    db.log_debug('Temp table: '+strings.as_tt(in_table.to_str(db)), level=2)
990
    
991
    mapping = sql_gen.ColDict(db, out_table, mapping)
992
        # after applying dicts.join() because that returns a plain dict
993
    
994
    # Resolve default value column
995
    try: default = mapping[default]
996
    except KeyError:
997
        if default != None:
998
            db.log_debug('Default value column '
999
                +strings.as_tt(strings.repr_no_u(default))
1000
                +' does not exist in mapping, falling back to None', level=2.1)
1001
            default = None
1002
    
1003
    out_pkey = pkey(db, out_table, recover=True)
1004
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
1005
    
1006
    pkeys_names = [in_pkey, out_pkey]
1007
    pkeys_cols = [in_pkey_col, out_pkey_col]
1008
    
1009
    pkeys_table_exists_ref = [False]
1010
    def insert_into_pkeys(joins, cols):
1011
        query, params = mk_select(db, joins, cols, order_by=None, start=0)
1012
        if pkeys_table_exists_ref[0]:
1013
            insert_select(db, into, pkeys_names, query, params)
1014
        else:
1015
            run_query_into(db, query, params, into=into)
1016
            pkeys_table_exists_ref[0] = True
1017
    
1018
    limit_ref = [None]
1019
    conds = set()
1020
    distinct_on = []
1021
    def mk_main_select(joins, cols):
1022
        return mk_select(db, joins, cols, conds, distinct_on,
1023
            limit=limit_ref[0], start=0)
1024
    
1025
    exc_strs = set()
1026
    def log_exc(e):
1027
        e_str = exc.str_(e, first_line_only=True)
1028
        log_debug('Caught exception: '+e_str)
1029
        assert e_str not in exc_strs # avoid infinite loops
1030
        exc_strs.add(e_str)
1031
    def remove_all_rows():
1032
        log_debug('Returning NULL for all rows')
1033
        limit_ref[0] = 0 # just create an empty pkeys table
1034
    def ignore(in_col, value):
1035
        in_col_str = strings.as_tt(repr(in_col))
1036
        db.log_debug('Adding index on '+in_col_str+' to enable fast filtering',
1037
            level=2.5)
1038
        add_index(db, in_col)
1039
        log_debug('Ignoring rows with '+in_col_str+' = '
1040
            +strings.as_tt(repr(value)))
1041
    def remove_rows(in_col, value):
1042
        ignore(in_col, value)
1043
        cond = (in_col, sql_gen.CompareCond(value, '!='))
1044
        assert cond not in conds # avoid infinite loops
1045
        conds.add(cond)
1046
    def invalid2null(in_col, value):
1047
        ignore(in_col, value)
1048
        update(db, in_table, [(in_col, None)],
1049
            sql_gen.ColValueCond(in_col, value))
1050
    
1051
    def insert_pkeys_table(which):
1052
        return sql_gen.Table(sql_gen.add_suffix(in_table.name,
1053
            '_insert_'+which+'_pkeys'))
1054
    insert_out_pkeys = insert_pkeys_table('out')
1055
    insert_in_pkeys = insert_pkeys_table('in')
1056
    
1057
    # Do inserts and selects
1058
    join_cols = sql_gen.ColDict(db, out_table)
1059
    while True:
1060
        if limit_ref[0] == 0: # special case
1061
            log_debug('Creating an empty pkeys table')
1062
            cur = run_query_into(db, *mk_select(db, out_table, [out_pkey],
1063
                limit=limit_ref[0]), into=insert_out_pkeys)
1064
            break # don't do main case
1065
        
1066
        has_joins = join_cols != {}
1067
        
1068
        # Prepare to insert new rows
1069
        insert_joins = input_joins[:] # don't modify original!
1070
        insert_args = dict(recover=True, cacheable=False)
1071
        if has_joins:
1072
            distinct_on = [v.to_Col() for v in join_cols.values()]
1073
            insert_joins.append(sql_gen.Join(out_table, join_cols,
1074
                sql_gen.filter_out))
1075
        else:
1076
            insert_args.update(dict(returning=out_pkey, into=insert_out_pkeys))
1077
        main_select = mk_main_select(insert_joins, mapping.values())[0]
1078
        
1079
        log_debug('Trying to insert new rows')
1080
        try:
1081
            cur = insert_select(db, out_table, mapping.keys(), main_select,
1082
                **insert_args)
1083
            break # insert successful
1084
        except DuplicateKeyException, e:
1085
            log_exc(e)
1086
            
1087
            old_join_cols = join_cols.copy()
1088
            join_cols.update(util.dict_subset_right_join(mapping, e.cols))
1089
            log_debug('Ignoring existing rows, comparing on these columns:\n'
1090
                +strings.as_inline_table(join_cols, ustr=col_ustr))
1091
            assert join_cols != old_join_cols # avoid infinite loops
1092
        except NullValueException, e:
1093
            log_exc(e)
1094
            
1095
            out_col, = e.cols
1096
            try: in_col = mapping[out_col]
1097
            except KeyError:
1098
                log_debug('Missing mapping for NOT NULL column '+out_col)
1099
                remove_all_rows()
1100
            else: remove_rows(in_col, None)
1101
        except FunctionValueException, e:
1102
            log_exc(e)
1103
            
1104
            func_name = e.name
1105
            value = e.value
1106
            for out_col, in_col in mapping.iteritems():
1107
                invalid2null(sql_gen.unwrap_func_call(in_col, func_name), value)
1108
        except MissingCastException, e:
1109
            log_exc(e)
1110
            
1111
            out_col = e.col
1112
            mapping[out_col] = sql_gen.wrap_in_func(e.type, mapping[out_col])
1113
        except DatabaseErrors, e:
1114
            log_exc(e)
1115
            
1116
            msg = 'No handler for exception: '+exc.str_(e)
1117
            warnings.warn(DbWarning(msg))
1118
            log_debug(msg)
1119
            remove_all_rows()
1120
        # after exception handled, rerun loop with additional constraints
1121
    
1122
    if row_ct_ref != None and cur.rowcount >= 0:
1123
        row_ct_ref[0] += cur.rowcount
1124
    
1125
    if has_joins:
1126
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols)]
1127
        log_debug('Getting output table pkeys of existing/inserted rows')
1128
        insert_into_pkeys(select_joins, pkeys_cols)
1129
    else:
1130
        add_row_num(db, insert_out_pkeys) # for joining with input pkeys
1131
        
1132
        log_debug('Getting input table pkeys of inserted rows')
1133
        run_query_into(db, *mk_main_select(input_joins, [in_pkey]),
1134
            into=insert_in_pkeys)
1135
        add_row_num(db, insert_in_pkeys) # for joining with output pkeys
1136
        
1137
        assert table_row_count(db, insert_out_pkeys) == table_row_count(db,
1138
            insert_in_pkeys)
1139
        
1140
        log_debug('Combining output and input pkeys in inserted order')
1141
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
1142
            {row_num_col: sql_gen.join_same_not_null})]
1143
        insert_into_pkeys(pkey_joins, pkeys_names)
1144
    
1145
    db.log_debug('Adding pkey on pkeys table to enable fast joins', level=2.5)
1146
    add_pkey(db, into)
1147
    
1148
    log_debug('Setting pkeys of missing rows to '+strings.as_tt(repr(default)))
1149
    missing_rows_joins = input_joins+[sql_gen.Join(into,
1150
        {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
1151
        # must use join_same_not_null or query will take forever
1152
    insert_into_pkeys(missing_rows_joins,
1153
        [in_pkey_col, sql_gen.NamedCol(out_pkey, default)])
1154
    
1155
    assert table_row_count(db, into) == table_row_count(db, in_table)
1156
    
1157
    return sql_gen.Col(out_pkey, into)
1158

    
1159
##### Data cleanup
1160

    
1161
def cleanup_table(db, table, cols):
1162
    def esc_name_(name): return esc_name(db, name)
1163
    
1164
    table = sql_gen.as_Table(table).to_str(db)
1165
    cols = map(esc_name_, cols)
1166
    
1167
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
1168
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
1169
            for col in cols))),
1170
        dict(null0='', null1=r'\N'))
(24-24/36)