# Database access

import copy
import re
import time
import warnings

import exc
import dicts
import iters
import lists
import profiling
from Proxy import Proxy
import rand
import sql_gen
import strings
import util

##### Exceptions

def get_cur_query(cur, input_query=None):
    raw_query = None
    if hasattr(cur, 'query'): raw_query = cur.query
    elif hasattr(cur, '_last_executed'): raw_query = cur._last_executed
    
    if raw_query != None: return raw_query
    else: return '[input] '+strings.ustr(input_query)

def _add_cursor_info(e, *args, **kw_args):
    '''For params, see get_cur_query()'''
    exc.add_msg(e, 'query: '+strings.ustr(get_cur_query(*args, **kw_args)))

class DbException(exc.ExceptionWithCause):
    def __init__(self, msg, cause=None, cur=None):
        exc.ExceptionWithCause.__init__(self, msg, cause, cause_newline=True)
        if cur != None: _add_cursor_info(self, cur)

class ExceptionWithName(DbException):
    def __init__(self, name, cause=None):
        DbException.__init__(self, 'for name: '
            +strings.as_tt(strings.ustr(name)), cause)
        self.name = name

class ExceptionWithValue(DbException):
    def __init__(self, value, cause=None):
        DbException.__init__(self, 'for value: '
            +strings.as_tt(strings.urepr(value)), cause)
        self.value = value

class ExceptionWithNameType(DbException):
    def __init__(self, type_, name, cause=None):
        DbException.__init__(self, 'for type: '+strings.as_tt(strings.ustr(
            type_))+'; name: '+strings.as_tt(name), cause)
        self.type = type_
        self.name = name

class ConstraintException(DbException):
    def __init__(self, name, cond, cols, cause=None):
        msg = 'Violated '+strings.as_tt(name)+' constraint'
        if cond != None: msg += ' with condition '+strings.as_tt(cond)
        if cols != []: msg += ' on columns: '+strings.as_tt(', '.join(cols))
        DbException.__init__(self, msg, cause)
        self.name = name
        self.cond = cond
        self.cols = cols

class MissingCastException(DbException):
    def __init__(self, type_, col=None, cause=None):
        msg = 'Missing cast to type '+strings.as_tt(type_)
        if col != None: msg += ' on column: '+strings.as_tt(col)
        DbException.__init__(self, msg, cause)
        self.type = type_
        self.col = col

class EncodingException(ExceptionWithName): pass

class DuplicateKeyException(ConstraintException): pass

class NullValueException(ConstraintException): pass

class CheckException(ConstraintException): pass

class InvalidValueException(ExceptionWithValue): pass

class InvalidTypeException(ExceptionWithNameType): pass

class DuplicateException(ExceptionWithNameType): pass

class DoesNotExistException(ExceptionWithNameType): pass

class EmptyRowException(DbException): pass

##### Warnings

class DbWarning(UserWarning): pass

##### Result retrieval

def col_names(cur): return (col[0] for col in cur.description)

def rows(cur): return iter(lambda: cur.fetchone(), None)

def consume_rows(cur):
    '''Used to fetch all rows so result will be cached'''
    iters.consume_iter(rows(cur))

def next_row(cur): return rows(cur).next()

def row(cur):
    row_ = next_row(cur)
    consume_rows(cur)
    return row_

def next_value(cur): return next_row(cur)[0]

def value(cur): return row(cur)[0]

def values(cur): return iters.func_iter(lambda: next_value(cur))

def value_or_none(cur):
    try: return value(cur)
    except StopIteration: return None

##### Escaping

def esc_name_by_module(module, name):
    if module == 'psycopg2' or module == None: quote = '"'
    elif module == 'MySQLdb': quote = '`'
    else: raise NotImplementedError("Can't escape name for "+module+' database')
    return sql_gen.esc_name(name, quote)

def esc_name_by_engine(engine, name, **kw_args):
    return esc_name_by_module(db_engines[engine][0], name, **kw_args)

def esc_name(db, name, **kw_args):
    return esc_name_by_module(util.root_module(db.db), name, **kw_args)

def qual_name(db, schema, table):
    def esc_name_(name): return esc_name(db, name)
    table = esc_name_(table)
    if schema != None: return esc_name_(schema)+'.'+table
    else: return table

##### Database connections

db_config_names = ['engine', 'host', 'user', 'password', 'database', 'schemas']

db_engines = {
    'MySQL': ('MySQLdb', {'password': 'passwd', 'database': 'db'}),
    'PostgreSQL': ('psycopg2', {}),
}

DatabaseErrors_set = set([DbException])
DatabaseErrors = tuple(DatabaseErrors_set)

def _add_module(module):
    DatabaseErrors_set.add(module.DatabaseError)
    global DatabaseErrors
    DatabaseErrors = tuple(DatabaseErrors_set)

def db_config_str(db_config):
    return db_config['engine']+' database '+db_config['database']

log_debug_none = lambda msg, level=2: None

class DbConn:
    def __init__(self, db_config, autocommit=True, caching=True,
        log_debug=log_debug_none, debug_temp=False, src=None):
        '''
        @param debug_temp Whether temporary objects should instead be permanent.
            This assists in debugging the internal objects used by the program.
        @param src In autocommit mode, will be included in a comment in every
            query, to help identify the data source in pg_stat_activity.
        '''
        self.db_config = db_config
        self.autocommit = autocommit
        self.caching = caching
        self.log_debug = log_debug
        self.debug = log_debug != log_debug_none
        self.debug_temp = debug_temp
        self.src = src
        self.autoanalyze = False
        self.autoexplain = False
        self.profile_row_ct = None
        
        self._savepoint = 0
        self._reset()
    
    def __getattr__(self, name):
        if name == '__dict__': raise Exception('getting __dict__')
        if name == 'db': return self._db()
        else: raise AttributeError()
    
    def __getstate__(self):
        state = copy.copy(self.__dict__) # shallow copy
        state['log_debug'] = None # don't pickle the debug callback
        state['_DbConn__db'] = None # don't pickle the connection
        return state
    
    def clear_cache(self): self.query_results = {}
    
    def _reset(self):
        self.clear_cache()
        assert self._savepoint == 0
        self._notices_seen = set()
        self.__db = None
    
    def connected(self): return self.__db != None
    
    def close(self):
        if not self.connected(): return
        
        # Record that the automatic transaction is now closed
        self._savepoint -= 1
        
        self.db.close()
        self._reset()
    
    def reconnect(self):
        # Do not do this in test mode as it would roll back everything
        if self.autocommit: self.close()
        # Connection will be reopened automatically on first query
    
    def _db(self):
        if self.__db == None:
            # Process db_config
            db_config = self.db_config.copy() # don't modify input!
            schemas = db_config.pop('schemas', None)
            module_name, mappings = db_engines[db_config.pop('engine')]
            module = __import__(module_name)
            _add_module(module)
            for orig, new in mappings.iteritems():
                try: util.rename_key(db_config, orig, new)
                except KeyError: pass
            
            # Connect
            self.__db = module.connect(**db_config)
            
            # Record that a transaction is already open
            self._savepoint += 1
            
            # Configure connection
            if hasattr(self.db, 'set_isolation_level'):
                import psycopg2.extensions
                self.db.set_isolation_level(
                    psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
            if schemas != None:
                search_path = [self.esc_name(s) for s in schemas.split(',')]
                run_query(self, 'SET search_path TO '+(','.join(search_path)),
                    log_level=3)
        
        return self.__db
    
    class DbCursor(Proxy):
        def __init__(self, outer):
            Proxy.__init__(self, outer.db.cursor())
            self.outer = outer
            self.query_results = outer.query_results
            self.query_lookup = None
            self.result = []
        
        def execute(self, query):
            self._is_insert = query.startswith('INSERT')
            self.query_lookup = query
            try:
                try: cur = self.inner.execute(query)
                finally: self.query = get_cur_query(self.inner, query)
            except Exception, e:
                self.result = e # cache the exception as the result
                self._cache_result()
                raise
            
            # Always cache certain queries
            query = sql_gen.lstrip(query)
            if query.startswith('CREATE') or query.startswith('ALTER'):
                # structural changes
                # Rest of query must be unique in the face of name collisions,
                # so don't cache ADD COLUMN unless it has distinguishing comment
                if query.find('ADD COLUMN') < 0 or query.endswith('*/'):
                    self._cache_result()
            elif self.rowcount == 0 and query.startswith('SELECT'): # empty
                consume_rows(self) # fetch all rows so result will be cached
            
            return cur
        
        def fetchone(self):
            row = self.inner.fetchone()
            if row != None: self.result.append(row)
            # otherwise, fetched all rows
            else: self._cache_result()
            return row
        
        def _cache_result(self):
            # For inserts that return a result set, don't cache result set since
            # inserts are not idempotent. Other non-SELECT queries don't have
            # their result set read, so only exceptions will be cached (an
            # invalid query will always be invalid).
            if self.query_results != None and (not self._is_insert
                or isinstance(self.result, Exception)):
                
                assert self.query_lookup != None
                self.query_results[self.query_lookup] = self.CacheCursor(
                    util.dict_subset(dicts.AttrsDictView(self),
                    ['query', 'result', 'rowcount', 'description']))
        
        class CacheCursor:
            def __init__(self, cached_result): self.__dict__ = cached_result
            
            def execute(self, *args, **kw_args):
                if isinstance(self.result, Exception): raise self.result
                # otherwise, result is a rows list
                self.iter = iter(self.result)
            
            def fetchone(self):
                try: return self.iter.next()
                except StopIteration: return None
    
    def esc_value(self, value):
        try: str_ = self.mogrify('%s', [value])
        except NotImplementedError, e:
            module = util.root_module(self.db)
            if module == 'MySQLdb':
                import _mysql
                str_ = _mysql.escape_string(value)
            else: raise e
        return strings.to_unicode(str_)
    
    def esc_name(self, name): return esc_name(self, name) # calls global func
    
    def std_code(self, str_):
        '''Standardizes SQL code.
        * Ensures that string literals are prefixed by `E`
        '''
        if str_.startswith("'"): str_ = 'E'+str_
        return str_
    
    def can_mogrify(self):
        module = util.root_module(self.db)
        return module == 'psycopg2'
    
    def mogrify(self, query, params=None):
        if self.can_mogrify(): return self.db.cursor().mogrify(query, params)
        else: raise NotImplementedError("Can't mogrify query")
    
    def set_encoding(self, encoding):
        encoding_str = sql_gen.Literal(encoding)
        run_query(self, 'SET NAMES '+encoding_str.to_str(self))
    
    def print_notices(self):
        if hasattr(self.db, 'notices'):
            for msg in self.db.notices:
                if msg not in self._notices_seen:
                    self._notices_seen.add(msg)
                    self.log_debug(msg, level=2)
    
    def run_query(self, query, cacheable=False, log_level=2,
        debug_msg_ref=None):
        '''
        @param log_ignore_excs The log_level will be increased by 2 if the query
            throws one of these exceptions.
        @param debug_msg_ref If specified, the log message will be returned in
            this instead of being output. This allows you to filter log messages
            depending on the result of the query.
        '''
        assert query != None
        
        if self.autocommit and self.src != None:
            query = sql_gen.esc_comment(self.src)+'\t'+query
        
        if not self.caching: cacheable = False
        used_cache = False
        
        if self.debug:
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
        try:
            # Get cursor
            if cacheable:
                try: cur = self.query_results[query]
                except KeyError: cur = self.DbCursor(self)
                else: used_cache = True
            else: cur = self.db.cursor()
            
            # Run query
            try: cur.execute(query)
            except Exception, e:
                _add_cursor_info(e, self, query)
                raise
            else: self.do_autocommit()
        finally:
            if self.debug:
                profiler.stop(self.profile_row_ct)
                
                ## Log or return query
                
                query = strings.ustr(get_cur_query(cur, query))
                # Put the src comment on a separate line in the log file
                query = query.replace('\t', '\n', 1)
                
                msg = 'DB query: '
                
                if used_cache: msg += 'cache hit'
                elif cacheable: msg += 'cache miss'
                else: msg += 'non-cacheable'
                
                msg += ':\n'+profiler.msg()+'\n'+strings.as_code(query, 'SQL')
                
                if debug_msg_ref != None: debug_msg_ref[0] = msg
                else: self.log_debug(msg, log_level)
                
                self.print_notices()
        
        return cur
    
    def is_cached(self, query): return query in self.query_results
    
    def with_autocommit(self, func):
        import psycopg2.extensions
        
        prev_isolation_level = self.db.isolation_level
        self.db.set_isolation_level(
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
        try: return func()
        finally: self.db.set_isolation_level(prev_isolation_level)
    
    def with_savepoint(self, func):
        top = self._savepoint == 0
        savepoint = 'level_'+str(self._savepoint)
        
        if self.debug:
            self.log_debug('Begin transaction', level=4)
            profiler = profiling.ItersProfiler(start_now=True, iter_text='row')
        
        # Must happen before running queries so they don't get autocommitted
        self._savepoint += 1
        
        if top: query = 'START TRANSACTION ISOLATION LEVEL READ COMMITTED'
        else: query = 'SAVEPOINT '+savepoint
        self.run_query(query, log_level=4)
        try:
            return func()
            if top: self.run_query('COMMIT', log_level=4)
        except:
            if top: query = 'ROLLBACK'
            else: query = 'ROLLBACK TO SAVEPOINT '+savepoint
            self.run_query(query, log_level=4)
            
            raise
        finally:
            # Always release savepoint, because after ROLLBACK TO SAVEPOINT,
            # "The savepoint remains valid and can be rolled back to again"
            # (http://www.postgresql.org/docs/8.3/static/sql-rollback-to.html).
            if not top:
                self.run_query('RELEASE SAVEPOINT '+savepoint, log_level=4)
            
            self._savepoint -= 1
            assert self._savepoint >= 0
            
            if self.debug:
                profiler.stop(self.profile_row_ct)
                self.log_debug('End transaction\n'+profiler.msg(), level=4)
            
            self.do_autocommit() # OK to do this after ROLLBACK TO SAVEPOINT
    
    def do_autocommit(self):
        '''Autocommits if outside savepoint'''
        assert self._savepoint >= 1
        if self.autocommit and self._savepoint == 1:
            self.log_debug('Autocommitting', level=4)
            self.db.commit()
    
    def col_info(self, col, cacheable=True):
        module = util.root_module(self.db)
        if module == 'psycopg2':
            qual_table = sql_gen.Literal(col.table.to_str(self))
            col_name_str = sql_gen.Literal(col.name)
            try:
                type_, is_array, default, nullable = row(run_query(self, '''\
SELECT
format_type(COALESCE(NULLIF(typelem, 0), pg_type.oid), -1) AS type
, typcategory = 'A' AS type_is_array
, pg_get_expr(pg_attrdef.adbin, attrelid, true) AS default
, NOT pg_attribute.attnotnull AS nullable
FROM pg_attribute
LEFT JOIN pg_type ON pg_type.oid = atttypid
LEFT JOIN pg_attrdef ON adrelid = attrelid AND adnum = attnum
WHERE
attrelid = '''+qual_table.to_str(self)+'''::regclass
AND attname = '''+col_name_str.to_str(self)+'''
'''
                    , recover=True, cacheable=cacheable, log_level=4))
            except (DoesNotExistException, StopIteration):
                raise sql_gen.NoUnderlyingTableException(col)
            if is_array: type_ = sql_gen.ArrayType(type_)
        else:
            table = sql_gen.Table('columns', 'information_schema')
            cols = [sql_gen.Col('data_type'), sql_gen.Col('udt_name'),
                'column_default', sql_gen.Cast('boolean',
                sql_gen.Col('is_nullable'))]
            
            conds = [('table_name', col.table.name),
                ('column_name', strings.ustr(col.name))]
            schema = col.table.schema
            if schema != None: conds.append(('table_schema', schema))
            
            cur = select(self, table, cols, conds, order_by='table_schema',
                limit=1, cacheable=cacheable, log_level=4)
            try: type_, extra_type, default, nullable = row(cur)
            except StopIteration: raise sql_gen.NoUnderlyingTableException(col)
            if type_ == 'USER-DEFINED': type_ = extra_type
            elif type_ == 'ARRAY':
                type_ = sql_gen.ArrayType(strings.remove_prefix('_', extra_type,
                    require=True))
        
        if default != None: default = sql_gen.as_Code(default, self)
        return sql_gen.TypedCol(col.name, type_, default, nullable)
    
    def TempFunction(self, name):
        if self.debug_temp: schema = None
        else: schema = 'pg_temp'
        return sql_gen.Function(name, schema)

connect = DbConn

##### Recoverable querying

def parse_exception(db, e, recover=False):
    msg = strings.ustr(e.args[0])
    msg = re.sub(r'^(?:PL/Python: )?ValueError: ', r'', msg)
    
    match = re.match(r'^invalid byte sequence for encoding "(.+?)":', msg)
    if match:
        encoding, = match.groups()
        raise EncodingException(encoding, e)
    
    def make_DuplicateKeyException(constraint, e):
        cols = []
        cond = None
        if recover: # need auto-rollback to run index_cols()
            try:
                cols = index_cols(db, constraint)
                cond = index_cond(db, constraint)
            except NotImplementedError: pass
        return DuplicateKeyException(constraint, cond, cols, e)
    
    match = re.match(r'^duplicate key value violates unique constraint "(.+?)"',
        msg)
    if match:
        constraint, = match.groups()
        raise make_DuplicateKeyException(constraint, e)
    
    match = re.match(r'^could not create unique index "(.+?)"\n'
        r'DETAIL:  Key .+? is duplicated', msg)
    if match:
        constraint, = match.groups()
        raise DuplicateKeyException(constraint, None, [], e)
    
    match = re.match(r'^null value in column "(.+?)" violates not-null'
        r' constraint', msg)
    if match:
        col, = match.groups()
        raise NullValueException('NOT NULL', None, [col], e)
    
    match = re.match(r'^new row for relation "(.+?)" violates check '
        r'constraint "(.+?)"', msg)
    if match:
        table, constraint = match.groups()
        constraint = sql_gen.Col(constraint, table)
        cond = None
        if recover: # need auto-rollback to run constraint_cond()
            try: cond = constraint_cond(db, constraint)
            except NotImplementedError: pass
        raise CheckException(constraint.to_str(db), cond, [], e)
    
    match = re.match(r'^(?:invalid input (?:syntax|value)\b[^:]*'
        r'|.+? out of range)(?:: "(.+?)")?', msg)
    if match:
        value, = match.groups()
        value = util.do_ignore_none(strings.to_unicode, value)
        raise InvalidValueException(value, e)
    
    match = re.match(r'^column "(.+?)" is of type (.+?) but expression '
        r'is of type', msg)
    if match:
        col, type_ = match.groups()
        raise MissingCastException(type_, col, e)
    
    match = re.match(r'^could not determine polymorphic type because '
        r'input has type "unknown"', msg)
    if match: raise MissingCastException('text', None, e)
    
    match = re.match(r'^.+? types (.+?) and (.+?) cannot be matched', msg)
    if match:
        type0, type1 = match.groups()
        raise MissingCastException(type0, None, e)
    
    match = re.match(r'^.*?\brelation "(.+?)" is not a (table)', msg)
    if match:
        name, type_ = match.groups()
        raise InvalidTypeException(type_, name, e)
    
    typed_name_re = r'^(\S+) "?(.+?)"?(?: of relation ".+?")?'
    
    match = re.match(typed_name_re+r'.*? already exists', msg)
    if match:
        type_, name = match.groups()
        raise DuplicateException(type_, name, e)
    
    match = re.match(r'more than one (\S+) named ""(.+?)""', msg)
    if match:
        type_, name = match.groups()
        raise DuplicateException(type_, name, e)
    
    match = re.match(typed_name_re+r' does not exist', msg)
    if match:
        type_, name = match.groups()
        if type_ == 'function':
            match = re.match(r'^(.+?)\(.*\)$', name)
            if match: # includes params, so is call rather than cast to regproc
                function_name, = match.groups()
                func = sql_gen.Function(function_name)
                if function_exists(db, func) and msg.find('CAST') < 0:
                    # not found only because of a missing cast
                    type_ = function_param0_type(db, func)
                    col = None
                    if type_ == 'anyelement': type_ = 'text'
                    elif type_ == 'hstore': # cast just the value param
                        type_ = 'text'
                        col = 'value'
                    raise MissingCastException(type_, col, e)
        raise DoesNotExistException(type_, name, e)
    
    raise # no specific exception raised

def with_savepoint(db, func): return db.with_savepoint(func)

def run_query(db, query, recover=None, cacheable=False, log_level=2,
    log_ignore_excs=None, **kw_args):
    '''For params, see DbConn.run_query()'''
    if recover == None: recover = False
    if log_ignore_excs == None: log_ignore_excs = ()
    log_ignore_excs = tuple(log_ignore_excs)
    debug_msg_ref = [None]
    
    query = with_explain_comment(db, query)
    
    try:
        try:
            def run(): return db.run_query(query, cacheable, log_level,
                debug_msg_ref, **kw_args)
            if recover and not db.is_cached(query):
                return with_savepoint(db, run)
            else: return run() # don't need savepoint if cached
        except Exception, e:
            # Give failed EXPLAIN approximately the log_level of its query
            if query.startswith('EXPLAIN'): log_level -= 1
            
            parse_exception(db, e, recover)
    except log_ignore_excs:
        log_level += 2
        raise
    finally:
        if debug_msg_ref[0] != None: db.log_debug(debug_msg_ref[0], log_level)

##### Basic queries

def is_explainable(query):
    # See <http://www.postgresql.org/docs/8.3/static/sql-explain.html#AEN57749>
    return re.match(r'^(?:SELECT|INSERT|UPDATE|DELETE|VALUES|EXECUTE|DECLARE)\b'
        , query)

def explain(db, query, **kw_args):
    '''
    For params, see run_query().
    '''
    kw_args.setdefault('log_level', 4)
    
    return strings.ustr(strings.join_lines(values(run_query(db,
        'EXPLAIN '+query, recover=True, cacheable=True, **kw_args))))
        # not a higher log_level because it's useful to see what query is being
        # run before it's executed, which EXPLAIN effectively provides

def has_comment(query): return query.endswith('*/')

def with_explain_comment(db, query, **kw_args):
    if db.autoexplain and not has_comment(query) and is_explainable(query):
        query += '\n'+sql_gen.esc_comment(' EXPLAIN:\n'
            +explain(db, query, **kw_args))
    return query

def next_version(name):
    version = 1 # first existing name was version 0
    match = re.match(r'^(.*)#(\d+)$', name)
    if match:
        name, version = match.groups()
        version = int(version)+1
    return sql_gen.concat(name, '#'+str(version))

def lock_table(db, table, mode):
    table = sql_gen.as_Table(table)
    run_query(db, 'LOCK TABLE '+table.to_str(db)+' IN '+mode+' MODE')

def run_query_into(db, query, into=None, add_pkey_=False, **kw_args):
    '''Outputs a query to a temp table.
    For params, see run_query().
    '''
    if into == None: return run_query(db, query, **kw_args)
    
    assert isinstance(into, sql_gen.Table)
    
    into.is_temp = True
    # "temporary tables cannot specify a schema name", so remove schema
    into.schema = None
    
    kw_args['recover'] = True
    kw_args.setdefault('log_ignore_excs', (DuplicateException,))
    
    temp = not db.debug_temp # tables are permanent in debug_temp mode
    
    # Create table
    while True:
        create_query = 'CREATE'
        if temp: create_query += ' TEMP'
        create_query += ' TABLE '+into.to_str(db)+' AS\n'+query
        
        try:
            cur = run_query(db, create_query, **kw_args)
                # CREATE TABLE AS sets rowcount to # rows in query
            break
        except DuplicateException, e:
            into.name = next_version(into.name)
            # try again with next version of name
    
    if add_pkey_: add_pkey_or_index(db, into, warn=True)
    
    # According to the PostgreSQL doc, "The autovacuum daemon cannot access and
    # therefore cannot vacuum or analyze temporary tables. [...] if a temporary
    # table is going to be used in complex queries, it is wise to run ANALYZE on
    # the temporary table after it is populated."
    # (http://www.postgresql.org/docs/9.1/static/sql-createtable.html)
    # If into is not a temp table, ANALYZE is useful but not required.
    analyze(db, into)
    
    return cur

order_by_pkey = object() # tells mk_select() to order by the pkey

distinct_on_all = object() # tells mk_select() to SELECT DISTINCT ON all columns

def has_subset_func(db, table):
    return sql_gen.is_underlying_table(table) and function_exists(db, table)

def mk_select(db, tables=None, fields=None, conds=None, distinct_on=[],
    limit=None, start=None, order_by=order_by_pkey, default_table=None,
    explain=True):
    '''
    @param tables The single table to select from, or a list of tables to join
        together, with tables after the first being sql_gen.Join objects
    @param fields Use None to select all fields in the table
    @param conds WHERE conditions: [(compare_left_side, compare_right_side),...]
        * container can be any iterable type
        * compare_left_side: sql_gen.Code|str (for col name)
        * compare_right_side: sql_gen.ValueCond|literal value
    @param distinct_on The columns to SELECT DISTINCT ON, or distinct_on_all to
        use all columns
    @return query
    '''
    # Parse tables param
    tables = lists.mk_seq(tables)
    tables = list(tables) # don't modify input! (list() copies input)
    table0 = sql_gen.as_Table(tables.pop(0)) # first table is separate
    
    # Parse other params
    if conds == None: conds = []
    elif dicts.is_dict(conds): conds = conds.items()
    conds = list(conds) # don't modify input! (list() copies input)
    assert limit == None or isinstance(limit, (int, long))
    assert start == None or isinstance(start, (int, long))
    if limit == 0: order_by = None
    if order_by in (None, order_by_pkey) and has_subset_func(db, table0):
        # can use subset function for fast querying at large OFFSET values
        run_query(db, 'SET LOCAL enable_sort TO off')
        table0 = sql_gen.FunctionCall(table0, limit_=limit, offset_=start)
        if limit != 0: limit = None # done by function
        start = None # done by function
        order_by = None # done by function
    elif order_by is order_by_pkey:
        if lists.is_seq(distinct_on) and distinct_on: order_by = distinct_on[0]
        elif table0 != None: order_by = table_order_by(db, table0, recover=True)
        else: order_by = None
    
    query = 'SELECT'
    
    def parse_col(col): return sql_gen.as_Col(col, default_table).to_str(db)
    
    # DISTINCT ON columns
    if distinct_on != []:
        query += '\nDISTINCT'
        if distinct_on is not distinct_on_all:
            query += ' ON ('+(', '.join(map(parse_col, distinct_on)))+')'
    
    # Columns
    if query.find('\n') >= 0: whitespace = '\n'
    else: whitespace = ' '
    if fields == None: query += whitespace+'*'
    else:
        assert fields != []
        if len(fields) > 1: whitespace = '\n'
        query += whitespace+('\n, '.join(map(parse_col, fields)))
    
    # Main table
    if query.find('\n') >= 0 or len(tables) > 0: whitespace = '\n'
    else: whitespace = ' '
    if table0 != None: query += whitespace+'FROM '+table0.to_str(db)
    
    # Add joins
    left_table = table0
    for join_ in tables:
        table = join_.table
        
        # Parse special values
        if join_.type_ is sql_gen.filter_out: # filter no match
            conds.append((sql_gen.Col(table_not_null_col(db, table), table),
                sql_gen.CompareCond(None, '~=')))
        
        query += '\n'+join_.to_str(db, left_table)
        
        left_table = table
    
    missing = True
    if conds != []:
        if len(conds) == 1: whitespace = ' '
        else: whitespace = '\n'
        query += '\n'+sql_gen.combine_conds([sql_gen.ColValueCond(l, r)
            .to_str(db) for l, r in conds], 'WHERE')
    if order_by != None:
        query += '\nORDER BY '+sql_gen.as_Col(order_by).to_str(db)
    if limit != None: query += '\nLIMIT '+str(limit)
    if start != None:
        if start != 0: query += '\nOFFSET '+str(start)
    
    if explain: query = with_explain_comment(db, query)
    
    return query

def select(db, *args, **kw_args):
    '''For params, see mk_select() and run_query()'''
    recover = kw_args.pop('recover', None)
    cacheable = kw_args.pop('cacheable', True)
    log_level = kw_args.pop('log_level', 2)
    
    return run_query(db, mk_select(db, *args, **kw_args), recover, cacheable,
        log_level=log_level)

def mk_insert_select(db, table, cols=None, select_query=None, returning=None,
    embeddable=False, ignore=False, src=None):
    '''
    @param returning str|None An inserted column (such as pkey) to return
    @param embeddable Whether the query should be embeddable as a nested SELECT.
        Warning: If you set this and cacheable=True when the query is run, the
        query will be fully cached, not just if it raises an exception.
    @param ignore Whether to ignore duplicate keys.
    @param src Will be included in the name of any created function, to help
        identify the data source in pg_stat_activity.
    '''
    table = sql_gen.remove_table_rename(sql_gen.as_Table(table))
    if cols == []: cols = None # no cols (all defaults) = unknown col names
    if cols != None: cols = [sql_gen.to_name_only_col(c, table) for c in cols]
    if select_query == None: select_query = 'DEFAULT VALUES'
    if returning != None: returning = sql_gen.as_Col(returning, table)
    
    first_line = 'INSERT INTO '+table.to_str(db)
    
    def mk_insert(select_query):
        query = first_line
        if cols != None:
            query += '\n('+(', '.join((c.to_str(db) for c in cols)))+')'
        query += '\n'+select_query
        
        if returning != None:
            returning_name_col = sql_gen.to_name_only_col(returning)
            query += '\nRETURNING '+returning_name_col.to_str(db)
        
        return query
    
    return_type = sql_gen.CustomCode('unknown')
    if returning != None: return_type = sql_gen.ColType(returning)
    
    if ignore:
        # Always return something to set the correct rowcount
        if returning == None: returning = sql_gen.NamedCol('NULL', None)
        
        embeddable = True # must use function
        
        if cols == None: row = [sql_gen.Col(sql_gen.all_cols, 'row')]
        else: row = [sql_gen.Col(c.name, 'row') for c in cols]
        
        query = sql_gen.RowExcIgnore(sql_gen.RowType(table), select_query,
            sql_gen.ReturnQuery(mk_insert(sql_gen.Values(row).to_str(db))),
            cols)
    else: query = mk_insert(select_query)
    
    if embeddable:
        # Create function
        function_name = sql_gen.clean_name(first_line)
        if src != None: function_name = src+': '+function_name
        while True:
            try:
                func = db.TempFunction(function_name)
                def_ = sql_gen.FunctionDef(func, sql_gen.SetOf(return_type),
                    query)
                
                run_query(db, def_.to_str(db), recover=True, cacheable=True,
                    log_ignore_excs=(DuplicateException,))
                break # this version was successful
            except DuplicateException, e:
                function_name = next_version(function_name)
                # try again with next version of name
        
        # Return query that uses function
        cols = None
        if returning != None: cols = [returning]
        func_table = sql_gen.NamedTable('f', sql_gen.FunctionCall(func), cols)
            # AS clause requires function alias
        return mk_select(db, func_table, order_by=None)
    
    return query

def insert_select(db, table, *args, **kw_args):
    '''For params, see mk_insert_select() and run_query_into()
    @param into sql_gen.Table with suggested name of temp table to put RETURNING
        values in
    '''
    returning = kw_args.get('returning', None)
    ignore = kw_args.get('ignore', False)
    
    into = kw_args.pop('into', None)
    if into != None: kw_args['embeddable'] = True
    recover = kw_args.pop('recover', None)
    if ignore: recover = True
    cacheable = kw_args.pop('cacheable', True)
    log_level = kw_args.pop('log_level', 2)
    
    rowcount_only = ignore and returning == None # keep NULL rows on server
    if rowcount_only: into = sql_gen.Table('rowcount')
    
    cur = run_query_into(db, mk_insert_select(db, table, *args, **kw_args),
        into, recover=recover, cacheable=cacheable, log_level=log_level)
    if rowcount_only: empty_temp(db, into)
    autoanalyze(db, table)
    return cur

default = sql_gen.default # tells insert() to use the default value for a column

def insert(db, table, row, *args, **kw_args):
    '''For params, see insert_select()'''
    ignore = kw_args.pop('ignore', False)
    if ignore: kw_args.setdefault('recover', True)
    
    if lists.is_seq(row): cols = None
    else:
        cols = row.keys()
        row = row.values()
    row = list(row) # ensure that "== []" works
    
    if row == []: query = None
    else: query = sql_gen.Values(row).to_str(db)
    
    try: return insert_select(db, table, cols, query, *args, **kw_args)
    except (DuplicateKeyException, NullValueException):
        if not ignore: raise
        return None

def mk_update(db, table, changes=None, cond=None, in_place=False,
    cacheable_=True):
    '''
    @param changes [(col, new_value),...]
        * container can be any iterable type
        * col: sql_gen.Code|str (for col name)
        * new_value: sql_gen.Code|literal value
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
    @param in_place If set, locks the table and updates rows in place.
        This avoids creating dead rows in PostgreSQL.
        * cond must be None
    @param cacheable_ Whether column structure information used to generate the
        query can be cached
    @return str query
    '''
    table = sql_gen.as_Table(table)
    changes = [(sql_gen.to_name_only_col(c, table), sql_gen.as_Value(v))
        for c, v in changes]
    
    if in_place:
        assert cond == None
        
        def col_type(col):
            return sql_gen.canon_type(db.col_info(
                sql_gen.with_default_table(c, table), cacheable_).type)
        changes = [(c, v, col_type(c)) for c, v in changes]
        query = 'ALTER TABLE '+table.to_str(db)+'\n'
        query += ',\n'.join(('ALTER COLUMN '+c.to_str(db)+' TYPE '+t+'\nUSING '
            +v.to_str(db) for c, v, t in changes))
    else:
        query = 'UPDATE '+table.to_str(db)+'\nSET\n'
        query += ',\n'.join((c.to_str(db)+' = '+v.to_str(db)
            for c, v in changes))
        if cond != None: query += '\nWHERE\n'+cond.to_str(db)
    
    query = with_explain_comment(db, query)
    
    return query

def update(db, table, *args, **kw_args):
    '''For params, see mk_update() and run_query()'''
    recover = kw_args.pop('recover', None)
    cacheable = kw_args.pop('cacheable', False)
    log_level = kw_args.pop('log_level', 2)
    
    cur = run_query(db, mk_update(db, table, *args, **kw_args), recover,
        cacheable, log_level=log_level)
    autoanalyze(db, table)
    return cur

def mk_delete(db, table, cond=None):
    '''
    @param cond sql_gen.Code WHERE condition. e.g. use sql_gen.*Cond objects.
    @return str query
    '''
    query = 'DELETE FROM '+table.to_str(db)
    if cond != None: query += '\nWHERE '+cond.to_str(db)
    
    query = with_explain_comment(db, query)
    
    return query

def delete(db, table, *args, **kw_args):
    '''For params, see mk_delete() and run_query()'''
    recover = kw_args.pop('recover', None)
    cacheable = kw_args.pop('cacheable', True)
    log_level = kw_args.pop('log_level', 2)
    
    cur = run_query(db, mk_delete(db, table, *args, **kw_args), recover,
        cacheable, log_level=log_level)
    autoanalyze(db, table)
    return cur

def last_insert_id(db):
    module = util.root_module(db.db)
    if module == 'psycopg2': return value(run_query(db, 'SELECT lastval()'))
    elif module == 'MySQLdb': return db.insert_id()
    else: return None

def define_func(db, def_):
    func = def_.function
    while True:
        try:
            run_query(db, def_.to_str(db), recover=True, cacheable=True,
                log_ignore_excs=(DuplicateException,))
            break # successful
        except DuplicateException:
            func.name = next_version(func.name)
            # try again with next version of name

def mk_flatten_mapping(db, into, cols, preserve=[], as_items=False):
    '''Creates a mapping from original column names (which may have collisions)
    to names that will be distinct among the columns' tables.
    This is meant to be used for several tables that are being joined together.
    @param cols The columns to combine. Duplicates will be removed.
    @param into The table for the new columns.
    @param preserve [sql_gen.Col...] Columns not to rename. Note that these
        columns will be included in the mapping even if they are not in cols.
        The tables of the provided Col objects will be changed to into, so make
        copies of them if you want to keep the original tables.
    @param as_items Whether to return a list of dict items instead of a dict
    @return dict(orig_col=new_col, ...)
        * orig_col: sql_gen.Col(orig_col_name, orig_table)
        * new_col: sql_gen.Col(orig_col_name, into)
        * All mappings use the into table so its name can easily be
          changed for all columns at once
    '''
    cols = lists.uniqify(cols)
    
    items = []
    for col in preserve:
        orig_col = copy.copy(col)
        col.table = into
        items.append((orig_col, col))
    preserve = set(preserve)
    for col in cols:
        if col not in preserve:
            items.append((col, sql_gen.Col(strings.ustr(col), into, col.srcs)))
    
    if not as_items: items = dict(items)
    return items

def flatten(db, into, joins, cols, limit=None, start=None, **kw_args):
    '''For params, see mk_flatten_mapping()
    @return See return value of mk_flatten_mapping()
    '''
    items = mk_flatten_mapping(db, into, cols, as_items=True, **kw_args)
    cols = [sql_gen.NamedCol(new.name, old) for old, new in items]
    run_query_into(db, mk_select(db, joins, cols, limit=limit, start=start),
        into=into, add_pkey_=True)
        # don't cache because the temp table will usually be truncated after use
    return dict(items)

##### Database structure introspection

#### Tables

def tables(db, schema_like='public', table_like='%', exact=False,
    cacheable=True):
    if exact: compare = '='
    else: compare = 'LIKE'
    
    module = util.root_module(db.db)
    if module == 'psycopg2':
        conds = [('schemaname', sql_gen.CompareCond(schema_like, compare)),
            ('tablename', sql_gen.CompareCond(table_like, compare))]
        return values(select(db, 'pg_tables', ['tablename'], conds,
            order_by='tablename', cacheable=cacheable, log_level=4))
    elif module == 'MySQLdb':
        return values(run_query(db, 'SHOW TABLES LIKE '+db.esc_value(table_like)
            , cacheable=True, log_level=4))
    else: raise NotImplementedError("Can't list tables for "+module+' database')

def table_exists(db, table, cacheable=True):
    table = sql_gen.as_Table(table)
    return list(tables(db, table.schema, table.name, True, cacheable)) != []

def table_row_count(db, table, recover=None):
    return value(run_query(db, mk_select(db, table, [sql_gen.row_count],
        order_by=None), recover=recover, log_level=3))

def table_col_names(db, table, recover=None):
    return list(col_names(select(db, table, limit=0, recover=recover,
        log_level=4)))

def table_cols(db, table, *args, **kw_args):
    return [sql_gen.as_Col(strings.ustr(c), table)
        for c in table_col_names(db, table, *args, **kw_args)]

def table_pkey_index(db, table, recover=None):
    table_str = sql_gen.Literal(table.to_str(db))
    try:
        return sql_gen.Table(value(run_query(db, '''\
SELECT relname
FROM pg_index
JOIN pg_class index ON index.oid = indexrelid
WHERE
indrelid = '''+table_str.to_str(db)+'''::regclass
AND indisprimary
'''
            , recover, cacheable=True, log_level=4)), table.schema)
    except StopIteration: raise DoesNotExistException('primary key', '')

def table_pkey_col(db, table, recover=None):
    table = sql_gen.as_Table(table)
    
    module = util.root_module(db.db)
    if module == 'psycopg2':
        return sql_gen.Col(index_cols(db, table_pkey_index(db, table,
            recover))[0], table)
    else:
        join_cols = ['table_schema', 'table_name', 'constraint_schema',
            'constraint_name']
        tables = [sql_gen.Table('key_column_usage', 'information_schema'),
            sql_gen.Join(
                sql_gen.Table('table_constraints', 'information_schema'),
                dict(((c, sql_gen.join_same_not_null) for c in join_cols)))]
        cols = [sql_gen.Col('column_name')]
        
        conds = [('constraint_type', 'PRIMARY KEY'), ('table_name', table.name)]
        schema = table.schema
        if schema != None: conds.append(('table_schema', schema))
        order_by = 'position_in_unique_constraint'
        
        try: return sql_gen.Col(value(select(db, tables, cols, conds,
            order_by=order_by, limit=1, log_level=4)), table)
        except StopIteration: raise DoesNotExistException('primary key', '')

def table_has_pkey(db, table, recover=None):
    try: table_pkey_col(db, table, recover)
    except DoesNotExistException: return False
    else: return True

def pkey_name(db, table, recover=None):
    '''If no pkey, returns the first column in the table.'''
    return pkey_col(db, table, recover).name

def pkey_col(db, table, recover=None):
    '''If no pkey, returns the first column in the table.'''
    try: return table_pkey_col(db, table, recover)
    except DoesNotExistException: return table_cols(db, table, recover)[0]

not_null_col = 'not_null_col'

def table_not_null_col(db, table, recover=None):
    '''Name assumed to be the value of not_null_col. If not found, uses pkey.'''
    if not_null_col in table_col_names(db, table, recover): return not_null_col
    else: return pkey_name(db, table, recover)

def constraint_cond(db, constraint):
    module = util.root_module(db.db)
    if module == 'psycopg2':
        table_str = sql_gen.Literal(constraint.table.to_str(db))
        name_str = sql_gen.Literal(constraint.name)
        return value(run_query(db, '''\
SELECT consrc
FROM pg_constraint
WHERE
conrelid = '''+table_str.to_str(db)+'''::regclass
AND conname = '''+name_str.to_str(db)+'''
'''
            , cacheable=True, log_level=4))
    else: raise NotImplementedError("Can't get constraint condition for "
        +module+' database')

def index_exprs(db, index):
    index = sql_gen.as_Table(index)
    module = util.root_module(db.db)
    if module == 'psycopg2':
        qual_index = sql_gen.Literal(index.to_str(db))
        return list(values(run_query(db, '''\
SELECT pg_get_indexdef(indexrelid, generate_series(1, indnatts), true)
FROM pg_index
WHERE indexrelid = '''+qual_index.to_str(db)+'''::regclass
'''
            , cacheable=True, log_level=4)))
    else: raise NotImplementedError()

def index_cols(db, index):
    '''Can also use this for UNIQUE constraints, because a UNIQUE index is
    automatically created. When you don't know whether something is a UNIQUE
    constraint or a UNIQUE index, use this function.'''
    return map(sql_gen.parse_expr_col, index_exprs(db, index))

def index_cond(db, index):
    index = sql_gen.as_Table(index)
    module = util.root_module(db.db)
    if module == 'psycopg2':
        qual_index = sql_gen.Literal(index.to_str(db))
        return value(run_query(db, '''\
SELECT pg_get_expr(indpred, indrelid, true)
FROM pg_index
WHERE indexrelid = '''+qual_index.to_str(db)+'''::regclass
'''
            , cacheable=True, log_level=4))
    else: raise NotImplementedError()

def index_order_by(db, index):
    return sql_gen.CustomCode(', '.join(index_exprs(db, index)))

def table_cluster_on(db, table, recover=None):
    '''
    @return The table's cluster index, or its pkey if none is set
    '''
    table_str = sql_gen.Literal(table.to_str(db))
    try:
        return sql_gen.Table(value(run_query(db, '''\
SELECT relname
FROM pg_index
JOIN pg_class index ON index.oid = indexrelid
WHERE
indrelid = '''+table_str.to_str(db)+'''::regclass
AND indisclustered
'''
            , recover, cacheable=True, log_level=4)), table.schema)
    except StopIteration: return table_pkey_index(db, table, recover)

def table_order_by(db, table, recover=None):
    '''
    @return None if table is view, because table_cluster_on() would return None
    '''
    if table.order_by == None:
        try: table.order_by = index_order_by(db, table_cluster_on(db, table,
            recover))
        except DoesNotExistException: pass
    return table.order_by

#### Views

def view_exists(db, view):
    view_str = sql_gen.Literal(view.to_str(db))
    try:
        return value(run_query(db, '''\
SELECT relkind = 'v'
FROM pg_class
WHERE oid = '''+view_str.to_str(db)+'''::regclass
'''
            , cacheable=True, log_level=4))
    except DoesNotExistException: return False

#### Functions

def function_exists(db, function):
    qual_function = sql_gen.Literal(function.to_str(db))
    try:
        select(db, fields=[sql_gen.Cast('regproc', qual_function)],
            recover=True, cacheable=True, log_level=4)
    except DoesNotExistException: return False
    except DuplicateException: return True # overloaded function
    else: return True

def function_param0_type(db, function):
    qual_function = sql_gen.Literal(function.to_str(db))
    return value(run_query(db, '''\
SELECT proargtypes[0]::regtype
FROM pg_proc
WHERE oid = '''+qual_function.to_str(db)+'''::regproc
'''
        , cacheable=True, log_level=4))

##### Structural changes

#### Columns

def add_col(db, table, col, comment=None, if_not_exists=False, **kw_args):
    '''
    @param col TypedCol Name may be versioned, so be sure to propagate any
        renaming back to any source column for the TypedCol.
    @param comment None|str SQL comment used to distinguish columns of the same
        name from each other when they contain different data, to allow the
        ADD COLUMN query to be cached. If not set, query will not be cached.
    '''
    assert isinstance(col, sql_gen.TypedCol)
    
    while True:
        str_ = 'ALTER TABLE '+table.to_str(db)+' ADD COLUMN '+col.to_str(db)
        if comment != None: str_ += ' '+sql_gen.esc_comment(comment)
        
        try:
            run_query(db, str_, recover=True, cacheable=True, **kw_args)
            break
        except DuplicateException:
            if if_not_exists: raise
            col.name = next_version(col.name)
            # try again with next version of name

def add_not_null(db, col):
    table = col.table
    col = sql_gen.to_name_only_col(col)
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
        +col.to_str(db)+' SET NOT NULL', cacheable=True, log_level=3)

def drop_not_null(db, col):
    table = col.table
    col = sql_gen.to_name_only_col(col)
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ALTER COLUMN '
        +col.to_str(db)+' DROP NOT NULL', cacheable=True, log_level=3)

row_num_col = '_row_num'

row_num_col_def = sql_gen.TypedCol('', 'serial', nullable=False,
    constraints='PRIMARY KEY')

def add_row_num(db, table, name=row_num_col):
    '''Adds a row number column to a table. Its definition is in
    row_num_col_def. It will be the primary key.'''
    col_def = copy.copy(row_num_col_def)
    col_def.name = name
    add_col(db, table, col_def, comment='', if_not_exists=True, log_level=3)

#### Indexes

def add_pkey(db, table, cols=None, recover=None):
    '''Adds a primary key.
    @param cols [sql_gen.Col,...] The columns in the primary key.
        Defaults to the first column in the table.
    @pre The table must not already have a primary key.
    '''
    table = sql_gen.as_Table(table)
    if cols == None: cols = [pkey_name(db, table, recover)]
    col_strs = [sql_gen.to_name_only_col(v).to_str(db) for v in cols]
    
    run_query(db, 'ALTER TABLE '+table.to_str(db)+' ADD PRIMARY KEY ('
        +(', '.join(col_strs))+')', recover=True, cacheable=True, log_level=3,
        log_ignore_excs=(DuplicateException,))

def add_index(db, exprs, table=None, unique=False, ensure_not_null_=True):
    '''Adds an index on column(s) or expression(s) if it doesn't already exist.
    Currently, only function calls and literal values are supported expressions.
    @param ensure_not_null_ If set, translates NULL values to sentinel values.
        This allows indexes to be used for comparisons where NULLs are equal.
    '''
    exprs = lists.mk_seq(exprs)
    
    # Parse exprs
    old_exprs = exprs[:]
    exprs = []
    cols = []
    for i, expr in enumerate(old_exprs):
        expr = sql_gen.as_Col(expr, table)
        
        # Handle nullable columns
        if ensure_not_null_:
            try: expr = sql_gen.ensure_not_null(db, expr)
            except KeyError: pass # unknown type, so just create plain index
        
        # Extract col
        expr = copy.deepcopy(expr) # don't modify input!
        col = expr
        if isinstance(expr, sql_gen.FunctionCall): col = expr.args[0]
        expr = sql_gen.cast_literal(expr)
        if not isinstance(expr, (sql_gen.Expr, sql_gen.Col)):
            expr = sql_gen.Expr(expr)
            
        
        # Extract table
        if table == None:
            assert sql_gen.is_table_col(col)
            table = col.table
        
        if isinstance(col, sql_gen.Col): col.table = None
        
        exprs.append(expr)
        cols.append(col)
    
    table = sql_gen.as_Table(table)
    
    # Add index
    str_ = 'CREATE'
    if unique: str_ += ' UNIQUE'
    str_ += ' INDEX ON '+table.to_str(db)+' ('+(
        ', '.join((v.to_str(db) for v in exprs)))+')'
    run_query(db, str_, recover=True, cacheable=True, log_level=3)

def add_pkey_index(db, table): add_index(db, pkey_col(db, table), table)

def add_pkey_or_index(db, table, cols=None, recover=None, warn=False):
    try: add_pkey(db, table, cols, recover)
    except DuplicateKeyException, e:
        if warn: warnings.warn(UserWarning(exc.str_(e)))
        add_pkey_index(db, table)

already_indexed = object() # tells add_indexes() the pkey has already been added

def add_indexes(db, table, has_pkey=True):
    '''Adds an index on all columns in a table.
    @param has_pkey bool|already_indexed Whether a pkey instead of a regular
        index should be added on the first column.
        * If already_indexed, the pkey is assumed to have already been added
    '''
    cols = table_col_names(db, table)
    if has_pkey:
        if has_pkey is not already_indexed: add_pkey(db, table)
        cols = cols[1:]
    for col in cols: add_index(db, col, table)

#### Tables

### Maintenance

def analyze(db, table):
    table = sql_gen.as_Table(table)
    run_query(db, 'ANALYZE '+table.to_str(db), log_level=3)

def autoanalyze(db, table):
    if db.autoanalyze: analyze(db, table)

def vacuum(db, table):
    table = sql_gen.as_Table(table)
    db.with_autocommit(lambda: run_query(db, 'VACUUM ANALYZE '+table.to_str(db),
        log_level=3))

### Lifecycle

def drop(db, type_, name):
    name = sql_gen.as_Name(name)
    run_query(db, 'DROP '+type_+' IF EXISTS '+name.to_str(db)+' CASCADE')

def drop_table(db, table): drop(db, 'TABLE', table)

def create_table(db, table, cols=[], has_pkey=True, col_indexes=True,
    like=None):
    '''Creates a table.
    @param cols [sql_gen.TypedCol,...] The column names and types
    @param has_pkey If set, the first column becomes the primary key.
    @param col_indexes bool|[ref]
        * If True, indexes will be added on all non-pkey columns.
        * If a list reference, [0] will be set to a function to do this.
          This can be used to delay index creation until the table is populated.
    '''
    table = sql_gen.as_Table(table)
    
    if like != None:
        cols = [sql_gen.CustomCode('LIKE '+like.to_str(db)+' INCLUDING ALL')
            ]+cols
        table.order_by = like.order_by
    if has_pkey:
        cols[0] = pkey = copy.copy(cols[0]) # don't modify input!
        pkey.constraints = 'PRIMARY KEY'
    
    temp = table.is_temp and not db.debug_temp
        # temp tables permanent in debug_temp mode
    
    # Create table
    def create():
        str_ = 'CREATE'
        if temp: str_ += ' TEMP'
        str_ += ' TABLE '+table.to_str(db)+' (\n'
        str_ += '\n, '.join(c.to_str(db) for c in cols)
        str_ += '\n);'
        
        opts = dict(recover=True, cacheable=True, log_level=2,
            log_ignore_excs=(DuplicateException,))
        try: run_query(db, str_, **opts)
        except InvalidTypeException: # try again as view
            run_query_into(db, mk_select(db, like, limit=0), into=table, **opts)
    if table.is_temp:
        while True:
            try:
                create()
                break
            except DuplicateException:
                table.name = next_version(table.name)
                # try again with next version of name
    else: create()
    
    # Add indexes
    if has_pkey: has_pkey = already_indexed
    def add_indexes_(): add_indexes(db, table, has_pkey)
    if isinstance(col_indexes, list): col_indexes[0] = add_indexes_ # defer
    elif col_indexes: add_indexes_() # add now

def copy_table_struct(db, src, dest):
    '''Creates a structure-only copy of a table. (Does not copy data.)'''
    create_table(db, dest, has_pkey=False, col_indexes=False, like=src)

def copy_table(db, src, dest):
    '''Creates a copy of a table, including data'''
    copy_table_struct(db, src, dest)
    insert_select(db, dest, None, mk_select(db, src))

### Data

def truncate(db, table, schema='public', **kw_args):
    '''For params, see run_query()'''
    table = sql_gen.as_Table(table, schema)
    return run_query(db, 'TRUNCATE '+table.to_str(db)+' CASCADE', **kw_args)

def empty_temp(db, tables):
    tables = lists.mk_seq(tables)
    for table in tables: truncate(db, table, log_level=3)

def empty_db(db, schema='public', **kw_args):
    '''For kw_args, see tables()'''
    for table in tables(db, schema, **kw_args): truncate(db, table, schema)

def distinct_table(db, table, distinct_on, joins=None):
    '''Creates a copy of a temp table which is distinct on the given columns.
    Adds an index on table's distinct_on columns, to facilitate merge joins.
    @param distinct_on If empty, creates a table with one row. This is useful if
        your distinct_on columns are all literal values.
    @param joins The joins to use when creating the new table
    @return The new table.
    '''
    if joins == None: joins = [table]
    
    new_table = sql_gen.suffixed_table(table, '_distinct')
    distinct_on = filter(sql_gen.is_table_col, distinct_on)
    
    copy_table_struct(db, table, new_table)
    
    limit = None
    if distinct_on == []: limit = 1 # one sample row
    else: add_index(db, distinct_on, table) # for join optimization
    
    insert_select(db, new_table, None, mk_select(db, joins,
        [sql_gen.Col(sql_gen.all_cols, table)], distinct_on=distinct_on,
        order_by=None, limit=limit))
    analyze(db, new_table)
    
    return new_table
