Skip to content

Commit

Permalink
[ir] Let the type of ExternalTensorExpression be an ndarray struct
Browse files Browse the repository at this point in the history
ghstack-source-id: 8b5cea534c215bcacff943ff89294700ffa63ce3
Pull Request resolved: taichi-dev#8191
  • Loading branch information
lin-hitonami authored and PGZXB committed Jul 13, 2023
1 parent 6cf2d46 commit 871a493
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
1 change: 1 addition & 0 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,7 @@ void InternalFuncCallExpression::flatten(FlattenContext *ctx) {
void ExternalTensorExpression::flatten(FlattenContext *ctx) {
auto type =
TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim, needs_grad);
type = TypeFactory::get_instance().get_pointer_type((Type *)type);

auto ptr = Stmt::make<ArgLoadStmt>(arg_id, type, /*is_ptr=*/true,
/*create_load=*/false);
Expand Down
3 changes: 2 additions & 1 deletion taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,8 @@ class ExternalTensorExpression : public Expression {
}

void type_check(const CompileConfig *config) override {
ret_type = dt;
ret_type = TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim,
needs_grad);
ret_type.set_is_pointer(true);
config_ = config;
}
Expand Down

0 comments on commit 871a493

Please sign in to comment.