Fix loading DMatrix binary in distributed env. (#8149)

- Try to load DMatrix binary before trying to parse text input.
- Remove some unmaintained code.
This commit is contained in:
Jiaming Yuan
2022-08-10 22:53:16 +08:00
committed by GitHub
parent 8fc60b31bc
commit 446d536c23
3 changed files with 71 additions and 80 deletions

View File

@@ -118,9 +118,10 @@ def make_categorical(
def generate_array(
with_weights: bool = False
) -> Tuple[xgb.dask._DaskCollection, xgb.dask._DaskCollection,
Optional[xgb.dask._DaskCollection]]:
with_weights: bool = False,
) -> Tuple[
xgb.dask._DataT, xgb.dask._DaskCollection, Optional[xgb.dask._DaskCollection]
]:
chunk_size = 20
rng = da.random.RandomState(1994)
X = rng.random_sample((kRows, kCols), chunks=(chunk_size, -1))
@@ -1265,6 +1266,50 @@ def test_dask_iteration_range(client: "Client"):
class TestWithDask:
def test_dmatrix_binary(self, client: "Client") -> None:
def save_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
with xgb.dask.RabitContext(rabit_args):
rank = xgb.rabit.get_rank()
X, y = tm.make_categorical(100, 4, 4, False)
Xy = xgb.DMatrix(X, y, enable_categorical=True)
path = os.path.join(tmpdir, f"{rank}.bin")
Xy.save_binary(path)
def load_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
with xgb.dask.RabitContext(rabit_args):
rank = xgb.rabit.get_rank()
path = os.path.join(tmpdir, f"{rank}.bin")
Xy = xgb.DMatrix(path)
assert Xy.num_row() == 100
assert Xy.num_col() == 4
with tempfile.TemporaryDirectory() as tmpdir:
workers = _get_client_workers(client)
rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), None, client
)
futures = []
for w in workers:
# same argument for each worker, must set pure to False otherwise dask
# will try to reuse the result from the first worker and hang waiting
# for it.
f = client.submit(
save_dmatrix, rabit_args, tmpdir, workers=[w], pure=False
)
futures.append(f)
client.gather(futures)
rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), None, client
)
futures = []
for w in workers:
f = client.submit(
load_dmatrix, rabit_args, tmpdir, workers=[w], pure=False
)
futures.append(f)
client.gather(futures)
@pytest.mark.parametrize('config_key,config_value', [('verbosity', 0), ('use_rmm', True)])
def test_global_config(
self,