[jvm-packages] add hostIp and python exec for rabit tracker (#7808)
This commit is contained in:
@@ -30,6 +30,8 @@ public class RabitTracker implements IRabitTracker {
|
||||
private Map<String, String> envs = new HashMap<String, String>();
|
||||
// number of workers to be submitted.
|
||||
private int numWorkers;
|
||||
private String hostIp = "";
|
||||
private String pythonExec = "";
|
||||
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
|
||||
|
||||
static {
|
||||
@@ -85,6 +87,13 @@ public class RabitTracker implements IRabitTracker {
|
||||
this.numWorkers = numWorkers;
|
||||
}
|
||||
|
||||
public RabitTracker(int numWorkers, String hostIp, String pythonExec)
|
||||
throws XGBoostError {
|
||||
this(numWorkers);
|
||||
this.hostIp = hostIp;
|
||||
this.pythonExec = pythonExec;
|
||||
}
|
||||
|
||||
public void uncaughtException(Thread t, Throwable e) {
|
||||
logger.error("Uncaught exception thrown by worker:", e);
|
||||
try {
|
||||
@@ -126,12 +135,34 @@ public class RabitTracker implements IRabitTracker {
|
||||
}
|
||||
}
|
||||
|
||||
/** visible for testing */
|
||||
public String getRabitTrackerCommand() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
if (pythonExec == null || pythonExec.isEmpty()) {
|
||||
sb.append("python ");
|
||||
} else {
|
||||
sb.append(pythonExec + " ");
|
||||
}
|
||||
sb.append(" " + tracker_py + " ");
|
||||
sb.append(" --log-level=DEBUG" + " ");
|
||||
sb.append(" --num-workers=" + numWorkers + " ");
|
||||
|
||||
// we first check the property then check the parameter
|
||||
String hostIpFromProperties = trackerProperties.getHostIp();
|
||||
if(hostIpFromProperties != null && !hostIpFromProperties.isEmpty()) {
|
||||
logger.debug("Using provided host-ip: " + hostIpFromProperties + " from properties");
|
||||
sb.append(" --host-ip=" + hostIpFromProperties + " ");
|
||||
} else if (hostIp != null & !hostIp.isEmpty()) {
|
||||
logger.debug("Using the parametr host-ip: " + hostIp);
|
||||
sb.append(" --host-ip=" + hostIp + " ");
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
private boolean startTrackerProcess() {
|
||||
try {
|
||||
String trackerExecString = this.addTrackerProperties("python " + tracker_py +
|
||||
" --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers));
|
||||
|
||||
trackerProcess.set(Runtime.getRuntime().exec(trackerExecString));
|
||||
String cmd = getRabitTrackerCommand();
|
||||
trackerProcess.set(Runtime.getRuntime().exec(cmd));
|
||||
loadEnvs(trackerProcess.get().getInputStream());
|
||||
return true;
|
||||
} catch (IOException ioe) {
|
||||
@@ -140,18 +171,6 @@ public class RabitTracker implements IRabitTracker {
|
||||
}
|
||||
}
|
||||
|
||||
private String addTrackerProperties(String trackerExecString) {
|
||||
StringBuilder sb = new StringBuilder(trackerExecString);
|
||||
String hostIp = trackerProperties.getHostIp();
|
||||
|
||||
if(hostIp != null && !hostIp.isEmpty()){
|
||||
logger.debug("Using provided host-ip: " + hostIp);
|
||||
sb.append(" --host-ip=").append(hostIp);
|
||||
}
|
||||
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
public void stop() {
|
||||
if (trackerProcess.get() != null) {
|
||||
trackerProcess.get().destroy();
|
||||
|
||||
Reference in New Issue
Block a user