Cache transformed in QuantileDMatrix for efficiency. (#8666)
This commit is contained in:
parent
06ba285f71
commit
247946a875
@ -407,21 +407,28 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cache_prefix:
|
||||
cache_prefix :
|
||||
Prefix to the cache files, only used in external memory. It can be either an
|
||||
URI or a file path.
|
||||
release_data :
|
||||
Whether the iterator should release the data during reset. Set it to True if the
|
||||
data transformation (converting data to np.float32 type) is expensive.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, cache_prefix: Optional[str] = None) -> None:
|
||||
def __init__(
|
||||
self, cache_prefix: Optional[str] = None, release_data: bool = True
|
||||
) -> None:
|
||||
self.cache_prefix = cache_prefix
|
||||
|
||||
self._handle = _ProxyDMatrix()
|
||||
self._exception: Optional[Exception] = None
|
||||
self._enable_categorical = False
|
||||
self._allow_host = True
|
||||
self._release = release_data
|
||||
# Stage data in Python until reset or next is called to avoid data being free.
|
||||
self._temporary_data: Optional[Tuple[Any, Any]] = None
|
||||
self._temporary_data: Optional[Tuple[Any, Any, Any, Any]] = None
|
||||
self._input_id: int = 0
|
||||
|
||||
def get_callbacks(
|
||||
self, allow_host: bool, enable_categorical: bool
|
||||
@ -477,6 +484,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
def _reset_wrapper(self, this: None) -> None: # pylint: disable=unused-argument
|
||||
"""A wrapper for user defined `reset` function."""
|
||||
# free the data
|
||||
if self._release:
|
||||
self._temporary_data = None
|
||||
self._handle_exception(self.reset, None)
|
||||
|
||||
@ -498,6 +506,10 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
) -> None:
|
||||
from .data import _proxy_transform, dispatch_proxy_set_data
|
||||
|
||||
# Reduce the amount of transformation that's needed for QuantileDMatrix.
|
||||
if self._temporary_data is not None and id(data) == self._input_id:
|
||||
new, cat_codes, feature_names, feature_types = self._temporary_data
|
||||
else:
|
||||
new, cat_codes, feature_names, feature_types = _proxy_transform(
|
||||
data,
|
||||
feature_names,
|
||||
@ -505,13 +517,14 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
self._enable_categorical,
|
||||
)
|
||||
# Stage the data, meta info are copied inside C++ MetaInfo.
|
||||
self._temporary_data = (new, cat_codes)
|
||||
self._temporary_data = (new, cat_codes, feature_names, feature_types)
|
||||
dispatch_proxy_set_data(self.proxy, new, cat_codes, self._allow_host)
|
||||
self.proxy.set_info(
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
**kwargs,
|
||||
)
|
||||
self._input_id = id(data)
|
||||
# pylint: disable=not-callable
|
||||
return self._handle_exception(lambda: self.next(input_data), 0)
|
||||
|
||||
|
||||
@ -1174,7 +1174,10 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.it = 0 # pylint: disable=invalid-name
|
||||
super().__init__()
|
||||
|
||||
# This does not necessarily increase memory usage as the data transformation
|
||||
# might use memory.
|
||||
super().__init__(release_data=False)
|
||||
|
||||
def next(self, input_data: Callable) -> int:
|
||||
if self.it == 1:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Dict, List
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -153,3 +153,30 @@ def test_data_iterator(
|
||||
run_data_iterator(
|
||||
n_samples_per_batch, n_features, n_batches, "hist", subsample, False
|
||||
)
|
||||
|
||||
|
||||
class IterForCacheTest(xgb.DataIter):
|
||||
def __init__(self, x: np.ndarray, y: np.ndarray, w: np.ndarray) -> None:
|
||||
self.kwargs = {"data": x, "label": y, "weight": w}
|
||||
super().__init__(release_data=False)
|
||||
|
||||
def next(self, input_data: Callable) -> int:
|
||||
if self.it == 1:
|
||||
return 0
|
||||
self.it += 1
|
||||
input_data(**self.kwargs)
|
||||
return 1
|
||||
|
||||
def reset(self) -> None:
|
||||
self.it = 0
|
||||
|
||||
|
||||
def test_data_cache() -> None:
|
||||
n_batches = 1
|
||||
n_features = 2
|
||||
n_samples_per_batch = 16
|
||||
data = make_batches(n_samples_per_batch, n_features, n_batches, False)
|
||||
batches = [v[0] for v in data]
|
||||
it = IterForCacheTest(*batches)
|
||||
xgb.QuantileDMatrix(it)
|
||||
assert it._input_id == id(batches[0])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user