Add base margin to sklearn interface. (#5151)
This commit is contained in:
parent
1d0ca49761
commit
0202e04a8e
@ -1,5 +1,4 @@
|
|||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
import numpy as np
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
|
|
||||||
dtrain = xgb.DMatrix('../data/agaricus.txt.train')
|
dtrain = xgb.DMatrix('../data/agaricus.txt.train')
|
||||||
@ -8,18 +7,19 @@ watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
|||||||
###
|
###
|
||||||
# advanced: start from a initial base prediction
|
# advanced: start from a initial base prediction
|
||||||
#
|
#
|
||||||
print ('start running example to start from a initial prediction')
|
print('start running example to start from a initial prediction')
|
||||||
# specify parameters via map, definition are same as c++ version
|
# specify parameters via map, definition are same as c++ version
|
||||||
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic'}
|
param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
|
||||||
# train xgboost for 1 round
|
# train xgboost for 1 round
|
||||||
bst = xgb.train(param, dtrain, 1, watchlist)
|
bst = xgb.train(param, dtrain, 1, watchlist)
|
||||||
# Note: we need the margin value instead of transformed prediction in set_base_margin
|
# Note: we need the margin value instead of transformed prediction in
|
||||||
# do predict with output_margin=True, will always give you margin values before logistic transformation
|
# set_base_margin
|
||||||
|
# do predict with output_margin=True, will always give you margin values
|
||||||
|
# before logistic transformation
|
||||||
ptrain = bst.predict(dtrain, output_margin=True)
|
ptrain = bst.predict(dtrain, output_margin=True)
|
||||||
ptest = bst.predict(dtest, output_margin=True)
|
ptest = bst.predict(dtest, output_margin=True)
|
||||||
dtrain.set_base_margin(ptrain)
|
dtrain.set_base_margin(ptrain)
|
||||||
dtest.set_base_margin(ptest)
|
dtest.set_base_margin(ptest)
|
||||||
|
|
||||||
|
|
||||||
print('this is result of running from initial prediction')
|
print('this is result of running from initial prediction')
|
||||||
bst = xgb.train(param, dtrain, 1, watchlist)
|
bst = xgb.train(param, dtrain, 1, watchlist)
|
||||||
|
|||||||
@ -434,9 +434,11 @@ class DMatrix(object):
|
|||||||
_feature_names = None # for previous version's pickle
|
_feature_names = None # for previous version's pickle
|
||||||
_feature_types = None
|
_feature_types = None
|
||||||
|
|
||||||
def __init__(self, data, label=None, missing=None,
|
def __init__(self, data, label=None, weight=None, base_margin=None,
|
||||||
weight=None, silent=False,
|
missing=None,
|
||||||
feature_names=None, feature_types=None,
|
silent=False,
|
||||||
|
feature_names=None,
|
||||||
|
feature_types=None,
|
||||||
nthread=None):
|
nthread=None):
|
||||||
"""Parameters
|
"""Parameters
|
||||||
----------
|
----------
|
||||||
@ -492,6 +494,7 @@ class DMatrix(object):
|
|||||||
label = _maybe_pandas_label(label)
|
label = _maybe_pandas_label(label)
|
||||||
label = _maybe_dt_array(label)
|
label = _maybe_dt_array(label)
|
||||||
weight = _maybe_dt_array(weight)
|
weight = _maybe_dt_array(weight)
|
||||||
|
base_margin = _maybe_dt_array(base_margin)
|
||||||
|
|
||||||
if isinstance(data, (STRING_TYPES, os_PathLike)):
|
if isinstance(data, (STRING_TYPES, os_PathLike)):
|
||||||
handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
@ -518,19 +521,11 @@ class DMatrix(object):
|
|||||||
' {}'.format(type(data).__name__))
|
' {}'.format(type(data).__name__))
|
||||||
|
|
||||||
if label is not None:
|
if label is not None:
|
||||||
if isinstance(label, np.ndarray):
|
self.set_label(label)
|
||||||
self.set_label_npy2d(label)
|
|
||||||
elif _use_columnar_initializer(label):
|
|
||||||
self.set_interface_info('label', label)
|
|
||||||
else:
|
|
||||||
self.set_label(label)
|
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
if isinstance(weight, np.ndarray):
|
self.set_weight(weight)
|
||||||
self.set_weight_npy2d(weight)
|
if base_margin is not None:
|
||||||
elif _use_columnar_initializer(label):
|
self.set_base_margin(base_margin)
|
||||||
self.set_interface_info('weight', weight)
|
|
||||||
else:
|
|
||||||
self.set_weight(weight)
|
|
||||||
|
|
||||||
self.feature_names = feature_names
|
self.feature_names = feature_names
|
||||||
self.feature_types = feature_types
|
self.feature_types = feature_types
|
||||||
@ -792,7 +787,12 @@ class DMatrix(object):
|
|||||||
label: array like
|
label: array like
|
||||||
The label information to be set into DMatrix
|
The label information to be set into DMatrix
|
||||||
"""
|
"""
|
||||||
self.set_float_info('label', label)
|
if isinstance(label, np.ndarray):
|
||||||
|
self.set_label_npy2d(label)
|
||||||
|
elif _use_columnar_initializer(label):
|
||||||
|
self.set_interface_info('label', label)
|
||||||
|
else:
|
||||||
|
self.set_float_info('label', label)
|
||||||
|
|
||||||
def set_label_npy2d(self, label):
|
def set_label_npy2d(self, label):
|
||||||
"""Set label of dmatrix
|
"""Set label of dmatrix
|
||||||
@ -820,7 +820,12 @@ class DMatrix(object):
|
|||||||
data points within each group, so it doesn't make sense to assign
|
data points within each group, so it doesn't make sense to assign
|
||||||
weights to individual data points.
|
weights to individual data points.
|
||||||
"""
|
"""
|
||||||
self.set_float_info('weight', weight)
|
if isinstance(weight, np.ndarray):
|
||||||
|
self.set_weight_npy2d(weight)
|
||||||
|
elif _use_columnar_initializer(weight):
|
||||||
|
self.set_interface_info('weight', weight)
|
||||||
|
else:
|
||||||
|
self.set_float_info('weight', weight)
|
||||||
|
|
||||||
def set_weight_npy2d(self, weight):
|
def set_weight_npy2d(self, weight):
|
||||||
""" Set weight of each instance
|
""" Set weight of each instance
|
||||||
|
|||||||
@ -288,9 +288,9 @@ class XGBModel(XGBModelBase):
|
|||||||
self._Booster = Booster({'n_jobs': self.n_jobs})
|
self._Booster = Booster({'n_jobs': self.n_jobs})
|
||||||
self._Booster.load_model(fname)
|
self._Booster.load_model(fname)
|
||||||
|
|
||||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
def fit(self, X, y, sample_weight=None, base_margin=None,
|
||||||
early_stopping_rounds=None, verbose=True, xgb_model=None,
|
eval_set=None, eval_metric=None, early_stopping_rounds=None,
|
||||||
sample_weight_eval_set=None, callbacks=None):
|
verbose=True, xgb_model=None, sample_weight_eval_set=None, callbacks=None):
|
||||||
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
|
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
|
||||||
"""Fit gradient boosting model
|
"""Fit gradient boosting model
|
||||||
|
|
||||||
@ -302,6 +302,8 @@ class XGBModel(XGBModelBase):
|
|||||||
Labels
|
Labels
|
||||||
sample_weight : array_like
|
sample_weight : array_like
|
||||||
instance weights
|
instance weights
|
||||||
|
base_margin : array_like
|
||||||
|
global bias for each instance.
|
||||||
eval_set : list, optional
|
eval_set : list, optional
|
||||||
A list of (X, y) tuple pairs to use as validation sets, for which
|
A list of (X, y) tuple pairs to use as validation sets, for which
|
||||||
metrics will be computed.
|
metrics will be computed.
|
||||||
@ -346,14 +348,10 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
[xgb.callback.reset_learning_rate(custom_rates)]
|
[xgb.callback.reset_learning_rate(custom_rates)]
|
||||||
"""
|
"""
|
||||||
if sample_weight is not None:
|
trainDmatrix = DMatrix(data=X, label=y, weight=sample_weight,
|
||||||
trainDmatrix = DMatrix(X, label=y, weight=sample_weight,
|
base_margin=base_margin,
|
||||||
missing=self.missing,
|
missing=self.missing,
|
||||||
nthread=self.n_jobs)
|
nthread=self.n_jobs)
|
||||||
else:
|
|
||||||
trainDmatrix = DMatrix(X, label=y, missing=self.missing,
|
|
||||||
nthread=self.n_jobs)
|
|
||||||
|
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
|
|
||||||
if eval_set is not None:
|
if eval_set is not None:
|
||||||
@ -404,7 +402,8 @@ class XGBModel(XGBModelBase):
|
|||||||
self.best_ntree_limit = self._Booster.best_ntree_limit
|
self.best_ntree_limit = self._Booster.best_ntree_limit
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def predict(self, data, output_margin=False, ntree_limit=None, validate_features=True):
|
def predict(self, data, output_margin=False, ntree_limit=None,
|
||||||
|
validate_features=True, base_margin=None):
|
||||||
"""
|
"""
|
||||||
Predict with `data`.
|
Predict with `data`.
|
||||||
|
|
||||||
@ -442,7 +441,8 @@ class XGBModel(XGBModelBase):
|
|||||||
prediction : numpy array
|
prediction : numpy array
|
||||||
"""
|
"""
|
||||||
# pylint: disable=missing-docstring,invalid-name
|
# pylint: disable=missing-docstring,invalid-name
|
||||||
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
|
test_dmatrix = DMatrix(data, base_margin=base_margin,
|
||||||
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
# get ntree_limit to use - if none specified, default to
|
# get ntree_limit to use - if none specified, default to
|
||||||
# best_ntree_limit if defined, otherwise 0.
|
# best_ntree_limit if defined, otherwise 0.
|
||||||
if ntree_limit is None:
|
if ntree_limit is None:
|
||||||
@ -621,7 +621,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
base_score=base_score, random_state=random_state, missing=missing,
|
base_score=base_score, random_state=random_state, missing=missing,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
def fit(self, X, y, sample_weight=None, base_margin=None,
|
||||||
|
eval_set=None, eval_metric=None,
|
||||||
early_stopping_rounds=None, verbose=True, xgb_model=None,
|
early_stopping_rounds=None, verbose=True, xgb_model=None,
|
||||||
sample_weight_eval_set=None, callbacks=None):
|
sample_weight_eval_set=None, callbacks=None):
|
||||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||||
@ -675,13 +676,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Please reshape the input data X into 2-dimensional matrix.')
|
'Please reshape the input data X into 2-dimensional matrix.')
|
||||||
self._features_count = X.shape[1]
|
self._features_count = X.shape[1]
|
||||||
|
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
|
||||||
if sample_weight is not None:
|
base_margin=base_margin,
|
||||||
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
missing=self.missing, nthread=self.n_jobs)
|
|
||||||
else:
|
|
||||||
train_dmatrix = DMatrix(X, label=training_labels,
|
|
||||||
missing=self.missing, nthread=self.n_jobs)
|
|
||||||
|
|
||||||
self._Booster = train(xgb_options, train_dmatrix, self.get_num_boosting_rounds(),
|
self._Booster = train(xgb_options, train_dmatrix, self.get_num_boosting_rounds(),
|
||||||
evals=evals, early_stopping_rounds=early_stopping_rounds,
|
evals=evals, early_stopping_rounds=early_stopping_rounds,
|
||||||
@ -706,7 +703,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
fit.__doc__ = XGBModel.fit.__doc__.replace('Fit gradient boosting model',
|
fit.__doc__ = XGBModel.fit.__doc__.replace('Fit gradient boosting model',
|
||||||
'Fit gradient boosting classifier', 1)
|
'Fit gradient boosting classifier', 1)
|
||||||
|
|
||||||
def predict(self, data, output_margin=False, ntree_limit=None, validate_features=True):
|
def predict(self, data, output_margin=False, ntree_limit=None,
|
||||||
|
validate_features=True, base_margin=None):
|
||||||
"""
|
"""
|
||||||
Predict with `data`.
|
Predict with `data`.
|
||||||
|
|
||||||
@ -729,7 +727,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
data : DMatrix
|
data : array_like
|
||||||
The dmatrix storing the input.
|
The dmatrix storing the input.
|
||||||
output_margin : bool
|
output_margin : bool
|
||||||
Whether to output the raw untransformed margin value.
|
Whether to output the raw untransformed margin value.
|
||||||
@ -743,7 +741,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
-------
|
-------
|
||||||
prediction : numpy array
|
prediction : numpy array
|
||||||
"""
|
"""
|
||||||
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
|
test_dmatrix = DMatrix(data, base_margin=base_margin,
|
||||||
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
if ntree_limit is None:
|
if ntree_limit is None:
|
||||||
ntree_limit = getattr(self, "best_ntree_limit", 0)
|
ntree_limit = getattr(self, "best_ntree_limit", 0)
|
||||||
class_probs = self.get_booster().predict(test_dmatrix,
|
class_probs = self.get_booster().predict(test_dmatrix,
|
||||||
@ -761,7 +760,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
column_indexes[class_probs > 0.5] = 1
|
column_indexes[class_probs > 0.5] = 1
|
||||||
return self._le.inverse_transform(column_indexes)
|
return self._le.inverse_transform(column_indexes)
|
||||||
|
|
||||||
def predict_proba(self, data, ntree_limit=None, validate_features=True):
|
def predict_proba(self, data, ntree_limit=None, validate_features=True,
|
||||||
|
base_margin=None):
|
||||||
"""
|
"""
|
||||||
Predict the probability of each `data` example being of a given class.
|
Predict the probability of each `data` example being of a given class.
|
||||||
|
|
||||||
@ -787,7 +787,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
prediction : numpy array
|
prediction : numpy array
|
||||||
a numpy array with the probability of each data example being of a given class.
|
a numpy array with the probability of each data example being of a given class.
|
||||||
"""
|
"""
|
||||||
test_dmatrix = DMatrix(data, missing=self.missing, nthread=self.n_jobs)
|
test_dmatrix = DMatrix(data, base_margin=base_margin,
|
||||||
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
if ntree_limit is None:
|
if ntree_limit is None:
|
||||||
ntree_limit = getattr(self, "best_ntree_limit", 0)
|
ntree_limit = getattr(self, "best_ntree_limit", 0)
|
||||||
class_probs = self.get_booster().predict(test_dmatrix,
|
class_probs = self.get_booster().predict(test_dmatrix,
|
||||||
@ -1045,7 +1046,8 @@ class XGBRanker(XGBModel):
|
|||||||
if "rank:" not in self.objective:
|
if "rank:" not in self.objective:
|
||||||
raise ValueError("please use XGBRanker for ranking task")
|
raise ValueError("please use XGBRanker for ranking task")
|
||||||
|
|
||||||
def fit(self, X, y, group, sample_weight=None, eval_set=None,
|
def fit(self, X, y, group, sample_weight=None, base_margin=None,
|
||||||
|
eval_set=None,
|
||||||
sample_weight_eval_set=None, eval_group=None, eval_metric=None,
|
sample_weight_eval_set=None, eval_group=None, eval_metric=None,
|
||||||
early_stopping_rounds=None, verbose=False, xgb_model=None,
|
early_stopping_rounds=None, verbose=False, xgb_model=None,
|
||||||
callbacks=None):
|
callbacks=None):
|
||||||
@ -1072,6 +1074,8 @@ class XGBRanker(XGBModel):
|
|||||||
data points within each group, so it doesn't make sense to assign
|
data points within each group, so it doesn't make sense to assign
|
||||||
weights to individual data points.
|
weights to individual data points.
|
||||||
|
|
||||||
|
base_margin : array_like
|
||||||
|
Global bias for each instance.
|
||||||
eval_set : list, optional
|
eval_set : list, optional
|
||||||
A list of (X, y) tuple pairs to use as validation sets, for which
|
A list of (X, y) tuple pairs to use as validation sets, for which
|
||||||
metrics will be computed.
|
metrics will be computed.
|
||||||
@ -1138,14 +1142,10 @@ class XGBRanker(XGBModel):
|
|||||||
ret.set_group(group)
|
ret.set_group(group)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
if sample_weight is not None:
|
train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
|
||||||
train_dmatrix = _dmat_init(
|
base_margin=base_margin,
|
||||||
group, data=X, label=y, weight=sample_weight,
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
missing=self.missing, nthread=self.n_jobs)
|
train_dmatrix.set_group(group)
|
||||||
else:
|
|
||||||
train_dmatrix = _dmat_init(
|
|
||||||
group, data=X, label=y,
|
|
||||||
missing=self.missing, nthread=self.n_jobs)
|
|
||||||
|
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
|
|
||||||
@ -1192,9 +1192,11 @@ class XGBRanker(XGBModel):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def predict(self, data, output_margin=False, ntree_limit=0, validate_features=True):
|
def predict(self, data, output_margin=False,
|
||||||
|
ntree_limit=0, validate_features=True, base_margin=None):
|
||||||
|
|
||||||
test_dmatrix = DMatrix(data, missing=self.missing)
|
test_dmatrix = DMatrix(data, base_margin=base_margin,
|
||||||
|
missing=self.missing)
|
||||||
if ntree_limit is None:
|
if ntree_limit is None:
|
||||||
ntree_limit = getattr(self, "best_ntree_limit", 0)
|
ntree_limit = getattr(self, "best_ntree_limit", 0)
|
||||||
|
|
||||||
|
|||||||
@ -132,6 +132,21 @@ class TestModels(unittest.TestCase):
|
|||||||
bst = xgb.train(param, dtrain, num_round, watchlist, learning_rates=eta_decay)
|
bst = xgb.train(param, dtrain, num_round, watchlist, learning_rates=eta_decay)
|
||||||
assert isinstance(bst, xgb.core.Booster)
|
assert isinstance(bst, xgb.core.Booster)
|
||||||
|
|
||||||
|
def test_boost_from_prediction(self):
|
||||||
|
# Re-construct dtrain here to avoid modification
|
||||||
|
margined = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
|
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
|
||||||
|
predt_0 = bst.predict(margined, output_margin=True)
|
||||||
|
margined.set_base_margin(predt_0)
|
||||||
|
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
|
||||||
|
predt_1 = bst.predict(margined)
|
||||||
|
|
||||||
|
assert np.any(np.abs(predt_1 - predt_0) > 1e-6)
|
||||||
|
|
||||||
|
bst = xgb.train({'tree_method': 'hist'}, dtrain, 2)
|
||||||
|
predt_2 = bst.predict(dtrain)
|
||||||
|
assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
|
||||||
|
|
||||||
def test_custom_objective(self):
|
def test_custom_objective(self):
|
||||||
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0}
|
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0}
|
||||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
|
|||||||
@ -695,3 +695,23 @@ def test_XGBClassifier_resume():
|
|||||||
|
|
||||||
assert np.any(pred1 != pred2)
|
assert np.any(pred1 != pred2)
|
||||||
assert log_loss1 > log_loss2
|
assert log_loss1 > log_loss2
|
||||||
|
|
||||||
|
|
||||||
|
def test_boost_from_prediction():
|
||||||
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
model_0 = xgb.XGBClassifier(
|
||||||
|
learning_rate=0.3, random_state=0, n_estimators=4)
|
||||||
|
model_0.fit(X=X, y=y)
|
||||||
|
margin = model_0.predict(X, output_margin=True)
|
||||||
|
|
||||||
|
model_1 = xgb.XGBClassifier(
|
||||||
|
learning_rate=0.3, random_state=0, n_estimators=4)
|
||||||
|
model_1.fit(X=X, y=y, base_margin=margin)
|
||||||
|
predictions_1 = model_1.predict(X, base_margin=margin)
|
||||||
|
|
||||||
|
cls_2 = xgb.XGBClassifier(
|
||||||
|
learning_rate=0.3, random_state=0, n_estimators=8)
|
||||||
|
cls_2.fit(X=X, y=y)
|
||||||
|
predictions_2 = cls_2.predict(X, base_margin=margin)
|
||||||
|
assert np.all(predictions_1 == predictions_2)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user