Release data in cache. (#10286)
This commit is contained in:
@@ -160,9 +160,11 @@ def test_data_iterator(
|
||||
|
||||
|
||||
class IterForCacheTest(xgb.DataIter):
|
||||
def __init__(self, x: np.ndarray, y: np.ndarray, w: np.ndarray) -> None:
|
||||
def __init__(
|
||||
self, x: np.ndarray, y: np.ndarray, w: np.ndarray, release_data: bool
|
||||
) -> None:
|
||||
self.kwargs = {"data": x, "label": y, "weight": w}
|
||||
super().__init__(release_data=False)
|
||||
super().__init__(release_data=release_data)
|
||||
|
||||
def next(self, input_data: Callable) -> int:
|
||||
if self.it == 1:
|
||||
@@ -181,7 +183,9 @@ def test_data_cache() -> None:
|
||||
n_samples_per_batch = 16
|
||||
data = make_batches(n_samples_per_batch, n_features, n_batches, False)
|
||||
batches = [v[0] for v in data]
|
||||
it = IterForCacheTest(*batches)
|
||||
|
||||
# Test with a cache.
|
||||
it = IterForCacheTest(batches[0], batches[1], batches[2], release_data=False)
|
||||
transform = xgb.data._proxy_transform
|
||||
|
||||
called = 0
|
||||
@@ -196,6 +200,12 @@ def test_data_cache() -> None:
|
||||
assert it._data_ref is weakref.ref(batches[0])
|
||||
assert called == 1
|
||||
|
||||
# Test without a cache.
|
||||
called = 0
|
||||
it = IterForCacheTest(batches[0], batches[1], batches[2], release_data=True)
|
||||
xgb.QuantileDMatrix(it)
|
||||
assert called == 4
|
||||
|
||||
xgb.data._proxy_transform = transform
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user