[EM] Handle base idx in GPU histogram. (#10549)
This commit is contained in:
parent
34b154c284
commit
5f910cd4ff
@ -1,8 +1,7 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2020-2024, XGBoost Contributors
|
* Copyright 2020-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <thrust/iterator/transform_iterator.h>
|
#include <thrust/iterator/transform_iterator.h> // for make_transform_iterator
|
||||||
#include <thrust/reduce.h>
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstdint> // uint32_t, int32_t
|
#include <cstdint> // uint32_t, int32_t
|
||||||
@ -101,9 +100,8 @@ GradientQuantiser::GradientQuantiser(Context const* ctx, common::Span<GradientPa
|
|||||||
static_cast<T>(1) / to_floating_point_.GetHess());
|
static_cast<T>(1) / to_floating_point_.GetHess());
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEV_INLINE void
|
XGBOOST_DEV_INLINE void AtomicAddGpairShared(xgboost::GradientPairInt64* dest,
|
||||||
AtomicAddGpairShared(xgboost::GradientPairInt64 *dest,
|
xgboost::GradientPairInt64 const& gpair) {
|
||||||
xgboost::GradientPairInt64 const &gpair) {
|
|
||||||
auto dst_ptr = reinterpret_cast<int64_t *>(dest);
|
auto dst_ptr = reinterpret_cast<int64_t *>(dest);
|
||||||
auto g = gpair.GetQuantisedGrad();
|
auto g = gpair.GetQuantisedGrad();
|
||||||
auto h = gpair.GetQuantisedHess();
|
auto h = gpair.GetQuantisedHess();
|
||||||
@ -131,7 +129,9 @@ template <int kBlockThreads, int kItemsPerThread,
|
|||||||
class HistogramAgent {
|
class HistogramAgent {
|
||||||
GradientPairInt64* smem_arr_;
|
GradientPairInt64* smem_arr_;
|
||||||
GradientPairInt64* d_node_hist_;
|
GradientPairInt64* d_node_hist_;
|
||||||
dh::LDGIterator<const RowPartitioner::RowIndexT> d_ridx_;
|
using Idx = RowPartitioner::RowIndexT;
|
||||||
|
|
||||||
|
dh::LDGIterator<const Idx> d_ridx_;
|
||||||
const GradientPair* d_gpair_;
|
const GradientPair* d_gpair_;
|
||||||
const FeatureGroup group_;
|
const FeatureGroup group_;
|
||||||
const EllpackDeviceAccessor& matrix_;
|
const EllpackDeviceAccessor& matrix_;
|
||||||
@ -142,8 +142,7 @@ class HistogramAgent {
|
|||||||
public:
|
public:
|
||||||
__device__ HistogramAgent(GradientPairInt64* smem_arr,
|
__device__ HistogramAgent(GradientPairInt64* smem_arr,
|
||||||
GradientPairInt64* __restrict__ d_node_hist, const FeatureGroup& group,
|
GradientPairInt64* __restrict__ d_node_hist, const FeatureGroup& group,
|
||||||
const EllpackDeviceAccessor& matrix,
|
const EllpackDeviceAccessor& matrix, common::Span<const Idx> d_ridx,
|
||||||
common::Span<const RowPartitioner::RowIndexT> d_ridx,
|
|
||||||
const GradientQuantiser& rounding, const GradientPair* d_gpair)
|
const GradientQuantiser& rounding, const GradientPair* d_gpair)
|
||||||
: smem_arr_(smem_arr),
|
: smem_arr_(smem_arr),
|
||||||
d_node_hist_(d_node_hist),
|
d_node_hist_(d_node_hist),
|
||||||
@ -154,15 +153,15 @@ class HistogramAgent {
|
|||||||
n_elements_(feature_stride_ * d_ridx.size()),
|
n_elements_(feature_stride_ * d_ridx.size()),
|
||||||
rounding_(rounding),
|
rounding_(rounding),
|
||||||
d_gpair_(d_gpair) {}
|
d_gpair_(d_gpair) {}
|
||||||
|
|
||||||
__device__ void ProcessPartialTileShared(std::size_t offset) {
|
__device__ void ProcessPartialTileShared(std::size_t offset) {
|
||||||
for (std::size_t idx = offset + threadIdx.x;
|
for (std::size_t idx = offset + threadIdx.x;
|
||||||
idx < std::min(offset + kBlockThreads * kItemsPerTile, n_elements_);
|
idx < std::min(offset + kBlockThreads * kItemsPerTile, n_elements_);
|
||||||
idx += kBlockThreads) {
|
idx += kBlockThreads) {
|
||||||
int ridx = d_ridx_[idx / feature_stride_];
|
Idx ridx = d_ridx_[idx / feature_stride_];
|
||||||
int gidx =
|
Idx midx = (ridx - matrix_.base_rowid) * matrix_.row_stride + group_.start_feature +
|
||||||
matrix_
|
idx % feature_stride_;
|
||||||
.gidx_iter[ridx * matrix_.row_stride + group_.start_feature + idx % feature_stride_] -
|
bst_bin_t gidx = matrix_.gidx_iter[midx] - group_.start_bin;
|
||||||
group_.start_bin;
|
|
||||||
if (matrix_.is_dense || gidx != matrix_.NumBins()) {
|
if (matrix_.is_dense || gidx != matrix_.NumBins()) {
|
||||||
auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]);
|
auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]);
|
||||||
AtomicAddGpairShared(smem_arr_ + gidx, adjusted);
|
AtomicAddGpairShared(smem_arr_ + gidx, adjusted);
|
||||||
@ -188,8 +187,8 @@ class HistogramAgent {
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < kItemsPerThread; i++) {
|
for (int i = 0; i < kItemsPerThread; i++) {
|
||||||
gpair[i] = d_gpair_[ridx[i]];
|
gpair[i] = d_gpair_[ridx[i]];
|
||||||
gidx[i] = matrix_.gidx_iter[ridx[i] * matrix_.row_stride + group_.start_feature +
|
gidx[i] = matrix_.gidx_iter[(ridx[i] - matrix_.base_rowid) * matrix_.row_stride +
|
||||||
idx[i] % feature_stride_];
|
group_.start_feature + idx[i] % feature_stride_];
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < kItemsPerThread; i++) {
|
for (int i = 0; i < kItemsPerThread; i++) {
|
||||||
@ -200,7 +199,7 @@ class HistogramAgent {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
__device__ void BuildHistogramWithShared() {
|
__device__ void BuildHistogramWithShared() {
|
||||||
dh::BlockFill(smem_arr_, group_.num_bins, GradientPairInt64());
|
dh::BlockFill(smem_arr_, group_.num_bins, GradientPairInt64{});
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
std::size_t offset = blockIdx.x * kItemsPerTile;
|
std::size_t offset = blockIdx.x * kItemsPerTile;
|
||||||
@ -219,10 +218,9 @@ class HistogramAgent {
|
|||||||
|
|
||||||
__device__ void BuildHistogramWithGlobal() {
|
__device__ void BuildHistogramWithGlobal() {
|
||||||
for (auto idx : dh::GridStrideRange(static_cast<std::size_t>(0), n_elements_)) {
|
for (auto idx : dh::GridStrideRange(static_cast<std::size_t>(0), n_elements_)) {
|
||||||
int ridx = d_ridx_[idx / feature_stride_];
|
Idx ridx = d_ridx_[idx / feature_stride_];
|
||||||
int gidx =
|
bst_bin_t gidx = matrix_.gidx_iter[(ridx - matrix_.base_rowid) * matrix_.row_stride +
|
||||||
matrix_
|
group_.start_feature + idx % feature_stride_];
|
||||||
.gidx_iter[ridx * matrix_.row_stride + group_.start_feature + idx % feature_stride_];
|
|
||||||
if (matrix_.is_dense || gidx != matrix_.NumBins()) {
|
if (matrix_.is_dense || gidx != matrix_.NumBins()) {
|
||||||
auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]);
|
auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]);
|
||||||
AtomicAddGpairGlobal(d_node_hist_ + gidx, adjusted);
|
AtomicAddGpairGlobal(d_node_hist_ + gidx, adjusted);
|
||||||
@ -231,8 +229,7 @@ class HistogramAgent {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <bool use_shared_memory_histograms, int kBlockThreads,
|
template <bool use_shared_memory_histograms, int kBlockThreads, int kItemsPerThread>
|
||||||
int kItemsPerThread>
|
|
||||||
__global__ void __launch_bounds__(kBlockThreads)
|
__global__ void __launch_bounds__(kBlockThreads)
|
||||||
SharedMemHistKernel(const EllpackDeviceAccessor matrix,
|
SharedMemHistKernel(const EllpackDeviceAccessor matrix,
|
||||||
const FeatureGroupsAccessor feature_groups,
|
const FeatureGroupsAccessor feature_groups,
|
||||||
@ -251,6 +248,7 @@ __global__ void __launch_bounds__(kBlockThreads)
|
|||||||
agent.BuildHistogramWithGlobal();
|
agent.BuildHistogramWithGlobal();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
constexpr std::int32_t kBlockThreads = 1024;
|
constexpr std::int32_t kBlockThreads = 1024;
|
||||||
constexpr std::int32_t kItemsPerThread = 8;
|
constexpr std::int32_t kItemsPerThread = 8;
|
||||||
|
|||||||
@ -78,5 +78,4 @@ class DeviceHistogramBuilder {
|
|||||||
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding);
|
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding);
|
||||||
};
|
};
|
||||||
} // namespace xgboost::tree
|
} // namespace xgboost::tree
|
||||||
|
|
||||||
#endif // HISTOGRAM_CUH_
|
#endif // HISTOGRAM_CUH_
|
||||||
|
|||||||
@ -1,28 +1,23 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2017-2022 XGBoost contributors
|
* Copyright 2017-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <thrust/iterator/discard_iterator.h>
|
#include <thrust/sequence.h> // for sequence
|
||||||
#include <thrust/iterator/transform_output_iterator.h>
|
|
||||||
#include <thrust/sequence.h>
|
|
||||||
|
|
||||||
#include <vector>
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../../common/device_helpers.cuh"
|
#include "../../common/cuda_context.cuh" // for CUDAContext
|
||||||
|
#include "../../common/device_helpers.cuh" // for CopyDeviceSpanToVector, ToSpan
|
||||||
#include "row_partitioner.cuh"
|
#include "row_partitioner.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
RowPartitioner::RowPartitioner(Context const* ctx, bst_idx_t n_samples, bst_idx_t base_rowid)
|
||||||
|
: device_idx_(ctx->Device()), ridx_(n_samples), ridx_tmp_(n_samples) {
|
||||||
RowPartitioner::RowPartitioner(DeviceOrd device_idx, size_t num_rows)
|
|
||||||
: device_idx_(device_idx), ridx_(num_rows), ridx_tmp_(num_rows) {
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx_.ordinal));
|
dh::safe_cuda(cudaSetDevice(device_idx_.ordinal));
|
||||||
ridx_segments_.emplace_back(NodePositionInfo{Segment(0, num_rows)});
|
ridx_segments_.emplace_back(NodePositionInfo{Segment(0, n_samples)});
|
||||||
thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size());
|
thrust::sequence(ctx->CUDACtx()->CTP(), ridx_.data(), ridx_.data() + ridx_.size(), base_rowid);
|
||||||
}
|
}
|
||||||
|
|
||||||
RowPartitioner::~RowPartitioner() {
|
RowPartitioner::~RowPartitioner() { dh::safe_cuda(cudaSetDevice(device_idx_.ordinal)); }
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx_.ordinal));
|
|
||||||
}
|
|
||||||
|
|
||||||
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(bst_node_t nidx) {
|
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(bst_node_t nidx) {
|
||||||
auto segment = ridx_segments_.at(nidx).segment;
|
auto segment = ridx_segments_.at(nidx).segment;
|
||||||
@ -39,6 +34,4 @@ std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost(bst_node_t ni
|
|||||||
dh::CopyDeviceSpanToVector(&rows, span);
|
dh::CopyDeviceSpanToVector(&rows, span);
|
||||||
return rows;
|
return rows;
|
||||||
}
|
}
|
||||||
|
}; // namespace xgboost::tree
|
||||||
}; // namespace tree
|
|
||||||
}; // namespace xgboost
|
|
||||||
|
|||||||
@ -1,17 +1,17 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2017-2022 XGBoost contributors
|
* Copyright 2017-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <thrust/execution_policy.h>
|
#include <thrust/execution_policy.h>
|
||||||
|
#include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
|
||||||
|
#include <thrust/iterator/transform_output_iterator.h> // for make_transform_output_iterator
|
||||||
|
|
||||||
#include <limits>
|
#include <algorithm> // for max
|
||||||
#include <vector>
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../../common/device_helpers.cuh"
|
#include "../../common/device_helpers.cuh" // for MakeTransformIterator
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h" // for bst_idx_t
|
||||||
#include "xgboost/context.h"
|
#include "xgboost/context.h" // for Context
|
||||||
#include "xgboost/task.h"
|
|
||||||
#include "xgboost/tree_model.h"
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -223,7 +223,12 @@ class RowPartitioner {
|
|||||||
dh::PinnedMemory pinned2_;
|
dh::PinnedMemory pinned2_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
RowPartitioner(DeviceOrd device_idx, size_t num_rows);
|
/**
|
||||||
|
* @param ctx Context for device ordinal and stream.
|
||||||
|
* @param n_samples The number of samples in each batch.
|
||||||
|
* @param base_rowid The base row index for the current batch.
|
||||||
|
*/
|
||||||
|
RowPartitioner(Context const* ctx, bst_idx_t n_samples, bst_idx_t base_rowid);
|
||||||
~RowPartitioner();
|
~RowPartitioner();
|
||||||
RowPartitioner(const RowPartitioner&) = delete;
|
RowPartitioner(const RowPartitioner&) = delete;
|
||||||
RowPartitioner& operator=(const RowPartitioner&) = delete;
|
RowPartitioner& operator=(const RowPartitioner&) = delete;
|
||||||
|
|||||||
@ -251,7 +251,8 @@ struct GPUHistMakerDevice {
|
|||||||
quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, dmat->Info());
|
quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, dmat->Info());
|
||||||
|
|
||||||
row_partitioner.reset(); // Release the device memory first before reallocating
|
row_partitioner.reset(); // Release the device memory first before reallocating
|
||||||
row_partitioner = std::make_unique<RowPartitioner>(ctx_->Device(), sample.sample_rows);
|
CHECK_EQ(page->base_rowid, 0);
|
||||||
|
row_partitioner = std::make_unique<RowPartitioner>(ctx_, sample.sample_rows, page->base_rowid);
|
||||||
|
|
||||||
// Init histogram
|
// Init histogram
|
||||||
hist.Init(ctx_->Device(), page->Cuts().TotalBins());
|
hist.Init(ctx_->Device(), page->Cuts().TotalBins());
|
||||||
|
|||||||
@ -2,13 +2,15 @@
|
|||||||
* Copyright 2020-2024, XGBoost Contributors
|
* Copyright 2020-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/context.h> // for Context
|
||||||
|
|
||||||
#include <vector>
|
#include <memory> // for unique_ptr
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../../../../src/tree/gpu_hist/histogram.cuh"
|
#include "../../../../src/tree/gpu_hist/histogram.cuh"
|
||||||
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
|
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh" // for RowPartitioner
|
||||||
#include "../../../../src/tree/param.h" // TrainParam
|
#include "../../../../src/tree/param.h" // for TrainParam
|
||||||
#include "../../categorical_helpers.h"
|
#include "../../categorical_helpers.h" // for OneHotEncodeFeature
|
||||||
#include "../../helpers.h"
|
#include "../../helpers.h"
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
@ -24,7 +26,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
|
|||||||
for (auto const& batch : matrix->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
for (auto const& batch : matrix->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
||||||
auto* page = batch.Impl();
|
auto* page = batch.Impl();
|
||||||
|
|
||||||
tree::RowPartitioner row_partitioner(ctx.Device(), kRows);
|
tree::RowPartitioner row_partitioner{&ctx, kRows, page->base_rowid};
|
||||||
auto ridx = row_partitioner.GetRows(0);
|
auto ridx = row_partitioner.GetRows(0);
|
||||||
|
|
||||||
bst_bin_t num_bins = kBins * kCols;
|
bst_bin_t num_bins = kBins * kCols;
|
||||||
@ -129,7 +131,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
|||||||
auto cat_m = GetDMatrixFromData(x, kRows, 1);
|
auto cat_m = GetDMatrixFromData(x, kRows, 1);
|
||||||
cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
||||||
auto batch_param = BatchParam{kBins, tree::TrainParam::DftSparseThreshold()};
|
auto batch_param = BatchParam{kBins, tree::TrainParam::DftSparseThreshold()};
|
||||||
tree::RowPartitioner row_partitioner(ctx.Device(), kRows);
|
tree::RowPartitioner row_partitioner{&ctx, kRows, 0};
|
||||||
auto ridx = row_partitioner.GetRows(0);
|
auto ridx = row_partitioner.GetRows(0);
|
||||||
dh::device_vector<GradientPairInt64> cat_hist(num_categories);
|
dh::device_vector<GradientPairInt64> cat_hist(num_categories);
|
||||||
auto gpair = GenerateRandomGradients(kRows, 0, 2);
|
auto gpair = GenerateRandomGradients(kRows, 0, 2);
|
||||||
@ -262,4 +264,105 @@ TEST(Histogram, Quantiser) {
|
|||||||
ASSERT_EQ(gh.GetHess(), 1.0);
|
ASSERT_EQ(gh.GetHess(), 1.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
namespace {
|
||||||
|
class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<float, bool>> {
|
||||||
|
public:
|
||||||
|
void Run(float sparsity, bool force_global) {
|
||||||
|
bst_idx_t n_samples{512}, n_features{12}, n_batches{3};
|
||||||
|
std::vector<std::unique_ptr<RowPartitioner>> partitioners;
|
||||||
|
auto p_fmat = RandomDataGenerator{n_samples, n_features, sparsity}
|
||||||
|
.Batches(n_batches)
|
||||||
|
.GenerateSparsePageDMatrix("cache", true);
|
||||||
|
bst_bin_t n_bins = 16;
|
||||||
|
BatchParam p{n_bins, TrainParam::DftSparseThreshold()};
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
|
|
||||||
|
std::unique_ptr<FeatureGroups> fg;
|
||||||
|
dh::device_vector<GradientPairInt64> single_hist;
|
||||||
|
dh::device_vector<GradientPairInt64> multi_hist;
|
||||||
|
|
||||||
|
auto gpair = GenerateRandomGradients(n_samples);
|
||||||
|
gpair.SetDevice(ctx.Device());
|
||||||
|
auto quantiser = GradientQuantiser{&ctx, gpair.ConstDeviceSpan(), p_fmat->Info()};
|
||||||
|
std::shared_ptr<common::HistogramCuts> cuts;
|
||||||
|
|
||||||
|
{
|
||||||
|
/**
|
||||||
|
* Multi page.
|
||||||
|
*/
|
||||||
|
std::int32_t k{0};
|
||||||
|
for (auto const& page : p_fmat->GetBatches<EllpackPage>(&ctx, p)) {
|
||||||
|
auto impl = page.Impl();
|
||||||
|
if (k == 0) {
|
||||||
|
// Initialization
|
||||||
|
auto d_matrix = impl->GetDeviceAccessor(ctx.Device());
|
||||||
|
fg = std::make_unique<FeatureGroups>(impl->Cuts());
|
||||||
|
auto init = GradientPairInt64{0, 0};
|
||||||
|
multi_hist = decltype(multi_hist)(impl->Cuts().TotalBins(), init);
|
||||||
|
single_hist = decltype(single_hist)(impl->Cuts().TotalBins(), init);
|
||||||
|
cuts = std::make_shared<common::HistogramCuts>(impl->Cuts());
|
||||||
|
}
|
||||||
|
|
||||||
|
partitioners.emplace_back(
|
||||||
|
std::make_unique<RowPartitioner>(&ctx, impl->Size(), impl->base_rowid));
|
||||||
|
|
||||||
|
auto ridx = partitioners.at(k)->GetRows(0);
|
||||||
|
auto d_histogram = dh::ToSpan(multi_hist);
|
||||||
|
DeviceHistogramBuilder builder;
|
||||||
|
builder.Reset(&ctx, fg->DeviceAccessor(ctx.Device()), force_global);
|
||||||
|
builder.BuildHistogram(ctx.CUDACtx(), impl->GetDeviceAccessor(ctx.Device()),
|
||||||
|
fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx,
|
||||||
|
d_histogram, quantiser);
|
||||||
|
++k;
|
||||||
|
}
|
||||||
|
ASSERT_EQ(k, n_batches);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
/**
|
||||||
|
* Single page.
|
||||||
|
*/
|
||||||
|
RowPartitioner partitioner{&ctx, p_fmat->Info().num_row_, 0};
|
||||||
|
SparsePage concat;
|
||||||
|
std::vector<float> hess(p_fmat->Info().num_row_, 1.0f);
|
||||||
|
for (auto const& page : p_fmat->GetBatches<SparsePage>()) {
|
||||||
|
concat.Push(page);
|
||||||
|
}
|
||||||
|
EllpackPageImpl page{
|
||||||
|
ctx.Device(), cuts, concat, p_fmat->IsDense(), p_fmat->Info().num_col_, {}};
|
||||||
|
auto ridx = partitioner.GetRows(0);
|
||||||
|
auto d_histogram = dh::ToSpan(single_hist);
|
||||||
|
DeviceHistogramBuilder builder;
|
||||||
|
builder.Reset(&ctx, fg->DeviceAccessor(ctx.Device()), force_global);
|
||||||
|
builder.BuildHistogram(ctx.CUDACtx(), page.GetDeviceAccessor(ctx.Device()),
|
||||||
|
fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx,
|
||||||
|
d_histogram, quantiser);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<GradientPairInt64> h_single(single_hist.size());
|
||||||
|
thrust::copy(single_hist.begin(), single_hist.end(), h_single.begin());
|
||||||
|
std::vector<GradientPairInt64> h_multi(multi_hist.size());
|
||||||
|
thrust::copy(multi_hist.begin(), multi_hist.end(), h_multi.begin());
|
||||||
|
|
||||||
|
for (std::size_t i = 0; i < single_hist.size(); ++i) {
|
||||||
|
ASSERT_EQ(h_single[i].GetQuantisedGrad(), h_multi[i].GetQuantisedGrad());
|
||||||
|
ASSERT_EQ(h_single[i].GetQuantisedHess(), h_multi[i].GetQuantisedHess());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST_P(HistogramExternalMemoryTest, ExternalMemory) {
|
||||||
|
std::apply(&HistogramExternalMemoryTest::Run, std::tuple_cat(std::make_tuple(this), GetParam()));
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(Histogram, HistogramExternalMemoryTest, ::testing::ValuesIn([]() {
|
||||||
|
std::vector<std::tuple<float, bool>> params;
|
||||||
|
for (auto global : {true, false}) {
|
||||||
|
for (auto sparsity : {0.0f, 0.2f, 0.8f}) {
|
||||||
|
params.emplace_back(sparsity, global);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return params;
|
||||||
|
}()));
|
||||||
} // namespace xgboost::tree
|
} // namespace xgboost::tree
|
||||||
|
|||||||
@ -1,25 +1,22 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2019-2022 by XGBoost Contributors
|
* Copyright 2019-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <thrust/device_vector.h>
|
#include <thrust/device_vector.h>
|
||||||
#include <thrust/host_vector.h>
|
|
||||||
#include <thrust/sequence.h>
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <cstddef> // for size_t
|
||||||
#include <vector>
|
#include <cstdint> // for uint32_t
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
|
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
|
||||||
#include "../../helpers.h"
|
#include "../../helpers.h"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/context.h"
|
|
||||||
#include "xgboost/task.h"
|
|
||||||
#include "xgboost/tree_model.h"
|
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
void TestUpdatePositionBatch() {
|
void TestUpdatePositionBatch() {
|
||||||
const int kNumRows = 10;
|
const int kNumRows = 10;
|
||||||
RowPartitioner rp(FstCU(), kNumRows);
|
auto ctx = MakeCUDACtx(0);
|
||||||
|
RowPartitioner rp{&ctx, kNumRows, 0};
|
||||||
auto rows = rp.GetRowsHost(0);
|
auto rows = rp.GetRowsHost(0);
|
||||||
EXPECT_EQ(rows.size(), kNumRows);
|
EXPECT_EQ(rows.size(), kNumRows);
|
||||||
for (auto i = 0ull; i < kNumRows; i++) {
|
for (auto i = 0ull; i < kNumRows; i++) {
|
||||||
|
|||||||
@ -106,7 +106,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
|||||||
gpair.SetDevice(ctx.Device());
|
gpair.SetDevice(ctx.Device());
|
||||||
|
|
||||||
thrust::host_vector<common::CompressedByteT> h_gidx_buffer(page->gidx_buffer.HostVector());
|
thrust::host_vector<common::CompressedByteT> h_gidx_buffer(page->gidx_buffer.HostVector());
|
||||||
maker.row_partitioner = std::make_unique<RowPartitioner>(ctx.Device(), kNRows);
|
maker.row_partitioner = std::make_unique<RowPartitioner>(&ctx, kNRows, 0);
|
||||||
|
|
||||||
maker.hist.Init(ctx.Device(), page->Cuts().TotalBins());
|
maker.hist.Init(ctx.Device(), page->Cuts().TotalBins());
|
||||||
maker.hist.AllocateHistograms({0});
|
maker.hist.AllocateHistograms({0});
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user