[jvm-packages] Expose prediction feature contribution on the Java side (#2441)
* Exposed prediction feature contribution on the Java side * was not supplying the newly added argument * Exposed from Scala-side as well * formatting (keep declaration in one line unless exceeding 100 chars)
This commit is contained in:
parent
d01a31088b
commit
2911597f3d
@ -131,7 +131,7 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
|
||||
float[][] predicts = this.predict(dtrain, true, 0, false);
|
||||
float[][] predicts = this.predict(dtrain, true, 0, false, false);
|
||||
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
||||
boost(dtrain, gradients.get(0), gradients.get(1));
|
||||
}
|
||||
@ -219,12 +219,14 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
* @param outputMargin output margin
|
||||
* @param treeLimit limit number of trees, 0 means all trees.
|
||||
* @param predLeaf prediction minimum to keep leafs
|
||||
* @param predContribs prediction feature contributions
|
||||
* @return predict results
|
||||
*/
|
||||
private synchronized float[][] predict(DMatrix data,
|
||||
boolean outputMargin,
|
||||
int treeLimit,
|
||||
boolean predLeaf) throws XGBoostError {
|
||||
boolean predLeaf,
|
||||
boolean predContribs) throws XGBoostError {
|
||||
int optionMask = 0;
|
||||
if (outputMargin) {
|
||||
optionMask = 1;
|
||||
@ -232,6 +234,9 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
if (predLeaf) {
|
||||
optionMask = 2;
|
||||
}
|
||||
if (predContribs) {
|
||||
optionMask = 4;
|
||||
}
|
||||
float[][] rawPredicts = new float[1][];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
|
||||
treeLimit, rawPredicts));
|
||||
@ -256,7 +261,19 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public float[][] predictLeaf(DMatrix data, int treeLimit) throws XGBoostError {
|
||||
return this.predict(data, false, treeLimit, true);
|
||||
return this.predict(data, false, treeLimit, true, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Output feature contributions toward predictions of given data
|
||||
*
|
||||
* @param data The input data.
|
||||
* @param treeLimit Number of trees to include, 0 means all trees.
|
||||
* @return The feature contributions and bias.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public float[][] predictContrib(DMatrix data, int treeLimit) throws XGBoostError {
|
||||
return this.predict(data, false, treeLimit, true, true);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -267,7 +284,7 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public float[][] predict(DMatrix data) throws XGBoostError {
|
||||
return this.predict(data, false, 0, false);
|
||||
return this.predict(data, false, 0, false, false);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -278,7 +295,7 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
* @return predict results
|
||||
*/
|
||||
public float[][] predict(DMatrix data, boolean outputMargin) throws XGBoostError {
|
||||
return this.predict(data, outputMargin, 0, false);
|
||||
return this.predict(data, outputMargin, 0, false, false);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -290,7 +307,7 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
* @return predict results
|
||||
*/
|
||||
public float[][] predict(DMatrix data, boolean outputMargin, int treeLimit) throws XGBoostError {
|
||||
return this.predict(data, outputMargin, treeLimit, false);
|
||||
return this.predict(data, outputMargin, treeLimit, false, false);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -135,11 +135,23 @@ class Booster private[xgboost4j](private var booster: JBooster)
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def predictLeaf(data: DMatrix, treeLimit: Int = 0)
|
||||
: Array[Array[Float]] = {
|
||||
def predictLeaf(data: DMatrix, treeLimit: Int = 0) : Array[Array[Float]] = {
|
||||
booster.predictLeaf(data.jDMatrix, treeLimit)
|
||||
}
|
||||
|
||||
/**
|
||||
* Output feature contributions toward predictions of given data
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
* @return The feature contributions and bias.
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def predictContrib(data: DMatrix, treeLimit: Int = 0) : Array[Array[Float]] = {
|
||||
booster.predictContrib(data.jDMatrix, treeLimit)
|
||||
}
|
||||
|
||||
/**
|
||||
* save model to modelPath
|
||||
*
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user