[breaking] [jvm-packages] Remove scala-implemented tracker. (#9045)

This commit is contained in:
Jiaming Yuan
2023-04-20 16:29:35 +08:00
committed by GitHub
parent 42d100de18
commit 564df59204
8 changed files with 9 additions and 1585 deletions

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -23,7 +23,6 @@ import scala.util.Random
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
@@ -44,21 +43,16 @@ import org.apache.spark.sql.SparkSession
* Use a finite, non-zero timeout value to prevent tracker from
* hanging indefinitely (in milliseconds)
* (supported by "scala" implementation only.)
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
* in Scala without Python components, and with full support of timeouts.
* The Scala implementation is currently experimental, use at your own risk.
*
* @param hostIp The Rabit Tracker host IP address which is only used for python implementation.
* This is only needed if the host IP cannot be automatically guessed.
* @param pythonExec The python executed path for Rabit Tracker,
* which is only used for python implementation.
*/
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String,
case class TrackerConf(workerConnectionTimeout: Long,
hostIp: String = "", pythonExec: String = "")
object TrackerConf {
def apply(): TrackerConf = TrackerConf(0L, "python")
def apply(): TrackerConf = TrackerConf(0L)
}
private[scala] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
@@ -349,11 +343,9 @@ object XGBoost extends Serializable {
/** visiable for testing */
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
val tracker: IRabitTracker = trackerConf.trackerImpl match {
case "scala" => new RabitTracker(nWorkers)
case "python" => new PyRabitTracker(nWorkers, trackerConf.hostIp, trackerConf.pythonExec)
case _ => new PyRabitTracker(nWorkers)
}
val tracker: IRabitTracker = new PyRabitTracker(
nWorkers, trackerConf.hostIp, trackerConf.pythonExec
)
tracker
}

View File

@@ -22,7 +22,6 @@ import scala.util.Random
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
import ml.dmlc.xgboost4j.scala.DMatrix
import org.scalatest.FunSuite
@@ -40,7 +39,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
val paramMap = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", hostIp))
"tracker_conf" -> TrackerConf(0L, hostIp))
val xgbExecParams = getXGBoostExecutionParams(paramMap)
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
tracker match {
@@ -53,7 +52,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
val paramMap1 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", "", pythonExec))
"tracker_conf" -> TrackerConf(0L, "", pythonExec))
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
tracker1 match {
@@ -66,7 +65,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
val paramMap2 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec))
"tracker_conf" -> TrackerConf(0L, hostIp, pythonExec))
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
tracker2 match {
@@ -78,58 +77,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
}
}
test("training with Scala-implemented Rabit tracker") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
val model = new XGBoostClassifier(paramMap).fit(training)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
test("test Communicator allreduce to validate Scala-implemented Rabit tracker") {
val vectorLength = 100
val rdd = sc.parallelize(
(1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
val rawData = rdd.mapPartitions { iter =>
Iterator(iter.toArray)
}.collect()
val maxVec = (0 until vectorLength).toArray.map { j =>
(0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
}
val allReduceResults = rdd.mapPartitions { iter =>
Communicator.init(trackerEnvs)
val arr = iter.toArray
val results = Communicator.allReduce(arr, Communicator.OpType.MAX)
Communicator.shutdown()
Iterator(results)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
allReduceResults.foreachPartition(() => _)
val byPartitionResults = allReduceResults.collect()
assert(byPartitionResults(0).length == vectorLength)
collectedAllReduceResults.put(byPartitionResults(0))
}
}
sparkThread.start()
assert(tracker.waitFor(0L) == 0)
sparkThread.join()
assert(collectedAllReduceResults.poll().sameElements(maxVec))
}
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
/*
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
@@ -193,68 +140,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
assert(tracker.waitFor(0) != 0)
}
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val workerCount: Int = numWorkers
val dummyTasks = rdd.mapPartitions { iter =>
Communicator.init(trackerEnvs)
val index = iter.next()
Thread.sleep(100 + index * 10)
if (index == workerCount) {
// kill the worker by throwing an exception
throw new RuntimeException("Worker exception.")
}
Communicator.shutdown()
Iterator(index)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
}
test("test Scala RabitTracker's workerConnectionTimeout") {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(500)
val trackerEnvs = tracker.getWorkerEnvs
val dummyTasks = rdd.mapPartitions { iter =>
val index = iter.next()
// simulate that the first worker cannot connect to tracker due to network issues.
if (index != 1) {
Communicator.init(trackerEnvs)
Thread.sleep(1000)
Communicator.shutdown()
}
Iterator(index)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
// should fail due to connection timeout
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
}
test("should allow the dataframe containing communicator calls to be partially evaluated for" +
" multiple times (ISSUE-4406)") {
val paramMap = Map(