From ff60e1463bca228a2a4fce7f083639e54e2a366e Mon Sep 17 00:00:00 2001 From: Bobby Date: Tue, 25 Jun 2024 18:19:39 +0800 Subject: [PATCH] supporting array --- .../ml/dmlc/xgboost4j/java/CudfColumn.java | 46 ++++++++++------ .../ml/dmlc/xgboost4j/java/DMatrixTest.java | 55 +++++++++++++++++++ 2 files changed, 85 insertions(+), 16 deletions(-) diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumn.java b/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumn.java index 32c64eadc360..b06813d58efc 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumn.java +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumn.java @@ -16,10 +16,7 @@ package ml.dmlc.xgboost4j.java; -import ai.rapids.cudf.BaseDeviceMemoryBuffer; -import ai.rapids.cudf.BufferType; -import ai.rapids.cudf.ColumnVector; -import ai.rapids.cudf.DType; +import ai.rapids.cudf.*; /** * This class is composing of base data with Apache Arrow format from Cudf ColumnVector. @@ -30,46 +27,63 @@ public class CudfColumn extends Column { private final long dataPtr; // gpu data buffer address private final long shape; // row count private final long validPtr; // gpu valid buffer address + private final long offsetPtr; // gpu offset address for list private final int typeSize; // type size in bytes private final String typeStr; // follow array interface spec private final long nullCount; // null count private String arrayInterface = null; // the cuda array interface - public static CudfColumn from(ColumnVector cv) { - BaseDeviceMemoryBuffer dataBuffer = cv.getDeviceBufferFor(BufferType.DATA); - BaseDeviceMemoryBuffer validBuffer = cv.getDeviceBufferFor(BufferType.VALIDITY); + public static CudfColumn from(ColumnVector column) { + ColumnView cv = column; + DType dType = cv.getType(); + long nullCount = cv.getNullCount(); + String floatType = " tables = new LinkedList<>(); + tables.add(new CudfColumnBatch(X1, y1, null, null)); + tables.add(new CudfColumnBatch(X2, y2, null, null)); + + try { + DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 8, 1); + } catch (XGBoostError e) { + throw new RuntimeException(e); + } + + System.out.println("--------------"); + + + } }