From cc71964a58664ae8f3e38ffb896144aafba67dc1 Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 13 Jan 2020 12:59:27 +0800 Subject: [PATCH] Add JVM_CHECK_CALL. * Added a check call macro in jvm package, prevents executing other functions from jvm when error occurred in XGBoost. For example, when prediction fails jvm should not try to allocate memory based on the output prediction size. --- .../xgboost4j/src/native/xgboost4j.cpp | 54 +++++++++++++++++-- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index b6b9a8377340..c99ed7a7a545 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -22,6 +22,14 @@ #include #include +#define JVM_CHECK_CALL(__expr) \ + { \ + int __errcode = (__expr); \ + if (__errcode != 0) { \ + return __errcode; \ + } \ + } + // helper functions // set handle void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) { @@ -177,6 +185,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro } int ret = XGDMatrixCreateFromDataIter( jiter, XGBoost4jCallbackDataIterNext, cache_info, &result); + JVM_CHECK_CALL(ret); if (cache_info) { jenv->ReleaseStringUTFChars(jcache_info, cache_info); } @@ -194,6 +203,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro DMatrixHandle result; const char* fname = jenv->GetStringUTFChars(jfname, 0); int ret = XGDMatrixCreateFromFile(fname, jsilent, &result); + JVM_CHECK_CALL(ret); if (fname) { jenv->ReleaseStringUTFChars(jfname, fname); } @@ -214,7 +224,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro jfloat* data = jenv->GetFloatArrayElements(jdata, 0); bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata); - jint ret = (jint) XGDMatrixCreateFromCSREx((size_t const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jcol, &result); + jint ret = (jint) XGDMatrixCreateFromCSREx((size_t const *)indptr, + (unsigned int const *)indices, + (float const *)data, + nindptr, nelem, jcol, &result); + JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); //Release jenv->ReleaseLongArrayElements(jindptr, indptr, 0); @@ -237,7 +251,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata); - jint ret = (jint) XGDMatrixCreateFromCSCEx((size_t const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jrow, &result); + jint ret = (jint) XGDMatrixCreateFromCSCEx((size_t const *)indptr, + (unsigned int const *)indices, + (float const *)data, + nindptr, nelem, jrow, &result); + JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); //release jenv->ReleaseLongArrayElements(jindptr, indptr, 0); @@ -258,6 +276,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro bst_ulong nrow = (bst_ulong)jnrow; bst_ulong ncol = (bst_ulong)jncol; jint ret = (jint) XGDMatrixCreateFromMat((float const *)jdataRef, nrow, ncol, jmiss, &result); + JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); return ret; } @@ -275,6 +294,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro bst_ulong nrow = (bst_ulong)jnrow; bst_ulong ncol = (bst_ulong)jncol; jint ret = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, jmiss, &result); + JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); //release jenv->ReleaseFloatArrayElements(jdata, data, 0); @@ -296,6 +316,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSliceDMat // default to not allowing slicing with group ID specified -- feel free to add if necessary jint ret = (jint) XGDMatrixSliceDMatrixEx(handle, (int const *)indexset, len, &result, 0); + JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); //release jenv->ReleaseIntArrayElements(jindexset, indexset, 0); @@ -325,6 +346,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSaveBinar DMatrixHandle handle = (DMatrixHandle) jhandle; const char* fname = jenv->GetStringUTFChars(jfname, 0); int ret = XGDMatrixSaveBinary(handle, fname, jsilent); + JVM_CHECK_CALL(ret); if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); return ret; } @@ -342,6 +364,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatI jfloat* array = jenv->GetFloatArrayElements(jarray, NULL); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); int ret = XGDMatrixSetFloatInfo(handle, field, (float const *)array, len); + JVM_CHECK_CALL(ret); //release if (field) jenv->ReleaseStringUTFChars(jfield, field); jenv->ReleaseFloatArrayElements(jarray, array, 0); @@ -360,6 +383,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntIn jint* array = jenv->GetIntArrayElements(jarray, NULL); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); int ret = XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len); + JVM_CHECK_CALL(ret); //release if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); jenv->ReleaseIntArrayElements(jarray, array, 0); @@ -379,6 +403,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetFloatI bst_ulong len; float *result; int ret = XGDMatrixGetFloatInfo(handle, field, &len, (const float**) &result); + JVM_CHECK_CALL(ret); if (field) jenv->ReleaseStringUTFChars(jfield, field); jsize jlen = (jsize) len; @@ -401,6 +426,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetUIntIn bst_ulong len; unsigned int *result; int ret = (jint) XGDMatrixGetUIntInfo(handle, field, &len, (const unsigned int **) &result); + JVM_CHECK_CALL(ret); if (field) jenv->ReleaseStringUTFChars(jfield, field); jsize jlen = (jsize) len; @@ -420,6 +446,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixNumRow DMatrixHandle handle = (DMatrixHandle) jhandle; bst_ulong result[1]; int ret = (jint) XGDMatrixNumRow(handle, result); + JVM_CHECK_CALL(ret); jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) result); return ret; } @@ -442,6 +469,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterCreate } BoosterHandle result; int ret = XGBoosterCreate(dmlc::BeginPtr(handles), handles.size(), &result); + JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); return ret; } @@ -469,6 +497,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetParam const char* name = jenv->GetStringUTFChars(jname, 0); const char* value = jenv->GetStringUTFChars(jvalue, 0); int ret = XGBoosterSetParam(handle, name, value); + JVM_CHECK_CALL(ret); //release if (name) jenv->ReleaseStringUTFChars(jname, name); if (value) jenv->ReleaseStringUTFChars(jvalue, value); @@ -500,6 +529,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneI jfloat* hess = jenv->GetFloatArrayElements(jhess, 0); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jgrad); int ret = XGBoosterBoostOneIter(handle, dtrain, grad, hess, len); + JVM_CHECK_CALL(ret); //release jenv->ReleaseFloatArrayElements(jgrad, grad, 0); jenv->ReleaseFloatArrayElements(jhess, hess, 0); @@ -537,6 +567,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIt dmlc::BeginPtr(dmats), dmlc::BeginPtr(evchars), len, &result); + JVM_CHECK_CALL(ret); jstring jinfo = nullptr; if (result != nullptr) { jinfo = jenv->NewStringUTF(result); @@ -557,6 +588,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict bst_ulong len; float *result; int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result); + JVM_CHECK_CALL(ret); if (len) { jsize jlen = (jsize) len; jfloatArray jarray = jenv->NewFloatArray(jlen); @@ -577,7 +609,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel const char* fname = jenv->GetStringUTFChars(jfname, 0); int ret = XGBoosterLoadModel(handle, fname); - if (fname) jenv->ReleaseStringUTFChars(jfname,fname); + JVM_CHECK_CALL(ret); + if (fname) { + jenv->ReleaseStringUTFChars(jfname,fname); + } return ret; } @@ -592,7 +627,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModel const char* fname = jenv->GetStringUTFChars(jfname, 0); int ret = XGBoosterSaveModel(handle, fname); - if (fname) jenv->ReleaseStringUTFChars(jfname, fname); + JVM_CHECK_CALL(ret); + if (fname) { + jenv->ReleaseStringUTFChars(jfname, fname); + } return ret; } @@ -607,6 +645,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel jbyte* buffer = jenv->GetByteArrayElements(jbytes, 0); int ret = XGBoosterLoadModelFromBuffer( handle, buffer, jenv->GetArrayLength(jbytes)); + JVM_CHECK_CALL(ret); jenv->ReleaseByteArrayElements(jbytes, buffer, 0); return ret; } @@ -622,6 +661,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelR bst_ulong len = 0; const char* result; int ret = XGBoosterGetModelRaw(handle, &len, &result); + JVM_CHECK_CALL(ret); if (result) { jbyteArray jarray = jenv->NewByteArray(len); @@ -645,6 +685,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel char **result; int ret = XGBoosterDumpModelEx(handle, fmap, jwith_stats, format, &len, (const char ***) &result); + JVM_CHECK_CALL(ret); jsize jlen = (jsize) len; jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); @@ -696,6 +737,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel (const char **) dmlc::BeginPtr(feature_names_char), (const char **) dmlc::BeginPtr(feature_types_char), jwith_stats, format, &len, (const char ***) &result); + JVM_CHECK_CALL(ret); jsize jlen = (jsize) len; jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); @@ -718,6 +760,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttrNa bst_ulong len = 0; char **result; int ret = XGBoosterGetAttrNames(handle, &len, (const char ***) &result); + JVM_CHECK_CALL(ret); jsize jlen = (jsize) len; jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); @@ -741,6 +784,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr const char* result; int success; int ret = XGBoosterGetAttr(handle, key, &result, &success); + JVM_CHECK_CALL(ret); //release if (key) jenv->ReleaseStringUTFChars(jkey, key); @@ -763,6 +807,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr const char* key = jenv->GetStringUTFChars(jkey, 0); const char* value = jenv->GetStringUTFChars(jvalue, 0); int ret = XGBoosterSetAttr(handle, key, value); + JVM_CHECK_CALL(ret); //release if (key) jenv->ReleaseStringUTFChars(jkey, key); if (value) jenv->ReleaseStringUTFChars(jvalue, value); @@ -779,6 +824,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabit BoosterHandle handle = (BoosterHandle) jhandle; int version; int ret = XGBoosterLoadRabitCheckpoint(handle, &version); + JVM_CHECK_CALL(ret); jint jversion = version; jenv->SetIntArrayRegion(jout, 0, 1, &jversion); return ret;