[jvm-packages] [breaking] rework xgboost4j-spark and xgboost4j-spark-gpu (#10639)

- Introduce an abstract XGBoost Estimator
- Update to the latest XGBoost parameters
  - Add all XGBoost parameters supported in XGBoost4j-spark.
  - Add setter and getter for these parameters.
  - Remove the deprecated parameters
- Address the missing value handling
- Remove any ETL operations in XGBoost
- Rework the GPU plugin
- Expand sanity tests for CPU and GPU consistency
This commit is contained in:
Bobby Wang
2024-09-11 15:54:19 +08:00
committed by GitHub
parent d94f6679fc
commit 67c8c96784
75 changed files with 4537 additions and 7556 deletions

View File

@@ -1,7 +1,6 @@
//
// Created by bobwang on 2021/9/8.
//
/**
* Copyright 2021-2024, XGBoost Contributors
*/
#ifndef XGBOOST_USE_CUDA
#include <jni.h>
@@ -21,7 +20,7 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass j
API_END();
}
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jdata_iter, jobject jref_iter,
jobject jdata_iter, jlongArray jref,
char const *config, jlongArray jout) {
API_BEGIN();
common::AssertGPUSupport();

View File

@@ -1,10 +1,13 @@
/**
* Copyright 2021-2024, XGBoost Contributors
*/
#include <jni.h>
#include <xgboost/c_api.h>
#include "../../../../src/common/device_helpers.cuh"
#include "../../../../src/common/cuda_pinned_allocator.h"
#include "../../../../src/common/device_vector.cuh" // for device_vector
#include "../../../../src/data/array_interface.h"
#include "jvm_utils.h"
#include <xgboost/c_api.h>
namespace xgboost {
namespace jni {
@@ -396,6 +399,9 @@ void Reset(DataIterHandle self) {
int Next(DataIterHandle self) {
return static_cast<xgboost::jni::DataIteratorProxy *>(self)->Next();
}
template <typename T>
using Deleter = std::function<void(T *)>;
} // anonymous namespace
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
@@ -413,17 +419,23 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass j
}
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jdata_iter, jobject jref_iter,
jobject jdata_iter, jlongArray jref,
char const *config, jlongArray jout) {
xgboost::jni::DataIteratorProxy proxy(jdata_iter);
DMatrixHandle result;
DMatrixHandle ref{nullptr};
std::unique_ptr<xgboost::jni::DataIteratorProxy> ref_proxy{nullptr};
if (jref_iter) {
ref_proxy = std::make_unique<xgboost::jni::DataIteratorProxy>(jref_iter);
if (jref != NULL) {
std::unique_ptr<jlong, Deleter<jlong>> refptr{jenv->GetLongArrayElements(jref, nullptr),
[&](jlong *ptr) {
jenv->ReleaseLongArrayElements(jref, ptr, 0);
jenv->DeleteLocalRef(jref);
}};
ref = reinterpret_cast<DMatrixHandle>(refptr.get()[0]);
}
auto ret = XGQuantileDMatrixCreateFromCallback(
&proxy, proxy.GetDMatrixHandle(), ref_proxy.get(), Reset, Next, config, &result);
&proxy, proxy.GetDMatrixHandle(), ref, Reset, Next, config, &result);
setHandle(jenv, jout, result);
return ret;
}

View File

@@ -20,6 +20,7 @@
#include <xgboost/c_api.h>
#include <xgboost/json.h>
#include <xgboost/logging.h>
#include <xgboost/string_view.h> // for StringView
#include <algorithm> // for copy_n
#include <cstddef>
@@ -30,8 +31,9 @@
#include <type_traits>
#include <vector>
#include "../../../src/c_api/c_api_error.h"
#include "../../../src/c_api/c_api_utils.h"
#include "../../../../src/c_api/c_api_error.h"
#include "../../../../src/c_api/c_api_utils.h"
#include "../../../../src/data/array_interface.h" // for ArrayInterface
#define JVM_CHECK_CALL(__expr) \
{ \
@@ -1330,16 +1332,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDM
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGQuantileDMatrixCreateFromCallback
* Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I
* Signature: (Ljava/util/Iterator;[JLjava/lang/String;[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback(
JNIEnv *jenv, jclass jcls, jobject jdata_iter, jobject jref_iter, jstring jconf,
JNIEnv *jenv, jclass jcls, jobject jdata_iter, jlongArray jref, jstring jconf,
jlongArray jout) {
std::unique_ptr<char const, Deleter<char const>> conf{jenv->GetStringUTFChars(jconf, nullptr),
[&](char const *ptr) {
jenv->ReleaseStringUTFChars(jconf, ptr);
}};
return xgboost::jni::XGQuantileDMatrixCreateFromCallbackImpl(jenv, jcls, jdata_iter, jref_iter,
return xgboost::jni::XGQuantileDMatrixCreateFromCallbackImpl(jenv, jcls, jdata_iter, jref,
conf.get(), jout);
}
@@ -1517,3 +1519,44 @@ Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo(
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetQuantileCut
* Signature: (J[[J[[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetQuantileCut(
JNIEnv *jenv, jclass, jlong jhandle, jobjectArray j_indptr, jobjectArray j_values) {
using namespace xgboost; // NOLINT
auto handle = reinterpret_cast<DMatrixHandle>(jhandle);
char const *str_indptr;
char const *str_data;
Json config{Object{}};
auto str_config = Json::Dump(config);
auto ret = XGDMatrixGetQuantileCut(handle, str_config.c_str(), &str_indptr, &str_data);
ArrayInterface<1> indptr{StringView{str_indptr}};
ArrayInterface<1> data{StringView{str_data}};
CHECK_GE(indptr.Shape(0), 2);
// Cut ptr
auto j_indptr_array = jenv->NewLongArray(indptr.Shape(0));
CHECK_EQ(indptr.type, ArrayInterfaceHandler::Type::kU8);
CHECK_LT(indptr(indptr.Shape(0) - 1),
static_cast<std::uint64_t>(std::numeric_limits<std::int64_t>::max()));
static_assert(sizeof(jlong) == sizeof(std::uint64_t));
jenv->SetLongArrayRegion(j_indptr_array, 0, indptr.Shape(0),
static_cast<jlong const *>(indptr.data));
jenv->SetObjectArrayElement(j_indptr, 0, j_indptr_array);
// Cut values
auto n_cuts = indptr(indptr.Shape(0) - 1);
jfloatArray jcuts_array = jenv->NewFloatArray(n_cuts);
CHECK_EQ(data.type, ArrayInterfaceHandler::Type::kF4);
jenv->SetFloatArrayRegion(jcuts_array, 0, n_cuts, static_cast<float const *>(data.data));
jenv->SetObjectArrayElement(j_values, 0, jcuts_array);
return ret;
}

View File

@@ -402,10 +402,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFr
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGQuantileDMatrixCreateFromCallback
* Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I
* Signature: (Ljava/util/Iterator;[JLjava/lang/String;[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback
(JNIEnv *, jclass, jobject, jobject, jstring, jlongArray);
(JNIEnv *, jclass, jobject, jlongArray, jstring, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
@@ -431,6 +431,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFea
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetQuantileCut
* Signature: (J[[J[[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetQuantileCut
(JNIEnv *, jclass, jlong, jobjectArray, jobjectArray);
#ifdef __cplusplus
}
#endif