[jvm-packages] add hostIp and python exec for rabit tracker (#7808)

This commit is contained in:
Bobby Wang
2022-04-15 16:28:43 +08:00
committed by GitHub
parent 6f032b7152
commit 2d83b2ad8f
3 changed files with 103 additions and 23 deletions

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014,2021 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -46,8 +46,14 @@ import org.apache.spark.sql.SparkSession
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
* in Scala without Python components, and with full support of timeouts.
* The Scala implementation is currently experimental, use at your own risk.
*
* @param hostIp The Rabit Tracker host IP address which is only used for python implementation.
* This is only needed if the host IP cannot be automatically guessed.
* @param pythonExec The python executed path for Rabit Tracker,
* which is only used for python implementation.
*/
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String )
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String,
hostIp: String = "", pythonExec: String = "")
object TrackerConf {
def apply(): TrackerConf = TrackerConf(0L, "python")
@@ -336,13 +342,18 @@ object XGBoost extends Serializable {
}
}
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
/** visiable for testing */
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
val tracker: IRabitTracker = trackerConf.trackerImpl match {
case "scala" => new RabitTracker(nWorkers)
case "python" => new PyRabitTracker(nWorkers)
case "python" => new PyRabitTracker(nWorkers, trackerConf.hostIp, trackerConf.pythonExec)
case _ => new PyRabitTracker(nWorkers)
}
tracker
}
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
val tracker = getTracker(nWorkers, trackerConf)
require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker")
tracker
}