Implement fit stump. (#8607)
This commit is contained in:
parent
20e6087579
commit
8d545ab2a2
@ -55,6 +55,7 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/predictor/cpu_predictor.o \
|
||||
$(PKGROOT)/src/tree/constraints.o \
|
||||
$(PKGROOT)/src/tree/param.o \
|
||||
$(PKGROOT)/src/tree/fit_stump.o \
|
||||
$(PKGROOT)/src/tree/tree_model.o \
|
||||
$(PKGROOT)/src/tree/tree_updater.o \
|
||||
$(PKGROOT)/src/tree/updater_approx.o \
|
||||
@ -85,6 +86,7 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/common/pseudo_huber.o \
|
||||
$(PKGROOT)/src/common/quantile.o \
|
||||
$(PKGROOT)/src/common/random.o \
|
||||
$(PKGROOT)/src/common/stats.o \
|
||||
$(PKGROOT)/src/common/survival_util.o \
|
||||
$(PKGROOT)/src/common/threading_utils.o \
|
||||
$(PKGROOT)/src/common/timer.o \
|
||||
|
||||
@ -55,6 +55,7 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/predictor/cpu_predictor.o \
|
||||
$(PKGROOT)/src/tree/constraints.o \
|
||||
$(PKGROOT)/src/tree/param.o \
|
||||
$(PKGROOT)/src/tree/fit_stump.o \
|
||||
$(PKGROOT)/src/tree/tree_model.o \
|
||||
$(PKGROOT)/src/tree/tree_updater.o \
|
||||
$(PKGROOT)/src/tree/updater_approx.o \
|
||||
@ -85,6 +86,7 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/common/pseudo_huber.o \
|
||||
$(PKGROOT)/src/common/quantile.o \
|
||||
$(PKGROOT)/src/common/random.o \
|
||||
$(PKGROOT)/src/common/stats.o \
|
||||
$(PKGROOT)/src/common/survival_util.o \
|
||||
$(PKGROOT)/src/common/threading_utils.o \
|
||||
$(PKGROOT)/src/common/timer.o \
|
||||
|
||||
@ -134,6 +134,8 @@ using bst_row_t = std::size_t; // NOLINT
|
||||
using bst_node_t = int32_t; // NOLINT
|
||||
/*! \brief Type for ranking group index. */
|
||||
using bst_group_t = uint32_t; // NOLINT
|
||||
/*! \brief Type for indexing target variables. */
|
||||
using bst_target_t = std::size_t; // NOLINT
|
||||
|
||||
namespace detail {
|
||||
/*! \brief Implementation of gradient statistics pair. Template specialisation
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cinttypes> // std::int32_t
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
@ -388,9 +389,9 @@ class TensorView {
|
||||
* \brief Create a tensor with data, shape and strides. Don't use this constructor if
|
||||
* stride can be calculated from shape.
|
||||
*/
|
||||
template <typename I, int32_t D>
|
||||
template <typename I, std::int32_t D>
|
||||
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], I const (&stride)[D],
|
||||
int32_t device)
|
||||
std::int32_t device)
|
||||
: data_{data}, ptr_{data_.data()}, device_{device} {
|
||||
static_assert(D == kDim, "Invalid shape & stride.");
|
||||
detail::UnrollLoop<D>([&](auto i) {
|
||||
@ -833,6 +834,27 @@ class Tensor {
|
||||
int32_t DeviceIdx() const { return data_.DeviceIdx(); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using Vector = Tensor<T, 1>;
|
||||
|
||||
template <typename T, typename... Index>
|
||||
auto Constant(Context const *ctx, T v, Index &&...index) {
|
||||
Tensor<T, sizeof...(Index)> t;
|
||||
t.SetDevice(ctx->gpu_id);
|
||||
t.Reshape(index...);
|
||||
t.Data()->Fill(std::move(v));
|
||||
return t;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* \brief Like `np.zeros`, return a new array of given shape and type, filled with zeros.
|
||||
*/
|
||||
template <typename T, typename... Index>
|
||||
auto Zeros(Context const *ctx, Index &&...index) {
|
||||
return Constant(ctx, static_cast<T>(0), index...);
|
||||
}
|
||||
|
||||
// Only first axis is supported for now.
|
||||
template <typename T, int32_t D>
|
||||
void Stack(Tensor<T, D> *l, Tensor<T, D> const &r) {
|
||||
|
||||
@ -93,7 +93,7 @@ class ObjFunction : public Configurable {
|
||||
* \brief Return number of targets for input matrix. Right now XGBoost supports only
|
||||
* multi-target regression.
|
||||
*/
|
||||
virtual uint32_t Targets(MetaInfo const& info) const {
|
||||
virtual bst_target_t Targets(MetaInfo const& info) const {
|
||||
if (info.labels.Shape(1) > 1) {
|
||||
LOG(FATAL) << "multioutput is not supported by current objective function";
|
||||
}
|
||||
|
||||
@ -172,6 +172,7 @@ void HostDeviceVector<T>::SetDevice(int) const {}
|
||||
template class HostDeviceVector<bst_float>;
|
||||
template class HostDeviceVector<double>;
|
||||
template class HostDeviceVector<GradientPair>;
|
||||
template class HostDeviceVector<GradientPairPrecise>;
|
||||
template class HostDeviceVector<int32_t>; // bst_node_t
|
||||
template class HostDeviceVector<uint8_t>;
|
||||
template class HostDeviceVector<FeatureType>;
|
||||
|
||||
@ -404,6 +404,7 @@ void HostDeviceVector<T>::Resize(size_t new_size, T v) {
|
||||
template class HostDeviceVector<bst_float>;
|
||||
template class HostDeviceVector<double>;
|
||||
template class HostDeviceVector<GradientPair>;
|
||||
template class HostDeviceVector<GradientPairPrecise>;
|
||||
template class HostDeviceVector<int32_t>; // bst_node_t
|
||||
template class HostDeviceVector<uint8_t>;
|
||||
template class HostDeviceVector<FeatureType>;
|
||||
|
||||
@ -3,10 +3,8 @@
|
||||
*/
|
||||
#include "numeric.h"
|
||||
|
||||
#include <numeric> // std::accumulate
|
||||
#include <type_traits> // std::is_same
|
||||
|
||||
#include "threading_utils.h" // MemStackAllocator, ParallelFor, DefaultMaxThreads
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||
|
||||
@ -15,14 +13,11 @@ namespace common {
|
||||
double Reduce(Context const* ctx, HostDeviceVector<float> const& values) {
|
||||
if (ctx->IsCPU()) {
|
||||
auto const& h_values = values.ConstHostVector();
|
||||
MemStackAllocator<double, DefaultMaxThreads()> result_tloc(ctx->Threads(), 0);
|
||||
ParallelFor(h_values.size(), ctx->Threads(),
|
||||
[&](auto i) { result_tloc[omp_get_thread_num()] += h_values[i]; });
|
||||
auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cend(), 0.0);
|
||||
auto result = cpu_impl::Reduce(ctx, h_values.cbegin(), h_values.cend(), 0.0);
|
||||
static_assert(std::is_same<decltype(result), double>::value, "");
|
||||
return result;
|
||||
}
|
||||
return cuda::Reduce(ctx, values);
|
||||
return cuda_impl::Reduce(ctx, values);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@ -2,24 +2,22 @@
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
*/
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/functional.h> // thrust:plus
|
||||
|
||||
#include "device_helpers.cuh" // dh::Reduce, safe_cuda, dh::XGBCachingDeviceAllocator
|
||||
#include "device_helpers.cuh" // dh::Reduce, dh::XGBCachingDeviceAllocator
|
||||
#include "numeric.h"
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
namespace cuda {
|
||||
namespace cuda_impl {
|
||||
double Reduce(Context const* ctx, HostDeviceVector<float> const& values) {
|
||||
values.SetDevice(ctx->gpu_id);
|
||||
auto const d_values = values.ConstDeviceSpan();
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto res = dh::Reduce(thrust::cuda::par(alloc), d_values.data(),
|
||||
d_values.data() + d_values.size(), 0.0, thrust::plus<double>{});
|
||||
return res;
|
||||
return dh::Reduce(thrust::cuda::par(alloc), dh::tcbegin(d_values), dh::tcend(d_values), 0.0,
|
||||
thrust::plus<float>{});
|
||||
}
|
||||
} // namespace cuda
|
||||
} // namespace cuda_impl
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@ -95,7 +95,7 @@ void PartialSum(int32_t n_threads, InIt begin, InIt end, T init, OutIt out_it) {
|
||||
exc.Rethrow();
|
||||
}
|
||||
|
||||
namespace cuda {
|
||||
namespace cuda_impl {
|
||||
double Reduce(Context const* ctx, HostDeviceVector<float> const& values);
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
inline double Reduce(Context const*, HostDeviceVector<float> const&) {
|
||||
@ -103,9 +103,25 @@ inline double Reduce(Context const*, HostDeviceVector<float> const&) {
|
||||
return 0;
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
} // namespace cuda
|
||||
} // namespace cuda_impl
|
||||
|
||||
/**
|
||||
* \brief Reduction with summation.
|
||||
* \brief Reduction with iterator. init must be additive identity. (0 for primitive types)
|
||||
*/
|
||||
namespace cpu_impl {
|
||||
template <typename It, typename V = typename It::value_type>
|
||||
V Reduce(Context const* ctx, It first, It second, V const& init) {
|
||||
size_t n = std::distance(first, second);
|
||||
common::MemStackAllocator<V, common::DefaultMaxThreads()> result_tloc(ctx->Threads(), init);
|
||||
common::ParallelFor(n, ctx->Threads(),
|
||||
[&](auto i) { result_tloc[omp_get_thread_num()] += first[i]; });
|
||||
auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cbegin() + ctx->Threads(), init);
|
||||
return result;
|
||||
}
|
||||
} // namespace cpu_impl
|
||||
|
||||
/**
|
||||
* \brief Reduction on host device vector.
|
||||
*/
|
||||
double Reduce(Context const* ctx, HostDeviceVector<float> const& values);
|
||||
|
||||
|
||||
@ -641,7 +641,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
|
||||
thrust::equal_to<bst_feature_t>{},
|
||||
[] __device__(auto l, auto r) { return l.value > r.value ? l : r; });
|
||||
dh::CopyDeviceSpanToVector(&max_values, dh::ToSpan(d_max_values));
|
||||
auto max_it = common::MakeIndexTransformIter([&](auto i) {
|
||||
auto max_it = MakeIndexTransformIter([&](auto i) {
|
||||
if (IsCat(h_feature_types, i)) {
|
||||
return max_values[i].value;
|
||||
}
|
||||
|
||||
64
src/common/stats.cc
Normal file
64
src/common/stats.cc
Normal file
@ -0,0 +1,64 @@
|
||||
/*!
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
*/
|
||||
#include "stats.h"
|
||||
|
||||
#include <numeric> // std::accumulate
|
||||
|
||||
#include "common.h" // OptionalWeights
|
||||
#include "threading_utils.h" // ParallelFor, MemStackAllocator
|
||||
#include "transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||
#include "xgboost/linalg.h" // Tensor, UnravelIndex, Apply
|
||||
#include "xgboost/logging.h" // CHECK_EQ
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
float Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
|
||||
HostDeviceVector<float> const& weights) {
|
||||
CHECK_LE(t.Shape(1), 1) << "Matrix is not yet supported.";
|
||||
if (!ctx->IsCPU()) {
|
||||
weights.SetDevice(ctx->gpu_id);
|
||||
auto opt_weights = OptionalWeights(weights.ConstDeviceSpan());
|
||||
auto t_v = t.View(ctx->gpu_id);
|
||||
return cuda_impl::Median(ctx, t_v, opt_weights);
|
||||
}
|
||||
|
||||
auto opt_weights = OptionalWeights(weights.ConstHostSpan());
|
||||
auto t_v = t.HostView();
|
||||
auto iter = common::MakeIndexTransformIter(
|
||||
[&](size_t i) { return linalg::detail::Apply(t_v, linalg::UnravelIndex(i, t_v.Shape())); });
|
||||
float q{0};
|
||||
if (opt_weights.Empty()) {
|
||||
q = common::Quantile(0.5, iter, iter + t_v.Size());
|
||||
} else {
|
||||
CHECK_NE(t_v.Shape(1), 0);
|
||||
auto w_it = common::MakeIndexTransformIter([&](size_t i) {
|
||||
auto sample_idx = i / t_v.Shape(1);
|
||||
return opt_weights[sample_idx];
|
||||
});
|
||||
q = common::WeightedQuantile(0.5, iter, iter + t_v.Size(), w_it);
|
||||
}
|
||||
return q;
|
||||
}
|
||||
|
||||
void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<float>* out) {
|
||||
v.SetDevice(ctx->gpu_id);
|
||||
out->SetDevice(ctx->gpu_id);
|
||||
out->Reshape(1);
|
||||
|
||||
if (ctx->IsCPU()) {
|
||||
auto h_v = v.HostView();
|
||||
float n = v.Size();
|
||||
MemStackAllocator<float, DefaultMaxThreads()> tloc(ctx->Threads(), 0.0f);
|
||||
ParallelFor(v.Size(), ctx->Threads(),
|
||||
[&](auto i) { tloc[omp_get_thread_num()] += h_v(i) / n; });
|
||||
auto ret = std::accumulate(tloc.cbegin(), tloc.cend(), .0f);
|
||||
out->HostView()(0) = ret;
|
||||
} else {
|
||||
cuda_impl::Mean(ctx, v.View(ctx->gpu_id), out->View(ctx->gpu_id));
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@ -13,7 +13,7 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
namespace cuda {
|
||||
namespace cuda_impl {
|
||||
float Median(Context const* ctx, linalg::TensorView<float const, 2> t,
|
||||
common::OptionalWeights weights) {
|
||||
HostDeviceVector<size_t> segments{0, t.Size()};
|
||||
@ -42,6 +42,17 @@ float Median(Context const* ctx, linalg::TensorView<float const, 2> t,
|
||||
CHECK_EQ(quantile.Size(), 1);
|
||||
return quantile.HostVector().front();
|
||||
}
|
||||
} // namespace cuda
|
||||
|
||||
void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::VectorView<float> out) {
|
||||
float n = v.Size();
|
||||
auto it = dh::MakeTransformIterator<float>(
|
||||
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { return v(i) / n; });
|
||||
std::size_t bytes;
|
||||
CHECK_EQ(out.Size(), 1);
|
||||
cub::DeviceReduce::Sum(nullptr, bytes, it, out.Values().data(), v.Size());
|
||||
dh::TemporaryArray<char> temp{bytes};
|
||||
cub::DeviceReduce::Sum(temp.data().get(), bytes, it, out.Values().data(), v.Size());
|
||||
}
|
||||
} // namespace cuda_impl
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@ -8,10 +8,11 @@
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "common.h" // AssertGPUSupport
|
||||
#include "common.h" // AssertGPUSupport, OptionalWeights
|
||||
#include "transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/logging.h" // CHECK_GE
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@ -93,43 +94,25 @@ float WeightedQuantile(double alpha, Iter begin, Iter end, WeightIter weights) {
|
||||
return val(idx);
|
||||
}
|
||||
|
||||
namespace cuda {
|
||||
float Median(Context const* ctx, linalg::TensorView<float const, 2> t,
|
||||
common::OptionalWeights weights);
|
||||
namespace cuda_impl {
|
||||
float Median(Context const* ctx, linalg::TensorView<float const, 2> t, OptionalWeights weights);
|
||||
void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::VectorView<float> out);
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
inline float Median(Context const*, linalg::TensorView<float const, 2>, common::OptionalWeights) {
|
||||
AssertGPUSupport();
|
||||
inline float Median(Context const*, linalg::TensorView<float const, 2>, OptionalWeights) {
|
||||
common::AssertGPUSupport();
|
||||
return 0;
|
||||
}
|
||||
inline void Mean(Context const*, linalg::VectorView<float const>, linalg::VectorView<float>) {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
} // namespace cuda
|
||||
} // namespace cuda_impl
|
||||
|
||||
inline float Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
|
||||
HostDeviceVector<float> const& weights) {
|
||||
if (!ctx->IsCPU()) {
|
||||
weights.SetDevice(ctx->gpu_id);
|
||||
auto opt_weights = OptionalWeights(weights.ConstDeviceSpan());
|
||||
auto t_v = t.View(ctx->gpu_id);
|
||||
return cuda::Median(ctx, t_v, opt_weights);
|
||||
}
|
||||
float Median(Context const* ctx, linalg::Tensor<float, 2> const& t,
|
||||
HostDeviceVector<float> const& weights);
|
||||
|
||||
auto opt_weights = OptionalWeights(weights.ConstHostSpan());
|
||||
auto t_v = t.HostView();
|
||||
auto iter = common::MakeIndexTransformIter(
|
||||
[&](size_t i) { return linalg::detail::Apply(t_v, linalg::UnravelIndex(i, t_v.Shape())); });
|
||||
float q{0};
|
||||
if (opt_weights.Empty()) {
|
||||
q = common::Quantile(0.5, iter, iter + t_v.Size());
|
||||
} else {
|
||||
CHECK_NE(t_v.Shape(1), 0);
|
||||
auto w_it = common::MakeIndexTransformIter([&](size_t i) {
|
||||
auto sample_idx = i / t_v.Shape(1);
|
||||
return opt_weights[sample_idx];
|
||||
});
|
||||
q = common::WeightedQuantile(0.5, iter, iter + t_v.Size(), w_it);
|
||||
}
|
||||
return q;
|
||||
}
|
||||
void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<float>* out);
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_STATS_H_
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
#include "../common/numeric.h"
|
||||
#include "../common/stats.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "xgboost/tree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@ -81,7 +81,7 @@ class RegLossObj : public ObjFunction {
|
||||
|
||||
ObjInfo Task() const override { return Loss::Info(); }
|
||||
|
||||
uint32_t Targets(MetaInfo const& info) const override {
|
||||
bst_target_t Targets(MetaInfo const& info) const override {
|
||||
// Multi-target regression.
|
||||
return std::max(static_cast<size_t>(1), info.labels.Shape(1));
|
||||
}
|
||||
@ -220,7 +220,7 @@ class PseudoHuberRegression : public ObjFunction {
|
||||
public:
|
||||
void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); }
|
||||
ObjInfo Task() const override { return ObjInfo::kRegression; }
|
||||
uint32_t Targets(MetaInfo const& info) const override {
|
||||
bst_target_t Targets(MetaInfo const& info) const override {
|
||||
return std::max(static_cast<size_t>(1), info.labels.Shape(1));
|
||||
}
|
||||
|
||||
|
||||
82
src/tree/fit_stump.cc
Normal file
82
src/tree/fit_stump.cc
Normal file
@ -0,0 +1,82 @@
|
||||
/**
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
*
|
||||
* \brief Utilities for estimating initial score.
|
||||
*/
|
||||
#include "fit_stump.h"
|
||||
|
||||
#include <cinttypes> // std::int32_t
|
||||
#include <cstddef> // std::size_t
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/common.h" // AssertGPUSupport
|
||||
#include "../common/numeric.h" // cpu_impl::Reduce
|
||||
#include "../common/threading_utils.h" // ParallelFor
|
||||
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "xgboost/base.h" // bst_target_t, GradientPairPrecise
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/linalg.h" // TensorView, Tensor, Constant
|
||||
#include "xgboost/logging.h" // CHECK_EQ
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
namespace cpu_impl {
|
||||
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair,
|
||||
linalg::VectorView<float> out) {
|
||||
auto n_targets = out.Size();
|
||||
CHECK_EQ(n_targets, gpair.Shape(1));
|
||||
linalg::Tensor<GradientPairPrecise, 2> sum_tloc =
|
||||
linalg::Constant(ctx, GradientPairPrecise{}, ctx->Threads(), n_targets);
|
||||
auto h_sum_tloc = sum_tloc.HostView();
|
||||
// first dim for gpair is samples, second dim is target.
|
||||
// Reduce by column, parallel by samples
|
||||
common::ParallelFor(gpair.Shape(0), ctx->Threads(), [&](auto i) {
|
||||
for (bst_target_t t = 0; t < n_targets; ++t) {
|
||||
h_sum_tloc(omp_get_thread_num(), t) += GradientPairPrecise{gpair(i, t)};
|
||||
}
|
||||
});
|
||||
// Aggregate to the first row.
|
||||
auto h_sum = h_sum_tloc.Slice(0, linalg::All());
|
||||
for (std::int32_t i = 1; i < ctx->Threads(); ++i) {
|
||||
for (bst_target_t j = 0; j < n_targets; ++j) {
|
||||
h_sum(j) += h_sum_tloc(i, j);
|
||||
}
|
||||
}
|
||||
CHECK(h_sum.CContiguous());
|
||||
collective::Allreduce<collective::Operation::kSum>(
|
||||
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
|
||||
|
||||
for (std::size_t i = 0; i < h_sum.Size(); ++i) {
|
||||
out(i) = static_cast<float>(CalcUnregularizedWeight(h_sum(i).GetGrad(), h_sum(i).GetHess()));
|
||||
}
|
||||
}
|
||||
} // namespace cpu_impl
|
||||
|
||||
namespace cuda_impl {
|
||||
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair,
|
||||
linalg::VectorView<float> out);
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
inline void FitStump(Context const*, linalg::TensorView<GradientPair const, 2>,
|
||||
linalg::VectorView<float>) {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
} // namespace cuda_impl
|
||||
|
||||
void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
|
||||
bst_target_t n_targets, linalg::Vector<float>* out) {
|
||||
out->SetDevice(ctx->gpu_id);
|
||||
out->Reshape(n_targets);
|
||||
auto n_samples = gpair.Size() / n_targets;
|
||||
|
||||
gpair.SetDevice(ctx->gpu_id);
|
||||
linalg::TensorView<GradientPair const, 2> gpair_t{
|
||||
ctx->IsCPU() ? gpair.ConstHostSpan() : gpair.ConstDeviceSpan(),
|
||||
{n_samples, n_targets},
|
||||
ctx->gpu_id};
|
||||
ctx->IsCPU() ? cpu_impl::FitStump(ctx, gpair_t, out->HostView())
|
||||
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id));
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
63
src/tree/fit_stump.cu
Normal file
63
src/tree/fit_stump.cu
Normal file
@ -0,0 +1,63 @@
|
||||
/**
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
*
|
||||
* \brief Utilities for estimating initial score.
|
||||
*/
|
||||
#if !defined(NOMINMAX) && defined(_WIN32)
|
||||
#define NOMINMAX
|
||||
#endif // !defined(NOMINMAX)
|
||||
#include <thrust/execution_policy.h> // cuda::par
|
||||
#include <thrust/iterator/counting_iterator.h> // thrust::make_counting_iterator
|
||||
|
||||
#include <cstddef> // std::size_t
|
||||
|
||||
#include "../collective/device_communicator.cuh" // DeviceCommunicator
|
||||
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
|
||||
#include "fit_stump.h"
|
||||
#include "xgboost/base.h" // GradientPairPrecise, GradientPair, XGBOOST_DEVICE
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/linalg.h" // TensorView, Tensor, Constant
|
||||
#include "xgboost/logging.h" // CHECK_EQ
|
||||
#include "xgboost/span.h" // span
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
namespace cuda_impl {
|
||||
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair,
|
||||
linalg::VectorView<float> out) {
|
||||
auto n_targets = out.Size();
|
||||
CHECK_EQ(n_targets, gpair.Shape(1));
|
||||
linalg::Vector<GradientPairPrecise> sum = linalg::Constant(ctx, GradientPairPrecise{}, n_targets);
|
||||
CHECK(out.Contiguous());
|
||||
|
||||
// Reduce by column
|
||||
auto key_it = dh::MakeTransformIterator<bst_target_t>(
|
||||
thrust::make_counting_iterator(0ul),
|
||||
[=] XGBOOST_DEVICE(std::size_t i) -> bst_target_t { return i / gpair.Shape(0); });
|
||||
auto grad_it = dh::MakeTransformIterator<GradientPairPrecise>(
|
||||
thrust::make_counting_iterator(0ul),
|
||||
[=] XGBOOST_DEVICE(std::size_t i) -> GradientPairPrecise {
|
||||
auto target = i / gpair.Shape(0);
|
||||
auto sample = i % gpair.Shape(0);
|
||||
return GradientPairPrecise{gpair(sample, target)};
|
||||
});
|
||||
auto d_sum = sum.View(ctx->gpu_id);
|
||||
CHECK(d_sum.CContiguous());
|
||||
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto policy = thrust::cuda::par(alloc);
|
||||
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
|
||||
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));
|
||||
|
||||
collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(ctx->gpu_id);
|
||||
communicator->AllReduceSum(reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2);
|
||||
|
||||
thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets,
|
||||
[=] XGBOOST_DEVICE(std::size_t i) mutable {
|
||||
out(i) = static_cast<float>(
|
||||
CalcUnregularizedWeight(d_sum(i).GetGrad(), d_sum(i).GetHess()));
|
||||
});
|
||||
}
|
||||
} // namespace cuda_impl
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
37
src/tree/fit_stump.h
Normal file
37
src/tree/fit_stump.h
Normal file
@ -0,0 +1,37 @@
|
||||
/**
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
*
|
||||
* \brief Utilities for estimating initial score.
|
||||
*/
|
||||
|
||||
#ifndef XGBOOST_TREE_FIT_STUMP_H_
|
||||
#define XGBOOST_TREE_FIT_STUMP_H_
|
||||
|
||||
#if !defined(NOMINMAX) && defined(_WIN32)
|
||||
#define NOMINMAX
|
||||
#endif // !defined(NOMINMAX)
|
||||
|
||||
#include <algorithm> // std::max
|
||||
|
||||
#include "../common/common.h" // AssertGPUSupport
|
||||
#include "xgboost/base.h" // GradientPair
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||
#include "xgboost/linalg.h" // TensorView
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
template <typename T>
|
||||
XGBOOST_DEVICE inline double CalcUnregularizedWeight(T sum_grad, T sum_hess) {
|
||||
return -sum_grad / std::max(sum_hess, static_cast<double>(kRtEps));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Fit a tree stump as an estimation of base_score.
|
||||
*/
|
||||
void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
|
||||
bst_target_t n_targets, linalg::Vector<float>* out);
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_FIT_STUMP_H_
|
||||
@ -3,8 +3,10 @@
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/context.h>
|
||||
#include <xgboost/linalg.h> // Tensor,Vector
|
||||
|
||||
#include "../../../src/common/stats.h"
|
||||
#include "../../../src/common/transform_iterator.h" // common::MakeIndexTransformIter
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@ -69,5 +71,35 @@ TEST(Stats, Median) {
|
||||
ASSERT_EQ(m, .5f);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
namespace {
|
||||
void TestMean(Context const* ctx) {
|
||||
std::size_t n{128};
|
||||
linalg::Vector<float> data({n}, ctx->gpu_id);
|
||||
auto h_v = data.HostView().Values();
|
||||
std::iota(h_v.begin(), h_v.end(), .0f);
|
||||
|
||||
auto nf = static_cast<float>(n);
|
||||
float mean = nf * (nf - 1) / 2 / n;
|
||||
|
||||
linalg::Vector<float> res{{1}, ctx->gpu_id};
|
||||
Mean(ctx, data, &res);
|
||||
auto h_res = res.HostView();
|
||||
ASSERT_EQ(h_res.Size(), 1);
|
||||
ASSERT_EQ(mean, h_res(0));
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(Stats, Mean) {
|
||||
Context ctx;
|
||||
TestMean(&ctx);
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
TEST(Stats, GPUMean) {
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
|
||||
TestMean(&ctx);
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "../../../src/common/stats.cuh"
|
||||
#include "../../../src/common/stats.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/context.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
@ -66,7 +66,7 @@ TEST(Learner, CheckGroup) {
|
||||
|
||||
std::shared_ptr<DMatrix> p_mat{
|
||||
RandomDataGenerator{kNumRows, kNumCols, 0.0f}.GenerateDMatrix()};
|
||||
std::vector<bst_float> weight(kNumGroups);
|
||||
std::vector<bst_float> weight(kNumGroups, 1);
|
||||
std::vector<bst_int> group(kNumGroups);
|
||||
group[0] = 2;
|
||||
group[1] = 3;
|
||||
|
||||
48
tests/cpp/tree/test_fit_stump.cc
Normal file
48
tests/cpp/tree/test_fit_stump.cc
Normal file
@ -0,0 +1,48 @@
|
||||
/**
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/linalg.h>
|
||||
|
||||
#include "../../src/common/linalg_op.h"
|
||||
#include "../../src/tree/fit_stump.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
namespace {
|
||||
void TestFitStump(Context const *ctx) {
|
||||
std::size_t constexpr kRows = 16, kTargets = 2;
|
||||
HostDeviceVector<GradientPair> gpair;
|
||||
auto &h_gpair = gpair.HostVector();
|
||||
h_gpair.resize(kRows * kTargets);
|
||||
for (std::size_t i = 0; i < kRows; ++i) {
|
||||
for (std::size_t t = 0; t < kTargets; ++t) {
|
||||
h_gpair.at(i * kTargets + t) = GradientPair{static_cast<float>(i), 1};
|
||||
}
|
||||
}
|
||||
linalg::Vector<float> out;
|
||||
FitStump(ctx, gpair, kTargets, &out);
|
||||
auto h_out = out.HostView();
|
||||
for (auto it = linalg::cbegin(h_out); it != linalg::cend(h_out); ++it) {
|
||||
// sum_hess == kRows
|
||||
auto n = static_cast<float>(kRows);
|
||||
auto sum_grad = n * (n - 1) / 2;
|
||||
ASSERT_EQ(static_cast<float>(-sum_grad / n), *it);
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(InitEstimation, FitStump) {
|
||||
Context ctx;
|
||||
TestFitStump(&ctx);
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
TEST(InitEstimation, GPUFitStump) {
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
|
||||
TestFitStump(&ctx);
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
Loading…
x
Reference in New Issue
Block a user