Skip to content

Commit

Permalink
Add missing for DMatrix when creating it from Iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 26, 2024
1 parent 0646c9e commit 9b23f55
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 18 deletions.
2 changes: 2 additions & 0 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,13 +402,15 @@ XGB_EXTERN_C typedef int XGBCallbackDataIterNext( // NOLINT(*)
* \param data_handle The handle to the data.
* \param callback The callback to get the data.
* \param cache_info Additional information about cache file, can be null.
* \param missing Which value to represent missing value.
* \param out The created DMatrix
* \return 0 when success, -1 when failure happens.
*/
XGB_DLL int XGDMatrixCreateFromDataIter(
DataIterHandle data_handle,
XGBCallbackDataIterNext* callback,
const char* cache_info,
float missing,
DMatrixHandle *out);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,29 @@ public static enum SparseType {
* @throws XGBoostError
*/
public DMatrix(Iterator<LabeledPoint> iter, String cacheInfo) throws XGBoostError {
this(iter, cacheInfo, Float.NaN);
}

/**
* Create DMatrix from iterator.
*
* @param iter The data iterator of mini batch to provide the data.
* @param cacheInfo Cache path information, used for external memory setting, can be null.
* @param missing the missing value
* @throws XGBoostError
*/
public DMatrix(Iterator<LabeledPoint> iter,
String cacheInfo,
float missing) throws XGBoostError {
if (iter == null) {
throw new NullPointerException("iter: null");
}
// 32k as batch size
int batchSize = 32 << 10;
Iterator<DataBatch> batchIter = new DataBatch.BatchIterator(iter, batchSize);
long[] out = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, cacheInfo, out));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter,
cacheInfo, missing, out));
handle = out[0];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ static void checkCall(int ret) throws XGBoostError {
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);

final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
String cache_info, long[] out);
String cache_info, float missing, long[] out);

public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices,
float[] data, int shapeParam,
Expand Down
6 changes: 4 additions & 2 deletions jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,16 +214,18 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBGetLastError
* Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromDataIter
(JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jlongArray jout) {
(JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jfloat jmissing, jlongArray jout) {
DMatrixHandle result;
std::unique_ptr<char const, Deleter<char const>> cache_info;
if (jcache_info != nullptr) {
cache_info = {jenv->GetStringUTFChars(jcache_info, nullptr), [&](char const *ptr) {
jenv->ReleaseStringUTFChars(jcache_info, ptr);
}};
}
auto missing = static_cast<float>(jmissing);
int ret =
XGDMatrixCreateFromDataIter(jiter, XGBoost4jCallbackDataIterNext, cache_info.get(), &result);
XGDMatrixCreateFromDataIter(jiter, XGBoost4jCallbackDataIterNext, cache_info.get(),
missing,&result);
JVM_CHECK_CALL(ret);
setHandle(jenv, jout, result);
return ret;
Expand Down
4 changes: 2 additions & 2 deletions jvm-packages/xgboost4j/src/native/xgboost4j.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@
*/
package ml.dmlc.xgboost4j.java;

import java.io.*;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

import junit.framework.TestCase;
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
import org.junit.Test;

import static org.junit.Assert.assertArrayEquals;
Expand All @@ -36,6 +39,18 @@
*/
public class DMatrixTest {

@Test
public void testCreateFromDataIteratorWithMissingValue() throws XGBoostError {
//create DMatrix from DataIterator
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
blist.add(new LabeledPoint(0.1f, 4, null, new float[]{1, 3, 4, 5}));
blist.add(new LabeledPoint(0.1f, 4, null, new float[]{11, 13, 14, 15}));
blist.add(new LabeledPoint(0.1f, 4, null, new float[]{21, 23, 24, 25}));
DMatrix dmat = new DMatrix(blist.iterator(), null, 15);

assert dmat.nonMissingNum() == 11;
}

@Test
public void testCreateFromDataIterator() throws XGBoostError {
//create DMatrix from DataIterator
Expand All @@ -45,7 +60,7 @@ public void testCreateFromDataIterator() throws XGBoostError {
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
for (int i = 0; i < nrep; ++i) {
LabeledPoint p = new LabeledPoint(
0.1f + i, 4, new int[]{0, 2, 3}, new float[]{3, 4, 5});
0.1f + i, 4, new int[]{0, 2, 3}, new float[]{3, 4, 5});
blist.add(p);
labelall.add(p.label());
}
Expand Down Expand Up @@ -290,7 +305,7 @@ public void testCreateFromDenseMatrixRef() throws XGBoostError {
} finally {
if (dmat0 != null) {
dmat0.dispose();
} else if (data0 != null){
} else if (data0 != null) {
data0.dispose();
}
}
Expand All @@ -309,9 +324,9 @@ public void testTrainWithDenseMatrixRef() throws XGBoostError {
// (3,1) -> 2
// (2,3) -> 3
float[][] data = new float[][]{
new float[]{4f, 5f},
new float[]{3f, 1f},
new float[]{2f, 3f}
new float[]{4f, 5f},
new float[]{3f, 1f},
new float[]{2f, 3f}
};
data0 = new BigDenseMatrix(3, 2);
for (int i = 0; i < data0.nrow; i++)
Expand Down
9 changes: 4 additions & 5 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ XGB_DLL int XGDMatrixCreateFromURI(const char *config, DMatrixHandle *out) {
XGB_DLL int XGDMatrixCreateFromDataIter(
void *data_handle, // a Java iterator
XGBCallbackDataIterNext *callback, // C++ callback defined in xgboost4j.cpp
const char *cache_info, DMatrixHandle *out) {
const char *cache_info,
float missing,
DMatrixHandle *out) {
API_BEGIN();

std::string scache;
Expand All @@ -264,10 +266,7 @@ XGB_DLL int XGDMatrixCreateFromDataIter(
data_handle, callback);
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix> {
DMatrix::Create(
&adapter, std::numeric_limits<float>::quiet_NaN(),
1, scache
)
DMatrix::Create(&adapter, missing,1, scache)
};
API_END();
}
Expand Down

0 comments on commit 9b23f55

Please sign in to comment.