Skip to content

Commit

Permalink
Add serializable.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 3, 2019
1 parent 85691f9 commit c601456
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
34 changes: 32 additions & 2 deletions include/xgboost/json.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
#ifndef XGBOOST_JSON_H_
#define XGBOOST_JSON_H_

#include <dmlc/io.h> // deprecated
#include <xgboost/logging.h>

#include <string>

#include <map>
Expand Down Expand Up @@ -577,6 +577,36 @@ inline std::map<std::string, std::string> fromJson(std::map<std::string, Json> c
}
return res;
}

} // namespace xgboost

#include <rabit/rabit.h>

namespace xgboost {

struct Serializable : public rabit::Serializable {
virtual ~Serializable() = default;
/*!
* \brief load the model from a stream
* \param fi stream where to load the model from
*/
virtual void Load(dmlc::Stream *fi) override = 0;
/*!
* \brief saves the model to a stream
* \param fo stream where to save the model to
*/
virtual void Save(dmlc::Stream *fo) const override = 0;

/*!
* \brief load the model from a json object
* \param in json object where to load the model from
*/
virtual void Load(Json const& in) = 0;
/*!
* \breif saves the model to a json object
* \param out json container where to save the model to
*/
virtual void Save(Json* out) const = 0;
};
} // namespace xgboost

#endif // XGBOOST_JSON_H_
8 changes: 3 additions & 5 deletions include/xgboost/learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
#ifndef XGBOOST_LEARNER_H_
#define XGBOOST_LEARNER_H_

#include <rabit/rabit.h>

#include <xgboost/base.h>
#include <xgboost/gbm.h>
#include <xgboost/metric.h>
Expand Down Expand Up @@ -42,7 +40,7 @@ namespace xgboost {
*
* \endcode
*/
class Learner : public rabit::Serializable {
class Learner : public Serializable {
public:
/*! \brief virtual destructor */
~Learner() override = default;
Expand All @@ -51,14 +49,14 @@ class Learner : public rabit::Serializable {
*/
virtual void Configure() = 0;

virtual void Load(Json const& in) = 0;
virtual void Load(Json const& in) override = 0;
/*!
* \brief load model from stream
* \param fi input stream.
*/
void Load(dmlc::Stream* fi) override = 0;

virtual void Save(Json* out) const = 0;
virtual void Save(Json* out) const override = 0;
/*!
* \brief save model to stream.
* \param fo output stream
Expand Down

0 comments on commit c601456

Please sign in to comment.