From 3b866d1b3e4c83f3e2891a169286abef3dd455ad Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 28 Nov 2023 09:37:49 -0600 Subject: [PATCH] [Container] Support non-nullable types in Array::Map Prior to this commit, the `Array::Map` member function could only be applied to nullable object types. This was due to the internal use of `U()` as the default value for initializing the output `ArrayNode`, where `U` is the return type of the mapping function. This default constructor is only available for nullable types, and would result in a compile-time failure for non-nullable types. This commit replaces `U()` with `ObjectRef()` in `Array::Map`, removing this limitation. Since all items in the output array are overwritten before returning to the calling scope, initializing the output array with `ObjectRef()` does not violate type safety. --- include/tvm/runtime/container/array.h | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index ff0bd03ab9cb..ba8fdfac5565 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -827,8 +827,13 @@ class Array : public ObjectRef { // consisting of any previous elements that had mapped to // themselves (if any), and the element that didn't map to // itself. + // + // We cannot use `U()` as the default object, as `U` may be + // a non-nullable type. Since the default `ObjectRef()` + // will be overwritten before returning, all objects will be + // of type `U` for the calling scope. all_identical = false; - output = ArrayNode::CreateRepeated(arr->size(), U()); + output = ArrayNode::CreateRepeated(arr->size(), ObjectRef()); output->InitRange(0, arr->begin(), it); output->SetItem(it - arr->begin(), std::move(mapped)); it++; @@ -843,7 +848,12 @@ class Array : public ObjectRef { // compatible types isn't strictly necessary, as the first // mapped.same_as(*it) would return false, but we might as well // avoid it altogether. - output = ArrayNode::CreateRepeated(arr->size(), U()); + // + // We cannot use `U()` as the default object, as `U` may be a + // non-nullable type. Since the default `ObjectRef()` will be + // overwritten before returning, all objects will be of type `U` + // for the calling scope. + output = ArrayNode::CreateRepeated(arr->size(), ObjectRef()); } // Normal path for incompatible types, or post-copy path for