* [jvm-packages] add hostIp and python exec for rabit tracker (#7808) * Fix training continuation with categorical model. (#7810) * Make sure the task is initialized before construction of tree updater. This is a quick fix meant to be backported to 1.6, for a full fix we should pass the model param into tree updater by reference instead. Co-authored-by: Bobby Wang <wbo4958@gmail.com>
This commit is contained in:
parent
3ee3b18a22
commit
816e788b29
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014,2021 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -46,8 +46,14 @@ import org.apache.spark.sql.SparkSession
|
|||||||
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
|
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
|
||||||
* in Scala without Python components, and with full support of timeouts.
|
* in Scala without Python components, and with full support of timeouts.
|
||||||
* The Scala implementation is currently experimental, use at your own risk.
|
* The Scala implementation is currently experimental, use at your own risk.
|
||||||
|
*
|
||||||
|
* @param hostIp The Rabit Tracker host IP address which is only used for python implementation.
|
||||||
|
* This is only needed if the host IP cannot be automatically guessed.
|
||||||
|
* @param pythonExec The python executed path for Rabit Tracker,
|
||||||
|
* which is only used for python implementation.
|
||||||
*/
|
*/
|
||||||
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String )
|
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String,
|
||||||
|
hostIp: String = "", pythonExec: String = "")
|
||||||
|
|
||||||
object TrackerConf {
|
object TrackerConf {
|
||||||
def apply(): TrackerConf = TrackerConf(0L, "python")
|
def apply(): TrackerConf = TrackerConf(0L, "python")
|
||||||
@ -336,13 +342,18 @@ object XGBoost extends Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
|
/** visiable for testing */
|
||||||
|
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
|
||||||
val tracker: IRabitTracker = trackerConf.trackerImpl match {
|
val tracker: IRabitTracker = trackerConf.trackerImpl match {
|
||||||
case "scala" => new RabitTracker(nWorkers)
|
case "scala" => new RabitTracker(nWorkers)
|
||||||
case "python" => new PyRabitTracker(nWorkers)
|
case "python" => new PyRabitTracker(nWorkers, trackerConf.hostIp, trackerConf.pythonExec)
|
||||||
case _ => new PyRabitTracker(nWorkers)
|
case _ => new PyRabitTracker(nWorkers)
|
||||||
}
|
}
|
||||||
|
tracker
|
||||||
|
}
|
||||||
|
|
||||||
|
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
|
||||||
|
val tracker = getTracker(nWorkers, trackerConf)
|
||||||
require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker")
|
require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker")
|
||||||
tracker
|
tracker
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014-2022 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -24,11 +24,61 @@ import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker}
|
|||||||
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
|
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
|
||||||
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
||||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||||
|
import org.scalatest.{FunSuite}
|
||||||
import org.scalatest.{FunSuite, Ignore}
|
|
||||||
|
|
||||||
class RabitRobustnessSuite extends FunSuite with PerTest {
|
class RabitRobustnessSuite extends FunSuite with PerTest {
|
||||||
|
|
||||||
|
private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = {
|
||||||
|
val classifier = new XGBoostClassifier(paramMap)
|
||||||
|
val xgbParamsFactory = new XGBoostExecutionParamsFactory(classifier.MLlib2XGBoostParams, sc)
|
||||||
|
xgbParamsFactory.buildXGBRuntimeParams
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
test("Customize host ip and python exec for Rabit tracker") {
|
||||||
|
val hostIp = "192.168.22.111"
|
||||||
|
val pythonExec = "/usr/bin/python3"
|
||||||
|
|
||||||
|
val paramMap = Map(
|
||||||
|
"num_workers" -> numWorkers,
|
||||||
|
"tracker_conf" -> TrackerConf(0L, "python", hostIp))
|
||||||
|
val xgbExecParams = getXGBoostExecutionParams(paramMap)
|
||||||
|
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||||
|
tracker match {
|
||||||
|
case pyTracker: PyRabitTracker =>
|
||||||
|
val cmd = pyTracker.getRabitTrackerCommand
|
||||||
|
assert(cmd.contains(hostIp))
|
||||||
|
assert(cmd.startsWith("python"))
|
||||||
|
case _ => assert(false, "expected python tracker implementation")
|
||||||
|
}
|
||||||
|
|
||||||
|
val paramMap1 = Map(
|
||||||
|
"num_workers" -> numWorkers,
|
||||||
|
"tracker_conf" -> TrackerConf(0L, "python", "", pythonExec))
|
||||||
|
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
|
||||||
|
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
|
||||||
|
tracker1 match {
|
||||||
|
case pyTracker: PyRabitTracker =>
|
||||||
|
val cmd = pyTracker.getRabitTrackerCommand
|
||||||
|
assert(cmd.startsWith(pythonExec))
|
||||||
|
assert(!cmd.contains(hostIp))
|
||||||
|
case _ => assert(false, "expected python tracker implementation")
|
||||||
|
}
|
||||||
|
|
||||||
|
val paramMap2 = Map(
|
||||||
|
"num_workers" -> numWorkers,
|
||||||
|
"tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec))
|
||||||
|
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
|
||||||
|
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
|
||||||
|
tracker2 match {
|
||||||
|
case pyTracker: PyRabitTracker =>
|
||||||
|
val cmd = pyTracker.getRabitTrackerCommand
|
||||||
|
assert(cmd.startsWith(pythonExec))
|
||||||
|
assert(cmd.contains(s" --host-ip=${hostIp}"))
|
||||||
|
case _ => assert(false, "expected python tracker implementation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
test("training with Scala-implemented Rabit tracker") {
|
test("training with Scala-implemented Rabit tracker") {
|
||||||
val eval = new EvalError()
|
val eval = new EvalError()
|
||||||
val training = buildDataFrame(Classification.train)
|
val training = buildDataFrame(Classification.train)
|
||||||
|
|||||||
@ -30,6 +30,8 @@ public class RabitTracker implements IRabitTracker {
|
|||||||
private Map<String, String> envs = new HashMap<String, String>();
|
private Map<String, String> envs = new HashMap<String, String>();
|
||||||
// number of workers to be submitted.
|
// number of workers to be submitted.
|
||||||
private int numWorkers;
|
private int numWorkers;
|
||||||
|
private String hostIp = "";
|
||||||
|
private String pythonExec = "";
|
||||||
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
|
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
|
||||||
|
|
||||||
static {
|
static {
|
||||||
@ -85,6 +87,13 @@ public class RabitTracker implements IRabitTracker {
|
|||||||
this.numWorkers = numWorkers;
|
this.numWorkers = numWorkers;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public RabitTracker(int numWorkers, String hostIp, String pythonExec)
|
||||||
|
throws XGBoostError {
|
||||||
|
this(numWorkers);
|
||||||
|
this.hostIp = hostIp;
|
||||||
|
this.pythonExec = pythonExec;
|
||||||
|
}
|
||||||
|
|
||||||
public void uncaughtException(Thread t, Throwable e) {
|
public void uncaughtException(Thread t, Throwable e) {
|
||||||
logger.error("Uncaught exception thrown by worker:", e);
|
logger.error("Uncaught exception thrown by worker:", e);
|
||||||
try {
|
try {
|
||||||
@ -126,12 +135,34 @@ public class RabitTracker implements IRabitTracker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** visible for testing */
|
||||||
|
public String getRabitTrackerCommand() {
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
if (pythonExec == null || pythonExec.isEmpty()) {
|
||||||
|
sb.append("python ");
|
||||||
|
} else {
|
||||||
|
sb.append(pythonExec + " ");
|
||||||
|
}
|
||||||
|
sb.append(" " + tracker_py + " ");
|
||||||
|
sb.append(" --log-level=DEBUG" + " ");
|
||||||
|
sb.append(" --num-workers=" + numWorkers + " ");
|
||||||
|
|
||||||
|
// we first check the property then check the parameter
|
||||||
|
String hostIpFromProperties = trackerProperties.getHostIp();
|
||||||
|
if(hostIpFromProperties != null && !hostIpFromProperties.isEmpty()) {
|
||||||
|
logger.debug("Using provided host-ip: " + hostIpFromProperties + " from properties");
|
||||||
|
sb.append(" --host-ip=" + hostIpFromProperties + " ");
|
||||||
|
} else if (hostIp != null & !hostIp.isEmpty()) {
|
||||||
|
logger.debug("Using the parametr host-ip: " + hostIp);
|
||||||
|
sb.append(" --host-ip=" + hostIp + " ");
|
||||||
|
}
|
||||||
|
return sb.toString();
|
||||||
|
}
|
||||||
|
|
||||||
private boolean startTrackerProcess() {
|
private boolean startTrackerProcess() {
|
||||||
try {
|
try {
|
||||||
String trackerExecString = this.addTrackerProperties("python " + tracker_py +
|
String cmd = getRabitTrackerCommand();
|
||||||
" --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers));
|
trackerProcess.set(Runtime.getRuntime().exec(cmd));
|
||||||
|
|
||||||
trackerProcess.set(Runtime.getRuntime().exec(trackerExecString));
|
|
||||||
loadEnvs(trackerProcess.get().getInputStream());
|
loadEnvs(trackerProcess.get().getInputStream());
|
||||||
return true;
|
return true;
|
||||||
} catch (IOException ioe) {
|
} catch (IOException ioe) {
|
||||||
@ -140,18 +171,6 @@ public class RabitTracker implements IRabitTracker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private String addTrackerProperties(String trackerExecString) {
|
|
||||||
StringBuilder sb = new StringBuilder(trackerExecString);
|
|
||||||
String hostIp = trackerProperties.getHostIp();
|
|
||||||
|
|
||||||
if(hostIp != null && !hostIp.isEmpty()){
|
|
||||||
logger.debug("Using provided host-ip: " + hostIp);
|
|
||||||
sb.append(" --host-ip=").append(hostIp);
|
|
||||||
}
|
|
||||||
|
|
||||||
return sb.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
public void stop() {
|
public void stop() {
|
||||||
if (trackerProcess.get() != null) {
|
if (trackerProcess.get() != null) {
|
||||||
trackerProcess.get().destroy();
|
trackerProcess.get().destroy();
|
||||||
|
|||||||
@ -419,6 +419,7 @@ class LearnerConfiguration : public Learner {
|
|||||||
obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_));
|
obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_));
|
||||||
}
|
}
|
||||||
obj_->LoadConfig(objective_fn);
|
obj_->LoadConfig(objective_fn);
|
||||||
|
learner_model_param_.task = obj_->Task();
|
||||||
|
|
||||||
tparam_.booster = get<String>(gradient_booster["name"]);
|
tparam_.booster = get<String>(gradient_booster["name"]);
|
||||||
if (!gbm_) {
|
if (!gbm_) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user