From c428a93adcb345f9a08c784aeecf7f7bd8c534f0 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 2 Mar 2016 22:27:55 -0800 Subject: [PATCH] [JVM-PKG] add distributed test simple case --- jvm-packages/test_distributed.sh | 5 ++ .../ml/dmlc/xgboost4j/demo/DistTrain.java | 49 +++++++++++++++++++ .../main/java/ml/dmlc/xgboost4j/XGBoost.java | 4 +- 3 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 jvm-packages/test_distributed.sh create mode 100644 jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/DistTrain.java diff --git a/jvm-packages/test_distributed.sh b/jvm-packages/test_distributed.sh new file mode 100644 index 000000000..7b5515b49 --- /dev/null +++ b/jvm-packages/test_distributed.sh @@ -0,0 +1,5 @@ +#!/bin/bash +# Simple script to test distributed version, to be deleted later. +cd xgboost4j-demo +../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=3 java -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain +cd .. diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/DistTrain.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/DistTrain.java new file mode 100644 index 000000000..30a8ba85f --- /dev/null +++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/DistTrain.java @@ -0,0 +1,49 @@ +package ml.dmlc.xgboost4j.demo; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; + +import ml.dmlc.xgboost4j.Rabit; +import ml.dmlc.xgboost4j.Booster; +import ml.dmlc.xgboost4j.DMatrix; +import ml.dmlc.xgboost4j.XGBoost; +import ml.dmlc.xgboost4j.XGBoostError; + +/** + * Distributed training example, used to quick test distributed training. + * + * @author tqchen + */ +public class DistTrain { + + public static void main(String[] args) throws IOException, XGBoostError { + // always initialize rabit module before training. + Rabit.init(new HashMap()); + + // load file from text file, also binary buffer generated by xgboost4j + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + HashMap params = new HashMap(); + params.put("eta", 1.0); + params.put("max_depth", 2); + params.put("silent", 1); + params.put("objective", "binary:logistic"); + + + HashMap watches = new HashMap(); + watches.put("train", trainMat); + watches.put("test", testMat); + + //set round + int round = 2; + + //train a boost model + Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + + // always shutdown rabit module after training. + Rabit.shutdown(); + } +} \ No newline at end of file diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java index 4214d8f13..8b7d3add0 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java @@ -96,7 +96,9 @@ public class XGBoost { } else { evalInfo = booster.evalSet(evalMats, evalNames, iter); } - logger.info(evalInfo); + if (Rabit.getRank() == 0) { + Rabit.trackerPrint(evalInfo + '\n'); + } } booster.saveRabitCheckpoint(); version += 1;