[EM] Refactor GPU histogram builder. (#10764)

- Expose the maximum number of cached nodes to be consistent with the CPU implementation. Also easier for testing.
- Extract the subtraction trick for easier testing.
- Split up the `GradientQuantiser` to avoid circular dependency.
This commit is contained in:
Jiaming Yuan 2024-08-30 02:39:14 +08:00 committed by GitHub
parent 34937fea41
commit 61dd854a52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 394 additions and 187 deletions

View File

@ -232,12 +232,12 @@ Parameters for Tree Booster
* ``max_cached_hist_node``, [default = 65536]
Maximum number of cached nodes for CPU histogram.
Maximum number of cached nodes for histogram.
.. versionadded:: 2.0.0
- For most of the cases this parameter should not be set except for growing deep trees
on CPU.
- For most of the cases this parameter should not be set except for growing deep
trees. After 3.0, this parameter affects GPU algorithms as well.
.. _cat-param:

View File

@ -522,6 +522,7 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
* - nthread (optional): Number of threads used for initializing DMatrix.
* - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with
the corresponding booster training parameter.
* - on_host (optional): Whether the data should be placed on host memory. Used by GPU inputs.
* @param out The created Quantile DMatrix.
*
* @return 0 when success, -1 when failure happens

View File

@ -60,10 +60,10 @@ template <typename T>
RET_IF_NOT(fi->Read(&impl->is_dense));
RET_IF_NOT(fi->Read(&impl->row_stride));
if (has_hmm_ats_ && !this->param_.prefetch_copy) {
RET_IF_NOT(common::ReadVec(fi, &impl->gidx_buffer));
} else {
if (this->param_.prefetch_copy || !has_hmm_ats_) {
RET_IF_NOT(ReadDeviceVec(fi, &impl->gidx_buffer));
} else {
RET_IF_NOT(common::ReadVec(fi, &impl->gidx_buffer));
}
RET_IF_NOT(fi->Read(&impl->base_rowid));
dh::DefaultStream().Sync();
@ -95,7 +95,7 @@ template <typename T>
CHECK(this->cuts_->cut_values_.DeviceCanRead());
impl->SetCuts(this->cuts_);
fi->Read(page, this->param_.prefetch_copy);
fi->Read(page, this->param_.prefetch_copy || !this->has_hmm_ats_);
dh::DefaultStream().Sync();
return true;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2023, XGBoost Contributors
* Copyright 2020-2024, XGBoost Contributors
*/
#ifndef EXPAND_ENTRY_CUH_
#define EXPAND_ENTRY_CUH_
@ -7,9 +7,9 @@
#include <limits> // for numeric_limits
#include <utility> // for move
#include "../param.h"
#include "../updater_gpu_common.cuh"
#include "xgboost/base.h" // for bst_node_t
#include "../param.h" // for TrainParam
#include "../updater_gpu_common.cuh" // for DeviceSplitCandidate
#include "xgboost/base.h" // for bst_node_t
namespace xgboost::tree {
struct GPUExpandEntry {

View File

@ -356,13 +356,19 @@ class DeviceHistogramBuilderImpl {
};
DeviceHistogramBuilder::DeviceHistogramBuilder()
: p_impl_{std::make_unique<DeviceHistogramBuilderImpl>()} {}
: p_impl_{std::make_unique<DeviceHistogramBuilderImpl>()} {
monitor_.Init(__func__);
}
DeviceHistogramBuilder::~DeviceHistogramBuilder() = default;
void DeviceHistogramBuilder::Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups,
bool force_global_memory) {
void DeviceHistogramBuilder::Reset(Context const* ctx, std::size_t max_cached_hist_nodes,
FeatureGroupsAccessor const& feature_groups,
bst_bin_t n_total_bins, bool force_global_memory) {
this->monitor_.Start(__func__);
this->p_impl_->Reset(ctx, feature_groups, force_global_memory);
this->hist_.Reset(ctx, n_total_bins, max_cached_hist_nodes);
this->monitor_.Stop(__func__);
}
void DeviceHistogramBuilder::BuildHistogram(CUDAContext const* ctx,
@ -372,6 +378,21 @@ void DeviceHistogramBuilder::BuildHistogram(CUDAContext const* ctx,
common::Span<const cuda_impl::RowIndexT> ridx,
common::Span<GradientPairInt64> histogram,
GradientQuantiser rounding) {
this->monitor_.Start(__func__);
this->p_impl_->BuildHistogram(ctx, matrix, feature_groups, gpair, ridx, histogram, rounding);
this->monitor_.Stop(__func__);
}
void DeviceHistogramBuilder::AllReduceHist(Context const* ctx, MetaInfo const& info,
bst_node_t nidx, std::size_t num_histograms) {
this->monitor_.Start(__func__);
auto d_node_hist = hist_.GetNodeHistogram(nidx);
using ReduceT = typename std::remove_pointer<decltype(d_node_hist.data())>::type::ValueT;
auto rc = collective::GlobalSum(
ctx, info,
linalg::MakeVec(reinterpret_cast<ReduceT*>(d_node_hist.data()),
d_node_hist.size() * 2 * num_histograms, ctx->Device()));
SafeColl(rc);
this->monitor_.Stop(__func__);
}
} // namespace xgboost::tree

View File

@ -9,7 +9,9 @@
#include "../../common/device_helpers.cuh" // for LaunchN
#include "../../common/device_vector.cuh" // for device_vector
#include "../../data/ellpack_page.cuh" // for EllpackDeviceAccessor
#include "expand_entry.cuh" // for GPUExpandEntry
#include "feature_groups.cuh" // for FeatureGroupsAccessor
#include "quantiser.cuh" // for GradientQuantiser
#include "xgboost/base.h" // for GradientPair, GradientPairInt64
#include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span
@ -34,77 +36,51 @@ XGBOOST_DEV_INLINE void AtomicAdd64As32(int64_t* dst, int64_t src) {
atomicAdd(y_high, sig);
}
class GradientQuantiser {
private:
/* Convert gradient to fixed point representation. */
GradientPairPrecise to_fixed_point_;
/* Convert fixed point representation back to floating point. */
GradientPairPrecise to_floating_point_;
public:
GradientQuantiser(Context const* ctx, common::Span<GradientPair const> gpair, MetaInfo const& info);
[[nodiscard]] XGBOOST_DEVICE GradientPairInt64 ToFixedPoint(GradientPair const& gpair) const {
auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(),
gpair.GetHess() * to_fixed_point_.GetHess());
return adjusted;
}
[[nodiscard]] XGBOOST_DEVICE GradientPairInt64
ToFixedPoint(GradientPairPrecise const& gpair) const {
auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(),
gpair.GetHess() * to_fixed_point_.GetHess());
return adjusted;
}
[[nodiscard]] XGBOOST_DEVICE GradientPairPrecise
ToFloatingPoint(const GradientPairInt64& gpair) const {
auto g = gpair.GetQuantisedGrad() * to_floating_point_.GetGrad();
auto h = gpair.GetQuantisedHess() * to_floating_point_.GetHess();
return {g,h};
}
};
namespace cuda_impl {
// Start with about 16mb
std::size_t constexpr DftReserveSize() { return 1 << 22; }
} // namespace cuda_impl
/**
* @brief Data storage for node histograms on device. Automatically expands.
*
* @tparam kStopGrowingSize Do not grow beyond this size
*
* @author Rory
* @date 28/07/2018
*/
template <size_t kStopGrowingSize = 1 << 28>
class DeviceHistogramStorage {
private:
using GradientSumT = GradientPairInt64;
std::size_t stop_growing_size_{0};
/** @brief Map nidx to starting index of its histogram. */
std::map<int, size_t> nidx_map_;
// Large buffer of zeroed memory, caches histograms
dh::device_vector<typename GradientSumT::ValueT> data_;
// If we run out of storage allocate one histogram at a time
// in overflow. Not cached, overwritten when a new histogram
// is requested
// If we run out of storage allocate one histogram at a time in overflow. Not cached,
// overwritten when a new histogram is requested
dh::device_vector<typename GradientSumT::ValueT> overflow_;
std::map<int, size_t> overflow_nidx_map_;
int n_bins_;
DeviceOrd device_id_;
static constexpr size_t kNumItemsInGradientSum =
static constexpr std::size_t kNumItemsInGradientSum =
sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT);
static_assert(kNumItemsInGradientSum == 2, "Number of items in gradient type should be 2.");
public:
// Start with about 16mb
DeviceHistogramStorage() { data_.reserve(1 << 22); }
void Init(DeviceOrd device_id, int n_bins) {
this->n_bins_ = n_bins;
this->device_id_ = device_id;
}
explicit DeviceHistogramStorage() { data_.reserve(cuda_impl::DftReserveSize()); }
void Reset(Context const* ctx) {
void Reset(Context const* ctx, bst_bin_t n_total_bins, std::size_t max_cached_nodes) {
this->n_bins_ = n_total_bins;
auto d_data = data_.data().get();
dh::LaunchN(data_.size(), ctx->CUDACtx()->Stream(),
[=] __device__(size_t idx) { d_data[idx] = 0.0f; });
nidx_map_.clear();
overflow_nidx_map_.clear();
auto max_cached_bin_values =
static_cast<std::size_t>(n_total_bins) * max_cached_nodes * kNumItemsInGradientSum;
this->stop_growing_size_ = max_cached_bin_values;
}
[[nodiscard]] bool HistogramExists(int nidx) const {
[[nodiscard]] bool HistogramExists(bst_node_t nidx) const {
return nidx_map_.find(nidx) != nidx_map_.cend() ||
overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend();
}
@ -112,14 +88,15 @@ class DeviceHistogramStorage {
[[nodiscard]] size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; }
dh::device_vector<typename GradientSumT::ValueT>& Data() { return data_; }
void AllocateHistograms(Context const* ctx, const std::vector<int>& new_nidxs) {
void AllocateHistograms(Context const* ctx, std::vector<bst_node_t> const& new_nidxs) {
for (int nidx : new_nidxs) {
CHECK(!HistogramExists(nidx));
}
// Number of items currently used in data
const size_t used_size = nidx_map_.size() * HistogramSize();
const size_t new_used_size = used_size + HistogramSize() * new_nidxs.size();
if (used_size >= kStopGrowingSize) {
CHECK_GE(this->stop_growing_size_, kNumItemsInGradientSum);
if (used_size >= this->stop_growing_size_) {
// Use overflow
// Delete previous entries
overflow_nidx_map_.clear();
@ -171,18 +148,77 @@ class DeviceHistogramBuilderImpl;
class DeviceHistogramBuilder {
std::unique_ptr<DeviceHistogramBuilderImpl> p_impl_;
DeviceHistogramStorage hist_;
common::Monitor monitor_;
public:
DeviceHistogramBuilder();
explicit DeviceHistogramBuilder();
~DeviceHistogramBuilder();
void Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups,
void Reset(Context const* ctx, std::size_t max_cached_hist_nodes,
FeatureGroupsAccessor const& feature_groups, bst_bin_t n_total_bins,
bool force_global_memory);
void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const std::uint32_t> ridx,
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding);
[[nodiscard]] auto GetNodeHistogram(bst_node_t nidx) { return hist_.GetNodeHistogram(nidx); }
// num histograms is the number of contiguous histograms in memory to reduce over
void AllReduceHist(Context const* ctx, MetaInfo const& info, bst_node_t nidx,
std::size_t num_histograms);
// Attempt to do subtraction trick
// return true if succeeded
[[nodiscard]] bool SubtractionTrick(bst_node_t nidx_parent, bst_node_t nidx_histogram,
bst_node_t nidx_subtraction) {
if (!hist_.HistogramExists(nidx_histogram) || !hist_.HistogramExists(nidx_parent)) {
return false;
}
auto d_node_hist_parent = hist_.GetNodeHistogram(nidx_parent);
auto d_node_hist_histogram = hist_.GetNodeHistogram(nidx_histogram);
auto d_node_hist_subtraction = hist_.GetNodeHistogram(nidx_subtraction);
dh::LaunchN(d_node_hist_parent.size(), [=] __device__(size_t idx) {
d_node_hist_subtraction[idx] = d_node_hist_parent[idx] - d_node_hist_histogram[idx];
});
return true;
}
[[nodiscard]] auto SubtractHist(std::vector<GPUExpandEntry> const& candidates,
std::vector<bst_node_t> const& build_nidx,
std::vector<bst_node_t> const& subtraction_nidx) {
this->monitor_.Start(__func__);
std::vector<bst_node_t> need_build;
for (std::size_t i = 0; i < subtraction_nidx.size(); i++) {
auto build_hist_nidx = build_nidx.at(i);
auto subtraction_trick_nidx = subtraction_nidx.at(i);
auto parent_nidx = candidates.at(i).nid;
if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
need_build.push_back(subtraction_trick_nidx);
}
}
this->monitor_.Stop(__func__);
return need_build;
}
void AllocateHistograms(Context const* ctx, std::vector<bst_node_t> const& nodes_to_build,
std::vector<bst_node_t> const& nodes_to_sub) {
this->monitor_.Start(__func__);
std::vector<bst_node_t> all_new = nodes_to_build;
all_new.insert(all_new.end(), nodes_to_sub.cbegin(), nodes_to_sub.cend());
// Allocate the histograms
// Guaranteed contiguous memory
this->AllocateHistograms(ctx, all_new);
this->monitor_.Stop(__func__);
}
void AllocateHistograms(Context const* ctx, std::vector<int> const& new_nidxs) {
this->hist_.AllocateHistograms(ctx, new_nidxs);
}
};
} // namespace xgboost::tree
#endif // HISTOGRAM_CUH_

View File

@ -0,0 +1,39 @@
/**
* Copyright 2020-2024, XGBoost Contributors
*/
#pragma once
#include "xgboost/base.h" // for GradientPairPrecise, GradientPairInt64
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for MetaInfo
#include "xgboost/span.h" // for Span
namespace xgboost::tree {
class GradientQuantiser {
private:
/* Convert gradient to fixed point representation. */
GradientPairPrecise to_fixed_point_;
/* Convert fixed point representation back to floating point. */
GradientPairPrecise to_floating_point_;
public:
GradientQuantiser(Context const* ctx, common::Span<GradientPair const> gpair,
MetaInfo const& info);
[[nodiscard]] XGBOOST_DEVICE GradientPairInt64 ToFixedPoint(GradientPair const& gpair) const {
auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(),
gpair.GetHess() * to_fixed_point_.GetHess());
return adjusted;
}
[[nodiscard]] XGBOOST_DEVICE GradientPairInt64
ToFixedPoint(GradientPairPrecise const& gpair) const {
auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(),
gpair.GetHess() * to_fixed_point_.GetHess());
return adjusted;
}
[[nodiscard]] XGBOOST_DEVICE GradientPairPrecise
ToFloatingPoint(const GradientPairInt64& gpair) const {
auto g = gpair.GetQuantisedGrad() * to_floating_point_.GetGrad();
auto h = gpair.GetQuantisedHess() * to_floating_point_.GetHess();
return {g, h};
}
};
} // namespace xgboost::tree

View File

@ -11,7 +11,7 @@
#include "../../common/hist_util.h" // for GHistRow, ConstGHistRow
#include "../../common/ref_resource_view.h" // for ReallocVector
#include "xgboost/base.h" // for bst_node_t, bst_bin_t
#include "xgboost/logging.h" // for CHECK_GT
#include "xgboost/logging.h" // for CHECK_EQ
#include "xgboost/span.h" // for Span
namespace xgboost::tree {
@ -40,7 +40,7 @@ class BoundedHistCollection {
// number of histogram bins across all features
bst_bin_t n_total_bins_{0};
// limits the number of nodes that can be in the cache for each tree
std::size_t n_cached_nodes_{0};
std::size_t max_cached_nodes_{0};
// whether the tree has grown beyond the cache limit
bool has_exceeded_{false};
@ -58,7 +58,7 @@ class BoundedHistCollection {
}
void Reset(bst_bin_t n_total_bins, std::size_t n_cached_nodes) {
n_total_bins_ = n_total_bins;
n_cached_nodes_ = n_cached_nodes;
max_cached_nodes_ = n_cached_nodes;
this->Clear(false);
}
/**
@ -73,7 +73,7 @@ class BoundedHistCollection {
[[nodiscard]] bool CanHost(common::Span<bst_node_t const> nodes_to_build,
common::Span<bst_node_t const> nodes_to_sub) const {
auto n_new_nodes = nodes_to_build.size() + nodes_to_sub.size();
return n_new_nodes + node_map_.size() <= n_cached_nodes_;
return n_new_nodes + node_map_.size() <= max_cached_nodes_;
}
/**

View File

@ -61,7 +61,7 @@ class HistogramBuilder {
bool is_col_split, HistMakerTrainParam const *param) {
n_threads_ = ctx->Threads();
param_ = p;
hist_.Reset(total_bins, param->max_cached_hist_node);
hist_.Reset(total_bins, param->MaxCachedHistNodes(ctx->Device()));
buffer_.Init(total_bins);
is_distributed_ = is_distributed;
is_col_split_ = is_col_split;

View File

@ -1,31 +1,47 @@
/**
* Copyright 2021-2023, XGBoost Contributors
* Copyright 2021-2024, XGBoost Contributors
*/
#pragma once
#include <cstddef> // for size_t
#include <limits> // for numeric_limits
#include "xgboost/parameter.h" // for XGBoostParameter
#include "xgboost/tree_model.h" // for RegTree
#include "xgboost/context.h" // for DeviceOrd
namespace xgboost::tree {
struct HistMakerTrainParam : public XGBoostParameter<HistMakerTrainParam> {
constexpr static std::size_t DefaultNodes() { return static_cast<std::size_t>(1) << 16; }
private:
constexpr static std::size_t NotSet() { return std::numeric_limits<std::size_t>::max(); }
std::size_t max_cached_hist_node{NotSet()}; // NOLINT
public:
// Smaller for GPU due to memory limitation.
constexpr static std::size_t CpuDefaultNodes() { return static_cast<std::size_t>(1) << 16; }
constexpr static std::size_t CudaDefaultNodes() { return static_cast<std::size_t>(1) << 12; }
bool debug_synchronize{false};
std::size_t max_cached_hist_node{DefaultNodes()};
void CheckTreesSynchronized(Context const* ctx, RegTree const* local_tree) const;
std::size_t MaxCachedHistNodes(DeviceOrd device) const {
if (max_cached_hist_node != NotSet()) {
return max_cached_hist_node;
}
return device.IsCPU() ? CpuDefaultNodes() : CudaDefaultNodes();
}
// declare parameters
DMLC_DECLARE_PARAMETER(HistMakerTrainParam) {
DMLC_DECLARE_FIELD(debug_synchronize)
.set_default(false)
.describe("Check if all distributed tree are identical after tree construction.");
DMLC_DECLARE_FIELD(max_cached_hist_node)
.set_default(DefaultNodes())
.set_default(NotSet())
.set_lower_bound(1)
.describe("Maximum number of nodes in CPU histogram cache. Only for internal usage.");
.describe("Maximum number of nodes in histogram cache.");
}
};
} // namespace xgboost::tree

View File

@ -5,10 +5,10 @@
#include <limits> // for numeric_limits
#include <ostream> // for ostream
#include "gpu_hist/histogram.cuh"
#include "param.h" // for TrainParam
#include "xgboost/base.h"
#include "xgboost/task.h" // for ObjInfo
#include "gpu_hist/quantiser.cuh" // for GradientQuantiser
#include "param.h" // for TrainParam
#include "xgboost/base.h" // for bst_bin_t
#include "xgboost/task.h" // for ObjInfo
namespace xgboost::tree {
struct GPUTrainingParam {

View File

@ -64,6 +64,47 @@ struct NodeSplitData {
};
static_assert(std::is_trivially_copyable_v<NodeSplitData>);
// To be tuned.
constexpr double ExtMemPrefetchThresh() { return 4.0; }
// Some nodes we will manually compute histograms, others we will do by subtraction
[[nodiscard]] bool AssignNodes(RegTree const* p_tree, GradientQuantiser const* quantizer,
std::vector<GPUExpandEntry> const& candidates,
common::Span<bst_node_t> nodes_to_build,
common::Span<bst_node_t> nodes_to_sub) {
auto const& tree = *p_tree;
std::size_t nidx_in_set{0};
double total{0.0}, smaller{0.0};
auto p_build_nidx = nodes_to_build.data();
auto p_sub_nidx = nodes_to_sub.data();
for (auto& e : candidates) {
// Decide whether to build the left histogram or right histogram Use sum of Hessian as
// a heuristic to select node with fewest training instances This optimization is for
// distributed training to avoid an allreduce call for synchronizing the number of
// instances for each node.
auto left_sum = quantizer->ToFloatingPoint(e.split.left_sum);
auto right_sum = quantizer->ToFloatingPoint(e.split.right_sum);
bool fewer_right = right_sum.GetHess() < left_sum.GetHess();
total += left_sum.GetHess() + right_sum.GetHess();
if (fewer_right) {
p_build_nidx[nidx_in_set] = tree[e.nid].RightChild();
p_sub_nidx[nidx_in_set] = tree[e.nid].LeftChild();
smaller += right_sum.GetHess();
} else {
p_build_nidx[nidx_in_set] = tree[e.nid].LeftChild();
p_sub_nidx[nidx_in_set] = tree[e.nid].RightChild();
smaller += left_sum.GetHess();
}
++nidx_in_set;
}
if (-kRtEps < smaller && smaller < kRtEps) { // Too close to 0, don't prefetch.
return false;
}
// Prefetch if these smaller nodes are not quite small.
return (total / smaller) < ExtMemPrefetchThresh();
}
// GPU tree updater implementation.
struct GPUHistMakerDevice {
private:
@ -78,11 +119,31 @@ struct GPUHistMakerDevice {
std::vector<bst_idx_t> batch_ptr_;
// node idx for each sample
dh::device_vector<bst_node_t> positions_;
HistMakerTrainParam const* hist_param_;
std::shared_ptr<common::HistogramCuts const> cuts_{nullptr};
public:
DeviceHistogramStorage<> hist{};
auto CreatePartitionNodes(RegTree const* p_tree, std::vector<GPUExpandEntry> const& candidates) {
std::vector<bst_node_t> nidx(candidates.size());
std::vector<bst_node_t> left_nidx(candidates.size());
std::vector<bst_node_t> right_nidx(candidates.size());
std::vector<NodeSplitData> split_data(candidates.size());
for (std::size_t i = 0, n = candidates.size(); i < n; i++) {
auto const& e = candidates[i];
RegTree::Node split_node = (*p_tree)[e.nid];
auto split_type = p_tree->NodeSplitType(e.nid);
nidx.at(i) = e.nid;
left_nidx[i] = split_node.LeftChild();
right_nidx[i] = split_node.RightChild();
split_data[i] =
NodeSplitData{split_node, split_type, this->evaluator_.GetDeviceNodeCats(e.nid)};
CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat);
}
return std::make_tuple(nidx, left_nidx, right_nidx, split_data);
}
public:
dh::device_vector<GradientPair> d_gpair; // storage for gpair;
common::Span<GradientPair const> gpair;
@ -102,7 +163,7 @@ struct GPUHistMakerDevice {
std::unique_ptr<FeatureGroups> feature_groups;
GPUHistMakerDevice(Context const* ctx, TrainParam _param,
GPUHistMakerDevice(Context const* ctx, TrainParam _param, HistMakerTrainParam const* hist_param,
std::shared_ptr<common::ColumnSampler> column_sampler, BatchParam batch_param,
MetaInfo const& info, std::vector<bst_idx_t> batch_ptr,
std::shared_ptr<common::HistogramCuts const> cuts)
@ -112,8 +173,9 @@ struct GPUHistMakerDevice {
column_sampler_(std::move(column_sampler)),
interaction_constraints(param, static_cast<bst_feature_t>(info.num_col_)),
batch_ptr_{std::move(batch_ptr)},
hist_param_{hist_param},
cuts_{std::move(cuts)} {
sampler =
this->sampler =
std::make_unique<GradientBasedSampler>(ctx, info.num_row_, batch_param, param.subsample,
param.sampling_method, batch_ptr_.size() > 2);
if (!param.monotone_constraints.empty()) {
@ -132,7 +194,7 @@ struct GPUHistMakerDevice {
CHECK(cuts_);
feature_groups = std::make_unique<FeatureGroups>(*cuts_, info.IsDense(),
dh::MaxSharedMemoryOptin(ctx_->Ordinal()),
sizeof(GradientPairPrecise));
sizeof(GradientPairInt64));
}
}
@ -142,7 +204,7 @@ struct GPUHistMakerDevice {
this->column_sampler_->Init(ctx_, p_fmat->Info().num_col_, info.feature_weights.HostVector(),
param.colsample_bynode, param.colsample_bylevel,
param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
common::SetDevice(ctx_->Ordinal());
this->interaction_constraints.Reset();
@ -185,13 +247,12 @@ struct GPUHistMakerDevice {
quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, p_fmat->Info());
// Init histogram
hist.Init(ctx_->Device(), this->cuts_->TotalBins());
hist.Reset(ctx_);
this->InitFeatureGroupsOnce(info);
this->histogram_.Reset(ctx_, feature_groups->DeviceAccessor(ctx_->Device()), false);
this->histogram_.Reset(ctx_, this->hist_param_->MaxCachedHistNodes(ctx_->Device()),
feature_groups->DeviceAccessor(ctx_->Device()), cuts_->TotalBins(),
false);
return p_fmat;
}
@ -202,7 +263,7 @@ struct GPUHistMakerDevice {
sampled_features->SetDevice(ctx_->Device());
common::Span<bst_feature_t> feature_set =
interaction_constraints.Query(sampled_features->DeviceSpan(), nidx);
EvaluateSplitInputs inputs{nidx, 0, root_sum, feature_set, hist.GetNodeHistogram(nidx)};
EvaluateSplitInputs inputs{nidx, 0, root_sum, feature_set, histogram_.GetNodeHistogram(nidx)};
EvaluateSplitSharedInputs shared_inputs{gpu_param,
*quantiser,
p_fmat->Info().feature_types.ConstDeviceSpan(),
@ -250,12 +311,10 @@ struct GPUHistMakerDevice {
common::Span<bst_feature_t> right_feature_set =
interaction_constraints.Query(right_sampled_features->DeviceSpan(),
right_nidx);
h_node_inputs[i * 2] = {left_nidx, candidate.depth + 1,
candidate.split.left_sum, left_feature_set,
hist.GetNodeHistogram(left_nidx)};
h_node_inputs[i * 2 + 1] = {right_nidx, candidate.depth + 1,
candidate.split.right_sum, right_feature_set,
hist.GetNodeHistogram(right_nidx)};
h_node_inputs[i * 2] = {left_nidx, candidate.depth + 1, candidate.split.left_sum,
left_feature_set, histogram_.GetNodeHistogram(left_nidx)};
h_node_inputs[i * 2 + 1] = {right_nidx, candidate.depth + 1, candidate.split.right_sum,
right_feature_set, histogram_.GetNodeHistogram(right_nidx)};
}
bst_feature_t max_active_features = 0;
for (auto input : h_node_inputs) {
@ -274,28 +333,17 @@ struct GPUHistMakerDevice {
this->monitor.Stop(__func__);
}
void BuildHist(EllpackPageImpl const* page, int nidx) {
auto d_node_hist = hist.GetNodeHistogram(nidx);
auto d_ridx = partitioners_.front()->GetRows(nidx);
this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()),
void BuildHist(EllpackPage const& page, std::int32_t k, bst_bin_t nidx) {
monitor.Start(__func__);
auto d_node_hist = histogram_.GetNodeHistogram(nidx);
auto batch = page.Impl();
auto acc = batch->GetDeviceAccessor(ctx_->Device());
auto d_ridx = partitioners_.at(k)->GetRows(nidx);
this->histogram_.BuildHistogram(ctx_->CUDACtx(), acc,
feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx,
d_node_hist, *quantiser);
}
// Attempt to do subtraction trick
// return true if succeeded
bool SubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) {
if (!hist.HistogramExists(nidx_histogram) || !hist.HistogramExists(nidx_parent)) {
return false;
}
auto d_node_hist_parent = hist.GetNodeHistogram(nidx_parent);
auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram);
auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction);
dh::LaunchN(cuts_->TotalBins(), [=] __device__(size_t idx) {
d_node_hist_subtraction[idx] = d_node_hist_parent[idx] - d_node_hist_histogram[idx];
});
return true;
monitor.Stop(__func__);
}
void UpdatePositionColumnSplit(EllpackDeviceAccessor d_matrix,
@ -349,6 +397,7 @@ struct GPUHistMakerDevice {
};
collective::SafeColl(rc);
CHECK_EQ(partitioners_.size(), 1) << "External memory with column split is not yet supported.";
partitioners_.front()->UpdatePositionBatch(
nidx, left_nidx, right_nidx, split_data,
[=] __device__(bst_uint ridx, int nidx_in_batch, NodeSplitData const& data) {
@ -393,10 +442,7 @@ struct GPUHistMakerDevice {
monitor.Start(__func__);
std::vector<bst_node_t> nidx(candidates.size());
std::vector<bst_node_t> left_nidx(candidates.size());
std::vector<bst_node_t> right_nidx(candidates.size());
std::vector<NodeSplitData> split_data(candidates.size());
auto [nidx, left_nidx, right_nidx, split_data] = this->CreatePartitionNodes(p_tree, candidates);
for (size_t i = 0; i < candidates.size(); i++) {
auto const& e = candidates[i];
@ -531,19 +577,6 @@ struct GPUHistMakerDevice {
return true;
}
// num histograms is the number of contiguous histograms in memory to reduce over
void AllReduceHist(MetaInfo const& info, bst_node_t nidx, int num_histograms) {
monitor.Start(__func__);
auto d_node_hist = hist.GetNodeHistogram(nidx);
using ReduceT = typename std::remove_pointer<decltype(d_node_hist.data())>::type::ValueT;
auto rc = collective::GlobalSum(
ctx_, info,
linalg::MakeVec(reinterpret_cast<ReduceT*>(d_node_hist.data()),
d_node_hist.size() * 2 * num_histograms, ctx_->Device()));
SafeColl(rc);
monitor.Stop(__func__);
}
/**
* \brief Build GPU local histograms for the left and right child of some parent node
*/
@ -555,48 +588,44 @@ struct GPUHistMakerDevice {
this->monitor.Start(__func__);
// Some nodes we will manually compute histograms
// others we will do by subtraction
std::vector<int> hist_nidx;
std::vector<int> subtraction_nidx;
for (auto& e : candidates) {
// Decide whether to build the left histogram or right histogram
// Use sum of Hessian as a heuristic to select node with fewest training instances
bool fewer_right = e.split.right_sum.GetQuantisedHess() < e.split.left_sum.GetQuantisedHess();
if (fewer_right) {
hist_nidx.emplace_back(tree[e.nid].RightChild());
subtraction_nidx.emplace_back(tree[e.nid].LeftChild());
} else {
hist_nidx.emplace_back(tree[e.nid].LeftChild());
subtraction_nidx.emplace_back(tree[e.nid].RightChild());
}
}
std::vector<bst_node_t> hist_nidx(candidates.size());
std::vector<bst_node_t> subtraction_nidx(candidates.size());
auto prefetch_copy =
AssignNodes(&tree, this->quantiser.get(), candidates, hist_nidx, subtraction_nidx);
std::vector<int> all_new = hist_nidx;
all_new.insert(all_new.end(), subtraction_nidx.begin(), subtraction_nidx.end());
// Allocate the histograms
// Guaranteed contiguous memory
hist.AllocateHistograms(ctx_, all_new);
histogram_.AllocateHistograms(ctx_, all_new);
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
std::int32_t k = 0;
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(prefetch_copy))) {
for (auto nidx : hist_nidx) {
this->BuildHist(page.Impl(), nidx);
this->BuildHist(page, k, nidx);
}
++k;
}
// Reduce all in one go
// This gives much better latency in a distributed setting
// when processing a large batch
this->AllReduceHist(p_fmat->Info(), hist_nidx.at(0), hist_nidx.size());
this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), hist_nidx.at(0), hist_nidx.size());
for (size_t i = 0; i < subtraction_nidx.size(); i++) {
auto build_hist_nidx = hist_nidx.at(i);
auto subtraction_trick_nidx = subtraction_nidx.at(i);
auto parent_nidx = candidates.at(i).nid;
if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
if (!this->histogram_.SubtractionTrick(parent_nidx, build_hist_nidx,
subtraction_trick_nidx)) {
// Calculate other histogram manually
std::int32_t k = 0;
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
this->BuildHist(page.Impl(), subtraction_trick_nidx);
this->BuildHist(page, k, subtraction_trick_nidx);
++k;
}
this->AllReduceHist(p_fmat->Info(), subtraction_trick_nidx, 1);
this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), subtraction_trick_nidx, 1);
}
}
this->monitor.Stop(__func__);
@ -666,11 +695,13 @@ struct GPUHistMakerDevice {
ctx_, p_fmat->Info(), linalg::MakeVec(reinterpret_cast<ReduceT*>(&root_sum_quantised), 2));
collective::SafeColl(rc);
hist.AllocateHistograms(ctx_, {kRootNIdx});
histogram_.AllocateHistograms(ctx_, {kRootNIdx});
std::int32_t k = 0;
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
this->BuildHist(page.Impl(), kRootNIdx);
this->BuildHist(page, k, kRootNIdx);
++k;
}
this->AllReduceHist(p_fmat->Info(), kRootNIdx, 1);
this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), kRootNIdx, 1);
// Remember root stats
auto root_sum = quantiser.ToFloatingPoint(root_sum_quantised);
@ -812,15 +843,15 @@ class GPUHistMaker : public TreeUpdater {
ctx_, linalg::MakeVec(&column_sampling_seed, sizeof(column_sampling_seed)), 0));
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
common::SetDevice(ctx_->Ordinal());
p_fmat->Info().feature_types.SetDevice(ctx_->Device());
std::vector<bst_idx_t> batch_ptr;
auto batch = HistBatch(*param);
auto cuts = InitBatchCuts(ctx_, p_fmat, batch, &batch_ptr);
this->maker = std::make_unique<GPUHistMakerDevice>(ctx_, *param, column_sampler_, batch,
p_fmat->Info(), batch_ptr, cuts);
this->maker = std::make_unique<GPUHistMakerDevice>(
ctx_, *param, &hist_maker_param_, column_sampler_, batch, p_fmat->Info(), batch_ptr, cuts);
p_last_fmat_ = p_fmat;
initialised_ = true;
@ -888,9 +919,6 @@ class GPUGlobalApproxMaker : public TreeUpdater {
// Used in test to count how many configurations are performed
LOG(DEBUG) << "[GPU Approx]: Configure";
hist_maker_param_.UpdateAllowUnknown(args);
if (hist_maker_param_.max_cached_hist_node != HistMakerTrainParam::DefaultNodes()) {
LOG(WARNING) << "The `max_cached_hist_node` is ignored in GPU.";
}
common::CheckComputeCapability();
initialised_ = false;
@ -932,8 +960,8 @@ class GPUGlobalApproxMaker : public TreeUpdater {
auto cuts = InitBatchCuts(ctx_, p_fmat, batch, &batch_ptr);
batch.regen = false; // Regen only at the beginning of the iteration.
this->maker_ = std::make_unique<GPUHistMakerDevice>(ctx_, *param, column_sampler_, batch,
p_fmat->Info(), batch_ptr, cuts);
this->maker_ = std::make_unique<GPUHistMakerDevice>(
ctx_, *param, &hist_maker_param_, column_sampler_, batch, p_fmat->Info(), batch_ptr, cuts);
std::size_t t_idx{0};
for (xgboost::RegTree* tree : trees) {

View File

@ -9,6 +9,7 @@
#include "../../../../src/tree/gpu_hist/histogram.cuh"
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh" // for RowPartitioner
#include "../../../../src/tree/hist/param.h" // for HistMakerTrainParam
#include "../../../../src/tree/param.h" // for TrainParam
#include "../../categorical_helpers.h" // for OneHotEncodeFeature
#include "../../helpers.h"
@ -21,13 +22,13 @@ TEST(Histogram, DeviceHistogramStorage) {
constexpr size_t kNBins = 128;
constexpr int kNNodes = 4;
constexpr size_t kStopGrowing = kNNodes * kNBins * 2u;
DeviceHistogramStorage<kStopGrowing> histogram;
histogram.Init(FstCU(), kNBins);
DeviceHistogramStorage histogram{};
histogram.Reset(&ctx, kNBins, kNNodes);
for (int i = 0; i < kNNodes; ++i) {
histogram.AllocateHistograms(&ctx, {i});
}
histogram.Reset(&ctx);
ASSERT_EQ(histogram.Data().size(), kStopGrowing);
histogram.Reset(&ctx, kNBins, kNNodes);
// Use allocated memory but do not erase nidx_map.
for (int i = 0; i < kNNodes; ++i) {
@ -55,6 +56,35 @@ TEST(Histogram, DeviceHistogramStorage) {
EXPECT_ANY_THROW(histogram.AllocateHistograms(&ctx, {kNNodes + 1}););
}
TEST(Histogram, SubtractionTrack) {
auto ctx = MakeCUDACtx(0);
auto page = BuildEllpackPage(&ctx, 64, 4);
auto cuts = page->CutsShared();
FeatureGroups fg{*cuts, true, std::numeric_limits<std::size_t>::max(),
sizeof(GradientPairPrecise)};
auto fg_acc = fg.DeviceAccessor(ctx.Device());
auto n_total_bins = cuts->TotalBins();
// 2 nodes
auto max_cached_hist_nodes = 2ull;
DeviceHistogramBuilder histogram;
histogram.Reset(&ctx, max_cached_hist_nodes, fg_acc, n_total_bins, false);
histogram.AllocateHistograms(&ctx, {0, 1, 2});
GPUExpandEntry root;
root.nid = 0;
auto need_build = histogram.SubtractHist({root}, {0}, {1});
std::vector<GPUExpandEntry> candidates(2);
candidates[0].nid = 1;
candidates[1].nid = 2;
need_build = histogram.SubtractHist(candidates, {3, 5}, {4, 6});
ASSERT_EQ(need_build.size(), 2);
ASSERT_EQ(need_build[0], 4);
ASSERT_EQ(need_build[1], 6);
}
std::vector<GradientPairPrecise> GetHostHistGpair() {
// 24 bins, 3 bins for each feature (column).
std::vector<GradientPairPrecise> hist_gpair = {
@ -101,17 +131,16 @@ void TestBuildHist(bool use_shared_memory_histograms) {
auto shm_size = use_shared_memory_histograms ? dh::MaxSharedMemoryOptin(ctx.Ordinal()) : 0;
FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size, sizeof(GradientPairInt64));
DeviceHistogramStorage hist;
hist.Init(ctx.Device(), page->Cuts().TotalBins());
hist.AllocateHistograms(&ctx, {0});
DeviceHistogramBuilder builder;
builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), !use_shared_memory_histograms);
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
feature_groups.DeviceAccessor(ctx.Device()), page->Cuts().TotalBins(),
!use_shared_memory_histograms);
builder.AllocateHistograms(&ctx, {0});
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(),
row_partitioner->GetRows(0), hist.GetNodeHistogram(0), *quantiser);
row_partitioner->GetRows(0), builder.GetNodeHistogram(0), *quantiser);
auto node_histogram = hist.GetNodeHistogram(0);
auto node_histogram = builder.GetNodeHistogram(0);
std::vector<GradientPairInt64> h_result(node_histogram.size());
dh::CopyDeviceSpanToVector(&h_result, node_histogram);
@ -158,7 +187,8 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo());
DeviceHistogramBuilder builder;
builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), force_global);
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
d_histogram, quantiser);
@ -173,7 +203,8 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo());
DeviceHistogramBuilder builder;
builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), force_global);
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
d_new_histogram, quantiser);
@ -197,7 +228,8 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
dh::device_vector<GradientPairInt64> baseline(num_bins);
DeviceHistogramBuilder builder;
builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), force_global);
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
single_group.DeviceAccessor(ctx.Device()), num_bins, force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(baseline), quantiser);
@ -264,7 +296,8 @@ void TestGPUHistogramCategorical(size_t num_categories) {
auto* page = batch.Impl();
FeatureGroups single_group(page->Cuts());
DeviceHistogramBuilder builder;
builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), false);
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
single_group.DeviceAccessor(ctx.Device()), num_categories, false);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(cat_hist), quantiser);
@ -280,7 +313,8 @@ void TestGPUHistogramCategorical(size_t num_categories) {
auto* page = batch.Impl();
FeatureGroups single_group(page->Cuts());
DeviceHistogramBuilder builder;
builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), false);
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
single_group.DeviceAccessor(ctx.Device()), encode_hist.size(), false);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(encode_hist), quantiser);
@ -429,7 +463,8 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
auto ridx = partitioners.at(k)->GetRows(0);
auto d_histogram = dh::ToSpan(multi_hist);
DeviceHistogramBuilder builder;
builder.Reset(&ctx, fg->DeviceAccessor(ctx.Device()), force_global);
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
fg->DeviceAccessor(ctx.Device()), d_histogram.size(), force_global);
builder.BuildHistogram(ctx.CUDACtx(), impl->GetDeviceAccessor(ctx.Device()),
fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx,
d_histogram, quantiser);
@ -454,7 +489,8 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
auto ridx = partitioner.GetRows(0);
auto d_histogram = dh::ToSpan(single_hist);
DeviceHistogramBuilder builder;
builder.Reset(&ctx, fg->DeviceAccessor(ctx.Device()), force_global);
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), fg->DeviceAccessor(ctx.Device()),
d_histogram.size(), force_global);
builder.BuildHistogram(ctx.CUDACtx(), page.GetDeviceAccessor(ctx.Device()),
fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx,
d_histogram, quantiser);

View File

@ -51,7 +51,7 @@ void TestEvaluateSplits(bool force_read_by_column) {
row_set_collection.Init();
HistMakerTrainParam hist_param;
hist.Reset(gmat.cut.Ptrs().back(), hist_param.max_cached_hist_node);
hist.Reset(gmat.cut.Ptrs().back(), hist_param.MaxCachedHistNodes(ctx.Device()));
hist.AllocateHistograms({0});
auto const &elem = row_set_collection[0];
common::BuildHist<false>(row_gpairs, common::Span{elem.begin(), elem.end()}, gmat, hist[0],
@ -120,7 +120,7 @@ TEST(HistMultiEvaluator, Evaluate) {
linalg::Vector<GradientPairPrecise> root_sum({2}, DeviceOrd::CPU());
for (bst_target_t t{0}; t < n_targets; ++t) {
auto &hist = histogram[t];
hist.Reset(n_bins * n_features, hist_param.max_cached_hist_node);
hist.Reset(n_bins * n_features, hist_param.MaxCachedHistNodes(ctx.Device()));
hist.AllocateHistograms({0});
auto node_hist = hist[0];
node_hist[0] = {-0.5, 0.5};
@ -237,7 +237,7 @@ auto CompareOneHotAndPartition(bool onehot) {
entries.front().nid = 0;
entries.front().depth = 0;
hist.Reset(gmat.cut.TotalBins(), hist_param.max_cached_hist_node);
hist.Reset(gmat.cut.TotalBins(), hist_param.MaxCachedHistNodes(ctx.Device()));
hist.AllocateHistograms({0});
auto node_hist = hist[0];
@ -265,9 +265,10 @@ TEST(HistEvaluator, Categorical) {
}
TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
Context ctx;
BoundedHistCollection hist;
HistMakerTrainParam hist_param;
hist.Reset(cuts_.TotalBins(), hist_param.max_cached_hist_node);
hist.Reset(cuts_.TotalBins(), hist_param.MaxCachedHistNodes(ctx.Device()));
hist.AllocateHistograms({0});
auto node_hist = hist[0];
ASSERT_EQ(node_hist.size(), feature_histogram_.size());
@ -277,10 +278,9 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
MetaInfo info;
info.num_col_ = 1;
info.feature_types = {FeatureType::kCategorical};
Context ctx;
auto evaluator = HistEvaluator{&ctx, &param_, info, sampler};
evaluator.InitRoot(GradStats{parent_sum_});
std::vector<CPUExpandEntry> entries(1);
RegTree tree;
evaluator.EvaluateSplits(hist, cuts_, info.feature_types.ConstHostSpan(), tree, &entries);

View File

@ -56,8 +56,9 @@ class TestPartitionBasedSplit : public ::testing::Test {
cuts_.min_vals_.Resize(1);
Context ctx;
HistMakerTrainParam hist_param;
hist_.Reset(cuts_.TotalBins(), hist_param.max_cached_hist_node);
hist_.Reset(cuts_.TotalBins(), hist_param.MaxCachedHistNodes(ctx.Device()));
hist_.AllocateHistograms({0});
auto node_hist = hist_[0];

View File

@ -216,7 +216,7 @@ TEST(GpuHist, ConfigIO) {
}
TEST(GpuHist, MaxDepth) {
Context ctx(MakeCUDACtx(0));
auto ctx = MakeCUDACtx(0);
size_t constexpr kRows = 16;
size_t constexpr kCols = 4;
auto p_mat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();

View File

@ -10,6 +10,7 @@ from xgboost import testing as tm
from xgboost.testing.params import (
cat_parameter_strategy,
exact_parameter_strategy,
hist_cache_strategy,
hist_parameter_strategy,
)
from xgboost.testing.updater import (
@ -46,6 +47,7 @@ class TestGPUUpdaters:
@given(
exact_parameter_strategy,
hist_parameter_strategy,
hist_cache_strategy,
strategies.integers(1, 20),
tm.make_dataset_strategy(),
)
@ -54,19 +56,44 @@ class TestGPUUpdaters:
self,
param: Dict[str, Any],
hist_param: Dict[str, Any],
cache_param: Dict[str, Any],
num_rounds: int,
dataset: tm.TestDataset,
) -> None:
param.update({"tree_method": "hist", "device": "cuda"})
param.update(hist_param)
param.update(cache_param)
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(str(result))
assert tm.non_increasing(result["train"][dataset.metric])
@pytest.mark.parametrize("tree_method", ["approx", "hist"])
def test_cache_size(self, tree_method: str) -> None:
from sklearn.datasets import make_regression
X, y = make_regression(n_samples=4096, n_features=64, random_state=1994)
Xy = xgb.DMatrix(X, y)
results = []
for cache_size in [1, 3, 2048]:
params: Dict[str, Any] = {"tree_method": tree_method, "device": "cuda"}
params["max_cached_hist_node"] = cache_size
evals_result: Dict[str, Dict[str, list]] = {}
xgb.train(
params,
Xy,
num_boost_round=4,
evals=[(Xy, "Train")],
evals_result=evals_result,
)
results.append(evals_result["Train"]["rmse"])
for i in range(1, len(results)):
np.testing.assert_allclose(results[0], results[i])
@given(
exact_parameter_strategy,
hist_parameter_strategy,
hist_cache_strategy,
strategies.integers(1, 20),
tm.make_dataset_strategy(),
)
@ -75,11 +102,13 @@ class TestGPUUpdaters:
self,
param: Dict[str, Any],
hist_param: Dict[str, Any],
cache_param: Dict[str, Any],
num_rounds: int,
dataset: tm.TestDataset,
) -> None:
param.update({"tree_method": "approx", "device": "cuda"})
param.update(hist_param)
param.update(cache_param)
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(str(result))