[JVM-PKG] add distributed test simple case
This commit is contained in:
parent
5c9e50148a
commit
c428a93adc
5
jvm-packages/test_distributed.sh
Normal file
5
jvm-packages/test_distributed.sh
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Simple script to test distributed version, to be deleted later.
|
||||||
|
cd xgboost4j-demo
|
||||||
|
../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=3 java -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain
|
||||||
|
cd ..
|
||||||
@ -0,0 +1,49 @@
|
|||||||
|
package ml.dmlc.xgboost4j.demo;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.Rabit;
|
||||||
|
import ml.dmlc.xgboost4j.Booster;
|
||||||
|
import ml.dmlc.xgboost4j.DMatrix;
|
||||||
|
import ml.dmlc.xgboost4j.XGBoost;
|
||||||
|
import ml.dmlc.xgboost4j.XGBoostError;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Distributed training example, used to quick test distributed training.
|
||||||
|
*
|
||||||
|
* @author tqchen
|
||||||
|
*/
|
||||||
|
public class DistTrain {
|
||||||
|
|
||||||
|
public static void main(String[] args) throws IOException, XGBoostError {
|
||||||
|
// always initialize rabit module before training.
|
||||||
|
Rabit.init(new HashMap<String, String>());
|
||||||
|
|
||||||
|
// load file from text file, also binary buffer generated by xgboost4j
|
||||||
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
|
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||||
|
params.put("eta", 1.0);
|
||||||
|
params.put("max_depth", 2);
|
||||||
|
params.put("silent", 1);
|
||||||
|
params.put("objective", "binary:logistic");
|
||||||
|
|
||||||
|
|
||||||
|
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||||
|
watches.put("train", trainMat);
|
||||||
|
watches.put("test", testMat);
|
||||||
|
|
||||||
|
//set round
|
||||||
|
int round = 2;
|
||||||
|
|
||||||
|
//train a boost model
|
||||||
|
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||||
|
|
||||||
|
// always shutdown rabit module after training.
|
||||||
|
Rabit.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -96,7 +96,9 @@ public class XGBoost {
|
|||||||
} else {
|
} else {
|
||||||
evalInfo = booster.evalSet(evalMats, evalNames, iter);
|
evalInfo = booster.evalSet(evalMats, evalNames, iter);
|
||||||
}
|
}
|
||||||
logger.info(evalInfo);
|
if (Rabit.getRank() == 0) {
|
||||||
|
Rabit.trackerPrint(evalInfo + '\n');
|
||||||
|
}
|
||||||
}
|
}
|
||||||
booster.saveRabitCheckpoint();
|
booster.saveRabitCheckpoint();
|
||||||
version += 1;
|
version += 1;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user