diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index 17afbe48d..684a178b9 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -35,16 +35,17 @@
1.8
1.19.0
4.13.2
- 3.4.1
- 3.4.1
+ 3.5.1
+ 3.5.1
+ 2.15.2
2.12.18
2.12
3.4.0
5
OFF
OFF
- 23.12.1
- 23.12.1
+ 24.04.0
+ 24.04.0
cuda12
3.2.18
2.12.0
@@ -489,11 +490,6 @@
kryo
5.6.0
-
- com.fasterxml.jackson.core
- jackson-databind
- 2.14.2
-
commons-logging
commons-logging
diff --git a/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java
index 99608b927..a660bca88 100644
--- a/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java
+++ b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java
@@ -176,7 +176,7 @@ public class XGBoost {
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
if (tracker.start()) {
return dtrain
- .mapPartition(new MapFunction(params, numBoostRound, tracker.workerArgs()))
+ .mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerArgs()))
.reduce((x, y) -> x)
.collect()
.get(0);
diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala
index 7e83dc6f1..00c547aa8 100644
--- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala
+++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala
@@ -1,5 +1,5 @@
/*
- Copyright (c) 2021-2022 by Contributors
+ Copyright (c) 2021-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -29,7 +29,7 @@ import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
-import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.functions.{col, collect_list, struct}
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
@@ -444,7 +444,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
.groupBy(groupName)
.agg(collect_list(struct(schema.fieldNames.map(col): _*)) as "list")
- implicit val encoder = RowEncoder(schema)
+ implicit val encoder = ExpressionEncoder(RowEncoder.encoderFor(schema, false))
// Expand the grouped rows after repartition
repartitionInputData(groupedDF, nWorkers).mapPartitions(iter => {
new Iterator[Row] {
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
index e17c68355..10c4b5a72 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
@@ -233,24 +233,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
xgbExecParam.setRawParamMap(overridedParams)
xgbExecParam
}
-
- private[spark] def buildRabitParams : Map[String, String] = Map(
- "rabit_reduce_ring_mincount" ->
- overridedParams.getOrElse("rabit_ring_reduce_threshold", 32 << 10).toString,
- "rabit_debug" ->
- (overridedParams.getOrElse("verbosity", 0).toString.toInt == 3).toString,
- "rabit_timeout" ->
- (overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0).toString,
- "rabit_timeout_sec" -> {
- if (overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0) {
- overridedParams.get("rabit_timeout").toString
- } else {
- "1800"
- }
- },
- "DMLC_WORKER_CONNECT_RETRY" ->
- overridedParams.getOrElse("dmlc_worker_connect_retry", 5).toString
- )
}
/**
@@ -475,17 +457,15 @@ object XGBoost extends XGBoostStageLevel {
}
}
- /** visiable for testing */
- private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
- val tracker: ITracker = new RabitTracker(
- nWorkers, trackerConf.hostIp, trackerConf.port, trackerConf.timeout)
- tracker
- }
-
- private def startTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = {
- val tracker = getTracker(nWorkers, trackerConf)
+ // Executes the provided code block inside a tracker and then stops the tracker
+ private def withTracker[T](nWorkers: Int, conf: TrackerConf)(block: ITracker => T): T = {
+ val tracker = new RabitTracker(nWorkers, conf.hostIp, conf.port, conf.timeout)
require(tracker.start(), "FAULT: Failed to start tracker")
- tracker
+ try {
+ block(tracker)
+ } finally {
+ tracker.stop()
+ }
}
/**
@@ -501,28 +481,27 @@ object XGBoost extends XGBoostStageLevel {
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc)
- val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
- val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
+ val runtimeParams = xgbParamsFactory.buildXGBRuntimeParams
- val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam =>
+ val prevBooster = runtimeParams.checkpointParam.map { checkpointParam =>
val checkpointManager = new ExternalCheckpointManager(
checkpointParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration))
- checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
+ checkpointManager.cleanUpHigherVersions(runtimeParams.numRounds)
checkpointManager.loadCheckpointAsScalaBooster()
}.orNull
// Get the training data RDD and the cachedRDD
- val (trainingRDD, optionalCachedRDD) = buildTrainingData(xgbExecParams)
+ val (trainingRDD, optionalCachedRDD) = buildTrainingData(runtimeParams)
try {
- // Train for every ${savingRound} rounds and save the partially completed booster
- val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
- val (booster, metrics) = try {
- tracker.workerArgs().putAll(xgbRabitParams)
- val rabitEnv = tracker.workerArgs
+ val (booster, metrics) = withTracker(
+ runtimeParams.numWorkers,
+ runtimeParams.trackerConf
+ ) { tracker =>
+ val rabitEnv = tracker.getWorkerArgs()
- val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => {
+ val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter =>
var optionWatches: Option[() => Watches] = None
// take the first Watches to train
@@ -530,26 +509,25 @@ object XGBoost extends XGBoostStageLevel {
optionWatches = Some(iter.next())
}
- optionWatches.map { buildWatches => buildDistributedBooster(buildWatches,
- xgbExecParams, rabitEnv, xgbExecParams.obj, xgbExecParams.eval, prevBooster)}
- .getOrElse(throw new RuntimeException("No Watches to train"))
+ optionWatches.map { buildWatches =>
+ buildDistributedBooster(buildWatches,
+ runtimeParams, rabitEnv, runtimeParams.obj, runtimeParams.eval, prevBooster)
+ }.getOrElse(throw new RuntimeException("No Watches to train"))
+ }
- }}
-
- val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, xgbExecParams,
+ val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, runtimeParams,
boostersAndMetrics)
// The repartition step is to make training stage as ShuffleMapStage, so that when one
// of the training task fails the training stage can retry. ResultStage won't retry when
// it fails.
val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
(booster, metrics)
- } finally {
- tracker.stop()
}
+
// we should delete the checkpoint directory after a successful training
- xgbExecParams.checkpointParam.foreach {
+ runtimeParams.checkpointParam.foreach {
cpParam =>
- if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) {
+ if (!runtimeParams.checkpointParam.get.skipCleanCheckpoint) {
val checkpointManager = new ExternalCheckpointManager(
cpParam.checkpointPath,
FileSystem.get(sc.hadoopConfiguration))
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala
index 108053af5..d3f3901ad 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala
@@ -45,7 +45,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
val tracker = new RabitTracker(numWorkers)
tracker.start()
- val trackerEnvs = tracker. workerArgs
+ val trackerEnvs = tracker.getWorkerArgs
val workerCount: Int = numWorkers
/*
@@ -84,7 +84,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new RabitTracker(numWorkers)
tracker.start()
- val trackerEnvs = tracker.workerArgs
+ val trackerEnvs = tracker.getWorkerArgs
val workerCount: Int = numWorkers
diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml
index 5a83a400c..e1e750866 100644
--- a/jvm-packages/xgboost4j/pom.xml
+++ b/jvm-packages/xgboost4j/pom.xml
@@ -53,6 +53,12 @@
${scalatest.version}
provided
+
+ com.fasterxml.jackson.core
+ jackson-databind
+ ${fasterxml.jackson.version}
+ provided
+
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ITracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ITracker.java
index 1bfef677d..84e535a26 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ITracker.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ITracker.java
@@ -7,7 +7,7 @@ import java.util.Map;
*
* - start(timeout): Start the tracker awaiting for worker connections, with a given
* timeout value (in seconds).
- * - workerArgs(): Return the arguments needed to initialize Rabit clients.
+ * - getWorkerArgs(): Return the arguments needed to initialize Rabit clients.
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
* milliseconds.
*
@@ -21,21 +21,8 @@ import java.util.Map;
* brokers connections between workers.
*/
public interface ITracker extends Thread.UncaughtExceptionHandler {
- enum TrackerStatus {
- SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
- private int statusCode;
-
- TrackerStatus(int statusCode) {
- this.statusCode = statusCode;
- }
-
- public int getStatusCode() {
- return this.statusCode;
- }
- }
-
- Map workerArgs() throws XGBoostError;
+ Map getWorkerArgs() throws XGBoostError;
boolean start() throws XGBoostError;
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java
index 914a493cc..48b163a77 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java
@@ -1,3 +1,19 @@
+/*
+ Copyright (c) 2014-2024 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
package ml.dmlc.xgboost4j.java;
import java.util.Map;
@@ -10,14 +26,12 @@ import org.apache.commons.logging.LogFactory;
/**
* Java implementation of the Rabit tracker to coordinate distributed workers.
- *
- * The tracker must be started on driver node before running distributed jobs.
*/
public class RabitTracker implements ITracker {
// Maybe per tracker logger?
private static final Log logger = LogFactory.getLog(RabitTracker.class);
private long handle = 0;
- private Thread tracker_daemon;
+ private Thread trackerDaemon;
public RabitTracker(int numWorkers) throws XGBoostError {
this(numWorkers, "");
@@ -44,7 +58,7 @@ public class RabitTracker implements ITracker {
} catch (InterruptedException ex) {
logger.error(ex);
} finally {
- this.tracker_daemon.interrupt();
+ this.trackerDaemon.interrupt();
}
}
@@ -52,16 +66,14 @@ public class RabitTracker implements ITracker {
* Get environments that can be used to pass to worker.
* @return The environment settings.
*/
- public Map workerArgs() throws XGBoostError {
+ public Map getWorkerArgs() throws XGBoostError {
// fixme: timeout
String[] args = new String[1];
XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args));
ObjectMapper mapper = new ObjectMapper();
- TypeReference