diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index d8ba6ea8e..20fb6f75c 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -677,6 +677,17 @@ public class Booster implements Serializable, KryoSerializable { version += 1; } + /** + * Get number of model features. + * @return the number of features. + * @throws XGBoostError + */ + public long getNumFeature() throws XGBoostError { + long[] numFeature = new long[1]; + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature)); + return numFeature[0]; + } + /** * Internal initialization function. * @param cacheMats The cached DMatrix. diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index 4eee147e9..bafa541e5 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -120,6 +120,7 @@ class XGBoostJNI { public final static native int XGBoosterSetAttr(long handle, String key, String value); public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version); public final static native int XGBoosterSaveRabitCheckpoint(long handle); + public final static native int XGBoosterGetNumFeature(long handle, long[] feature); // rabit functions public final static native int RabitInit(String[] args); diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index bb2d5e9e5..e442c4f75 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -291,6 +291,14 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) .asScala.mapValues(_.doubleValue).toSeq: _*) } + /** + * Get the number of model features. + * + * @return number of features + */ + @throws(classOf[XGBoostError]) + def getNumFeature: Long = booster.getNumFeature + def getVersion: Int = booster.getVersion def toByteArray: Array[Byte] = { diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 13529f6e1..b7461a7de 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -848,6 +848,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit return XGBoosterSaveRabitCheckpoint(handle); } +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGBoosterGetNumFeature + * Signature: (J[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature + (JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) { + BoosterHandle handle = (BoosterHandle) jhandle; + bst_ulong num_feature; + int ret = XGBoosterGetNumFeature(handle, &num_feature); + JVM_CHECK_CALL(ret); + jlong jnum_feature = num_feature; + jenv->SetLongArrayRegion(jout, 0, 1, &jnum_feature); + return ret; +} + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: RabitInit diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index 3d0f0c468..fd9e0932a 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -271,6 +271,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabit JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint (JNIEnv *, jclass, jlong); +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGBoosterGetNumFeature + * Signature: (J[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature + (JNIEnv *, jclass, jlong, jlongArray); + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: RabitInit diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index 3b565ab8f..a01d7d6cf 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -646,4 +646,18 @@ public class BoosterImplTest { TestCase.assertEquals(attr.get("bb"), "BB"); TestCase.assertEquals(attr.get("cc"), "CC"); } + + /** + * test get number of features from a booster + * + * @throws XGBoostError + */ + @Test + public void testGetNumFeature() throws XGBoostError { + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + Booster booster = trainBooster(trainMat, testMat); + TestCase.assertEquals(booster.getNumFeature(), 127); + } } diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala index adea1b1ec..8f7bbd322 100644 --- a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala +++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala @@ -210,4 +210,12 @@ class ScalaBoosterImplSuite extends FunSuite { val nextBooster = XGBoost.train(trainMat, paramMap, round = 4, booster = prevBooster) assert(prevBooster == nextBooster) } + + test("test getting number of features from a booster") { + val trainMat = new DMatrix("../../demo/data/agaricus.txt.train") + val testMat = new DMatrix("../../demo/data/agaricus.txt.test") + val booster = trainBooster(trainMat, testMat) + + TestCase.assertEquals(booster.getNumFeature, 127) + } }