From 9b23f558a22a401aebc7351f9633489d1edcf307 Mon Sep 17 00:00:00 2001 From: Bobby Date: Wed, 26 Jun 2024 18:19:23 +0800 Subject: [PATCH] Add missing for DMatrix when creating it from Iterator --- include/xgboost/c_api.h | 2 ++ .../java/ml/dmlc/xgboost4j/java/DMatrix.java | 17 ++++++++++- .../ml/dmlc/xgboost4j/java/XGBoostJNI.java | 2 +- .../xgboost4j/src/native/xgboost4j.cpp | 6 ++-- jvm-packages/xgboost4j/src/native/xgboost4j.h | 4 +-- .../ml/dmlc/xgboost4j/java/DMatrixTest.java | 29 ++++++++++++++----- src/c_api/c_api.cc | 9 +++--- 7 files changed, 51 insertions(+), 18 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 85897412f9a6..16817bf5ad1c 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -402,6 +402,7 @@ XGB_EXTERN_C typedef int XGBCallbackDataIterNext( // NOLINT(*) * \param data_handle The handle to the data. * \param callback The callback to get the data. * \param cache_info Additional information about cache file, can be null. + * \param missing Which value to represent missing value. * \param out The created DMatrix * \return 0 when success, -1 when failure happens. */ @@ -409,6 +410,7 @@ XGB_DLL int XGDMatrixCreateFromDataIter( DataIterHandle data_handle, XGBCallbackDataIterNext* callback, const char* cache_info, + float missing, DMatrixHandle *out); /** diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java index 0e88c25d3fda..ddb873c945a6 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java @@ -44,6 +44,20 @@ public static enum SparseType { * @throws XGBoostError */ public DMatrix(Iterator iter, String cacheInfo) throws XGBoostError { + this(iter, cacheInfo, Float.NaN); + } + + /** + * Create DMatrix from iterator. + * + * @param iter The data iterator of mini batch to provide the data. + * @param cacheInfo Cache path information, used for external memory setting, can be null. + * @param missing the missing value + * @throws XGBoostError + */ + public DMatrix(Iterator iter, + String cacheInfo, + float missing) throws XGBoostError { if (iter == null) { throw new NullPointerException("iter: null"); } @@ -51,7 +65,8 @@ public DMatrix(Iterator iter, String cacheInfo) throws XGBoostErro int batchSize = 32 << 10; Iterator batchIter = new DataBatch.BatchIterator(iter, batchSize); long[] out = new long[1]; - XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, cacheInfo, out)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, + cacheInfo, missing, out)); handle = out[0]; } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index b410d2be1d02..00413636e0f0 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -54,7 +54,7 @@ static void checkCall(int ret) throws XGBoostError { public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out); final static native int XGDMatrixCreateFromDataIter(java.util.Iterator iter, - String cache_info, long[] out); + String cache_info, float missing, long[] out); public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data, int shapeParam, diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index cfab645ed6bf..d8f169157e3a 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -214,7 +214,7 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBGetLastError * Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromDataIter - (JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jlongArray jout) { + (JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jfloat jmissing, jlongArray jout) { DMatrixHandle result; std::unique_ptr> cache_info; if (jcache_info != nullptr) { @@ -222,8 +222,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro jenv->ReleaseStringUTFChars(jcache_info, ptr); }}; } + auto missing = static_cast(jmissing); int ret = - XGDMatrixCreateFromDataIter(jiter, XGBoost4jCallbackDataIterNext, cache_info.get(), &result); + XGDMatrixCreateFromDataIter(jiter, XGBoost4jCallbackDataIterNext, cache_info.get(), + missing,&result); JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); return ret; diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index c8e48cfc9de9..f8657b5a61a1 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -26,10 +26,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGDMatrixCreateFromDataIter - * Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I + * Signature: (Ljava/util/Iterator;Ljava/lang/String;F[J)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromDataIter - (JNIEnv *, jclass, jobject, jstring, jlongArray); + (JNIEnv *, jclass, jobject, jstring, jfloat, jlongArray); /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java index b6ffe84e30e9..f6f914a94e68 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java @@ -15,15 +15,18 @@ */ package ml.dmlc.xgboost4j.java; -import java.io.*; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.Random; import junit.framework.TestCase; -import ml.dmlc.xgboost4j.java.util.BigDenseMatrix; import ml.dmlc.xgboost4j.LabeledPoint; +import ml.dmlc.xgboost4j.java.util.BigDenseMatrix; import org.junit.Test; import static org.junit.Assert.assertArrayEquals; @@ -36,6 +39,18 @@ */ public class DMatrixTest { + @Test + public void testCreateFromDataIteratorWithMissingValue() throws XGBoostError { + //create DMatrix from DataIterator + java.util.List blist = new java.util.LinkedList(); + blist.add(new LabeledPoint(0.1f, 4, null, new float[]{1, 3, 4, 5})); + blist.add(new LabeledPoint(0.1f, 4, null, new float[]{11, 13, 14, 15})); + blist.add(new LabeledPoint(0.1f, 4, null, new float[]{21, 23, 24, 25})); + DMatrix dmat = new DMatrix(blist.iterator(), null, 15); + + assert dmat.nonMissingNum() == 11; + } + @Test public void testCreateFromDataIterator() throws XGBoostError { //create DMatrix from DataIterator @@ -45,7 +60,7 @@ public void testCreateFromDataIterator() throws XGBoostError { java.util.List blist = new java.util.LinkedList(); for (int i = 0; i < nrep; ++i) { LabeledPoint p = new LabeledPoint( - 0.1f + i, 4, new int[]{0, 2, 3}, new float[]{3, 4, 5}); + 0.1f + i, 4, new int[]{0, 2, 3}, new float[]{3, 4, 5}); blist.add(p); labelall.add(p.label()); } @@ -290,7 +305,7 @@ public void testCreateFromDenseMatrixRef() throws XGBoostError { } finally { if (dmat0 != null) { dmat0.dispose(); - } else if (data0 != null){ + } else if (data0 != null) { data0.dispose(); } } @@ -309,9 +324,9 @@ public void testTrainWithDenseMatrixRef() throws XGBoostError { // (3,1) -> 2 // (2,3) -> 3 float[][] data = new float[][]{ - new float[]{4f, 5f}, - new float[]{3f, 1f}, - new float[]{2f, 3f} + new float[]{4f, 5f}, + new float[]{3f, 1f}, + new float[]{2f, 3f} }; data0 = new BigDenseMatrix(3, 2); for (int i = 0; i < data0.nrow; i++) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 45160baea51f..b0c84ef43468 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -253,7 +253,9 @@ XGB_DLL int XGDMatrixCreateFromURI(const char *config, DMatrixHandle *out) { XGB_DLL int XGDMatrixCreateFromDataIter( void *data_handle, // a Java iterator XGBCallbackDataIterNext *callback, // C++ callback defined in xgboost4j.cpp - const char *cache_info, DMatrixHandle *out) { + const char *cache_info, + float missing, + DMatrixHandle *out) { API_BEGIN(); std::string scache; @@ -264,10 +266,7 @@ XGB_DLL int XGDMatrixCreateFromDataIter( data_handle, callback); xgboost_CHECK_C_ARG_PTR(out); *out = new std::shared_ptr { - DMatrix::Create( - &adapter, std::numeric_limits::quiet_NaN(), - 1, scache - ) + DMatrix::Create(&adapter, missing,1, scache) }; API_END(); }