[jvm-packages] Saving models into a tmp folder every a few rounds (#2964)
* [jvm-packages] Train Booster from an existing model * Align Scala API with Java API * Existing model should not load rabit checkpoint * Address minor comments * Implement saving temporary boosters and loading previous booster * Add more unit tests for loadPrevBooster * Add params to XGBoostEstimator * (1) Move repartition out of the temp model saving loop (2) Address CR comments * Catch a corner case of training next model with fewer rounds * Address comments * Refactor newly added methods into TmpBoosterManager * Add two files which is missing in previous commit * Rename TmpBooster to checkpoint
This commit is contained in:
@@ -34,6 +34,7 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
private static final Log logger = LogFactory.getLog(Booster.class);
|
||||
// handle to the booster.
|
||||
private long handle = 0;
|
||||
private int version = 0;
|
||||
|
||||
/**
|
||||
* Create a new Booster with empty stage.
|
||||
@@ -416,6 +417,14 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
public int getVersion() {
|
||||
return this.version;
|
||||
}
|
||||
|
||||
public void setVersion(int version) {
|
||||
this.version = version;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return the saved byte array.
|
||||
@@ -436,16 +445,18 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
int loadRabitCheckpoint() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
||||
return out[0];
|
||||
version = out[0];
|
||||
return version;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the booster model into thread-local rabit checkpoint.
|
||||
* Save the booster model into thread-local rabit checkpoint and increment the version.
|
||||
* This is only used in distributed training.
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
void saveRabitCheckpoint() throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
||||
version += 1;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -481,6 +492,7 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
// making Booster serializable
|
||||
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
|
||||
try {
|
||||
out.writeInt(version);
|
||||
out.writeObject(this.toByteArray());
|
||||
} catch (XGBoostError ex) {
|
||||
ex.printStackTrace();
|
||||
@@ -492,6 +504,7 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
throws IOException, ClassNotFoundException {
|
||||
try {
|
||||
this.init(null);
|
||||
this.version = in.readInt();
|
||||
byte[] bytes = (byte[])in.readObject();
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
|
||||
} catch (XGBoostError ex) {
|
||||
@@ -520,6 +533,7 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
int serObjSize = serObj.length;
|
||||
System.out.println("==== serialized obj size " + serObjSize);
|
||||
output.writeInt(serObjSize);
|
||||
output.writeInt(version);
|
||||
output.write(serObj);
|
||||
} catch (XGBoostError ex) {
|
||||
ex.printStackTrace();
|
||||
@@ -532,6 +546,7 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
try {
|
||||
this.init(null);
|
||||
int serObjSize = input.readInt();
|
||||
this.version = input.readInt();
|
||||
System.out.println("==== the size of the object: " + serObjSize);
|
||||
byte[] bytes = new byte[serObjSize];
|
||||
input.readBytes(bytes);
|
||||
|
||||
@@ -57,6 +57,18 @@ public class XGBoost {
|
||||
return Booster.loadModel(in);
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Parameters.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @return The trained booster.
|
||||
*/
|
||||
public static Booster train(
|
||||
DMatrix dtrain,
|
||||
Map<String, Object> params,
|
||||
@@ -67,6 +79,23 @@ public class XGBoost {
|
||||
return train(dtrain, params, round, watches, null, obj, eval, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Parameters.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||
* iteration
|
||||
* @param earlyStoppingRound if non-zero, training would be stopped
|
||||
* after a specified number of consecutive
|
||||
* increases in any evaluation metric.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @return The trained booster.
|
||||
*/
|
||||
public static Booster train(
|
||||
DMatrix dtrain,
|
||||
Map<String, Object> params,
|
||||
@@ -76,6 +105,37 @@ public class XGBoost {
|
||||
IObjective obj,
|
||||
IEvaluation eval,
|
||||
int earlyStoppingRound) throws XGBoostError {
|
||||
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a booster given parameters.
|
||||
*
|
||||
* @param dtrain Data to be trained.
|
||||
* @param params Parameters.
|
||||
* @param round Number of boosting iterations.
|
||||
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||
* performance on the validation set.
|
||||
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||
* iteration
|
||||
* @param earlyStoppingRound if non-zero, training would be stopped
|
||||
* after a specified number of consecutive
|
||||
* increases in any evaluation metric.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
||||
* @return The trained booster.
|
||||
*/
|
||||
public static Booster train(
|
||||
DMatrix dtrain,
|
||||
Map<String, Object> params,
|
||||
int round,
|
||||
Map<String, DMatrix> watches,
|
||||
float[][] metrics,
|
||||
IObjective obj,
|
||||
IEvaluation eval,
|
||||
int earlyStoppingRound,
|
||||
Booster booster) throws XGBoostError {
|
||||
|
||||
//collect eval matrixs
|
||||
String[] evalNames;
|
||||
@@ -104,20 +164,24 @@ public class XGBoost {
|
||||
}
|
||||
|
||||
//initialize booster
|
||||
Booster booster = new Booster(params, allMats);
|
||||
|
||||
int version = booster.loadRabitCheckpoint();
|
||||
if (booster == null) {
|
||||
// Start training on a new booster
|
||||
booster = new Booster(params, allMats);
|
||||
booster.loadRabitCheckpoint();
|
||||
} else {
|
||||
// Start training on an existing booster
|
||||
booster.setParams(params);
|
||||
}
|
||||
|
||||
//begin to train
|
||||
for (int iter = version / 2; iter < round; iter++) {
|
||||
if (version % 2 == 0) {
|
||||
for (int iter = booster.getVersion() / 2; iter < round; iter++) {
|
||||
if (booster.getVersion() % 2 == 0) {
|
||||
if (obj != null) {
|
||||
booster.update(dtrain, obj);
|
||||
} else {
|
||||
booster.update(dtrain, iter);
|
||||
}
|
||||
booster.saveRabitCheckpoint();
|
||||
version += 1;
|
||||
}
|
||||
|
||||
//evaluation
|
||||
@@ -149,7 +213,6 @@ public class XGBoost {
|
||||
}
|
||||
}
|
||||
booster.saveRabitCheckpoint();
|
||||
version += 1;
|
||||
}
|
||||
return booster;
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
class Booster private[xgboost4j](private var booster: JBooster)
|
||||
class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
extends Serializable with KryoSerializable {
|
||||
|
||||
/**
|
||||
|
||||
@@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import java.io.InputStream
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost, XGBoostError}
|
||||
import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError}
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
/**
|
||||
@@ -41,6 +41,7 @@ object XGBoost {
|
||||
* increases in any evaluation metric.
|
||||
* @param obj customized objective
|
||||
* @param eval customized evaluation
|
||||
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
||||
* @return The trained booster.
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
@@ -52,13 +53,19 @@ object XGBoost {
|
||||
metrics: Array[Array[Float]] = null,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null,
|
||||
earlyStoppingRound: Int = 0): Booster = {
|
||||
earlyStoppingRound: Int = 0,
|
||||
booster: Booster = null): Booster = {
|
||||
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
||||
val jBooster = if (booster == null) {
|
||||
null
|
||||
} else {
|
||||
booster.booster
|
||||
}
|
||||
val xgboostInJava = JXGBoost.train(
|
||||
dtrain.jDMatrix,
|
||||
// we have to filter null value for customized obj and eval
|
||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||
round, jWatches, metrics, obj, eval, earlyStoppingRound)
|
||||
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
||||
new Booster(xgboostInJava)
|
||||
}
|
||||
|
||||
|
||||
@@ -15,10 +15,7 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.*;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
@@ -347,4 +344,55 @@ public class BoosterImplTest {
|
||||
int nfold = 5;
|
||||
String[] evalHist = XGBoost.crossValidation(trainMat, param, round, nfold, null, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* test train from existing model
|
||||
*
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
@Test
|
||||
public void testTrainFromExistingModel() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
IEvaluation eval = new EvalError();
|
||||
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("eta", 1.0);
|
||||
put("max_depth", 2);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
}
|
||||
};
|
||||
|
||||
//set watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
// Train without saving temp booster
|
||||
int round = 4;
|
||||
Booster booster1 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0);
|
||||
float booster1error = eval.eval(booster1.predict(testMat, true, 0), testMat);
|
||||
|
||||
// Train with temp Booster
|
||||
round = 2;
|
||||
Booster tempBooster = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0);
|
||||
float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat);
|
||||
|
||||
// Save tempBooster to bytestream and load back
|
||||
int prevVersion = tempBooster.getVersion();
|
||||
ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray());
|
||||
tempBooster = XGBoost.loadModel(in);
|
||||
in.close();
|
||||
tempBooster.setVersion(prevVersion);
|
||||
|
||||
// Continue training using tempBooster
|
||||
round = 4;
|
||||
Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster);
|
||||
float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat);
|
||||
TestCase.assertTrue(booster1error == booster2error);
|
||||
TestCase.assertTrue(tempBoosterError > booster2error);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user