Cleanup GPU Hist tests. (#10677)

* Cleanup GPU Hist tests.

- Remove GPU Hist gradient sampling test. The same properties are tested in the gradient
  sampler test suite.
- Move basic histogram tests into the histogram test suite.
- Remove the header inclusion of the `updater_gpu_hist.cu` in tests.
This commit is contained in:
Jiaming Yuan 2024-08-06 11:50:44 +08:00 committed by GitHub
parent 6ccf116601
commit cc3b56fc37
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 104 additions and 183 deletions

View File

@ -1,12 +1,12 @@
/*! /**
* Copyright 2021-2022 by XGBoost Contributors * Copyright 2021-2024, XGBoost Contributors
*/ */
#ifndef XGBOOST_TASK_H_ #ifndef XGBOOST_TASK_H_
#define XGBOOST_TASK_H_ #define XGBOOST_TASK_H_
#include <xgboost/base.h> #include <xgboost/base.h>
#include <cinttypes> #include <cstdint> // for uint8_t
namespace xgboost { namespace xgboost {
/*! /*!
@ -23,7 +23,7 @@ namespace xgboost {
*/ */
struct ObjInfo { struct ObjInfo {
// What kind of problem are we trying to solve // What kind of problem are we trying to solve
enum Task : uint8_t { enum Task : std::uint8_t {
kRegression = 0, kRegression = 0,
kBinary = 1, kBinary = 1,
kClassification = 2, kClassification = 2,

View File

@ -45,9 +45,7 @@
#include "xgboost/tree_model.h" #include "xgboost/tree_model.h"
namespace xgboost::tree { namespace xgboost::tree {
#if !defined(GTEST_TEST)
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
#endif // !defined(GTEST_TEST)
// Manage memory for a single GPU // Manage memory for a single GPU
struct GPUHistMakerDevice { struct GPUHistMakerDevice {
@ -831,13 +829,11 @@ class GPUHistMaker : public TreeUpdater {
std::shared_ptr<common::ColumnSampler> column_sampler_; std::shared_ptr<common::ColumnSampler> column_sampler_;
}; };
#if !defined(GTEST_TEST)
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist") XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
.describe("Grow tree with GPU.") .describe("Grow tree with GPU.")
.set_body([](Context const* ctx, ObjInfo const* task) { .set_body([](Context const* ctx, ObjInfo const* task) {
return new GPUHistMaker(ctx, task); return new GPUHistMaker(ctx, task);
}); });
#endif // !defined(GTEST_TEST)
class GPUGlobalApproxMaker : public TreeUpdater { class GPUGlobalApproxMaker : public TreeUpdater {
public: public:
@ -960,11 +956,9 @@ class GPUGlobalApproxMaker : public TreeUpdater {
common::Monitor monitor_; common::Monitor monitor_;
}; };
#if !defined(GTEST_TEST)
XGBOOST_REGISTER_TREE_UPDATER(GPUApproxMaker, "grow_gpu_approx") XGBOOST_REGISTER_TREE_UPDATER(GPUApproxMaker, "grow_gpu_approx")
.describe("Grow tree with GPU.") .describe("Grow tree with GPU.")
.set_body([](Context const* ctx, ObjInfo const* task) { .set_body([](Context const* ctx, ObjInfo const* task) {
return new GPUGlobalApproxMaker(ctx, task); return new GPUGlobalApproxMaker(ctx, task);
}); });
#endif // !defined(GTEST_TEST)
} // namespace xgboost::tree } // namespace xgboost::tree

View File

@ -10,9 +10,7 @@
#include "../../filesystem.h" // dmlc::TemporaryDirectory #include "../../filesystem.h" // dmlc::TemporaryDirectory
#include "../../helpers.h" #include "../../helpers.h"
namespace xgboost { namespace xgboost::tree {
namespace tree {
void VerifySampling(size_t page_size, void VerifySampling(size_t page_size,
float subsample, float subsample,
int sampling_method, int sampling_method,
@ -151,6 +149,4 @@ TEST(GradientBasedSampler, GradientBasedSamplingExternalMemory) {
constexpr bool kFixedSizeSampling = false; constexpr bool kFixedSizeSampling = false;
VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling); VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling);
} }
}; // namespace xgboost::tree
}; // namespace tree
}; // namespace xgboost

View File

@ -12,6 +12,7 @@
#include "../../../../src/tree/param.h" // for TrainParam #include "../../../../src/tree/param.h" // for TrainParam
#include "../../categorical_helpers.h" // for OneHotEncodeFeature #include "../../categorical_helpers.h" // for OneHotEncodeFeature
#include "../../helpers.h" #include "../../helpers.h"
#include "../../histogram_helpers.h" // for BuildEllpackPage
namespace xgboost::tree { namespace xgboost::tree {
TEST(Histogram, DeviceHistogramStorage) { TEST(Histogram, DeviceHistogramStorage) {
@ -54,6 +55,83 @@ TEST(Histogram, DeviceHistogramStorage) {
EXPECT_ANY_THROW(histogram.AllocateHistograms(&ctx, {kNNodes + 1});); EXPECT_ANY_THROW(histogram.AllocateHistograms(&ctx, {kNNodes + 1}););
} }
std::vector<GradientPairPrecise> GetHostHistGpair() {
// 24 bins, 3 bins for each feature (column).
std::vector<GradientPairPrecise> hist_gpair = {
{0.8314f, 0.7147f}, {1.7989f, 3.7312f}, {3.3846f, 3.4598f},
{2.9277f, 3.5886f}, {1.8429f, 2.4152f}, {1.2443f, 1.9019f},
{1.6380f, 2.9174f}, {1.5657f, 2.5107f}, {2.8111f, 2.4776f},
{2.1322f, 3.0651f}, {3.2927f, 3.8540f}, {0.5899f, 0.9866f},
{1.5185f, 1.6263f}, {2.0686f, 3.1844f}, {2.4278f, 3.0950f},
{1.5105f, 2.1403f}, {2.6922f, 4.2217f}, {1.8122f, 1.5437f},
{0.0000f, 0.0000f}, {4.3245f, 5.7955f}, {1.6903f, 2.1103f},
{2.4012f, 4.4754f}, {3.6136f, 3.4303f}, {0.0000f, 0.0000f}
};
return hist_gpair;
}
void TestBuildHist(bool use_shared_memory_histograms) {
int const kNRows = 16, kNCols = 8;
Context ctx{MakeCUDACtx(0)};
TrainParam param;
Args args{
{"max_depth", "6"},
{"max_leaves", "0"},
};
param.Init(args);
auto page = BuildEllpackPage(&ctx, kNRows, kNCols);
BatchParam batch_param{};
xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
HostDeviceVector<GradientPair> gpair(kNRows);
for (auto& gp : gpair.HostVector()) {
float grad = dist(&gen);
float hess = dist(&gen);
gp = GradientPair{grad, hess};
}
gpair.SetDevice(ctx.Device());
auto row_partitioner = std::make_unique<RowPartitioner>();
row_partitioner->Reset(&ctx, kNRows, 0);
auto quantiser = std::make_unique<GradientQuantiser>(&ctx, gpair.ConstDeviceSpan(), MetaInfo());
auto shm_size = use_shared_memory_histograms ? dh::MaxSharedMemoryOptin(ctx.Ordinal()) : 0;
FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size, sizeof(GradientPairInt64));
DeviceHistogramStorage hist;
hist.Init(ctx.Device(), page->Cuts().TotalBins());
hist.AllocateHistograms(&ctx, {0});
DeviceHistogramBuilder builder;
builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), !use_shared_memory_histograms);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(),
row_partitioner->GetRows(0), hist.GetNodeHistogram(0), *quantiser);
auto node_histogram = hist.GetNodeHistogram(0);
std::vector<GradientPairInt64> h_result(node_histogram.size());
dh::CopyDeviceSpanToVector(&h_result, node_histogram);
std::vector<GradientPairPrecise> solution = GetHostHistGpair();
for (size_t i = 0; i < h_result.size(); ++i) {
auto result = quantiser->ToFloatingPoint(h_result[i]);
ASSERT_NEAR(result.GetGrad(), solution[i].GetGrad(), 0.01f);
ASSERT_NEAR(result.GetHess(), solution[i].GetHess(), 0.01f);
}
}
TEST(Histogram, BuildHistGlobalMem) {
TestBuildHist(false);
}
TEST(Histogram, BuildHistSharedMem) {
TestBuildHist(true);
}
void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global) { void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global) {
Context ctx = MakeCUDACtx(0); Context ctx = MakeCUDACtx(0);
size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16; size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16;

View File

@ -2,173 +2,26 @@
* Copyright 2017-2024, XGBoost contributors * Copyright 2017-2024, XGBoost contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <thrust/device_vector.h> #include <xgboost/base.h> // for Args
#include <thrust/host_vector.h> #include <xgboost/context.h> // for Context
#include <xgboost/base.h> #include <xgboost/host_device_vector.h> // for HostDeviceVector
#include <xgboost/json.h> // for Jons
#include <xgboost/task.h> // for ObjInfo
#include <xgboost/tree_model.h> // for RegTree
#include <xgboost/tree_updater.h> // for TreeUpdater
#include <string> #include <memory> // for unique_ptr
#include <vector> #include <string> // for string
#include <vector> // for vector
#include "../../../src/common/common.h" #include "../../../src/common/random.h" // for GlobalRandom
#include "../../../src/data/ellpack_page.cuh" // for EllpackPageImpl
#include "../../../src/data/ellpack_page.h" // for EllpackPage #include "../../../src/data/ellpack_page.h" // for EllpackPage
#include "../../../src/tree/param.h" // for TrainParam #include "../../../src/tree/param.h" // for TrainParam
#include "../../../src/tree/updater_gpu_hist.cu"
#include "../collective/test_worker.h" // for BaseMGPUTest #include "../collective/test_worker.h" // for BaseMGPUTest
#include "../filesystem.h" // dmlc::TemporaryDirectory #include "../filesystem.h" // dmlc::TemporaryDirectory
#include "../helpers.h" #include "../helpers.h"
#include "../histogram_helpers.h"
#include "xgboost/context.h"
#include "xgboost/json.h"
namespace xgboost::tree { namespace xgboost::tree {
std::vector<GradientPairPrecise> GetHostHistGpair() {
// 24 bins, 3 bins for each feature (column).
std::vector<GradientPairPrecise> hist_gpair = {
{0.8314f, 0.7147f}, {1.7989f, 3.7312f}, {3.3846f, 3.4598f},
{2.9277f, 3.5886f}, {1.8429f, 2.4152f}, {1.2443f, 1.9019f},
{1.6380f, 2.9174f}, {1.5657f, 2.5107f}, {2.8111f, 2.4776f},
{2.1322f, 3.0651f}, {3.2927f, 3.8540f}, {0.5899f, 0.9866f},
{1.5185f, 1.6263f}, {2.0686f, 3.1844f}, {2.4278f, 3.0950f},
{1.5105f, 2.1403f}, {2.6922f, 4.2217f}, {1.8122f, 1.5437f},
{0.0000f, 0.0000f}, {4.3245f, 5.7955f}, {1.6903f, 2.1103f},
{2.4012f, 4.4754f}, {3.6136f, 3.4303f}, {0.0000f, 0.0000f}
};
return hist_gpair;
}
template <typename GradientSumT>
void TestBuildHist(bool use_shared_memory_histograms) {
int const kNRows = 16, kNCols = 8;
Context ctx{MakeCUDACtx(0)};
TrainParam param;
Args args{
{"max_depth", "6"},
{"max_leaves", "0"},
};
param.Init(args);
auto page = BuildEllpackPage(&ctx, kNRows, kNCols);
BatchParam batch_param{};
auto cs = std::make_shared<common::ColumnSampler>(0);
GPUHistMakerDevice maker(&ctx, /*is_external_memory=*/false, {}, kNRows, param, cs, kNCols,
batch_param, MetaInfo());
xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
HostDeviceVector<GradientPair> gpair(kNRows);
for (auto& gp : gpair.HostVector()) {
float grad = dist(&gen);
float hess = dist(&gen);
gp = GradientPair{grad, hess};
}
gpair.SetDevice(ctx.Device());
maker.row_partitioner = std::make_unique<RowPartitioner>();
maker.row_partitioner->Reset(&ctx, kNRows, 0);
maker.hist.Init(ctx.Device(), page->Cuts().TotalBins());
maker.hist.AllocateHistograms(&ctx, {0});
maker.gpair = gpair.DeviceSpan();
maker.quantiser = std::make_unique<GradientQuantiser>(&ctx, maker.gpair, MetaInfo());
maker.page = page.get();
maker.InitFeatureGroupsOnce();
DeviceHistogramBuilder builder;
builder.Reset(&ctx, maker.feature_groups->DeviceAccessor(ctx.Device()),
!use_shared_memory_histograms);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
maker.feature_groups->DeviceAccessor(ctx.Device()), gpair.DeviceSpan(),
maker.row_partitioner->GetRows(0), maker.hist.GetNodeHistogram(0),
*maker.quantiser);
DeviceHistogramStorage<>& d_hist = maker.hist;
auto node_histogram = d_hist.GetNodeHistogram(0);
// d_hist.data stored in float, not gradient pair
thrust::host_vector<GradientPairInt64> h_result (node_histogram.size());
dh::safe_cuda(cudaMemcpy(h_result.data(), node_histogram.data(), node_histogram.size_bytes(),
cudaMemcpyDeviceToHost));
std::vector<GradientPairPrecise> solution = GetHostHistGpair();
for (size_t i = 0; i < h_result.size(); ++i) {
auto result = maker.quantiser->ToFloatingPoint(h_result[i]);
ASSERT_NEAR(result.GetGrad(), solution[i].GetGrad(), 0.01f);
ASSERT_NEAR(result.GetHess(), solution[i].GetHess(), 0.01f);
}
}
TEST(GpuHist, BuildHistGlobalMem) {
TestBuildHist<GradientPairPrecise>(false);
}
TEST(GpuHist, BuildHistSharedMem) {
TestBuildHist<GradientPairPrecise>(true);
}
std::shared_ptr<detail::HistogramCutsWrapper> GetHostCutMatrix () {
auto cmat = std::make_shared<detail::HistogramCutsWrapper>();
cmat->SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
cmat->SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f});
// 24 cut fields, 3 cut fields for each feature (column).
// Each row of the cut represents the cuts for a data column.
cmat->SetValues({0.30f, 0.67f, 1.64f,
0.32f, 0.77f, 1.95f,
0.29f, 0.70f, 1.80f,
0.32f, 0.75f, 1.85f,
0.18f, 0.59f, 1.69f,
0.25f, 0.74f, 2.00f,
0.26f, 0.74f, 1.98f,
0.26f, 0.71f, 1.83f});
return cmat;
}
void TestHistogramIndexImpl() {
// Test if the compressed histogram index matches when using a sparse
// dmatrix with and without using external memory
int constexpr kNRows = 1000, kNCols = 10;
// Build 2 matrices and build a histogram maker with that
Context ctx(MakeCUDACtx(0));
ObjInfo task{ObjInfo::kRegression};
tree::GPUHistMaker hist_maker{&ctx, &task}, hist_maker_ext{&ctx, &task};
std::unique_ptr<DMatrix> hist_maker_dmat(
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));
dmlc::TemporaryDirectory tempdir;
std::unique_ptr<DMatrix> hist_maker_ext_dmat(
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 128UL, true, tempdir));
Args training_params = {{"max_depth", "10"}, {"max_leaves", "0"}};
TrainParam param;
param.UpdateAllowUnknown(training_params);
hist_maker.Configure(training_params);
hist_maker.InitDataOnce(&param, hist_maker_dmat.get());
hist_maker_ext.Configure(training_params);
hist_maker_ext.InitDataOnce(&param, hist_maker_ext_dmat.get());
// Extract the device maker from the histogram makers and from that its compressed
// histogram index
const auto &maker = hist_maker.maker;
auto grad = GenerateRandomGradients(kNRows);
grad.SetDevice(DeviceOrd::CUDA(0));
maker->Reset(&grad, hist_maker_dmat.get(), kNCols);
const auto &maker_ext = hist_maker_ext.maker;
maker_ext->Reset(&grad, hist_maker_ext_dmat.get(), kNCols);
ASSERT_EQ(maker->page->Cuts().TotalBins(), maker_ext->page->Cuts().TotalBins());
ASSERT_EQ(maker->page->gidx_buffer.size(), maker_ext->page->gidx_buffer.size());
}
TEST(GpuHist, TestHistogramIndex) {
TestHistogramIndexImpl();
}
void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix* dmat, void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix* dmat,
size_t gpu_page_size, RegTree* tree, HostDeviceVector<bst_float>* preds, size_t gpu_page_size, RegTree* tree, HostDeviceVector<bst_float>* preds,
float subsample = 1.0f, const std::string& sampling_method = "uniform", float subsample = 1.0f, const std::string& sampling_method = "uniform",
@ -200,14 +53,14 @@ void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix
param.UpdateAllowUnknown(args); param.UpdateAllowUnknown(args);
ObjInfo task{ObjInfo::kRegression}; ObjInfo task{ObjInfo::kRegression};
tree::GPUHistMaker hist_maker{ctx, &task}; std::unique_ptr<TreeUpdater> hist_maker{TreeUpdater::Create("grow_gpu_hist", ctx, &task)};
hist_maker.Configure(Args{}); hist_maker->Configure(Args{});
std::vector<HostDeviceVector<bst_node_t>> position(1); std::vector<HostDeviceVector<bst_node_t>> position(1);
hist_maker.Update(&param, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position}, hist_maker->Update(&param, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
{tree}); {tree});
auto cache = linalg::MakeTensorView(ctx, preds->DeviceSpan(), preds->Size(), 1); auto cache = linalg::MakeTensorView(ctx, preds->DeviceSpan(), preds->Size(), 1);
hist_maker.UpdatePredictionCache(dmat, cache); hist_maker->UpdatePredictionCache(dmat, cache);
} }
TEST(GpuHist, UniformSampling) { TEST(GpuHist, UniformSampling) {