[jvm-packages] update rabit, surface new changes to spark, add parity and failure tests (#4876)

* Expose sets of rabit configurations to spark layer
This commit is contained in:
Chen Qin
2019-10-18 12:07:31 -07:00
committed by Jiaming Yuan
parent 31030a8d3a
commit 86ed01c4bb
73 changed files with 343 additions and 115 deletions

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -21,11 +21,11 @@ import java.nio.file.Files
import scala.collection.{AbstractIterator, mutable}
import scala.util.Random
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam
import ml.dmlc.xgboost4j.scala.spark.XGBoost.logger
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
@@ -155,7 +155,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
overridedParams
}
def buildXGBRuntimeParams: XGBoostExecutionParams = {
private[spark] def buildXGBRuntimeParams: XGBoostExecutionParams = {
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
val round = overridedParams("num_round").asInstanceOf[Int]
val useExternalMemory = overridedParams("use_external_memory").asInstanceOf[Boolean]
@@ -221,6 +221,25 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
xgbExecParam.setRawParamMap(overridedParams)
xgbExecParam
}
private[spark] def buildRabitParams: Map[String, String] = Map(
"rabit_reduce_ring_mincount" ->
overridedParams.getOrElse("rabit_reduce_ring_mincount", 32<<10).toString,
"rabit_reduce_buffer" ->
overridedParams.getOrElse("rabit_reduce_buffer", "256MB").toString,
"rabit_bootstrap_cache" ->
overridedParams.getOrElse("rabit_bootstrap_cache", false).toString,
"rabit_debug" ->
overridedParams.getOrElse("rabit_debug", false).toString,
"rabit_timeout" ->
overridedParams.getOrElse("rabit_timeout", false).toString,
"rabit_timeout_sec" ->
overridedParams.getOrElse("rabit_timeout_sec", 1800).toString,
"DMLC_WORKER_CONNECT_RETRY" ->
overridedParams.getOrElse("dmlc_worker_connect_retry", 5).toString,
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" ->
overridedParams.getOrElse("dmlc_worker_stop_process_on_error", false).toString
)
}
/**
@@ -321,7 +340,6 @@ object XGBoost extends Serializable {
}
val taskId = TaskContext.getPartitionId().toString
rabitEnv.put("DMLC_TASK_ID", taskId)
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
try {
Rabit.init(rabitEnv)
@@ -490,8 +508,9 @@ object XGBoost extends Serializable {
evalSetsMap: Map[String, RDD[XGBLabeledPoint]] = Map()):
(Booster, Map[String, Array[Float]]) = {
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
val xgbExecParams = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext).
buildXGBRuntimeParams
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
val xgbRabitParams = xgbParamsFactory.buildRabitParams
val sc = trainingData.sparkContext
val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam.
checkpointPath)
@@ -510,12 +529,12 @@ object XGBoost extends Serializable {
val parallelismTracker = new SparkParallelismTracker(sc,
xgbExecParams.timeoutRequestWorkers,
xgbExecParams.numWorkers)
val rabitEnv = tracker.getWorkerEnvs
val rabitEnv = tracker.getWorkerEnvs.asScala ++ xgbRabitParams
val boostersAndMetrics = if (hasGroup) {
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv,
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv.asJava,
checkpointRound, prevBooster, evalSetsMap)
} else {
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv.asJava,
checkpointRound, prevBooster, evalSetsMap)
}
val sparkJobThread = new Thread() {

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -50,7 +50,7 @@ class XGBoostClassifier (
def this(xgboostParams: Map[String, Any]) = this(
Identifiable.randomUID("xgbc"), xgboostParams)
XGBoostToMLlibParams(xgboostParams)
XGBoost2MLlibParams(xgboostParams)
def setWeightCol(value: String): this.type = set(weightCol, value)

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.scala.spark.params._
import org.apache.spark.ml.param.shared.HasWeightCol
private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
with BoosterParams with ParamMapFuncs with NonParamVariables {
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
def needDeterministicRepartitioning: Boolean = {
getCheckpointPath.nonEmpty && getCheckpointInterval > 0

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -54,7 +54,7 @@ class XGBoostRegressor (
def this(xgboostParams: Map[String, Any]) = this(
Identifiable.randomUID("xgbr"), xgboostParams)
XGBoostToMLlibParams(xgboostParams)
XGBoost2MLlibParams(xgboostParams)
def setWeightCol(value: String): this.type = set(weightCol, value)

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -241,7 +241,7 @@ trait HasNumClass extends Params {
private[spark] trait ParamMapFuncs extends Params {
def XGBoostToMLlibParams(xgboostParams: Map[String, Any]): Unit = {
def XGBoost2MLlibParams(xgboostParams: Map[String, Any]): Unit = {
for ((paramName, paramValue) <- xgboostParams) {
if ((paramName == "booster" && paramValue != "gbtree") ||
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -0,0 +1,62 @@
/*
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.spark.ml.param._
private[spark] trait RabitParams extends Params {
/**
* Rabit worker configurations. These parameters were passed to Rabit.Init and decide
* rabit_reduce_ring_mincount - threshold of enable ring based allreduce/broadcast operations.
* rabit_reduce_buffer - buffer size to recv and run reduction
* rabit_bootstrap_cache - enable save allreduce cache before loadcheckpoint
* rabit_debug - enable more verbose rabit logging to stdout
* rabit_timeout - enable sidecar thread after rabit observed failures
* rabit_timeout_sec - wait interval before exit after rabit observed failures
* dmlc_worker_connect_retry - number of retrys to tracker
* dmlc_worker_stop_process_on_error - exit process when rabit see assert/error
*/
final val ringReduceMin = new IntParam(this, "rabitReduceRingMincount",
"minimal counts of enable allreduce/broadcast with ring based topology",
ParamValidators.gtEq(1))
final def reduceBuffer: Param[String] = new Param[String](this, "rabitReduceBuffer",
"buffer size (MB/GB) allocated to each xgb trainner recv and run reduction",
(buf: String) => buf.contains("MB") || buf.contains("GB"))
final def bootstrapCache: BooleanParam = new BooleanParam(this, "rabitBootstrapCache",
"enable save allreduce cache before loadcheckpoint, used to allow failed task retry")
final def rabitDebug: BooleanParam = new BooleanParam(this, "rabitDebug",
"enable more verbose rabit logging to stdout")
final def rabitTimeout: BooleanParam = new BooleanParam(this, "rabitTimeout",
"enable failure timeout sidecar threads")
final def timeoutInterval: IntParam = new IntParam(this, "rabitTimeoutSec",
"timeout threshold after rabit observed failures", (interval: Int) => interval > 0)
final def connectRetry: IntParam = new IntParam(this, "dmlcWorkerConnectRetry",
"number of retry worker do before fail", ParamValidators.gtEq(1))
final def exitOnError: BooleanParam = new BooleanParam(this, "dmlcWorkerStopProcessOnError",
"exit process when rabit see assert error")
setDefault(ringReduceMin -> (32 << 10), reduceBuffer -> "256MB", bootstrapCache -> false,
rabitDebug -> false, connectRetry -> 5, rabitTimeout -> false, timeoutInterval -> 1800,
exitOnError -> false)
}

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2019 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.