Use matrix for gradient. (#9508)

- Use the `linalg::Matrix` for storing gradients.
- New API for the custom objective.
- Custom objective for multi-class/multi-target is now required to return the correct shape.
- Custom objective for Python can accept arrays with any strides. (row-major, column-major)
This commit is contained in:
Jiaming Yuan
2023-08-24 05:29:52 +08:00
committed by GitHub
parent 6103dca0bb
commit 972730cde0
77 changed files with 1052 additions and 651 deletions

View File

@@ -218,34 +218,48 @@ public class Booster implements Serializable, KryoSerializable {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
}
@Deprecated
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
float[][] predicts = this.predict(dtrain, true, 0, false, false);
List<float[]> gradients = obj.getGradient(predicts, dtrain);
this.boost(dtrain, gradients.get(0), gradients.get(1));
}
/**
* Update with customize obj func
*
* @param dtrain training data
* @param iter The current training iteration.
* @param obj customized objective class
* @throws XGBoostError native error
*/
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
public void update(DMatrix dtrain, int iter, IObjective obj) throws XGBoostError {
float[][] predicts = this.predict(dtrain, true, 0, false, false);
List<float[]> gradients = obj.getGradient(predicts, dtrain);
boost(dtrain, gradients.get(0), gradients.get(1));
this.boost(dtrain, iter, gradients.get(0), gradients.get(1));
}
@Deprecated
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
this.boost(dtrain, 0, grad, hess);
}
/**
* update with give grad and hess
* Update with give grad and hess
*
* @param dtrain training data
* @param iter The current training iteration.
* @param grad first order of gradient
* @param hess seconde order of gradient
* @throws XGBoostError native error
*/
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
public void boost(DMatrix dtrain, int iter, float[] grad, float[] hess) throws XGBoostError {
if (grad.length != hess.length) {
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
hess.length));
}
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle,
dtrain.getHandle(), grad, hess));
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterTrainOneIter(handle,
dtrain.getHandle(), iter, grad, hess));
}
/**

View File

@@ -110,7 +110,7 @@ class XGBoostJNI {
public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad,
public final static native int XGBoosterTrainOneIter(long handle, long dtrain, int iter, float[] grad,
float[] hess);
public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats,

View File

@@ -106,27 +106,41 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
booster.update(dtrain.jDMatrix, iter)
}
@throws(classOf[XGBoostError])
@deprecated
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
booster.update(dtrain.jDMatrix, obj)
}
/**
* update with customize obj func
*
* @param dtrain training data
* @param iter The current training iteration
* @param obj customized objective class
*/
@throws(classOf[XGBoostError])
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
booster.update(dtrain.jDMatrix, obj)
def update(dtrain: DMatrix, iter: Int, obj: ObjectiveTrait): Unit = {
booster.update(dtrain.jDMatrix, iter, obj)
}
@throws(classOf[XGBoostError])
@deprecated
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
booster.boost(dtrain.jDMatrix, grad, hess)
}
/**
* update with give grad and hess
*
* @param dtrain training data
* @param iter The current training iteration
* @param grad first order of gradient
* @param hess seconde order of gradient
*/
@throws(classOf[XGBoostError])
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
booster.boost(dtrain.jDMatrix, grad, hess)
def boost(dtrain: DMatrix, iter: Int, grad: Array[Float], hess: Array[Float]): Unit = {
booster.boost(dtrain.jDMatrix, iter, grad, hess)
}
/**