[jvm-packages] Expose prediction feature contribution on the Java side (#2441)
* Exposed prediction feature contribution on the Java side * was not supplying the newly added argument * Exposed from Scala-side as well * formatting (keep declaration in one line unless exceeding 100 chars)
This commit is contained in:
parent
d01a31088b
commit
2911597f3d
@ -131,7 +131,7 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
|
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
|
||||||
float[][] predicts = this.predict(dtrain, true, 0, false);
|
float[][] predicts = this.predict(dtrain, true, 0, false, false);
|
||||||
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
||||||
boost(dtrain, gradients.get(0), gradients.get(1));
|
boost(dtrain, gradients.get(0), gradients.get(1));
|
||||||
}
|
}
|
||||||
@ -219,12 +219,14 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
* @param outputMargin output margin
|
* @param outputMargin output margin
|
||||||
* @param treeLimit limit number of trees, 0 means all trees.
|
* @param treeLimit limit number of trees, 0 means all trees.
|
||||||
* @param predLeaf prediction minimum to keep leafs
|
* @param predLeaf prediction minimum to keep leafs
|
||||||
|
* @param predContribs prediction feature contributions
|
||||||
* @return predict results
|
* @return predict results
|
||||||
*/
|
*/
|
||||||
private synchronized float[][] predict(DMatrix data,
|
private synchronized float[][] predict(DMatrix data,
|
||||||
boolean outputMargin,
|
boolean outputMargin,
|
||||||
int treeLimit,
|
int treeLimit,
|
||||||
boolean predLeaf) throws XGBoostError {
|
boolean predLeaf,
|
||||||
|
boolean predContribs) throws XGBoostError {
|
||||||
int optionMask = 0;
|
int optionMask = 0;
|
||||||
if (outputMargin) {
|
if (outputMargin) {
|
||||||
optionMask = 1;
|
optionMask = 1;
|
||||||
@ -232,6 +234,9 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
if (predLeaf) {
|
if (predLeaf) {
|
||||||
optionMask = 2;
|
optionMask = 2;
|
||||||
}
|
}
|
||||||
|
if (predContribs) {
|
||||||
|
optionMask = 4;
|
||||||
|
}
|
||||||
float[][] rawPredicts = new float[1][];
|
float[][] rawPredicts = new float[1][];
|
||||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
|
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
|
||||||
treeLimit, rawPredicts));
|
treeLimit, rawPredicts));
|
||||||
@ -256,7 +261,19 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
* @throws XGBoostError
|
* @throws XGBoostError
|
||||||
*/
|
*/
|
||||||
public float[][] predictLeaf(DMatrix data, int treeLimit) throws XGBoostError {
|
public float[][] predictLeaf(DMatrix data, int treeLimit) throws XGBoostError {
|
||||||
return this.predict(data, false, treeLimit, true);
|
return this.predict(data, false, treeLimit, true, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Output feature contributions toward predictions of given data
|
||||||
|
*
|
||||||
|
* @param data The input data.
|
||||||
|
* @param treeLimit Number of trees to include, 0 means all trees.
|
||||||
|
* @return The feature contributions and bias.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public float[][] predictContrib(DMatrix data, int treeLimit) throws XGBoostError {
|
||||||
|
return this.predict(data, false, treeLimit, true, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -267,7 +284,7 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public float[][] predict(DMatrix data) throws XGBoostError {
|
public float[][] predict(DMatrix data) throws XGBoostError {
|
||||||
return this.predict(data, false, 0, false);
|
return this.predict(data, false, 0, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -278,7 +295,7 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
* @return predict results
|
* @return predict results
|
||||||
*/
|
*/
|
||||||
public float[][] predict(DMatrix data, boolean outputMargin) throws XGBoostError {
|
public float[][] predict(DMatrix data, boolean outputMargin) throws XGBoostError {
|
||||||
return this.predict(data, outputMargin, 0, false);
|
return this.predict(data, outputMargin, 0, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -290,7 +307,7 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
* @return predict results
|
* @return predict results
|
||||||
*/
|
*/
|
||||||
public float[][] predict(DMatrix data, boolean outputMargin, int treeLimit) throws XGBoostError {
|
public float[][] predict(DMatrix data, boolean outputMargin, int treeLimit) throws XGBoostError {
|
||||||
return this.predict(data, outputMargin, treeLimit, false);
|
return this.predict(data, outputMargin, treeLimit, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -135,11 +135,23 @@ class Booster private[xgboost4j](private var booster: JBooster)
|
|||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
def predictLeaf(data: DMatrix, treeLimit: Int = 0)
|
def predictLeaf(data: DMatrix, treeLimit: Int = 0) : Array[Array[Float]] = {
|
||||||
: Array[Array[Float]] = {
|
|
||||||
booster.predictLeaf(data.jDMatrix, treeLimit)
|
booster.predictLeaf(data.jDMatrix, treeLimit)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Output feature contributions toward predictions of given data
|
||||||
|
*
|
||||||
|
* @param data dmatrix storing the input
|
||||||
|
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||||
|
* @return The feature contributions and bias.
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def predictContrib(data: DMatrix, treeLimit: Int = 0) : Array[Array[Float]] = {
|
||||||
|
booster.predictContrib(data.jDMatrix, treeLimit)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* save model to modelPath
|
* save model to modelPath
|
||||||
*
|
*
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user