[jvm-packages] refine tracker (#10313)
Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -53,6 +53,12 @@
|
||||
<version>${scalatest.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>${fasterxml.jackson.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
|
||||
@@ -7,7 +7,7 @@ import java.util.Map;
|
||||
*
|
||||
* - start(timeout): Start the tracker awaiting for worker connections, with a given
|
||||
* timeout value (in seconds).
|
||||
* - workerArgs(): Return the arguments needed to initialize Rabit clients.
|
||||
* - getWorkerArgs(): Return the arguments needed to initialize Rabit clients.
|
||||
* - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout`
|
||||
* milliseconds.
|
||||
*
|
||||
@@ -21,21 +21,8 @@ import java.util.Map;
|
||||
* brokers connections between workers.
|
||||
*/
|
||||
public interface ITracker extends Thread.UncaughtExceptionHandler {
|
||||
enum TrackerStatus {
|
||||
SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3);
|
||||
|
||||
private int statusCode;
|
||||
|
||||
TrackerStatus(int statusCode) {
|
||||
this.statusCode = statusCode;
|
||||
}
|
||||
|
||||
public int getStatusCode() {
|
||||
return this.statusCode;
|
||||
}
|
||||
}
|
||||
|
||||
Map<String, Object> workerArgs() throws XGBoostError;
|
||||
Map<String, Object> getWorkerArgs() throws XGBoostError;
|
||||
|
||||
boolean start() throws XGBoostError;
|
||||
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
/*
|
||||
Copyright (c) 2014-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Map;
|
||||
@@ -10,14 +26,12 @@ import org.apache.commons.logging.LogFactory;
|
||||
|
||||
/**
|
||||
* Java implementation of the Rabit tracker to coordinate distributed workers.
|
||||
*
|
||||
* The tracker must be started on driver node before running distributed jobs.
|
||||
*/
|
||||
public class RabitTracker implements ITracker {
|
||||
// Maybe per tracker logger?
|
||||
private static final Log logger = LogFactory.getLog(RabitTracker.class);
|
||||
private long handle = 0;
|
||||
private Thread tracker_daemon;
|
||||
private Thread trackerDaemon;
|
||||
|
||||
public RabitTracker(int numWorkers) throws XGBoostError {
|
||||
this(numWorkers, "");
|
||||
@@ -44,7 +58,7 @@ public class RabitTracker implements ITracker {
|
||||
} catch (InterruptedException ex) {
|
||||
logger.error(ex);
|
||||
} finally {
|
||||
this.tracker_daemon.interrupt();
|
||||
this.trackerDaemon.interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,16 +66,14 @@ public class RabitTracker implements ITracker {
|
||||
* Get environments that can be used to pass to worker.
|
||||
* @return The environment settings.
|
||||
*/
|
||||
public Map<String, Object> workerArgs() throws XGBoostError {
|
||||
public Map<String, Object> getWorkerArgs() throws XGBoostError {
|
||||
// fixme: timeout
|
||||
String[] args = new String[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args));
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {
|
||||
};
|
||||
Map<String, Object> config;
|
||||
try {
|
||||
config = mapper.readValue(args[0], typeRef);
|
||||
config = mapper.readValue(args[0], new TypeReference<Map<String, Object>>() {});
|
||||
} catch (JsonProcessingException ex) {
|
||||
throw new XGBoostError("Failed to get worker arguments.", ex);
|
||||
}
|
||||
@@ -74,18 +86,18 @@ public class RabitTracker implements ITracker {
|
||||
|
||||
public boolean start() throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle));
|
||||
this.tracker_daemon = new Thread(() -> {
|
||||
this.trackerDaemon = new Thread(() -> {
|
||||
try {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0));
|
||||
waitFor(0);
|
||||
} catch (XGBoostError ex) {
|
||||
logger.error(ex);
|
||||
return; // exit the thread
|
||||
}
|
||||
});
|
||||
this.tracker_daemon.setDaemon(true);
|
||||
this.tracker_daemon.start();
|
||||
this.trackerDaemon.setDaemon(true);
|
||||
this.trackerDaemon.start();
|
||||
|
||||
return this.tracker_daemon.isAlive();
|
||||
return this.trackerDaemon.isAlive();
|
||||
}
|
||||
|
||||
public void waitFor(long timeout) throws XGBoostError {
|
||||
|
||||
Reference in New Issue
Block a user