[jvm-packages] Add getNumFeature method (#6075)
* Add getNumFeature to the Java API * Add getNumFeature to the Scala API * Add unit tests for getNumFeature Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
93e9af43bb
commit
da61d9460b
@ -677,6 +677,17 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
version += 1;
|
version += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get number of model features.
|
||||||
|
* @return the number of features.
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public long getNumFeature() throws XGBoostError {
|
||||||
|
long[] numFeature = new long[1];
|
||||||
|
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature));
|
||||||
|
return numFeature[0];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Internal initialization function.
|
* Internal initialization function.
|
||||||
* @param cacheMats The cached DMatrix.
|
* @param cacheMats The cached DMatrix.
|
||||||
|
|||||||
@ -120,6 +120,7 @@ class XGBoostJNI {
|
|||||||
public final static native int XGBoosterSetAttr(long handle, String key, String value);
|
public final static native int XGBoosterSetAttr(long handle, String key, String value);
|
||||||
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
|
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
|
||||||
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
|
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
|
||||||
|
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);
|
||||||
|
|
||||||
// rabit functions
|
// rabit functions
|
||||||
public final static native int RabitInit(String[] args);
|
public final static native int RabitInit(String[] args);
|
||||||
|
|||||||
@ -291,6 +291,14 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
|||||||
.asScala.mapValues(_.doubleValue).toSeq: _*)
|
.asScala.mapValues(_.doubleValue).toSeq: _*)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the number of model features.
|
||||||
|
*
|
||||||
|
* @return number of features
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def getNumFeature: Long = booster.getNumFeature
|
||||||
|
|
||||||
def getVersion: Int = booster.getVersion
|
def getVersion: Int = booster.getVersion
|
||||||
|
|
||||||
def toByteArray: Array[Byte] = {
|
def toByteArray: Array[Byte] = {
|
||||||
|
|||||||
@ -848,6 +848,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit
|
|||||||
return XGBoosterSaveRabitCheckpoint(handle);
|
return XGBoosterSaveRabitCheckpoint(handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: XGBoosterGetNumFeature
|
||||||
|
* Signature: (J[J)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
|
||||||
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) {
|
||||||
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
|
bst_ulong num_feature;
|
||||||
|
int ret = XGBoosterGetNumFeature(handle, &num_feature);
|
||||||
|
JVM_CHECK_CALL(ret);
|
||||||
|
jlong jnum_feature = num_feature;
|
||||||
|
jenv->SetLongArrayRegion(jout, 0, 1, &jnum_feature);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: RabitInit
|
* Method: RabitInit
|
||||||
|
|||||||
@ -271,6 +271,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabit
|
|||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
|
||||||
(JNIEnv *, jclass, jlong);
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: XGBoosterGetNumFeature
|
||||||
|
* Signature: (J[J)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
|
||||||
|
(JNIEnv *, jclass, jlong, jlongArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: RabitInit
|
* Method: RabitInit
|
||||||
|
|||||||
@ -646,4 +646,18 @@ public class BoosterImplTest {
|
|||||||
TestCase.assertEquals(attr.get("bb"), "BB");
|
TestCase.assertEquals(attr.get("bb"), "BB");
|
||||||
TestCase.assertEquals(attr.get("cc"), "CC");
|
TestCase.assertEquals(attr.get("cc"), "CC");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* test get number of features from a booster
|
||||||
|
*
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void testGetNumFeature() throws XGBoostError {
|
||||||
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
|
Booster booster = trainBooster(trainMat, testMat);
|
||||||
|
TestCase.assertEquals(booster.getNumFeature(), 127);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -210,4 +210,12 @@ class ScalaBoosterImplSuite extends FunSuite {
|
|||||||
val nextBooster = XGBoost.train(trainMat, paramMap, round = 4, booster = prevBooster)
|
val nextBooster = XGBoost.train(trainMat, paramMap, round = 4, booster = prevBooster)
|
||||||
assert(prevBooster == nextBooster)
|
assert(prevBooster == nextBooster)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("test getting number of features from a booster") {
|
||||||
|
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
|
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||||
|
val booster = trainBooster(trainMat, testMat)
|
||||||
|
|
||||||
|
TestCase.assertEquals(booster.getNumFeature, 127)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user