Sklearn: validation set weights (#2354)
* Add option to use weights when evaluating metrics in validation sets * Add test for validation-set weights functionality * simplify case with no weights for test sets * fix lint issues
This commit is contained in:
parent
71e226120a
commit
480e3fd764
@ -215,7 +215,8 @@ class XGBModel(XGBModelBase):
|
|||||||
return xgb_params
|
return xgb_params
|
||||||
|
|
||||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||||
early_stopping_rounds=None, verbose=True, xgb_model=None):
|
early_stopping_rounds=None, verbose=True, xgb_model=None,
|
||||||
|
sample_weight_eval_set=None):
|
||||||
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
|
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
|
||||||
"""
|
"""
|
||||||
Fit the gradient boosting model
|
Fit the gradient boosting model
|
||||||
@ -231,6 +232,9 @@ class XGBModel(XGBModelBase):
|
|||||||
eval_set : list, optional
|
eval_set : list, optional
|
||||||
A list of (X, y) tuple pairs to use as a validation set for
|
A list of (X, y) tuple pairs to use as a validation set for
|
||||||
early-stopping
|
early-stopping
|
||||||
|
sample_weight_eval_set : list, optional
|
||||||
|
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
|
||||||
|
instance weights on the i-th validation set.
|
||||||
eval_metric : str, callable, optional
|
eval_metric : str, callable, optional
|
||||||
If a str, should be a built-in evaluation metric to use. See
|
If a str, should be a built-in evaluation metric to use. See
|
||||||
doc/parameter.md. If callable, a custom evaluation metric. The call
|
doc/parameter.md. If callable, a custom evaluation metric. The call
|
||||||
@ -263,9 +267,14 @@ class XGBModel(XGBModelBase):
|
|||||||
trainDmatrix = DMatrix(X, label=y, missing=self.missing, nthread=self.n_jobs)
|
trainDmatrix = DMatrix(X, label=y, missing=self.missing, nthread=self.n_jobs)
|
||||||
|
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
|
|
||||||
if eval_set is not None:
|
if eval_set is not None:
|
||||||
evals = list(DMatrix(x[0], label=x[1], missing=self.missing,
|
if sample_weight_eval_set is None:
|
||||||
nthread=self.n_jobs) for x in eval_set)
|
sample_weight_eval_set = [None] * len(eval_set)
|
||||||
|
evals = list(
|
||||||
|
DMatrix(eval_set[i][0], label=eval_set[i][1], missing=self.missing,
|
||||||
|
weight=sample_weight_eval_set[i], nthread=self.n_jobs)
|
||||||
|
for i in range(len(eval_set)))
|
||||||
evals = list(zip(evals, ["validation_{}".format(i) for i in
|
evals = list(zip(evals, ["validation_{}".format(i) for i in
|
||||||
range(len(evals))]))
|
range(len(evals))]))
|
||||||
else:
|
else:
|
||||||
@ -408,7 +417,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
random_state, seed, missing, **kwargs)
|
random_state, seed, missing, **kwargs)
|
||||||
|
|
||||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||||
early_stopping_rounds=None, verbose=True, xgb_model=None):
|
early_stopping_rounds=None, verbose=True, xgb_model=None,
|
||||||
|
sample_weight_eval_set=None):
|
||||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||||
"""
|
"""
|
||||||
Fit gradient boosting classifier
|
Fit gradient boosting classifier
|
||||||
@ -424,6 +434,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
eval_set : list, optional
|
eval_set : list, optional
|
||||||
A list of (X, y) pairs to use as a validation set for
|
A list of (X, y) pairs to use as a validation set for
|
||||||
early-stopping
|
early-stopping
|
||||||
|
sample_weight_eval_set : list, optional
|
||||||
|
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
|
||||||
|
instance weights on the i-th validation set.
|
||||||
eval_metric : str, callable, optional
|
eval_metric : str, callable, optional
|
||||||
If a str, should be a built-in evaluation metric to use. See
|
If a str, should be a built-in evaluation metric to use. See
|
||||||
doc/parameter.md. If callable, a custom evaluation metric. The call
|
doc/parameter.md. If callable, a custom evaluation metric. The call
|
||||||
@ -478,11 +491,13 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
training_labels = self._le.transform(y)
|
training_labels = self._le.transform(y)
|
||||||
|
|
||||||
if eval_set is not None:
|
if eval_set is not None:
|
||||||
# TODO: use sample_weight if given?
|
if sample_weight_eval_set is None:
|
||||||
|
sample_weight_eval_set = [None] * len(eval_set)
|
||||||
evals = list(
|
evals = list(
|
||||||
DMatrix(x[0], label=self._le.transform(x[1]),
|
DMatrix(eval_set[i][0], label=self._le.transform(eval_set[i][1]),
|
||||||
missing=self.missing, nthread=self.n_jobs)
|
missing=self.missing, weight=sample_weight_eval_set[i],
|
||||||
for x in eval_set
|
nthread=self.n_jobs)
|
||||||
|
for i in range(len(eval_set))
|
||||||
)
|
)
|
||||||
nevals = len(evals)
|
nevals = len(evals)
|
||||||
eval_names = ["validation_{}".format(i) for i in range(nevals)]
|
eval_names = ["validation_{}".format(i) for i in range(nevals)]
|
||||||
|
|||||||
@ -370,3 +370,91 @@ def test_sklearn_clone():
|
|||||||
clf = xgb.XGBClassifier(n_jobs=2, nthread=3)
|
clf = xgb.XGBClassifier(n_jobs=2, nthread=3)
|
||||||
clf.n_jobs = -1
|
clf.n_jobs = -1
|
||||||
clone(clf)
|
clone(clf)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validation_weights_xgbmodel():
|
||||||
|
tm._skip_if_no_sklearn()
|
||||||
|
from sklearn.datasets import make_hastie_10_2
|
||||||
|
|
||||||
|
# prepare training and test data
|
||||||
|
X, y = make_hastie_10_2(n_samples=2000, random_state=42)
|
||||||
|
labels, y = np.unique(y, return_inverse=True)
|
||||||
|
X_train, X_test = X[:1600], X[1600:]
|
||||||
|
y_train, y_test = y[:1600], y[1600:]
|
||||||
|
|
||||||
|
# instantiate model
|
||||||
|
param_dist = {'objective': 'binary:logistic', 'n_estimators': 2,
|
||||||
|
'random_state': 123}
|
||||||
|
clf = xgb.sklearn.XGBModel(**param_dist)
|
||||||
|
|
||||||
|
# train it using instance weights only in the training set
|
||||||
|
weights_train = np.random.choice([1, 2], len(X_train))
|
||||||
|
clf.fit(X_train, y_train,
|
||||||
|
sample_weight=weights_train,
|
||||||
|
eval_set=[(X_test, y_test)],
|
||||||
|
eval_metric='logloss',
|
||||||
|
verbose=False)
|
||||||
|
|
||||||
|
# evaluate logloss metric on test set *without* using weights
|
||||||
|
evals_result_without_weights = clf.evals_result()
|
||||||
|
logloss_without_weights = evals_result_without_weights["validation_0"]["logloss"]
|
||||||
|
|
||||||
|
# now use weights for the test set
|
||||||
|
np.random.seed(0)
|
||||||
|
weights_test = np.random.choice([1, 2], len(X_test))
|
||||||
|
clf.fit(X_train, y_train,
|
||||||
|
sample_weight=weights_train,
|
||||||
|
eval_set=[(X_test, y_test)],
|
||||||
|
sample_weight_eval_set=[weights_test],
|
||||||
|
eval_metric='logloss',
|
||||||
|
verbose=False)
|
||||||
|
evals_result_with_weights = clf.evals_result()
|
||||||
|
logloss_with_weights = evals_result_with_weights["validation_0"]["logloss"]
|
||||||
|
|
||||||
|
# check that the logloss in the test set is actually different when using weights
|
||||||
|
# than when not using them
|
||||||
|
assert all((logloss_with_weights[i] != logloss_without_weights[i] for i in [0, 1]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_validation_weights_xgbclassifier():
|
||||||
|
tm._skip_if_no_sklearn()
|
||||||
|
from sklearn.datasets import make_hastie_10_2
|
||||||
|
|
||||||
|
# prepare training and test data
|
||||||
|
X, y = make_hastie_10_2(n_samples=2000, random_state=42)
|
||||||
|
labels, y = np.unique(y, return_inverse=True)
|
||||||
|
X_train, X_test = X[:1600], X[1600:]
|
||||||
|
y_train, y_test = y[:1600], y[1600:]
|
||||||
|
|
||||||
|
# instantiate model
|
||||||
|
param_dist = {'objective': 'binary:logistic', 'n_estimators': 2,
|
||||||
|
'random_state': 123}
|
||||||
|
clf = xgb.sklearn.XGBClassifier(**param_dist)
|
||||||
|
|
||||||
|
# train it using instance weights only in the training set
|
||||||
|
weights_train = np.random.choice([1, 2], len(X_train))
|
||||||
|
clf.fit(X_train, y_train,
|
||||||
|
sample_weight=weights_train,
|
||||||
|
eval_set=[(X_test, y_test)],
|
||||||
|
eval_metric='logloss',
|
||||||
|
verbose=False)
|
||||||
|
|
||||||
|
# evaluate logloss metric on test set *without* using weights
|
||||||
|
evals_result_without_weights = clf.evals_result()
|
||||||
|
logloss_without_weights = evals_result_without_weights["validation_0"]["logloss"]
|
||||||
|
|
||||||
|
# now use weights for the test set
|
||||||
|
np.random.seed(0)
|
||||||
|
weights_test = np.random.choice([1, 2], len(X_test))
|
||||||
|
clf.fit(X_train, y_train,
|
||||||
|
sample_weight=weights_train,
|
||||||
|
eval_set=[(X_test, y_test)],
|
||||||
|
sample_weight_eval_set=[weights_test],
|
||||||
|
eval_metric='logloss',
|
||||||
|
verbose=False)
|
||||||
|
evals_result_with_weights = clf.evals_result()
|
||||||
|
logloss_with_weights = evals_result_with_weights["validation_0"]["logloss"]
|
||||||
|
|
||||||
|
# check that the logloss in the test set is actually different when using weights
|
||||||
|
# than when not using them
|
||||||
|
assert all((logloss_with_weights[i] != logloss_without_weights[i] for i in [0, 1]))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user