Avoid caching allocator for large allocations. (#10582)

This commit is contained in:
Jiaming Yuan 2024-07-23 03:48:03 +08:00 committed by GitHub
parent b2cae34a8e
commit a19bbc9be5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 80 additions and 55 deletions

View File

@ -227,7 +227,7 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c
}); });
detail::SortByWeight(&entry_weight, &sorted_entries); detail::SortByWeight(&entry_weight, &sorted_entries);
} else { } else {
thrust::sort(cuctx->CTP(), sorted_entries.begin(), sorted_entries.end(), thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(),
detail::EntryCompareOp()); detail::EntryCompareOp());
} }

View File

@ -10,14 +10,20 @@
#include "row_partitioner.cuh" #include "row_partitioner.cuh"
namespace xgboost::tree { namespace xgboost::tree {
RowPartitioner::RowPartitioner(Context const* ctx, bst_idx_t n_samples, bst_idx_t base_rowid) void RowPartitioner::Reset(Context const* ctx, bst_idx_t n_samples, bst_idx_t base_rowid) {
: device_idx_(ctx->Device()), ridx_(n_samples), ridx_tmp_(n_samples) { ridx_segments_.clear();
dh::safe_cuda(cudaSetDevice(device_idx_.ordinal)); ridx_.resize(n_samples);
ridx_segments_.emplace_back(NodePositionInfo{Segment(0, n_samples)}); ridx_tmp_.resize(n_samples);
tmp_.clear();
CHECK_LE(n_samples, std::numeric_limits<cuda_impl::RowIndexT>::max());
ridx_segments_.emplace_back(
NodePositionInfo{Segment{0, static_cast<cuda_impl::RowIndexT>(n_samples)}});
thrust::sequence(ctx->CUDACtx()->CTP(), ridx_.data(), ridx_.data() + ridx_.size(), base_rowid); thrust::sequence(ctx->CUDACtx()->CTP(), ridx_.data(), ridx_.data() + ridx_.size(), base_rowid);
} }
RowPartitioner::~RowPartitioner() { dh::safe_cuda(cudaSetDevice(device_idx_.ordinal)); } RowPartitioner::~RowPartitioner() = default;
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(bst_node_t nidx) { common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(bst_node_t nidx) {
auto segment = ridx_segments_.at(nidx).segment; auto segment = ridx_segments_.at(nidx).segment;

View File

@ -7,25 +7,34 @@
#include <thrust/iterator/transform_output_iterator.h> // for make_transform_output_iterator #include <thrust/iterator/transform_output_iterator.h> // for make_transform_output_iterator
#include <algorithm> // for max #include <algorithm> // for max
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, uint32_t
#include <vector> // for vector #include <vector> // for vector
#include "../../common/device_helpers.cuh" // for MakeTransformIterator #include "../../common/device_helpers.cuh" // for MakeTransformIterator
#include "xgboost/base.h" // for bst_idx_t #include "xgboost/base.h" // for bst_idx_t
#include "xgboost/context.h" // for Context #include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span
namespace xgboost { namespace xgboost::tree {
namespace tree { namespace cuda_impl {
using RowIndexT = std::uint32_t;
}
/** \brief Used to demarcate a contiguous set of row indices associated with /**
* some tree node. */ * @brief Used to demarcate a contiguous set of row indices associated with some tree
* node.
*/
struct Segment { struct Segment {
bst_uint begin{0}; cuda_impl::RowIndexT begin{0};
bst_uint end{0}; cuda_impl::RowIndexT end{0};
Segment() = default; Segment() = default;
Segment(bst_uint begin, bst_uint end) : begin(begin), end(end) { CHECK_GE(end, begin); } Segment(cuda_impl::RowIndexT begin, cuda_impl::RowIndexT end) : begin(begin), end(end) {
__host__ __device__ size_t Size() const { return end - begin; } CHECK_GE(end, begin);
}
__host__ __device__ bst_idx_t Size() const { return end - begin; }
}; };
// TODO(Rory): Can be larger. To be tuned alongside other batch operations. // TODO(Rory): Can be larger. To be tuned alongside other batch operations.
@ -39,7 +48,7 @@ struct PerNodeData {
template <typename BatchIterT> template <typename BatchIterT>
__device__ __forceinline__ void AssignBatch(BatchIterT batch_info, std::size_t global_thread_idx, __device__ __forceinline__ void AssignBatch(BatchIterT batch_info, std::size_t global_thread_idx,
int* batch_idx, std::size_t* item_idx) { int* batch_idx, std::size_t* item_idx) {
bst_uint sum = 0; cuda_impl::RowIndexT sum = 0;
for (int i = 0; i < kMaxUpdatePositionBatchSize; i++) { for (int i = 0; i < kMaxUpdatePositionBatchSize; i++) {
if (sum + batch_info[i].segment.Size() > global_thread_idx) { if (sum + batch_info[i].segment.Size() > global_thread_idx) {
*batch_idx = i; *batch_idx = i;
@ -65,9 +74,9 @@ __global__ __launch_bounds__(kBlockSize) void SortPositionCopyKernel(
// We can scan over this tuple, where the scan gives us information on how to partition inputs // We can scan over this tuple, where the scan gives us information on how to partition inputs
// according to the flag // according to the flag
struct IndexFlagTuple { struct IndexFlagTuple {
bst_uint idx; // The location of the item we are working on in ridx_ cuda_impl::RowIndexT idx; // The location of the item we are working on in ridx_
bst_uint flag_scan; // This gets populated after scanning cuda_impl::RowIndexT flag_scan; // This gets populated after scanning
int batch_idx; // Which node in the batch does this item belong to std::int32_t batch_idx; // Which node in the batch does this item belong to
bool flag; // Result of op (is this item going left?) bool flag; // Result of op (is this item going left?)
}; };
@ -86,18 +95,18 @@ struct IndexFlagOp {
template <typename OpDataT> template <typename OpDataT>
struct WriteResultsFunctor { struct WriteResultsFunctor {
dh::LDGIterator<PerNodeData<OpDataT>> batch_info; dh::LDGIterator<PerNodeData<OpDataT>> batch_info;
const bst_uint* ridx_in; cuda_impl::RowIndexT const* ridx_in;
bst_uint* ridx_out; cuda_impl::RowIndexT* ridx_out;
bst_uint* counts; cuda_impl::RowIndexT* counts;
__device__ IndexFlagTuple operator()(const IndexFlagTuple& x) { __device__ IndexFlagTuple operator()(const IndexFlagTuple& x) {
std::size_t scatter_address; std::size_t scatter_address;
const Segment& segment = batch_info[x.batch_idx].segment; const Segment& segment = batch_info[x.batch_idx].segment;
if (x.flag) { if (x.flag) {
bst_uint num_previous_flagged = x.flag_scan - 1; // -1 because inclusive scan cuda_impl::RowIndexT num_previous_flagged = x.flag_scan - 1; // -1 because inclusive scan
scatter_address = segment.begin + num_previous_flagged; scatter_address = segment.begin + num_previous_flagged;
} else { } else {
bst_uint num_previous_unflagged = (x.idx - segment.begin) - x.flag_scan; cuda_impl::RowIndexT num_previous_unflagged = (x.idx - segment.begin) - x.flag_scan;
scatter_address = segment.end - num_previous_unflagged - 1; scatter_address = segment.end - num_previous_unflagged - 1;
} }
ridx_out[scatter_address] = ridx_in[x.idx]; ridx_out[scatter_address] = ridx_in[x.idx];
@ -115,7 +124,7 @@ struct WriteResultsFunctor {
template <typename RowIndexT, typename OpT, typename OpDataT> template <typename RowIndexT, typename OpT, typename OpDataT>
void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info, void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx_tmp, common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx_tmp,
common::Span<bst_uint> d_counts, std::size_t total_rows, OpT op, common::Span<cuda_impl::RowIndexT> d_counts, std::size_t total_rows, OpT op,
dh::device_vector<int8_t>* tmp) { dh::device_vector<int8_t>* tmp) {
dh::LDGIterator<PerNodeData<OpDataT>> batch_info_itr(d_batch_info.data()); dh::LDGIterator<PerNodeData<OpDataT>> batch_info_itr(d_batch_info.data());
WriteResultsFunctor<OpDataT> write_results{batch_info_itr, ridx.data(), ridx_tmp.data(), WriteResultsFunctor<OpDataT> write_results{batch_info_itr, ridx.data(), ridx_tmp.data(),
@ -130,7 +139,7 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
std::size_t item_idx; std::size_t item_idx;
AssignBatch(batch_info_itr, idx, &batch_idx, &item_idx); AssignBatch(batch_info_itr, idx, &batch_idx, &item_idx);
auto op_res = op(ridx[item_idx], batch_idx, batch_info_itr[batch_idx].data); auto op_res = op(ridx[item_idx], batch_idx, batch_info_itr[batch_idx].data);
return IndexFlagTuple{static_cast<bst_uint>(item_idx), op_res, batch_idx, op_res}; return IndexFlagTuple{static_cast<cuda_impl::RowIndexT>(item_idx), op_res, batch_idx, op_res};
}); });
size_t temp_bytes = 0; size_t temp_bytes = 0;
if (tmp->empty()) { if (tmp->empty()) {
@ -195,29 +204,31 @@ __global__ __launch_bounds__(kBlockSize) void FinalisePositionKernel(
* partition training rows into different leaf nodes. */ * partition training rows into different leaf nodes. */
class RowPartitioner { class RowPartitioner {
public: public:
using RowIndexT = bst_uint; using RowIndexT = cuda_impl::RowIndexT;
static constexpr bst_node_t kIgnoredTreePosition = -1; static constexpr bst_node_t kIgnoredTreePosition = -1;
private: private:
DeviceOrd device_idx_; /**
/*! \brief In here if you want to find the rows belong to a node nid, first you need to * In here if you want to find the rows belong to a node nid, first you need to get the
* get the indices segment from ridx_segments[nid], then get the row index that * indices segment from ridx_segments[nid], then get the row index that represents
* represents position of row in input data X. `RowPartitioner::GetRows` would be a * position of row in input data X. `RowPartitioner::GetRows` would be a good starting
* good starting place to get a sense what are these vector storing. * place to get a sense what are these vector storing.
* *
* node id -> segment -> indices of rows belonging to node * node id -> segment -> indices of rows belonging to node
*/ */
/*! \brief Range of row index for each node, pointers into ridx below. */
/** @brief Range of row index for each node, pointers into ridx below. */
std::vector<NodePositionInfo> ridx_segments_; std::vector<NodePositionInfo> ridx_segments_;
/*! \brief mapping for node id -> rows. /**
* @brief mapping for node id -> rows.
*
* This looks like: * This looks like:
* node id | 1 | 2 | * node id | 1 | 2 |
* rows idx | 3, 5, 1 | 13, 31 | * rows idx | 3, 5, 1 | 13, 31 |
*/ */
dh::TemporaryArray<RowIndexT> ridx_; dh::DeviceUVector<RowIndexT> ridx_;
// Staging area for sorting ridx // Staging area for sorting ridx
dh::TemporaryArray<RowIndexT> ridx_tmp_; dh::DeviceUVector<RowIndexT> ridx_tmp_;
dh::device_vector<int8_t> tmp_; dh::device_vector<int8_t> tmp_;
dh::PinnedMemory pinned_; dh::PinnedMemory pinned_;
dh::PinnedMemory pinned2_; dh::PinnedMemory pinned2_;
@ -228,7 +239,9 @@ class RowPartitioner {
* @param n_samples The number of samples in each batch. * @param n_samples The number of samples in each batch.
* @param base_rowid The base row index for the current batch. * @param base_rowid The base row index for the current batch.
*/ */
RowPartitioner(Context const* ctx, bst_idx_t n_samples, bst_idx_t base_rowid); RowPartitioner() = default;
void Reset(Context const* ctx, bst_idx_t n_samples, bst_idx_t base_rowid);
~RowPartitioner(); ~RowPartitioner();
RowPartitioner(const RowPartitioner&) = delete; RowPartitioner(const RowPartitioner&) = delete;
RowPartitioner& operator=(const RowPartitioner&) = delete; RowPartitioner& operator=(const RowPartitioner&) = delete;
@ -285,8 +298,8 @@ class RowPartitioner {
cudaMemcpyDefault)); cudaMemcpyDefault));
// Temporary arrays // Temporary arrays
auto h_counts = pinned_.GetSpan<bst_uint>(nidx.size(), 0); auto h_counts = pinned_.GetSpan<RowIndexT>(nidx.size(), 0);
dh::TemporaryArray<bst_uint> d_counts(nidx.size(), 0); dh::TemporaryArray<RowIndexT> d_counts(nidx.size(), 0);
// Partition the rows according to the operator // Partition the rows according to the operator
SortPositionBatch<RowIndexT, UpdatePositionOpT, OpDataT>( SortPositionBatch<RowIndexT, UpdatePositionOpT, OpDataT>(
@ -299,7 +312,7 @@ class RowPartitioner {
dh::DefaultStream().Sync(); dh::DefaultStream().Sync();
// Update segments // Update segments
for (size_t i = 0; i < nidx.size(); i++) { for (std::size_t i = 0; i < nidx.size(); i++) {
auto segment = ridx_segments_.at(nidx[i]).segment; auto segment = ridx_segments_.at(nidx[i]).segment;
auto left_count = h_counts[i]; auto left_count = h_counts[i];
CHECK_LE(left_count, segment.Size()); CHECK_LE(left_count, segment.Size());
@ -336,11 +349,9 @@ class RowPartitioner {
constexpr int kBlockSize = 512; constexpr int kBlockSize = 512;
const int kItemsThread = 8; const int kItemsThread = 8;
const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread); const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread);
common::Span<const RowIndexT> d_ridx(ridx_.data().get(), ridx_.size()); common::Span<RowIndexT const> d_ridx{ridx_.data(), ridx_.size()};
FinalisePositionKernel<kBlockSize><<<grid_size, kBlockSize, 0>>>( FinalisePositionKernel<kBlockSize>
dh::ToSpan(d_node_info_storage), d_ridx, d_out_position, op); <<<grid_size, kBlockSize, 0>>>(dh::ToSpan(d_node_info_storage), d_ridx, d_out_position, op);
} }
}; };
}; // namespace xgboost::tree
}; // namespace tree
}; // namespace xgboost

View File

@ -145,9 +145,11 @@ struct GPUHistMakerDevice {
quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, dmat->Info()); quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, dmat->Info());
row_partitioner.reset(); // Release the device memory first before reallocating if (!row_partitioner) {
row_partitioner = std::make_unique<RowPartitioner>();
}
row_partitioner->Reset(ctx_, sample.sample_rows, page->base_rowid);
CHECK_EQ(page->base_rowid, 0); CHECK_EQ(page->base_rowid, 0);
row_partitioner = std::make_unique<RowPartitioner>(ctx_, sample.sample_rows, page->base_rowid);
// Init histogram // Init histogram
hist.Init(ctx_->Device(), page->Cuts().TotalBins()); hist.Init(ctx_->Device(), page->Cuts().TotalBins());

View File

@ -66,7 +66,8 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
for (auto const& batch : matrix->GetBatches<EllpackPage>(&ctx, batch_param)) { for (auto const& batch : matrix->GetBatches<EllpackPage>(&ctx, batch_param)) {
auto* page = batch.Impl(); auto* page = batch.Impl();
tree::RowPartitioner row_partitioner{&ctx, kRows, page->base_rowid}; tree::RowPartitioner row_partitioner;
row_partitioner.Reset(&ctx, kRows, page->base_rowid);
auto ridx = row_partitioner.GetRows(0); auto ridx = row_partitioner.GetRows(0);
bst_bin_t num_bins = kBins * kCols; bst_bin_t num_bins = kBins * kCols;
@ -171,7 +172,8 @@ void TestGPUHistogramCategorical(size_t num_categories) {
auto cat_m = GetDMatrixFromData(x, kRows, 1); auto cat_m = GetDMatrixFromData(x, kRows, 1);
cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical); cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
auto batch_param = BatchParam{kBins, tree::TrainParam::DftSparseThreshold()}; auto batch_param = BatchParam{kBins, tree::TrainParam::DftSparseThreshold()};
tree::RowPartitioner row_partitioner{&ctx, kRows, 0}; tree::RowPartitioner row_partitioner;
row_partitioner.Reset(&ctx, kRows, 0);
auto ridx = row_partitioner.GetRows(0); auto ridx = row_partitioner.GetRows(0);
dh::device_vector<GradientPairInt64> cat_hist(num_categories); dh::device_vector<GradientPairInt64> cat_hist(num_categories);
auto gpair = GenerateRandomGradients(kRows, 0, 2); auto gpair = GenerateRandomGradients(kRows, 0, 2);
@ -343,8 +345,8 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
cuts = std::make_shared<common::HistogramCuts>(impl->Cuts()); cuts = std::make_shared<common::HistogramCuts>(impl->Cuts());
} }
partitioners.emplace_back( partitioners.emplace_back(std::make_unique<RowPartitioner>());
std::make_unique<RowPartitioner>(&ctx, impl->Size(), impl->base_rowid)); partitioners.back()->Reset(&ctx, impl->Size(), impl->base_rowid);
auto ridx = partitioners.at(k)->GetRows(0); auto ridx = partitioners.at(k)->GetRows(0);
auto d_histogram = dh::ToSpan(multi_hist); auto d_histogram = dh::ToSpan(multi_hist);
@ -362,7 +364,9 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
/** /**
* Single page. * Single page.
*/ */
RowPartitioner partitioner{&ctx, p_fmat->Info().num_row_, 0}; RowPartitioner partitioner;
partitioner.Reset(&ctx, p_fmat->Info().num_row_, 0);
SparsePage concat; SparsePage concat;
std::vector<float> hess(p_fmat->Info().num_row_, 1.0f); std::vector<float> hess(p_fmat->Info().num_row_, 1.0f);
for (auto const& page : p_fmat->GetBatches<SparsePage>()) { for (auto const& page : p_fmat->GetBatches<SparsePage>()) {

View File

@ -16,7 +16,8 @@ namespace xgboost::tree {
void TestUpdatePositionBatch() { void TestUpdatePositionBatch() {
const int kNumRows = 10; const int kNumRows = 10;
auto ctx = MakeCUDACtx(0); auto ctx = MakeCUDACtx(0);
RowPartitioner rp{&ctx, kNumRows, 0}; RowPartitioner rp;
rp.Reset(&ctx, kNumRows, 0);
auto rows = rp.GetRowsHost(0); auto rows = rp.GetRowsHost(0);
EXPECT_EQ(rows.size(), kNumRows); EXPECT_EQ(rows.size(), kNumRows);
for (auto i = 0ull; i < kNumRows; i++) { for (auto i = 0ull; i < kNumRows; i++) {

View File

@ -64,7 +64,8 @@ void TestBuildHist(bool use_shared_memory_histograms) {
} }
gpair.SetDevice(ctx.Device()); gpair.SetDevice(ctx.Device());
maker.row_partitioner = std::make_unique<RowPartitioner>(&ctx, kNRows, 0); maker.row_partitioner = std::make_unique<RowPartitioner>();
maker.row_partitioner->Reset(&ctx, kNRows, 0);
maker.hist.Init(ctx.Device(), page->Cuts().TotalBins()); maker.hist.Init(ctx.Device(), page->Cuts().TotalBins());
maker.hist.AllocateHistograms(&ctx, {0}); maker.hist.AllocateHistograms(&ctx, {0});