Skip to content

Commit

Permalink
[ir] Let the type of ExternalTensorExpression be an ndarray struct
Browse files Browse the repository at this point in the history
ghstack-source-id: 429c5e9f00736ace86cdd11d65c861da2119550f
Pull Request resolved: #8191
  • Loading branch information
lin-hitonami committed Jun 15, 2023
1 parent eec09b3 commit 1f4da93
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
1 change: 1 addition & 0 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,7 @@ void InternalFuncCallExpression::flatten(FlattenContext *ctx) {
void ExternalTensorExpression::flatten(FlattenContext *ctx) {
auto type =
TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim, needs_grad);
type = TypeFactory::get_instance().get_pointer_type((Type *)type);

auto ptr = Stmt::make<ArgLoadStmt>(arg_id, type, /*is_ptr=*/true,
/*create_load=*/false);
Expand Down
3 changes: 2 additions & 1 deletion taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,8 @@ class ExternalTensorExpression : public Expression {
}

void type_check(const CompileConfig *config) override {
ret_type = dt;
ret_type = TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim,
needs_grad);
ret_type.set_is_pointer(true);
config_ = config;
}
Expand Down

0 comments on commit 1f4da93

Please sign in to comment.