Enforce correct data shape. (#5191)
* Fix syncing DMatrix columns. * notes for tree method. * Enable feature validation for all interfaces except for jvm. * Better tests for boosting from predictions. * Disable validation on JVM.
This commit is contained in:
@@ -5,6 +5,7 @@ import tempfile
|
||||
import os
|
||||
import shutil
|
||||
import pytest
|
||||
import unittest
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
@@ -697,21 +698,37 @@ def test_XGBClassifier_resume():
|
||||
assert log_loss1 > log_loss2
|
||||
|
||||
|
||||
def test_boost_from_prediction():
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
model_0 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4)
|
||||
model_0.fit(X=X, y=y)
|
||||
margin = model_0.predict(X, output_margin=True)
|
||||
class TestBoostFromPrediction(unittest.TestCase):
|
||||
def run_boost_from_prediction(self, tree_method):
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
model_0 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4,
|
||||
tree_method=tree_method)
|
||||
model_0.fit(X=X, y=y)
|
||||
margin = model_0.predict(X, output_margin=True)
|
||||
|
||||
model_1 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4)
|
||||
model_1.fit(X=X, y=y, base_margin=margin)
|
||||
predictions_1 = model_1.predict(X, base_margin=margin)
|
||||
model_1 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4,
|
||||
tree_method=tree_method)
|
||||
model_1.fit(X=X, y=y, base_margin=margin)
|
||||
predictions_1 = model_1.predict(X, base_margin=margin)
|
||||
|
||||
cls_2 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=8)
|
||||
cls_2.fit(X=X, y=y)
|
||||
predictions_2 = cls_2.predict(X, base_margin=margin)
|
||||
assert np.all(predictions_1 == predictions_2)
|
||||
cls_2 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=8,
|
||||
tree_method=tree_method)
|
||||
cls_2.fit(X=X, y=y)
|
||||
predictions_2 = cls_2.predict(X)
|
||||
assert np.all(predictions_1 == predictions_2)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_boost_from_prediction_hist(self):
|
||||
self.run_boost_from_prediction('hist')
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_boost_from_prediction_approx(self):
|
||||
self.run_boost_from_prediction('approx')
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_boost_from_prediction_exact(self):
|
||||
self.run_boost_from_prediction('exact')
|
||||
|
||||
Reference in New Issue
Block a user