[jvm-packages] Add getNumFeature method (#6075)
* Add getNumFeature to the Java API * Add getNumFeature to the Scala API * Add unit tests for getNumFeature Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -677,6 +677,17 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
version += 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get number of model features.
|
||||
* @return the number of features.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public long getNumFeature() throws XGBoostError {
|
||||
long[] numFeature = new long[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature));
|
||||
return numFeature[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal initialization function.
|
||||
* @param cacheMats The cached DMatrix.
|
||||
|
||||
@@ -120,6 +120,7 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterSetAttr(long handle, String key, String value);
|
||||
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
|
||||
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
|
||||
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
|
||||
|
||||
// rabit functions
|
||||
public final static native int RabitInit(String[] args);
|
||||
|
||||
@@ -291,6 +291,14 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
.asScala.mapValues(_.doubleValue).toSeq: _*)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of model features.
|
||||
*
|
||||
* @return number of features
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getNumFeature: Long = booster.getNumFeature
|
||||
|
||||
def getVersion: Int = booster.getVersion
|
||||
|
||||
def toByteArray: Array[Byte] = {
|
||||
|
||||
Reference in New Issue
Block a user