[jvm-packages] Add getNumFeature method (#6075)

* Add getNumFeature to the Java API
* Add getNumFeature to the Scala API
* Add unit tests for getNumFeature

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Hristo Iliev
2020-09-08 06:57:46 +03:00
committed by GitHub
parent 93e9af43bb
commit da61d9460b
7 changed files with 66 additions and 0 deletions

View File

@@ -646,4 +646,18 @@ public class BoosterImplTest {
TestCase.assertEquals(attr.get("bb"), "BB");
TestCase.assertEquals(attr.get("cc"), "CC");
}
/**
* test get number of features from a booster
*
* @throws XGBoostError
*/
@Test
public void testGetNumFeature() throws XGBoostError {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
Booster booster = trainBooster(trainMat, testMat);
TestCase.assertEquals(booster.getNumFeature(), 127);
}
}

View File

@@ -210,4 +210,12 @@ class ScalaBoosterImplSuite extends FunSuite {
val nextBooster = XGBoost.train(trainMat, paramMap, round = 4, booster = prevBooster)
assert(prevBooster == nextBooster)
}
test("test getting number of features from a booster") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val booster = trainBooster(trainMat, testMat)
TestCase.assertEquals(booster.getNumFeature, 127)
}
}