Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move DataType enum inside VarType #8447

Merged
merged 81 commits into from
Feb 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
e18bc02
Move Pod Types from DataType enum to Type enum
Feb 14, 2018
e6c68a2
Fixed data_type.h
Feb 14, 2018
49f3934
Fix type in TensorDesc
Feb 14, 2018
25c8349
Add comment to framework.proto
Feb 14, 2018
4bfb927
Fixed type in data_type.h
Feb 14, 2018
9d5ff82
Merge branch 'refine_pod' of https://github.com/abhinavarora/Paddle i…
Feb 14, 2018
f380ee4
Updated format of type in data_type.h
Feb 14, 2018
27e796b
Fix var_desc.h
Feb 14, 2018
c5af325
Fix op_kernel_type.h
Feb 14, 2018
66532a2
Fixed data_type_transform_test.cc
Feb 14, 2018
5c0ca47
Merge branch 'refine_pod' of https://github.com/abhinavarora/Paddle i…
Feb 14, 2018
4136fb4
Fix operator.h
Feb 14, 2018
72f738b
Merge branch 'refine_pod' of github.com:abhinavarora/Paddle into refi…
Feb 14, 2018
b4794d6
Fixed data_type_transform.cc
Feb 14, 2018
efc517f
Merge branch 'refine_pod' of https://github.com/abhinavarora/Paddle i…
Feb 14, 2018
ff004a6
Fixed op_kernel_type_test.cc
Feb 14, 2018
68d435c
Fix operator.cc
Feb 14, 2018
5459285
Merge branch 'refine_pod' of github.com:abhinavarora/Paddle into refi…
Feb 14, 2018
050f4ae
Fixed data_layout_transform_test.cc
Feb 14, 2018
e335170
Merge branch 'refine_pod' of https://github.com/abhinavarora/Paddle i…
Feb 14, 2018
f8e98f7
Fix var_desc.cc
Feb 14, 2018
a0d1a43
Merge branch 'refine_pod' of github.com:abhinavarora/Paddle into refi…
Feb 14, 2018
d11ba0d
Fixed assign_value_op.cc
Feb 14, 2018
b6e0f66
Merge branch 'refine_pod' of https://github.com/abhinavarora/Paddle i…
Feb 14, 2018
b14f84d
Fixed assign_value_op.h
Feb 14, 2018
ec1d63a
fixed protobuf.cc
Feb 14, 2018
f69a9c2
Fix data_layout_transform_test.cc and op_kernel_type_test.cc
Feb 14, 2018
b567ce7
Merge branch 'refine_pod' of github.com:abhinavarora/Paddle into refi…
Feb 14, 2018
b5a3d1e
Fixed rnn_memory_helper_op.cc
Feb 14, 2018
4dcf87e
Merge branch 'refine_pod' of https://github.com/abhinavarora/Paddle i…
Feb 14, 2018
a99e87e
Fix progrma_desc_test.cc
Feb 14, 2018
4ec6d28
Fixed fill_constant_batch_size_like_op.cc
Feb 14, 2018
85204cb
Merge branch 'refine_pod' of https://github.com/abhinavarora/Paddle i…
Feb 14, 2018
8d00568
Fix operator_test.cc
Feb 14, 2018
4be98b8
Fixed fill_constant_op.cc
Feb 14, 2018
39c1dba
Merge branch 'refine_pod' of https://github.com/abhinavarora/Paddle i…
Feb 14, 2018
0706435
Fixed gaussian_random_op.cc
Feb 14, 2018
9c83b5a
Fixed uniform_random_op.cc
Feb 14, 2018
ff5804b
Fixed edit_distance_op.cc
Feb 14, 2018
3337454
Fixed fill_constant_batch_size_like_op.cc
Feb 14, 2018
a9d7567
Fixed rnn_memory_helper_op.cc
Feb 14, 2018
a5a7bfd
Fixed chunk_eval_op.cc
Feb 14, 2018
e7ddd9a
Fixed assign_value_op.cc
Feb 14, 2018
c7e7e90
Fixed assign_value_op.h
Feb 14, 2018
d862193
Fixed cast_op.h
Feb 14, 2018
39cf0ac
Fixed cast_op.h
Feb 14, 2018
82ff241
Fix fill constant op
Feb 14, 2018
0e06e60
Merge branch 'refine_pod' of github.com:abhinavarora/Paddle into refi…
Feb 14, 2018
82a07f2
Fixed clang for assign_value_op.cc
Feb 14, 2018
359ba78
Merge branch 'refine_pod' of https://github.com/abhinavarora/Paddle i…
Feb 14, 2018
c58f255
Fix one_hot_op.h
Feb 14, 2018
62b48a5
Merge branch 'refine_pod' of github.com:abhinavarora/Paddle into refi…
Feb 14, 2018
7190e28
Fix one_hot_op.cc
Feb 14, 2018
f71f11e
Fix fill_op.cc
Feb 14, 2018
42e2fd8
Fixed sum_op.cc
Feb 14, 2018
a1e0660
Merge branch 'refine_pod' of https://github.com/abhinavarora/Paddle i…
Feb 14, 2018
bfdb906
Fixed sum_op clang
Feb 14, 2018
83412f4
Fix uniform_random_op.cc
Feb 14, 2018
e58b11f
Fix gaussian_random_op.cc
Feb 14, 2018
e1024bb
Fix backward.cc
Feb 14, 2018
b8c5fdb
Fix protobuf.cc
Feb 14, 2018
5b59ee3
Fixed prune_test.cc
Feb 14, 2018
e9d6f9b
Merge branch 'refine_pod' of https://github.com/abhinavarora/Paddle i…
Feb 14, 2018
4f23dbd
Fixed op_registry_test.cc
Feb 14, 2018
d1c3a71
Merge remote-tracking branch 'origin/develop' into refine_pod
Feb 14, 2018
394828b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Feb 14, 2018
a9c448e
Merge branch 'refine_pod' of https://github.com/abhinavarora/Paddle i…
Feb 14, 2018
174fa88
Fix data_device_transform_test.cu
Feb 14, 2018
df13499
Fix travis error
Feb 14, 2018
93d2885
Merge branch 'refine_pod' of https://github.com/abhinavarora/Paddle i…
Feb 14, 2018
a0fb646
Fixed one_hot_op.cu
Feb 14, 2018
7343b20
Fixed op_registry_test.cc
Feb 14, 2018
8ab2b5c
Fixed nccl_op.cc
Feb 14, 2018
fccaa4c
Fixing python tests
Feb 14, 2018
6fda466
Revert "Fixing python tests"
Feb 15, 2018
e81a481
Fixing Pybind to remove data type
Feb 15, 2018
9c1d6a1
Fixing tensor.py
Feb 15, 2018
a9860dd
fix merge conflict
Feb 15, 2018
c533f77
Updated the new files:
Feb 15, 2018
7d66a09
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Feb 15, 2018
9ff7b17
Resolve error in merge conflict of fill_constant_batch_size_like_op.cc
Feb 16, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/framework/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ static void CreateGradVarInBlock(
auto* param = block_desc->FindVarRecursive(pname);
auto* grad = block_desc->FindVar(arg);
if (param == nullptr) {
grad->SetDataType(proto::DataType::FP32);
grad->SetDataType(proto::VarType::FP32);
} else {
grad->SetDataType(param->GetDataType());
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/data_device_transform_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ class TestOpWithKernel : public OperatorWithKernel {
const ExecutionContext& ctx) const override {
if (Attr<bool>("use_gpu")) {
VLOG(3) << "force use gpu kernel";
return OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0));
return OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0));
} else {
VLOG(3) << "use default kernel";
return OpKernelType(proto::DataType::FP32,
return OpKernelType(proto::VarType::FP32,
ctx.Input<Tensor>("input")->place());
}
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/data_layout_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ TEST(DataTransform, DataLayoutFunction) {
in.mutable_data<double>(make_ddim({2, 3, 1, 2}), place);
in.set_layout(DataLayout::kNHWC);

auto kernel_nhwc = OpKernelType(proto::DataType::FP32, place,
auto kernel_nhwc = OpKernelType(proto::VarType::FP32, place,
DataLayout::kNHWC, LibraryType::kPlain);
auto kernel_ncwh = OpKernelType(proto::DataType::FP32, place,
auto kernel_ncwh = OpKernelType(proto::VarType::FP32, place,
DataLayout::kNCHW, LibraryType::kPlain);

TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out);
Expand Down
54 changes: 27 additions & 27 deletions paddle/fluid/framework/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,89 +20,89 @@ limitations under the License. */
namespace paddle {
namespace framework {

inline proto::DataType ToDataType(std::type_index type) {
inline proto::VarType::Type ToDataType(std::type_index type) {
using namespace paddle::framework::proto;
if (typeid(float).hash_code() == type.hash_code()) {
return DataType::FP32;
return proto::VarType::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) {
return DataType::FP64;
return proto::VarType::FP64;
} else if (typeid(int).hash_code() == type.hash_code()) {
return DataType::INT32;
return proto::VarType::INT32;
} else if (typeid(int64_t).hash_code() == type.hash_code()) {
return DataType::INT64;
return proto::VarType::INT64;
} else if (typeid(bool).hash_code() == type.hash_code()) {
return DataType::BOOL;
return proto::VarType::BOOL;
} else {
PADDLE_THROW("Not supported");
}
}

inline std::type_index ToTypeIndex(proto::DataType type) {
inline std::type_index ToTypeIndex(proto::VarType::Type type) {
using namespace paddle::framework::proto;
switch (type) {
case DataType::FP32:
case proto::VarType::FP32:
return typeid(float);
case DataType::FP64:
case proto::VarType::FP64:
return typeid(double);
case DataType::INT32:
case proto::VarType::INT32:
return typeid(int);
case DataType::INT64:
case proto::VarType::INT64:
return typeid(int64_t);
case DataType::BOOL:
case proto::VarType::BOOL:
return typeid(bool);
default:
PADDLE_THROW("Not support type %d", type);
}
}

template <typename Visitor>
inline void VisitDataType(proto::DataType type, Visitor visitor) {
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
using namespace paddle::framework::proto;
switch (type) {
case DataType::FP32:
case proto::VarType::FP32:
visitor.template operator()<float>();
break;
case DataType::FP64:
case proto::VarType::FP64:
visitor.template operator()<double>();
break;
case DataType::INT32:
case proto::VarType::INT32:
visitor.template operator()<int>();
break;
case DataType::INT64:
case proto::VarType::INT64:
visitor.template operator()<int64_t>();
break;
case DataType::BOOL:
case proto::VarType::BOOL:
visitor.template operator()<bool>();
break;
default:
PADDLE_THROW("Not supported");
}
}

inline std::string DataTypeToString(const proto::DataType type) {
inline std::string DataTypeToString(const proto::VarType::Type type) {
using namespace paddle::framework::proto;
switch (type) {
case DataType::FP16:
case proto::VarType::FP16:
return "float16";
case DataType::FP32:
case proto::VarType::FP32:
return "float32";
case DataType::FP64:
case proto::VarType::FP64:
return "float64";
case DataType::INT16:
case proto::VarType::INT16:
return "int16";
case DataType::INT32:
case proto::VarType::INT32:
return "int32";
case DataType::INT64:
case proto::VarType::INT64:
return "int64";
case DataType::BOOL:
case proto::VarType::BOOL:
return "bool";
default:
PADDLE_THROW("Not support type %d", type);
}
}

inline std::ostream& operator<<(std::ostream& out,
const proto::DataType& type) {
const proto::VarType::Type& type) {
out << DataTypeToString(type);
return out;
}
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/framework/data_type_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,19 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
auto ctx = pool.Get(in.place());

switch (src_type) {
case proto::DataType::FP32:
case proto::VarType::FP32:
framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
break;
case proto::DataType::FP64:
case proto::VarType::FP64:
framework::VisitDataType(dst_type, CastDataType<double>(in, out, ctx));
break;
case proto::DataType::INT32:
case proto::VarType::INT32:
framework::VisitDataType(dst_type, CastDataType<int>(in, out, ctx));
break;
case proto::DataType::INT64:
case proto::VarType::INT64:
framework::VisitDataType(dst_type, CastDataType<int64_t>(in, out, ctx));
break;
case proto::DataType::BOOL:
case proto::VarType::BOOL:
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
break;
default:
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/data_type_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ TEST(DataTypeTransform, CPUTransform) {
ptr[i] = i / 3;
}

auto kernel_fp32 = OpKernelType(proto::DataType::FP32, place,
auto kernel_fp32 = OpKernelType(proto::VarType::FP32, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_fp64 = OpKernelType(proto::DataType::FP64, place,
auto kernel_fp64 = OpKernelType(proto::VarType::FP64, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_int32 = OpKernelType(proto::DataType::INT32, place,
auto kernel_int32 = OpKernelType(proto::VarType::INT32, place,
DataLayout::kAnyLayout, LibraryType::kPlain);

TransDataType(kernel_fp32, kernel_fp64, in, &out);
Expand Down
41 changes: 21 additions & 20 deletions paddle/fluid/framework/framework.proto
Original file line number Diff line number Diff line change
Expand Up @@ -91,33 +91,34 @@ message OpProto {
required string comment = 5;
}

enum DataType {
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
}

message VarType {
enum Type {
LOD_TENSOR = 1;
SELECTED_ROWS = 2;
FEED_MINIBATCH = 3;
FETCH_LIST = 4;
STEP_SCOPES = 5;
LOD_RANK_TABLE = 6;
LOD_TENSOR_ARRAY = 7;
PLACE_LIST = 8;
READER = 9;
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;

// Other types that may need additional descriptions
LOD_TENSOR = 7;
SELECTED_ROWS = 8;
FEED_MINIBATCH = 9;
FETCH_LIST = 10;
STEP_SCOPES = 11;
LOD_RANK_TABLE = 12;
LOD_TENSOR_ARRAY = 13;
PLACE_LIST = 14;
READER = 15;
}

required Type type = 1;

message TensorDesc {
required DataType data_type = 1;
// Should only be PODType. Is enforced in C++
required Type data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
}
optional TensorDesc selected_rows = 2;
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/op_kernel_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,20 @@ struct OpKernelType {
// place, data_type, library_type kinds less than 2^8
constexpr static int LEFT_SHIFT = 8;

proto::DataType data_type_;
proto::VarType::Type data_type_;
DataLayout data_layout_;
platform::Place place_;
LibraryType library_type_;

OpKernelType(proto::DataType data_type, platform::Place place,
OpKernelType(proto::VarType::Type data_type, platform::Place place,
DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain)
: data_type_(data_type),
data_layout_(data_layout),
place_(place),
library_type_(library_type) {}

OpKernelType(proto::DataType data_type,
OpKernelType(proto::VarType::Type data_type,
const platform::DeviceContext& dev_ctx,
DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain)
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/op_kernel_type_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License. */

TEST(OpKernelType, ToString) {
using OpKernelType = paddle::framework::OpKernelType;
using DataType = paddle::framework::proto::DataType;
using DataType = paddle::framework::proto::VarType;
using CPUPlace = paddle::platform::CPUPlace;
using DataLayout = paddle::framework::DataLayout;
using LibraryType = paddle::framework::LibraryType;
Expand All @@ -33,7 +33,7 @@ TEST(OpKernelType, ToString) {

TEST(OpKernelType, Hash) {
using OpKernelType = paddle::framework::OpKernelType;
using DataType = paddle::framework::proto::DataType;
using DataType = paddle::framework::proto::VarType;
using CPUPlace = paddle::platform::CPUPlace;
using CUDAPlace = paddle::platform::CUDAPlace;
using DataLayout = paddle::framework::DataLayout;
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class OpWithKernelTest : public OperatorWithKernel {

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
return framework::OpKernelType(proto::VarType::FP32, ctx.device_context());
}
};

Expand Down Expand Up @@ -290,9 +290,9 @@ class OpWithMultiKernelTest : public OperatorWithKernel {

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
proto::DataType::FP32, platform::CUDAPlace(0), DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
return framework::OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0),
DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
};

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
}

proto::DataType OperatorWithKernel::IndicateDataType(
proto::VarType::Type OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
int data_type = -1;
Expand All @@ -595,7 +595,7 @@ proto::DataType OperatorWithKernel::IndicateDataType(
}
}
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<proto::DataType>(data_type);
return static_cast<proto::VarType::Type>(data_type);
}

OpKernelType OperatorWithKernel::GetExpectedKernelType(
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,9 @@ class OperatorWithKernel : public OperatorBase {
const OpKernelType& expected_kernel_type) const;

private:
// indicate kernel DataType by input data. Defaultly all input data must be
// indicate kernel DataType by input data. By default all input data must be
// same.
proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
void RunImpl(const Scope& scope, const platform::Place& place) const final;
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class OpWithKernelTest : public OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
return OpKernelType(proto::VarType::FP32, ctx.GetPlace());
}
};

Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/framework/program_desc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ TEST(ProgramDesc, copy_ctor) {
auto* x = global_block->Var("X");
x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(proto::FP32);
x->SetDataType(proto::VarType::FP32);
x->SetShape({1000, 784});

auto* y = global_block->Var("Y");
y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(proto::FP32);
y->SetDataType(proto::VarType::FP32);
y->SetShape({784, 100});

auto* op = global_block->AppendOp();
Expand Down Expand Up @@ -86,13 +86,13 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
auto* x = global_block->Var("X");
x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(proto::FP32);
x->SetDataType(proto::VarType::FP32);
x->SetShape({1000, 784});

auto* y = global_block->Var("Y");
y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(proto::FP32);
y->SetDataType(proto::VarType::FP32);
y->SetShape({784, 100});

auto* op = global_block->AppendOp();
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/prune_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
for (auto kv : outputs) {
for (auto v : kv.second) {
auto var = block->Var(v);
var->SetDataType(paddle::framework::proto::DataType::FP32);
var->SetDataType(paddle::framework::proto::VarType::FP32);
}
}

Expand Down
Loading