Use weakref instead of id for DataIter cache. (#9445)

- Fix case where Python reuses id from freed objects.
- Small optimization to column matrix with QDM by using `realloc` instead of copying data.
This commit is contained in:
Jiaming Yuan 2023-08-10 00:40:06 +08:00 committed by GitHub
parent d495a180d8
commit f05a23b41c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 193 additions and 63 deletions

View File

@ -22,7 +22,10 @@ import xgboost
def make_batches( def make_batches(
n_samples_per_batch: int, n_features: int, n_batches: int, tmpdir: str, n_samples_per_batch: int,
n_features: int,
n_batches: int,
tmpdir: str,
) -> List[Tuple[str, str]]: ) -> List[Tuple[str, str]]:
files: List[Tuple[str, str]] = [] files: List[Tuple[str, str]] = []
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
@ -38,6 +41,7 @@ def make_batches(
class Iterator(xgboost.DataIter): class Iterator(xgboost.DataIter):
"""A custom iterator for loading files in batches.""" """A custom iterator for loading files in batches."""
def __init__(self, file_paths: List[Tuple[str, str]]): def __init__(self, file_paths: List[Tuple[str, str]]):
self._file_paths = file_paths self._file_paths = file_paths
self._it = 0 self._it = 0

View File

@ -23,12 +23,16 @@ Core Data Structure
:show-inheritance: :show-inheritance:
.. autoclass:: xgboost.QuantileDMatrix .. autoclass:: xgboost.QuantileDMatrix
:members:
:show-inheritance: :show-inheritance:
.. autoclass:: xgboost.Booster .. autoclass:: xgboost.Booster
:members: :members:
:show-inheritance: :show-inheritance:
.. autoclass:: xgboost.DataIter
:members:
:show-inheritance:
Learning API Learning API
------------ ------------

View File

@ -8,7 +8,9 @@ from typing import (
Callable, Callable,
Dict, Dict,
List, List,
Optional,
Sequence, Sequence,
Tuple,
Type, Type,
TypeVar, TypeVar,
Union, Union,
@ -20,8 +22,6 @@ import numpy as np
DataType = Any DataType = Any
# xgboost accepts some other possible types in practice due to historical reason, which is
# lesser tested. For now we encourage users to pass a simple list of string.
FeatureInfo = Sequence[str] FeatureInfo = Sequence[str]
FeatureNames = FeatureInfo FeatureNames = FeatureInfo
FeatureTypes = FeatureInfo FeatureTypes = FeatureInfo
@ -97,6 +97,13 @@ else:
ctypes._Pointer, ctypes._Pointer,
] ]
# The second arg is actually Optional[List[cudf.Series]], skipped for easier type check.
# The cudf Series is the obtained cat codes, preserved in the `DataIter` to prevent it
# being freed.
TransformedData = Tuple[
Any, Optional[List], Optional[FeatureNames], Optional[FeatureTypes]
]
# template parameter # template parameter
_T = TypeVar("_T") _T = TypeVar("_T")
_F = TypeVar("_F", bound=Callable[..., Any]) _F = TypeVar("_F", bound=Callable[..., Any])

View File

@ -9,6 +9,7 @@ import os
import re import re
import sys import sys
import warnings import warnings
import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from enum import IntEnum, unique from enum import IntEnum, unique
@ -51,6 +52,7 @@ from ._typing import (
FeatureTypes, FeatureTypes,
ModelIn, ModelIn,
NumpyOrCupy, NumpyOrCupy,
TransformedData,
c_bst_ulong, c_bst_ulong,
) )
from .compat import PANDAS_INSTALLED, DataFrame, py_str from .compat import PANDAS_INSTALLED, DataFrame, py_str
@ -486,7 +488,16 @@ def _prediction_output(
class DataIter(ABC): # pylint: disable=too-many-instance-attributes class DataIter(ABC): # pylint: disable=too-many-instance-attributes
"""The interface for user defined data iterator. """The interface for user defined data iterator. The iterator facilitates
distributed training, :py:class:`QuantileDMatrix`, and external memory support using
:py:class:`DMatrix`. Most of time, users don't need to interact with this class
directly.
.. note::
The class caches some intermediate results using the `data` input (predictor
`X`) as key. Don't repeat the `X` for multiple batches with different meta data
(like `label`), make a copy if necessary.
Parameters Parameters
---------- ----------
@ -510,13 +521,13 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
self._allow_host = True self._allow_host = True
self._release = release_data self._release = release_data
# Stage data in Python until reset or next is called to avoid data being free. # Stage data in Python until reset or next is called to avoid data being free.
self._temporary_data: Optional[Tuple[Any, Any, Any, Any]] = None self._temporary_data: Optional[TransformedData] = None
self._input_id: int = 0 self._data_ref: Optional[weakref.ReferenceType] = None
def get_callbacks( def get_callbacks(
self, allow_host: bool, enable_categorical: bool self, allow_host: bool, enable_categorical: bool
) -> Tuple[Callable, Callable]: ) -> Tuple[Callable, Callable]:
"""Get callback functions for iterating in C.""" """Get callback functions for iterating in C. This is an internal function."""
assert hasattr(self, "cache_prefix"), "__init__ is not called." assert hasattr(self, "cache_prefix"), "__init__ is not called."
self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)( self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(
self._reset_wrapper self._reset_wrapper
@ -591,7 +602,19 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
from .data import _proxy_transform, dispatch_proxy_set_data from .data import _proxy_transform, dispatch_proxy_set_data
# Reduce the amount of transformation that's needed for QuantileDMatrix. # Reduce the amount of transformation that's needed for QuantileDMatrix.
if self._temporary_data is not None and id(data) == self._input_id: #
# To construct the QDM, one needs 4 iterations on CPU, or 2 iterations on
# GPU. If the QDM has only one batch of input (most of the cases), we can
# avoid transforming the data repeatly.
try:
ref = weakref.ref(data)
except TypeError:
ref = None
if (
self._temporary_data is not None
and ref is not None
and ref is self._data_ref
):
new, cat_codes, feature_names, feature_types = self._temporary_data new, cat_codes, feature_names, feature_types = self._temporary_data
else: else:
new, cat_codes, feature_names, feature_types = _proxy_transform( new, cat_codes, feature_names, feature_types = _proxy_transform(
@ -608,7 +631,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
feature_types=feature_types, feature_types=feature_types,
**kwargs, **kwargs,
) )
self._input_id = id(data) self._data_ref = ref
# pylint: disable=not-callable # pylint: disable=not-callable
return self._handle_exception(lambda: self.next(input_data), 0) return self._handle_exception(lambda: self.next(input_data), 0)
@ -1134,7 +1157,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
testing purposes. If this is a quantized DMatrix then quantized values are testing purposes. If this is a quantized DMatrix then quantized values are
returned instead of input values. returned instead of input values.
.. versionadded:: 1.7.0 .. versionadded:: 1.7.0
""" """
indptr = np.empty(self.num_row() + 1, dtype=np.uint64) indptr = np.empty(self.num_row() + 1, dtype=np.uint64)
@ -1155,7 +1178,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
return ret return ret
def get_quantile_cut(self) -> Tuple[np.ndarray, np.ndarray]: def get_quantile_cut(self) -> Tuple[np.ndarray, np.ndarray]:
"""Get quantile cuts for quantization.""" """Get quantile cuts for quantization.
.. versionadded:: 2.0.0
"""
n_features = self.num_col() n_features = self.num_col()
c_sindptr = ctypes.c_char_p() c_sindptr = ctypes.c_char_p()

View File

@ -5,7 +5,7 @@ import ctypes
import json import json
import os import os
import warnings import warnings
from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, Union, cast from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, cast
import numpy as np import numpy as np
@ -17,6 +17,7 @@ from ._typing import (
FloatCompatible, FloatCompatible,
NumpyDType, NumpyDType,
PandasDType, PandasDType,
TransformedData,
c_bst_ulong, c_bst_ulong,
) )
from .compat import DataFrame, lazy_isinstance from .compat import DataFrame, lazy_isinstance
@ -1268,12 +1269,7 @@ def _proxy_transform(
feature_names: Optional[FeatureNames], feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes], feature_types: Optional[FeatureTypes],
enable_categorical: bool, enable_categorical: bool,
) -> Tuple[ ) -> TransformedData:
Union[bool, ctypes.c_void_p, np.ndarray],
Optional[list],
Optional[FeatureNames],
Optional[FeatureTypes],
]:
if _is_cudf_df(data) or _is_cudf_ser(data): if _is_cudf_df(data) or _is_cudf_ser(data):
return _transform_cudf_df( return _transform_cudf_df(
data, feature_names, feature_types, enable_categorical data, feature_names, feature_types, enable_categorical

View File

@ -230,7 +230,7 @@ class IteratorForTest(xgb.core.DataIter):
def as_arrays( def as_arrays(
self, self,
) -> Tuple[Union[np.ndarray, sparse.csr_matrix], ArrayLike, ArrayLike]: ) -> Tuple[Union[np.ndarray, sparse.csr_matrix], ArrayLike, Optional[ArrayLike]]:
if isinstance(self.X[0], sparse.csr_matrix): if isinstance(self.X[0], sparse.csr_matrix):
X = sparse.vstack(self.X, format="csr") X = sparse.vstack(self.X, format="csr")
else: else:
@ -244,7 +244,12 @@ class IteratorForTest(xgb.core.DataIter):
def make_batches( def make_batches(
n_samples_per_batch: int, n_features: int, n_batches: int, use_cupy: bool = False n_samples_per_batch: int,
n_features: int,
n_batches: int,
use_cupy: bool = False,
*,
vary_size: bool = False,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
X = [] X = []
y = [] y = []
@ -255,10 +260,11 @@ def make_batches(
rng = cupy.random.RandomState(1994) rng = cupy.random.RandomState(1994)
else: else:
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
for _ in range(n_batches): for i in range(n_batches):
_X = rng.randn(n_samples_per_batch, n_features) n_samples = n_samples_per_batch + i * 10 if vary_size else n_samples_per_batch
_y = rng.randn(n_samples_per_batch) _X = rng.randn(n_samples, n_features)
_w = rng.uniform(low=0, high=1, size=n_samples_per_batch) _y = rng.randn(n_samples)
_w = rng.uniform(low=0, high=1, size=n_samples)
X.append(_X) X.append(_X)
y.append(_y) y.append(_y)
w.append(_w) w.append(_w)

View File

@ -9,12 +9,12 @@
#define XGBOOST_COMMON_COLUMN_MATRIX_H_ #define XGBOOST_COMMON_COLUMN_MATRIX_H_
#include <algorithm> #include <algorithm>
#include <cstddef> // for size_t #include <cstddef> // for size_t, byte
#include <cstdint> // for uint8_t #include <cstdint> // for uint8_t
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <utility> // for move #include <type_traits> // for enable_if_t, is_same_v, is_signed_v
#include <vector> #include <utility> // for move
#include "../data/adapter.h" #include "../data/adapter.h"
#include "../data/gradient_index.h" #include "../data/gradient_index.h"
@ -112,9 +112,6 @@ class SparseColumnIter : public Column<BinIdxT> {
*/ */
template <typename BinIdxT, bool any_missing> template <typename BinIdxT, bool any_missing>
class DenseColumnIter : public Column<BinIdxT> { class DenseColumnIter : public Column<BinIdxT> {
public:
using ByteType = bool;
private: private:
using Base = Column<BinIdxT>; using Base = Column<BinIdxT>;
/* flags for missing values in dense columns */ /* flags for missing values in dense columns */
@ -153,8 +150,17 @@ class ColumnMatrix {
* @brief A bit set for indicating whether an element in a dense column is missing. * @brief A bit set for indicating whether an element in a dense column is missing.
*/ */
struct MissingIndicator { struct MissingIndicator {
LBitField32 missing; using BitFieldT = LBitField32;
RefResourceView<std::uint32_t> storage; using T = typename BitFieldT::value_type;
BitFieldT missing;
RefResourceView<T> storage;
static_assert(std::is_same_v<T, std::uint32_t>);
template <typename U>
[[nodiscard]] std::enable_if_t<!std::is_signed_v<U>, U> static InitValue(bool init) {
return init ? ~U{0} : U{0};
}
MissingIndicator() = default; MissingIndicator() = default;
/** /**
@ -163,7 +169,7 @@ class ColumnMatrix {
*/ */
MissingIndicator(std::size_t n_elements, bool init) { MissingIndicator(std::size_t n_elements, bool init) {
auto m_size = missing.ComputeStorageSize(n_elements); auto m_size = missing.ComputeStorageSize(n_elements);
storage = common::MakeFixedVecWithMalloc(m_size, init ? ~std::uint32_t{0} : std::uint32_t{0}); storage = common::MakeFixedVecWithMalloc(m_size, InitValue<T>(init));
this->InitView(); this->InitView();
} }
/** @brief Set the i^th element to be a valid element (instead of missing). */ /** @brief Set the i^th element to be a valid element (instead of missing). */
@ -181,11 +187,12 @@ class ColumnMatrix {
if (m_size == storage.size()) { if (m_size == storage.size()) {
return; return;
} }
// grow the storage
auto resource = std::dynamic_pointer_cast<common::MallocResource>(storage.Resource());
CHECK(resource);
resource->Resize(m_size * sizeof(T), InitValue<std::byte>(init));
storage = RefResourceView<T>{resource->DataAs<T>(), m_size, resource};
auto new_storage =
common::MakeFixedVecWithMalloc(m_size, init ? ~std::uint32_t{0} : std::uint32_t{0});
std::copy_n(storage.cbegin(), storage.size(), new_storage.begin());
storage = std::move(new_storage);
this->InitView(); this->InitView();
} }
}; };
@ -210,7 +217,6 @@ class ColumnMatrix {
} }
public: public:
using ByteType = bool;
// get number of features // get number of features
[[nodiscard]] bst_feature_t GetNumFeature() const { [[nodiscard]] bst_feature_t GetNumFeature() const {
return static_cast<bst_feature_t>(type_.size()); return static_cast<bst_feature_t>(type_.size());
@ -408,6 +414,7 @@ class ColumnMatrix {
// IO procedures for external memory. // IO procedures for external memory.
[[nodiscard]] bool Read(AlignedResourceReadStream* fi, uint32_t const* index_base); [[nodiscard]] bool Read(AlignedResourceReadStream* fi, uint32_t const* index_base);
[[nodiscard]] std::size_t Write(AlignedFileWriteStream* fo) const; [[nodiscard]] std::size_t Write(AlignedFileWriteStream* fo) const;
[[nodiscard]] MissingIndicator const& Missing() const { return missing_; }
private: private:
RefResourceView<std::uint8_t> index_; RefResourceView<std::uint8_t> index_;

View File

@ -10,7 +10,7 @@
#include <dmlc/io.h> #include <dmlc/io.h>
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <algorithm> // for min #include <algorithm> // for min, fill_n, copy_n
#include <array> // for array #include <array> // for array
#include <cstddef> // for byte, size_t #include <cstddef> // for byte, size_t
#include <cstdlib> // for malloc, realloc, free #include <cstdlib> // for malloc, realloc, free
@ -207,7 +207,7 @@ class MallocResource : public ResourceHandler {
* @param n_bytes The new size. * @param n_bytes The new size.
*/ */
template <bool force_malloc = false> template <bool force_malloc = false>
void Resize(std::size_t n_bytes) { void Resize(std::size_t n_bytes, std::byte init = std::byte{0}) {
// realloc(ptr, 0) works, but is deprecated. // realloc(ptr, 0) works, but is deprecated.
if (n_bytes == 0) { if (n_bytes == 0) {
this->Clear(); this->Clear();
@ -236,7 +236,7 @@ class MallocResource : public ResourceHandler {
std::copy_n(reinterpret_cast<std::byte*>(ptr_), n_, reinterpret_cast<std::byte*>(new_ptr)); std::copy_n(reinterpret_cast<std::byte*>(ptr_), n_, reinterpret_cast<std::byte*>(new_ptr));
} }
// default initialize // default initialize
std::memset(reinterpret_cast<std::byte*>(new_ptr) + n_, '\0', n_bytes - n_); std::fill_n(reinterpret_cast<std::byte*>(new_ptr) + n_, n_bytes - n_, init);
// free the old ptr if malloc is used. // free the old ptr if malloc is used.
if (need_copy) { if (need_copy) {
this->Clear(); this->Clear();

View File

@ -42,6 +42,7 @@ class LintersPaths:
"demo/guide-python/feature_weights.py", "demo/guide-python/feature_weights.py",
"demo/guide-python/sklearn_parallel.py", "demo/guide-python/sklearn_parallel.py",
"demo/guide-python/spark_estimator_examples.py", "demo/guide-python/spark_estimator_examples.py",
"demo/guide-python/external_memory.py",
"demo/guide-python/individual_trees.py", "demo/guide-python/individual_trees.py",
"demo/guide-python/quantile_regression.py", "demo/guide-python/quantile_regression.py",
"demo/guide-python/multioutput_regression.py", "demo/guide-python/multioutput_regression.py",

View File

@ -2,15 +2,26 @@
* Copyright 2018-2023 by XGBoost Contributors * Copyright 2018-2023 by XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/base.h> // for bst_bin_t
#include <xgboost/context.h> // for Context
#include <xgboost/data.h> // for BatchIterator, BatchSet, DMatrix, Met...
#include "../../../src/common/column_matrix.h" #include <cstddef> // for size_t
#include "../helpers.h" #include <cstdint> // for int32_t, uint16_t, uint8_t
#include <limits> // for numeric_limits
#include <memory> // for shared_ptr, __shared_ptr_access, allo...
#include <type_traits> // for remove_reference_t
#include "../../../src/common/column_matrix.h" // for ColumnMatrix, Column, DenseColumnIter
#include "../../../src/common/hist_util.h" // for DispatchBinType, BinTypeSize, Index
#include "../../../src/common/ref_resource_view.h" // for RefResourceView
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
#include "../../../src/data/iterative_dmatrix.h" // for IterativeDMatrix
#include "../../../src/tree/param.h" // for TrainParam
#include "../helpers.h" // for RandomDataGenerator, NumpyArrayIterFo...
namespace xgboost { namespace xgboost::common {
namespace common { TEST(ColumnMatrix, Basic) {
TEST(DenseColumn, Test) {
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1, int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1, static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2}; static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
@ -22,7 +33,7 @@ TEST(DenseColumn, Test) {
GHistIndexMatrix gmat{&ctx, dmat.get(), max_num_bin, sparse_thresh, false}; GHistIndexMatrix gmat{&ctx, dmat.get(), max_num_bin, sparse_thresh, false};
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
for (auto const& page : dmat->GetBatches<SparsePage>()) { for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.InitFromSparse(page, gmat, sparse_thresh, AllThreadsForTest()); column_matrix.InitFromSparse(page, gmat, sparse_thresh, ctx.Threads());
} }
ASSERT_GE(column_matrix.GetTypeSize(), last); ASSERT_GE(column_matrix.GetTypeSize(), last);
ASSERT_LE(column_matrix.GetTypeSize(), kUint32BinsTypeSize); ASSERT_LE(column_matrix.GetTypeSize(), kUint32BinsTypeSize);
@ -59,7 +70,7 @@ void CheckSparseColumn(SparseColumnIter<BinIdxType>* p_col, const GHistIndexMatr
} }
} }
TEST(SparseColumn, Test) { TEST(ColumnMatrix, SparseColumn) {
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1, int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1, static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2}; static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
@ -69,7 +80,7 @@ TEST(SparseColumn, Test) {
GHistIndexMatrix gmat{&ctx, dmat.get(), max_num_bin, 0.5f, false}; GHistIndexMatrix gmat{&ctx, dmat.get(), max_num_bin, 0.5f, false};
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
for (auto const& page : dmat->GetBatches<SparsePage>()) { for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.InitFromSparse(page, gmat, 1.0, AllThreadsForTest()); column_matrix.InitFromSparse(page, gmat, 1.0, ctx.Threads());
} }
common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) { common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
using T = decltype(dtype); using T = decltype(dtype);
@ -90,7 +101,7 @@ void CheckColumWithMissingValue(const DenseColumnIter<BinIdxType, true>& col,
} }
} }
TEST(DenseColumnWithMissing, Test) { TEST(ColumnMatrix, DenseColumnWithMissing) {
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1, int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1, static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2}; static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
@ -100,7 +111,7 @@ TEST(DenseColumnWithMissing, Test) {
GHistIndexMatrix gmat(&ctx, dmat.get(), max_num_bin, 0.2, false); GHistIndexMatrix gmat(&ctx, dmat.get(), max_num_bin, 0.2, false);
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
for (auto const& page : dmat->GetBatches<SparsePage>()) { for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.InitFromSparse(page, gmat, 0.2, AllThreadsForTest()); column_matrix.InitFromSparse(page, gmat, 0.2, ctx.Threads());
} }
ASSERT_TRUE(column_matrix.AnyMissing()); ASSERT_TRUE(column_matrix.AnyMissing());
DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) { DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
@ -110,5 +121,29 @@ TEST(DenseColumnWithMissing, Test) {
}); });
} }
} }
} // namespace common
} // namespace xgboost TEST(ColumnMatrix, GrowMissing) {
float sparsity = 0.5;
NumpyArrayIterForTest iter(sparsity);
auto n_threads = 0;
bst_bin_t n_bins = 16;
BatchParam batch{n_bins, tree::TrainParam::DftSparseThreshold()};
Context ctx;
auto m = std::make_shared<data::IterativeDMatrix>(&iter, iter.Proxy(), nullptr, Reset, Next,
std::numeric_limits<float>::quiet_NaN(),
n_threads, n_bins);
for (auto const& page : m->GetBatches<GHistIndexMatrix>(&ctx, batch)) {
auto const& column_matrix = page.Transpose();
auto const& missing = column_matrix.Missing();
auto n = NumpyArrayIterForTest::Rows() * NumpyArrayIterForTest::Cols();
auto expected = std::remove_reference_t<decltype(missing)>::BitFieldT::ComputeStorageSize(n);
auto got = missing.storage.size();
ASSERT_EQ(expected, got);
DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
using T = decltype(dtype);
auto col = column_matrix.DenseColumn<T, true>(0);
CheckColumWithMissingValue(col, page);
});
}
}
} // namespace xgboost::common

View File

@ -119,6 +119,20 @@ TEST(IO, Resource) {
for (std::size_t i = n; i < 2 * n; ++i) { for (std::size_t i = n; i < 2 * n; ++i) {
ASSERT_EQ(malloc_resource->DataAs<std::uint8_t>()[i], 0); ASSERT_EQ(malloc_resource->DataAs<std::uint8_t>()[i], 0);
} }
ptr = malloc_resource->DataAs<std::uint8_t>();
std::fill_n(ptr, malloc_resource->Size(), 7);
if (force_malloc) {
malloc_resource->Resize<true>(n * 3, std::byte{3});
} else {
malloc_resource->Resize<false>(n * 3, std::byte{3});
}
for (std::size_t i = 0; i < n * 2; ++i) {
ASSERT_EQ(malloc_resource->DataAs<std::uint8_t>()[i], 7);
}
for (std::size_t i = n * 2; i < n * 3; ++i) {
ASSERT_EQ(malloc_resource->DataAs<std::uint8_t>()[i], 3);
}
}; };
test_malloc_resize(true); test_malloc_resize(true);
test_malloc_resize(false); test_malloc_resize(false);

View File

@ -12,8 +12,7 @@
#include "../helpers.h" #include "../helpers.h"
#include "xgboost/data.h" // DMatrix #include "xgboost/data.h" // DMatrix
namespace xgboost { namespace xgboost::data {
namespace data {
TEST(IterativeDMatrix, Ref) { TEST(IterativeDMatrix, Ref) {
Context ctx; Context ctx;
TestRefDMatrix<GHistIndexMatrix, NumpyArrayIterForTest>( TestRefDMatrix<GHistIndexMatrix, NumpyArrayIterForTest>(
@ -21,7 +20,7 @@ TEST(IterativeDMatrix, Ref) {
} }
TEST(IterativeDMatrix, IsDense) { TEST(IterativeDMatrix, IsDense) {
int n_bins = 16; bst_bin_t n_bins = 16;
auto test = [n_bins](float sparsity) { auto test = [n_bins](float sparsity) {
NumpyArrayIterForTest iter(sparsity); NumpyArrayIterForTest iter(sparsity);
auto n_threads = 0; auto n_threads = 0;
@ -38,5 +37,4 @@ TEST(IterativeDMatrix, IsDense) {
test(0.1); test(0.1);
test(1.0); test(1.0);
} }
} // namespace data } // namespace xgboost::data
} // namespace xgboost

View File

@ -1,4 +1,5 @@
from typing import Callable, Dict, List import weakref
from typing import Any, Callable, Dict, List
import numpy as np import numpy as np
import pytest import pytest
@ -179,5 +180,18 @@ def test_data_cache() -> None:
data = make_batches(n_samples_per_batch, n_features, n_batches, False) data = make_batches(n_samples_per_batch, n_features, n_batches, False)
batches = [v[0] for v in data] batches = [v[0] for v in data]
it = IterForCacheTest(*batches) it = IterForCacheTest(*batches)
transform = xgb.data._proxy_transform
called = 0
def mock(*args: Any, **kwargs: Any) -> Any:
nonlocal called
called += 1
return transform(*args, **kwargs)
xgb.data._proxy_transform = mock
xgb.QuantileDMatrix(it) xgb.QuantileDMatrix(it)
assert it._input_id == id(batches[0]) assert it._data_ref is weakref.ref(batches[0])
assert called == 1
xgb.data._proxy_transform = transform

View File

@ -103,12 +103,29 @@ class TestQuantileDMatrix:
*make_batches_sparse( *make_batches_sparse(
n_samples_per_batch, n_features, n_batches, sparsity n_samples_per_batch, n_features, n_batches, sparsity
), ),
None None,
) )
Xy = xgb.QuantileDMatrix(it) Xy = xgb.QuantileDMatrix(it)
assert Xy.num_row() == n_samples_per_batch * n_batches assert Xy.num_row() == n_samples_per_batch * n_batches
assert Xy.num_col() == n_features assert Xy.num_col() == n_features
def test_different_size(self) -> None:
n_samples_per_batch = 317
n_features = 8
n_batches = 7
it = IteratorForTest(
*make_batches(
n_samples_per_batch, n_features, n_batches, False, vary_size=True
),
cache=None,
)
Xy = xgb.QuantileDMatrix(it)
assert Xy.num_row() == 2429
X, y, w = it.as_arrays()
Xy1 = xgb.QuantileDMatrix(X, y, weight=w)
assert predictor_equal(Xy, Xy1)
@pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.5, 0.8, 0.9]) @pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.5, 0.8, 0.9])
def test_training(self, sparsity: float) -> None: def test_training(self, sparsity: float) -> None:
n_samples_per_batch = 317 n_samples_per_batch = 317
@ -123,7 +140,7 @@ class TestQuantileDMatrix:
*make_batches_sparse( *make_batches_sparse(
n_samples_per_batch, n_features, n_batches, sparsity n_samples_per_batch, n_features, n_batches, sparsity
), ),
None None,
) )
parameters = {"tree_method": "hist", "max_bin": 256} parameters = {"tree_method": "hist", "max_bin": 256}