Use matrix for gradient. (#9508)
- Use the `linalg::Matrix` for storing gradients. - New API for the custom objective. - Custom objective for multi-class/multi-target is now required to return the correct shape. - Custom objective for Python can accept arrays with any strides. (row-major, column-major)
This commit is contained in:
@@ -218,34 +218,48 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
|
||||
float[][] predicts = this.predict(dtrain, true, 0, false, false);
|
||||
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
||||
this.boost(dtrain, gradients.get(0), gradients.get(1));
|
||||
}
|
||||
|
||||
/**
|
||||
* Update with customize obj func
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param iter The current training iteration.
|
||||
* @param obj customized objective class
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
|
||||
public void update(DMatrix dtrain, int iter, IObjective obj) throws XGBoostError {
|
||||
float[][] predicts = this.predict(dtrain, true, 0, false, false);
|
||||
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
||||
boost(dtrain, gradients.get(0), gradients.get(1));
|
||||
this.boost(dtrain, iter, gradients.get(0), gradients.get(1));
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
|
||||
this.boost(dtrain, 0, grad, hess);
|
||||
}
|
||||
|
||||
/**
|
||||
* update with give grad and hess
|
||||
* Update with give grad and hess
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param iter The current training iteration.
|
||||
* @param grad first order of gradient
|
||||
* @param hess seconde order of gradient
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
|
||||
public void boost(DMatrix dtrain, int iter, float[] grad, float[] hess) throws XGBoostError {
|
||||
if (grad.length != hess.length) {
|
||||
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
|
||||
hess.length));
|
||||
}
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle,
|
||||
dtrain.getHandle(), grad, hess));
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterTrainOneIter(handle,
|
||||
dtrain.getHandle(), iter, grad, hess));
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -110,7 +110,7 @@ class XGBoostJNI {
|
||||
|
||||
public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
|
||||
|
||||
public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad,
|
||||
public final static native int XGBoosterTrainOneIter(long handle, long dtrain, int iter, float[] grad,
|
||||
float[] hess);
|
||||
|
||||
public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats,
|
||||
|
||||
@@ -106,27 +106,41 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
booster.update(dtrain.jDMatrix, iter)
|
||||
}
|
||||
|
||||
@throws(classOf[XGBoostError])
|
||||
@deprecated
|
||||
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
|
||||
booster.update(dtrain.jDMatrix, obj)
|
||||
}
|
||||
|
||||
/**
|
||||
* update with customize obj func
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param iter The current training iteration
|
||||
* @param obj customized objective class
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
|
||||
booster.update(dtrain.jDMatrix, obj)
|
||||
def update(dtrain: DMatrix, iter: Int, obj: ObjectiveTrait): Unit = {
|
||||
booster.update(dtrain.jDMatrix, iter, obj)
|
||||
}
|
||||
|
||||
@throws(classOf[XGBoostError])
|
||||
@deprecated
|
||||
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
|
||||
booster.boost(dtrain.jDMatrix, grad, hess)
|
||||
}
|
||||
|
||||
/**
|
||||
* update with give grad and hess
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param iter The current training iteration
|
||||
* @param grad first order of gradient
|
||||
* @param hess seconde order of gradient
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
|
||||
booster.boost(dtrain.jDMatrix, grad, hess)
|
||||
def boost(dtrain: DMatrix, iter: Int, grad: Array[Float], hess: Array[Float]): Unit = {
|
||||
booster.boost(dtrain.jDMatrix, iter, grad, hess)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -28,6 +28,7 @@
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "../../../src/c_api/c_api_error.h"
|
||||
#include "../../../src/c_api/c_api_utils.h"
|
||||
|
||||
#define JVM_CHECK_CALL(__expr) \
|
||||
@@ -579,22 +580,44 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOne
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterBoostOneIter
|
||||
* Signature: (JJ[F[F)V
|
||||
* Method: XGBoosterTrainOneIter
|
||||
* Signature: (JJI[F[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
|
||||
jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0);
|
||||
jfloat* hess = jenv->GetFloatArrayElements(jhess, 0);
|
||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jgrad);
|
||||
int ret = XGBoosterBoostOneIter(handle, dtrain, grad, hess, len);
|
||||
JVM_CHECK_CALL(ret);
|
||||
//release
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneIter(
|
||||
JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jint jiter, jfloatArray jgrad,
|
||||
jfloatArray jhess) {
|
||||
API_BEGIN();
|
||||
BoosterHandle handle = reinterpret_cast<BoosterHandle *>(jhandle);
|
||||
DMatrixHandle dtrain = reinterpret_cast<DMatrixHandle *>(jdtrain);
|
||||
CHECK(handle);
|
||||
CHECK(dtrain);
|
||||
bst_ulong n_samples{0};
|
||||
JVM_CHECK_CALL(XGDMatrixNumRow(dtrain, &n_samples));
|
||||
|
||||
bst_ulong len = static_cast<bst_ulong>(jenv->GetArrayLength(jgrad));
|
||||
jfloat *grad = jenv->GetFloatArrayElements(jgrad, nullptr);
|
||||
jfloat *hess = jenv->GetFloatArrayElements(jhess, nullptr);
|
||||
CHECK(grad);
|
||||
CHECK(hess);
|
||||
|
||||
xgboost::bst_target_t n_targets{1};
|
||||
if (len != n_samples && n_samples != 0) {
|
||||
CHECK_EQ(len % n_samples, 0) << "Invalid size of gradient.";
|
||||
n_targets = len / n_samples;
|
||||
}
|
||||
|
||||
auto ctx = xgboost::detail::BoosterCtx(handle);
|
||||
auto [s_grad, s_hess] =
|
||||
xgboost::detail::MakeGradientInterface(ctx, grad, hess, n_samples, n_targets);
|
||||
int ret = XGBoosterTrainOneIter(handle, dtrain, static_cast<std::int32_t>(jiter), s_grad.c_str(),
|
||||
s_hess.c_str());
|
||||
|
||||
// release
|
||||
jenv->ReleaseFloatArrayElements(jgrad, grad, 0);
|
||||
jenv->ReleaseFloatArrayElements(jhess, hess, 0);
|
||||
|
||||
return ret;
|
||||
API_END();
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
@@ -185,11 +185,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOne
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterBoostOneIter
|
||||
* Signature: (JJ[F[F)I
|
||||
* Method: XGBoosterTrainOneIter
|
||||
* Signature: (JJI[F[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter
|
||||
(JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray);
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneIter
|
||||
(JNIEnv *, jclass, jlong, jlong, jint, jfloatArray, jfloatArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
@@ -386,19 +386,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSetStrFeatureInfo
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo
|
||||
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetStrFeatureInfo
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
|
||||
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
Reference in New Issue
Block a user