[jvm-packages] [breaking] rework xgboost4j-spark and xgboost4j-spark-gpu (#10639)

- Introduce an abstract XGBoost Estimator
- Update to the latest XGBoost parameters
  - Add all XGBoost parameters supported in XGBoost4j-spark.
  - Add setter and getter for these parameters.
  - Remove the deprecated parameters
- Address the missing value handling
- Remove any ETL operations in XGBoost
- Rework the GPU plugin
- Expand sanity tests for CPU and GPU consistency
This commit is contained in:
Bobby Wang
2024-09-11 15:54:19 +08:00
committed by GitHub
parent d94f6679fc
commit 67c8c96784
75 changed files with 4537 additions and 7556 deletions

View File

@@ -519,4 +519,39 @@ public class DMatrix {
CSR,
CSC
}
/**
* A class to hold the quantile information
*/
public class QuantileCut {
// cut ptr
long[] indptr;
// cut values
float[] values;
QuantileCut(long[] indptr, float[] values) {
this.indptr = indptr;
this.values = values;
}
public long[] getIndptr() {
return indptr;
}
public float[] getValues() {
return values;
}
}
/**
* Get the Quantile Cut.
* @return QuantileCut
* @throws XGBoostError
*/
public QuantileCut getQuantileCut() throws XGBoostError {
long[][] indptr = new long[1][];
float[][] values = new float[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetQuantileCut(this.handle, indptr, values));
return new QuantileCut(indptr[0], values[0]);
}
}

View File

@@ -1,75 +0,0 @@
package ml.dmlc.xgboost4j.java;
import java.util.Iterator;
/**
* QuantileDMatrix will only be used to train
*/
public class QuantileDMatrix extends DMatrix {
/**
* Create QuantileDMatrix from iterator based on the cuda array interface
*
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @throws XGBoostError
*/
public QuantileDMatrix(
Iterator<ColumnBatch> iter,
float missing,
int maxBin,
int nthread) throws XGBoostError {
super(0);
long[] out = new long[1];
String conf = getConfig(missing, maxBin, nthread);
XGBoostJNI.checkCall(XGBoostJNI.XGQuantileDMatrixCreateFromCallback(
iter, (java.util.Iterator<ColumnBatch>)null, conf, out));
handle = out[0];
}
@Override
public void setLabel(Column column) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
}
@Override
public void setWeight(Column column) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
}
@Override
public void setBaseMargin(Column column) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setLabel(float[] labels) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
}
@Override
public void setWeight(float[] weights) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
}
@Override
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setGroup(int[] group) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setGroup.");
}
private String getConfig(float missing, int maxBin, int nthread) {
return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d}",
missing, maxBin, nthread);
}
}

View File

@@ -172,7 +172,7 @@ class XGBoostJNI {
long handle, String field, String json);
public final static native int XGQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, java.util.Iterator<ColumnBatch> ref, String config, long[] out);
java.util.Iterator<ColumnBatch> iter, long[] ref, String config, long[] out);
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
String featureJson, float missing, int nthread, long[] out);
@@ -180,4 +180,7 @@ class XGBoostJNI {
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
public final static native int XGDMatrixGetQuantileCut(long handle, long[][] outIndptr, float[][] outValues);
}

View File

@@ -365,4 +365,8 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
override def read(kryo: Kryo, input: Input): Unit = {
booster = kryo.readObject(input, classOf[JBooster])
}
// a flag to indicate if the device is set for the GPU transform
var deviceIsSet = false
}

View File

@@ -16,7 +16,7 @@
package ml.dmlc.xgboost4j.scala
import _root_.scala.collection.JavaConverters._
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, DMatrix => JDMatrix, XGBoostError}

View File

@@ -1,107 +0,0 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala
import _root_.scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, XGBoostError, QuantileDMatrix => JQuantileDMatrix}
class QuantileDMatrix private[scala](
private[scala] override val jDMatrix: JQuantileDMatrix) extends DMatrix(jDMatrix) {
/**
* Create QuantileDMatrix from iterator based on the cuda array interface
*
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @throws XGBoostError
*/
def this(iter: Iterator[ColumnBatch], missing: Float, maxBin: Int, nthread: Int) {
this(new JQuantileDMatrix(iter.asJava, missing, maxBin, nthread))
}
/**
* set label of dmatrix
*
* @param labels labels
*/
@throws(classOf[XGBoostError])
override def setLabel(labels: Array[Float]): Unit =
throw new XGBoostError("QuantileDMatrix does not support setLabel.")
/**
* set weight of each instance
*
* @param weights weights
*/
@throws(classOf[XGBoostError])
override def setWeight(weights: Array[Float]): Unit =
throw new XGBoostError("QuantileDMatrix does not support setWeight.")
/**
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from
*
* @param baseMargin base margin
*/
@throws(classOf[XGBoostError])
override def setBaseMargin(baseMargin: Array[Float]): Unit =
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.")
/**
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from
*
* @param baseMargin base margin
*/
@throws(classOf[XGBoostError])
override def setBaseMargin(baseMargin: Array[Array[Float]]): Unit =
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.")
/**
* Set group sizes of DMatrix (used for ranking)
*
* @param group group size as array
*/
@throws(classOf[XGBoostError])
override def setGroup(group: Array[Int]): Unit =
throw new XGBoostError("QuantileDMatrix does not support setGroup.")
/**
* Set label of DMatrix from cuda array interface
*/
@throws(classOf[XGBoostError])
override def setLabel(column: Column): Unit =
throw new XGBoostError("QuantileDMatrix does not support setLabel.")
/**
* set weight of dmatrix from column array interface
*/
@throws(classOf[XGBoostError])
override def setWeight(column: Column): Unit =
throw new XGBoostError("QuantileDMatrix does not support setWeight.")
/**
* set base margin of dmatrix from column array interface
*/
@throws(classOf[XGBoostError])
override def setBaseMargin(column: Column): Unit =
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.")
}

View File

@@ -1,7 +1,6 @@
//
// Created by bobwang on 2021/9/8.
//
/**
* Copyright 2021-2024, XGBoost Contributors
*/
#ifndef XGBOOST_USE_CUDA
#include <jni.h>
@@ -21,7 +20,7 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass j
API_END();
}
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jdata_iter, jobject jref_iter,
jobject jdata_iter, jlongArray jref,
char const *config, jlongArray jout) {
API_BEGIN();
common::AssertGPUSupport();

View File

@@ -1,10 +1,13 @@
/**
* Copyright 2021-2024, XGBoost Contributors
*/
#include <jni.h>
#include <xgboost/c_api.h>
#include "../../../../src/common/device_helpers.cuh"
#include "../../../../src/common/cuda_pinned_allocator.h"
#include "../../../../src/common/device_vector.cuh" // for device_vector
#include "../../../../src/data/array_interface.h"
#include "jvm_utils.h"
#include <xgboost/c_api.h>
namespace xgboost {
namespace jni {
@@ -396,6 +399,9 @@ void Reset(DataIterHandle self) {
int Next(DataIterHandle self) {
return static_cast<xgboost::jni::DataIteratorProxy *>(self)->Next();
}
template <typename T>
using Deleter = std::function<void(T *)>;
} // anonymous namespace
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
@@ -413,17 +419,23 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass j
}
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jdata_iter, jobject jref_iter,
jobject jdata_iter, jlongArray jref,
char const *config, jlongArray jout) {
xgboost::jni::DataIteratorProxy proxy(jdata_iter);
DMatrixHandle result;
DMatrixHandle ref{nullptr};
std::unique_ptr<xgboost::jni::DataIteratorProxy> ref_proxy{nullptr};
if (jref_iter) {
ref_proxy = std::make_unique<xgboost::jni::DataIteratorProxy>(jref_iter);
if (jref != NULL) {
std::unique_ptr<jlong, Deleter<jlong>> refptr{jenv->GetLongArrayElements(jref, nullptr),
[&](jlong *ptr) {
jenv->ReleaseLongArrayElements(jref, ptr, 0);
jenv->DeleteLocalRef(jref);
}};
ref = reinterpret_cast<DMatrixHandle>(refptr.get()[0]);
}
auto ret = XGQuantileDMatrixCreateFromCallback(
&proxy, proxy.GetDMatrixHandle(), ref_proxy.get(), Reset, Next, config, &result);
&proxy, proxy.GetDMatrixHandle(), ref, Reset, Next, config, &result);
setHandle(jenv, jout, result);
return ret;
}

View File

@@ -20,6 +20,7 @@
#include <xgboost/c_api.h>
#include <xgboost/json.h>
#include <xgboost/logging.h>
#include <xgboost/string_view.h> // for StringView
#include <algorithm> // for copy_n
#include <cstddef>
@@ -30,8 +31,9 @@
#include <type_traits>
#include <vector>
#include "../../../src/c_api/c_api_error.h"
#include "../../../src/c_api/c_api_utils.h"
#include "../../../../src/c_api/c_api_error.h"
#include "../../../../src/c_api/c_api_utils.h"
#include "../../../../src/data/array_interface.h" // for ArrayInterface
#define JVM_CHECK_CALL(__expr) \
{ \
@@ -1330,16 +1332,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDM
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGQuantileDMatrixCreateFromCallback
* Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I
* Signature: (Ljava/util/Iterator;[JLjava/lang/String;[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback(
JNIEnv *jenv, jclass jcls, jobject jdata_iter, jobject jref_iter, jstring jconf,
JNIEnv *jenv, jclass jcls, jobject jdata_iter, jlongArray jref, jstring jconf,
jlongArray jout) {
std::unique_ptr<char const, Deleter<char const>> conf{jenv->GetStringUTFChars(jconf, nullptr),
[&](char const *ptr) {
jenv->ReleaseStringUTFChars(jconf, ptr);
}};
return xgboost::jni::XGQuantileDMatrixCreateFromCallbackImpl(jenv, jcls, jdata_iter, jref_iter,
return xgboost::jni::XGQuantileDMatrixCreateFromCallbackImpl(jenv, jcls, jdata_iter, jref,
conf.get(), jout);
}
@@ -1517,3 +1519,44 @@ Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo(
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetQuantileCut
* Signature: (J[[J[[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetQuantileCut(
JNIEnv *jenv, jclass, jlong jhandle, jobjectArray j_indptr, jobjectArray j_values) {
using namespace xgboost; // NOLINT
auto handle = reinterpret_cast<DMatrixHandle>(jhandle);
char const *str_indptr;
char const *str_data;
Json config{Object{}};
auto str_config = Json::Dump(config);
auto ret = XGDMatrixGetQuantileCut(handle, str_config.c_str(), &str_indptr, &str_data);
ArrayInterface<1> indptr{StringView{str_indptr}};
ArrayInterface<1> data{StringView{str_data}};
CHECK_GE(indptr.Shape(0), 2);
// Cut ptr
auto j_indptr_array = jenv->NewLongArray(indptr.Shape(0));
CHECK_EQ(indptr.type, ArrayInterfaceHandler::Type::kU8);
CHECK_LT(indptr(indptr.Shape(0) - 1),
static_cast<std::uint64_t>(std::numeric_limits<std::int64_t>::max()));
static_assert(sizeof(jlong) == sizeof(std::uint64_t));
jenv->SetLongArrayRegion(j_indptr_array, 0, indptr.Shape(0),
static_cast<jlong const *>(indptr.data));
jenv->SetObjectArrayElement(j_indptr, 0, j_indptr_array);
// Cut values
auto n_cuts = indptr(indptr.Shape(0) - 1);
jfloatArray jcuts_array = jenv->NewFloatArray(n_cuts);
CHECK_EQ(data.type, ArrayInterfaceHandler::Type::kF4);
jenv->SetFloatArrayRegion(jcuts_array, 0, n_cuts, static_cast<float const *>(data.data));
jenv->SetObjectArrayElement(j_values, 0, jcuts_array);
return ret;
}

View File

@@ -402,10 +402,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFr
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGQuantileDMatrixCreateFromCallback
* Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I
* Signature: (Ljava/util/Iterator;[JLjava/lang/String;[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback
(JNIEnv *, jclass, jobject, jobject, jstring, jlongArray);
(JNIEnv *, jclass, jobject, jlongArray, jstring, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
@@ -431,6 +431,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFea
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetQuantileCut
* Signature: (J[[J[[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetQuantileCut
(JNIEnv *, jclass, jlong, jobjectArray, jobjectArray);
#ifdef __cplusplus
}
#endif

View File

@@ -258,8 +258,7 @@ public class DMatrixTest {
TestCase.assertTrue(Arrays.equals(weights, dmat0.getWeight()));
}
@Test
public void testCreateFromDenseMatrixWithMissingValue() throws XGBoostError {
private DMatrix createFromDenseMatrix() throws XGBoostError {
//create DMatrix from 10*5 dense matrix
int nrow = 10;
int ncol = 5;
@@ -280,12 +279,17 @@ public class DMatrixTest {
label0[i] = random.nextFloat();
}
DMatrix dmat0 = new DMatrix(data0, nrow, ncol, -0.1f);
dmat0.setLabel(label0);
DMatrix dm = new DMatrix(data0, nrow, ncol, -0.1f);
dm.setLabel(label0);
return dm;
}
@Test
public void testCreateFromDenseMatrixWithMissingValue() throws XGBoostError {
DMatrix dm = createFromDenseMatrix();
//check
TestCase.assertTrue(dmat0.rowNum() == 10);
TestCase.assertTrue(dmat0.getLabel().length == 10);
TestCase.assertTrue(dm.rowNum() == 10);
TestCase.assertTrue(dm.getLabel().length == 10);
}
@Test
@@ -493,4 +497,28 @@ public class DMatrixTest {
TestCase.assertTrue(Arrays.equals(qidExpected1, dmat0.getGroup()));
}
@Test
public void getGetQuantileCut() throws XGBoostError {
DMatrix Xy = createFromDenseMatrix();
Map<String, Object> params = new HashMap<String, Object>();
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", Xy);
XGBoost.train(Xy, params, 1, watches, null, null); // Create the cuts
DMatrix.QuantileCut cuts = Xy.getQuantileCut();
TestCase.assertEquals(cuts.indptr.length, 6);
for (int i = 1; i < cuts.indptr.length; ++i) {
// Number of bins for each feature + min value.
TestCase.assertTrue(cuts.indptr[i] - cuts.indptr[i - 1] >= 5);
TestCase.assertTrue(cuts.indptr[i] - cuts.indptr[i - 1] <= Xy.rowNum() + 1);
}
TestCase.assertEquals(cuts.values.length, cuts.indptr[cuts.indptr.length - 1]);
for (int i = 1; i < cuts.indptr.length; ++i) {
long begin = cuts.indptr[i - 1];
long end = cuts.indptr[i];
for (long j = begin + 1; j < end; ++j) {
TestCase.assertTrue(cuts.values[(int) j] > cuts.values[(int) j - 1]);
}
}
}
}