/*! * Copyright (c) by Contributors 2020 * \file intrusive_ptr.h * \brief Implementation of Intrusive Ptr. */ #ifndef XGBOOST_INTRUSIVE_PTR_H_ #define XGBOOST_INTRUSIVE_PTR_H_ #include #include #include namespace xgboost { /*! * \brief Helper class for embedding reference counting into client objects. See * https://www.boost.org/doc/libs/1_74_0/doc/html/atomic/usage_examples.html for * discussions of memory order. */ class IntrusivePtrCell { private: std::atomic count_; template friend class IntrusivePtr; std::int32_t IncRef() noexcept { return count_.fetch_add(1, std::memory_order_relaxed); } std::int32_t DecRef() noexcept { return count_.fetch_sub(1, std::memory_order_release); } bool IsZero() const { return Count() == 0; } public: IntrusivePtrCell() noexcept : count_{0} {} int32_t Count() const { return count_.load(std::memory_order_relaxed); } }; /*! * \brief User defined function for returing embedded reference count. */ template IntrusivePtrCell &IntrusivePtrRefCount(T const *ptr) noexcept; /*! * \brief Implementation of Intrusive Pointer. A smart pointer that points to an object * with an embedded reference counter. The underlying object must implement a * friend function IntrusivePtrRefCount() that returns the ref counter (of type * IntrusivePtrCell). The intrusive pointer is faster than std::shared_ptr<>: * std::shared_ptr<> makes an extra memory allocation for the ref counter whereas * the intrusive pointer does not. * * \code * * class ForIntrusivePtrTest { * public: * mutable class IntrusivePtrCell ref; * float data { 0 }; * * friend IntrusivePtrCell & * IntrusivePtrRefCount(ForIntrusivePtrTest const *t) noexcept { // NOLINT * return t->ref; * } * * ForIntrusivePtrTest() = default; * ForIntrusivePtrTest(float a, int32_t b) : data{a + static_cast(b)} {} * * explicit ForIntrusivePtrTest(NotCopyConstructible a) : data{a.data} {} * }; * * IntrusivePtr ptr {new ForIntrusivePtrTest}; * * \endcode */ template class IntrusivePtr { private: void IncRef(T *ptr) { if (ptr) { IntrusivePtrRefCount(ptr).IncRef(); } } void DecRef(T *ptr) { if (ptr) { if (IntrusivePtrRefCount(ptr).DecRef() == 1) { std::atomic_thread_fence(std::memory_order_acquire); delete ptr; } } } protected: T *ptr_{nullptr}; public: using element_type = T; // NOLINT struct Hash { std::size_t operator()(IntrusivePtr const &ptr) const noexcept { return std::hash()(ptr.get()); } }; /*! * \brief Contruct an IntrusivePtr from raw pointer. IntrusivePtr takes the ownership. * * \param p Raw pointer to object */ explicit IntrusivePtr(T *p) : ptr_{p} { if (ptr_) { IncRef(ptr_); } } IntrusivePtr() noexcept = default; IntrusivePtr(IntrusivePtr const &that) : ptr_{that.ptr_} { IncRef(ptr_); } IntrusivePtr(IntrusivePtr &&that) noexcept : ptr_{that.ptr_} { that.ptr_ = nullptr; } ~IntrusivePtr() { DecRef(ptr_); } IntrusivePtr &operator=(IntrusivePtr const &that) { IntrusivePtr{that}.swap(*this); return *this; } IntrusivePtr &operator=(IntrusivePtr &&that) noexcept { std::swap(ptr_, that.ptr_); return *this; } void reset() { // NOLINT DecRef(ptr_); ptr_ = nullptr; } void reset(element_type *that) { IntrusivePtr{that}.swap(*this); } // NOLINT element_type &operator*() const noexcept { return *ptr_; } element_type *operator->() const noexcept { return ptr_; } element_type *get() const noexcept { return ptr_; } // NOLINT explicit operator bool() const noexcept { return static_cast(ptr_); } int32_t use_count() noexcept { // NOLINT return ptr_ ? IntrusivePtrRefCount(ptr_).Count() : 0; } /* * \brief Helper function for swapping 2 pointers. */ void swap(IntrusivePtr &that) noexcept { // NOLINT std::swap(ptr_, that.ptr_); } }; template bool operator==(IntrusivePtr const &x, IntrusivePtr const &y) noexcept { return x.get() == y.get(); } template bool operator!=(IntrusivePtr const &x, IntrusivePtr const &y) noexcept { return x.get() != y.get(); } template bool operator==(IntrusivePtr const &x, U *y) noexcept { return x.get() == y; } template bool operator!=(IntrusivePtr const &x, U *y) noexcept { return x.get() != y; } template bool operator==(T *x, IntrusivePtr const &y) noexcept { return y == x; } template bool operator!=(T *x, IntrusivePtr const &y) noexcept { return y != x; } template bool operator<(IntrusivePtr const &x, IntrusivePtr const &y) noexcept { return std::less{}(x.get(), y.get()); } template bool operator<=(IntrusivePtr const &x, IntrusivePtr const &y) noexcept { return std::less_equal{}(x.get(), y.get()); } template bool operator>(IntrusivePtr const &x, IntrusivePtr const &y) noexcept { return !(x <= y); } template bool operator>=(IntrusivePtr const &x, IntrusivePtr const &y) noexcept { return !(x < y); } template std::basic_ostream &operator<<(std::basic_ostream &os, IntrusivePtr const &p) { os << p.get(); return os; } } // namespace xgboost namespace std { template void swap(xgboost::IntrusivePtr &x, // NOLINT xgboost::IntrusivePtr &y) noexcept { x.swap(y); } template struct hash> : public xgboost::IntrusivePtr::Hash {}; } // namespace std #endif // XGBOOST_INTRUSIVE_PTR_H_