[jvm-packages] [breaking] rework xgboost4j-spark and xgboost4j-spark-gpu (#10639)
- Introduce an abstract XGBoost Estimator - Update to the latest XGBoost parameters - Add all XGBoost parameters supported in XGBoost4j-spark. - Add setter and getter for these parameters. - Remove the deprecated parameters - Address the missing value handling - Remove any ETL operations in XGBoost - Rework the GPU plugin - Expand sanity tests for CPU and GPU consistency
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
//
|
||||
// Created by bobwang on 2021/9/8.
|
||||
//
|
||||
|
||||
/**
|
||||
* Copyright 2021-2024, XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
|
||||
#include <jni.h>
|
||||
@@ -21,7 +20,7 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass j
|
||||
API_END();
|
||||
}
|
||||
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
jobject jdata_iter, jobject jref_iter,
|
||||
jobject jdata_iter, jlongArray jref,
|
||||
char const *config, jlongArray jout) {
|
||||
API_BEGIN();
|
||||
common::AssertGPUSupport();
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
/**
|
||||
* Copyright 2021-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <jni.h>
|
||||
#include <xgboost/c_api.h>
|
||||
|
||||
#include "../../../../src/common/device_helpers.cuh"
|
||||
#include "../../../../src/common/cuda_pinned_allocator.h"
|
||||
#include "../../../../src/common/device_vector.cuh" // for device_vector
|
||||
#include "../../../../src/data/array_interface.h"
|
||||
#include "jvm_utils.h"
|
||||
#include <xgboost/c_api.h>
|
||||
|
||||
namespace xgboost {
|
||||
namespace jni {
|
||||
@@ -396,6 +399,9 @@ void Reset(DataIterHandle self) {
|
||||
int Next(DataIterHandle self) {
|
||||
return static_cast<xgboost::jni::DataIteratorProxy *>(self)->Next();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using Deleter = std::function<void(T *)>;
|
||||
} // anonymous namespace
|
||||
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
@@ -413,17 +419,23 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass j
|
||||
}
|
||||
|
||||
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
jobject jdata_iter, jobject jref_iter,
|
||||
jobject jdata_iter, jlongArray jref,
|
||||
char const *config, jlongArray jout) {
|
||||
xgboost::jni::DataIteratorProxy proxy(jdata_iter);
|
||||
DMatrixHandle result;
|
||||
DMatrixHandle ref{nullptr};
|
||||
|
||||
std::unique_ptr<xgboost::jni::DataIteratorProxy> ref_proxy{nullptr};
|
||||
if (jref_iter) {
|
||||
ref_proxy = std::make_unique<xgboost::jni::DataIteratorProxy>(jref_iter);
|
||||
if (jref != NULL) {
|
||||
std::unique_ptr<jlong, Deleter<jlong>> refptr{jenv->GetLongArrayElements(jref, nullptr),
|
||||
[&](jlong *ptr) {
|
||||
jenv->ReleaseLongArrayElements(jref, ptr, 0);
|
||||
jenv->DeleteLocalRef(jref);
|
||||
}};
|
||||
ref = reinterpret_cast<DMatrixHandle>(refptr.get()[0]);
|
||||
}
|
||||
|
||||
auto ret = XGQuantileDMatrixCreateFromCallback(
|
||||
&proxy, proxy.GetDMatrixHandle(), ref_proxy.get(), Reset, Next, config, &result);
|
||||
&proxy, proxy.GetDMatrixHandle(), ref, Reset, Next, config, &result);
|
||||
setHandle(jenv, jout, result);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/json.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/string_view.h> // for StringView
|
||||
|
||||
#include <algorithm> // for copy_n
|
||||
#include <cstddef>
|
||||
@@ -30,8 +31,9 @@
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "../../../src/c_api/c_api_error.h"
|
||||
#include "../../../src/c_api/c_api_utils.h"
|
||||
#include "../../../../src/c_api/c_api_error.h"
|
||||
#include "../../../../src/c_api/c_api_utils.h"
|
||||
#include "../../../../src/data/array_interface.h" // for ArrayInterface
|
||||
|
||||
#define JVM_CHECK_CALL(__expr) \
|
||||
{ \
|
||||
@@ -1330,16 +1332,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDM
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGQuantileDMatrixCreateFromCallback
|
||||
* Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I
|
||||
* Signature: (Ljava/util/Iterator;[JLjava/lang/String;[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback(
|
||||
JNIEnv *jenv, jclass jcls, jobject jdata_iter, jobject jref_iter, jstring jconf,
|
||||
JNIEnv *jenv, jclass jcls, jobject jdata_iter, jlongArray jref, jstring jconf,
|
||||
jlongArray jout) {
|
||||
std::unique_ptr<char const, Deleter<char const>> conf{jenv->GetStringUTFChars(jconf, nullptr),
|
||||
[&](char const *ptr) {
|
||||
jenv->ReleaseStringUTFChars(jconf, ptr);
|
||||
}};
|
||||
return xgboost::jni::XGQuantileDMatrixCreateFromCallbackImpl(jenv, jcls, jdata_iter, jref_iter,
|
||||
return xgboost::jni::XGQuantileDMatrixCreateFromCallbackImpl(jenv, jcls, jdata_iter, jref,
|
||||
conf.get(), jout);
|
||||
}
|
||||
|
||||
@@ -1517,3 +1519,44 @@ Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo(
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixGetQuantileCut
|
||||
* Signature: (J[[J[[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetQuantileCut(
|
||||
JNIEnv *jenv, jclass, jlong jhandle, jobjectArray j_indptr, jobjectArray j_values) {
|
||||
using namespace xgboost; // NOLINT
|
||||
auto handle = reinterpret_cast<DMatrixHandle>(jhandle);
|
||||
|
||||
char const *str_indptr;
|
||||
char const *str_data;
|
||||
Json config{Object{}};
|
||||
auto str_config = Json::Dump(config);
|
||||
|
||||
auto ret = XGDMatrixGetQuantileCut(handle, str_config.c_str(), &str_indptr, &str_data);
|
||||
|
||||
ArrayInterface<1> indptr{StringView{str_indptr}};
|
||||
ArrayInterface<1> data{StringView{str_data}};
|
||||
CHECK_GE(indptr.Shape(0), 2);
|
||||
|
||||
// Cut ptr
|
||||
auto j_indptr_array = jenv->NewLongArray(indptr.Shape(0));
|
||||
CHECK_EQ(indptr.type, ArrayInterfaceHandler::Type::kU8);
|
||||
CHECK_LT(indptr(indptr.Shape(0) - 1),
|
||||
static_cast<std::uint64_t>(std::numeric_limits<std::int64_t>::max()));
|
||||
static_assert(sizeof(jlong) == sizeof(std::uint64_t));
|
||||
jenv->SetLongArrayRegion(j_indptr_array, 0, indptr.Shape(0),
|
||||
static_cast<jlong const *>(indptr.data));
|
||||
jenv->SetObjectArrayElement(j_indptr, 0, j_indptr_array);
|
||||
|
||||
// Cut values
|
||||
auto n_cuts = indptr(indptr.Shape(0) - 1);
|
||||
jfloatArray jcuts_array = jenv->NewFloatArray(n_cuts);
|
||||
CHECK_EQ(data.type, ArrayInterfaceHandler::Type::kF4);
|
||||
jenv->SetFloatArrayRegion(jcuts_array, 0, n_cuts, static_cast<float const *>(data.data));
|
||||
jenv->SetObjectArrayElement(j_values, 0, jcuts_array);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -402,10 +402,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFr
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGQuantileDMatrixCreateFromCallback
|
||||
* Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I
|
||||
* Signature: (Ljava/util/Iterator;[JLjava/lang/String;[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback
|
||||
(JNIEnv *, jclass, jobject, jobject, jstring, jlongArray);
|
||||
(JNIEnv *, jclass, jobject, jlongArray, jstring, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
@@ -431,6 +431,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFea
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
|
||||
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixGetQuantileCut
|
||||
* Signature: (J[[J[[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetQuantileCut
|
||||
(JNIEnv *, jclass, jlong, jobjectArray, jobjectArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user