Add base margin to sklearn interface. (#5151)

This commit is contained in:
Jiaming Yuan 2019-12-24 09:43:41 +08:00 committed by GitHub
parent 1d0ca49761
commit 0202e04a8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 103 additions and 61 deletions

View File

@ -1,5 +1,4 @@
#!/usr/bin/python #!/usr/bin/python
import numpy as np
import xgboost as xgb import xgboost as xgb
dtrain = xgb.DMatrix('../data/agaricus.txt.train') dtrain = xgb.DMatrix('../data/agaricus.txt.train')
@ -8,18 +7,19 @@ watchlist = [(dtest, 'eval'), (dtrain, 'train')]
### ###
# advanced: start from a initial base prediction # advanced: start from a initial base prediction
# #
print ('start running example to start from a initial prediction') print('start running example to start from a initial prediction')
# specify parameters via map, definition are same as c++ version # specify parameters via map, definition are same as c++ version
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic'} param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
# train xgboost for 1 round # train xgboost for 1 round
bst = xgb.train(param, dtrain, 1, watchlist) bst = xgb.train(param, dtrain, 1, watchlist)
# Note: we need the margin value instead of transformed prediction in set_base_margin # Note: we need the margin value instead of transformed prediction in
# do predict with output_margin=True, will always give you margin values before logistic transformation # set_base_margin
# do predict with output_margin=True, will always give you margin values
# before logistic transformation
ptrain = bst.predict(dtrain, output_margin=True) ptrain = bst.predict(dtrain, output_margin=True)
ptest = bst.predict(dtest, output_margin=True) ptest = bst.predict(dtest, output_margin=True)
dtrain.set_base_margin(ptrain) dtrain.set_base_margin(ptrain)
dtest.set_base_margin(ptest) dtest.set_base_margin(ptest)
print('this is result of running from initial prediction') print('this is result of running from initial prediction')
bst = xgb.train(param, dtrain, 1, watchlist) bst = xgb.train(param, dtrain, 1, watchlist)

View File

@ -434,9 +434,11 @@ class DMatrix(object):
_feature_names = None # for previous version's pickle _feature_names = None # for previous version's pickle
_feature_types = None _feature_types = None
def __init__(self, data, label=None, missing=None, def __init__(self, data, label=None, weight=None, base_margin=None,
weight=None, silent=False, missing=None,
feature_names=None, feature_types=None, silent=False,
feature_names=None,
feature_types=None,
nthread=None): nthread=None):
"""Parameters """Parameters
---------- ----------
@ -492,6 +494,7 @@ class DMatrix(object):
label = _maybe_pandas_label(label) label = _maybe_pandas_label(label)
label = _maybe_dt_array(label) label = _maybe_dt_array(label)
weight = _maybe_dt_array(weight) weight = _maybe_dt_array(weight)
base_margin = _maybe_dt_array(base_margin)
if isinstance(data, (STRING_TYPES, os_PathLike)): if isinstance(data, (STRING_TYPES, os_PathLike)):
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
@ -518,19 +521,11 @@ class DMatrix(object):
' {}'.format(type(data).__name__)) ' {}'.format(type(data).__name__))
if label is not None: if label is not None:
if isinstance(label, np.ndarray):
self.set_label_npy2d(label)
elif _use_columnar_initializer(label):
self.set_interface_info('label', label)
else:
self.set_label(label) self.set_label(label)
if weight is not None: if weight is not None:
if isinstance(weight, np.ndarray):
self.set_weight_npy2d(weight)
elif _use_columnar_initializer(label):
self.set_interface_info('weight', weight)
else:
self.set_weight(weight) self.set_weight(weight)
if base_margin is not None:
self.set_base_margin(base_margin)
self.feature_names = feature_names self.feature_names = feature_names
self.feature_types = feature_types self.feature_types = feature_types
@ -792,6 +787,11 @@ class DMatrix(object):
label: array like label: array like
The label information to be set into DMatrix The label information to be set into DMatrix
""" """
if isinstance(label, np.ndarray):
self.set_label_npy2d(label)
elif _use_columnar_initializer(label):
self.set_interface_info('label', label)
else:
self.set_float_info('label', label) self.set_float_info('label', label)
def set_label_npy2d(self, label): def set_label_npy2d(self, label):
@ -820,6 +820,11 @@ class DMatrix(object):
data points within each group, so it doesn't make sense to assign data points within each group, so it doesn't make sense to assign
weights to individual data points. weights to individual data points.
""" """
if isinstance(weight, np.ndarray):
self.set_weight_npy2d(weight)
elif _use_columnar_initializer(weight):
self.set_interface_info('weight', weight)
else:
self.set_float_info('weight', weight) self.set_float_info('weight', weight)
def set_weight_npy2d(self, weight): def set_weight_npy2d(self, weight):

View File

@ -288,9 +288,9 @@ class XGBModel(XGBModelBase):
self._Booster = Booster({'n_jobs': self.n_jobs}) self._Booster = Booster({'n_jobs': self.n_jobs})
self._Booster.load_model(fname) self._Booster.load_model(fname)
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None, def fit(self, X, y, sample_weight=None, base_margin=None,
early_stopping_rounds=None, verbose=True, xgb_model=None, eval_set=None, eval_metric=None, early_stopping_rounds=None,
sample_weight_eval_set=None, callbacks=None): verbose=True, xgb_model=None, sample_weight_eval_set=None, callbacks=None):
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init # pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
"""Fit gradient boosting model """Fit gradient boosting model
@ -302,6 +302,8 @@ class XGBModel(XGBModelBase):
Labels Labels
sample_weight : array_like sample_weight : array_like
instance weights instance weights
base_margin : array_like
global bias for each instance.
eval_set : list, optional eval_set : list, optional
A list of (X, y) tuple pairs to use as validation sets, for which A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed. metrics will be computed.
@ -346,14 +348,10 @@ class XGBModel(XGBModelBase):
[xgb.callback.reset_learning_rate(custom_rates)] [xgb.callback.reset_learning_rate(custom_rates)]
""" """
if sample_weight is not None: trainDmatrix = DMatrix(data=X, label=y, weight=sample_weight,
trainDmatrix = DMatrix(X, label=y, weight=sample_weight, base_margin=base_margin,
missing=self.missing, missing=self.missing,
nthread=self.n_jobs) nthread=self.n_jobs)
else:
trainDmatrix = DMatrix(X, label=y, missing=self.missing,
nthread=self.n_jobs)
evals_result = {} evals_result = {}
if eval_set is not None: if eval_set is not None:
@ -404,7 +402,8 @@ class XGBModel(XGBModelBase):
self.best_ntree_limit = self._Booster.best_ntree_limit self.best_ntree_limit = self._Booster.best_ntree_limit
return self return self
def predict(self, data, output_margin=False, ntree_limit=None, validate_features=True): def predict(self, data, output_margin=False, ntree_limit=None,
validate_features=True, base_margin=None):
""" """
Predict with `data`. Predict with `data`.
@ -442,7 +441,8 @@ class XGBModel(XGBModelBase):
prediction : numpy array prediction : numpy array
""" """
# pylint: disable=missing-docstring,invalid-name # pylint: disable=missing-docstring,invalid-name
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs) test_dmatrix = DMatrix(data, base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs)
# get ntree_limit to use - if none specified, default to # get ntree_limit to use - if none specified, default to
# best_ntree_limit if defined, otherwise 0. # best_ntree_limit if defined, otherwise 0.
if ntree_limit is None: if ntree_limit is None:
@ -621,7 +621,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
base_score=base_score, random_state=random_state, missing=missing, base_score=base_score, random_state=random_state, missing=missing,
**kwargs) **kwargs)
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None, def fit(self, X, y, sample_weight=None, base_margin=None,
eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None, early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None, callbacks=None): sample_weight_eval_set=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ # pylint: disable = attribute-defined-outside-init,arguments-differ
@ -675,12 +676,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
raise ValueError( raise ValueError(
'Please reshape the input data X into 2-dimensional matrix.') 'Please reshape the input data X into 2-dimensional matrix.')
self._features_count = X.shape[1] self._features_count = X.shape[1]
if sample_weight is not None:
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight, train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
missing=self.missing, nthread=self.n_jobs) base_margin=base_margin,
else:
train_dmatrix = DMatrix(X, label=training_labels,
missing=self.missing, nthread=self.n_jobs) missing=self.missing, nthread=self.n_jobs)
self._Booster = train(xgb_options, train_dmatrix, self.get_num_boosting_rounds(), self._Booster = train(xgb_options, train_dmatrix, self.get_num_boosting_rounds(),
@ -706,7 +703,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
fit.__doc__ = XGBModel.fit.__doc__.replace('Fit gradient boosting model', fit.__doc__ = XGBModel.fit.__doc__.replace('Fit gradient boosting model',
'Fit gradient boosting classifier', 1) 'Fit gradient boosting classifier', 1)
def predict(self, data, output_margin=False, ntree_limit=None, validate_features=True): def predict(self, data, output_margin=False, ntree_limit=None,
validate_features=True, base_margin=None):
""" """
Predict with `data`. Predict with `data`.
@ -729,7 +727,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
Parameters Parameters
---------- ----------
data : DMatrix data : array_like
The dmatrix storing the input. The dmatrix storing the input.
output_margin : bool output_margin : bool
Whether to output the raw untransformed margin value. Whether to output the raw untransformed margin value.
@ -743,7 +741,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
------- -------
prediction : numpy array prediction : numpy array
""" """
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs) test_dmatrix = DMatrix(data, base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs)
if ntree_limit is None: if ntree_limit is None:
ntree_limit = getattr(self, "best_ntree_limit", 0) ntree_limit = getattr(self, "best_ntree_limit", 0)
class_probs = self.get_booster().predict(test_dmatrix, class_probs = self.get_booster().predict(test_dmatrix,
@ -761,7 +760,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
column_indexes[class_probs > 0.5] = 1 column_indexes[class_probs > 0.5] = 1
return self._le.inverse_transform(column_indexes) return self._le.inverse_transform(column_indexes)
def predict_proba(self, data, ntree_limit=None, validate_features=True): def predict_proba(self, data, ntree_limit=None, validate_features=True,
base_margin=None):
""" """
Predict the probability of each `data` example being of a given class. Predict the probability of each `data` example being of a given class.
@ -787,7 +787,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
prediction : numpy array prediction : numpy array
a numpy array with the probability of each data example being of a given class. a numpy array with the probability of each data example being of a given class.
""" """
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs) test_dmatrix = DMatrix(data, base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs)
if ntree_limit is None: if ntree_limit is None:
ntree_limit = getattr(self, "best_ntree_limit", 0) ntree_limit = getattr(self, "best_ntree_limit", 0)
class_probs = self.get_booster().predict(test_dmatrix, class_probs = self.get_booster().predict(test_dmatrix,
@ -1045,7 +1046,8 @@ class XGBRanker(XGBModel):
if "rank:" not in self.objective: if "rank:" not in self.objective:
raise ValueError("please use XGBRanker for ranking task") raise ValueError("please use XGBRanker for ranking task")
def fit(self, X, y, group, sample_weight=None, eval_set=None, def fit(self, X, y, group, sample_weight=None, base_margin=None,
eval_set=None,
sample_weight_eval_set=None, eval_group=None, eval_metric=None, sample_weight_eval_set=None, eval_group=None, eval_metric=None,
early_stopping_rounds=None, verbose=False, xgb_model=None, early_stopping_rounds=None, verbose=False, xgb_model=None,
callbacks=None): callbacks=None):
@ -1072,6 +1074,8 @@ class XGBRanker(XGBModel):
data points within each group, so it doesn't make sense to assign data points within each group, so it doesn't make sense to assign
weights to individual data points. weights to individual data points.
base_margin : array_like
Global bias for each instance.
eval_set : list, optional eval_set : list, optional
A list of (X, y) tuple pairs to use as validation sets, for which A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed. metrics will be computed.
@ -1138,14 +1142,10 @@ class XGBRanker(XGBModel):
ret.set_group(group) ret.set_group(group)
return ret return ret
if sample_weight is not None: train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
train_dmatrix = _dmat_init( base_margin=base_margin,
group, data=X, label=y, weight=sample_weight,
missing=self.missing, nthread=self.n_jobs)
else:
train_dmatrix = _dmat_init(
group, data=X, label=y,
missing=self.missing, nthread=self.n_jobs) missing=self.missing, nthread=self.n_jobs)
train_dmatrix.set_group(group)
evals_result = {} evals_result = {}
@ -1192,9 +1192,11 @@ class XGBRanker(XGBModel):
return self return self
def predict(self, data, output_margin=False, ntree_limit=0, validate_features=True): def predict(self, data, output_margin=False,
ntree_limit=0, validate_features=True, base_margin=None):
test_dmatrix = DMatrix(data, missing=self.missing) test_dmatrix = DMatrix(data, base_margin=base_margin,
missing=self.missing)
if ntree_limit is None: if ntree_limit is None:
ntree_limit = getattr(self, "best_ntree_limit", 0) ntree_limit = getattr(self, "best_ntree_limit", 0)

View File

@ -132,6 +132,21 @@ class TestModels(unittest.TestCase):
bst = xgb.train(param, dtrain, num_round, watchlist, learning_rates=eta_decay) bst = xgb.train(param, dtrain, num_round, watchlist, learning_rates=eta_decay)
assert isinstance(bst, xgb.core.Booster) assert isinstance(bst, xgb.core.Booster)
def test_boost_from_prediction(self):
# Re-construct dtrain here to avoid modification
margined = xgb.DMatrix(dpath + 'agaricus.txt.train')
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
predt_0 = bst.predict(margined, output_margin=True)
margined.set_base_margin(predt_0)
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
predt_1 = bst.predict(margined)
assert np.any(np.abs(predt_1 - predt_0) > 1e-6)
bst = xgb.train({'tree_method': 'hist'}, dtrain, 2)
predt_2 = bst.predict(dtrain)
assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
def test_custom_objective(self): def test_custom_objective(self):
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0} param = {'max_depth': 2, 'eta': 1, 'verbosity': 0}
watchlist = [(dtest, 'eval'), (dtrain, 'train')] watchlist = [(dtest, 'eval'), (dtrain, 'train')]

View File

@ -695,3 +695,23 @@ def test_XGBClassifier_resume():
assert np.any(pred1 != pred2) assert np.any(pred1 != pred2)
assert log_loss1 > log_loss2 assert log_loss1 > log_loss2
def test_boost_from_prediction():
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
model_0 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=4)
model_0.fit(X=X, y=y)
margin = model_0.predict(X, output_margin=True)
model_1 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=4)
model_1.fit(X=X, y=y, base_margin=margin)
predictions_1 = model_1.predict(X, base_margin=margin)
cls_2 = xgb.XGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=8)
cls_2.fit(X=X, y=y)
predictions_2 = cls_2.predict(X, base_margin=margin)
assert np.all(predictions_1 == predictions_2)