diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index 06ea2eb4b..b6a173bd6 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -228,8 +228,10 @@ public class XGBoost { break; } } - if (Rabit.getRank() == 0) { - Rabit.trackerPrint(evalInfo + '\n'); + if (Rabit.getRank() == 0 && shouldPrint(params, iter)) { + if (shouldPrint(params, iter)){ + Rabit.trackerPrint(evalInfo + '\n'); + } } } booster.saveRabitCheckpoint(); @@ -237,6 +239,47 @@ public class XGBoost { return booster; } + private static Integer tryGetIntFromObject(Object o) { + if (o instanceof Integer) { + return (int)o; + } else if (o instanceof String) { + try { + return Integer.parseInt((String)o); + } catch (NumberFormatException e) { + return null; + } + } else { + return null; + } + } + + private static boolean shouldPrint(Map params, int iter) { + Object silent = params.get("silent"); + Integer silentInt = tryGetIntFromObject(silent); + if (silent != null) { + if (silent.equals("true") || silent.equals("True") + || (silentInt != null && silentInt != 0)) { + return false; // "silent" will stop printing, otherwise go look at "verbose_eval" + } + } + + Object verboseEval = params.get("verbose_eval"); + Integer verboseEvalInt = tryGetIntFromObject(verboseEval); + if (verboseEval == null) { + return true; // Default to printing evalInfo + } else if (verboseEval.equals("false") || verboseEval.equals("False")) { + return false; + } else if (verboseEvalInt != null) { + if (verboseEvalInt == 0) { + return false; + } else { + return iter % verboseEvalInt == 0; + } + } else { + return true; // Don't understand the option, default to printing + } + } + static boolean shouldEarlyStop(int earlyStoppingRounds, int iter, int bestIteration) { return iter - bestIteration >= earlyStoppingRounds; } diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala index a53a5cd29..adea1b1ec 100644 --- a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala +++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala @@ -139,7 +139,7 @@ class ScalaBoosterImplSuite extends FunSuite { test("cross validation") { val trainMat = new DMatrix("../../demo/data/agaricus.txt.train") - val params = List("eta" -> "1.0", "max_depth" -> "3", "slient" -> "1", "nthread" -> "6", + val params = List("eta" -> "1.0", "max_depth" -> "3", "silent" -> "1", "nthread" -> "6", "objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap val round = 2 val nfold = 5