[jvm-packages] Saving models into a tmp folder every a few rounds (#2964)

* [jvm-packages] Train Booster from an existing model

* Align Scala API with Java API

* Existing model should not load rabit checkpoint

* Address minor comments

* Implement saving temporary boosters and loading previous booster

* Add more unit tests for loadPrevBooster

* Add params to XGBoostEstimator

* (1) Move repartition out of the temp model saving loop (2) Address CR comments

* Catch a corner case of training next model with fewer rounds

* Address comments

* Refactor newly added methods into TmpBoosterManager

* Add two files which is missing in previous commit

* Rename TmpBooster to checkpoint
This commit is contained in:
Yun Ni
2017-12-29 08:36:41 -08:00
committed by Nan Zhu
parent eedca8c8ec
commit 9004ca03ca
11 changed files with 481 additions and 60 deletions

View File

@@ -15,10 +15,7 @@
*/
package ml.dmlc.xgboost4j.java;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
@@ -347,4 +344,55 @@ public class BoosterImplTest {
int nfold = 5;
String[] evalHist = XGBoost.crossValidation(trainMat, param, round, nfold, null, null, null);
}
/**
* test train from existing model
*
* @throws XGBoostError
*/
@Test
public void testTrainFromExistingModel() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
IEvaluation eval = new EvalError();
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("eta", 1.0);
put("max_depth", 2);
put("silent", 1);
put("objective", "binary:logistic");
}
};
//set watchList
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
// Train without saving temp booster
int round = 4;
Booster booster1 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0);
float booster1error = eval.eval(booster1.predict(testMat, true, 0), testMat);
// Train with temp Booster
round = 2;
Booster tempBooster = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0);
float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat);
// Save tempBooster to bytestream and load back
int prevVersion = tempBooster.getVersion();
ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray());
tempBooster = XGBoost.loadModel(in);
in.close();
tempBooster.setVersion(prevVersion);
// Continue training using tempBooster
round = 4;
Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster);
float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat);
TestCase.assertTrue(booster1error == booster2error);
TestCase.assertTrue(tempBoosterError > booster2error);
}
}