From f05a23b41cb9b3e55fe9f53c3ca26203c1d42748 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 10 Aug 2023 00:40:06 +0800 Subject: [PATCH] Use `weakref` instead of `id` for `DataIter` cache. (#9445) - Fix case where Python reuses id from freed objects. - Small optimization to column matrix with QDM by using `realloc` instead of copying data. --- demo/guide-python/external_memory.py | 6 ++- doc/python/python_api.rst | 4 ++ python-package/xgboost/_typing.py | 11 +++- python-package/xgboost/core.py | 43 ++++++++++++--- python-package/xgboost/data.py | 10 ++-- python-package/xgboost/testing/__init__.py | 18 ++++--- src/common/column_matrix.h | 35 ++++++++----- src/common/io.h | 6 +-- tests/ci_build/lint_python.py | 1 + tests/cpp/common/test_column_matrix.cc | 61 +++++++++++++++++----- tests/cpp/common/test_io.cc | 14 +++++ tests/cpp/data/test_iterative_dmatrix.cc | 8 ++- tests/python/test_data_iterator.py | 18 ++++++- tests/python/test_quantile_dmatrix.py | 21 +++++++- 14 files changed, 193 insertions(+), 63 deletions(-) diff --git a/demo/guide-python/external_memory.py b/demo/guide-python/external_memory.py index fdaa9dab9..6d789486e 100644 --- a/demo/guide-python/external_memory.py +++ b/demo/guide-python/external_memory.py @@ -22,7 +22,10 @@ import xgboost def make_batches( - n_samples_per_batch: int, n_features: int, n_batches: int, tmpdir: str, + n_samples_per_batch: int, + n_features: int, + n_batches: int, + tmpdir: str, ) -> List[Tuple[str, str]]: files: List[Tuple[str, str]] = [] rng = np.random.RandomState(1994) @@ -38,6 +41,7 @@ def make_batches( class Iterator(xgboost.DataIter): """A custom iterator for loading files in batches.""" + def __init__(self, file_paths: List[Tuple[str, str]]): self._file_paths = file_paths self._it = 0 diff --git a/doc/python/python_api.rst b/doc/python/python_api.rst index 0cbf63456..38b22a994 100644 --- a/doc/python/python_api.rst +++ b/doc/python/python_api.rst @@ -23,12 +23,16 @@ Core Data Structure :show-inheritance: .. autoclass:: xgboost.QuantileDMatrix + :members: :show-inheritance: .. autoclass:: xgboost.Booster :members: :show-inheritance: +.. autoclass:: xgboost.DataIter + :members: + :show-inheritance: Learning API ------------ diff --git a/python-package/xgboost/_typing.py b/python-package/xgboost/_typing.py index 39952aca9..a36757a81 100644 --- a/python-package/xgboost/_typing.py +++ b/python-package/xgboost/_typing.py @@ -8,7 +8,9 @@ from typing import ( Callable, Dict, List, + Optional, Sequence, + Tuple, Type, TypeVar, Union, @@ -20,8 +22,6 @@ import numpy as np DataType = Any -# xgboost accepts some other possible types in practice due to historical reason, which is -# lesser tested. For now we encourage users to pass a simple list of string. FeatureInfo = Sequence[str] FeatureNames = FeatureInfo FeatureTypes = FeatureInfo @@ -97,6 +97,13 @@ else: ctypes._Pointer, ] +# The second arg is actually Optional[List[cudf.Series]], skipped for easier type check. +# The cudf Series is the obtained cat codes, preserved in the `DataIter` to prevent it +# being freed. +TransformedData = Tuple[ + Any, Optional[List], Optional[FeatureNames], Optional[FeatureTypes] +] + # template parameter _T = TypeVar("_T") _F = TypeVar("_F", bound=Callable[..., Any]) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 14a96f117..fbedfd7fb 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -9,6 +9,7 @@ import os import re import sys import warnings +import weakref from abc import ABC, abstractmethod from collections.abc import Mapping from enum import IntEnum, unique @@ -51,6 +52,7 @@ from ._typing import ( FeatureTypes, ModelIn, NumpyOrCupy, + TransformedData, c_bst_ulong, ) from .compat import PANDAS_INSTALLED, DataFrame, py_str @@ -486,7 +488,16 @@ def _prediction_output( class DataIter(ABC): # pylint: disable=too-many-instance-attributes - """The interface for user defined data iterator. + """The interface for user defined data iterator. The iterator facilitates + distributed training, :py:class:`QuantileDMatrix`, and external memory support using + :py:class:`DMatrix`. Most of time, users don't need to interact with this class + directly. + + .. note:: + + The class caches some intermediate results using the `data` input (predictor + `X`) as key. Don't repeat the `X` for multiple batches with different meta data + (like `label`), make a copy if necessary. Parameters ---------- @@ -510,13 +521,13 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes self._allow_host = True self._release = release_data # Stage data in Python until reset or next is called to avoid data being free. - self._temporary_data: Optional[Tuple[Any, Any, Any, Any]] = None - self._input_id: int = 0 + self._temporary_data: Optional[TransformedData] = None + self._data_ref: Optional[weakref.ReferenceType] = None def get_callbacks( self, allow_host: bool, enable_categorical: bool ) -> Tuple[Callable, Callable]: - """Get callback functions for iterating in C.""" + """Get callback functions for iterating in C. This is an internal function.""" assert hasattr(self, "cache_prefix"), "__init__ is not called." self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)( self._reset_wrapper @@ -591,7 +602,19 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes from .data import _proxy_transform, dispatch_proxy_set_data # Reduce the amount of transformation that's needed for QuantileDMatrix. - if self._temporary_data is not None and id(data) == self._input_id: + # + # To construct the QDM, one needs 4 iterations on CPU, or 2 iterations on + # GPU. If the QDM has only one batch of input (most of the cases), we can + # avoid transforming the data repeatly. + try: + ref = weakref.ref(data) + except TypeError: + ref = None + if ( + self._temporary_data is not None + and ref is not None + and ref is self._data_ref + ): new, cat_codes, feature_names, feature_types = self._temporary_data else: new, cat_codes, feature_names, feature_types = _proxy_transform( @@ -608,7 +631,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes feature_types=feature_types, **kwargs, ) - self._input_id = id(data) + self._data_ref = ref # pylint: disable=not-callable return self._handle_exception(lambda: self.next(input_data), 0) @@ -1134,7 +1157,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m testing purposes. If this is a quantized DMatrix then quantized values are returned instead of input values. - .. versionadded:: 1.7.0 + .. versionadded:: 1.7.0 """ indptr = np.empty(self.num_row() + 1, dtype=np.uint64) @@ -1155,7 +1178,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m return ret def get_quantile_cut(self) -> Tuple[np.ndarray, np.ndarray]: - """Get quantile cuts for quantization.""" + """Get quantile cuts for quantization. + + .. versionadded:: 2.0.0 + + """ n_features = self.num_col() c_sindptr = ctypes.c_char_p() diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 7864d541f..04bdc5739 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -5,7 +5,7 @@ import ctypes import json import os import warnings -from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, cast import numpy as np @@ -17,6 +17,7 @@ from ._typing import ( FloatCompatible, NumpyDType, PandasDType, + TransformedData, c_bst_ulong, ) from .compat import DataFrame, lazy_isinstance @@ -1268,12 +1269,7 @@ def _proxy_transform( feature_names: Optional[FeatureNames], feature_types: Optional[FeatureTypes], enable_categorical: bool, -) -> Tuple[ - Union[bool, ctypes.c_void_p, np.ndarray], - Optional[list], - Optional[FeatureNames], - Optional[FeatureTypes], -]: +) -> TransformedData: if _is_cudf_df(data) or _is_cudf_ser(data): return _transform_cudf_df( data, feature_names, feature_types, enable_categorical diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 48809b46f..8a21b6085 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -230,7 +230,7 @@ class IteratorForTest(xgb.core.DataIter): def as_arrays( self, - ) -> Tuple[Union[np.ndarray, sparse.csr_matrix], ArrayLike, ArrayLike]: + ) -> Tuple[Union[np.ndarray, sparse.csr_matrix], ArrayLike, Optional[ArrayLike]]: if isinstance(self.X[0], sparse.csr_matrix): X = sparse.vstack(self.X, format="csr") else: @@ -244,7 +244,12 @@ class IteratorForTest(xgb.core.DataIter): def make_batches( - n_samples_per_batch: int, n_features: int, n_batches: int, use_cupy: bool = False + n_samples_per_batch: int, + n_features: int, + n_batches: int, + use_cupy: bool = False, + *, + vary_size: bool = False, ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: X = [] y = [] @@ -255,10 +260,11 @@ def make_batches( rng = cupy.random.RandomState(1994) else: rng = np.random.RandomState(1994) - for _ in range(n_batches): - _X = rng.randn(n_samples_per_batch, n_features) - _y = rng.randn(n_samples_per_batch) - _w = rng.uniform(low=0, high=1, size=n_samples_per_batch) + for i in range(n_batches): + n_samples = n_samples_per_batch + i * 10 if vary_size else n_samples_per_batch + _X = rng.randn(n_samples, n_features) + _y = rng.randn(n_samples) + _w = rng.uniform(low=0, high=1, size=n_samples) X.append(_X) y.append(_y) w.append(_w) diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 78361744d..0862c21ad 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -9,12 +9,12 @@ #define XGBOOST_COMMON_COLUMN_MATRIX_H_ #include -#include // for size_t +#include // for size_t, byte #include // for uint8_t #include #include -#include // for move -#include +#include // for enable_if_t, is_same_v, is_signed_v +#include // for move #include "../data/adapter.h" #include "../data/gradient_index.h" @@ -112,9 +112,6 @@ class SparseColumnIter : public Column { */ template class DenseColumnIter : public Column { - public: - using ByteType = bool; - private: using Base = Column; /* flags for missing values in dense columns */ @@ -153,8 +150,17 @@ class ColumnMatrix { * @brief A bit set for indicating whether an element in a dense column is missing. */ struct MissingIndicator { - LBitField32 missing; - RefResourceView storage; + using BitFieldT = LBitField32; + using T = typename BitFieldT::value_type; + + BitFieldT missing; + RefResourceView storage; + static_assert(std::is_same_v); + + template + [[nodiscard]] std::enable_if_t, U> static InitValue(bool init) { + return init ? ~U{0} : U{0}; + } MissingIndicator() = default; /** @@ -163,7 +169,7 @@ class ColumnMatrix { */ MissingIndicator(std::size_t n_elements, bool init) { auto m_size = missing.ComputeStorageSize(n_elements); - storage = common::MakeFixedVecWithMalloc(m_size, init ? ~std::uint32_t{0} : std::uint32_t{0}); + storage = common::MakeFixedVecWithMalloc(m_size, InitValue(init)); this->InitView(); } /** @brief Set the i^th element to be a valid element (instead of missing). */ @@ -181,11 +187,12 @@ class ColumnMatrix { if (m_size == storage.size()) { return; } + // grow the storage + auto resource = std::dynamic_pointer_cast(storage.Resource()); + CHECK(resource); + resource->Resize(m_size * sizeof(T), InitValue(init)); + storage = RefResourceView{resource->DataAs(), m_size, resource}; - auto new_storage = - common::MakeFixedVecWithMalloc(m_size, init ? ~std::uint32_t{0} : std::uint32_t{0}); - std::copy_n(storage.cbegin(), storage.size(), new_storage.begin()); - storage = std::move(new_storage); this->InitView(); } }; @@ -210,7 +217,6 @@ class ColumnMatrix { } public: - using ByteType = bool; // get number of features [[nodiscard]] bst_feature_t GetNumFeature() const { return static_cast(type_.size()); @@ -408,6 +414,7 @@ class ColumnMatrix { // IO procedures for external memory. [[nodiscard]] bool Read(AlignedResourceReadStream* fi, uint32_t const* index_base); [[nodiscard]] std::size_t Write(AlignedFileWriteStream* fo) const; + [[nodiscard]] MissingIndicator const& Missing() const { return missing_; } private: RefResourceView index_; diff --git a/src/common/io.h b/src/common/io.h index baf518aa5..95971abae 100644 --- a/src/common/io.h +++ b/src/common/io.h @@ -10,7 +10,7 @@ #include #include -#include // for min +#include // for min, fill_n, copy_n #include // for array #include // for byte, size_t #include // for malloc, realloc, free @@ -207,7 +207,7 @@ class MallocResource : public ResourceHandler { * @param n_bytes The new size. */ template - void Resize(std::size_t n_bytes) { + void Resize(std::size_t n_bytes, std::byte init = std::byte{0}) { // realloc(ptr, 0) works, but is deprecated. if (n_bytes == 0) { this->Clear(); @@ -236,7 +236,7 @@ class MallocResource : public ResourceHandler { std::copy_n(reinterpret_cast(ptr_), n_, reinterpret_cast(new_ptr)); } // default initialize - std::memset(reinterpret_cast(new_ptr) + n_, '\0', n_bytes - n_); + std::fill_n(reinterpret_cast(new_ptr) + n_, n_bytes - n_, init); // free the old ptr if malloc is used. if (need_copy) { this->Clear(); diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index 9749a8485..e6cfb462b 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -42,6 +42,7 @@ class LintersPaths: "demo/guide-python/feature_weights.py", "demo/guide-python/sklearn_parallel.py", "demo/guide-python/spark_estimator_examples.py", + "demo/guide-python/external_memory.py", "demo/guide-python/individual_trees.py", "demo/guide-python/quantile_regression.py", "demo/guide-python/multioutput_regression.py", diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc index 8b8df4861..8445dc466 100644 --- a/tests/cpp/common/test_column_matrix.cc +++ b/tests/cpp/common/test_column_matrix.cc @@ -2,15 +2,26 @@ * Copyright 2018-2023 by XGBoost Contributors */ #include +#include // for bst_bin_t +#include // for Context +#include // for BatchIterator, BatchSet, DMatrix, Met... -#include "../../../src/common/column_matrix.h" -#include "../helpers.h" +#include // for size_t +#include // for int32_t, uint16_t, uint8_t +#include // for numeric_limits +#include // for shared_ptr, __shared_ptr_access, allo... +#include // for remove_reference_t +#include "../../../src/common/column_matrix.h" // for ColumnMatrix, Column, DenseColumnIter +#include "../../../src/common/hist_util.h" // for DispatchBinType, BinTypeSize, Index +#include "../../../src/common/ref_resource_view.h" // for RefResourceView +#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix +#include "../../../src/data/iterative_dmatrix.h" // for IterativeDMatrix +#include "../../../src/tree/param.h" // for TrainParam +#include "../helpers.h" // for RandomDataGenerator, NumpyArrayIterFo... -namespace xgboost { -namespace common { - -TEST(DenseColumn, Test) { +namespace xgboost::common { +TEST(ColumnMatrix, Basic) { int32_t max_num_bins[] = {static_cast(std::numeric_limits::max()) + 1, static_cast(std::numeric_limits::max()) + 1, static_cast(std::numeric_limits::max()) + 2}; @@ -22,7 +33,7 @@ TEST(DenseColumn, Test) { GHistIndexMatrix gmat{&ctx, dmat.get(), max_num_bin, sparse_thresh, false}; ColumnMatrix column_matrix; for (auto const& page : dmat->GetBatches()) { - column_matrix.InitFromSparse(page, gmat, sparse_thresh, AllThreadsForTest()); + column_matrix.InitFromSparse(page, gmat, sparse_thresh, ctx.Threads()); } ASSERT_GE(column_matrix.GetTypeSize(), last); ASSERT_LE(column_matrix.GetTypeSize(), kUint32BinsTypeSize); @@ -59,7 +70,7 @@ void CheckSparseColumn(SparseColumnIter* p_col, const GHistIndexMatr } } -TEST(SparseColumn, Test) { +TEST(ColumnMatrix, SparseColumn) { int32_t max_num_bins[] = {static_cast(std::numeric_limits::max()) + 1, static_cast(std::numeric_limits::max()) + 1, static_cast(std::numeric_limits::max()) + 2}; @@ -69,7 +80,7 @@ TEST(SparseColumn, Test) { GHistIndexMatrix gmat{&ctx, dmat.get(), max_num_bin, 0.5f, false}; ColumnMatrix column_matrix; for (auto const& page : dmat->GetBatches()) { - column_matrix.InitFromSparse(page, gmat, 1.0, AllThreadsForTest()); + column_matrix.InitFromSparse(page, gmat, 1.0, ctx.Threads()); } common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) { using T = decltype(dtype); @@ -90,7 +101,7 @@ void CheckColumWithMissingValue(const DenseColumnIter& col, } } -TEST(DenseColumnWithMissing, Test) { +TEST(ColumnMatrix, DenseColumnWithMissing) { int32_t max_num_bins[] = {static_cast(std::numeric_limits::max()) + 1, static_cast(std::numeric_limits::max()) + 1, static_cast(std::numeric_limits::max()) + 2}; @@ -100,7 +111,7 @@ TEST(DenseColumnWithMissing, Test) { GHistIndexMatrix gmat(&ctx, dmat.get(), max_num_bin, 0.2, false); ColumnMatrix column_matrix; for (auto const& page : dmat->GetBatches()) { - column_matrix.InitFromSparse(page, gmat, 0.2, AllThreadsForTest()); + column_matrix.InitFromSparse(page, gmat, 0.2, ctx.Threads()); } ASSERT_TRUE(column_matrix.AnyMissing()); DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) { @@ -110,5 +121,29 @@ TEST(DenseColumnWithMissing, Test) { }); } } -} // namespace common -} // namespace xgboost + +TEST(ColumnMatrix, GrowMissing) { + float sparsity = 0.5; + NumpyArrayIterForTest iter(sparsity); + auto n_threads = 0; + bst_bin_t n_bins = 16; + BatchParam batch{n_bins, tree::TrainParam::DftSparseThreshold()}; + Context ctx; + auto m = std::make_shared(&iter, iter.Proxy(), nullptr, Reset, Next, + std::numeric_limits::quiet_NaN(), + n_threads, n_bins); + for (auto const& page : m->GetBatches(&ctx, batch)) { + auto const& column_matrix = page.Transpose(); + auto const& missing = column_matrix.Missing(); + auto n = NumpyArrayIterForTest::Rows() * NumpyArrayIterForTest::Cols(); + auto expected = std::remove_reference_t::BitFieldT::ComputeStorageSize(n); + auto got = missing.storage.size(); + ASSERT_EQ(expected, got); + DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) { + using T = decltype(dtype); + auto col = column_matrix.DenseColumn(0); + CheckColumWithMissingValue(col, page); + }); + } +} +} // namespace xgboost::common diff --git a/tests/cpp/common/test_io.cc b/tests/cpp/common/test_io.cc index 986e58c5a..8bc12698b 100644 --- a/tests/cpp/common/test_io.cc +++ b/tests/cpp/common/test_io.cc @@ -119,6 +119,20 @@ TEST(IO, Resource) { for (std::size_t i = n; i < 2 * n; ++i) { ASSERT_EQ(malloc_resource->DataAs()[i], 0); } + + ptr = malloc_resource->DataAs(); + std::fill_n(ptr, malloc_resource->Size(), 7); + if (force_malloc) { + malloc_resource->Resize(n * 3, std::byte{3}); + } else { + malloc_resource->Resize(n * 3, std::byte{3}); + } + for (std::size_t i = 0; i < n * 2; ++i) { + ASSERT_EQ(malloc_resource->DataAs()[i], 7); + } + for (std::size_t i = n * 2; i < n * 3; ++i) { + ASSERT_EQ(malloc_resource->DataAs()[i], 3); + } }; test_malloc_resize(true); test_malloc_resize(false); diff --git a/tests/cpp/data/test_iterative_dmatrix.cc b/tests/cpp/data/test_iterative_dmatrix.cc index 74a69e109..b760b967d 100644 --- a/tests/cpp/data/test_iterative_dmatrix.cc +++ b/tests/cpp/data/test_iterative_dmatrix.cc @@ -12,8 +12,7 @@ #include "../helpers.h" #include "xgboost/data.h" // DMatrix -namespace xgboost { -namespace data { +namespace xgboost::data { TEST(IterativeDMatrix, Ref) { Context ctx; TestRefDMatrix( @@ -21,7 +20,7 @@ TEST(IterativeDMatrix, Ref) { } TEST(IterativeDMatrix, IsDense) { - int n_bins = 16; + bst_bin_t n_bins = 16; auto test = [n_bins](float sparsity) { NumpyArrayIterForTest iter(sparsity); auto n_threads = 0; @@ -38,5 +37,4 @@ TEST(IterativeDMatrix, IsDense) { test(0.1); test(1.0); } -} // namespace data -} // namespace xgboost +} // namespace xgboost::data diff --git a/tests/python/test_data_iterator.py b/tests/python/test_data_iterator.py index 24c117f15..e6bdfd2e7 100644 --- a/tests/python/test_data_iterator.py +++ b/tests/python/test_data_iterator.py @@ -1,4 +1,5 @@ -from typing import Callable, Dict, List +import weakref +from typing import Any, Callable, Dict, List import numpy as np import pytest @@ -179,5 +180,18 @@ def test_data_cache() -> None: data = make_batches(n_samples_per_batch, n_features, n_batches, False) batches = [v[0] for v in data] it = IterForCacheTest(*batches) + transform = xgb.data._proxy_transform + + called = 0 + + def mock(*args: Any, **kwargs: Any) -> Any: + nonlocal called + called += 1 + return transform(*args, **kwargs) + + xgb.data._proxy_transform = mock xgb.QuantileDMatrix(it) - assert it._input_id == id(batches[0]) + assert it._data_ref is weakref.ref(batches[0]) + assert called == 1 + + xgb.data._proxy_transform = transform diff --git a/tests/python/test_quantile_dmatrix.py b/tests/python/test_quantile_dmatrix.py index 8ee00b8c0..28a7eb37a 100644 --- a/tests/python/test_quantile_dmatrix.py +++ b/tests/python/test_quantile_dmatrix.py @@ -103,12 +103,29 @@ class TestQuantileDMatrix: *make_batches_sparse( n_samples_per_batch, n_features, n_batches, sparsity ), - None + None, ) Xy = xgb.QuantileDMatrix(it) assert Xy.num_row() == n_samples_per_batch * n_batches assert Xy.num_col() == n_features + def test_different_size(self) -> None: + n_samples_per_batch = 317 + n_features = 8 + n_batches = 7 + + it = IteratorForTest( + *make_batches( + n_samples_per_batch, n_features, n_batches, False, vary_size=True + ), + cache=None, + ) + Xy = xgb.QuantileDMatrix(it) + assert Xy.num_row() == 2429 + X, y, w = it.as_arrays() + Xy1 = xgb.QuantileDMatrix(X, y, weight=w) + assert predictor_equal(Xy, Xy1) + @pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.5, 0.8, 0.9]) def test_training(self, sparsity: float) -> None: n_samples_per_batch = 317 @@ -123,7 +140,7 @@ class TestQuantileDMatrix: *make_batches_sparse( n_samples_per_batch, n_features, n_batches, sparsity ), - None + None, ) parameters = {"tree_method": "hist", "max_bin": 256}