Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
This commit is contained in:
@@ -54,9 +54,9 @@ public class XGBoost {
|
||||
|
||||
private final Map<String, Object> params;
|
||||
private final int round;
|
||||
private final Map<String, String> workerEnvs;
|
||||
private final Map<String, Object> workerEnvs;
|
||||
|
||||
public MapFunction(Map<String, Object> params, int round, Map<String, String> workerEnvs) {
|
||||
public MapFunction(Map<String, Object> params, int round, Map<String, Object> workerEnvs) {
|
||||
this.params = params;
|
||||
this.round = round;
|
||||
this.workerEnvs = workerEnvs;
|
||||
@@ -174,9 +174,9 @@ public class XGBoost {
|
||||
int numBoostRound) throws Exception {
|
||||
final RabitTracker tracker =
|
||||
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
|
||||
if (tracker.start(0L)) {
|
||||
if (tracker.start()) {
|
||||
return dtrain
|
||||
.mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs()))
|
||||
.mapPartition(new MapFunction(params, numBoostRound, tracker.workerArgs()))
|
||||
.reduce((x, y) -> x)
|
||||
.collect()
|
||||
.get(0);
|
||||
|
||||
Reference in New Issue
Block a user