diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index ca816bcdb..2f84bb1cb 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2021-2022 by XGBoost Contributors +/** + * Copyright 2021-2023 by XGBoost Contributors * \file linalg.h * \brief Linear algebra related utilities. */ @@ -8,7 +8,7 @@ #include #include -#include // fixme(jiamingy): Remove the dependency on this header. +#include #include #include #include @@ -834,9 +834,26 @@ class Tensor { int32_t DeviceIdx() const { return data_.DeviceIdx(); } }; +template +using Matrix = Tensor; + template using Vector = Tensor; +/** + * \brief Create an array without initialization. + */ +template +auto Empty(Context const *ctx, Index &&...index) { + Tensor t; + t.SetDevice(ctx->gpu_id); + t.Reshape(index...); + return t; +} + +/** + * \brief Create an array with value v. + */ template auto Constant(Context const *ctx, T v, Index &&...index) { Tensor t; @@ -846,7 +863,6 @@ auto Constant(Context const *ctx, T v, Index &&...index) { return t; } - /** * \brief Like `np.zeros`, return a new array of given shape and type, filled with zeros. */ diff --git a/src/tree/hist/sampler.h b/src/tree/hist/sampler.h new file mode 100644 index 000000000..803e40d54 --- /dev/null +++ b/src/tree/hist/sampler.h @@ -0,0 +1,109 @@ +/** + * Copyright 2020-2023 by XGBoost Contributors + */ +#ifndef XGBOOST_TREE_HIST_SAMPLER_H_ +#define XGBOOST_TREE_HIST_SAMPLER_H_ + +#include // std::size-t +#include // std::uint64_t +#include // bernoulli_distribution, linear_congruential_engine + +#include "../../common/random.h" // GlobalRandom +#include "../param.h" // TrainParam +#include "xgboost/base.h" // GradientPair +#include "xgboost/context.h" // Context +#include "xgboost/data.h" // MetaInfo +#include "xgboost/linalg.h" // TensorView + +namespace xgboost { +namespace tree { +struct RandomReplace { + public: + // similar value as for minstd_rand + static constexpr std::uint64_t kBase = 16807; + static constexpr std::uint64_t kMod = static_cast(1) << 63; + + using EngineT = std::linear_congruential_engine; + + /* + Right-to-left binary method: https://en.wikipedia.org/wiki/Modular_exponentiation + */ + static std::uint64_t SimpleSkip(std::uint64_t exponent, std::uint64_t initial_seed, + std::uint64_t base, std::uint64_t mod) { + CHECK_LE(exponent, mod); + std::uint64_t result = 1; + while (exponent > 0) { + if (exponent % 2 == 1) { + result = (result * base) % mod; + } + base = (base * base) % mod; + exponent = exponent >> 1; + } + // with result we can now find the new seed + return (result * initial_seed) % mod; + } +}; + +// Only uniform sampling, no gradient-based yet. +inline void SampleGradient(Context const* ctx, TrainParam param, + linalg::MatrixView out) { + CHECK(out.Contiguous()); + CHECK_EQ(param.sampling_method, TrainParam::kUniform) + << "Only uniform sampling is supported, gradient-based sampling is only support by GPU Hist."; + + if (param.subsample >= 1.0) { + return; + } + bst_row_t n_samples = out.Shape(0); + auto& rnd = common::GlobalRandom(); + +#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG + std::bernoulli_distribution coin_flip(param.subsample); + CHECK_EQ(out.Shape(1), 1) << "Multi-target with sampling for R is not yet supported."; + for (size_t i = 0; i < n_samples; ++i) { + if (!(out(i, 0).GetHess() >= 0.0f && coin_flip(rnd)) || out(i, 0).GetGrad() == 0.0f) { + out(i, 0) = GradientPair(0); + } + } +#else + std::uint64_t initial_seed = rnd(); + + auto n_threads = static_cast(ctx->Threads()); + std::size_t const discard_size = n_samples / n_threads; + std::bernoulli_distribution coin_flip(param.subsample); + + dmlc::OMPException exc; +#pragma omp parallel num_threads(n_threads) + { + exc.Run([&]() { + const size_t tid = omp_get_thread_num(); + const size_t ibegin = tid * discard_size; + const size_t iend = (tid == (n_threads - 1)) ? n_samples : ibegin + discard_size; + + const uint64_t displaced_seed = RandomReplace::SimpleSkip( + ibegin, initial_seed, RandomReplace::kBase, RandomReplace::kMod); + RandomReplace::EngineT eng(displaced_seed); + std::size_t n_targets = out.Shape(1); + if (n_targets > 1) { + for (std::size_t i = ibegin; i < iend; ++i) { + if (!coin_flip(eng)) { + for (std::size_t j = 0; j < n_targets; ++j) { + out(i, j) = GradientPair{}; + } + } + } + } else { + for (std::size_t i = ibegin; i < iend; ++i) { + if (!coin_flip(eng)) { + out(i, 0) = GradientPair{}; + } + } + } + }); + } + exc.Rethrow(); +#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG +} +} // namespace tree +} // namespace xgboost +#endif // XGBOOST_TREE_HIST_SAMPLER_H_ diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index dc333247a..0e3675888 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2021-2022 XGBoost contributors +/** + * Copyright 2021-2023 by XGBoost contributors * * \brief Implementation for the approx tree method. */ @@ -14,9 +14,12 @@ #include "driver.h" #include "hist/evaluate_splits.h" #include "hist/histogram.h" +#include "hist/sampler.h" // SampleGradient #include "param.h" #include "xgboost/base.h" +#include "xgboost/data.h" #include "xgboost/json.h" +#include "xgboost/linalg.h" #include "xgboost/tree_model.h" #include "xgboost/tree_updater.h" @@ -256,8 +259,7 @@ class GlobalApproxUpdater : public TreeUpdater { ObjInfo task_; public: - explicit GlobalApproxUpdater(Context const *ctx, ObjInfo task) - : TreeUpdater(ctx), task_{task} { + explicit GlobalApproxUpdater(Context const *ctx, ObjInfo task) : TreeUpdater(ctx), task_{task} { monitor_.Init(__func__); } @@ -272,24 +274,11 @@ class GlobalApproxUpdater : public TreeUpdater { } void InitData(TrainParam const ¶m, HostDeviceVector const *gpair, - std::vector *sampled) { - auto const &h_gpair = gpair->ConstHostVector(); - sampled->resize(h_gpair.size()); - std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin()); - auto &rnd = common::GlobalRandom(); + linalg::Matrix *sampled) { + *sampled = linalg::Empty(ctx_, gpair->Size(), 1); + sampled->Data()->Copy(*gpair); - if (param.subsample != 1.0) { - CHECK(param.sampling_method != TrainParam::kGradientBased) - << "Gradient based sampling is not supported for approx tree method."; - std::bernoulli_distribution coin_flip(param.subsample); - std::transform(sampled->begin(), sampled->end(), sampled->begin(), [&](GradientPair &g) { - if (coin_flip(rnd)) { - return g; - } else { - return GradientPair{}; - } - }); - } + SampleGradient(ctx_, param, sampled->HostView()); } char const *Name() const override { return "grow_histmaker"; } @@ -303,18 +292,19 @@ class GlobalApproxUpdater : public TreeUpdater { pimpl_ = std::make_unique(param_, m->Info(), ctx_, column_sampler_, task_, &monitor_); - std::vector h_gpair; - InitData(param_, gpair, &h_gpair); + linalg::Matrix h_gpair; // Obtain the hessian values for weighted sketching - std::vector hess(h_gpair.size()); - std::transform(h_gpair.begin(), h_gpair.end(), hess.begin(), + InitData(param_, gpair, &h_gpair); + std::vector hess(h_gpair.Size()); + auto const &s_gpair = h_gpair.Data()->ConstHostVector(); + std::transform(s_gpair.begin(), s_gpair.end(), hess.begin(), [](auto g) { return g.GetHess(); }); cached_ = m; size_t t_idx = 0; for (auto p_tree : trees) { - this->pimpl_->UpdateTree(m, h_gpair, hess, p_tree, &out_position[t_idx]); + this->pimpl_->UpdateTree(m, s_gpair, hess, p_tree, &out_position[t_idx]); ++t_idx; } param_.learning_rate = lr; diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 525376730..f7cf73f1d 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2017-2022 by XGBoost Contributors +/** + * Copyright 2017-2023 by XGBoost Contributors * \file updater_quantile_hist.cc * \brief use quantized feature values to construct a tree * \author Philip Cho, Tianqi Checn, Egor Smirnov @@ -7,6 +7,7 @@ #include "./updater_quantile_hist.h" #include +#include #include #include #include @@ -14,9 +15,11 @@ #include "common_row_partitioner.h" #include "constraints.h" -#include "hist/histogram.h" #include "hist/evaluate_splits.h" +#include "hist/histogram.h" +#include "hist/sampler.h" #include "param.h" +#include "xgboost/linalg.h" #include "xgboost/logging.h" #include "xgboost/tree_updater.h" @@ -257,43 +260,6 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache(DMatrix const *data, return true; } -void QuantileHistMaker::Builder::InitSampling(const DMatrix &fmat, - std::vector *gpair) { - monitor_->Start(__func__); - const auto &info = fmat.Info(); - auto& rnd = common::GlobalRandom(); - std::vector& gpair_ref = *gpair; - -#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG - std::bernoulli_distribution coin_flip(param_.subsample); - for (size_t i = 0; i < info.num_row_; ++i) { - if (!(gpair_ref[i].GetHess() >= 0.0f && coin_flip(rnd)) || gpair_ref[i].GetGrad() == 0.0f) { - gpair_ref[i] = GradientPair(0); - } - } -#else - uint64_t initial_seed = rnd(); - - auto n_threads = static_cast(ctx_->Threads()); - const size_t discard_size = info.num_row_ / n_threads; - std::bernoulli_distribution coin_flip(param_.subsample); - - dmlc::OMPException exc; - #pragma omp parallel num_threads(n_threads) - { - exc.Run([&]() { - const size_t tid = omp_get_thread_num(); - const size_t ibegin = tid * discard_size; - const size_t iend = (tid == (n_threads - 1)) ? info.num_row_ : ibegin + discard_size; - RandomReplace::MakeIf([&](size_t i, RandomReplace::EngineT& eng) { - return !(gpair_ref[i].GetHess() >= 0.0f && coin_flip(eng)); - }, GradientPair(0), initial_seed, ibegin, iend, &gpair_ref); - }); - } - exc.Rethrow(); -#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG - monitor_->Stop(__func__); -} size_t QuantileHistMaker::Builder::GetNumberOfTrees() { return n_trees_; } void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree, @@ -317,12 +283,9 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree, histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id, collective::IsDistributed()); - if (param_.subsample < 1.0f) { - CHECK_EQ(param_.sampling_method, TrainParam::kUniform) - << "Only uniform sampling is supported, " - << "gradient-based sampling is only support by GPU Hist."; - InitSampling(*fmat, gpair); - } + auto m_gpair = + linalg::MakeTensorView(*gpair, {gpair->size(), static_cast(1)}, ctx_->gpu_id); + SampleGradient(ctx_, param_, m_gpair); } // store a pointer to the tree diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index dfb9c45b0..ea7000651 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -36,47 +36,6 @@ #include "../common/column_matrix.h" namespace xgboost { -struct RandomReplace { - public: - // similar value as for minstd_rand - static constexpr uint64_t kBase = 16807; - static constexpr uint64_t kMod = static_cast(1) << 63; - - using EngineT = std::linear_congruential_engine; - - /* - Right-to-left binary method: https://en.wikipedia.org/wiki/Modular_exponentiation - */ - static uint64_t SimpleSkip(uint64_t exponent, uint64_t initial_seed, - uint64_t base, uint64_t mod) { - CHECK_LE(exponent, mod); - uint64_t result = 1; - while (exponent > 0) { - if (exponent % 2 == 1) { - result = (result * base) % mod; - } - base = (base * base) % mod; - exponent = exponent >> 1; - } - // with result we can now find the new seed - return (result * initial_seed) % mod; - } - - template - static void MakeIf(Condition condition, const typename ContainerData::value_type replace_value, - const uint64_t initial_seed, const size_t ibegin, - const size_t iend, ContainerData* gpair) { - ContainerData& gpair_ref = *gpair; - const uint64_t displaced_seed = SimpleSkip(ibegin, initial_seed, kBase, kMod); - EngineT eng(displaced_seed); - for (size_t i = ibegin; i < iend; ++i) { - if (condition(i, eng)) { - gpair_ref[i] = replace_value; - } - } - } -}; - namespace tree { inline BatchParam HistBatch(TrainParam const& param) { return {param.max_bin, param.sparse_threshold}; @@ -141,8 +100,6 @@ class QuantileHistMaker: public TreeUpdater { size_t GetNumberOfTrees(); - void InitSampling(const DMatrix& fmat, std::vector* gpair); - CPUExpandEntry InitRoot(DMatrix* p_fmat, RegTree* p_tree, const std::vector& gpair_h); diff --git a/tests/cpp/tree/hist/test_sampler.cc b/tests/cpp/tree/hist/test_sampler.cc new file mode 100644 index 000000000..5d747f04b --- /dev/null +++ b/tests/cpp/tree/hist/test_sampler.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include + +#include // std::size_t +#include // std::to_string + +#include "../../../../src/tree/hist/sampler.h" // SampleGradient +#include "../../../../src/tree/param.h" // TrainParam +#include "xgboost/base.h" // GradientPair,bst_target_t +#include "xgboost/context.h" // Context +#include "xgboost/data.h" // MetaInfo +#include "xgboost/linalg.h" // Matrix,Constants + +namespace xgboost { +namespace tree { +TEST(Sampler, Basic) { + std::size_t constexpr kRows = 1024; + double constexpr kSubsample = .2; + TrainParam param; + param.UpdateAllowUnknown(Args{{"subsample", std::to_string(kSubsample)}}); + Context ctx; + + auto run = [&](bst_target_t n_targets) { + auto init = GradientPair{1.0f, 1.0f}; + linalg::Matrix gpair = linalg::Constant(&ctx, init, kRows, n_targets); + auto h_gpair = gpair.HostView(); + SampleGradient(&ctx, param, h_gpair); + std::size_t n_sampled{0}; + for (std::size_t i = 0; i < kRows; ++i) { + bool sampled{false}; + if (h_gpair(i, 0).GetGrad() - .0f != .0f) { + sampled = true; + n_sampled++; + } + for (bst_target_t t = 1; t < n_targets; ++t) { + if (sampled) { + ASSERT_EQ(h_gpair(i, t).GetGrad() - init.GetGrad(), .0f); + ASSERT_EQ(h_gpair(i, t).GetHess() - init.GetHess(), .0f); + + } else { + ASSERT_EQ(h_gpair(i, t).GetGrad() - .0f, .0f); + ASSERT_EQ(h_gpair(i, t).GetHess() - .0f, .0f); + } + } + } + auto ratio = static_cast(n_sampled) / static_cast(kRows); + ASSERT_LT(ratio, kSubsample * 1.5); + ASSERT_GT(ratio, kSubsample * 0.5); + }; + + run(1); + run(3); +} +} // namespace tree +} // namespace xgboost