Fix loading DMatrix binary in distributed env. (#8149)
- Try to load DMatrix binary before trying to parse text input. - Remove some unmaintained code.
This commit is contained in:
@@ -118,9 +118,10 @@ def make_categorical(
|
||||
|
||||
|
||||
def generate_array(
|
||||
with_weights: bool = False
|
||||
) -> Tuple[xgb.dask._DaskCollection, xgb.dask._DaskCollection,
|
||||
Optional[xgb.dask._DaskCollection]]:
|
||||
with_weights: bool = False,
|
||||
) -> Tuple[
|
||||
xgb.dask._DataT, xgb.dask._DaskCollection, Optional[xgb.dask._DaskCollection]
|
||||
]:
|
||||
chunk_size = 20
|
||||
rng = da.random.RandomState(1994)
|
||||
X = rng.random_sample((kRows, kCols), chunks=(chunk_size, -1))
|
||||
@@ -1265,6 +1266,50 @@ def test_dask_iteration_range(client: "Client"):
|
||||
|
||||
|
||||
class TestWithDask:
|
||||
def test_dmatrix_binary(self, client: "Client") -> None:
|
||||
def save_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
rank = xgb.rabit.get_rank()
|
||||
X, y = tm.make_categorical(100, 4, 4, False)
|
||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||
path = os.path.join(tmpdir, f"{rank}.bin")
|
||||
Xy.save_binary(path)
|
||||
|
||||
def load_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
rank = xgb.rabit.get_rank()
|
||||
path = os.path.join(tmpdir, f"{rank}.bin")
|
||||
Xy = xgb.DMatrix(path)
|
||||
assert Xy.num_row() == 100
|
||||
assert Xy.num_col() == 4
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
workers = _get_client_workers(client)
|
||||
rabit_args = client.sync(
|
||||
xgb.dask._get_rabit_args, len(workers), None, client
|
||||
)
|
||||
futures = []
|
||||
for w in workers:
|
||||
# same argument for each worker, must set pure to False otherwise dask
|
||||
# will try to reuse the result from the first worker and hang waiting
|
||||
# for it.
|
||||
f = client.submit(
|
||||
save_dmatrix, rabit_args, tmpdir, workers=[w], pure=False
|
||||
)
|
||||
futures.append(f)
|
||||
client.gather(futures)
|
||||
|
||||
rabit_args = client.sync(
|
||||
xgb.dask._get_rabit_args, len(workers), None, client
|
||||
)
|
||||
futures = []
|
||||
for w in workers:
|
||||
f = client.submit(
|
||||
load_dmatrix, rabit_args, tmpdir, workers=[w], pure=False
|
||||
)
|
||||
futures.append(f)
|
||||
client.gather(futures)
|
||||
|
||||
@pytest.mark.parametrize('config_key,config_value', [('verbosity', 0), ('use_rmm', True)])
|
||||
def test_global_config(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user