[jvm-packages] Allow supression of Rabit output in Booster::train in xgboost4j (#4262)
* Make train in xgboost4j respect print params Previously no setting in params argument of Booster::train would prevent the Rabit.trackerPrint call. This can fill up a lot of screen space in the case that many folds are being trained. * Setting "silent" in this map to "true", "True", a non-zero integer, or a string that can be parsed to such an int will prevent printing. * Setting "verbose_eval" to "False" or "false" will prevent printing. * Setting "verbose_eval" to an int (or a String parseable to an int) n will result in printing every n steps, or no printing is n is zero. This is to match the python behaviour described here: https://www.kaggle.com/c/rossmann-store-sales/discussion/17499 * Fixed 'slient' typo in xgboost4j test * private access on two methods
This commit is contained in:
parent
45c89a6792
commit
b374e0a7ab
@ -228,8 +228,10 @@ public class XGBoost {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Rabit.getRank() == 0) {
|
if (Rabit.getRank() == 0 && shouldPrint(params, iter)) {
|
||||||
Rabit.trackerPrint(evalInfo + '\n');
|
if (shouldPrint(params, iter)){
|
||||||
|
Rabit.trackerPrint(evalInfo + '\n');
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
booster.saveRabitCheckpoint();
|
booster.saveRabitCheckpoint();
|
||||||
@ -237,6 +239,47 @@ public class XGBoost {
|
|||||||
return booster;
|
return booster;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static Integer tryGetIntFromObject(Object o) {
|
||||||
|
if (o instanceof Integer) {
|
||||||
|
return (int)o;
|
||||||
|
} else if (o instanceof String) {
|
||||||
|
try {
|
||||||
|
return Integer.parseInt((String)o);
|
||||||
|
} catch (NumberFormatException e) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean shouldPrint(Map<String, Object> params, int iter) {
|
||||||
|
Object silent = params.get("silent");
|
||||||
|
Integer silentInt = tryGetIntFromObject(silent);
|
||||||
|
if (silent != null) {
|
||||||
|
if (silent.equals("true") || silent.equals("True")
|
||||||
|
|| (silentInt != null && silentInt != 0)) {
|
||||||
|
return false; // "silent" will stop printing, otherwise go look at "verbose_eval"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Object verboseEval = params.get("verbose_eval");
|
||||||
|
Integer verboseEvalInt = tryGetIntFromObject(verboseEval);
|
||||||
|
if (verboseEval == null) {
|
||||||
|
return true; // Default to printing evalInfo
|
||||||
|
} else if (verboseEval.equals("false") || verboseEval.equals("False")) {
|
||||||
|
return false;
|
||||||
|
} else if (verboseEvalInt != null) {
|
||||||
|
if (verboseEvalInt == 0) {
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
return iter % verboseEvalInt == 0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return true; // Don't understand the option, default to printing
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static boolean shouldEarlyStop(int earlyStoppingRounds, int iter, int bestIteration) {
|
static boolean shouldEarlyStop(int earlyStoppingRounds, int iter, int bestIteration) {
|
||||||
return iter - bestIteration >= earlyStoppingRounds;
|
return iter - bestIteration >= earlyStoppingRounds;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -139,7 +139,7 @@ class ScalaBoosterImplSuite extends FunSuite {
|
|||||||
|
|
||||||
test("cross validation") {
|
test("cross validation") {
|
||||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
val params = List("eta" -> "1.0", "max_depth" -> "3", "slient" -> "1", "nthread" -> "6",
|
val params = List("eta" -> "1.0", "max_depth" -> "3", "silent" -> "1", "nthread" -> "6",
|
||||||
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
|
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
|
||||||
val round = 2
|
val round = 2
|
||||||
val nfold = 5
|
val nfold = 5
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user