Use weakref instead of id for DataIter cache. (#9445)
- Fix case where Python reuses id from freed objects. - Small optimization to column matrix with QDM by using `realloc` instead of copying data.
This commit is contained in:
parent
d495a180d8
commit
f05a23b41c
@ -22,7 +22,10 @@ import xgboost
|
||||
|
||||
|
||||
def make_batches(
|
||||
n_samples_per_batch: int, n_features: int, n_batches: int, tmpdir: str,
|
||||
n_samples_per_batch: int,
|
||||
n_features: int,
|
||||
n_batches: int,
|
||||
tmpdir: str,
|
||||
) -> List[Tuple[str, str]]:
|
||||
files: List[Tuple[str, str]] = []
|
||||
rng = np.random.RandomState(1994)
|
||||
@ -38,6 +41,7 @@ def make_batches(
|
||||
|
||||
class Iterator(xgboost.DataIter):
|
||||
"""A custom iterator for loading files in batches."""
|
||||
|
||||
def __init__(self, file_paths: List[Tuple[str, str]]):
|
||||
self._file_paths = file_paths
|
||||
self._it = 0
|
||||
|
||||
@ -23,12 +23,16 @@ Core Data Structure
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: xgboost.QuantileDMatrix
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: xgboost.Booster
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: xgboost.DataIter
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
Learning API
|
||||
------------
|
||||
|
||||
@ -8,7 +8,9 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -20,8 +22,6 @@ import numpy as np
|
||||
|
||||
DataType = Any
|
||||
|
||||
# xgboost accepts some other possible types in practice due to historical reason, which is
|
||||
# lesser tested. For now we encourage users to pass a simple list of string.
|
||||
FeatureInfo = Sequence[str]
|
||||
FeatureNames = FeatureInfo
|
||||
FeatureTypes = FeatureInfo
|
||||
@ -97,6 +97,13 @@ else:
|
||||
ctypes._Pointer,
|
||||
]
|
||||
|
||||
# The second arg is actually Optional[List[cudf.Series]], skipped for easier type check.
|
||||
# The cudf Series is the obtained cat codes, preserved in the `DataIter` to prevent it
|
||||
# being freed.
|
||||
TransformedData = Tuple[
|
||||
Any, Optional[List], Optional[FeatureNames], Optional[FeatureTypes]
|
||||
]
|
||||
|
||||
# template parameter
|
||||
_T = TypeVar("_T")
|
||||
_F = TypeVar("_F", bound=Callable[..., Any])
|
||||
|
||||
@ -9,6 +9,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from enum import IntEnum, unique
|
||||
@ -51,6 +52,7 @@ from ._typing import (
|
||||
FeatureTypes,
|
||||
ModelIn,
|
||||
NumpyOrCupy,
|
||||
TransformedData,
|
||||
c_bst_ulong,
|
||||
)
|
||||
from .compat import PANDAS_INSTALLED, DataFrame, py_str
|
||||
@ -486,7 +488,16 @@ def _prediction_output(
|
||||
|
||||
|
||||
class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
"""The interface for user defined data iterator.
|
||||
"""The interface for user defined data iterator. The iterator facilitates
|
||||
distributed training, :py:class:`QuantileDMatrix`, and external memory support using
|
||||
:py:class:`DMatrix`. Most of time, users don't need to interact with this class
|
||||
directly.
|
||||
|
||||
.. note::
|
||||
|
||||
The class caches some intermediate results using the `data` input (predictor
|
||||
`X`) as key. Don't repeat the `X` for multiple batches with different meta data
|
||||
(like `label`), make a copy if necessary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -510,13 +521,13 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
self._allow_host = True
|
||||
self._release = release_data
|
||||
# Stage data in Python until reset or next is called to avoid data being free.
|
||||
self._temporary_data: Optional[Tuple[Any, Any, Any, Any]] = None
|
||||
self._input_id: int = 0
|
||||
self._temporary_data: Optional[TransformedData] = None
|
||||
self._data_ref: Optional[weakref.ReferenceType] = None
|
||||
|
||||
def get_callbacks(
|
||||
self, allow_host: bool, enable_categorical: bool
|
||||
) -> Tuple[Callable, Callable]:
|
||||
"""Get callback functions for iterating in C."""
|
||||
"""Get callback functions for iterating in C. This is an internal function."""
|
||||
assert hasattr(self, "cache_prefix"), "__init__ is not called."
|
||||
self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(
|
||||
self._reset_wrapper
|
||||
@ -591,7 +602,19 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
from .data import _proxy_transform, dispatch_proxy_set_data
|
||||
|
||||
# Reduce the amount of transformation that's needed for QuantileDMatrix.
|
||||
if self._temporary_data is not None and id(data) == self._input_id:
|
||||
#
|
||||
# To construct the QDM, one needs 4 iterations on CPU, or 2 iterations on
|
||||
# GPU. If the QDM has only one batch of input (most of the cases), we can
|
||||
# avoid transforming the data repeatly.
|
||||
try:
|
||||
ref = weakref.ref(data)
|
||||
except TypeError:
|
||||
ref = None
|
||||
if (
|
||||
self._temporary_data is not None
|
||||
and ref is not None
|
||||
and ref is self._data_ref
|
||||
):
|
||||
new, cat_codes, feature_names, feature_types = self._temporary_data
|
||||
else:
|
||||
new, cat_codes, feature_names, feature_types = _proxy_transform(
|
||||
@ -608,7 +631,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
feature_types=feature_types,
|
||||
**kwargs,
|
||||
)
|
||||
self._input_id = id(data)
|
||||
self._data_ref = ref
|
||||
|
||||
# pylint: disable=not-callable
|
||||
return self._handle_exception(lambda: self.next(input_data), 0)
|
||||
@ -1134,7 +1157,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
||||
testing purposes. If this is a quantized DMatrix then quantized values are
|
||||
returned instead of input values.
|
||||
|
||||
.. versionadded:: 1.7.0
|
||||
.. versionadded:: 1.7.0
|
||||
|
||||
"""
|
||||
indptr = np.empty(self.num_row() + 1, dtype=np.uint64)
|
||||
@ -1155,7 +1178,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
||||
return ret
|
||||
|
||||
def get_quantile_cut(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Get quantile cuts for quantization."""
|
||||
"""Get quantile cuts for quantization.
|
||||
|
||||
.. versionadded:: 2.0.0
|
||||
|
||||
"""
|
||||
n_features = self.num_col()
|
||||
|
||||
c_sindptr = ctypes.c_char_p()
|
||||
|
||||
@ -5,7 +5,7 @@ import ctypes
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, Union, cast
|
||||
from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -17,6 +17,7 @@ from ._typing import (
|
||||
FloatCompatible,
|
||||
NumpyDType,
|
||||
PandasDType,
|
||||
TransformedData,
|
||||
c_bst_ulong,
|
||||
)
|
||||
from .compat import DataFrame, lazy_isinstance
|
||||
@ -1268,12 +1269,7 @@ def _proxy_transform(
|
||||
feature_names: Optional[FeatureNames],
|
||||
feature_types: Optional[FeatureTypes],
|
||||
enable_categorical: bool,
|
||||
) -> Tuple[
|
||||
Union[bool, ctypes.c_void_p, np.ndarray],
|
||||
Optional[list],
|
||||
Optional[FeatureNames],
|
||||
Optional[FeatureTypes],
|
||||
]:
|
||||
) -> TransformedData:
|
||||
if _is_cudf_df(data) or _is_cudf_ser(data):
|
||||
return _transform_cudf_df(
|
||||
data, feature_names, feature_types, enable_categorical
|
||||
|
||||
@ -230,7 +230,7 @@ class IteratorForTest(xgb.core.DataIter):
|
||||
|
||||
def as_arrays(
|
||||
self,
|
||||
) -> Tuple[Union[np.ndarray, sparse.csr_matrix], ArrayLike, ArrayLike]:
|
||||
) -> Tuple[Union[np.ndarray, sparse.csr_matrix], ArrayLike, Optional[ArrayLike]]:
|
||||
if isinstance(self.X[0], sparse.csr_matrix):
|
||||
X = sparse.vstack(self.X, format="csr")
|
||||
else:
|
||||
@ -244,7 +244,12 @@ class IteratorForTest(xgb.core.DataIter):
|
||||
|
||||
|
||||
def make_batches(
|
||||
n_samples_per_batch: int, n_features: int, n_batches: int, use_cupy: bool = False
|
||||
n_samples_per_batch: int,
|
||||
n_features: int,
|
||||
n_batches: int,
|
||||
use_cupy: bool = False,
|
||||
*,
|
||||
vary_size: bool = False,
|
||||
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
|
||||
X = []
|
||||
y = []
|
||||
@ -255,10 +260,11 @@ def make_batches(
|
||||
rng = cupy.random.RandomState(1994)
|
||||
else:
|
||||
rng = np.random.RandomState(1994)
|
||||
for _ in range(n_batches):
|
||||
_X = rng.randn(n_samples_per_batch, n_features)
|
||||
_y = rng.randn(n_samples_per_batch)
|
||||
_w = rng.uniform(low=0, high=1, size=n_samples_per_batch)
|
||||
for i in range(n_batches):
|
||||
n_samples = n_samples_per_batch + i * 10 if vary_size else n_samples_per_batch
|
||||
_X = rng.randn(n_samples, n_features)
|
||||
_y = rng.randn(n_samples)
|
||||
_w = rng.uniform(low=0, high=1, size=n_samples)
|
||||
X.append(_X)
|
||||
y.append(_y)
|
||||
w.append(_w)
|
||||
|
||||
@ -9,12 +9,12 @@
|
||||
#define XGBOOST_COMMON_COLUMN_MATRIX_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstddef> // for size_t, byte
|
||||
#include <cstdint> // for uint8_t
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <utility> // for move
|
||||
#include <vector>
|
||||
#include <type_traits> // for enable_if_t, is_same_v, is_signed_v
|
||||
#include <utility> // for move
|
||||
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/gradient_index.h"
|
||||
@ -112,9 +112,6 @@ class SparseColumnIter : public Column<BinIdxT> {
|
||||
*/
|
||||
template <typename BinIdxT, bool any_missing>
|
||||
class DenseColumnIter : public Column<BinIdxT> {
|
||||
public:
|
||||
using ByteType = bool;
|
||||
|
||||
private:
|
||||
using Base = Column<BinIdxT>;
|
||||
/* flags for missing values in dense columns */
|
||||
@ -153,8 +150,17 @@ class ColumnMatrix {
|
||||
* @brief A bit set for indicating whether an element in a dense column is missing.
|
||||
*/
|
||||
struct MissingIndicator {
|
||||
LBitField32 missing;
|
||||
RefResourceView<std::uint32_t> storage;
|
||||
using BitFieldT = LBitField32;
|
||||
using T = typename BitFieldT::value_type;
|
||||
|
||||
BitFieldT missing;
|
||||
RefResourceView<T> storage;
|
||||
static_assert(std::is_same_v<T, std::uint32_t>);
|
||||
|
||||
template <typename U>
|
||||
[[nodiscard]] std::enable_if_t<!std::is_signed_v<U>, U> static InitValue(bool init) {
|
||||
return init ? ~U{0} : U{0};
|
||||
}
|
||||
|
||||
MissingIndicator() = default;
|
||||
/**
|
||||
@ -163,7 +169,7 @@ class ColumnMatrix {
|
||||
*/
|
||||
MissingIndicator(std::size_t n_elements, bool init) {
|
||||
auto m_size = missing.ComputeStorageSize(n_elements);
|
||||
storage = common::MakeFixedVecWithMalloc(m_size, init ? ~std::uint32_t{0} : std::uint32_t{0});
|
||||
storage = common::MakeFixedVecWithMalloc(m_size, InitValue<T>(init));
|
||||
this->InitView();
|
||||
}
|
||||
/** @brief Set the i^th element to be a valid element (instead of missing). */
|
||||
@ -181,11 +187,12 @@ class ColumnMatrix {
|
||||
if (m_size == storage.size()) {
|
||||
return;
|
||||
}
|
||||
// grow the storage
|
||||
auto resource = std::dynamic_pointer_cast<common::MallocResource>(storage.Resource());
|
||||
CHECK(resource);
|
||||
resource->Resize(m_size * sizeof(T), InitValue<std::byte>(init));
|
||||
storage = RefResourceView<T>{resource->DataAs<T>(), m_size, resource};
|
||||
|
||||
auto new_storage =
|
||||
common::MakeFixedVecWithMalloc(m_size, init ? ~std::uint32_t{0} : std::uint32_t{0});
|
||||
std::copy_n(storage.cbegin(), storage.size(), new_storage.begin());
|
||||
storage = std::move(new_storage);
|
||||
this->InitView();
|
||||
}
|
||||
};
|
||||
@ -210,7 +217,6 @@ class ColumnMatrix {
|
||||
}
|
||||
|
||||
public:
|
||||
using ByteType = bool;
|
||||
// get number of features
|
||||
[[nodiscard]] bst_feature_t GetNumFeature() const {
|
||||
return static_cast<bst_feature_t>(type_.size());
|
||||
@ -408,6 +414,7 @@ class ColumnMatrix {
|
||||
// IO procedures for external memory.
|
||||
[[nodiscard]] bool Read(AlignedResourceReadStream* fi, uint32_t const* index_base);
|
||||
[[nodiscard]] std::size_t Write(AlignedFileWriteStream* fo) const;
|
||||
[[nodiscard]] MissingIndicator const& Missing() const { return missing_; }
|
||||
|
||||
private:
|
||||
RefResourceView<std::uint8_t> index_;
|
||||
|
||||
@ -10,7 +10,7 @@
|
||||
#include <dmlc/io.h>
|
||||
#include <rabit/rabit.h>
|
||||
|
||||
#include <algorithm> // for min
|
||||
#include <algorithm> // for min, fill_n, copy_n
|
||||
#include <array> // for array
|
||||
#include <cstddef> // for byte, size_t
|
||||
#include <cstdlib> // for malloc, realloc, free
|
||||
@ -207,7 +207,7 @@ class MallocResource : public ResourceHandler {
|
||||
* @param n_bytes The new size.
|
||||
*/
|
||||
template <bool force_malloc = false>
|
||||
void Resize(std::size_t n_bytes) {
|
||||
void Resize(std::size_t n_bytes, std::byte init = std::byte{0}) {
|
||||
// realloc(ptr, 0) works, but is deprecated.
|
||||
if (n_bytes == 0) {
|
||||
this->Clear();
|
||||
@ -236,7 +236,7 @@ class MallocResource : public ResourceHandler {
|
||||
std::copy_n(reinterpret_cast<std::byte*>(ptr_), n_, reinterpret_cast<std::byte*>(new_ptr));
|
||||
}
|
||||
// default initialize
|
||||
std::memset(reinterpret_cast<std::byte*>(new_ptr) + n_, '\0', n_bytes - n_);
|
||||
std::fill_n(reinterpret_cast<std::byte*>(new_ptr) + n_, n_bytes - n_, init);
|
||||
// free the old ptr if malloc is used.
|
||||
if (need_copy) {
|
||||
this->Clear();
|
||||
|
||||
@ -42,6 +42,7 @@ class LintersPaths:
|
||||
"demo/guide-python/feature_weights.py",
|
||||
"demo/guide-python/sklearn_parallel.py",
|
||||
"demo/guide-python/spark_estimator_examples.py",
|
||||
"demo/guide-python/external_memory.py",
|
||||
"demo/guide-python/individual_trees.py",
|
||||
"demo/guide-python/quantile_regression.py",
|
||||
"demo/guide-python/multioutput_regression.py",
|
||||
|
||||
@ -2,15 +2,26 @@
|
||||
* Copyright 2018-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/base.h> // for bst_bin_t
|
||||
#include <xgboost/context.h> // for Context
|
||||
#include <xgboost/data.h> // for BatchIterator, BatchSet, DMatrix, Met...
|
||||
|
||||
#include "../../../src/common/column_matrix.h"
|
||||
#include "../helpers.h"
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t, uint16_t, uint8_t
|
||||
#include <limits> // for numeric_limits
|
||||
#include <memory> // for shared_ptr, __shared_ptr_access, allo...
|
||||
#include <type_traits> // for remove_reference_t
|
||||
|
||||
#include "../../../src/common/column_matrix.h" // for ColumnMatrix, Column, DenseColumnIter
|
||||
#include "../../../src/common/hist_util.h" // for DispatchBinType, BinTypeSize, Index
|
||||
#include "../../../src/common/ref_resource_view.h" // for RefResourceView
|
||||
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
|
||||
#include "../../../src/data/iterative_dmatrix.h" // for IterativeDMatrix
|
||||
#include "../../../src/tree/param.h" // for TrainParam
|
||||
#include "../helpers.h" // for RandomDataGenerator, NumpyArrayIterFo...
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
TEST(DenseColumn, Test) {
|
||||
namespace xgboost::common {
|
||||
TEST(ColumnMatrix, Basic) {
|
||||
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
|
||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||
@ -22,7 +33,7 @@ TEST(DenseColumn, Test) {
|
||||
GHistIndexMatrix gmat{&ctx, dmat.get(), max_num_bin, sparse_thresh, false};
|
||||
ColumnMatrix column_matrix;
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
column_matrix.InitFromSparse(page, gmat, sparse_thresh, AllThreadsForTest());
|
||||
column_matrix.InitFromSparse(page, gmat, sparse_thresh, ctx.Threads());
|
||||
}
|
||||
ASSERT_GE(column_matrix.GetTypeSize(), last);
|
||||
ASSERT_LE(column_matrix.GetTypeSize(), kUint32BinsTypeSize);
|
||||
@ -59,7 +70,7 @@ void CheckSparseColumn(SparseColumnIter<BinIdxType>* p_col, const GHistIndexMatr
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SparseColumn, Test) {
|
||||
TEST(ColumnMatrix, SparseColumn) {
|
||||
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
|
||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||
@ -69,7 +80,7 @@ TEST(SparseColumn, Test) {
|
||||
GHistIndexMatrix gmat{&ctx, dmat.get(), max_num_bin, 0.5f, false};
|
||||
ColumnMatrix column_matrix;
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
column_matrix.InitFromSparse(page, gmat, 1.0, AllThreadsForTest());
|
||||
column_matrix.InitFromSparse(page, gmat, 1.0, ctx.Threads());
|
||||
}
|
||||
common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
|
||||
using T = decltype(dtype);
|
||||
@ -90,7 +101,7 @@ void CheckColumWithMissingValue(const DenseColumnIter<BinIdxType, true>& col,
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DenseColumnWithMissing, Test) {
|
||||
TEST(ColumnMatrix, DenseColumnWithMissing) {
|
||||
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
|
||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||
@ -100,7 +111,7 @@ TEST(DenseColumnWithMissing, Test) {
|
||||
GHistIndexMatrix gmat(&ctx, dmat.get(), max_num_bin, 0.2, false);
|
||||
ColumnMatrix column_matrix;
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
column_matrix.InitFromSparse(page, gmat, 0.2, AllThreadsForTest());
|
||||
column_matrix.InitFromSparse(page, gmat, 0.2, ctx.Threads());
|
||||
}
|
||||
ASSERT_TRUE(column_matrix.AnyMissing());
|
||||
DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
|
||||
@ -110,5 +121,29 @@ TEST(DenseColumnWithMissing, Test) {
|
||||
});
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
TEST(ColumnMatrix, GrowMissing) {
|
||||
float sparsity = 0.5;
|
||||
NumpyArrayIterForTest iter(sparsity);
|
||||
auto n_threads = 0;
|
||||
bst_bin_t n_bins = 16;
|
||||
BatchParam batch{n_bins, tree::TrainParam::DftSparseThreshold()};
|
||||
Context ctx;
|
||||
auto m = std::make_shared<data::IterativeDMatrix>(&iter, iter.Proxy(), nullptr, Reset, Next,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
n_threads, n_bins);
|
||||
for (auto const& page : m->GetBatches<GHistIndexMatrix>(&ctx, batch)) {
|
||||
auto const& column_matrix = page.Transpose();
|
||||
auto const& missing = column_matrix.Missing();
|
||||
auto n = NumpyArrayIterForTest::Rows() * NumpyArrayIterForTest::Cols();
|
||||
auto expected = std::remove_reference_t<decltype(missing)>::BitFieldT::ComputeStorageSize(n);
|
||||
auto got = missing.storage.size();
|
||||
ASSERT_EQ(expected, got);
|
||||
DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
|
||||
using T = decltype(dtype);
|
||||
auto col = column_matrix.DenseColumn<T, true>(0);
|
||||
CheckColumWithMissingValue(col, page);
|
||||
});
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::common
|
||||
|
||||
@ -119,6 +119,20 @@ TEST(IO, Resource) {
|
||||
for (std::size_t i = n; i < 2 * n; ++i) {
|
||||
ASSERT_EQ(malloc_resource->DataAs<std::uint8_t>()[i], 0);
|
||||
}
|
||||
|
||||
ptr = malloc_resource->DataAs<std::uint8_t>();
|
||||
std::fill_n(ptr, malloc_resource->Size(), 7);
|
||||
if (force_malloc) {
|
||||
malloc_resource->Resize<true>(n * 3, std::byte{3});
|
||||
} else {
|
||||
malloc_resource->Resize<false>(n * 3, std::byte{3});
|
||||
}
|
||||
for (std::size_t i = 0; i < n * 2; ++i) {
|
||||
ASSERT_EQ(malloc_resource->DataAs<std::uint8_t>()[i], 7);
|
||||
}
|
||||
for (std::size_t i = n * 2; i < n * 3; ++i) {
|
||||
ASSERT_EQ(malloc_resource->DataAs<std::uint8_t>()[i], 3);
|
||||
}
|
||||
};
|
||||
test_malloc_resize(true);
|
||||
test_malloc_resize(false);
|
||||
|
||||
@ -12,8 +12,7 @@
|
||||
#include "../helpers.h"
|
||||
#include "xgboost/data.h" // DMatrix
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
namespace xgboost::data {
|
||||
TEST(IterativeDMatrix, Ref) {
|
||||
Context ctx;
|
||||
TestRefDMatrix<GHistIndexMatrix, NumpyArrayIterForTest>(
|
||||
@ -21,7 +20,7 @@ TEST(IterativeDMatrix, Ref) {
|
||||
}
|
||||
|
||||
TEST(IterativeDMatrix, IsDense) {
|
||||
int n_bins = 16;
|
||||
bst_bin_t n_bins = 16;
|
||||
auto test = [n_bins](float sparsity) {
|
||||
NumpyArrayIterForTest iter(sparsity);
|
||||
auto n_threads = 0;
|
||||
@ -38,5 +37,4 @@ TEST(IterativeDMatrix, IsDense) {
|
||||
test(0.1);
|
||||
test(1.0);
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::data
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import Callable, Dict, List
|
||||
import weakref
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -179,5 +180,18 @@ def test_data_cache() -> None:
|
||||
data = make_batches(n_samples_per_batch, n_features, n_batches, False)
|
||||
batches = [v[0] for v in data]
|
||||
it = IterForCacheTest(*batches)
|
||||
transform = xgb.data._proxy_transform
|
||||
|
||||
called = 0
|
||||
|
||||
def mock(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal called
|
||||
called += 1
|
||||
return transform(*args, **kwargs)
|
||||
|
||||
xgb.data._proxy_transform = mock
|
||||
xgb.QuantileDMatrix(it)
|
||||
assert it._input_id == id(batches[0])
|
||||
assert it._data_ref is weakref.ref(batches[0])
|
||||
assert called == 1
|
||||
|
||||
xgb.data._proxy_transform = transform
|
||||
|
||||
@ -103,12 +103,29 @@ class TestQuantileDMatrix:
|
||||
*make_batches_sparse(
|
||||
n_samples_per_batch, n_features, n_batches, sparsity
|
||||
),
|
||||
None
|
||||
None,
|
||||
)
|
||||
Xy = xgb.QuantileDMatrix(it)
|
||||
assert Xy.num_row() == n_samples_per_batch * n_batches
|
||||
assert Xy.num_col() == n_features
|
||||
|
||||
def test_different_size(self) -> None:
|
||||
n_samples_per_batch = 317
|
||||
n_features = 8
|
||||
n_batches = 7
|
||||
|
||||
it = IteratorForTest(
|
||||
*make_batches(
|
||||
n_samples_per_batch, n_features, n_batches, False, vary_size=True
|
||||
),
|
||||
cache=None,
|
||||
)
|
||||
Xy = xgb.QuantileDMatrix(it)
|
||||
assert Xy.num_row() == 2429
|
||||
X, y, w = it.as_arrays()
|
||||
Xy1 = xgb.QuantileDMatrix(X, y, weight=w)
|
||||
assert predictor_equal(Xy, Xy1)
|
||||
|
||||
@pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.5, 0.8, 0.9])
|
||||
def test_training(self, sparsity: float) -> None:
|
||||
n_samples_per_batch = 317
|
||||
@ -123,7 +140,7 @@ class TestQuantileDMatrix:
|
||||
*make_batches_sparse(
|
||||
n_samples_per_batch, n_features, n_batches, sparsity
|
||||
),
|
||||
None
|
||||
None,
|
||||
)
|
||||
|
||||
parameters = {"tree_method": "hist", "max_bin": 256}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user