From 80a4ee0433c32f559b8914fa26a0ab697f827438 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Tue, 13 Apr 2021 16:16:43 +0800 Subject: [PATCH] fix expand op lack of float16 --- python/paddle/fluid/layers/nn.py | 3 ++- python/paddle/tensor/manipulation.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c4f4754cc7794..565c134ae9d95 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -10332,7 +10332,8 @@ def expand(x, expand_times, name=None): inputs = {"X": [x]} attrs = {} check_variable_and_dtype( - x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'expand') + x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'expand') check_type(expand_times, 'expand_times', (list, tuple, Variable), 'expand') if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == True: raise ValueError( diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 377435a50008a..696775434b967 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1432,7 +1432,8 @@ def expand(x, shape, name=None): 'Elements in shape must be 1-D Tensors or integers.') check_variable_and_dtype( - x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'expand') + x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'expand') check_type(shape, 'shape', (list, tuple, Variable), 'expand') if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == False: raise ValueError("When the data type of input 'x' for expand is bool, "