From a069a21e038b12a66e87f4158037b6a101d69675 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sun, 20 Sep 2020 20:07:16 +0800 Subject: [PATCH] Implement intrusive ptr (#6129) * Use intrusive ptr for JSON. --- doc/contrib/unit_tests.rst | 24 ++- include/xgboost/intrusive_ptr.h | 216 +++++++++++++++++++++++++ include/xgboost/json.h | 25 ++- tests/cpp/common/test_intrusive_ptr.cc | 109 +++++++++++++ 4 files changed, 365 insertions(+), 9 deletions(-) create mode 100644 include/xgboost/intrusive_ptr.h create mode 100644 tests/cpp/common/test_intrusive_ptr.cc diff --git a/doc/contrib/unit_tests.rst b/doc/contrib/unit_tests.rst index de31cf26f..5131dbabb 100644 --- a/doc/contrib/unit_tests.rst +++ b/doc/contrib/unit_tests.rst @@ -134,16 +134,18 @@ One can also run all unit test using ctest tool which provides higher flexibilit Sanitizers: Detect memory errors and data races *********************************************** -By default, sanitizers are bundled in GCC and Clang/LLVM. One can enable -sanitizers with GCC >= 4.8 or LLVM >= 3.1, But some distributions might package -sanitizers separately. Here is a list of supported sanitizers with -corresponding library names: +By default, sanitizers are bundled in GCC and Clang/LLVM. One can enable sanitizers with +GCC >= 4.8 or LLVM >= 3.1, But some distributions might package sanitizers separately. +Here is a list of supported sanitizers with corresponding library names: - Address sanitizer: libasan +- Undefined sanitizer: libubsan - Leak sanitizer: liblsan - Thread sanitizer: libtsan -Memory sanitizer is exclusive to LLVM, hence not supported in XGBoost. +Memory sanitizer is exclusive to LLVM, hence not supported in XGBoost. With latest +compilers like gcc-9, when sanitizer flags are specified, the compiler driver should be +able to link the runtime libraries automatically. How to build XGBoost with sanitizers ==================================== @@ -175,5 +177,15 @@ environment variable: ASAN_OPTIONS=protect_shadow_gap=0 ${BUILD_DIR}/testxgboost -For details, please consult `official documentation `_ for sanitizers. +Other sanitizer runtime options +=============================== + +By default undefined sanitizer doesn't print out the backtrace. You can enable it by +exporting environment variable: + +.. code-block:: + + UBSAN_OPTIONS=print_stacktrace=1 ${BUILD_DIR}/testxgboost + +For details, please consult `official documentation `_ for sanitizers. diff --git a/include/xgboost/intrusive_ptr.h b/include/xgboost/intrusive_ptr.h new file mode 100644 index 000000000..879c0b48a --- /dev/null +++ b/include/xgboost/intrusive_ptr.h @@ -0,0 +1,216 @@ +/*! + * Copyright (c) by Contributors 2020 + * \file intrusive_ptr.h + * \brief Implementation of Intrusive Ptr. + */ +#ifndef XGBOOST_INTRUSIVE_PTR_H_ +#define XGBOOST_INTRUSIVE_PTR_H_ + +#include +#include +#include + +namespace xgboost { +/*! + * \brief Helper class for embedding reference counting into client objects. See + * https://www.boost.org/doc/libs/1_74_0/doc/html/atomic/usage_examples.html for + * discussions of memory order. + */ +class IntrusivePtrCell { + private: + std::atomic count_; + template friend class IntrusivePtr; + + std::int32_t IncRef() noexcept { + return count_.fetch_add(1, std::memory_order_relaxed); + } + std::int32_t DecRef() noexcept { + return count_.fetch_sub(1, std::memory_order_release); + } + bool IsZero() const { return Count() == 0; } + + public: + IntrusivePtrCell() noexcept : count_{0} {} + int32_t Count() const { return count_.load(std::memory_order_relaxed); } +}; + +/*! + * \brief User defined function for returing embedded reference count. + */ +template IntrusivePtrCell &IntrusivePtrRefCount(T const *ptr) noexcept; + +/*! + * \brief Implementation of Intrusive Pointer. A smart pointer that points to an object + * with an embedded reference counter. The underlying object must implement a + * friend function IntrusivePtrRefCount() that returns the ref counter (of type + * IntrusivePtrCell). The intrusive pointer is faster than std::shared_ptr<>: + * std::shared_ptr<> makes an extra memory allocation for the ref counter whereas + * the intrusive pointer does not. + * + * \code + * + * class ForIntrusivePtrTest { + * public: + * mutable class IntrusivePtrCell ref; + * float data { 0 }; + * + * friend IntrusivePtrCell & + * IntrusivePtrRefCount(ForIntrusivePtrTest const *t) noexcept { // NOLINT + * return t->ref; + * } + * + * ForIntrusivePtrTest() = default; + * ForIntrusivePtrTest(float a, int32_t b) : data{a + static_cast(b)} {} + * + * explicit ForIntrusivePtrTest(NotCopyConstructible a) : data{a.data} {} + * }; + * + * IntrusivePtr ptr {new ForIntrusivePtrTest}; + * + * \endcode + */ +template class IntrusivePtr { + private: + void IncRef(T *ptr) { + if (ptr) { + IntrusivePtrRefCount(ptr).IncRef(); + } + } + void DecRef(T *ptr) { + if (ptr) { + if (IntrusivePtrRefCount(ptr).DecRef() == 1) { + std::atomic_thread_fence(std::memory_order_acquire); + delete ptr; + } + } + } + + protected: + T *ptr_{nullptr}; + + public: + using element_type = T; // NOLINT + struct Hash { + std::size_t operator()(IntrusivePtr const &ptr) const noexcept { + return std::hash()(ptr.get()); + } + }; + /*! + * \brief Contruct an IntrusivePtr from raw pointer. IntrusivePtr takes the ownership. + * + * \param p Raw pointer to object + */ + explicit IntrusivePtr(T *p) : ptr_{p} { + if (ptr_) { + IncRef(ptr_); + } + } + + IntrusivePtr() noexcept = default; + IntrusivePtr(IntrusivePtr const &that) : ptr_{that.ptr_} { IncRef(ptr_); } + IntrusivePtr(IntrusivePtr &&that) noexcept : ptr_{that.ptr_} { that.ptr_ = nullptr; } + + ~IntrusivePtr() { DecRef(ptr_); } + + IntrusivePtr &operator=(IntrusivePtr const &that) { + IntrusivePtr{that}.swap(*this); + return *this; + } + IntrusivePtr &operator=(IntrusivePtr &&that) noexcept { + std::swap(ptr_, that.ptr_); + return *this; + } + + void reset() { // NOLINT + DecRef(ptr_); + ptr_ = nullptr; + } + void reset(element_type *that) { IntrusivePtr{that}.swap(*this); } // NOLINT + + element_type &operator*() const noexcept { return *ptr_; } + element_type *operator->() const noexcept { return ptr_; } + element_type *get() const noexcept { return ptr_; } // NOLINT + + explicit operator bool() const noexcept { return static_cast(ptr_); } + + int32_t use_count() noexcept { // NOLINT + return ptr_ ? IntrusivePtrRefCount(ptr_).Count() : 0; + } + + /* + * \brief Helper function for swapping 2 pointers. + */ + void swap(IntrusivePtr &that) noexcept { // NOLINT + std::swap(ptr_, that.ptr_); + } +}; + +template +bool operator==(IntrusivePtr const &x, IntrusivePtr const &y) noexcept { + return x.get() == y.get(); +} + +template +bool operator!=(IntrusivePtr const &x, IntrusivePtr const &y) noexcept { + return x.get() != y.get(); +} + +template +bool operator==(IntrusivePtr const &x, U *y) noexcept { + return x.get() == y; +} + +template +bool operator!=(IntrusivePtr const &x, U *y) noexcept { + return x.get() != y; +} + +template +bool operator==(T *x, IntrusivePtr const &y) noexcept { + return y == x; +} + +template +bool operator!=(T *x, IntrusivePtr const &y) noexcept { + return y != x; +} + +template +bool operator<(IntrusivePtr const &x, IntrusivePtr const &y) noexcept { + return std::less{}(x.get(), y.get()); +} + +template +bool operator<=(IntrusivePtr const &x, IntrusivePtr const &y) noexcept { + return std::less_equal{}(x.get(), y.get()); +} + +template +bool operator>(IntrusivePtr const &x, IntrusivePtr const &y) noexcept { + return !(x <= y); +} + +template +bool operator>=(IntrusivePtr const &x, IntrusivePtr const &y) noexcept { + return !(x < y); +} + +template +std::basic_ostream &operator<<(std::basic_ostream &os, + IntrusivePtr const &p) { + os << p.get(); + return os; +} +} // namespace xgboost + +namespace std { +template +void swap(xgboost::IntrusivePtr &x, // NOLINT + xgboost::IntrusivePtr &y) noexcept { + x.swap(y); +} + +template +struct hash> : public xgboost::IntrusivePtr::Hash {}; +} // namespace std +#endif // XGBOOST_INTRUSIVE_PTR_H_ diff --git a/include/xgboost/json.h b/include/xgboost/json.h index 5048bb7ec..4c4bafb1a 100644 --- a/include/xgboost/json.h +++ b/include/xgboost/json.h @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -21,6 +22,13 @@ class JsonReader; class JsonWriter; class Value { + private: + mutable class IntrusivePtrCell ref_; + friend IntrusivePtrCell & + IntrusivePtrRefCount(xgboost::Value const *t) noexcept { + return t->ref_; + } + public: /*!\brief Simplified implementation of LLVM RTTI. */ enum class ValueKind { @@ -69,12 +77,15 @@ T* Cast(U* value) { class JsonString : public Value { std::string str_; + public: JsonString() : Value(ValueKind::kString) {} JsonString(std::string const& str) : // NOLINT Value(ValueKind::kString), str_{str} {} JsonString(std::string&& str) : // NOLINT Value(ValueKind::kString), str_{std::move(str)} {} + JsonString(JsonString&& str) : // NOLINT + Value(ValueKind::kString), str_{std::move(str.str_)} {} void Save(JsonWriter* writer) override; @@ -167,6 +178,8 @@ class JsonNumber : public Value { typename std::enable_if::value>::type* = nullptr> JsonNumber(FloatT value) : Value{ValueKind::kNumber}, // NOLINT number_{static_cast(value)} {} + JsonNumber(JsonNumber const& that) = delete; + JsonNumber(JsonNumber&& that) : Value{ValueKind::kNumber}, number_{that.number_} {} void Save(JsonWriter* writer) override; @@ -214,6 +227,9 @@ class JsonInteger : public Value { : Value(ValueKind::kInteger), integer_{static_cast(value)} {} + JsonInteger(JsonInteger &&that) + : Value{ValueKind::kInteger}, integer_{that.integer_} {} + Json& operator[](std::string const & key) override; Json& operator[](int ind) override; @@ -234,6 +250,7 @@ class JsonNull : public Value { public: JsonNull() : Value(ValueKind::kNull) {} JsonNull(std::nullptr_t) : Value(ValueKind::kNull) {} // NOLINT + JsonNull(JsonNull&& that) : Value(ValueKind::kNull) {} void Save(JsonWriter* writer) override; @@ -261,6 +278,8 @@ class JsonBoolean : public Value { std::is_same::value>::type* = nullptr> JsonBoolean(Bool value) : // NOLINT Value(ValueKind::kBoolean), boolean_{value} {} + JsonBoolean(JsonBoolean&& value) : // NOLINT + Value(ValueKind::kBoolean), boolean_{value.boolean_} {} void Save(JsonWriter* writer) override; @@ -336,14 +355,14 @@ class Json { Json() : ptr_{new JsonNull} {} // number - explicit Json(JsonNumber number) : ptr_{new JsonNumber(number)} {} + explicit Json(JsonNumber number) : ptr_{new JsonNumber(std::move(number))} {} Json& operator=(JsonNumber number) { ptr_.reset(new JsonNumber(std::move(number))); return *this; } // integer - explicit Json(JsonInteger integer) : ptr_{new JsonInteger(integer)} {} + explicit Json(JsonInteger integer) : ptr_{new JsonInteger(std::move(integer))} {} Json& operator=(JsonInteger integer) { ptr_.reset(new JsonInteger(std::move(integer))); return *this; @@ -418,7 +437,7 @@ class Json { } private: - std::shared_ptr ptr_; + IntrusivePtr ptr_; }; template diff --git a/tests/cpp/common/test_intrusive_ptr.cc b/tests/cpp/common/test_intrusive_ptr.cc new file mode 100644 index 000000000..a41697f17 --- /dev/null +++ b/tests/cpp/common/test_intrusive_ptr.cc @@ -0,0 +1,109 @@ +#include +#include + +namespace xgboost { +namespace { +class NotCopyConstructible { + public: + float data; + + explicit NotCopyConstructible(float d) : data{d} {} + NotCopyConstructible(NotCopyConstructible const &that) = delete; + NotCopyConstructible &operator=(NotCopyConstructible const &that) = delete; + NotCopyConstructible(NotCopyConstructible&& that) = default; +}; +static_assert( + !std::is_trivially_copy_constructible::value, ""); +static_assert( + !std::is_trivially_copy_assignable::value, ""); + +class ForIntrusivePtrTest { + public: + mutable class IntrusivePtrCell ref; + float data { 0 }; + + friend IntrusivePtrCell & + IntrusivePtrRefCount(ForIntrusivePtrTest const *t) noexcept { // NOLINT + return t->ref; + } + + ForIntrusivePtrTest() = default; + ForIntrusivePtrTest(float a, int32_t b) : data{a + static_cast(b)} {} + + explicit ForIntrusivePtrTest(NotCopyConstructible a) : data{a.data} {} +}; +} // anonymous namespace + +TEST(IntrusivePtr, Basic) { + IntrusivePtr ptr {new ForIntrusivePtrTest}; + auto p = ptr.get(); + + // Copy ctor + IntrusivePtr ptr_1 { ptr }; + ASSERT_EQ(ptr_1.get(), p); + + ASSERT_EQ((*ptr_1).data, ptr_1->data); + ASSERT_EQ(ptr.use_count(), 2); + + // hash + ASSERT_EQ(std::hash>{}(ptr_1), + std::hash{}(ptr_1.get())); + + // Raw ptr comparison + ASSERT_EQ(ptr, p); + ASSERT_EQ(ptr_1, ptr); + + ForIntrusivePtrTest* raw_ptr {nullptr}; + ASSERT_NE(ptr_1, raw_ptr); + ASSERT_NE(raw_ptr, ptr_1); + + // Reset with raw ptr. + auto p_1 = new ForIntrusivePtrTest; + ptr.reset(p_1); + + ASSERT_EQ(ptr_1.use_count(), 1); + ASSERT_EQ(ptr.use_count(), 1); + + ASSERT_TRUE(ptr); + ASSERT_TRUE(ptr_1); + + // Swap + std::swap(ptr, ptr_1); + ASSERT_NE(ptr, p_1); + ASSERT_EQ(ptr_1, p_1); + + // Reset + ptr.reset(); + ASSERT_FALSE(ptr); + ASSERT_EQ(ptr.use_count(), 0); + + // Comparison operators + ASSERT_EQ(ptr < ptr_1, ptr.get() < ptr_1.get()); + ASSERT_EQ(ptr > ptr_1, ptr.get() > ptr_1.get()); + + ASSERT_LE(ptr, ptr); + ASSERT_GE(ptr, ptr); + + // Copy assign + IntrusivePtr ptr_2; + ptr_2 = ptr_1; + ASSERT_EQ(ptr_2, ptr_1); + ASSERT_EQ(ptr_2.use_count(), 2); + + // Move assign + IntrusivePtr ptr_3; + ptr_3 = std::move(ptr_2); + ASSERT_EQ(ptr_2.use_count(), 0); // NOLINT + ASSERT_EQ(ptr_3.use_count(), 2); + + // Move ctor + IntrusivePtr ptr_4 { std::move(ptr_3) }; + ASSERT_EQ(ptr_3.use_count(), 0); // NOLINT + ASSERT_EQ(ptr_4.use_count(), 2); + + // Comparison + ASSERT_EQ(ptr_1 > ptr_2, ptr_1.get() > ptr_2.get()); + ASSERT_EQ(ptr_1, ptr_1); + ASSERT_EQ(ptr_1 < ptr_2, ptr_1.get() < ptr_2.get()); +} +} // namespace xgboost