diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index b5bafe453..cd542ba70 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -407,21 +407,28 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes Parameters ---------- - cache_prefix: + cache_prefix : Prefix to the cache files, only used in external memory. It can be either an URI or a file path. + release_data : + Whether the iterator should release the data during reset. Set it to True if the + data transformation (converting data to np.float32 type) is expensive. """ - def __init__(self, cache_prefix: Optional[str] = None) -> None: + def __init__( + self, cache_prefix: Optional[str] = None, release_data: bool = True + ) -> None: self.cache_prefix = cache_prefix self._handle = _ProxyDMatrix() self._exception: Optional[Exception] = None self._enable_categorical = False self._allow_host = True + self._release = release_data # Stage data in Python until reset or next is called to avoid data being free. - self._temporary_data: Optional[Tuple[Any, Any]] = None + self._temporary_data: Optional[Tuple[Any, Any, Any, Any]] = None + self._input_id: int = 0 def get_callbacks( self, allow_host: bool, enable_categorical: bool @@ -477,7 +484,8 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes def _reset_wrapper(self, this: None) -> None: # pylint: disable=unused-argument """A wrapper for user defined `reset` function.""" # free the data - self._temporary_data = None + if self._release: + self._temporary_data = None self._handle_exception(self.reset, None) def _next_wrapper(self, this: None) -> int: # pylint: disable=unused-argument @@ -498,20 +506,25 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes ) -> None: from .data import _proxy_transform, dispatch_proxy_set_data - new, cat_codes, feature_names, feature_types = _proxy_transform( - data, - feature_names, - feature_types, - self._enable_categorical, - ) + # Reduce the amount of transformation that's needed for QuantileDMatrix. + if self._temporary_data is not None and id(data) == self._input_id: + new, cat_codes, feature_names, feature_types = self._temporary_data + else: + new, cat_codes, feature_names, feature_types = _proxy_transform( + data, + feature_names, + feature_types, + self._enable_categorical, + ) # Stage the data, meta info are copied inside C++ MetaInfo. - self._temporary_data = (new, cat_codes) + self._temporary_data = (new, cat_codes, feature_names, feature_types) dispatch_proxy_set_data(self.proxy, new, cat_codes, self._allow_host) self.proxy.set_info( feature_names=feature_names, feature_types=feature_types, **kwargs, ) + self._input_id = id(data) # pylint: disable=not-callable return self._handle_exception(lambda: self.next(input_data), 0) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index d22769a3a..026b1c6ea 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -1174,7 +1174,10 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902 def __init__(self, **kwargs: Any) -> None: self.kwargs = kwargs self.it = 0 # pylint: disable=invalid-name - super().__init__() + + # This does not necessarily increase memory usage as the data transformation + # might use memory. + super().__init__(release_data=False) def next(self, input_data: Callable) -> int: if self.it == 1: diff --git a/tests/python/test_data_iterator.py b/tests/python/test_data_iterator.py index cf81288e8..4b4258a21 100644 --- a/tests/python/test_data_iterator.py +++ b/tests/python/test_data_iterator.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Callable, Dict, List import numpy as np import pytest @@ -153,3 +153,30 @@ def test_data_iterator( run_data_iterator( n_samples_per_batch, n_features, n_batches, "hist", subsample, False ) + + +class IterForCacheTest(xgb.DataIter): + def __init__(self, x: np.ndarray, y: np.ndarray, w: np.ndarray) -> None: + self.kwargs = {"data": x, "label": y, "weight": w} + super().__init__(release_data=False) + + def next(self, input_data: Callable) -> int: + if self.it == 1: + return 0 + self.it += 1 + input_data(**self.kwargs) + return 1 + + def reset(self) -> None: + self.it = 0 + + +def test_data_cache() -> None: + n_batches = 1 + n_features = 2 + n_samples_per_batch = 16 + data = make_batches(n_samples_per_batch, n_features, n_batches, False) + batches = [v[0] for v in data] + it = IterForCacheTest(*batches) + xgb.QuantileDMatrix(it) + assert it._input_id == id(batches[0])