[JVM] Refactor, add filesys API
This commit is contained in:
@@ -15,6 +15,10 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@@ -67,7 +71,7 @@ public class BoosterImplTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBoosterBasic() throws XGBoostError {
|
||||
public void testBoosterBasic() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
@@ -94,15 +98,20 @@ public class BoosterImplTest {
|
||||
Booster booster = XGBoost.train(paramMap, trainMat, round, watches, null, null);
|
||||
|
||||
//predict raw output
|
||||
float[][] predicts = booster.predict(testMat, true);
|
||||
float[][] predicts = booster.predict(testMat, true, 0);
|
||||
|
||||
//eval
|
||||
IEvaluation eval = new EvalError();
|
||||
//error must be less than 0.1
|
||||
TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f);
|
||||
|
||||
//test dump model
|
||||
// save and load
|
||||
File temp = File.createTempFile("temp", "model");
|
||||
temp.deleteOnExit();
|
||||
booster.saveModel(temp.getAbsolutePath());
|
||||
|
||||
Booster bst2 = XGBoost.loadModel(new FileInputStream(temp.getAbsolutePath()));
|
||||
assert (Arrays.equals(bst2.toByteArray(), booster.toByteArray()));
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user