Move metric configuration into booster. (#6504)

This commit is contained in:
Jiaming Yuan
2020-12-16 05:35:04 +08:00
committed by GitHub
parent d45c0d843b
commit 3c3f026ec1
3 changed files with 35 additions and 18 deletions

View File

@@ -1,11 +1,12 @@
# coding: utf-8
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
# pylint: disable=too-many-lines, too-many-locals
# pylint: disable=too-many-lines, too-many-locals, no-self-use
"""Core XGBoost Library."""
import collections
# pylint: disable=no-name-in-module,import-error
from collections.abc import Mapping
# pylint: enable=no-name-in-module,import-error
from typing import Dict, Union, List
import ctypes
import os
import re
@@ -1012,6 +1013,7 @@ class Booster(object):
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)),
ctypes.byref(self.handle)))
params = params or {}
params = self._configure_metrics(params.copy())
if isinstance(params, list):
params.append(('validate_parameters', True))
else:
@@ -1041,6 +1043,17 @@ class Booster(object):
else:
raise TypeError('Unknown type:', model_file)
def _configure_metrics(self, params: Union[Dict, List]) -> Union[Dict, List]:
if isinstance(params, dict) and 'eval_metric' in params \
and isinstance(params['eval_metric'], list):
params = dict((k, v) for k, v in params.items())
eval_metrics = params['eval_metric']
params.pop("eval_metric", None)
params = list(params.items())
for eval_metric in eval_metrics:
params += [('eval_metric', eval_metric)]
return params
def __del__(self):
if hasattr(self, 'handle') and self.handle is not None:
_check_call(_LIB.XGBoosterFree(self.handle))

View File

@@ -40,18 +40,6 @@ def _is_new_callback(callbacks):
for c in callbacks) or not callbacks
def _configure_metrics(params):
if isinstance(params, dict) and 'eval_metric' in params \
and isinstance(params['eval_metric'], list):
params = dict((k, v) for k, v in params.items())
eval_metrics = params['eval_metric']
params.pop("eval_metric", None)
params = list(params.items())
for eval_metric in eval_metrics:
params += [('eval_metric', eval_metric)]
return params
def _train_internal(params, dtrain,
num_boost_round=10, evals=(),
obj=None, feval=None,
@@ -61,7 +49,6 @@ def _train_internal(params, dtrain,
"""internal training function"""
callbacks = [] if callbacks is None else copy.copy(callbacks)
evals = list(evals)
params = _configure_metrics(params.copy())
bst = Booster(params, [dtrain] + [d[0] for d in evals])
nboost = 0