[jvm-packages] refine tracker (#10313)
Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
parent
966dc81788
commit
932d7201f9
@ -35,16 +35,17 @@
|
||||
<maven.compiler.target>1.8</maven.compiler.target>
|
||||
<flink.version>1.19.0</flink.version>
|
||||
<junit.version>4.13.2</junit.version>
|
||||
<spark.version>3.4.1</spark.version>
|
||||
<spark.version.gpu>3.4.1</spark.version.gpu>
|
||||
<spark.version>3.5.1</spark.version>
|
||||
<spark.version.gpu>3.5.1</spark.version.gpu>
|
||||
<fasterxml.jackson.version>2.15.2</fasterxml.jackson.version>
|
||||
<scala.version>2.12.18</scala.version>
|
||||
<scala.binary.version>2.12</scala.binary.version>
|
||||
<hadoop.version>3.4.0</hadoop.version>
|
||||
<maven.wagon.http.retryHandler.count>5</maven.wagon.http.retryHandler.count>
|
||||
<log.capi.invocation>OFF</log.capi.invocation>
|
||||
<use.cuda>OFF</use.cuda>
|
||||
<cudf.version>23.12.1</cudf.version>
|
||||
<spark.rapids.version>23.12.1</spark.rapids.version>
|
||||
<cudf.version>24.04.0</cudf.version>
|
||||
<spark.rapids.version>24.04.0</spark.rapids.version>
|
||||
<cudf.classifier>cuda12</cudf.classifier>
|
||||
<scalatest.version>3.2.18</scalatest.version>
|
||||
<scala-collection-compat.version>2.12.0</scala-collection-compat.version>
|
||||
@ -489,11 +490,6 @@
|
||||
<artifactId>kryo</artifactId>
|
||||
<version>5.6.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>2.14.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-logging</groupId>
|
||||
<artifactId>commons-logging</artifactId>
|
||||
|
||||
@ -176,7 +176,7 @@ public class XGBoost {
|
||||
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
|
||||
if (tracker.start()) {
|
||||
return dtrain
|
||||
.mapPartition(new MapFunction(params, numBoostRound, tracker.workerArgs()))
|
||||
.mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerArgs()))
|
||||
.reduce((x, y) -> x)
|
||||
.collect()
|
||||
.get(0);
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021-2022 by Contributors
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -29,7 +29,7 @@ import org.apache.spark.{SparkContext, TaskContext}
|
||||
import org.apache.spark.ml.{Estimator, Model}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
||||
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
|
||||
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
|
||||
import org.apache.spark.sql.functions.{col, collect_list, struct}
|
||||
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
||||
@ -444,7 +444,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
.groupBy(groupName)
|
||||
.agg(collect_list(struct(schema.fieldNames.map(col): _*)) as "list")
|
||||
|
||||
implicit val encoder = RowEncoder(schema)
|
||||
implicit val encoder = ExpressionEncoder(RowEncoder.encoderFor(schema, false))
|
||||
// Expand the grouped rows after repartition
|
||||
repartitionInputData(groupedDF, nWorkers).mapPartitions(iter => {
|
||||
new Iterator[Row] {
|
||||
|
||||
@ -233,24 +233,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
||||
xgbExecParam.setRawParamMap(overridedParams)
|
||||
xgbExecParam
|
||||
}
|
||||
|
||||
private[spark] def buildRabitParams : Map[String, String] = Map(
|
||||
"rabit_reduce_ring_mincount" ->
|
||||
overridedParams.getOrElse("rabit_ring_reduce_threshold", 32 << 10).toString,
|
||||
"rabit_debug" ->
|
||||
(overridedParams.getOrElse("verbosity", 0).toString.toInt == 3).toString,
|
||||
"rabit_timeout" ->
|
||||
(overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0).toString,
|
||||
"rabit_timeout_sec" -> {
|
||||
if (overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0) {
|
||||
overridedParams.get("rabit_timeout").toString
|
||||
} else {
|
||||
"1800"
|
||||
}
|
||||
},
|
||||
"DMLC_WORKER_CONNECT_RETRY" ->
|
||||
overridedParams.getOrElse("dmlc_worker_connect_retry", 5).toString
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -475,17 +457,15 @@ object XGBoost extends XGBoostStageLevel {
|
||||
}
|
||||
}
|
||||
|
||||
/** visiable for testing */
|
||||
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
|
||||
val tracker: ITracker = new RabitTracker(
|
||||
nWorkers, trackerConf.hostIp, trackerConf.port, trackerConf.timeout)
|
||||
tracker
|
||||
}
|
||||
|
||||
private def startTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
|
||||
val tracker = getTracker(nWorkers, trackerConf)
|
||||
// Executes the provided code block inside a tracker and then stops the tracker
|
||||
private def withTracker[T](nWorkers: Int, conf: TrackerConf)(block: ITracker => T): T = {
|
||||
val tracker = new RabitTracker(nWorkers, conf.hostIp, conf.port, conf.timeout)
|
||||
require(tracker.start(), "FAULT: Failed to start tracker")
|
||||
tracker
|
||||
try {
|
||||
block(tracker)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -501,28 +481,27 @@ object XGBoost extends XGBoostStageLevel {
|
||||
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
||||
|
||||
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc)
|
||||
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
|
||||
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
|
||||
val runtimeParams = xgbParamsFactory.buildXGBRuntimeParams
|
||||
|
||||
val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam =>
|
||||
val prevBooster = runtimeParams.checkpointParam.map { checkpointParam =>
|
||||
val checkpointManager = new ExternalCheckpointManager(
|
||||
checkpointParam.checkpointPath,
|
||||
FileSystem.get(sc.hadoopConfiguration))
|
||||
checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
|
||||
checkpointManager.cleanUpHigherVersions(runtimeParams.numRounds)
|
||||
checkpointManager.loadCheckpointAsScalaBooster()
|
||||
}.orNull
|
||||
|
||||
// Get the training data RDD and the cachedRDD
|
||||
val (trainingRDD, optionalCachedRDD) = buildTrainingData(xgbExecParams)
|
||||
val (trainingRDD, optionalCachedRDD) = buildTrainingData(runtimeParams)
|
||||
|
||||
try {
|
||||
// Train for every ${savingRound} rounds and save the partially completed booster
|
||||
val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
val (booster, metrics) = try {
|
||||
tracker.workerArgs().putAll(xgbRabitParams)
|
||||
val rabitEnv = tracker.workerArgs
|
||||
val (booster, metrics) = withTracker(
|
||||
runtimeParams.numWorkers,
|
||||
runtimeParams.trackerConf
|
||||
) { tracker =>
|
||||
val rabitEnv = tracker.getWorkerArgs()
|
||||
|
||||
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => {
|
||||
val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter =>
|
||||
var optionWatches: Option[() => Watches] = None
|
||||
|
||||
// take the first Watches to train
|
||||
@ -530,26 +509,25 @@ object XGBoost extends XGBoostStageLevel {
|
||||
optionWatches = Some(iter.next())
|
||||
}
|
||||
|
||||
optionWatches.map { buildWatches => buildDistributedBooster(buildWatches,
|
||||
xgbExecParams, rabitEnv, xgbExecParams.obj, xgbExecParams.eval, prevBooster)}
|
||||
.getOrElse(throw new RuntimeException("No Watches to train"))
|
||||
optionWatches.map { buildWatches =>
|
||||
buildDistributedBooster(buildWatches,
|
||||
runtimeParams, rabitEnv, runtimeParams.obj, runtimeParams.eval, prevBooster)
|
||||
}.getOrElse(throw new RuntimeException("No Watches to train"))
|
||||
}
|
||||
|
||||
}}
|
||||
|
||||
val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, xgbExecParams,
|
||||
val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, runtimeParams,
|
||||
boostersAndMetrics)
|
||||
// The repartition step is to make training stage as ShuffleMapStage, so that when one
|
||||
// of the training task fails the training stage can retry. ResultStage won't retry when
|
||||
// it fails.
|
||||
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
|
||||
(booster, metrics)
|
||||
} finally {
|
||||
tracker.stop()
|
||||
}
|
||||
|
||||
// we should delete the checkpoint directory after a successful training
|
||||
xgbExecParams.checkpointParam.foreach {
|
||||
runtimeParams.checkpointParam.foreach {
|
||||
cpParam =>
|
||||
if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) {
|
||||
if (!runtimeParams.checkpointParam.get.skipCleanCheckpoint) {
|
||||
val checkpointManager = new ExternalCheckpointManager(
|
||||
cpParam.checkpointPath,
|
||||
FileSystem.get(sc.hadoopConfiguration))
|
||||
|
||||
@ -45,7 +45,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
tracker.start()
|
||||
val trackerEnvs = tracker. workerArgs
|
||||
val trackerEnvs = tracker.getWorkerArgs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
/*
|
||||
@ -84,7 +84,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
val tracker = new RabitTracker(numWorkers)
|
||||
tracker.start()
|
||||
val trackerEnvs = tracker.workerArgs
|
||||
val trackerEnvs = tracker.getWorkerArgs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
|
||||
|
||||
@ -53,6 +53,12 @@
|
||||
<version>${scalatest.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>${fasterxml.jackson.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
|
||||
@ -7,7 +7,7 @@ import java.util.Map;
|
||||
*
|
||||
* - start(timeout): Start the tracker awaiting for worker connections, with a given
|
||||
* timeout value (in seconds).
|
||||
* - workerArgs(): Return the arguments needed to initialize Rabit clients.
|
||||
* - getWorkerArgs(): Return the arguments needed to initialize Rabit clients.
|
||||
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
|
||||
* milliseconds.
|
||||
*
|
||||
@ -21,21 +21,8 @@ import java.util.Map;
|
||||
* brokers connections between workers.
|
||||
*/
|
||||
public interface ITracker extends Thread.UncaughtExceptionHandler {
|
||||
enum TrackerStatus {
|
||||
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
|
||||
|
||||
private int statusCode;
|
||||
|
||||
TrackerStatus(int statusCode) {
|
||||
this.statusCode = statusCode;
|
||||
}
|
||||
|
||||
public int getStatusCode() {
|
||||
return this.statusCode;
|
||||
}
|
||||
}
|
||||
|
||||
Map<String, Object> workerArgs() throws XGBoostError;
|
||||
Map<String, Object> getWorkerArgs() throws XGBoostError;
|
||||
|
||||
boolean start() throws XGBoostError;
|
||||
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
/*
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Map;
|
||||
@ -10,14 +26,12 @@ import org.apache.commons.logging.LogFactory;
|
||||
|
||||
/**
|
||||
* Java implementation of the Rabit tracker to coordinate distributed workers.
|
||||
*
|
||||
* The tracker must be started on driver node before running distributed jobs.
|
||||
*/
|
||||
public class RabitTracker implements ITracker {
|
||||
// Maybe per tracker logger?
|
||||
private static final Log logger = LogFactory.getLog(RabitTracker.class);
|
||||
private long handle = 0;
|
||||
private Thread tracker_daemon;
|
||||
private Thread trackerDaemon;
|
||||
|
||||
public RabitTracker(int numWorkers) throws XGBoostError {
|
||||
this(numWorkers, "");
|
||||
@ -44,7 +58,7 @@ public class RabitTracker implements ITracker {
|
||||
} catch (InterruptedException ex) {
|
||||
logger.error(ex);
|
||||
} finally {
|
||||
this.tracker_daemon.interrupt();
|
||||
this.trackerDaemon.interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
@ -52,16 +66,14 @@ public class RabitTracker implements ITracker {
|
||||
* Get environments that can be used to pass to worker.
|
||||
* @return The environment settings.
|
||||
*/
|
||||
public Map<String, Object> workerArgs() throws XGBoostError {
|
||||
public Map<String, Object> getWorkerArgs() throws XGBoostError {
|
||||
// fixme: timeout
|
||||
String[] args = new String[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args));
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {
|
||||
};
|
||||
Map<String, Object> config;
|
||||
try {
|
||||
config = mapper.readValue(args[0], typeRef);
|
||||
config = mapper.readValue(args[0], new TypeReference<Map<String, Object>>() {});
|
||||
} catch (JsonProcessingException ex) {
|
||||
throw new XGBoostError("Failed to get worker arguments.", ex);
|
||||
}
|
||||
@ -74,18 +86,18 @@ public class RabitTracker implements ITracker {
|
||||
|
||||
public boolean start() throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle));
|
||||
this.tracker_daemon = new Thread(() -> {
|
||||
this.trackerDaemon = new Thread(() -> {
|
||||
try {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0));
|
||||
waitFor(0);
|
||||
} catch (XGBoostError ex) {
|
||||
logger.error(ex);
|
||||
return; // exit the thread
|
||||
}
|
||||
});
|
||||
this.tracker_daemon.setDaemon(true);
|
||||
this.tracker_daemon.start();
|
||||
this.trackerDaemon.setDaemon(true);
|
||||
this.trackerDaemon.start();
|
||||
|
||||
return this.tracker_daemon.isAlive();
|
||||
return this.trackerDaemon.isAlive();
|
||||
}
|
||||
|
||||
public void waitFor(long timeout) throws XGBoostError {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user