Sklearn: validation set weights (#2354)
* Add option to use weights when evaluating metrics in validation sets * Add test for validation-set weights functionality * simplify case with no weights for test sets * fix lint issues
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
71e226120a
commit
480e3fd764
@@ -215,7 +215,8 @@ class XGBModel(XGBModelBase):
|
||||
return xgb_params
|
||||
|
||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||
early_stopping_rounds=None, verbose=True, xgb_model=None):
|
||||
early_stopping_rounds=None, verbose=True, xgb_model=None,
|
||||
sample_weight_eval_set=None):
|
||||
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
|
||||
"""
|
||||
Fit the gradient boosting model
|
||||
@@ -231,6 +232,9 @@ class XGBModel(XGBModelBase):
|
||||
eval_set : list, optional
|
||||
A list of (X, y) tuple pairs to use as a validation set for
|
||||
early-stopping
|
||||
sample_weight_eval_set : list, optional
|
||||
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
|
||||
instance weights on the i-th validation set.
|
||||
eval_metric : str, callable, optional
|
||||
If a str, should be a built-in evaluation metric to use. See
|
||||
doc/parameter.md. If callable, a custom evaluation metric. The call
|
||||
@@ -263,9 +267,14 @@ class XGBModel(XGBModelBase):
|
||||
trainDmatrix = DMatrix(X, label=y, missing=self.missing, nthread=self.n_jobs)
|
||||
|
||||
evals_result = {}
|
||||
|
||||
if eval_set is not None:
|
||||
evals = list(DMatrix(x[0], label=x[1], missing=self.missing,
|
||||
nthread=self.n_jobs) for x in eval_set)
|
||||
if sample_weight_eval_set is None:
|
||||
sample_weight_eval_set = [None] * len(eval_set)
|
||||
evals = list(
|
||||
DMatrix(eval_set[i][0], label=eval_set[i][1], missing=self.missing,
|
||||
weight=sample_weight_eval_set[i], nthread=self.n_jobs)
|
||||
for i in range(len(eval_set)))
|
||||
evals = list(zip(evals, ["validation_{}".format(i) for i in
|
||||
range(len(evals))]))
|
||||
else:
|
||||
@@ -408,7 +417,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
random_state, seed, missing, **kwargs)
|
||||
|
||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||
early_stopping_rounds=None, verbose=True, xgb_model=None):
|
||||
early_stopping_rounds=None, verbose=True, xgb_model=None,
|
||||
sample_weight_eval_set=None):
|
||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||
"""
|
||||
Fit gradient boosting classifier
|
||||
@@ -424,6 +434,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
eval_set : list, optional
|
||||
A list of (X, y) pairs to use as a validation set for
|
||||
early-stopping
|
||||
sample_weight_eval_set : list, optional
|
||||
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
|
||||
instance weights on the i-th validation set.
|
||||
eval_metric : str, callable, optional
|
||||
If a str, should be a built-in evaluation metric to use. See
|
||||
doc/parameter.md. If callable, a custom evaluation metric. The call
|
||||
@@ -478,11 +491,13 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
training_labels = self._le.transform(y)
|
||||
|
||||
if eval_set is not None:
|
||||
# TODO: use sample_weight if given?
|
||||
if sample_weight_eval_set is None:
|
||||
sample_weight_eval_set = [None] * len(eval_set)
|
||||
evals = list(
|
||||
DMatrix(x[0], label=self._le.transform(x[1]),
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
for x in eval_set
|
||||
DMatrix(eval_set[i][0], label=self._le.transform(eval_set[i][1]),
|
||||
missing=self.missing, weight=sample_weight_eval_set[i],
|
||||
nthread=self.n_jobs)
|
||||
for i in range(len(eval_set))
|
||||
)
|
||||
nevals = len(evals)
|
||||
eval_names = ["validation_{}".format(i) for i in range(nevals)]
|
||||
|
||||
Reference in New Issue
Block a user