Move metric configuration into booster. (#6504)
This commit is contained in:
parent
d45c0d843b
commit
3c3f026ec1
@ -1,11 +1,12 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
|
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
|
||||||
# pylint: disable=too-many-lines, too-many-locals
|
# pylint: disable=too-many-lines, too-many-locals, no-self-use
|
||||||
"""Core XGBoost Library."""
|
"""Core XGBoost Library."""
|
||||||
import collections
|
import collections
|
||||||
# pylint: disable=no-name-in-module,import-error
|
# pylint: disable=no-name-in-module,import-error
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
# pylint: enable=no-name-in-module,import-error
|
# pylint: enable=no-name-in-module,import-error
|
||||||
|
from typing import Dict, Union, List
|
||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@ -1012,6 +1013,7 @@ class Booster(object):
|
|||||||
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)),
|
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)),
|
||||||
ctypes.byref(self.handle)))
|
ctypes.byref(self.handle)))
|
||||||
params = params or {}
|
params = params or {}
|
||||||
|
params = self._configure_metrics(params.copy())
|
||||||
if isinstance(params, list):
|
if isinstance(params, list):
|
||||||
params.append(('validate_parameters', True))
|
params.append(('validate_parameters', True))
|
||||||
else:
|
else:
|
||||||
@ -1041,6 +1043,17 @@ class Booster(object):
|
|||||||
else:
|
else:
|
||||||
raise TypeError('Unknown type:', model_file)
|
raise TypeError('Unknown type:', model_file)
|
||||||
|
|
||||||
|
def _configure_metrics(self, params: Union[Dict, List]) -> Union[Dict, List]:
|
||||||
|
if isinstance(params, dict) and 'eval_metric' in params \
|
||||||
|
and isinstance(params['eval_metric'], list):
|
||||||
|
params = dict((k, v) for k, v in params.items())
|
||||||
|
eval_metrics = params['eval_metric']
|
||||||
|
params.pop("eval_metric", None)
|
||||||
|
params = list(params.items())
|
||||||
|
for eval_metric in eval_metrics:
|
||||||
|
params += [('eval_metric', eval_metric)]
|
||||||
|
return params
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if hasattr(self, 'handle') and self.handle is not None:
|
if hasattr(self, 'handle') and self.handle is not None:
|
||||||
_check_call(_LIB.XGBoosterFree(self.handle))
|
_check_call(_LIB.XGBoosterFree(self.handle))
|
||||||
|
|||||||
@ -40,18 +40,6 @@ def _is_new_callback(callbacks):
|
|||||||
for c in callbacks) or not callbacks
|
for c in callbacks) or not callbacks
|
||||||
|
|
||||||
|
|
||||||
def _configure_metrics(params):
|
|
||||||
if isinstance(params, dict) and 'eval_metric' in params \
|
|
||||||
and isinstance(params['eval_metric'], list):
|
|
||||||
params = dict((k, v) for k, v in params.items())
|
|
||||||
eval_metrics = params['eval_metric']
|
|
||||||
params.pop("eval_metric", None)
|
|
||||||
params = list(params.items())
|
|
||||||
for eval_metric in eval_metrics:
|
|
||||||
params += [('eval_metric', eval_metric)]
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def _train_internal(params, dtrain,
|
def _train_internal(params, dtrain,
|
||||||
num_boost_round=10, evals=(),
|
num_boost_round=10, evals=(),
|
||||||
obj=None, feval=None,
|
obj=None, feval=None,
|
||||||
@ -61,7 +49,6 @@ def _train_internal(params, dtrain,
|
|||||||
"""internal training function"""
|
"""internal training function"""
|
||||||
callbacks = [] if callbacks is None else copy.copy(callbacks)
|
callbacks = [] if callbacks is None else copy.copy(callbacks)
|
||||||
evals = list(evals)
|
evals = list(evals)
|
||||||
params = _configure_metrics(params.copy())
|
|
||||||
|
|
||||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||||
nboost = 0
|
nboost = 0
|
||||||
|
|||||||
@ -57,6 +57,25 @@ class TestBasic:
|
|||||||
# assert they are the same
|
# assert they are the same
|
||||||
assert np.sum(np.abs(preds2 - preds)) == 0
|
assert np.sum(np.abs(preds2 - preds)) == 0
|
||||||
|
|
||||||
|
def test_metric_config(self):
|
||||||
|
# Make sure that the metric configuration happens in booster so the
|
||||||
|
# string `['error', 'auc']` doesn't get passed down to core.
|
||||||
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
|
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||||
|
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||||
|
'objective': 'binary:logistic', 'eval_metric': ['error', 'auc']}
|
||||||
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
|
num_round = 2
|
||||||
|
booster = xgb.train(param, dtrain, num_round, watchlist)
|
||||||
|
predt_0 = booster.predict(dtrain)
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
path = os.path.join(tmpdir, 'model.json')
|
||||||
|
booster.save_model(path)
|
||||||
|
|
||||||
|
booster = xgb.Booster(params=param, model_file=path)
|
||||||
|
predt_1 = booster.predict(dtrain)
|
||||||
|
np.testing.assert_allclose(predt_0, predt_1)
|
||||||
|
|
||||||
def test_record_results(self):
|
def test_record_results(self):
|
||||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||||
@ -124,8 +143,8 @@ class TestBasic:
|
|||||||
|
|
||||||
dump2 = bst.get_dump(with_stats=True)
|
dump2 = bst.get_dump(with_stats=True)
|
||||||
assert dump2[0].count('\n') == 3, 'Expected 1 root and 2 leaves - 3 lines in dump.'
|
assert dump2[0].count('\n') == 3, 'Expected 1 root and 2 leaves - 3 lines in dump.'
|
||||||
assert (dump2[0].find('\n') > dump1[0].find('\n'),
|
msg = 'Expected more info when with_stats=True is given.'
|
||||||
'Expected more info when with_stats=True is given.')
|
assert dump2[0].find('\n') > dump1[0].find('\n'), msg
|
||||||
|
|
||||||
dump3 = bst.get_dump(dump_format="json")
|
dump3 = bst.get_dump(dump_format="json")
|
||||||
dump3j = json.loads(dump3[0])
|
dump3j = json.loads(dump3[0])
|
||||||
@ -248,13 +267,11 @@ class TestBasicPathLike:
|
|||||||
assert binary_path.exists()
|
assert binary_path.exists()
|
||||||
Path.unlink(binary_path)
|
Path.unlink(binary_path)
|
||||||
|
|
||||||
|
|
||||||
def test_Booster_init_invalid_path(self):
|
def test_Booster_init_invalid_path(self):
|
||||||
"""An invalid model_file path should raise XGBoostError."""
|
"""An invalid model_file path should raise XGBoostError."""
|
||||||
with pytest.raises(xgb.core.XGBoostError):
|
with pytest.raises(xgb.core.XGBoostError):
|
||||||
xgb.Booster(model_file=Path("invalidpath"))
|
xgb.Booster(model_file=Path("invalidpath"))
|
||||||
|
|
||||||
|
|
||||||
def test_Booster_save_and_load(self):
|
def test_Booster_save_and_load(self):
|
||||||
"""Saving and loading model files from paths."""
|
"""Saving and loading model files from paths."""
|
||||||
save_path = Path("saveload.model")
|
save_path = Path("saveload.model")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user