diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala new file mode 100644 index 000000000..ae7e296ad --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala @@ -0,0 +1,139 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import ml.dmlc.xgboost4j.scala.Booster +import org.apache.commons.logging.LogFactory +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.SparkContext + +/** + * A class which allows user to save checkpoint boosters every a few rounds. If a previous job + * fails, the job can restart training from a saved booster instead of from scratch. This class + * provides interface and helper methods for the checkpoint functionality. + * + * @param sc the sparkContext object + * @param checkpointPath the hdfs path to store checkpoint boosters + */ +private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) { + private val logger = LogFactory.getLog("XGBoostSpark") + private val modelSuffix = ".model" + + private def getPath(version: Int) = { + s"$checkpointPath/$version$modelSuffix" + } + + private def getExistingVersions: Seq[Int] = { + val fs = FileSystem.get(sc.hadoopConfiguration) + if (checkpointPath.isEmpty || !fs.exists(new Path(checkpointPath))) { + Seq() + } else { + fs.listStatus(new Path(checkpointPath)).map(_.getPath.getName).collect { + case fileName if fileName.endsWith(modelSuffix) => fileName.stripSuffix(modelSuffix).toInt + } + } + } + + /** + * Load existing checkpoint with the highest version. + * + * @return the booster with the highest version, null if no checkpoints available. + */ + private[spark] def loadBooster: Booster = { + val versions = getExistingVersions + if (versions.nonEmpty) { + val version = versions.max + val fullPath = getPath(version) + logger.info(s"Start training from previous booster at $fullPath") + val model = XGBoost.loadModelFromHadoopFile(fullPath)(sc) + model.booster.booster.setVersion(version) + model.booster + } else { + null + } + } + + /** + * Clean up all previous models and save a new model + * + * @param model the xgboost model to save + */ + private[spark] def updateModel(model: XGBoostModel): Unit = { + val fs = FileSystem.get(sc.hadoopConfiguration) + val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version))) + val fullPath = getPath(model.version) + logger.info(s"Saving checkpoint model with version ${model.version} to $fullPath") + model.saveModelAsHadoopFile(fullPath)(sc) + prevModelPaths.foreach(path => fs.delete(path, true)) + } + + /** + * Clean up checkpoint boosters with version higher than or equal to the round. + * + * @param round the number of rounds in the current training job + */ + private[spark] def cleanUpHigherVersions(round: Int): Unit = { + val higherVersions = getExistingVersions.filter(_ / 2 >= round) + higherVersions.foreach { version => + val fs = FileSystem.get(sc.hadoopConfiguration) + fs.delete(new Path(getPath(version)), true) + } + } + + /** + * Calculate a list of checkpoint rounds to save checkpoints based on the savingFreq and + * total number of rounds for the training. Concretely, the saving rounds start with + * prevRounds + savingFreq, and increase by savingFreq in each step until it reaches total + * number of rounds. If savingFreq is 0, the checkpoint will be disabled and the method + * returns Seq(round) + * + * @param savingFreq the increase on rounds during each step of training + * @param round the total number of rounds for the training + * @return a seq of integers, each represent the index of round to save the checkpoints + */ + private[spark] def getSavingRounds(savingFreq: Int, round: Int): Seq[Int] = { + if (checkpointPath.nonEmpty && savingFreq > 0) { + val prevRounds = getExistingVersions.map(_ / 2) + val firstSavingRound = (0 +: prevRounds).max + savingFreq + (firstSavingRound until round by savingFreq) :+ round + } else if (savingFreq <= 0) { + Seq(round) + } else { + throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.") + } + } +} + +object CheckpointManager { + + private[spark] def extractParams(params: Map[String, Any]): (String, Int) = { + val checkpointPath: String = params.get("checkpoint_path") match { + case None => "" + case Some(path: String) => path + case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" + + " an instance of String.") + } + + val savingFreq: Int = params.get("saving_frequency") match { + case None => 0 + case Some(freq: Int) => freq + case _ => throw new IllegalArgumentException("parameter \"saving_frequency\" must be" + + " an instance of Int.") + } + (checkpointPath, savingFreq) + } +} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 3ad724a94..3d342ff07 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -20,7 +20,6 @@ import java.io.File import scala.collection.mutable import scala.util.Random - import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.scala.rabit.RabitTracker import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} @@ -101,23 +100,19 @@ object XGBoost extends Serializable { data: RDD[XGBLabeledPoint], params: Map[String, Any], rabitEnv: java.util.Map[String, String], - numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait, useExternalMemory: Boolean, - missing: Float): RDD[(Booster, Map[String, Array[Float]])] = { - val partitionedData = if (data.getNumPartitions != numWorkers) { - logger.info(s"repartitioning training set to $numWorkers partitions") - data.repartition(numWorkers) - } else { - data - } - val partitionedBaseMargin = partitionedData.map(_.baseMargin) + missing: Float, + prevBooster: Booster + ): RDD[(Booster, Map[String, Array[Float]])] = { + + val partitionedBaseMargin = data.map(_.baseMargin) // to workaround the empty partitions in training dataset, // this might not be the best efficient implementation, see // (https://github.com/dmlc/xgboost/issues/1277) - partitionedData.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) => + data.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) => if (labeledPoints.isEmpty) { throw new XGBoostError( s"detected an empty partition in the training data, partition ID:" + @@ -145,7 +140,7 @@ object XGBoost extends Serializable { val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round)) val booster = SXGBoost.train(watches.train, params, round, watches.toMap, metrics, obj, eval, - earlyStoppingRound = numEarlyStoppingRounds) + earlyStoppingRound = numEarlyStoppingRounds, prevBooster) Iterator(booster -> watches.toMap.keys.zip(metrics).toMap) } finally { Rabit.shutdown() @@ -330,34 +325,58 @@ object XGBoost extends Serializable { case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" + " an instance of Long.") } + val (checkpointPath, savingFeq) = CheckpointManager.extractParams(params) + val partitionedData = repartitionForTraining(trainingData, nWorkers) - val tracker = startTracker(nWorkers, trackerConf) - try { - val sc = trainingData.sparkContext - val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers) - val overriddenParams = overrideParamsAccordingToTaskCPUs(params, trainingData.sparkContext) - val boostersAndMetrics = buildDistributedBoosters(trainingData, overriddenParams, - tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing) - val sparkJobThread = new Thread() { - override def run() { - // force the job - boostersAndMetrics.foreachPartition(() => _) - } + val sc = trainingData.sparkContext + val checkpointManager = new CheckpointManager(sc, checkpointPath) + checkpointManager.cleanUpHigherVersions(round) + + var prevBooster = checkpointManager.loadBooster + // Train for every ${savingRound} rounds and save the partially completed booster + checkpointManager.getSavingRounds(savingFeq, round).map { + savingRound: Int => + val tracker = startTracker(nWorkers, trackerConf) + try { + val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers) + val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc) + val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams, + tracker.getWorkerEnvs, savingRound, obj, eval, useExternalMemory, missing, prevBooster) + val sparkJobThread = new Thread() { + override def run() { + // force the job + boostersAndMetrics.foreachPartition(() => _) + } + } + sparkJobThread.setUncaughtExceptionHandler(tracker) + sparkJobThread.start() + val isClsTask = isClassificationTask(params) + val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L)) + logger.info(s"Rabit returns with exit code $trackerReturnVal") + val model = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics, + sparkJobThread, isClsTask) + if (isClsTask){ + model.asInstanceOf[XGBoostClassificationModel].numOfClasses = + params.getOrElse("num_class", "2").toString.toInt + } + if (savingRound < round) { + prevBooster = model.booster + checkpointManager.updateModel(model) + } + model + } finally { + tracker.stop() } - sparkJobThread.setUncaughtExceptionHandler(tracker) - sparkJobThread.start() - val isClsTask = isClassificationTask(params) - val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L)) - logger.info(s"Rabit returns with exit code $trackerReturnVal") - val model = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics, - sparkJobThread, isClsTask) - if (isClsTask){ - model.asInstanceOf[XGBoostClassificationModel].numOfClasses = - params.getOrElse("num_class", "2").toString.toInt - } - model - } finally { - tracker.stop() + }.last + } + + + private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = { + if (trainingData.getNumPartitions != nWorkers) { + logger.info(s"repartitioning training set to $nWorkers partitions") + trainingData.repartition(nWorkers) + } else { + trainingData } } @@ -405,6 +424,7 @@ object XGBoost extends Serializable { xgBoostModel.setPredictionCol(predCol) } + /** * Load XGBoost model from path in HDFS-compatible file system * diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index e0d0f82f6..4b77eec4b 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -344,6 +344,8 @@ abstract class XGBoostModel(protected var _booster: Booster) def booster: Booster = _booster + def version: Int = this.booster.booster.getVersion + override def copy(extra: ParamMap): XGBoostModel = defaultCopy(extra) override def write: MLWriter = new XGBoostModel.XGBoostModelModelWriter(this) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index 96dada6cb..106514f96 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -77,6 +77,21 @@ trait GeneralParams extends Params { " request new Workers if numCores are insufficient. The timeout will be disabled if this" + " value is set smaller than or equal to 0.") + /** + * The hdfs folder to load and save checkpoint boosters. default: `empty_string` + */ + val checkpointPath = new Param[String](this, "checkpoint_path", "the hdfs folder to load and " + + "save checkpoints. The job will try to load the existing booster as the starting point for " + + "training. If saving_frequency is also set, the job will save a checkpoint every a few rounds.") + + /** + * The frequency to save checkpoint boosters. default: 0 + */ + val savingFrequency = new IntParam(this, "saving_frequency", "if checkpoint_path is also set," + + " the job will save checkpoints at this frequency. If the job fails and gets restarted with" + + " same setting, it will load the existing booster instead of training from scratch." + + " Checkpoint will be disabled if set to 0.") + /** * Rabit tracker configurations. The parameter must be provided as an instance of the * TrackerConf class, which has the following definition: @@ -112,6 +127,7 @@ trait GeneralParams extends Params { setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1, useExternalMemory -> false, silent -> 0, customObj -> null, customEval -> null, missing -> Float.NaN, - trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L + trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L, + checkpointPath -> "", savingFrequency -> 0 ) } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala new file mode 100644 index 000000000..f0c9ba697 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala @@ -0,0 +1,80 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import java.io.File +import java.nio.file.Files + +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.{SparkConf, SparkContext} + +class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll { + var sc: SparkContext = _ + + override def beforeAll(): Unit = { + val conf: SparkConf = new SparkConf() + .setMaster("local[*]") + .setAppName("XGBoostSuite") + sc = new SparkContext(conf) + } + + private lazy val (model4, model8) = { + import DataUtils._ + val trainingRDD = sc.parallelize(Classification.train).map(_.asML).cache() + val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", + "objective" -> "binary:logistic") + (XGBoost.trainWithRDD(trainingRDD, paramMap, round = 2, sc.defaultParallelism), + XGBoost.trainWithRDD(trainingRDD, paramMap, round = 4, sc.defaultParallelism)) + } + + test("test update/load models") { + val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString + val manager = new CheckpointManager(sc, tmpPath) + manager.updateModel(model4) + var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) + assert(files.length == 1) + assert(files.head.getPath.getName == "4.model") + assert(manager.loadBooster.booster.getVersion == 4) + + manager.updateModel(model8) + files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) + assert(files.length == 1) + assert(files.head.getPath.getName == "8.model") + assert(manager.loadBooster.booster.getVersion == 8) + } + + test("test cleanUpHigherVersions") { + val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString + val manager = new CheckpointManager(sc, tmpPath) + manager.updateModel(model8) + manager.cleanUpHigherVersions(round = 8) + assert(new File(s"$tmpPath/8.model").exists()) + + manager.cleanUpHigherVersions(round = 4) + assert(!new File(s"$tmpPath/8.model").exists()) + } + + test("test saving rounds") { + val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString + val manager = new CheckpointManager(sc, tmpPath) + assertResult(Seq(7))(manager.getSavingRounds(savingFreq = 0, round = 7)) + assertResult(Seq(2, 4, 6, 7))(manager.getSavingRounds(savingFreq = 2, round = 7)) + manager.updateModel(model4) + assertResult(Seq(4, 6, 7))(manager.getSavingRounds(2, 7)) + } +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 6aff01efd..053dff156 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -16,13 +16,14 @@ package ml.dmlc.xgboost4j.scala.spark +import java.nio.file.Files import java.util.concurrent.LinkedBlockingDeque import scala.util.Random - import ml.dmlc.xgboost4j.java.Rabit import ml.dmlc.xgboost4j.scala.DMatrix import ml.dmlc.xgboost4j.scala.rabit.RabitTracker +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint} import org.apache.spark.ml.linalg.{DenseVector, Vectors, Vector => SparkVector} @@ -73,13 +74,14 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { test("build RDD containing boosters with the specified worker number") { val trainingRDD = sc.parallelize(Classification.train) + val partitionedRDD = XGBoost.repartitionForTraining(trainingRDD, 2) val boosterRDD = XGBoost.buildDistributedBoosters( - trainingRDD, + partitionedRDD, List("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic").toMap, new java.util.HashMap[String, String](), - numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true, - missing = Float.NaN) + round = 5, eval = null, obj = null, useExternalMemory = true, + missing = Float.NaN, prevBooster = null) val boosterCount = boosterRDD.count() assert(boosterCount === 2) } @@ -335,4 +337,33 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { assert(XGBoost.isClassificationTask(params) == isClassificationTask) } } + + test("training with saving checkpoint boosters") { + import DataUtils._ + val eval = new EvalError() + val trainingRDD = sc.parallelize(Classification.train).map(_.asML) + val testSetDMatrix = new DMatrix(Classification.test.iterator) + + val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString + val paramMap = List("eta" -> "1", "max_depth" -> 2, "silent" -> "1", + "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath, + "saving_frequency" -> 2).toMap + val prevModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, + nWorkers = numWorkers) + def error(model: XGBoostModel): Float = eval.eval( + model.booster.predict(testSetDMatrix, outPutMargin = true), testSetDMatrix) + + // Check only one model is kept after training + val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) + assert(files.length == 1) + assert(files.head.getPath.getName == "8.model") + val tmpModel = XGBoost.loadModelFromHadoopFile(s"$tmpPath/8.model") + + // Train next model based on prev model + val nextModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 8, + nWorkers = numWorkers) + assert(error(tmpModel) > error(prevModel)) + assert(error(prevModel) > error(nextModel)) + assert(error(nextModel) < 0.1) + } } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 40e00db7c..91fc2cb4d 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -34,6 +34,7 @@ public class Booster implements Serializable, KryoSerializable { private static final Log logger = LogFactory.getLog(Booster.class); // handle to the booster. private long handle = 0; + private int version = 0; /** * Create a new Booster with empty stage. @@ -416,6 +417,14 @@ public class Booster implements Serializable, KryoSerializable { return modelInfos[0]; } + public int getVersion() { + return this.version; + } + + public void setVersion(int version) { + this.version = version; + } + /** * * @return the saved byte array. @@ -436,16 +445,18 @@ public class Booster implements Serializable, KryoSerializable { int loadRabitCheckpoint() throws XGBoostError { int[] out = new int[1]; XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out)); - return out[0]; + version = out[0]; + return version; } /** - * Save the booster model into thread-local rabit checkpoint. + * Save the booster model into thread-local rabit checkpoint and increment the version. * This is only used in distributed training. * @throws XGBoostError */ void saveRabitCheckpoint() throws XGBoostError { XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle)); + version += 1; } /** @@ -481,6 +492,7 @@ public class Booster implements Serializable, KryoSerializable { // making Booster serializable private void writeObject(java.io.ObjectOutputStream out) throws IOException { try { + out.writeInt(version); out.writeObject(this.toByteArray()); } catch (XGBoostError ex) { ex.printStackTrace(); @@ -492,6 +504,7 @@ public class Booster implements Serializable, KryoSerializable { throws IOException, ClassNotFoundException { try { this.init(null); + this.version = in.readInt(); byte[] bytes = (byte[])in.readObject(); XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes)); } catch (XGBoostError ex) { @@ -520,6 +533,7 @@ public class Booster implements Serializable, KryoSerializable { int serObjSize = serObj.length; System.out.println("==== serialized obj size " + serObjSize); output.writeInt(serObjSize); + output.writeInt(version); output.write(serObj); } catch (XGBoostError ex) { ex.printStackTrace(); @@ -532,6 +546,7 @@ public class Booster implements Serializable, KryoSerializable { try { this.init(null); int serObjSize = input.readInt(); + this.version = input.readInt(); System.out.println("==== the size of the object: " + serObjSize); byte[] bytes = new byte[serObjSize]; input.readBytes(bytes); diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index 8021e878f..df030105d 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -57,6 +57,18 @@ public class XGBoost { return Booster.loadModel(in); } + /** + * Train a booster given parameters. + * + * @param dtrain Data to be trained. + * @param params Parameters. + * @param round Number of boosting iterations. + * @param watches a group of items to be evaluated during training, this allows user to watch + * performance on the validation set. + * @param obj customized objective + * @param eval customized evaluation + * @return The trained booster. + */ public static Booster train( DMatrix dtrain, Map params, @@ -67,6 +79,23 @@ public class XGBoost { return train(dtrain, params, round, watches, null, obj, eval, 0); } + /** + * Train a booster given parameters. + * + * @param dtrain Data to be trained. + * @param params Parameters. + * @param round Number of boosting iterations. + * @param watches a group of items to be evaluated during training, this allows user to watch + * performance on the validation set. + * @param metrics array containing the evaluation metrics for each matrix in watches for each + * iteration + * @param earlyStoppingRound if non-zero, training would be stopped + * after a specified number of consecutive + * increases in any evaluation metric. + * @param obj customized objective + * @param eval customized evaluation + * @return The trained booster. + */ public static Booster train( DMatrix dtrain, Map params, @@ -76,6 +105,37 @@ public class XGBoost { IObjective obj, IEvaluation eval, int earlyStoppingRound) throws XGBoostError { + return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null); + } + + /** + * Train a booster given parameters. + * + * @param dtrain Data to be trained. + * @param params Parameters. + * @param round Number of boosting iterations. + * @param watches a group of items to be evaluated during training, this allows user to watch + * performance on the validation set. + * @param metrics array containing the evaluation metrics for each matrix in watches for each + * iteration + * @param earlyStoppingRound if non-zero, training would be stopped + * after a specified number of consecutive + * increases in any evaluation metric. + * @param obj customized objective + * @param eval customized evaluation + * @param booster train from scratch if set to null; train from an existing booster if not null. + * @return The trained booster. + */ + public static Booster train( + DMatrix dtrain, + Map params, + int round, + Map watches, + float[][] metrics, + IObjective obj, + IEvaluation eval, + int earlyStoppingRound, + Booster booster) throws XGBoostError { //collect eval matrixs String[] evalNames; @@ -104,20 +164,24 @@ public class XGBoost { } //initialize booster - Booster booster = new Booster(params, allMats); - - int version = booster.loadRabitCheckpoint(); + if (booster == null) { + // Start training on a new booster + booster = new Booster(params, allMats); + booster.loadRabitCheckpoint(); + } else { + // Start training on an existing booster + booster.setParams(params); + } //begin to train - for (int iter = version / 2; iter < round; iter++) { - if (version % 2 == 0) { + for (int iter = booster.getVersion() / 2; iter < round; iter++) { + if (booster.getVersion() % 2 == 0) { if (obj != null) { booster.update(dtrain, obj); } else { booster.update(dtrain, iter); } booster.saveRabitCheckpoint(); - version += 1; } //evaluation @@ -149,7 +213,6 @@ public class XGBoost { } } booster.saveRabitCheckpoint(); - version += 1; } return booster; } diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index 174c68804..4d0c839f2 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -25,7 +25,7 @@ import ml.dmlc.xgboost4j.java.XGBoostError import scala.collection.JavaConverters._ import scala.collection.mutable -class Booster private[xgboost4j](private var booster: JBooster) +class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) extends Serializable with KryoSerializable { /** diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala index fc05e899d..76c04921a 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala @@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala import java.io.InputStream -import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost, XGBoostError} +import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError} import scala.collection.JavaConverters._ /** @@ -41,6 +41,7 @@ object XGBoost { * increases in any evaluation metric. * @param obj customized objective * @param eval customized evaluation + * @param booster train from scratch if set to null; train from an existing booster if not null. * @return The trained booster. */ @throws(classOf[XGBoostError]) @@ -52,13 +53,19 @@ object XGBoost { metrics: Array[Array[Float]] = null, obj: ObjectiveTrait = null, eval: EvalTrait = null, - earlyStoppingRound: Int = 0): Booster = { + earlyStoppingRound: Int = 0, + booster: Booster = null): Booster = { val jWatches = watches.mapValues(_.jDMatrix).asJava + val jBooster = if (booster == null) { + null + } else { + booster.booster + } val xgboostInJava = JXGBoost.train( dtrain.jDMatrix, // we have to filter null value for customized obj and eval params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava, - round, jWatches, metrics, obj, eval, earlyStoppingRound) + round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster) new Booster(xgboostInJava) } diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index 0d790c9ce..1a7ad9d68 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -15,10 +15,7 @@ */ package ml.dmlc.xgboost4j.java; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; +import java.io.*; import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; @@ -347,4 +344,55 @@ public class BoosterImplTest { int nfold = 5; String[] evalHist = XGBoost.crossValidation(trainMat, param, round, nfold, null, null, null); } + + /** + * test train from existing model + * + * @throws XGBoostError + */ + @Test + public void testTrainFromExistingModel() throws XGBoostError, IOException { + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + IEvaluation eval = new EvalError(); + + Map paramMap = new HashMap() { + { + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); + put("objective", "binary:logistic"); + } + }; + + //set watchList + HashMap watches = new HashMap(); + + watches.put("train", trainMat); + watches.put("test", testMat); + + // Train without saving temp booster + int round = 4; + Booster booster1 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0); + float booster1error = eval.eval(booster1.predict(testMat, true, 0), testMat); + + // Train with temp Booster + round = 2; + Booster tempBooster = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0); + float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat); + + // Save tempBooster to bytestream and load back + int prevVersion = tempBooster.getVersion(); + ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray()); + tempBooster = XGBoost.loadModel(in); + in.close(); + tempBooster.setVersion(prevVersion); + + // Continue training using tempBooster + round = 4; + Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster); + float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat); + TestCase.assertTrue(booster1error == booster2error); + TestCase.assertTrue(tempBoosterError > booster2error); + } }