[jvm-packages] refine tracker (#10313)

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Bobby Wang
2024-05-23 12:46:21 +08:00
committed by GitHub
parent 966dc81788
commit 932d7201f9
8 changed files with 71 additions and 92 deletions

View File

@@ -53,6 +53,12 @@
<version>${scalatest.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${fasterxml.jackson.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>
<build>

View File

@@ -7,7 +7,7 @@ import java.util.Map;
*
* - start(timeout): Start the tracker awaiting for worker connections, with a given
* timeout value (in seconds).
* - workerArgs(): Return the arguments needed to initialize Rabit clients.
* - getWorkerArgs(): Return the arguments needed to initialize Rabit clients.
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
* milliseconds.
*
@@ -21,21 +21,8 @@ import java.util.Map;
* brokers connections between workers.
*/
public interface ITracker extends Thread.UncaughtExceptionHandler {
enum TrackerStatus {
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
private int statusCode;
TrackerStatus(int statusCode) {
this.statusCode = statusCode;
}
public int getStatusCode() {
return this.statusCode;
}
}
Map<String, Object> workerArgs() throws XGBoostError;
Map<String, Object> getWorkerArgs() throws XGBoostError;
boolean start() throws XGBoostError;

View File

@@ -1,3 +1,19 @@
/*
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;
import java.util.Map;
@@ -10,14 +26,12 @@ import org.apache.commons.logging.LogFactory;
/**
* Java implementation of the Rabit tracker to coordinate distributed workers.
*
* The tracker must be started on driver node before running distributed jobs.
*/
public class RabitTracker implements ITracker {
// Maybe per tracker logger?
private static final Log logger = LogFactory.getLog(RabitTracker.class);
private long handle = 0;
private Thread tracker_daemon;
private Thread trackerDaemon;
public RabitTracker(int numWorkers) throws XGBoostError {
this(numWorkers, "");
@@ -44,7 +58,7 @@ public class RabitTracker implements ITracker {
} catch (InterruptedException ex) {
logger.error(ex);
} finally {
this.tracker_daemon.interrupt();
this.trackerDaemon.interrupt();
}
}
@@ -52,16 +66,14 @@ public class RabitTracker implements ITracker {
* Get environments that can be used to pass to worker.
* @return The environment settings.
*/
public Map<String, Object> workerArgs() throws XGBoostError {
public Map<String, Object> getWorkerArgs() throws XGBoostError {
// fixme: timeout
String[] args = new String[1];
XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args));
ObjectMapper mapper = new ObjectMapper();
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {
};
Map<String, Object> config;
try {
config = mapper.readValue(args[0], typeRef);
config = mapper.readValue(args[0], new TypeReference<Map<String, Object>>() {});
} catch (JsonProcessingException ex) {
throw new XGBoostError("Failed to get worker arguments.", ex);
}
@@ -74,18 +86,18 @@ public class RabitTracker implements ITracker {
public boolean start() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle));
this.tracker_daemon = new Thread(() -> {
this.trackerDaemon = new Thread(() -> {
try {
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0));
waitFor(0);
} catch (XGBoostError ex) {
logger.error(ex);
return; // exit the thread
}
});
this.tracker_daemon.setDaemon(true);
this.tracker_daemon.start();
this.trackerDaemon.setDaemon(true);
this.trackerDaemon.start();
return this.tracker_daemon.isAlive();
return this.trackerDaemon.isAlive();
}
public void waitFor(long timeout) throws XGBoostError {