[jvm-packages] Expose prediction feature contribution on the Java side (#2441)

* Exposed prediction feature contribution on the Java side

* was not supplying the newly added argument

* Exposed from Scala-side as well

* formatting (keep declaration in one line unless exceeding 100 chars)
This commit is contained in:
Edi Bice 2017-06-28 16:34:51 -04:00 committed by Nan Zhu
parent d01a31088b
commit 2911597f3d
2 changed files with 37 additions and 8 deletions

View File

@ -131,7 +131,7 @@ public class Booster implements Serializable, KryoSerializable {
* @throws XGBoostError native error
*/
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
float[][] predicts = this.predict(dtrain, true, 0, false);
float[][] predicts = this.predict(dtrain, true, 0, false, false);
List<float[]> gradients = obj.getGradient(predicts, dtrain);
boost(dtrain, gradients.get(0), gradients.get(1));
}
@ -219,12 +219,14 @@ public class Booster implements Serializable, KryoSerializable {
* @param outputMargin output margin
* @param treeLimit limit number of trees, 0 means all trees.
* @param predLeaf prediction minimum to keep leafs
* @param predContribs prediction feature contributions
* @return predict results
*/
private synchronized float[][] predict(DMatrix data,
boolean outputMargin,
int treeLimit,
boolean predLeaf) throws XGBoostError {
boolean predLeaf,
boolean predContribs) throws XGBoostError {
int optionMask = 0;
if (outputMargin) {
optionMask = 1;
@ -232,6 +234,9 @@ public class Booster implements Serializable, KryoSerializable {
if (predLeaf) {
optionMask = 2;
}
if (predContribs) {
optionMask = 4;
}
float[][] rawPredicts = new float[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
treeLimit, rawPredicts));
@ -256,7 +261,19 @@ public class Booster implements Serializable, KryoSerializable {
* @throws XGBoostError
*/
public float[][] predictLeaf(DMatrix data, int treeLimit) throws XGBoostError {
return this.predict(data, false, treeLimit, true);
return this.predict(data, false, treeLimit, true, false);
}
/**
* Output feature contributions toward predictions of given data
*
* @param data The input data.
* @param treeLimit Number of trees to include, 0 means all trees.
* @return The feature contributions and bias.
* @throws XGBoostError
*/
public float[][] predictContrib(DMatrix data, int treeLimit) throws XGBoostError {
return this.predict(data, false, treeLimit, true, true);
}
/**
@ -267,7 +284,7 @@ public class Booster implements Serializable, KryoSerializable {
* @throws XGBoostError native error
*/
public float[][] predict(DMatrix data) throws XGBoostError {
return this.predict(data, false, 0, false);
return this.predict(data, false, 0, false, false);
}
/**
@ -278,7 +295,7 @@ public class Booster implements Serializable, KryoSerializable {
* @return predict results
*/
public float[][] predict(DMatrix data, boolean outputMargin) throws XGBoostError {
return this.predict(data, outputMargin, 0, false);
return this.predict(data, outputMargin, 0, false, false);
}
/**
@ -290,7 +307,7 @@ public class Booster implements Serializable, KryoSerializable {
* @return predict results
*/
public float[][] predict(DMatrix data, boolean outputMargin, int treeLimit) throws XGBoostError {
return this.predict(data, outputMargin, treeLimit, false);
return this.predict(data, outputMargin, treeLimit, false, false);
}
/**

View File

@ -135,11 +135,23 @@ class Booster private[xgboost4j](private var booster: JBooster)
* @throws XGBoostError native error
*/
@throws(classOf[XGBoostError])
def predictLeaf(data: DMatrix, treeLimit: Int = 0)
: Array[Array[Float]] = {
def predictLeaf(data: DMatrix, treeLimit: Int = 0) : Array[Array[Float]] = {
booster.predictLeaf(data.jDMatrix, treeLimit)
}
/**
* Output feature contributions toward predictions of given data
*
* @param data dmatrix storing the input
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @return The feature contributions and bias.
* @throws XGBoostError native error
*/
@throws(classOf[XGBoostError])
def predictContrib(data: DMatrix, treeLimit: Int = 0) : Array[Array[Float]] = {
booster.predictContrib(data.jDMatrix, treeLimit)
}
/**
* save model to modelPath
*