Remove warnings in tests. (#6554)
This commit is contained in:
parent
8ad22bf4e7
commit
5e9e525223
@ -16,7 +16,7 @@ class TestEarlyStopping:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from sklearn.cross_validation import train_test_split
|
from sklearn.cross_validation import train_test_split
|
||||||
|
|
||||||
digits = load_digits(2)
|
digits = load_digits(n_class=2)
|
||||||
X = digits['data']
|
X = digits['data']
|
||||||
y = digits['target']
|
y = digits['target']
|
||||||
X_train, X_test, y_train, y_test = train_test_split(X, y,
|
X_train, X_test, y_train, y_test = train_test_split(X, y,
|
||||||
@ -52,7 +52,7 @@ class TestEarlyStopping:
|
|||||||
def test_cv_early_stopping(self):
|
def test_cv_early_stopping(self):
|
||||||
from sklearn.datasets import load_digits
|
from sklearn.datasets import load_digits
|
||||||
|
|
||||||
digits = load_digits(2)
|
digits = load_digits(n_class=2)
|
||||||
X = digits['data']
|
X = digits['data']
|
||||||
y = digits['target']
|
y = digits['target']
|
||||||
dm = xgb.DMatrix(X, label=y)
|
dm = xgb.DMatrix(X, label=y)
|
||||||
|
|||||||
@ -63,7 +63,7 @@ class TestEvalMetrics:
|
|||||||
from sklearn.cross_validation import train_test_split
|
from sklearn.cross_validation import train_test_split
|
||||||
from sklearn.datasets import load_digits
|
from sklearn.datasets import load_digits
|
||||||
|
|
||||||
digits = load_digits(2)
|
digits = load_digits(n_class=2)
|
||||||
X = digits['data']
|
X = digits['data']
|
||||||
y = digits['target']
|
y = digits['target']
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import xgboost as xgb
|
|||||||
import testing as tm
|
import testing as tm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
import os
|
||||||
|
|
||||||
rng = np.random.RandomState(1337)
|
rng = np.random.RandomState(1337)
|
||||||
|
|
||||||
@ -34,8 +35,8 @@ class TestTrainingContinuation:
|
|||||||
from sklearn.datasets import load_digits
|
from sklearn.datasets import load_digits
|
||||||
from sklearn.metrics import mean_squared_error
|
from sklearn.metrics import mean_squared_error
|
||||||
|
|
||||||
digits_2class = load_digits(2)
|
digits_2class = load_digits(n_class=2)
|
||||||
digits_5class = load_digits(5)
|
digits_5class = load_digits(n_class=5)
|
||||||
|
|
||||||
X_2class = digits_2class['data']
|
X_2class = digits_2class['data']
|
||||||
y_2class = digits_2class['target']
|
y_2class = digits_2class['target']
|
||||||
@ -85,6 +86,8 @@ class TestTrainingContinuation:
|
|||||||
assert ntrees_03a == 10
|
assert ntrees_03a == 10
|
||||||
assert ntrees_03b == 10
|
assert ntrees_03b == 10
|
||||||
|
|
||||||
|
os.remove('xgb_tc.model')
|
||||||
|
|
||||||
res1 = mean_squared_error(y_2class, gbdt_03a.predict(dtrain_2class))
|
res1 = mean_squared_error(y_2class, gbdt_03a.predict(dtrain_2class))
|
||||||
res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class))
|
res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class))
|
||||||
assert res1 == res2
|
assert res1 == res2
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user