From 9081009c3666c893a06023bdda4aff3fd0606f0b Mon Sep 17 00:00:00 2001 From: lyuwenyu Date: Thu, 15 Apr 2021 15:32:22 +0800 Subject: [PATCH] support `numpy.array/asarray(tensor) -> ndarray`, test=develop --- python/paddle/fluid/dygraph/varbase_patch_methods.py | 5 ++++- python/paddle/fluid/tests/unittests/test_var_base.py | 9 +++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index ac594709867d1..64209aee875ba 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -434,6 +434,9 @@ def __nonzero__(self): def __bool__(self): return self.__nonzero__() + def __array__(self, dtype=None): + return self.numpy().astype(dtype) + for method_name, method in ( ("__bool__", __bool__), ("__nonzero__", __nonzero__), ("_to_static_var", _to_static_var), ("set_value", set_value), @@ -442,7 +445,7 @@ def __bool__(self): ("gradient", gradient), ("register_hook", register_hook), ("__str__", __str__), ("__repr__", __str__), ("__deepcopy__", __deepcopy__), ("__module__", "paddle"), - ("__name__", "Tensor")): + ("__name__", "Tensor"), ("__array__", __array__)): setattr(core.VarBase, method_name, method) # NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class. diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 1fea1935473a7..76c871f37216b 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -502,6 +502,15 @@ def test_var_base_to_np(self): np.array_equal(var.numpy(), fluid.framework._var_base_to_np(var))) + def test_var_base_as_np(self): + with fluid.dygraph.guard(): + var = fluid.dygraph.to_variable(self.array) + self.assertTrue(np.array_equal(var.numpy(), np.array(var))) + self.assertTrue( + np.array_equal( + var.numpy(), np.array( + var, dtype=np.float32))) + def test_if(self): with fluid.dygraph.guard(): var1 = fluid.dygraph.to_variable(np.array([[[0]]]))