@@ -1,11 +1,12 @@
|
||||
# coding: utf-8
|
||||
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
|
||||
# pylint: disable=too-many-lines, too-many-locals
|
||||
# pylint: disable=too-many-lines, too-many-locals, no-self-use
|
||||
"""Core XGBoost Library."""
|
||||
import collections
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from collections.abc import Mapping
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from typing import Dict, Union, List
|
||||
import ctypes
|
||||
import os
|
||||
import re
|
||||
@@ -1012,6 +1013,7 @@ class Booster(object):
|
||||
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)),
|
||||
ctypes.byref(self.handle)))
|
||||
params = params or {}
|
||||
params = self._configure_metrics(params.copy())
|
||||
if isinstance(params, list):
|
||||
params.append(('validate_parameters', True))
|
||||
else:
|
||||
@@ -1041,6 +1043,17 @@ class Booster(object):
|
||||
else:
|
||||
raise TypeError('Unknown type:', model_file)
|
||||
|
||||
def _configure_metrics(self, params: Union[Dict, List]) -> Union[Dict, List]:
|
||||
if isinstance(params, dict) and 'eval_metric' in params \
|
||||
and isinstance(params['eval_metric'], list):
|
||||
params = dict((k, v) for k, v in params.items())
|
||||
eval_metrics = params['eval_metric']
|
||||
params.pop("eval_metric", None)
|
||||
params = list(params.items())
|
||||
for eval_metric in eval_metrics:
|
||||
params += [('eval_metric', eval_metric)]
|
||||
return params
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'handle') and self.handle is not None:
|
||||
_check_call(_LIB.XGBoosterFree(self.handle))
|
||||
|
||||
@@ -40,18 +40,6 @@ def _is_new_callback(callbacks):
|
||||
for c in callbacks) or not callbacks
|
||||
|
||||
|
||||
def _configure_metrics(params):
|
||||
if isinstance(params, dict) and 'eval_metric' in params \
|
||||
and isinstance(params['eval_metric'], list):
|
||||
params = dict((k, v) for k, v in params.items())
|
||||
eval_metrics = params['eval_metric']
|
||||
params.pop("eval_metric", None)
|
||||
params = list(params.items())
|
||||
for eval_metric in eval_metrics:
|
||||
params += [('eval_metric', eval_metric)]
|
||||
return params
|
||||
|
||||
|
||||
def _train_internal(params, dtrain,
|
||||
num_boost_round=10, evals=(),
|
||||
obj=None, feval=None,
|
||||
@@ -61,7 +49,6 @@ def _train_internal(params, dtrain,
|
||||
"""internal training function"""
|
||||
callbacks = [] if callbacks is None else copy.copy(callbacks)
|
||||
evals = list(evals)
|
||||
params = _configure_metrics(params.copy())
|
||||
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||
nboost = 0
|
||||
|
||||
Reference in New Issue
Block a user