From 995931639cc8dc07e4fcbb8908c30f7371cc5e91 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sun, 8 Oct 2023 13:35:58 +0530 Subject: [PATCH 1/3] Added support for comparing symbolic expressions --- src/libasr/pass/replace_symbolic.cpp | 143 +++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 044d944ecd..59bdf45f81 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -626,6 +626,96 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name); } + ASR::symbol_t* declare_basic_eq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { + std::string name = "basic_eq"; + symbolic_dependencies.push_back(name); + if (!module_scope->get_symbol(name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + ASR::symbol_t* arg3 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg3); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); + ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, + return_var, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* symbol = ASR::down_cast(subrout); + module_scope->add_symbol(s2c(al, name), symbol); + } + return module_scope->get_symbol(name); + } + + ASR::symbol_t* declare_basic_neq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { + std::string name = "basic_neq"; + symbolic_dependencies.push_back(name); + if (!module_scope->get_symbol(name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + ASR::symbol_t* arg3 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg3); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); + ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, + return_var, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* symbol = ASR::down_cast(subrout); + module_scope->add_symbol(s2c(al, name), symbol); + } + return module_scope->get_symbol(name); + } + ASR::expr_t* process_attributes(Allocator &al, const Location &loc, ASR::expr_t* expr, SymbolTable* module_scope) { if (ASR::is_a(*expr)) { @@ -772,6 +862,33 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_value)) { + ASR::SymbolicCompare_t *s = ASR::down_cast(x.m_value); + if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) { + ASR::symbol_t* sym = nullptr; + if (s->m_op == ASR::cmpopType::Eq) { + sym = declare_basic_eq_function(al, x.base.base.loc, module_scope); + } else { + sym = declare_basic_neq_function(al, x.base.base.loc, module_scope); + } + ASR::expr_t* value1 = handle_argument(al, x.base.base.loc, s->m_left); + ASR::expr_t* value2 = handle_argument(al, x.base.base.loc, s->m_right); + + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = x.base.base.loc; + call_arg1.m_value = value1; + call_args.push_back(al, call_arg1); + call_arg2.loc = x.base.base.loc; + call_arg2.m_value = value2; + call_args.push_back(al, call_arg2); + + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, + sym, sym, call_args.p, call_args.n, ASRUtils::TYPE(ASR::make_Logical_t(al, x.base.base.loc, 4)), nullptr, nullptr)); + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr)); + pass_result.push_back(al, stmt); + } } } @@ -905,6 +1022,32 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*val)) { + ASR::SymbolicCompare_t *s = ASR::down_cast(val); + if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) { + ASR::symbol_t* sym = nullptr; + if (s->m_op == ASR::cmpopType::Eq) { + sym = declare_basic_eq_function(al, x.base.base.loc, module_scope); + } else { + sym = declare_basic_neq_function(al, x.base.base.loc, module_scope); + } + ASR::expr_t* value1 = handle_argument(al, x.base.base.loc, s->m_left); + ASR::expr_t* value2 = handle_argument(al, x.base.base.loc, s->m_right); + + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = x.base.base.loc; + call_arg1.m_value = value1; + call_args.push_back(al, call_arg1); + call_arg2.loc = x.base.base.loc; + call_arg2.m_value = value2; + call_args.push_back(al, call_arg2); + + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, + sym, sym, call_args.p, call_args.n, ASRUtils::TYPE(ASR::make_Logical_t(al, x.base.base.loc, 4)), nullptr, nullptr)); + print_tmp.push_back(function_call); + } } else { print_tmp.push_back(x.m_values[i]); } From b50013baaeb880a08cd5d45034c7379b8aa746e3 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sun, 8 Oct 2023 13:59:14 +0530 Subject: [PATCH 2/3] Added tests --- integration_tests/symbolics_02.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/integration_tests/symbolics_02.py b/integration_tests/symbolics_02.py index f22a432606..e7f562f664 100644 --- a/integration_tests/symbolics_02.py +++ b/integration_tests/symbolics_02.py @@ -1,9 +1,11 @@ -from sympy import Symbol +from sympy import Symbol, pi from lpython import S def test_symbolic_operations(): x: S = Symbol('x') y: S = Symbol('y') + p1: S = pi + p2: S = pi # Addition z: S = x + y @@ -37,4 +39,19 @@ def test_symbolic_operations(): assert(c == S(0)) print(c) + # Comparison + b1: bool = p1 == p2 + print(b1) + assert(b1 == True) + b2: bool = p1 == pi + print(b2) + assert(b2 == True) + b3: bool = p1 != x + print(b3) + assert(b3 == True) + b4: bool = pi != Symbol("x") + print(b4) + assert(b4 == True) + + test_symbolic_operations() From 4ec4a76fe3164fbe04c94244de98f9e9b177ba93 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sun, 8 Oct 2023 14:05:34 +0530 Subject: [PATCH 3/3] Improved test cases --- integration_tests/symbolics_02.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/integration_tests/symbolics_02.py b/integration_tests/symbolics_02.py index e7f562f664..74f4a4af35 100644 --- a/integration_tests/symbolics_02.py +++ b/integration_tests/symbolics_02.py @@ -43,15 +43,15 @@ def test_symbolic_operations(): b1: bool = p1 == p2 print(b1) assert(b1 == True) - b2: bool = p1 == pi + b2: bool = p1 != pi print(b2) - assert(b2 == True) + assert(b2 == False) b3: bool = p1 != x print(b3) assert(b3 == True) - b4: bool = pi != Symbol("x") + b4: bool = pi == Symbol("x") print(b4) - assert(b4 == True) + assert(b4 == False) test_symbolic_operations()