Parameterize host-ip to pass to tracker.py (#2831)
This commit is contained in:
parent
602b34ab91
commit
0ad20f8fe0
@ -27,6 +27,7 @@ public class RabitTracker implements IRabitTracker {
|
|||||||
private static final Log logger = LogFactory.getLog(RabitTracker.class);
|
private static final Log logger = LogFactory.getLog(RabitTracker.class);
|
||||||
// tracker python file.
|
// tracker python file.
|
||||||
private static String tracker_py = null;
|
private static String tracker_py = null;
|
||||||
|
private static TrackerProperties trackerProperties = TrackerProperties.getInstance();
|
||||||
// environment variable to be pased.
|
// environment variable to be pased.
|
||||||
private Map<String, String> envs = new HashMap<String, String>();
|
private Map<String, String> envs = new HashMap<String, String>();
|
||||||
// number of workers to be submitted.
|
// number of workers to be submitted.
|
||||||
@ -129,8 +130,10 @@ public class RabitTracker implements IRabitTracker {
|
|||||||
|
|
||||||
private boolean startTrackerProcess() {
|
private boolean startTrackerProcess() {
|
||||||
try {
|
try {
|
||||||
trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py +
|
String trackerExecString = this.addTrackerProperties("python " + tracker_py +
|
||||||
" --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers)));
|
" --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers));
|
||||||
|
|
||||||
|
trackerProcess.set(Runtime.getRuntime().exec(trackerExecString));
|
||||||
loadEnvs(trackerProcess.get().getInputStream());
|
loadEnvs(trackerProcess.get().getInputStream());
|
||||||
return true;
|
return true;
|
||||||
} catch (IOException ioe) {
|
} catch (IOException ioe) {
|
||||||
@ -139,6 +142,18 @@ public class RabitTracker implements IRabitTracker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private String addTrackerProperties(String trackerExecString) {
|
||||||
|
StringBuilder sb = new StringBuilder(trackerExecString);
|
||||||
|
String hostIp = trackerProperties.getHostIp();
|
||||||
|
|
||||||
|
if(hostIp != null && !hostIp.isEmpty()){
|
||||||
|
logger.debug("Using provided host-ip: " + hostIp);
|
||||||
|
sb.append(" --host-ip=").append(hostIp);
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.toString();
|
||||||
|
}
|
||||||
|
|
||||||
public void stop() {
|
public void stop() {
|
||||||
if (trackerProcess.get() != null) {
|
if (trackerProcess.get() != null) {
|
||||||
trackerProcess.get().destroy();
|
trackerProcess.get().destroy();
|
||||||
|
|||||||
@ -0,0 +1,56 @@
|
|||||||
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
|
import java.io.*;
|
||||||
|
import java.net.URL;
|
||||||
|
import java.util.Properties;
|
||||||
|
|
||||||
|
import org.apache.commons.logging.Log;
|
||||||
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
|
||||||
|
public class TrackerProperties {
|
||||||
|
private static String PROPERTIES_FILENAME = "xgboost-tracker.properties";
|
||||||
|
private static String HOST_IP = "host-ip";
|
||||||
|
|
||||||
|
private static final Log logger = LogFactory.getLog(TrackerProperties.class);
|
||||||
|
private static TrackerProperties instance = new TrackerProperties();
|
||||||
|
|
||||||
|
private Properties properties;
|
||||||
|
|
||||||
|
private TrackerProperties() {
|
||||||
|
this.properties = new Properties();
|
||||||
|
|
||||||
|
InputStream inputStream = null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
URL propertiesFileURL =
|
||||||
|
Thread.currentThread().getContextClassLoader().getResource(PROPERTIES_FILENAME);
|
||||||
|
if (propertiesFileURL != null){
|
||||||
|
inputStream = propertiesFileURL.openStream();
|
||||||
|
}
|
||||||
|
} catch (IOException e) {
|
||||||
|
logger.warn("Could not load " + PROPERTIES_FILENAME + " file. ", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(inputStream != null){
|
||||||
|
try {
|
||||||
|
properties.load(inputStream);
|
||||||
|
logger.debug("Loaded properties from external source");
|
||||||
|
} catch (IOException e) {
|
||||||
|
logger.error("Error loading tracker properties file. Skipping and using defaults. ", e);
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
inputStream.close();
|
||||||
|
} catch (IOException e) {
|
||||||
|
// ignore exception
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static TrackerProperties getInstance() {
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getHostIp(){
|
||||||
|
return this.properties.getProperty(HOST_IP);
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user