Skip to content

Commit

Permalink
Fix build on big endian CPUs (#5617)
Browse files Browse the repository at this point in the history
* Fix build on big endian CPUs

* Clang-tidy
  • Loading branch information
hcho3 authored Apr 30, 2020
1 parent b9649e7 commit 8de7f19
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
19 changes: 18 additions & 1 deletion include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <dmlc/base.h>
#include <dmlc/data.h>
#include <dmlc/serializer.h>
#include <rabit/rabit.h>
#include <xgboost/base.h>
#include <xgboost/span.h>
Expand Down Expand Up @@ -554,5 +555,21 @@ inline BatchSet<EllpackPage> DMatrix::GetBatches(const BatchParam& param) {

namespace dmlc {
DMLC_DECLARE_TRAITS(is_pod, xgboost::Entry, true);
}

namespace serializer {

template <>
struct Handler<xgboost::Entry> {
inline static void Write(Stream* strm, const xgboost::Entry& data) {
strm->Write(data.index);
strm->Write(data.fvalue);
}

inline static bool Read(Stream* strm, xgboost::Entry* data) {
return strm->Read(&data->index) && strm->Read(&data->fvalue);
}
};

} // namespace serializer
} // namespace dmlc
#endif // XGBOOST_DATA_H_
12 changes: 8 additions & 4 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ template <typename T>
void SaveScalarField(dmlc::Stream *strm, const std::string &name,
xgboost::DataType type, const T &field) {
strm->Write(name);
strm->Write(type);
strm->Write(static_cast<uint8_t>(type));
strm->Write(true); // is_scalar=True
strm->Write(field);
}
Expand All @@ -47,7 +47,7 @@ void SaveVectorField(dmlc::Stream *strm, const std::string &name,
xgboost::DataType type, std::pair<uint64_t, uint64_t> shape,
const std::vector<T>& field) {
strm->Write(name);
strm->Write(type);
strm->Write(static_cast<uint8_t>(type));
strm->Write(false); // is_scalar=False
strm->Write(shape.first);
strm->Write(shape.second);
Expand All @@ -71,7 +71,9 @@ void LoadScalarField(dmlc::Stream* strm, const std::string& expected_name,
CHECK(strm->Read(&name)) << invalid;
CHECK_EQ(name, expected_name)
<< invalid << " Expected field: " << expected_name << ", got: " << name;
CHECK(strm->Read(&type)) << invalid;
uint8_t type_val;
CHECK(strm->Read(&type_val)) << invalid;
type = static_cast<xgboost::DataType>(type_val);
CHECK(type == expected_type)
<< invalid << "Expected field of type: " << static_cast<int>(expected_type) << ", "
<< "got field type: " << static_cast<int>(type);
Expand All @@ -91,7 +93,9 @@ void LoadVectorField(dmlc::Stream* strm, const std::string& expected_name,
CHECK(strm->Read(&name)) << invalid;
CHECK_EQ(name, expected_name)
<< invalid << " Expected field: " << expected_name << ", got: " << name;
CHECK(strm->Read(&type)) << invalid;
uint8_t type_val;
CHECK(strm->Read(&type_val)) << invalid;
type = static_cast<xgboost::DataType>(type_val);
CHECK(type == expected_type)
<< invalid << "Expected field of type: " << static_cast<int>(expected_type) << ", "
<< "got field type: " << static_cast<int>(type);
Expand Down

0 comments on commit 8de7f19

Please sign in to comment.