diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 5f694a63a500..da2c80622ad3 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2015 by Contributors. + * Copyright 2015-2019 by Contributors. * \brief XGBoost Amalgamation. * This offers an alternative way to compile the entire library from this single file. * @@ -66,6 +66,8 @@ #include "../src/common/common.cc" #include "../src/common/host_device_vector.cc" #include "../src/common/hist_util.cc" +#include "../src/common/json.cc" +#include "../src/common/io.cc" // c_api #include "../src/c_api/c_api.cc" diff --git a/include/xgboost/json.h b/include/xgboost/json.h new file mode 100644 index 000000000000..46836f1add48 --- /dev/null +++ b/include/xgboost/json.h @@ -0,0 +1,530 @@ +/*! + * Copyright (c) by Contributors 2019 + */ +#ifndef XGBOOST_JSON_H_ +#define XGBOOST_JSON_H_ + +#include + +#include + +#include +#include +#include +#include +#include + +namespace xgboost { + +class Json; +class JsonReader; +class JsonWriter; + +class Value { + public: + /*!\brief Simplified implementation of LLVM RTTI. */ + enum class ValueKind { + String, + Number, + Integer, + Object, // std::map + Array, // std::vector + Raw, + Boolean, + Null + }; + + explicit Value(ValueKind _kind) : kind_{_kind} {} + + ValueKind Type() const { return kind_; } + virtual ~Value() = default; + + virtual void Save(JsonWriter* writer) = 0; + + virtual Json& operator[](std::string const & key) = 0; + virtual Json& operator[](int ind) = 0; + + virtual bool operator==(Value const& rhs) const = 0; + virtual Value& operator=(Value const& rhs) = 0; + + std::string TypeStr() const; + + private: + ValueKind kind_; +}; + +template +bool IsA(Value const* value) { + return T::isClassOf(value); +} + +template +T* Cast(U* value) { + if (IsA(value)) { + return dynamic_cast(value); + } else { + throw std::runtime_error( + "Invalid cast, from " + value->TypeStr() + " to " + T().TypeStr()); + } +} + +class JsonString : public Value { + std::string str_; + public: + JsonString() : Value(ValueKind::String) {} + JsonString(std::string const& str) : // NOLINT + Value(ValueKind::String), str_{str} {} + JsonString(std::string&& str) : // NOLINT + Value(ValueKind::String), str_{std::move(str)} {} + + void Save(JsonWriter* writer) override; + + Json& operator[](std::string const & key) override; + Json& operator[](int ind) override; + + std::string const& getString() && { return str_; } + std::string const& getString() const & { return str_; } + std::string& getString() & { return str_; } + + bool operator==(Value const& rhs) const override; + Value& operator=(Value const& rhs) override; + + static bool isClassOf(Value const* value) { + return value->Type() == ValueKind::String; + } +}; + +class JsonArray : public Value { + std::vector vec_; + + public: + JsonArray() : Value(ValueKind::Array) {} + JsonArray(std::vector&& arr) : // NOLINT + Value(ValueKind::Array), vec_{std::move(arr)} {} + JsonArray(std::vector const& arr) : // NOLINT + Value(ValueKind::Array), vec_{arr} {} + JsonArray(JsonArray const& that) = delete; + JsonArray(JsonArray && that); + + void Save(JsonWriter* writer) override; + + Json& operator[](std::string const & key) override; + Json& operator[](int ind) override; + + std::vector const& getArray() && { return vec_; } + std::vector const& getArray() const & { return vec_; } + std::vector& getArray() & { return vec_; } + + bool operator==(Value const& rhs) const override; + Value& operator=(Value const& rhs) override; + + static bool isClassOf(Value const* value) { + return value->Type() == ValueKind::Array; + } +}; + +class JsonRaw : public Value { + std::string str_; + + public: + explicit JsonRaw(std::string&& str) : + Value(ValueKind::Raw), + str_{std::move(str)}{} // NOLINT + JsonRaw() : Value(ValueKind::Raw) {} + + std::string const& getRaw() && { return str_; } + std::string const& getRaw() const & { return str_; } + std::string& getRaw() & { return str_; } + + void Save(JsonWriter* writer) override; + + Json& operator[](std::string const & key) override; + Json& operator[](int ind) override; + + bool operator==(Value const& rhs) const override; + Value& operator=(Value const& rhs) override; + + static bool isClassOf(Value const* value) { + return value->Type() == ValueKind::Raw; + } +}; + +class JsonObject : public Value { + std::map object_; + + public: + JsonObject() : Value(ValueKind::Object) {} + JsonObject(std::map&& object); // NOLINT + JsonObject(JsonObject const& that) = delete; + JsonObject(JsonObject && that); + + void Save(JsonWriter* writer) override; + + Json& operator[](std::string const & key) override; + Json& operator[](int ind) override; + + std::map const& getObject() && { return object_; } + std::map const& getObject() const & { return object_; } + std::map & getObject() & { return object_; } + + bool operator==(Value const& rhs) const override; + Value& operator=(Value const& rhs) override; + + static bool isClassOf(Value const* value) { + return value->Type() == ValueKind::Object; + } + virtual ~JsonObject() = default; +}; + +class JsonNumber : public Value { + public: + using Float = float; + + private: + Float number_; + + public: + JsonNumber() : Value(ValueKind::Number) {} + JsonNumber(double value) : Value(ValueKind::Number) { // NOLINT + number_ = value; + } + + void Save(JsonWriter* writer) override; + + Json& operator[](std::string const & key) override; + Json& operator[](int ind) override; + + Float const& getNumber() && { return number_; } + Float const& getNumber() const & { return number_; } + Float& getNumber() & { return number_; } + + bool operator==(Value const& rhs) const override; + Value& operator=(Value const& rhs) override; + + static bool isClassOf(Value const* value) { + return value->Type() == ValueKind::Number; + } +}; + +class JsonNull : public Value { + public: + JsonNull() : Value(ValueKind::Null) {} + JsonNull(std::nullptr_t) : Value(ValueKind::Null) {} // NOLINT + + void Save(JsonWriter* writer) override; + + Json& operator[](std::string const & key) override; + Json& operator[](int ind) override; + + bool operator==(Value const& rhs) const override; + Value& operator=(Value const& rhs) override; + + static bool isClassOf(Value const* value) { + return value->Type() == ValueKind::Null; + } +}; + +/*! \brief Describes both true and false. */ +class JsonBoolean : public Value { + bool boolean_; + + public: + JsonBoolean() : Value(ValueKind::Boolean) {} // NOLINT + // Ambigious with JsonNumber. + template ::value || + std::is_same::value>::type* = nullptr> + JsonBoolean(Bool value) : // NOLINT + Value(ValueKind::Boolean), boolean_{value} {} + + void Save(JsonWriter* writer) override; + + Json& operator[](std::string const & key) override; + Json& operator[](int ind) override; + + bool const& getBoolean() && { return boolean_; } + bool const& getBoolean() const & { return boolean_; } + bool& getBoolean() & { return boolean_; } + + bool operator==(Value const& rhs) const override; + Value& operator=(Value const& rhs) override; + + static bool isClassOf(Value const* value) { + return value->Type() == ValueKind::Boolean; + } +}; + +struct StringView { + char const* str_; + size_t size_; + + public: + StringView() = default; + StringView(char const* str, size_t size) : str_{str}, size_{size} {} + + char const& operator[](size_t p) const { return str_[p]; } + char const& at(size_t p) const { // NOLINT + CHECK_LT(p, size_); + return str_[p]; + } + size_t size() const { return size_; } // NOLINT + // Copies a portion of string. Since we don't have std::from_chars and friends here, so + // copying substring is necessary for appending `\0`. It's not too bad since string by + // default has small vector optimization, which is enabled by most if not all modern + // compilers for numeric values. + std::string substr(size_t beg, size_t n) const { // NOLINT + CHECK_LE(beg, size_); + return std::string {str_ + beg, n < (size_ - beg) ? n : (size_ - beg)}; + } + char const* c_str() const { return str_; } // NOLINT +}; + +/*! + * \brief Data structure representing JSON format. + * + * Limitation: UTF-8 is not properly supported. Code points above ASCII are + * invalid. + * + * Examples: + * + * \code + * // Create a JSON object. + * Json object { Object() }; + * // Assign key "key" with a JSON string "Value"; + * object["key"] = String("Value"); + * // Assign key "arr" with a empty JSON Array; + * object["arr"] = Array(); + * \endcode + */ +class Json { + friend JsonWriter; + + public: + /*! \brief Load a Json object from string. */ + static Json Load(StringView str, bool ignore_specialization = false); + /*! \brief Pass your own JsonReader. */ + static Json Load(JsonReader* reader); + /*! \brief Dump json into stream. */ + static void Dump(Json json, std::ostream* stream, + bool pretty = ConsoleLogger::ShouldLog( + ConsoleLogger::LogVerbosity::kDebug)); + + Json() : ptr_{new JsonNull} {} + + // number + explicit Json(JsonNumber number) : ptr_{new JsonNumber(number)} {} + Json& operator=(JsonNumber number) { + ptr_.reset(new JsonNumber(std::move(number))); + return *this; + } + + // array + explicit Json(JsonArray list) : + ptr_ {new JsonArray(std::move(list))} {} + Json& operator=(JsonArray array) { + ptr_.reset(new JsonArray(std::move(array))); + return *this; + } + + // raw + explicit Json(JsonRaw str) : + ptr_{new JsonRaw(std::move(str))} {} + Json& operator=(JsonRaw str) { + ptr_.reset(new JsonRaw(std::move(str))); + return *this; + } + + // object + explicit Json(JsonObject object) : + ptr_{new JsonObject(std::move(object))} {} + Json& operator=(JsonObject object) { + ptr_.reset(new JsonObject(std::move(object))); + return *this; + } + // string + explicit Json(JsonString str) : + ptr_{new JsonString(std::move(str))} {} + Json& operator=(JsonString str) { + ptr_.reset(new JsonString(std::move(str))); + return *this; + } + // bool + explicit Json(JsonBoolean boolean) : + ptr_{new JsonBoolean(std::move(boolean))} {} + Json& operator=(JsonBoolean boolean) { + ptr_.reset(new JsonBoolean(std::move(boolean))); + return *this; + } + // null + explicit Json(JsonNull null) : + ptr_{new JsonNull(std::move(null))} {} + Json& operator=(JsonNull null) { + ptr_.reset(new JsonNull(std::move(null))); + return *this; + } + + // copy + Json(Json const& other) : ptr_{other.ptr_} {} + Json& operator=(Json const& other); + // move + Json(Json&& other) : ptr_{std::move(other.ptr_)} {} + Json& operator=(Json&& other) { + ptr_ = std::move(other.ptr_); + return *this; + } + + /*! \brief Index Json object with a std::string, used for Json Object. */ + Json& operator[](std::string const & key) const { return (*ptr_)[key]; } + /*! \brief Index Json object with int, used for Json Array. */ + Json& operator[](int ind) const { return (*ptr_)[ind]; } + + /*! \Brief Return the reference to stored Json value. */ + Value const& GetValue() const & { return *ptr_; } + Value const& GetValue() && { return *ptr_; } + Value& GetValue() & { return *ptr_; } + + bool operator==(Json const& rhs) const { + return *ptr_ == *(rhs.ptr_); + } + + private: + std::shared_ptr ptr_; +}; + +template +bool IsA(Json const j) { + auto const& v = j.GetValue(); + return IsA(&v); +} + +namespace detail { + +// Number +template ::value>::type* = nullptr> +JsonNumber::Float& GetImpl(T& val) { // NOLINT + return val.getNumber(); +} +template ::value>::type* = nullptr> +double const& GetImpl(T& val) { // NOLINT + return val.getNumber(); +} + +// String +template ::value>::type* = nullptr> +std::string& GetImpl(T& val) { // NOLINT + return val.getString(); +} +template ::value>::type* = nullptr> +std::string const& GetImpl(T& val) { // NOLINT + return val.getString(); +} + +// Boolean +template ::value>::type* = nullptr> +bool& GetImpl(T& val) { // NOLINT + return val.getBoolean(); +} +template ::value>::type* = nullptr> +bool const& GetImpl(T& val) { // NOLINT + return val.getBoolean(); +} + +template ::value>::type* = nullptr> +std::string& GetImpl(T& val) { // NOLINT + return val.getRaw(); +} +template ::value>::type* = nullptr> +std::string const& GetImpl(T& val) { // NOLINT + return val.getRaw(); +} + +// Array +template ::value>::type* = nullptr> +std::vector& GetImpl(T& val) { // NOLINT + return val.getArray(); +} +template ::value>::type* = nullptr> +std::vector const& GetImpl(T& val) { // NOLINT + return val.getArray(); +} + +// Object +template ::value>::type* = nullptr> +std::map& GetImpl(T& val) { // NOLINT + return val.getObject(); +} +template ::value>::type* = nullptr> +std::map const& GetImpl(T& val) { // NOLINT + return val.getObject(); +} + +} // namespace detail + +/*! + * \brief Get Json value. + * + * \tparam T One of the Json value type. + * + * \param json + * \return Value contained in Json object of type T. + */ +template +auto get(U& json) -> decltype(detail::GetImpl(*Cast(&json.GetValue())))& { // NOLINT + auto& value = *Cast(&json.GetValue()); + return detail::GetImpl(value); +} + +using Object = JsonObject; +using Array = JsonArray; +using Number = JsonNumber; +using Boolean = JsonBoolean; +using String = JsonString; +using Null = JsonNull; +using Raw = JsonRaw; + +// Utils tailored for XGBoost. + +template +Object toJson(dmlc::Parameter const& param) { + Object obj; + for (auto const& kv : param.__DICT__()) { + obj[kv.first] = kv.second; + } + return obj; +} + +inline std::map fromJson(std::map const& param) { + std::map res; + for (auto const& kv : param) { + res[kv.first] = get(kv.second); + } + return res; +} + +} // namespace xgboost +#endif // XGBOOST_JSON_H_ diff --git a/include/xgboost/json_io.h b/include/xgboost/json_io.h new file mode 100644 index 000000000000..d0323354d1a3 --- /dev/null +++ b/include/xgboost/json_io.h @@ -0,0 +1,217 @@ +/*! + * Copyright (c) by Contributors 2019 + */ +#ifndef XGBOOST_JSON_IO_H_ +#define XGBOOST_JSON_IO_H_ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace xgboost { + +template +class FixedPrecisionStreamContainer : public std::basic_stringstream< + char, std::char_traits, Allocator> { + public: + FixedPrecisionStreamContainer() { + this->precision(std::numeric_limits::max_digits10); + } +}; + +using FixedPrecisionStream = FixedPrecisionStreamContainer>; + +/* + * \brief An reader that can be specialised. + * + * Why specialization? + * + * First of all, we don't like specialization. This is purely for performance concern. + * Distributed environment freqently serializes model so at some point this could be a + * bottle neck for training performance. There are many other techniques for obtaining + * better performance, but all of them requires implementing thier own allocaltor(s), + * using simd instructions. And few of them can provide a easy to modify structure + * since they assumes a fixed memory layout. + * + * In XGBoost we provide specialized logic for parsing/writing tree models and linear + * models, where dense numeric values is presented, including weights, node ids etc. + * + * Plan for removing the specialization: + * + * We plan to upstream this implementaion into DMLC as it matures. For XGBoost, most of + * the time spent in load/dump is actually `sprintf`. + * + * To enable specialization, register a keyword that corresponds to + * key in Json object. For example in: + * + * \code + * { "key": {...} } + * \endcode + * + * To add special logic for parsing {...}, one can call: + * + * \code + * JsonReader::registry("key", [](StringView str, size_t* pos){ ... return JsonRaw(...); }); + * \endcode + * + * Where str is a view of entire input string, while pos is a pointer to current position. + * The function must return a raw object. Later after obtaining a parsed object, say + * `Json obj`, you can obtain * the raw object by calling `obj["key"]' then perform the + * specialized parsing on it. + * + * See `LinearSelectRaw` and `LinearReader` in combination as an example. + */ +class JsonReader { + protected: + size_t constexpr static kMaxNumLength = + std::numeric_limits::max_digits10 + 1; + + struct SourceLocation { + size_t pos_; // current position in raw_str_ + + public: + SourceLocation() : pos_(0) {} + explicit SourceLocation(size_t pos) : pos_{pos} {} + size_t Pos() const { return pos_; } + + SourceLocation& Forward(char c = 0) { + pos_++; + return *this; + } + } cursor_; + + StringView raw_str_; + bool ignore_specialization_; + + protected: + void SkipSpaces(); + + char GetNextChar() { + if (cursor_.Pos() == raw_str_.size()) { + return -1; + } + char ch = raw_str_[cursor_.Pos()]; + cursor_.Forward(); + return ch; + } + + char PeekNextChar() { + if (cursor_.Pos() == raw_str_.size()) { + return -1; + } + char ch = raw_str_[cursor_.Pos()]; + return ch; + } + + char GetNextNonSpaceChar() { + SkipSpaces(); + return GetNextChar(); + } + + char GetChar(char c) { + char result = GetNextNonSpaceChar(); + if (result != c) { Expect(c, result); } + return result; + } + + void Error(std::string msg) const; + + // Report expected character + void Expect(char c, char got) { + std::string msg = "Expecting: \""; + msg += c; + msg += "\", got: \""; + msg += std::string {got} + " \""; + Error(msg); + } + + virtual Json ParseString(); + virtual Json ParseObject(); + virtual Json ParseArray(); + virtual Json ParseNumber(); + virtual Json ParseBoolean(); + virtual Json ParseNull(); + + Json Parse(); + + private: + using Fn = std::function; + + public: + explicit JsonReader(StringView str, bool ignore = false) : + raw_str_{str}, + ignore_specialization_{ignore} {} + explicit JsonReader(StringView str, size_t pos, bool ignore = false) : + cursor_{pos}, + raw_str_{str}, + ignore_specialization_{ignore} {} + + virtual ~JsonReader() = default; + + Json Load(); + + static std::map& getRegistry() { + static std::map set; + return set; + } + + static std::map const& registry( + std::string const& key, Fn fn) { + getRegistry()[key] = fn; + return getRegistry(); + } +}; + +class JsonWriter { + static constexpr size_t kIndentSize = 2; + FixedPrecisionStream convertor_; + + size_t n_spaces_; + std::ostream* stream_; + bool pretty_; + + public: + JsonWriter(std::ostream* stream, bool pretty) : + n_spaces_{0}, stream_{stream}, pretty_{pretty} {} + + virtual ~JsonWriter() = default; + + void NewLine() { + if (pretty_) { + *stream_ << u8"\n" << std::string(n_spaces_, ' '); + } + } + + void BeginIndent() { + n_spaces_ += kIndentSize; + } + void EndIndent() { + n_spaces_ -= kIndentSize; + } + + void Write(std::string str) { + *stream_ << str; + } + void Write(StringView str) { + stream_->write(str.c_str(), str.size()); + } + + void Save(Json json); + + virtual void Visit(JsonArray const* arr); + virtual void Visit(JsonObject const* obj); + virtual void Visit(JsonNumber const* num); + virtual void Visit(JsonRaw const* raw); + virtual void Visit(JsonNull const* null); + virtual void Visit(JsonString const* str); + virtual void Visit(JsonBoolean const* boolean); +}; +} // namespace xgboost + +#endif // XGBOOST_JSON_IO_H_ diff --git a/src/common/io.cc b/src/common/io.cc new file mode 100644 index 000000000000..025e70ffb780 --- /dev/null +++ b/src/common/io.cc @@ -0,0 +1,67 @@ +/*! + * Copyright (c) by Contributors 2019 + */ +#if defined(__unix__) +#include +#include +#include +#endif // defined(__unix__) +#include +#include + +#include "xgboost/logging.h" + +namespace xgboost { +namespace common { + +std::string LoadSequentialFile(std::string fname) { + auto OpenErr = [&fname]() { + std::string msg; + msg = "Opening " + fname + " failed: "; + msg += strerror(errno); + LOG(FATAL) << msg; + }; + auto ReadErr = [&fname]() { + std::string msg {"Error in reading file: "}; + msg += fname; + msg += ": "; + msg += strerror(errno); + LOG(FATAL) << msg; + }; + + std::string buffer; +#if defined(__unix__) + struct stat fs; + if (stat(fname.c_str(), &fs) != 0) { + OpenErr(); + } + + size_t f_size_bytes = fs.st_size; + buffer.resize(f_size_bytes+1); + int32_t fd = open(fname.c_str(), O_RDONLY); + posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL); + ssize_t bytes_read = read(fd, &buffer[0], f_size_bytes); + if (bytes_read < 0) { + close(fd); + ReadErr(); + } + close(fd); +#else + FILE *f = fopen(fname.c_str(), "r"); + if (f == NULL) { + std::string msg; + OpenErr(); + } + fseek(f, 0, SEEK_END); + auto fsize = ftell(f); + fseek(f, 0, SEEK_SET); + + buffer.resize(fsize + 1); + fread(&buffer[0], 1, fsize, f); + fclose(f); +#endif // defined(__unix__) + return buffer; +} + +} // namespace common +} // namespace xgboost diff --git a/src/common/io.h b/src/common/io.h index 29d68abec09a..6dac70c3d79d 100644 --- a/src/common/io.h +++ b/src/common/io.h @@ -70,6 +70,10 @@ class PeekableInStream : public dmlc::Stream { /*! \brief internal buffer */ std::string buffer_; }; + +// Optimized for consecutive file loading in unix like systime. +std::string LoadSequentialFile(std::string fname); + } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_IO_H_ diff --git a/src/common/json.cc b/src/common/json.cc new file mode 100644 index 000000000000..917aa1748379 --- /dev/null +++ b/src/common/json.cc @@ -0,0 +1,624 @@ +/*! + * Copyright (c) by Contributors 2019 + */ +#include + +#include "xgboost/logging.h" +#include "xgboost/json.h" +#include "xgboost/json_io.h" +#include "../common/timer.h" + +namespace xgboost { + +void JsonWriter::Save(Json json) { + json.ptr_->Save(this); +} + +void JsonWriter::Visit(JsonArray const* arr) { + this->Write("["); + auto const& vec = arr->getArray(); + size_t size = vec.size(); + for (size_t i = 0; i < size; ++i) { + auto const& value = vec[i]; + this->Save(value); + if (i != size-1) { Write(", "); } + } + this->Write("]"); +} + +void JsonWriter::Visit(JsonObject const* obj) { + this->Write("{"); + this->BeginIndent(); + this->NewLine(); + + size_t i = 0; + size_t size = obj->getObject().size(); + + for (auto& value : obj->getObject()) { + this->Write("\"" + value.first + "\": "); + this->Save(value.second); + + if (i != size-1) { + this->Write(","); + this->NewLine(); + } + i++; + } + this->EndIndent(); + this->NewLine(); + this->Write("}"); +} + +void JsonWriter::Visit(JsonNumber const* num) { + convertor_ << num->getNumber(); + auto const& str = convertor_.str(); + this->Write(StringView{str.c_str(), str.size()}); + convertor_.str(""); +} + +void JsonWriter::Visit(JsonRaw const* raw) { + auto const& str = raw->getRaw(); + this->Write(str); +} + +void JsonWriter::Visit(JsonNull const* null) { + this->Write("null"); +} + +void JsonWriter::Visit(JsonString const* str) { + std::string buffer; + buffer += '"'; + auto const& string = str->getString(); + for (size_t i = 0; i < string.length(); i++) { + const char ch = string[i]; + if (ch == '\\') { + if (i < string.size() && string[i+1] == 'u') { + buffer += "\\"; + } else { + buffer += "\\\\"; + } + } else if (ch == '"') { + buffer += "\\\""; + } else if (ch == '\b') { + buffer += "\\b"; + } else if (ch == '\f') { + buffer += "\\f"; + } else if (ch == '\n') { + buffer += "\\n"; + } else if (ch == '\r') { + buffer += "\\r"; + } else if (ch == '\t') { + buffer += "\\t"; + } else if (static_cast(ch) <= 0x1f) { + // Unit separator + char buf[8]; + snprintf(buf, sizeof buf, "\\u%04x", ch); + buffer += buf; + } else { + buffer += ch; + } + } + buffer += '"'; + this->Write(buffer); +} + +void JsonWriter::Visit(JsonBoolean const* boolean) { + bool val = boolean->getBoolean(); + if (val) { + this->Write(u8"true"); + } else { + this->Write(u8"false"); + } +} + +// Value +std::string Value::TypeStr() const { + switch (kind_) { + case ValueKind::String: return "String"; break; + case ValueKind::Number: return "Number"; break; + case ValueKind::Object: return "Object"; break; + case ValueKind::Array: return "Array"; break; + case ValueKind::Boolean: return "Boolean"; break; + case ValueKind::Null: return "Null"; break; + case ValueKind::Raw: return "Raw"; break; + case ValueKind::Integer: return "Integer"; break; + } + return ""; +} + +// Only used for keeping old compilers happy about non-reaching return +// statement. +Json& DummyJsonObject() { + static Json obj; + return obj; +} + +// Json Object +JsonObject::JsonObject(JsonObject && that) : + Value(ValueKind::Object), object_{std::move(that.object_)} {} + +JsonObject::JsonObject(std::map&& object) + : Value(ValueKind::Object), object_{std::move(object)} {} + +Json& JsonObject::operator[](std::string const & key) { + return object_[key]; +} + +Json& JsonObject::operator[](int ind) { + LOG(FATAL) << "Object of type " + << Value::TypeStr() << " can not be indexed by Integer."; + return DummyJsonObject(); +} + +bool JsonObject::operator==(Value const& rhs) const { + if (!IsA(&rhs)) { return false; } + return object_ == Cast(&rhs)->getObject(); +} + +Value& JsonObject::operator=(Value const &rhs) { + JsonObject const* casted = Cast(&rhs); + object_ = casted->getObject(); + return *this; +} + +void JsonObject::Save(JsonWriter* writer) { + writer->Visit(this); +} + +// Json String +Json& JsonString::operator[](std::string const & key) { + LOG(FATAL) << "Object of type " + << Value::TypeStr() << " can not be indexed by string."; + return DummyJsonObject(); +} + +Json& JsonString::operator[](int ind) { + LOG(FATAL) << "Object of type " + << Value::TypeStr() << " can not be indexed by Integer." + << " Please try obtaining std::string first."; + return DummyJsonObject(); +} + +bool JsonString::operator==(Value const& rhs) const { + if (!IsA(&rhs)) { return false; } + return Cast(&rhs)->getString() == str_; +} + +Value & JsonString::operator=(Value const &rhs) { + JsonString const* casted = Cast(&rhs); + str_ = casted->getString(); + return *this; +} + +// FIXME: UTF-8 parsing support. +void JsonString::Save(JsonWriter* writer) { + writer->Visit(this); +} + +// Json Array +JsonArray::JsonArray(JsonArray && that) : + Value(ValueKind::Array), vec_{std::move(that.vec_)} {} + +Json& JsonArray::operator[](std::string const & key) { + LOG(FATAL) << "Object of type " + << Value::TypeStr() << " can not be indexed by string."; + return DummyJsonObject(); +} + +Json& JsonArray::operator[](int ind) { + return vec_.at(ind); +} + +bool JsonArray::operator==(Value const& rhs) const { + if (!IsA(&rhs)) { return false; } + auto& arr = Cast(&rhs)->getArray(); + return std::equal(arr.cbegin(), arr.cend(), vec_.cbegin()); +} + +Value & JsonArray::operator=(Value const &rhs) { + JsonArray const* casted = Cast(&rhs); + vec_ = casted->getArray(); + return *this; +} + +void JsonArray::Save(JsonWriter* writer) { + writer->Visit(this); +} + +// Json raw +Json& JsonRaw::operator[](std::string const & key) { + LOG(FATAL) << "Object of type " + << Value::TypeStr() << " can not be indexed by string."; + return DummyJsonObject(); +} + +Json& JsonRaw::operator[](int ind) { + LOG(FATAL) << "Object of type " + << Value::TypeStr() << " can not be indexed by Integer."; + return DummyJsonObject(); +} + +bool JsonRaw::operator==(Value const& rhs) const { + if (!IsA(&rhs)) { return false; } + auto& arr = Cast(&rhs)->getRaw(); + return std::equal(arr.cbegin(), arr.cend(), str_.cbegin()); +} + +Value & JsonRaw::operator=(Value const &rhs) { + auto const* casted = Cast(&rhs); + str_ = casted->getRaw(); + return *this; +} + +void JsonRaw::Save(JsonWriter* writer) { + writer->Visit(this); +} + +// Json Number +Json& JsonNumber::operator[](std::string const & key) { + LOG(FATAL) << "Object of type " + << Value::TypeStr() << " can not be indexed by string."; + return DummyJsonObject(); +} + +Json& JsonNumber::operator[](int ind) { + LOG(FATAL) << "Object of type " + << Value::TypeStr() << " can not be indexed by Integer."; + return DummyJsonObject(); +} + +bool JsonNumber::operator==(Value const& rhs) const { + if (!IsA(&rhs)) { return false; } + return number_ == Cast(&rhs)->getNumber(); +} + +Value & JsonNumber::operator=(Value const &rhs) { + JsonNumber const* casted = Cast(&rhs); + number_ = casted->getNumber(); + return *this; +} + +void JsonNumber::Save(JsonWriter* writer) { + writer->Visit(this); +} + +// Json Null +Json& JsonNull::operator[](std::string const & key) { + LOG(FATAL) << "Object of type " + << Value::TypeStr() << " can not be indexed by string."; + return DummyJsonObject(); +} + +Json& JsonNull::operator[](int ind) { + LOG(FATAL) << "Object of type " + << Value::TypeStr() << " can not be indexed by Integer."; + return DummyJsonObject(); +} + +bool JsonNull::operator==(Value const& rhs) const { + if (!IsA(&rhs)) { return false; } + return true; +} + +Value & JsonNull::operator=(Value const &rhs) { + Cast(&rhs); // Checking only. + return *this; +} + +void JsonNull::Save(JsonWriter* writer) { + writer->Write("null"); +} + +// Json Boolean +Json& JsonBoolean::operator[](std::string const & key) { + LOG(FATAL) << "Object of type " + << Value::TypeStr() << " can not be indexed by string."; + return DummyJsonObject(); +} + +Json& JsonBoolean::operator[](int ind) { + LOG(FATAL) << "Object of type " + << Value::TypeStr() << " can not be indexed by Integer."; + return DummyJsonObject(); +} + +bool JsonBoolean::operator==(Value const& rhs) const { + if (!IsA(&rhs)) { return false; } + return boolean_ == Cast(&rhs)->getBoolean(); +} + +Value & JsonBoolean::operator=(Value const &rhs) { + JsonBoolean const* casted = Cast(&rhs); + boolean_ = casted->getBoolean(); + return *this; +} + +void JsonBoolean::Save(JsonWriter *writer) { + writer->Visit(this); +} + +size_t constexpr JsonReader::kMaxNumLength; + +Json JsonReader::Parse() { + while (true) { + SkipSpaces(); + char c = PeekNextChar(); + if (c == -1) { break; } + + if (c == '{') { + return ParseObject(); + } else if ( c == '[' ) { + return ParseArray(); + } else if ( c == '-' || std::isdigit(c) ) { + return ParseNumber(); + } else if ( c == '\"' ) { + return ParseString(); + } else if ( c == 't' || c == 'f' ) { + return ParseBoolean(); + } else if (c == 'n') { + return ParseNull(); + } else { + Error("Unknown construct"); + } + } + return Json(); +} + +Json JsonReader::Load() { + Json result = Parse(); + return result; +} + +void JsonReader::Error(std::string msg) const { + // just copy it. + std::istringstream str_s(raw_str_.substr(0, raw_str_.size())); + + msg += ", around character: " + std::to_string(cursor_.Pos()); + msg += '\n'; + + constexpr size_t kExtend = 8; + auto beg = cursor_.Pos() - kExtend < 0 ? 0 : cursor_.Pos() - kExtend; + auto end = cursor_.Pos() + kExtend >= raw_str_.size() ? + raw_str_.size() : cursor_.Pos() + kExtend; + + msg += " "; + msg += raw_str_.substr(beg, end - beg); + msg += '\n'; + + msg += " "; + for (size_t i = beg; i < cursor_.Pos() - 1; ++i) { + msg += '~'; + } + msg += '^'; + for (size_t i = cursor_.Pos(); i < end; ++i) { + msg += '~'; + } + LOG(FATAL) << msg; +} + +// Json class +void JsonReader::SkipSpaces() { + while (cursor_.Pos() < raw_str_.size()) { + char c = raw_str_[cursor_.Pos()]; + if (std::isspace(c)) { + cursor_.Forward(c); + } else { + break; + } + } +} + +void ParseStr(std::string const& str) { + size_t end = 0; + for (size_t i = 0; i < str.size(); ++i) { + if (str[i] == '"' && i > 0 && str[i-1] != '\\') { + end = i; + break; + } + } + std::string result; + result.resize(end); +} + +Json JsonReader::ParseString() { + char ch { GetChar('\"') }; // NOLINT + std::ostringstream output; + std::string str; + while (true) { + ch = GetNextChar(); + if (ch == '\\') { + char next = static_cast(GetNextChar()); + switch (next) { + case 'r': str += u8"\r"; break; + case 'n': str += u8"\n"; break; + case '\\': str += u8"\\"; break; + case 't': str += u8"\t"; break; + case '\"': str += u8"\""; break; + case 'u': + str += ch; + str += 'u'; + break; + default: Error("Unknown escape"); + } + } else { + if (ch == '\"') break; + str += ch; + } + if (ch == EOF || ch == '\r' || ch == '\n') { + Expect('\"', ch); + } + } + return Json(std::move(str)); +} + +Json JsonReader::ParseNull() { + char ch = GetNextNonSpaceChar(); + std::string buffer{ch}; + for (size_t i = 0; i < 3; ++i) { + buffer.push_back(GetNextChar()); + } + if (buffer != "null") { + Error("Expecting null value \"null\""); + } + return Json{JsonNull()}; +} + +Json JsonReader::ParseArray() { + std::vector data; + + char ch { GetChar('[') }; // NOLINT + while (true) { + if (PeekNextChar() == ']') { + GetChar(']'); + return Json(std::move(data)); + } + auto obj = Parse(); + data.push_back(obj); + ch = GetNextNonSpaceChar(); + if (ch == ']') break; + if (ch != ',') { + Expect(',', ch); + } + } + + return Json(std::move(data)); +} + +Json JsonReader::ParseObject() { + char ch = GetChar('{'); + + std::map data; + if (ch == '}') return Json(std::move(data)); + + while (true) { + SkipSpaces(); + ch = PeekNextChar(); + if (ch != '"') { + Expect('"', ch); + } + Json key = ParseString(); + + ch = GetNextNonSpaceChar(); + + if (ch != ':') { + Expect(':', ch); + } + + Json value; + if (!ignore_specialization_ && + (getRegistry().find(get(key)) != getRegistry().cend())) { + LOG(DEBUG) << "Using specialized parser for: " << get(key); + value = getRegistry().at(get(key))(raw_str_, &(cursor_.pos_)); + } else { + value = Parse(); + } + + data[get(key)] = std::move(value); + + ch = GetNextNonSpaceChar(); + + if (ch == '}') break; + if (ch != ',') { + Expect(',', ch); + } + } + + return Json(std::move(data)); +} + +Json JsonReader::ParseNumber() { + std::string substr = raw_str_.substr(cursor_.Pos(), kMaxNumLength); + size_t pos = 0; + + Number::Float number{0}; + number = std::stof(substr, &pos); + for (size_t i = 0; i < pos; ++i) { + GetNextChar(); + } + return Json(number); +} + +Json JsonReader::ParseBoolean() { + bool result = false; + char ch = GetNextNonSpaceChar(); + std::string const t_value = u8"true"; + std::string const f_value = u8"false"; + std::string buffer; + + if (ch == 't') { + for (size_t i = 0; i < 3; ++i) { + buffer.push_back(GetNextNonSpaceChar()); + } + if (buffer != u8"rue") { + Error("Expecting boolean value \"true\"."); + } + result = true; + } else { + for (size_t i = 0; i < 4; ++i) { + buffer.push_back(GetNextNonSpaceChar()); + } + if (buffer != u8"alse") { + Error("Expecting boolean value \"false\"."); + } + result = false; + } + return Json{JsonBoolean{result}}; +} + +// This is an ad-hoc solution for writing numeric value in standard way. We need to add +// something locale independent way of writing stream. +// FIXME(trivialfis): Remove this. +class GlobalCLocale { + std::locale ori_; + + public: + GlobalCLocale() : ori_{std::locale()} { + std::string const name {"C"}; + try { + std::locale::global(std::locale(name.c_str())); + } catch (std::runtime_error const& e) { + LOG(FATAL) << "Failed to set locale: " << name; + } + } + ~GlobalCLocale() { + std::locale::global(ori_); + } +}; + +Json Json::Load(StringView str, bool ignore_specialization) { + GlobalCLocale guard; + LOG(WARNING) << "Json serialization is still experimental." + " Output schema is subject to change in the future."; + JsonReader reader(str, ignore_specialization); + common::Timer t; + t.Start(); + Json json{reader.Load()}; + t.Stop(); + t.PrintElapsed("Json::load"); + return json; +} + +Json Json::Load(JsonReader* reader) { + GlobalCLocale guard; + common::Timer t; + t.Start(); + Json json{reader->Load()}; + t.Stop(); + t.PrintElapsed("Json::load"); + return json; +} + +void Json::Dump(Json json, std::ostream *stream, bool pretty) { + GlobalCLocale guard; + LOG(WARNING) << "Json serialization is still experimental." + " Output schema is subject to change in the future."; + JsonWriter writer(stream, true); + common::Timer t; + t.Start(); + writer.Save(json); + t.Stop(); + t.PrintElapsed("Json::dump"); +} + +Json& Json::operator=(Json const &other) = default; +} // namespace xgboost diff --git a/tests/cpp/common/test_json.cc b/tests/cpp/common/test_json.cc new file mode 100644 index 000000000000..a599c3f19fbb --- /dev/null +++ b/tests/cpp/common/test_json.cc @@ -0,0 +1,371 @@ +/*! + * Copyright (c) by Contributors 2019 + */ +#include +#include +#include +#include + +#include "xgboost/json.h" +#include "xgboost/logging.h" +#include "xgboost/json_io.h" +#include "../../../src/common/io.h" + +namespace xgboost { + +std::string GetModelStr() { + std::string model_json = R"json( +{ + "model_parameter": { + "base_score": "0.5", + "num_class": "0", + "num_feature": "10" + }, + "train_parameter": { + "debug_verbose": "0", + "disable_default_eval_metric": "0", + "dsplit": "auto", + "nthread": "0", + "seed": "0", + "seed_per_iteration": "0", + "test_flag": "", + "tree_method": "gpu_hist" + }, + "configuration": { + "booster": "gbtree", + "n_gpus": "1", + "num_class": "0", + "num_feature": "10", + "objective": "reg:linear", + "predictor": "gpu_predictor", + "tree_method": "gpu_hist", + "updater": "grow_gpu_hist" + }, + "objective": "reg:linear", + "booster": "gbtree", + "gbm": { + "GBTreeModelParam": { + "num_feature": "10", + "num_output_group": "1", + "num_roots": "1", + "size_leaf_vector": "0" + }, + "trees": [{ + "TreeParam": { + "num_feature": "10", + "num_roots": "1", + "size_leaf_vector": "0" + }, + "num_nodes": "9", + "nodes": [ + { + "depth": 0, + "gain": 31.8892, + "hess": 10, + "left": 1, + "missing": 1, + "nodeid": 0, + "right": 2, + "split_condition": 0.580717, + "split_index": 2 + }, + { + "depth": 1, + "gain": 1.5625, + "hess": 3, + "left": 5, + "missing": 5, + "nodeid": 2, + "right": 6, + "split_condition": 0.160345, + "split_index": 0 + }, + { + "depth": 2, + "gain": 0.25, + "hess": 2, + "left": 7, + "missing": 7, + "nodeid": 6, + "right": 8, + "split_condition": 0.62788, + "split_index": 0 + }, + { + "hess": 1, + "leaf": 0.375, + "nodeid": 8 + }, + { + "hess": 1, + "leaf": 0.075, + "nodeid": 7 + }, + { + "hess": 1, + "leaf": -0.075, + "nodeid": 5 + }, + { + "depth": 3, + "gain": 10.4866, + "hess": 7, + "left": 3, + "missing": 3, + "nodeid": 1, + "right": 4, + "split_condition": 0.238748, + "split_index": 1 + }, + { + "hess": 6, + "leaf": 1.54286, + "nodeid": 4 + }, + { + "hess": 1, + "leaf": 0.225, + "nodeid": 3 + } + ], + "leaf_vector": [] + }], + "tree_info": [0] + } +} +)json"; + return model_json; +} + +TEST(Json, TestParseObject) { + std::string str = R"obj({"TreeParam" : {"num_feature": "10"}})obj"; + auto json = Json::Load(StringView{str.c_str(), str.size()}); +} + +TEST(Json, ParseNumber) { + std::string str = "31.8892"; + auto json = Json::Load(StringView{str.c_str(), str.size()}); + ASSERT_NEAR(get(json), 31.8892f, kRtEps); +} + +TEST(Json, ParseArray) { + std::string str = R"json( +{ + "nodes": [ + { + "depth": 3, + "gain": 10.4866, + "hess": 7, + "left": 3, + "missing": 3, + "nodeid": 1, + "right": 4, + "split_condition": 0.238748, + "split_index": 1 + }, + { + "hess": 6, + "leaf": 1.54286, + "nodeid": 4 + }, + { + "hess": 1, + "leaf": 0.225, + "nodeid": 3 + } + ] +} +)json"; + auto json = Json::Load(StringView{str.c_str(), str.size()}, true); + json = json["nodes"]; + std::vector arr = get(json); + ASSERT_EQ(arr.size(), 3); + Json v0 = arr[0]; + ASSERT_EQ(get(v0["depth"]), 3); +} + +TEST(Json, Null) { + Json json {JsonNull()}; + std::stringstream ss; + Json::Dump(json, &ss); + ASSERT_EQ(ss.str(), "null"); + + std::string null_input {R"null({"key": null })null"}; + + json = Json::Load({null_input.c_str(), null_input.size()}); + ASSERT_TRUE(IsA(json["key"])); +} + +TEST(Json, EmptyArray) { + std::string str = R"json( +{ + "leaf_vector": [] +} +)json"; + std::istringstream iss(str); + auto json = Json::Load(StringView{str.c_str(), str.size()}, true); + auto arr = get(json["leaf_vector"]); + ASSERT_EQ(arr.size(), 0); +} + +TEST(Json, Boolean) { + std::string str = R"json( +{ + "left_child": true, + "right_child": false +} +)json"; + Json j {Json::Load(StringView{str.c_str(), str.size()}, true)}; + ASSERT_EQ(get(j["left_child"]), true); + ASSERT_EQ(get(j["right_child"]), false); +} + +TEST(Json, Indexing) { + auto str = GetModelStr(); + JsonReader reader(StringView{str.c_str(), str.size()}, true); + Json j {Json::Load(&reader)}; + auto& value_1 = j["model_parameter"]; + auto& value = value_1["base_score"]; + std::string result = Cast(&value.GetValue())->getString(); + + ASSERT_EQ(result, "0.5"); +} + +TEST(Json, AssigningObjects) { + { + Json json; + json = JsonObject(); + json["Okay"] = JsonArray(); + ASSERT_EQ(get(json["Okay"]).size(), 0); + } + + { + std::map objects; + Json json_objects { JsonObject() }; + std::vector arr_0 (1, Json(3.3)); + json_objects["tree_parameters"] = JsonArray(arr_0); + std::vector json_arr = get(json_objects["tree_parameters"]); + ASSERT_NEAR(get(json_arr[0]), 3.3f, kRtEps); + } + + { + Json json_object { JsonObject() }; + auto str = JsonString("1"); + auto& k = json_object["1"]; + k = str; + auto& m = json_object["1"]; + std::string value = get(m); + ASSERT_EQ(value, "1"); + ASSERT_EQ(get(json_object["1"]), "1"); + } +} + +TEST(Json, AssigningArray) { + Json json; + json = JsonArray(); + std::vector tmp_0 {Json(Number(1)), Json(Number(2))}; + json = tmp_0; + std::vector tmp_1 {Json(Number(3))}; + get(json) = tmp_1; + std::vector res = get(json); + ASSERT_EQ(get(res[0]), 3); +} + +TEST(Json, AssigningNumber) { + { + // right value + Json json = Json{ Number(4) }; + get(json) = 15; + ASSERT_EQ(get(json), 15); + } + + { + // left value ref + Json json = Json{ Number(4) }; + Number::Float& ref = get(json); + ref = 15; + ASSERT_EQ(get(json), 15); + } + + { + // left value + Json json = Json{ Number(4) }; + double value = get(json); + ASSERT_EQ(value, 4); + value = 15; // NOLINT + ASSERT_EQ(get(json), 4); + } +} + +TEST(Json, AssigningString) { + { + // right value + Json json = Json{ String("str") }; + get(json) = "modified"; + ASSERT_EQ(get(json), "modified"); + } + + { + // left value ref + Json json = Json{ String("str") }; + std::string& ref = get(json); + ref = "modified"; + ASSERT_EQ(get(json), "modified"); + } + + { + // left value + Json json = Json{ String("str") }; + std::string value = get(json); + value = "modified"; + ASSERT_EQ(get(json), "str"); + } +} + +TEST(Json, LoadDump) { + std::string buffer = GetModelStr(); + Json origin {Json::Load(StringView{buffer.c_str(), buffer.size()}, true)}; + + dmlc::TemporaryDirectory tempdir; + auto const& path = tempdir.path + "test_model_dump"; + + std::ofstream fout (path); + Json::Dump(origin, &fout); + fout.close(); + + buffer = common::LoadSequentialFile(path); + Json load_back {Json::Load(StringView(buffer.c_str(), buffer.size()), true)}; + + ASSERT_EQ(load_back, origin); +} + +// For now Json is quite ignorance about unicode. +TEST(Json, CopyUnicode) { + std::string json_str = R"json( +{"m": ["\ud834\udd1e", "\u20ac", "\u0416", "\u00f6"]} +)json"; + Json loaded {Json::Load(StringView{json_str.c_str(), json_str.size()}, true)}; + + std::stringstream ss_1; + Json::Dump(loaded, &ss_1); + + std::string dumped_string = ss_1.str(); + ASSERT_NE(dumped_string.find("\\u20ac"), std::string::npos); +} + +TEST(Json, WrongCasts) { + { + Json json = Json{ String{"str"} }; + ASSERT_ANY_THROW(get(json)); + } + { + Json json = Json{ Array{ std::vector{ Json{ Number{1} } } } }; + ASSERT_ANY_THROW(get(json)); + } + { + Json json = Json{ Object{std::map{ + {"key", Json{String{"value"}}}} } }; + ASSERT_ANY_THROW(get(json)); + } +} +} // namespace xgboost