revise the RabitTracker Impl
delete FileUtil class fix bugs
This commit is contained in:
@@ -20,16 +20,17 @@ public class DistTrain {
|
||||
private Map<String, String> envs = null;
|
||||
|
||||
private class Worker implements Runnable {
|
||||
private int worker_id;
|
||||
Worker(int worker_id) {
|
||||
this.worker_id = worker_id;
|
||||
private final int workerId;
|
||||
|
||||
Worker(int workerId) {
|
||||
this.workerId = workerId;
|
||||
}
|
||||
|
||||
public void run() {
|
||||
try {
|
||||
Map<String, String> worker_env = new HashMap<String, String>(envs);
|
||||
|
||||
worker_env.put("DMLC_TASK_ID", new Integer(worker_id).toString());
|
||||
worker_env.put("DMLC_TASK_ID", String.valueOf(workerId));
|
||||
// always initialize rabit module before training.
|
||||
Rabit.init(worker_env);
|
||||
|
||||
@@ -44,7 +45,6 @@ public class DistTrain {
|
||||
params.put("nthread", 2);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
@@ -63,14 +63,15 @@ public class DistTrain {
|
||||
}
|
||||
}
|
||||
|
||||
void start(int nworker) throws IOException, XGBoostError, InterruptedException {
|
||||
RabitTracker tracker = new RabitTracker(nworker);
|
||||
tracker.start();
|
||||
envs = tracker.getWorkerEnvs();
|
||||
for (int i = 0; i < nworker; ++i) {
|
||||
new Thread(new Worker(i)).start();
|
||||
void start(int nWorkers) throws IOException, XGBoostError, InterruptedException {
|
||||
RabitTracker tracker = new RabitTracker(nWorkers);
|
||||
if (tracker.start()) {
|
||||
envs = tracker.getWorkerEnvs();
|
||||
for (int i = 0; i < nWorkers; ++i) {
|
||||
new Thread(new Worker(i)).start();
|
||||
}
|
||||
tracker.waitFor();
|
||||
}
|
||||
tracker.waitFor();
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws IOException, XGBoostError, InterruptedException {
|
||||
|
||||
Reference in New Issue
Block a user