Use matrix for gradient. (#9508)
- Use the `linalg::Matrix` for storing gradients. - New API for the custom objective. - Custom objective for multi-class/multi-target is now required to return the correct shape. - Custom objective for Python can accept arrays with any strides. (row-major, column-major)
This commit is contained in:
@@ -28,6 +28,7 @@
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "../../../src/c_api/c_api_error.h"
|
||||
#include "../../../src/c_api/c_api_utils.h"
|
||||
|
||||
#define JVM_CHECK_CALL(__expr) \
|
||||
@@ -579,22 +580,44 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOne
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterBoostOneIter
|
||||
* Signature: (JJ[F[F)V
|
||||
* Method: XGBoosterTrainOneIter
|
||||
* Signature: (JJI[F[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
|
||||
jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0);
|
||||
jfloat* hess = jenv->GetFloatArrayElements(jhess, 0);
|
||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jgrad);
|
||||
int ret = XGBoosterBoostOneIter(handle, dtrain, grad, hess, len);
|
||||
JVM_CHECK_CALL(ret);
|
||||
//release
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneIter(
|
||||
JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jint jiter, jfloatArray jgrad,
|
||||
jfloatArray jhess) {
|
||||
API_BEGIN();
|
||||
BoosterHandle handle = reinterpret_cast<BoosterHandle *>(jhandle);
|
||||
DMatrixHandle dtrain = reinterpret_cast<DMatrixHandle *>(jdtrain);
|
||||
CHECK(handle);
|
||||
CHECK(dtrain);
|
||||
bst_ulong n_samples{0};
|
||||
JVM_CHECK_CALL(XGDMatrixNumRow(dtrain, &n_samples));
|
||||
|
||||
bst_ulong len = static_cast<bst_ulong>(jenv->GetArrayLength(jgrad));
|
||||
jfloat *grad = jenv->GetFloatArrayElements(jgrad, nullptr);
|
||||
jfloat *hess = jenv->GetFloatArrayElements(jhess, nullptr);
|
||||
CHECK(grad);
|
||||
CHECK(hess);
|
||||
|
||||
xgboost::bst_target_t n_targets{1};
|
||||
if (len != n_samples && n_samples != 0) {
|
||||
CHECK_EQ(len % n_samples, 0) << "Invalid size of gradient.";
|
||||
n_targets = len / n_samples;
|
||||
}
|
||||
|
||||
auto ctx = xgboost::detail::BoosterCtx(handle);
|
||||
auto [s_grad, s_hess] =
|
||||
xgboost::detail::MakeGradientInterface(ctx, grad, hess, n_samples, n_targets);
|
||||
int ret = XGBoosterTrainOneIter(handle, dtrain, static_cast<std::int32_t>(jiter), s_grad.c_str(),
|
||||
s_hess.c_str());
|
||||
|
||||
// release
|
||||
jenv->ReleaseFloatArrayElements(jgrad, grad, 0);
|
||||
jenv->ReleaseFloatArrayElements(jhess, hess, 0);
|
||||
|
||||
return ret;
|
||||
API_END();
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
@@ -185,11 +185,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOne
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterBoostOneIter
|
||||
* Signature: (JJ[F[F)I
|
||||
* Method: XGBoosterTrainOneIter
|
||||
* Signature: (JJI[F[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter
|
||||
(JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray);
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneIter
|
||||
(JNIEnv *, jclass, jlong, jlong, jint, jfloatArray, jfloatArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
@@ -386,19 +386,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSetStrFeatureInfo
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo
|
||||
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetStrFeatureInfo
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
|
||||
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
Reference in New Issue
Block a user