[jvm-packages] Saving models into a tmp folder every a few rounds (#2964)

* [jvm-packages] Train Booster from an existing model

* Align Scala API with Java API

* Existing model should not load rabit checkpoint

* Address minor comments

* Implement saving temporary boosters and loading previous booster

* Add more unit tests for loadPrevBooster

* Add params to XGBoostEstimator

* (1) Move repartition out of the temp model saving loop (2) Address CR comments

* Catch a corner case of training next model with fewer rounds

* Address comments

* Refactor newly added methods into TmpBoosterManager

* Add two files which is missing in previous commit

* Rename TmpBooster to checkpoint
This commit is contained in:
Yun Ni
2017-12-29 08:36:41 -08:00
committed by Nan Zhu
parent eedca8c8ec
commit 9004ca03ca
11 changed files with 481 additions and 60 deletions

View File

@@ -34,6 +34,7 @@ public class Booster implements Serializable, KryoSerializable {
private static final Log logger = LogFactory.getLog(Booster.class);
// handle to the booster.
private long handle = 0;
private int version = 0;
/**
* Create a new Booster with empty stage.
@@ -416,6 +417,14 @@ public class Booster implements Serializable, KryoSerializable {
return modelInfos[0];
}
public int getVersion() {
return this.version;
}
public void setVersion(int version) {
this.version = version;
}
/**
*
* @return the saved byte array.
@@ -436,16 +445,18 @@ public class Booster implements Serializable, KryoSerializable {
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
return out[0];
version = out[0];
return version;
}
/**
* Save the booster model into thread-local rabit checkpoint.
* Save the booster model into thread-local rabit checkpoint and increment the version.
* This is only used in distributed training.
* @throws XGBoostError
*/
void saveRabitCheckpoint() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
version += 1;
}
/**
@@ -481,6 +492,7 @@ public class Booster implements Serializable, KryoSerializable {
// making Booster serializable
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
try {
out.writeInt(version);
out.writeObject(this.toByteArray());
} catch (XGBoostError ex) {
ex.printStackTrace();
@@ -492,6 +504,7 @@ public class Booster implements Serializable, KryoSerializable {
throws IOException, ClassNotFoundException {
try {
this.init(null);
this.version = in.readInt();
byte[] bytes = (byte[])in.readObject();
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
} catch (XGBoostError ex) {
@@ -520,6 +533,7 @@ public class Booster implements Serializable, KryoSerializable {
int serObjSize = serObj.length;
System.out.println("==== serialized obj size " + serObjSize);
output.writeInt(serObjSize);
output.writeInt(version);
output.write(serObj);
} catch (XGBoostError ex) {
ex.printStackTrace();
@@ -532,6 +546,7 @@ public class Booster implements Serializable, KryoSerializable {
try {
this.init(null);
int serObjSize = input.readInt();
this.version = input.readInt();
System.out.println("==== the size of the object: " + serObjSize);
byte[] bytes = new byte[serObjSize];
input.readBytes(bytes);

View File

@@ -57,6 +57,18 @@ public class XGBoost {
return Booster.loadModel(in);
}
/**
* Train a booster given parameters.
*
* @param dtrain Data to be trained.
* @param params Parameters.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
* @param obj customized objective
* @param eval customized evaluation
* @return The trained booster.
*/
public static Booster train(
DMatrix dtrain,
Map<String, Object> params,
@@ -67,6 +79,23 @@ public class XGBoost {
return train(dtrain, params, round, watches, null, obj, eval, 0);
}
/**
* Train a booster given parameters.
*
* @param dtrain Data to be trained.
* @param params Parameters.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
* @param metrics array containing the evaluation metrics for each matrix in watches for each
* iteration
* @param earlyStoppingRound if non-zero, training would be stopped
* after a specified number of consecutive
* increases in any evaluation metric.
* @param obj customized objective
* @param eval customized evaluation
* @return The trained booster.
*/
public static Booster train(
DMatrix dtrain,
Map<String, Object> params,
@@ -76,6 +105,37 @@ public class XGBoost {
IObjective obj,
IEvaluation eval,
int earlyStoppingRound) throws XGBoostError {
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
}
/**
* Train a booster given parameters.
*
* @param dtrain Data to be trained.
* @param params Parameters.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
* @param metrics array containing the evaluation metrics for each matrix in watches for each
* iteration
* @param earlyStoppingRound if non-zero, training would be stopped
* after a specified number of consecutive
* increases in any evaluation metric.
* @param obj customized objective
* @param eval customized evaluation
* @param booster train from scratch if set to null; train from an existing booster if not null.
* @return The trained booster.
*/
public static Booster train(
DMatrix dtrain,
Map<String, Object> params,
int round,
Map<String, DMatrix> watches,
float[][] metrics,
IObjective obj,
IEvaluation eval,
int earlyStoppingRound,
Booster booster) throws XGBoostError {
//collect eval matrixs
String[] evalNames;
@@ -104,20 +164,24 @@ public class XGBoost {
}
//initialize booster
Booster booster = new Booster(params, allMats);
int version = booster.loadRabitCheckpoint();
if (booster == null) {
// Start training on a new booster
booster = new Booster(params, allMats);
booster.loadRabitCheckpoint();
} else {
// Start training on an existing booster
booster.setParams(params);
}
//begin to train
for (int iter = version / 2; iter < round; iter++) {
if (version % 2 == 0) {
for (int iter = booster.getVersion() / 2; iter < round; iter++) {
if (booster.getVersion() % 2 == 0) {
if (obj != null) {
booster.update(dtrain, obj);
} else {
booster.update(dtrain, iter);
}
booster.saveRabitCheckpoint();
version += 1;
}
//evaluation
@@ -149,7 +213,6 @@ public class XGBoost {
}
}
booster.saveRabitCheckpoint();
version += 1;
}
return booster;
}

View File

@@ -25,7 +25,7 @@ import ml.dmlc.xgboost4j.java.XGBoostError
import scala.collection.JavaConverters._
import scala.collection.mutable
class Booster private[xgboost4j](private var booster: JBooster)
class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
extends Serializable with KryoSerializable {
/**

View File

@@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala
import java.io.InputStream
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost, XGBoostError}
import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError}
import scala.collection.JavaConverters._
/**
@@ -41,6 +41,7 @@ object XGBoost {
* increases in any evaluation metric.
* @param obj customized objective
* @param eval customized evaluation
* @param booster train from scratch if set to null; train from an existing booster if not null.
* @return The trained booster.
*/
@throws(classOf[XGBoostError])
@@ -52,13 +53,19 @@ object XGBoost {
metrics: Array[Array[Float]] = null,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
earlyStoppingRound: Int = 0): Booster = {
earlyStoppingRound: Int = 0,
booster: Booster = null): Booster = {
val jWatches = watches.mapValues(_.jDMatrix).asJava
val jBooster = if (booster == null) {
null
} else {
booster.booster
}
val xgboostInJava = JXGBoost.train(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
round, jWatches, metrics, obj, eval, earlyStoppingRound)
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
new Booster(xgboostInJava)
}

View File

@@ -15,10 +15,7 @@
*/
package ml.dmlc.xgboost4j.java;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
@@ -347,4 +344,55 @@ public class BoosterImplTest {
int nfold = 5;
String[] evalHist = XGBoost.crossValidation(trainMat, param, round, nfold, null, null, null);
}
/**
* test train from existing model
*
* @throws XGBoostError
*/
@Test
public void testTrainFromExistingModel() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
IEvaluation eval = new EvalError();
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("eta", 1.0);
put("max_depth", 2);
put("silent", 1);
put("objective", "binary:logistic");
}
};
//set watchList
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
// Train without saving temp booster
int round = 4;
Booster booster1 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0);
float booster1error = eval.eval(booster1.predict(testMat, true, 0), testMat);
// Train with temp Booster
round = 2;
Booster tempBooster = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0);
float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat);
// Save tempBooster to bytestream and load back
int prevVersion = tempBooster.getVersion();
ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray());
tempBooster = XGBoost.loadModel(in);
in.close();
tempBooster.setVersion(prevVersion);
// Continue training using tempBooster
round = 4;
Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster);
float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat);
TestCase.assertTrue(booster1error == booster2error);
TestCase.assertTrue(tempBoosterError > booster2error);
}
}