diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index e80ed5fb1f8f..d530ef587782 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -156,12 +156,13 @@ class NDArray : public ObjectRef { TVM_DLL static NDArray Empty(ShapeTuple shape, DLDataType dtype, Device dev, Optional mem_scope = NullOpt); /*! - * \brief Create a NDArray backed by an external DLTensor. + * \brief Create a NDArray backed by an external DLTensor without memory copying. * + * If DLTensor is not contiguous or has bad aligned data, It fails. * This allows us to create a NDArray using the memory * allocated by an external source. Responsibility for memory * retaining lies with the external source. - * \param dl_tensor The DLTensor to copy from. + * \param dl_tensor The DLTensor for NDArray base. * \return The created NDArray view. */ TVM_DLL static NDArray FromExternalDLTensor(const DLTensor& dl_tensor); @@ -172,7 +173,7 @@ class NDArray : public ObjectRef { * \param dev device location of the created NDArray. * \return The created NDArray view. */ - TVM_DLL static NDArray NewFromDLTensor(DLTensor* dl_tensor, Device dev); + TVM_DLL static NDArray NewFromDLTensor(DLTensor* dl_tensor, const Device& dev); /*! * \brief Create a NDArray backed by a dlpack tensor. * @@ -196,9 +197,23 @@ class NDArray : public ObjectRef { TVM_DLL ShapeTuple Shape() const; TVM_DLL runtime::DataType DataType() const; + /*! + * \brief Check conditions for construction NDArray over DLTensor without copying. + * There are three conditions to check: + * 1. Destination device is the same as DLTensor device + * 2. Destination device id is the same as DLTensor device id + * 3. Memory in DLTensor is aligned as expected for NDArray + * \param tensor the DLTensor. + * \param dev destination device. + * \return true if all conditions are satisfied. + */ + TVM_DLL static bool AbilityOfZeroCopyForDLTensor(DLTensor* tensor, const Device& dev); // internal namespace struct Internal; + private: + TVM_DLL static bool IsAligned(const DLTensor& tensor); + protected: friend class TVMPODValue_; friend class TVMRetValue; @@ -345,7 +360,7 @@ inline size_t GetDataSize(const DLTensor& arr) { * \param arr The input DLTensor. * \return The check result. */ -inline bool IsContiguous(const DLTensor& arr) { +static inline bool IsContiguous(const DLTensor& arr) { if (arr.strides == nullptr) return true; int64_t expected_stride = 1; for (int32_t i = arr.ndim; i != 0; --i) { diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index f44dc86f902a..0b4a9dfdd9e9 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -206,6 +206,9 @@ NDArray NDArray::Empty(ShapeTuple shape, DLDataType dtype, Device dev, Optional< } NDArray NDArray::FromExternalDLTensor(const DLTensor& dl_tensor) { + ICHECK(::tvm::runtime::IsContiguous(dl_tensor)) + << "External DLTensor is not contiguous. It does not support for now"; + ICHECK(IsAligned(dl_tensor)) << "Data in DLTensor is not aligned as required by NDArray"; NDArray::Container* data = new NDArray::Container(); data->SetDeleter(Internal::SelfDeleter); @@ -219,7 +222,9 @@ NDArray NDArray::FromExternalDLTensor(const DLTensor& dl_tensor) { return NDArray(GetObjectPtr(data)); } -NDArray NDArray::NewFromDLTensor(DLTensor* tensor, Device dev) { +NDArray NDArray::NewFromDLTensor(DLTensor* tensor, const Device& dev) { + ICHECK(::tvm::runtime::IsContiguous(*tensor)) + << "DLTensor is not contiguous. It does not support for now"; std::vector shape; for (int64_t i = 0; i < tensor->ndim; i++) { shape.push_back(tensor->shape[i]); @@ -276,10 +281,24 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str } ShapeTuple NDArray::Shape() const { return get_mutable()->shape_; } + runtime::DataType NDArray::DataType() const { return runtime::DataType(get_mutable()->dl_tensor.dtype); } +bool NDArray::AbilityOfZeroCopyForDLTensor(DLTensor* tensor, const Device& dev) { + bool device_check = (dev.device_type == tensor->device.device_type); + bool device_id_check = (dev.device_id == tensor->device.device_id); + bool alignment_check = IsAligned(*tensor); + return device_check && device_id_check && alignment_check; +} + +bool NDArray::IsAligned(const DLTensor& tensor) { + return (reinterpret_cast(static_cast(tensor.data) + tensor.byte_offset) % + tvm::runtime::kAllocAlignment == + 0); +} + TVM_REGISTER_OBJECT_TYPE(NDArray::Container); } // namespace runtime diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 41b9395237ee..8d03dbf210c3 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -308,13 +308,10 @@ void VirtualMachine::CreateInputsOrCheckSize(const std::string& func_name, size_ void VirtualMachine::SetInputTensorWithIndex(std::vector& tensors, const TVMArgValue& inp_tensor, int index, Device dev) { if (inp_tensor.type_code() == kTVMDLTensorHandle) { - // Automatically convert input DLTensors to NDArray - DLTensor* tensor = inp_tensor; - if (dev.device_type == tensor->device.device_type && - dev.device_id == tensor->device.device_id) { - tensors[index] = NDArray::FromExternalDLTensor(*tensor); + if (NDArray::AbilityOfZeroCopyForDLTensor(inp_tensor, dev)) { + tensors[index] = NDArray::FromExternalDLTensor(*inp_tensor); } else { - tensors[index] = NDArray::NewFromDLTensor(tensor, dev); + tensors[index] = NDArray::NewFromDLTensor(inp_tensor, dev); } } else { tensors[index] = CopyTo(inp_tensor, dev);