Cache transformed in QuantileDMatrix for efficiency. (#8666)
This commit is contained in:
parent
06ba285f71
commit
247946a875
@ -407,21 +407,28 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
cache_prefix:
|
cache_prefix :
|
||||||
Prefix to the cache files, only used in external memory. It can be either an
|
Prefix to the cache files, only used in external memory. It can be either an
|
||||||
URI or a file path.
|
URI or a file path.
|
||||||
|
release_data :
|
||||||
|
Whether the iterator should release the data during reset. Set it to True if the
|
||||||
|
data transformation (converting data to np.float32 type) is expensive.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cache_prefix: Optional[str] = None) -> None:
|
def __init__(
|
||||||
|
self, cache_prefix: Optional[str] = None, release_data: bool = True
|
||||||
|
) -> None:
|
||||||
self.cache_prefix = cache_prefix
|
self.cache_prefix = cache_prefix
|
||||||
|
|
||||||
self._handle = _ProxyDMatrix()
|
self._handle = _ProxyDMatrix()
|
||||||
self._exception: Optional[Exception] = None
|
self._exception: Optional[Exception] = None
|
||||||
self._enable_categorical = False
|
self._enable_categorical = False
|
||||||
self._allow_host = True
|
self._allow_host = True
|
||||||
|
self._release = release_data
|
||||||
# Stage data in Python until reset or next is called to avoid data being free.
|
# Stage data in Python until reset or next is called to avoid data being free.
|
||||||
self._temporary_data: Optional[Tuple[Any, Any]] = None
|
self._temporary_data: Optional[Tuple[Any, Any, Any, Any]] = None
|
||||||
|
self._input_id: int = 0
|
||||||
|
|
||||||
def get_callbacks(
|
def get_callbacks(
|
||||||
self, allow_host: bool, enable_categorical: bool
|
self, allow_host: bool, enable_categorical: bool
|
||||||
@ -477,6 +484,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
|||||||
def _reset_wrapper(self, this: None) -> None: # pylint: disable=unused-argument
|
def _reset_wrapper(self, this: None) -> None: # pylint: disable=unused-argument
|
||||||
"""A wrapper for user defined `reset` function."""
|
"""A wrapper for user defined `reset` function."""
|
||||||
# free the data
|
# free the data
|
||||||
|
if self._release:
|
||||||
self._temporary_data = None
|
self._temporary_data = None
|
||||||
self._handle_exception(self.reset, None)
|
self._handle_exception(self.reset, None)
|
||||||
|
|
||||||
@ -498,6 +506,10 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
|||||||
) -> None:
|
) -> None:
|
||||||
from .data import _proxy_transform, dispatch_proxy_set_data
|
from .data import _proxy_transform, dispatch_proxy_set_data
|
||||||
|
|
||||||
|
# Reduce the amount of transformation that's needed for QuantileDMatrix.
|
||||||
|
if self._temporary_data is not None and id(data) == self._input_id:
|
||||||
|
new, cat_codes, feature_names, feature_types = self._temporary_data
|
||||||
|
else:
|
||||||
new, cat_codes, feature_names, feature_types = _proxy_transform(
|
new, cat_codes, feature_names, feature_types = _proxy_transform(
|
||||||
data,
|
data,
|
||||||
feature_names,
|
feature_names,
|
||||||
@ -505,13 +517,14 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
|||||||
self._enable_categorical,
|
self._enable_categorical,
|
||||||
)
|
)
|
||||||
# Stage the data, meta info are copied inside C++ MetaInfo.
|
# Stage the data, meta info are copied inside C++ MetaInfo.
|
||||||
self._temporary_data = (new, cat_codes)
|
self._temporary_data = (new, cat_codes, feature_names, feature_types)
|
||||||
dispatch_proxy_set_data(self.proxy, new, cat_codes, self._allow_host)
|
dispatch_proxy_set_data(self.proxy, new, cat_codes, self._allow_host)
|
||||||
self.proxy.set_info(
|
self.proxy.set_info(
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=feature_types,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
self._input_id = id(data)
|
||||||
# pylint: disable=not-callable
|
# pylint: disable=not-callable
|
||||||
return self._handle_exception(lambda: self.next(input_data), 0)
|
return self._handle_exception(lambda: self.next(input_data), 0)
|
||||||
|
|
||||||
|
|||||||
@ -1174,7 +1174,10 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
|||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
self.it = 0 # pylint: disable=invalid-name
|
self.it = 0 # pylint: disable=invalid-name
|
||||||
super().__init__()
|
|
||||||
|
# This does not necessarily increase memory usage as the data transformation
|
||||||
|
# might use memory.
|
||||||
|
super().__init__(release_data=False)
|
||||||
|
|
||||||
def next(self, input_data: Callable) -> int:
|
def next(self, input_data: Callable) -> int:
|
||||||
if self.it == 1:
|
if self.it == 1:
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List
|
from typing import Callable, Dict, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -153,3 +153,30 @@ def test_data_iterator(
|
|||||||
run_data_iterator(
|
run_data_iterator(
|
||||||
n_samples_per_batch, n_features, n_batches, "hist", subsample, False
|
n_samples_per_batch, n_features, n_batches, "hist", subsample, False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IterForCacheTest(xgb.DataIter):
|
||||||
|
def __init__(self, x: np.ndarray, y: np.ndarray, w: np.ndarray) -> None:
|
||||||
|
self.kwargs = {"data": x, "label": y, "weight": w}
|
||||||
|
super().__init__(release_data=False)
|
||||||
|
|
||||||
|
def next(self, input_data: Callable) -> int:
|
||||||
|
if self.it == 1:
|
||||||
|
return 0
|
||||||
|
self.it += 1
|
||||||
|
input_data(**self.kwargs)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.it = 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_data_cache() -> None:
|
||||||
|
n_batches = 1
|
||||||
|
n_features = 2
|
||||||
|
n_samples_per_batch = 16
|
||||||
|
data = make_batches(n_samples_per_batch, n_features, n_batches, False)
|
||||||
|
batches = [v[0] for v in data]
|
||||||
|
it = IterForCacheTest(*batches)
|
||||||
|
xgb.QuantileDMatrix(it)
|
||||||
|
assert it._input_id == id(batches[0])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user