Skip to content

Commit

Permalink
Merge pull request #2477 from anutosh491/fix_symbolic_list
Browse files Browse the repository at this point in the history
Fixing Symbolic List assignment
  • Loading branch information
certik authored Feb 7, 2024
2 parents 467081e + aef00c0 commit a6b9256
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 24 deletions.
3 changes: 2 additions & 1 deletion integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -718,9 +718,10 @@ RUN(NAME symbolics_10 LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME symbolics_11 LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME symbolics_12 LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME symbolics_13 LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME symbolics_14 LABELS cpython_sym llvm_sym NOFAST)
RUN(NAME symbolics_14 LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME test_gruntz LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME symbolics_15 LABELS c_sym llvm_sym NOFAST)
RUN(NAME symbolics_16 LABELS cpython_sym c_sym llvm_sym NOFAST)

RUN(NAME sizeof_01 LABELS llvm c
EXTRAFILES sizeof_01b.c)
Expand Down
11 changes: 5 additions & 6 deletions integration_tests/symbolics_15.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,12 @@ def mmrv(r: Out[list[CPtr]]) -> None:
basic_new_stack(x)
basic_const_pi(x)

# l1: list[S]
# l1: list[S] = [x]
_l1: list[CPtr] = [x]
l1: list[CPtr] = []

# l1 = [x]
i: i32 = 0
Len: i32 = 1
for i in range(Len):
for i in range(len(_l1)):
tmp: CPtr = basic_new_heap()
l1.append(tmp)
basic_assign(l1[0], x)
Expand All @@ -57,8 +56,8 @@ def mmrv(r: Out[list[CPtr]]) -> None:
def test_mrv():
# ans : list[S]
# temp : list[S]
ans: list[CPtr] = []
temp: list[CPtr] = []
ans: list[CPtr]
temp: list[CPtr]

# mmrv(ans)
# temp = ans
Expand Down
18 changes: 18 additions & 0 deletions integration_tests/symbolics_16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from lpython import S
from sympy import Symbol, pi, sin

def mmrv() -> list[S]:
x: S = Symbol('x')
l1: list[S] = [pi, sin(x)]
return l1

def test_mrv1():
ans: list[S] = mmrv()
element_1: S = ans[0]
element_2: S = ans[1]
assert element_1 == pi
assert element_2 == sin(Symbol('x'))
print(element_1, element_2)


test_mrv1()
115 changes: 98 additions & 17 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
return SubroutineCall(loc, basic_free_stack_sym, {x});
}

ASR::expr_t *basic_new_heap(const Location& loc) {
ASR::symbol_t* basic_new_heap_sym = create_bindc_function(loc,
"basic_new_heap", {}, ASRUtils::TYPE((ASR::make_CPtr_t(al, loc))));
Vec<ASR::call_arg_t> call_args; call_args.reserve(al, 1);
return FunctionCall(loc, basic_new_heap_sym, {},
ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)));
}

ASR::stmt_t* basic_get_args(const Location& loc, ASR::expr_t *x, ASR::expr_t *y) {
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_CPtr_t(al, loc));
ASR::symbol_t* basic_get_args_sym = create_bindc_function(loc,
Expand Down Expand Up @@ -323,8 +331,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
std::string var_name = xx.m_name;
std::string placeholder = "_" + std::string(var_name);

ASR::ttype_t *type1 = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
xx.m_type = type1;
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
xx.m_type = CPtr_type;
if (var_name != "_lpython_return_variable" && xx.m_intent != ASR::intentType::Out) {
symbolic_vars_to_free.insert(ASR::down_cast<ASR::symbol_t>((ASR::asr_t*)&xx));
}
Expand Down Expand Up @@ -357,13 +365,13 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, xx.base.base.loc, 0, type2))));

// statement 2
ASR::expr_t* value2 = ASRUtils::EXPR(ASR::make_PointerNullConstant_t(al, xx.base.base.loc, type1));
ASR::expr_t* value2 = ASRUtils::EXPR(ASR::make_PointerNullConstant_t(al, xx.base.base.loc, CPtr_type));

// statement 3
ASR::expr_t* get_pointer_node = ASRUtils::EXPR(ASR::make_GetPointer_t(al, xx.base.base.loc,
target1, ASRUtils::TYPE(ASR::make_Pointer_t(al, xx.base.base.loc, type2)), nullptr));
ASR::expr_t* value3 = ASRUtils::EXPR(ASR::make_PointerToCPtr_t(al, xx.base.base.loc, get_pointer_node,
type1, nullptr));
CPtr_type, nullptr));

// defining the assignment statement
ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target1, value1, nullptr));
Expand Down Expand Up @@ -548,21 +556,94 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::ListConstant_t* list_constant = ASR::down_cast<ASR::ListConstant_t>(x.m_value);
if (list_constant->m_type->type == ASR::ttypeType::List) {
ASR::List_t* list = ASR::down_cast<ASR::List_t>(list_constant->m_type);
if (list->m_type->type == ASR::ttypeType::SymbolicExpression){
Vec<ASR::expr_t*> temp_list;
temp_list.reserve(al, list_constant->n_args + 1);

for (size_t i = 0; i < list_constant->n_args; ++i) {
ASR::expr_t* value = handle_argument(al, x.base.base.loc, list_constant->m_args[i]);
temp_list.push_back(al, value);
if (list->m_type->type == ASR::ttypeType::SymbolicExpression){
if(ASR::is_a<ASR::Var_t>(*x.m_target)) {
ASR::symbol_t *v = ASR::down_cast<ASR::Var_t>(x.m_target)->m_v;
if (ASR::is_a<ASR::Variable_t>(*v)) {
// Step1: Add the placeholder for the list variable to the scope
ASRUtils::ASRBuilder b(al, x.base.base.loc);
ASR::ttype_t* CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
ASR::ttype_t* list_type = ASRUtils::TYPE(ASR::make_List_t(al, x.base.base.loc, CPtr_type));
ASR::Variable_t *list_variable = ASR::down_cast<ASR::Variable_t>(v);
std::string list_name = list_variable->m_name;
std::string placeholder = "_" + std::string(list_name);

ASR::symbol_t* placeholder_sym = ASR::down_cast<ASR::symbol_t>(
ASR::make_Variable_t(al, list_variable->base.base.loc, current_scope,
s2c(al, placeholder), nullptr, 0,
list_variable->m_intent, nullptr,
nullptr, list_variable->m_storage,
list_type, nullptr, list_variable->m_abi,
list_variable->m_access, list_variable->m_presence,
list_variable->m_value_attr));

current_scope->add_symbol(s2c(al, placeholder), placeholder_sym);
ASR::expr_t* placeholder_target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, placeholder_sym));

Vec<ASR::expr_t*> temp_list1, temp_list2;
temp_list1.reserve(al, list_constant->n_args + 1);
temp_list2.reserve(al, list_constant->n_args + 1);

for (size_t i = 0; i < list_constant->n_args; ++i) {
ASR::expr_t* value = handle_argument(al, x.base.base.loc, list_constant->m_args[i]);
temp_list1.push_back(al, value);
}

ASR::expr_t* temp_list_const1 = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, temp_list1.p,
temp_list1.size(), list_type));
ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, placeholder_target, temp_list_const1, nullptr));
pass_result.push_back(al, stmt1);

// Step2: Add the empty list variable
ASR::expr_t* temp_list_const2 = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, temp_list2.p,
temp_list2.size(), list_type));
ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, temp_list_const2, nullptr));
pass_result.push_back(al, stmt2);

// Step3: Add the list index to the function scope
std::string symbolic_list_index = current_scope->get_unique_name("symbolic_list_index");
ASR::ttype_t* int32_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 4));
ASR::symbol_t* index_sym = ASR::down_cast<ASR::symbol_t>(
ASR::make_Variable_t(al, x.base.base.loc, current_scope, s2c(al, symbolic_list_index),
nullptr, 0, ASR::intentType::Local, nullptr, nullptr, ASR::storage_typeType::Default,
int32_type, nullptr, ASR::abiType::Source, ASR::Public, ASR::presenceType::Required, false));
current_scope->add_symbol(symbolic_list_index, index_sym);
ASR::expr_t* index = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, index_sym));
ASR::stmt_t* stmt3 = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, index,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, 0, int32_type)), nullptr));
pass_result.push_back(al, stmt3);

// Step4: Add the DoLoop for appending elements into the list
std::string block_name = current_scope->get_unique_name("block");
SymbolTable* block_symtab = al.make_new<SymbolTable>(current_scope);
char *tmp_var_name = s2c(al, "tmp");
ASR::expr_t* tmp_var = b.Variable(block_symtab, tmp_var_name, CPtr_type,
ASR::intentType::Local, ASR::abiType::Source, false);
Vec<ASR::stmt_t*> block_body; block_body.reserve(al, 1);
ASR::stmt_t* block_stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, tmp_var,
basic_new_heap(x.base.base.loc), nullptr));
block_body.push_back(al, block_stmt1);
ASR::stmt_t* block_stmt2 = ASRUtils::STMT(ASR::make_ListAppend_t(al, x.base.base.loc, x.m_target, tmp_var));
block_body.push_back(al, block_stmt2);
block_body.push_back(al, basic_assign(x.base.base.loc, ASRUtils::EXPR(ASR::make_ListItem_t(al,
x.base.base.loc, x.m_target, index, CPtr_type, nullptr)), ASRUtils::EXPR(ASR::make_ListItem_t(al,
x.base.base.loc, placeholder_target, index, CPtr_type, nullptr))));
ASR::symbol_t* block = ASR::down_cast<ASR::symbol_t>(ASR::make_Block_t(al, x.base.base.loc,
block_symtab, s2c(al, block_name), block_body.p, block_body.n));
current_scope->add_symbol(block_name, block);
ASR::stmt_t* block_call = ASRUtils::STMT(ASR::make_BlockCall_t(
al, x.base.base.loc, -1, block));
std::vector<ASR::stmt_t*> do_loop_body;
do_loop_body.push_back(block_call);
ASR::stmt_t* stmt4 = b.DoLoop(index, ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, 0, int32_type)),
ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, x.base.base.loc,
ASRUtils::EXPR(ASR::make_ListLen_t(al, x.base.base.loc, placeholder_target, int32_type, nullptr)), ASR::binopType::Sub,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, 1, int32_type)), int32_type, nullptr)),
do_loop_body, ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, 1, int32_type)));
pass_result.push_back(al, stmt4);
}
}

ASR::ttype_t* type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
ASR::ttype_t* list_type = ASRUtils::TYPE(ASR::make_List_t(al, x.base.base.loc, type));
ASR::expr_t* temp_list_const = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, temp_list.p,
temp_list.size(), list_type));
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, temp_list_const, nullptr));
pass_result.push_back(al, stmt);
}
}
} else if (ASR::is_a<ASR::ListItem_t>(*x.m_value)) {
Expand Down

0 comments on commit a6b9256

Please sign in to comment.