Use matrix for gradient. (#9508)

- Use the `linalg::Matrix` for storing gradients.
- New API for the custom objective.
- Custom objective for multi-class/multi-target is now required to return the correct shape.
- Custom objective for Python can accept arrays with any strides. (row-major, column-major)
This commit is contained in:
Jiaming Yuan
2023-08-24 05:29:52 +08:00
committed by GitHub
parent 6103dca0bb
commit 972730cde0
77 changed files with 1052 additions and 651 deletions

View File

@@ -218,34 +218,48 @@ public class Booster implements Serializable, KryoSerializable {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
}
@Deprecated
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
float[][] predicts = this.predict(dtrain, true, 0, false, false);
List<float[]> gradients = obj.getGradient(predicts, dtrain);
this.boost(dtrain, gradients.get(0), gradients.get(1));
}
/**
* Update with customize obj func
*
* @param dtrain training data
* @param iter The current training iteration.
* @param obj customized objective class
* @throws XGBoostError native error
*/
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
public void update(DMatrix dtrain, int iter, IObjective obj) throws XGBoostError {
float[][] predicts = this.predict(dtrain, true, 0, false, false);
List<float[]> gradients = obj.getGradient(predicts, dtrain);
boost(dtrain, gradients.get(0), gradients.get(1));
this.boost(dtrain, iter, gradients.get(0), gradients.get(1));
}
@Deprecated
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
this.boost(dtrain, 0, grad, hess);
}
/**
* update with give grad and hess
* Update with give grad and hess
*
* @param dtrain training data
* @param iter The current training iteration.
* @param grad first order of gradient
* @param hess seconde order of gradient
* @throws XGBoostError native error
*/
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
public void boost(DMatrix dtrain, int iter, float[] grad, float[] hess) throws XGBoostError {
if (grad.length != hess.length) {
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
hess.length));
}
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle,
dtrain.getHandle(), grad, hess));
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterTrainOneIter(handle,
dtrain.getHandle(), iter, grad, hess));
}
/**

View File

@@ -110,7 +110,7 @@ class XGBoostJNI {
public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad,
public final static native int XGBoosterTrainOneIter(long handle, long dtrain, int iter, float[] grad,
float[] hess);
public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats,

View File

@@ -106,27 +106,41 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
booster.update(dtrain.jDMatrix, iter)
}
@throws(classOf[XGBoostError])
@deprecated
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
booster.update(dtrain.jDMatrix, obj)
}
/**
* update with customize obj func
*
* @param dtrain training data
* @param iter The current training iteration
* @param obj customized objective class
*/
@throws(classOf[XGBoostError])
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
booster.update(dtrain.jDMatrix, obj)
def update(dtrain: DMatrix, iter: Int, obj: ObjectiveTrait): Unit = {
booster.update(dtrain.jDMatrix, iter, obj)
}
@throws(classOf[XGBoostError])
@deprecated
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
booster.boost(dtrain.jDMatrix, grad, hess)
}
/**
* update with give grad and hess
*
* @param dtrain training data
* @param iter The current training iteration
* @param grad first order of gradient
* @param hess seconde order of gradient
*/
@throws(classOf[XGBoostError])
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
booster.boost(dtrain.jDMatrix, grad, hess)
def boost(dtrain: DMatrix, iter: Int, grad: Array[Float], hess: Array[Float]): Unit = {
booster.boost(dtrain.jDMatrix, iter, grad, hess)
}
/**

View File

@@ -28,6 +28,7 @@
#include <type_traits>
#include <vector>
#include "../../../src/c_api/c_api_error.h"
#include "../../../src/c_api/c_api_utils.h"
#define JVM_CHECK_CALL(__expr) \
@@ -579,22 +580,44 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOne
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterBoostOneIter
* Signature: (JJ[F[F)V
* Method: XGBoosterTrainOneIter
* Signature: (JJI[F[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) {
BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0);
jfloat* hess = jenv->GetFloatArrayElements(jhess, 0);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jgrad);
int ret = XGBoosterBoostOneIter(handle, dtrain, grad, hess, len);
JVM_CHECK_CALL(ret);
//release
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneIter(
JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jint jiter, jfloatArray jgrad,
jfloatArray jhess) {
API_BEGIN();
BoosterHandle handle = reinterpret_cast<BoosterHandle *>(jhandle);
DMatrixHandle dtrain = reinterpret_cast<DMatrixHandle *>(jdtrain);
CHECK(handle);
CHECK(dtrain);
bst_ulong n_samples{0};
JVM_CHECK_CALL(XGDMatrixNumRow(dtrain, &n_samples));
bst_ulong len = static_cast<bst_ulong>(jenv->GetArrayLength(jgrad));
jfloat *grad = jenv->GetFloatArrayElements(jgrad, nullptr);
jfloat *hess = jenv->GetFloatArrayElements(jhess, nullptr);
CHECK(grad);
CHECK(hess);
xgboost::bst_target_t n_targets{1};
if (len != n_samples && n_samples != 0) {
CHECK_EQ(len % n_samples, 0) << "Invalid size of gradient.";
n_targets = len / n_samples;
}
auto ctx = xgboost::detail::BoosterCtx(handle);
auto [s_grad, s_hess] =
xgboost::detail::MakeGradientInterface(ctx, grad, hess, n_samples, n_targets);
int ret = XGBoosterTrainOneIter(handle, dtrain, static_cast<std::int32_t>(jiter), s_grad.c_str(),
s_hess.c_str());
// release
jenv->ReleaseFloatArrayElements(jgrad, grad, 0);
jenv->ReleaseFloatArrayElements(jhess, hess, 0);
return ret;
API_END();
}
/*

View File

@@ -185,11 +185,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOne
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterBoostOneIter
* Signature: (JJ[F[F)I
* Method: XGBoosterTrainOneIter
* Signature: (JJI[F[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter
(JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray);
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneIter
(JNIEnv *, jclass, jlong, jlong, jint, jfloatArray, jfloatArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
@@ -386,19 +386,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSetStrFeatureInfo
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetStrFeatureInfo
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
#ifdef __cplusplus