Skip related tests when sklearn is not installed. (#4791)
This commit is contained in:
@@ -1,12 +1,9 @@
|
||||
import numpy as np
|
||||
from scipy.sparse import csr_matrix
|
||||
import xgboost
|
||||
import sys
|
||||
import os
|
||||
from sklearn.datasets import load_svmlight_files
|
||||
import unittest
|
||||
import itertools
|
||||
import glob
|
||||
import shutil
|
||||
import urllib.request
|
||||
import zipfile
|
||||
@@ -36,6 +33,7 @@ def test_ranking_with_unweighted_data():
|
||||
auc_rec = evals_result['train']['aucpr']
|
||||
assert all(p <= q for p, q in zip(auc_rec, auc_rec[1:]))
|
||||
|
||||
|
||||
def test_ranking_with_weighted_data():
|
||||
Xrow = np.array([1, 2, 6, 8, 11, 14, 16, 17])
|
||||
Xcol = np.array([0, 0, 1, 1, 2, 2, 3, 3])
|
||||
@@ -82,6 +80,7 @@ class TestRanking(unittest.TestCase):
|
||||
"""
|
||||
Download and setup the test fixtures
|
||||
"""
|
||||
from sklearn.datasets import load_svmlight_files
|
||||
# download the test data
|
||||
cls.dpath = 'demo/rank/'
|
||||
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip'
|
||||
@@ -91,7 +90,8 @@ class TestRanking(unittest.TestCase):
|
||||
with zipfile.ZipFile(target, 'r') as f:
|
||||
f.extractall(path=cls.dpath)
|
||||
|
||||
x_train, y_train, qid_train, x_test, y_test, qid_test, x_valid, y_valid, qid_valid = load_svmlight_files(
|
||||
(x_train, y_train, qid_train, x_test, y_test, qid_test,
|
||||
x_valid, y_valid, qid_valid) = load_svmlight_files(
|
||||
(cls.dpath + "MQ2008/Fold1/train.txt",
|
||||
cls.dpath + "MQ2008/Fold1/test.txt",
|
||||
cls.dpath + "MQ2008/Fold1/vali.txt"),
|
||||
|
||||
Reference in New Issue
Block a user