From 10a1517502c6493ab366af9c631f702f05cb050a Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 22 Dec 2015 03:29:20 -0600 Subject: [PATCH] revise the RabitTracker Impl delete FileUtil class fix bugs --- .gitignore | 1 - dmlc-core | 2 +- jvm-packages/test_distributed.sh | 0 .../ml/dmlc/xgboost4j/demo/DistTrain.java | 25 ++-- .../main/java/ml/dmlc/xgboost4j/FileUtil.java | 78 ------------ .../ml/dmlc/xgboost4j/NativeLibLoader.java | 80 +++++++++++- .../java/ml/dmlc/xgboost4j/RabitTracker.java | 120 ++++++++++++------ rabit | 2 +- 8 files changed, 168 insertions(+), 140 deletions(-) mode change 100644 => 100755 jvm-packages/test_distributed.sh delete mode 100644 jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/FileUtil.java diff --git a/.gitignore b/.gitignore index ae196175c..5533356f5 100644 --- a/.gitignore +++ b/.gitignore @@ -78,5 +78,4 @@ tags *.iml *.class target - *.swp diff --git a/dmlc-core b/dmlc-core index 3f6ff43d3..71360023d 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 3f6ff43d3976d5b6d5001608b0e3e526ecde098f +Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0 diff --git a/jvm-packages/test_distributed.sh b/jvm-packages/test_distributed.sh old mode 100644 new mode 100755 diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/DistTrain.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/DistTrain.java index 3cff4bd79..fdbbc6599 100644 --- a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/DistTrain.java +++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/DistTrain.java @@ -20,16 +20,17 @@ public class DistTrain { private Map envs = null; private class Worker implements Runnable { - private int worker_id; - Worker(int worker_id) { - this.worker_id = worker_id; + private final int workerId; + + Worker(int workerId) { + this.workerId = workerId; } public void run() { try { Map worker_env = new HashMap(envs); - worker_env.put("DMLC_TASK_ID", new Integer(worker_id).toString()); + worker_env.put("DMLC_TASK_ID", String.valueOf(workerId)); // always initialize rabit module before training. Rabit.init(worker_env); @@ -44,7 +45,6 @@ public class DistTrain { params.put("nthread", 2); params.put("objective", "binary:logistic"); - HashMap watches = new HashMap(); watches.put("train", trainMat); watches.put("test", testMat); @@ -63,14 +63,15 @@ public class DistTrain { } } - void start(int nworker) throws IOException, XGBoostError, InterruptedException { - RabitTracker tracker = new RabitTracker(nworker); - tracker.start(); - envs = tracker.getWorkerEnvs(); - for (int i = 0; i < nworker; ++i) { - new Thread(new Worker(i)).start(); + void start(int nWorkers) throws IOException, XGBoostError, InterruptedException { + RabitTracker tracker = new RabitTracker(nWorkers); + if (tracker.start()) { + envs = tracker.getWorkerEnvs(); + for (int i = 0; i < nWorkers; ++i) { + new Thread(new Worker(i)).start(); + } + tracker.waitFor(); } - tracker.waitFor(); } public static void main(String[] args) throws IOException, XGBoostError, InterruptedException { diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/FileUtil.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/FileUtil.java deleted file mode 100644 index 4b535bd2f..000000000 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/FileUtil.java +++ /dev/null @@ -1,78 +0,0 @@ -package ml.dmlc.xgboost4j; - - -import java.io.*; -import java.io.IOException; - -/** - * Auxiliary utils to - */ -class FileUtil { - /** - * Create a temp file that copies the resource from current JAR archive - *

- * The file from JAR is copied into system temp file. - * The temporary file is deleted after exiting. - * Method uses String as filename because the pathname is "abstract", not system-dependent. - *

- * The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to - * {@code path}. - * @param path Path to the resources in the jar - * @return The created temp file. - * @throws IOException - * @throws IllegalArgumentException - */ - static File createTempFileFromResource(String path) throws IOException, IllegalArgumentException { - // Obtain filename from path - if (!path.startsWith("/")) { - throw new IllegalArgumentException("The path has to be absolute (start with '/')."); - } - - String[] parts = path.split("/"); - String filename = (parts.length > 1) ? parts[parts.length - 1] : null; - - // Split filename to prexif and suffix (extension) - String prefix = ""; - String suffix = null; - if (filename != null) { - parts = filename.split("\\.", 2); - prefix = parts[0]; - suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-) - } - - // Check if the filename is okay - if (filename == null || prefix.length() < 3) { - throw new IllegalArgumentException("The filename has to be at least 3 characters long."); - } - // Prepare temporary file - File temp = File.createTempFile(prefix, suffix); - temp.deleteOnExit(); - - if (!temp.exists()) { - throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist."); - } - - // Prepare buffer for data copying - byte[] buffer = new byte[1024]; - int readBytes; - - // Open and check input stream - InputStream is = NativeLibLoader.class.getResourceAsStream(path); - if (is == null) { - throw new FileNotFoundException("File " + path + " was not found inside JAR."); - } - - // Open output stream and copy data between source file in JAR and the temporary file - OutputStream os = new FileOutputStream(temp); - try { - while ((readBytes = is.read(buffer)) != -1) { - os.write(buffer, 0, readBytes); - } - } finally { - // If read/write fails, close streams safely before throwing an exception - os.close(); - is.close(); - } - return temp; - } -} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/NativeLibLoader.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/NativeLibLoader.java index 01d846f62..796eb28aa 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/NativeLibLoader.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/NativeLibLoader.java @@ -21,8 +21,6 @@ import java.lang.reflect.Field; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import ml.dmlc.xgboost4j.FileUtil; - /** * class to load native library @@ -64,9 +62,78 @@ class NativeLibLoader { * three characters */ private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{ - File temp = FileUtil.createTempFileFromResource(path); + String temp = createTempFileFromResource(path); // Finally, load the library - System.load(temp.getAbsolutePath()); + System.load(temp); + } + + /** + * Create a temp file that copies the resource from current JAR archive + *

+ * The file from JAR is copied into system temp file. + * The temporary file is deleted after exiting. + * Method uses String as filename because the pathname is "abstract", not system-dependent. + *

+ * The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to + * {@code path}. + * @param path Path to the resources in the jar + * @return The created temp file. + * @throws IOException + * @throws IllegalArgumentException + */ + static String createTempFileFromResource(String path) throws + IOException, IllegalArgumentException { + // Obtain filename from path + if (!path.startsWith("/")) { + throw new IllegalArgumentException("The path has to be absolute (start with '/')."); + } + + String[] parts = path.split("/"); + String filename = (parts.length > 1) ? parts[parts.length - 1] : null; + + // Split filename to prexif and suffix (extension) + String prefix = ""; + String suffix = null; + if (filename != null) { + parts = filename.split("\\.", 2); + prefix = parts[0]; + suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-) + } + + // Check if the filename is okay + if (filename == null || prefix.length() < 3) { + throw new IllegalArgumentException("The filename has to be at least 3 characters long."); + } + // Prepare temporary file + File temp = File.createTempFile(prefix, suffix); + temp.deleteOnExit(); + + if (!temp.exists()) { + throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist."); + } + + // Prepare buffer for data copying + byte[] buffer = new byte[1024]; + int readBytes; + + // Open and check input stream + InputStream is = NativeLibLoader.class.getResourceAsStream(path); + if (is == null) { + throw new FileNotFoundException("File " + path + " was not found inside JAR."); + } + + // Open output stream and copy data between source file in JAR and the temporary file + OutputStream os = new FileOutputStream(temp); + try { + while ((readBytes = is.read(buffer)) != -1) { + os.write(buffer, 0, readBytes); + } + } finally { + // If read/write fails, close streams safely before throwing an exception + os.close(); + is.close(); + } + return temp.getAbsolutePath(); } /** @@ -84,8 +151,9 @@ class NativeLibLoader { try { String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName); loadLibraryFromJar(libraryFromJar); - } catch (IOException e1) { - throw e1; + } catch (IOException ioe) { + logger.error("failed to load library from both native path and jar"); + throw ioe; } } } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/RabitTracker.java index 793943fff..1f1c70251 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/RabitTracker.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/RabitTracker.java @@ -5,6 +5,7 @@ package ml.dmlc.xgboost4j; import java.io.*; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -16,45 +17,40 @@ public class RabitTracker { // Maybe per tracker logger? private static final Log logger = LogFactory.getLog(RabitTracker.class); // tracker python file. - private static File tracker_py = null; + private static String tracker_py = null; // environment variable to be pased. private Map envs = new HashMap(); // number of workers to be submitted. private int num_workers; - // child process - private Process process = null; - // logger thread - private Thread logger_thread = null; - - //load native library - static { - try { - initTrackerPy(); - } catch (IOException ex) { - logger.error("load tracker library failed."); - logger.error(ex); - } - } + private AtomicReference trackerProcess = new AtomicReference(); /** * Tracker logger that logs output from tracker. */ - private class TrackerLogger implements Runnable { + private class TrackerProcessLogger implements Runnable { public void run() { - BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream())); + + Log trackerProcessLogger = LogFactory.getLog(TrackerProcessLogger.class); + BufferedReader reader = new BufferedReader(new InputStreamReader( + trackerProcess.get().getErrorStream())); String line; try { while ((line = reader.readLine()) != null) { - logger.info(line); + trackerProcessLogger.info(line); } } catch (IOException ex) { - logger.error(ex.toString()); + trackerProcessLogger.error(ex.toString()); } } } - private static synchronized void initTrackerPy() throws IOException { - tracker_py = FileUtil.createTempFileFromResource("/tracker.py"); + private void initTrackerPy() throws IOException { + try { + tracker_py = NativeLibLoader.createTempFileFromResource("/tracker.py"); + } catch (IOException ioe) { + logger.trace("cannot access tracker python script"); + throw ioe; + } } @@ -70,29 +66,71 @@ public class RabitTracker { return envs; } - public void start() throws IOException { - process = Runtime.getRuntime().exec("python " + tracker_py.getAbsolutePath() + - " --num-workers=" + new Integer(num_workers).toString()); - BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream())); - assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START"); - String line; - while ((line = reader.readLine()) != null) { - if (line.trim().equals("DMLC_TRACKER_ENV_END")) { - break; - } - String []sep = line.split("="); - if (sep.length == 2) { - envs.put(sep[0], sep[1]); + private void loadEnvs(InputStream ins) throws IOException { + try { + BufferedReader reader = new BufferedReader(new InputStreamReader(ins)); + assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START"); + String line; + while ((line = reader.readLine()) != null) { + if (line.trim().equals("DMLC_TRACKER_ENV_END")) { + break; + } + String[] sep = line.split("="); + if (sep.length == 2) { + envs.put(sep[0], sep[1]); + } } + reader.close(); + } catch (IOException ioe){ + logger.error("cannot get runtime configuration from tracker process"); + ioe.printStackTrace(); + throw ioe; } - logger.debug("Tracker started, with env=" + envs.toString()); - // also start a tracker logger - logger_thread = new Thread(new TrackerLogger()); - logger_thread.setDaemon(true); - logger_thread.start(); } - public void waitFor() throws InterruptedException { - process.waitFor(); + private boolean startTrackerProcess() { + try { + initTrackerPy(); + trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py + + " --num-workers=" + String.valueOf(num_workers))); + loadEnvs(trackerProcess.get().getInputStream()); + return true; + } catch (IOException ioe) { + ioe.printStackTrace(); + return false; + } + } + + private void stop() { + if (trackerProcess.get() != null) { + trackerProcess.get().destroy(); + } + } + + public boolean start() { + if (startTrackerProcess()) { + logger.debug("Tracker started, with env=" + envs.toString()); + // also start a tracker logger + Thread logger_thread = new Thread(new TrackerProcessLogger()); + logger_thread.setDaemon(true); + logger_thread.start(); + return true; + } else { + logger.error("FAULT: failed to start tracker process"); + stop(); + return false; + } + } + + public void waitFor() { + try { + trackerProcess.get().waitFor(); + logger.info("Tracker Process ends with exit code " + trackerProcess.get().exitValue()); + stop(); + } catch (InterruptedException e) { + // we should not get here as RabitTracker is accessed in the main thread + e.printStackTrace(); + logger.error("the RabitTracker thread is terminated unexpectedly"); + } } } diff --git a/rabit b/rabit index be50e7b63..1392e9f3d 160000 --- a/rabit +++ b/rabit @@ -1 +1 @@ -Subproject commit be50e7b63224b9fb7ff94ce34df9f8752ef83043 +Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0