[dask] Test for data initializaton. (#6226)
This commit is contained in:
@@ -17,20 +17,23 @@ if sys.platform.startswith("win"):
|
||||
pytestmark = pytest.mark.skipif(**tm.no_dask())
|
||||
|
||||
try:
|
||||
from distributed import LocalCluster, Client
|
||||
from distributed import LocalCluster, Client, get_client
|
||||
from distributed.utils_test import client, loop, cluster_fixture
|
||||
import dask.dataframe as dd
|
||||
import dask.array as da
|
||||
from xgboost.dask import DaskDMatrix
|
||||
import dask
|
||||
except ImportError:
|
||||
LocalCluster = None
|
||||
Client = None
|
||||
get_client = None
|
||||
client = None
|
||||
loop = None
|
||||
cluster_fixture = None
|
||||
dd = None
|
||||
da = None
|
||||
DaskDMatrix = None
|
||||
dask = None
|
||||
|
||||
kRows = 1000
|
||||
kCols = 10
|
||||
@@ -142,7 +145,7 @@ def test_boost_from_prediction(tree_method):
|
||||
y_ = dd.from_array(y, chunksize=100)
|
||||
|
||||
with LocalCluster(n_workers=4) as cluster:
|
||||
with Client(cluster) as client:
|
||||
with Client(cluster) as _:
|
||||
model_0 = xgb.dask.DaskXGBClassifier(
|
||||
learning_rate=0.3,
|
||||
random_state=123,
|
||||
@@ -744,3 +747,39 @@ class TestDaskCallbacks:
|
||||
assert hasattr(booster, 'best_score')
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
def test_data_initialization(self):
|
||||
'''Assert each worker has the correct amount of data, and DMatrix initialization doesn't
|
||||
generate unnecessary copies of data.
|
||||
|
||||
'''
|
||||
with LocalCluster(n_workers=2) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y = generate_array()
|
||||
n_partitions = X.npartitions
|
||||
m = xgb.dask.DaskDMatrix(client, X, y)
|
||||
workers = list(xgb.dask._get_client_workers(client).keys())
|
||||
rabit_args = client.sync(xgb.dask._get_rabit_args, workers, client)
|
||||
n_workers = len(workers)
|
||||
|
||||
def worker_fn(worker_addr, data_ref):
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
local_dtrain = xgb.dask._dmatrix_from_worker_map(**data_ref)
|
||||
assert local_dtrain.num_row() == kRows / n_workers
|
||||
|
||||
futures = client.map(
|
||||
worker_fn, workers, [m.create_fn_args()] * len(workers),
|
||||
pure=False, workers=workers)
|
||||
client.gather(futures)
|
||||
|
||||
has_what = client.has_what()
|
||||
cnt = 0
|
||||
data = set()
|
||||
for k, v in has_what.items():
|
||||
for d in v:
|
||||
cnt += 1
|
||||
data.add(d)
|
||||
|
||||
assert len(data) == cnt
|
||||
# Subtract the on disk resource from each worker
|
||||
assert cnt - n_workers == n_partitions
|
||||
|
||||
Reference in New Issue
Block a user