Skip to content

Commit

Permalink
Initial commit for jvm.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed May 21, 2020
1 parent 094a3c7 commit c1d2c27
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 9 deletions.
2 changes: 1 addition & 1 deletion include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter,
* \brief Set data on a DMatrix proxy.
*
* \param handle a DMatrix proxy created by XGProxyDMatrixCreate
* \param interface string for CUDA array interface.
* \param interface Null terminated string for CUDA array interface.
*/
XGB_DLL int XGDMatrixSetDataCudaArrayInterface(
DMatrixHandle handle,
Expand Down
11 changes: 11 additions & 0 deletions jvm-packages/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ find_package(JNI REQUIRED)

add_library(xgboost4j SHARED
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j.cu
${XGBOOST_OBJ_SOURCES})
target_include_directories(xgboost4j
PRIVATE
Expand All @@ -11,6 +12,16 @@ target_include_directories(xgboost4j
${PROJECT_SOURCE_DIR}/dmlc-core/include
${PROJECT_SOURCE_DIR}/rabit/include)

if (USE_CUDA)
target_include_directories(xgboost4j PRIVATE ${xgboost_SOURCE_DIR}/cub/)
target_compile_options(xgboost4j PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda>
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>
$<$<COMPILE_LANGUAGE:CUDA>:-lineinfo>
$<$<AND:$<NOT:$<CXX_COMPILER_ID:MSVC>>,$<COMPILE_LANGUAGE:CUDA>>:--std=c++11>
$<$<COMPILE_LANGUAGE:CUDA>:${GEN_CODE}>)
endif (USE_CUDA)

set_output_directory(xgboost4j ${PROJECT_SOURCE_DIR}/lib)
set_target_properties(
xgboost4j PROPERTIES
Expand Down
78 changes: 70 additions & 8 deletions jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include <vector>
#include <string>

#include "../../../../src/data/array_interface.h"

#define JVM_CHECK_CALL(__expr) \
{ \
int __errcode = (__expr); \
Expand All @@ -43,12 +45,14 @@ void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
}

// global JVM
static JavaVM* global_jvm = nullptr;
JavaVM*& GlobalJvm() {
static JavaVM* vm;
return vm;
}

// overrides JNI on load
jint JNI_OnLoad(JavaVM *vm, void *reserved) {
global_jvm = vm;
GlobalJvm() = vm;
return JNI_VERSION_1_6;
}

Expand All @@ -58,9 +62,9 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
DataHolderHandle set_function_handle) {
jobject jiter = static_cast<jobject>(data_handle);
JNIEnv* jenv;
int jni_status = global_jvm->GetEnv((void **)&jenv, JNI_VERSION_1_6);
int jni_status = GlobalJvm()->GetEnv((void **)&jenv, JNI_VERSION_1_6);
if (jni_status == JNI_EDETACHED) {
global_jvm->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr);
GlobalJvm()->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr);
} else {
CHECK(jni_status == JNI_OK);
}
Expand Down Expand Up @@ -148,13 +152,13 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
jenv->DeleteLocalRef(iterClass);
// only detach if it is a async call.
if (jni_status == JNI_EDETACHED) {
global_jvm->DetachCurrentThread();
GlobalJvm()->DetachCurrentThread();
}
return ret_value;
} catch(dmlc::Error e) {
} catch(dmlc::Error const& e) {
// only detach if it is a async call.
if (jni_status == JNI_EDETACHED) {
global_jvm->DetachCurrentThread();
GlobalJvm()->DetachCurrentThread();
}
LOG(FATAL) << e.what();
return -1;
Expand Down Expand Up @@ -952,3 +956,61 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce

return 0;
}

jint XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jiter,
jfloat jmissing,
jint jmax_bin, jint jnthread,
jlongArray jout);

JNIEXPORT jint JNICALL
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback( // NOLINT
JNIEnv *jenv, jclass jcls, jobject jiter, jfloat jmissing, jint jmax_bin,
jint jnthread, jlongArray jout) {
return XGDeviceQuantileDMatrixCreateFromCallbackImpl(
jenv, jcls, jiter, jmissing, jmax_bin, jnthread, jout);
}

extern int XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(
char const *c_json_strs, float missing, int nthread, int max_bin,
DMatrixHandle *out);

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromArrayInterfaceColumns
* Signature: (Ljava/lang/String;FI[J)I
*/
JNIEXPORT jint JNICALL
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns(
JNIEnv *jenv, jclass jcls, jstring jjson_columns, jint jmax_bin,
jfloat jmissing, jint jnthread, jlongArray jout) {
DMatrixHandle result;
const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, nullptr);
int ret = XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(
cjson_columns, jmissing, jnthread, jmax_bin, &result);
JVM_CHECK_CALL(ret);
if (cjson_columns) {
jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns);
}
setHandle(jenv, jout, result);
return ret;
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSetInfoFromInterface
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
*/
// JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
// (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jstring jjson_columns) {
// DMatrixHandle handle = (DMatrixHandle) jhandle;
// const char* field = jenv->GetStringUTFChars(jfield, 0);
// const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, 0);

// int ret = XGDMatrixSetInfoFromInterface(handle, field, cjson_columns);
// JVM_CHECK_CALL(ret);
// //release
// if (field) jenv->ReleaseStringUTFChars(jfield, field);
// if (cjson_columns) jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns);
// return ret;
// }
220 changes: 220 additions & 0 deletions jvm-packages/xgboost4j/src/native/xgboost4j.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
#include <jni.h>
#include <thrust/device_vector.h>

#include "../../../../src/data/array_interface.h"
#include "../../../../src/common/device_helpers.cuh"

extern JavaVM*& GlobalJvm();
extern void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle);

namespace xgboost {
namespace spark {

template <typename T>
T const* RawPtr(std::vector<T> const& data) {
return data.data();
}

template <typename T>
T const* RawPtr(thrust::device_vector<T> const& data) {
return data.data().get();
}

template <typename DCont, typename VCont>
void CopyInterface(std::vector<xgboost::ArrayInterface> const& interface_arr,
cudaMemcpyKind kind,
std::vector<DCont> *p_data,
std::vector<VCont>* p_mask,
xgboost::Json* p_out) {
p_data->resize(interface_arr.size());
p_mask->resize(interface_arr.size());
for (size_t c = 0; c < interface_arr.size(); ++c) {
auto const& interface = interface_arr.at(c);
size_t element_size = interface.type[2];
size_t size = element_size * interface.num_rows * interface.num_cols;

auto& data = (*p_data)[c];
auto& mask = (*p_mask)[c];
data.resize(size);
cudaMemcpyAsync(interface.data, RawPtr(data),
size, cudaMemcpyDeviceToHost);

mask.resize(interface.valid.Size());
cudaMemcpyAsync(interface.valid.Data(), RawPtr(mask),
interface.valid.Size(), kind);

auto& out = (*p_out)[c];
out["data"] = Integer(reinterpret_cast<Integer::Int>(RawPtr(data)));
out["shape"] = Array(
std::vector<Json>{Json(Integer(interface.num_rows)),
Json(Integer(interface.num_cols))});

out["mask"] = Object();
out["mask"]["data"] = Integer(reinterpret_cast<Integer::Int>(RawPtr(mask)));
out["mask"]["shape"] = Array(
std::vector<Json>{Json(Integer(interface.num_rows)),
Json(Integer(interface.num_cols))});
}
}

namespace xgboost {
namespace spark {
template <typename DCont, typename VCont>
struct ColumnContainer {
std::vector<std::vector<DCont>> data;
std::vector<std::vector<VCont>> valid;
std::vector<Json> interfaces;

void Resize(size_t n) {
data.resize(n);
valid.resize(n);
interfaces.resize(n);
}
};

using ColumnHost = ColumnContainer<std::vector<char>, std::vector<std::uint8_t>>;

class DataIteratorProxy {
DMatrixHandle proxy_;
JNIEnv *jenv_;
int jni_status_;
jobject jiter_;

ColumnHost host_columns_;

size_t it_ {0};
size_t n_batches_ {0};

public:
explicit DataIteratorProxy(jobject jiter) : jiter_{jiter} {
XGProxyDMatrixCreate(&proxy_);
jni_status_ =
GlobalJvm()->GetEnv(reinterpret_cast<void **>(&jenv_), JNI_VERSION_1_6);
this->InitializeLoop();
this->Reset();
}
~DataIteratorProxy() {
XGDMatrixFree(proxy_);
}

DMatrixHandle GetDMatrixHandle() const { return proxy_; }

void InitializeLoop() {
while (true) {
try {
jclass iterClass = jenv_->FindClass("java/util/Iterator");
jmethodID has_next = jenv_->GetMethodID(iterClass, "hasNext", "()Z");
jmethodID next =
jenv_->GetMethodID(iterClass, "next", "()Ljava/lang/Object;");
if (jenv_->CallBooleanMethod(jiter_, has_next)) {
jobject batch = jenv_->CallObjectMethod(jiter_, next);
if (!batch) {
CHECK(jenv_->ExceptionOccurred());
jenv_->ExceptionDescribe();
}
jclass batch_class = jenv_->GetObjectClass(batch);
CHECK(batch_class);
jmethodID get_array_interface = jenv_->GetMethodID(
batch_class, "getArrayInterface", "()Ljava/lang/Object;");
CHECK(get_array_interface);

auto jinterface = static_cast<jstring>(
jenv_->CallObjectMethod(batch, get_array_interface));
CHECK(jinterface);
char const *c_interface_str =
jenv_->GetStringUTFChars(jinterface, nullptr);
CHECK(c_interface_str);
std::string interface_str {c_interface_str};
jenv_->ReleaseStringUTFChars(jinterface, c_interface_str);

++n_batches_;
host_columns_.Resize(n_batches_);

auto json_interface = Json::Load({interface_str.c_str(), interface_str.size()});
auto json_columns = get<Array const>(json_interface);
std::vector<ArrayInterface> interfaces(get<Array const>(json_interface).size());

for (auto& json_col : json_columns) {
auto column = ArrayInterface(get<Object const>(json_col));
interfaces.emplace_back(column);
}

host_columns_.interfaces.back() = json_interface;
CopyInterface(interfaces,
cudaMemcpyDeviceToHost,
&host_columns_.data.back(),
&host_columns_.valid.back(),
&host_columns_.interfaces.back());
} else {
break;
}
} catch (dmlc::Error const &e) {
if (jni_status_ == JNI_EDETACHED) {
GlobalJvm()->DetachCurrentThread();
}
LOG(FATAL) << e.what();
}
}
}

void Reset() {
it_ = 0;
}

int Next() {
if (it_ == n_batches_) {
return 0;
}
auto json_interface = host_columns_.interfaces.at(it_);
auto json_columns = get<Array const>(json_interface);

std::vector<ArrayInterface> in(get<Array const>(json_interface).size());
for (auto& json_col : json_columns) {
auto column = ArrayInterface(get<Object const>(json_col));
in.emplace_back(column);
}

std::string temp;
Json::Dump(json_interface, &temp);
Json out { Json::Load({temp.c_str(), temp.size()}) };

std::vector<thrust::device_vector<char>> data;
std::vector<thrust::device_vector<uint8_t>> mask;
CopyInterface(in, cudaMemcpyHostToDevice, &data, &mask, &out);

std::string interface_str;
Json::Dump(out, &interface_str);
XGDMatrixSetDataCudaArrayInterface(proxy_, interface_str.c_str());
it_++;
return 1;
};
};
} // namespace spark
} // namespace xgboost

namespace {
void Reset(DataIterHandle self) {
static_cast<xgboost::spark::DataIteratorProxy*>(self)->Reset();
}

int Next(DataIterHandle self) {
return static_cast<xgboost::spark::DataIteratorProxy*>(self)->Next();
}
} // anonymous namespace

jint XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jiter,
jfloat jmissing,
jint jmax_bin, jint jnthread,
jlongArray jout) {
xgboost::spark::DataIteratorProxy proxy(jiter);
DMatrixHandle result;
auto ret =
XGDMatrixCreateFromCallback(&proxy, proxy.GetDMatrixHandle(), Reset, Next,
jmissing, jnthread, jmax_bin, &result);
setHandle(jenv, jout, result);
return ret;
}

} // namespace spark
} // namespace xgboost
3 changes: 3 additions & 0 deletions src/data/array_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ class ArrayInterfaceHandler {
class ArrayInterface {
public:
ArrayInterface() = default;
explicit ArrayInterface(std::string const& str) {
auto jinterface = Json::Load({str.c_str(), str.size()});
}
explicit ArrayInterface(std::map<std::string, Json> const& column) {
ArrayInterfaceHandler::Validate(column);
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
Expand Down

0 comments on commit c1d2c27

Please sign in to comment.