Cache transformed in QuantileDMatrix for efficiency. (#8666)

This commit is contained in:
Jiaming Yuan 2023-01-17 06:02:40 +08:00 committed by GitHub
parent 06ba285f71
commit 247946a875
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 13 deletions

View File

@ -407,21 +407,28 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
Parameters
----------
cache_prefix:
cache_prefix :
Prefix to the cache files, only used in external memory. It can be either an
URI or a file path.
release_data :
Whether the iterator should release the data during reset. Set it to True if the
data transformation (converting data to np.float32 type) is expensive.
"""
def __init__(self, cache_prefix: Optional[str] = None) -> None:
def __init__(
self, cache_prefix: Optional[str] = None, release_data: bool = True
) -> None:
self.cache_prefix = cache_prefix
self._handle = _ProxyDMatrix()
self._exception: Optional[Exception] = None
self._enable_categorical = False
self._allow_host = True
self._release = release_data
# Stage data in Python until reset or next is called to avoid data being free.
self._temporary_data: Optional[Tuple[Any, Any]] = None
self._temporary_data: Optional[Tuple[Any, Any, Any, Any]] = None
self._input_id: int = 0
def get_callbacks(
self, allow_host: bool, enable_categorical: bool
@ -477,7 +484,8 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
def _reset_wrapper(self, this: None) -> None: # pylint: disable=unused-argument
"""A wrapper for user defined `reset` function."""
# free the data
self._temporary_data = None
if self._release:
self._temporary_data = None
self._handle_exception(self.reset, None)
def _next_wrapper(self, this: None) -> int: # pylint: disable=unused-argument
@ -498,20 +506,25 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
) -> None:
from .data import _proxy_transform, dispatch_proxy_set_data
new, cat_codes, feature_names, feature_types = _proxy_transform(
data,
feature_names,
feature_types,
self._enable_categorical,
)
# Reduce the amount of transformation that's needed for QuantileDMatrix.
if self._temporary_data is not None and id(data) == self._input_id:
new, cat_codes, feature_names, feature_types = self._temporary_data
else:
new, cat_codes, feature_names, feature_types = _proxy_transform(
data,
feature_names,
feature_types,
self._enable_categorical,
)
# Stage the data, meta info are copied inside C++ MetaInfo.
self._temporary_data = (new, cat_codes)
self._temporary_data = (new, cat_codes, feature_names, feature_types)
dispatch_proxy_set_data(self.proxy, new, cat_codes, self._allow_host)
self.proxy.set_info(
feature_names=feature_names,
feature_types=feature_types,
**kwargs,
)
self._input_id = id(data)
# pylint: disable=not-callable
return self._handle_exception(lambda: self.next(input_data), 0)

View File

@ -1174,7 +1174,10 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
def __init__(self, **kwargs: Any) -> None:
self.kwargs = kwargs
self.it = 0 # pylint: disable=invalid-name
super().__init__()
# This does not necessarily increase memory usage as the data transformation
# might use memory.
super().__init__(release_data=False)
def next(self, input_data: Callable) -> int:
if self.it == 1:

View File

@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Callable, Dict, List
import numpy as np
import pytest
@ -153,3 +153,30 @@ def test_data_iterator(
run_data_iterator(
n_samples_per_batch, n_features, n_batches, "hist", subsample, False
)
class IterForCacheTest(xgb.DataIter):
def __init__(self, x: np.ndarray, y: np.ndarray, w: np.ndarray) -> None:
self.kwargs = {"data": x, "label": y, "weight": w}
super().__init__(release_data=False)
def next(self, input_data: Callable) -> int:
if self.it == 1:
return 0
self.it += 1
input_data(**self.kwargs)
return 1
def reset(self) -> None:
self.it = 0
def test_data_cache() -> None:
n_batches = 1
n_features = 2
n_samples_per_batch = 16
data = make_batches(n_samples_per_batch, n_features, n_batches, False)
batches = [v[0] for v in data]
it = IterForCacheTest(*batches)
xgb.QuantileDMatrix(it)
assert it._input_id == id(batches[0])