Use matrix for gradient. (#9508)

- Use the `linalg::Matrix` for storing gradients.
- New API for the custom objective.
- Custom objective for multi-class/multi-target is now required to return the correct shape.
- Custom objective for Python can accept arrays with any strides. (row-major, column-major)
This commit is contained in:
Jiaming Yuan
2023-08-24 05:29:52 +08:00
committed by GitHub
parent 6103dca0bb
commit 972730cde0
77 changed files with 1052 additions and 651 deletions

View File

@@ -28,6 +28,7 @@
#include <type_traits>
#include <vector>
#include "../../../src/c_api/c_api_error.h"
#include "../../../src/c_api/c_api_utils.h"
#define JVM_CHECK_CALL(__expr) \
@@ -579,22 +580,44 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOne
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterBoostOneIter
* Signature: (JJ[F[F)V
* Method: XGBoosterTrainOneIter
* Signature: (JJI[F[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) {
BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0);
jfloat* hess = jenv->GetFloatArrayElements(jhess, 0);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jgrad);
int ret = XGBoosterBoostOneIter(handle, dtrain, grad, hess, len);
JVM_CHECK_CALL(ret);
//release
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneIter(
JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jint jiter, jfloatArray jgrad,
jfloatArray jhess) {
API_BEGIN();
BoosterHandle handle = reinterpret_cast<BoosterHandle *>(jhandle);
DMatrixHandle dtrain = reinterpret_cast<DMatrixHandle *>(jdtrain);
CHECK(handle);
CHECK(dtrain);
bst_ulong n_samples{0};
JVM_CHECK_CALL(XGDMatrixNumRow(dtrain, &n_samples));
bst_ulong len = static_cast<bst_ulong>(jenv->GetArrayLength(jgrad));
jfloat *grad = jenv->GetFloatArrayElements(jgrad, nullptr);
jfloat *hess = jenv->GetFloatArrayElements(jhess, nullptr);
CHECK(grad);
CHECK(hess);
xgboost::bst_target_t n_targets{1};
if (len != n_samples && n_samples != 0) {
CHECK_EQ(len % n_samples, 0) << "Invalid size of gradient.";
n_targets = len / n_samples;
}
auto ctx = xgboost::detail::BoosterCtx(handle);
auto [s_grad, s_hess] =
xgboost::detail::MakeGradientInterface(ctx, grad, hess, n_samples, n_targets);
int ret = XGBoosterTrainOneIter(handle, dtrain, static_cast<std::int32_t>(jiter), s_grad.c_str(),
s_hess.c_str());
// release
jenv->ReleaseFloatArrayElements(jgrad, grad, 0);
jenv->ReleaseFloatArrayElements(jhess, hess, 0);
return ret;
API_END();
}
/*

View File

@@ -185,11 +185,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOne
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterBoostOneIter
* Signature: (JJ[F[F)I
* Method: XGBoosterTrainOneIter
* Signature: (JJI[F[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter
(JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray);
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneIter
(JNIEnv *, jclass, jlong, jlong, jint, jfloatArray, jfloatArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
@@ -386,19 +386,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSetStrFeatureInfo
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetStrFeatureInfo
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
#ifdef __cplusplus