[jvm-packages] update rabit, surface new changes to spark, add parity and failure tests (#4966)
* [phase 1] expose sets of rabit configurations to spark layer * add back mutable import * disable ring_mincount till https://github.com/dmlc/rabit/pull/106d * Revert "disable ring_mincount till https://github.com/dmlc/rabit/pull/106d" This reverts commit 65e95a98e24f5eb53c6ba9ef9b2379524258984d. * apply latest rabit * fix build error * apply https://github.com/dmlc/xgboost/pull/4880 * downgrade cmake in rabit * point to rabit with DMLC_ROOT fix * relative path of rabit install prefix * split rabit parameters to another trait * misc * misc * Delete .classpath * Delete .classpath * Delete .classpath * Update XGBoostClassifier.scala * Update XGBoostRegressor.scala * Update GeneralParams.scala * Update GeneralParams.scala * Update GeneralParams.scala * Update GeneralParams.scala * Delete .classpath * Update RabitParams.scala * Update .gitignore * Update .gitignore * apply rabitParams to training * use string as rabit parameter value type * cleanup * add rabitEnv check * point to dmlc/rabit * per feedback * update private scope * misc * update rabit * add rabit_timtout, fix failing test. * split tests * allow build jvm with rabit mock * pass mock failures to rabit with test * add mock error and graceful handle rabit assertion error test * split mvn test * remove sign for test * update rabit * build jvm_packages with rabit mock * point back to dmlc/rabit * per feedback, update scala header * cleanup pom * per feedback * try fix lint * fix lint * per feedback, remove bootstrap_cache * per feedback 2 * try replace dev profile with passing mvn property * fix build error * remove mvn property and replace with env setting to build test jar * per feedback * revert copyright headlines, point to dmlc/rabit * revert python lint * remove multiple failure test case as retry is not enabled in spark * Update core.py * Update core.py * per feedback, style fix
This commit is contained in:
@@ -21,11 +21,11 @@ import java.nio.file.Files
|
||||
|
||||
import scala.collection.{AbstractIterator, mutable}
|
||||
import scala.util.Random
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam
|
||||
import ml.dmlc.xgboost4j.scala.spark.XGBoost.logger
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
@@ -221,6 +221,24 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
xgbExecParam.setRawParamMap(overridedParams)
|
||||
xgbExecParam
|
||||
}
|
||||
|
||||
private[spark] def buildRabitParams : Map[String, String] = Map(
|
||||
"rabit_reduce_ring_mincount" ->
|
||||
overridedParams.getOrElse("rabit_ring_reduce_threshold", 32 << 10).toString,
|
||||
"rabit_debug" ->
|
||||
(overridedParams.getOrElse("verbosity", 0).toString.toInt == 3).toString,
|
||||
"rabit_timeout" ->
|
||||
(overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0).toString,
|
||||
"rabit_timeout_sec" -> {
|
||||
if (overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0) {
|
||||
overridedParams.get("rabit_timeout").toString
|
||||
} else {
|
||||
"1800"
|
||||
}
|
||||
},
|
||||
"DMLC_WORKER_CONNECT_RETRY" ->
|
||||
overridedParams.getOrElse("dmlc_worker_connect_retry", 5).toString
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -320,7 +338,9 @@ object XGBoost extends Serializable {
|
||||
s" ${TaskContext.getPartitionId()}")
|
||||
}
|
||||
val taskId = TaskContext.getPartitionId().toString
|
||||
val attempt = TaskContext.get().attemptNumber.toString
|
||||
rabitEnv.put("DMLC_TASK_ID", taskId)
|
||||
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
|
||||
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
|
||||
|
||||
try {
|
||||
@@ -333,7 +353,7 @@ object XGBoost extends Serializable {
|
||||
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
|
||||
} catch {
|
||||
case xgbException: XGBoostError =>
|
||||
logger.error(s"XGBooster worker $taskId has failed due to ", xgbException)
|
||||
logger.error(s"XGBooster worker $taskId has failed $attempt times due to ", xgbException)
|
||||
throw xgbException
|
||||
} finally {
|
||||
Rabit.shutdown()
|
||||
@@ -490,8 +510,9 @@ object XGBoost extends Serializable {
|
||||
evalSetsMap: Map[String, RDD[XGBLabeledPoint]] = Map()):
|
||||
(Booster, Map[String, Array[Float]]) = {
|
||||
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
||||
val xgbExecParams = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext).
|
||||
buildXGBRuntimeParams
|
||||
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
|
||||
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
|
||||
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
|
||||
val sc = trainingData.sparkContext
|
||||
val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam.
|
||||
checkpointPath)
|
||||
@@ -510,13 +531,14 @@ object XGBoost extends Serializable {
|
||||
val parallelismTracker = new SparkParallelismTracker(sc,
|
||||
xgbExecParams.timeoutRequestWorkers,
|
||||
xgbExecParams.numWorkers)
|
||||
val rabitEnv = tracker.getWorkerEnvs
|
||||
|
||||
tracker.getWorkerEnvs().putAll(xgbRabitParams)
|
||||
val boostersAndMetrics = if (hasGroup) {
|
||||
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv,
|
||||
checkpointRound, prevBooster, evalSetsMap)
|
||||
trainForRanking(transformedTrainingData.left.get, xgbExecParams,
|
||||
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap)
|
||||
} else {
|
||||
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
|
||||
checkpointRound, prevBooster, evalSetsMap)
|
||||
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams,
|
||||
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap)
|
||||
}
|
||||
val sparkJobThread = new Thread() {
|
||||
override def run() {
|
||||
|
||||
@@ -50,7 +50,7 @@ class XGBoostClassifier (
|
||||
def this(xgboostParams: Map[String, Any]) = this(
|
||||
Identifiable.randomUID("xgbc"), xgboostParams)
|
||||
|
||||
XGBoostToMLlibParams(xgboostParams)
|
||||
XGBoost2MLlibParams(xgboostParams)
|
||||
|
||||
def setWeightCol(value: String): this.type = set(weightCol, value)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.scala.spark.params._
|
||||
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||
|
||||
private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
|
||||
with BoosterParams with ParamMapFuncs with NonParamVariables {
|
||||
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
|
||||
|
||||
def needDeterministicRepartitioning: Boolean = {
|
||||
getCheckpointPath.nonEmpty && getCheckpointInterval > 0
|
||||
|
||||
@@ -54,7 +54,7 @@ class XGBoostRegressor (
|
||||
def this(xgboostParams: Map[String, Any]) = this(
|
||||
Identifiable.randomUID("xgbr"), xgboostParams)
|
||||
|
||||
XGBoostToMLlibParams(xgboostParams)
|
||||
XGBoost2MLlibParams(xgboostParams)
|
||||
|
||||
def setWeightCol(value: String): this.type = set(weightCol, value)
|
||||
|
||||
|
||||
@@ -241,7 +241,7 @@ trait HasNumClass extends Params {
|
||||
|
||||
private[spark] trait ParamMapFuncs extends Params {
|
||||
|
||||
def XGBoostToMLlibParams(xgboostParams: Map[String, Any]): Unit = {
|
||||
def XGBoost2MLlibParams(xgboostParams: Map[String, Any]): Unit = {
|
||||
for ((paramName, paramValue) <- xgboostParams) {
|
||||
if ((paramName == "booster" && paramValue != "gbtree") ||
|
||||
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark.params
|
||||
|
||||
import org.apache.spark.ml.param._
|
||||
|
||||
private[spark] trait RabitParams extends Params {
|
||||
/**
|
||||
* Rabit parameters passed through Rabit.Init into native layer
|
||||
* rabit_ring_reduce_threshold - minimal threshold to enable ring based allreduce operation
|
||||
* rabit_timeout - wait interval before exit after rabit observed failures set -1 to disable
|
||||
* dmlc_worker_connect_retry - number of retrys to tracker
|
||||
* dmlc_worker_stop_process_on_error - exit process when rabit see assert/error
|
||||
*/
|
||||
final val rabitRingReduceThreshold = new IntParam(this, "rabitRingReduceThreshold",
|
||||
"threshold count to enable allreduce/broadcast with ring based topology",
|
||||
ParamValidators.gtEq(1))
|
||||
|
||||
final def rabitTimeout: IntParam = new IntParam(this, "rabitTimeout",
|
||||
"timeout threshold after rabit observed failures")
|
||||
|
||||
final def rabitConnectRetry: IntParam = new IntParam(this, "dmlcWorkerConnectRetry",
|
||||
"number of retry worker do before fail", ParamValidators.gtEq(1))
|
||||
|
||||
setDefault(rabitRingReduceThreshold -> (32 << 10), rabitConnectRetry -> 5, rabitTimeout -> -1)
|
||||
}
|
||||
Reference in New Issue
Block a user