[jvm-packages] Saving models into a tmp folder every a few rounds (#2964)

* [jvm-packages] Train Booster from an existing model

* Align Scala API with Java API

* Existing model should not load rabit checkpoint

* Address minor comments

* Implement saving temporary boosters and loading previous booster

* Add more unit tests for loadPrevBooster

* Add params to XGBoostEstimator

* (1) Move repartition out of the temp model saving loop (2) Address CR comments

* Catch a corner case of training next model with fewer rounds

* Address comments

* Refactor newly added methods into TmpBoosterManager

* Add two files which is missing in previous commit

* Rename TmpBooster to checkpoint
This commit is contained in:
Yun Ni 2017-12-29 08:36:41 -08:00 committed by Nan Zhu
parent eedca8c8ec
commit 9004ca03ca
11 changed files with 481 additions and 60 deletions

View File

@ -0,0 +1,139 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.Booster
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
/**
* A class which allows user to save checkpoint boosters every a few rounds. If a previous job
* fails, the job can restart training from a saved booster instead of from scratch. This class
* provides interface and helper methods for the checkpoint functionality.
*
* @param sc the sparkContext object
* @param checkpointPath the hdfs path to store checkpoint boosters
*/
private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) {
private val logger = LogFactory.getLog("XGBoostSpark")
private val modelSuffix = ".model"
private def getPath(version: Int) = {
s"$checkpointPath/$version$modelSuffix"
}
private def getExistingVersions: Seq[Int] = {
val fs = FileSystem.get(sc.hadoopConfiguration)
if (checkpointPath.isEmpty || !fs.exists(new Path(checkpointPath))) {
Seq()
} else {
fs.listStatus(new Path(checkpointPath)).map(_.getPath.getName).collect {
case fileName if fileName.endsWith(modelSuffix) => fileName.stripSuffix(modelSuffix).toInt
}
}
}
/**
* Load existing checkpoint with the highest version.
*
* @return the booster with the highest version, null if no checkpoints available.
*/
private[spark] def loadBooster: Booster = {
val versions = getExistingVersions
if (versions.nonEmpty) {
val version = versions.max
val fullPath = getPath(version)
logger.info(s"Start training from previous booster at $fullPath")
val model = XGBoost.loadModelFromHadoopFile(fullPath)(sc)
model.booster.booster.setVersion(version)
model.booster
} else {
null
}
}
/**
* Clean up all previous models and save a new model
*
* @param model the xgboost model to save
*/
private[spark] def updateModel(model: XGBoostModel): Unit = {
val fs = FileSystem.get(sc.hadoopConfiguration)
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
val fullPath = getPath(model.version)
logger.info(s"Saving checkpoint model with version ${model.version} to $fullPath")
model.saveModelAsHadoopFile(fullPath)(sc)
prevModelPaths.foreach(path => fs.delete(path, true))
}
/**
* Clean up checkpoint boosters with version higher than or equal to the round.
*
* @param round the number of rounds in the current training job
*/
private[spark] def cleanUpHigherVersions(round: Int): Unit = {
val higherVersions = getExistingVersions.filter(_ / 2 >= round)
higherVersions.foreach { version =>
val fs = FileSystem.get(sc.hadoopConfiguration)
fs.delete(new Path(getPath(version)), true)
}
}
/**
* Calculate a list of checkpoint rounds to save checkpoints based on the savingFreq and
* total number of rounds for the training. Concretely, the saving rounds start with
* prevRounds + savingFreq, and increase by savingFreq in each step until it reaches total
* number of rounds. If savingFreq is 0, the checkpoint will be disabled and the method
* returns Seq(round)
*
* @param savingFreq the increase on rounds during each step of training
* @param round the total number of rounds for the training
* @return a seq of integers, each represent the index of round to save the checkpoints
*/
private[spark] def getSavingRounds(savingFreq: Int, round: Int): Seq[Int] = {
if (checkpointPath.nonEmpty && savingFreq > 0) {
val prevRounds = getExistingVersions.map(_ / 2)
val firstSavingRound = (0 +: prevRounds).max + savingFreq
(firstSavingRound until round by savingFreq) :+ round
} else if (savingFreq <= 0) {
Seq(round)
} else {
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.")
}
}
}
object CheckpointManager {
private[spark] def extractParams(params: Map[String, Any]): (String, Int) = {
val checkpointPath: String = params.get("checkpoint_path") match {
case None => ""
case Some(path: String) => path
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
" an instance of String.")
}
val savingFreq: Int = params.get("saving_frequency") match {
case None => 0
case Some(freq: Int) => freq
case _ => throw new IllegalArgumentException("parameter \"saving_frequency\" must be" +
" an instance of Int.")
}
(checkpointPath, savingFreq)
}
}

View File

@ -20,7 +20,6 @@ import java.io.File
import scala.collection.mutable
import scala.util.Random
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
@ -101,23 +100,19 @@ object XGBoost extends Serializable {
data: RDD[XGBLabeledPoint],
params: Map[String, Any],
rabitEnv: java.util.Map[String, String],
numWorkers: Int,
round: Int,
obj: ObjectiveTrait,
eval: EvalTrait,
useExternalMemory: Boolean,
missing: Float): RDD[(Booster, Map[String, Array[Float]])] = {
val partitionedData = if (data.getNumPartitions != numWorkers) {
logger.info(s"repartitioning training set to $numWorkers partitions")
data.repartition(numWorkers)
} else {
data
}
val partitionedBaseMargin = partitionedData.map(_.baseMargin)
missing: Float,
prevBooster: Booster
): RDD[(Booster, Map[String, Array[Float]])] = {
val partitionedBaseMargin = data.map(_.baseMargin)
// to workaround the empty partitions in training dataset,
// this might not be the best efficient implementation, see
// (https://github.com/dmlc/xgboost/issues/1277)
partitionedData.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) =>
data.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) =>
if (labeledPoints.isEmpty) {
throw new XGBoostError(
s"detected an empty partition in the training data, partition ID:" +
@ -145,7 +140,7 @@ object XGBoost extends Serializable {
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
val booster = SXGBoost.train(watches.train, params, round,
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds)
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
} finally {
Rabit.shutdown()
@ -330,34 +325,58 @@ object XGBoost extends Serializable {
case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" +
" an instance of Long.")
}
val (checkpointPath, savingFeq) = CheckpointManager.extractParams(params)
val partitionedData = repartitionForTraining(trainingData, nWorkers)
val tracker = startTracker(nWorkers, trackerConf)
try {
val sc = trainingData.sparkContext
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, trainingData.sparkContext)
val boostersAndMetrics = buildDistributedBoosters(trainingData, overriddenParams,
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
val sparkJobThread = new Thread() {
override def run() {
// force the job
boostersAndMetrics.foreachPartition(() => _)
}
val sc = trainingData.sparkContext
val checkpointManager = new CheckpointManager(sc, checkpointPath)
checkpointManager.cleanUpHigherVersions(round)
var prevBooster = checkpointManager.loadBooster
// Train for every ${savingRound} rounds and save the partially completed booster
checkpointManager.getSavingRounds(savingFeq, round).map {
savingRound: Int =>
val tracker = startTracker(nWorkers, trackerConf)
try {
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams,
tracker.getWorkerEnvs, savingRound, obj, eval, useExternalMemory, missing, prevBooster)
val sparkJobThread = new Thread() {
override def run() {
// force the job
boostersAndMetrics.foreachPartition(() => _)
}
}
sparkJobThread.setUncaughtExceptionHandler(tracker)
sparkJobThread.start()
val isClsTask = isClassificationTask(params)
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
logger.info(s"Rabit returns with exit code $trackerReturnVal")
val model = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics,
sparkJobThread, isClsTask)
if (isClsTask){
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
params.getOrElse("num_class", "2").toString.toInt
}
if (savingRound < round) {
prevBooster = model.booster
checkpointManager.updateModel(model)
}
model
} finally {
tracker.stop()
}
sparkJobThread.setUncaughtExceptionHandler(tracker)
sparkJobThread.start()
val isClsTask = isClassificationTask(params)
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
logger.info(s"Rabit returns with exit code $trackerReturnVal")
val model = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics,
sparkJobThread, isClsTask)
if (isClsTask){
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
params.getOrElse("num_class", "2").toString.toInt
}
model
} finally {
tracker.stop()
}.last
}
private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = {
if (trainingData.getNumPartitions != nWorkers) {
logger.info(s"repartitioning training set to $nWorkers partitions")
trainingData.repartition(nWorkers)
} else {
trainingData
}
}
@ -405,6 +424,7 @@ object XGBoost extends Serializable {
xgBoostModel.setPredictionCol(predCol)
}
/**
* Load XGBoost model from path in HDFS-compatible file system
*

View File

@ -344,6 +344,8 @@ abstract class XGBoostModel(protected var _booster: Booster)
def booster: Booster = _booster
def version: Int = this.booster.booster.getVersion
override def copy(extra: ParamMap): XGBoostModel = defaultCopy(extra)
override def write: MLWriter = new XGBoostModel.XGBoostModelModelWriter(this)

View File

@ -77,6 +77,21 @@ trait GeneralParams extends Params {
" request new Workers if numCores are insufficient. The timeout will be disabled if this" +
" value is set smaller than or equal to 0.")
/**
* The hdfs folder to load and save checkpoint boosters. default: `empty_string`
*/
val checkpointPath = new Param[String](this, "checkpoint_path", "the hdfs folder to load and " +
"save checkpoints. The job will try to load the existing booster as the starting point for " +
"training. If saving_frequency is also set, the job will save a checkpoint every a few rounds.")
/**
* The frequency to save checkpoint boosters. default: 0
*/
val savingFrequency = new IntParam(this, "saving_frequency", "if checkpoint_path is also set," +
" the job will save checkpoints at this frequency. If the job fails and gets restarted with" +
" same setting, it will load the existing booster instead of training from scratch." +
" Checkpoint will be disabled if set to 0.")
/**
* Rabit tracker configurations. The parameter must be provided as an instance of the
* TrackerConf class, which has the following definition:
@ -112,6 +127,7 @@ trait GeneralParams extends Params {
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
useExternalMemory -> false, silent -> 0,
customObj -> null, customEval -> null, missing -> Float.NaN,
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L,
checkpointPath -> "", savingFrequency -> 0
)
}

View File

@ -0,0 +1,80 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import java.nio.file.Files
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.{SparkConf, SparkContext}
class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
var sc: SparkContext = _
override def beforeAll(): Unit = {
val conf: SparkConf = new SparkConf()
.setMaster("local[*]")
.setAppName("XGBoostSuite")
sc = new SparkContext(conf)
}
private lazy val (model4, model8) = {
import DataUtils._
val trainingRDD = sc.parallelize(Classification.train).map(_.asML).cache()
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
(XGBoost.trainWithRDD(trainingRDD, paramMap, round = 2, sc.defaultParallelism),
XGBoost.trainWithRDD(trainingRDD, paramMap, round = 4, sc.defaultParallelism))
}
test("test update/load models") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
manager.updateModel(model4)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
assert(manager.loadBooster.booster.getVersion == 4)
manager.updateModel(model8)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
assert(manager.loadBooster.booster.getVersion == 8)
}
test("test cleanUpHigherVersions") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
manager.updateModel(model8)
manager.cleanUpHigherVersions(round = 8)
assert(new File(s"$tmpPath/8.model").exists())
manager.cleanUpHigherVersions(round = 4)
assert(!new File(s"$tmpPath/8.model").exists())
}
test("test saving rounds") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
assertResult(Seq(7))(manager.getSavingRounds(savingFreq = 0, round = 7))
assertResult(Seq(2, 4, 6, 7))(manager.getSavingRounds(savingFreq = 2, round = 7))
manager.updateModel(model4)
assertResult(Seq(4, 6, 7))(manager.getSavingRounds(2, 7))
}
}

View File

@ -16,13 +16,14 @@
package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files
import java.util.concurrent.LinkedBlockingDeque
import scala.util.Random
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, Vectors, Vector => SparkVector}
@ -73,13 +74,14 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
test("build RDD containing boosters with the specified worker number") {
val trainingRDD = sc.parallelize(Classification.train)
val partitionedRDD = XGBoost.repartitionForTraining(trainingRDD, 2)
val boosterRDD = XGBoost.buildDistributedBoosters(
trainingRDD,
partitionedRDD,
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap,
new java.util.HashMap[String, String](),
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true,
missing = Float.NaN)
round = 5, eval = null, obj = null, useExternalMemory = true,
missing = Float.NaN, prevBooster = null)
val boosterCount = boosterRDD.count()
assert(boosterCount === 2)
}
@ -335,4 +337,33 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
assert(XGBoost.isClassificationTask(params) == isClassificationTask)
}
}
test("training with saving checkpoint boosters") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString
val paramMap = List("eta" -> "1", "max_depth" -> 2, "silent" -> "1",
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"saving_frequency" -> 2).toMap
val prevModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers)
def error(model: XGBoostModel): Float = eval.eval(
model.booster.predict(testSetDMatrix, outPutMargin = true), testSetDMatrix)
// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
val tmpModel = XGBoost.loadModelFromHadoopFile(s"$tmpPath/8.model")
// Train next model based on prev model
val nextModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 8,
nWorkers = numWorkers)
assert(error(tmpModel) > error(prevModel))
assert(error(prevModel) > error(nextModel))
assert(error(nextModel) < 0.1)
}
}

View File

@ -34,6 +34,7 @@ public class Booster implements Serializable, KryoSerializable {
private static final Log logger = LogFactory.getLog(Booster.class);
// handle to the booster.
private long handle = 0;
private int version = 0;
/**
* Create a new Booster with empty stage.
@ -416,6 +417,14 @@ public class Booster implements Serializable, KryoSerializable {
return modelInfos[0];
}
public int getVersion() {
return this.version;
}
public void setVersion(int version) {
this.version = version;
}
/**
*
* @return the saved byte array.
@ -436,16 +445,18 @@ public class Booster implements Serializable, KryoSerializable {
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
return out[0];
version = out[0];
return version;
}
/**
* Save the booster model into thread-local rabit checkpoint.
* Save the booster model into thread-local rabit checkpoint and increment the version.
* This is only used in distributed training.
* @throws XGBoostError
*/
void saveRabitCheckpoint() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
version += 1;
}
/**
@ -481,6 +492,7 @@ public class Booster implements Serializable, KryoSerializable {
// making Booster serializable
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
try {
out.writeInt(version);
out.writeObject(this.toByteArray());
} catch (XGBoostError ex) {
ex.printStackTrace();
@ -492,6 +504,7 @@ public class Booster implements Serializable, KryoSerializable {
throws IOException, ClassNotFoundException {
try {
this.init(null);
this.version = in.readInt();
byte[] bytes = (byte[])in.readObject();
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
} catch (XGBoostError ex) {
@ -520,6 +533,7 @@ public class Booster implements Serializable, KryoSerializable {
int serObjSize = serObj.length;
System.out.println("==== serialized obj size " + serObjSize);
output.writeInt(serObjSize);
output.writeInt(version);
output.write(serObj);
} catch (XGBoostError ex) {
ex.printStackTrace();
@ -532,6 +546,7 @@ public class Booster implements Serializable, KryoSerializable {
try {
this.init(null);
int serObjSize = input.readInt();
this.version = input.readInt();
System.out.println("==== the size of the object: " + serObjSize);
byte[] bytes = new byte[serObjSize];
input.readBytes(bytes);

View File

@ -57,6 +57,18 @@ public class XGBoost {
return Booster.loadModel(in);
}
/**
* Train a booster given parameters.
*
* @param dtrain Data to be trained.
* @param params Parameters.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
* @param obj customized objective
* @param eval customized evaluation
* @return The trained booster.
*/
public static Booster train(
DMatrix dtrain,
Map<String, Object> params,
@ -67,6 +79,23 @@ public class XGBoost {
return train(dtrain, params, round, watches, null, obj, eval, 0);
}
/**
* Train a booster given parameters.
*
* @param dtrain Data to be trained.
* @param params Parameters.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
* @param metrics array containing the evaluation metrics for each matrix in watches for each
* iteration
* @param earlyStoppingRound if non-zero, training would be stopped
* after a specified number of consecutive
* increases in any evaluation metric.
* @param obj customized objective
* @param eval customized evaluation
* @return The trained booster.
*/
public static Booster train(
DMatrix dtrain,
Map<String, Object> params,
@ -76,6 +105,37 @@ public class XGBoost {
IObjective obj,
IEvaluation eval,
int earlyStoppingRound) throws XGBoostError {
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
}
/**
* Train a booster given parameters.
*
* @param dtrain Data to be trained.
* @param params Parameters.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
* @param metrics array containing the evaluation metrics for each matrix in watches for each
* iteration
* @param earlyStoppingRound if non-zero, training would be stopped
* after a specified number of consecutive
* increases in any evaluation metric.
* @param obj customized objective
* @param eval customized evaluation
* @param booster train from scratch if set to null; train from an existing booster if not null.
* @return The trained booster.
*/
public static Booster train(
DMatrix dtrain,
Map<String, Object> params,
int round,
Map<String, DMatrix> watches,
float[][] metrics,
IObjective obj,
IEvaluation eval,
int earlyStoppingRound,
Booster booster) throws XGBoostError {
//collect eval matrixs
String[] evalNames;
@ -104,20 +164,24 @@ public class XGBoost {
}
//initialize booster
Booster booster = new Booster(params, allMats);
int version = booster.loadRabitCheckpoint();
if (booster == null) {
// Start training on a new booster
booster = new Booster(params, allMats);
booster.loadRabitCheckpoint();
} else {
// Start training on an existing booster
booster.setParams(params);
}
//begin to train
for (int iter = version / 2; iter < round; iter++) {
if (version % 2 == 0) {
for (int iter = booster.getVersion() / 2; iter < round; iter++) {
if (booster.getVersion() % 2 == 0) {
if (obj != null) {
booster.update(dtrain, obj);
} else {
booster.update(dtrain, iter);
}
booster.saveRabitCheckpoint();
version += 1;
}
//evaluation
@ -149,7 +213,6 @@ public class XGBoost {
}
}
booster.saveRabitCheckpoint();
version += 1;
}
return booster;
}

View File

@ -25,7 +25,7 @@ import ml.dmlc.xgboost4j.java.XGBoostError
import scala.collection.JavaConverters._
import scala.collection.mutable
class Booster private[xgboost4j](private var booster: JBooster)
class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
extends Serializable with KryoSerializable {
/**

View File

@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala
import java.io.InputStream
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost, XGBoostError}
import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError}
import scala.collection.JavaConverters._
/**
@ -41,6 +41,7 @@ object XGBoost {
* increases in any evaluation metric.
* @param obj customized objective
* @param eval customized evaluation
* @param booster train from scratch if set to null; train from an existing booster if not null.
* @return The trained booster.
*/
@throws(classOf[XGBoostError])
@ -52,13 +53,19 @@ object XGBoost {
metrics: Array[Array[Float]] = null,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
earlyStoppingRound: Int = 0): Booster = {
earlyStoppingRound: Int = 0,
booster: Booster = null): Booster = {
val jWatches = watches.mapValues(_.jDMatrix).asJava
val jBooster = if (booster == null) {
null
} else {
booster.booster
}
val xgboostInJava = JXGBoost.train(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
round, jWatches, metrics, obj, eval, earlyStoppingRound)
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
new Booster(xgboostInJava)
}

View File

@ -15,10 +15,7 @@
*/
package ml.dmlc.xgboost4j.java;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
@ -347,4 +344,55 @@ public class BoosterImplTest {
int nfold = 5;
String[] evalHist = XGBoost.crossValidation(trainMat, param, round, nfold, null, null, null);
}
/**
* test train from existing model
*
* @throws XGBoostError
*/
@Test
public void testTrainFromExistingModel() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
IEvaluation eval = new EvalError();
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("eta", 1.0);
put("max_depth", 2);
put("silent", 1);
put("objective", "binary:logistic");
}
};
//set watchList
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
// Train without saving temp booster
int round = 4;
Booster booster1 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0);
float booster1error = eval.eval(booster1.predict(testMat, true, 0), testMat);
// Train with temp Booster
round = 2;
Booster tempBooster = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0);
float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat);
// Save tempBooster to bytestream and load back
int prevVersion = tempBooster.getVersion();
ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray());
tempBooster = XGBoost.loadModel(in);
in.close();
tempBooster.setVersion(prevVersion);
// Continue training using tempBooster
round = 4;
Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster);
float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat);
TestCase.assertTrue(booster1error == booster2error);
TestCase.assertTrue(tempBoosterError > booster2error);
}
}