Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept string for ArrayInterface constructor. #5799

Merged
merged 1 commit into from
Jun 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions src/data/array_interface.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*!
* Copyright 2019 by Contributors
* \file array_interface.h
* \brief Basic structure holding a reference to arrow columnar data format.
* \brief View of __array_interface__
*/
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
Expand All @@ -11,6 +11,7 @@
#include <string>
#include <utility>

#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/json.h"
#include "xgboost/logging.h"
Expand Down Expand Up @@ -113,6 +114,7 @@ class ArrayInterfaceHandler {
get<Array const>(
obj.at("data"))
.at(0))));
CHECK(p_data);
return p_data;
}

Expand Down Expand Up @@ -186,7 +188,7 @@ class ArrayInterfaceHandler {
return 0;
}

static std::pair<size_t, size_t> ExtractShape(
static std::pair<bst_row_t, bst_feature_t> ExtractShape(
std::map<std::string, Json> const& column) {
auto j_shape = get<Array const>(column.at("shape"));
auto typestr = get<String const>(column.at("typestr"));
Expand All @@ -201,12 +203,12 @@ class ArrayInterfaceHandler {
}

if (j_shape.size() == 1) {
return {static_cast<size_t>(get<Integer const>(j_shape.at(0))), 1};
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))), 1};
} else {
CHECK_EQ(j_shape.size(), 2)
<< "Only 1D or 2-D arrays currently supported.";
return {static_cast<size_t>(get<Integer const>(j_shape.at(0))),
static_cast<size_t>(get<Integer const>(j_shape.at(1)))};
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))),
static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))};
}
}
template <typename T>
Expand All @@ -219,7 +221,6 @@ class ArrayInterfaceHandler {
CHECK_EQ(typestr.at(2), static_cast<char>(sizeof(T) + 48))
<< "Input data type and typestr mismatch. typestr: " << typestr;


auto shape = ExtractShape(column);

T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
Expand All @@ -229,10 +230,8 @@ class ArrayInterfaceHandler {

// A view over __array_interface__
class ArrayInterface {
public:
ArrayInterface() = default;
explicit ArrayInterface(std::map<std::string, Json> const &column,
bool allow_mask = true) {
void Initialize(std::map<std::string, Json> const &column,
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
bool allow_mask = true) {
ArrayInterfaceHandler::Validate(column);
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
CHECK(data) << "Column is null";
Expand Down Expand Up @@ -263,6 +262,27 @@ class ArrayInterface {
this->CheckType();
}

public:
ArrayInterface() = default;
explicit ArrayInterface(std::string const& str, bool allow_mask = true) {
auto jinterface = Json::Load({str.c_str(), str.size()});
if (IsA<Object>(jinterface)) {
this->Initialize(get<Object const>(jinterface), allow_mask);
return;
}
if (IsA<Array>(jinterface)) {
CHECK_EQ(get<Array const>(jinterface).size(), 1)
<< "Column: " << ArrayInterfaceErrors::Dimension(1);
this->Initialize(get<Object const>(get<Array const>(jinterface)[0]), allow_mask);
return;
}
}

explicit ArrayInterface(std::map<std::string, Json> const &column,
bool allow_mask = true) {
this->Initialize(column, allow_mask);
}

void CheckType() const {
if (type[1] == 'f' && type[2] == '4') {
return;
Expand Down Expand Up @@ -291,6 +311,7 @@ class ArrayInterface {
}

XGBOOST_DEVICE float GetElement(size_t idx) const {
SPAN_CHECK(idx < num_cols * num_rows);
if (type[1] == 'f' && type[2] == '4') {
return reinterpret_cast<float*>(data)[idx];
} else if (type[1] == 'f' && type[2] == '8') {
Expand Down Expand Up @@ -318,8 +339,8 @@ class ArrayInterface {
}

RBitField8 valid;
int32_t num_rows;
int32_t num_cols;
bst_row_t num_rows;
bst_feature_t num_cols;
void* data;
char type[3];
};
Expand Down
2 changes: 1 addition & 1 deletion src/data/data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
auto const& j_arr = get<Array>(j_interface);
CHECK_EQ(j_arr.size(), 1)
<< "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1);
ArrayInterface array_interface(get<Object const>(j_arr[0]));
ArrayInterface array_interface(interface_str);
std::string key{c_key};
CHECK(!array_interface.valid.Data())
<< "Meta info " << key << " should be dense, found validity mask";
Expand Down
51 changes: 51 additions & 0 deletions tests/cpp/data/test_array_interface.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/host_device_vector.h>
#include "../helpers.h"
#include "../../../src/data/array_interface.h"

namespace xgboost {
TEST(ArrayInterface, Initialize) {
size_t constexpr kRows = 10, kCols = 10;
HostDeviceVector<float> storage;
auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage);
auto arr_interface = ArrayInterface(array);
ASSERT_EQ(arr_interface.num_rows, kRows);
ASSERT_EQ(arr_interface.num_cols, kCols);
ASSERT_EQ(arr_interface.data, storage.ConstHostPointer());
}

TEST(ArrayInterface, Error) {
constexpr size_t kRows = 16, kCols = 10;
Json column { Object() };
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
column["shape"] = Array(j_shape);
std::vector<Json> j_data {
Json(Integer(reinterpret_cast<Integer::Int>(nullptr))),
Json(Boolean(false))};

auto const& column_obj = get<Object>(column);
// missing version
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error);
column["version"] = Integer(static_cast<Integer::Int>(1));
// missing data
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error);
column["data"] = j_data;
// missing typestr
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error);
column["typestr"] = String("<f4");
// nullptr is not valid
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error);

HostDeviceVector<float> storage;
auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage);
j_data = {
Json(Integer(reinterpret_cast<Integer::Int>(storage.ConstHostPointer()))),
Json(Boolean(false))};
column["data"] = j_data;
EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj));
}

} // namespace xgboost
8 changes: 7 additions & 1 deletion tests/cpp/helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,13 @@ Json RandomDataGenerator::ArrayInterfaceImpl(HostDeviceVector<float> *storage,
this->GenerateDense(storage);
Json array_interface {Object()};
array_interface["data"] = std::vector<Json>(2);
array_interface["data"][0] = Integer(reinterpret_cast<int64_t>(storage->DevicePointer()));
if (storage->DeviceCanRead()) {
array_interface["data"][0] =
Integer(reinterpret_cast<int64_t>(storage->ConstDevicePointer()));
} else {
array_interface["data"][0] =
Integer(reinterpret_cast<int64_t>(storage->ConstHostPointer()));
}
array_interface["data"][1] = Boolean(false);

array_interface["shape"] = std::vector<Json>(2);
Expand Down