From a19434e40d26ece2f5d298b5274d1aab266d7eb3 Mon Sep 17 00:00:00 2001 From: co63oc Date: Fri, 8 Dec 2023 21:48:31 +0800 Subject: [PATCH] Fix Tensor.to --- python/paddle/base/dygraph/tensor_patch_methods.py | 1 + test/legacy_test/test_Tensor_to.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/python/paddle/base/dygraph/tensor_patch_methods.py b/python/paddle/base/dygraph/tensor_patch_methods.py index 674d9c740f158..1b779ed87cb5d 100644 --- a/python/paddle/base/dygraph/tensor_patch_methods.py +++ b/python/paddle/base/dygraph/tensor_patch_methods.py @@ -455,6 +455,7 @@ def _to(self, device=None, dtype=None, blocking=None): elif isinstance( device, ( + core.Place, core.CPUPlace, core.CUDAPlace, core.CUDAPinnedPlace, diff --git a/test/legacy_test/test_Tensor_to.py b/test/legacy_test/test_Tensor_to.py index c9901cb68d780..9821fac861621 100644 --- a/test/legacy_test/test_Tensor_to.py +++ b/test/legacy_test/test_Tensor_to.py @@ -55,6 +55,13 @@ def test_Tensor_to_device(self): else: self.assertTrue(placex_str, "Place(" + place + ")") + def test_Tensor_to_device2(self): + x = paddle.to_tensor([1, 2, 3]) + y = paddle.to_tensor([1, 2, 3], place="cpu") + + y.to(x.place) + self.assertTrue(x.place, y.place) + def test_Tensor_to_device_dtype(self): tensorx = paddle.to_tensor([1, 2, 3]) places = ["cpu"]