revise the RabitTracker Impl

delete FileUtil class

fix bugs
This commit is contained in:
CodingCat 2015-12-22 03:29:20 -06:00
parent 0df2ed80c8
commit 10a1517502
8 changed files with 168 additions and 140 deletions

1
.gitignore vendored
View File

@ -78,5 +78,4 @@ tags
*.iml *.iml
*.class *.class
target target
*.swp *.swp

@ -1 +1 @@
Subproject commit 3f6ff43d3976d5b6d5001608b0e3e526ecde098f Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0

0
jvm-packages/test_distributed.sh Normal file → Executable file
View File

View File

@ -20,16 +20,17 @@ public class DistTrain {
private Map<String, String> envs = null; private Map<String, String> envs = null;
private class Worker implements Runnable { private class Worker implements Runnable {
private int worker_id; private final int workerId;
Worker(int worker_id) {
this.worker_id = worker_id; Worker(int workerId) {
this.workerId = workerId;
} }
public void run() { public void run() {
try { try {
Map<String, String> worker_env = new HashMap<String, String>(envs); Map<String, String> worker_env = new HashMap<String, String>(envs);
worker_env.put("DMLC_TASK_ID", new Integer(worker_id).toString()); worker_env.put("DMLC_TASK_ID", String.valueOf(workerId));
// always initialize rabit module before training. // always initialize rabit module before training.
Rabit.init(worker_env); Rabit.init(worker_env);
@ -44,7 +45,6 @@ public class DistTrain {
params.put("nthread", 2); params.put("nthread", 2);
params.put("objective", "binary:logistic"); params.put("objective", "binary:logistic");
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>(); HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat); watches.put("train", trainMat);
watches.put("test", testMat); watches.put("test", testMat);
@ -63,15 +63,16 @@ public class DistTrain {
} }
} }
void start(int nworker) throws IOException, XGBoostError, InterruptedException { void start(int nWorkers) throws IOException, XGBoostError, InterruptedException {
RabitTracker tracker = new RabitTracker(nworker); RabitTracker tracker = new RabitTracker(nWorkers);
tracker.start(); if (tracker.start()) {
envs = tracker.getWorkerEnvs(); envs = tracker.getWorkerEnvs();
for (int i = 0; i < nworker; ++i) { for (int i = 0; i < nWorkers; ++i) {
new Thread(new Worker(i)).start(); new Thread(new Worker(i)).start();
} }
tracker.waitFor(); tracker.waitFor();
} }
}
public static void main(String[] args) throws IOException, XGBoostError, InterruptedException { public static void main(String[] args) throws IOException, XGBoostError, InterruptedException {
new DistTrain().start(Integer.parseInt(args[0])); new DistTrain().start(Integer.parseInt(args[0]));

View File

@ -1,78 +0,0 @@
package ml.dmlc.xgboost4j;
import java.io.*;
import java.io.IOException;
/**
* Auxiliary utils to
*/
class FileUtil {
/**
* Create a temp file that copies the resource from current JAR archive
* <p/>
* The file from JAR is copied into system temp file.
* The temporary file is deleted after exiting.
* Method uses String as filename because the pathname is "abstract", not system-dependent.
* <p/>
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
* {@code path}.
* @param path Path to the resources in the jar
* @return The created temp file.
* @throws IOException
* @throws IllegalArgumentException
*/
static File createTempFileFromResource(String path) throws IOException, IllegalArgumentException {
// Obtain filename from path
if (!path.startsWith("/")) {
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
}
String[] parts = path.split("/");
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
// Split filename to prexif and suffix (extension)
String prefix = "";
String suffix = null;
if (filename != null) {
parts = filename.split("\\.", 2);
prefix = parts[0];
suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-)
}
// Check if the filename is okay
if (filename == null || prefix.length() < 3) {
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
}
// Prepare temporary file
File temp = File.createTempFile(prefix, suffix);
temp.deleteOnExit();
if (!temp.exists()) {
throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist.");
}
// Prepare buffer for data copying
byte[] buffer = new byte[1024];
int readBytes;
// Open and check input stream
InputStream is = NativeLibLoader.class.getResourceAsStream(path);
if (is == null) {
throw new FileNotFoundException("File " + path + " was not found inside JAR.");
}
// Open output stream and copy data between source file in JAR and the temporary file
OutputStream os = new FileOutputStream(temp);
try {
while ((readBytes = is.read(buffer)) != -1) {
os.write(buffer, 0, readBytes);
}
} finally {
// If read/write fails, close streams safely before throwing an exception
os.close();
is.close();
}
return temp;
}
}

View File

@ -21,8 +21,6 @@ import java.lang.reflect.Field;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import ml.dmlc.xgboost4j.FileUtil;
/** /**
* class to load native library * class to load native library
@ -64,9 +62,78 @@ class NativeLibLoader {
* three characters * three characters
*/ */
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{ private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{
File temp = FileUtil.createTempFileFromResource(path); String temp = createTempFileFromResource(path);
// Finally, load the library // Finally, load the library
System.load(temp.getAbsolutePath()); System.load(temp);
}
/**
* Create a temp file that copies the resource from current JAR archive
* <p/>
* The file from JAR is copied into system temp file.
* The temporary file is deleted after exiting.
* Method uses String as filename because the pathname is "abstract", not system-dependent.
* <p/>
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
* {@code path}.
* @param path Path to the resources in the jar
* @return The created temp file.
* @throws IOException
* @throws IllegalArgumentException
*/
static String createTempFileFromResource(String path) throws
IOException, IllegalArgumentException {
// Obtain filename from path
if (!path.startsWith("/")) {
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
}
String[] parts = path.split("/");
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
// Split filename to prexif and suffix (extension)
String prefix = "";
String suffix = null;
if (filename != null) {
parts = filename.split("\\.", 2);
prefix = parts[0];
suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-)
}
// Check if the filename is okay
if (filename == null || prefix.length() < 3) {
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
}
// Prepare temporary file
File temp = File.createTempFile(prefix, suffix);
temp.deleteOnExit();
if (!temp.exists()) {
throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist.");
}
// Prepare buffer for data copying
byte[] buffer = new byte[1024];
int readBytes;
// Open and check input stream
InputStream is = NativeLibLoader.class.getResourceAsStream(path);
if (is == null) {
throw new FileNotFoundException("File " + path + " was not found inside JAR.");
}
// Open output stream and copy data between source file in JAR and the temporary file
OutputStream os = new FileOutputStream(temp);
try {
while ((readBytes = is.read(buffer)) != -1) {
os.write(buffer, 0, readBytes);
}
} finally {
// If read/write fails, close streams safely before throwing an exception
os.close();
is.close();
}
return temp.getAbsolutePath();
} }
/** /**
@ -84,8 +151,9 @@ class NativeLibLoader {
try { try {
String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName); String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);
loadLibraryFromJar(libraryFromJar); loadLibraryFromJar(libraryFromJar);
} catch (IOException e1) { } catch (IOException ioe) {
throw e1; logger.error("failed to load library from both native path and jar");
throw ioe;
} }
} }
} }

View File

@ -5,6 +5,7 @@ package ml.dmlc.xgboost4j;
import java.io.*; import java.io.*;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
@ -16,45 +17,40 @@ public class RabitTracker {
// Maybe per tracker logger? // Maybe per tracker logger?
private static final Log logger = LogFactory.getLog(RabitTracker.class); private static final Log logger = LogFactory.getLog(RabitTracker.class);
// tracker python file. // tracker python file.
private static File tracker_py = null; private static String tracker_py = null;
// environment variable to be pased. // environment variable to be pased.
private Map<String, String> envs = new HashMap<String, String>(); private Map<String, String> envs = new HashMap<String, String>();
// number of workers to be submitted. // number of workers to be submitted.
private int num_workers; private int num_workers;
// child process private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
private Process process = null;
// logger thread
private Thread logger_thread = null;
//load native library
static {
try {
initTrackerPy();
} catch (IOException ex) {
logger.error("load tracker library failed.");
logger.error(ex);
}
}
/** /**
* Tracker logger that logs output from tracker. * Tracker logger that logs output from tracker.
*/ */
private class TrackerLogger implements Runnable { private class TrackerProcessLogger implements Runnable {
public void run() { public void run() {
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream()));
Log trackerProcessLogger = LogFactory.getLog(TrackerProcessLogger.class);
BufferedReader reader = new BufferedReader(new InputStreamReader(
trackerProcess.get().getErrorStream()));
String line; String line;
try { try {
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
logger.info(line); trackerProcessLogger.info(line);
} }
} catch (IOException ex) { } catch (IOException ex) {
logger.error(ex.toString()); trackerProcessLogger.error(ex.toString());
} }
} }
} }
private static synchronized void initTrackerPy() throws IOException { private void initTrackerPy() throws IOException {
tracker_py = FileUtil.createTempFileFromResource("/tracker.py"); try {
tracker_py = NativeLibLoader.createTempFileFromResource("/tracker.py");
} catch (IOException ioe) {
logger.trace("cannot access tracker python script");
throw ioe;
}
} }
@ -70,29 +66,71 @@ public class RabitTracker {
return envs; return envs;
} }
public void start() throws IOException { private void loadEnvs(InputStream ins) throws IOException {
process = Runtime.getRuntime().exec("python " + tracker_py.getAbsolutePath() + try {
" --num-workers=" + new Integer(num_workers).toString()); BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START"); assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START");
String line; String line;
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
if (line.trim().equals("DMLC_TRACKER_ENV_END")) { if (line.trim().equals("DMLC_TRACKER_ENV_END")) {
break; break;
} }
String []sep = line.split("="); String[] sep = line.split("=");
if (sep.length == 2) { if (sep.length == 2) {
envs.put(sep[0], sep[1]); envs.put(sep[0], sep[1]);
} }
} }
logger.debug("Tracker started, with env=" + envs.toString()); reader.close();
// also start a tracker logger } catch (IOException ioe){
logger_thread = new Thread(new TrackerLogger()); logger.error("cannot get runtime configuration from tracker process");
logger_thread.setDaemon(true); ioe.printStackTrace();
logger_thread.start(); throw ioe;
}
} }
public void waitFor() throws InterruptedException { private boolean startTrackerProcess() {
process.waitFor(); try {
initTrackerPy();
trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py +
" --num-workers=" + String.valueOf(num_workers)));
loadEnvs(trackerProcess.get().getInputStream());
return true;
} catch (IOException ioe) {
ioe.printStackTrace();
return false;
}
}
private void stop() {
if (trackerProcess.get() != null) {
trackerProcess.get().destroy();
}
}
public boolean start() {
if (startTrackerProcess()) {
logger.debug("Tracker started, with env=" + envs.toString());
// also start a tracker logger
Thread logger_thread = new Thread(new TrackerProcessLogger());
logger_thread.setDaemon(true);
logger_thread.start();
return true;
} else {
logger.error("FAULT: failed to start tracker process");
stop();
return false;
}
}
public void waitFor() {
try {
trackerProcess.get().waitFor();
logger.info("Tracker Process ends with exit code " + trackerProcess.get().exitValue());
stop();
} catch (InterruptedException e) {
// we should not get here as RabitTracker is accessed in the main thread
e.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly");
}
} }
} }

2
rabit

@ -1 +1 @@
Subproject commit be50e7b63224b9fb7ff94ce34df9f8752ef83043 Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0