Add base margin to sklearn interface. (#5151)
This commit is contained in:
parent
1d0ca49761
commit
0202e04a8e
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/python
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
|
||||
dtrain = xgb.DMatrix('../data/agaricus.txt.train')
|
||||
@ -8,18 +7,19 @@ watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
###
|
||||
# advanced: start from a initial base prediction
|
||||
#
|
||||
print ('start running example to start from a initial prediction')
|
||||
print('start running example to start from a initial prediction')
|
||||
# specify parameters via map, definition are same as c++ version
|
||||
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic'}
|
||||
param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
|
||||
# train xgboost for 1 round
|
||||
bst = xgb.train(param, dtrain, 1, watchlist)
|
||||
# Note: we need the margin value instead of transformed prediction in set_base_margin
|
||||
# do predict with output_margin=True, will always give you margin values before logistic transformation
|
||||
# Note: we need the margin value instead of transformed prediction in
|
||||
# set_base_margin
|
||||
# do predict with output_margin=True, will always give you margin values
|
||||
# before logistic transformation
|
||||
ptrain = bst.predict(dtrain, output_margin=True)
|
||||
ptest = bst.predict(dtest, output_margin=True)
|
||||
dtrain.set_base_margin(ptrain)
|
||||
dtest.set_base_margin(ptest)
|
||||
|
||||
|
||||
print('this is result of running from initial prediction')
|
||||
bst = xgb.train(param, dtrain, 1, watchlist)
|
||||
@ -434,9 +434,11 @@ class DMatrix(object):
|
||||
_feature_names = None # for previous version's pickle
|
||||
_feature_types = None
|
||||
|
||||
def __init__(self, data, label=None, missing=None,
|
||||
weight=None, silent=False,
|
||||
feature_names=None, feature_types=None,
|
||||
def __init__(self, data, label=None, weight=None, base_margin=None,
|
||||
missing=None,
|
||||
silent=False,
|
||||
feature_names=None,
|
||||
feature_types=None,
|
||||
nthread=None):
|
||||
"""Parameters
|
||||
----------
|
||||
@ -492,6 +494,7 @@ class DMatrix(object):
|
||||
label = _maybe_pandas_label(label)
|
||||
label = _maybe_dt_array(label)
|
||||
weight = _maybe_dt_array(weight)
|
||||
base_margin = _maybe_dt_array(base_margin)
|
||||
|
||||
if isinstance(data, (STRING_TYPES, os_PathLike)):
|
||||
handle = ctypes.c_void_p()
|
||||
@ -518,19 +521,11 @@ class DMatrix(object):
|
||||
' {}'.format(type(data).__name__))
|
||||
|
||||
if label is not None:
|
||||
if isinstance(label, np.ndarray):
|
||||
self.set_label_npy2d(label)
|
||||
elif _use_columnar_initializer(label):
|
||||
self.set_interface_info('label', label)
|
||||
else:
|
||||
self.set_label(label)
|
||||
if weight is not None:
|
||||
if isinstance(weight, np.ndarray):
|
||||
self.set_weight_npy2d(weight)
|
||||
elif _use_columnar_initializer(label):
|
||||
self.set_interface_info('weight', weight)
|
||||
else:
|
||||
self.set_weight(weight)
|
||||
if base_margin is not None:
|
||||
self.set_base_margin(base_margin)
|
||||
|
||||
self.feature_names = feature_names
|
||||
self.feature_types = feature_types
|
||||
@ -792,6 +787,11 @@ class DMatrix(object):
|
||||
label: array like
|
||||
The label information to be set into DMatrix
|
||||
"""
|
||||
if isinstance(label, np.ndarray):
|
||||
self.set_label_npy2d(label)
|
||||
elif _use_columnar_initializer(label):
|
||||
self.set_interface_info('label', label)
|
||||
else:
|
||||
self.set_float_info('label', label)
|
||||
|
||||
def set_label_npy2d(self, label):
|
||||
@ -820,6 +820,11 @@ class DMatrix(object):
|
||||
data points within each group, so it doesn't make sense to assign
|
||||
weights to individual data points.
|
||||
"""
|
||||
if isinstance(weight, np.ndarray):
|
||||
self.set_weight_npy2d(weight)
|
||||
elif _use_columnar_initializer(weight):
|
||||
self.set_interface_info('weight', weight)
|
||||
else:
|
||||
self.set_float_info('weight', weight)
|
||||
|
||||
def set_weight_npy2d(self, weight):
|
||||
|
||||
@ -288,9 +288,9 @@ class XGBModel(XGBModelBase):
|
||||
self._Booster = Booster({'n_jobs': self.n_jobs})
|
||||
self._Booster.load_model(fname)
|
||||
|
||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||
early_stopping_rounds=None, verbose=True, xgb_model=None,
|
||||
sample_weight_eval_set=None, callbacks=None):
|
||||
def fit(self, X, y, sample_weight=None, base_margin=None,
|
||||
eval_set=None, eval_metric=None, early_stopping_rounds=None,
|
||||
verbose=True, xgb_model=None, sample_weight_eval_set=None, callbacks=None):
|
||||
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
|
||||
"""Fit gradient boosting model
|
||||
|
||||
@ -302,6 +302,8 @@ class XGBModel(XGBModelBase):
|
||||
Labels
|
||||
sample_weight : array_like
|
||||
instance weights
|
||||
base_margin : array_like
|
||||
global bias for each instance.
|
||||
eval_set : list, optional
|
||||
A list of (X, y) tuple pairs to use as validation sets, for which
|
||||
metrics will be computed.
|
||||
@ -346,14 +348,10 @@ class XGBModel(XGBModelBase):
|
||||
|
||||
[xgb.callback.reset_learning_rate(custom_rates)]
|
||||
"""
|
||||
if sample_weight is not None:
|
||||
trainDmatrix = DMatrix(X, label=y, weight=sample_weight,
|
||||
trainDmatrix = DMatrix(data=X, label=y, weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
missing=self.missing,
|
||||
nthread=self.n_jobs)
|
||||
else:
|
||||
trainDmatrix = DMatrix(X, label=y, missing=self.missing,
|
||||
nthread=self.n_jobs)
|
||||
|
||||
evals_result = {}
|
||||
|
||||
if eval_set is not None:
|
||||
@ -404,7 +402,8 @@ class XGBModel(XGBModelBase):
|
||||
self.best_ntree_limit = self._Booster.best_ntree_limit
|
||||
return self
|
||||
|
||||
def predict(self, data, output_margin=False, ntree_limit=None, validate_features=True):
|
||||
def predict(self, data, output_margin=False, ntree_limit=None,
|
||||
validate_features=True, base_margin=None):
|
||||
"""
|
||||
Predict with `data`.
|
||||
|
||||
@ -442,7 +441,8 @@ class XGBModel(XGBModelBase):
|
||||
prediction : numpy array
|
||||
"""
|
||||
# pylint: disable=missing-docstring,invalid-name
|
||||
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
|
||||
test_dmatrix = DMatrix(data, base_margin=base_margin,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
# get ntree_limit to use - if none specified, default to
|
||||
# best_ntree_limit if defined, otherwise 0.
|
||||
if ntree_limit is None:
|
||||
@ -621,7 +621,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
base_score=base_score, random_state=random_state, missing=missing,
|
||||
**kwargs)
|
||||
|
||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||
def fit(self, X, y, sample_weight=None, base_margin=None,
|
||||
eval_set=None, eval_metric=None,
|
||||
early_stopping_rounds=None, verbose=True, xgb_model=None,
|
||||
sample_weight_eval_set=None, callbacks=None):
|
||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||
@ -675,12 +676,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
raise ValueError(
|
||||
'Please reshape the input data X into 2-dimensional matrix.')
|
||||
self._features_count = X.shape[1]
|
||||
|
||||
if sample_weight is not None:
|
||||
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
else:
|
||||
train_dmatrix = DMatrix(X, label=training_labels,
|
||||
base_margin=base_margin,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
|
||||
self._Booster = train(xgb_options, train_dmatrix, self.get_num_boosting_rounds(),
|
||||
@ -706,7 +703,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
fit.__doc__ = XGBModel.fit.__doc__.replace('Fit gradient boosting model',
|
||||
'Fit gradient boosting classifier', 1)
|
||||
|
||||
def predict(self, data, output_margin=False, ntree_limit=None, validate_features=True):
|
||||
def predict(self, data, output_margin=False, ntree_limit=None,
|
||||
validate_features=True, base_margin=None):
|
||||
"""
|
||||
Predict with `data`.
|
||||
|
||||
@ -729,7 +727,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : DMatrix
|
||||
data : array_like
|
||||
The dmatrix storing the input.
|
||||
output_margin : bool
|
||||
Whether to output the raw untransformed margin value.
|
||||
@ -743,7 +741,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
-------
|
||||
prediction : numpy array
|
||||
"""
|
||||
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
|
||||
test_dmatrix = DMatrix(data, base_margin=base_margin,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
if ntree_limit is None:
|
||||
ntree_limit = getattr(self, "best_ntree_limit", 0)
|
||||
class_probs = self.get_booster().predict(test_dmatrix,
|
||||
@ -761,7 +760,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
column_indexes[class_probs > 0.5] = 1
|
||||
return self._le.inverse_transform(column_indexes)
|
||||
|
||||
def predict_proba(self, data, ntree_limit=None, validate_features=True):
|
||||
def predict_proba(self, data, ntree_limit=None, validate_features=True,
|
||||
base_margin=None):
|
||||
"""
|
||||
Predict the probability of each `data` example being of a given class.
|
||||
|
||||
@ -787,7 +787,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
prediction : numpy array
|
||||
a numpy array with the probability of each data example being of a given class.
|
||||
"""
|
||||
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
|
||||
test_dmatrix = DMatrix(data, base_margin=base_margin,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
if ntree_limit is None:
|
||||
ntree_limit = getattr(self, "best_ntree_limit", 0)
|
||||
class_probs = self.get_booster().predict(test_dmatrix,
|
||||
@ -1045,7 +1046,8 @@ class XGBRanker(XGBModel):
|
||||
if "rank:" not in self.objective:
|
||||
raise ValueError("please use XGBRanker for ranking task")
|
||||
|
||||
def fit(self, X, y, group, sample_weight=None, eval_set=None,
|
||||
def fit(self, X, y, group, sample_weight=None, base_margin=None,
|
||||
eval_set=None,
|
||||
sample_weight_eval_set=None, eval_group=None, eval_metric=None,
|
||||
early_stopping_rounds=None, verbose=False, xgb_model=None,
|
||||
callbacks=None):
|
||||
@ -1072,6 +1074,8 @@ class XGBRanker(XGBModel):
|
||||
data points within each group, so it doesn't make sense to assign
|
||||
weights to individual data points.
|
||||
|
||||
base_margin : array_like
|
||||
Global bias for each instance.
|
||||
eval_set : list, optional
|
||||
A list of (X, y) tuple pairs to use as validation sets, for which
|
||||
metrics will be computed.
|
||||
@ -1138,14 +1142,10 @@ class XGBRanker(XGBModel):
|
||||
ret.set_group(group)
|
||||
return ret
|
||||
|
||||
if sample_weight is not None:
|
||||
train_dmatrix = _dmat_init(
|
||||
group, data=X, label=y, weight=sample_weight,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
else:
|
||||
train_dmatrix = _dmat_init(
|
||||
group, data=X, label=y,
|
||||
train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
train_dmatrix.set_group(group)
|
||||
|
||||
evals_result = {}
|
||||
|
||||
@ -1192,9 +1192,11 @@ class XGBRanker(XGBModel):
|
||||
|
||||
return self
|
||||
|
||||
def predict(self, data, output_margin=False, ntree_limit=0, validate_features=True):
|
||||
def predict(self, data, output_margin=False,
|
||||
ntree_limit=0, validate_features=True, base_margin=None):
|
||||
|
||||
test_dmatrix = DMatrix(data, missing=self.missing)
|
||||
test_dmatrix = DMatrix(data, base_margin=base_margin,
|
||||
missing=self.missing)
|
||||
if ntree_limit is None:
|
||||
ntree_limit = getattr(self, "best_ntree_limit", 0)
|
||||
|
||||
|
||||
@ -132,6 +132,21 @@ class TestModels(unittest.TestCase):
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, learning_rates=eta_decay)
|
||||
assert isinstance(bst, xgb.core.Booster)
|
||||
|
||||
def test_boost_from_prediction(self):
|
||||
# Re-construct dtrain here to avoid modification
|
||||
margined = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
|
||||
predt_0 = bst.predict(margined, output_margin=True)
|
||||
margined.set_base_margin(predt_0)
|
||||
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
|
||||
predt_1 = bst.predict(margined)
|
||||
|
||||
assert np.any(np.abs(predt_1 - predt_0) > 1e-6)
|
||||
|
||||
bst = xgb.train({'tree_method': 'hist'}, dtrain, 2)
|
||||
predt_2 = bst.predict(dtrain)
|
||||
assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
|
||||
|
||||
def test_custom_objective(self):
|
||||
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0}
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
|
||||
@ -695,3 +695,23 @@ def test_XGBClassifier_resume():
|
||||
|
||||
assert np.any(pred1 != pred2)
|
||||
assert log_loss1 > log_loss2
|
||||
|
||||
|
||||
def test_boost_from_prediction():
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
model_0 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4)
|
||||
model_0.fit(X=X, y=y)
|
||||
margin = model_0.predict(X, output_margin=True)
|
||||
|
||||
model_1 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4)
|
||||
model_1.fit(X=X, y=y, base_margin=margin)
|
||||
predictions_1 = model_1.predict(X, base_margin=margin)
|
||||
|
||||
cls_2 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=8)
|
||||
cls_2.fit(X=X, y=y)
|
||||
predictions_2 = cls_2.predict(X, base_margin=margin)
|
||||
assert np.all(predictions_1 == predictions_2)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user