From fa208c631e88315c79b8d61f10db025e16561038 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sun, 4 Feb 2024 14:31:21 +0530 Subject: [PATCH 1/5] Fixing Symbolic List assignment --- src/libasr/pass/replace_symbolic.cpp | 119 +++++++++++++++++++++++---- 1 file changed, 101 insertions(+), 18 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index cf3da1acf1..9bed5df1cb 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -147,6 +147,14 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor call_args; call_args.reserve(al, 1); + return FunctionCall(loc, basic_new_heap_sym, {}, + ASRUtils::TYPE(ASR::make_CPtr_t(al, loc))); + } + ASR::stmt_t* basic_get_args(const Location& loc, ASR::expr_t *x, ASR::expr_t *y) { ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)); ASR::symbol_t* basic_get_args_sym = create_bindc_function(loc, @@ -323,8 +331,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor((ASR::asr_t*)&xx)); } @@ -357,13 +365,13 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitortype == ASR::ttypeType::List) { ASR::List_t* list = ASR::down_cast(xx.m_type); if (list->m_type->type == ASR::ttypeType::SymbolicExpression){ + std::string var_name = xx.m_name; + std::string placeholder = "_" + std::string(var_name); + ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc)); ASR::ttype_t* list_type = ASRUtils::TYPE(ASR::make_List_t(al, xx.base.base.loc, CPtr_type)); - xx.m_type = list_type; + + if(xx.m_intent == ASR::intentType::Local){ + ASR::symbol_t* sym2 = ASR::down_cast( + ASR::make_Variable_t(al, xx.base.base.loc, current_scope, + s2c(al, placeholder), nullptr, 0, + xx.m_intent, nullptr, + nullptr, xx.m_storage, + list_type, nullptr, xx.m_abi, + xx.m_access, xx.m_presence, + xx.m_value_attr)); + + current_scope->add_symbol(s2c(al, placeholder), sym2); + xx.m_type = list_type; + } } } } @@ -548,21 +572,80 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(x.m_value); if (list_constant->m_type->type == ASR::ttypeType::List) { ASR::List_t* list = ASR::down_cast(list_constant->m_type); - if (list->m_type->type == ASR::ttypeType::SymbolicExpression){ - Vec temp_list; - temp_list.reserve(al, list_constant->n_args + 1); - for (size_t i = 0; i < list_constant->n_args; ++i) { - ASR::expr_t* value = handle_argument(al, x.base.base.loc, list_constant->m_args[i]); - temp_list.push_back(al, value); + if (list->m_type->type == ASR::ttypeType::SymbolicExpression){ + if(ASR::is_a(*x.m_target)) { + ASR::symbol_t *v = ASR::down_cast(x.m_target)->m_v; + if (ASR::is_a(*v)) { + ASRUtils::ASRBuilder b(al, x.base.base.loc); + ASR::ttype_t* CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)); + ASR::ttype_t* list_type = ASRUtils::TYPE(ASR::make_List_t(al, x.base.base.loc, CPtr_type)); + ASR::Variable_t *list_variable = ASR::down_cast(v); + std::string list_name = list_variable->m_name; + std::string placeholder = "_" + std::string(list_name); + ASR::symbol_t* placeholder_sym = current_scope->get_symbol(placeholder); + ASR::expr_t* placeholder_target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, placeholder_sym)); + + Vec temp_list1, temp_list2; + temp_list1.reserve(al, list_constant->n_args + 1); + temp_list2.reserve(al, list_constant->n_args + 1); + + for (size_t i = 0; i < list_constant->n_args; ++i) { + ASR::expr_t* value = handle_argument(al, x.base.base.loc, list_constant->m_args[i]); + temp_list1.push_back(al, value); + } + + ASR::expr_t* temp_list_const1 = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, temp_list1.p, + temp_list1.size(), list_type)); + ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, placeholder_target, temp_list_const1, nullptr)); + pass_result.push_back(al, stmt1); + + ASR::expr_t* temp_list_const2 = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, temp_list2.p, + temp_list2.size(), list_type)); + ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, temp_list_const2, nullptr)); + pass_result.push_back(al, stmt2); + + std::string symbolic_list_index = current_scope->get_unique_name("symbolic_list_index"); + ASR::ttype_t* int32_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 4)); + ASR::symbol_t* index_sym = ASR::down_cast( + ASR::make_Variable_t(al, x.base.base.loc, current_scope, s2c(al, symbolic_list_index), + nullptr, 0, ASR::intentType::Local, nullptr, nullptr, ASR::storage_typeType::Default, + int32_type, nullptr, ASR::abiType::Source, ASR::Public, ASR::presenceType::Required, false)); + current_scope->add_symbol(symbolic_list_index, index_sym); + ASR::expr_t* index = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, index_sym)); + ASR::stmt_t* stmt3 = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, index, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, 0, int32_type)), nullptr)); + pass_result.push_back(al, stmt3); + + std::string block_name = current_scope->get_unique_name("block"); + SymbolTable* block_symtab = al.make_new(current_scope); + char *tmp_var_name = s2c(al, "tmp"); + ASR::expr_t* tmp_var = b.Variable(block_symtab, tmp_var_name, CPtr_type, + ASR::intentType::Local, ASR::abiType::Source, false); + Vec block_body; block_body.reserve(al, 1); + ASR::stmt_t* block_stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, tmp_var, + basic_new_heap(x.base.base.loc), nullptr)); + block_body.push_back(al, block_stmt1); + ASR::stmt_t* block_stmt2 = ASRUtils::STMT(ASR::make_ListAppend_t(al, x.base.base.loc, x.m_target, tmp_var)); + block_body.push_back(al, block_stmt2); + block_body.push_back(al, basic_assign(x.base.base.loc, ASRUtils::EXPR(ASR::make_ListItem_t(al, + x.base.base.loc, x.m_target, index, CPtr_type, nullptr)), ASRUtils::EXPR(ASR::make_ListItem_t(al, + x.base.base.loc, placeholder_target, index, CPtr_type, nullptr)))); + ASR::symbol_t* block = ASR::down_cast(ASR::make_Block_t(al, x.base.base.loc, + block_symtab, s2c(al, block_name), block_body.p, block_body.n)); + current_scope->add_symbol(block_name, block); + ASR::stmt_t* block_call = ASRUtils::STMT(ASR::make_BlockCall_t( + al, x.base.base.loc, -1, block)); + std::vector do_loop_body; + do_loop_body.push_back(block_call); + ASR::stmt_t* stmt4 = b.DoLoop(index, ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, 0, int32_type)), + ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, x.base.base.loc, + ASRUtils::EXPR(ASR::make_ListLen_t(al, x.base.base.loc, placeholder_target, int32_type, nullptr)), ASR::binopType::Sub, + ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, 1, int32_type)), int32_type, nullptr)), + do_loop_body, ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, 1, int32_type))); + pass_result.push_back(al, stmt4); + } } - - ASR::ttype_t* type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)); - ASR::ttype_t* list_type = ASRUtils::TYPE(ASR::make_List_t(al, x.base.base.loc, type)); - ASR::expr_t* temp_list_const = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, temp_list.p, - temp_list.size(), list_type)); - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, temp_list_const, nullptr)); - pass_result.push_back(al, stmt); } } } else if (ASR::is_a(*x.m_value)) { From 8b8109186504d667277f2b69978464041f7b442c Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sun, 4 Feb 2024 15:53:42 +0530 Subject: [PATCH 2/5] Fixed failing test --- src/libasr/pass/replace_symbolic.cpp | 30 +++++++++++----------------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 9bed5df1cb..71410bf856 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -388,25 +388,9 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitortype == ASR::ttypeType::List) { ASR::List_t* list = ASR::down_cast(xx.m_type); if (list->m_type->type == ASR::ttypeType::SymbolicExpression){ - std::string var_name = xx.m_name; - std::string placeholder = "_" + std::string(var_name); - ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc)); ASR::ttype_t* list_type = ASRUtils::TYPE(ASR::make_List_t(al, xx.base.base.loc, CPtr_type)); - - if(xx.m_intent == ASR::intentType::Local){ - ASR::symbol_t* sym2 = ASR::down_cast( - ASR::make_Variable_t(al, xx.base.base.loc, current_scope, - s2c(al, placeholder), nullptr, 0, - xx.m_intent, nullptr, - nullptr, xx.m_storage, - list_type, nullptr, xx.m_abi, - xx.m_access, xx.m_presence, - xx.m_value_attr)); - - current_scope->add_symbol(s2c(al, placeholder), sym2); - xx.m_type = list_type; - } + xx.m_type = list_type; } } } @@ -583,7 +567,17 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(v); std::string list_name = list_variable->m_name; std::string placeholder = "_" + std::string(list_name); - ASR::symbol_t* placeholder_sym = current_scope->get_symbol(placeholder); + + ASR::symbol_t* placeholder_sym = ASR::down_cast( + ASR::make_Variable_t(al, list_variable->base.base.loc, current_scope, + s2c(al, placeholder), nullptr, 0, + list_variable->m_intent, nullptr, + nullptr, list_variable->m_storage, + list_type, nullptr, list_variable->m_abi, + list_variable->m_access, list_variable->m_presence, + list_variable->m_value_attr)); + + current_scope->add_symbol(s2c(al, placeholder), placeholder_sym); ASR::expr_t* placeholder_target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, placeholder_sym)); Vec temp_list1, temp_list2; From 4828e38a5105e9dcd98ab6663b22f4cd1f5b7133 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sun, 4 Feb 2024 16:07:26 +0530 Subject: [PATCH 3/5] Maded suitable changes in symbolics_15.py --- integration_tests/symbolics_15.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/integration_tests/symbolics_15.py b/integration_tests/symbolics_15.py index 83a58b6b97..e326e891a5 100644 --- a/integration_tests/symbolics_15.py +++ b/integration_tests/symbolics_15.py @@ -33,13 +33,12 @@ def mmrv(r: Out[list[CPtr]]) -> None: basic_new_stack(x) basic_const_pi(x) - # l1: list[S] + # l1: list[S] = [x] + _l1: list[CPtr] = [x] l1: list[CPtr] = [] - # l1 = [x] i: i32 = 0 - Len: i32 = 1 - for i in range(Len): + for i in range(len(_l1)): tmp: CPtr = basic_new_heap() l1.append(tmp) basic_assign(l1[0], x) From 21ada16c2a653aa43d97eecf17c4deda5c0ec988 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sun, 4 Feb 2024 16:32:42 +0530 Subject: [PATCH 4/5] Added tests --- integration_tests/CMakeLists.txt | 3 ++- integration_tests/symbolics_15.py | 4 ++-- integration_tests/symbolics_16.py | 18 ++++++++++++++++++ 3 files changed, 22 insertions(+), 3 deletions(-) create mode 100644 integration_tests/symbolics_16.py diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 54584dfc0f..12e8e30f21 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -718,9 +718,10 @@ RUN(NAME symbolics_10 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_11 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_12 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_13 LABELS cpython_sym c_sym llvm_sym NOFAST) -RUN(NAME symbolics_14 LABELS cpython_sym llvm_sym NOFAST) +RUN(NAME symbolics_14 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME test_gruntz LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_15 LABELS c_sym llvm_sym NOFAST) +RUN(NAME symbolics_16 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME sizeof_01 LABELS llvm c EXTRAFILES sizeof_01b.c) diff --git a/integration_tests/symbolics_15.py b/integration_tests/symbolics_15.py index e326e891a5..096f7c442d 100644 --- a/integration_tests/symbolics_15.py +++ b/integration_tests/symbolics_15.py @@ -56,8 +56,8 @@ def mmrv(r: Out[list[CPtr]]) -> None: def test_mrv(): # ans : list[S] # temp : list[S] - ans: list[CPtr] = [] - temp: list[CPtr] = [] + ans: list[CPtr] + temp: list[CPtr] # mmrv(ans) # temp = ans diff --git a/integration_tests/symbolics_16.py b/integration_tests/symbolics_16.py new file mode 100644 index 0000000000..abdaa2c92b --- /dev/null +++ b/integration_tests/symbolics_16.py @@ -0,0 +1,18 @@ +from lpython import S +from sympy import Symbol, pi, sin + +def mmrv() -> list[S]: + x: S = Symbol('x') + l1: list[S] = [pi, sin(x)] + return l1 + +def test_mrv1(): + ans: list[S] = mmrv() + element_1: S = ans[0] + element_2: S = ans[1] + assert element_1 == pi + assert element_2 == sin(Symbol('x')) + print(element_1, element_2) + + +test_mrv1() \ No newline at end of file From aef00c094da567ff3e6dcba48ddea44fc9c7b982 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sun, 4 Feb 2024 16:46:54 +0530 Subject: [PATCH 5/5] Added comments --- src/libasr/pass/replace_symbolic.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 71410bf856..afa7082ff9 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -561,6 +561,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_target)) { ASR::symbol_t *v = ASR::down_cast(x.m_target)->m_v; if (ASR::is_a(*v)) { + // Step1: Add the placeholder for the list variable to the scope ASRUtils::ASRBuilder b(al, x.base.base.loc); ASR::ttype_t* CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)); ASR::ttype_t* list_type = ASRUtils::TYPE(ASR::make_List_t(al, x.base.base.loc, CPtr_type)); @@ -594,11 +595,13 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_unique_name("symbolic_list_index"); ASR::ttype_t* int32_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 4)); ASR::symbol_t* index_sym = ASR::down_cast( @@ -611,6 +614,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_unique_name("block"); SymbolTable* block_symtab = al.make_new(current_scope); char *tmp_var_name = s2c(al, "tmp");