Revision 2704
Added by Aaron Marcuse-Kubitza over 12 years ago
lib/sql.py | ||
---|---|---|
752 | 752 |
into=into) |
753 | 753 |
return dict(items) |
754 | 754 |
|
755 |
def cast(db, type_, col, errors_table=None): |
|
756 |
'''Casts an (unrenamed) column or value. |
|
757 |
If a column, converts any errors to warnings. |
|
758 |
@param col sql_gen.Col|sql_gen.Literal |
|
759 |
@param errors_table None|sql_gen.Table|str |
|
760 |
If set and col is a column with srcs, saves any errors in this table, |
|
761 |
using column's srcs attr as the source columns. |
|
762 |
''' |
|
763 |
if isinstance(col, sql_gen.Literal): # literal value, so just cast |
|
764 |
return sql_gen.CustomCode(col.to_str(db)+'::'+type_) |
|
765 |
|
|
766 |
assert isinstance(col, sql_gen.Col) |
|
767 |
assert not isinstance(col, sql_gen.NamedCol) |
|
768 |
|
|
769 |
save_errors = errors_table != None and col.srcs != () |
|
770 |
if save_errors: |
|
771 |
errors_table = sql_gen.as_Table(errors_table) |
|
772 |
srcs = map(sql_gen.to_name_only_col, col.srcs) |
|
773 |
function_name = str(sql_gen.FunctionCall(type_, srcs)) |
|
774 |
else: function_name = type_ |
|
775 |
function = sql_gen.TempFunction(function_name, db.autocommit) |
|
776 |
|
|
777 |
query = '''\ |
|
778 |
CREATE FUNCTION '''+function.to_str(db)+'''(value text) |
|
779 |
RETURNS '''+type_+''' |
|
780 |
LANGUAGE plpgsql |
|
781 |
''' |
|
782 |
if not save_errors: query += 'IMMUTABLE\n' |
|
783 |
query += '''\ |
|
784 |
STRICT |
|
785 |
AS $$ |
|
786 |
BEGIN |
|
787 |
/* The explicit cast to the return type is needed to make the cast happen |
|
788 |
inside the try block. (Implicit casts to the return type happen at the end |
|
789 |
of the function, outside any block.) */ |
|
790 |
RETURN value::'''+type_+'''; |
|
791 |
EXCEPTION |
|
792 |
WHEN data_exception THEN |
|
793 |
''' |
|
794 |
if save_errors: |
|
795 |
col_names = map(sql_gen.Literal, srcs) |
|
796 |
query += '''\ |
|
797 |
-- Save error in errors table. |
|
798 |
BEGIN |
|
799 |
INSERT INTO '''+errors_table.to_str(db)+''' |
|
800 |
("column", value, error) |
|
801 |
(VALUES '''+(', '.join(('('+c.to_str(db)+')' for c in col_names)) |
|
802 |
)+''') AS c |
|
803 |
CROSS JOIN |
|
804 |
(VALUES (value, SQLERRM)) AS v |
|
805 |
; |
|
806 |
EXCEPTION |
|
807 |
WHEN unique_violation THEN NULL; -- ignore duplicate key |
|
808 |
END; |
|
809 |
|
|
810 |
''' |
|
811 |
query += '''\ |
|
812 |
RAISE WARNING '%', SQLERRM; |
|
813 |
RETURN NULL; |
|
814 |
END; |
|
815 |
$$; |
|
816 |
''' |
|
817 |
try: |
|
818 |
run_query(db, query, recover=True, cacheable=True, |
|
819 |
log_ignore_excs=(DuplicateFunctionException,)) |
|
820 |
except DuplicateFunctionException: pass # function already existed |
|
821 |
|
|
822 |
return sql_gen.FunctionCall(function, col) |
|
823 |
|
|
755 | 824 |
##### Database structure queries |
756 | 825 |
|
757 | 826 |
def table_row_count(db, table, recover=None): |
Also available in: Unified diff
sql.py: Added cast()