Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan
2024-05-20 11:56:23 +08:00
committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
195 changed files with 2768 additions and 9234 deletions

View File

@@ -54,9 +54,9 @@ public class XGBoost {
private final Map<String, Object> params;
private final int round;
private final Map<String, String> workerEnvs;
private final Map<String, Object> workerEnvs;
public MapFunction(Map<String, Object> params, int round, Map<String, String> workerEnvs) {
public MapFunction(Map<String, Object> params, int round, Map<String, Object> workerEnvs) {
this.params = params;
this.round = round;
this.workerEnvs = workerEnvs;
@@ -174,9 +174,9 @@ public class XGBoost {
int numBoostRound) throws Exception {
final RabitTracker tracker =
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
if (tracker.start(0L)) {
if (tracker.start()) {
return dtrain
.mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs()))
.mapPartition(new MapFunction(params, numBoostRound, tracker.workerArgs()))
.reduce((x, y) -> x)
.collect()
.get(0);