Allow JVM-Package to access inplace predict method (#9167)
--------- Co-authored-by: Stephan T. Lavavej <stl@nuwen.net> Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com> Co-authored-by: Joe <25804777+ByteSizedJoe@users.noreply.github.com>
This commit is contained in:
@@ -39,6 +39,21 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
// handle to the booster.
|
||||
private long handle = 0;
|
||||
private int version = 0;
|
||||
/**
|
||||
* Type of prediction, used for inplace_predict.
|
||||
*/
|
||||
public enum PredictionType {
|
||||
kValue(0),
|
||||
kMargin(1);
|
||||
|
||||
private Integer ptype;
|
||||
private PredictionType(final Integer ptype) {
|
||||
this.ptype = ptype;
|
||||
}
|
||||
public Integer getPType() {
|
||||
return ptype;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new Booster with empty stage.
|
||||
@@ -375,6 +390,97 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
return predicts;
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform thread-safe prediction.
|
||||
*
|
||||
* @param data Flattened input matrix of features for prediction
|
||||
* @param nrow The number of preditions to make (count of input matrix rows)
|
||||
* @param ncol The number of features in the model (count of input matrix columns)
|
||||
* @param missing Value indicating missing element in the <code>data</code> input matrix
|
||||
*
|
||||
* @return predict Result matrix
|
||||
*/
|
||||
public float[][] inplace_predict(float[] data,
|
||||
int nrow,
|
||||
int ncol,
|
||||
float missing) throws XGBoostError {
|
||||
int[] iteration_range = new int[2];
|
||||
iteration_range[0] = 0;
|
||||
iteration_range[1] = 0;
|
||||
return this.inplace_predict(data, nrow, ncol,
|
||||
missing, iteration_range, PredictionType.kValue, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform thread-safe prediction.
|
||||
*
|
||||
* @param data Flattened input matrix of features for prediction
|
||||
* @param nrow The number of preditions to make (count of input matrix rows)
|
||||
* @param ncol The number of features in the model (count of input matrix columns)
|
||||
* @param missing Value indicating missing element in the <code>data</code> input matrix
|
||||
* @param iteration_range Specifies which layer of trees are used in prediction. For
|
||||
* example, if a random forest is trained with 100 rounds.
|
||||
* Specifying `iteration_range=[10, 20)`, then only the forests
|
||||
* built during [10, 20) (half open set) rounds are used in this
|
||||
* prediction.
|
||||
*
|
||||
* @return predict Result matrix
|
||||
*/
|
||||
public float[][] inplace_predict(float[] data,
|
||||
int nrow,
|
||||
int ncol,
|
||||
float missing, int[] iteration_range) throws XGBoostError {
|
||||
return this.inplace_predict(data, nrow, ncol,
|
||||
missing, iteration_range, PredictionType.kValue, null);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Perform thread-safe prediction.
|
||||
*
|
||||
* @param data Flattened input matrix of features for prediction
|
||||
* @param nrow The number of preditions to make (count of input matrix rows)
|
||||
* @param ncol The number of features in the model (count of input matrix columns)
|
||||
* @param missing Value indicating missing element in the <code>data</code> input matrix
|
||||
* @param iteration_range Specifies which layer of trees are used in prediction. For
|
||||
* example, if a random forest is trained with 100 rounds.
|
||||
* Specifying `iteration_range=[10, 20)`, then only the forests
|
||||
* built during [10, 20) (half open set) rounds are used in this
|
||||
* prediction.
|
||||
* @param predict_type What kind of prediction to run.
|
||||
* @return predict Result matrix
|
||||
*/
|
||||
public float[][] inplace_predict(float[] data,
|
||||
int nrow,
|
||||
int ncol,
|
||||
float missing,
|
||||
int[] iteration_range,
|
||||
PredictionType predict_type,
|
||||
float[] base_margin) throws XGBoostError {
|
||||
if (iteration_range.length != 2) {
|
||||
throw new XGBoostError(new String("Iteration range is expected to be [begin, end)."));
|
||||
}
|
||||
int ptype = predict_type.getPType();
|
||||
|
||||
int begin = iteration_range[0];
|
||||
int end = iteration_range[1];
|
||||
|
||||
float[][] rawPredicts = new float[1][];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredictFromDense(handle, data, nrow, ncol,
|
||||
missing,
|
||||
begin, end, ptype, base_margin, rawPredicts));
|
||||
|
||||
int col = rawPredicts[0].length / nrow;
|
||||
float[][] predicts = new float[nrow][col];
|
||||
int r, c;
|
||||
for (int i = 0; i < rawPredicts[0].length; i++) {
|
||||
r = i / col;
|
||||
c = i % col;
|
||||
predicts[r][c] = rawPredicts[0][i];
|
||||
}
|
||||
return predicts;
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict leaf indices given the data
|
||||
*
|
||||
|
||||
@@ -119,6 +119,10 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask,
|
||||
int ntree_limit, float[][] predicts);
|
||||
|
||||
public final static native int XGBoosterPredictFromDense(long handle, float[] data,
|
||||
long nrow, long ncol, float missing, int iteration_begin, int iteration_end, int predict_type, float[] margin,
|
||||
float[][] predicts);
|
||||
|
||||
public final static native int XGBoosterLoadModel(long handle, String fname);
|
||||
|
||||
public final static native int XGBoosterSaveModel(long handle, String fname);
|
||||
@@ -154,10 +158,6 @@ class XGBoostJNI {
|
||||
public final static native int XGDMatrixSetInfoFromInterface(
|
||||
long handle, String field, String json);
|
||||
|
||||
@Deprecated
|
||||
public final static native int XGDeviceQuantileDMatrixCreateFromCallback(
|
||||
java.util.Iterator<ColumnBatch> iter, float missing, int nthread, int maxBin, long[] out);
|
||||
|
||||
public final static native int XGQuantileDMatrixCreateFromCallback(
|
||||
java.util.Iterator<ColumnBatch> iter, java.util.Iterator<ColumnBatch> ref, String config, long[] out);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user