Remove warnings in tests. (#6554)
This commit is contained in:
@@ -2,6 +2,7 @@ import xgboost as xgb
|
||||
import testing as tm
|
||||
import numpy as np
|
||||
import pytest
|
||||
import os
|
||||
|
||||
rng = np.random.RandomState(1337)
|
||||
|
||||
@@ -34,8 +35,8 @@ class TestTrainingContinuation:
|
||||
from sklearn.datasets import load_digits
|
||||
from sklearn.metrics import mean_squared_error
|
||||
|
||||
digits_2class = load_digits(2)
|
||||
digits_5class = load_digits(5)
|
||||
digits_2class = load_digits(n_class=2)
|
||||
digits_5class = load_digits(n_class=5)
|
||||
|
||||
X_2class = digits_2class['data']
|
||||
y_2class = digits_2class['target']
|
||||
@@ -85,6 +86,8 @@ class TestTrainingContinuation:
|
||||
assert ntrees_03a == 10
|
||||
assert ntrees_03b == 10
|
||||
|
||||
os.remove('xgb_tc.model')
|
||||
|
||||
res1 = mean_squared_error(y_2class, gbdt_03a.predict(dtrain_2class))
|
||||
res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class))
|
||||
assert res1 == res2
|
||||
|
||||
Reference in New Issue
Block a user