Cleanup Python tests. (#7426)
This commit is contained in:
@@ -137,16 +137,13 @@ def test_from_dask_array() -> None:
|
||||
prediction = prediction.compute()
|
||||
|
||||
booster: xgb.Booster = result['booster']
|
||||
single_node_predt = booster.predict(
|
||||
xgb.DMatrix(X.compute())
|
||||
)
|
||||
single_node_predt = booster.predict(xgb.DMatrix(X.compute()))
|
||||
np.testing.assert_allclose(prediction, single_node_predt)
|
||||
|
||||
config = json.loads(booster.save_config())
|
||||
assert int(config['learner']['generic_param']['nthread']) == 5
|
||||
|
||||
from_arr = xgb.dask.predict(
|
||||
client, model=booster, data=X)
|
||||
from_arr = xgb.dask.predict(client, model=booster, data=X)
|
||||
|
||||
assert isinstance(from_arr, da.Array)
|
||||
assert np.all(single_node_predt == from_arr.compute())
|
||||
@@ -477,23 +474,6 @@ def test_dask_classifier(model: str, client: "Client") -> None:
|
||||
run_dask_classifier(X, y_bin, w, model, None, client, 2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_sklearn_grid_search(client: "Client") -> None:
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
X, y, _ = generate_array()
|
||||
reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1,
|
||||
tree_method='hist')
|
||||
reg.client = client
|
||||
model = GridSearchCV(reg, {'max_depth': [2, 4],
|
||||
'n_estimators': [5, 10]},
|
||||
cv=2, verbose=1)
|
||||
model.fit(X, y)
|
||||
# Expect unique results for each parameter value This confirms
|
||||
# sklearn is able to successfully update the parameter
|
||||
means = model.cv_results_['mean_test_score']
|
||||
assert len(means) == len(set(means))
|
||||
|
||||
|
||||
def test_empty_dmatrix_training_continuation(client: "Client") -> None:
|
||||
kRows, kCols = 1, 97
|
||||
X = dd.from_array(np.random.randn(kRows, kCols))
|
||||
@@ -714,18 +694,11 @@ def test_auc(client: "Client") -> None:
|
||||
|
||||
# No test for Exact, as empty DMatrix handling are mostly for distributed
|
||||
# environment and Exact doesn't support it.
|
||||
def test_empty_dmatrix_hist() -> None:
|
||||
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
|
||||
def test_empty_dmatrix(tree_method) -> None:
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
parameters = {'tree_method': 'hist'}
|
||||
run_empty_dmatrix_reg(client, parameters)
|
||||
run_empty_dmatrix_cls(client, parameters)
|
||||
|
||||
|
||||
def test_empty_dmatrix_approx() -> None:
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
parameters = {'tree_method': 'approx'}
|
||||
parameters = {'tree_method': tree_method}
|
||||
run_empty_dmatrix_reg(client, parameters)
|
||||
run_empty_dmatrix_cls(client, parameters)
|
||||
|
||||
@@ -1102,12 +1075,12 @@ class TestWithDask:
|
||||
os.remove(after_fname)
|
||||
|
||||
def run_updater_test(
|
||||
self,
|
||||
client: "Client",
|
||||
params: Dict,
|
||||
num_rounds: int,
|
||||
dataset: tm.TestDataset,
|
||||
tree_method: str
|
||||
self,
|
||||
client: "Client",
|
||||
params: Dict,
|
||||
num_rounds: int,
|
||||
dataset: tm.TestDataset,
|
||||
tree_method: str
|
||||
) -> None:
|
||||
params['tree_method'] = tree_method
|
||||
params = dataset.set_params(params)
|
||||
|
||||
Reference in New Issue
Block a user