Implement intrusive ptr (#6129)

* Use intrusive ptr for JSON.
This commit is contained in:
Jiaming Yuan
2020-09-20 20:07:16 +08:00
committed by GitHub
parent e319b63f9e
commit a069a21e03
4 changed files with 365 additions and 9 deletions

View File

@@ -0,0 +1,216 @@
/*!
* Copyright (c) by Contributors 2020
* \file intrusive_ptr.h
* \brief Implementation of Intrusive Ptr.
*/
#ifndef XGBOOST_INTRUSIVE_PTR_H_
#define XGBOOST_INTRUSIVE_PTR_H_
#include <atomic>
#include <cinttypes>
#include <functional>
namespace xgboost {
/*!
* \brief Helper class for embedding reference counting into client objects. See
* https://www.boost.org/doc/libs/1_74_0/doc/html/atomic/usage_examples.html for
* discussions of memory order.
*/
class IntrusivePtrCell {
private:
std::atomic<int32_t> count_;
template <typename T> friend class IntrusivePtr;
std::int32_t IncRef() noexcept {
return count_.fetch_add(1, std::memory_order_relaxed);
}
std::int32_t DecRef() noexcept {
return count_.fetch_sub(1, std::memory_order_release);
}
bool IsZero() const { return Count() == 0; }
public:
IntrusivePtrCell() noexcept : count_{0} {}
int32_t Count() const { return count_.load(std::memory_order_relaxed); }
};
/*!
* \brief User defined function for returing embedded reference count.
*/
template <typename T> IntrusivePtrCell &IntrusivePtrRefCount(T const *ptr) noexcept;
/*!
* \brief Implementation of Intrusive Pointer. A smart pointer that points to an object
* with an embedded reference counter. The underlying object must implement a
* friend function IntrusivePtrRefCount() that returns the ref counter (of type
* IntrusivePtrCell). The intrusive pointer is faster than std::shared_ptr<>:
* std::shared_ptr<> makes an extra memory allocation for the ref counter whereas
* the intrusive pointer does not.
*
* \code
*
* class ForIntrusivePtrTest {
* public:
* mutable class IntrusivePtrCell ref;
* float data { 0 };
*
* friend IntrusivePtrCell &
* IntrusivePtrRefCount(ForIntrusivePtrTest const *t) noexcept { // NOLINT
* return t->ref;
* }
*
* ForIntrusivePtrTest() = default;
* ForIntrusivePtrTest(float a, int32_t b) : data{a + static_cast<float>(b)} {}
*
* explicit ForIntrusivePtrTest(NotCopyConstructible a) : data{a.data} {}
* };
*
* IntrusivePtr<ForIntrusivePtrTest> ptr {new ForIntrusivePtrTest};
*
* \endcode
*/
template <typename T> class IntrusivePtr {
private:
void IncRef(T *ptr) {
if (ptr) {
IntrusivePtrRefCount(ptr).IncRef();
}
}
void DecRef(T *ptr) {
if (ptr) {
if (IntrusivePtrRefCount(ptr).DecRef() == 1) {
std::atomic_thread_fence(std::memory_order_acquire);
delete ptr;
}
}
}
protected:
T *ptr_{nullptr};
public:
using element_type = T; // NOLINT
struct Hash {
std::size_t operator()(IntrusivePtr<element_type> const &ptr) const noexcept {
return std::hash<element_type *>()(ptr.get());
}
};
/*!
* \brief Contruct an IntrusivePtr from raw pointer. IntrusivePtr takes the ownership.
*
* \param p Raw pointer to object
*/
explicit IntrusivePtr(T *p) : ptr_{p} {
if (ptr_) {
IncRef(ptr_);
}
}
IntrusivePtr() noexcept = default;
IntrusivePtr(IntrusivePtr const &that) : ptr_{that.ptr_} { IncRef(ptr_); }
IntrusivePtr(IntrusivePtr &&that) noexcept : ptr_{that.ptr_} { that.ptr_ = nullptr; }
~IntrusivePtr() { DecRef(ptr_); }
IntrusivePtr<T> &operator=(IntrusivePtr<T> const &that) {
IntrusivePtr<T>{that}.swap(*this);
return *this;
}
IntrusivePtr<T> &operator=(IntrusivePtr<T> &&that) noexcept {
std::swap(ptr_, that.ptr_);
return *this;
}
void reset() { // NOLINT
DecRef(ptr_);
ptr_ = nullptr;
}
void reset(element_type *that) { IntrusivePtr{that}.swap(*this); } // NOLINT
element_type &operator*() const noexcept { return *ptr_; }
element_type *operator->() const noexcept { return ptr_; }
element_type *get() const noexcept { return ptr_; } // NOLINT
explicit operator bool() const noexcept { return static_cast<bool>(ptr_); }
int32_t use_count() noexcept { // NOLINT
return ptr_ ? IntrusivePtrRefCount(ptr_).Count() : 0;
}
/*
* \brief Helper function for swapping 2 pointers.
*/
void swap(IntrusivePtr<T> &that) noexcept { // NOLINT
std::swap(ptr_, that.ptr_);
}
};
template <class T, class U>
bool operator==(IntrusivePtr<T> const &x, IntrusivePtr<U> const &y) noexcept {
return x.get() == y.get();
}
template <class T, class U>
bool operator!=(IntrusivePtr<T> const &x, IntrusivePtr<U> const &y) noexcept {
return x.get() != y.get();
}
template <class T, class U>
bool operator==(IntrusivePtr<T> const &x, U *y) noexcept {
return x.get() == y;
}
template <class T, class U>
bool operator!=(IntrusivePtr<T> const &x, U *y) noexcept {
return x.get() != y;
}
template <class T, class U>
bool operator==(T *x, IntrusivePtr<U> const &y) noexcept {
return y == x;
}
template <class T, class U>
bool operator!=(T *x, IntrusivePtr<U> const &y) noexcept {
return y != x;
}
template <class T>
bool operator<(IntrusivePtr<T> const &x, IntrusivePtr<T> const &y) noexcept {
return std::less<T*>{}(x.get(), y.get());
}
template <class T>
bool operator<=(IntrusivePtr<T> const &x, IntrusivePtr<T> const &y) noexcept {
return std::less_equal<T*>{}(x.get(), y.get());
}
template <class T>
bool operator>(IntrusivePtr<T> const &x, IntrusivePtr<T> const &y) noexcept {
return !(x <= y);
}
template <class T>
bool operator>=(IntrusivePtr<T> const &x, IntrusivePtr<T> const &y) noexcept {
return !(x < y);
}
template <class E, class T, class Y>
std::basic_ostream<E, T> &operator<<(std::basic_ostream<E, T> &os,
IntrusivePtr<Y> const &p) {
os << p.get();
return os;
}
} // namespace xgboost
namespace std {
template <class T>
void swap(xgboost::IntrusivePtr<T> &x, // NOLINT
xgboost::IntrusivePtr<T> &y) noexcept {
x.swap(y);
}
template <typename T>
struct hash<xgboost::IntrusivePtr<T>> : public xgboost::IntrusivePtr<T>::Hash {};
} // namespace std
#endif // XGBOOST_INTRUSIVE_PTR_H_

View File

@@ -6,6 +6,7 @@
#include <xgboost/logging.h>
#include <xgboost/parameter.h>
#include <xgboost/intrusive_ptr.h>
#include <map>
#include <memory>
@@ -21,6 +22,13 @@ class JsonReader;
class JsonWriter;
class Value {
private:
mutable class IntrusivePtrCell ref_;
friend IntrusivePtrCell &
IntrusivePtrRefCount(xgboost::Value const *t) noexcept {
return t->ref_;
}
public:
/*!\brief Simplified implementation of LLVM RTTI. */
enum class ValueKind {
@@ -69,12 +77,15 @@ T* Cast(U* value) {
class JsonString : public Value {
std::string str_;
public:
JsonString() : Value(ValueKind::kString) {}
JsonString(std::string const& str) : // NOLINT
Value(ValueKind::kString), str_{str} {}
JsonString(std::string&& str) : // NOLINT
Value(ValueKind::kString), str_{std::move(str)} {}
JsonString(JsonString&& str) : // NOLINT
Value(ValueKind::kString), str_{std::move(str.str_)} {}
void Save(JsonWriter* writer) override;
@@ -167,6 +178,8 @@ class JsonNumber : public Value {
typename std::enable_if<std::is_same<FloatT, double>::value>::type* = nullptr>
JsonNumber(FloatT value) : Value{ValueKind::kNumber}, // NOLINT
number_{static_cast<Float>(value)} {}
JsonNumber(JsonNumber const& that) = delete;
JsonNumber(JsonNumber&& that) : Value{ValueKind::kNumber}, number_{that.number_} {}
void Save(JsonWriter* writer) override;
@@ -214,6 +227,9 @@ class JsonInteger : public Value {
: Value(ValueKind::kInteger),
integer_{static_cast<Int>(value)} {}
JsonInteger(JsonInteger &&that)
: Value{ValueKind::kInteger}, integer_{that.integer_} {}
Json& operator[](std::string const & key) override;
Json& operator[](int ind) override;
@@ -234,6 +250,7 @@ class JsonNull : public Value {
public:
JsonNull() : Value(ValueKind::kNull) {}
JsonNull(std::nullptr_t) : Value(ValueKind::kNull) {} // NOLINT
JsonNull(JsonNull&& that) : Value(ValueKind::kNull) {}
void Save(JsonWriter* writer) override;
@@ -261,6 +278,8 @@ class JsonBoolean : public Value {
std::is_same<Bool, bool const>::value>::type* = nullptr>
JsonBoolean(Bool value) : // NOLINT
Value(ValueKind::kBoolean), boolean_{value} {}
JsonBoolean(JsonBoolean&& value) : // NOLINT
Value(ValueKind::kBoolean), boolean_{value.boolean_} {}
void Save(JsonWriter* writer) override;
@@ -336,14 +355,14 @@ class Json {
Json() : ptr_{new JsonNull} {}
// number
explicit Json(JsonNumber number) : ptr_{new JsonNumber(number)} {}
explicit Json(JsonNumber number) : ptr_{new JsonNumber(std::move(number))} {}
Json& operator=(JsonNumber number) {
ptr_.reset(new JsonNumber(std::move(number)));
return *this;
}
// integer
explicit Json(JsonInteger integer) : ptr_{new JsonInteger(integer)} {}
explicit Json(JsonInteger integer) : ptr_{new JsonInteger(std::move(integer))} {}
Json& operator=(JsonInteger integer) {
ptr_.reset(new JsonInteger(std::move(integer)));
return *this;
@@ -418,7 +437,7 @@ class Json {
}
private:
std::shared_ptr<Value> ptr_;
IntrusivePtr<Value> ptr_;
};
template <typename T>