[dask] Mitigate non-deterministic test. (#8077)
This commit is contained in:
parent
7a6b711eb8
commit
2365f82750
@ -1,39 +1,45 @@
|
||||
"""Copyright 2019-2022 XGBoost contributors"""
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import testing as tm
|
||||
import pytest
|
||||
import xgboost as xgb
|
||||
import sys
|
||||
import numpy as np
|
||||
import scipy
|
||||
import json
|
||||
from typing import List, Tuple, Dict, Optional, Type, Any
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import tempfile
|
||||
from sklearn.datasets import make_classification
|
||||
import sklearn
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from itertools import starmap
|
||||
from math import ceil
|
||||
from operator import attrgetter, getitem
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import hypothesis
|
||||
from hypothesis import given, settings, note, HealthCheck
|
||||
from test_updaters import hist_parameter_strategy, exact_parameter_strategy
|
||||
from test_with_sklearn import run_feature_weights, run_data_initialization
|
||||
import numpy as np
|
||||
import pytest
|
||||
import scipy
|
||||
import sklearn
|
||||
import testing as tm
|
||||
from hypothesis import HealthCheck, given, note, settings
|
||||
from sklearn.datasets import make_classification, make_regression
|
||||
from test_predict import verify_leaf_output
|
||||
from sklearn.datasets import make_regression
|
||||
from test_updaters import exact_parameter_strategy, hist_parameter_strategy
|
||||
from test_with_sklearn import run_data_initialization, run_feature_weights
|
||||
from xgboost.data import _is_cudf_df
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||
if tm.no_dask()['condition']:
|
||||
pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True)
|
||||
|
||||
from distributed import LocalCluster, Client
|
||||
import dask
|
||||
import dask.dataframe as dd
|
||||
import dask.array as da
|
||||
import dask.dataframe as dd
|
||||
from distributed import Client, LocalCluster
|
||||
from toolz import sliding_window # dependency of dask
|
||||
from xgboost.dask import DaskDMatrix
|
||||
|
||||
dask.config.set({"distributed.scheduler.allowed-failures": False})
|
||||
@ -125,6 +131,63 @@ def generate_array(
|
||||
return X, y, None
|
||||
|
||||
|
||||
def deterministic_persist_per_worker(df, client):
|
||||
# Got this script from https://github.com/dmlc/xgboost/issues/7927
|
||||
# Query workers
|
||||
n_workers = len(client.cluster.workers)
|
||||
workers = map(attrgetter("worker_address"), client.cluster.workers.values())
|
||||
|
||||
# Slice data into roughly equal partitions
|
||||
subpartition_size = ceil(df.npartitions / n_workers)
|
||||
subpartition_divisions = range(
|
||||
0, df.npartitions + subpartition_size, subpartition_size
|
||||
)
|
||||
subpartition_slices = starmap(slice, sliding_window(2, subpartition_divisions))
|
||||
subpartitions = map(partial(getitem, df.partitions), subpartition_slices)
|
||||
|
||||
# Persist each subpartition on each worker
|
||||
# Rebuild dataframe from persisted subpartitions
|
||||
df2 = dd.concat(
|
||||
[
|
||||
sp.persist(workers=w, allow_other_workers=False)
|
||||
for sp, w in zip(subpartitions, workers)
|
||||
]
|
||||
)
|
||||
|
||||
return df2
|
||||
|
||||
|
||||
def deterministic_repartition(
|
||||
client: Client,
|
||||
X: dd.DataFrame,
|
||||
y: dd.Series,
|
||||
m: Optional[Union[dd.DataFrame, dd.Series]],
|
||||
) -> Tuple[dd.DataFrame, dd.Series, Optional[Union[dd.DataFrame, dd.Series]]]:
|
||||
# force repartition the data to avoid non-deterministic result
|
||||
if any(X.map_partitions(lambda x: _is_cudf_df(x)).compute()):
|
||||
# dask_cudf seems to be doing fine for now
|
||||
return X, y, m
|
||||
|
||||
X["_y"] = y
|
||||
if m is not None:
|
||||
if isinstance(m, dd.DataFrame):
|
||||
m_columns = m.columns
|
||||
X = dd.concat([X, m], join="outer", axis=1)
|
||||
else:
|
||||
m_columns = ["_m"]
|
||||
X["_m"] = m
|
||||
|
||||
X = deterministic_persist_per_worker(X, client)
|
||||
|
||||
y = X["_y"]
|
||||
X = X[X.columns.difference(["_y"])]
|
||||
if m is not None:
|
||||
m = X[m_columns]
|
||||
X = X[X.columns.difference(m_columns)]
|
||||
|
||||
return X, y, m
|
||||
|
||||
|
||||
def test_from_dask_dataframe() -> None:
|
||||
with LocalCluster(n_workers=kWorkers, dashboard_address=":0") as cluster:
|
||||
with Client(cluster) as client:
|
||||
@ -347,22 +410,25 @@ def test_dask_predict_shape_infer(client: "Client") -> None:
|
||||
|
||||
|
||||
def run_boost_from_prediction_multi_class(
|
||||
X: xgb.dask._DaskCollection,
|
||||
y: xgb.dask._DaskCollection,
|
||||
X: dd.DataFrame,
|
||||
y: dd.Series,
|
||||
tree_method: str,
|
||||
client: "Client",
|
||||
) -> None:
|
||||
model_0 = xgb.dask.DaskXGBClassifier(
|
||||
learning_rate=0.3, n_estimators=4, tree_method=tree_method, max_bin=768
|
||||
)
|
||||
X, y, _ = deterministic_repartition(client, X, y, None)
|
||||
model_0.fit(X=X, y=y)
|
||||
margin = xgb.dask.inplace_predict(
|
||||
client, model_0.get_booster(), X, predict_type="margin"
|
||||
)
|
||||
margin.columns = [f"m_{i}" for i in range(margin.shape[1])]
|
||||
|
||||
model_1 = xgb.dask.DaskXGBClassifier(
|
||||
learning_rate=0.3, n_estimators=4, tree_method=tree_method, max_bin=768
|
||||
)
|
||||
X, y, margin = deterministic_repartition(client, X, y, margin)
|
||||
model_1.fit(X=X, y=y, base_margin=margin)
|
||||
predictions_1 = xgb.dask.predict(
|
||||
client,
|
||||
@ -374,6 +440,7 @@ def run_boost_from_prediction_multi_class(
|
||||
model_2 = xgb.dask.DaskXGBClassifier(
|
||||
learning_rate=0.3, n_estimators=8, tree_method=tree_method, max_bin=768
|
||||
)
|
||||
X, y, _ = deterministic_repartition(client, X, y, None)
|
||||
model_2.fit(X=X, y=y)
|
||||
predictions_2 = xgb.dask.inplace_predict(
|
||||
client, model_2.get_booster(), X, predict_type="margin"
|
||||
@ -391,40 +458,45 @@ def run_boost_from_prediction_multi_class(
|
||||
|
||||
|
||||
def run_boost_from_prediction(
|
||||
X: xgb.dask._DaskCollection,
|
||||
y: xgb.dask._DaskCollection,
|
||||
X: dd.DataFrame,
|
||||
y: dd.Series,
|
||||
tree_method: str,
|
||||
client: "Client",
|
||||
) -> None:
|
||||
X = client.persist(X)
|
||||
y = client.persist(y)
|
||||
X, y = client.persist([X, y])
|
||||
|
||||
model_0 = xgb.dask.DaskXGBClassifier(
|
||||
learning_rate=0.3, n_estimators=4, tree_method=tree_method, max_bin=512
|
||||
)
|
||||
X, y, _ = deterministic_repartition(client, X, y, None)
|
||||
model_0.fit(X=X, y=y)
|
||||
margin = model_0.predict(X, output_margin=True)
|
||||
margin: dd.Series = model_0.predict(X, output_margin=True)
|
||||
|
||||
model_1 = xgb.dask.DaskXGBClassifier(
|
||||
learning_rate=0.3, n_estimators=4, tree_method=tree_method, max_bin=512
|
||||
)
|
||||
X, y, margin = deterministic_repartition(client, X, y, margin)
|
||||
model_1.fit(X=X, y=y, base_margin=margin)
|
||||
predictions_1 = model_1.predict(X, base_margin=margin)
|
||||
X, y, margin = deterministic_repartition(client, X, y, margin)
|
||||
predictions_1: dd.Series = model_1.predict(X, base_margin=margin)
|
||||
|
||||
cls_2 = xgb.dask.DaskXGBClassifier(
|
||||
learning_rate=0.3, n_estimators=8, tree_method=tree_method, max_bin=512
|
||||
)
|
||||
X, y, _ = deterministic_repartition(client, X, y, None)
|
||||
cls_2.fit(X=X, y=y)
|
||||
predictions_2 = cls_2.predict(X)
|
||||
predictions_2: dd.Series = cls_2.predict(X)
|
||||
|
||||
assert np.all(predictions_1.compute() == predictions_2.compute())
|
||||
|
||||
margined = xgb.dask.DaskXGBClassifier(n_estimators=4)
|
||||
X, y, margin = deterministic_repartition(client, X, y, margin)
|
||||
margined.fit(
|
||||
X=X, y=y, base_margin=margin, eval_set=[(X, y)], base_margin_eval_set=[margin]
|
||||
)
|
||||
|
||||
unmargined = xgb.dask.DaskXGBClassifier(n_estimators=4)
|
||||
X, y, margin = deterministic_repartition(client, X, y, margin)
|
||||
unmargined.fit(X=X, y=y, eval_set=[(X, y)], base_margin=margin)
|
||||
|
||||
margined_res = margined.evals_result()["validation_0"]["logloss"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user