[dask] Test for data initializaton. (#6226)
This commit is contained in:
parent
2443275891
commit
b05073bda5
@ -326,10 +326,10 @@ class DaskDMatrix:
|
|||||||
self.partition_order[part.key] = i
|
self.partition_order[part.key] = i
|
||||||
|
|
||||||
key_to_partition = {part.key: part for part in parts}
|
key_to_partition = {part.key: part for part in parts}
|
||||||
who_has = await client.scheduler.who_has(
|
who_has = await client.scheduler.who_has(keys=[part.key for part in parts])
|
||||||
keys=[part.key for part in parts])
|
|
||||||
|
|
||||||
worker_map = defaultdict(list)
|
worker_map = defaultdict(list)
|
||||||
|
|
||||||
for key, workers in who_has.items():
|
for key, workers in who_has.items():
|
||||||
worker_map[next(iter(workers))].append(key_to_partition[key])
|
worker_map[next(iter(workers))].append(key_to_partition[key])
|
||||||
|
|
||||||
@ -651,9 +651,9 @@ async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
|
|||||||
'The evaluation history is returned as result of training.')
|
'The evaluation history is returned as result of training.')
|
||||||
|
|
||||||
workers = list(_get_client_workers(client).keys())
|
workers = list(_get_client_workers(client).keys())
|
||||||
rabit_args = await _get_rabit_args(workers, client)
|
_rabit_args = await _get_rabit_args(workers, client)
|
||||||
|
|
||||||
def dispatched_train(worker_addr, dtrain_ref, evals_ref):
|
def dispatched_train(worker_addr, rabit_args, dtrain_ref, evals_ref):
|
||||||
'''Perform training on a single worker. A local function prevents pickling.
|
'''Perform training on a single worker. A local function prevents pickling.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
@ -699,8 +699,13 @@ async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
|
|||||||
if evals:
|
if evals:
|
||||||
evals = [(e.create_fn_args(), name) for e, name in evals]
|
evals = [(e.create_fn_args(), name) for e, name in evals]
|
||||||
|
|
||||||
|
# Note for function purity:
|
||||||
|
# XGBoost is deterministic in most of the cases, which means train function is
|
||||||
|
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
|
||||||
|
# We haven't been able to do a full verification so here we keep pure to be False.
|
||||||
futures = client.map(dispatched_train,
|
futures = client.map(dispatched_train,
|
||||||
workers,
|
workers,
|
||||||
|
[_rabit_args] * len(workers),
|
||||||
[dtrain.create_fn_args()] * len(workers),
|
[dtrain.create_fn_args()] * len(workers),
|
||||||
[evals] * len(workers),
|
[evals] * len(workers),
|
||||||
pure=False,
|
pure=False,
|
||||||
|
|||||||
@ -369,6 +369,7 @@ size_t SketchContainer::ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_col
|
|||||||
* from user input data. Second is duplicated sketching entries, which is generated by
|
* from user input data. Second is duplicated sketching entries, which is generated by
|
||||||
* prunning or merging. We preserve the first type and remove the second type.
|
* prunning or merging. We preserve the first type and remove the second type.
|
||||||
*/
|
*/
|
||||||
|
timer_.Start(__func__);
|
||||||
dh::safe_cuda(cudaSetDevice(device_));
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
CHECK_EQ(d_columns_ptr_in.size(), num_columns_ + 1);
|
CHECK_EQ(d_columns_ptr_in.size(), num_columns_ + 1);
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
|
|||||||
@ -17,20 +17,23 @@ if sys.platform.startswith("win"):
|
|||||||
pytestmark = pytest.mark.skipif(**tm.no_dask())
|
pytestmark = pytest.mark.skipif(**tm.no_dask())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from distributed import LocalCluster, Client
|
from distributed import LocalCluster, Client, get_client
|
||||||
from distributed.utils_test import client, loop, cluster_fixture
|
from distributed.utils_test import client, loop, cluster_fixture
|
||||||
import dask.dataframe as dd
|
import dask.dataframe as dd
|
||||||
import dask.array as da
|
import dask.array as da
|
||||||
from xgboost.dask import DaskDMatrix
|
from xgboost.dask import DaskDMatrix
|
||||||
|
import dask
|
||||||
except ImportError:
|
except ImportError:
|
||||||
LocalCluster = None
|
LocalCluster = None
|
||||||
Client = None
|
Client = None
|
||||||
|
get_client = None
|
||||||
client = None
|
client = None
|
||||||
loop = None
|
loop = None
|
||||||
cluster_fixture = None
|
cluster_fixture = None
|
||||||
dd = None
|
dd = None
|
||||||
da = None
|
da = None
|
||||||
DaskDMatrix = None
|
DaskDMatrix = None
|
||||||
|
dask = None
|
||||||
|
|
||||||
kRows = 1000
|
kRows = 1000
|
||||||
kCols = 10
|
kCols = 10
|
||||||
@ -142,7 +145,7 @@ def test_boost_from_prediction(tree_method):
|
|||||||
y_ = dd.from_array(y, chunksize=100)
|
y_ = dd.from_array(y, chunksize=100)
|
||||||
|
|
||||||
with LocalCluster(n_workers=4) as cluster:
|
with LocalCluster(n_workers=4) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as _:
|
||||||
model_0 = xgb.dask.DaskXGBClassifier(
|
model_0 = xgb.dask.DaskXGBClassifier(
|
||||||
learning_rate=0.3,
|
learning_rate=0.3,
|
||||||
random_state=123,
|
random_state=123,
|
||||||
@ -744,3 +747,39 @@ class TestDaskCallbacks:
|
|||||||
assert hasattr(booster, 'best_score')
|
assert hasattr(booster, 'best_score')
|
||||||
dump = booster.get_dump(dump_format='json')
|
dump = booster.get_dump(dump_format='json')
|
||||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||||
|
|
||||||
|
def test_data_initialization(self):
|
||||||
|
'''Assert each worker has the correct amount of data, and DMatrix initialization doesn't
|
||||||
|
generate unnecessary copies of data.
|
||||||
|
|
||||||
|
'''
|
||||||
|
with LocalCluster(n_workers=2) as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
|
X, y = generate_array()
|
||||||
|
n_partitions = X.npartitions
|
||||||
|
m = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
|
workers = list(xgb.dask._get_client_workers(client).keys())
|
||||||
|
rabit_args = client.sync(xgb.dask._get_rabit_args, workers, client)
|
||||||
|
n_workers = len(workers)
|
||||||
|
|
||||||
|
def worker_fn(worker_addr, data_ref):
|
||||||
|
with xgb.dask.RabitContext(rabit_args):
|
||||||
|
local_dtrain = xgb.dask._dmatrix_from_worker_map(**data_ref)
|
||||||
|
assert local_dtrain.num_row() == kRows / n_workers
|
||||||
|
|
||||||
|
futures = client.map(
|
||||||
|
worker_fn, workers, [m.create_fn_args()] * len(workers),
|
||||||
|
pure=False, workers=workers)
|
||||||
|
client.gather(futures)
|
||||||
|
|
||||||
|
has_what = client.has_what()
|
||||||
|
cnt = 0
|
||||||
|
data = set()
|
||||||
|
for k, v in has_what.items():
|
||||||
|
for d in v:
|
||||||
|
cnt += 1
|
||||||
|
data.add(d)
|
||||||
|
|
||||||
|
assert len(data) == cnt
|
||||||
|
# Subtract the on disk resource from each worker
|
||||||
|
assert cnt - n_workers == n_partitions
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user