Remove warnings in tests. (#6554)

This commit is contained in:
Jiaming Yuan 2020-12-31 13:41:18 +08:00 committed by GitHub
parent 8ad22bf4e7
commit 5e9e525223
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 5 deletions

View File

@ -16,7 +16,7 @@ class TestEarlyStopping:
except ImportError:
from sklearn.cross_validation import train_test_split
digits = load_digits(2)
digits = load_digits(n_class=2)
X = digits['data']
y = digits['target']
X_train, X_test, y_train, y_test = train_test_split(X, y,
@ -52,7 +52,7 @@ class TestEarlyStopping:
def test_cv_early_stopping(self):
from sklearn.datasets import load_digits
digits = load_digits(2)
digits = load_digits(n_class=2)
X = digits['data']
y = digits['target']
dm = xgb.DMatrix(X, label=y)

View File

@ -63,7 +63,7 @@ class TestEvalMetrics:
from sklearn.cross_validation import train_test_split
from sklearn.datasets import load_digits
digits = load_digits(2)
digits = load_digits(n_class=2)
X = digits['data']
y = digits['target']

View File

@ -2,6 +2,7 @@ import xgboost as xgb
import testing as tm
import numpy as np
import pytest
import os
rng = np.random.RandomState(1337)
@ -34,8 +35,8 @@ class TestTrainingContinuation:
from sklearn.datasets import load_digits
from sklearn.metrics import mean_squared_error
digits_2class = load_digits(2)
digits_5class = load_digits(5)
digits_2class = load_digits(n_class=2)
digits_5class = load_digits(n_class=5)
X_2class = digits_2class['data']
y_2class = digits_2class['target']
@ -85,6 +86,8 @@ class TestTrainingContinuation:
assert ntrees_03a == 10
assert ntrees_03b == 10
os.remove('xgb_tc.model')
res1 = mean_squared_error(y_2class, gbdt_03a.predict(dtrain_2class))
res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class))
assert res1 == res2