revise the RabitTracker Impl
delete FileUtil class fix bugs
This commit is contained in:
parent
0df2ed80c8
commit
10a1517502
1
.gitignore
vendored
1
.gitignore
vendored
@ -78,5 +78,4 @@ tags
|
||||
*.iml
|
||||
*.class
|
||||
target
|
||||
|
||||
*.swp
|
||||
|
||||
@ -1 +1 @@
|
||||
Subproject commit 3f6ff43d3976d5b6d5001608b0e3e526ecde098f
|
||||
Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0
|
||||
0
jvm-packages/test_distributed.sh
Normal file → Executable file
0
jvm-packages/test_distributed.sh
Normal file → Executable file
@ -20,16 +20,17 @@ public class DistTrain {
|
||||
private Map<String, String> envs = null;
|
||||
|
||||
private class Worker implements Runnable {
|
||||
private int worker_id;
|
||||
Worker(int worker_id) {
|
||||
this.worker_id = worker_id;
|
||||
private final int workerId;
|
||||
|
||||
Worker(int workerId) {
|
||||
this.workerId = workerId;
|
||||
}
|
||||
|
||||
public void run() {
|
||||
try {
|
||||
Map<String, String> worker_env = new HashMap<String, String>(envs);
|
||||
|
||||
worker_env.put("DMLC_TASK_ID", new Integer(worker_id).toString());
|
||||
worker_env.put("DMLC_TASK_ID", String.valueOf(workerId));
|
||||
// always initialize rabit module before training.
|
||||
Rabit.init(worker_env);
|
||||
|
||||
@ -44,7 +45,6 @@ public class DistTrain {
|
||||
params.put("nthread", 2);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
@ -63,14 +63,15 @@ public class DistTrain {
|
||||
}
|
||||
}
|
||||
|
||||
void start(int nworker) throws IOException, XGBoostError, InterruptedException {
|
||||
RabitTracker tracker = new RabitTracker(nworker);
|
||||
tracker.start();
|
||||
envs = tracker.getWorkerEnvs();
|
||||
for (int i = 0; i < nworker; ++i) {
|
||||
new Thread(new Worker(i)).start();
|
||||
void start(int nWorkers) throws IOException, XGBoostError, InterruptedException {
|
||||
RabitTracker tracker = new RabitTracker(nWorkers);
|
||||
if (tracker.start()) {
|
||||
envs = tracker.getWorkerEnvs();
|
||||
for (int i = 0; i < nWorkers; ++i) {
|
||||
new Thread(new Worker(i)).start();
|
||||
}
|
||||
tracker.waitFor();
|
||||
}
|
||||
tracker.waitFor();
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws IOException, XGBoostError, InterruptedException {
|
||||
|
||||
@ -1,78 +0,0 @@
|
||||
package ml.dmlc.xgboost4j;
|
||||
|
||||
|
||||
import java.io.*;
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Auxiliary utils to
|
||||
*/
|
||||
class FileUtil {
|
||||
/**
|
||||
* Create a temp file that copies the resource from current JAR archive
|
||||
* <p/>
|
||||
* The file from JAR is copied into system temp file.
|
||||
* The temporary file is deleted after exiting.
|
||||
* Method uses String as filename because the pathname is "abstract", not system-dependent.
|
||||
* <p/>
|
||||
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
|
||||
* {@code path}.
|
||||
* @param path Path to the resources in the jar
|
||||
* @return The created temp file.
|
||||
* @throws IOException
|
||||
* @throws IllegalArgumentException
|
||||
*/
|
||||
static File createTempFileFromResource(String path) throws IOException, IllegalArgumentException {
|
||||
// Obtain filename from path
|
||||
if (!path.startsWith("/")) {
|
||||
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
|
||||
}
|
||||
|
||||
String[] parts = path.split("/");
|
||||
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
|
||||
|
||||
// Split filename to prexif and suffix (extension)
|
||||
String prefix = "";
|
||||
String suffix = null;
|
||||
if (filename != null) {
|
||||
parts = filename.split("\\.", 2);
|
||||
prefix = parts[0];
|
||||
suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-)
|
||||
}
|
||||
|
||||
// Check if the filename is okay
|
||||
if (filename == null || prefix.length() < 3) {
|
||||
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
|
||||
}
|
||||
// Prepare temporary file
|
||||
File temp = File.createTempFile(prefix, suffix);
|
||||
temp.deleteOnExit();
|
||||
|
||||
if (!temp.exists()) {
|
||||
throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist.");
|
||||
}
|
||||
|
||||
// Prepare buffer for data copying
|
||||
byte[] buffer = new byte[1024];
|
||||
int readBytes;
|
||||
|
||||
// Open and check input stream
|
||||
InputStream is = NativeLibLoader.class.getResourceAsStream(path);
|
||||
if (is == null) {
|
||||
throw new FileNotFoundException("File " + path + " was not found inside JAR.");
|
||||
}
|
||||
|
||||
// Open output stream and copy data between source file in JAR and the temporary file
|
||||
OutputStream os = new FileOutputStream(temp);
|
||||
try {
|
||||
while ((readBytes = is.read(buffer)) != -1) {
|
||||
os.write(buffer, 0, readBytes);
|
||||
}
|
||||
} finally {
|
||||
// If read/write fails, close streams safely before throwing an exception
|
||||
os.close();
|
||||
is.close();
|
||||
}
|
||||
return temp;
|
||||
}
|
||||
}
|
||||
@ -21,8 +21,6 @@ import java.lang.reflect.Field;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import ml.dmlc.xgboost4j.FileUtil;
|
||||
|
||||
|
||||
/**
|
||||
* class to load native library
|
||||
@ -64,9 +62,78 @@ class NativeLibLoader {
|
||||
* three characters
|
||||
*/
|
||||
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{
|
||||
File temp = FileUtil.createTempFileFromResource(path);
|
||||
String temp = createTempFileFromResource(path);
|
||||
// Finally, load the library
|
||||
System.load(temp.getAbsolutePath());
|
||||
System.load(temp);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a temp file that copies the resource from current JAR archive
|
||||
* <p/>
|
||||
* The file from JAR is copied into system temp file.
|
||||
* The temporary file is deleted after exiting.
|
||||
* Method uses String as filename because the pathname is "abstract", not system-dependent.
|
||||
* <p/>
|
||||
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
|
||||
* {@code path}.
|
||||
* @param path Path to the resources in the jar
|
||||
* @return The created temp file.
|
||||
* @throws IOException
|
||||
* @throws IllegalArgumentException
|
||||
*/
|
||||
static String createTempFileFromResource(String path) throws
|
||||
IOException, IllegalArgumentException {
|
||||
// Obtain filename from path
|
||||
if (!path.startsWith("/")) {
|
||||
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
|
||||
}
|
||||
|
||||
String[] parts = path.split("/");
|
||||
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
|
||||
|
||||
// Split filename to prexif and suffix (extension)
|
||||
String prefix = "";
|
||||
String suffix = null;
|
||||
if (filename != null) {
|
||||
parts = filename.split("\\.", 2);
|
||||
prefix = parts[0];
|
||||
suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-)
|
||||
}
|
||||
|
||||
// Check if the filename is okay
|
||||
if (filename == null || prefix.length() < 3) {
|
||||
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
|
||||
}
|
||||
// Prepare temporary file
|
||||
File temp = File.createTempFile(prefix, suffix);
|
||||
temp.deleteOnExit();
|
||||
|
||||
if (!temp.exists()) {
|
||||
throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist.");
|
||||
}
|
||||
|
||||
// Prepare buffer for data copying
|
||||
byte[] buffer = new byte[1024];
|
||||
int readBytes;
|
||||
|
||||
// Open and check input stream
|
||||
InputStream is = NativeLibLoader.class.getResourceAsStream(path);
|
||||
if (is == null) {
|
||||
throw new FileNotFoundException("File " + path + " was not found inside JAR.");
|
||||
}
|
||||
|
||||
// Open output stream and copy data between source file in JAR and the temporary file
|
||||
OutputStream os = new FileOutputStream(temp);
|
||||
try {
|
||||
while ((readBytes = is.read(buffer)) != -1) {
|
||||
os.write(buffer, 0, readBytes);
|
||||
}
|
||||
} finally {
|
||||
// If read/write fails, close streams safely before throwing an exception
|
||||
os.close();
|
||||
is.close();
|
||||
}
|
||||
return temp.getAbsolutePath();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -84,8 +151,9 @@ class NativeLibLoader {
|
||||
try {
|
||||
String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);
|
||||
loadLibraryFromJar(libraryFromJar);
|
||||
} catch (IOException e1) {
|
||||
throw e1;
|
||||
} catch (IOException ioe) {
|
||||
logger.error("failed to load library from both native path and jar");
|
||||
throw ioe;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ package ml.dmlc.xgboost4j;
|
||||
import java.io.*;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
@ -16,45 +17,40 @@ public class RabitTracker {
|
||||
// Maybe per tracker logger?
|
||||
private static final Log logger = LogFactory.getLog(RabitTracker.class);
|
||||
// tracker python file.
|
||||
private static File tracker_py = null;
|
||||
private static String tracker_py = null;
|
||||
// environment variable to be pased.
|
||||
private Map<String, String> envs = new HashMap<String, String>();
|
||||
// number of workers to be submitted.
|
||||
private int num_workers;
|
||||
// child process
|
||||
private Process process = null;
|
||||
// logger thread
|
||||
private Thread logger_thread = null;
|
||||
|
||||
//load native library
|
||||
static {
|
||||
try {
|
||||
initTrackerPy();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load tracker library failed.");
|
||||
logger.error(ex);
|
||||
}
|
||||
}
|
||||
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
|
||||
|
||||
/**
|
||||
* Tracker logger that logs output from tracker.
|
||||
*/
|
||||
private class TrackerLogger implements Runnable {
|
||||
private class TrackerProcessLogger implements Runnable {
|
||||
public void run() {
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream()));
|
||||
|
||||
Log trackerProcessLogger = LogFactory.getLog(TrackerProcessLogger.class);
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(
|
||||
trackerProcess.get().getErrorStream()));
|
||||
String line;
|
||||
try {
|
||||
while ((line = reader.readLine()) != null) {
|
||||
logger.info(line);
|
||||
trackerProcessLogger.info(line);
|
||||
}
|
||||
} catch (IOException ex) {
|
||||
logger.error(ex.toString());
|
||||
trackerProcessLogger.error(ex.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static synchronized void initTrackerPy() throws IOException {
|
||||
tracker_py = FileUtil.createTempFileFromResource("/tracker.py");
|
||||
private void initTrackerPy() throws IOException {
|
||||
try {
|
||||
tracker_py = NativeLibLoader.createTempFileFromResource("/tracker.py");
|
||||
} catch (IOException ioe) {
|
||||
logger.trace("cannot access tracker python script");
|
||||
throw ioe;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -70,29 +66,71 @@ public class RabitTracker {
|
||||
return envs;
|
||||
}
|
||||
|
||||
public void start() throws IOException {
|
||||
process = Runtime.getRuntime().exec("python " + tracker_py.getAbsolutePath() +
|
||||
" --num-workers=" + new Integer(num_workers).toString());
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
|
||||
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START");
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
if (line.trim().equals("DMLC_TRACKER_ENV_END")) {
|
||||
break;
|
||||
}
|
||||
String []sep = line.split("=");
|
||||
if (sep.length == 2) {
|
||||
envs.put(sep[0], sep[1]);
|
||||
private void loadEnvs(InputStream ins) throws IOException {
|
||||
try {
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
|
||||
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START");
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
if (line.trim().equals("DMLC_TRACKER_ENV_END")) {
|
||||
break;
|
||||
}
|
||||
String[] sep = line.split("=");
|
||||
if (sep.length == 2) {
|
||||
envs.put(sep[0], sep[1]);
|
||||
}
|
||||
}
|
||||
reader.close();
|
||||
} catch (IOException ioe){
|
||||
logger.error("cannot get runtime configuration from tracker process");
|
||||
ioe.printStackTrace();
|
||||
throw ioe;
|
||||
}
|
||||
logger.debug("Tracker started, with env=" + envs.toString());
|
||||
// also start a tracker logger
|
||||
logger_thread = new Thread(new TrackerLogger());
|
||||
logger_thread.setDaemon(true);
|
||||
logger_thread.start();
|
||||
}
|
||||
|
||||
public void waitFor() throws InterruptedException {
|
||||
process.waitFor();
|
||||
private boolean startTrackerProcess() {
|
||||
try {
|
||||
initTrackerPy();
|
||||
trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py +
|
||||
" --num-workers=" + String.valueOf(num_workers)));
|
||||
loadEnvs(trackerProcess.get().getInputStream());
|
||||
return true;
|
||||
} catch (IOException ioe) {
|
||||
ioe.printStackTrace();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private void stop() {
|
||||
if (trackerProcess.get() != null) {
|
||||
trackerProcess.get().destroy();
|
||||
}
|
||||
}
|
||||
|
||||
public boolean start() {
|
||||
if (startTrackerProcess()) {
|
||||
logger.debug("Tracker started, with env=" + envs.toString());
|
||||
// also start a tracker logger
|
||||
Thread logger_thread = new Thread(new TrackerProcessLogger());
|
||||
logger_thread.setDaemon(true);
|
||||
logger_thread.start();
|
||||
return true;
|
||||
} else {
|
||||
logger.error("FAULT: failed to start tracker process");
|
||||
stop();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public void waitFor() {
|
||||
try {
|
||||
trackerProcess.get().waitFor();
|
||||
logger.info("Tracker Process ends with exit code " + trackerProcess.get().exitValue());
|
||||
stop();
|
||||
} catch (InterruptedException e) {
|
||||
// we should not get here as RabitTracker is accessed in the main thread
|
||||
e.printStackTrace();
|
||||
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
2
rabit
2
rabit
@ -1 +1 @@
|
||||
Subproject commit be50e7b63224b9fb7ff94ce34df9f8752ef83043
|
||||
Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0
|
||||
Loading…
x
Reference in New Issue
Block a user