Split Features into Groups to Compute Histograms in Shared Memory (#5795)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
#include "../../helpers.h"
|
||||
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
|
||||
#include "../../../../src/tree/gpu_hist/histogram.cuh"
|
||||
@@ -7,11 +8,12 @@ namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
template <typename Gradient>
|
||||
void TestDeterminsticHistogram() {
|
||||
size_t constexpr kBins = 24, kCols = 8, kRows = 32768, kRounds = 16;
|
||||
void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
||||
size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16;
|
||||
float constexpr kLower = -1e-2, kUpper = 1e2;
|
||||
|
||||
auto matrix = RandomDataGenerator(kRows, kCols, 0.5).GenerateDMatrix();
|
||||
float sparsity = is_dense ? 0.0f : 0.5f;
|
||||
auto matrix = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix();
|
||||
BatchParam batch_param{0, static_cast<int32_t>(kBins), 0};
|
||||
|
||||
for (auto const& batch : matrix->GetBatches<EllpackPage>(batch_param)) {
|
||||
@@ -20,48 +22,80 @@ void TestDeterminsticHistogram() {
|
||||
tree::RowPartitioner row_partitioner(0, kRows);
|
||||
auto ridx = row_partitioner.GetRows(0);
|
||||
|
||||
dh::device_vector<Gradient> histogram(kBins * kCols);
|
||||
int num_bins = kBins * kCols;
|
||||
dh::device_vector<Gradient> histogram(num_bins);
|
||||
auto d_histogram = dh::ToSpan(histogram);
|
||||
auto gpair = GenerateRandomGradients(kRows, kLower, kUpper);
|
||||
gpair.SetDevice(0);
|
||||
|
||||
FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size,
|
||||
sizeof(Gradient));
|
||||
|
||||
auto rounding = CreateRoundingFactor<Gradient>(gpair.DeviceSpan());
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0), gpair.DeviceSpan(), ridx,
|
||||
d_histogram, rounding);
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
||||
feature_groups.DeviceAccessor(0), gpair.DeviceSpan(),
|
||||
ridx, d_histogram, rounding);
|
||||
|
||||
std::vector<Gradient> histogram_h(num_bins);
|
||||
dh::safe_cuda(cudaMemcpy(histogram_h.data(), d_histogram.data(),
|
||||
num_bins * sizeof(Gradient),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
for (size_t i = 0; i < kRounds; ++i) {
|
||||
dh::device_vector<Gradient> new_histogram(kBins * kCols);
|
||||
auto d_histogram = dh::ToSpan(new_histogram);
|
||||
dh::device_vector<Gradient> new_histogram(num_bins);
|
||||
auto d_new_histogram = dh::ToSpan(new_histogram);
|
||||
|
||||
auto rounding = CreateRoundingFactor<Gradient>(gpair.DeviceSpan());
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0), gpair.DeviceSpan(), ridx,
|
||||
d_histogram, rounding);
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
||||
feature_groups.DeviceAccessor(0),
|
||||
gpair.DeviceSpan(), ridx, d_new_histogram,
|
||||
rounding);
|
||||
|
||||
for (size_t j = 0; j < new_histogram.size(); ++j) {
|
||||
ASSERT_EQ(((Gradient)new_histogram[j]).GetGrad(),
|
||||
((Gradient)histogram[j]).GetGrad());
|
||||
ASSERT_EQ(((Gradient)new_histogram[j]).GetHess(),
|
||||
((Gradient)histogram[j]).GetHess());
|
||||
std::vector<Gradient> new_histogram_h(num_bins);
|
||||
dh::safe_cuda(cudaMemcpy(new_histogram_h.data(), d_new_histogram.data(),
|
||||
num_bins * sizeof(Gradient),
|
||||
cudaMemcpyDeviceToHost));
|
||||
for (size_t j = 0; j < new_histogram_h.size(); ++j) {
|
||||
ASSERT_EQ(new_histogram_h[j].GetGrad(), histogram_h[j].GetGrad());
|
||||
ASSERT_EQ(new_histogram_h[j].GetHess(), histogram_h[j].GetHess());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
auto gpair = GenerateRandomGradients(kRows, kLower, kUpper);
|
||||
gpair.SetDevice(0);
|
||||
dh::device_vector<Gradient> baseline(kBins * kCols);
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0), gpair.DeviceSpan(), ridx,
|
||||
dh::ToSpan(baseline), rounding);
|
||||
|
||||
// Use a single feature group to compute the baseline.
|
||||
FeatureGroups single_group(page->Cuts());
|
||||
|
||||
dh::device_vector<Gradient> baseline(num_bins);
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
||||
single_group.DeviceAccessor(0),
|
||||
gpair.DeviceSpan(), ridx, dh::ToSpan(baseline),
|
||||
rounding);
|
||||
|
||||
std::vector<Gradient> baseline_h(num_bins);
|
||||
dh::safe_cuda(cudaMemcpy(baseline_h.data(), baseline.data().get(),
|
||||
num_bins * sizeof(Gradient),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
for (size_t i = 0; i < baseline.size(); ++i) {
|
||||
EXPECT_NEAR(((Gradient)baseline[i]).GetGrad(), ((Gradient)histogram[i]).GetGrad(),
|
||||
((Gradient)baseline[i]).GetGrad() * 1e-3);
|
||||
EXPECT_NEAR(baseline_h[i].GetGrad(), histogram_h[i].GetGrad(),
|
||||
baseline_h[i].GetGrad() * 1e-3);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Histogram, GPUDeterminstic) {
|
||||
TestDeterminsticHistogram<GradientPair>();
|
||||
TestDeterminsticHistogram<GradientPairPrecise>();
|
||||
TEST(Histogram, GPUDeterministic) {
|
||||
std::vector<bool> is_dense_array{false, true};
|
||||
std::vector<int> shm_sizes{48 * 1024, 64 * 1024, 160 * 1024};
|
||||
for (bool is_dense : is_dense_array) {
|
||||
for (int shm_size : shm_sizes) {
|
||||
TestDeterministicHistogram<GradientPair>(is_dense, shm_size);
|
||||
TestDeterministicHistogram<GradientPairPrecise>(is_dense, shm_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user