diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java index 888d501db..58b9b2500 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java @@ -27,6 +27,7 @@ public class RabitTracker implements IRabitTracker { private static final Log logger = LogFactory.getLog(RabitTracker.class); // tracker python file. private static String tracker_py = null; + private static TrackerProperties trackerProperties = TrackerProperties.getInstance(); // environment variable to be pased. private Map envs = new HashMap(); // number of workers to be submitted. @@ -129,8 +130,10 @@ public class RabitTracker implements IRabitTracker { private boolean startTrackerProcess() { try { - trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py + - " --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers))); + String trackerExecString = this.addTrackerProperties("python " + tracker_py + + " --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers)); + + trackerProcess.set(Runtime.getRuntime().exec(trackerExecString)); loadEnvs(trackerProcess.get().getInputStream()); return true; } catch (IOException ioe) { @@ -139,6 +142,18 @@ public class RabitTracker implements IRabitTracker { } } + private String addTrackerProperties(String trackerExecString) { + StringBuilder sb = new StringBuilder(trackerExecString); + String hostIp = trackerProperties.getHostIp(); + + if(hostIp != null && !hostIp.isEmpty()){ + logger.debug("Using provided host-ip: " + hostIp); + sb.append(" --host-ip=").append(hostIp); + } + + return sb.toString(); + } + public void stop() { if (trackerProcess.get() != null) { trackerProcess.get().destroy(); diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/TrackerProperties.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/TrackerProperties.java new file mode 100644 index 000000000..45a6b1e06 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/TrackerProperties.java @@ -0,0 +1,56 @@ +package ml.dmlc.xgboost4j.java; + +import java.io.*; +import java.net.URL; +import java.util.Properties; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +public class TrackerProperties { + private static String PROPERTIES_FILENAME = "xgboost-tracker.properties"; + private static String HOST_IP = "host-ip"; + + private static final Log logger = LogFactory.getLog(TrackerProperties.class); + private static TrackerProperties instance = new TrackerProperties(); + + private Properties properties; + + private TrackerProperties() { + this.properties = new Properties(); + + InputStream inputStream = null; + + try { + URL propertiesFileURL = + Thread.currentThread().getContextClassLoader().getResource(PROPERTIES_FILENAME); + if (propertiesFileURL != null){ + inputStream = propertiesFileURL.openStream(); + } + } catch (IOException e) { + logger.warn("Could not load " + PROPERTIES_FILENAME + " file. ", e); + } + + if(inputStream != null){ + try { + properties.load(inputStream); + logger.debug("Loaded properties from external source"); + } catch (IOException e) { + logger.error("Error loading tracker properties file. Skipping and using defaults. ", e); + } + try { + inputStream.close(); + } catch (IOException e) { + // ignore exception + } + } + } + + public static TrackerProperties getInstance() { + return instance; + } + + public String getHostIp(){ + return this.properties.getProperty(HOST_IP); + } +}