Typehint for Sklearn. (#6799)
This commit is contained in:
@@ -60,25 +60,25 @@ class TestInteractionConstraints:
|
||||
def test_interaction_constraints_feature_names(self):
|
||||
with pytest.raises(ValueError):
|
||||
constraints = [('feature_0', 'feature_1')]
|
||||
self.run_interaction_constraints(tree_method='exact',
|
||||
self.run_interaction_constraints(tree_method='exact',
|
||||
interaction_constraints=constraints)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
constraints = [('feature_0', 'feature_3')]
|
||||
feature_names = ['feature_0', 'feature_1', 'feature_2']
|
||||
self.run_interaction_constraints(tree_method='exact',
|
||||
feature_names=feature_names,
|
||||
self.run_interaction_constraints(tree_method='exact',
|
||||
feature_names=feature_names,
|
||||
interaction_constraints=constraints)
|
||||
|
||||
|
||||
constraints = [('feature_0', 'feature_1')]
|
||||
feature_names = ['feature_0', 'feature_1', 'feature_2']
|
||||
self.run_interaction_constraints(tree_method='exact',
|
||||
feature_names=feature_names,
|
||||
self.run_interaction_constraints(tree_method='exact',
|
||||
feature_names=feature_names,
|
||||
interaction_constraints=constraints)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def training_accuracy(self, tree_method):
|
||||
"""Test accuracy, reused by GPU tests."""
|
||||
from sklearn.metrics import accuracy_score
|
||||
dtrain = xgboost.DMatrix(dpath + 'agaricus.txt.train?indexing_mode=1')
|
||||
dtest = xgboost.DMatrix(dpath + 'agaricus.txt.test?indexing_mode=1')
|
||||
@@ -101,11 +101,6 @@ class TestInteractionConstraints:
|
||||
pred_dtest = (bst.predict(dtest) < 0.5)
|
||||
assert accuracy_score(dtest.get_label(), pred_dtest) < 0.1
|
||||
|
||||
def test_hist_training_accuracy(self):
|
||||
self.training_accuracy(tree_method='hist')
|
||||
|
||||
def test_exact_training_accuracy(self):
|
||||
self.training_accuracy(tree_method='exact')
|
||||
|
||||
def test_approx_training_accuracy(self):
|
||||
self.training_accuracy(tree_method='approx')
|
||||
@pytest.mark.parametrize("tree_method", ["hist", "approx", "exact"])
|
||||
def test_hist_training_accuracy(self, tree_method):
|
||||
self.training_accuracy(tree_method=tree_method)
|
||||
|
||||
Reference in New Issue
Block a user