Serialize expand entry for allgather. (#9702)
This commit is contained in:
parent
ee8b29c843
commit
7a02facc9d
@ -153,7 +153,7 @@ class JsonTypedArray : public Value {
|
|||||||
using Type = T;
|
using Type = T;
|
||||||
|
|
||||||
JsonTypedArray() : Value(kind) {}
|
JsonTypedArray() : Value(kind) {}
|
||||||
explicit JsonTypedArray(size_t n) : Value(kind) { vec_.resize(n); }
|
explicit JsonTypedArray(std::size_t n) : Value(kind) { vec_.resize(n); }
|
||||||
JsonTypedArray(JsonTypedArray&& that) noexcept : Value{kind}, vec_{std::move(that.vec_)} {}
|
JsonTypedArray(JsonTypedArray&& that) noexcept : Value{kind}, vec_{std::move(that.vec_)} {}
|
||||||
|
|
||||||
bool operator==(Value const& rhs) const override;
|
bool operator==(Value const& rhs) const override;
|
||||||
@ -171,21 +171,21 @@ class JsonTypedArray : public Value {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Typed UBJSON array for 32-bit floating point.
|
* @brief Typed UBJSON array for 32-bit floating point.
|
||||||
*/
|
*/
|
||||||
using F32Array = JsonTypedArray<float, Value::ValueKind::kNumberArray>;
|
using F32Array = JsonTypedArray<float, Value::ValueKind::kNumberArray>;
|
||||||
/**
|
/**
|
||||||
* \brief Typed UBJSON array for uint8_t.
|
* @brief Typed UBJSON array for uint8_t.
|
||||||
*/
|
*/
|
||||||
using U8Array = JsonTypedArray<uint8_t, Value::ValueKind::kU8Array>;
|
using U8Array = JsonTypedArray<std::uint8_t, Value::ValueKind::kU8Array>;
|
||||||
/**
|
/**
|
||||||
* \brief Typed UBJSON array for int32_t.
|
* @brief Typed UBJSON array for int32_t.
|
||||||
*/
|
*/
|
||||||
using I32Array = JsonTypedArray<int32_t, Value::ValueKind::kI32Array>;
|
using I32Array = JsonTypedArray<std::int32_t, Value::ValueKind::kI32Array>;
|
||||||
/**
|
/**
|
||||||
* \brief Typed UBJSON array for int64_t.
|
* @brief Typed UBJSON array for int64_t.
|
||||||
*/
|
*/
|
||||||
using I64Array = JsonTypedArray<int64_t, Value::ValueKind::kI64Array>;
|
using I64Array = JsonTypedArray<std::int64_t, Value::ValueKind::kI64Array>;
|
||||||
|
|
||||||
class JsonObject : public Value {
|
class JsonObject : public Value {
|
||||||
public:
|
public:
|
||||||
|
|||||||
@ -9,7 +9,8 @@
|
|||||||
#include <type_traits> // for remove_cv_t
|
#include <type_traits> // for remove_cv_t
|
||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "comm.h" // for Comm, Channel, EraseType
|
#include "../common/type.h" // for EraseType
|
||||||
|
#include "comm.h" // for Comm, Channel
|
||||||
#include "xgboost/collective/result.h" // for Result
|
#include "xgboost/collective/result.h" // for Result
|
||||||
#include "xgboost/span.h" // for Span
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
@ -33,7 +34,7 @@ namespace cpu_impl {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
|
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
|
||||||
auto n_bytes = sizeof(T) * size;
|
auto n_bytes = sizeof(T) * size;
|
||||||
auto erased = EraseType(data);
|
auto erased = common::EraseType(data);
|
||||||
|
|
||||||
auto rank = comm.Rank();
|
auto rank = comm.Rank();
|
||||||
auto prev = BootstrapPrev(rank, comm.World());
|
auto prev = BootstrapPrev(rank, comm.World());
|
||||||
@ -65,8 +66,8 @@ template <typename T>
|
|||||||
auto n_total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
|
auto n_total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
|
||||||
result.resize(n_total_bytes / sizeof(T));
|
result.resize(n_total_bytes / sizeof(T));
|
||||||
auto h_result = common::Span{result.data(), result.size()};
|
auto h_result = common::Span{result.data(), result.size()};
|
||||||
auto erased_result = EraseType(h_result);
|
auto erased_result = common::EraseType(h_result);
|
||||||
auto erased_data = EraseType(data);
|
auto erased_data = common::EraseType(data);
|
||||||
std::vector<std::int64_t> offset(world + 1);
|
std::vector<std::int64_t> offset(world + 1);
|
||||||
|
|
||||||
return cpu_impl::RingAllgatherV(comm, sizes, erased_data,
|
return cpu_impl::RingAllgatherV(comm, sizes, erased_data,
|
||||||
|
|||||||
@ -4,8 +4,9 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
#include <cstdint> // for int8_t
|
#include <cstdint> // for int8_t
|
||||||
#include <functional> // for function
|
#include <functional> // for function
|
||||||
#include <type_traits> // for is_invocable_v
|
#include <type_traits> // for is_invocable_v, enable_if_t
|
||||||
|
|
||||||
|
#include "../common/type.h" // for EraseType, RestoreType
|
||||||
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||||
#include "comm.h" // for Comm, RestoreType
|
#include "comm.h" // for Comm, RestoreType
|
||||||
#include "xgboost/collective/result.h" // for Result
|
#include "xgboost/collective/result.h" // for Result
|
||||||
@ -23,14 +24,14 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
|
|||||||
template <typename T, typename Fn>
|
template <typename T, typename Fn>
|
||||||
std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>, Result> Allreduce(
|
std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>, Result> Allreduce(
|
||||||
Comm const& comm, common::Span<T> data, Fn redop) {
|
Comm const& comm, common::Span<T> data, Fn redop) {
|
||||||
auto erased = EraseType(data);
|
auto erased = common::EraseType(data);
|
||||||
auto type = ToDType<T>::kType;
|
auto type = ToDType<T>::kType;
|
||||||
|
|
||||||
auto erased_fn = [type, redop](common::Span<std::int8_t const> lhs,
|
auto erased_fn = [type, redop](common::Span<std::int8_t const> lhs,
|
||||||
common::Span<std::int8_t> out) {
|
common::Span<std::int8_t> out) {
|
||||||
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
|
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
|
||||||
auto lhs_t = RestoreType<T const>(lhs);
|
auto lhs_t = common::RestoreType<T const>(lhs);
|
||||||
auto rhs_t = RestoreType<T>(out);
|
auto rhs_t = common::RestoreType<T>(out);
|
||||||
redop(lhs_t, rhs_t);
|
redop(lhs_t, rhs_t);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -137,20 +137,4 @@ class Channel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
enum class Op { kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 };
|
enum class Op { kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 };
|
||||||
|
|
||||||
template <typename T, typename U = std::conditional_t<std::is_const_v<T>,
|
|
||||||
std::add_const_t<std::int8_t>, std::int8_t>>
|
|
||||||
common::Span<U> EraseType(common::Span<T> data) {
|
|
||||||
auto n_total_bytes = data.size_bytes();
|
|
||||||
auto erased = common::Span{reinterpret_cast<std::add_pointer_t<U>>(data.data()), n_total_bytes};
|
|
||||||
return erased;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U>
|
|
||||||
common::Span<T> RestoreType(common::Span<U> data) {
|
|
||||||
static_assert(std::is_same_v<std::remove_const_t<U>, std::int8_t>);
|
|
||||||
auto n_total_bytes = data.size_bytes();
|
|
||||||
auto restored = common::Span{reinterpret_cast<T*>(data.data()), n_total_bytes / sizeof(T)};
|
|
||||||
return restored;
|
|
||||||
}
|
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -327,6 +327,5 @@ inline SpecialAllgatherVResult<T> SpecialAllgatherV(std::vector<T> const &inputs
|
|||||||
|
|
||||||
return {offsets, all_sizes, all_inputs};
|
return {offsets, all_sizes, all_inputs};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace collective
|
} // namespace collective
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -9,7 +9,7 @@
|
|||||||
#include <bitset> // for bitset
|
#include <bitset> // for bitset
|
||||||
#include <cstdint> // for uint32_t, uint64_t, uint8_t
|
#include <cstdint> // for uint32_t, uint64_t, uint8_t
|
||||||
#include <ostream> // for ostream
|
#include <ostream> // for ostream
|
||||||
#include <type_traits> // for conditional_t, is_signed_v
|
#include <type_traits> // for conditional_t, is_signed_v, add_const_t
|
||||||
|
|
||||||
#if defined(__CUDACC__)
|
#if defined(__CUDACC__)
|
||||||
#include <thrust/copy.h>
|
#include <thrust/copy.h>
|
||||||
|
|||||||
24
src/common/type.h
Normal file
24
src/common/type.h
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <cstdint> // for int8_t
|
||||||
|
#include <type_traits> // for is_const_v, add_const_t, conditional_t, add_pointer_t
|
||||||
|
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
namespace xgboost::common {
|
||||||
|
template <typename T, typename U = std::conditional_t<std::is_const_v<T>,
|
||||||
|
std::add_const_t<std::int8_t>, std::int8_t>>
|
||||||
|
common::Span<U> EraseType(common::Span<T> data) {
|
||||||
|
auto n_total_bytes = data.size_bytes();
|
||||||
|
auto erased = common::Span{reinterpret_cast<std::add_pointer_t<U>>(data.data()), n_total_bytes};
|
||||||
|
return erased;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
common::Span<T> RestoreType(common::Span<U> data) {
|
||||||
|
auto n_total_bytes = data.size_bytes();
|
||||||
|
auto restored = common::Span{reinterpret_cast<T*>(data.data()), n_total_bytes / sizeof(T)};
|
||||||
|
return restored;
|
||||||
|
}
|
||||||
|
} // namespace xgboost::common
|
||||||
@ -1,31 +1,36 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2020 by XGBoost Contributors
|
* Copyright 2020-2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#ifndef EXPAND_ENTRY_CUH_
|
#ifndef EXPAND_ENTRY_CUH_
|
||||||
#define EXPAND_ENTRY_CUH_
|
#define EXPAND_ENTRY_CUH_
|
||||||
#include <xgboost/span.h>
|
|
||||||
|
#include <limits> // for numeric_limits
|
||||||
|
#include <utility> // for move
|
||||||
|
|
||||||
#include "../param.h"
|
#include "../param.h"
|
||||||
#include "../updater_gpu_common.cuh"
|
#include "../updater_gpu_common.cuh"
|
||||||
|
#include "xgboost/base.h" // for bst_node_t
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
|
||||||
|
|
||||||
struct GPUExpandEntry {
|
struct GPUExpandEntry {
|
||||||
int nid;
|
bst_node_t nid;
|
||||||
int depth;
|
bst_node_t depth;
|
||||||
DeviceSplitCandidate split;
|
DeviceSplitCandidate split;
|
||||||
|
|
||||||
float base_weight { std::numeric_limits<float>::quiet_NaN() };
|
float base_weight{std::numeric_limits<float>::quiet_NaN()};
|
||||||
float left_weight { std::numeric_limits<float>::quiet_NaN() };
|
float left_weight{std::numeric_limits<float>::quiet_NaN()};
|
||||||
float right_weight { std::numeric_limits<float>::quiet_NaN() };
|
float right_weight{std::numeric_limits<float>::quiet_NaN()};
|
||||||
|
|
||||||
GPUExpandEntry() = default;
|
GPUExpandEntry() = default;
|
||||||
XGBOOST_DEVICE GPUExpandEntry(int nid, int depth, DeviceSplitCandidate split,
|
XGBOOST_DEVICE GPUExpandEntry(bst_node_t nid, bst_node_t depth, DeviceSplitCandidate split,
|
||||||
float base, float left, float right)
|
float base, float left, float right)
|
||||||
: nid(nid), depth(depth), split(std::move(split)), base_weight{base},
|
: nid(nid),
|
||||||
left_weight{left}, right_weight{right} {}
|
depth(depth),
|
||||||
bool IsValid(const TrainParam& param, int num_leaves) const {
|
split(std::move(split)),
|
||||||
|
base_weight{base},
|
||||||
|
left_weight{left},
|
||||||
|
right_weight{right} {}
|
||||||
|
[[nodiscard]] bool IsValid(TrainParam const& param, bst_node_t num_leaves) const {
|
||||||
if (split.loss_chg <= kRtEps) return false;
|
if (split.loss_chg <= kRtEps) return false;
|
||||||
if (split.left_sum.GetQuantisedHess() == 0 || split.right_sum.GetQuantisedHess() == 0) {
|
if (split.left_sum.GetQuantisedHess() == 0 || split.right_sum.GetQuantisedHess() == 0) {
|
||||||
return false;
|
return false;
|
||||||
@ -42,17 +47,11 @@ struct GPUExpandEntry {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bst_float GetLossChange() const {
|
[[nodiscard]] float GetLossChange() const { return split.loss_chg; }
|
||||||
return split.loss_chg;
|
|
||||||
}
|
|
||||||
|
|
||||||
int GetNodeId() const {
|
[[nodiscard]] bst_node_t GetNodeId() const { return nid; }
|
||||||
return nid;
|
|
||||||
}
|
|
||||||
|
|
||||||
int GetDepth() const {
|
[[nodiscard]] bst_node_t GetDepth() const { return depth; }
|
||||||
return depth;
|
|
||||||
}
|
|
||||||
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, const GPUExpandEntry& e) {
|
friend std::ostream& operator<<(std::ostream& os, const GPUExpandEntry& e) {
|
||||||
os << "GPUExpandEntry: \n";
|
os << "GPUExpandEntry: \n";
|
||||||
@ -63,9 +62,69 @@ struct GPUExpandEntry {
|
|||||||
os << "right_sum: " << e.split.right_sum << "\n";
|
os << "right_sum: " << e.split.right_sum << "\n";
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace tree
|
void Save(Json* p_out) const {
|
||||||
} // namespace xgboost
|
auto& out = *p_out;
|
||||||
|
|
||||||
|
out["nid"] = Integer{this->nid};
|
||||||
|
out["depth"] = Integer{this->depth};
|
||||||
|
// GPU specific
|
||||||
|
out["base_weight"] = this->base_weight;
|
||||||
|
out["left_weight"] = this->left_weight;
|
||||||
|
out["right_weight"] = this->right_weight;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle split
|
||||||
|
*/
|
||||||
|
out["split"] = Object{};
|
||||||
|
auto& split = out["split"];
|
||||||
|
split["loss_chg"] = this->split.loss_chg;
|
||||||
|
split["sindex"] = Integer{this->split.findex};
|
||||||
|
split["split_value"] = this->split.fvalue;
|
||||||
|
|
||||||
|
// cat
|
||||||
|
split["thresh"] = Integer{this->split.thresh};
|
||||||
|
split["is_cat"] = Boolean{this->split.is_cat};
|
||||||
|
/**
|
||||||
|
* Gradients
|
||||||
|
*/
|
||||||
|
auto save = [&](std::string const& name, GradientPairInt64 const& sum) {
|
||||||
|
out[name] = I64Array{2};
|
||||||
|
auto& array = get<I64Array>(out[name]);
|
||||||
|
array[0] = sum.GetQuantisedGrad();
|
||||||
|
array[1] = sum.GetQuantisedHess();
|
||||||
|
};
|
||||||
|
save("left_sum", this->split.left_sum);
|
||||||
|
save("right_sum", this->split.right_sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Load(Json const& in) {
|
||||||
|
this->nid = get<Integer const>(in["nid"]);
|
||||||
|
this->depth = get<Integer const>(in["depth"]);
|
||||||
|
// GPU specific
|
||||||
|
this->base_weight = get<Number const>(in["base_weight"]);
|
||||||
|
this->left_weight = get<Number const>(in["left_weight"]);
|
||||||
|
this->right_weight = get<Number const>(in["right_weight"]);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle split
|
||||||
|
*/
|
||||||
|
auto const& split = in["split"];
|
||||||
|
this->split.loss_chg = get<Number const>(split["loss_chg"]);
|
||||||
|
this->split.findex = get<Integer const>(split["sindex"]);
|
||||||
|
this->split.fvalue = get<Number const>(split["split_value"]);
|
||||||
|
// cat
|
||||||
|
this->split.thresh = get<Integer const>(split["thresh"]);
|
||||||
|
this->split.is_cat = get<Boolean const>(split["is_cat"]);
|
||||||
|
/**
|
||||||
|
* Gradients
|
||||||
|
*/
|
||||||
|
auto const& left_sum = get<I64Array const>(in["left_sum"]);
|
||||||
|
this->split.left_sum = GradientPairInt64{left_sum[0], left_sum[1]};
|
||||||
|
auto const& right_sum = get<I64Array const>(in["right_sum"]);
|
||||||
|
this->split.right_sum = GradientPairInt64{right_sum[0], right_sum[1]};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace xgboost::tree
|
||||||
|
|
||||||
#endif // EXPAND_ENTRY_CUH_
|
#endif // EXPAND_ENTRY_CUH_
|
||||||
|
|||||||
@ -1,16 +1,20 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2021-2023 XGBoost contributors
|
* Copyright 2021-2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
#ifndef XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
||||||
#define XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
#define XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
||||||
|
|
||||||
#include <algorithm> // for all_of
|
#include <algorithm> // for all_of
|
||||||
#include <ostream> // for ostream
|
#include <ostream> // for ostream
|
||||||
#include <utility> // for move
|
#include <string> // for string
|
||||||
#include <vector> // for vector
|
#include <type_traits> // for add_const_t
|
||||||
|
#include <utility> // for move
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../param.h" // for SplitEntry, SplitEntryContainer, TrainParam
|
#include "../../common/type.h" // for EraseType
|
||||||
#include "xgboost/base.h" // for GradientPairPrecise, bst_node_t
|
#include "../param.h" // for SplitEntry, SplitEntryContainer, TrainParam
|
||||||
|
#include "xgboost/base.h" // for GradientPairPrecise, bst_node_t
|
||||||
|
#include "xgboost/json.h" // for Json
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
/**
|
/**
|
||||||
@ -29,6 +33,66 @@ struct ExpandEntryImpl {
|
|||||||
[[nodiscard]] bool IsValid(TrainParam const& param, bst_node_t num_leaves) const {
|
[[nodiscard]] bool IsValid(TrainParam const& param, bst_node_t num_leaves) const {
|
||||||
return static_cast<Impl const*>(this)->IsValidImpl(param, num_leaves);
|
return static_cast<Impl const*>(this)->IsValidImpl(param, num_leaves);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Save(Json* p_out) const {
|
||||||
|
auto& out = *p_out;
|
||||||
|
auto self = static_cast<Impl const*>(this);
|
||||||
|
|
||||||
|
out["nid"] = Integer{this->nid};
|
||||||
|
out["depth"] = Integer{this->depth};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle split
|
||||||
|
*/
|
||||||
|
out["split"] = Object{};
|
||||||
|
auto& split = out["split"];
|
||||||
|
split["loss_chg"] = self->split.loss_chg;
|
||||||
|
split["sindex"] = Integer{self->split.sindex};
|
||||||
|
split["split_value"] = self->split.split_value;
|
||||||
|
|
||||||
|
auto const& cat_bits = self->split.cat_bits;
|
||||||
|
auto s_cat_bits = common::Span{cat_bits.data(), cat_bits.size()};
|
||||||
|
split["cat_bits"] = U8Array{s_cat_bits.size_bytes()};
|
||||||
|
auto& j_cat_bits = get<U8Array>(split["cat_bits"]);
|
||||||
|
using T = typename decltype(self->split.cat_bits)::value_type;
|
||||||
|
auto erased =
|
||||||
|
common::EraseType<std::add_const_t<T>, std::add_const_t<std::uint8_t>>(s_cat_bits);
|
||||||
|
for (std::size_t i = 0; i < erased.size(); ++i) {
|
||||||
|
j_cat_bits[i] = erased[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
split["is_cat"] = Boolean{self->split.is_cat};
|
||||||
|
|
||||||
|
self->SaveGrad(&split);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Load(Json const& in) {
|
||||||
|
auto self = static_cast<Impl*>(this);
|
||||||
|
|
||||||
|
this->nid = get<Integer const>(in["nid"]);
|
||||||
|
this->depth = get<Integer const>(in["depth"]);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle split
|
||||||
|
*/
|
||||||
|
auto const& split = in["split"];
|
||||||
|
self->split.loss_chg = get<Number const>(split["loss_chg"]);
|
||||||
|
self->split.sindex = get<Integer const>(split["sindex"]);
|
||||||
|
self->split.split_value = get<Number const>(split["split_value"]);
|
||||||
|
|
||||||
|
auto const& j_cat_bits = get<U8Array const>(split["cat_bits"]);
|
||||||
|
using T = typename decltype(self->split.cat_bits)::value_type;
|
||||||
|
auto restored = common::RestoreType<std::add_const_t<T>>(
|
||||||
|
common::Span{j_cat_bits.data(), j_cat_bits.size()});
|
||||||
|
self->split.cat_bits.resize(restored.size());
|
||||||
|
for (std::size_t i = 0; i < restored.size(); ++i) {
|
||||||
|
self->split.cat_bits[i] = restored[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
self->split.is_cat = get<Boolean const>(split["is_cat"]);
|
||||||
|
|
||||||
|
self->LoadGrad(split);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CPUExpandEntry : public ExpandEntryImpl<CPUExpandEntry> {
|
struct CPUExpandEntry : public ExpandEntryImpl<CPUExpandEntry> {
|
||||||
@ -39,6 +103,24 @@ struct CPUExpandEntry : public ExpandEntryImpl<CPUExpandEntry> {
|
|||||||
: ExpandEntryImpl{nidx, depth}, split(std::move(split)) {}
|
: ExpandEntryImpl{nidx, depth}, split(std::move(split)) {}
|
||||||
CPUExpandEntry(bst_node_t nidx, bst_node_t depth) : ExpandEntryImpl{nidx, depth} {}
|
CPUExpandEntry(bst_node_t nidx, bst_node_t depth) : ExpandEntryImpl{nidx, depth} {}
|
||||||
|
|
||||||
|
void SaveGrad(Json* p_out) const {
|
||||||
|
auto& out = *p_out;
|
||||||
|
auto save = [&](std::string const& name, GradStats const& sum) {
|
||||||
|
out[name] = F32Array{2};
|
||||||
|
auto& array = get<F32Array>(out[name]);
|
||||||
|
array[0] = sum.GetGrad();
|
||||||
|
array[1] = sum.GetHess();
|
||||||
|
};
|
||||||
|
save("left_sum", this->split.left_sum);
|
||||||
|
save("right_sum", this->split.right_sum);
|
||||||
|
}
|
||||||
|
void LoadGrad(Json const& in) {
|
||||||
|
auto const& left_sum = get<F32Array const>(in["left_sum"]);
|
||||||
|
this->split.left_sum = GradStats{left_sum[0], left_sum[1]};
|
||||||
|
auto const& right_sum = get<F32Array const>(in["right_sum"]);
|
||||||
|
this->split.right_sum = GradStats{right_sum[0], right_sum[1]};
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]] bool IsValidImpl(TrainParam const& param, bst_node_t num_leaves) const {
|
[[nodiscard]] bool IsValidImpl(TrainParam const& param, bst_node_t num_leaves) const {
|
||||||
if (split.loss_chg <= kRtEps) return false;
|
if (split.loss_chg <= kRtEps) return false;
|
||||||
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
|
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
|
||||||
@ -88,6 +170,32 @@ struct MultiExpandEntry : public ExpandEntryImpl<MultiExpandEntry> {
|
|||||||
MultiExpandEntry() = default;
|
MultiExpandEntry() = default;
|
||||||
MultiExpandEntry(bst_node_t nidx, bst_node_t depth) : ExpandEntryImpl{nidx, depth} {}
|
MultiExpandEntry(bst_node_t nidx, bst_node_t depth) : ExpandEntryImpl{nidx, depth} {}
|
||||||
|
|
||||||
|
void SaveGrad(Json* p_out) const {
|
||||||
|
auto& out = *p_out;
|
||||||
|
auto save = [&](std::string const& name, std::vector<GradientPairPrecise> const& sum) {
|
||||||
|
out[name] = F32Array{sum.size() * 2};
|
||||||
|
auto& array = get<F32Array>(out[name]);
|
||||||
|
for (std::size_t i = 0, j = 0; i < sum.size(); i++, j += 2) {
|
||||||
|
array[j] = sum[i].GetGrad();
|
||||||
|
array[j + 1] = sum[i].GetHess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
save("left_sum", this->split.left_sum);
|
||||||
|
save("right_sum", this->split.right_sum);
|
||||||
|
}
|
||||||
|
void LoadGrad(Json const& in) {
|
||||||
|
auto load = [&](std::string const& name, std::vector<GradientPairPrecise>* p_sum) {
|
||||||
|
auto const& array = get<F32Array const>(in[name]);
|
||||||
|
auto& sum = *p_sum;
|
||||||
|
sum.resize(array.size() / 2);
|
||||||
|
for (std::size_t i = 0, j = 0; i < sum.size(); ++i, j += 2) {
|
||||||
|
sum[i] = GradientPairPrecise{array[j], array[j + 1]};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
load("left_sum", &this->split.left_sum);
|
||||||
|
load("right_sum", &this->split.right_sum);
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]] bool IsValidImpl(TrainParam const& param, bst_node_t num_leaves) const {
|
[[nodiscard]] bool IsValidImpl(TrainParam const& param, bst_node_t num_leaves) const {
|
||||||
if (split.loss_chg <= kRtEps) return false;
|
if (split.loss_chg <= kRtEps) return false;
|
||||||
auto is_zero = [](auto const& sum) {
|
auto is_zero = [](auto const& sum) {
|
||||||
|
|||||||
@ -401,7 +401,7 @@ struct SplitEntryContainer {
|
|||||||
/*! \brief split index */
|
/*! \brief split index */
|
||||||
bst_feature_t sindex{0};
|
bst_feature_t sindex{0};
|
||||||
bst_float split_value{0.0f};
|
bst_float split_value{0.0f};
|
||||||
std::vector<uint32_t> cat_bits;
|
std::vector<std::uint32_t> cat_bits;
|
||||||
bool is_cat{false};
|
bool is_cat{false};
|
||||||
|
|
||||||
GradientT left_sum;
|
GradientT left_sum;
|
||||||
|
|||||||
@ -14,9 +14,7 @@
|
|||||||
#include "gpu_hist/histogram.cuh"
|
#include "gpu_hist/histogram.cuh"
|
||||||
#include "param.h"
|
#include "param.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
|
||||||
|
|
||||||
struct GPUTrainingParam {
|
struct GPUTrainingParam {
|
||||||
// minimum amount of hessian(weight) allowed in a child
|
// minimum amount of hessian(weight) allowed in a child
|
||||||
float min_child_weight;
|
float min_child_weight;
|
||||||
@ -136,5 +134,4 @@ struct SumCallbackOp {
|
|||||||
return old_prefix;
|
return old_prefix;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace tree
|
} // namespace xgboost::tree
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -4,7 +4,9 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include "../../../src/collective/comm.h"
|
#include "../../../src/collective/comm.h"
|
||||||
#include "test_worker.h"
|
#include "../../../src/common/type.h" // for EraseType
|
||||||
|
#include "test_worker.h" // for TrackerTest
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
namespace {
|
namespace {
|
||||||
class CommTest : public TrackerTest {};
|
class CommTest : public TrackerTest {};
|
||||||
|
|||||||
28
tests/cpp/tree/gpu_hist/test_expand_entry.cu
Normal file
28
tests/cpp/tree/gpu_hist/test_expand_entry.cu
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/json.h>
|
||||||
|
#include <xgboost/tree_model.h> // for RegTree
|
||||||
|
|
||||||
|
#include "../../../../src/tree/gpu_hist/expand_entry.cuh"
|
||||||
|
|
||||||
|
namespace xgboost::tree {
|
||||||
|
TEST(ExpandEntry, IOGPU) {
|
||||||
|
DeviceSplitCandidate split;
|
||||||
|
GPUExpandEntry entry{RegTree::kRoot, 0, split, 3.0, 1.0, 2.0};
|
||||||
|
|
||||||
|
Json je{Object{}};
|
||||||
|
entry.Save(&je);
|
||||||
|
|
||||||
|
GPUExpandEntry loaded;
|
||||||
|
loaded.Load(je);
|
||||||
|
|
||||||
|
ASSERT_EQ(entry.base_weight, loaded.base_weight);
|
||||||
|
ASSERT_EQ(entry.left_weight, loaded.left_weight);
|
||||||
|
ASSERT_EQ(entry.right_weight, loaded.right_weight);
|
||||||
|
|
||||||
|
ASSERT_EQ(entry.GetDepth(), loaded.GetDepth());
|
||||||
|
ASSERT_EQ(entry.GetLossChange(), loaded.GetLossChange());
|
||||||
|
}
|
||||||
|
} // namespace xgboost::tree
|
||||||
57
tests/cpp/tree/hist/test_expand_entry.cc
Normal file
57
tests/cpp/tree/hist/test_expand_entry.cc
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/json.h> // for Json
|
||||||
|
#include <xgboost/tree_model.h> // for RegTree
|
||||||
|
|
||||||
|
#include "../../../../src/tree/hist/expand_entry.h"
|
||||||
|
|
||||||
|
namespace xgboost::tree {
|
||||||
|
TEST(ExpandEntry, IO) {
|
||||||
|
CPUExpandEntry entry{RegTree::kRoot, 0};
|
||||||
|
entry.split.Update(1.0, 1, /*new_split_value=*/0.3, true, true, GradStats{1.0, 1.0},
|
||||||
|
GradStats{2.0, 2.0});
|
||||||
|
bst_bin_t n_bins_feature = 256;
|
||||||
|
auto n = common::CatBitField::ComputeStorageSize(n_bins_feature);
|
||||||
|
entry.split.cat_bits = decltype(entry.split.cat_bits)(n, 0);
|
||||||
|
common::CatBitField cat_bits{entry.split.cat_bits};
|
||||||
|
cat_bits.Set(n_bins_feature / 2);
|
||||||
|
|
||||||
|
Json je{Object{}};
|
||||||
|
entry.Save(&je);
|
||||||
|
|
||||||
|
CPUExpandEntry loaded;
|
||||||
|
loaded.Load(je);
|
||||||
|
|
||||||
|
ASSERT_EQ(loaded.split.is_cat, entry.split.is_cat);
|
||||||
|
ASSERT_EQ(loaded.split.cat_bits, entry.split.cat_bits);
|
||||||
|
ASSERT_EQ(loaded.split.left_sum.GetGrad(), entry.split.left_sum.GetGrad());
|
||||||
|
ASSERT_EQ(loaded.split.right_sum.GetHess(), entry.split.right_sum.GetHess());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ExpandEntry, IOMulti) {
|
||||||
|
MultiExpandEntry entry{RegTree::kRoot, 0};
|
||||||
|
auto left_sum = std::vector<GradientPairPrecise>{{1.0, 1.0}, {1.0, 1.0}};
|
||||||
|
auto right_sum = std::vector<GradientPairPrecise>{{2.0, 2.0}, {2.0, 2.0}};
|
||||||
|
entry.split.Update(1.0, 1, /*new_split_value=*/0.3, true, true,
|
||||||
|
linalg::MakeVec(left_sum.data(), left_sum.size()),
|
||||||
|
linalg::MakeVec(right_sum.data(), right_sum.size()));
|
||||||
|
bst_bin_t n_bins_feature = 256;
|
||||||
|
auto n = common::CatBitField::ComputeStorageSize(n_bins_feature);
|
||||||
|
entry.split.cat_bits = decltype(entry.split.cat_bits)(n, 0);
|
||||||
|
common::CatBitField cat_bits{entry.split.cat_bits};
|
||||||
|
cat_bits.Set(n_bins_feature / 2);
|
||||||
|
|
||||||
|
Json je{Object{}};
|
||||||
|
entry.Save(&je);
|
||||||
|
|
||||||
|
MultiExpandEntry loaded;
|
||||||
|
loaded.Load(je);
|
||||||
|
|
||||||
|
ASSERT_EQ(loaded.split.is_cat, entry.split.is_cat);
|
||||||
|
ASSERT_EQ(loaded.split.cat_bits, entry.split.cat_bits);
|
||||||
|
ASSERT_EQ(loaded.split.left_sum, entry.split.left_sum);
|
||||||
|
ASSERT_EQ(loaded.split.right_sum, entry.split.right_sum);
|
||||||
|
}
|
||||||
|
} // namespace xgboost::tree
|
||||||
Loading…
x
Reference in New Issue
Block a user