diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 04be6abf67..d6ca666674 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -480,6 +480,7 @@ RUN(NAME expr_18 FAIL LABELS cpython llvm c) RUN(NAME expr_19 LABELS cpython llvm c) RUN(NAME expr_20 LABELS cpython llvm c) RUN(NAME expr_21 LABELS cpython llvm c) +RUN(NAME expr_22 LABELS cpython llvm c) RUN(NAME expr_01u LABELS cpython llvm c NOFAST) RUN(NAME expr_02u LABELS cpython llvm c NOFAST) diff --git a/integration_tests/expr_22.py b/integration_tests/expr_22.py new file mode 100644 index 0000000000..5b61bebc35 --- /dev/null +++ b/integration_tests/expr_22.py @@ -0,0 +1,10 @@ +from lpython import f64 + +# test issue 1671 +def test_fast_fma() -> f64: + a : f64 = 5.00 + a = a + a * 10.00 + assert abs(a - 55.00) < 1e-12 + return a + +print(test_fast_fma()) diff --git a/src/libasr/pass/fma.cpp b/src/libasr/pass/fma.cpp index ded6561ba5..ae1f49b8ec 100644 --- a/src/libasr/pass/fma.cpp +++ b/src/libasr/pass/fma.cpp @@ -118,8 +118,7 @@ class FMAVisitor : public PassUtils::SkipOptimizationFunctionVisitor } fma_var = PassUtils::get_fma(other_expr, first_arg, second_arg, - al, unit, pass_options, current_scope, x.base.base.loc, - [&](const std::string &msg, const Location &) { throw LCompilersException(msg); }); + al, unit, x.base.base.loc); from_fma = false; } @@ -170,6 +169,8 @@ void pass_replace_fma(Allocator &al, ASR::TranslationUnit_t &unit, const LCompilers::PassOptions& pass_options) { FMAVisitor v(al, unit, pass_options); v.visit_TranslationUnit(unit); + PassUtils::UpdateDependenciesVisitor u(al); + u.visit_TranslationUnit(unit); } diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index 637ee58ed0..5fdc2a4d7a 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -41,6 +41,7 @@ enum class IntrinsicScalarFunctions : int64_t { Exp, Exp2, Expm1, + FMA, ListIndex, Partition, ListReverse, @@ -93,6 +94,7 @@ inline std::string get_intrinsic_name(int x) { INTRINSIC_NAME_CASE(Exp) INTRINSIC_NAME_CASE(Exp2) INTRINSIC_NAME_CASE(Expm1) + INTRINSIC_NAME_CASE(FMA) INTRINSIC_NAME_CASE(ListIndex) INTRINSIC_NAME_CASE(Partition) INTRINSIC_NAME_CASE(ListReverse) @@ -1281,6 +1283,82 @@ namespace Sign { } // namespace Sign +namespace FMA { + + static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, diag::Diagnostics& diagnostics) { + ASRUtils::require_impl(x.n_args == 3, + "ASR Verify: Call to FMA must have exactly 3 arguments", + x.base.base.loc, diagnostics); + ASR::ttype_t *type1 = ASRUtils::expr_type(x.m_args[0]); + ASR::ttype_t *type2 = ASRUtils::expr_type(x.m_args[1]); + ASR::ttype_t *type3 = ASRUtils::expr_type(x.m_args[2]); + ASRUtils::require_impl((is_real(*type1) && is_real(*type2) && is_real(*type3)), + "ASR Verify: Arguments to FMA must be of real type", + x.base.base.loc, diagnostics); + } + + static ASR::expr_t *eval_FMA(Allocator &al, const Location &loc, + ASR::ttype_t* t1, Vec &args) { + double a = ASR::down_cast(args[0])->m_r; + double b = ASR::down_cast(args[1])->m_r; + double c = ASR::down_cast(args[2])->m_r; + return make_ConstantWithType(make_RealConstant_t, a + b*c, t1, loc); + } + + static inline ASR::asr_t* create_FMA(Allocator& al, const Location& loc, + Vec& args, + const std::function err) { + if (args.size() != 3) { + err("Intrinsic FMA function accepts exactly 3 arguments", loc); + } + ASR::ttype_t *type1 = ASRUtils::expr_type(args[0]); + ASR::ttype_t *type2 = ASRUtils::expr_type(args[1]); + ASR::ttype_t *type3 = ASRUtils::expr_type(args[2]); + if (!ASRUtils::is_real(*type1) || !ASRUtils::is_real(*type2) || !ASRUtils::is_real(*type3)) { + err("Argument of the FMA function must be Real", + args[0]->base.loc); + } + ASR::expr_t *m_value = nullptr; + if (all_args_evaluated(args)) { + Vec arg_values; arg_values.reserve(al, 3); + arg_values.push_back(al, expr_value(args[0])); + arg_values.push_back(al, expr_value(args[1])); + arg_values.push_back(al, expr_value(args[2])); + m_value = eval_FMA(al, loc, expr_type(args[0]), arg_values); + } + return ASR::make_IntrinsicScalarFunction_t(al, loc, + static_cast(IntrinsicScalarFunctions::FMA), + args.p, args.n, 0, ASRUtils::expr_type(args[0]), m_value); + } + + static inline ASR::expr_t* instantiate_FMA(Allocator &al, const Location &loc, + SymbolTable *scope, Vec& arg_types, ASR::ttype_t *return_type, + Vec& new_args, int64_t /*overload_id*/, + ASR::expr_t* compile_time_value) { + if (compile_time_value) { + return compile_time_value; + } + declare_basic_variables("_lcompilers_optimization_fma_" + type_to_str_python(arg_types[0])); + fill_func_arg("a", arg_types[0]); + fill_func_arg("b", arg_types[0]); + fill_func_arg("c", arg_types[0]); + auto result = declare(fn_name, return_type, ReturnVar); + /* + * result = a + b*c + */ + + ASR::expr_t *op1 = b.ElementalMul(args[1], args[2], loc); + body.push_back(al, b.Assignment(result, + b.ElementalAdd(args[0], op1, loc))); + + ASR::symbol_t *f_sym = make_Function_t(fn_name, fn_symtab, dep, args, + body, result, Source, Implementation, nullptr); + scope->add_symbol(fn_name, f_sym); + return b.Call(f_sym, new_args, return_type, nullptr); + } + +} // namespace FMA + #define create_exp_macro(X, stdeval) \ namespace X { \ static inline ASR::expr_t* eval_##X(Allocator &al, const Location &loc, \ @@ -2314,6 +2392,8 @@ namespace IntrinsicScalarFunctionRegistry { {nullptr, &UnaryIntrinsicFunction::verify_args}}, {static_cast(IntrinsicScalarFunctions::Expm1), {nullptr, &UnaryIntrinsicFunction::verify_args}}, + {static_cast(IntrinsicScalarFunctions::FMA), + {&FMA::instantiate_FMA, &FMA::verify_args}}, {static_cast(IntrinsicScalarFunctions::Abs), {&Abs::instantiate_Abs, &Abs::verify_args}}, {static_cast(IntrinsicScalarFunctions::Partition), @@ -2400,6 +2480,8 @@ namespace IntrinsicScalarFunctionRegistry { "exp"}, {static_cast(IntrinsicScalarFunctions::Exp2), "exp2"}, + {static_cast(IntrinsicScalarFunctions::FMA), + "fma"}, {static_cast(IntrinsicScalarFunctions::Expm1), "expm1"}, {static_cast(IntrinsicScalarFunctions::ListIndex), @@ -2474,6 +2556,7 @@ namespace IntrinsicScalarFunctionRegistry { {"exp", {&Exp::create_Exp, &Exp::eval_Exp}}, {"exp2", {&Exp2::create_Exp2, &Exp2::eval_Exp2}}, {"expm1", {&Expm1::create_Expm1, &Expm1::eval_Expm1}}, + {"fma", {&FMA::create_FMA, &FMA::eval_FMA}}, {"list.index", {&ListIndex::create_ListIndex, &ListIndex::eval_list_index}}, {"list.reverse", {&ListReverse::create_ListReverse, &ListReverse::eval_list_reverse}}, {"list.pop", {&ListPop::create_ListPop, &ListPop::eval_list_pop}}, diff --git a/src/libasr/pass/pass_utils.cpp b/src/libasr/pass/pass_utils.cpp index 526746a540..75a1f5394c 100644 --- a/src/libasr/pass/pass_utils.cpp +++ b/src/libasr/pass/pass_utils.cpp @@ -666,11 +666,17 @@ namespace LCompilers { } ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2, - Allocator& al, ASR::TranslationUnit_t& unit, LCompilers::PassOptions& pass_options, - SymbolTable*& current_scope, Location& loc, - const std::function err) { - ASR::symbol_t *v = import_generic_procedure("fma", "lfortran_intrinsic_optimization", - al, unit, pass_options, current_scope, arg0->base.loc); + Allocator& al, ASR::TranslationUnit_t& unit, Location& loc){ + + ASRUtils::impl_function instantiate_function = + ASRUtils::IntrinsicScalarFunctionRegistry::get_instantiate_function( + static_cast(ASRUtils::IntrinsicScalarFunctions::FMA)); + Vec arg_types; + ASR::ttype_t* type = ASRUtils::expr_type(arg0); + arg_types.reserve(al, 3); + arg_types.push_back(al, ASRUtils::expr_type(arg0)); + arg_types.push_back(al, ASRUtils::expr_type(arg1)); + arg_types.push_back(al, ASRUtils::expr_type(arg2)); Vec args; args.reserve(al, 3); ASR::call_arg_t arg0_, arg1_, arg2_; @@ -680,9 +686,9 @@ namespace LCompilers { args.push_back(al, arg1_); arg2_.loc = arg2->base.loc, arg2_.m_value = arg2; args.push_back(al, arg2_); - return ASRUtils::EXPR( - ASRUtils::symbol_resolve_external_generic_procedure_without_eval( - loc, v, args, current_scope, al, err)); + return instantiate_function(al, loc, + unit.m_global_scope, arg_types, type, args, 0, + nullptr); } ASR::symbol_t* insert_fallback_vector_copy(Allocator& al, ASR::TranslationUnit_t& unit, diff --git a/src/libasr/pass/pass_utils.h b/src/libasr/pass/pass_utils.h index e0f0cf0083..c8bf786b99 100644 --- a/src/libasr/pass/pass_utils.h +++ b/src/libasr/pass/pass_utils.h @@ -90,9 +90,7 @@ namespace LCompilers { ASR::intentType var_intent=ASR::intentType::Local); ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2, - Allocator& al, ASR::TranslationUnit_t& unit, LCompilers::PassOptions& pass_options, - SymbolTable*& current_scope,Location& loc, - const std::function err); + Allocator& al, ASR::TranslationUnit_t& unit, Location& loc); ASR::expr_t* get_sign_from_value(ASR::expr_t* arg0, ASR::expr_t* arg1, Allocator& al, ASR::TranslationUnit_t& unit,