Parameterize host-ip to pass to tracker.py (#2831)

This commit is contained in:
avinocur 2017-11-29 16:14:34 -03:00 committed by Tianqi Chen
parent 602b34ab91
commit 0ad20f8fe0
2 changed files with 73 additions and 2 deletions

View File

@ -27,6 +27,7 @@ public class RabitTracker implements IRabitTracker {
private static final Log logger = LogFactory.getLog(RabitTracker.class); private static final Log logger = LogFactory.getLog(RabitTracker.class);
// tracker python file. // tracker python file.
private static String tracker_py = null; private static String tracker_py = null;
private static TrackerProperties trackerProperties = TrackerProperties.getInstance();
// environment variable to be pased. // environment variable to be pased.
private Map<String, String> envs = new HashMap<String, String>(); private Map<String, String> envs = new HashMap<String, String>();
// number of workers to be submitted. // number of workers to be submitted.
@ -129,8 +130,10 @@ public class RabitTracker implements IRabitTracker {
private boolean startTrackerProcess() { private boolean startTrackerProcess() {
try { try {
trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py + String trackerExecString = this.addTrackerProperties("python " + tracker_py +
" --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers))); " --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers));
trackerProcess.set(Runtime.getRuntime().exec(trackerExecString));
loadEnvs(trackerProcess.get().getInputStream()); loadEnvs(trackerProcess.get().getInputStream());
return true; return true;
} catch (IOException ioe) { } catch (IOException ioe) {
@ -139,6 +142,18 @@ public class RabitTracker implements IRabitTracker {
} }
} }
private String addTrackerProperties(String trackerExecString) {
StringBuilder sb = new StringBuilder(trackerExecString);
String hostIp = trackerProperties.getHostIp();
if(hostIp != null && !hostIp.isEmpty()){
logger.debug("Using provided host-ip: " + hostIp);
sb.append(" --host-ip=").append(hostIp);
}
return sb.toString();
}
public void stop() { public void stop() {
if (trackerProcess.get() != null) { if (trackerProcess.get() != null) {
trackerProcess.get().destroy(); trackerProcess.get().destroy();

View File

@ -0,0 +1,56 @@
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.net.URL;
import java.util.Properties;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
public class TrackerProperties {
private static String PROPERTIES_FILENAME = "xgboost-tracker.properties";
private static String HOST_IP = "host-ip";
private static final Log logger = LogFactory.getLog(TrackerProperties.class);
private static TrackerProperties instance = new TrackerProperties();
private Properties properties;
private TrackerProperties() {
this.properties = new Properties();
InputStream inputStream = null;
try {
URL propertiesFileURL =
Thread.currentThread().getContextClassLoader().getResource(PROPERTIES_FILENAME);
if (propertiesFileURL != null){
inputStream = propertiesFileURL.openStream();
}
} catch (IOException e) {
logger.warn("Could not load " + PROPERTIES_FILENAME + " file. ", e);
}
if(inputStream != null){
try {
properties.load(inputStream);
logger.debug("Loaded properties from external source");
} catch (IOException e) {
logger.error("Error loading tracker properties file. Skipping and using defaults. ", e);
}
try {
inputStream.close();
} catch (IOException e) {
// ignore exception
}
}
}
public static TrackerProperties getInstance() {
return instance;
}
public String getHostIp(){
return this.properties.getProperty(HOST_IP);
}
}