-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add one_hot operator. #7819
Merged
Merged
Add one_hot operator. #7819
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/operators/one_hot_op.h" | ||
#include "paddle/framework/framework.pb.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class OneHotOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
PADDLE_ENFORCE(ctx->HasInput("X"), | ||
"Input(X) of OneHotOp should not be null."); | ||
PADDLE_ENFORCE(ctx->HasOutput("Out"), | ||
"Output(Out) of OneHotOp should not be null."); | ||
|
||
auto x_dims = ctx->GetInputDim("X"); | ||
PADDLE_ENFORCE_GE(x_dims.size(), 2, | ||
"Rank of Input(X) should be at least 2."); | ||
PADDLE_ENFORCE_GE(x_dims[x_dims.size() - 1], 1U, | ||
"Last dimension of Input(X) should be 1."); | ||
|
||
int depth = ctx->Attrs().Get<int>("depth"); | ||
|
||
PADDLE_ENFORCE_GT(depth, 0, "Should provide a positive depth (%d).", depth); | ||
|
||
framework::DDim out_dims(x_dims); | ||
out_dims[out_dims.size() - 1] = depth; | ||
ctx->SetOutputDim("Out", out_dims); | ||
ctx->ShareLoD("X", /* --> */ "Out"); | ||
} | ||
}; | ||
|
||
class OneHotOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
OneHotOpMaker(OpProto* proto, OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("X", | ||
"(LoDTensor, LoDTensor<int>) Input variable with rank at least 2. " | ||
"The last dimension of X should be 1. Each value of X is an index " | ||
"to indicate the position."); | ||
AddOutput("Out", | ||
"(Tensor, Tensor<float>) Output tensor with same rank as X. " | ||
"The tensor consists of one-hot representations of values in X."); | ||
AddAttr<int>("depth", | ||
"A positive integer to specify the length of one-hot vector."); | ||
AddAttr<int>("dtype", | ||
"An integer to specify the data type of one-hot " | ||
"vector. The default value is FP32.") | ||
.SetDefault(paddle::framework::proto::DataType::FP32); | ||
AddComment(R"DOC( | ||
One Hot Operator. This operator creates the one-hot representations for input | ||
index values. The following example will help to explain the function of this | ||
operator: | ||
|
||
X is a LoDTensor: | ||
X.lod = [[0, 1, 4]] | ||
X.shape = [4, 1] | ||
X.data = [[1], [1], [3], [0]] | ||
|
||
set depth = 4 | ||
|
||
Out is a LoDTensor: | ||
Out.lod = [[0, 1, 4]] | ||
Out.shape = [4, 4] | ||
Out.data = [[0., 1., 0., 0.], | ||
[0., 1., 0., 0.], | ||
[0., 0., 0., 1.], | ||
[1., 0., 0., 0.]] | ||
)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OPERATOR(one_hot, ops::OneHotOp, ops::OneHotOpMaker, | ||
paddle::framework::EmptyGradOpMaker); | ||
REGISTER_OP_CPU_KERNEL( | ||
one_hot, ops::OneHotKernel<paddle::platform::CPUDeviceContext, int>, | ||
ops::OneHotKernel<paddle::platform::CPUDeviceContext, int64_t>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/operators/one_hot_op.h" | ||
#include "paddle/platform/cuda_helper.h" | ||
#include "paddle/platform/gpu_info.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
using platform::PADDLE_CUDA_NUM_THREADS; | ||
|
||
template <typename InT, typename OutT> | ||
__global__ void FillOutputKernel(const InT* p_in_data, OutT* p_out_data, | ||
const int64_t numel, const int depth) { | ||
int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (idx < numel) { | ||
*(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0; | ||
} | ||
} | ||
|
||
template <typename DeviceContext, typename InT> | ||
struct OneHotOpCUDAFunctor { | ||
const framework::LoDTensor* in_; | ||
framework::LoDTensor* out_; | ||
const DeviceContext& ctx_; | ||
int depth_; | ||
|
||
OneHotOpCUDAFunctor(const framework::LoDTensor* in, framework::LoDTensor* out, | ||
int depth, const DeviceContext& ctx) | ||
: in_(in), out_(out), depth_(depth), ctx_(ctx) {} | ||
|
||
template <typename OutT> | ||
void operator()() const { | ||
auto* p_in_data = in_->data<InT>(); | ||
auto numel = in_->numel(); | ||
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace()); | ||
auto stream = ctx_.stream(); | ||
math::set_constant(ctx_, out_, 0.0); | ||
|
||
FillOutputKernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / | ||
PADDLE_CUDA_NUM_THREADS, | ||
PADDLE_CUDA_NUM_THREADS, 0, stream>>>( | ||
p_in_data, p_out_data, numel, depth_); | ||
} | ||
}; | ||
|
||
using LoDTensor = framework::LoDTensor; | ||
template <typename DeviceContext, typename T> | ||
class OneHotCUDAKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
auto* in = context.Input<LoDTensor>("X"); | ||
auto* out = context.Output<LoDTensor>("Out"); | ||
int depth = context.Attr<int>("depth"); | ||
|
||
framework::VisitDataType( | ||
static_cast<framework::proto::DataType>(context.Attr<int>("dtype")), | ||
OneHotOpCUDAFunctor<DeviceContext, T>( | ||
in, out, depth, context.template device_context<DeviceContext>())); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP_CUDA_KERNEL( | ||
one_hot, ops::OneHotCUDAKernel<paddle::platform::CUDADeviceContext, int>, | ||
ops::OneHotCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/operators/math/math_function.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename DeviceContext, typename InT> | ||
struct OneHotOpFunctor { | ||
const framework::LoDTensor* in_; | ||
framework::LoDTensor* out_; | ||
int depth_; | ||
const DeviceContext& ctx_; | ||
|
||
OneHotOpFunctor(const framework::LoDTensor* in, framework::LoDTensor* out, | ||
int depth, const DeviceContext& ctx) | ||
: in_(in), out_(out), depth_(depth), ctx_(ctx) {} | ||
|
||
template <typename OutT> | ||
void operator()() const { | ||
auto* p_in_data = in_->data<InT>(); | ||
auto numel = in_->numel(); | ||
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace()); | ||
math::set_constant(ctx_, out_, 0.0); | ||
|
||
for (int i = 0; i < numel; ++i) { | ||
PADDLE_ENFORCE_GE(p_in_data[i], 0, | ||
"Illegal index value, should be at least 0."); | ||
PADDLE_ENFORCE_LT(p_in_data[i], depth_, | ||
"Illegal index value, should be less than depth (%d).", | ||
depth_); | ||
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0; | ||
} | ||
} | ||
}; | ||
|
||
using LoDTensor = framework::LoDTensor; | ||
template <typename DeviceContext, typename T> | ||
class OneHotKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
auto* in = context.Input<LoDTensor>("X"); | ||
auto* out = context.Output<LoDTensor>("Out"); | ||
int depth = context.Attr<int>("depth"); | ||
|
||
framework::VisitDataType( | ||
static_cast<framework::proto::DataType>(context.Attr<int>("dtype")), | ||
OneHotOpFunctor<DeviceContext, T>( | ||
in, out, depth, context.template device_context<DeviceContext>())); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
import numpy as np | ||
import math | ||
from op_test import OpTest | ||
import paddle.v2.fluid as fluid | ||
import paddle.v2.fluid.core as core | ||
import paddle.v2.fluid.framework as framework | ||
from paddle.v2.fluid.framework import Program, program_guard | ||
|
||
|
||
class TestOneHotOp(OpTest): | ||
def setUp(self): | ||
self.op_type = 'one_hot' | ||
depth = 10 | ||
dimension = 12 | ||
x_lod = [[0, 4, 5, 8, 11]] | ||
x = [np.random.randint(0, depth - 1) for i in xrange(x_lod[0][-1])] | ||
x = np.array(x).astype('int').reshape([x_lod[0][-1], 1]) | ||
|
||
out = np.zeros(shape=(np.product(x.shape[:-1]), | ||
depth)).astype('float32') | ||
|
||
for i in xrange(np.product(x.shape)): | ||
out[i, x[i]] = 1.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
==>
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, but I think current implement is ok. |
||
|
||
self.inputs = {'X': (x, x_lod)} | ||
self.attrs = {'depth': depth, 'dtype': int(core.DataType.FP32)} | ||
self.outputs = {'Out': (out, x_lod)} | ||
|
||
def test_check_output(self): | ||
self.check_output() | ||
|
||
|
||
class TestOneHotOp_default_dtype(OpTest): | ||
def setUp(self): | ||
self.op_type = 'one_hot' | ||
depth = 10 | ||
dimension = 12 | ||
x_lod = [[0, 4, 5, 8, 11]] | ||
x = [np.random.randint(0, depth - 1) for i in xrange(x_lod[0][-1])] | ||
x = np.array(x).astype('int').reshape([x_lod[0][-1], 1]) | ||
|
||
out = np.zeros(shape=(np.product(x.shape[:-1]), | ||
depth)).astype('float32') | ||
|
||
for i in xrange(np.product(x.shape)): | ||
out[i, x[i]] = 1.0 | ||
|
||
self.inputs = {'X': (x, x_lod)} | ||
self.attrs = {'depth': depth} | ||
self.outputs = {'Out': (out, x_lod)} | ||
|
||
def test_check_output(self): | ||
self.check_output() | ||
|
||
|
||
class TestOneHotOp_exception(OpTest): | ||
def setUp(self): | ||
self.op_type = 'one_hot' | ||
self.depth = 10 | ||
self.place = core.CPUPlace() | ||
self.dimension = 12 | ||
self.x = core.LoDTensor() | ||
x_lod = [[0, 4, 5, 8, 11]] | ||
data = [np.random.randint(11, 20) for i in xrange(x_lod[0][-1])] | ||
data = np.array(data).astype('int').reshape([x_lod[0][-1], 1]) | ||
self.x.set(data, self.place) | ||
self.x.set_lod(x_lod) | ||
|
||
def test_check_output(self): | ||
program = Program() | ||
with program_guard(program): | ||
x = fluid.layers.data( | ||
name='x', shape=[self.dimension], dtype='float32', lod_level=1) | ||
block = program.current_block() | ||
one_hot_out = block.create_var( | ||
name="one_hot_out", | ||
type=core.VarDesc.VarType.LOD_TENSOR, | ||
dtype='float32') | ||
block.append_op( | ||
type='one_hot', | ||
inputs={'X': x}, | ||
attrs={'depth': self.depth}, | ||
outputs={'Out': one_hot_out}) | ||
exe = fluid.Executor(self.place) | ||
|
||
def run(): | ||
exe.run(feed={'x': self.x}, | ||
fetch_list=[one_hot_out], | ||
return_numpy=False) | ||
|
||
self.assertRaises(core.EnforceNotMet, run) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should enforce that a positive depth is provided.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, done.