Use weakref instead of id for DataIter cache. (#9445)

- Fix case where Python reuses id from freed objects.
- Small optimization to column matrix with QDM by using `realloc` instead of copying data.
This commit is contained in:
Jiaming Yuan 2023-08-10 00:40:06 +08:00 committed by GitHub
parent d495a180d8
commit f05a23b41c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 193 additions and 63 deletions

View File

@ -22,7 +22,10 @@ import xgboost
def make_batches(
n_samples_per_batch: int, n_features: int, n_batches: int, tmpdir: str,
n_samples_per_batch: int,
n_features: int,
n_batches: int,
tmpdir: str,
) -> List[Tuple[str, str]]:
files: List[Tuple[str, str]] = []
rng = np.random.RandomState(1994)
@ -38,6 +41,7 @@ def make_batches(
class Iterator(xgboost.DataIter):
"""A custom iterator for loading files in batches."""
def __init__(self, file_paths: List[Tuple[str, str]]):
self._file_paths = file_paths
self._it = 0

View File

@ -23,12 +23,16 @@ Core Data Structure
:show-inheritance:
.. autoclass:: xgboost.QuantileDMatrix
:members:
:show-inheritance:
.. autoclass:: xgboost.Booster
:members:
:show-inheritance:
.. autoclass:: xgboost.DataIter
:members:
:show-inheritance:
Learning API
------------

View File

@ -8,7 +8,9 @@ from typing import (
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
@ -20,8 +22,6 @@ import numpy as np
DataType = Any
# xgboost accepts some other possible types in practice due to historical reason, which is
# lesser tested. For now we encourage users to pass a simple list of string.
FeatureInfo = Sequence[str]
FeatureNames = FeatureInfo
FeatureTypes = FeatureInfo
@ -97,6 +97,13 @@ else:
ctypes._Pointer,
]
# The second arg is actually Optional[List[cudf.Series]], skipped for easier type check.
# The cudf Series is the obtained cat codes, preserved in the `DataIter` to prevent it
# being freed.
TransformedData = Tuple[
Any, Optional[List], Optional[FeatureNames], Optional[FeatureTypes]
]
# template parameter
_T = TypeVar("_T")
_F = TypeVar("_F", bound=Callable[..., Any])

View File

@ -9,6 +9,7 @@ import os
import re
import sys
import warnings
import weakref
from abc import ABC, abstractmethod
from collections.abc import Mapping
from enum import IntEnum, unique
@ -51,6 +52,7 @@ from ._typing import (
FeatureTypes,
ModelIn,
NumpyOrCupy,
TransformedData,
c_bst_ulong,
)
from .compat import PANDAS_INSTALLED, DataFrame, py_str
@ -486,7 +488,16 @@ def _prediction_output(
class DataIter(ABC): # pylint: disable=too-many-instance-attributes
"""The interface for user defined data iterator.
"""The interface for user defined data iterator. The iterator facilitates
distributed training, :py:class:`QuantileDMatrix`, and external memory support using
:py:class:`DMatrix`. Most of time, users don't need to interact with this class
directly.
.. note::
The class caches some intermediate results using the `data` input (predictor
`X`) as key. Don't repeat the `X` for multiple batches with different meta data
(like `label`), make a copy if necessary.
Parameters
----------
@ -510,13 +521,13 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
self._allow_host = True
self._release = release_data
# Stage data in Python until reset or next is called to avoid data being free.
self._temporary_data: Optional[Tuple[Any, Any, Any, Any]] = None
self._input_id: int = 0
self._temporary_data: Optional[TransformedData] = None
self._data_ref: Optional[weakref.ReferenceType] = None
def get_callbacks(
self, allow_host: bool, enable_categorical: bool
) -> Tuple[Callable, Callable]:
"""Get callback functions for iterating in C."""
"""Get callback functions for iterating in C. This is an internal function."""
assert hasattr(self, "cache_prefix"), "__init__ is not called."
self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(
self._reset_wrapper
@ -591,7 +602,19 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
from .data import _proxy_transform, dispatch_proxy_set_data
# Reduce the amount of transformation that's needed for QuantileDMatrix.
if self._temporary_data is not None and id(data) == self._input_id:
#
# To construct the QDM, one needs 4 iterations on CPU, or 2 iterations on
# GPU. If the QDM has only one batch of input (most of the cases), we can
# avoid transforming the data repeatly.
try:
ref = weakref.ref(data)
except TypeError:
ref = None
if (
self._temporary_data is not None
and ref is not None
and ref is self._data_ref
):
new, cat_codes, feature_names, feature_types = self._temporary_data
else:
new, cat_codes, feature_names, feature_types = _proxy_transform(
@ -608,7 +631,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
feature_types=feature_types,
**kwargs,
)
self._input_id = id(data)
self._data_ref = ref
# pylint: disable=not-callable
return self._handle_exception(lambda: self.next(input_data), 0)
@ -1134,7 +1157,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
testing purposes. If this is a quantized DMatrix then quantized values are
returned instead of input values.
.. versionadded:: 1.7.0
.. versionadded:: 1.7.0
"""
indptr = np.empty(self.num_row() + 1, dtype=np.uint64)
@ -1155,7 +1178,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
return ret
def get_quantile_cut(self) -> Tuple[np.ndarray, np.ndarray]:
"""Get quantile cuts for quantization."""
"""Get quantile cuts for quantization.
.. versionadded:: 2.0.0
"""
n_features = self.num_col()
c_sindptr = ctypes.c_char_p()

View File

@ -5,7 +5,7 @@ import ctypes
import json
import os
import warnings
from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, Union, cast
from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, cast
import numpy as np
@ -17,6 +17,7 @@ from ._typing import (
FloatCompatible,
NumpyDType,
PandasDType,
TransformedData,
c_bst_ulong,
)
from .compat import DataFrame, lazy_isinstance
@ -1268,12 +1269,7 @@ def _proxy_transform(
feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes],
enable_categorical: bool,
) -> Tuple[
Union[bool, ctypes.c_void_p, np.ndarray],
Optional[list],
Optional[FeatureNames],
Optional[FeatureTypes],
]:
) -> TransformedData:
if _is_cudf_df(data) or _is_cudf_ser(data):
return _transform_cudf_df(
data, feature_names, feature_types, enable_categorical

View File

@ -230,7 +230,7 @@ class IteratorForTest(xgb.core.DataIter):
def as_arrays(
self,
) -> Tuple[Union[np.ndarray, sparse.csr_matrix], ArrayLike, ArrayLike]:
) -> Tuple[Union[np.ndarray, sparse.csr_matrix], ArrayLike, Optional[ArrayLike]]:
if isinstance(self.X[0], sparse.csr_matrix):
X = sparse.vstack(self.X, format="csr")
else:
@ -244,7 +244,12 @@ class IteratorForTest(xgb.core.DataIter):
def make_batches(
n_samples_per_batch: int, n_features: int, n_batches: int, use_cupy: bool = False
n_samples_per_batch: int,
n_features: int,
n_batches: int,
use_cupy: bool = False,
*,
vary_size: bool = False,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
X = []
y = []
@ -255,10 +260,11 @@ def make_batches(
rng = cupy.random.RandomState(1994)
else:
rng = np.random.RandomState(1994)
for _ in range(n_batches):
_X = rng.randn(n_samples_per_batch, n_features)
_y = rng.randn(n_samples_per_batch)
_w = rng.uniform(low=0, high=1, size=n_samples_per_batch)
for i in range(n_batches):
n_samples = n_samples_per_batch + i * 10 if vary_size else n_samples_per_batch
_X = rng.randn(n_samples, n_features)
_y = rng.randn(n_samples)
_w = rng.uniform(low=0, high=1, size=n_samples)
X.append(_X)
y.append(_y)
w.append(_w)

View File

@ -9,12 +9,12 @@
#define XGBOOST_COMMON_COLUMN_MATRIX_H_
#include <algorithm>
#include <cstddef> // for size_t
#include <cstddef> // for size_t, byte
#include <cstdint> // for uint8_t
#include <limits>
#include <memory>
#include <utility> // for move
#include <vector>
#include <type_traits> // for enable_if_t, is_same_v, is_signed_v
#include <utility> // for move
#include "../data/adapter.h"
#include "../data/gradient_index.h"
@ -112,9 +112,6 @@ class SparseColumnIter : public Column<BinIdxT> {
*/
template <typename BinIdxT, bool any_missing>
class DenseColumnIter : public Column<BinIdxT> {
public:
using ByteType = bool;
private:
using Base = Column<BinIdxT>;
/* flags for missing values in dense columns */
@ -153,8 +150,17 @@ class ColumnMatrix {
* @brief A bit set for indicating whether an element in a dense column is missing.
*/
struct MissingIndicator {
LBitField32 missing;
RefResourceView<std::uint32_t> storage;
using BitFieldT = LBitField32;
using T = typename BitFieldT::value_type;
BitFieldT missing;
RefResourceView<T> storage;
static_assert(std::is_same_v<T, std::uint32_t>);
template <typename U>
[[nodiscard]] std::enable_if_t<!std::is_signed_v<U>, U> static InitValue(bool init) {
return init ? ~U{0} : U{0};
}
MissingIndicator() = default;
/**
@ -163,7 +169,7 @@ class ColumnMatrix {
*/
MissingIndicator(std::size_t n_elements, bool init) {
auto m_size = missing.ComputeStorageSize(n_elements);
storage = common::MakeFixedVecWithMalloc(m_size, init ? ~std::uint32_t{0} : std::uint32_t{0});
storage = common::MakeFixedVecWithMalloc(m_size, InitValue<T>(init));
this->InitView();
}
/** @brief Set the i^th element to be a valid element (instead of missing). */
@ -181,11 +187,12 @@ class ColumnMatrix {
if (m_size == storage.size()) {
return;
}
// grow the storage
auto resource = std::dynamic_pointer_cast<common::MallocResource>(storage.Resource());
CHECK(resource);
resource->Resize(m_size * sizeof(T), InitValue<std::byte>(init));
storage = RefResourceView<T>{resource->DataAs<T>(), m_size, resource};
auto new_storage =
common::MakeFixedVecWithMalloc(m_size, init ? ~std::uint32_t{0} : std::uint32_t{0});
std::copy_n(storage.cbegin(), storage.size(), new_storage.begin());
storage = std::move(new_storage);
this->InitView();
}
};
@ -210,7 +217,6 @@ class ColumnMatrix {
}
public:
using ByteType = bool;
// get number of features
[[nodiscard]] bst_feature_t GetNumFeature() const {
return static_cast<bst_feature_t>(type_.size());
@ -408,6 +414,7 @@ class ColumnMatrix {
// IO procedures for external memory.
[[nodiscard]] bool Read(AlignedResourceReadStream* fi, uint32_t const* index_base);
[[nodiscard]] std::size_t Write(AlignedFileWriteStream* fo) const;
[[nodiscard]] MissingIndicator const& Missing() const { return missing_; }
private:
RefResourceView<std::uint8_t> index_;

View File

@ -10,7 +10,7 @@
#include <dmlc/io.h>
#include <rabit/rabit.h>
#include <algorithm> // for min
#include <algorithm> // for min, fill_n, copy_n
#include <array> // for array
#include <cstddef> // for byte, size_t
#include <cstdlib> // for malloc, realloc, free
@ -207,7 +207,7 @@ class MallocResource : public ResourceHandler {
* @param n_bytes The new size.
*/
template <bool force_malloc = false>
void Resize(std::size_t n_bytes) {
void Resize(std::size_t n_bytes, std::byte init = std::byte{0}) {
// realloc(ptr, 0) works, but is deprecated.
if (n_bytes == 0) {
this->Clear();
@ -236,7 +236,7 @@ class MallocResource : public ResourceHandler {
std::copy_n(reinterpret_cast<std::byte*>(ptr_), n_, reinterpret_cast<std::byte*>(new_ptr));
}
// default initialize
std::memset(reinterpret_cast<std::byte*>(new_ptr) + n_, '\0', n_bytes - n_);
std::fill_n(reinterpret_cast<std::byte*>(new_ptr) + n_, n_bytes - n_, init);
// free the old ptr if malloc is used.
if (need_copy) {
this->Clear();

View File

@ -42,6 +42,7 @@ class LintersPaths:
"demo/guide-python/feature_weights.py",
"demo/guide-python/sklearn_parallel.py",
"demo/guide-python/spark_estimator_examples.py",
"demo/guide-python/external_memory.py",
"demo/guide-python/individual_trees.py",
"demo/guide-python/quantile_regression.py",
"demo/guide-python/multioutput_regression.py",

View File

@ -2,15 +2,26 @@
* Copyright 2018-2023 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/base.h> // for bst_bin_t
#include <xgboost/context.h> // for Context
#include <xgboost/data.h> // for BatchIterator, BatchSet, DMatrix, Met...
#include "../../../src/common/column_matrix.h"
#include "../helpers.h"
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, uint16_t, uint8_t
#include <limits> // for numeric_limits
#include <memory> // for shared_ptr, __shared_ptr_access, allo...
#include <type_traits> // for remove_reference_t
#include "../../../src/common/column_matrix.h" // for ColumnMatrix, Column, DenseColumnIter
#include "../../../src/common/hist_util.h" // for DispatchBinType, BinTypeSize, Index
#include "../../../src/common/ref_resource_view.h" // for RefResourceView
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
#include "../../../src/data/iterative_dmatrix.h" // for IterativeDMatrix
#include "../../../src/tree/param.h" // for TrainParam
#include "../helpers.h" // for RandomDataGenerator, NumpyArrayIterFo...
namespace xgboost {
namespace common {
TEST(DenseColumn, Test) {
namespace xgboost::common {
TEST(ColumnMatrix, Basic) {
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
@ -22,7 +33,7 @@ TEST(DenseColumn, Test) {
GHistIndexMatrix gmat{&ctx, dmat.get(), max_num_bin, sparse_thresh, false};
ColumnMatrix column_matrix;
for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.InitFromSparse(page, gmat, sparse_thresh, AllThreadsForTest());
column_matrix.InitFromSparse(page, gmat, sparse_thresh, ctx.Threads());
}
ASSERT_GE(column_matrix.GetTypeSize(), last);
ASSERT_LE(column_matrix.GetTypeSize(), kUint32BinsTypeSize);
@ -59,7 +70,7 @@ void CheckSparseColumn(SparseColumnIter<BinIdxType>* p_col, const GHistIndexMatr
}
}
TEST(SparseColumn, Test) {
TEST(ColumnMatrix, SparseColumn) {
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
@ -69,7 +80,7 @@ TEST(SparseColumn, Test) {
GHistIndexMatrix gmat{&ctx, dmat.get(), max_num_bin, 0.5f, false};
ColumnMatrix column_matrix;
for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.InitFromSparse(page, gmat, 1.0, AllThreadsForTest());
column_matrix.InitFromSparse(page, gmat, 1.0, ctx.Threads());
}
common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
using T = decltype(dtype);
@ -90,7 +101,7 @@ void CheckColumWithMissingValue(const DenseColumnIter<BinIdxType, true>& col,
}
}
TEST(DenseColumnWithMissing, Test) {
TEST(ColumnMatrix, DenseColumnWithMissing) {
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
@ -100,7 +111,7 @@ TEST(DenseColumnWithMissing, Test) {
GHistIndexMatrix gmat(&ctx, dmat.get(), max_num_bin, 0.2, false);
ColumnMatrix column_matrix;
for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.InitFromSparse(page, gmat, 0.2, AllThreadsForTest());
column_matrix.InitFromSparse(page, gmat, 0.2, ctx.Threads());
}
ASSERT_TRUE(column_matrix.AnyMissing());
DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
@ -110,5 +121,29 @@ TEST(DenseColumnWithMissing, Test) {
});
}
}
} // namespace common
} // namespace xgboost
TEST(ColumnMatrix, GrowMissing) {
float sparsity = 0.5;
NumpyArrayIterForTest iter(sparsity);
auto n_threads = 0;
bst_bin_t n_bins = 16;
BatchParam batch{n_bins, tree::TrainParam::DftSparseThreshold()};
Context ctx;
auto m = std::make_shared<data::IterativeDMatrix>(&iter, iter.Proxy(), nullptr, Reset, Next,
std::numeric_limits<float>::quiet_NaN(),
n_threads, n_bins);
for (auto const& page : m->GetBatches<GHistIndexMatrix>(&ctx, batch)) {
auto const& column_matrix = page.Transpose();
auto const& missing = column_matrix.Missing();
auto n = NumpyArrayIterForTest::Rows() * NumpyArrayIterForTest::Cols();
auto expected = std::remove_reference_t<decltype(missing)>::BitFieldT::ComputeStorageSize(n);
auto got = missing.storage.size();
ASSERT_EQ(expected, got);
DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
using T = decltype(dtype);
auto col = column_matrix.DenseColumn<T, true>(0);
CheckColumWithMissingValue(col, page);
});
}
}
} // namespace xgboost::common

View File

@ -119,6 +119,20 @@ TEST(IO, Resource) {
for (std::size_t i = n; i < 2 * n; ++i) {
ASSERT_EQ(malloc_resource->DataAs<std::uint8_t>()[i], 0);
}
ptr = malloc_resource->DataAs<std::uint8_t>();
std::fill_n(ptr, malloc_resource->Size(), 7);
if (force_malloc) {
malloc_resource->Resize<true>(n * 3, std::byte{3});
} else {
malloc_resource->Resize<false>(n * 3, std::byte{3});
}
for (std::size_t i = 0; i < n * 2; ++i) {
ASSERT_EQ(malloc_resource->DataAs<std::uint8_t>()[i], 7);
}
for (std::size_t i = n * 2; i < n * 3; ++i) {
ASSERT_EQ(malloc_resource->DataAs<std::uint8_t>()[i], 3);
}
};
test_malloc_resize(true);
test_malloc_resize(false);

View File

@ -12,8 +12,7 @@
#include "../helpers.h"
#include "xgboost/data.h" // DMatrix
namespace xgboost {
namespace data {
namespace xgboost::data {
TEST(IterativeDMatrix, Ref) {
Context ctx;
TestRefDMatrix<GHistIndexMatrix, NumpyArrayIterForTest>(
@ -21,7 +20,7 @@ TEST(IterativeDMatrix, Ref) {
}
TEST(IterativeDMatrix, IsDense) {
int n_bins = 16;
bst_bin_t n_bins = 16;
auto test = [n_bins](float sparsity) {
NumpyArrayIterForTest iter(sparsity);
auto n_threads = 0;
@ -38,5 +37,4 @@ TEST(IterativeDMatrix, IsDense) {
test(0.1);
test(1.0);
}
} // namespace data
} // namespace xgboost
} // namespace xgboost::data

View File

@ -1,4 +1,5 @@
from typing import Callable, Dict, List
import weakref
from typing import Any, Callable, Dict, List
import numpy as np
import pytest
@ -179,5 +180,18 @@ def test_data_cache() -> None:
data = make_batches(n_samples_per_batch, n_features, n_batches, False)
batches = [v[0] for v in data]
it = IterForCacheTest(*batches)
transform = xgb.data._proxy_transform
called = 0
def mock(*args: Any, **kwargs: Any) -> Any:
nonlocal called
called += 1
return transform(*args, **kwargs)
xgb.data._proxy_transform = mock
xgb.QuantileDMatrix(it)
assert it._input_id == id(batches[0])
assert it._data_ref is weakref.ref(batches[0])
assert called == 1
xgb.data._proxy_transform = transform

View File

@ -103,12 +103,29 @@ class TestQuantileDMatrix:
*make_batches_sparse(
n_samples_per_batch, n_features, n_batches, sparsity
),
None
None,
)
Xy = xgb.QuantileDMatrix(it)
assert Xy.num_row() == n_samples_per_batch * n_batches
assert Xy.num_col() == n_features
def test_different_size(self) -> None:
n_samples_per_batch = 317
n_features = 8
n_batches = 7
it = IteratorForTest(
*make_batches(
n_samples_per_batch, n_features, n_batches, False, vary_size=True
),
cache=None,
)
Xy = xgb.QuantileDMatrix(it)
assert Xy.num_row() == 2429
X, y, w = it.as_arrays()
Xy1 = xgb.QuantileDMatrix(X, y, weight=w)
assert predictor_equal(Xy, Xy1)
@pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.5, 0.8, 0.9])
def test_training(self, sparsity: float) -> None:
n_samples_per_batch = 317
@ -123,7 +140,7 @@ class TestQuantileDMatrix:
*make_batches_sparse(
n_samples_per_batch, n_features, n_batches, sparsity
),
None
None,
)
parameters = {"tree_method": "hist", "max_bin": 256}