Project

General

Profile

« Previous | Next » 

Revision 2704

sql.py: Added cast()

View differences:

lib/sql.py
752 752
        into=into)
753 753
    return dict(items)
754 754

  
755
def cast(db, type_, col, errors_table=None):
756
    '''Casts an (unrenamed) column or value.
757
    If a column, converts any errors to warnings.
758
    @param col sql_gen.Col|sql_gen.Literal
759
    @param errors_table None|sql_gen.Table|str
760
        If set and col is a column with srcs, saves any errors in this table,
761
        using column's srcs attr as the source columns.
762
    '''
763
    if isinstance(col, sql_gen.Literal): # literal value, so just cast
764
        return sql_gen.CustomCode(col.to_str(db)+'::'+type_)
765
    
766
    assert isinstance(col, sql_gen.Col)
767
    assert not isinstance(col, sql_gen.NamedCol)
768
    
769
    save_errors = errors_table != None and col.srcs != ()
770
    if save_errors:
771
        errors_table = sql_gen.as_Table(errors_table)
772
        srcs = map(sql_gen.to_name_only_col, col.srcs)
773
        function_name = str(sql_gen.FunctionCall(type_, srcs))
774
    else: function_name = type_
775
    function = sql_gen.TempFunction(function_name, db.autocommit)
776
    
777
    query = '''\
778
CREATE FUNCTION '''+function.to_str(db)+'''(value text)
779
RETURNS '''+type_+'''
780
LANGUAGE plpgsql
781
'''
782
    if not save_errors: query += 'IMMUTABLE\n'
783
    query += '''\
784
STRICT
785
AS $$
786
BEGIN
787
    /* The explicit cast to the return type is needed to make the cast happen
788
    inside the try block. (Implicit casts to the return type happen at the end
789
    of the function, outside any block.) */
790
    RETURN value::'''+type_+''';
791
EXCEPTION
792
    WHEN data_exception THEN
793
'''
794
    if save_errors:
795
        col_names = map(sql_gen.Literal, srcs)
796
        query += '''\
797
        -- Save error in errors table.
798
        BEGIN
799
            INSERT INTO '''+errors_table.to_str(db)+'''
800
            ("column", value, error)
801
            (VALUES '''+(', '.join(('('+c.to_str(db)+')' for c in col_names))
802
                )+''') AS c
803
            CROSS JOIN
804
            (VALUES (value, SQLERRM)) AS v
805
            ;
806
        EXCEPTION
807
            WHEN unique_violation THEN NULL; -- ignore duplicate key
808
        END;
809
        
810
'''
811
    query += '''\
812
        RAISE WARNING '%', SQLERRM;
813
        RETURN NULL;
814
END;
815
$$;
816
'''
817
    try:
818
        run_query(db, query, recover=True, cacheable=True,
819
            log_ignore_excs=(DuplicateFunctionException,))
820
    except DuplicateFunctionException: pass # function already existed
821
    
822
    return sql_gen.FunctionCall(function, col)
823

  
755 824
##### Database structure queries
756 825

  
757 826
def table_row_count(db, table, recover=None):

Also available in: Unified diff