Skip to content

Commit

Permalink
Merge pull request #2335 from anutosh491/Implement_visit_SubroutineCall
Browse files Browse the repository at this point in the history
Implemented `visit_SubroutineCall` for the symbolic pass
  • Loading branch information
certik authored Sep 24, 2023
2 parents f9b09dd + 37cc6eb commit f76dff7
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 2 deletions.
8 changes: 6 additions & 2 deletions integration_tests/symbolics_09.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sympy import Symbol, pi, S
from sympy import Symbol, pi, sin, cos
from lpython import S, i32

def addInteger(x: S, y: S, z: S, i: i32):
Expand All @@ -9,7 +9,11 @@ def call_addInteger():
a: S = Symbol("x")
b: S = Symbol("y")
c: S = pi
addInteger(a, b, c, 2)
d: S = sin(a)
e: S = cos(b)
addInteger(c, d, e, 2)
addInteger(c, sin(a), cos(b), 2)
addInteger(pi, sin(Symbol("x")), cos(Symbol("y")), 2)

def main0():
call_addInteger()
Expand Down
69 changes: 69 additions & 0 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,56 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
}
}

void visit_SubroutineCall(const ASR::SubroutineCall_t &x) {
SymbolTable* module_scope = current_scope->parent;
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);

for (size_t i=0; i<x.n_args; i++) {
ASR::expr_t* val = x.m_args[i].m_value;
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*val) && ASR::is_a<ASR::SymbolicExpression_t>(*ASRUtils::expr_type(val))) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(val);
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc));
std::string symengine_var = symengine_stack.push();
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, x.base.base.loc, current_scope, s2c(al, symengine_var), nullptr, 0, ASR::intentType::Local,
nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr,
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
current_scope->add_symbol(s2c(al, symengine_var), arg);
for (auto &item : current_scope->get_scope()) {
if (ASR::is_a<ASR::Variable_t>(*item.second)) {
ASR::Variable_t *s = ASR::down_cast<ASR::Variable_t>(item.second);
this->visit_Variable(*s);
}
}

ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg));
process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, target);

ASR::call_arg_t call_arg;
call_arg.loc = x.base.base.loc;
call_arg.m_value = target;
call_args.push_back(al, call_arg);
} else if (ASR::is_a<ASR::Cast_t>(*val)) {
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(val);
if(cast_t->m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return;
this->visit_Cast(*cast_t);
ASR::symbol_t *var_sym = current_scope->get_symbol(symengine_stack.pop());
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym));

ASR::call_arg_t call_arg;
call_arg.loc = x.base.base.loc;
call_arg.m_value = target;
call_args.push_back(al, call_arg);
} else {
call_args.push_back(al, x.m_args[i]);
}
}
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, x.m_name,
x.m_name, call_args.p, call_args.n, nullptr));
pass_result.push_back(al, stmt);
}

void visit_Print(const ASR::Print_t &x) {
std::vector<ASR::expr_t*> print_tmp;
SymbolTable* module_scope = current_scope->parent;
Expand Down Expand Up @@ -739,6 +789,25 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg));
process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, target);

// Now create the FunctionCall node for basic_str
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = x.base.base.loc;
call_arg.m_value = target;
call_args.push_back(al, call_arg);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc,
basic_str_sym, basic_str_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr));
print_tmp.push_back(function_call);
} else if (ASR::is_a<ASR::Cast_t>(*val)) {
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(val);
if(cast_t->m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return;
this->visit_Cast(*cast_t);
ASR::symbol_t *var_sym = current_scope->get_symbol(symengine_stack.pop());
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym));

// Now create the FunctionCall node for basic_str
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
Vec<ASR::call_arg_t> call_args;
Expand Down

0 comments on commit f76dff7

Please sign in to comment.