merge latest changes
This commit is contained in:
@@ -37,7 +37,7 @@
|
||||
* \brief Whether to customize global PRNG.
|
||||
*/
|
||||
#ifndef XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
||||
#define XGBOOST_CUSTOMIZE_GLOBAL_PRNG XGBOOST_STRICT_R_MODE
|
||||
#define XGBOOST_CUSTOMIZE_GLOBAL_PRNG 0
|
||||
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
||||
|
||||
/*!
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <memory> // for unique_ptr
|
||||
#include <sstream> // for stringstream
|
||||
#include <stack> // for stack
|
||||
@@ -160,10 +162,16 @@ struct Result {
|
||||
|
||||
// We don't have monad, a simple helper would do.
|
||||
template <typename Fn>
|
||||
Result operator<<(Result&& r, Fn&& fn) {
|
||||
[[nodiscard]] Result operator<<(Result&& r, Fn&& fn) {
|
||||
if (!r.OK()) {
|
||||
return std::forward<Result>(r);
|
||||
}
|
||||
return fn();
|
||||
}
|
||||
|
||||
inline void SafeColl(Result const& rc) {
|
||||
if (!rc.OK()) {
|
||||
LOG(FATAL) << rc.Report();
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2019-2023 by XGBoost Contributors
|
||||
* Copyright 2019-2024, XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_JSON_H_
|
||||
#define XGBOOST_JSON_H_
|
||||
@@ -42,7 +42,8 @@ class Value {
|
||||
kBoolean,
|
||||
kNull,
|
||||
// typed array for ubjson
|
||||
kNumberArray,
|
||||
kF32Array,
|
||||
kF64Array,
|
||||
kU8Array,
|
||||
kI32Array,
|
||||
kI64Array
|
||||
@@ -173,7 +174,11 @@ class JsonTypedArray : public Value {
|
||||
/**
|
||||
* @brief Typed UBJSON array for 32-bit floating point.
|
||||
*/
|
||||
using F32Array = JsonTypedArray<float, Value::ValueKind::kNumberArray>;
|
||||
using F32Array = JsonTypedArray<float, Value::ValueKind::kF32Array>;
|
||||
/**
|
||||
* @brief Typed UBJSON array for 64-bit floating point.
|
||||
*/
|
||||
using F64Array = JsonTypedArray<double, Value::ValueKind::kF64Array>;
|
||||
/**
|
||||
* @brief Typed UBJSON array for uint8_t.
|
||||
*/
|
||||
@@ -457,9 +462,9 @@ class Json {
|
||||
Json& operator[](int ind) const { return (*ptr_)[ind]; }
|
||||
|
||||
/*! \brief Return the reference to stored Json value. */
|
||||
Value const& GetValue() const & { return *ptr_; }
|
||||
Value const& GetValue() && { return *ptr_; }
|
||||
Value& GetValue() & { return *ptr_; }
|
||||
[[nodiscard]] Value const& GetValue() const& { return *ptr_; }
|
||||
Value const& GetValue() && { return *ptr_; }
|
||||
Value& GetValue() & { return *ptr_; }
|
||||
|
||||
bool operator==(Json const& rhs) const {
|
||||
return *ptr_ == *(rhs.ptr_);
|
||||
@@ -472,7 +477,7 @@ class Json {
|
||||
return os;
|
||||
}
|
||||
|
||||
IntrusivePtr<Value> const& Ptr() const { return ptr_; }
|
||||
[[nodiscard]] IntrusivePtr<Value> const& Ptr() const { return ptr_; }
|
||||
|
||||
private:
|
||||
IntrusivePtr<Value> ptr_{new JsonNull};
|
||||
|
||||
@@ -142,6 +142,7 @@ class JsonWriter {
|
||||
|
||||
virtual void Visit(JsonArray const* arr);
|
||||
virtual void Visit(F32Array const* arr);
|
||||
virtual void Visit(F64Array const*) { LOG(FATAL) << "Only UBJSON format can handle f64 array."; }
|
||||
virtual void Visit(U8Array const* arr);
|
||||
virtual void Visit(I32Array const* arr);
|
||||
virtual void Visit(I64Array const* arr);
|
||||
@@ -244,7 +245,8 @@ class UBJReader : public JsonReader {
|
||||
*/
|
||||
class UBJWriter : public JsonWriter {
|
||||
void Visit(JsonArray const* arr) override;
|
||||
void Visit(F32Array const* arr) override;
|
||||
void Visit(F32Array const* arr) override;
|
||||
void Visit(F64Array const* arr) override;
|
||||
void Visit(U8Array const* arr) override;
|
||||
void Visit(I32Array const* arr) override;
|
||||
void Visit(I64Array const* arr) override;
|
||||
|
||||
@@ -295,6 +295,9 @@ class TensorView {
|
||||
using ShapeT = std::size_t[kDim];
|
||||
using StrideT = ShapeT;
|
||||
|
||||
using element_type = T; // NOLINT
|
||||
using value_type = std::remove_cv_t<T>; // NOLINT
|
||||
|
||||
private:
|
||||
StrideT stride_{1};
|
||||
ShapeT shape_{0};
|
||||
@@ -314,7 +317,7 @@ class TensorView {
|
||||
}
|
||||
|
||||
template <size_t old_dim, size_t new_dim, int32_t D, typename I>
|
||||
LINALG_HD size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D],
|
||||
LINALG_HD size_t MakeSliceDim(std::size_t new_shape[D], std::size_t new_stride[D],
|
||||
detail::RangeTag<I> &&range) const {
|
||||
static_assert(new_dim < D);
|
||||
static_assert(old_dim < kDim);
|
||||
@@ -528,9 +531,10 @@ class TensorView {
|
||||
LINALG_HD auto Stride(size_t i) const { return stride_[i]; }
|
||||
|
||||
/**
|
||||
* \brief Number of items in the tensor.
|
||||
* @brief Number of items in the tensor.
|
||||
*/
|
||||
[[nodiscard]] LINALG_HD std::size_t Size() const { return size_; }
|
||||
[[nodiscard]] bool Empty() const { return Size() == 0; }
|
||||
/**
|
||||
* \brief Whether this is a contiguous array, both C and F contiguous returns true.
|
||||
*/
|
||||
@@ -865,7 +869,9 @@ class Tensor {
|
||||
auto HostView() { return this->View(DeviceOrd::CPU()); }
|
||||
auto HostView() const { return this->View(DeviceOrd::CPU()); }
|
||||
|
||||
[[nodiscard]] size_t Size() const { return data_.Size(); }
|
||||
[[nodiscard]] std::size_t Size() const { return data_.Size(); }
|
||||
[[nodiscard]] bool Empty() const { return Size() == 0; }
|
||||
|
||||
auto Shape() const { return common::Span<size_t const, kDim>{shape_}; }
|
||||
auto Shape(size_t i) const { return shape_[i]; }
|
||||
|
||||
|
||||
@@ -738,10 +738,10 @@ class IterSpan {
|
||||
return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count};
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE constexpr iterator begin() const noexcept { // NOLINT
|
||||
return {this, 0};
|
||||
return it_;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE constexpr iterator end() const noexcept { // NOLINT
|
||||
return {this, size()};
|
||||
return it_ + size();
|
||||
}
|
||||
};
|
||||
} // namespace common
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2014-2023 by Contributors
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
* \file tree_model.h
|
||||
* \brief model structure for tree
|
||||
* \author Tianqi Chen
|
||||
@@ -688,6 +688,9 @@ class RegTree : public Model {
|
||||
}
|
||||
return (*this)[nidx].DefaultLeft();
|
||||
}
|
||||
[[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
|
||||
return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
|
||||
}
|
||||
[[nodiscard]] bool IsRoot(bst_node_t nidx) const {
|
||||
if (IsMultiTarget()) {
|
||||
return nidx == kRoot;
|
||||
|
||||
Reference in New Issue
Block a user