Sklearn: validation set weights (#2354)

* Add option to use weights when evaluating metrics in validation sets

* Add test for validation-set weights functionality

* simplify case with no weights for test sets

* fix lint issues
This commit is contained in:
pdavalo 2018-05-23 19:06:20 -05:00 committed by Philip Hyunsu Cho
parent 71e226120a
commit 480e3fd764
2 changed files with 111 additions and 8 deletions

View File

@ -215,7 +215,8 @@ class XGBModel(XGBModelBase):
return xgb_params
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None):
early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None):
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
"""
Fit the gradient boosting model
@ -231,6 +232,9 @@ class XGBModel(XGBModelBase):
eval_set : list, optional
A list of (X, y) tuple pairs to use as a validation set for
early-stopping
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
instance weights on the i-th validation set.
eval_metric : str, callable, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.md. If callable, a custom evaluation metric. The call
@ -263,9 +267,14 @@ class XGBModel(XGBModelBase):
trainDmatrix = DMatrix(X, label=y, missing=self.missing, nthread=self.n_jobs)
evals_result = {}
if eval_set is not None:
evals = list(DMatrix(x[0], label=x[1], missing=self.missing,
nthread=self.n_jobs) for x in eval_set)
if sample_weight_eval_set is None:
sample_weight_eval_set = [None] * len(eval_set)
evals = list(
DMatrix(eval_set[i][0], label=eval_set[i][1], missing=self.missing,
weight=sample_weight_eval_set[i], nthread=self.n_jobs)
for i in range(len(eval_set)))
evals = list(zip(evals, ["validation_{}".format(i) for i in
range(len(evals))]))
else:
@ -408,7 +417,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
random_state, seed, missing, **kwargs)
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None):
early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""
Fit gradient boosting classifier
@ -424,6 +434,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
eval_set : list, optional
A list of (X, y) pairs to use as a validation set for
early-stopping
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
instance weights on the i-th validation set.
eval_metric : str, callable, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.md. If callable, a custom evaluation metric. The call
@ -478,11 +491,13 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
training_labels = self._le.transform(y)
if eval_set is not None:
# TODO: use sample_weight if given?
if sample_weight_eval_set is None:
sample_weight_eval_set = [None] * len(eval_set)
evals = list(
DMatrix(x[0], label=self._le.transform(x[1]),
missing=self.missing, nthread=self.n_jobs)
for x in eval_set
DMatrix(eval_set[i][0], label=self._le.transform(eval_set[i][1]),
missing=self.missing, weight=sample_weight_eval_set[i],
nthread=self.n_jobs)
for i in range(len(eval_set))
)
nevals = len(evals)
eval_names = ["validation_{}".format(i) for i in range(nevals)]

View File

@ -370,3 +370,91 @@ def test_sklearn_clone():
clf = xgb.XGBClassifier(n_jobs=2, nthread=3)
clf.n_jobs = -1
clone(clf)
def test_validation_weights_xgbmodel():
tm._skip_if_no_sklearn()
from sklearn.datasets import make_hastie_10_2
# prepare training and test data
X, y = make_hastie_10_2(n_samples=2000, random_state=42)
labels, y = np.unique(y, return_inverse=True)
X_train, X_test = X[:1600], X[1600:]
y_train, y_test = y[:1600], y[1600:]
# instantiate model
param_dist = {'objective': 'binary:logistic', 'n_estimators': 2,
'random_state': 123}
clf = xgb.sklearn.XGBModel(**param_dist)
# train it using instance weights only in the training set
weights_train = np.random.choice([1, 2], len(X_train))
clf.fit(X_train, y_train,
sample_weight=weights_train,
eval_set=[(X_test, y_test)],
eval_metric='logloss',
verbose=False)
# evaluate logloss metric on test set *without* using weights
evals_result_without_weights = clf.evals_result()
logloss_without_weights = evals_result_without_weights["validation_0"]["logloss"]
# now use weights for the test set
np.random.seed(0)
weights_test = np.random.choice([1, 2], len(X_test))
clf.fit(X_train, y_train,
sample_weight=weights_train,
eval_set=[(X_test, y_test)],
sample_weight_eval_set=[weights_test],
eval_metric='logloss',
verbose=False)
evals_result_with_weights = clf.evals_result()
logloss_with_weights = evals_result_with_weights["validation_0"]["logloss"]
# check that the logloss in the test set is actually different when using weights
# than when not using them
assert all((logloss_with_weights[i] != logloss_without_weights[i] for i in [0, 1]))
def test_validation_weights_xgbclassifier():
tm._skip_if_no_sklearn()
from sklearn.datasets import make_hastie_10_2
# prepare training and test data
X, y = make_hastie_10_2(n_samples=2000, random_state=42)
labels, y = np.unique(y, return_inverse=True)
X_train, X_test = X[:1600], X[1600:]
y_train, y_test = y[:1600], y[1600:]
# instantiate model
param_dist = {'objective': 'binary:logistic', 'n_estimators': 2,
'random_state': 123}
clf = xgb.sklearn.XGBClassifier(**param_dist)
# train it using instance weights only in the training set
weights_train = np.random.choice([1, 2], len(X_train))
clf.fit(X_train, y_train,
sample_weight=weights_train,
eval_set=[(X_test, y_test)],
eval_metric='logloss',
verbose=False)
# evaluate logloss metric on test set *without* using weights
evals_result_without_weights = clf.evals_result()
logloss_without_weights = evals_result_without_weights["validation_0"]["logloss"]
# now use weights for the test set
np.random.seed(0)
weights_test = np.random.choice([1, 2], len(X_test))
clf.fit(X_train, y_train,
sample_weight=weights_train,
eval_set=[(X_test, y_test)],
sample_weight_eval_set=[weights_test],
eval_metric='logloss',
verbose=False)
evals_result_with_weights = clf.evals_result()
logloss_with_weights = evals_result_with_weights["validation_0"]["logloss"]
# check that the logloss in the test set is actually different when using weights
# than when not using them
assert all((logloss_with_weights[i] != logloss_without_weights[i] for i in [0, 1]))