Use matrix for gradient. (#9508)
- Use the `linalg::Matrix` for storing gradients. - New API for the custom objective. - Custom objective for multi-class/multi-target is now required to return the correct shape. - Custom objective for Python can accept arrays with any strides. (row-major, column-major)
This commit is contained in:
@@ -218,34 +218,48 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
|
||||
float[][] predicts = this.predict(dtrain, true, 0, false, false);
|
||||
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
||||
this.boost(dtrain, gradients.get(0), gradients.get(1));
|
||||
}
|
||||
|
||||
/**
|
||||
* Update with customize obj func
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param iter The current training iteration.
|
||||
* @param obj customized objective class
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
|
||||
public void update(DMatrix dtrain, int iter, IObjective obj) throws XGBoostError {
|
||||
float[][] predicts = this.predict(dtrain, true, 0, false, false);
|
||||
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
||||
boost(dtrain, gradients.get(0), gradients.get(1));
|
||||
this.boost(dtrain, iter, gradients.get(0), gradients.get(1));
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
|
||||
this.boost(dtrain, 0, grad, hess);
|
||||
}
|
||||
|
||||
/**
|
||||
* update with give grad and hess
|
||||
* Update with give grad and hess
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param iter The current training iteration.
|
||||
* @param grad first order of gradient
|
||||
* @param hess seconde order of gradient
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
|
||||
public void boost(DMatrix dtrain, int iter, float[] grad, float[] hess) throws XGBoostError {
|
||||
if (grad.length != hess.length) {
|
||||
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
|
||||
hess.length));
|
||||
}
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle,
|
||||
dtrain.getHandle(), grad, hess));
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterTrainOneIter(handle,
|
||||
dtrain.getHandle(), iter, grad, hess));
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -110,7 +110,7 @@ class XGBoostJNI {
|
||||
|
||||
public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
|
||||
|
||||
public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad,
|
||||
public final static native int XGBoosterTrainOneIter(long handle, long dtrain, int iter, float[] grad,
|
||||
float[] hess);
|
||||
|
||||
public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats,
|
||||
|
||||
@@ -106,27 +106,41 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
booster.update(dtrain.jDMatrix, iter)
|
||||
}
|
||||
|
||||
@throws(classOf[XGBoostError])
|
||||
@deprecated
|
||||
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
|
||||
booster.update(dtrain.jDMatrix, obj)
|
||||
}
|
||||
|
||||
/**
|
||||
* update with customize obj func
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param iter The current training iteration
|
||||
* @param obj customized objective class
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
|
||||
booster.update(dtrain.jDMatrix, obj)
|
||||
def update(dtrain: DMatrix, iter: Int, obj: ObjectiveTrait): Unit = {
|
||||
booster.update(dtrain.jDMatrix, iter, obj)
|
||||
}
|
||||
|
||||
@throws(classOf[XGBoostError])
|
||||
@deprecated
|
||||
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
|
||||
booster.boost(dtrain.jDMatrix, grad, hess)
|
||||
}
|
||||
|
||||
/**
|
||||
* update with give grad and hess
|
||||
*
|
||||
* @param dtrain training data
|
||||
* @param iter The current training iteration
|
||||
* @param grad first order of gradient
|
||||
* @param hess seconde order of gradient
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
|
||||
booster.boost(dtrain.jDMatrix, grad, hess)
|
||||
def boost(dtrain: DMatrix, iter: Int, grad: Array[Float], hess: Array[Float]): Unit = {
|
||||
booster.boost(dtrain.jDMatrix, iter, grad, hess)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user