[jvm-packages] update rabit, surface new changes to spark, add parity and failure tests (#4966)

* [phase 1] expose sets of rabit configurations to spark layer

* add back mutable import

* disable ring_mincount till https://github.com/dmlc/rabit/pull/106d

* Revert "disable ring_mincount till https://github.com/dmlc/rabit/pull/106d"

This reverts commit 65e95a98e24f5eb53c6ba9ef9b2379524258984d.

* apply latest rabit

* fix build error

* apply https://github.com/dmlc/xgboost/pull/4880

* downgrade cmake in rabit

* point to rabit with DMLC_ROOT fix

* relative path of rabit install prefix

* split rabit parameters to another trait

* misc

* misc

* Delete .classpath

* Delete .classpath

* Delete .classpath

* Update XGBoostClassifier.scala

* Update XGBoostRegressor.scala

* Update GeneralParams.scala

* Update GeneralParams.scala

* Update GeneralParams.scala

* Update GeneralParams.scala

* Delete .classpath

* Update RabitParams.scala

* Update .gitignore

* Update .gitignore

* apply rabitParams to training

* use string as rabit parameter value type

* cleanup

* add rabitEnv check

* point to dmlc/rabit

* per feedback

* update private scope

* misc

* update rabit

* add rabit_timtout, fix failing test.

* split tests

* allow build jvm with rabit mock

* pass mock failures to rabit with test

* add mock error and graceful handle rabit assertion error test

* split mvn test

* remove sign for test

* update rabit

* build jvm_packages with rabit mock

* point back to dmlc/rabit

* per feedback, update scala header

* cleanup pom

* per feedback

* try fix lint

* fix lint

* per feedback, remove bootstrap_cache

* per feedback 2

* try replace dev profile with passing mvn property

* fix build error

* remove mvn property and replace with env setting to build test jar

* per feedback

* revert copyright headlines, point to dmlc/rabit

* revert python lint

* remove multiple failure test case as retry is not enabled in spark

* Update core.py

* Update core.py

* per feedback, style fix
This commit is contained in:
Chen Qin
2019-11-01 14:21:19 -07:00
committed by Nan Zhu
parent a37691428f
commit b29b8c2f34
15 changed files with 232 additions and 51 deletions

View File

@@ -21,11 +21,11 @@ import java.nio.file.Files
import scala.collection.{AbstractIterator, mutable}
import scala.util.Random
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam
import ml.dmlc.xgboost4j.scala.spark.XGBoost.logger
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
@@ -221,6 +221,24 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
xgbExecParam.setRawParamMap(overridedParams)
xgbExecParam
}
private[spark] def buildRabitParams : Map[String, String] = Map(
"rabit_reduce_ring_mincount" ->
overridedParams.getOrElse("rabit_ring_reduce_threshold", 32 << 10).toString,
"rabit_debug" ->
(overridedParams.getOrElse("verbosity", 0).toString.toInt == 3).toString,
"rabit_timeout" ->
(overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0).toString,
"rabit_timeout_sec" -> {
if (overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0) {
overridedParams.get("rabit_timeout").toString
} else {
"1800"
}
},
"DMLC_WORKER_CONNECT_RETRY" ->
overridedParams.getOrElse("dmlc_worker_connect_retry", 5).toString
)
}
/**
@@ -320,7 +338,9 @@ object XGBoost extends Serializable {
s" ${TaskContext.getPartitionId()}")
}
val taskId = TaskContext.getPartitionId().toString
val attempt = TaskContext.get().attemptNumber.toString
rabitEnv.put("DMLC_TASK_ID", taskId)
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
try {
@@ -333,7 +353,7 @@ object XGBoost extends Serializable {
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
} catch {
case xgbException: XGBoostError =>
logger.error(s"XGBooster worker $taskId has failed due to ", xgbException)
logger.error(s"XGBooster worker $taskId has failed $attempt times due to ", xgbException)
throw xgbException
} finally {
Rabit.shutdown()
@@ -490,8 +510,9 @@ object XGBoost extends Serializable {
evalSetsMap: Map[String, RDD[XGBLabeledPoint]] = Map()):
(Booster, Map[String, Array[Float]]) = {
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
val xgbExecParams = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext).
buildXGBRuntimeParams
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
val sc = trainingData.sparkContext
val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam.
checkpointPath)
@@ -510,13 +531,14 @@ object XGBoost extends Serializable {
val parallelismTracker = new SparkParallelismTracker(sc,
xgbExecParams.timeoutRequestWorkers,
xgbExecParams.numWorkers)
val rabitEnv = tracker.getWorkerEnvs
tracker.getWorkerEnvs().putAll(xgbRabitParams)
val boostersAndMetrics = if (hasGroup) {
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv,
checkpointRound, prevBooster, evalSetsMap)
trainForRanking(transformedTrainingData.left.get, xgbExecParams,
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap)
} else {
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
checkpointRound, prevBooster, evalSetsMap)
trainForNonRanking(transformedTrainingData.right.get, xgbExecParams,
tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap)
}
val sparkJobThread = new Thread() {
override def run() {

View File

@@ -50,7 +50,7 @@ class XGBoostClassifier (
def this(xgboostParams: Map[String, Any]) = this(
Identifiable.randomUID("xgbc"), xgboostParams)
XGBoostToMLlibParams(xgboostParams)
XGBoost2MLlibParams(xgboostParams)
def setWeightCol(value: String): this.type = set(weightCol, value)

View File

@@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.scala.spark.params._
import org.apache.spark.ml.param.shared.HasWeightCol
private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
with BoosterParams with ParamMapFuncs with NonParamVariables {
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
def needDeterministicRepartitioning: Boolean = {
getCheckpointPath.nonEmpty && getCheckpointInterval > 0

View File

@@ -54,7 +54,7 @@ class XGBoostRegressor (
def this(xgboostParams: Map[String, Any]) = this(
Identifiable.randomUID("xgbr"), xgboostParams)
XGBoostToMLlibParams(xgboostParams)
XGBoost2MLlibParams(xgboostParams)
def setWeightCol(value: String): this.type = set(weightCol, value)

View File

@@ -241,7 +241,7 @@ trait HasNumClass extends Params {
private[spark] trait ParamMapFuncs extends Params {
def XGBoostToMLlibParams(xgboostParams: Map[String, Any]): Unit = {
def XGBoost2MLlibParams(xgboostParams: Map[String, Any]): Unit = {
for ((paramName, paramValue) <- xgboostParams) {
if ((paramName == "booster" && paramValue != "gbtree") ||
(paramName == "updater" && paramValue != "grow_histmaker,prune" &&

View File

@@ -0,0 +1,40 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.spark.ml.param._
private[spark] trait RabitParams extends Params {
/**
* Rabit parameters passed through Rabit.Init into native layer
* rabit_ring_reduce_threshold - minimal threshold to enable ring based allreduce operation
* rabit_timeout - wait interval before exit after rabit observed failures set -1 to disable
* dmlc_worker_connect_retry - number of retrys to tracker
* dmlc_worker_stop_process_on_error - exit process when rabit see assert/error
*/
final val rabitRingReduceThreshold = new IntParam(this, "rabitRingReduceThreshold",
"threshold count to enable allreduce/broadcast with ring based topology",
ParamValidators.gtEq(1))
final def rabitTimeout: IntParam = new IntParam(this, "rabitTimeout",
"timeout threshold after rabit observed failures")
final def rabitConnectRetry: IntParam = new IntParam(this, "dmlcWorkerConnectRetry",
"number of retry worker do before fail", ParamValidators.gtEq(1))
setDefault(rabitRingReduceThreshold -> (32 << 10), rabitConnectRetry -> 5, rabitTimeout -> -1)
}