Add base margin to sklearn interface. (#5151)
This commit is contained in:
@@ -132,6 +132,21 @@ class TestModels(unittest.TestCase):
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, learning_rates=eta_decay)
|
||||
assert isinstance(bst, xgb.core.Booster)
|
||||
|
||||
def test_boost_from_prediction(self):
|
||||
# Re-construct dtrain here to avoid modification
|
||||
margined = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
|
||||
predt_0 = bst.predict(margined, output_margin=True)
|
||||
margined.set_base_margin(predt_0)
|
||||
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
|
||||
predt_1 = bst.predict(margined)
|
||||
|
||||
assert np.any(np.abs(predt_1 - predt_0) > 1e-6)
|
||||
|
||||
bst = xgb.train({'tree_method': 'hist'}, dtrain, 2)
|
||||
predt_2 = bst.predict(dtrain)
|
||||
assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
|
||||
|
||||
def test_custom_objective(self):
|
||||
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0}
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
|
||||
@@ -695,3 +695,23 @@ def test_XGBClassifier_resume():
|
||||
|
||||
assert np.any(pred1 != pred2)
|
||||
assert log_loss1 > log_loss2
|
||||
|
||||
|
||||
def test_boost_from_prediction():
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
model_0 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4)
|
||||
model_0.fit(X=X, y=y)
|
||||
margin = model_0.predict(X, output_margin=True)
|
||||
|
||||
model_1 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4)
|
||||
model_1.fit(X=X, y=y, base_margin=margin)
|
||||
predictions_1 = model_1.predict(X, base_margin=margin)
|
||||
|
||||
cls_2 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=8)
|
||||
cls_2.fit(X=X, y=y)
|
||||
predictions_2 = cls_2.predict(X, base_margin=margin)
|
||||
assert np.all(predictions_1 == predictions_2)
|
||||
|
||||
Reference in New Issue
Block a user