Extract CPU sampling routines. (#8697)
This commit is contained in:
parent
7a068af1a3
commit
e49e0998c0
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021-2022 by XGBoost Contributors
|
* Copyright 2021-2023 by XGBoost Contributors
|
||||||
* \file linalg.h
|
* \file linalg.h
|
||||||
* \brief Linear algebra related utilities.
|
* \brief Linear algebra related utilities.
|
||||||
*/
|
*/
|
||||||
@ -8,7 +8,7 @@
|
|||||||
|
|
||||||
#include <dmlc/endian.h>
|
#include <dmlc/endian.h>
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
#include <xgboost/context.h> // fixme(jiamingy): Remove the dependency on this header.
|
#include <xgboost/context.h>
|
||||||
#include <xgboost/host_device_vector.h>
|
#include <xgboost/host_device_vector.h>
|
||||||
#include <xgboost/json.h>
|
#include <xgboost/json.h>
|
||||||
#include <xgboost/span.h>
|
#include <xgboost/span.h>
|
||||||
@ -834,9 +834,26 @@ class Tensor {
|
|||||||
int32_t DeviceIdx() const { return data_.DeviceIdx(); }
|
int32_t DeviceIdx() const { return data_.DeviceIdx(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using Matrix = Tensor<T, 2>;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
using Vector = Tensor<T, 1>;
|
using Vector = Tensor<T, 1>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Create an array without initialization.
|
||||||
|
*/
|
||||||
|
template <typename T, typename... Index>
|
||||||
|
auto Empty(Context const *ctx, Index &&...index) {
|
||||||
|
Tensor<T, sizeof...(Index)> t;
|
||||||
|
t.SetDevice(ctx->gpu_id);
|
||||||
|
t.Reshape(index...);
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Create an array with value v.
|
||||||
|
*/
|
||||||
template <typename T, typename... Index>
|
template <typename T, typename... Index>
|
||||||
auto Constant(Context const *ctx, T v, Index &&...index) {
|
auto Constant(Context const *ctx, T v, Index &&...index) {
|
||||||
Tensor<T, sizeof...(Index)> t;
|
Tensor<T, sizeof...(Index)> t;
|
||||||
@ -846,7 +863,6 @@ auto Constant(Context const *ctx, T v, Index &&...index) {
|
|||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Like `np.zeros`, return a new array of given shape and type, filled with zeros.
|
* \brief Like `np.zeros`, return a new array of given shape and type, filled with zeros.
|
||||||
*/
|
*/
|
||||||
|
|||||||
109
src/tree/hist/sampler.h
Normal file
109
src/tree/hist/sampler.h
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020-2023 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_TREE_HIST_SAMPLER_H_
|
||||||
|
#define XGBOOST_TREE_HIST_SAMPLER_H_
|
||||||
|
|
||||||
|
#include <cstddef> // std::size-t
|
||||||
|
#include <cstdint> // std::uint64_t
|
||||||
|
#include <random> // bernoulli_distribution, linear_congruential_engine
|
||||||
|
|
||||||
|
#include "../../common/random.h" // GlobalRandom
|
||||||
|
#include "../param.h" // TrainParam
|
||||||
|
#include "xgboost/base.h" // GradientPair
|
||||||
|
#include "xgboost/context.h" // Context
|
||||||
|
#include "xgboost/data.h" // MetaInfo
|
||||||
|
#include "xgboost/linalg.h" // TensorView
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
struct RandomReplace {
|
||||||
|
public:
|
||||||
|
// similar value as for minstd_rand
|
||||||
|
static constexpr std::uint64_t kBase = 16807;
|
||||||
|
static constexpr std::uint64_t kMod = static_cast<std::uint64_t>(1) << 63;
|
||||||
|
|
||||||
|
using EngineT = std::linear_congruential_engine<uint64_t, kBase, 0, kMod>;
|
||||||
|
|
||||||
|
/*
|
||||||
|
Right-to-left binary method: https://en.wikipedia.org/wiki/Modular_exponentiation
|
||||||
|
*/
|
||||||
|
static std::uint64_t SimpleSkip(std::uint64_t exponent, std::uint64_t initial_seed,
|
||||||
|
std::uint64_t base, std::uint64_t mod) {
|
||||||
|
CHECK_LE(exponent, mod);
|
||||||
|
std::uint64_t result = 1;
|
||||||
|
while (exponent > 0) {
|
||||||
|
if (exponent % 2 == 1) {
|
||||||
|
result = (result * base) % mod;
|
||||||
|
}
|
||||||
|
base = (base * base) % mod;
|
||||||
|
exponent = exponent >> 1;
|
||||||
|
}
|
||||||
|
// with result we can now find the new seed
|
||||||
|
return (result * initial_seed) % mod;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Only uniform sampling, no gradient-based yet.
|
||||||
|
inline void SampleGradient(Context const* ctx, TrainParam param,
|
||||||
|
linalg::MatrixView<GradientPair> out) {
|
||||||
|
CHECK(out.Contiguous());
|
||||||
|
CHECK_EQ(param.sampling_method, TrainParam::kUniform)
|
||||||
|
<< "Only uniform sampling is supported, gradient-based sampling is only support by GPU Hist.";
|
||||||
|
|
||||||
|
if (param.subsample >= 1.0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
bst_row_t n_samples = out.Shape(0);
|
||||||
|
auto& rnd = common::GlobalRandom();
|
||||||
|
|
||||||
|
#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
||||||
|
std::bernoulli_distribution coin_flip(param.subsample);
|
||||||
|
CHECK_EQ(out.Shape(1), 1) << "Multi-target with sampling for R is not yet supported.";
|
||||||
|
for (size_t i = 0; i < n_samples; ++i) {
|
||||||
|
if (!(out(i, 0).GetHess() >= 0.0f && coin_flip(rnd)) || out(i, 0).GetGrad() == 0.0f) {
|
||||||
|
out(i, 0) = GradientPair(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
std::uint64_t initial_seed = rnd();
|
||||||
|
|
||||||
|
auto n_threads = static_cast<size_t>(ctx->Threads());
|
||||||
|
std::size_t const discard_size = n_samples / n_threads;
|
||||||
|
std::bernoulli_distribution coin_flip(param.subsample);
|
||||||
|
|
||||||
|
dmlc::OMPException exc;
|
||||||
|
#pragma omp parallel num_threads(n_threads)
|
||||||
|
{
|
||||||
|
exc.Run([&]() {
|
||||||
|
const size_t tid = omp_get_thread_num();
|
||||||
|
const size_t ibegin = tid * discard_size;
|
||||||
|
const size_t iend = (tid == (n_threads - 1)) ? n_samples : ibegin + discard_size;
|
||||||
|
|
||||||
|
const uint64_t displaced_seed = RandomReplace::SimpleSkip(
|
||||||
|
ibegin, initial_seed, RandomReplace::kBase, RandomReplace::kMod);
|
||||||
|
RandomReplace::EngineT eng(displaced_seed);
|
||||||
|
std::size_t n_targets = out.Shape(1);
|
||||||
|
if (n_targets > 1) {
|
||||||
|
for (std::size_t i = ibegin; i < iend; ++i) {
|
||||||
|
if (!coin_flip(eng)) {
|
||||||
|
for (std::size_t j = 0; j < n_targets; ++j) {
|
||||||
|
out(i, j) = GradientPair{};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (std::size_t i = ibegin; i < iend; ++i) {
|
||||||
|
if (!coin_flip(eng)) {
|
||||||
|
out(i, 0) = GradientPair{};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
exc.Rethrow();
|
||||||
|
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
||||||
|
}
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_TREE_HIST_SAMPLER_H_
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021-2022 XGBoost contributors
|
* Copyright 2021-2023 by XGBoost contributors
|
||||||
*
|
*
|
||||||
* \brief Implementation for the approx tree method.
|
* \brief Implementation for the approx tree method.
|
||||||
*/
|
*/
|
||||||
@ -14,9 +14,12 @@
|
|||||||
#include "driver.h"
|
#include "driver.h"
|
||||||
#include "hist/evaluate_splits.h"
|
#include "hist/evaluate_splits.h"
|
||||||
#include "hist/histogram.h"
|
#include "hist/histogram.h"
|
||||||
|
#include "hist/sampler.h" // SampleGradient
|
||||||
#include "param.h"
|
#include "param.h"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
|
#include "xgboost/data.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
|
#include "xgboost/linalg.h"
|
||||||
#include "xgboost/tree_model.h"
|
#include "xgboost/tree_model.h"
|
||||||
#include "xgboost/tree_updater.h"
|
#include "xgboost/tree_updater.h"
|
||||||
|
|
||||||
@ -256,8 +259,7 @@ class GlobalApproxUpdater : public TreeUpdater {
|
|||||||
ObjInfo task_;
|
ObjInfo task_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit GlobalApproxUpdater(Context const *ctx, ObjInfo task)
|
explicit GlobalApproxUpdater(Context const *ctx, ObjInfo task) : TreeUpdater(ctx), task_{task} {
|
||||||
: TreeUpdater(ctx), task_{task} {
|
|
||||||
monitor_.Init(__func__);
|
monitor_.Init(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -272,24 +274,11 @@ class GlobalApproxUpdater : public TreeUpdater {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void InitData(TrainParam const ¶m, HostDeviceVector<GradientPair> const *gpair,
|
void InitData(TrainParam const ¶m, HostDeviceVector<GradientPair> const *gpair,
|
||||||
std::vector<GradientPair> *sampled) {
|
linalg::Matrix<GradientPair> *sampled) {
|
||||||
auto const &h_gpair = gpair->ConstHostVector();
|
*sampled = linalg::Empty<GradientPair>(ctx_, gpair->Size(), 1);
|
||||||
sampled->resize(h_gpair.size());
|
sampled->Data()->Copy(*gpair);
|
||||||
std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin());
|
|
||||||
auto &rnd = common::GlobalRandom();
|
|
||||||
|
|
||||||
if (param.subsample != 1.0) {
|
SampleGradient(ctx_, param, sampled->HostView());
|
||||||
CHECK(param.sampling_method != TrainParam::kGradientBased)
|
|
||||||
<< "Gradient based sampling is not supported for approx tree method.";
|
|
||||||
std::bernoulli_distribution coin_flip(param.subsample);
|
|
||||||
std::transform(sampled->begin(), sampled->end(), sampled->begin(), [&](GradientPair &g) {
|
|
||||||
if (coin_flip(rnd)) {
|
|
||||||
return g;
|
|
||||||
} else {
|
|
||||||
return GradientPair{};
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
char const *Name() const override { return "grow_histmaker"; }
|
char const *Name() const override { return "grow_histmaker"; }
|
||||||
@ -303,18 +292,19 @@ class GlobalApproxUpdater : public TreeUpdater {
|
|||||||
pimpl_ = std::make_unique<GloablApproxBuilder>(param_, m->Info(), ctx_, column_sampler_, task_,
|
pimpl_ = std::make_unique<GloablApproxBuilder>(param_, m->Info(), ctx_, column_sampler_, task_,
|
||||||
&monitor_);
|
&monitor_);
|
||||||
|
|
||||||
std::vector<GradientPair> h_gpair;
|
linalg::Matrix<GradientPair> h_gpair;
|
||||||
InitData(param_, gpair, &h_gpair);
|
|
||||||
// Obtain the hessian values for weighted sketching
|
// Obtain the hessian values for weighted sketching
|
||||||
std::vector<float> hess(h_gpair.size());
|
InitData(param_, gpair, &h_gpair);
|
||||||
std::transform(h_gpair.begin(), h_gpair.end(), hess.begin(),
|
std::vector<float> hess(h_gpair.Size());
|
||||||
|
auto const &s_gpair = h_gpair.Data()->ConstHostVector();
|
||||||
|
std::transform(s_gpair.begin(), s_gpair.end(), hess.begin(),
|
||||||
[](auto g) { return g.GetHess(); });
|
[](auto g) { return g.GetHess(); });
|
||||||
|
|
||||||
cached_ = m;
|
cached_ = m;
|
||||||
|
|
||||||
size_t t_idx = 0;
|
size_t t_idx = 0;
|
||||||
for (auto p_tree : trees) {
|
for (auto p_tree : trees) {
|
||||||
this->pimpl_->UpdateTree(m, h_gpair, hess, p_tree, &out_position[t_idx]);
|
this->pimpl_->UpdateTree(m, s_gpair, hess, p_tree, &out_position[t_idx]);
|
||||||
++t_idx;
|
++t_idx;
|
||||||
}
|
}
|
||||||
param_.learning_rate = lr;
|
param_.learning_rate = lr;
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2017-2022 by XGBoost Contributors
|
* Copyright 2017-2023 by XGBoost Contributors
|
||||||
* \file updater_quantile_hist.cc
|
* \file updater_quantile_hist.cc
|
||||||
* \brief use quantized feature values to construct a tree
|
* \brief use quantized feature values to construct a tree
|
||||||
* \author Philip Cho, Tianqi Checn, Egor Smirnov
|
* \author Philip Cho, Tianqi Checn, Egor Smirnov
|
||||||
@ -7,6 +7,7 @@
|
|||||||
#include "./updater_quantile_hist.h"
|
#include "./updater_quantile_hist.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cstddef>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -14,9 +15,11 @@
|
|||||||
|
|
||||||
#include "common_row_partitioner.h"
|
#include "common_row_partitioner.h"
|
||||||
#include "constraints.h"
|
#include "constraints.h"
|
||||||
#include "hist/histogram.h"
|
|
||||||
#include "hist/evaluate_splits.h"
|
#include "hist/evaluate_splits.h"
|
||||||
|
#include "hist/histogram.h"
|
||||||
|
#include "hist/sampler.h"
|
||||||
#include "param.h"
|
#include "param.h"
|
||||||
|
#include "xgboost/linalg.h"
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include "xgboost/tree_updater.h"
|
#include "xgboost/tree_updater.h"
|
||||||
|
|
||||||
@ -257,43 +260,6 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache(DMatrix const *data,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void QuantileHistMaker::Builder::InitSampling(const DMatrix &fmat,
|
|
||||||
std::vector<GradientPair> *gpair) {
|
|
||||||
monitor_->Start(__func__);
|
|
||||||
const auto &info = fmat.Info();
|
|
||||||
auto& rnd = common::GlobalRandom();
|
|
||||||
std::vector<GradientPair>& gpair_ref = *gpair;
|
|
||||||
|
|
||||||
#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
|
||||||
std::bernoulli_distribution coin_flip(param_.subsample);
|
|
||||||
for (size_t i = 0; i < info.num_row_; ++i) {
|
|
||||||
if (!(gpair_ref[i].GetHess() >= 0.0f && coin_flip(rnd)) || gpair_ref[i].GetGrad() == 0.0f) {
|
|
||||||
gpair_ref[i] = GradientPair(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
uint64_t initial_seed = rnd();
|
|
||||||
|
|
||||||
auto n_threads = static_cast<size_t>(ctx_->Threads());
|
|
||||||
const size_t discard_size = info.num_row_ / n_threads;
|
|
||||||
std::bernoulli_distribution coin_flip(param_.subsample);
|
|
||||||
|
|
||||||
dmlc::OMPException exc;
|
|
||||||
#pragma omp parallel num_threads(n_threads)
|
|
||||||
{
|
|
||||||
exc.Run([&]() {
|
|
||||||
const size_t tid = omp_get_thread_num();
|
|
||||||
const size_t ibegin = tid * discard_size;
|
|
||||||
const size_t iend = (tid == (n_threads - 1)) ? info.num_row_ : ibegin + discard_size;
|
|
||||||
RandomReplace::MakeIf([&](size_t i, RandomReplace::EngineT& eng) {
|
|
||||||
return !(gpair_ref[i].GetHess() >= 0.0f && coin_flip(eng));
|
|
||||||
}, GradientPair(0), initial_seed, ibegin, iend, &gpair_ref);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
exc.Rethrow();
|
|
||||||
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
|
||||||
monitor_->Stop(__func__);
|
|
||||||
}
|
|
||||||
size_t QuantileHistMaker::Builder::GetNumberOfTrees() { return n_trees_; }
|
size_t QuantileHistMaker::Builder::GetNumberOfTrees() { return n_trees_; }
|
||||||
|
|
||||||
void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
|
void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
|
||||||
@ -317,12 +283,9 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
|
|||||||
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||||
collective::IsDistributed());
|
collective::IsDistributed());
|
||||||
|
|
||||||
if (param_.subsample < 1.0f) {
|
auto m_gpair =
|
||||||
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
|
linalg::MakeTensorView(*gpair, {gpair->size(), static_cast<std::size_t>(1)}, ctx_->gpu_id);
|
||||||
<< "Only uniform sampling is supported, "
|
SampleGradient(ctx_, param_, m_gpair);
|
||||||
<< "gradient-based sampling is only support by GPU Hist.";
|
|
||||||
InitSampling(*fmat, gpair);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// store a pointer to the tree
|
// store a pointer to the tree
|
||||||
|
|||||||
@ -36,47 +36,6 @@
|
|||||||
#include "../common/column_matrix.h"
|
#include "../common/column_matrix.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
struct RandomReplace {
|
|
||||||
public:
|
|
||||||
// similar value as for minstd_rand
|
|
||||||
static constexpr uint64_t kBase = 16807;
|
|
||||||
static constexpr uint64_t kMod = static_cast<uint64_t>(1) << 63;
|
|
||||||
|
|
||||||
using EngineT = std::linear_congruential_engine<uint64_t, kBase, 0, kMod>;
|
|
||||||
|
|
||||||
/*
|
|
||||||
Right-to-left binary method: https://en.wikipedia.org/wiki/Modular_exponentiation
|
|
||||||
*/
|
|
||||||
static uint64_t SimpleSkip(uint64_t exponent, uint64_t initial_seed,
|
|
||||||
uint64_t base, uint64_t mod) {
|
|
||||||
CHECK_LE(exponent, mod);
|
|
||||||
uint64_t result = 1;
|
|
||||||
while (exponent > 0) {
|
|
||||||
if (exponent % 2 == 1) {
|
|
||||||
result = (result * base) % mod;
|
|
||||||
}
|
|
||||||
base = (base * base) % mod;
|
|
||||||
exponent = exponent >> 1;
|
|
||||||
}
|
|
||||||
// with result we can now find the new seed
|
|
||||||
return (result * initial_seed) % mod;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename Condition, typename ContainerData>
|
|
||||||
static void MakeIf(Condition condition, const typename ContainerData::value_type replace_value,
|
|
||||||
const uint64_t initial_seed, const size_t ibegin,
|
|
||||||
const size_t iend, ContainerData* gpair) {
|
|
||||||
ContainerData& gpair_ref = *gpair;
|
|
||||||
const uint64_t displaced_seed = SimpleSkip(ibegin, initial_seed, kBase, kMod);
|
|
||||||
EngineT eng(displaced_seed);
|
|
||||||
for (size_t i = ibegin; i < iend; ++i) {
|
|
||||||
if (condition(i, eng)) {
|
|
||||||
gpair_ref[i] = replace_value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace tree {
|
namespace tree {
|
||||||
inline BatchParam HistBatch(TrainParam const& param) {
|
inline BatchParam HistBatch(TrainParam const& param) {
|
||||||
return {param.max_bin, param.sparse_threshold};
|
return {param.max_bin, param.sparse_threshold};
|
||||||
@ -141,8 +100,6 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
|
|
||||||
size_t GetNumberOfTrees();
|
size_t GetNumberOfTrees();
|
||||||
|
|
||||||
void InitSampling(const DMatrix& fmat, std::vector<GradientPair>* gpair);
|
|
||||||
|
|
||||||
CPUExpandEntry InitRoot(DMatrix* p_fmat, RegTree* p_tree,
|
CPUExpandEntry InitRoot(DMatrix* p_fmat, RegTree* p_tree,
|
||||||
const std::vector<GradientPair>& gpair_h);
|
const std::vector<GradientPair>& gpair_h);
|
||||||
|
|
||||||
|
|||||||
57
tests/cpp/tree/hist/test_sampler.cc
Normal file
57
tests/cpp/tree/hist/test_sampler.cc
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <cstddef> // std::size_t
|
||||||
|
#include <string> // std::to_string
|
||||||
|
|
||||||
|
#include "../../../../src/tree/hist/sampler.h" // SampleGradient
|
||||||
|
#include "../../../../src/tree/param.h" // TrainParam
|
||||||
|
#include "xgboost/base.h" // GradientPair,bst_target_t
|
||||||
|
#include "xgboost/context.h" // Context
|
||||||
|
#include "xgboost/data.h" // MetaInfo
|
||||||
|
#include "xgboost/linalg.h" // Matrix,Constants
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
TEST(Sampler, Basic) {
|
||||||
|
std::size_t constexpr kRows = 1024;
|
||||||
|
double constexpr kSubsample = .2;
|
||||||
|
TrainParam param;
|
||||||
|
param.UpdateAllowUnknown(Args{{"subsample", std::to_string(kSubsample)}});
|
||||||
|
Context ctx;
|
||||||
|
|
||||||
|
auto run = [&](bst_target_t n_targets) {
|
||||||
|
auto init = GradientPair{1.0f, 1.0f};
|
||||||
|
linalg::Matrix<GradientPair> gpair = linalg::Constant(&ctx, init, kRows, n_targets);
|
||||||
|
auto h_gpair = gpair.HostView();
|
||||||
|
SampleGradient(&ctx, param, h_gpair);
|
||||||
|
std::size_t n_sampled{0};
|
||||||
|
for (std::size_t i = 0; i < kRows; ++i) {
|
||||||
|
bool sampled{false};
|
||||||
|
if (h_gpair(i, 0).GetGrad() - .0f != .0f) {
|
||||||
|
sampled = true;
|
||||||
|
n_sampled++;
|
||||||
|
}
|
||||||
|
for (bst_target_t t = 1; t < n_targets; ++t) {
|
||||||
|
if (sampled) {
|
||||||
|
ASSERT_EQ(h_gpair(i, t).GetGrad() - init.GetGrad(), .0f);
|
||||||
|
ASSERT_EQ(h_gpair(i, t).GetHess() - init.GetHess(), .0f);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
ASSERT_EQ(h_gpair(i, t).GetGrad() - .0f, .0f);
|
||||||
|
ASSERT_EQ(h_gpair(i, t).GetHess() - .0f, .0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto ratio = static_cast<double>(n_sampled) / static_cast<double>(kRows);
|
||||||
|
ASSERT_LT(ratio, kSubsample * 1.5);
|
||||||
|
ASSERT_GT(ratio, kSubsample * 0.5);
|
||||||
|
};
|
||||||
|
|
||||||
|
run(1);
|
||||||
|
run(3);
|
||||||
|
}
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
Loading…
x
Reference in New Issue
Block a user