[jvm-packages] Saving models into a tmp folder every a few rounds (#2964)
* [jvm-packages] Train Booster from an existing model * Align Scala API with Java API * Existing model should not load rabit checkpoint * Address minor comments * Implement saving temporary boosters and loading previous booster * Add more unit tests for loadPrevBooster * Add params to XGBoostEstimator * (1) Move repartition out of the temp model saving loop (2) Address CR comments * Catch a corner case of training next model with fewer rounds * Address comments * Refactor newly added methods into TmpBoosterManager * Add two files which is missing in previous commit * Rename TmpBooster to checkpoint
This commit is contained in:
parent
eedca8c8ec
commit
9004ca03ca
@ -0,0 +1,139 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.Booster
|
||||||
|
import org.apache.commons.logging.LogFactory
|
||||||
|
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||||
|
import org.apache.spark.SparkContext
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A class which allows user to save checkpoint boosters every a few rounds. If a previous job
|
||||||
|
* fails, the job can restart training from a saved booster instead of from scratch. This class
|
||||||
|
* provides interface and helper methods for the checkpoint functionality.
|
||||||
|
*
|
||||||
|
* @param sc the sparkContext object
|
||||||
|
* @param checkpointPath the hdfs path to store checkpoint boosters
|
||||||
|
*/
|
||||||
|
private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) {
|
||||||
|
private val logger = LogFactory.getLog("XGBoostSpark")
|
||||||
|
private val modelSuffix = ".model"
|
||||||
|
|
||||||
|
private def getPath(version: Int) = {
|
||||||
|
s"$checkpointPath/$version$modelSuffix"
|
||||||
|
}
|
||||||
|
|
||||||
|
private def getExistingVersions: Seq[Int] = {
|
||||||
|
val fs = FileSystem.get(sc.hadoopConfiguration)
|
||||||
|
if (checkpointPath.isEmpty || !fs.exists(new Path(checkpointPath))) {
|
||||||
|
Seq()
|
||||||
|
} else {
|
||||||
|
fs.listStatus(new Path(checkpointPath)).map(_.getPath.getName).collect {
|
||||||
|
case fileName if fileName.endsWith(modelSuffix) => fileName.stripSuffix(modelSuffix).toInt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load existing checkpoint with the highest version.
|
||||||
|
*
|
||||||
|
* @return the booster with the highest version, null if no checkpoints available.
|
||||||
|
*/
|
||||||
|
private[spark] def loadBooster: Booster = {
|
||||||
|
val versions = getExistingVersions
|
||||||
|
if (versions.nonEmpty) {
|
||||||
|
val version = versions.max
|
||||||
|
val fullPath = getPath(version)
|
||||||
|
logger.info(s"Start training from previous booster at $fullPath")
|
||||||
|
val model = XGBoost.loadModelFromHadoopFile(fullPath)(sc)
|
||||||
|
model.booster.booster.setVersion(version)
|
||||||
|
model.booster
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clean up all previous models and save a new model
|
||||||
|
*
|
||||||
|
* @param model the xgboost model to save
|
||||||
|
*/
|
||||||
|
private[spark] def updateModel(model: XGBoostModel): Unit = {
|
||||||
|
val fs = FileSystem.get(sc.hadoopConfiguration)
|
||||||
|
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
|
||||||
|
val fullPath = getPath(model.version)
|
||||||
|
logger.info(s"Saving checkpoint model with version ${model.version} to $fullPath")
|
||||||
|
model.saveModelAsHadoopFile(fullPath)(sc)
|
||||||
|
prevModelPaths.foreach(path => fs.delete(path, true))
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clean up checkpoint boosters with version higher than or equal to the round.
|
||||||
|
*
|
||||||
|
* @param round the number of rounds in the current training job
|
||||||
|
*/
|
||||||
|
private[spark] def cleanUpHigherVersions(round: Int): Unit = {
|
||||||
|
val higherVersions = getExistingVersions.filter(_ / 2 >= round)
|
||||||
|
higherVersions.foreach { version =>
|
||||||
|
val fs = FileSystem.get(sc.hadoopConfiguration)
|
||||||
|
fs.delete(new Path(getPath(version)), true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate a list of checkpoint rounds to save checkpoints based on the savingFreq and
|
||||||
|
* total number of rounds for the training. Concretely, the saving rounds start with
|
||||||
|
* prevRounds + savingFreq, and increase by savingFreq in each step until it reaches total
|
||||||
|
* number of rounds. If savingFreq is 0, the checkpoint will be disabled and the method
|
||||||
|
* returns Seq(round)
|
||||||
|
*
|
||||||
|
* @param savingFreq the increase on rounds during each step of training
|
||||||
|
* @param round the total number of rounds for the training
|
||||||
|
* @return a seq of integers, each represent the index of round to save the checkpoints
|
||||||
|
*/
|
||||||
|
private[spark] def getSavingRounds(savingFreq: Int, round: Int): Seq[Int] = {
|
||||||
|
if (checkpointPath.nonEmpty && savingFreq > 0) {
|
||||||
|
val prevRounds = getExistingVersions.map(_ / 2)
|
||||||
|
val firstSavingRound = (0 +: prevRounds).max + savingFreq
|
||||||
|
(firstSavingRound until round by savingFreq) :+ round
|
||||||
|
} else if (savingFreq <= 0) {
|
||||||
|
Seq(round)
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object CheckpointManager {
|
||||||
|
|
||||||
|
private[spark] def extractParams(params: Map[String, Any]): (String, Int) = {
|
||||||
|
val checkpointPath: String = params.get("checkpoint_path") match {
|
||||||
|
case None => ""
|
||||||
|
case Some(path: String) => path
|
||||||
|
case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
|
||||||
|
" an instance of String.")
|
||||||
|
}
|
||||||
|
|
||||||
|
val savingFreq: Int = params.get("saving_frequency") match {
|
||||||
|
case None => 0
|
||||||
|
case Some(freq: Int) => freq
|
||||||
|
case _ => throw new IllegalArgumentException("parameter \"saving_frequency\" must be" +
|
||||||
|
" an instance of Int.")
|
||||||
|
}
|
||||||
|
(checkpointPath, savingFreq)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -20,7 +20,6 @@ import java.io.File
|
|||||||
|
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
@ -101,23 +100,19 @@ object XGBoost extends Serializable {
|
|||||||
data: RDD[XGBLabeledPoint],
|
data: RDD[XGBLabeledPoint],
|
||||||
params: Map[String, Any],
|
params: Map[String, Any],
|
||||||
rabitEnv: java.util.Map[String, String],
|
rabitEnv: java.util.Map[String, String],
|
||||||
numWorkers: Int,
|
|
||||||
round: Int,
|
round: Int,
|
||||||
obj: ObjectiveTrait,
|
obj: ObjectiveTrait,
|
||||||
eval: EvalTrait,
|
eval: EvalTrait,
|
||||||
useExternalMemory: Boolean,
|
useExternalMemory: Boolean,
|
||||||
missing: Float): RDD[(Booster, Map[String, Array[Float]])] = {
|
missing: Float,
|
||||||
val partitionedData = if (data.getNumPartitions != numWorkers) {
|
prevBooster: Booster
|
||||||
logger.info(s"repartitioning training set to $numWorkers partitions")
|
): RDD[(Booster, Map[String, Array[Float]])] = {
|
||||||
data.repartition(numWorkers)
|
|
||||||
} else {
|
val partitionedBaseMargin = data.map(_.baseMargin)
|
||||||
data
|
|
||||||
}
|
|
||||||
val partitionedBaseMargin = partitionedData.map(_.baseMargin)
|
|
||||||
// to workaround the empty partitions in training dataset,
|
// to workaround the empty partitions in training dataset,
|
||||||
// this might not be the best efficient implementation, see
|
// this might not be the best efficient implementation, see
|
||||||
// (https://github.com/dmlc/xgboost/issues/1277)
|
// (https://github.com/dmlc/xgboost/issues/1277)
|
||||||
partitionedData.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) =>
|
data.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) =>
|
||||||
if (labeledPoints.isEmpty) {
|
if (labeledPoints.isEmpty) {
|
||||||
throw new XGBoostError(
|
throw new XGBoostError(
|
||||||
s"detected an empty partition in the training data, partition ID:" +
|
s"detected an empty partition in the training data, partition ID:" +
|
||||||
@ -145,7 +140,7 @@ object XGBoost extends Serializable {
|
|||||||
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
|
||||||
val booster = SXGBoost.train(watches.train, params, round,
|
val booster = SXGBoost.train(watches.train, params, round,
|
||||||
watches.toMap, metrics, obj, eval,
|
watches.toMap, metrics, obj, eval,
|
||||||
earlyStoppingRound = numEarlyStoppingRounds)
|
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
|
||||||
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||||
} finally {
|
} finally {
|
||||||
Rabit.shutdown()
|
Rabit.shutdown()
|
||||||
@ -330,34 +325,58 @@ object XGBoost extends Serializable {
|
|||||||
case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" +
|
case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" +
|
||||||
" an instance of Long.")
|
" an instance of Long.")
|
||||||
}
|
}
|
||||||
|
val (checkpointPath, savingFeq) = CheckpointManager.extractParams(params)
|
||||||
|
val partitionedData = repartitionForTraining(trainingData, nWorkers)
|
||||||
|
|
||||||
val tracker = startTracker(nWorkers, trackerConf)
|
val sc = trainingData.sparkContext
|
||||||
try {
|
val checkpointManager = new CheckpointManager(sc, checkpointPath)
|
||||||
val sc = trainingData.sparkContext
|
checkpointManager.cleanUpHigherVersions(round)
|
||||||
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
|
||||||
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, trainingData.sparkContext)
|
var prevBooster = checkpointManager.loadBooster
|
||||||
val boostersAndMetrics = buildDistributedBoosters(trainingData, overriddenParams,
|
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||||
tracker.getWorkerEnvs, nWorkers, round, obj, eval, useExternalMemory, missing)
|
checkpointManager.getSavingRounds(savingFeq, round).map {
|
||||||
val sparkJobThread = new Thread() {
|
savingRound: Int =>
|
||||||
override def run() {
|
val tracker = startTracker(nWorkers, trackerConf)
|
||||||
// force the job
|
try {
|
||||||
boostersAndMetrics.foreachPartition(() => _)
|
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
|
||||||
}
|
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
|
||||||
|
val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams,
|
||||||
|
tracker.getWorkerEnvs, savingRound, obj, eval, useExternalMemory, missing, prevBooster)
|
||||||
|
val sparkJobThread = new Thread() {
|
||||||
|
override def run() {
|
||||||
|
// force the job
|
||||||
|
boostersAndMetrics.foreachPartition(() => _)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
||||||
|
sparkJobThread.start()
|
||||||
|
val isClsTask = isClassificationTask(params)
|
||||||
|
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
||||||
|
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
||||||
|
val model = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics,
|
||||||
|
sparkJobThread, isClsTask)
|
||||||
|
if (isClsTask){
|
||||||
|
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
|
||||||
|
params.getOrElse("num_class", "2").toString.toInt
|
||||||
|
}
|
||||||
|
if (savingRound < round) {
|
||||||
|
prevBooster = model.booster
|
||||||
|
checkpointManager.updateModel(model)
|
||||||
|
}
|
||||||
|
model
|
||||||
|
} finally {
|
||||||
|
tracker.stop()
|
||||||
}
|
}
|
||||||
sparkJobThread.setUncaughtExceptionHandler(tracker)
|
}.last
|
||||||
sparkJobThread.start()
|
}
|
||||||
val isClsTask = isClassificationTask(params)
|
|
||||||
val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
|
|
||||||
logger.info(s"Rabit returns with exit code $trackerReturnVal")
|
private[spark] def repartitionForTraining(trainingData: RDD[XGBLabeledPoint], nWorkers: Int) = {
|
||||||
val model = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics,
|
if (trainingData.getNumPartitions != nWorkers) {
|
||||||
sparkJobThread, isClsTask)
|
logger.info(s"repartitioning training set to $nWorkers partitions")
|
||||||
if (isClsTask){
|
trainingData.repartition(nWorkers)
|
||||||
model.asInstanceOf[XGBoostClassificationModel].numOfClasses =
|
} else {
|
||||||
params.getOrElse("num_class", "2").toString.toInt
|
trainingData
|
||||||
}
|
|
||||||
model
|
|
||||||
} finally {
|
|
||||||
tracker.stop()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -405,6 +424,7 @@ object XGBoost extends Serializable {
|
|||||||
xgBoostModel.setPredictionCol(predCol)
|
xgBoostModel.setPredictionCol(predCol)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load XGBoost model from path in HDFS-compatible file system
|
* Load XGBoost model from path in HDFS-compatible file system
|
||||||
*
|
*
|
||||||
|
|||||||
@ -344,6 +344,8 @@ abstract class XGBoostModel(protected var _booster: Booster)
|
|||||||
|
|
||||||
def booster: Booster = _booster
|
def booster: Booster = _booster
|
||||||
|
|
||||||
|
def version: Int = this.booster.booster.getVersion
|
||||||
|
|
||||||
override def copy(extra: ParamMap): XGBoostModel = defaultCopy(extra)
|
override def copy(extra: ParamMap): XGBoostModel = defaultCopy(extra)
|
||||||
|
|
||||||
override def write: MLWriter = new XGBoostModel.XGBoostModelModelWriter(this)
|
override def write: MLWriter = new XGBoostModel.XGBoostModelModelWriter(this)
|
||||||
|
|||||||
@ -77,6 +77,21 @@ trait GeneralParams extends Params {
|
|||||||
" request new Workers if numCores are insufficient. The timeout will be disabled if this" +
|
" request new Workers if numCores are insufficient. The timeout will be disabled if this" +
|
||||||
" value is set smaller than or equal to 0.")
|
" value is set smaller than or equal to 0.")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The hdfs folder to load and save checkpoint boosters. default: `empty_string`
|
||||||
|
*/
|
||||||
|
val checkpointPath = new Param[String](this, "checkpoint_path", "the hdfs folder to load and " +
|
||||||
|
"save checkpoints. The job will try to load the existing booster as the starting point for " +
|
||||||
|
"training. If saving_frequency is also set, the job will save a checkpoint every a few rounds.")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The frequency to save checkpoint boosters. default: 0
|
||||||
|
*/
|
||||||
|
val savingFrequency = new IntParam(this, "saving_frequency", "if checkpoint_path is also set," +
|
||||||
|
" the job will save checkpoints at this frequency. If the job fails and gets restarted with" +
|
||||||
|
" same setting, it will load the existing booster instead of training from scratch." +
|
||||||
|
" Checkpoint will be disabled if set to 0.")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Rabit tracker configurations. The parameter must be provided as an instance of the
|
* Rabit tracker configurations. The parameter must be provided as an instance of the
|
||||||
* TrackerConf class, which has the following definition:
|
* TrackerConf class, which has the following definition:
|
||||||
@ -112,6 +127,7 @@ trait GeneralParams extends Params {
|
|||||||
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
|
setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
|
||||||
useExternalMemory -> false, silent -> 0,
|
useExternalMemory -> false, silent -> 0,
|
||||||
customObj -> null, customEval -> null, missing -> Float.NaN,
|
customObj -> null, customEval -> null, missing -> Float.NaN,
|
||||||
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L
|
trackerConf -> TrackerConf(), seed -> 0, timeoutRequestWorkers -> 30 * 60 * 1000L,
|
||||||
|
checkpointPath -> "", savingFrequency -> 0
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,80 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import java.io.File
|
||||||
|
import java.nio.file.Files
|
||||||
|
|
||||||
|
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||||
|
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||||
|
import org.apache.spark.{SparkConf, SparkContext}
|
||||||
|
|
||||||
|
class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
|
||||||
|
var sc: SparkContext = _
|
||||||
|
|
||||||
|
override def beforeAll(): Unit = {
|
||||||
|
val conf: SparkConf = new SparkConf()
|
||||||
|
.setMaster("local[*]")
|
||||||
|
.setAppName("XGBoostSuite")
|
||||||
|
sc = new SparkContext(conf)
|
||||||
|
}
|
||||||
|
|
||||||
|
private lazy val (model4, model8) = {
|
||||||
|
import DataUtils._
|
||||||
|
val trainingRDD = sc.parallelize(Classification.train).map(_.asML).cache()
|
||||||
|
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic")
|
||||||
|
(XGBoost.trainWithRDD(trainingRDD, paramMap, round = 2, sc.defaultParallelism),
|
||||||
|
XGBoost.trainWithRDD(trainingRDD, paramMap, round = 4, sc.defaultParallelism))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test update/load models") {
|
||||||
|
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
|
||||||
|
val manager = new CheckpointManager(sc, tmpPath)
|
||||||
|
manager.updateModel(model4)
|
||||||
|
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||||
|
assert(files.length == 1)
|
||||||
|
assert(files.head.getPath.getName == "4.model")
|
||||||
|
assert(manager.loadBooster.booster.getVersion == 4)
|
||||||
|
|
||||||
|
manager.updateModel(model8)
|
||||||
|
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||||
|
assert(files.length == 1)
|
||||||
|
assert(files.head.getPath.getName == "8.model")
|
||||||
|
assert(manager.loadBooster.booster.getVersion == 8)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test cleanUpHigherVersions") {
|
||||||
|
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
|
||||||
|
val manager = new CheckpointManager(sc, tmpPath)
|
||||||
|
manager.updateModel(model8)
|
||||||
|
manager.cleanUpHigherVersions(round = 8)
|
||||||
|
assert(new File(s"$tmpPath/8.model").exists())
|
||||||
|
|
||||||
|
manager.cleanUpHigherVersions(round = 4)
|
||||||
|
assert(!new File(s"$tmpPath/8.model").exists())
|
||||||
|
}
|
||||||
|
|
||||||
|
test("test saving rounds") {
|
||||||
|
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
|
||||||
|
val manager = new CheckpointManager(sc, tmpPath)
|
||||||
|
assertResult(Seq(7))(manager.getSavingRounds(savingFreq = 0, round = 7))
|
||||||
|
assertResult(Seq(2, 4, 6, 7))(manager.getSavingRounds(savingFreq = 2, round = 7))
|
||||||
|
manager.updateModel(model4)
|
||||||
|
assertResult(Seq(4, 6, 7))(manager.getSavingRounds(2, 7))
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -16,13 +16,14 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala.spark
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import java.nio.file.Files
|
||||||
import java.util.concurrent.LinkedBlockingDeque
|
import java.util.concurrent.LinkedBlockingDeque
|
||||||
|
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.Rabit
|
import ml.dmlc.xgboost4j.java.Rabit
|
||||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||||
|
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||||
import org.apache.spark.ml.linalg.{DenseVector, Vectors, Vector => SparkVector}
|
import org.apache.spark.ml.linalg.{DenseVector, Vectors, Vector => SparkVector}
|
||||||
@ -73,13 +74,14 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
|
|
||||||
test("build RDD containing boosters with the specified worker number") {
|
test("build RDD containing boosters with the specified worker number") {
|
||||||
val trainingRDD = sc.parallelize(Classification.train)
|
val trainingRDD = sc.parallelize(Classification.train)
|
||||||
|
val partitionedRDD = XGBoost.repartitionForTraining(trainingRDD, 2)
|
||||||
val boosterRDD = XGBoost.buildDistributedBoosters(
|
val boosterRDD = XGBoost.buildDistributedBoosters(
|
||||||
trainingRDD,
|
partitionedRDD,
|
||||||
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||||
"objective" -> "binary:logistic").toMap,
|
"objective" -> "binary:logistic").toMap,
|
||||||
new java.util.HashMap[String, String](),
|
new java.util.HashMap[String, String](),
|
||||||
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = true,
|
round = 5, eval = null, obj = null, useExternalMemory = true,
|
||||||
missing = Float.NaN)
|
missing = Float.NaN, prevBooster = null)
|
||||||
val boosterCount = boosterRDD.count()
|
val boosterCount = boosterRDD.count()
|
||||||
assert(boosterCount === 2)
|
assert(boosterCount === 2)
|
||||||
}
|
}
|
||||||
@ -335,4 +337,33 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
|||||||
assert(XGBoost.isClassificationTask(params) == isClassificationTask)
|
assert(XGBoost.isClassificationTask(params) == isClassificationTask)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("training with saving checkpoint boosters") {
|
||||||
|
import DataUtils._
|
||||||
|
val eval = new EvalError()
|
||||||
|
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||||
|
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||||
|
|
||||||
|
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> 2, "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
||||||
|
"saving_frequency" -> 2).toMap
|
||||||
|
val prevModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||||
|
nWorkers = numWorkers)
|
||||||
|
def error(model: XGBoostModel): Float = eval.eval(
|
||||||
|
model.booster.predict(testSetDMatrix, outPutMargin = true), testSetDMatrix)
|
||||||
|
|
||||||
|
// Check only one model is kept after training
|
||||||
|
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||||
|
assert(files.length == 1)
|
||||||
|
assert(files.head.getPath.getName == "8.model")
|
||||||
|
val tmpModel = XGBoost.loadModelFromHadoopFile(s"$tmpPath/8.model")
|
||||||
|
|
||||||
|
// Train next model based on prev model
|
||||||
|
val nextModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 8,
|
||||||
|
nWorkers = numWorkers)
|
||||||
|
assert(error(tmpModel) > error(prevModel))
|
||||||
|
assert(error(prevModel) > error(nextModel))
|
||||||
|
assert(error(nextModel) < 0.1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -34,6 +34,7 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
private static final Log logger = LogFactory.getLog(Booster.class);
|
private static final Log logger = LogFactory.getLog(Booster.class);
|
||||||
// handle to the booster.
|
// handle to the booster.
|
||||||
private long handle = 0;
|
private long handle = 0;
|
||||||
|
private int version = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new Booster with empty stage.
|
* Create a new Booster with empty stage.
|
||||||
@ -416,6 +417,14 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
return modelInfos[0];
|
return modelInfos[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public int getVersion() {
|
||||||
|
return this.version;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setVersion(int version) {
|
||||||
|
this.version = version;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @return the saved byte array.
|
* @return the saved byte array.
|
||||||
@ -436,16 +445,18 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
int loadRabitCheckpoint() throws XGBoostError {
|
int loadRabitCheckpoint() throws XGBoostError {
|
||||||
int[] out = new int[1];
|
int[] out = new int[1];
|
||||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
||||||
return out[0];
|
version = out[0];
|
||||||
|
return version;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Save the booster model into thread-local rabit checkpoint.
|
* Save the booster model into thread-local rabit checkpoint and increment the version.
|
||||||
* This is only used in distributed training.
|
* This is only used in distributed training.
|
||||||
* @throws XGBoostError
|
* @throws XGBoostError
|
||||||
*/
|
*/
|
||||||
void saveRabitCheckpoint() throws XGBoostError {
|
void saveRabitCheckpoint() throws XGBoostError {
|
||||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
||||||
|
version += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -481,6 +492,7 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
// making Booster serializable
|
// making Booster serializable
|
||||||
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
|
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
|
||||||
try {
|
try {
|
||||||
|
out.writeInt(version);
|
||||||
out.writeObject(this.toByteArray());
|
out.writeObject(this.toByteArray());
|
||||||
} catch (XGBoostError ex) {
|
} catch (XGBoostError ex) {
|
||||||
ex.printStackTrace();
|
ex.printStackTrace();
|
||||||
@ -492,6 +504,7 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
throws IOException, ClassNotFoundException {
|
throws IOException, ClassNotFoundException {
|
||||||
try {
|
try {
|
||||||
this.init(null);
|
this.init(null);
|
||||||
|
this.version = in.readInt();
|
||||||
byte[] bytes = (byte[])in.readObject();
|
byte[] bytes = (byte[])in.readObject();
|
||||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
|
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
|
||||||
} catch (XGBoostError ex) {
|
} catch (XGBoostError ex) {
|
||||||
@ -520,6 +533,7 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
int serObjSize = serObj.length;
|
int serObjSize = serObj.length;
|
||||||
System.out.println("==== serialized obj size " + serObjSize);
|
System.out.println("==== serialized obj size " + serObjSize);
|
||||||
output.writeInt(serObjSize);
|
output.writeInt(serObjSize);
|
||||||
|
output.writeInt(version);
|
||||||
output.write(serObj);
|
output.write(serObj);
|
||||||
} catch (XGBoostError ex) {
|
} catch (XGBoostError ex) {
|
||||||
ex.printStackTrace();
|
ex.printStackTrace();
|
||||||
@ -532,6 +546,7 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
try {
|
try {
|
||||||
this.init(null);
|
this.init(null);
|
||||||
int serObjSize = input.readInt();
|
int serObjSize = input.readInt();
|
||||||
|
this.version = input.readInt();
|
||||||
System.out.println("==== the size of the object: " + serObjSize);
|
System.out.println("==== the size of the object: " + serObjSize);
|
||||||
byte[] bytes = new byte[serObjSize];
|
byte[] bytes = new byte[serObjSize];
|
||||||
input.readBytes(bytes);
|
input.readBytes(bytes);
|
||||||
|
|||||||
@ -57,6 +57,18 @@ public class XGBoost {
|
|||||||
return Booster.loadModel(in);
|
return Booster.loadModel(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Train a booster given parameters.
|
||||||
|
*
|
||||||
|
* @param dtrain Data to be trained.
|
||||||
|
* @param params Parameters.
|
||||||
|
* @param round Number of boosting iterations.
|
||||||
|
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||||
|
* performance on the validation set.
|
||||||
|
* @param obj customized objective
|
||||||
|
* @param eval customized evaluation
|
||||||
|
* @return The trained booster.
|
||||||
|
*/
|
||||||
public static Booster train(
|
public static Booster train(
|
||||||
DMatrix dtrain,
|
DMatrix dtrain,
|
||||||
Map<String, Object> params,
|
Map<String, Object> params,
|
||||||
@ -67,6 +79,23 @@ public class XGBoost {
|
|||||||
return train(dtrain, params, round, watches, null, obj, eval, 0);
|
return train(dtrain, params, round, watches, null, obj, eval, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Train a booster given parameters.
|
||||||
|
*
|
||||||
|
* @param dtrain Data to be trained.
|
||||||
|
* @param params Parameters.
|
||||||
|
* @param round Number of boosting iterations.
|
||||||
|
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||||
|
* performance on the validation set.
|
||||||
|
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||||
|
* iteration
|
||||||
|
* @param earlyStoppingRound if non-zero, training would be stopped
|
||||||
|
* after a specified number of consecutive
|
||||||
|
* increases in any evaluation metric.
|
||||||
|
* @param obj customized objective
|
||||||
|
* @param eval customized evaluation
|
||||||
|
* @return The trained booster.
|
||||||
|
*/
|
||||||
public static Booster train(
|
public static Booster train(
|
||||||
DMatrix dtrain,
|
DMatrix dtrain,
|
||||||
Map<String, Object> params,
|
Map<String, Object> params,
|
||||||
@ -76,6 +105,37 @@ public class XGBoost {
|
|||||||
IObjective obj,
|
IObjective obj,
|
||||||
IEvaluation eval,
|
IEvaluation eval,
|
||||||
int earlyStoppingRound) throws XGBoostError {
|
int earlyStoppingRound) throws XGBoostError {
|
||||||
|
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Train a booster given parameters.
|
||||||
|
*
|
||||||
|
* @param dtrain Data to be trained.
|
||||||
|
* @param params Parameters.
|
||||||
|
* @param round Number of boosting iterations.
|
||||||
|
* @param watches a group of items to be evaluated during training, this allows user to watch
|
||||||
|
* performance on the validation set.
|
||||||
|
* @param metrics array containing the evaluation metrics for each matrix in watches for each
|
||||||
|
* iteration
|
||||||
|
* @param earlyStoppingRound if non-zero, training would be stopped
|
||||||
|
* after a specified number of consecutive
|
||||||
|
* increases in any evaluation metric.
|
||||||
|
* @param obj customized objective
|
||||||
|
* @param eval customized evaluation
|
||||||
|
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
||||||
|
* @return The trained booster.
|
||||||
|
*/
|
||||||
|
public static Booster train(
|
||||||
|
DMatrix dtrain,
|
||||||
|
Map<String, Object> params,
|
||||||
|
int round,
|
||||||
|
Map<String, DMatrix> watches,
|
||||||
|
float[][] metrics,
|
||||||
|
IObjective obj,
|
||||||
|
IEvaluation eval,
|
||||||
|
int earlyStoppingRound,
|
||||||
|
Booster booster) throws XGBoostError {
|
||||||
|
|
||||||
//collect eval matrixs
|
//collect eval matrixs
|
||||||
String[] evalNames;
|
String[] evalNames;
|
||||||
@ -104,20 +164,24 @@ public class XGBoost {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//initialize booster
|
//initialize booster
|
||||||
Booster booster = new Booster(params, allMats);
|
if (booster == null) {
|
||||||
|
// Start training on a new booster
|
||||||
int version = booster.loadRabitCheckpoint();
|
booster = new Booster(params, allMats);
|
||||||
|
booster.loadRabitCheckpoint();
|
||||||
|
} else {
|
||||||
|
// Start training on an existing booster
|
||||||
|
booster.setParams(params);
|
||||||
|
}
|
||||||
|
|
||||||
//begin to train
|
//begin to train
|
||||||
for (int iter = version / 2; iter < round; iter++) {
|
for (int iter = booster.getVersion() / 2; iter < round; iter++) {
|
||||||
if (version % 2 == 0) {
|
if (booster.getVersion() % 2 == 0) {
|
||||||
if (obj != null) {
|
if (obj != null) {
|
||||||
booster.update(dtrain, obj);
|
booster.update(dtrain, obj);
|
||||||
} else {
|
} else {
|
||||||
booster.update(dtrain, iter);
|
booster.update(dtrain, iter);
|
||||||
}
|
}
|
||||||
booster.saveRabitCheckpoint();
|
booster.saveRabitCheckpoint();
|
||||||
version += 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//evaluation
|
//evaluation
|
||||||
@ -149,7 +213,6 @@ public class XGBoost {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
booster.saveRabitCheckpoint();
|
booster.saveRabitCheckpoint();
|
||||||
version += 1;
|
|
||||||
}
|
}
|
||||||
return booster;
|
return booster;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -25,7 +25,7 @@ import ml.dmlc.xgboost4j.java.XGBoostError
|
|||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
|
||||||
class Booster private[xgboost4j](private var booster: JBooster)
|
class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||||
extends Serializable with KryoSerializable {
|
extends Serializable with KryoSerializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala
|
|||||||
|
|
||||||
import java.io.InputStream
|
import java.io.InputStream
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost, XGBoostError}
|
import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError}
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -41,6 +41,7 @@ object XGBoost {
|
|||||||
* increases in any evaluation metric.
|
* increases in any evaluation metric.
|
||||||
* @param obj customized objective
|
* @param obj customized objective
|
||||||
* @param eval customized evaluation
|
* @param eval customized evaluation
|
||||||
|
* @param booster train from scratch if set to null; train from an existing booster if not null.
|
||||||
* @return The trained booster.
|
* @return The trained booster.
|
||||||
*/
|
*/
|
||||||
@throws(classOf[XGBoostError])
|
@throws(classOf[XGBoostError])
|
||||||
@ -52,13 +53,19 @@ object XGBoost {
|
|||||||
metrics: Array[Array[Float]] = null,
|
metrics: Array[Array[Float]] = null,
|
||||||
obj: ObjectiveTrait = null,
|
obj: ObjectiveTrait = null,
|
||||||
eval: EvalTrait = null,
|
eval: EvalTrait = null,
|
||||||
earlyStoppingRound: Int = 0): Booster = {
|
earlyStoppingRound: Int = 0,
|
||||||
|
booster: Booster = null): Booster = {
|
||||||
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
val jWatches = watches.mapValues(_.jDMatrix).asJava
|
||||||
|
val jBooster = if (booster == null) {
|
||||||
|
null
|
||||||
|
} else {
|
||||||
|
booster.booster
|
||||||
|
}
|
||||||
val xgboostInJava = JXGBoost.train(
|
val xgboostInJava = JXGBoost.train(
|
||||||
dtrain.jDMatrix,
|
dtrain.jDMatrix,
|
||||||
// we have to filter null value for customized obj and eval
|
// we have to filter null value for customized obj and eval
|
||||||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
|
||||||
round, jWatches, metrics, obj, eval, earlyStoppingRound)
|
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
||||||
new Booster(xgboostInJava)
|
new Booster(xgboostInJava)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -15,10 +15,7 @@
|
|||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.*;
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.io.FileOutputStream;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
@ -347,4 +344,55 @@ public class BoosterImplTest {
|
|||||||
int nfold = 5;
|
int nfold = 5;
|
||||||
String[] evalHist = XGBoost.crossValidation(trainMat, param, round, nfold, null, null, null);
|
String[] evalHist = XGBoost.crossValidation(trainMat, param, round, nfold, null, null, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* test train from existing model
|
||||||
|
*
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void testTrainFromExistingModel() throws XGBoostError, IOException {
|
||||||
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
IEvaluation eval = new EvalError();
|
||||||
|
|
||||||
|
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||||
|
{
|
||||||
|
put("eta", 1.0);
|
||||||
|
put("max_depth", 2);
|
||||||
|
put("silent", 1);
|
||||||
|
put("objective", "binary:logistic");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//set watchList
|
||||||
|
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||||
|
|
||||||
|
watches.put("train", trainMat);
|
||||||
|
watches.put("test", testMat);
|
||||||
|
|
||||||
|
// Train without saving temp booster
|
||||||
|
int round = 4;
|
||||||
|
Booster booster1 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0);
|
||||||
|
float booster1error = eval.eval(booster1.predict(testMat, true, 0), testMat);
|
||||||
|
|
||||||
|
// Train with temp Booster
|
||||||
|
round = 2;
|
||||||
|
Booster tempBooster = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0);
|
||||||
|
float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat);
|
||||||
|
|
||||||
|
// Save tempBooster to bytestream and load back
|
||||||
|
int prevVersion = tempBooster.getVersion();
|
||||||
|
ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray());
|
||||||
|
tempBooster = XGBoost.loadModel(in);
|
||||||
|
in.close();
|
||||||
|
tempBooster.setVersion(prevVersion);
|
||||||
|
|
||||||
|
// Continue training using tempBooster
|
||||||
|
round = 4;
|
||||||
|
Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster);
|
||||||
|
float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat);
|
||||||
|
TestCase.assertTrue(booster1error == booster2error);
|
||||||
|
TestCase.assertTrue(tempBoosterError > booster2error);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user