[dask] Reduce the flakiness of tests. (#10678)

This commit is contained in:
Jiaming Yuan 2024-08-06 06:04:10 +08:00 committed by GitHub
parent 35b1cdb365
commit 6ccf116601
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 50 deletions

View File

@ -221,16 +221,12 @@ class TestDistributedGPU:
X_, y_ = load_breast_cancer(return_X_y=True)
X = dd.from_array(X_, chunksize=100).to_backend("cudf")
y = dd.from_array(y_, chunksize=100).to_backend("cudf")
divisions = copy(X.divisions)
run_boost_from_prediction(X, y, "hist", "cuda", local_cuda_client, divisions)
run_boost_from_prediction(X, y, "hist", "cuda", local_cuda_client)
X_, y_ = load_iris(return_X_y=True)
X = dd.from_array(X_, chunksize=50).to_backend("cudf")
y = dd.from_array(y_, chunksize=50).to_backend("cudf")
divisions = copy(X.divisions)
run_boost_from_prediction_multi_class(
X, y, "hist", "cuda", local_cuda_client, divisions
)
run_boost_from_prediction_multi_class(X, y, "hist", "cuda", local_cuda_client)
def test_init_estimation(self, local_cuda_client: Client) -> None:
check_init_estimation("hist", "cuda", local_cuda_client)

View File

@ -50,7 +50,7 @@ pytestmark = [tm.timeout(60), pytest.mark.skipif(**tm.no_dask())]
import dask
import dask.array as da
import dask.dataframe as dd
from distributed import Client, LocalCluster, Nanny, Worker
from distributed import Client, LocalCluster, Nanny, Worker, wait
from distributed.utils_test import async_poll_for, gen_cluster
from toolz import sliding_window # dependency of dask
@ -146,28 +146,6 @@ def generate_array(
return X, y, None
Margin = TypeVar("Margin", dd.DataFrame, dd.Series, None)
def deterministic_repartition(
client: Client,
X: dd.DataFrame,
y: dd.Series,
m: Margin,
divisions,
) -> Tuple[dd.DataFrame, dd.Series, Margin]:
"""Try to partition the dataframes according to divisions, this doesn't guarantee
the reproducibiliy.
"""
X, y, margin = (
dd.repartition(X, divisions=divisions, force=True),
dd.repartition(y, divisions=divisions, force=True),
dd.repartition(m, divisions=divisions, force=True) if m is not None else None,
)
return X, y, margin
@pytest.mark.parametrize("to_frame", [True, False])
def test_xgbclassifier_classes_type_and_value(to_frame: bool, client: "Client"):
X, y = make_classification(n_samples=1000, n_features=4, random_state=123)
@ -429,7 +407,6 @@ def run_boost_from_prediction_multi_class(
tree_method: str,
device: str,
client: "Client",
divisions: List[int],
) -> None:
model_0 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3,
@ -438,7 +415,6 @@ def run_boost_from_prediction_multi_class(
max_bin=768,
device=device,
)
X, y, _ = deterministic_repartition(client, X, y, None, divisions)
model_0.fit(X=X, y=y, eval_set=[(X, y)])
margin = xgb.dask.inplace_predict(
client, model_0.get_booster(), X, predict_type="margin"
@ -452,7 +428,6 @@ def run_boost_from_prediction_multi_class(
max_bin=768,
device=device,
)
X, y, margin = deterministic_repartition(client, X, y, margin, divisions)
model_1.fit(
X=X, y=y, base_margin=margin, eval_set=[(X, y)], base_margin_eval_set=[margin]
)
@ -470,7 +445,6 @@ def run_boost_from_prediction_multi_class(
max_bin=768,
device=device,
)
X, y, _ = deterministic_repartition(client, X, y, None, divisions)
model_2.fit(X=X, y=y, eval_set=[(X, y)])
predictions_2 = xgb.dask.inplace_predict(
client, model_2.get_booster(), X, predict_type="margin"
@ -493,7 +467,6 @@ def run_boost_from_prediction(
tree_method: str,
device: str,
client: "Client",
divisions: List[int],
) -> None:
X, y = client.persist([X, y])
@ -504,7 +477,6 @@ def run_boost_from_prediction(
max_bin=512,
device=device,
)
X, y, _ = deterministic_repartition(client, X, y, None, divisions)
model_0.fit(X=X, y=y, eval_set=[(X, y)])
margin: dd.Series = model_0.predict(X, output_margin=True)
@ -515,11 +487,9 @@ def run_boost_from_prediction(
max_bin=512,
device=device,
)
X, y, margin = deterministic_repartition(client, X, y, margin, divisions)
model_1.fit(
X=X, y=y, base_margin=margin, eval_set=[(X, y)], base_margin_eval_set=[margin]
)
X, y, margin = deterministic_repartition(client, X, y, margin, divisions)
predictions_1: dd.Series = model_1.predict(X, base_margin=margin)
model_2 = xgb.dask.DaskXGBClassifier(
@ -529,7 +499,6 @@ def run_boost_from_prediction(
max_bin=512,
device=device,
)
X, y, _ = deterministic_repartition(client, X, y, None, divisions)
model_2.fit(X=X, y=y, eval_set=[(X, y)])
predictions_2: dd.Series = model_2.predict(X)
@ -541,13 +510,11 @@ def run_boost_from_prediction(
np.testing.assert_allclose(logloss_concat, logloss_2, rtol=1e-4)
margined = xgb.dask.DaskXGBClassifier(n_estimators=4)
X, y, margin = deterministic_repartition(client, X, y, margin, divisions)
margined.fit(
X=X, y=y, base_margin=margin, eval_set=[(X, y)], base_margin_eval_set=[margin]
)
unmargined = xgb.dask.DaskXGBClassifier(n_estimators=4)
X, y, margin = deterministic_repartition(client, X, y, margin, divisions)
unmargined.fit(X=X, y=y, eval_set=[(X, y)], base_margin=margin)
margined_res = margined.evals_result()["validation_0"]["logloss"]
@ -560,18 +527,28 @@ def run_boost_from_prediction(
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
def test_boost_from_prediction(tree_method: str) -> None:
from sklearn.datasets import load_breast_cancer, load_digits
n_threads = os.cpu_count()
assert n_threads is not None
# This test has strict reproducibility requirements. However, Dask is freed to move
# partitions between workers and modify the partitions' size during the test. Given
# the lack of control over the partitioning logic, here we use a single worker as a
# workaround.
n_workers = 1
with LocalCluster(
n_workers=n_workers, threads_per_worker=n_threads // n_workers
) as cluster:
with Client(cluster) as client:
X_, y_ = load_breast_cancer(return_X_y=True)
X, y = dd.from_array(X_, chunksize=200), dd.from_array(y_, chunksize=200)
divisions = copy(X.divisions)
run_boost_from_prediction(X, y, tree_method, "cpu", client, divisions)
run_boost_from_prediction(X, y, tree_method, "cpu", client)
X_, y_ = load_digits(return_X_y=True)
X, y = dd.from_array(X_, chunksize=100), dd.from_array(y_, chunksize=100)
divisions = copy(X.divisions)
run_boost_from_prediction_multi_class(X, y, tree_method, "cpu", client, divisions)
run_boost_from_prediction_multi_class(X, y, tree_method, "cpu", client)
def test_inplace_predict(client: "Client") -> None:
@ -1574,7 +1551,6 @@ class TestWithDask:
def test_empty_quantile_dmatrix(self, client: Client) -> None:
X, y = make_categorical(client, 2, 30, 13)
X_valid, y_valid = make_categorical(client, 10000, 30, 13)
divisions = copy(X_valid.divisions)
Xy = xgb.dask.DaskQuantileDMatrix(client, X, y, enable_categorical=True)
Xy_valid = xgb.dask.DaskQuantileDMatrix(