Skip related tests when sklearn is not installed. (#4791)
This commit is contained in:
parent
fba298fecb
commit
6e6216ad67
@ -2,7 +2,8 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import xgboost
|
import xgboost
|
||||||
import unittest
|
import unittest
|
||||||
from sklearn.metrics import accuracy_score
|
import testing as tm
|
||||||
|
import pytest
|
||||||
|
|
||||||
dpath = 'demo/data/'
|
dpath = 'demo/data/'
|
||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
@ -49,7 +50,9 @@ class TestInteractionConstraints(unittest.TestCase):
|
|||||||
diff2 = preds[2] - preds[1]
|
diff2 = preds[2] - preds[1]
|
||||||
assert np.all(np.abs(diff2 - diff2[0]) < 1e-4)
|
assert np.all(np.abs(diff2 - diff2[0]) < 1e-4)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
def test_training_accuracy(self, tree_method='hist'):
|
def test_training_accuracy(self, tree_method='hist'):
|
||||||
|
from sklearn.metrics import accuracy_score
|
||||||
dtrain = xgboost.DMatrix(dpath + 'agaricus.txt.train?indexing_mode=1')
|
dtrain = xgboost.DMatrix(dpath + 'agaricus.txt.train?indexing_mode=1')
|
||||||
dtest = xgboost.DMatrix(dpath + 'agaricus.txt.test?indexing_mode=1')
|
dtest = xgboost.DMatrix(dpath + 'agaricus.txt.test?indexing_mode=1')
|
||||||
params = {
|
params = {
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import unittest
|
import unittest
|
||||||
from sklearn.metrics import accuracy_score
|
import testing as tm
|
||||||
|
import pytest
|
||||||
|
|
||||||
dpath = 'demo/data/'
|
dpath = 'demo/data/'
|
||||||
|
|
||||||
|
|
||||||
def is_increasing(y):
|
def is_increasing(y):
|
||||||
return np.count_nonzero(np.diff(y) < 0.0) == 0
|
return np.count_nonzero(np.diff(y) < 0.0) == 0
|
||||||
|
|
||||||
@ -100,7 +102,9 @@ class TestMonotoneConstraints(unittest.TestCase):
|
|||||||
|
|
||||||
assert is_correctly_constrained(constrained_hist_method)
|
assert is_correctly_constrained(constrained_hist_method)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
def test_training_accuracy(self):
|
def test_training_accuracy(self):
|
||||||
|
from sklearn.metrics import accuracy_score
|
||||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train?indexing_mode=1')
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train?indexing_mode=1')
|
||||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test?indexing_mode=1')
|
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test?indexing_mode=1')
|
||||||
params = {'eta': 1, 'max_depth': 6, 'objective': 'binary:logistic',
|
params = {'eta': 1, 'max_depth': 6, 'objective': 'binary:logistic',
|
||||||
|
|||||||
@ -1,12 +1,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.sparse import csr_matrix
|
from scipy.sparse import csr_matrix
|
||||||
import xgboost
|
import xgboost
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
from sklearn.datasets import load_svmlight_files
|
|
||||||
import unittest
|
import unittest
|
||||||
import itertools
|
import itertools
|
||||||
import glob
|
|
||||||
import shutil
|
import shutil
|
||||||
import urllib.request
|
import urllib.request
|
||||||
import zipfile
|
import zipfile
|
||||||
@ -36,6 +33,7 @@ def test_ranking_with_unweighted_data():
|
|||||||
auc_rec = evals_result['train']['aucpr']
|
auc_rec = evals_result['train']['aucpr']
|
||||||
assert all(p <= q for p, q in zip(auc_rec, auc_rec[1:]))
|
assert all(p <= q for p, q in zip(auc_rec, auc_rec[1:]))
|
||||||
|
|
||||||
|
|
||||||
def test_ranking_with_weighted_data():
|
def test_ranking_with_weighted_data():
|
||||||
Xrow = np.array([1, 2, 6, 8, 11, 14, 16, 17])
|
Xrow = np.array([1, 2, 6, 8, 11, 14, 16, 17])
|
||||||
Xcol = np.array([0, 0, 1, 1, 2, 2, 3, 3])
|
Xcol = np.array([0, 0, 1, 1, 2, 2, 3, 3])
|
||||||
@ -82,6 +80,7 @@ class TestRanking(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
Download and setup the test fixtures
|
Download and setup the test fixtures
|
||||||
"""
|
"""
|
||||||
|
from sklearn.datasets import load_svmlight_files
|
||||||
# download the test data
|
# download the test data
|
||||||
cls.dpath = 'demo/rank/'
|
cls.dpath = 'demo/rank/'
|
||||||
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip'
|
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip'
|
||||||
@ -91,7 +90,8 @@ class TestRanking(unittest.TestCase):
|
|||||||
with zipfile.ZipFile(target, 'r') as f:
|
with zipfile.ZipFile(target, 'r') as f:
|
||||||
f.extractall(path=cls.dpath)
|
f.extractall(path=cls.dpath)
|
||||||
|
|
||||||
x_train, y_train, qid_train, x_test, y_test, qid_test, x_valid, y_valid, qid_valid = load_svmlight_files(
|
(x_train, y_train, qid_train, x_test, y_test, qid_test,
|
||||||
|
x_valid, y_valid, qid_valid) = load_svmlight_files(
|
||||||
(cls.dpath + "MQ2008/Fold1/train.txt",
|
(cls.dpath + "MQ2008/Fold1/train.txt",
|
||||||
cls.dpath + "MQ2008/Fold1/test.txt",
|
cls.dpath + "MQ2008/Fold1/test.txt",
|
||||||
cls.dpath + "MQ2008/Fold1/vali.txt"),
|
cls.dpath + "MQ2008/Fold1/vali.txt"),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user