Split Features into Groups to Compute Histograms in Shared Memory (#5795)
This commit is contained in:
parent
93c44a9a64
commit
ac3f0e78dc
64
src/tree/gpu_hist/feature_groups.cu
Normal file
64
src/tree/gpu_hist/feature_groups.cu
Normal file
@ -0,0 +1,64 @@
|
||||
/*!
|
||||
* Copyright 2020 by XGBoost Contributors
|
||||
*/
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "feature_groups.cuh"
|
||||
|
||||
#include "../../common/device_helpers.cuh"
|
||||
#include "../../common/hist_util.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
FeatureGroups::FeatureGroups(const common::HistogramCuts& cuts, bool is_dense,
|
||||
size_t shm_size, size_t bin_size) {
|
||||
// Only use a single feature group for sparse matrices.
|
||||
bool single_group = !is_dense;
|
||||
if (single_group) {
|
||||
InitSingle(cuts);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int>& feature_segments_h = feature_segments.HostVector();
|
||||
std::vector<int>& bin_segments_h = bin_segments.HostVector();
|
||||
feature_segments_h.push_back(0);
|
||||
bin_segments_h.push_back(0);
|
||||
|
||||
const std::vector<uint32_t>& cut_ptrs = cuts.Ptrs();
|
||||
int max_shmem_bins = shm_size / bin_size;
|
||||
max_group_bins = 0;
|
||||
|
||||
for (size_t i = 2; i < cut_ptrs.size(); ++i) {
|
||||
int last_start = bin_segments_h.back();
|
||||
if (cut_ptrs[i] - last_start > max_shmem_bins) {
|
||||
feature_segments_h.push_back(i - 1);
|
||||
bin_segments_h.push_back(cut_ptrs[i - 1]);
|
||||
max_group_bins = std::max(max_group_bins,
|
||||
bin_segments_h.back() - last_start);
|
||||
}
|
||||
}
|
||||
feature_segments_h.push_back(cut_ptrs.size() - 1);
|
||||
bin_segments_h.push_back(cut_ptrs.back());
|
||||
max_group_bins = std::max(max_group_bins,
|
||||
bin_segments_h.back() -
|
||||
bin_segments_h[bin_segments_h.size() - 2]);
|
||||
}
|
||||
|
||||
void FeatureGroups::InitSingle(const common::HistogramCuts& cuts) {
|
||||
std::vector<int>& feature_segments_h = feature_segments.HostVector();
|
||||
feature_segments_h.push_back(0);
|
||||
feature_segments_h.push_back(cuts.Ptrs().size() - 1);
|
||||
|
||||
std::vector<int>& bin_segments_h = bin_segments.HostVector();
|
||||
bin_segments_h.push_back(0);
|
||||
bin_segments_h.push_back(cuts.TotalBins());
|
||||
|
||||
max_group_bins = cuts.TotalBins();
|
||||
}
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
119
src/tree/gpu_hist/feature_groups.cuh
Normal file
119
src/tree/gpu_hist/feature_groups.cuh
Normal file
@ -0,0 +1,119 @@
|
||||
/*!
|
||||
* Copyright 2020 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef FEATURE_GROUPS_CUH_
|
||||
#define FEATURE_GROUPS_CUH_
|
||||
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/span.h>
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
// Forward declarations.
|
||||
namespace common {
|
||||
class HistogramCuts;
|
||||
} // namespace common
|
||||
|
||||
namespace tree {
|
||||
|
||||
/** \brief FeatureGroup is a feature group. It is defined by a range of
|
||||
consecutive feature indices, and also contains a range of all bin indices
|
||||
associated with those features. */
|
||||
struct FeatureGroup {
|
||||
__host__ __device__ FeatureGroup(int start_feature_, int num_features_,
|
||||
int start_bin_, int num_bins_) :
|
||||
start_feature(start_feature_), num_features(num_features_),
|
||||
start_bin(start_bin_), num_bins(num_bins_) {}
|
||||
/** The first feature of the group. */
|
||||
int start_feature;
|
||||
/** The number of features in the group. */
|
||||
int num_features;
|
||||
/** The first bin in the group. */
|
||||
int start_bin;
|
||||
/** The number of bins in the group. */
|
||||
int num_bins;
|
||||
};
|
||||
|
||||
/** \brief FeatureGroupsAccessor is a non-owning accessor for FeatureGroups. */
|
||||
struct FeatureGroupsAccessor {
|
||||
FeatureGroupsAccessor(common::Span<const int> feature_segments_,
|
||||
common::Span<const int> bin_segments_, int max_group_bins_) :
|
||||
feature_segments(feature_segments_), bin_segments(bin_segments_),
|
||||
max_group_bins(max_group_bins_) {}
|
||||
|
||||
common::Span<const int> feature_segments;
|
||||
common::Span<const int> bin_segments;
|
||||
int max_group_bins;
|
||||
|
||||
/** \brief Gets the number of feature groups. */
|
||||
__host__ __device__ int NumGroups() const {
|
||||
return feature_segments.size() - 1;
|
||||
}
|
||||
|
||||
/** \brief Gets the information about a feature group with index i. */
|
||||
__host__ __device__ FeatureGroup operator[](int i) const {
|
||||
return {feature_segments[i], feature_segments[i + 1] - feature_segments[i],
|
||||
bin_segments[i], bin_segments[i + 1] - bin_segments[i]};
|
||||
}
|
||||
};
|
||||
|
||||
/** \brief FeatureGroups contains information that defines a split of features
|
||||
into groups. Bins of a single feature group typically fit into shared
|
||||
memory, so the histogram for the features of a single group can be computed
|
||||
faster.
|
||||
|
||||
\notes Known limitations:
|
||||
|
||||
- splitting features into groups currently works only for dense matrices,
|
||||
where it is easy to get a feature value in a row by its index; for sparse
|
||||
matrices, the structure contains only a single group containing all
|
||||
features;
|
||||
|
||||
- if a single feature requires more bins than fit into shared memory, the
|
||||
histogram is computed in global memory even if there are multiple feature
|
||||
groups; note that this is unlikely to occur in practice, as the default
|
||||
number of bins per feature is 256, whereas a thread block with 48 KiB
|
||||
shared memory can contain 3072 bins if each gradient sum component is a
|
||||
64-bit floating-point value (double)
|
||||
*/
|
||||
struct FeatureGroups {
|
||||
/** Group cuts for features. Size equals to (number of groups + 1). */
|
||||
HostDeviceVector<int> feature_segments;
|
||||
/** Group cuts for bins. Size equals to (number of groups + 1) */
|
||||
HostDeviceVector<int> bin_segments;
|
||||
/** Maximum number of bins in a group. Useful to compute the amount of dynamic
|
||||
shared memory when launching a kernel. */
|
||||
int max_group_bins;
|
||||
|
||||
/** Creates feature groups by splitting features into groups.
|
||||
\param cuts Histogram cuts that given the number of bins per feature.
|
||||
\param is_dense Whether the data matrix is dense.
|
||||
\param shm_size Available size of shared memory per thread block (in
|
||||
bytes) used to compute feature groups.
|
||||
\param bin_size Size of a single bin of the histogram. */
|
||||
FeatureGroups(const common::HistogramCuts& cuts, bool is_dense,
|
||||
size_t shm_size, size_t bin_size);
|
||||
|
||||
/** Creates a single feature group containing all features and bins.
|
||||
\notes This is used as a fallback for sparse matrices, and is also useful
|
||||
for testing.
|
||||
*/
|
||||
explicit FeatureGroups(const common::HistogramCuts& cuts) {
|
||||
InitSingle(cuts);
|
||||
}
|
||||
|
||||
FeatureGroupsAccessor DeviceAccessor(int device) const {
|
||||
feature_segments.SetDevice(device);
|
||||
bin_segments.SetDevice(device);
|
||||
return {feature_segments.ConstDeviceSpan(), bin_segments.ConstDeviceSpan(),
|
||||
max_group_bins};
|
||||
}
|
||||
|
||||
private:
|
||||
void InitSingle(const common::HistogramCuts& cuts);
|
||||
};
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // FEATURE_GROUPS_CUH_
|
||||
@ -102,23 +102,26 @@ template GradientPair CreateRoundingFactor(common::Span<GradientPair const> gpai
|
||||
|
||||
template <typename GradientSumT>
|
||||
__global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix,
|
||||
FeatureGroupsAccessor feature_groups,
|
||||
common::Span<const RowPartitioner::RowIndexT> d_ridx,
|
||||
GradientSumT* __restrict__ d_node_hist,
|
||||
const GradientPair* __restrict__ d_gpair,
|
||||
size_t n_elements,
|
||||
GradientSumT const rounding,
|
||||
bool use_shared_memory_histograms) {
|
||||
using T = typename GradientSumT::ValueT;
|
||||
extern __shared__ char smem[];
|
||||
FeatureGroup group = feature_groups[blockIdx.y];
|
||||
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
|
||||
if (use_shared_memory_histograms) {
|
||||
dh::BlockFill(smem_arr, matrix.NumBins(), GradientSumT());
|
||||
dh::BlockFill(smem_arr, group.num_bins, GradientSumT());
|
||||
__syncthreads();
|
||||
}
|
||||
int feature_stride = matrix.is_dense ? group.num_features : matrix.row_stride;
|
||||
size_t n_elements = feature_stride * d_ridx.size();
|
||||
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
|
||||
int ridx = d_ridx[idx / matrix.row_stride];
|
||||
int gidx =
|
||||
matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride];
|
||||
int ridx = d_ridx[idx / feature_stride];
|
||||
int gidx = matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature +
|
||||
idx % feature_stride];
|
||||
if (gidx != matrix.NumBins()) {
|
||||
GradientSumT truncated {
|
||||
TruncateWithRoundingFactor<T>(rounding.GetGrad(), d_gpair[ridx].GetGrad()),
|
||||
@ -127,7 +130,8 @@ __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix,
|
||||
// If we are not using shared memory, accumulate the values directly into
|
||||
// global memory
|
||||
GradientSumT* atomic_add_ptr =
|
||||
use_shared_memory_histograms ? smem_arr : d_node_hist;
|
||||
use_shared_memory_histograms ? smem_arr : d_node_hist;
|
||||
gidx = use_shared_memory_histograms ? gidx - group.start_bin : gidx;
|
||||
dh::AtomicAddGpair(atomic_add_ptr + gidx, truncated);
|
||||
}
|
||||
}
|
||||
@ -135,18 +139,21 @@ __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix,
|
||||
if (use_shared_memory_histograms) {
|
||||
// Write shared memory back to global memory
|
||||
__syncthreads();
|
||||
for (auto i : dh::BlockStrideRange(static_cast<size_t>(0), matrix.NumBins())) {
|
||||
GradientSumT truncated {
|
||||
TruncateWithRoundingFactor<T>(rounding.GetGrad(), smem_arr[i].GetGrad()),
|
||||
TruncateWithRoundingFactor<T>(rounding.GetHess(), smem_arr[i].GetHess()),
|
||||
for (auto i : dh::BlockStrideRange(0, group.num_bins)) {
|
||||
GradientSumT truncated{
|
||||
TruncateWithRoundingFactor<T>(rounding.GetGrad(),
|
||||
smem_arr[i].GetGrad()),
|
||||
TruncateWithRoundingFactor<T>(rounding.GetHess(),
|
||||
smem_arr[i].GetHess()),
|
||||
};
|
||||
dh::AtomicAddGpair(d_node_hist + i, truncated);
|
||||
dh::AtomicAddGpair(d_node_hist + group.start_bin + i, truncated);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
FeatureGroupsAccessor const& feature_groups,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> d_ridx,
|
||||
common::Span<GradientSumT> histogram,
|
||||
@ -155,7 +162,7 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
int device = 0;
|
||||
dh::safe_cuda(cudaGetDevice(&device));
|
||||
int max_shared_memory = dh::MaxSharedMemoryOptin(device);
|
||||
size_t smem_size = sizeof(GradientSumT) * matrix.NumBins();
|
||||
size_t smem_size = sizeof(GradientSumT) * feature_groups.max_group_bins;
|
||||
bool shared = smem_size <= max_shared_memory;
|
||||
smem_size = shared ? smem_size : 0;
|
||||
|
||||
@ -169,6 +176,7 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
|
||||
// determine the launch configuration
|
||||
unsigned block_threads = shared ? 1024 : 256;
|
||||
int num_groups = feature_groups.NumGroups();
|
||||
int n_mps = 0;
|
||||
dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device));
|
||||
int n_blocks_per_mp = 0;
|
||||
@ -176,15 +184,31 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
(&n_blocks_per_mp, kernel, block_threads, smem_size));
|
||||
unsigned grid_size = n_blocks_per_mp * n_mps;
|
||||
|
||||
auto n_elements = d_ridx.size() * matrix.row_stride;
|
||||
dh::LaunchKernel {grid_size, block_threads, smem_size} (
|
||||
kernel, matrix, d_ridx, histogram.data(), gpair.data(), n_elements,
|
||||
rounding, shared);
|
||||
// TODO(canonizer): This is really a hack, find a better way to distribute the
|
||||
// data among thread blocks.
|
||||
// The intention is to generate enough thread blocks to fill the GPU, but
|
||||
// avoid having too many thread blocks, as this is less efficient when the
|
||||
// number of rows is low. At least one thread block per feature group is
|
||||
// required.
|
||||
// The number of thread blocks:
|
||||
// - for num_groups <= num_groups_threshold, around grid_size * num_groups
|
||||
// - for num_groups_threshold <= num_groups <= num_groups_threshold * grid_size,
|
||||
// around grid_size * num_groups_threshold
|
||||
// - for num_groups_threshold * grid_size <= num_groups, around num_groups
|
||||
int num_groups_threshold = 4;
|
||||
grid_size = common::DivRoundUp(grid_size,
|
||||
common::DivRoundUp(num_groups, num_groups_threshold));
|
||||
|
||||
dh::LaunchKernel {dim3(grid_size, num_groups), block_threads, smem_size} (
|
||||
kernel,
|
||||
matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding,
|
||||
shared);
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
}
|
||||
|
||||
template void BuildGradientHistogram<GradientPair>(
|
||||
EllpackDeviceAccessor const& matrix,
|
||||
FeatureGroupsAccessor const& feature_groups,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> ridx,
|
||||
common::Span<GradientPair> histogram,
|
||||
@ -192,6 +216,7 @@ template void BuildGradientHistogram<GradientPair>(
|
||||
|
||||
template void BuildGradientHistogram<GradientPairPrecise>(
|
||||
EllpackDeviceAccessor const& matrix,
|
||||
FeatureGroupsAccessor const& feature_groups,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> ridx,
|
||||
common::Span<GradientPairPrecise> histogram,
|
||||
|
||||
@ -4,6 +4,9 @@
|
||||
#ifndef HISTOGRAM_CUH_
|
||||
#define HISTOGRAM_CUH_
|
||||
#include <thrust/transform.h>
|
||||
|
||||
#include "feature_groups.cuh"
|
||||
|
||||
#include "../../data/ellpack_page.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
@ -19,6 +22,7 @@ DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, float const x)
|
||||
|
||||
template <typename GradientSumT>
|
||||
void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
FeatureGroupsAccessor const& feature_groups,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> ridx,
|
||||
common::Span<GradientSumT> histogram,
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
#include "param.h"
|
||||
#include "updater_gpu_common.cuh"
|
||||
#include "constraints.cuh"
|
||||
#include "gpu_hist/feature_groups.cuh"
|
||||
#include "gpu_hist/gradient_based_sampler.cuh"
|
||||
#include "gpu_hist/row_partitioner.cuh"
|
||||
#include "gpu_hist/histogram.cuh"
|
||||
@ -203,6 +204,8 @@ struct GPUHistMakerDevice {
|
||||
|
||||
std::unique_ptr<GradientBasedSampler> sampler;
|
||||
|
||||
std::unique_ptr<FeatureGroups> feature_groups;
|
||||
|
||||
GPUHistMakerDevice(int _device_id,
|
||||
EllpackPageImpl* _page,
|
||||
bst_uint _n_rows,
|
||||
@ -229,6 +232,9 @@ struct GPUHistMakerDevice {
|
||||
// Init histogram
|
||||
hist.Init(device_id, page->Cuts().TotalBins());
|
||||
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id));
|
||||
feature_groups.reset(new FeatureGroups(
|
||||
page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(device_id),
|
||||
sizeof(GradientSumT)));
|
||||
}
|
||||
|
||||
~GPUHistMakerDevice() { // NOLINT
|
||||
@ -372,8 +378,9 @@ struct GPUHistMakerDevice {
|
||||
hist.AllocateHistogram(nidx);
|
||||
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
||||
auto d_ridx = row_partitioner->GetRows(nidx);
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(device_id), gpair, d_ridx, d_node_hist,
|
||||
histogram_rounding);
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(device_id),
|
||||
feature_groups->DeviceAccessor(device_id), gpair,
|
||||
d_ridx, d_node_hist, histogram_rounding);
|
||||
}
|
||||
|
||||
void SubtractionTrick(int nidx_parent, int nidx_histogram,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
#include "../../helpers.h"
|
||||
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
|
||||
#include "../../../../src/tree/gpu_hist/histogram.cuh"
|
||||
@ -7,11 +8,12 @@ namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
template <typename Gradient>
|
||||
void TestDeterminsticHistogram() {
|
||||
size_t constexpr kBins = 24, kCols = 8, kRows = 32768, kRounds = 16;
|
||||
void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
||||
size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16;
|
||||
float constexpr kLower = -1e-2, kUpper = 1e2;
|
||||
|
||||
auto matrix = RandomDataGenerator(kRows, kCols, 0.5).GenerateDMatrix();
|
||||
float sparsity = is_dense ? 0.0f : 0.5f;
|
||||
auto matrix = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix();
|
||||
BatchParam batch_param{0, static_cast<int32_t>(kBins), 0};
|
||||
|
||||
for (auto const& batch : matrix->GetBatches<EllpackPage>(batch_param)) {
|
||||
@ -20,48 +22,80 @@ void TestDeterminsticHistogram() {
|
||||
tree::RowPartitioner row_partitioner(0, kRows);
|
||||
auto ridx = row_partitioner.GetRows(0);
|
||||
|
||||
dh::device_vector<Gradient> histogram(kBins * kCols);
|
||||
int num_bins = kBins * kCols;
|
||||
dh::device_vector<Gradient> histogram(num_bins);
|
||||
auto d_histogram = dh::ToSpan(histogram);
|
||||
auto gpair = GenerateRandomGradients(kRows, kLower, kUpper);
|
||||
gpair.SetDevice(0);
|
||||
|
||||
FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size,
|
||||
sizeof(Gradient));
|
||||
|
||||
auto rounding = CreateRoundingFactor<Gradient>(gpair.DeviceSpan());
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0), gpair.DeviceSpan(), ridx,
|
||||
d_histogram, rounding);
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
||||
feature_groups.DeviceAccessor(0), gpair.DeviceSpan(),
|
||||
ridx, d_histogram, rounding);
|
||||
|
||||
std::vector<Gradient> histogram_h(num_bins);
|
||||
dh::safe_cuda(cudaMemcpy(histogram_h.data(), d_histogram.data(),
|
||||
num_bins * sizeof(Gradient),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
for (size_t i = 0; i < kRounds; ++i) {
|
||||
dh::device_vector<Gradient> new_histogram(kBins * kCols);
|
||||
auto d_histogram = dh::ToSpan(new_histogram);
|
||||
dh::device_vector<Gradient> new_histogram(num_bins);
|
||||
auto d_new_histogram = dh::ToSpan(new_histogram);
|
||||
|
||||
auto rounding = CreateRoundingFactor<Gradient>(gpair.DeviceSpan());
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0), gpair.DeviceSpan(), ridx,
|
||||
d_histogram, rounding);
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
||||
feature_groups.DeviceAccessor(0),
|
||||
gpair.DeviceSpan(), ridx, d_new_histogram,
|
||||
rounding);
|
||||
|
||||
for (size_t j = 0; j < new_histogram.size(); ++j) {
|
||||
ASSERT_EQ(((Gradient)new_histogram[j]).GetGrad(),
|
||||
((Gradient)histogram[j]).GetGrad());
|
||||
ASSERT_EQ(((Gradient)new_histogram[j]).GetHess(),
|
||||
((Gradient)histogram[j]).GetHess());
|
||||
std::vector<Gradient> new_histogram_h(num_bins);
|
||||
dh::safe_cuda(cudaMemcpy(new_histogram_h.data(), d_new_histogram.data(),
|
||||
num_bins * sizeof(Gradient),
|
||||
cudaMemcpyDeviceToHost));
|
||||
for (size_t j = 0; j < new_histogram_h.size(); ++j) {
|
||||
ASSERT_EQ(new_histogram_h[j].GetGrad(), histogram_h[j].GetGrad());
|
||||
ASSERT_EQ(new_histogram_h[j].GetHess(), histogram_h[j].GetHess());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
auto gpair = GenerateRandomGradients(kRows, kLower, kUpper);
|
||||
gpair.SetDevice(0);
|
||||
dh::device_vector<Gradient> baseline(kBins * kCols);
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0), gpair.DeviceSpan(), ridx,
|
||||
dh::ToSpan(baseline), rounding);
|
||||
|
||||
// Use a single feature group to compute the baseline.
|
||||
FeatureGroups single_group(page->Cuts());
|
||||
|
||||
dh::device_vector<Gradient> baseline(num_bins);
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
||||
single_group.DeviceAccessor(0),
|
||||
gpair.DeviceSpan(), ridx, dh::ToSpan(baseline),
|
||||
rounding);
|
||||
|
||||
std::vector<Gradient> baseline_h(num_bins);
|
||||
dh::safe_cuda(cudaMemcpy(baseline_h.data(), baseline.data().get(),
|
||||
num_bins * sizeof(Gradient),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
for (size_t i = 0; i < baseline.size(); ++i) {
|
||||
EXPECT_NEAR(((Gradient)baseline[i]).GetGrad(), ((Gradient)histogram[i]).GetGrad(),
|
||||
((Gradient)baseline[i]).GetGrad() * 1e-3);
|
||||
EXPECT_NEAR(baseline_h[i].GetGrad(), histogram_h[i].GetGrad(),
|
||||
baseline_h[i].GetGrad() * 1e-3);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Histogram, GPUDeterminstic) {
|
||||
TestDeterminsticHistogram<GradientPair>();
|
||||
TestDeterminsticHistogram<GradientPairPrecise>();
|
||||
TEST(Histogram, GPUDeterministic) {
|
||||
std::vector<bool> is_dense_array{false, true};
|
||||
std::vector<int> shm_sizes{48 * 1024, 64 * 1024, 160 * 1024};
|
||||
for (bool is_dense : is_dense_array) {
|
||||
for (int shm_size : shm_sizes) {
|
||||
TestDeterministicHistogram<GradientPair>(is_dense, shm_size);
|
||||
TestDeterministicHistogram<GradientPairPrecise>(is_dense, shm_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user