Release data in cache. (#10286)

This commit is contained in:
Jiaming Yuan 2024-05-14 14:20:19 +08:00 committed by GitHub
parent f1f69ff10e
commit ca1d04bcb7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 46 additions and 39 deletions

View File

@ -504,8 +504,10 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
cache_prefix : cache_prefix :
Prefix to the cache files, only used in external memory. Prefix to the cache files, only used in external memory.
release_data : release_data :
Whether the iterator should release the data during reset. Set it to True if the Whether the iterator should release the data during iteration. Set it to True if
data transformation (converting data to np.float32 type) is expensive. the data transformation (converting data to np.float32 type) is memory
intensive. Otherwise, if the transformation is computation intensive then we can
keep the cache.
""" """
@ -517,15 +519,12 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
self._handle = _ProxyDMatrix() self._handle = _ProxyDMatrix()
self._exception: Optional[Exception] = None self._exception: Optional[Exception] = None
self._enable_categorical = False self._enable_categorical = False
self._allow_host = True
self._release = release_data self._release = release_data
# Stage data in Python until reset or next is called to avoid data being free. # Stage data in Python until reset or next is called to avoid data being free.
self._temporary_data: Optional[TransformedData] = None self._temporary_data: Optional[TransformedData] = None
self._data_ref: Optional[weakref.ReferenceType] = None self._data_ref: Optional[weakref.ReferenceType] = None
def get_callbacks( def get_callbacks(self, enable_categorical: bool) -> Tuple[Callable, Callable]:
self, allow_host: bool, enable_categorical: bool
) -> Tuple[Callable, Callable]:
"""Get callback functions for iterating in C. This is an internal function.""" """Get callback functions for iterating in C. This is an internal function."""
assert hasattr(self, "cache_prefix"), "__init__ is not called." assert hasattr(self, "cache_prefix"), "__init__ is not called."
self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)( self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(
@ -535,7 +534,6 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
ctypes.c_int, ctypes.c_int,
ctypes.c_void_p, ctypes.c_void_p,
)(self._next_wrapper) )(self._next_wrapper)
self._allow_host = allow_host
self._enable_categorical = enable_categorical self._enable_categorical = enable_categorical
return self._reset_callback, self._next_callback return self._reset_callback, self._next_callback
@ -624,7 +622,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
) )
# Stage the data, meta info are copied inside C++ MetaInfo. # Stage the data, meta info are copied inside C++ MetaInfo.
self._temporary_data = (new, cat_codes, feature_names, feature_types) self._temporary_data = (new, cat_codes, feature_names, feature_types)
dispatch_proxy_set_data(self.proxy, new, cat_codes, self._allow_host) dispatch_proxy_set_data(self.proxy, new, cat_codes)
self.proxy.set_info( self.proxy.set_info(
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types, feature_types=feature_types,
@ -632,6 +630,9 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
) )
self._data_ref = ref self._data_ref = ref
# Release the data before next batch is loaded.
if self._release:
self._temporary_data = None
# pylint: disable=not-callable # pylint: disable=not-callable
return self._handle_exception(lambda: self.next(input_data), 0) return self._handle_exception(lambda: self.next(input_data), 0)
@ -911,7 +912,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
} }
args_cstr = from_pystr_to_cstr(json.dumps(args)) args_cstr = from_pystr_to_cstr(json.dumps(args))
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
reset_callback, next_callback = it.get_callbacks(True, enable_categorical) reset_callback, next_callback = it.get_callbacks(enable_categorical)
ret = _LIB.XGDMatrixCreateFromCallback( ret = _LIB.XGDMatrixCreateFromCallback(
None, None,
it.proxy.handle, it.proxy.handle,
@ -1437,37 +1438,37 @@ class _ProxyDMatrix(DMatrix):
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
_check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(self.handle))) _check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(self.handle)))
def _set_data_from_cuda_interface(self, data: DataType) -> None: def _ref_data_from_cuda_interface(self, data: DataType) -> None:
"""Set data from CUDA array interface.""" """Reference data from CUDA array interface."""
interface = data.__cuda_array_interface__ interface = data.__cuda_array_interface__
interface_str = bytes(json.dumps(interface), "utf-8") interface_str = bytes(json.dumps(interface), "utf-8")
_check_call( _check_call(
_LIB.XGProxyDMatrixSetDataCudaArrayInterface(self.handle, interface_str) _LIB.XGProxyDMatrixSetDataCudaArrayInterface(self.handle, interface_str)
) )
def _set_data_from_cuda_columnar(self, data: DataType, cat_codes: list) -> None: def _ref_data_from_cuda_columnar(self, data: DataType, cat_codes: list) -> None:
"""Set data from CUDA columnar format.""" """Reference data from CUDA columnar format."""
from .data import _cudf_array_interfaces from .data import _cudf_array_interfaces
interfaces_str = _cudf_array_interfaces(data, cat_codes) interfaces_str = _cudf_array_interfaces(data, cat_codes)
_check_call(_LIB.XGProxyDMatrixSetDataCudaColumnar(self.handle, interfaces_str)) _check_call(_LIB.XGProxyDMatrixSetDataCudaColumnar(self.handle, interfaces_str))
def _set_data_from_array(self, data: np.ndarray) -> None: def _ref_data_from_array(self, data: np.ndarray) -> None:
"""Set data from numpy array.""" """Reference data from numpy array."""
from .data import _array_interface from .data import _array_interface
_check_call( _check_call(
_LIB.XGProxyDMatrixSetDataDense(self.handle, _array_interface(data)) _LIB.XGProxyDMatrixSetDataDense(self.handle, _array_interface(data))
) )
def _set_data_from_pandas(self, data: DataType) -> None: def _ref_data_from_pandas(self, data: DataType) -> None:
"""Set data from a pandas DataFrame. The input is a PandasTransformed instance.""" """Reference data from a pandas DataFrame. The input is a PandasTransformed instance."""
_check_call( _check_call(
_LIB.XGProxyDMatrixSetDataColumnar(self.handle, data.array_interface()) _LIB.XGProxyDMatrixSetDataColumnar(self.handle, data.array_interface())
) )
def _set_data_from_csr(self, csr: scipy.sparse.csr_matrix) -> None: def _ref_data_from_csr(self, csr: scipy.sparse.csr_matrix) -> None:
"""Set data from scipy csr""" """Reference data from scipy csr."""
from .data import _array_interface from .data import _array_interface
_LIB.XGProxyDMatrixSetDataCSR( _LIB.XGProxyDMatrixSetDataCSR(
@ -1609,7 +1610,7 @@ class QuantileDMatrix(DMatrix):
it = SingleBatchInternalIter(data=data, **meta) it = SingleBatchInternalIter(data=data, **meta)
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
reset_callback, next_callback = it.get_callbacks(True, enable_categorical) reset_callback, next_callback = it.get_callbacks(enable_categorical)
if it.cache_prefix is not None: if it.cache_prefix is not None:
raise ValueError( raise ValueError(
"QuantileDMatrix doesn't cache data, remove the cache_prefix " "QuantileDMatrix doesn't cache data, remove the cache_prefix "

View File

@ -616,7 +616,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
assert isinstance(self._label_upper_bound, types) assert isinstance(self._label_upper_bound, types)
self._iter = 0 # set iterator to 0 self._iter = 0 # set iterator to 0
super().__init__() super().__init__(release_data=True)
def _get(self, attr: str) -> Optional[Any]: def _get(self, attr: str) -> Optional[Any]:
if getattr(self, attr) is not None: if getattr(self, attr) is not None:

View File

@ -1467,7 +1467,6 @@ def dispatch_proxy_set_data(
proxy: _ProxyDMatrix, proxy: _ProxyDMatrix,
data: DataType, data: DataType,
cat_codes: Optional[list], cat_codes: Optional[list],
allow_host: bool,
) -> None: ) -> None:
"""Dispatch for QuantileDMatrix.""" """Dispatch for QuantileDMatrix."""
if not _is_cudf_ser(data) and not _is_pandas_series(data): if not _is_cudf_ser(data) and not _is_pandas_series(data):
@ -1475,33 +1474,30 @@ def dispatch_proxy_set_data(
if _is_cudf_df(data): if _is_cudf_df(data):
# pylint: disable=W0212 # pylint: disable=W0212
proxy._set_data_from_cuda_columnar(data, cast(List, cat_codes)) proxy._ref_data_from_cuda_columnar(data, cast(List, cat_codes))
return return
if _is_cudf_ser(data): if _is_cudf_ser(data):
# pylint: disable=W0212 # pylint: disable=W0212
proxy._set_data_from_cuda_columnar(data, cast(List, cat_codes)) proxy._ref_data_from_cuda_columnar(data, cast(List, cat_codes))
return return
if _is_cupy_alike(data): if _is_cupy_alike(data):
proxy._set_data_from_cuda_interface(data) # pylint: disable=W0212 proxy._ref_data_from_cuda_interface(data) # pylint: disable=W0212
return return
if _is_dlpack(data): if _is_dlpack(data):
data = _transform_dlpack(data) data = _transform_dlpack(data)
proxy._set_data_from_cuda_interface(data) # pylint: disable=W0212 proxy._ref_data_from_cuda_interface(data) # pylint: disable=W0212
return return
# Host
err = TypeError("Value type is not supported for data iterator:" + str(type(data)))
if not allow_host:
raise err
if isinstance(data, PandasTransformed): if isinstance(data, PandasTransformed):
proxy._set_data_from_pandas(data) # pylint: disable=W0212 proxy._ref_data_from_pandas(data) # pylint: disable=W0212
return return
if _is_np_array_like(data): if _is_np_array_like(data):
_check_data_shape(data) _check_data_shape(data)
proxy._set_data_from_array(data) # pylint: disable=W0212 proxy._ref_data_from_array(data) # pylint: disable=W0212
return return
if is_scipy_csr(data): if is_scipy_csr(data):
proxy._set_data_from_csr(data) # pylint: disable=W0212 proxy._ref_data_from_csr(data) # pylint: disable=W0212
return return
err = TypeError("Value type is not supported for data iterator:" + str(type(data)))
raise err raise err

View File

@ -77,7 +77,7 @@ class PartIter(DataIter):
self._data = data self._data = data
self._kwargs = kwargs self._kwargs = kwargs
super().__init__() super().__init__(release_data=True)
def _fetch(self, data: Optional[Sequence[pd.DataFrame]]) -> Optional[pd.DataFrame]: def _fetch(self, data: Optional[Sequence[pd.DataFrame]]) -> Optional[pd.DataFrame]:
if not data: if not data:

View File

@ -160,9 +160,11 @@ def test_data_iterator(
class IterForCacheTest(xgb.DataIter): class IterForCacheTest(xgb.DataIter):
def __init__(self, x: np.ndarray, y: np.ndarray, w: np.ndarray) -> None: def __init__(
self, x: np.ndarray, y: np.ndarray, w: np.ndarray, release_data: bool
) -> None:
self.kwargs = {"data": x, "label": y, "weight": w} self.kwargs = {"data": x, "label": y, "weight": w}
super().__init__(release_data=False) super().__init__(release_data=release_data)
def next(self, input_data: Callable) -> int: def next(self, input_data: Callable) -> int:
if self.it == 1: if self.it == 1:
@ -181,7 +183,9 @@ def test_data_cache() -> None:
n_samples_per_batch = 16 n_samples_per_batch = 16
data = make_batches(n_samples_per_batch, n_features, n_batches, False) data = make_batches(n_samples_per_batch, n_features, n_batches, False)
batches = [v[0] for v in data] batches = [v[0] for v in data]
it = IterForCacheTest(*batches)
# Test with a cache.
it = IterForCacheTest(batches[0], batches[1], batches[2], release_data=False)
transform = xgb.data._proxy_transform transform = xgb.data._proxy_transform
called = 0 called = 0
@ -196,6 +200,12 @@ def test_data_cache() -> None:
assert it._data_ref is weakref.ref(batches[0]) assert it._data_ref is weakref.ref(batches[0])
assert called == 1 assert called == 1
# Test without a cache.
called = 0
it = IterForCacheTest(batches[0], batches[1], batches[2], release_data=True)
xgb.QuantileDMatrix(it)
assert called == 4
xgb.data._proxy_transform = transform xgb.data._proxy_transform = transform