Cache transformed in QuantileDMatrix for efficiency. (#8666)

This commit is contained in:
Jiaming Yuan
2023-01-17 06:02:40 +08:00
committed by GitHub
parent 06ba285f71
commit 247946a875
3 changed files with 56 additions and 13 deletions

View File

@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Callable, Dict, List
import numpy as np
import pytest
@@ -153,3 +153,30 @@ def test_data_iterator(
run_data_iterator(
n_samples_per_batch, n_features, n_batches, "hist", subsample, False
)
class IterForCacheTest(xgb.DataIter):
def __init__(self, x: np.ndarray, y: np.ndarray, w: np.ndarray) -> None:
self.kwargs = {"data": x, "label": y, "weight": w}
super().__init__(release_data=False)
def next(self, input_data: Callable) -> int:
if self.it == 1:
return 0
self.it += 1
input_data(**self.kwargs)
return 1
def reset(self) -> None:
self.it = 0
def test_data_cache() -> None:
n_batches = 1
n_features = 2
n_samples_per_batch = 16
data = make_batches(n_samples_per_batch, n_features, n_batches, False)
batches = [v[0] for v in data]
it = IterForCacheTest(*batches)
xgb.QuantileDMatrix(it)
assert it._input_id == id(batches[0])