[jvm-packages] Add getNumFeature method (#6075)
* Add getNumFeature to the Java API * Add getNumFeature to the Scala API * Add unit tests for getNumFeature Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
93e9af43bb
commit
da61d9460b
@ -677,6 +677,17 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
version += 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get number of model features.
|
||||
* @return the number of features.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public long getNumFeature() throws XGBoostError {
|
||||
long[] numFeature = new long[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature));
|
||||
return numFeature[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal initialization function.
|
||||
* @param cacheMats The cached DMatrix.
|
||||
|
||||
@ -120,6 +120,7 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterSetAttr(long handle, String key, String value);
|
||||
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
|
||||
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
|
||||
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
|
||||
|
||||
// rabit functions
|
||||
public final static native int RabitInit(String[] args);
|
||||
|
||||
@ -291,6 +291,14 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
.asScala.mapValues(_.doubleValue).toSeq: _*)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of model features.
|
||||
*
|
||||
* @return number of features
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getNumFeature: Long = booster.getNumFeature
|
||||
|
||||
def getVersion: Int = booster.getVersion
|
||||
|
||||
def toByteArray: Array[Byte] = {
|
||||
|
||||
@ -848,6 +848,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit
|
||||
return XGBoosterSaveRabitCheckpoint(handle);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetNumFeature
|
||||
* Signature: (J[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
bst_ulong num_feature;
|
||||
int ret = XGBoosterGetNumFeature(handle, &num_feature);
|
||||
JVM_CHECK_CALL(ret);
|
||||
jlong jnum_feature = num_feature;
|
||||
jenv->SetLongArrayRegion(jout, 0, 1, &jnum_feature);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitInit
|
||||
|
||||
@ -271,6 +271,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabit
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetNumFeature
|
||||
* Signature: (J[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
|
||||
(JNIEnv *, jclass, jlong, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitInit
|
||||
|
||||
@ -646,4 +646,18 @@ public class BoosterImplTest {
|
||||
TestCase.assertEquals(attr.get("bb"), "BB");
|
||||
TestCase.assertEquals(attr.get("cc"), "CC");
|
||||
}
|
||||
|
||||
/**
|
||||
* test get number of features from a booster
|
||||
*
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
@Test
|
||||
public void testGetNumFeature() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
TestCase.assertEquals(booster.getNumFeature(), 127);
|
||||
}
|
||||
}
|
||||
|
||||
@ -210,4 +210,12 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
val nextBooster = XGBoost.train(trainMat, paramMap, round = 4, booster = prevBooster)
|
||||
assert(prevBooster == nextBooster)
|
||||
}
|
||||
|
||||
test("test getting number of features from a booster") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val booster = trainBooster(trainMat, testMat)
|
||||
|
||||
TestCase.assertEquals(booster.getNumFeature, 127)
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user