[jvm-packages] Add getNumFeature method (#6075)

* Add getNumFeature to the Java API
* Add getNumFeature to the Scala API
* Add unit tests for getNumFeature

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Hristo Iliev
2020-09-08 06:57:46 +03:00
committed by GitHub
parent 93e9af43bb
commit da61d9460b
7 changed files with 66 additions and 0 deletions

View File

@@ -677,6 +677,17 @@ public class Booster implements Serializable, KryoSerializable {
version += 1;
}
/**
* Get number of model features.
* @return the number of features.
* @throws XGBoostError
*/
public long getNumFeature() throws XGBoostError {
long[] numFeature = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature));
return numFeature[0];
}
/**
* Internal initialization function.
* @param cacheMats The cached DMatrix.

View File

@@ -120,6 +120,7 @@ class XGBoostJNI {
public final static native int XGBoosterSetAttr(long handle, String key, String value);
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
// rabit functions
public final static native int RabitInit(String[] args);

View File

@@ -291,6 +291,14 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
.asScala.mapValues(_.doubleValue).toSeq: _*)
}
/**
* Get the number of model features.
*
* @return number of features
*/
@throws(classOf[XGBoostError])
def getNumFeature: Long = booster.getNumFeature
def getVersion: Int = booster.getVersion
def toByteArray: Array[Byte] = {