Skip to content

Commit

Permalink
Merge pull request #2378 from anutosh491/basic_assign
Browse files Browse the repository at this point in the history
Supporting assignment through `basic_assign`
  • Loading branch information
certik authored Oct 9, 2023
2 parents 87e7caa + 1830fc9 commit 013df0c
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
3 changes: 3 additions & 0 deletions integration_tests/symbolics_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ def main0():
y: S = Symbol('y')
x = pi
z: S = x + y
x = z
print(x)
print(z)
assert(x == z)
assert(z == pi + y)
assert(z != S(2)*pi + y)

Expand Down
60 changes: 59 additions & 1 deletion src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,45 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
}
}

ASR::symbol_t* declare_basic_assign_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "basic_assign";
symbolic_dependencies.push_back(name);
if (!module_scope->get_symbol(name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 2);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg1);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1)));
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "y"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
module_scope->add_symbol(s2c(al, name), symbol);
}
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_basic_str_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "basic_str";
symbolic_dependencies.push_back(name);
Expand Down Expand Up @@ -794,7 +833,26 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi

void visit_Assignment(const ASR::Assignment_t &x) {
SymbolTable* module_scope = current_scope->parent;
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*x.m_value)) {
if (ASR::is_a<ASR::Var_t>(*x.m_value) && ASR::is_a<ASR::CPtr_t>(*ASRUtils::expr_type(x.m_value))) {
ASR::symbol_t *v = ASR::down_cast<ASR::Var_t>(x.m_value)->m_v;
if (symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) return;
ASR::symbol_t* basic_assign_sym = declare_basic_assign_function(al, x.base.base.loc, module_scope);
ASR::symbol_t* var_sym = ASR::down_cast<ASR::Var_t>(x.m_value)->m_v;
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym));

Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 2);
ASR::call_arg_t call_arg1, call_arg2;
call_arg1.loc = x.base.base.loc;
call_arg1.m_value = x.m_target;
call_arg2.loc = x.base.base.loc;
call_arg2.m_value = target;
call_args.push_back(al, call_arg1);
call_args.push_back(al, call_arg2);
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_assign_sym,
basic_assign_sym, call_args.p, call_args.n, nullptr));
pass_result.push_back(al, stmt);
} else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*x.m_value)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(x.m_value);
if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) {
process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, x.m_target);
Expand Down

0 comments on commit 013df0c

Please sign in to comment.