From c044da29ccaba80a99588c786c00331e762de4db Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Thu, 15 Jun 2023 16:40:07 +0800 Subject: [PATCH] [ir] Let the type of ExternalTensorExpression be an ndarray struct [ghstack-poisoned] --- taichi/ir/frontend_ir.cpp | 1 + taichi/ir/frontend_ir.h | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 028f9655bb084..3b6a659fdc63c 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -608,6 +608,7 @@ void InternalFuncCallExpression::flatten(FlattenContext *ctx) { void ExternalTensorExpression::flatten(FlattenContext *ctx) { auto type = TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim, needs_grad); + type = TypeFactory::get_instance().get_pointer_type((Type *)type); auto ptr = Stmt::make(arg_id, type, /*is_ptr=*/true, /*create_load=*/false); diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index f123c64c087c8..46341d4404172 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -502,7 +502,8 @@ class ExternalTensorExpression : public Expression { } void type_check(const CompileConfig *config) override { - ret_type = dt; + ret_type = TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim, + needs_grad); ret_type.set_is_pointer(true); config_ = config; }