diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 5528b1e836b4..3111b957006a 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -22,6 +22,14 @@ #include #include +#define JVM_CHECK_CALL(__expr) \ + { \ + int __errcode = (__expr); \ + if (__errcode != 0) { \ + return __errcode; \ + } \ + } + // helper functions // set handle void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) { @@ -177,6 +185,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro } int ret = XGDMatrixCreateFromDataIter( jiter, XGBoost4jCallbackDataIterNext, cache_info, &result); + JVM_CHECK_CALL(ret); if (cache_info) { jenv->ReleaseStringUTFChars(jcache_info, cache_info); } @@ -194,6 +203,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro DMatrixHandle result; const char* fname = jenv->GetStringUTFChars(jfname, 0); int ret = XGDMatrixCreateFromFile(fname, jsilent, &result); + JVM_CHECK_CALL(ret); if (fname) { jenv->ReleaseStringUTFChars(jfname, fname); } @@ -214,7 +224,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro jfloat* data = jenv->GetFloatArrayElements(jdata, 0); bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata); - jint ret = (jint) XGDMatrixCreateFromCSREx((size_t const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jcol, &result); + jint ret = (jint) XGDMatrixCreateFromCSREx((size_t const *)indptr, + (unsigned int const *)indices, + (float const *)data, + nindptr, nelem, jcol, &result); + JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); //Release jenv->ReleaseLongArrayElements(jindptr, indptr, 0); @@ -237,7 +251,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata); - jint ret = (jint) XGDMatrixCreateFromCSCEx((size_t const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jrow, &result); + jint ret = (jint) XGDMatrixCreateFromCSCEx((size_t const *)indptr, + (unsigned int const *)indices, + (float const *)data, + nindptr, nelem, jrow, &result); + JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); //release jenv->ReleaseLongArrayElements(jindptr, indptr, 0); @@ -258,6 +276,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro bst_ulong nrow = (bst_ulong)jnrow; bst_ulong ncol = (bst_ulong)jncol; jint ret = (jint) XGDMatrixCreateFromMat((float const *)jdataRef, nrow, ncol, jmiss, &result); + JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); return ret; } @@ -275,6 +294,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro bst_ulong nrow = (bst_ulong)jnrow; bst_ulong ncol = (bst_ulong)jncol; jint ret = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, jmiss, &result); + JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); //release jenv->ReleaseFloatArrayElements(jdata, data, 0); @@ -296,6 +316,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSliceDMat // default to not allowing slicing with group ID specified -- feel free to add if necessary jint ret = (jint) XGDMatrixSliceDMatrixEx(handle, (int const *)indexset, len, &result, 0); + JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); //release jenv->ReleaseIntArrayElements(jindexset, indexset, 0); @@ -325,6 +346,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSaveBinar DMatrixHandle handle = (DMatrixHandle) jhandle; const char* fname = jenv->GetStringUTFChars(jfname, 0); int ret = XGDMatrixSaveBinary(handle, fname, jsilent); + JVM_CHECK_CALL(ret); if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); return ret; } @@ -342,6 +364,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatI jfloat* array = jenv->GetFloatArrayElements(jarray, NULL); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); int ret = XGDMatrixSetFloatInfo(handle, field, (float const *)array, len); + JVM_CHECK_CALL(ret); //release if (field) jenv->ReleaseStringUTFChars(jfield, field); jenv->ReleaseFloatArrayElements(jarray, array, 0); @@ -360,6 +383,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntIn jint* array = jenv->GetIntArrayElements(jarray, NULL); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); int ret = XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len); + JVM_CHECK_CALL(ret); //release if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); jenv->ReleaseIntArrayElements(jarray, array, 0); @@ -379,6 +403,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetFloatI bst_ulong len; float *result; int ret = XGDMatrixGetFloatInfo(handle, field, &len, (const float**) &result); + JVM_CHECK_CALL(ret); if (field) jenv->ReleaseStringUTFChars(jfield, field); jsize jlen = (jsize) len; @@ -401,6 +426,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetUIntIn bst_ulong len; unsigned int *result; int ret = (jint) XGDMatrixGetUIntInfo(handle, field, &len, (const unsigned int **) &result); + JVM_CHECK_CALL(ret); if (field) jenv->ReleaseStringUTFChars(jfield, field); jsize jlen = (jsize) len; @@ -420,6 +446,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixNumRow DMatrixHandle handle = (DMatrixHandle) jhandle; bst_ulong result[1]; int ret = (jint) XGDMatrixNumRow(handle, result); + JVM_CHECK_CALL(ret); jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) result); return ret; } @@ -442,6 +469,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterCreate } BoosterHandle result; int ret = XGBoosterCreate(dmlc::BeginPtr(handles), handles.size(), &result); + JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); return ret; } @@ -469,6 +497,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetParam const char* name = jenv->GetStringUTFChars(jname, 0); const char* value = jenv->GetStringUTFChars(jvalue, 0); int ret = XGBoosterSetParam(handle, name, value); + JVM_CHECK_CALL(ret); //release if (name) jenv->ReleaseStringUTFChars(jname, name); if (value) jenv->ReleaseStringUTFChars(jvalue, value); @@ -500,6 +529,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneI jfloat* hess = jenv->GetFloatArrayElements(jhess, 0); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jgrad); int ret = XGBoosterBoostOneIter(handle, dtrain, grad, hess, len); + JVM_CHECK_CALL(ret); //release jenv->ReleaseFloatArrayElements(jgrad, grad, 0); jenv->ReleaseFloatArrayElements(jhess, hess, 0); @@ -537,6 +567,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIt dmlc::BeginPtr(dmats), dmlc::BeginPtr(evchars), len, &result); + JVM_CHECK_CALL(ret); jstring jinfo = nullptr; if (result != nullptr) { jinfo = jenv->NewStringUTF(result); @@ -556,8 +587,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict DMatrixHandle dmat = (DMatrixHandle) jdmat; bst_ulong len; float *result; - int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, 0, + int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, + /* training = */ 0, // Currently this parameter is not supported by JVM &len, (const float **) &result); + JVM_CHECK_CALL(ret); if (len) { jsize jlen = (jsize) len; jfloatArray jarray = jenv->NewFloatArray(jlen); @@ -578,7 +611,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel const char* fname = jenv->GetStringUTFChars(jfname, 0); int ret = XGBoosterLoadModel(handle, fname); - if (fname) jenv->ReleaseStringUTFChars(jfname,fname); + JVM_CHECK_CALL(ret); + if (fname) { + jenv->ReleaseStringUTFChars(jfname,fname); + } return ret; } @@ -593,7 +629,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModel const char* fname = jenv->GetStringUTFChars(jfname, 0); int ret = XGBoosterSaveModel(handle, fname); - if (fname) jenv->ReleaseStringUTFChars(jfname, fname); + JVM_CHECK_CALL(ret); + if (fname) { + jenv->ReleaseStringUTFChars(jfname, fname); + } return ret; } @@ -608,6 +647,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel jbyte* buffer = jenv->GetByteArrayElements(jbytes, 0); int ret = XGBoosterLoadModelFromBuffer( handle, buffer, jenv->GetArrayLength(jbytes)); + JVM_CHECK_CALL(ret); jenv->ReleaseByteArrayElements(jbytes, buffer, 0); return ret; } @@ -623,6 +663,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelR bst_ulong len = 0; const char* result; int ret = XGBoosterGetModelRaw(handle, &len, &result); + JVM_CHECK_CALL(ret); if (result) { jbyteArray jarray = jenv->NewByteArray(len); @@ -646,6 +687,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel char **result; int ret = XGBoosterDumpModelEx(handle, fmap, jwith_stats, format, &len, (const char ***) &result); + JVM_CHECK_CALL(ret); jsize jlen = (jsize) len; jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); @@ -697,6 +739,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel (const char **) dmlc::BeginPtr(feature_names_char), (const char **) dmlc::BeginPtr(feature_types_char), jwith_stats, format, &len, (const char ***) &result); + JVM_CHECK_CALL(ret); jsize jlen = (jsize) len; jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); @@ -719,6 +762,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttrNa bst_ulong len = 0; char **result; int ret = XGBoosterGetAttrNames(handle, &len, (const char ***) &result); + JVM_CHECK_CALL(ret); jsize jlen = (jsize) len; jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); @@ -742,6 +786,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr const char* result; int success; int ret = XGBoosterGetAttr(handle, key, &result, &success); + JVM_CHECK_CALL(ret); //release if (key) jenv->ReleaseStringUTFChars(jkey, key); @@ -764,6 +809,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr const char* key = jenv->GetStringUTFChars(jkey, 0); const char* value = jenv->GetStringUTFChars(jvalue, 0); int ret = XGBoosterSetAttr(handle, key, value); + JVM_CHECK_CALL(ret); //release if (key) jenv->ReleaseStringUTFChars(jkey, key); if (value) jenv->ReleaseStringUTFChars(jvalue, value); @@ -780,6 +826,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabit BoosterHandle handle = (BoosterHandle) jhandle; int version; int ret = XGBoosterLoadRabitCheckpoint(handle, &version); + JVM_CHECK_CALL(ret); jint jversion = version; jenv->SetIntArrayRegion(jout, 0, 1, &jversion); return ret;