Control logging for early stopping using shouldPrint() (#7326)

This commit is contained in:
nicovdijk 2021-10-21 06:12:06 +02:00 committed by GitHub
parent 8d7c6366d7
commit 74bab6e504
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,5 @@
/* /*
Copyright (c) 2014 by Contributors Copyright (c) 2014,2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -21,7 +21,6 @@ import java.util.*;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
/** /**
* trainer for xgboost * trainer for xgboost
@ -253,13 +252,14 @@ public class XGBoost {
booster.setAttr("best_score", String.valueOf(bestScore)); booster.setAttr("best_score", String.valueOf(bestScore));
} }
} }
if (earlyStoppingRounds > 0) {
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) { if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
if (shouldPrint(params, iter)) {
Rabit.trackerPrint(String.format( Rabit.trackerPrint(String.format(
"early stopping after %d rounds away from the best iteration", "early stopping after %d rounds away from the best iteration",
earlyStoppingRounds)); earlyStoppingRounds
break; ));
} }
break;
} }
if (Rabit.getRank() == 0 && shouldPrint(params, iter)) { if (Rabit.getRank() == 0 && shouldPrint(params, iter)) {
if (shouldPrint(params, iter)){ if (shouldPrint(params, iter)){
@ -352,6 +352,9 @@ public class XGBoost {
} }
static boolean shouldEarlyStop(int earlyStoppingRounds, int iter, int bestIteration) { static boolean shouldEarlyStop(int earlyStoppingRounds, int iter, int bestIteration) {
if (earlyStoppingRounds <= 0) {
return false;
}
return iter - bestIteration >= earlyStoppingRounds; return iter - bestIteration >= earlyStoppingRounds;
} }