Use weakref instead of id for DataIter cache. (#9445)

- Fix case where Python reuses id from freed objects.
- Small optimization to column matrix with QDM by using `realloc` instead of copying data.
This commit is contained in:
Jiaming Yuan
2023-08-10 00:40:06 +08:00
committed by GitHub
parent d495a180d8
commit f05a23b41c
14 changed files with 193 additions and 63 deletions

View File

@@ -9,12 +9,12 @@
#define XGBOOST_COMMON_COLUMN_MATRIX_H_
#include <algorithm>
#include <cstddef> // for size_t
#include <cstddef> // for size_t, byte
#include <cstdint> // for uint8_t
#include <limits>
#include <memory>
#include <utility> // for move
#include <vector>
#include <type_traits> // for enable_if_t, is_same_v, is_signed_v
#include <utility> // for move
#include "../data/adapter.h"
#include "../data/gradient_index.h"
@@ -112,9 +112,6 @@ class SparseColumnIter : public Column<BinIdxT> {
*/
template <typename BinIdxT, bool any_missing>
class DenseColumnIter : public Column<BinIdxT> {
public:
using ByteType = bool;
private:
using Base = Column<BinIdxT>;
/* flags for missing values in dense columns */
@@ -153,8 +150,17 @@ class ColumnMatrix {
* @brief A bit set for indicating whether an element in a dense column is missing.
*/
struct MissingIndicator {
LBitField32 missing;
RefResourceView<std::uint32_t> storage;
using BitFieldT = LBitField32;
using T = typename BitFieldT::value_type;
BitFieldT missing;
RefResourceView<T> storage;
static_assert(std::is_same_v<T, std::uint32_t>);
template <typename U>
[[nodiscard]] std::enable_if_t<!std::is_signed_v<U>, U> static InitValue(bool init) {
return init ? ~U{0} : U{0};
}
MissingIndicator() = default;
/**
@@ -163,7 +169,7 @@ class ColumnMatrix {
*/
MissingIndicator(std::size_t n_elements, bool init) {
auto m_size = missing.ComputeStorageSize(n_elements);
storage = common::MakeFixedVecWithMalloc(m_size, init ? ~std::uint32_t{0} : std::uint32_t{0});
storage = common::MakeFixedVecWithMalloc(m_size, InitValue<T>(init));
this->InitView();
}
/** @brief Set the i^th element to be a valid element (instead of missing). */
@@ -181,11 +187,12 @@ class ColumnMatrix {
if (m_size == storage.size()) {
return;
}
// grow the storage
auto resource = std::dynamic_pointer_cast<common::MallocResource>(storage.Resource());
CHECK(resource);
resource->Resize(m_size * sizeof(T), InitValue<std::byte>(init));
storage = RefResourceView<T>{resource->DataAs<T>(), m_size, resource};
auto new_storage =
common::MakeFixedVecWithMalloc(m_size, init ? ~std::uint32_t{0} : std::uint32_t{0});
std::copy_n(storage.cbegin(), storage.size(), new_storage.begin());
storage = std::move(new_storage);
this->InitView();
}
};
@@ -210,7 +217,6 @@ class ColumnMatrix {
}
public:
using ByteType = bool;
// get number of features
[[nodiscard]] bst_feature_t GetNumFeature() const {
return static_cast<bst_feature_t>(type_.size());
@@ -408,6 +414,7 @@ class ColumnMatrix {
// IO procedures for external memory.
[[nodiscard]] bool Read(AlignedResourceReadStream* fi, uint32_t const* index_base);
[[nodiscard]] std::size_t Write(AlignedFileWriteStream* fo) const;
[[nodiscard]] MissingIndicator const& Missing() const { return missing_; }
private:
RefResourceView<std::uint8_t> index_;

View File

@@ -10,7 +10,7 @@
#include <dmlc/io.h>
#include <rabit/rabit.h>
#include <algorithm> // for min
#include <algorithm> // for min, fill_n, copy_n
#include <array> // for array
#include <cstddef> // for byte, size_t
#include <cstdlib> // for malloc, realloc, free
@@ -207,7 +207,7 @@ class MallocResource : public ResourceHandler {
* @param n_bytes The new size.
*/
template <bool force_malloc = false>
void Resize(std::size_t n_bytes) {
void Resize(std::size_t n_bytes, std::byte init = std::byte{0}) {
// realloc(ptr, 0) works, but is deprecated.
if (n_bytes == 0) {
this->Clear();
@@ -236,7 +236,7 @@ class MallocResource : public ResourceHandler {
std::copy_n(reinterpret_cast<std::byte*>(ptr_), n_, reinterpret_cast<std::byte*>(new_ptr));
}
// default initialize
std::memset(reinterpret_cast<std::byte*>(new_ptr) + n_, '\0', n_bytes - n_);
std::fill_n(reinterpret_cast<std::byte*>(new_ptr) + n_, n_bytes - n_, init);
// free the old ptr if malloc is used.
if (need_copy) {
this->Clear();