[jvm-packages] Add getNumFeature method (#6075)

* Add getNumFeature to the Java API
* Add getNumFeature to the Scala API
* Add unit tests for getNumFeature

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Hristo Iliev 2020-09-08 06:57:46 +03:00 committed by GitHub
parent 93e9af43bb
commit da61d9460b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 66 additions and 0 deletions

View File

@ -677,6 +677,17 @@ public class Booster implements Serializable, KryoSerializable {
version += 1; version += 1;
} }
/**
* Get number of model features.
* @return the number of features.
* @throws XGBoostError
*/
public long getNumFeature() throws XGBoostError {
long[] numFeature = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature));
return numFeature[0];
}
/** /**
* Internal initialization function. * Internal initialization function.
* @param cacheMats The cached DMatrix. * @param cacheMats The cached DMatrix.

View File

@ -120,6 +120,7 @@ class XGBoostJNI {
public final static native int XGBoosterSetAttr(long handle, String key, String value); public final static native int XGBoosterSetAttr(long handle, String key, String value);
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version); public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
public final static native int XGBoosterSaveRabitCheckpoint(long handle); public final static native int XGBoosterSaveRabitCheckpoint(long handle);
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
// rabit functions // rabit functions
public final static native int RabitInit(String[] args); public final static native int RabitInit(String[] args);

View File

@ -291,6 +291,14 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
.asScala.mapValues(_.doubleValue).toSeq: _*) .asScala.mapValues(_.doubleValue).toSeq: _*)
} }
/**
* Get the number of model features.
*
* @return number of features
*/
@throws(classOf[XGBoostError])
def getNumFeature: Long = booster.getNumFeature
def getVersion: Int = booster.getVersion def getVersion: Int = booster.getVersion
def toByteArray: Array[Byte] = { def toByteArray: Array[Byte] = {

View File

@ -848,6 +848,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit
return XGBoosterSaveRabitCheckpoint(handle); return XGBoosterSaveRabitCheckpoint(handle);
} }
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetNumFeature
* Signature: (J[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
bst_ulong num_feature;
int ret = XGBoosterGetNumFeature(handle, &num_feature);
JVM_CHECK_CALL(ret);
jlong jnum_feature = num_feature;
jenv->SetLongArrayRegion(jout, 0, 1, &jnum_feature);
return ret;
}
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitInit * Method: RabitInit

View File

@ -271,6 +271,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabit
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *, jclass, jlong); (JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetNumFeature
* Signature: (J[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
(JNIEnv *, jclass, jlong, jlongArray);
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitInit * Method: RabitInit

View File

@ -646,4 +646,18 @@ public class BoosterImplTest {
TestCase.assertEquals(attr.get("bb"), "BB"); TestCase.assertEquals(attr.get("bb"), "BB");
TestCase.assertEquals(attr.get("cc"), "CC"); TestCase.assertEquals(attr.get("cc"), "CC");
} }
/**
* test get number of features from a booster
*
* @throws XGBoostError
*/
@Test
public void testGetNumFeature() throws XGBoostError {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
Booster booster = trainBooster(trainMat, testMat);
TestCase.assertEquals(booster.getNumFeature(), 127);
}
} }

View File

@ -210,4 +210,12 @@ class ScalaBoosterImplSuite extends FunSuite {
val nextBooster = XGBoost.train(trainMat, paramMap, round = 4, booster = prevBooster) val nextBooster = XGBoost.train(trainMat, paramMap, round = 4, booster = prevBooster)
assert(prevBooster == nextBooster) assert(prevBooster == nextBooster)
} }
test("test getting number of features from a booster") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val booster = trainBooster(trainMat, testMat)
TestCase.assertEquals(booster.getNumFeature, 127)
}
} }