Release data in cache. (#10286)

This commit is contained in:
Jiaming Yuan
2024-05-14 14:20:19 +08:00
committed by GitHub
parent f1f69ff10e
commit ca1d04bcb7
5 changed files with 46 additions and 39 deletions

View File

@@ -160,9 +160,11 @@ def test_data_iterator(
class IterForCacheTest(xgb.DataIter):
def __init__(self, x: np.ndarray, y: np.ndarray, w: np.ndarray) -> None:
def __init__(
self, x: np.ndarray, y: np.ndarray, w: np.ndarray, release_data: bool
) -> None:
self.kwargs = {"data": x, "label": y, "weight": w}
super().__init__(release_data=False)
super().__init__(release_data=release_data)
def next(self, input_data: Callable) -> int:
if self.it == 1:
@@ -181,7 +183,9 @@ def test_data_cache() -> None:
n_samples_per_batch = 16
data = make_batches(n_samples_per_batch, n_features, n_batches, False)
batches = [v[0] for v in data]
it = IterForCacheTest(*batches)
# Test with a cache.
it = IterForCacheTest(batches[0], batches[1], batches[2], release_data=False)
transform = xgb.data._proxy_transform
called = 0
@@ -196,6 +200,12 @@ def test_data_cache() -> None:
assert it._data_ref is weakref.ref(batches[0])
assert called == 1
# Test without a cache.
called = 0
it = IterForCacheTest(batches[0], batches[1], batches[2], release_data=True)
xgb.QuantileDMatrix(it)
assert called == 4
xgb.data._proxy_transform = transform