Cache transformed in QuantileDMatrix for efficiency. (#8666)
This commit is contained in:
@@ -407,21 +407,28 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cache_prefix:
|
||||
cache_prefix :
|
||||
Prefix to the cache files, only used in external memory. It can be either an
|
||||
URI or a file path.
|
||||
release_data :
|
||||
Whether the iterator should release the data during reset. Set it to True if the
|
||||
data transformation (converting data to np.float32 type) is expensive.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, cache_prefix: Optional[str] = None) -> None:
|
||||
def __init__(
|
||||
self, cache_prefix: Optional[str] = None, release_data: bool = True
|
||||
) -> None:
|
||||
self.cache_prefix = cache_prefix
|
||||
|
||||
self._handle = _ProxyDMatrix()
|
||||
self._exception: Optional[Exception] = None
|
||||
self._enable_categorical = False
|
||||
self._allow_host = True
|
||||
self._release = release_data
|
||||
# Stage data in Python until reset or next is called to avoid data being free.
|
||||
self._temporary_data: Optional[Tuple[Any, Any]] = None
|
||||
self._temporary_data: Optional[Tuple[Any, Any, Any, Any]] = None
|
||||
self._input_id: int = 0
|
||||
|
||||
def get_callbacks(
|
||||
self, allow_host: bool, enable_categorical: bool
|
||||
@@ -477,7 +484,8 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
def _reset_wrapper(self, this: None) -> None: # pylint: disable=unused-argument
|
||||
"""A wrapper for user defined `reset` function."""
|
||||
# free the data
|
||||
self._temporary_data = None
|
||||
if self._release:
|
||||
self._temporary_data = None
|
||||
self._handle_exception(self.reset, None)
|
||||
|
||||
def _next_wrapper(self, this: None) -> int: # pylint: disable=unused-argument
|
||||
@@ -498,20 +506,25 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
) -> None:
|
||||
from .data import _proxy_transform, dispatch_proxy_set_data
|
||||
|
||||
new, cat_codes, feature_names, feature_types = _proxy_transform(
|
||||
data,
|
||||
feature_names,
|
||||
feature_types,
|
||||
self._enable_categorical,
|
||||
)
|
||||
# Reduce the amount of transformation that's needed for QuantileDMatrix.
|
||||
if self._temporary_data is not None and id(data) == self._input_id:
|
||||
new, cat_codes, feature_names, feature_types = self._temporary_data
|
||||
else:
|
||||
new, cat_codes, feature_names, feature_types = _proxy_transform(
|
||||
data,
|
||||
feature_names,
|
||||
feature_types,
|
||||
self._enable_categorical,
|
||||
)
|
||||
# Stage the data, meta info are copied inside C++ MetaInfo.
|
||||
self._temporary_data = (new, cat_codes)
|
||||
self._temporary_data = (new, cat_codes, feature_names, feature_types)
|
||||
dispatch_proxy_set_data(self.proxy, new, cat_codes, self._allow_host)
|
||||
self.proxy.set_info(
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
**kwargs,
|
||||
)
|
||||
self._input_id = id(data)
|
||||
# pylint: disable=not-callable
|
||||
return self._handle_exception(lambda: self.next(input_data), 0)
|
||||
|
||||
|
||||
@@ -1174,7 +1174,10 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.it = 0 # pylint: disable=invalid-name
|
||||
super().__init__()
|
||||
|
||||
# This does not necessarily increase memory usage as the data transformation
|
||||
# might use memory.
|
||||
super().__init__(release_data=False)
|
||||
|
||||
def next(self, input_data: Callable) -> int:
|
||||
if self.it == 1:
|
||||
|
||||
Reference in New Issue
Block a user