[jvm-packages] Saving models into a tmp folder every a few rounds (#2964)
* [jvm-packages] Train Booster from an existing model * Align Scala API with Java API * Existing model should not load rabit checkpoint * Address minor comments * Implement saving temporary boosters and loading previous booster * Add more unit tests for loadPrevBooster * Add params to XGBoostEstimator * (1) Move repartition out of the temp model saving loop (2) Address CR comments * Catch a corner case of training next model with fewer rounds * Address comments * Refactor newly added methods into TmpBoosterManager * Add two files which is missing in previous commit * Rename TmpBooster to checkpoint
This commit is contained in:
@@ -15,10 +15,7 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.*;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
@@ -347,4 +344,55 @@ public class BoosterImplTest {
|
||||
int nfold = 5;
|
||||
String[] evalHist = XGBoost.crossValidation(trainMat, param, round, nfold, null, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* test train from existing model
|
||||
*
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
@Test
|
||||
public void testTrainFromExistingModel() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
IEvaluation eval = new EvalError();
|
||||
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("eta", 1.0);
|
||||
put("max_depth", 2);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
}
|
||||
};
|
||||
|
||||
//set watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
// Train without saving temp booster
|
||||
int round = 4;
|
||||
Booster booster1 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0);
|
||||
float booster1error = eval.eval(booster1.predict(testMat, true, 0), testMat);
|
||||
|
||||
// Train with temp Booster
|
||||
round = 2;
|
||||
Booster tempBooster = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0);
|
||||
float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat);
|
||||
|
||||
// Save tempBooster to bytestream and load back
|
||||
int prevVersion = tempBooster.getVersion();
|
||||
ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray());
|
||||
tempBooster = XGBoost.loadModel(in);
|
||||
in.close();
|
||||
tempBooster.setVersion(prevVersion);
|
||||
|
||||
// Continue training using tempBooster
|
||||
round = 4;
|
||||
Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster);
|
||||
float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat);
|
||||
TestCase.assertTrue(booster1error == booster2error);
|
||||
TestCase.assertTrue(tempBoosterError > booster2error);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user