Skip to content

Commit

Permalink
Merge pull request #10683 from reyoung/feature/tensor_support_uint8
Browse files Browse the repository at this point in the history
Make tensor support uint8
  • Loading branch information
reyoung authored May 16, 2018
2 parents 8b1b756 + 9b7cd7f commit e528862
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 6 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/data_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ static DataTypeMap* InitDataTypeMap() {
RegType(bool, proto::VarType::BOOL);
RegType(size_t, proto::VarType::SIZE_T);
RegType(int16_t, proto::VarType::INT16);
RegType(uint8_t, proto::VarType::UINT8);

#undef RegType
return retv;
Expand Down
8 changes: 7 additions & 1 deletion paddle/fluid/framework/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,14 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
case proto::VarType::BOOL:
visitor.template operator()<bool>();
break;
case proto::VarType::UINT8:
visitor.template operator()<uint8_t>();
break;
case proto::VarType::INT16:
visitor.template operator()<int16_t>();
break;
default:
PADDLE_THROW("Not supported");
PADDLE_THROW("Not supported %d", type);
}
}

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/framework.proto
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ message VarType {
FP64 = 6;
// Tensor<size_t> is used in C++.
SIZE_T = 19;
UINT8 = 20;

// Other types that may need additional descriptions
LOD_TENSOR = 7;
Expand Down
17 changes: 13 additions & 4 deletions paddle/fluid/framework/lod_tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,12 @@ TEST(LoD, CheckAbsLoD) {
ASSERT_FALSE(CheckAbsLoD(abs_lod0));
}

TEST(LoDTensor, RecordIO) {
template <typename T>
static void TestRecordIO() {
LoDTensor tensor;
int* tmp = tensor.mutable_data<int>(make_ddim({4, 5}), platform::CPUPlace());
T* tmp = tensor.mutable_data<T>(make_ddim({4, 5}), platform::CPUPlace());
for (int i = 0; i < 20; ++i) {
tmp[i] = i;
tmp[i] = static_cast<T>(i);
}

std::stringstream* stream = new std::stringstream();
Expand All @@ -247,7 +248,7 @@ TEST(LoDTensor, RecordIO) {

auto assert_tensor_ok = [](const LoDTensor& tensor) {
for (int i = 0; i < 20; ++i) {
ASSERT_EQ(tensor.data<int>()[i], i);
ASSERT_EQ(tensor.data<T>()[i], static_cast<T>(i));
}
};

Expand All @@ -265,5 +266,13 @@ TEST(LoDTensor, RecordIO) {
}
}

TEST(LoDTensor, RecordIO) {
TestRecordIO<int>();
TestRecordIO<int16_t>();
TestRecordIO<uint8_t>();
TestRecordIO<float>();
TestRecordIO<double>();
}

} // namespace framework
} // namespace paddle
4 changes: 3 additions & 1 deletion paddle/fluid/operators/math/math_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ template struct SetConstant<platform::CPUDeviceContext, bool>;
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>;
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>;

DEFINE_CPU_TRANS(1);
DEFINE_CPU_TRANS(2);
Expand Down

0 comments on commit e528862

Please sign in to comment.