Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
smallv0221 committed Oct 20, 2021
1 parent 74b18ea commit 401062e
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 7 deletions.
17 changes: 15 additions & 2 deletions paddle/fluid/operators/bincount_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,29 @@ void BincountCUDAInner(const framework::ExecutionContext& context) {
}
auto input_x = framework::EigenVector<InputT>::Flatten(*input);

framework::Tensor input_max_t;
framework::Tensor input_min_t, input_max_t;
auto* input_max_data =
input_max_t.mutable_data<InputT>({1}, context.GetPlace());
auto* input_min_data =
input_min_t.mutable_data<InputT>({1}, context.GetPlace());

auto input_max_scala = framework::EigenScalar<InputT>::From(input_max_t);
auto input_min_scala = framework::EigenScalar<InputT>::From(input_min_t);

auto* place = context.template device_context<DeviceContext>().eigen_device();
input_max_scala.device(*place) = input_x.maximum();
input_min_scala.device(*place) = input_x.minimum();

Tensor input_max_cpu;
Tensor input_min_cpu, input_max_cpu;
TensorCopySync(input_max_t, platform::CPUPlace(), &input_max_cpu);
TensorCopySync(input_min_t, platform::CPUPlace(), &input_min_cpu);

InputT input_min = input_min_cpu.data<InputT>()[0];

PADDLE_ENFORCE_GE(
input_min, static_cast<InputT>(0),
platform::errors::InvalidArgument(
"The elements in input tensor must be non-negative ints"));

int64_t output_size =
static_cast<int64_t>(input_max_cpu.data<InputT>()[0]) + 1L;
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/operators/bincount_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ void BincountInner(const framework::ExecutionContext& context) {
return;
}

PADDLE_ENFORCE_GE(
*std::min_element(input_data, input_data + input_numel),
static_cast<InputT>(0),
platform::errors::InvalidArgument(
"The elements in input tensor must be non-negative ints"));

int64_t output_size = static_cast<int64_t>(*std::max_element(
input_data, input_data + input_numel)) +
1L;
Expand Down
14 changes: 12 additions & 2 deletions python/paddle/fluid/tests/unittests/test_bincount_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_static_graph(self):
'weights': w},
fetch_list=[output])
actual = np.array(res[0])
expected = np.bincount(inputs, weights=weights)
expected = np.bincount(img, weights=w)
self.assertTrue(
(actual == expected).all(),
msg='bincount output is wrong, out =' + str(actual))
Expand All @@ -70,6 +70,16 @@ def run_network(self, net_func):
with fluid.dygraph.guard():
net_func()

def test_input_value_error(self):
"""Test input tensor should be non-negative."""

def net_func():
input_value = paddle.to_tensor([1, 2, 3, 4, -5])
paddle.bincount(input_value)

with self.assertRaises(ValueError):
self.run_network(net_func)

def test_input_shape_error(self):
"""Test input tensor should be 1-D tansor."""

Expand Down Expand Up @@ -97,7 +107,7 @@ def net_func():
input_value = paddle.to_tensor([1., 2., 3., 4., 5.])
paddle.bincount(input_value)

with self.assertRaises(ValueError):
with self.assertRaises(TypeError):
self.run_network(net_func)

def test_weights_shape_error(self):
Expand Down
7 changes: 4 additions & 3 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,15 +1318,16 @@ def bincount(x, weights=None, minlength=0, name=None):
result2 = paddle.bincount(x, weights=w)
print(result2) # [0., 2.19999981, 0.40000001, 0., 0.50000000, 0.50000000]
"""
check_variable_and_dtype(x, 'X', ['int32', 'int64'], 'bincount')
if x.dtype not in [paddle.int32, paddle.int64]:
raise TypeError("Elements in Input(x) should all be integers")

if paddle.min(x) < 0:
raise ValueError("Elements in Input(x) should all be non-negative")
if in_dygraph_mode():
return _C_ops.bincount(x, weights, "minlength", minlength)

helper = LayerHelper('bincount', **locals())

check_variable_and_dtype(x, 'X', ['int32', 'int64'], 'bincount')

if weights is not None:
check_variable_and_dtype(weights, 'Weights',
['int32', 'int64', 'float32', 'float64'],
Expand Down

0 comments on commit 401062e

Please sign in to comment.