From b29b8c2f34cd0be58e8cadd9bb388e85898d4ae4 Mon Sep 17 00:00:00 2001 From: Chen Qin Date: Fri, 1 Nov 2019 14:21:19 -0700 Subject: [PATCH] [jvm-packages] update rabit, surface new changes to spark, add parity and failure tests (#4966) * [phase 1] expose sets of rabit configurations to spark layer * add back mutable import * disable ring_mincount till https://github.com/dmlc/rabit/pull/106d * Revert "disable ring_mincount till https://github.com/dmlc/rabit/pull/106d" This reverts commit 65e95a98e24f5eb53c6ba9ef9b2379524258984d. * apply latest rabit * fix build error * apply https://github.com/dmlc/xgboost/pull/4880 * downgrade cmake in rabit * point to rabit with DMLC_ROOT fix * relative path of rabit install prefix * split rabit parameters to another trait * misc * misc * Delete .classpath * Delete .classpath * Delete .classpath * Update XGBoostClassifier.scala * Update XGBoostRegressor.scala * Update GeneralParams.scala * Update GeneralParams.scala * Update GeneralParams.scala * Update GeneralParams.scala * Delete .classpath * Update RabitParams.scala * Update .gitignore * Update .gitignore * apply rabitParams to training * use string as rabit parameter value type * cleanup * add rabitEnv check * point to dmlc/rabit * per feedback * update private scope * misc * update rabit * add rabit_timtout, fix failing test. * split tests * allow build jvm with rabit mock * pass mock failures to rabit with test * add mock error and graceful handle rabit assertion error test * split mvn test * remove sign for test * update rabit * build jvm_packages with rabit mock * point back to dmlc/rabit * per feedback, update scala header * cleanup pom * per feedback * try fix lint * fix lint * per feedback, remove bootstrap_cache * per feedback 2 * try replace dev profile with passing mvn property * fix build error * remove mvn property and replace with env setting to build test jar * per feedback * revert copyright headlines, point to dmlc/rabit * revert python lint * remove multiple failure test case as retry is not enabled in spark * Update core.py * Update core.py * per feedback, style fix --- CMakeLists.txt | 40 ++---- jvm-packages/.gitignore | 3 +- jvm-packages/create_jni.py | 5 + .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 40 ++++-- .../scala/spark/XGBoostClassifier.scala | 2 +- .../scala/spark/XGBoostEstimatorCommon.scala | 2 +- .../scala/spark/XGBoostRegressor.scala | 2 +- .../scala/spark/params/GeneralParams.scala | 2 +- .../scala/spark/params/RabitParams.scala | 40 ++++++ .../scala/spark/XGBoostConfigureSuite.scala | 9 +- .../spark/XGBoostRabitRegressionSuite.scala | 120 ++++++++++++++++++ .../java/ml/dmlc/xgboost4j/java/Rabit.java | 13 +- rabit | 2 +- tests/ci_build/build_jvm_packages.sh | 2 +- tests/travis/run_test.sh | 1 + 15 files changed, 232 insertions(+), 51 deletions(-) create mode 100644 jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/RabitParams.scala create mode 100644 jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala diff --git a/CMakeLists.txt b/CMakeLists.txt index c70212f6b..f2316882b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,36 +92,16 @@ set_target_properties(dmlc PROPERTIES list(APPEND LINKED_LIBRARIES_PRIVATE dmlc) # rabit -# full rabit doesn't build on windows, so we can't import it as subdirectory -if(MINGW OR R_LIB OR WIN32) - set(RABIT_SOURCES - rabit/src/engine_empty.cc - rabit/src/c_api.cc) -else () - if(RABIT_MOCK) - set(RABIT_SOURCES - rabit/src/allreduce_base.cc - rabit/src/allreduce_robust.cc - rabit/src/engine_mock.cc - rabit/src/c_api.cc) - else() - set(RABIT_SOURCES - rabit/src/allreduce_base.cc - rabit/src/allreduce_robust.cc - rabit/src/engine.cc - rabit/src/c_api.cc) - endif(RABIT_MOCK) -endif (MINGW OR R_LIB OR WIN32) -add_library(rabit STATIC ${RABIT_SOURCES}) -target_include_directories(rabit PRIVATE - $ - $) -set_target_properties(rabit - PROPERTIES - CXX_STANDARD 11 - CXX_STANDARD_REQUIRED ON - POSITION_INDEPENDENT_CODE ON) -list(APPEND LINKED_LIBRARIES_PRIVATE rabit) +set(RABIT_BUILD_DMLC OFF) +set(DMLC_ROOT ${xgboost_SOURCE_DIR}/dmlc-core) +set(RABIT_WITH_R_LIB ${R_LIB}) +add_subdirectory(rabit) + +if (RABIT_MOCK) + list(APPEND LINKED_LIBRARIES_PRIVATE rabit_mock_static) +else() + list(APPEND LINKED_LIBRARIES_PRIVATE rabit) +endif(RABIT_MOCK) # Exports some R specific definitions and objects if (R_LIB) diff --git a/jvm-packages/.gitignore b/jvm-packages/.gitignore index d1d4390d6..becd9e300 100644 --- a/jvm-packages/.gitignore +++ b/jvm-packages/.gitignore @@ -1,2 +1,3 @@ tracker.py -build.sh \ No newline at end of file +build.sh + diff --git a/jvm-packages/create_jni.py b/jvm-packages/create_jni.py index 4d627cb7c..11c33d584 100755 --- a/jvm-packages/create_jni.py +++ b/jvm-packages/create_jni.py @@ -89,6 +89,11 @@ if __name__ == "__main__": maybe_parallel_build = "" args = ["-D{0}:BOOL={1}".format(k, v) for k, v in CONFIG.items()] + + # if enviorment set rabit_mock + if os.getenv("RABIT_MOCK", None) is not None: + args.append("-DRABIT_MOCK:BOOL=ON") + run("cmake .. " + " ".join(args) + maybe_generator) run("cmake --build . --config Release" + maybe_parallel_build) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 5bd847e0f..422cad9a9 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -21,11 +21,11 @@ import java.nio.file.Files import scala.collection.{AbstractIterator, mutable} import scala.util.Random +import scala.collection.JavaConverters._ import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.scala.rabit.RabitTracker import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam -import ml.dmlc.xgboost4j.scala.spark.XGBoost.logger import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} @@ -221,6 +221,24 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s xgbExecParam.setRawParamMap(overridedParams) xgbExecParam } + + private[spark] def buildRabitParams : Map[String, String] = Map( + "rabit_reduce_ring_mincount" -> + overridedParams.getOrElse("rabit_ring_reduce_threshold", 32 << 10).toString, + "rabit_debug" -> + (overridedParams.getOrElse("verbosity", 0).toString.toInt == 3).toString, + "rabit_timeout" -> + (overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0).toString, + "rabit_timeout_sec" -> { + if (overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0) { + overridedParams.get("rabit_timeout").toString + } else { + "1800" + } + }, + "DMLC_WORKER_CONNECT_RETRY" -> + overridedParams.getOrElse("dmlc_worker_connect_retry", 5).toString + ) } /** @@ -320,7 +338,9 @@ object XGBoost extends Serializable { s" ${TaskContext.getPartitionId()}") } val taskId = TaskContext.getPartitionId().toString + val attempt = TaskContext.get().attemptNumber.toString rabitEnv.put("DMLC_TASK_ID", taskId) + rabitEnv.put("DMLC_NUM_ATTEMPT", attempt) rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false") try { @@ -333,7 +353,7 @@ object XGBoost extends Serializable { Iterator(booster -> watches.toMap.keys.zip(metrics).toMap) } catch { case xgbException: XGBoostError => - logger.error(s"XGBooster worker $taskId has failed due to ", xgbException) + logger.error(s"XGBooster worker $taskId has failed $attempt times due to ", xgbException) throw xgbException } finally { Rabit.shutdown() @@ -490,8 +510,9 @@ object XGBoost extends Serializable { evalSetsMap: Map[String, RDD[XGBLabeledPoint]] = Map()): (Booster, Map[String, Array[Float]]) = { logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}") - val xgbExecParams = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext). - buildXGBRuntimeParams + val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext) + val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams + val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava val sc = trainingData.sparkContext val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam. checkpointPath) @@ -510,13 +531,14 @@ object XGBoost extends Serializable { val parallelismTracker = new SparkParallelismTracker(sc, xgbExecParams.timeoutRequestWorkers, xgbExecParams.numWorkers) - val rabitEnv = tracker.getWorkerEnvs + + tracker.getWorkerEnvs().putAll(xgbRabitParams) val boostersAndMetrics = if (hasGroup) { - trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, - checkpointRound, prevBooster, evalSetsMap) + trainForRanking(transformedTrainingData.left.get, xgbExecParams, + tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap) } else { - trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv, - checkpointRound, prevBooster, evalSetsMap) + trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, + tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap) } val sparkJobThread = new Thread() { override def run() { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index db4936430..d371603e3 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -50,7 +50,7 @@ class XGBoostClassifier ( def this(xgboostParams: Map[String, Any]) = this( Identifiable.randomUID("xgbc"), xgboostParams) - XGBoostToMLlibParams(xgboostParams) + XGBoost2MLlibParams(xgboostParams) def setWeightCol(value: String): this.type = set(weightCol, value) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorCommon.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorCommon.scala index 1213a8f72..a69097c80 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorCommon.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorCommon.scala @@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.scala.spark.params._ import org.apache.spark.ml.param.shared.HasWeightCol private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams - with BoosterParams with ParamMapFuncs with NonParamVariables { + with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables { def needDeterministicRepartitioning: Boolean = { getCheckpointPath.nonEmpty && getCheckpointInterval > 0 diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index e2f22c7af..210eaab3b 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -54,7 +54,7 @@ class XGBoostRegressor ( def this(xgboostParams: Map[String, Any]) = this( Identifiable.randomUID("xgbr"), xgboostParams) - XGBoostToMLlibParams(xgboostParams) + XGBoost2MLlibParams(xgboostParams) def setWeightCol(value: String): this.type = set(weightCol, value) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index 076dff42e..ab0e33278 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -241,7 +241,7 @@ trait HasNumClass extends Params { private[spark] trait ParamMapFuncs extends Params { - def XGBoostToMLlibParams(xgboostParams: Map[String, Any]): Unit = { + def XGBoost2MLlibParams(xgboostParams: Map[String, Any]): Unit = { for ((paramName, paramValue) <- xgboostParams) { if ((paramName == "booster" && paramValue != "gbtree") || (paramName == "updater" && paramValue != "grow_histmaker,prune" && diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/RabitParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/RabitParams.scala new file mode 100644 index 000000000..6b811e0d1 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/RabitParams.scala @@ -0,0 +1,40 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark.params + +import org.apache.spark.ml.param._ + +private[spark] trait RabitParams extends Params { + /** + * Rabit parameters passed through Rabit.Init into native layer + * rabit_ring_reduce_threshold - minimal threshold to enable ring based allreduce operation + * rabit_timeout - wait interval before exit after rabit observed failures set -1 to disable + * dmlc_worker_connect_retry - number of retrys to tracker + * dmlc_worker_stop_process_on_error - exit process when rabit see assert/error + */ + final val rabitRingReduceThreshold = new IntParam(this, "rabitRingReduceThreshold", + "threshold count to enable allreduce/broadcast with ring based topology", + ParamValidators.gtEq(1)) + + final def rabitTimeout: IntParam = new IntParam(this, "rabitTimeout", + "timeout threshold after rabit observed failures") + + final def rabitConnectRetry: IntParam = new IntParam(this, "dmlcWorkerConnectRetry", + "number of retry worker do before fail", ParamValidators.gtEq(1)) + + setDefault(rabitRingReduceThreshold -> (32 << 10), rabitConnectRetry -> 5, rabitTimeout -> -1) +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala index fe16bcda5..4b3d8d7c9 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala @@ -16,7 +16,10 @@ package ml.dmlc.xgboost4j.scala.spark +import ml.dmlc.xgboost4j.java.Rabit import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} + +import scala.collection.JavaConverters._ import org.apache.spark.sql._ import org.scalatest.FunSuite @@ -28,7 +31,7 @@ class XGBoostConfigureSuite extends FunSuite with PerTest { test("nthread configuration must be no larger than spark.task.cpus") { val training = buildDataFrame(Classification.train) - val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", + val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1", "objective" -> "binary:logistic", "num_workers" -> numWorkers, "nthread" -> (sc.getConf.getInt("spark.task.cpus", 1) + 1)) intercept[IllegalArgumentException] { @@ -40,7 +43,7 @@ class XGBoostConfigureSuite extends FunSuite with PerTest { // TODO write an isolated test for Booster. val training = buildDataFrame(Classification.train) val testDM = new DMatrix(Classification.test.iterator, null) - val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", + val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1", "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers) val model = new XGBoostClassifier(paramMap).fit(training) @@ -52,7 +55,7 @@ class XGBoostConfigureSuite extends FunSuite with PerTest { val originalSslConfOpt = ss.conf.getOption("spark.ssl.enabled") ss.conf.set("spark.ssl.enabled", true) - val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", + val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1", "objective" -> "binary:logistic", "num_round" -> 2, "num_workers" -> numWorkers) val training = buildDataFrame(Classification.train) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala new file mode 100644 index 000000000..12ba9366a --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala @@ -0,0 +1,120 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import ml.dmlc.xgboost4j.java.{Rabit, XGBoostError} +import ml.dmlc.xgboost4j.scala.{Booster, DMatrix} +import org.apache.spark.TaskFailedListener +import org.apache.spark.SparkException +import scala.collection.JavaConverters._ +import org.apache.spark.sql._ +import org.scalatest.FunSuite + +class XGBoostRabitRegressionSuite extends FunSuite with PerTest { + val predictionErrorMin = 0.00001f + val maxFailure = 2; + + override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .config("spark.kryo.classesToRegister", classOf[Booster].getName) + .master(s"local[${numWorkers},${maxFailure}]") + + private def waitAndCheckSparkShutdown(waitMiliSec: Int): Boolean = { + var totalWaitedTime = 0L + while (!ss.sparkContext.isStopped && totalWaitedTime <= waitMiliSec) { + Thread.sleep(10) + totalWaitedTime += 10 + } + return ss.sparkContext.isStopped + } + + test("test classification prediction parity w/o ring reduce") { + val training = buildDataFrame(Classification.train) + val testDF = buildDataFrame(Classification.test) + + val xgbSettings = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1", + "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers) + + val model1 = new XGBoostClassifier(xgbSettings).fit(training) + val prediction1 = model1.transform(testDF).select("prediction").collect() + + val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)) + .fit(training) + + assert(Rabit.rabitEnvs.asScala.size > 3) + Rabit.rabitEnvs.asScala.foreach( item => { + if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1") + }) + + val prediction2 = model2.transform(testDF).select("prediction").collect() + // check parity w/o rabit cache + prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) => + assert(p1 == p2) + } + } + + test("test regression prediction parity w/o ring reduce") { + val training = buildDataFrame(Regression.train) + val testDM = new DMatrix(Regression.test.iterator, null) + val testDF = buildDataFrame(Classification.test) + val xgbSettings = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1", + "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers) + val model1 = new XGBoostRegressor(xgbSettings).fit(training) + + val prediction1 = model1.transform(testDF).select("prediction").collect() + + val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1) + ).fit(training) + assert(Rabit.rabitEnvs.asScala.size > 3) + Rabit.rabitEnvs.asScala.foreach( item => { + if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1") + }) + // check the equality of single instance prediction + val prediction2 = model2.transform(testDF).select("prediction").collect() + // check parity w/o rabit cache + prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) => + assert(math.abs(p1 - p2) < predictionErrorMin) + } + } + + test("test rabit timeout fail handle") { + // disable spark kill listener to verify if rabit_timeout take effect and kill tasks + TaskFailedListener.killerStarted = true + + val training = buildDataFrame(Classification.train) + // mock rank 0 failure during 8th allreduce synchronization + Rabit.mockList = Array("0,8,0,0").toList.asJava + + try { + new XGBoostClassifier(Map( + "eta" -> "0.1", + "max_depth" -> "10", + "verbosity" -> "1", + "objective" -> "binary:logistic", + "num_round" -> 5, + "num_workers" -> numWorkers, + "rabit_timeout" -> 0)) + .fit(training) + } catch { + case e: Throwable => // swallow anything + } finally { + // assume all tasks throw exception almost same time + // 100ms should be enough to exhaust all retries + assert(waitAndCheckSparkShutdown(100) == true) + } + } +} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java index 35b500757..7e019dc65 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java @@ -3,6 +3,8 @@ package ml.dmlc.xgboost4j.java; import java.io.Serializable; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.LinkedList; +import java.util.List; import java.util.Map; /** @@ -51,18 +53,25 @@ public class Rabit { throw new XGBoostError(XGBoostJNI.XGBGetLastError()); } } - + // used as way to test/debug passed rabit init parameters + public static Map rabitEnvs; + public static List mockList = new LinkedList<>(); /** * Initialize the rabit library on current working thread. * @param envs The additional environment variables to pass to rabit. * @throws XGBoostError */ public static void init(Map envs) throws XGBoostError { - String[] args = new String[envs.size()]; + rabitEnvs = envs; + String[] args = new String[envs.size() + mockList.size()]; int idx = 0; for (java.util.Map.Entry e : envs.entrySet()) { args[idx++] = e.getKey() + '=' + e.getValue(); } + // pass list of rabit mock strings eg mock=0,1,0,0 + for(String mock : mockList) { + args[idx++] = "mock=" + mock; + } checkCall(XGBoostJNI.RabitInit(args)); } diff --git a/rabit b/rabit index 9a7ac85d7..2f2534716 160000 --- a/rabit +++ b/rabit @@ -1 +1 @@ -Subproject commit 9a7ac85d7eb65b1a0b904e1fa8d5a01b910adda4 +Subproject commit 2f253471680f1bdafc1dfa17395ca0f309fe96de diff --git a/tests/ci_build/build_jvm_packages.sh b/tests/ci_build/build_jvm_packages.sh index e342c8f89..8190aa1e1 100755 --- a/tests/ci_build/build_jvm_packages.sh +++ b/tests/ci_build/build_jvm_packages.sh @@ -15,7 +15,7 @@ spark_version=$1 rm -rf build/ cd jvm-packages - +export RABIT_MOCK=ON mvn --no-transfer-progress package -Dspark.version=${spark_version} set +x diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index 8890c15bd..20216b4cd 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -26,6 +26,7 @@ fi if [ ${TASK} == "java_test" ]; then set -e + export RABIT_MOCK=ON cd jvm-packages mvn -q clean install -DskipTests -Dmaven.test.skip mvn -q test