Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow using string view to find JSON value. #8332

Merged
merged 1 commit into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions include/xgboost/json.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,29 +187,31 @@ using I32Array = JsonTypedArray<int32_t, Value::ValueKind::kI32Array>;
using I64Array = JsonTypedArray<int64_t, Value::ValueKind::kI64Array>;

class JsonObject : public Value {
std::map<std::string, Json> object_;
public:
using Map = std::map<std::string, Json, std::less<>>;

private:
Map object_;

public:
JsonObject() : Value(ValueKind::kObject) {}
JsonObject(std::map<std::string, Json>&& object) noexcept; // NOLINT
JsonObject(Map&& object) noexcept; // NOLINT
JsonObject(JsonObject const& that) = delete;
JsonObject(JsonObject && that) noexcept;
JsonObject(JsonObject&& that) noexcept;

void Save(JsonWriter* writer) const override;

// silent the partial oveeridden warning
Json& operator[](int ind) override { return Value::operator[](ind); }
Json& operator[](std::string const& key) override { return object_[key]; }

std::map<std::string, Json> const& GetObject() && { return object_; }
std::map<std::string, Json> const& GetObject() const & { return object_; }
std::map<std::string, Json> & GetObject() & { return object_; }
Map const& GetObject() && { return object_; }
Map const& GetObject() const& { return object_; }
Map& GetObject() & { return object_; }

bool operator==(Value const& rhs) const override;

static bool IsClassOf(Value const* value) {
return value->Type() == ValueKind::kObject;
}
static bool IsClassOf(Value const* value) { return value->Type() == ValueKind::kObject; }
~JsonObject() override = default;
};

Expand Down Expand Up @@ -559,16 +561,13 @@ std::vector<T> const& GetImpl(JsonTypedArray<T, kind> const& val) {
}

// Object
template <typename T,
typename std::enable_if<
std::is_same<T, JsonObject>::value>::type* = nullptr>
std::map<std::string, Json>& GetImpl(T& val) { // NOLINT
template <typename T, typename std::enable_if<std::is_same<T, JsonObject>::value>::type* = nullptr>
JsonObject::Map& GetImpl(T& val) { // NOLINT
return val.GetObject();
}
template <typename T,
typename std::enable_if<
std::is_same<T, JsonObject const>::value>::type* = nullptr>
std::map<std::string, Json> const& GetImpl(T& val) { // NOLINT
typename std::enable_if<std::is_same<T, JsonObject const>::value>::type* = nullptr>
JsonObject::Map const& GetImpl(T& val) { // NOLINT
return val.GetObject();
}
} // namespace detail
Expand Down
11 changes: 11 additions & 0 deletions include/xgboost/string_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#ifndef XGBOOST_STRING_VIEW_H_
#define XGBOOST_STRING_VIEW_H_
#include <xgboost/logging.h>
#include <xgboost/span.h>

#include <algorithm>
#include <iterator>
Expand All @@ -19,6 +20,7 @@ struct StringView {
size_t size_{0};

public:
using value_type = CharT; // NOLINT
using iterator = const CharT*; // NOLINT
using const_iterator = iterator; // NOLINT
using reverse_iterator = std::reverse_iterator<const_iterator>; // NOLINT
Expand Down Expand Up @@ -77,5 +79,14 @@ inline bool operator==(StringView l, StringView r) {
}

inline bool operator!=(StringView l, StringView r) { return !(l == r); }

inline bool operator<(StringView l, StringView r) {
return common::Span<StringView::value_type const>{l.c_str(), l.size()} <
common::Span<StringView::value_type const>{r.c_str(), r.size()};
}

inline bool operator<(std::string const& l, StringView r) { return StringView{l} < r; }

inline bool operator<(StringView l, std::string const& r) { return l < StringView{r}; }
} // namespace xgboost
#endif // XGBOOST_STRING_VIEW_H_
6 changes: 3 additions & 3 deletions src/c_api/c_api_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ void TypeCheck(Json const &value, StringView name) {
}

template <typename JT>
auto const &RequiredArg(Json const &in, std::string const &key, StringView func) {
auto const &RequiredArg(Json const &in, StringView key, StringView func) {
auto const &obj = get<Object const>(in);
auto it = obj.find(key);
if (it == obj.cend() || IsA<Null>(it->second)) {
Expand All @@ -269,11 +269,11 @@ auto const &RequiredArg(Json const &in, std::string const &key, StringView func)
}

template <typename JT, typename T>
auto const &OptionalArg(Json const &in, std::string const &key, T const &dft) {
auto const &OptionalArg(Json const &in, StringView key, T const &dft) {
auto const &obj = get<Object const>(in);
auto it = obj.find(key);
if (it != obj.cend() && !IsA<Null>(it->second)) {
TypeCheck<JT>(it->second, StringView{key});
TypeCheck<JT>(it->second, key);
return get<std::remove_const_t<JT> const>(it->second);
}
return dft;
Expand Down
8 changes: 4 additions & 4 deletions src/common/json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ JsonObject::JsonObject(JsonObject&& that) noexcept : Value(ValueKind::kObject) {
std::swap(that.object_, this->object_);
}

JsonObject::JsonObject(std::map<std::string, Json>&& object) noexcept
: Value(ValueKind::kObject), object_{std::forward<std::map<std::string, Json>>(object)} {}
JsonObject::JsonObject(Map&& object) noexcept
: Value(ValueKind::kObject), object_{std::forward<Map>(object)} {}

bool JsonObject::operator==(Value const& rhs) const {
if (!IsA<JsonObject>(&rhs)) {
Expand Down Expand Up @@ -502,7 +502,7 @@ Json JsonReader::ParseArray() {
Json JsonReader::ParseObject() {
GetConsecutiveChar('{');

std::map<std::string, Json> data;
Object::Map data;
SkipSpaces();
char ch = PeekNextChar();

Expand Down Expand Up @@ -777,7 +777,7 @@ std::string UBJReader::DecodeStr() {

Json UBJReader::ParseObject() {
auto marker = PeekNextChar();
std::map<std::string, Json> results;
Object::Map results;

while (marker != '}') {
auto str = this->DecodeStr();
Expand Down
16 changes: 8 additions & 8 deletions src/data/array_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class ArrayInterfaceHandler {
enum Type : std::int8_t { kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };

template <typename PtrType>
static PtrType GetPtrFromArrayData(std::map<std::string, Json> const &obj) {
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
auto data_it = obj.find("data");
if (data_it == obj.cend()) {
LOG(FATAL) << "Empty data passed in.";
Expand All @@ -109,7 +109,7 @@ class ArrayInterfaceHandler {
return p_data;
}

static void Validate(std::map<std::string, Json> const &array) {
static void Validate(Object::Map const &array) {
auto version_it = array.find("version");
if (version_it == array.cend()) {
LOG(FATAL) << "Missing `version' field for array interface";
Expand All @@ -136,7 +136,7 @@ class ArrayInterfaceHandler {

// Find null mask (validity mask) field
// Mask object is also an array interface, but with different requirements.
static size_t ExtractMask(std::map<std::string, Json> const &column,
static size_t ExtractMask(Object::Map const &column,
common::Span<RBitField8::value_type> *p_out) {
auto &s_mask = *p_out;
if (column.find("mask") != column.cend()) {
Expand Down Expand Up @@ -208,7 +208,7 @@ class ArrayInterfaceHandler {
}

template <int32_t D>
static void ExtractShape(std::map<std::string, Json> const &array, size_t (&out_shape)[D]) {
static void ExtractShape(Object::Map const &array, size_t (&out_shape)[D]) {
auto const &j_shape = get<Array const>(array.at("shape"));
std::vector<size_t> shape_arr(j_shape.size(), 0);
std::transform(j_shape.cbegin(), j_shape.cend(), shape_arr.begin(),
Expand All @@ -229,7 +229,7 @@ class ArrayInterfaceHandler {
* \brief Extracts the optiona `strides' field and returns whether the array is c-contiguous.
*/
template <int32_t D>
static bool ExtractStride(std::map<std::string, Json> const &array, size_t itemsize,
static bool ExtractStride(Object::Map const &array, size_t itemsize,
size_t (&shape)[D], size_t (&stride)[D]) {
auto strides_it = array.find("strides");
// No stride is provided
Expand Down Expand Up @@ -272,7 +272,7 @@ class ArrayInterfaceHandler {
return std::equal(stride_tmp, stride_tmp + D, stride);
}

static void *ExtractData(std::map<std::string, Json> const &array, size_t size) {
static void *ExtractData(Object::Map const &array, size_t size) {
Validate(array);
void *p_data = ArrayInterfaceHandler::GetPtrFromArrayData<void *>(array);
if (!p_data) {
Expand Down Expand Up @@ -378,7 +378,7 @@ class ArrayInterface {
* to a vector of size n_samples. For for inputs like weights, this should be a 1
* dimension column vector even though user might provide a matrix.
*/
void Initialize(std::map<std::string, Json> const &array) {
void Initialize(Object::Map const &array) {
ArrayInterfaceHandler::Validate(array);

auto typestr = get<String const>(array.at("typestr"));
Expand Down Expand Up @@ -413,7 +413,7 @@ class ArrayInterface {

public:
ArrayInterface() = default;
explicit ArrayInterface(std::map<std::string, Json> const &array) { this->Initialize(array); }
explicit ArrayInterface(Object::Map const &array) { this->Initialize(array); }

explicit ArrayInterface(Json const &array) {
if (IsA<Object>(array)) {
Expand Down
22 changes: 10 additions & 12 deletions src/metric/auc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ struct DeviceAUCCache {
};

template <bool is_multi>
void InitCacheOnce(common::Span<float const> predts, int32_t device,
std::shared_ptr<DeviceAUCCache>* p_cache) {
void InitCacheOnce(common::Span<float const> predts, std::shared_ptr<DeviceAUCCache> *p_cache) {
auto& cache = *p_cache;
if (!cache) {
cache.reset(new DeviceAUCCache);
Expand Down Expand Up @@ -167,7 +166,7 @@ std::tuple<double, double, double>
GPUBinaryROCAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
auto &cache = *p_cache;
InitCacheOnce<false>(predts, device, p_cache);
InitCacheOnce<false>(predts, p_cache);

/**
* Create sorted index for each class
Expand Down Expand Up @@ -196,8 +195,7 @@ void Transpose(common::Span<float const> in, common::Span<float> out, size_t m,
}

double ScaleClasses(common::Span<double> results, common::Span<double> local_area,
common::Span<double> tp, common::Span<double> auc,
std::shared_ptr<DeviceAUCCache> cache, size_t n_classes) {
common::Span<double> tp, common::Span<double> auc, size_t n_classes) {
dh::XGBDeviceAllocator<char> alloc;
if (collective::IsDistributed()) {
int32_t device = dh::CurrentDevice();
Expand Down Expand Up @@ -330,7 +328,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, int32_t device, common::Span<ui
auto local_area = d_results.subspan(0, n_classes);
auto tp = d_results.subspan(2 * n_classes, n_classes);
auto auc = d_results.subspan(3 * n_classes, n_classes);
return ScaleClasses(d_results, local_area, tp, auc, cache, n_classes);
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
}

/**
Expand Down Expand Up @@ -434,7 +432,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, int32_t device, common::Span<ui
tp[c] = 1.0f;
}
});
return ScaleClasses(d_results, local_area, tp, auc, cache, n_classes);
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
}

void MultiClassSortedIdx(common::Span<float const> predts,
Expand All @@ -458,7 +456,7 @@ double GPUMultiClassROCAUC(common::Span<float const> predts,
std::shared_ptr<DeviceAUCCache> *p_cache,
size_t n_classes) {
auto& cache = *p_cache;
InitCacheOnce<true>(predts, device, p_cache);
InitCacheOnce<true>(predts, p_cache);

/**
* Create sorted index for each class
Expand Down Expand Up @@ -486,7 +484,7 @@ std::pair<double, uint32_t>
GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
auto& cache = *p_cache;
InitCacheOnce<false>(predts, device, p_cache);
InitCacheOnce<false>(predts, p_cache);

dh::caching_device_vector<bst_group_t> group_ptr(info.group_ptr_);
dh::XGBCachingDeviceAllocator<char> alloc;
Expand Down Expand Up @@ -606,7 +604,7 @@ std::tuple<double, double, double>
GPUBinaryPRAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
auto& cache = *p_cache;
InitCacheOnce<false>(predts, device, p_cache);
InitCacheOnce<false>(predts, p_cache);

/**
* Create sorted index for each class
Expand Down Expand Up @@ -647,7 +645,7 @@ double GPUMultiClassPRAUC(common::Span<float const> predts,
std::shared_ptr<DeviceAUCCache> *p_cache,
size_t n_classes) {
auto& cache = *p_cache;
InitCacheOnce<true>(predts, device, p_cache);
InitCacheOnce<true>(predts, p_cache);

/**
* Create sorted index for each class
Expand Down Expand Up @@ -827,7 +825,7 @@ GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info,
}

auto &cache = *p_cache;
InitCacheOnce<false>(predts, device, p_cache);
InitCacheOnce<false>(predts, p_cache);

dh::device_vector<bst_group_t> group_ptr(info.group_ptr_.size());
thrust::copy(info.group_ptr_.begin(), info.group_ptr_.end(), group_ptr.begin());
Expand Down
3 changes: 1 addition & 2 deletions tests/cpp/common/test_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,8 +499,7 @@ TEST(Json, WrongCasts) {
ASSERT_ANY_THROW(get<Number>(json));
}
{
Json json = Json{ Object{std::map<std::string, Json>{
{"key", Json{String{"value"}}}} } };
Json json = Json{Object{{{"key", Json{String{"value"}}}}}};
ASSERT_ANY_THROW(get<Number>(json));
}
}
Expand Down
Loading