parent
bce7ca313c
commit
7109c6c1f2
@ -1,11 +1,12 @@
|
||||
# coding: utf-8
|
||||
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
|
||||
# pylint: disable=too-many-lines, too-many-locals
|
||||
# pylint: disable=too-many-lines, too-many-locals, no-self-use
|
||||
"""Core XGBoost Library."""
|
||||
import collections
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from collections.abc import Mapping
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from typing import Dict, Union, List
|
||||
import ctypes
|
||||
import os
|
||||
import re
|
||||
@ -1012,6 +1013,7 @@ class Booster(object):
|
||||
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)),
|
||||
ctypes.byref(self.handle)))
|
||||
params = params or {}
|
||||
params = self._configure_metrics(params.copy())
|
||||
if isinstance(params, list):
|
||||
params.append(('validate_parameters', True))
|
||||
else:
|
||||
@ -1041,6 +1043,17 @@ class Booster(object):
|
||||
else:
|
||||
raise TypeError('Unknown type:', model_file)
|
||||
|
||||
def _configure_metrics(self, params: Union[Dict, List]) -> Union[Dict, List]:
|
||||
if isinstance(params, dict) and 'eval_metric' in params \
|
||||
and isinstance(params['eval_metric'], list):
|
||||
params = dict((k, v) for k, v in params.items())
|
||||
eval_metrics = params['eval_metric']
|
||||
params.pop("eval_metric", None)
|
||||
params = list(params.items())
|
||||
for eval_metric in eval_metrics:
|
||||
params += [('eval_metric', eval_metric)]
|
||||
return params
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'handle') and self.handle is not None:
|
||||
_check_call(_LIB.XGBoosterFree(self.handle))
|
||||
|
||||
@ -40,18 +40,6 @@ def _is_new_callback(callbacks):
|
||||
for c in callbacks) or not callbacks
|
||||
|
||||
|
||||
def _configure_metrics(params):
|
||||
if isinstance(params, dict) and 'eval_metric' in params \
|
||||
and isinstance(params['eval_metric'], list):
|
||||
params = dict((k, v) for k, v in params.items())
|
||||
eval_metrics = params['eval_metric']
|
||||
params.pop("eval_metric", None)
|
||||
params = list(params.items())
|
||||
for eval_metric in eval_metrics:
|
||||
params += [('eval_metric', eval_metric)]
|
||||
return params
|
||||
|
||||
|
||||
def _train_internal(params, dtrain,
|
||||
num_boost_round=10, evals=(),
|
||||
obj=None, feval=None,
|
||||
@ -61,7 +49,6 @@ def _train_internal(params, dtrain,
|
||||
"""internal training function"""
|
||||
callbacks = [] if callbacks is None else copy.copy(callbacks)
|
||||
evals = list(evals)
|
||||
params = _configure_metrics(params.copy())
|
||||
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||
nboost = 0
|
||||
|
||||
@ -57,6 +57,25 @@ class TestBasic:
|
||||
# assert they are the same
|
||||
assert np.sum(np.abs(preds2 - preds)) == 0
|
||||
|
||||
def test_metric_config(self):
|
||||
# Make sure that the metric configuration happens in booster so the
|
||||
# string `['error', 'auc']` doesn't get passed down to core.
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
'objective': 'binary:logistic', 'eval_metric': ['error', 'auc']}
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
num_round = 2
|
||||
booster = xgb.train(param, dtrain, num_round, watchlist)
|
||||
predt_0 = booster.predict(dtrain)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, 'model.json')
|
||||
booster.save_model(path)
|
||||
|
||||
booster = xgb.Booster(params=param, model_file=path)
|
||||
predt_1 = booster.predict(dtrain)
|
||||
np.testing.assert_allclose(predt_0, predt_1)
|
||||
|
||||
def test_record_results(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
@ -124,8 +143,8 @@ class TestBasic:
|
||||
|
||||
dump2 = bst.get_dump(with_stats=True)
|
||||
assert dump2[0].count('\n') == 3, 'Expected 1 root and 2 leaves - 3 lines in dump.'
|
||||
assert (dump2[0].find('\n') > dump1[0].find('\n'),
|
||||
'Expected more info when with_stats=True is given.')
|
||||
msg = 'Expected more info when with_stats=True is given.'
|
||||
assert dump2[0].find('\n') > dump1[0].find('\n'), msg
|
||||
|
||||
dump3 = bst.get_dump(dump_format="json")
|
||||
dump3j = json.loads(dump3[0])
|
||||
@ -248,13 +267,11 @@ class TestBasicPathLike:
|
||||
assert binary_path.exists()
|
||||
Path.unlink(binary_path)
|
||||
|
||||
|
||||
def test_Booster_init_invalid_path(self):
|
||||
"""An invalid model_file path should raise XGBoostError."""
|
||||
with pytest.raises(xgb.core.XGBoostError):
|
||||
xgb.Booster(model_file=Path("invalidpath"))
|
||||
|
||||
|
||||
def test_Booster_save_and_load(self):
|
||||
"""Saving and loading model files from paths."""
|
||||
save_path = Path("saveload.model")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user