[jvm-packages] update rabit, surface new changes to spark, add parity and failure tests (#4876)
* Expose sets of rabit configurations to spark layer
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -21,11 +21,11 @@ import java.nio.file.Files
|
||||
|
||||
import scala.collection.{AbstractIterator, mutable}
|
||||
import scala.util.Random
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoost.logger
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
@@ -155,7 +155,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
overridedParams
|
||||
}
|
||||
|
||||
def buildXGBRuntimeParams: XGBoostExecutionParams = {
|
||||
private[spark] def buildXGBRuntimeParams: XGBoostExecutionParams = {
|
||||
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
|
||||
val round = overridedParams("num_round").asInstanceOf[Int]
|
||||
val useExternalMemory = overridedParams("use_external_memory").asInstanceOf[Boolean]
|
||||
@@ -221,6 +221,25 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
xgbExecParam.setRawParamMap(overridedParams)
|
||||
xgbExecParam
|
||||
}
|
||||
|
||||
private[spark] def buildRabitParams: Map[String, String] = Map(
|
||||
"rabit_reduce_ring_mincount" ->
|
||||
overridedParams.getOrElse("rabit_reduce_ring_mincount", 32<<10).toString,
|
||||
"rabit_reduce_buffer" ->
|
||||
overridedParams.getOrElse("rabit_reduce_buffer", "256MB").toString,
|
||||
"rabit_bootstrap_cache" ->
|
||||
overridedParams.getOrElse("rabit_bootstrap_cache", false).toString,
|
||||
"rabit_debug" ->
|
||||
overridedParams.getOrElse("rabit_debug", false).toString,
|
||||
"rabit_timeout" ->
|
||||
overridedParams.getOrElse("rabit_timeout", false).toString,
|
||||
"rabit_timeout_sec" ->
|
||||
overridedParams.getOrElse("rabit_timeout_sec", 1800).toString,
|
||||
"DMLC_WORKER_CONNECT_RETRY" ->
|
||||
overridedParams.getOrElse("dmlc_worker_connect_retry", 5).toString,
|
||||
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" ->
|
||||
overridedParams.getOrElse("dmlc_worker_stop_process_on_error", false).toString
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -321,7 +340,6 @@ object XGBoost extends Serializable {
|
||||
}
|
||||
val taskId = TaskContext.getPartitionId().toString
|
||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
|
||||
|
||||
try {
|
||||
Rabit.init(rabitEnv)
|
||||
@@ -490,8 +508,9 @@ object XGBoost extends Serializable {
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]] = Map()):
|
||||
(Booster, Map[String, Array[Float]]) = {
|
||||
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
||||
val xgbExecParams = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext).
|
||||
buildXGBRuntimeParams
|
||||
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
|
||||
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
|
||||
val xgbRabitParams = xgbParamsFactory.buildRabitParams
|
||||
val sc = trainingData.sparkContext
|
||||
val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam.
|
||||
checkpointPath)
|
||||
@@ -510,12 +529,12 @@ object XGBoost extends Serializable {
|
||||
val parallelismTracker = new SparkParallelismTracker(sc,
|
||||
xgbExecParams.timeoutRequestWorkers,
|
||||
xgbExecParams.numWorkers)
|
||||
val rabitEnv = tracker.getWorkerEnvs
|
||||
val rabitEnv = tracker.getWorkerEnvs.asScala ++ xgbRabitParams
|
||||
val boostersAndMetrics = if (hasGroup) {
|
||||
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv,
|
||||
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv.asJava,
|
||||
checkpointRound, prevBooster, evalSetsMap)
|
||||
} else {
|
||||
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
|
||||
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv.asJava,
|
||||
checkpointRound, prevBooster, evalSetsMap)
|
||||
}
|
||||
val sparkJobThread = new Thread() {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -50,7 +50,7 @@ class XGBoostClassifier (
|
||||
def this(xgboostParams: Map[String, Any]) = this(
|
||||
Identifiable.randomUID("xgbc"), xgboostParams)
|
||||
|
||||
XGBoostToMLlibParams(xgboostParams)
|
||||
XGBoost2MLlibParams(xgboostParams)
|
||||
|
||||
def setWeightCol(value: String): this.type = set(weightCol, value)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.scala.spark.params._
|
||||
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||
|
||||
private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
|
||||
with BoosterParams with ParamMapFuncs with NonParamVariables {
|
||||
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
|
||||
|
||||
def needDeterministicRepartitioning: Boolean = {
|
||||
getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -54,7 +54,7 @@ class XGBoostRegressor (
|
||||
def this(xgboostParams: Map[String, Any]) = this(
|
||||
Identifiable.randomUID("xgbr"), xgboostParams)
|
||||
|
||||
XGBoostToMLlibParams(xgboostParams)
|
||||
XGBoost2MLlibParams(xgboostParams)
|
||||
|
||||
def setWeightCol(value: String): this.type = set(weightCol, value)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -241,7 +241,7 @@ trait HasNumClass extends Params {
|
||||
|
||||
private[spark] trait ParamMapFuncs extends Params {
|
||||
|
||||
def XGBoostToMLlibParams(xgboostParams: Map[String, Any]): Unit = {
|
||||
def XGBoost2MLlibParams(xgboostParams: Map[String, Any]): Unit = {
|
||||
for ((paramName, paramValue) <- xgboostParams) {
|
||||
if ((paramName == "booster" && paramValue != "gbtree") ||
|
||||
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
/*
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import org.apache.spark.ml.param._
|
||||
|
||||
private[spark] trait RabitParams extends Params {
|
||||
/**
|
||||
* Rabit worker configurations. These parameters were passed to Rabit.Init and decide
|
||||
* rabit_reduce_ring_mincount - threshold of enable ring based allreduce/broadcast operations.
|
||||
* rabit_reduce_buffer - buffer size to recv and run reduction
|
||||
* rabit_bootstrap_cache - enable save allreduce cache before loadcheckpoint
|
||||
* rabit_debug - enable more verbose rabit logging to stdout
|
||||
* rabit_timeout - enable sidecar thread after rabit observed failures
|
||||
* rabit_timeout_sec - wait interval before exit after rabit observed failures
|
||||
* dmlc_worker_connect_retry - number of retrys to tracker
|
||||
* dmlc_worker_stop_process_on_error - exit process when rabit see assert/error
|
||||
*/
|
||||
final val ringReduceMin = new IntParam(this, "rabitReduceRingMincount",
|
||||
"minimal counts of enable allreduce/broadcast with ring based topology",
|
||||
ParamValidators.gtEq(1))
|
||||
|
||||
final def reduceBuffer: Param[String] = new Param[String](this, "rabitReduceBuffer",
|
||||
"buffer size (MB/GB) allocated to each xgb trainner recv and run reduction",
|
||||
(buf: String) => buf.contains("MB") || buf.contains("GB"))
|
||||
|
||||
final def bootstrapCache: BooleanParam = new BooleanParam(this, "rabitBootstrapCache",
|
||||
"enable save allreduce cache before loadcheckpoint, used to allow failed task retry")
|
||||
|
||||
final def rabitDebug: BooleanParam = new BooleanParam(this, "rabitDebug",
|
||||
"enable more verbose rabit logging to stdout")
|
||||
|
||||
final def rabitTimeout: BooleanParam = new BooleanParam(this, "rabitTimeout",
|
||||
"enable failure timeout sidecar threads")
|
||||
|
||||
final def timeoutInterval: IntParam = new IntParam(this, "rabitTimeoutSec",
|
||||
"timeout threshold after rabit observed failures", (interval: Int) => interval > 0)
|
||||
|
||||
final def connectRetry: IntParam = new IntParam(this, "dmlcWorkerConnectRetry",
|
||||
"number of retry worker do before fail", ParamValidators.gtEq(1))
|
||||
|
||||
final def exitOnError: BooleanParam = new BooleanParam(this, "dmlcWorkerStopProcessOnError",
|
||||
"exit process when rabit see assert error")
|
||||
|
||||
setDefault(ringReduceMin -> (32 << 10), reduceBuffer -> "256MB", bootstrapCache -> false,
|
||||
rabitDebug -> false, connectRetry -> 5, rabitTimeout -> false, timeoutInterval -> 1800,
|
||||
exitOnError -> false)
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -16,7 +16,10 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
@@ -28,7 +31,7 @@ class XGBoostConfigureSuite extends FunSuite with PerTest {
|
||||
|
||||
test("nthread configuration must be no larger than spark.task.cpus") {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
||||
"objective" -> "binary:logistic", "num_workers" -> numWorkers,
|
||||
"nthread" -> (sc.getConf.getInt("spark.task.cpus", 1) + 1))
|
||||
intercept[IllegalArgumentException] {
|
||||
@@ -40,7 +43,7 @@ class XGBoostConfigureSuite extends FunSuite with PerTest {
|
||||
// TODO write an isolated test for Booster.
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator, null)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
@@ -52,7 +55,7 @@ class XGBoostConfigureSuite extends FunSuite with PerTest {
|
||||
val originalSslConfOpt = ss.conf.getOption("spark.ssl.enabled")
|
||||
ss.conf.set("spark.ssl.enabled", true)
|
||||
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 2, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
/*
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Rabit, XGBoostError}
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class XGBoostRabitRegressionSuite extends FunSuite with PerTest {
|
||||
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
|
||||
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||
.config("spark.kryo.classesToRegister", classOf[Booster].getName)
|
||||
|
||||
test("test parity classification prediction") {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDF = buildDataFrame(Classification.test)
|
||||
|
||||
val model1 = new XGBoostClassifier(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
).fit(training)
|
||||
val prediction1 = model1.transform(testDF).select("prediction").collect()
|
||||
|
||||
val model2 = new XGBoostClassifier(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"rabit_bootstrap_cache" -> true, "rabit_debug" -> true, "rabit_reduce_ring_mincount" -> 100,
|
||||
"rabit_reduce_buffer" -> "2MB", "DMLC_WORKER_CONNECT_RETRY" -> 1,
|
||||
"rabit_timeout" -> true, "rabit_timeout_sec" -> 5)).fit(training)
|
||||
|
||||
assert(Rabit.rabitEnvs.asScala.size > 7)
|
||||
Rabit.rabitEnvs.asScala.foreach( item => {
|
||||
if (item._1.toString == "rabit_bootstrap_cache") assert(item._2 == "true")
|
||||
if (item._1.toString == "rabit_debug") assert(item._2 == "true")
|
||||
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "100")
|
||||
if (item._1.toString == "rabit_reduce_buffer") assert(item._2 == "2MB")
|
||||
if (item._1.toString == "dmlc_worker_connect_retry") assert(item._2 == "1")
|
||||
if (item._1.toString == "rabit_timeout") assert(item._2 == "true")
|
||||
if (item._1.toString == "rabit_timeout_sec") assert(item._2 == "5")
|
||||
})
|
||||
|
||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
||||
// check parity w/o rabit cache
|
||||
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
|
||||
assert(p1 == p2)
|
||||
}
|
||||
}
|
||||
|
||||
test("test parity regression prediction") {
|
||||
val training = buildDataFrame(Regression.train)
|
||||
val testDM = new DMatrix(Regression.test.iterator, null)
|
||||
val testDF = buildDataFrame(Classification.test)
|
||||
|
||||
val model1 = new XGBoostRegressor(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
).fit(training)
|
||||
val prediction1 = model1.transform(testDF).select("prediction").collect()
|
||||
|
||||
val model2 = new XGBoostRegressor(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
||||
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"rabit_bootstrap_cache" -> true, "rabit_debug" -> true, "rabit_reduce_ring_mincount" -> 100,
|
||||
"rabit_reduce_buffer" -> "2MB", "DMLC_WORKER_CONNECT_RETRY" -> 1,
|
||||
"rabit_timeout" -> true, "rabit_timeout_sec" -> 5)).fit(training)
|
||||
assert(Rabit.rabitEnvs.asScala.size > 7)
|
||||
Rabit.rabitEnvs.asScala.foreach( item => {
|
||||
if (item._1.toString == "rabit_bootstrap_cache") assert(item._2 == "true")
|
||||
if (item._1.toString == "rabit_debug") assert(item._2 == "true")
|
||||
if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "100")
|
||||
if (item._1.toString == "rabit_reduce_buffer") assert(item._2 == "2MB")
|
||||
if (item._1.toString == "dmlc_worker_connect_retry") assert(item._2 == "true")
|
||||
if (item._1.toString == "rabit_timeout") assert(item._2 == "true")
|
||||
if (item._1.toString == "rabit_timeout_sec") assert(item._2 == "5")
|
||||
if (item._1.toString == "DMLC_WORKER_STOP_PROCESS_ON_ERROR") assert(item._2 == "false")
|
||||
})
|
||||
// check the equality of single instance prediction
|
||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
||||
// check parity w/o rabit cache
|
||||
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
|
||||
assert(math.abs(p1 - p2) < 0.00001f)
|
||||
}
|
||||
}
|
||||
|
||||
test("test graceful failure handle") {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDF = buildDataFrame(Classification.test)
|
||||
// mock rank 0 failure during 4th allreduce synchronization
|
||||
Rabit.mockList = Array("0,4,0,0").toList.asJava
|
||||
intercept[XGBoostError] {
|
||||
new XGBoostClassifier(Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"rabit_timeout" -> true, "rabit_timeout_sec" -> 1,
|
||||
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> false)).fit(training)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2019 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
Reference in New Issue
Block a user