[jvm-packages] refine tracker (#10313)

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Bobby Wang 2024-05-23 12:46:21 +08:00 committed by GitHub
parent 966dc81788
commit 932d7201f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 71 additions and 92 deletions

View File

@ -35,16 +35,17 @@
<maven.compiler.target>1.8</maven.compiler.target> <maven.compiler.target>1.8</maven.compiler.target>
<flink.version>1.19.0</flink.version> <flink.version>1.19.0</flink.version>
<junit.version>4.13.2</junit.version> <junit.version>4.13.2</junit.version>
<spark.version>3.4.1</spark.version> <spark.version>3.5.1</spark.version>
<spark.version.gpu>3.4.1</spark.version.gpu> <spark.version.gpu>3.5.1</spark.version.gpu>
<fasterxml.jackson.version>2.15.2</fasterxml.jackson.version>
<scala.version>2.12.18</scala.version> <scala.version>2.12.18</scala.version>
<scala.binary.version>2.12</scala.binary.version> <scala.binary.version>2.12</scala.binary.version>
<hadoop.version>3.4.0</hadoop.version> <hadoop.version>3.4.0</hadoop.version>
<maven.wagon.http.retryHandler.count>5</maven.wagon.http.retryHandler.count> <maven.wagon.http.retryHandler.count>5</maven.wagon.http.retryHandler.count>
<log.capi.invocation>OFF</log.capi.invocation> <log.capi.invocation>OFF</log.capi.invocation>
<use.cuda>OFF</use.cuda> <use.cuda>OFF</use.cuda>
<cudf.version>23.12.1</cudf.version> <cudf.version>24.04.0</cudf.version>
<spark.rapids.version>23.12.1</spark.rapids.version> <spark.rapids.version>24.04.0</spark.rapids.version>
<cudf.classifier>cuda12</cudf.classifier> <cudf.classifier>cuda12</cudf.classifier>
<scalatest.version>3.2.18</scalatest.version> <scalatest.version>3.2.18</scalatest.version>
<scala-collection-compat.version>2.12.0</scala-collection-compat.version> <scala-collection-compat.version>2.12.0</scala-collection-compat.version>
@ -489,11 +490,6 @@
<artifactId>kryo</artifactId> <artifactId>kryo</artifactId>
<version>5.6.0</version> <version>5.6.0</version>
</dependency> </dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.14.2</version>
</dependency>
<dependency> <dependency>
<groupId>commons-logging</groupId> <groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId> <artifactId>commons-logging</artifactId>

View File

@ -176,7 +176,7 @@ public class XGBoost {
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism()); new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
if (tracker.start()) { if (tracker.start()) {
return dtrain return dtrain
.mapPartition(new MapFunction(params, numBoostRound, tracker.workerArgs())) .mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerArgs()))
.reduce((x, y) -> x) .reduce((x, y) -> x)
.collect() .collect()
.get(0); .get(0);

View File

@ -1,5 +1,5 @@
/* /*
Copyright (c) 2021-2022 by Contributors Copyright (c) 2021-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -29,7 +29,7 @@ import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.functions.{col, collect_list, struct} import org.apache.spark.sql.functions.{col, collect_list, struct}
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType} import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
@ -444,7 +444,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
.groupBy(groupName) .groupBy(groupName)
.agg(collect_list(struct(schema.fieldNames.map(col): _*)) as "list") .agg(collect_list(struct(schema.fieldNames.map(col): _*)) as "list")
implicit val encoder = RowEncoder(schema) implicit val encoder = ExpressionEncoder(RowEncoder.encoderFor(schema, false))
// Expand the grouped rows after repartition // Expand the grouped rows after repartition
repartitionInputData(groupedDF, nWorkers).mapPartitions(iter => { repartitionInputData(groupedDF, nWorkers).mapPartitions(iter => {
new Iterator[Row] { new Iterator[Row] {

View File

@ -233,24 +233,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
xgbExecParam.setRawParamMap(overridedParams) xgbExecParam.setRawParamMap(overridedParams)
xgbExecParam xgbExecParam
} }
private[spark] def buildRabitParams : Map[String, String] = Map(
"rabit_reduce_ring_mincount" ->
overridedParams.getOrElse("rabit_ring_reduce_threshold", 32 << 10).toString,
"rabit_debug" ->
(overridedParams.getOrElse("verbosity", 0).toString.toInt == 3).toString,
"rabit_timeout" ->
(overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0).toString,
"rabit_timeout_sec" -> {
if (overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0) {
overridedParams.get("rabit_timeout").toString
} else {
"1800"
}
},
"DMLC_WORKER_CONNECT_RETRY" ->
overridedParams.getOrElse("dmlc_worker_connect_retry", 5).toString
)
} }
/** /**
@ -475,17 +457,15 @@ object XGBoost extends XGBoostStageLevel {
} }
} }
/** visiable for testing */ // Executes the provided code block inside a tracker and then stops the tracker
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = { private def withTracker[T](nWorkers: Int, conf: TrackerConf)(block: ITracker => T): T = {
val tracker: ITracker = new RabitTracker( val tracker = new RabitTracker(nWorkers, conf.hostIp, conf.port, conf.timeout)
nWorkers, trackerConf.hostIp, trackerConf.port, trackerConf.timeout)
tracker
}
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
val tracker = getTracker(nWorkers, trackerConf)
require(tracker.start(), "FAULT: Failed to start tracker") require(tracker.start(), "FAULT: Failed to start tracker")
tracker try {
block(tracker)
} finally {
tracker.stop()
}
} }
/** /**
@ -501,28 +481,27 @@ object XGBoost extends XGBoostStageLevel {
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}") logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc) val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc)
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams val runtimeParams = xgbParamsFactory.buildXGBRuntimeParams
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam => val prevBooster = runtimeParams.checkpointParam.map { checkpointParam =>
val checkpointManager = new ExternalCheckpointManager( val checkpointManager = new ExternalCheckpointManager(
checkpointParam.checkpointPath, checkpointParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration)) FileSystem.get(sc.hadoopConfiguration))
checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds) checkpointManager.cleanUpHigherVersions(runtimeParams.numRounds)
checkpointManager.loadCheckpointAsScalaBooster() checkpointManager.loadCheckpointAsScalaBooster()
}.orNull }.orNull
// Get the training data RDD and the cachedRDD // Get the training data RDD and the cachedRDD
val (trainingRDD, optionalCachedRDD) = buildTrainingData(xgbExecParams) val (trainingRDD, optionalCachedRDD) = buildTrainingData(runtimeParams)
try { try {
// Train for every ${savingRound} rounds and save the partially completed booster val (booster, metrics) = withTracker(
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf) runtimeParams.numWorkers,
val (booster, metrics) = try { runtimeParams.trackerConf
tracker.workerArgs().putAll(xgbRabitParams) ) { tracker =>
val rabitEnv = tracker.workerArgs val rabitEnv = tracker.getWorkerArgs()
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => { val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter =>
var optionWatches: Option[() => Watches] = None var optionWatches: Option[() => Watches] = None
// take the first Watches to train // take the first Watches to train
@ -530,26 +509,25 @@ object XGBoost extends XGBoostStageLevel {
optionWatches = Some(iter.next()) optionWatches = Some(iter.next())
} }
optionWatches.map { buildWatches => buildDistributedBooster(buildWatches, optionWatches.map { buildWatches =>
xgbExecParams, rabitEnv, xgbExecParams.obj, xgbExecParams.eval, prevBooster)} buildDistributedBooster(buildWatches,
.getOrElse(throw new RuntimeException("No Watches to train")) runtimeParams, rabitEnv, runtimeParams.obj, runtimeParams.eval, prevBooster)
}.getOrElse(throw new RuntimeException("No Watches to train"))
}
}} val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, runtimeParams,
val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, xgbExecParams,
boostersAndMetrics) boostersAndMetrics)
// The repartition step is to make training stage as ShuffleMapStage, so that when one // The repartition step is to make training stage as ShuffleMapStage, so that when one
// of the training task fails the training stage can retry. ResultStage won't retry when // of the training task fails the training stage can retry. ResultStage won't retry when
// it fails. // it fails.
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0) val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
(booster, metrics) (booster, metrics)
} finally {
tracker.stop()
} }
// we should delete the checkpoint directory after a successful training // we should delete the checkpoint directory after a successful training
xgbExecParams.checkpointParam.foreach { runtimeParams.checkpointParam.foreach {
cpParam => cpParam =>
if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) { if (!runtimeParams.checkpointParam.get.skipCleanCheckpoint) {
val checkpointManager = new ExternalCheckpointManager( val checkpointManager = new ExternalCheckpointManager(
cpParam.checkpointPath, cpParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration)) FileSystem.get(sc.hadoopConfiguration))

View File

@ -45,7 +45,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
val tracker = new RabitTracker(numWorkers) val tracker = new RabitTracker(numWorkers)
tracker.start() tracker.start()
val trackerEnvs = tracker. workerArgs val trackerEnvs = tracker.getWorkerArgs
val workerCount: Int = numWorkers val workerCount: Int = numWorkers
/* /*
@ -84,7 +84,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache() val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new RabitTracker(numWorkers) val tracker = new RabitTracker(numWorkers)
tracker.start() tracker.start()
val trackerEnvs = tracker.workerArgs val trackerEnvs = tracker.getWorkerArgs
val workerCount: Int = numWorkers val workerCount: Int = numWorkers

View File

@ -53,6 +53,12 @@
<version>${scalatest.version}</version> <version>${scalatest.version}</version>
<scope>provided</scope> <scope>provided</scope>
</dependency> </dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${fasterxml.jackson.version}</version>
<scope>provided</scope>
</dependency>
</dependencies> </dependencies>
<build> <build>

View File

@ -7,7 +7,7 @@ import java.util.Map;
* *
* - start(timeout): Start the tracker awaiting for worker connections, with a given * - start(timeout): Start the tracker awaiting for worker connections, with a given
* timeout value (in seconds). * timeout value (in seconds).
* - workerArgs(): Return the arguments needed to initialize Rabit clients. * - getWorkerArgs(): Return the arguments needed to initialize Rabit clients.
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout` * - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
* milliseconds. * milliseconds.
* *
@ -21,21 +21,8 @@ import java.util.Map;
* brokers connections between workers. * brokers connections between workers.
*/ */
public interface ITracker extends Thread.UncaughtExceptionHandler { public interface ITracker extends Thread.UncaughtExceptionHandler {
enum TrackerStatus {
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
private int statusCode; Map<String, Object> getWorkerArgs() throws XGBoostError;
TrackerStatus(int statusCode) {
this.statusCode = statusCode;
}
public int getStatusCode() {
return this.statusCode;
}
}
Map<String, Object> workerArgs() throws XGBoostError;
boolean start() throws XGBoostError; boolean start() throws XGBoostError;

View File

@ -1,3 +1,19 @@
/*
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java; package ml.dmlc.xgboost4j.java;
import java.util.Map; import java.util.Map;
@ -10,14 +26,12 @@ import org.apache.commons.logging.LogFactory;
/** /**
* Java implementation of the Rabit tracker to coordinate distributed workers. * Java implementation of the Rabit tracker to coordinate distributed workers.
*
* The tracker must be started on driver node before running distributed jobs.
*/ */
public class RabitTracker implements ITracker { public class RabitTracker implements ITracker {
// Maybe per tracker logger? // Maybe per tracker logger?
private static final Log logger = LogFactory.getLog(RabitTracker.class); private static final Log logger = LogFactory.getLog(RabitTracker.class);
private long handle = 0; private long handle = 0;
private Thread tracker_daemon; private Thread trackerDaemon;
public RabitTracker(int numWorkers) throws XGBoostError { public RabitTracker(int numWorkers) throws XGBoostError {
this(numWorkers, ""); this(numWorkers, "");
@ -44,7 +58,7 @@ public class RabitTracker implements ITracker {
} catch (InterruptedException ex) { } catch (InterruptedException ex) {
logger.error(ex); logger.error(ex);
} finally { } finally {
this.tracker_daemon.interrupt(); this.trackerDaemon.interrupt();
} }
} }
@ -52,16 +66,14 @@ public class RabitTracker implements ITracker {
* Get environments that can be used to pass to worker. * Get environments that can be used to pass to worker.
* @return The environment settings. * @return The environment settings.
*/ */
public Map<String, Object> workerArgs() throws XGBoostError { public Map<String, Object> getWorkerArgs() throws XGBoostError {
// fixme: timeout // fixme: timeout
String[] args = new String[1]; String[] args = new String[1];
XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args)); XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args));
ObjectMapper mapper = new ObjectMapper(); ObjectMapper mapper = new ObjectMapper();
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {
};
Map<String, Object> config; Map<String, Object> config;
try { try {
config = mapper.readValue(args[0], typeRef); config = mapper.readValue(args[0], new TypeReference<Map<String, Object>>() {});
} catch (JsonProcessingException ex) { } catch (JsonProcessingException ex) {
throw new XGBoostError("Failed to get worker arguments.", ex); throw new XGBoostError("Failed to get worker arguments.", ex);
} }
@ -74,18 +86,18 @@ public class RabitTracker implements ITracker {
public boolean start() throws XGBoostError { public boolean start() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle)); XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle));
this.tracker_daemon = new Thread(() -> { this.trackerDaemon = new Thread(() -> {
try { try {
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0)); waitFor(0);
} catch (XGBoostError ex) { } catch (XGBoostError ex) {
logger.error(ex); logger.error(ex);
return; // exit the thread return; // exit the thread
} }
}); });
this.tracker_daemon.setDaemon(true); this.trackerDaemon.setDaemon(true);
this.tracker_daemon.start(); this.trackerDaemon.start();
return this.tracker_daemon.isAlive(); return this.trackerDaemon.isAlive();
} }
public void waitFor(long timeout) throws XGBoostError { public void waitFor(long timeout) throws XGBoostError {