Project

General

Profile

1
# Database access
2

    
3
import copy
4
import operator
5
import re
6
import warnings
7

    
8
import exc
9
import dicts
10
import iters
11
import lists
12
from Proxy import Proxy
13
import rand
14
import sql_gen
15
import strings
16
import util
17

    
18
##### Exceptions
19

    
20
def get_cur_query(cur, input_query=None, input_params=None):
21
    raw_query = None
22
    if hasattr(cur, 'query'): raw_query = cur.query
23
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
24
    
25
    if raw_query != None: return raw_query
26
    else: return '[input] '+strings.ustr(input_query)+' % '+repr(input_params)
27

    
28
def _add_cursor_info(e, *args, **kw_args):
29
    '''For params, see get_cur_query()'''
30
    exc.add_msg(e, 'query: '+str(get_cur_query(*args, **kw_args)))
31

    
32
class DbException(exc.ExceptionWithCause):
33
    def __init__(self, msg, cause=None, cur=None):
34
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
35
        if cur != None: _add_cursor_info(self, cur)
36

    
37
class ExceptionWithName(DbException):
38
    def __init__(self, name, cause=None):
39
        DbException.__init__(self, 'for name: '+str(name), cause)
40
        self.name = name
41

    
42
class ExceptionWithNameValue(DbException):
43
    def __init__(self, name, value, cause=None):
44
        DbException.__init__(self,
45
            'for name: '+str(name)+'; value: '+repr(value), cause)
46
        self.name = name
47
        self.value = value
48

    
49
class ConstraintException(DbException):
50
    def __init__(self, name, cols, cause=None):
51
        DbException.__init__(self, 'Violated '+name+ ' constraint on columns: '
52
            +(', '.join(cols)), cause)
53
        self.name = name
54
        self.cols = cols
55

    
56
class NameException(DbException): pass
57

    
58
class DuplicateKeyException(ConstraintException): pass
59

    
60
class NullValueException(ConstraintException): pass
61

    
62
class FunctionValueException(ExceptionWithNameValue): pass
63

    
64
class DuplicateTableException(ExceptionWithName): pass
65

    
66
class DuplicateFunctionException(ExceptionWithName): pass
67

    
68
class EmptyRowException(DbException): pass
69

    
70
##### Warnings
71

    
72
class DbWarning(UserWarning): pass
73

    
74
##### Result retrieval
75

    
76
def col_names(cur): return (col[0] for col in cur.description)
77

    
78
def rows(cur): return iter(lambda: cur.fetchone(), None)
79

    
80
def consume_rows(cur):
81
    '''Used to fetch all rows so result will be cached'''
82
    iters.consume_iter(rows(cur))
83

    
84
def next_row(cur): return rows(cur).next()
85

    
86
def row(cur):
87
    row_ = next_row(cur)
88
    consume_rows(cur)
89
    return row_
90

    
91
def next_value(cur): return next_row(cur)[0]
92

    
93
def value(cur): return row(cur)[0]
94

    
95
def values(cur): return iters.func_iter(lambda: next_value(cur))
96

    
97
def value_or_none(cur):
98
    try: return value(cur)
99
    except StopIteration: return None
100

    
101
##### Input validation
102

    
103
def clean_name(name): return re.sub(r'\W', r'', name.replace('.', '_'))
104

    
105
def check_name(name):
106
    if re.search(r'\W', name) != None: raise NameException('Name "'+name
107
        +'" may contain only alphanumeric characters and _')
108

    
109
def esc_name_by_module(module, name, ignore_case=False):
110
    if module == 'psycopg2' or module == None:
111
        if ignore_case:
112
            # Don't enclose in quotes because this disables case-insensitivity
113
            check_name(name)
114
            return name
115
        else: quote = '"'
116
    elif module == 'MySQLdb': quote = '`'
117
    else: raise NotImplementedError("Can't escape name for "+module+' database')
118
    return quote + name.replace(quote, '') + quote
119

    
120
def esc_name_by_engine(engine, name, **kw_args):
121
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)
122

    
123
def esc_name(db, name, **kw_args):
124
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)
125

    
126
def qual_name(db, schema, table):
127
    def esc_name_(name): return esc_name(db, name)
128
    table = esc_name_(table)
129
    if schema != None: return esc_name_(schema)+'.'+table
130
    else: return table
131

    
132
##### Database connections
133

    
134
db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']
135

    
136
db_engines = {
137
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
138
    'PostgreSQL': ('psycopg2', {}),
139
}
140

    
141
DatabaseErrors_set = set([DbException])
142
DatabaseErrors = tuple(DatabaseErrors_set)
143

    
144
def _add_module(module):
145
    DatabaseErrors_set.add(module.DatabaseError)
146
    global DatabaseErrors
147
    DatabaseErrors = tuple(DatabaseErrors_set)
148

    
149
def db_config_str(db_config):
150
    return db_config['engine']+' database '+db_config['database']
151

    
152
def _query_lookup(query, params): return (query, dicts.make_hashable(params))
153

    
154
log_debug_none = lambda msg: None
155

    
156
class DbConn:
157
    def __init__(self, db_config, serializable=True, autocommit=False,
158
        caching=True, log_debug=log_debug_none):
159
        self.db_config = db_config
160
        self.serializable = serializable
161
        self.autocommit = autocommit
162
        self.caching = caching
163
        self.log_debug = log_debug
164
        self.debug = log_debug != log_debug_none
165
        
166
        self.__db = None
167
        self.query_results = {}
168
        self._savepoint = 0
169
    
170
    def __getattr__(self, name):
171
        if name == '__dict__': raise Exception('getting __dict__')
172
        if name == 'db': return self._db()
173
        else: raise AttributeError()
174
    
175
    def __getstate__(self):
176
        state = copy.copy(self.__dict__) # shallow copy
177
        state['log_debug'] = None # don't pickle the debug callback
178
        state['_DbConn__db'] = None # don't pickle the connection
179
        return state
180
    
181
    def connected(self): return self.__db != None
182
    
183
    def _db(self):
184
        if self.__db == None:
185
            # Process db_config
186
            db_config = self.db_config.copy() # don't modify input!
187
            schemas = db_config.pop('schemas', None)
188
            module_name, mappings = db_engines[db_config.pop('engine')]
189
            module = __import__(module_name)
190
            _add_module(module)
191
            for orig, new in mappings.iteritems():
192
                try: util.rename_key(db_config, orig, new)
193
                except KeyError: pass
194
            
195
            # Connect
196
            self.__db = module.connect(**db_config)
197
            
198
            # Configure connection
199
            if self.serializable and not self.autocommit: run_raw_query(self,
200
                'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
201
            if schemas != None:
202
                schemas_ = ''.join((esc_name(self, s)+', '
203
                    for s in schemas.split(',')))
204
                run_raw_query(self, "SELECT set_config('search_path', \
205
%s || current_setting('search_path'), false)", [schemas_])
206
        
207
        return self.__db
208
    
209
    class DbCursor(Proxy):
210
        def __init__(self, outer):
211
            Proxy.__init__(self, outer.db.cursor())
212
            self.outer = outer
213
            self.query_results = outer.query_results
214
            self.query_lookup = None
215
            self.result = []
216
        
217
        def execute(self, query, params=None):
218
            self._is_insert = query.upper().find('INSERT') >= 0
219
            self.query_lookup = _query_lookup(query, params)
220
            try:
221
                try:
222
                    return_value = self.inner.execute(query, params)
223
                    self.outer.do_autocommit()
224
                finally: self.query = get_cur_query(self.inner)
225
            except Exception, e:
226
                _add_cursor_info(e, self, query, params)
227
                self.result = e # cache the exception as the result
228
                self._cache_result()
229
                raise
230
            # Fetch all rows so result will be cached
231
            if self.rowcount == 0 and not self._is_insert: consume_rows(self)
232
            return return_value
233
        
234
        def fetchone(self):
235
            row = self.inner.fetchone()
236
            if row != None: self.result.append(row)
237
            # otherwise, fetched all rows
238
            else: self._cache_result()
239
            return row
240
        
241
        def _cache_result(self):
242
            # For inserts, only cache exceptions since inserts are not
243
            # idempotent, but an invalid insert will always be invalid
244
            if self.query_results != None and (not self._is_insert
245
                or isinstance(self.result, Exception)):
246
                
247
                assert self.query_lookup != None
248
                self.query_results[self.query_lookup] = self.CacheCursor(
249
                    util.dict_subset(dicts.AttrsDictView(self),
250
                    ['query', 'result', 'rowcount', 'description']))
251
        
252
        class CacheCursor:
253
            def __init__(self, cached_result): self.__dict__ = cached_result
254
            
255
            def execute(self, *args, **kw_args):
256
                if isinstance(self.result, Exception): raise self.result
257
                # otherwise, result is a rows list
258
                self.iter = iter(self.result)
259
            
260
            def fetchone(self):
261
                try: return self.iter.next()
262
                except StopIteration: return None
263
    
264
    def esc_value(self, value):
265
        module = util.root_module(self.db)
266
        if module == 'psycopg2': str_ = self.db.cursor().mogrify('%s', [value])
267
        elif module == 'MySQLdb':
268
            import _mysql
269
            str_ = _mysql.escape_string(value)
270
        else: raise NotImplementedError("Can't escape value for "+module
271
            +' database')
272
        return strings.to_unicode(str_)
273
    
274
    def esc_name(self, name): return esc_name(self, name) # calls global func
275
    
276
    def run_query(self, query, params=None, cacheable=False, log_level=2,
277
        exc_log_level=None):
278
        '''
279
        @param exc_log_level The log_level if the query throws an exception.
280
            Defaults to the value of log_level.
281
        '''
282
        assert query != None
283
        if exc_log_level == None: exc_log_level = log_level
284
        
285
        if not self.caching: cacheable = False
286
        used_cache = False
287
        success = False
288
        try:
289
            # Get cursor
290
            if cacheable:
291
                query_lookup = _query_lookup(query, params)
292
                try:
293
                    cur = self.query_results[query_lookup]
294
                    used_cache = True
295
                except KeyError: cur = self.DbCursor(self)
296
            else: cur = self.db.cursor()
297
            
298
            # Run query
299
            cur.execute(query, params)
300
            
301
            success = True
302
        finally:
303
            if self.debug: # only compute msg if needed
304
                if not success: log_level = exc_log_level
305
                if used_cache: cache_status = 'Cache hit'
306
                elif cacheable: cache_status = 'Cache miss'
307
                else: cache_status = 'Non-cacheable'
308
                self.log_debug(cache_status+': '+strings.one_line(
309
                    str(get_cur_query(cur, query, params))), log_level)
310
        
311
        return cur
312
    
313
    def is_cached(self, query, params=None):
314
        return _query_lookup(query, params) in self.query_results
315
    
316
    def with_savepoint(self, func):
317
        savepoint = 'level_'+str(self._savepoint)
318
        self.run_query('SAVEPOINT '+savepoint, log_level=4)
319
        self._savepoint += 1
320
        try: 
321
            try: return_val = func()
322
            finally:
323
                self._savepoint -= 1
324
                assert self._savepoint >= 0
325
        except:
326
            self.run_query('ROLLBACK TO SAVEPOINT '+savepoint, log_level=4)
327
            raise
328
        else:
329
            self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
330
            self.do_autocommit()
331
            return return_val
332
    
333
    def do_autocommit(self):
334
        '''Autocommits if outside savepoint'''
335
        assert self._savepoint >= 0
336
        if self.autocommit and self._savepoint == 0:
337
            self.log_debug('Autocommiting')
338
            self.db.commit()
339

    
340
connect = DbConn
341

    
342
##### Querying
343

    
344
def run_raw_query(db, *args, **kw_args):
345
    '''For params, see DbConn.run_query()'''
346
    return db.run_query(*args, **kw_args)
347

    
348
def mogrify(db, query, params):
349
    module = util.root_module(db.db)
350
    if module == 'psycopg2': return db.db.cursor().mogrify(query, params)
351
    else: raise NotImplementedError("Can't mogrify query for "+module+
352
        ' database')
353

    
354
##### Recoverable querying
355

    
356
def with_savepoint(db, func): return db.with_savepoint(func)
357

    
358
def run_query(db, query, params=None, recover=None, cacheable=False, **kw_args):
359
    '''For params, see run_raw_query()'''
360
    if recover == None: recover = False
361
    
362
    try:
363
        def run(): return run_raw_query(db, query, params, cacheable, **kw_args)
364
        if recover and not db.is_cached(query, params):
365
            return with_savepoint(db, run)
366
        else: return run() # don't need savepoint if cached
367
    except Exception, e:
368
        if not recover: raise # need savepoint to run index_cols()
369
        msg = exc.str_(e)
370
        
371
        match = re.search(r'duplicate key value violates unique constraint '
372
            r'"((_?[^\W_]+)_[^"]+?)"', msg)
373
        if match:
374
            constraint, table = match.groups()
375
            try: cols = index_cols(db, table, constraint)
376
            except NotImplementedError: raise e
377
            else: raise DuplicateKeyException(constraint, cols, e)
378
        
379
        match = re.search(r'null value in column "(\w+?)" violates not-null '
380
            r'constraint', msg)
381
        if match: raise NullValueException('NOT NULL', [match.group(1)], e)
382
        
383
        match = re.search(r'\b(?:invalid input (?:syntax|value)\b.*?'
384
            r'|date/time field value out of range): "(.+?)"\n'
385
            r'(?:(?s).*?)\bfunction "(\w+?)".*?\bat assignment', msg)
386
        if match:
387
            value, name = match.groups()
388
            raise FunctionValueException(name, strings.to_unicode(value), e)
389
        
390
        match = re.search(r'relation "(\w+?)" already exists', msg)
391
        if match: raise DuplicateTableException(match.group(1), e)
392
        
393
        match = re.search(r'function "(\w+?)" already exists', msg)
394
        if match: raise DuplicateFunctionException(match.group(1), e)
395
        
396
        raise # no specific exception raised
397

    
398
##### Basic queries
399

    
400
def next_version(name):
401
    '''Prepends the version # so it won't be removed if the name is truncated'''
402
    version = 1 # first existing name was version 0
403
    match = re.match(r'^v(\d+)_(.*)$', name)
404
    if match:
405
        version = int(match.group(1))+1
406
        name = match.group(2)
407
    return 'v'+str(version)+'_'+name
408

    
409
def run_query_into(db, query, params, into=None, *args, **kw_args):
410
    '''Outputs a query to a temp table.
411
    For params, see run_query().
412
    '''
413
    if into == None: return run_query(db, query, params, *args, **kw_args)
414
    else: # place rows in temp table
415
        assert isinstance(into, sql_gen.Table)
416
        
417
        kw_args['recover'] = True
418
        
419
        temp = not db.debug # tables are created as permanent in debug mode
420
        # "temporary tables cannot specify a schema name", so remove schema
421
        if temp: into.schema = None
422
        
423
        while True:
424
            try:
425
                create_query = 'CREATE'
426
                if temp: create_query += ' TEMP'
427
                create_query += ' TABLE '+into.to_str(db)+' AS '+query
428
                
429
                return run_query(db, create_query, params, *args, **kw_args)
430
                    # CREATE TABLE AS sets rowcount to # rows in query
431
            except DuplicateTableException, e:
432
                into.name = next_version(into.name)
433
                # try again with next version of name
434

    
435
order_by_pkey = object() # tells mk_select() to order by the pkey
436

    
437
distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns
438

    
439
def mk_select(db, tables, fields=None, conds=None, distinct_on=[], limit=None,
440
    start=None, order_by=order_by_pkey, default_table=None):
441
    '''
442
    @param tables The single table to select from, or a list of tables to join
443
        together, with tables after the first being sql_gen.Join objects
444
    @param fields Use None to select all fields in the table
445
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
446
        * container can be any iterable type
447
        * compare_left_side: sql_gen.Code|str (for col name)
448
        * compare_right_side: sql_gen.ValueCond|literal value
449
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
450
        use all columns
451
    @return tuple(query, params)
452
    '''
453
    # Parse tables param
454
    if not lists.is_seq(tables): tables = [tables]
455
    tables = list(tables) # don't modify input! (list() copies input)
456
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
457
    
458
    # Parse other params
459
    if conds == None: conds = []
460
    elif isinstance(conds, dict): conds = conds.items()
461
    conds = list(conds) # don't modify input! (list() copies input)
462
    assert limit == None or type(limit) == int
463
    assert start == None or type(start) == int
464
    if order_by is order_by_pkey:
465
        if distinct_on != []: order_by = None
466
        else: order_by = pkey(db, table0, recover=True)
467
    
468
    query = 'SELECT'
469
    
470
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
471
    
472
    # DISTINCT ON columns
473
    if distinct_on != []:
474
        query += ' DISTINCT'
475
        if distinct_on is not distinct_on_all:
476
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
477
    
478
    # Columns
479
    query += ' '
480
    if fields == None: query += '*'
481
    else: query += ', '.join(map(parse_col, fields))
482
    
483
    # Main table
484
    query += ' FROM '+table0.to_str(db)
485
    
486
    # Add joins
487
    left_table = table0
488
    for join_ in tables:
489
        table = join_.table
490
        
491
        # Parse special values
492
        if join_.type_ is sql_gen.filter_out: # filter no match
493
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
494
                None))
495
        
496
        query += ' '+join_.to_str(db, left_table)
497
        
498
        left_table = table
499
    
500
    missing = True
501
    if conds != []:
502
        query += ' WHERE '+(' AND '.join(('('+sql_gen.ColValueCond(l, r)
503
            .to_str(db)+')' for l, r in conds)))
504
        missing = False
505
    if order_by != None:
506
        query += ' ORDER BY '+sql_gen.as_Col(order_by, table0).to_str(db)
507
    if limit != None: query += ' LIMIT '+str(limit); missing = False
508
    if start != None:
509
        if start != 0: query += ' OFFSET '+str(start)
510
        missing = False
511
    if missing: warnings.warn(DbWarning(
512
        'SELECT statement missing a WHERE, LIMIT, or OFFSET clause: '+query))
513
    
514
    return (query, [])
515

    
516
def select(db, *args, **kw_args):
517
    '''For params, see mk_select() and run_query()'''
518
    recover = kw_args.pop('recover', None)
519
    cacheable = kw_args.pop('cacheable', True)
520
    log_level = kw_args.pop('log_level', 2)
521
    
522
    query, params = mk_select(db, *args, **kw_args)
523
    return run_query(db, query, params, recover, cacheable, log_level=log_level)
524

    
525
def mk_insert_select(db, table, cols=None, select_query=None, params=None,
526
    returning=None, embeddable=False):
527
    '''
528
    @param returning str|None An inserted column (such as pkey) to return
529
    @param embeddable Whether the query should be embeddable as a nested SELECT.
530
        Warning: If you set this and cacheable=True when the query is run, the
531
        query will be fully cached, not just if it raises an exception.
532
    '''
533
    table = sql_gen.as_Table(table)
534
    if cols == []: cols = None # no cols (all defaults) = unknown col names
535
    if cols != None: cols = [sql_gen.as_Col(v).to_str(db) for v in cols]
536
    if select_query == None: select_query = 'DEFAULT VALUES'
537
    if returning != None: returning = sql_gen.as_Col(returning, table)
538
    
539
    # Build query
540
    query = 'INSERT INTO '+table.to_str(db)
541
    if cols != None: query += ' ('+', '.join(cols)+')'
542
    query += ' '+select_query
543
    
544
    if returning != None:
545
        returning_name = copy.copy(returning)
546
        returning_name.table = None
547
        returning_name = returning_name.to_str(db)
548
        query += ' RETURNING '+returning_name
549
    
550
    if embeddable:
551
        assert returning != None
552
        
553
        # Create function
554
        function_name = '_'.join(['insert', table.name] + cols)
555
        return_type = 'SETOF '+returning.to_str(db)+'%TYPE'
556
        while True:
557
            try:
558
                func_schema = None
559
                if not db.debug: func_schema = 'pg_temp'
560
                function = sql_gen.Table(function_name, func_schema).to_str(db)
561
                
562
                function_query = '''\
563
CREATE FUNCTION '''+function+'''() RETURNS '''+return_type+'''
564
    LANGUAGE sql
565
    AS $$'''+mogrify(db, query, params)+''';$$;
566
'''
567
                run_query(db, function_query, recover=True, cacheable=True)
568
                break # this version was successful
569
            except DuplicateFunctionException, e:
570
                function_name = next_version(function_name)
571
                # try again with next version of name
572
        
573
        # Return query that uses function
574
        func_table = sql_gen.NamedTable('f', sql_gen.CustomCode(function+'()'),
575
            [returning_name]) # AS clause requires function alias
576
        return mk_select(db, func_table, start=0, order_by=None)
577
    
578
    return (query, params)
579

    
580
def insert_select(db, *args, **kw_args):
581
    '''For params, see mk_insert_select() and run_query_into()
582
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
583
        values in
584
    '''
585
    into = kw_args.pop('into', None)
586
    if into != None: kw_args['embeddable'] = True
587
    recover = kw_args.pop('recover', None)
588
    cacheable = kw_args.pop('cacheable', True)
589
    
590
    query, params = mk_insert_select(db, *args, **kw_args)
591
    return run_query_into(db, query, params, into, recover=recover,
592
        cacheable=cacheable)
593

    
594
default = object() # tells insert() to use the default value for a column
595

    
596
def insert(db, table, row, *args, **kw_args):
597
    '''For params, see insert_select()'''
598
    if lists.is_seq(row): cols = None
599
    else:
600
        cols = row.keys()
601
        row = row.values()
602
    row = list(row) # ensure that "!= []" works
603
    
604
    # Check for special values
605
    labels = []
606
    values = []
607
    for value in row:
608
        if value is default: labels.append('DEFAULT')
609
        else:
610
            labels.append('%s')
611
            values.append(value)
612
    
613
    # Build query
614
    if values != []: query = ' VALUES ('+(', '.join(labels))+')'
615
    else: query = None
616
    
617
    return insert_select(db, table, cols, query, values, *args, **kw_args)
618

    
619
def mk_update(db, table, changes=None, cond=None):
620
    '''
621
    @param changes [(col, new_value),...]
622
        * container can be any iterable type
623
        * col: sql_gen.Code|str (for col name)
624
        * new_value: sql_gen.Code|literal value
625
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
626
    @return str query
627
    '''
628
    query = 'UPDATE '+sql_gen.as_Table(table).to_str(db)+'\nSET\n'
629
    query += ',\n'.join((sql_gen.to_name_only_col(col, table).to_str(db)+' = '
630
        +sql_gen.as_Value(new_value).to_str(db) for col, new_value in changes))
631
    if cond != None: query += ' WHERE '+cond.to_str(db)
632
    
633
    return query
634

    
635
def update(db, *args, **kw_args):
636
    '''For params, see mk_update() and run_query()'''
637
    recover = kw_args.pop('recover', None)
638
    
639
    return run_query(db, mk_update(db, *args, **kw_args), [], recover)
640

    
641
def last_insert_id(db):
642
    module = util.root_module(db.db)
643
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
644
    elif module == 'MySQLdb': return db.insert_id()
645
    else: return None
646

    
647
def truncate(db, table, schema='public'):
648
    return run_query(db, 'TRUNCATE '+qual_name(db, schema, table)+' CASCADE')
649

    
650
def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
651
    '''Creates a mapping from original column names (which may have collisions)
652
    to names that will be distinct among the columns' tables.
653
    This is meant to be used for several tables that are being joined together.
654
    @param cols The columns to combine. Duplicates will be removed.
655
    @param into The table for the new columns.
656
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
657
        columns will be included in the mapping even if they are not in cols.
658
        The tables of the provided Col objects will be changed to into, so make
659
        copies of them if you want to keep the original tables.
660
    @param as_items Whether to return a list of dict items instead of a dict
661
    @return dict(orig_col=new_col, ...)
662
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
663
        * new_col: sql_gen.Col(orig_col_name, into)
664
        * All mappings use the into table so its name can easily be
665
          changed for all columns at once
666
    '''
667
    cols = lists.uniqify(cols)
668
    
669
    items = []
670
    for col in preserve:
671
        orig_col = copy.copy(col)
672
        col.table = into
673
        items.append((orig_col, col))
674
    preserve = set(preserve)
675
    for col in cols:
676
        if col not in preserve:
677
            items.append((col, sql_gen.Col(clean_name(str(col)), into)))
678
    
679
    if not as_items: items = dict(items)
680
    return items
681

    
682
def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
683
    '''For params, see mk_flatten_mapping()
684
    @return See return value of mk_flatten_mapping()
685
    '''
686
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
687
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
688
    run_query_into(db, *mk_select(db, joins, cols, limit=limit, start=start),
689
        into=into)
690
    return dict(items)
691

    
692
##### Database structure queries
693

    
694
def table_row_count(db, table, recover=None):
695
    return value(run_query(db, *mk_select(db, table, [sql_gen.row_count],
696
        order_by=None, start=0), recover=recover, log_level=3))
697

    
698
def table_cols(db, table, recover=None):
699
    return list(col_names(select(db, table, limit=0, order_by=None,
700
        recover=recover, log_level=4)))
701

    
702
def pkey(db, table, recover=None):
703
    '''Assumed to be first column in table'''
704
    return table_cols(db, table, recover)[0]
705

    
706
not_null_col = 'not_null'
707

    
708
def table_not_null_col(db, table, recover=None):
709
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
710
    if not_null_col in table_cols(db, table, recover): return not_null_col
711
    else: return pkey(db, table, recover)
712

    
713
def index_cols(db, table, index):
714
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
715
    automatically created. When you don't know whether something is a UNIQUE
716
    constraint or a UNIQUE index, use this function.'''
717
    module = util.root_module(db.db)
718
    if module == 'psycopg2':
719
        return list(values(run_query(db, '''\
720
SELECT attname
721
FROM
722
(
723
        SELECT attnum, attname
724
        FROM pg_index
725
        JOIN pg_class index ON index.oid = indexrelid
726
        JOIN pg_class table_ ON table_.oid = indrelid
727
        JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY (indkey)
728
        WHERE
729
            table_.relname = %(table)s
730
            AND index.relname = %(index)s
731
    UNION
732
        SELECT attnum, attname
733
        FROM
734
        (
735
            SELECT
736
                indrelid
737
                , (regexp_matches(indexprs, E':varattno (\\\\d+)', 'g'))[1]::int
738
                    AS indkey
739
            FROM pg_index
740
            JOIN pg_class index ON index.oid = indexrelid
741
            JOIN pg_class table_ ON table_.oid = indrelid
742
            WHERE
743
                table_.relname = %(table)s
744
                AND index.relname = %(index)s
745
        ) s
746
        JOIN pg_attribute ON attrelid = indrelid AND attnum = indkey
747
) s
748
ORDER BY attnum
749
''',
750
            {'table': table, 'index': index}, cacheable=True, log_level=4)))
751
    else: raise NotImplementedError("Can't list index columns for "+module+
752
        ' database')
753

    
754
def constraint_cols(db, table, constraint):
755
    module = util.root_module(db.db)
756
    if module == 'psycopg2':
757
        return list(values(run_query(db, '''\
758
SELECT attname
759
FROM pg_constraint
760
JOIN pg_class ON pg_class.oid = conrelid
761
JOIN pg_attribute ON attrelid = conrelid AND attnum = ANY (conkey)
762
WHERE
763
    relname = %(table)s
764
    AND conname = %(constraint)s
765
ORDER BY attnum
766
''',
767
            {'table': table, 'constraint': constraint})))
768
    else: raise NotImplementedError("Can't list constraint columns for "+module+
769
        ' database')
770

    
771
row_num_col = '_row_num'
772

    
773
def index_col(db, col):
774
    '''Adds an index on a column if it doesn't already exist.'''
775
    assert sql_gen.is_table_col(col)
776
    
777
    table = col.table
778
    index = sql_gen.as_Table(clean_name(str(col)))
779
    col = sql_gen.to_name_only_col(col)
780
    try: run_query(db, 'CREATE INDEX '+index.to_str(db)+' ON '+table.to_str(db)
781
        +' ('+col.to_str(db)+')', recover=True, cacheable=True, log_level=3)
782
    except DuplicateTableException: pass # index already existed
783

    
784
def index_pkey(db, table, recover=None):
785
    '''Makes the first column in a table the primary key.
786
    @pre The table must not already have a primary key.
787
    '''
788
    table = sql_gen.as_Table(table)
789
    
790
    index = sql_gen.as_Table(table.name+'_pkey')
791
    col = sql_gen.to_name_only_col(pkey(db, table, recover))
792
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD CONSTRAINT '
793
        +index.to_str(db)+' PRIMARY KEY('+col.to_str(db)+')', recover=recover,
794
        log_level=3)
795

    
796
def add_row_num(db, table):
797
    '''Adds a row number column to a table. Its name is in row_num_col. It will
798
    be the primary key.'''
799
    table = sql_gen.as_Table(table).to_str(db)
800
    run_query(db, 'ALTER TABLE '+table+' ADD COLUMN '+row_num_col
801
        +' serial NOT NULL PRIMARY KEY', log_level=3)
802

    
803
def tables(db, schema='public', table_like='%'):
804
    module = util.root_module(db.db)
805
    params = {'schema': schema, 'table_like': table_like}
806
    if module == 'psycopg2':
807
        return values(run_query(db, '''\
808
SELECT tablename
809
FROM pg_tables
810
WHERE
811
    schemaname = %(schema)s
812
    AND tablename LIKE %(table_like)s
813
ORDER BY tablename
814
''',
815
            params, cacheable=True))
816
    elif module == 'MySQLdb':
817
        return values(run_query(db, 'SHOW TABLES LIKE %(table_like)s', params,
818
            cacheable=True))
819
    else: raise NotImplementedError("Can't list tables for "+module+' database')
820

    
821
##### Database management
822

    
823
def empty_db(db, schema='public', **kw_args):
824
    '''For kw_args, see tables()'''
825
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)
826

    
827
##### Heuristic queries
828

    
829
def put(db, table, row, pkey_=None, row_ct_ref=None):
830
    '''Recovers from errors.
831
    Only works under PostgreSQL (uses INSERT RETURNING).
832
    '''
833
    if pkey_ == None: pkey_ = pkey(db, table, recover=True)
834
    
835
    try:
836
        cur = insert(db, table, row, pkey_, recover=True)
837
        if row_ct_ref != None and cur.rowcount >= 0:
838
            row_ct_ref[0] += cur.rowcount
839
        return value(cur)
840
    except DuplicateKeyException, e:
841
        return value(select(db, table, [pkey_],
842
            util.dict_subset_right_join(row, e.cols), recover=True))
843

    
844
def get(db, table, row, pkey, row_ct_ref=None, create=False):
845
    '''Recovers from errors'''
846
    try: return value(select(db, table, [pkey], row, limit=1, recover=True))
847
    except StopIteration:
848
        if not create: raise
849
        return put(db, table, row, pkey, row_ct_ref) # insert new row
850

    
851
def put_table(db, out_table, in_tables, mapping, row_ct_ref=None):
852
    '''Recovers from errors.
853
    Only works under PostgreSQL (uses INSERT RETURNING).
854
    @param in_tables The main input table to select from, followed by a list of
855
        tables to join with it using the main input table's pkey
856
    @param mapping dict(out_table_col=in_table_col, ...)
857
        * out_table_col: sql_gen.Col|str
858
        * in_table_col: sql_gen.Col Wrap literal values in a sql_gen.NamedCol
859
    @return sql_gen.Col Where the output pkeys are made available
860
    '''
861
    out_table = sql_gen.as_Table(out_table)
862
    for in_table_col in mapping.itervalues():
863
        assert isinstance(in_table_col, sql_gen.Col)
864
    
865
    temp_prefix = out_table.name
866
    pkeys = sql_gen.Table(temp_prefix+'_pkeys')
867
    
868
    # Create input joins from list of input tables
869
    in_tables_ = in_tables[:] # don't modify input!
870
    in_tables0 = in_tables_.pop(0) # first table is separate
871
    in_pkey = pkey(db, in_tables0, recover=True)
872
    in_pkey_col = sql_gen.as_Col(in_pkey, in_tables0)
873
    input_joins = [in_tables0]+[sql_gen.Join(v, {in_pkey: sql_gen.join_same})
874
        for v in in_tables_]
875
    
876
    db.log_debug('Joining together input tables')
877
    # Place in new table for speed and so don't modify input if values edited
878
    in_table = sql_gen.Table(temp_prefix+'_in')
879
    flatten_cols = filter(sql_gen.is_table_col, mapping.values())
880
    mapping = dicts.join(mapping, flatten(db, in_table, input_joins,
881
        flatten_cols, preserve=[in_pkey_col], start=0))
882
    input_joins = [in_table]
883
    
884
    out_pkey = pkey(db, out_table, recover=True)
885
    out_pkey_col = sql_gen.as_Col(out_pkey, out_table)
886
    
887
    pkeys_names = [in_pkey, out_pkey]
888
    pkeys_cols = [in_pkey_col, out_pkey_col]
889
    
890
    pkeys_table_exists_ref = [False]
891
    def insert_into_pkeys(joins, cols):
892
        query, params = mk_select(db, joins, cols, order_by=None, start=0)
893
        if pkeys_table_exists_ref[0]:
894
            insert_select(db, pkeys, pkeys_names, query, params)
895
        else:
896
            run_query_into(db, query, params, into=pkeys)
897
            pkeys_table_exists_ref[0] = True
898
    
899
    limit_ref = [None]
900
    conds = set()
901
    distinct_on = []
902
    def mk_main_select(joins, cols):
903
        return mk_select(db, joins, cols, conds, distinct_on,
904
            limit=limit_ref[0], start=0)
905
    
906
    def log_exc(e):
907
        db.log_debug('Caught exception: '+exc.str_(e, first_line_only=True))
908
    def remove_all_rows(msg):
909
        warnings.warn(DbWarning(msg))
910
        db.log_debug(msg.partition('\n')[0])
911
        db.log_debug('Returning NULL for all rows')
912
        limit_ref[0] = 0 # just create an empty pkeys table
913
    def ignore(in_col, value):
914
        in_col_str = str(in_col)
915
        db.log_debug('Adding index on '+in_col_str+' to enable fast filtering')
916
        index_col(db, in_col)
917
        db.log_debug('Ignoring rows with '+in_col_str+' = '+repr(value))
918
    def remove_rows(in_col, value):
919
        ignore(in_col, value)
920
        cond = (in_col, sql_gen.CompareCond(value, '!='))
921
        assert cond not in conds # avoid infinite loops
922
        conds.add(cond)
923
    def invalid2null(in_col, value):
924
        ignore(in_col, value)
925
        update(db, in_table, [(in_col, None)],
926
            sql_gen.ColValueCond(in_col, value))
927
    
928
    # Do inserts and selects
929
    join_cols = {}
930
    insert_out_pkeys = sql_gen.Table(temp_prefix+'_insert_out_pkeys')
931
    insert_in_pkeys = sql_gen.Table(temp_prefix+'_insert_in_pkeys')
932
    while True:
933
        has_joins = join_cols != {}
934
        
935
        # Prepare to insert new rows
936
        insert_joins = input_joins[:] # don't modify original!
937
        insert_args = dict(recover=True, cacheable=False)
938
        if has_joins:
939
            distinct_on = [v.to_Col() for v in join_cols.values()]
940
            insert_joins.append(sql_gen.Join(out_table, join_cols,
941
                sql_gen.filter_out))
942
        else:
943
            insert_args.update(dict(returning=out_pkey, into=insert_out_pkeys))
944
        
945
        db.log_debug('Inserting new rows')
946
        try:
947
            cur = insert_select(db, out_table, mapping.keys(),
948
                *mk_main_select(insert_joins, mapping.values()), **insert_args)
949
            break # insert successful
950
        except DuplicateKeyException, e:
951
            log_exc(e)
952
            
953
            old_join_cols = join_cols.copy()
954
            join_cols.update(util.dict_subset(mapping, e.cols))
955
            db.log_debug('Ignoring existing rows, comparing on '+str(join_cols))
956
            assert join_cols != old_join_cols # avoid infinite loops
957
        except NullValueException, e:
958
            log_exc(e)
959
            
960
            out_col, = e.cols
961
            try: in_col = mapping[out_col]
962
            except KeyError:
963
                remove_all_rows('Missing mapping for NOT NULL '+out_col)
964
            else: remove_rows(in_col, None)
965
        except FunctionValueException, e:
966
            log_exc(e)
967
            
968
            assert e.name == out_table.name
969
            out_col = 'value' # assume function param was named "value"
970
            invalid2null(mapping[out_col], e.value)
971
        except DatabaseErrors, e:
972
            log_exc(e)
973
            
974
            remove_all_rows('No handler for exception: '+exc.str_(e))
975
        # after exception handled, rerun loop with additional constraints
976
    
977
    if row_ct_ref != None and cur.rowcount >= 0:
978
        row_ct_ref[0] += cur.rowcount
979
    
980
    if has_joins:
981
        select_joins = input_joins+[sql_gen.Join(out_table, join_cols)]
982
        db.log_debug('Getting output pkeys of existing/inserted rows')
983
        insert_into_pkeys(select_joins, pkeys_cols)
984
    else:
985
        add_row_num(db, insert_out_pkeys) # for joining with input pkeys
986
        
987
        db.log_debug('Getting input pkeys for rows in insert')
988
        run_query_into(db, *mk_main_select(input_joins, [in_pkey]),
989
            into=insert_in_pkeys)
990
        add_row_num(db, insert_in_pkeys) # for joining with output pkeys
991
        
992
        assert table_row_count(db, insert_out_pkeys) == table_row_count(db,
993
            insert_in_pkeys)
994
        
995
        db.log_debug('Joining together output and input pkeys')
996
        pkey_joins = [insert_in_pkeys, sql_gen.Join(insert_out_pkeys,
997
            {row_num_col: sql_gen.join_same_not_null})]
998
        insert_into_pkeys(pkey_joins, pkeys_names)
999
    
1000
    db.log_debug('Adding pkey on returned pkeys table to enable fast joins')
1001
    index_pkey(db, pkeys)
1002
    
1003
    db.log_debug("Setting missing rows' pkeys to NULL")
1004
    missing_rows_joins = input_joins+[sql_gen.Join(pkeys,
1005
        {in_pkey: sql_gen.join_same_not_null}, sql_gen.filter_out)]
1006
        # must use join_same_not_null or query will take forever
1007
    insert_into_pkeys(missing_rows_joins,
1008
        [in_pkey_col, sql_gen.NamedCol(out_pkey, None)])
1009
    
1010
    assert table_row_count(db, pkeys) == table_row_count(db, in_table)
1011
    
1012
    return sql_gen.Col(out_pkey, pkeys)
1013

    
1014
##### Data cleanup
1015

    
1016
def cleanup_table(db, table, cols):
1017
    def esc_name_(name): return esc_name(db, name)
1018
    
1019
    table = sql_gen.as_Table(table).to_str(db)
1020
    cols = map(esc_name_, cols)
1021
    
1022
    run_query(db, 'UPDATE '+table+' SET\n'+(',\n'.join(('\n'+col
1023
        +' = nullif(nullif(trim(both from '+col+"), %(null0)s), %(null1)s)"
1024
            for col in cols))),
1025
        dict(null0='', null1=r'\N'))
(23-23/35)