[EM] Basic distributed test for external memory. (#10492)
This commit is contained in:
@@ -248,13 +248,14 @@ class IteratorForTest(xgb.core.DataIter):
|
||||
return X, y, w
|
||||
|
||||
|
||||
def make_batches(
|
||||
def make_batches( # pylint: disable=too-many-arguments,too-many-locals
|
||||
n_samples_per_batch: int,
|
||||
n_features: int,
|
||||
n_batches: int,
|
||||
use_cupy: bool = False,
|
||||
*,
|
||||
vary_size: bool = False,
|
||||
random_state: int = 1994,
|
||||
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
|
||||
X = []
|
||||
y = []
|
||||
@@ -262,9 +263,9 @@ def make_batches(
|
||||
if use_cupy:
|
||||
import cupy
|
||||
|
||||
rng = cupy.random.RandomState(1994)
|
||||
rng = cupy.random.RandomState(random_state)
|
||||
else:
|
||||
rng = np.random.RandomState(1994)
|
||||
rng = np.random.RandomState(random_state)
|
||||
for i in range(n_batches):
|
||||
n_samples = n_samples_per_batch + i * 10 if vary_size else n_samples_per_batch
|
||||
_X = rng.randn(n_samples, n_features)
|
||||
|
||||
Reference in New Issue
Block a user