[jvm-packages] Add getNumFeature method (#6075)
* Add getNumFeature to the Java API * Add getNumFeature to the Scala API * Add unit tests for getNumFeature Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -646,4 +646,18 @@ public class BoosterImplTest {
|
||||
TestCase.assertEquals(attr.get("bb"), "BB");
|
||||
TestCase.assertEquals(attr.get("cc"), "CC");
|
||||
}
|
||||
|
||||
/**
|
||||
* test get number of features from a booster
|
||||
*
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
@Test
|
||||
public void testGetNumFeature() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
TestCase.assertEquals(booster.getNumFeature(), 127);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,4 +210,12 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
val nextBooster = XGBoost.train(trainMat, paramMap, round = 4, booster = prevBooster)
|
||||
assert(prevBooster == nextBooster)
|
||||
}
|
||||
|
||||
test("test getting number of features from a booster") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val booster = trainBooster(trainMat, testMat)
|
||||
|
||||
TestCase.assertEquals(booster.getNumFeature, 127)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user