Add period to evaluation monitor. (#6348)
This commit is contained in:
parent
d411f98d26
commit
184e2eac7d
@ -583,12 +583,18 @@ class EvaluationMonitor(TrainingCallback):
|
|||||||
Extra user defined metric.
|
Extra user defined metric.
|
||||||
rank : int
|
rank : int
|
||||||
Which worker should be used for printing the result.
|
Which worker should be used for printing the result.
|
||||||
|
period : int
|
||||||
|
How many epoches between printing.
|
||||||
show_stdv : bool
|
show_stdv : bool
|
||||||
Used in cv to show standard deviation. Users should not specify it.
|
Used in cv to show standard deviation. Users should not specify it.
|
||||||
'''
|
'''
|
||||||
def __init__(self, rank=0, show_stdv=False):
|
def __init__(self, rank=0, period=1, show_stdv=False):
|
||||||
self.printer_rank = rank
|
self.printer_rank = rank
|
||||||
self.show_stdv = show_stdv
|
self.show_stdv = show_stdv
|
||||||
|
self.period = period
|
||||||
|
assert period > 0
|
||||||
|
# last error message, useful when early stopping and period are used together.
|
||||||
|
self._lastest = None
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def _fmt_metric(self, data, metric, score, std):
|
def _fmt_metric(self, data, metric, score, std):
|
||||||
@ -601,6 +607,7 @@ class EvaluationMonitor(TrainingCallback):
|
|||||||
def after_iteration(self, model, epoch, evals_log):
|
def after_iteration(self, model, epoch, evals_log):
|
||||||
if not evals_log:
|
if not evals_log:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
msg = f'[{epoch}]'
|
msg = f'[{epoch}]'
|
||||||
if rabit.get_rank() == self.printer_rank:
|
if rabit.get_rank() == self.printer_rank:
|
||||||
for data, metric in evals_log.items():
|
for data, metric in evals_log.items():
|
||||||
@ -613,9 +620,20 @@ class EvaluationMonitor(TrainingCallback):
|
|||||||
stdv = None
|
stdv = None
|
||||||
msg += self._fmt_metric(data, metric_name, score, stdv)
|
msg += self._fmt_metric(data, metric_name, score, stdv)
|
||||||
msg += '\n'
|
msg += '\n'
|
||||||
rabit.tracker_print(msg)
|
|
||||||
|
if (epoch % self.period) != 0:
|
||||||
|
rabit.tracker_print(msg)
|
||||||
|
self._lastest = None
|
||||||
|
else:
|
||||||
|
# There is skipped message
|
||||||
|
self._lastest = msg
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def after_training(self, model):
|
||||||
|
if rabit.get_rank() == self.printer_rank and self._lastest is not None:
|
||||||
|
rabit.tracker_print(self._lastest)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class TrainingCheckPoint(TrainingCallback):
|
class TrainingCheckPoint(TrainingCallback):
|
||||||
'''Checkpointing operation.
|
'''Checkpointing operation.
|
||||||
|
|||||||
@ -92,7 +92,8 @@ def _train_internal(params, dtrain,
|
|||||||
assert all(isinstance(c, callback.TrainingCallback)
|
assert all(isinstance(c, callback.TrainingCallback)
|
||||||
for c in callbacks), "You can't mix new and old callback styles."
|
for c in callbacks), "You can't mix new and old callback styles."
|
||||||
if verbose_eval:
|
if verbose_eval:
|
||||||
callbacks.append(callback.EvaluationMonitor())
|
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||||
|
callbacks.append(callback.EvaluationMonitor(period=verbose_eval))
|
||||||
if early_stopping_rounds:
|
if early_stopping_rounds:
|
||||||
callbacks.append(callback.EarlyStopping(
|
callbacks.append(callback.EarlyStopping(
|
||||||
rounds=early_stopping_rounds, maximize=maximize))
|
rounds=early_stopping_rounds, maximize=maximize))
|
||||||
@ -485,7 +486,9 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
|||||||
assert all(isinstance(c, callback.TrainingCallback)
|
assert all(isinstance(c, callback.TrainingCallback)
|
||||||
for c in callbacks), "You can't mix new and old callback styles."
|
for c in callbacks), "You can't mix new and old callback styles."
|
||||||
if isinstance(verbose_eval, bool) and verbose_eval:
|
if isinstance(verbose_eval, bool) and verbose_eval:
|
||||||
callbacks.append(callback.EvaluationMonitor(show_stdv=show_stdv))
|
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||||
|
callbacks.append(callback.EvaluationMonitor(period=verbose_eval,
|
||||||
|
show_stdv=show_stdv))
|
||||||
if early_stopping_rounds:
|
if early_stopping_rounds:
|
||||||
callbacks.append(callback.EarlyStopping(
|
callbacks.append(callback.EarlyStopping(
|
||||||
rounds=early_stopping_rounds, maximize=maximize))
|
rounds=early_stopping_rounds, maximize=maximize))
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import sys
|
|||||||
from test_gpu_pickling import build_dataset, model_path, load_pickle
|
from test_gpu_pickling import build_dataset, model_path, load_pickle
|
||||||
|
|
||||||
sys.path.append("tests/python")
|
sys.path.append("tests/python")
|
||||||
import test_basic as tb
|
import testing as tm
|
||||||
|
|
||||||
|
|
||||||
class TestLoadPickle(unittest.TestCase):
|
class TestLoadPickle(unittest.TestCase):
|
||||||
@ -61,7 +61,7 @@ class TestLoadPickle(unittest.TestCase):
|
|||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
X = rng.randn(10, 10)
|
X = rng.randn(10, 10)
|
||||||
y = rng.randn(10)
|
y = rng.randn(10)
|
||||||
with tb.captured_output() as (out, err):
|
with tm.captured_output() as (out, err):
|
||||||
# Test no thrust exception is thrown
|
# Test no thrust exception is thrown
|
||||||
with pytest.raises(xgb.core.XGBoostError):
|
with pytest.raises(xgb.core.XGBoostError):
|
||||||
xgb.train({'tree_method': 'gpu_hist'}, xgb.DMatrix(X, y))
|
xgb.train({'tree_method': 'gpu_hist'}, xgb.DMatrix(X, y))
|
||||||
|
|||||||
@ -7,13 +7,12 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
import pytest
|
import pytest
|
||||||
|
import xgboost as xgb
|
||||||
|
from xgboost import XGBClassifier
|
||||||
|
|
||||||
sys.path.append("tests/python")
|
sys.path.append("tests/python")
|
||||||
import testing as tm
|
import testing as tm
|
||||||
|
|
||||||
import xgboost as xgb
|
|
||||||
from xgboost import XGBClassifier
|
|
||||||
|
|
||||||
model_path = './model.pkl'
|
model_path = './model.pkl'
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,4 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import sys
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from io import StringIO
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
@ -9,29 +6,12 @@ import unittest
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import testing as tm
|
||||||
|
|
||||||
dpath = 'demo/data/'
|
dpath = 'demo/data/'
|
||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def captured_output():
|
|
||||||
"""Reassign stdout temporarily in order to test printed statements
|
|
||||||
Taken from:
|
|
||||||
https://stackoverflow.com/questions/4219717/how-to-assert-output-with-nosetest-unittest-in-python
|
|
||||||
|
|
||||||
Also works for pytest.
|
|
||||||
|
|
||||||
"""
|
|
||||||
new_out, new_err = StringIO(), StringIO()
|
|
||||||
old_out, old_err = sys.stdout, sys.stderr
|
|
||||||
try:
|
|
||||||
sys.stdout, sys.stderr = new_out, new_err
|
|
||||||
yield sys.stdout, sys.stderr
|
|
||||||
finally:
|
|
||||||
sys.stdout, sys.stderr = old_out, old_err
|
|
||||||
|
|
||||||
|
|
||||||
class TestBasic(unittest.TestCase):
|
class TestBasic(unittest.TestCase):
|
||||||
def test_compat(self):
|
def test_compat(self):
|
||||||
from xgboost.compat import lazy_isinstance
|
from xgboost.compat import lazy_isinstance
|
||||||
@ -181,7 +161,6 @@ class TestBasic(unittest.TestCase):
|
|||||||
assert dm.num_row() == row
|
assert dm.num_row() == row
|
||||||
assert dm.num_col() == cols
|
assert dm.num_col() == cols
|
||||||
|
|
||||||
|
|
||||||
def test_cv(self):
|
def test_cv(self):
|
||||||
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||||
@ -236,7 +215,7 @@ class TestBasic(unittest.TestCase):
|
|||||||
print([fold.dtest.get_label() for fold in cbackenv.cvfolds])
|
print([fold.dtest.get_label() for fold in cbackenv.cvfolds])
|
||||||
|
|
||||||
# Run cross validation and capture standard out to test callback result
|
# Run cross validation and capture standard out to test callback result
|
||||||
with captured_output() as (out, err):
|
with tm.captured_output() as (out, err):
|
||||||
xgb.cv(
|
xgb.cv(
|
||||||
params, dm, num_boost_round=1, folds=folds, callbacks=[cb],
|
params, dm, num_boost_round=1, folds=folds, callbacks=[cb],
|
||||||
as_pandas=False
|
as_pandas=False
|
||||||
@ -257,7 +236,6 @@ class TestBasicPathLike(unittest.TestCase):
|
|||||||
assert dtrain.num_row() == 6513
|
assert dtrain.num_row() == 6513
|
||||||
assert dtrain.num_col() == 127
|
assert dtrain.num_col() == 127
|
||||||
|
|
||||||
|
|
||||||
def test_DMatrix_save_to_path(self):
|
def test_DMatrix_save_to_path(self):
|
||||||
"""Saving to a binary file using pathlib from a DMatrix."""
|
"""Saving to a binary file using pathlib from a DMatrix."""
|
||||||
data = np.random.randn(100, 2)
|
data = np.random.randn(100, 2)
|
||||||
|
|||||||
@ -34,10 +34,27 @@ class TestCallbacks(unittest.TestCase):
|
|||||||
num_boost_round=rounds,
|
num_boost_round=rounds,
|
||||||
evals_result=evals_result,
|
evals_result=evals_result,
|
||||||
verbose_eval=True)
|
verbose_eval=True)
|
||||||
print('evals_result:', evals_result)
|
|
||||||
assert len(evals_result['Train']['error']) == rounds
|
assert len(evals_result['Train']['error']) == rounds
|
||||||
assert len(evals_result['Valid']['error']) == rounds
|
assert len(evals_result['Valid']['error']) == rounds
|
||||||
|
|
||||||
|
with tm.captured_output() as (out, err):
|
||||||
|
xgb.train({'objective': 'binary:logistic',
|
||||||
|
'eval_metric': 'error'}, D_train,
|
||||||
|
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
||||||
|
num_boost_round=rounds,
|
||||||
|
evals_result=evals_result,
|
||||||
|
verbose_eval=2)
|
||||||
|
output: str = out.getvalue().strip()
|
||||||
|
|
||||||
|
pos = 0
|
||||||
|
msg = 'Train-error'
|
||||||
|
for i in range(rounds // 2):
|
||||||
|
pos = output.find('Train-error', pos)
|
||||||
|
assert pos != -1
|
||||||
|
pos += len(msg)
|
||||||
|
|
||||||
|
assert output.find('Train-error', pos) == -1
|
||||||
|
|
||||||
def test_early_stopping(self):
|
def test_early_stopping(self):
|
||||||
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||||
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
|
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import collections
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost.sklearn import XGBoostLabelEncoder
|
|
||||||
import testing as tm
|
import testing as tm
|
||||||
import tempfile
|
import tempfile
|
||||||
import os
|
import os
|
||||||
@ -11,8 +10,6 @@ import pytest
|
|||||||
import unittest
|
import unittest
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from test_basic import captured_output
|
|
||||||
|
|
||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(**tm.no_sklearn())
|
pytestmark = pytest.mark.skipif(**tm.no_sklearn())
|
||||||
@ -872,7 +869,7 @@ def test_parameter_validation():
|
|||||||
reg = xgb.XGBRegressor(foo='bar', verbosity=1)
|
reg = xgb.XGBRegressor(foo='bar', verbosity=1)
|
||||||
X = np.random.randn(10, 10)
|
X = np.random.randn(10, 10)
|
||||||
y = np.random.randn(10)
|
y = np.random.randn(10)
|
||||||
with captured_output() as (out, err):
|
with tm.captured_output() as (out, err):
|
||||||
reg.fit(X, y)
|
reg.fit(X, y)
|
||||||
output = out.getvalue().strip()
|
output = out.getvalue().strip()
|
||||||
|
|
||||||
@ -882,7 +879,7 @@ def test_parameter_validation():
|
|||||||
importance_type='gain', verbosity=1)
|
importance_type='gain', verbosity=1)
|
||||||
X = np.random.randn(10, 10)
|
X = np.random.randn(10, 10)
|
||||||
y = np.random.randn(10)
|
y = np.random.randn(10)
|
||||||
with captured_output() as (out, err):
|
with tm.captured_output() as (out, err):
|
||||||
reg.fit(X, y)
|
reg.fit(X, y)
|
||||||
output = out.getvalue().strip()
|
output = out.getvalue().strip()
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
import os
|
import os
|
||||||
import platform
|
import sys
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from io import StringIO
|
||||||
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
|
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
|
||||||
from xgboost.compat import DASK_INSTALLED
|
from xgboost.compat import DASK_INSTALLED
|
||||||
import pytest
|
import pytest
|
||||||
@ -281,6 +283,24 @@ class DirectoryExcursion:
|
|||||||
os.remove(f)
|
os.remove(f)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def captured_output():
|
||||||
|
"""Reassign stdout temporarily in order to test printed statements
|
||||||
|
Taken from:
|
||||||
|
https://stackoverflow.com/questions/4219717/how-to-assert-output-with-nosetest-unittest-in-python
|
||||||
|
|
||||||
|
Also works for pytest.
|
||||||
|
|
||||||
|
"""
|
||||||
|
new_out, new_err = StringIO(), StringIO()
|
||||||
|
old_out, old_err = sys.stdout, sys.stderr
|
||||||
|
try:
|
||||||
|
sys.stdout, sys.stderr = new_out, new_err
|
||||||
|
yield sys.stdout, sys.stderr
|
||||||
|
finally:
|
||||||
|
sys.stdout, sys.stderr = old_out, old_err
|
||||||
|
|
||||||
|
|
||||||
CURDIR = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
|
CURDIR = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
|
||||||
PROJECT_ROOT = os.path.normpath(
|
PROJECT_ROOT = os.path.normpath(
|
||||||
os.path.join(CURDIR, os.path.pardir, os.path.pardir))
|
os.path.join(CURDIR, os.path.pardir, os.path.pardir))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user