# SQL code generation

import copy
import itertools
import operator
from ordereddict import OrderedDict
import re
import UserDict
import warnings

import dicts
import exc
import iters
import lists
import objects
import regexp
import strings
import util

##### Names

identifier_max_len = 63 # works for both PostgreSQL and MySQL

def concat(str_, suffix):
    '''Preserves version so that it won't be truncated off the string, leading
    to collisions.'''
    # Preserve version
    match = re.match(r'^(.*?)((?:(?:#\d+)?\)?)*(?:\.\w+)+(?:::[\w ]+)*)$', str_)
    if match:
        str_, old_suffix = match.groups()
        suffix = old_suffix+suffix
    
    return strings.concat(str_, suffix, identifier_max_len)

def truncate(str_): return concat(str_, '')

def is_safe_name(name):
    '''A name is safe *and unambiguous* if it:
    * contains only *lowercase* word (\w) characters
    * doesn't start with a digit
    * contains "_", so that it's not a keyword
    '''
    return re.match(r'^(?=.*_)(?!\d)[^\WA-Z]+$', name)

def esc_name(name, quote='"'):
    return quote + name.replace(quote, quote+quote) + quote
        # doubling an embedded quote escapes it in both PostgreSQL and MySQL

def unesc_name(name, quote='"'):
    removed_ref = [False]
    name = strings.remove_prefix(quote, name, removed_ref)
    if removed_ref[0]:
        name = strings.remove_suffix(quote, name, removed_ref)
        assert removed_ref[0]
        name = name.replace(quote+quote, quote)
    return name

def clean_name(name): return name.replace('"', '').replace('`', '')

def esc_comment(comment): return '/*'+comment.replace('*/', '* /')+'*/'

def lstrip(str_):
    '''Also removes comments.'''
    if str_.startswith('/*'): comment, sep, str_ = str_.partition('*/')
    return str_.lstrip()

##### General SQL code objects

class MockDb:
    def esc_value(self, value): return strings.repr_no_u(value)
    
    def esc_name(self, name): return esc_name(name)
    
    def col_info(self, col):
        return TypedCol(col.name, '<type>', CustomCode('<default>'), True)

mockDb = MockDb()

class BasicObject(objects.BasicObject):
    def __str__(self): return clean_name(strings.repr_no_u(self))

##### Unparameterized code objects

class Code(BasicObject):
    def __init__(self, lang='sql'):
        self.lang = lang
    
    def to_str(self, db): raise NotImplementedError()
    
    def __repr__(self): return self.to_str(mockDb)

class CustomCode(Code):
    def __init__(self, str_):
        Code.__init__(self)
        
        self.str_ = str_
    
    def to_str(self, db): return self.str_

def as_Code(value, db=None):
    '''
    @param db If set, runs db.std_code() on the value.
    '''
    if isinstance(value, Code): return value
    
    if util.is_str(value):
        if db != None: value = db.std_code(value)
        return CustomCode(value)
    else: return Literal(value)

class Expr(Code):
    def __init__(self, expr):
        Code.__init__(self)
        
        self.expr = expr
    
    def to_str(self, db): return '('+self.expr.to_str(db)+')'

##### Names

class Name(Code):
    def __init__(self, name):
        Code.__init__(self)
        
        name = truncate(name)
        
        self.name = name
    
    def to_str(self, db): return db.esc_name(self.name)

def as_Name(value):
    if isinstance(value, Code): return value
    else: return Name(value)

##### Literal values

#### Primitives

class Literal(Code):
    def __init__(self, value):
        Code.__init__(self)
        
        self.value = value
    
    def to_str(self, db): return db.esc_value(self.value)

def as_Value(value):
    if isinstance(value, Code): return value
    else: return Literal(value)

def get_value(value):
    '''Unwraps a Literal's value'''
    value = remove_col_rename(value)
    if isinstance(value, Literal): return value.value
    else:
        assert not isinstance(value, Code)
        return value

def is_literal(value): return isinstance(value, Literal)

def is_null(value): return is_literal(value) and value.value == None

#### Composites

class List(Code):
    def __init__(self, values):
        Code.__init__(self)
        
        self.values = values
    
    def to_str(self, db): return ', '.join((v.to_str(db) for v in self.values))

class Tuple(List):
    def __init__(self, *values):
        List.__init__(self, values)
    
    def to_str(self, db): return '('+List.to_str(self, db)+')'

class Row(Tuple):
    def to_str(self, db): return 'ROW'+Tuple.to_str(self, db)

### Arrays

class Array(List):
    def __init__(self, values):
        values = map(remove_col_rename, values)
        
        List.__init__(self, values)
    
    def to_str(self, db): return 'ARRAY['+List.to_str(self, db)+']'

def to_Array(value):
    if isinstance(value, Array): return value
    return Array(lists.mk_seq(value))

##### Derived elements

src_self = object() # tells Col that it is its own source column

class Derived(Code):
    def __init__(self, srcs):
        '''An element which was derived from some other element(s).
        @param srcs See self.set_srcs()
        '''
        Code.__init__(self)
        
        self.set_srcs(srcs)
    
    def set_srcs(self, srcs, overwrite=True):
        '''
        @param srcs (self_type...)|src_self The element(s) this is derived from
        '''
        if not overwrite and self.srcs != (): return # already set
        
        if srcs == src_self: srcs = (self,)
        srcs = tuple(srcs) # make Col hashable
        self.srcs = srcs
    
    def _compare_on(self):
        compare_on = self.__dict__.copy()
        del compare_on['srcs'] # ignore
        return compare_on

def cols_srcs(cols): return lists.uniqify(iters.flatten((v.srcs for v in cols)))

##### Tables

class Table(Derived):
    def __init__(self, name, schema=None, srcs=(), is_temp=False):
        '''
        @param schema str|None (for no schema)
        @param srcs (Table...)|src_self See Derived.set_srcs()
        '''
        Derived.__init__(self, srcs)
        
        if util.is_str(name): name = truncate(name)
        
        self.name = name
        self.schema = schema
        self.is_temp = is_temp
        self.order_by = None
        self.index_cols = {}
    
    def to_str(self, db):
        str_ = ''
        if self.schema != None: str_ += as_Name(self.schema).to_str(db)+'.'
        str_ += as_Name(self.name).to_str(db)
        return str_
    
    def to_Table(self): return self
    
    def _compare_on(self):
        compare_on = Derived._compare_on(self)
        del compare_on['order_by'] # ignore
        del compare_on['index_cols'] # ignore
        return compare_on

def is_underlying_table(table):
    return isinstance(table, Table) and table.to_Table() is table

class NoUnderlyingTableException(Exception):
    def __init__(self, ref):
        Exception.__init__(self, 'for: '+strings.as_tt(strings.urepr(ref)))
        self.ref = ref

def underlying_table(table):
    table = remove_table_rename(table)
    if table != None and table.srcs:
        table, = table.srcs # for derived tables or row vars
    if not is_underlying_table(table): raise NoUnderlyingTableException(table)
    return table

def as_Table(table, schema=None):
    if table == None or isinstance(table, Code): return table
    else: return Table(table, schema)

def suffixed_table(table, suffix):
    table = copy.copy(table) # don't modify input!
    table.name = concat(table.name, suffix)
    return table

class NamedTable(Table):
    def __init__(self, name, code, cols=None):
        Table.__init__(self, name)
        
        code = as_Table(code)
        if not isinstance(code, (Table, FunctionCall, Expr)): code = Expr(code)
        if cols != None: cols = [to_name_only_col(c).to_Col() for c in cols]
        
        self.code = code
        self.cols = cols
    
    def to_str(self, db):
        str_ = self.code.to_str(db)
        if str_.find('\n') >= 0: whitespace = '\n'
        else: whitespace = ' '
        str_ += whitespace+'AS '+Table.to_str(self, db)
        if self.cols != None:
            str_ += ' ('+(', '.join((c.to_str(db) for c in self.cols)))+')'
        return str_
    
    def to_Table(self): return Table(self.name)

def remove_table_rename(table):
    if isinstance(table, NamedTable): table = table.code
    return table

##### Columns

class Col(Derived):
    def __init__(self, name, table=None, srcs=()):
        '''
        @param table Table|None (for no table)
        @param srcs (Col...)|src_self See Derived.set_srcs()
        '''
        Derived.__init__(self, srcs)
        
        if util.is_str(name): name = truncate(name)
        if util.is_str(table): table = Table(table)
        assert table == None or isinstance(table, Table)
        
        self.name = name
        self.table = table
    
    def to_str(self, db, for_str=False):
        str_ = as_Name(self.name).to_str(db)
        if for_str: str_ = clean_name(str_)
        if self.table != None:
            table = self.table.to_Table()
            if for_str: str_ = concat(strings.ustr(table), '.'+str_)
            else: str_ = table.to_str(db)+'.'+str_
        return str_
    
    def __str__(self): return self.to_str(mockDb, for_str=True)
    
    def to_Col(self): return self

def is_col(col): return isinstance(col, Col)

def is_table_col(col): return is_col(col) and col.table != None

def index_col(col):
    if not is_table_col(col): return None
    
    table = col.table
    try: name = table.index_cols[col.name]
    except KeyError: return None
    else: return Col(name, table, col.srcs)

def is_temp_col(col): return is_table_col(col) and col.table.is_temp

def as_Col(col, table=None, name=None):
    '''
    @param name If not None, any non-Col input will be renamed using NamedCol.
    '''
    if name != None:
        col = as_Value(col)
        if not isinstance(col, Col): col = NamedCol(name, col)
    
    if isinstance(col, Code): return col
    elif util.is_str(col): return Col(col, table)
    else: return Literal(col)

def with_table(col, table):
    if isinstance(col, NamedCol): pass # doesn't take a table
    elif isinstance(col, FunctionCall):
        col = copy.deepcopy(col) # don't modify input!
        col.args[0].table = table
    elif isinstance(col, Col):
        col = copy.copy(col) # don't modify input!
        col.table = table
    return col

def with_default_table(col, table):
    col = as_Col(col)
    if col.table == None: col = with_table(col, table)
    return col

def set_cols_table(table, cols):
    table = as_Table(table)
    
    for i, col in enumerate(cols):
        col = cols[i] = as_Col(col)
        col.table = table

def to_name_only_col(col, check_table=None):
    col = as_Col(col)
    if not is_table_col(col): return col
    
    if check_table != None:
        table = col.table
        assert table == None or table == check_table
    return Col(col.name)

def suffixed_col(col, suffix):
    return Col(concat(col.name, suffix), col.table, col.srcs)

def has_srcs(col): return is_col(col) and col.srcs

def cross_join_srcs(cols):
    cols = filter(has_srcs, cols) # empty srcs will mess up the cross join
    srcs = [[s.name for s in c.srcs] for c in cols]
    if not srcs: return [] # itertools.product() returns [()] for empty input
    return [Col(','.join(s)) for s in itertools.product(*srcs)]

class NamedCol(Col):
    def __init__(self, name, code):
        Col.__init__(self, name)
        
        code = as_Value(code)
        
        self.code = code
    
    def to_str(self, db):
        return self.code.to_str(db)+' AS '+Col.to_str(self, db)
    
    def to_Col(self): return Col(self.name)

def remove_col_rename(col):
    if isinstance(col, NamedCol): col = col.code
    return col

def underlying_col(col):
    col = remove_col_rename(col)
    if not isinstance(col, Col): raise NoUnderlyingTableException(col)
    
    return Col(col.name, underlying_table(col.table), col.srcs)

def wrap(wrap_func, value):
    '''Wraps a value, propagating any column renaming to the returned value.'''
    if isinstance(value, NamedCol):
        return NamedCol(value.name, wrap_func(value.code))
    else: return wrap_func(value)

class ColDict(dicts.DictProxy):
    '''A dict that automatically makes inserted entries Col objects.
    Anything that isn't a column is wrapped in a NamedCol with the key's column
    name by `as_Col(value, name=key.name)`.
    '''
    
    def __init__(self, db, keys_table, dict_={}):
        dicts.DictProxy.__init__(self, OrderedDict())
        
        keys_table = as_Table(keys_table)
        
        self.db = db
        self.table = keys_table
        self.update(dict_) # after setting vars because __setitem__() needs them
    
    def copy(self): return ColDict(self.db, self.table, self.inner.copy())
    
    def __getitem__(self, key):
        return dicts.DictProxy.__getitem__(self, self._key(key))
    
    def __setitem__(self, key, value):
        key = self._key(key)
        if value == None:
            try: value = self.db.col_info(key).default
            except NoUnderlyingTableException: pass # not a table column
        dicts.DictProxy.__setitem__(self, key, as_Col(value, name=key.name))
    
    def _key(self, key): return as_Col(key, self.table)

##### Definitions

class TypedCol(Col):
    def __init__(self, name, type_, default=None, nullable=True,
        constraints=None):
        assert default == None or isinstance(default, Code)
        
        Col.__init__(self, name)
        
        self.type = type_
        self.default = default
        self.nullable = nullable
        self.constraints = constraints
    
    def to_str(self, db):
        str_ = Col.to_str(self, db)+' '+as_Code(self.type).to_str(db)
        if not self.nullable: str_ += ' NOT NULL'
        if self.default != None: str_ += ' DEFAULT '+self.default.to_str(db)
        if self.constraints != None: str_ += ' '+self.constraints
        return str_
    
    def to_Col(self): return Col(self.name)

class SetOf(Code):
    def __init__(self, type_):
        Code.__init__(self)
        
        self.type = type_
    
    def to_str(self, db):
        return 'SETOF '+self.type.to_str(db)

class RowType(Code):
    def __init__(self, table):
        Code.__init__(self)
        
        self.table = table
    
    def to_str(self, db):
        return self.table.to_str(db)+'%ROWTYPE'

class ColType(Code):
    def __init__(self, col):
        Code.__init__(self)
        
        self.col = col
    
    def to_str(self, db):
        return self.col.to_str(db)+'%TYPE'

class ArrayType(Code):
    def __init__(self, elem_type):
        Code.__init__(self)
        elem_type = as_Code(elem_type)
        
        self.elem_type = elem_type
    
    def to_str(self, db):
        return self.elem_type.to_str(db)+'[]'

##### Functions

Function = Table
as_Function = as_Table

class InternalFunction(CustomCode): pass

#### Calls

class NamedArg(NamedCol):
    def __init__(self, name, value):
        NamedCol.__init__(self, name, value)
    
    def to_str(self, db):
        return Col.to_str(self, db)+' := '+self.code.to_str(db)

class FunctionCall(Code):
    def __init__(self, function, *args, **kw_args):
        '''
        @param args [Code|literal-value...] The function's arguments
        '''
        Code.__init__(self)
        
        function = as_Function(function)
        def filter_(arg): return remove_col_rename(as_Value(arg))
        args = map(filter_, args)
        args += [NamedArg(k, filter_(v)) for k, v in kw_args.iteritems()]
        
        self.function = function
        self.args = args
    
    def to_str(self, db):
        args_str = ', '.join((v.to_str(db) for v in self.args))
        return self.function.to_str(db)+'('+args_str+')'

def wrap_in_func(function, value):
    '''Wraps a value inside a function call.
    Propagates any column renaming to the returned value.
    '''
    return wrap(lambda v: FunctionCall(function, v), value)

def unwrap_func_call(func_call, check_name=None):
    '''Unwraps any function call to its first argument.
    Also removes any column renaming.
    '''
    func_call = remove_col_rename(func_call)
    if not isinstance(func_call, FunctionCall): return func_call
    
    if check_name != None:
        name = func_call.function.name
        assert name == None or name == check_name
    return func_call.args[0]

#### Definitions

class FunctionDef(Code):
    def __init__(self, function, return_type, body, params=[], modifiers=None):
        Code.__init__(self)
        
        return_type = as_Code(return_type)
        body = as_Code(body)
        
        self.function = function
        self.return_type = return_type
        self.body = body
        self.params = params
        self.modifiers = modifiers
    
    def to_str(self, db):
        params_str = (', '.join((p.to_str(db) for p in self.params)))
        str_ = '''\
CREATE FUNCTION '''+self.function.to_str(db)+'''('''+params_str+''')
RETURNS '''+self.return_type.to_str(db)+'''
LANGUAGE '''+self.body.lang+'''
'''
        if self.modifiers != None: str_ += self.modifiers+'\n'
        str_ += '''\
AS $$
'''+self.body.to_str(db)+'''
$$;
'''
        return str_

class FunctionParam(TypedCol):
    def __init__(self, name, type_, default=None, out=False):
        TypedCol.__init__(self, name, type_, default)
        
        self.out = out
    
    def to_str(self, db):
        str_ = TypedCol.to_str(self, db)
        if self.out: str_ = 'OUT '+str_
        return str_
    
    def to_Col(self): return Col(self.name)

### PL/pgSQL

class ReturnQuery(Code):
    def __init__(self, query):
        Code.__init__(self)
        
        query = as_Code(query)
        
        self.query = query
    
    def to_str(self, db):
        return 'RETURN QUERY\n'+strings.indent(self.query.to_str(db))+';\n'

## Exceptions

class BaseExcHandler(BasicObject):
    def to_str(self, db, body): raise NotImplementedError()
    
    def __repr__(self): return self.to_str(mockDb, '<body>')

suppress_exc = 'NULL;\n';

reraise_exc = 'RAISE USING ERRCODE = SQLSTATE, MESSAGE = SQLERRM;\n';

class ExcHandler(BaseExcHandler):
    def __init__(self, exc, handler=None):
        if handler != None: handler = as_Code(handler)
        
        self.exc = exc
        self.handler = handler
    
    def to_str(self, db, body):
        body = as_Code(body)
        
        if self.handler != None:
            handler_str = '\n'+strings.indent(self.handler.to_str(db), 2)
        else: handler_str = ' '+suppress_exc
        
        str_ = '''\
BEGIN
'''+strings.indent(body.to_str(db))+'''\
EXCEPTION
    WHEN '''+self.exc+''' THEN'''+handler_str+'''\
END;\
'''
        return str_

class NestedExcHandler(BaseExcHandler):
    def __init__(self, *handlers):
        '''
        @param handlers Sorted from outermost to innermost
        '''
        self.handlers = handlers
    
    def to_str(self, db, body):
        for handler in reversed(self.handlers): body = handler.to_str(db, body)
        return body

class ExcToWarning(Code):
    def __init__(self, return_):
        '''
        @param return_ Statement to return a default value in case of error
        '''
        Code.__init__(self)
        
        return_ = as_Code(return_)
        
        self.return_ = return_
    
    def to_str(self, db):
        return "RAISE WARNING '%', SQLERRM;\n"+self.return_.to_str(db)

unique_violation_handler = ExcHandler('unique_violation')

# Note doubled "\"s because inside Python string
plpythonu_error_handler = ExcHandler('internal_error', '''\
-- Handle PL/Python exceptions
DECLARE
    matches text[] := regexp_matches(SQLERRM,
        E'^(?:PL/Python: )?(\\\\w+): (.*)$'); -- .* also matches \\n
    exc_name text := matches[1];
    msg text := matches[2];
BEGIN
    /* Re-raise PL/Python exceptions with the PL/Python prefix removed.
    This allows the exception to be parsed like a native exception.
    Always raise as data_exception so it goes in the errors table. */
    IF exc_name IS NOT NULL THEN
        RAISE data_exception USING MESSAGE = msg;
    -- Re-raise non-PL/Python exceptions
    ELSE
        '''+reraise_exc+'''\
    END IF;
END;
''')

def data_exception_handler(handler):
    return ExcHandler('data_exception', handler)

row_var = Table('row')

class RowExcIgnore(Code):
    def __init__(self, row_type, select_query, with_row, cols=None,
        exc_handler=unique_violation_handler, row_var=row_var):
        '''
        @param row_type Ignored if a custom row_var is used.
        @pre If a custom row_var is used, it must already be defined.
        '''
        Code.__init__(self, lang='plpgsql')
        
        row_type = as_Code(row_type)
        select_query = as_Code(select_query)
        with_row = as_Code(with_row)
        row_var = as_Table(row_var)
        
        self.row_type = row_type
        self.select_query = select_query
        self.with_row = with_row
        self.cols = cols
        self.exc_handler = exc_handler
        self.row_var = row_var
    
    def to_str(self, db):
        if self.cols == None: row_vars = [self.row_var]
        else: row_vars = [Col(c.name, self.row_var) for c in self.cols]
        
        # Need an EXCEPTION block for each individual row because "When an error
        # is caught by an EXCEPTION clause, [...] all changes to persistent
        # database state within the block are rolled back."
        # This is unfortunate because "A block containing an EXCEPTION clause is
        # significantly more expensive to enter and exit than a block without
        # one."
        # (http://www.postgresql.org/docs/8.3/static/\
        # plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING)
        str_ = '''\
FOR '''+(', '.join((v.to_str(db) for v in row_vars)))+''' IN
'''+strings.indent(self.select_query.to_str(db))+'''\
LOOP
'''+strings.indent(self.exc_handler.to_str(db, self.with_row))+'''\
END LOOP;
'''
        if self.row_var == row_var:
            str_ = '''\
DECLARE
    '''+self.row_var.to_str(db)+''' '''+self.row_type.to_str(db)+''';
BEGIN
'''+strings.indent(str_)+'''\
END;
'''
        return str_

##### Casts

class Cast(FunctionCall):
    def __init__(self, type_, value):
        type_ = as_Code(type_)
        value = as_Value(value)
        
        # Most types cannot be cast directly to unknown
        if type_.to_str(mockDb) == 'unknown': value = Cast('text', value)
        
        self.type_ = type_
        self.value = value
    
    def to_str(self, db):
        return 'CAST('+self.value.to_str(db)+' AS '+self.type_.to_str(db)+')'

def cast_literal(value):
    if not is_literal(value): return value
    
    if util.is_str(value.value): value = Cast('text', value)
    return value

##### Conditions

class NotCond(Code):
    def __init__(self, cond):
        Code.__init__(self)
        
        if not isinstance(cond, Coalesce): cond = Coalesce(cond, False)
        
        self.cond = cond
    
    def to_str(self, db): return 'NOT '+self.cond.to_str(db)

class ColValueCond(Code):
    def __init__(self, col, value):
        Code.__init__(self)
        
        value = as_ValueCond(value)
        
        self.col = col
        self.value = value
    
    def to_str(self, db): return self.value.to_str(db, self.col)

def combine_conds(conds, keyword=None):
    '''
    @param keyword The keyword to add before the conditions, if any
    '''
    str_ = ''
    if keyword != None:
        if conds == []: whitespace = ''
        elif len(conds) == 1: whitespace = ' '
        else: whitespace = '\n'
        str_ += keyword+whitespace
    
    str_ += '\nAND '.join(conds)
    return str_

##### Condition column comparisons

class ValueCond(BasicObject):
    def __init__(self, value):
        value = remove_col_rename(as_Value(value))
        
        self.value = value
    
    def to_str(self, db, left_value):
        '''
        @param left_value The Code object that the condition is being applied on
        '''
        raise NotImplemented()
    
    def __repr__(self): return self.to_str(mockDb, '<left_value>')

class CompareCond(ValueCond):
    def __init__(self, value, operator='='):
        '''
        @param operator By default, compares NULL values literally. Use '~=' or
            '~!=' to pass NULLs through.
        '''
        ValueCond.__init__(self, value)
        self.operator = operator
    
    def to_str(self, db, left_value):
        left_value = remove_col_rename(as_Col(left_value))
        
        right_value = self.value
        
        # Parse operator
        operator = self.operator
        passthru_null_ref = [False]
        operator = strings.remove_prefix('~', operator, passthru_null_ref)
        neg_ref = [False]
        operator = strings.remove_prefix('!', operator, neg_ref)
        equals = operator.endswith('=') # also includes <=, >=
        
        # Handle nullable columns
        check_null = False
        if not passthru_null_ref[0]: # NULLs compare equal
            try: left_value = ensure_not_null(db, left_value)
            except ensure_not_null_excs: # fall back to alternate method
                check_null = equals and isinstance(right_value, Col)
            else:
                if isinstance(left_value, EnsureNotNull):
                    right_value = ensure_not_null(db, right_value,
                        left_value.type) # apply same function to both sides
        
        if equals and is_null(right_value): operator = 'IS'
        
        left = left_value.to_str(db)
        right = right_value.to_str(db)
        
        # Create str
        str_ = left+' '+operator+' '+right
        if check_null:
            str_ = '('+str_+' OR ('+left+' IS NULL AND '+right+' IS NULL))'
        if neg_ref[0]: str_ = 'NOT '+str_
        return str_

# Tells as_ValueCond() to assume a non-ValueCond is a literal value
assume_literal = object()

def as_ValueCond(value, default_table=assume_literal):
    if not isinstance(value, ValueCond):
        if default_table is not assume_literal:
            value = with_default_table(value, default_table)
        return CompareCond(value)
    else: return value

##### Joins

join_same = object() # tells Join the left and right columns have the same name

# Tells Join the left and right columns have the same name and are never NULL
join_same_not_null = object()

filter_out = object() # tells Join to filter out rows that match the join

class Join(BasicObject):
    def __init__(self, table, mapping={}, type_=None):
        '''
        @param mapping dict(right_table_col=left_table_col, ...)
            or [using_col...]
            * if left_table_col is join_same: left_table_col = right_table_col
              * Note that right_table_col must be a string
            * if left_table_col is join_same_not_null:
              left_table_col = right_table_col and both have NOT NULL constraint
              * Note that right_table_col must be a string
        @param type_ None (for plain join)|str (e.g. 'LEFT')|filter_out
            * filter_out: equivalent to 'LEFT' with the query filtered by
              `table_pkey IS NULL` (indicating no match)
        '''
        if util.is_str(table): table = Table(table)
        if lists.is_seq(mapping):
            mapping = dict(((c, join_same_not_null) for c in mapping))
        assert type_ == None or util.is_str(type_) or type_ is filter_out
        
        self.table = table
        self.mapping = mapping
        self.type_ = type_
    
    def to_str(self, db, left_table_):
        def join(entry):
            '''Parses non-USING joins'''
            right_table_col, left_table_col = entry
            
            # Switch order (right_table_col is on the left in the comparison)
            left = right_table_col
            right = left_table_col
            left_table = self.table
            right_table = left_table_
            
            # Parse left side
            left = with_default_table(left, left_table)
            
            # Parse special values
            left_on_right = Col(left.name, right_table)
            if right is join_same: right = left_on_right
            elif right is join_same_not_null:
                right = CompareCond(left_on_right, '~=')
            
            # Parse right side
            right = as_ValueCond(right, right_table)
            
            return right.to_str(db, left)
        
        # Create join condition
        type_ = self.type_
        joins = self.mapping
        if joins == {}: join_cond = None
        elif type_ is not filter_out and reduce(operator.and_,
            (v is join_same_not_null for v in joins.itervalues())):
            # all cols w/ USING, so can use simpler USING syntax
            cols = map(to_name_only_col, joins.iterkeys())
            join_cond = 'USING ('+(', '.join((c.to_str(db) for c in cols)))+')'
        else: join_cond = combine_conds(map(join, joins.iteritems()), 'ON')
        
        if isinstance(self.table, NamedTable): whitespace = '\n'
        else: whitespace = ' '
        
        # Create join
        if type_ is filter_out: type_ = 'LEFT'
        str_ = ''
        if type_ != None: str_ += type_+' '
        str_ += 'JOIN'+whitespace+self.table.to_str(db)
        if join_cond != None: str_ += whitespace+join_cond
        return str_
    
    def __repr__(self): return self.to_str(mockDb, '<left_table>')

##### Value exprs

all_cols = CustomCode('*')

default = CustomCode('DEFAULT')

row_count = FunctionCall(InternalFunction('COUNT'), all_cols)

class Coalesce(FunctionCall):
    def __init__(self, *args):
        FunctionCall.__init__(self, InternalFunction('COALESCE'), *args)

class Nullif(FunctionCall):
    def __init__(self, *args):
        FunctionCall.__init__(self, InternalFunction('NULLIF'), *args)

null = Literal(None)
null_as_str = Cast('text', null)

def to_text(value): return Coalesce(Cast('text', value), null_as_str)

# See <http://www.postgresql.org/docs/8.3/static/datatype-numeric.html>
null_sentinels = {
    'character varying': r'\N',
    'double precision': 'NaN',
    'integer': 2147483647,
    'text': r'\N',
    'date': 'infinity',
    'timestamp with time zone': 'infinity',
    'taxonrank': 'unknown',
}

class EnsureNotNull(Coalesce):
    def __init__(self, value, type_):
        if isinstance(type_, ArrayType): null = []
        else: null = null_sentinels[type_]
        Coalesce.__init__(self, as_Col(value), Cast(type_, null))
        
        self.type = type_
    
    def to_str(self, db):
        col = self.args[0]
        index_col_ = index_col(col)
        if index_col_ != None: return index_col_.to_str(db)
        return Coalesce.to_str(self, db)

#### Arrays

class ArrayMerge(FunctionCall):
    def __init__(self, sep, array):
        array = to_Array(array)
        FunctionCall.__init__(self, InternalFunction('array_to_string'), array,
            sep)

def merge_not_null(db, sep, values):
    return ArrayMerge(sep, map(to_text, values))

##### Table exprs

class Values(Code):
    def __init__(self, values):
        '''
        @param values [...]|[[...], ...] Can be one or multiple rows.
        '''
        Code.__init__(self)
        
        rows = values
        if len(values) >= 1 and not lists.is_seq(values[0]): # only one row
            rows = [values]
        for i, row in enumerate(rows):
            rows[i] = map(remove_col_rename, map(as_Value, row))
        
        self.rows = rows
    
    def to_str(self, db):
        return 'VALUES '+(', '.join((Tuple(*r).to_str(db) for r in self.rows)))

def NamedValues(name, cols, values):
    '''
    @param cols None|[...]
    @post `cols` will be changed to Col objects with the table set to `name`.
    '''
    table = NamedTable(name, Values(values), cols)
    if cols != None: set_cols_table(table, cols)
    return table

##### Database structure

def is_nullable(db, value):
    if not is_table_col(value): return is_null(value)
    try: return db.col_info(value).nullable
    except NoUnderlyingTableException: return True # not a table column

text_types = set(['character varying', 'text'])

def is_text_type(type_): return type_ in text_types

def is_text_col(db, col): return is_text_type(db.col_info(col).type)

def canon_type(type_):
    if type_ in text_types: return 'text'
    else: return type_

ensure_not_null_excs = (NoUnderlyingTableException, KeyError)

def ensure_not_null(db, col, type_=None):
    '''
    @param col If type_ is not set, must have an underlying column.
    @param type_ If set, overrides the underlying column's type and casts the
        column to it if needed.
    @return EnsureNotNull|Col
    @throws ensure_not_null_excs
    '''
    col = remove_col_rename(col)
    
    try: col_type = db.col_info(underlying_col(col)).type
    except NoUnderlyingTableException:
        if type_ == None and is_null(col): raise # NULL has no type
    else:
        if type_ == None: type_ = col_type
        elif type_ != col_type: col = Cast(type_, col)
    
    if is_nullable(db, col):
        try: col = EnsureNotNull(col, type_)
        except KeyError, e:
            # Warn of no null sentinel for type, even if caller catches error
            warnings.warn(UserWarning(exc.str_(e)))
            raise
    
    return col

def try_mk_not_null(db, value):
    '''
    Warning: This function does not guarantee that its result is NOT NULL.
    '''
    try: return ensure_not_null(db, value)
    except ensure_not_null_excs: return value

##### Expression transforming

true_expr = 'true'
false_expr = 'false'

true_re = true_expr
false_re = false_expr
bool_re = r'(?:'+true_re+r'|'+false_re+r')'
atom_re = r'(?:'+bool_re+r'|\([^()]*\)'+r')'

def logic_op_re(op, value_re, expr_re=''):
    op_re = ' '+op+' '
    return '(?:'+expr_re+op_re+value_re+'|'+value_re+op_re+expr_re+')'

not_re = r'\bNOT '
not_false_re = not_re+false_re+r'\b'
not_true_re = not_re+true_re+r'\b'
and_false_re = logic_op_re('AND', false_re, atom_re)
and_false_not_true_re = '(?:'+not_true_re+'|'+and_false_re+')'
and_true_re = logic_op_re('AND', true_re)
or_re = logic_op_re('OR', bool_re)
or_and_true_re = '(?:'+and_true_re+'|'+or_re+')'

def simplify_parens(expr):
    return regexp.sub_nested(r'\(('+atom_re+')\)', r'\1', expr)

def simplify_recursive(sub_func, expr):
    '''
    @param sub_func See regexp.sub_recursive() sub_func param
    '''
    return regexp.sub_recursive(lambda s: sub_func(simplify_parens(s)), expr)
        # simplify_parens() is also done at end in final iteration

def simplify_expr(expr):
    def simplify_logic_ops(expr):
        total_n = 0
        expr, n = re.subn(not_false_re, true_re, expr)
        total_n += n
        expr, n = re.subn(and_false_not_true_re, false_expr, expr)
        total_n += n
        expr, n = re.subn(or_and_true_re, r'', expr)
        total_n += n
        return expr, total_n
    
    expr = expr.replace('NULL IS NULL', true_expr)
    expr = expr.replace('NULL IS NOT NULL', false_expr)
    expr = simplify_recursive(simplify_logic_ops, expr)
    return expr

name_re = r'(?:\w+|(?:"[^"]*")+)'

def parse_expr_col(str_):
    match = re.match(r'^\('+name_re+r'\(('+name_re+r').*\)\)$', str_)
    if match: str_ = match.group(1)
    return unesc_name(str_)

def map_expr(db, expr, mapping, in_cols_found=None):
    '''Replaces output columns with input columns in an expression.
    @param in_cols_found If set, will be filled in with the expr's (input) cols
    '''
    for out, in_ in mapping.iteritems():
        orig_expr = expr
        out = to_name_only_col(out)
        in_str = to_name_only_col(remove_col_rename(in_)).to_str(db)
        
        # Replace out both with and without quotes
        expr = expr.replace(out.to_str(db), in_str)
        expr = re.sub(r'(?<!["\'\.=\[])\b'+out.name+r'\b(?!["\',\.=\]])',
            in_str, expr)
        
        if in_cols_found != None and expr != orig_expr: # replaced something
            in_cols_found.append(in_)
    
    return simplify_expr(expr)
