Enable parameter validation for skl. (#5477)

This commit is contained in:
Jiaming Yuan
2020-04-03 10:23:58 +08:00
committed by GitHub
parent d0b86c75d9
commit c218d8ffbf
3 changed files with 43 additions and 12 deletions

View File

@@ -1098,6 +1098,7 @@ class DeviceQuantileDMatrix(DMatrix):
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
self.handle = handle
class Booster(object):
# pylint: disable=too-many-public-methods
"""A Booster of XGBoost.
@@ -1129,10 +1130,12 @@ class Booster(object):
self.handle = ctypes.c_void_p()
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)),
ctypes.byref(self.handle)))
params = params or {}
if isinstance(params, list):
params.append(('validate_parameters', True))
else:
params['validate_parameters'] = True
if isinstance(params, dict) and \
'validate_parameters' not in params.keys():
params['validate_parameters'] = 1
self.set_param(params or {})
if (params is not None) and ('booster' in params):
self.booster = params['booster']

View File

@@ -210,7 +210,7 @@ class XGBModel(XGBModelBase):
missing=np.nan, num_parallel_tree=None,
monotone_constraints=None, interaction_constraints=None,
importance_type="gain", gpu_id=None,
validate_parameters=False, **kwargs):
validate_parameters=None, **kwargs):
if not SKLEARN_INSTALLED:
raise XGBoostError(
'sklearn needs to be installed in order to use this module')
@@ -242,9 +242,6 @@ class XGBModel(XGBModelBase):
self.interaction_constraints = interaction_constraints
self.importance_type = importance_type
self.gpu_id = gpu_id
# Parameter validation is not working with Scikit-Learn interface, as
# it passes all paraemters into XGBoost core, whether they are used or
# not.
self.validate_parameters = validate_parameters
def get_booster(self):
@@ -340,9 +337,16 @@ class XGBModel(XGBModelBase):
return params
def get_xgb_params(self):
"""Get xgboost type parameters."""
xgb_params = self.get_params()
return xgb_params
"""Get xgboost specific parameters."""
params = self.get_params()
# Parameters that should not go into native learner.
wrapper_specific = {
'importance_type', 'kwargs', 'missing', 'n_estimators'}
filtered = dict()
for k, v in params.items():
if k not in wrapper_specific:
filtered[k] = v
return filtered
def get_num_boosting_rounds(self):
"""Gets the number of xgboost boosting rounds."""
@@ -540,7 +544,8 @@ class XGBModel(XGBModelBase):
if evals_result:
for val in evals_result.items():
evals_result_key = list(val[1].keys())[0]
evals_result[val[0]][evals_result_key] = val[1][evals_result_key]
evals_result[val[0]][evals_result_key] = val[1][
evals_result_key]
self.evals_result_ = evals_result
if early_stopping_rounds is not None: