Parameterize host-ip to pass to tracker.py (#2831)
This commit is contained in:
parent
602b34ab91
commit
0ad20f8fe0
@ -27,6 +27,7 @@ public class RabitTracker implements IRabitTracker {
|
||||
private static final Log logger = LogFactory.getLog(RabitTracker.class);
|
||||
// tracker python file.
|
||||
private static String tracker_py = null;
|
||||
private static TrackerProperties trackerProperties = TrackerProperties.getInstance();
|
||||
// environment variable to be pased.
|
||||
private Map<String, String> envs = new HashMap<String, String>();
|
||||
// number of workers to be submitted.
|
||||
@ -129,8 +130,10 @@ public class RabitTracker implements IRabitTracker {
|
||||
|
||||
private boolean startTrackerProcess() {
|
||||
try {
|
||||
trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py +
|
||||
" --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers)));
|
||||
String trackerExecString = this.addTrackerProperties("python " + tracker_py +
|
||||
" --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers));
|
||||
|
||||
trackerProcess.set(Runtime.getRuntime().exec(trackerExecString));
|
||||
loadEnvs(trackerProcess.get().getInputStream());
|
||||
return true;
|
||||
} catch (IOException ioe) {
|
||||
@ -139,6 +142,18 @@ public class RabitTracker implements IRabitTracker {
|
||||
}
|
||||
}
|
||||
|
||||
private String addTrackerProperties(String trackerExecString) {
|
||||
StringBuilder sb = new StringBuilder(trackerExecString);
|
||||
String hostIp = trackerProperties.getHostIp();
|
||||
|
||||
if(hostIp != null && !hostIp.isEmpty()){
|
||||
logger.debug("Using provided host-ip: " + hostIp);
|
||||
sb.append(" --host-ip=").append(hostIp);
|
||||
}
|
||||
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
public void stop() {
|
||||
if (trackerProcess.get() != null) {
|
||||
trackerProcess.get().destroy();
|
||||
|
||||
@ -0,0 +1,56 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.net.URL;
|
||||
import java.util.Properties;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
public class TrackerProperties {
|
||||
private static String PROPERTIES_FILENAME = "xgboost-tracker.properties";
|
||||
private static String HOST_IP = "host-ip";
|
||||
|
||||
private static final Log logger = LogFactory.getLog(TrackerProperties.class);
|
||||
private static TrackerProperties instance = new TrackerProperties();
|
||||
|
||||
private Properties properties;
|
||||
|
||||
private TrackerProperties() {
|
||||
this.properties = new Properties();
|
||||
|
||||
InputStream inputStream = null;
|
||||
|
||||
try {
|
||||
URL propertiesFileURL =
|
||||
Thread.currentThread().getContextClassLoader().getResource(PROPERTIES_FILENAME);
|
||||
if (propertiesFileURL != null){
|
||||
inputStream = propertiesFileURL.openStream();
|
||||
}
|
||||
} catch (IOException e) {
|
||||
logger.warn("Could not load " + PROPERTIES_FILENAME + " file. ", e);
|
||||
}
|
||||
|
||||
if(inputStream != null){
|
||||
try {
|
||||
properties.load(inputStream);
|
||||
logger.debug("Loaded properties from external source");
|
||||
} catch (IOException e) {
|
||||
logger.error("Error loading tracker properties file. Skipping and using defaults. ", e);
|
||||
}
|
||||
try {
|
||||
inputStream.close();
|
||||
} catch (IOException e) {
|
||||
// ignore exception
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static TrackerProperties getInstance() {
|
||||
return instance;
|
||||
}
|
||||
|
||||
public String getHostIp(){
|
||||
return this.properties.getProperty(HOST_IP);
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user