From f1e9bbcee52159d4bd5f7d25ef539777ceac147c Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Mon, 5 Dec 2022 12:23:21 +0800 Subject: [PATCH] [breakinig] [jvm-packages] change DeviceQuantileDmatrix into QuantileDMatrix (#8461) --- .../src/native/xgboost4j-gpu.cpp | 7 ++ .../xgboost4j-gpu/src/native/xgboost4j-gpu.cu | 18 ++++- .../dmlc/xgboost4j/gpu/java/BoosterTest.java | 4 +- .../dmlc/xgboost4j/gpu/java/DMatrixTest.java | 4 +- ...Suite.scala => QuantileDMatrixSuite.scala} | 7 +- .../scala/rapids/spark/GpuPreXGBoost.scala | 6 +- .../ml/dmlc/xgboost4j/java/ColumnBatch.java | 2 +- .../xgboost4j/java/DeviceQuantileDMatrix.java | 68 ----------------- .../dmlc/xgboost4j/java/QuantileDMatrix.java | 75 +++++++++++++++++++ .../ml/dmlc/xgboost4j/java/XGBoostJNI.java | 4 + ...ileDMatrix.scala => QuantileDMatrix.scala} | 26 +++---- .../xgboost4j/src/native/xgboost4j.cpp | 15 ++++ jvm-packages/xgboost4j/src/native/xgboost4j.h | 8 ++ 13 files changed, 150 insertions(+), 94 deletions(-) rename jvm-packages/xgboost4j-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/{DeviceQuantileDMatrixSuite.scala => QuantileDMatrixSuite.scala} (94%) delete mode 100644 jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DeviceQuantileDMatrix.java create mode 100644 jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/QuantileDMatrix.java rename jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/{DeviceQuantileDMatrix.scala => QuantileDMatrix.scala} (72%) diff --git a/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cpp b/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cpp index f55e4f837..698da6244 100644 --- a/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cpp +++ b/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cpp @@ -20,6 +20,13 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass j common::AssertGPUSupport(); API_END(); } +XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls, + jobject jdata_iter, jobject jref_iter, + char const *config, jlongArray jout) { + API_BEGIN(); + common::AssertGPUSupport(); + API_END(); +} } // namespace jni } // namespace xgboost #endif // XGBOOST_USE_CUDA diff --git a/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu b/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu index bf3f6a0db..317be01ad 100644 --- a/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu +++ b/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu @@ -379,7 +379,7 @@ int Next(DataIterHandle self) { } } // anonymous namespace -XGB_DLL jint XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls, +XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls, jobject jiter, jfloat jmissing, jint jmax_bin, jint jnthread, @@ -392,5 +392,21 @@ XGB_DLL jint XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass setHandle(jenv, jout, result); return ret; } + +XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls, + jobject jdata_iter, jobject jref_iter, + char const *config, jlongArray jout) { + xgboost::jni::DataIteratorProxy proxy(jdata_iter); + DMatrixHandle result; + + std::unique_ptr ref_proxy{nullptr}; + if (jref_iter) { + ref_proxy = std::make_unique(jref_iter); + } + auto ret = XGQuantileDMatrixCreateFromCallback( + &proxy, proxy.GetDMatrixHandle(), ref_proxy.get(), Reset, Next, config, &result); + setHandle(jenv, jout, result); + return ret; +} } // namespace jni } // namespace xgboost diff --git a/jvm-packages/xgboost4j-gpu/src/test/java/ml/dmlc/xgboost4j/gpu/java/BoosterTest.java b/jvm-packages/xgboost4j-gpu/src/test/java/ml/dmlc/xgboost4j/gpu/java/BoosterTest.java index c6109a236..49d17b6be 100644 --- a/jvm-packages/xgboost4j-gpu/src/test/java/ml/dmlc/xgboost4j/gpu/java/BoosterTest.java +++ b/jvm-packages/xgboost4j-gpu/src/test/java/ml/dmlc/xgboost4j/gpu/java/BoosterTest.java @@ -34,7 +34,7 @@ import ai.rapids.cudf.CSVOptions; import ml.dmlc.xgboost4j.java.Booster; import ml.dmlc.xgboost4j.java.ColumnBatch; import ml.dmlc.xgboost4j.java.DMatrix; -import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix; +import ml.dmlc.xgboost4j.java.QuantileDMatrix; import ml.dmlc.xgboost4j.java.XGBoost; import ml.dmlc.xgboost4j.java.XGBoostError; @@ -107,7 +107,7 @@ public class BoosterTest { List tables = new LinkedList<>(); tables.add(batch); - DMatrix incrementalDMatrix = new DeviceQuantileDMatrix(tables.iterator(), Float.NaN, maxBin, 1); + DMatrix incrementalDMatrix = new QuantileDMatrix(tables.iterator(), Float.NaN, maxBin, 1); //set watchList HashMap watches1 = new HashMap<>(); watches1.put("train", incrementalDMatrix); diff --git a/jvm-packages/xgboost4j-gpu/src/test/java/ml/dmlc/xgboost4j/gpu/java/DMatrixTest.java b/jvm-packages/xgboost4j-gpu/src/test/java/ml/dmlc/xgboost4j/gpu/java/DMatrixTest.java index b08694658..ea9f422e1 100644 --- a/jvm-packages/xgboost4j-gpu/src/test/java/ml/dmlc/xgboost4j/gpu/java/DMatrixTest.java +++ b/jvm-packages/xgboost4j-gpu/src/test/java/ml/dmlc/xgboost4j/gpu/java/DMatrixTest.java @@ -29,7 +29,7 @@ import org.junit.Test; import ai.rapids.cudf.Table; import ml.dmlc.xgboost4j.java.DMatrix; -import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix; +import ml.dmlc.xgboost4j.java.QuantileDMatrix; import ml.dmlc.xgboost4j.java.ColumnBatch; import ml.dmlc.xgboost4j.java.XGBoostError; @@ -117,7 +117,7 @@ public class DMatrixTest { tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0)); tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1)); - DMatrix dmat = new DeviceQuantileDMatrix(tables.iterator(), 0.0f, 8, 1); + DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 8, 1); float[] anchorLabel = convertFloatTofloat((Float[]) ArrayUtils.addAll(label1, label2)); float[] anchorWeight = convertFloatTofloat((Float[]) ArrayUtils.addAll(weight1, weight2)); diff --git a/jvm-packages/xgboost4j-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/DeviceQuantileDMatrixSuite.scala b/jvm-packages/xgboost4j-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrixSuite.scala similarity index 94% rename from jvm-packages/xgboost4j-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/DeviceQuantileDMatrixSuite.scala rename to jvm-packages/xgboost4j-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrixSuite.scala index a98054e67..ba8c5fa9a 100644 --- a/jvm-packages/xgboost4j-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/DeviceQuantileDMatrixSuite.scala +++ b/jvm-packages/xgboost4j-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrixSuite.scala @@ -22,9 +22,9 @@ import ai.rapids.cudf.Table import org.scalatest.FunSuite import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch -class DeviceQuantileDMatrixSuite extends FunSuite { +class QuantileDMatrixSuite extends FunSuite { - test("DeviceQuantileDMatrix test") { + test("QuantileDMatrix test") { val label1 = Array[java.lang.Float](25f, 21f, 22f, 20f, 24f) val weight1 = Array[java.lang.Float](1.3f, 2.31f, 0.32f, 3.3f, 1.34f) @@ -51,8 +51,7 @@ class DeviceQuantileDMatrixSuite extends FunSuite { val batches = new ArrayBuffer[CudfColumnBatch]() batches += new CudfColumnBatch(X_0, y_0, w_0, m_0) batches += new CudfColumnBatch(X_1, y_1, w_1, m_1) - val dmatrix = new DeviceQuantileDMatrix(batches.toIterator, 0.0f, 8, 1) - + val dmatrix = new QuantileDMatrix(batches.toIterator, 0.0f, 8, 1) assert(dmatrix.getLabel.sameElements(label1 ++ label2)) assert(dmatrix.getWeight.sameElements(weight1 ++ weight2)) assert(dmatrix.getBaseMargin.sameElements(baseMargin1 ++ baseMargin2)) diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala index 6fbe6e129..d28ae55e5 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala @@ -20,7 +20,7 @@ import scala.collection.JavaConverters._ import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch -import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, DeviceQuantileDMatrix} +import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, QuantileDMatrix} import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon import ml.dmlc.xgboost4j.scala.spark.{PreXGBoost, PreXGBoostProvider, Watches, XGBoost, XGBoostClassificationModel, XGBoostClassifier, XGBoostExecutionParams, XGBoostRegressionModel, XGBoostRegressor} import org.apache.commons.logging.LogFactory @@ -532,7 +532,7 @@ object GpuPreXGBoost extends PreXGBoostProvider { } /** - * Build DeviceQuantileDMatrix based on GpuColumnBatches + * Build QuantileDMatrix based on GpuColumnBatches * * @param iter a sequence of GpuColumnBatch * @param indices indicate the feature, label, weight, base margin column ids. @@ -546,7 +546,7 @@ object GpuPreXGBoost extends PreXGBoostProvider { missing: Float, maxBin: Int): DMatrix = { val rapidsIterator = new RapidsIterator(iter, indices) - new DeviceQuantileDMatrix(rapidsIterator, missing, maxBin, 1) + new QuantileDMatrix(rapidsIterator, missing, maxBin, 1) } // zip all the Columnar RDDs into one RDD containing named column data batch. diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ColumnBatch.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ColumnBatch.java index c151fc749..2ac481193 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ColumnBatch.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ColumnBatch.java @@ -71,7 +71,7 @@ public abstract class ColumnBatch implements AutoCloseable { /** * Get the cuda array interface of the label columns. * The returned value must not be null or empty if we're creating - * {@link DeviceQuantileDMatrix#DeviceQuantileDMatrix(Iterator, float, int, int)} + * {@link QuantileDMatrix#QuantileDMatrix(Iterator, float, int, int)} */ public abstract String getLabelsArrayInterface(); diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DeviceQuantileDMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DeviceQuantileDMatrix.java deleted file mode 100644 index 849e7a723..000000000 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DeviceQuantileDMatrix.java +++ /dev/null @@ -1,68 +0,0 @@ -package ml.dmlc.xgboost4j.java; - -import java.util.Iterator; - -/** - * DeviceQuantileDMatrix will only be used to train - */ -public class DeviceQuantileDMatrix extends DMatrix { - /** - * Create DeviceQuantileDMatrix from iterator based on the cuda array interface - * @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface - * @param missing the missing value - * @param maxBin the max bin - * @param nthread the parallelism - * @throws XGBoostError - */ - public DeviceQuantileDMatrix( - Iterator iter, - float missing, - int maxBin, - int nthread) throws XGBoostError { - super(0); - long[] out = new long[1]; - XGBoostJNI.checkCall(XGBoostJNI.XGDeviceQuantileDMatrixCreateFromCallback( - iter, missing, maxBin, nthread, out)); - handle = out[0]; - } - - @Override - public void setLabel(Column column) throws XGBoostError { - throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel."); - } - - @Override - public void setWeight(Column column) throws XGBoostError { - throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight."); - } - - @Override - public void setBaseMargin(Column column) throws XGBoostError { - throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin."); - } - - @Override - public void setLabel(float[] labels) throws XGBoostError { - throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel."); - } - - @Override - public void setWeight(float[] weights) throws XGBoostError { - throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight."); - } - - @Override - public void setBaseMargin(float[] baseMargin) throws XGBoostError { - throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin."); - } - - @Override - public void setBaseMargin(float[][] baseMargin) throws XGBoostError { - throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin."); - } - - @Override - public void setGroup(int[] group) throws XGBoostError { - throw new XGBoostError("DeviceQuantileDMatrix does not support setGroup."); - } -} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/QuantileDMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/QuantileDMatrix.java new file mode 100644 index 000000000..6cd189e69 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/QuantileDMatrix.java @@ -0,0 +1,75 @@ +package ml.dmlc.xgboost4j.java; + +import java.util.Iterator; + +/** + * QuantileDMatrix will only be used to train + */ +public class QuantileDMatrix extends DMatrix { + /** + * Create QuantileDMatrix from iterator based on the cuda array interface + * + * @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface + * @param missing the missing value + * @param maxBin the max bin + * @param nthread the parallelism + * @throws XGBoostError + */ + public QuantileDMatrix( + Iterator iter, + float missing, + int maxBin, + int nthread) throws XGBoostError { + super(0); + long[] out = new long[1]; + String conf = getConfig(missing, maxBin, nthread); + XGBoostJNI.checkCall(XGBoostJNI.XGQuantileDMatrixCreateFromCallback( + iter, (java.util.Iterator)null, conf, out)); + handle = out[0]; + } + + @Override + public void setLabel(Column column) throws XGBoostError { + throw new XGBoostError("QuantileDMatrix does not support setLabel."); + } + + @Override + public void setWeight(Column column) throws XGBoostError { + throw new XGBoostError("QuantileDMatrix does not support setWeight."); + } + + @Override + public void setBaseMargin(Column column) throws XGBoostError { + throw new XGBoostError("QuantileDMatrix does not support setBaseMargin."); + } + + @Override + public void setLabel(float[] labels) throws XGBoostError { + throw new XGBoostError("QuantileDMatrix does not support setLabel."); + } + + @Override + public void setWeight(float[] weights) throws XGBoostError { + throw new XGBoostError("QuantileDMatrix does not support setWeight."); + } + + @Override + public void setBaseMargin(float[] baseMargin) throws XGBoostError { + throw new XGBoostError("QuantileDMatrix does not support setBaseMargin."); + } + + @Override + public void setBaseMargin(float[][] baseMargin) throws XGBoostError { + throw new XGBoostError("QuantileDMatrix does not support setBaseMargin."); + } + + @Override + public void setGroup(int[] group) throws XGBoostError { + throw new XGBoostError("QuantileDMatrix does not support setGroup."); + } + + private String getConfig(float missing, int maxBin, int nthread) { + return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d}", + missing, maxBin, nthread); + } +} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index afe576598..63d536527 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -149,9 +149,13 @@ class XGBoostJNI { public final static native int XGDMatrixSetInfoFromInterface( long handle, String field, String json); + @Deprecated public final static native int XGDeviceQuantileDMatrixCreateFromCallback( java.util.Iterator iter, float missing, int nthread, int maxBin, long[] out); + public final static native int XGQuantileDMatrixCreateFromCallback( + java.util.Iterator iter, java.util.Iterator ref, String config, long[] out); + public final static native int XGDMatrixCreateFromArrayInterfaceColumns( String featureJson, float missing, int nthread, long[] out); diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DeviceQuantileDMatrix.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrix.scala similarity index 72% rename from jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DeviceQuantileDMatrix.scala rename to jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrix.scala index efe98bd42..cf72746d2 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DeviceQuantileDMatrix.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrix.scala @@ -18,13 +18,13 @@ package ml.dmlc.xgboost4j.scala import _root_.scala.collection.JavaConverters._ -import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, XGBoostError, DeviceQuantileDMatrix => JDeviceQuantileDMatrix} +import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, XGBoostError, QuantileDMatrix => JQuantileDMatrix} -class DeviceQuantileDMatrix private[scala]( - private[scala] override val jDMatrix: JDeviceQuantileDMatrix) extends DMatrix(jDMatrix) { +class QuantileDMatrix private[scala]( + private[scala] override val jDMatrix: JQuantileDMatrix) extends DMatrix(jDMatrix) { /** - * Create DeviceQuantileDMatrix from iterator based on the cuda array interface + * Create QuantileDMatrix from iterator based on the cuda array interface * * @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface * @param missing the missing value @@ -33,7 +33,7 @@ class DeviceQuantileDMatrix private[scala]( * @throws XGBoostError */ def this(iter: Iterator[ColumnBatch], missing: Float, maxBin: Int, nthread: Int) { - this(new JDeviceQuantileDMatrix(iter.asJava, missing, maxBin, nthread)) + this(new JQuantileDMatrix(iter.asJava, missing, maxBin, nthread)) } /** @@ -43,7 +43,7 @@ class DeviceQuantileDMatrix private[scala]( */ @throws(classOf[XGBoostError]) override def setLabel(labels: Array[Float]): Unit = - throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.") + throw new XGBoostError("QuantileDMatrix does not support setLabel.") /** * set weight of each instance @@ -52,7 +52,7 @@ class DeviceQuantileDMatrix private[scala]( */ @throws(classOf[XGBoostError]) override def setWeight(weights: Array[Float]): Unit = - throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.") + throw new XGBoostError("QuantileDMatrix does not support setWeight.") /** * if specified, xgboost will start from this init margin @@ -62,7 +62,7 @@ class DeviceQuantileDMatrix private[scala]( */ @throws(classOf[XGBoostError]) override def setBaseMargin(baseMargin: Array[Float]): Unit = - throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.") + throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.") /** * if specified, xgboost will start from this init margin @@ -72,7 +72,7 @@ class DeviceQuantileDMatrix private[scala]( */ @throws(classOf[XGBoostError]) override def setBaseMargin(baseMargin: Array[Array[Float]]): Unit = - throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.") + throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.") /** * Set group sizes of DMatrix (used for ranking) @@ -81,27 +81,27 @@ class DeviceQuantileDMatrix private[scala]( */ @throws(classOf[XGBoostError]) override def setGroup(group: Array[Int]): Unit = - throw new XGBoostError("DeviceQuantileDMatrix does not support setGroup.") + throw new XGBoostError("QuantileDMatrix does not support setGroup.") /** * Set label of DMatrix from cuda array interface */ @throws(classOf[XGBoostError]) override def setLabel(column: Column): Unit = - throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.") + throw new XGBoostError("QuantileDMatrix does not support setLabel.") /** * set weight of dmatrix from column array interface */ @throws(classOf[XGBoostError]) override def setWeight(column: Column): Unit = - throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.") + throw new XGBoostError("QuantileDMatrix does not support setWeight.") /** * set base margin of dmatrix from column array interface */ @throws(classOf[XGBoostError]) override def setBaseMargin(column: Column): Unit = - throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.") + throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.") } diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 749fa5b40..5ca2dc42d 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -962,6 +962,9 @@ namespace jni { jfloat jmissing, jint jmax_bin, jint jnthread, jlongArray jout); + XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls, + jobject jdata_iter, jobject jref_iter, + char const *config, jlongArray jout); } // namespace jni } // namespace xgboost @@ -977,6 +980,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDM jenv, jcls, jiter, jmissing, jmax_bin, jnthread, jout); } +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGQuantileDMatrixCreateFromCallback + * Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback + (JNIEnv *jenv, jclass jcls, jobject jdata_iter, jobject jref_iter, jstring jconf, jlongArray jout) { + char const *conf = jenv->GetStringUTFChars(jconf, 0); + return xgboost::jni::XGQuantileDMatrixCreateFromCallbackImpl(jenv, jcls, jdata_iter, jref_iter, + conf, jout); +} + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGDMatrixSetInfoFromInterface diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index 5afe92b52..adc5e814c 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -345,6 +345,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDM /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGQuantileDMatrixCreateFromCallback + * Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback + (JNIEnv *, jclass, jobject, jobject, jstring, jlongArray); + + /* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGDMatrixCreateFromArrayInterfaceColumns * Signature: (Ljava/lang/String;FI[J)I */