[breaking] Use integer atomic for GPU histogram. (#7180)

On GPU we use rouding factor to truncate the gradient for deterministic results. This PR changes the gradient representation to fixed point number with exponent aligned with rounding factor.

    [breaking] Drop non-deterministic histogram.
    Use fixed point for shared memory.

This PR is to improve the performance of GPU Hist. 

Co-authored-by: Andy Adinets <aadinets@nvidia.com>
This commit is contained in:
Jiaming Yuan
2021-08-28 05:17:05 +08:00
committed by GitHub
parent e7d7ab6bc3
commit 7a1d67f9cb
11 changed files with 295 additions and 142 deletions

View File

@@ -1,7 +1,10 @@
/*!
* Copyright 2017 XGBoost contributors
* Copyright 2017-2021 XGBoost contributors
*/
#include <cstddef>
#include <cstdint>
#include <thrust/device_vector.h>
#include <vector>
#include <xgboost/base.h>
#include "../../../src/common/device_helpers.cuh"
#include "../../../src/common/quantile.h"
@@ -101,8 +104,6 @@ struct IsSorted {
} // namespace
namespace xgboost {
namespace common {
void TestSegmentedUniqueRegression(std::vector<SketchEntry> values, size_t n_duplicated) {
std::vector<bst_feature_t> segments{0, static_cast<bst_feature_t>(values.size())};
@@ -194,5 +195,73 @@ TEST(DeviceHelpers, ArgSort) {
ASSERT_TRUE(thrust::is_sorted(sorted_idx.begin() + 10, sorted_idx.end(),
thrust::greater<size_t>{}));
}
} // namespace common
namespace {
// Atomic add as type cast for test.
XGBOOST_DEV_INLINE int64_t atomicAdd(int64_t *dst, int64_t src) { // NOLINT
uint64_t* u_dst = reinterpret_cast<uint64_t*>(dst);
uint64_t u_src = *reinterpret_cast<uint64_t*>(&src);
uint64_t ret = ::atomicAdd(u_dst, u_src);
return *reinterpret_cast<int64_t*>(&ret);
}
}
void TestAtomicAdd() {
size_t n_elements = 1024;
dh::device_vector<int64_t> result_a(1, 0);
auto d_result_a = result_a.data().get();
dh::device_vector<int64_t> result_b(1, 0);
auto d_result_b = result_b.data().get();
/**
* Test for simple inputs
*/
std::vector<int64_t> h_inputs(n_elements);
for (size_t i = 0; i < h_inputs.size(); ++i) {
h_inputs[i] = (i % 2 == 0) ? i : -i;
}
dh::device_vector<int64_t> inputs(h_inputs);
auto d_inputs = inputs.data().get();
dh::LaunchN(n_elements, [=] __device__(size_t i) {
dh::AtomicAdd64As32(d_result_a, d_inputs[i]);
atomicAdd(d_result_b, d_inputs[i]);
});
ASSERT_EQ(result_a[0], result_b[0]);
/**
* Test for positive values that don't fit into 32 bit integer.
*/
thrust::fill(inputs.begin(), inputs.end(),
(std::numeric_limits<uint32_t>::max() / 2));
thrust::fill(result_a.begin(), result_a.end(), 0);
thrust::fill(result_b.begin(), result_b.end(), 0);
dh::LaunchN(n_elements, [=] __device__(size_t i) {
dh::AtomicAdd64As32(d_result_a, d_inputs[i]);
atomicAdd(d_result_b, d_inputs[i]);
});
ASSERT_EQ(result_a[0], result_b[0]);
ASSERT_GT(result_a[0], std::numeric_limits<uint32_t>::max());
CHECK_EQ(thrust::reduce(inputs.begin(), inputs.end(), int64_t(0)), result_a[0]);
/**
* Test for negative values that don't fit into 32 bit integer.
*/
thrust::fill(inputs.begin(), inputs.end(),
(std::numeric_limits<int32_t>::min() / 2));
thrust::fill(result_a.begin(), result_a.end(), 0);
thrust::fill(result_b.begin(), result_b.end(), 0);
dh::LaunchN(n_elements, [=] __device__(size_t i) {
dh::AtomicAdd64As32(d_result_a, d_inputs[i]);
atomicAdd(d_result_b, d_inputs[i]);
});
ASSERT_EQ(result_a[0], result_b[0]);
ASSERT_LT(result_a[0], std::numeric_limits<int32_t>::min());
CHECK_EQ(thrust::reduce(inputs.begin(), inputs.end(), int64_t(0)), result_a[0]);
}
TEST(AtomicAdd, Int64) {
TestAtomicAdd();
}
} // namespace xgboost

View File

@@ -1,5 +1,9 @@
#include "test_ranking_obj.cc"
/*!
* Copyright 2019-2021 by XGBoost Contributors
*/
#include <thrust/host_vector.h>
#include "test_ranking_obj.cc"
#include "../../../src/objective/rank_obj.cu"
namespace xgboost {

View File

@@ -1,8 +1,13 @@
/*!
* Copyright 2019-2021 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <vector>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/sequence.h>
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
#include "../../helpers.h"

View File

@@ -1,8 +1,9 @@
/*!
* Copyright 2017-2020 XGBoost contributors
* Copyright 2017-2021 XGBoost contributors
*/
#include <gtest/gtest.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <dmlc/filesystem.h>
#include <xgboost/base.h>
#include <random>
@@ -80,8 +81,8 @@ void TestBuildHist(bool use_shared_memory_histograms) {
param.Init(args);
auto page = BuildEllpackPage(kNRows, kNCols);
BatchParam batch_param{};
GPUHistMakerDevice<GradientSumT> maker(0, page.get(), {}, kNRows, param, kNCols, kNCols,
true, batch_param);
GPUHistMakerDevice<GradientSumT> maker(0, page.get(), {}, kNRows, param,
kNCols, kNCols, batch_param);
xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
HostDeviceVector<GradientPair> gpair(kNRows);
@@ -93,14 +94,18 @@ void TestBuildHist(bool use_shared_memory_histograms) {
gpair.SetDevice(0);
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (page->gidx_buffer.HostVector());
maker.row_partitioner.reset(new RowPartitioner(0, kNRows));
maker.hist.AllocateHistogram(0);
maker.gpair = gpair.DeviceSpan();
maker.histogram_rounding = CreateRoundingFactor<GradientSumT>(maker.gpair);;
maker.BuildHist(0);
DeviceHistogram<GradientSumT> d_hist = maker.hist;
BuildGradientHistogram(
page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0),
gpair.DeviceSpan(), maker.row_partitioner->GetRows(0),
maker.hist.GetNodeHistogram(0), maker.histogram_rounding,
!use_shared_memory_histograms);
DeviceHistogram<GradientSumT>& d_hist = maker.hist;
auto node_histogram = d_hist.GetNodeHistogram(0);
// d_hist.data stored in float, not gradient pair
@@ -115,6 +120,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
std::vector<GradientPairPrecise> solution = GetHostHistGpair();
std::cout << std::fixed;
for (size_t i = 0; i < h_result.size(); ++i) {
ASSERT_FALSE(std::isnan(h_result[i].GetGrad()));
EXPECT_NEAR(h_result[i].GetGrad(), solution[i].GetGrad(), 0.01f);
EXPECT_NEAR(h_result[i].GetHess(), solution[i].GetHess(), 0.01f);
}
@@ -158,7 +164,8 @@ TEST(GpuHist, ApplySplit) {
HostDeviceVector<FeatureType> feature_types(10, FeatureType::kCategorical);
feature_types.SetDevice(bparam.gpu_id);
tree::GPUHistMakerDevice<GradientPairPrecise> updater(
0, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, true, bparam);
0, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols,
bparam);
updater.ApplySplit(candidate, &tree);
ASSERT_EQ(tree.GetSplitTypes().size(), 3);
@@ -217,8 +224,8 @@ TEST(GpuHist, EvaluateRootSplit) {
// Initialize GPUHistMakerDevice
auto page = BuildEllpackPage(kNRows, kNCols);
BatchParam batch_param{};
GPUHistMakerDevice<GradientPairPrecise>
maker(0, page.get(), {}, kNRows, param, kNCols, kNCols, true, batch_param);
GPUHistMakerDevice<GradientPairPrecise> maker(
0, page.get(), {}, kNRows, param, kNCols, kNCols, batch_param);
// Initialize GPUHistMakerDevice::node_sum_gradients
maker.node_sum_gradients = {};