[EM] Basic distributed test for external memory. (#10492)

This commit is contained in:
Jiaming Yuan
2024-07-06 01:15:20 +08:00
committed by GitHub
parent 513d7a7d84
commit 00264eb72b
3 changed files with 93 additions and 3 deletions

View File

@@ -248,13 +248,14 @@ class IteratorForTest(xgb.core.DataIter):
return X, y, w
def make_batches(
def make_batches( # pylint: disable=too-many-arguments,too-many-locals
n_samples_per_batch: int,
n_features: int,
n_batches: int,
use_cupy: bool = False,
*,
vary_size: bool = False,
random_state: int = 1994,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
X = []
y = []
@@ -262,9 +263,9 @@ def make_batches(
if use_cupy:
import cupy
rng = cupy.random.RandomState(1994)
rng = cupy.random.RandomState(random_state)
else:
rng = np.random.RandomState(1994)
rng = np.random.RandomState(random_state)
for i in range(n_batches):
n_samples = n_samples_per_batch + i * 10 if vary_size else n_samples_per_batch
_X = rng.randn(n_samples, n_features)