diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 55227c3d7b229..851d72a654aff 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -608,6 +608,7 @@ void InternalFuncCallExpression::flatten(FlattenContext *ctx) { void ExternalTensorExpression::flatten(FlattenContext *ctx) { auto type = TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim, needs_grad); + type = TypeFactory::get_instance().get_pointer_type((Type *)type); auto ptr = Stmt::make(arg_id, type, /*is_ptr=*/true, /*create_load=*/false); diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index f123c64c087c8..46341d4404172 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -502,7 +502,8 @@ class ExternalTensorExpression : public Expression { } void type_check(const CompileConfig *config) override { - ret_type = dt; + ret_type = TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim, + needs_grad); ret_type.set_is_pointer(true); config_ = config; }