Pass scikit learn estimator checks for regressor. (#7130)

* Check data shape.
* Check labels.
This commit is contained in:
Jiaming Yuan
2021-08-03 18:58:20 +08:00
committed by GitHub
parent 8ee127469f
commit 8a84be37b8
7 changed files with 103 additions and 39 deletions

View File

@@ -330,3 +330,12 @@ class TestDMatrix:
with pytest.warns(UserWarning):
d = Data()
xgb.DMatrix(d)
from scipy import sparse
rng = np.random.RandomState(1994)
X = rng.rand(10, 10)
y = rng.rand(10)
X = sparse.dok_matrix(X)
Xy = xgb.DMatrix(X, y)
assert Xy.num_row() == 10
assert Xy.num_col() == 10