[JVM-PKG] add distributed test simple case

This commit is contained in:
tqchen
2016-03-02 22:27:55 -08:00
parent 5c9e50148a
commit c428a93adc
3 changed files with 57 additions and 1 deletions

View File

@@ -0,0 +1,49 @@
package ml.dmlc.xgboost4j.demo;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import ml.dmlc.xgboost4j.Rabit;
import ml.dmlc.xgboost4j.Booster;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.XGBoost;
import ml.dmlc.xgboost4j.XGBoostError;
/**
* Distributed training example, used to quick test distributed training.
*
* @author tqchen
*/
public class DistTrain {
public static void main(String[] args) throws IOException, XGBoostError {
// always initialize rabit module before training.
Rabit.init(new HashMap<String, String>());
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("objective", "binary:logistic");
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//set round
int round = 2;
//train a boost model
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
// always shutdown rabit module after training.
Rabit.shutdown();
}
}