revise the RabitTracker Impl

delete FileUtil class

fix bugs
This commit is contained in:
CodingCat 2015-12-22 03:29:20 -06:00
parent 0df2ed80c8
commit 10a1517502
8 changed files with 168 additions and 140 deletions

1
.gitignore vendored
View File

@ -78,5 +78,4 @@ tags
*.iml
*.class
target
*.swp

@ -1 +1 @@
Subproject commit 3f6ff43d3976d5b6d5001608b0e3e526ecde098f
Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0

0
jvm-packages/test_distributed.sh Normal file → Executable file
View File

View File

@ -20,16 +20,17 @@ public class DistTrain {
private Map<String, String> envs = null;
private class Worker implements Runnable {
private int worker_id;
Worker(int worker_id) {
this.worker_id = worker_id;
private final int workerId;
Worker(int workerId) {
this.workerId = workerId;
}
public void run() {
try {
Map<String, String> worker_env = new HashMap<String, String>(envs);
worker_env.put("DMLC_TASK_ID", new Integer(worker_id).toString());
worker_env.put("DMLC_TASK_ID", String.valueOf(workerId));
// always initialize rabit module before training.
Rabit.init(worker_env);
@ -44,7 +45,6 @@ public class DistTrain {
params.put("nthread", 2);
params.put("objective", "binary:logistic");
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
@ -63,15 +63,16 @@ public class DistTrain {
}
}
void start(int nworker) throws IOException, XGBoostError, InterruptedException {
RabitTracker tracker = new RabitTracker(nworker);
tracker.start();
void start(int nWorkers) throws IOException, XGBoostError, InterruptedException {
RabitTracker tracker = new RabitTracker(nWorkers);
if (tracker.start()) {
envs = tracker.getWorkerEnvs();
for (int i = 0; i < nworker; ++i) {
for (int i = 0; i < nWorkers; ++i) {
new Thread(new Worker(i)).start();
}
tracker.waitFor();
}
}
public static void main(String[] args) throws IOException, XGBoostError, InterruptedException {
new DistTrain().start(Integer.parseInt(args[0]));

View File

@ -1,78 +0,0 @@
package ml.dmlc.xgboost4j;
import java.io.*;
import java.io.IOException;
/**
* Auxiliary utils to
*/
class FileUtil {
/**
* Create a temp file that copies the resource from current JAR archive
* <p/>
* The file from JAR is copied into system temp file.
* The temporary file is deleted after exiting.
* Method uses String as filename because the pathname is "abstract", not system-dependent.
* <p/>
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
* {@code path}.
* @param path Path to the resources in the jar
* @return The created temp file.
* @throws IOException
* @throws IllegalArgumentException
*/
static File createTempFileFromResource(String path) throws IOException, IllegalArgumentException {
// Obtain filename from path
if (!path.startsWith("/")) {
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
}
String[] parts = path.split("/");
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
// Split filename to prexif and suffix (extension)
String prefix = "";
String suffix = null;
if (filename != null) {
parts = filename.split("\\.", 2);
prefix = parts[0];
suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-)
}
// Check if the filename is okay
if (filename == null || prefix.length() < 3) {
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
}
// Prepare temporary file
File temp = File.createTempFile(prefix, suffix);
temp.deleteOnExit();
if (!temp.exists()) {
throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist.");
}
// Prepare buffer for data copying
byte[] buffer = new byte[1024];
int readBytes;
// Open and check input stream
InputStream is = NativeLibLoader.class.getResourceAsStream(path);
if (is == null) {
throw new FileNotFoundException("File " + path + " was not found inside JAR.");
}
// Open output stream and copy data between source file in JAR and the temporary file
OutputStream os = new FileOutputStream(temp);
try {
while ((readBytes = is.read(buffer)) != -1) {
os.write(buffer, 0, readBytes);
}
} finally {
// If read/write fails, close streams safely before throwing an exception
os.close();
is.close();
}
return temp;
}
}

View File

@ -21,8 +21,6 @@ import java.lang.reflect.Field;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import ml.dmlc.xgboost4j.FileUtil;
/**
* class to load native library
@ -64,9 +62,78 @@ class NativeLibLoader {
* three characters
*/
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{
File temp = FileUtil.createTempFileFromResource(path);
String temp = createTempFileFromResource(path);
// Finally, load the library
System.load(temp.getAbsolutePath());
System.load(temp);
}
/**
* Create a temp file that copies the resource from current JAR archive
* <p/>
* The file from JAR is copied into system temp file.
* The temporary file is deleted after exiting.
* Method uses String as filename because the pathname is "abstract", not system-dependent.
* <p/>
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
* {@code path}.
* @param path Path to the resources in the jar
* @return The created temp file.
* @throws IOException
* @throws IllegalArgumentException
*/
static String createTempFileFromResource(String path) throws
IOException, IllegalArgumentException {
// Obtain filename from path
if (!path.startsWith("/")) {
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
}
String[] parts = path.split("/");
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
// Split filename to prexif and suffix (extension)
String prefix = "";
String suffix = null;
if (filename != null) {
parts = filename.split("\\.", 2);
prefix = parts[0];
suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-)
}
// Check if the filename is okay
if (filename == null || prefix.length() < 3) {
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
}
// Prepare temporary file
File temp = File.createTempFile(prefix, suffix);
temp.deleteOnExit();
if (!temp.exists()) {
throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist.");
}
// Prepare buffer for data copying
byte[] buffer = new byte[1024];
int readBytes;
// Open and check input stream
InputStream is = NativeLibLoader.class.getResourceAsStream(path);
if (is == null) {
throw new FileNotFoundException("File " + path + " was not found inside JAR.");
}
// Open output stream and copy data between source file in JAR and the temporary file
OutputStream os = new FileOutputStream(temp);
try {
while ((readBytes = is.read(buffer)) != -1) {
os.write(buffer, 0, readBytes);
}
} finally {
// If read/write fails, close streams safely before throwing an exception
os.close();
is.close();
}
return temp.getAbsolutePath();
}
/**
@ -84,8 +151,9 @@ class NativeLibLoader {
try {
String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);
loadLibraryFromJar(libraryFromJar);
} catch (IOException e1) {
throw e1;
} catch (IOException ioe) {
logger.error("failed to load library from both native path and jar");
throw ioe;
}
}
}

View File

@ -5,6 +5,7 @@ package ml.dmlc.xgboost4j;
import java.io.*;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@ -16,45 +17,40 @@ public class RabitTracker {
// Maybe per tracker logger?
private static final Log logger = LogFactory.getLog(RabitTracker.class);
// tracker python file.
private static File tracker_py = null;
private static String tracker_py = null;
// environment variable to be pased.
private Map<String, String> envs = new HashMap<String, String>();
// number of workers to be submitted.
private int num_workers;
// child process
private Process process = null;
// logger thread
private Thread logger_thread = null;
//load native library
static {
try {
initTrackerPy();
} catch (IOException ex) {
logger.error("load tracker library failed.");
logger.error(ex);
}
}
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
/**
* Tracker logger that logs output from tracker.
*/
private class TrackerLogger implements Runnable {
private class TrackerProcessLogger implements Runnable {
public void run() {
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream()));
Log trackerProcessLogger = LogFactory.getLog(TrackerProcessLogger.class);
BufferedReader reader = new BufferedReader(new InputStreamReader(
trackerProcess.get().getErrorStream()));
String line;
try {
while ((line = reader.readLine()) != null) {
logger.info(line);
trackerProcessLogger.info(line);
}
} catch (IOException ex) {
logger.error(ex.toString());
trackerProcessLogger.error(ex.toString());
}
}
}
private static synchronized void initTrackerPy() throws IOException {
tracker_py = FileUtil.createTempFileFromResource("/tracker.py");
private void initTrackerPy() throws IOException {
try {
tracker_py = NativeLibLoader.createTempFileFromResource("/tracker.py");
} catch (IOException ioe) {
logger.trace("cannot access tracker python script");
throw ioe;
}
}
@ -70,29 +66,71 @@ public class RabitTracker {
return envs;
}
public void start() throws IOException {
process = Runtime.getRuntime().exec("python " + tracker_py.getAbsolutePath() +
" --num-workers=" + new Integer(num_workers).toString());
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
private void loadEnvs(InputStream ins) throws IOException {
try {
BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START");
String line;
while ((line = reader.readLine()) != null) {
if (line.trim().equals("DMLC_TRACKER_ENV_END")) {
break;
}
String []sep = line.split("=");
String[] sep = line.split("=");
if (sep.length == 2) {
envs.put(sep[0], sep[1]);
}
}
logger.debug("Tracker started, with env=" + envs.toString());
// also start a tracker logger
logger_thread = new Thread(new TrackerLogger());
logger_thread.setDaemon(true);
logger_thread.start();
reader.close();
} catch (IOException ioe){
logger.error("cannot get runtime configuration from tracker process");
ioe.printStackTrace();
throw ioe;
}
}
public void waitFor() throws InterruptedException {
process.waitFor();
private boolean startTrackerProcess() {
try {
initTrackerPy();
trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py +
" --num-workers=" + String.valueOf(num_workers)));
loadEnvs(trackerProcess.get().getInputStream());
return true;
} catch (IOException ioe) {
ioe.printStackTrace();
return false;
}
}
private void stop() {
if (trackerProcess.get() != null) {
trackerProcess.get().destroy();
}
}
public boolean start() {
if (startTrackerProcess()) {
logger.debug("Tracker started, with env=" + envs.toString());
// also start a tracker logger
Thread logger_thread = new Thread(new TrackerProcessLogger());
logger_thread.setDaemon(true);
logger_thread.start();
return true;
} else {
logger.error("FAULT: failed to start tracker process");
stop();
return false;
}
}
public void waitFor() {
try {
trackerProcess.get().waitFor();
logger.info("Tracker Process ends with exit code " + trackerProcess.get().exitValue());
stop();
} catch (InterruptedException e) {
// we should not get here as RabitTracker is accessed in the main thread
e.printStackTrace();
logger.error("the RabitTracker thread is terminated unexpectedly");
}
}
}

2
rabit

@ -1 +1 @@
Subproject commit be50e7b63224b9fb7ff94ce34df9f8752ef83043
Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0