revise the RabitTracker Impl
delete FileUtil class fix bugs
This commit is contained in:
parent
0df2ed80c8
commit
10a1517502
1
.gitignore
vendored
1
.gitignore
vendored
@ -78,5 +78,4 @@ tags
|
|||||||
*.iml
|
*.iml
|
||||||
*.class
|
*.class
|
||||||
target
|
target
|
||||||
|
|
||||||
*.swp
|
*.swp
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
Subproject commit 3f6ff43d3976d5b6d5001608b0e3e526ecde098f
|
Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0
|
||||||
0
jvm-packages/test_distributed.sh
Normal file → Executable file
0
jvm-packages/test_distributed.sh
Normal file → Executable file
@ -20,16 +20,17 @@ public class DistTrain {
|
|||||||
private Map<String, String> envs = null;
|
private Map<String, String> envs = null;
|
||||||
|
|
||||||
private class Worker implements Runnable {
|
private class Worker implements Runnable {
|
||||||
private int worker_id;
|
private final int workerId;
|
||||||
Worker(int worker_id) {
|
|
||||||
this.worker_id = worker_id;
|
Worker(int workerId) {
|
||||||
|
this.workerId = workerId;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void run() {
|
public void run() {
|
||||||
try {
|
try {
|
||||||
Map<String, String> worker_env = new HashMap<String, String>(envs);
|
Map<String, String> worker_env = new HashMap<String, String>(envs);
|
||||||
|
|
||||||
worker_env.put("DMLC_TASK_ID", new Integer(worker_id).toString());
|
worker_env.put("DMLC_TASK_ID", String.valueOf(workerId));
|
||||||
// always initialize rabit module before training.
|
// always initialize rabit module before training.
|
||||||
Rabit.init(worker_env);
|
Rabit.init(worker_env);
|
||||||
|
|
||||||
@ -44,7 +45,6 @@ public class DistTrain {
|
|||||||
params.put("nthread", 2);
|
params.put("nthread", 2);
|
||||||
params.put("objective", "binary:logistic");
|
params.put("objective", "binary:logistic");
|
||||||
|
|
||||||
|
|
||||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||||
watches.put("train", trainMat);
|
watches.put("train", trainMat);
|
||||||
watches.put("test", testMat);
|
watches.put("test", testMat);
|
||||||
@ -63,14 +63,15 @@ public class DistTrain {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void start(int nworker) throws IOException, XGBoostError, InterruptedException {
|
void start(int nWorkers) throws IOException, XGBoostError, InterruptedException {
|
||||||
RabitTracker tracker = new RabitTracker(nworker);
|
RabitTracker tracker = new RabitTracker(nWorkers);
|
||||||
tracker.start();
|
if (tracker.start()) {
|
||||||
envs = tracker.getWorkerEnvs();
|
envs = tracker.getWorkerEnvs();
|
||||||
for (int i = 0; i < nworker; ++i) {
|
for (int i = 0; i < nWorkers; ++i) {
|
||||||
new Thread(new Worker(i)).start();
|
new Thread(new Worker(i)).start();
|
||||||
|
}
|
||||||
|
tracker.waitFor();
|
||||||
}
|
}
|
||||||
tracker.waitFor();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void main(String[] args) throws IOException, XGBoostError, InterruptedException {
|
public static void main(String[] args) throws IOException, XGBoostError, InterruptedException {
|
||||||
|
|||||||
@ -1,78 +0,0 @@
|
|||||||
package ml.dmlc.xgboost4j;
|
|
||||||
|
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Auxiliary utils to
|
|
||||||
*/
|
|
||||||
class FileUtil {
|
|
||||||
/**
|
|
||||||
* Create a temp file that copies the resource from current JAR archive
|
|
||||||
* <p/>
|
|
||||||
* The file from JAR is copied into system temp file.
|
|
||||||
* The temporary file is deleted after exiting.
|
|
||||||
* Method uses String as filename because the pathname is "abstract", not system-dependent.
|
|
||||||
* <p/>
|
|
||||||
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
|
|
||||||
* {@code path}.
|
|
||||||
* @param path Path to the resources in the jar
|
|
||||||
* @return The created temp file.
|
|
||||||
* @throws IOException
|
|
||||||
* @throws IllegalArgumentException
|
|
||||||
*/
|
|
||||||
static File createTempFileFromResource(String path) throws IOException, IllegalArgumentException {
|
|
||||||
// Obtain filename from path
|
|
||||||
if (!path.startsWith("/")) {
|
|
||||||
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
|
|
||||||
}
|
|
||||||
|
|
||||||
String[] parts = path.split("/");
|
|
||||||
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
|
|
||||||
|
|
||||||
// Split filename to prexif and suffix (extension)
|
|
||||||
String prefix = "";
|
|
||||||
String suffix = null;
|
|
||||||
if (filename != null) {
|
|
||||||
parts = filename.split("\\.", 2);
|
|
||||||
prefix = parts[0];
|
|
||||||
suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the filename is okay
|
|
||||||
if (filename == null || prefix.length() < 3) {
|
|
||||||
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
|
|
||||||
}
|
|
||||||
// Prepare temporary file
|
|
||||||
File temp = File.createTempFile(prefix, suffix);
|
|
||||||
temp.deleteOnExit();
|
|
||||||
|
|
||||||
if (!temp.exists()) {
|
|
||||||
throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist.");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare buffer for data copying
|
|
||||||
byte[] buffer = new byte[1024];
|
|
||||||
int readBytes;
|
|
||||||
|
|
||||||
// Open and check input stream
|
|
||||||
InputStream is = NativeLibLoader.class.getResourceAsStream(path);
|
|
||||||
if (is == null) {
|
|
||||||
throw new FileNotFoundException("File " + path + " was not found inside JAR.");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open output stream and copy data between source file in JAR and the temporary file
|
|
||||||
OutputStream os = new FileOutputStream(temp);
|
|
||||||
try {
|
|
||||||
while ((readBytes = is.read(buffer)) != -1) {
|
|
||||||
os.write(buffer, 0, readBytes);
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
// If read/write fails, close streams safely before throwing an exception
|
|
||||||
os.close();
|
|
||||||
is.close();
|
|
||||||
}
|
|
||||||
return temp;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -21,8 +21,6 @@ import java.lang.reflect.Field;
|
|||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.FileUtil;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* class to load native library
|
* class to load native library
|
||||||
@ -64,9 +62,78 @@ class NativeLibLoader {
|
|||||||
* three characters
|
* three characters
|
||||||
*/
|
*/
|
||||||
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{
|
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{
|
||||||
File temp = FileUtil.createTempFileFromResource(path);
|
String temp = createTempFileFromResource(path);
|
||||||
// Finally, load the library
|
// Finally, load the library
|
||||||
System.load(temp.getAbsolutePath());
|
System.load(temp);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a temp file that copies the resource from current JAR archive
|
||||||
|
* <p/>
|
||||||
|
* The file from JAR is copied into system temp file.
|
||||||
|
* The temporary file is deleted after exiting.
|
||||||
|
* Method uses String as filename because the pathname is "abstract", not system-dependent.
|
||||||
|
* <p/>
|
||||||
|
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
|
||||||
|
* {@code path}.
|
||||||
|
* @param path Path to the resources in the jar
|
||||||
|
* @return The created temp file.
|
||||||
|
* @throws IOException
|
||||||
|
* @throws IllegalArgumentException
|
||||||
|
*/
|
||||||
|
static String createTempFileFromResource(String path) throws
|
||||||
|
IOException, IllegalArgumentException {
|
||||||
|
// Obtain filename from path
|
||||||
|
if (!path.startsWith("/")) {
|
||||||
|
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
|
||||||
|
}
|
||||||
|
|
||||||
|
String[] parts = path.split("/");
|
||||||
|
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
|
||||||
|
|
||||||
|
// Split filename to prexif and suffix (extension)
|
||||||
|
String prefix = "";
|
||||||
|
String suffix = null;
|
||||||
|
if (filename != null) {
|
||||||
|
parts = filename.split("\\.", 2);
|
||||||
|
prefix = parts[0];
|
||||||
|
suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the filename is okay
|
||||||
|
if (filename == null || prefix.length() < 3) {
|
||||||
|
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
|
||||||
|
}
|
||||||
|
// Prepare temporary file
|
||||||
|
File temp = File.createTempFile(prefix, suffix);
|
||||||
|
temp.deleteOnExit();
|
||||||
|
|
||||||
|
if (!temp.exists()) {
|
||||||
|
throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare buffer for data copying
|
||||||
|
byte[] buffer = new byte[1024];
|
||||||
|
int readBytes;
|
||||||
|
|
||||||
|
// Open and check input stream
|
||||||
|
InputStream is = NativeLibLoader.class.getResourceAsStream(path);
|
||||||
|
if (is == null) {
|
||||||
|
throw new FileNotFoundException("File " + path + " was not found inside JAR.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open output stream and copy data between source file in JAR and the temporary file
|
||||||
|
OutputStream os = new FileOutputStream(temp);
|
||||||
|
try {
|
||||||
|
while ((readBytes = is.read(buffer)) != -1) {
|
||||||
|
os.write(buffer, 0, readBytes);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
// If read/write fails, close streams safely before throwing an exception
|
||||||
|
os.close();
|
||||||
|
is.close();
|
||||||
|
}
|
||||||
|
return temp.getAbsolutePath();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -84,8 +151,9 @@ class NativeLibLoader {
|
|||||||
try {
|
try {
|
||||||
String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);
|
String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);
|
||||||
loadLibraryFromJar(libraryFromJar);
|
loadLibraryFromJar(libraryFromJar);
|
||||||
} catch (IOException e1) {
|
} catch (IOException ioe) {
|
||||||
throw e1;
|
logger.error("failed to load library from both native path and jar");
|
||||||
|
throw ioe;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,6 +5,7 @@ package ml.dmlc.xgboost4j;
|
|||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
@ -16,45 +17,40 @@ public class RabitTracker {
|
|||||||
// Maybe per tracker logger?
|
// Maybe per tracker logger?
|
||||||
private static final Log logger = LogFactory.getLog(RabitTracker.class);
|
private static final Log logger = LogFactory.getLog(RabitTracker.class);
|
||||||
// tracker python file.
|
// tracker python file.
|
||||||
private static File tracker_py = null;
|
private static String tracker_py = null;
|
||||||
// environment variable to be pased.
|
// environment variable to be pased.
|
||||||
private Map<String, String> envs = new HashMap<String, String>();
|
private Map<String, String> envs = new HashMap<String, String>();
|
||||||
// number of workers to be submitted.
|
// number of workers to be submitted.
|
||||||
private int num_workers;
|
private int num_workers;
|
||||||
// child process
|
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
|
||||||
private Process process = null;
|
|
||||||
// logger thread
|
|
||||||
private Thread logger_thread = null;
|
|
||||||
|
|
||||||
//load native library
|
|
||||||
static {
|
|
||||||
try {
|
|
||||||
initTrackerPy();
|
|
||||||
} catch (IOException ex) {
|
|
||||||
logger.error("load tracker library failed.");
|
|
||||||
logger.error(ex);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tracker logger that logs output from tracker.
|
* Tracker logger that logs output from tracker.
|
||||||
*/
|
*/
|
||||||
private class TrackerLogger implements Runnable {
|
private class TrackerProcessLogger implements Runnable {
|
||||||
public void run() {
|
public void run() {
|
||||||
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream()));
|
|
||||||
|
Log trackerProcessLogger = LogFactory.getLog(TrackerProcessLogger.class);
|
||||||
|
BufferedReader reader = new BufferedReader(new InputStreamReader(
|
||||||
|
trackerProcess.get().getErrorStream()));
|
||||||
String line;
|
String line;
|
||||||
try {
|
try {
|
||||||
while ((line = reader.readLine()) != null) {
|
while ((line = reader.readLine()) != null) {
|
||||||
logger.info(line);
|
trackerProcessLogger.info(line);
|
||||||
}
|
}
|
||||||
} catch (IOException ex) {
|
} catch (IOException ex) {
|
||||||
logger.error(ex.toString());
|
trackerProcessLogger.error(ex.toString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static synchronized void initTrackerPy() throws IOException {
|
private void initTrackerPy() throws IOException {
|
||||||
tracker_py = FileUtil.createTempFileFromResource("/tracker.py");
|
try {
|
||||||
|
tracker_py = NativeLibLoader.createTempFileFromResource("/tracker.py");
|
||||||
|
} catch (IOException ioe) {
|
||||||
|
logger.trace("cannot access tracker python script");
|
||||||
|
throw ioe;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -70,29 +66,71 @@ public class RabitTracker {
|
|||||||
return envs;
|
return envs;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void start() throws IOException {
|
private void loadEnvs(InputStream ins) throws IOException {
|
||||||
process = Runtime.getRuntime().exec("python " + tracker_py.getAbsolutePath() +
|
try {
|
||||||
" --num-workers=" + new Integer(num_workers).toString());
|
BufferedReader reader = new BufferedReader(new InputStreamReader(ins));
|
||||||
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
|
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START");
|
||||||
assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START");
|
String line;
|
||||||
String line;
|
while ((line = reader.readLine()) != null) {
|
||||||
while ((line = reader.readLine()) != null) {
|
if (line.trim().equals("DMLC_TRACKER_ENV_END")) {
|
||||||
if (line.trim().equals("DMLC_TRACKER_ENV_END")) {
|
break;
|
||||||
break;
|
}
|
||||||
}
|
String[] sep = line.split("=");
|
||||||
String []sep = line.split("=");
|
if (sep.length == 2) {
|
||||||
if (sep.length == 2) {
|
envs.put(sep[0], sep[1]);
|
||||||
envs.put(sep[0], sep[1]);
|
}
|
||||||
}
|
}
|
||||||
|
reader.close();
|
||||||
|
} catch (IOException ioe){
|
||||||
|
logger.error("cannot get runtime configuration from tracker process");
|
||||||
|
ioe.printStackTrace();
|
||||||
|
throw ioe;
|
||||||
}
|
}
|
||||||
logger.debug("Tracker started, with env=" + envs.toString());
|
|
||||||
// also start a tracker logger
|
|
||||||
logger_thread = new Thread(new TrackerLogger());
|
|
||||||
logger_thread.setDaemon(true);
|
|
||||||
logger_thread.start();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void waitFor() throws InterruptedException {
|
private boolean startTrackerProcess() {
|
||||||
process.waitFor();
|
try {
|
||||||
|
initTrackerPy();
|
||||||
|
trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py +
|
||||||
|
" --num-workers=" + String.valueOf(num_workers)));
|
||||||
|
loadEnvs(trackerProcess.get().getInputStream());
|
||||||
|
return true;
|
||||||
|
} catch (IOException ioe) {
|
||||||
|
ioe.printStackTrace();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void stop() {
|
||||||
|
if (trackerProcess.get() != null) {
|
||||||
|
trackerProcess.get().destroy();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean start() {
|
||||||
|
if (startTrackerProcess()) {
|
||||||
|
logger.debug("Tracker started, with env=" + envs.toString());
|
||||||
|
// also start a tracker logger
|
||||||
|
Thread logger_thread = new Thread(new TrackerProcessLogger());
|
||||||
|
logger_thread.setDaemon(true);
|
||||||
|
logger_thread.start();
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
logger.error("FAULT: failed to start tracker process");
|
||||||
|
stop();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void waitFor() {
|
||||||
|
try {
|
||||||
|
trackerProcess.get().waitFor();
|
||||||
|
logger.info("Tracker Process ends with exit code " + trackerProcess.get().exitValue());
|
||||||
|
stop();
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
// we should not get here as RabitTracker is accessed in the main thread
|
||||||
|
e.printStackTrace();
|
||||||
|
logger.error("the RabitTracker thread is terminated unexpectedly");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
2
rabit
2
rabit
@ -1 +1 @@
|
|||||||
Subproject commit be50e7b63224b9fb7ff94ce34df9f8752ef83043
|
Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0
|
||||||
Loading…
x
Reference in New Issue
Block a user