Skip to content

Commit

Permalink
Add JVM_CHECK_CALL. (#5199)
Browse files Browse the repository at this point in the history
* Added a check call macro in jvm package, prevents executing other functions
from jvm when error occurred in XGBoost. For example, when prediction fails jvm
should not try to allocate memory based on the output prediction size.
  • Loading branch information
trivialfis authored Feb 18, 2020
1 parent 0110754 commit 9f77c18
Showing 1 changed file with 52 additions and 5 deletions.
57 changes: 52 additions & 5 deletions jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@
#include <vector>
#include <string>

#define JVM_CHECK_CALL(__expr) \
{ \
int __errcode = (__expr); \
if (__errcode != 0) { \
return __errcode; \
} \
}

// helper functions
// set handle
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
Expand Down Expand Up @@ -177,6 +185,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
}
int ret = XGDMatrixCreateFromDataIter(
jiter, XGBoost4jCallbackDataIterNext, cache_info, &result);
JVM_CHECK_CALL(ret);
if (cache_info) {
jenv->ReleaseStringUTFChars(jcache_info, cache_info);
}
Expand All @@ -194,6 +203,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
DMatrixHandle result;
const char* fname = jenv->GetStringUTFChars(jfname, 0);
int ret = XGDMatrixCreateFromFile(fname, jsilent, &result);
JVM_CHECK_CALL(ret);
if (fname) {
jenv->ReleaseStringUTFChars(jfname, fname);
}
Expand All @@ -214,7 +224,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
jint ret = (jint) XGDMatrixCreateFromCSREx((size_t const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jcol, &result);
jint ret = (jint) XGDMatrixCreateFromCSREx((size_t const *)indptr,
(unsigned int const *)indices,
(float const *)data,
nindptr, nelem, jcol, &result);
JVM_CHECK_CALL(ret);
setHandle(jenv, jout, result);
//Release
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
Expand All @@ -237,7 +251,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);

jint ret = (jint) XGDMatrixCreateFromCSCEx((size_t const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jrow, &result);
jint ret = (jint) XGDMatrixCreateFromCSCEx((size_t const *)indptr,
(unsigned int const *)indices,
(float const *)data,
nindptr, nelem, jrow, &result);
JVM_CHECK_CALL(ret);
setHandle(jenv, jout, result);
//release
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
Expand All @@ -258,6 +276,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
bst_ulong nrow = (bst_ulong)jnrow;
bst_ulong ncol = (bst_ulong)jncol;
jint ret = (jint) XGDMatrixCreateFromMat((float const *)jdataRef, nrow, ncol, jmiss, &result);
JVM_CHECK_CALL(ret);
setHandle(jenv, jout, result);
return ret;
}
Expand All @@ -275,6 +294,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
bst_ulong nrow = (bst_ulong)jnrow;
bst_ulong ncol = (bst_ulong)jncol;
jint ret = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, jmiss, &result);
JVM_CHECK_CALL(ret);
setHandle(jenv, jout, result);
//release
jenv->ReleaseFloatArrayElements(jdata, data, 0);
Expand All @@ -296,6 +316,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSliceDMat

// default to not allowing slicing with group ID specified -- feel free to add if necessary
jint ret = (jint) XGDMatrixSliceDMatrixEx(handle, (int const *)indexset, len, &result, 0);
JVM_CHECK_CALL(ret);
setHandle(jenv, jout, result);
//release
jenv->ReleaseIntArrayElements(jindexset, indexset, 0);
Expand Down Expand Up @@ -325,6 +346,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSaveBinar
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* fname = jenv->GetStringUTFChars(jfname, 0);
int ret = XGDMatrixSaveBinary(handle, fname, jsilent);
JVM_CHECK_CALL(ret);
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
return ret;
}
Expand All @@ -342,6 +364,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatI
jfloat* array = jenv->GetFloatArrayElements(jarray, NULL);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
int ret = XGDMatrixSetFloatInfo(handle, field, (float const *)array, len);
JVM_CHECK_CALL(ret);
//release
if (field) jenv->ReleaseStringUTFChars(jfield, field);
jenv->ReleaseFloatArrayElements(jarray, array, 0);
Expand All @@ -360,6 +383,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntIn
jint* array = jenv->GetIntArrayElements(jarray, NULL);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
int ret = XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len);
JVM_CHECK_CALL(ret);
//release
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
jenv->ReleaseIntArrayElements(jarray, array, 0);
Expand All @@ -379,6 +403,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetFloatI
bst_ulong len;
float *result;
int ret = XGDMatrixGetFloatInfo(handle, field, &len, (const float**) &result);
JVM_CHECK_CALL(ret);
if (field) jenv->ReleaseStringUTFChars(jfield, field);

jsize jlen = (jsize) len;
Expand All @@ -401,6 +426,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetUIntIn
bst_ulong len;
unsigned int *result;
int ret = (jint) XGDMatrixGetUIntInfo(handle, field, &len, (const unsigned int **) &result);
JVM_CHECK_CALL(ret);
if (field) jenv->ReleaseStringUTFChars(jfield, field);

jsize jlen = (jsize) len;
Expand All @@ -420,6 +446,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixNumRow
DMatrixHandle handle = (DMatrixHandle) jhandle;
bst_ulong result[1];
int ret = (jint) XGDMatrixNumRow(handle, result);
JVM_CHECK_CALL(ret);
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) result);
return ret;
}
Expand All @@ -442,6 +469,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterCreate
}
BoosterHandle result;
int ret = XGBoosterCreate(dmlc::BeginPtr(handles), handles.size(), &result);
JVM_CHECK_CALL(ret);
setHandle(jenv, jout, result);
return ret;
}
Expand Down Expand Up @@ -469,6 +497,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetParam
const char* name = jenv->GetStringUTFChars(jname, 0);
const char* value = jenv->GetStringUTFChars(jvalue, 0);
int ret = XGBoosterSetParam(handle, name, value);
JVM_CHECK_CALL(ret);
//release
if (name) jenv->ReleaseStringUTFChars(jname, name);
if (value) jenv->ReleaseStringUTFChars(jvalue, value);
Expand Down Expand Up @@ -500,6 +529,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneI
jfloat* hess = jenv->GetFloatArrayElements(jhess, 0);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jgrad);
int ret = XGBoosterBoostOneIter(handle, dtrain, grad, hess, len);
JVM_CHECK_CALL(ret);
//release
jenv->ReleaseFloatArrayElements(jgrad, grad, 0);
jenv->ReleaseFloatArrayElements(jhess, hess, 0);
Expand Down Expand Up @@ -537,6 +567,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIt
dmlc::BeginPtr(dmats),
dmlc::BeginPtr(evchars),
len, &result);
JVM_CHECK_CALL(ret);
jstring jinfo = nullptr;
if (result != nullptr) {
jinfo = jenv->NewStringUTF(result);
Expand All @@ -556,8 +587,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
DMatrixHandle dmat = (DMatrixHandle) jdmat;
bst_ulong len;
float *result;
int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, 0,
int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit,
/* training = */ 0, // Currently this parameter is not supported by JVM
&len, (const float **) &result);
JVM_CHECK_CALL(ret);
if (len) {
jsize jlen = (jsize) len;
jfloatArray jarray = jenv->NewFloatArray(jlen);
Expand All @@ -578,7 +611,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel
const char* fname = jenv->GetStringUTFChars(jfname, 0);

int ret = XGBoosterLoadModel(handle, fname);
if (fname) jenv->ReleaseStringUTFChars(jfname,fname);
JVM_CHECK_CALL(ret);
if (fname) {
jenv->ReleaseStringUTFChars(jfname,fname);
}
return ret;
}

Expand All @@ -593,7 +629,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModel
const char* fname = jenv->GetStringUTFChars(jfname, 0);

int ret = XGBoosterSaveModel(handle, fname);
if (fname) jenv->ReleaseStringUTFChars(jfname, fname);
JVM_CHECK_CALL(ret);
if (fname) {
jenv->ReleaseStringUTFChars(jfname, fname);
}
return ret;
}

Expand All @@ -608,6 +647,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel
jbyte* buffer = jenv->GetByteArrayElements(jbytes, 0);
int ret = XGBoosterLoadModelFromBuffer(
handle, buffer, jenv->GetArrayLength(jbytes));
JVM_CHECK_CALL(ret);
jenv->ReleaseByteArrayElements(jbytes, buffer, 0);
return ret;
}
Expand All @@ -623,6 +663,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelR
bst_ulong len = 0;
const char* result;
int ret = XGBoosterGetModelRaw(handle, &len, &result);
JVM_CHECK_CALL(ret);

if (result) {
jbyteArray jarray = jenv->NewByteArray(len);
Expand All @@ -646,6 +687,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel
char **result;

int ret = XGBoosterDumpModelEx(handle, fmap, jwith_stats, format, &len, (const char ***) &result);
JVM_CHECK_CALL(ret);

jsize jlen = (jsize) len;
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));
Expand Down Expand Up @@ -697,6 +739,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel
(const char **) dmlc::BeginPtr(feature_names_char),
(const char **) dmlc::BeginPtr(feature_types_char),
jwith_stats, format, &len, (const char ***) &result);
JVM_CHECK_CALL(ret);

jsize jlen = (jsize) len;
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));
Expand All @@ -719,6 +762,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttrNa
bst_ulong len = 0;
char **result;
int ret = XGBoosterGetAttrNames(handle, &len, (const char ***) &result);
JVM_CHECK_CALL(ret);

jsize jlen = (jsize) len;
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));
Expand All @@ -742,6 +786,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr
const char* result;
int success;
int ret = XGBoosterGetAttr(handle, key, &result, &success);
JVM_CHECK_CALL(ret);
//release
if (key) jenv->ReleaseStringUTFChars(jkey, key);

Expand All @@ -764,6 +809,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
const char* key = jenv->GetStringUTFChars(jkey, 0);
const char* value = jenv->GetStringUTFChars(jvalue, 0);
int ret = XGBoosterSetAttr(handle, key, value);
JVM_CHECK_CALL(ret);
//release
if (key) jenv->ReleaseStringUTFChars(jkey, key);
if (value) jenv->ReleaseStringUTFChars(jvalue, value);
Expand All @@ -780,6 +826,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabit
BoosterHandle handle = (BoosterHandle) jhandle;
int version;
int ret = XGBoosterLoadRabitCheckpoint(handle, &version);
JVM_CHECK_CALL(ret);
jint jversion = version;
jenv->SetIntArrayRegion(jout, 0, 1, &jversion);
return ret;
Expand Down

0 comments on commit 9f77c18

Please sign in to comment.