Cache transformed in QuantileDMatrix for efficiency. (#8666)

This commit is contained in:
Jiaming Yuan 2023-01-17 06:02:40 +08:00 committed by GitHub
parent 06ba285f71
commit 247946a875
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 13 deletions

View File

@ -410,18 +410,25 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
cache_prefix : cache_prefix :
Prefix to the cache files, only used in external memory. It can be either an Prefix to the cache files, only used in external memory. It can be either an
URI or a file path. URI or a file path.
release_data :
Whether the iterator should release the data during reset. Set it to True if the
data transformation (converting data to np.float32 type) is expensive.
""" """
def __init__(self, cache_prefix: Optional[str] = None) -> None: def __init__(
self, cache_prefix: Optional[str] = None, release_data: bool = True
) -> None:
self.cache_prefix = cache_prefix self.cache_prefix = cache_prefix
self._handle = _ProxyDMatrix() self._handle = _ProxyDMatrix()
self._exception: Optional[Exception] = None self._exception: Optional[Exception] = None
self._enable_categorical = False self._enable_categorical = False
self._allow_host = True self._allow_host = True
self._release = release_data
# Stage data in Python until reset or next is called to avoid data being free. # Stage data in Python until reset or next is called to avoid data being free.
self._temporary_data: Optional[Tuple[Any, Any]] = None self._temporary_data: Optional[Tuple[Any, Any, Any, Any]] = None
self._input_id: int = 0
def get_callbacks( def get_callbacks(
self, allow_host: bool, enable_categorical: bool self, allow_host: bool, enable_categorical: bool
@ -477,6 +484,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
def _reset_wrapper(self, this: None) -> None: # pylint: disable=unused-argument def _reset_wrapper(self, this: None) -> None: # pylint: disable=unused-argument
"""A wrapper for user defined `reset` function.""" """A wrapper for user defined `reset` function."""
# free the data # free the data
if self._release:
self._temporary_data = None self._temporary_data = None
self._handle_exception(self.reset, None) self._handle_exception(self.reset, None)
@ -498,6 +506,10 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
) -> None: ) -> None:
from .data import _proxy_transform, dispatch_proxy_set_data from .data import _proxy_transform, dispatch_proxy_set_data
# Reduce the amount of transformation that's needed for QuantileDMatrix.
if self._temporary_data is not None and id(data) == self._input_id:
new, cat_codes, feature_names, feature_types = self._temporary_data
else:
new, cat_codes, feature_names, feature_types = _proxy_transform( new, cat_codes, feature_names, feature_types = _proxy_transform(
data, data,
feature_names, feature_names,
@ -505,13 +517,14 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
self._enable_categorical, self._enable_categorical,
) )
# Stage the data, meta info are copied inside C++ MetaInfo. # Stage the data, meta info are copied inside C++ MetaInfo.
self._temporary_data = (new, cat_codes) self._temporary_data = (new, cat_codes, feature_names, feature_types)
dispatch_proxy_set_data(self.proxy, new, cat_codes, self._allow_host) dispatch_proxy_set_data(self.proxy, new, cat_codes, self._allow_host)
self.proxy.set_info( self.proxy.set_info(
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types, feature_types=feature_types,
**kwargs, **kwargs,
) )
self._input_id = id(data)
# pylint: disable=not-callable # pylint: disable=not-callable
return self._handle_exception(lambda: self.next(input_data), 0) return self._handle_exception(lambda: self.next(input_data), 0)

View File

@ -1174,7 +1174,10 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
self.kwargs = kwargs self.kwargs = kwargs
self.it = 0 # pylint: disable=invalid-name self.it = 0 # pylint: disable=invalid-name
super().__init__()
# This does not necessarily increase memory usage as the data transformation
# might use memory.
super().__init__(release_data=False)
def next(self, input_data: Callable) -> int: def next(self, input_data: Callable) -> int:
if self.it == 1: if self.it == 1:

View File

@ -1,4 +1,4 @@
from typing import Dict, List from typing import Callable, Dict, List
import numpy as np import numpy as np
import pytest import pytest
@ -153,3 +153,30 @@ def test_data_iterator(
run_data_iterator( run_data_iterator(
n_samples_per_batch, n_features, n_batches, "hist", subsample, False n_samples_per_batch, n_features, n_batches, "hist", subsample, False
) )
class IterForCacheTest(xgb.DataIter):
def __init__(self, x: np.ndarray, y: np.ndarray, w: np.ndarray) -> None:
self.kwargs = {"data": x, "label": y, "weight": w}
super().__init__(release_data=False)
def next(self, input_data: Callable) -> int:
if self.it == 1:
return 0
self.it += 1
input_data(**self.kwargs)
return 1
def reset(self) -> None:
self.it = 0
def test_data_cache() -> None:
n_batches = 1
n_features = 2
n_samples_per_batch = 16
data = make_batches(n_samples_per_batch, n_features, n_batches, False)
batches = [v[0] for v in data]
it = IterForCacheTest(*batches)
xgb.QuantileDMatrix(it)
assert it._input_id == id(batches[0])