From 2911597f3daec5a94e80d2a43a29e4b872befad0 Mon Sep 17 00:00:00 2001 From: Edi Bice Date: Wed, 28 Jun 2017 16:34:51 -0400 Subject: [PATCH] [jvm-packages] Expose prediction feature contribution on the Java side (#2441) * Exposed prediction feature contribution on the Java side * was not supplying the newly added argument * Exposed from Scala-side as well * formatting (keep declaration in one line unless exceeding 100 chars) --- .../java/ml/dmlc/xgboost4j/java/Booster.java | 29 +++++++++++++++---- .../ml/dmlc/xgboost4j/scala/Booster.scala | 16 ++++++++-- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 92438121d..3b1476e54 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -131,7 +131,7 @@ public class Booster implements Serializable, KryoSerializable { * @throws XGBoostError native error */ public void update(DMatrix dtrain, IObjective obj) throws XGBoostError { - float[][] predicts = this.predict(dtrain, true, 0, false); + float[][] predicts = this.predict(dtrain, true, 0, false, false); List gradients = obj.getGradient(predicts, dtrain); boost(dtrain, gradients.get(0), gradients.get(1)); } @@ -219,12 +219,14 @@ public class Booster implements Serializable, KryoSerializable { * @param outputMargin output margin * @param treeLimit limit number of trees, 0 means all trees. * @param predLeaf prediction minimum to keep leafs + * @param predContribs prediction feature contributions * @return predict results */ private synchronized float[][] predict(DMatrix data, boolean outputMargin, int treeLimit, - boolean predLeaf) throws XGBoostError { + boolean predLeaf, + boolean predContribs) throws XGBoostError { int optionMask = 0; if (outputMargin) { optionMask = 1; @@ -232,6 +234,9 @@ public class Booster implements Serializable, KryoSerializable { if (predLeaf) { optionMask = 2; } + if (predContribs) { + optionMask = 4; + } float[][] rawPredicts = new float[1][]; XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, treeLimit, rawPredicts)); @@ -256,7 +261,19 @@ public class Booster implements Serializable, KryoSerializable { * @throws XGBoostError */ public float[][] predictLeaf(DMatrix data, int treeLimit) throws XGBoostError { - return this.predict(data, false, treeLimit, true); + return this.predict(data, false, treeLimit, true, false); + } + + /** + * Output feature contributions toward predictions of given data + * + * @param data The input data. + * @param treeLimit Number of trees to include, 0 means all trees. + * @return The feature contributions and bias. + * @throws XGBoostError + */ + public float[][] predictContrib(DMatrix data, int treeLimit) throws XGBoostError { + return this.predict(data, false, treeLimit, true, true); } /** @@ -267,7 +284,7 @@ public class Booster implements Serializable, KryoSerializable { * @throws XGBoostError native error */ public float[][] predict(DMatrix data) throws XGBoostError { - return this.predict(data, false, 0, false); + return this.predict(data, false, 0, false, false); } /** @@ -278,7 +295,7 @@ public class Booster implements Serializable, KryoSerializable { * @return predict results */ public float[][] predict(DMatrix data, boolean outputMargin) throws XGBoostError { - return this.predict(data, outputMargin, 0, false); + return this.predict(data, outputMargin, 0, false, false); } /** @@ -290,7 +307,7 @@ public class Booster implements Serializable, KryoSerializable { * @return predict results */ public float[][] predict(DMatrix data, boolean outputMargin, int treeLimit) throws XGBoostError { - return this.predict(data, outputMargin, treeLimit, false); + return this.predict(data, outputMargin, treeLimit, false, false); } /** diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index 760aee828..174c68804 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -135,11 +135,23 @@ class Booster private[xgboost4j](private var booster: JBooster) * @throws XGBoostError native error */ @throws(classOf[XGBoostError]) - def predictLeaf(data: DMatrix, treeLimit: Int = 0) - : Array[Array[Float]] = { + def predictLeaf(data: DMatrix, treeLimit: Int = 0) : Array[Array[Float]] = { booster.predictLeaf(data.jDMatrix, treeLimit) } + /** + * Output feature contributions toward predictions of given data + * + * @param data dmatrix storing the input + * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). + * @return The feature contributions and bias. + * @throws XGBoostError native error + */ + @throws(classOf[XGBoostError]) + def predictContrib(data: DMatrix, treeLimit: Int = 0) : Array[Array[Float]] = { + booster.predictContrib(data.jDMatrix, treeLimit) + } + /** * save model to modelPath *