Avoid caching allocator for large allocations. (#10582)
This commit is contained in:
parent
b2cae34a8e
commit
a19bbc9be5
@ -227,7 +227,7 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c
|
|||||||
});
|
});
|
||||||
detail::SortByWeight(&entry_weight, &sorted_entries);
|
detail::SortByWeight(&entry_weight, &sorted_entries);
|
||||||
} else {
|
} else {
|
||||||
thrust::sort(cuctx->CTP(), sorted_entries.begin(), sorted_entries.end(),
|
thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(),
|
||||||
detail::EntryCompareOp());
|
detail::EntryCompareOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -10,14 +10,20 @@
|
|||||||
#include "row_partitioner.cuh"
|
#include "row_partitioner.cuh"
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
RowPartitioner::RowPartitioner(Context const* ctx, bst_idx_t n_samples, bst_idx_t base_rowid)
|
void RowPartitioner::Reset(Context const* ctx, bst_idx_t n_samples, bst_idx_t base_rowid) {
|
||||||
: device_idx_(ctx->Device()), ridx_(n_samples), ridx_tmp_(n_samples) {
|
ridx_segments_.clear();
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx_.ordinal));
|
ridx_.resize(n_samples);
|
||||||
ridx_segments_.emplace_back(NodePositionInfo{Segment(0, n_samples)});
|
ridx_tmp_.resize(n_samples);
|
||||||
|
tmp_.clear();
|
||||||
|
|
||||||
|
CHECK_LE(n_samples, std::numeric_limits<cuda_impl::RowIndexT>::max());
|
||||||
|
ridx_segments_.emplace_back(
|
||||||
|
NodePositionInfo{Segment{0, static_cast<cuda_impl::RowIndexT>(n_samples)}});
|
||||||
|
|
||||||
thrust::sequence(ctx->CUDACtx()->CTP(), ridx_.data(), ridx_.data() + ridx_.size(), base_rowid);
|
thrust::sequence(ctx->CUDACtx()->CTP(), ridx_.data(), ridx_.data() + ridx_.size(), base_rowid);
|
||||||
}
|
}
|
||||||
|
|
||||||
RowPartitioner::~RowPartitioner() { dh::safe_cuda(cudaSetDevice(device_idx_.ordinal)); }
|
RowPartitioner::~RowPartitioner() = default;
|
||||||
|
|
||||||
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(bst_node_t nidx) {
|
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(bst_node_t nidx) {
|
||||||
auto segment = ridx_segments_.at(nidx).segment;
|
auto segment = ridx_segments_.at(nidx).segment;
|
||||||
|
|||||||
@ -7,25 +7,34 @@
|
|||||||
#include <thrust/iterator/transform_output_iterator.h> // for make_transform_output_iterator
|
#include <thrust/iterator/transform_output_iterator.h> // for make_transform_output_iterator
|
||||||
|
|
||||||
#include <algorithm> // for max
|
#include <algorithm> // for max
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
#include <cstdint> // for int32_t, uint32_t
|
||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../../common/device_helpers.cuh" // for MakeTransformIterator
|
#include "../../common/device_helpers.cuh" // for MakeTransformIterator
|
||||||
#include "xgboost/base.h" // for bst_idx_t
|
#include "xgboost/base.h" // for bst_idx_t
|
||||||
#include "xgboost/context.h" // for Context
|
#include "xgboost/context.h" // for Context
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
namespace cuda_impl {
|
||||||
|
using RowIndexT = std::uint32_t;
|
||||||
|
}
|
||||||
|
|
||||||
/** \brief Used to demarcate a contiguous set of row indices associated with
|
/**
|
||||||
* some tree node. */
|
* @brief Used to demarcate a contiguous set of row indices associated with some tree
|
||||||
|
* node.
|
||||||
|
*/
|
||||||
struct Segment {
|
struct Segment {
|
||||||
bst_uint begin{0};
|
cuda_impl::RowIndexT begin{0};
|
||||||
bst_uint end{0};
|
cuda_impl::RowIndexT end{0};
|
||||||
|
|
||||||
Segment() = default;
|
Segment() = default;
|
||||||
|
|
||||||
Segment(bst_uint begin, bst_uint end) : begin(begin), end(end) { CHECK_GE(end, begin); }
|
Segment(cuda_impl::RowIndexT begin, cuda_impl::RowIndexT end) : begin(begin), end(end) {
|
||||||
__host__ __device__ size_t Size() const { return end - begin; }
|
CHECK_GE(end, begin);
|
||||||
|
}
|
||||||
|
__host__ __device__ bst_idx_t Size() const { return end - begin; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO(Rory): Can be larger. To be tuned alongside other batch operations.
|
// TODO(Rory): Can be larger. To be tuned alongside other batch operations.
|
||||||
@ -39,7 +48,7 @@ struct PerNodeData {
|
|||||||
template <typename BatchIterT>
|
template <typename BatchIterT>
|
||||||
__device__ __forceinline__ void AssignBatch(BatchIterT batch_info, std::size_t global_thread_idx,
|
__device__ __forceinline__ void AssignBatch(BatchIterT batch_info, std::size_t global_thread_idx,
|
||||||
int* batch_idx, std::size_t* item_idx) {
|
int* batch_idx, std::size_t* item_idx) {
|
||||||
bst_uint sum = 0;
|
cuda_impl::RowIndexT sum = 0;
|
||||||
for (int i = 0; i < kMaxUpdatePositionBatchSize; i++) {
|
for (int i = 0; i < kMaxUpdatePositionBatchSize; i++) {
|
||||||
if (sum + batch_info[i].segment.Size() > global_thread_idx) {
|
if (sum + batch_info[i].segment.Size() > global_thread_idx) {
|
||||||
*batch_idx = i;
|
*batch_idx = i;
|
||||||
@ -65,9 +74,9 @@ __global__ __launch_bounds__(kBlockSize) void SortPositionCopyKernel(
|
|||||||
// We can scan over this tuple, where the scan gives us information on how to partition inputs
|
// We can scan over this tuple, where the scan gives us information on how to partition inputs
|
||||||
// according to the flag
|
// according to the flag
|
||||||
struct IndexFlagTuple {
|
struct IndexFlagTuple {
|
||||||
bst_uint idx; // The location of the item we are working on in ridx_
|
cuda_impl::RowIndexT idx; // The location of the item we are working on in ridx_
|
||||||
bst_uint flag_scan; // This gets populated after scanning
|
cuda_impl::RowIndexT flag_scan; // This gets populated after scanning
|
||||||
int batch_idx; // Which node in the batch does this item belong to
|
std::int32_t batch_idx; // Which node in the batch does this item belong to
|
||||||
bool flag; // Result of op (is this item going left?)
|
bool flag; // Result of op (is this item going left?)
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -86,18 +95,18 @@ struct IndexFlagOp {
|
|||||||
template <typename OpDataT>
|
template <typename OpDataT>
|
||||||
struct WriteResultsFunctor {
|
struct WriteResultsFunctor {
|
||||||
dh::LDGIterator<PerNodeData<OpDataT>> batch_info;
|
dh::LDGIterator<PerNodeData<OpDataT>> batch_info;
|
||||||
const bst_uint* ridx_in;
|
cuda_impl::RowIndexT const* ridx_in;
|
||||||
bst_uint* ridx_out;
|
cuda_impl::RowIndexT* ridx_out;
|
||||||
bst_uint* counts;
|
cuda_impl::RowIndexT* counts;
|
||||||
|
|
||||||
__device__ IndexFlagTuple operator()(const IndexFlagTuple& x) {
|
__device__ IndexFlagTuple operator()(const IndexFlagTuple& x) {
|
||||||
std::size_t scatter_address;
|
std::size_t scatter_address;
|
||||||
const Segment& segment = batch_info[x.batch_idx].segment;
|
const Segment& segment = batch_info[x.batch_idx].segment;
|
||||||
if (x.flag) {
|
if (x.flag) {
|
||||||
bst_uint num_previous_flagged = x.flag_scan - 1; // -1 because inclusive scan
|
cuda_impl::RowIndexT num_previous_flagged = x.flag_scan - 1; // -1 because inclusive scan
|
||||||
scatter_address = segment.begin + num_previous_flagged;
|
scatter_address = segment.begin + num_previous_flagged;
|
||||||
} else {
|
} else {
|
||||||
bst_uint num_previous_unflagged = (x.idx - segment.begin) - x.flag_scan;
|
cuda_impl::RowIndexT num_previous_unflagged = (x.idx - segment.begin) - x.flag_scan;
|
||||||
scatter_address = segment.end - num_previous_unflagged - 1;
|
scatter_address = segment.end - num_previous_unflagged - 1;
|
||||||
}
|
}
|
||||||
ridx_out[scatter_address] = ridx_in[x.idx];
|
ridx_out[scatter_address] = ridx_in[x.idx];
|
||||||
@ -115,7 +124,7 @@ struct WriteResultsFunctor {
|
|||||||
template <typename RowIndexT, typename OpT, typename OpDataT>
|
template <typename RowIndexT, typename OpT, typename OpDataT>
|
||||||
void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
|
void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
|
||||||
common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx_tmp,
|
common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx_tmp,
|
||||||
common::Span<bst_uint> d_counts, std::size_t total_rows, OpT op,
|
common::Span<cuda_impl::RowIndexT> d_counts, std::size_t total_rows, OpT op,
|
||||||
dh::device_vector<int8_t>* tmp) {
|
dh::device_vector<int8_t>* tmp) {
|
||||||
dh::LDGIterator<PerNodeData<OpDataT>> batch_info_itr(d_batch_info.data());
|
dh::LDGIterator<PerNodeData<OpDataT>> batch_info_itr(d_batch_info.data());
|
||||||
WriteResultsFunctor<OpDataT> write_results{batch_info_itr, ridx.data(), ridx_tmp.data(),
|
WriteResultsFunctor<OpDataT> write_results{batch_info_itr, ridx.data(), ridx_tmp.data(),
|
||||||
@ -130,7 +139,7 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
|
|||||||
std::size_t item_idx;
|
std::size_t item_idx;
|
||||||
AssignBatch(batch_info_itr, idx, &batch_idx, &item_idx);
|
AssignBatch(batch_info_itr, idx, &batch_idx, &item_idx);
|
||||||
auto op_res = op(ridx[item_idx], batch_idx, batch_info_itr[batch_idx].data);
|
auto op_res = op(ridx[item_idx], batch_idx, batch_info_itr[batch_idx].data);
|
||||||
return IndexFlagTuple{static_cast<bst_uint>(item_idx), op_res, batch_idx, op_res};
|
return IndexFlagTuple{static_cast<cuda_impl::RowIndexT>(item_idx), op_res, batch_idx, op_res};
|
||||||
});
|
});
|
||||||
size_t temp_bytes = 0;
|
size_t temp_bytes = 0;
|
||||||
if (tmp->empty()) {
|
if (tmp->empty()) {
|
||||||
@ -195,29 +204,31 @@ __global__ __launch_bounds__(kBlockSize) void FinalisePositionKernel(
|
|||||||
* partition training rows into different leaf nodes. */
|
* partition training rows into different leaf nodes. */
|
||||||
class RowPartitioner {
|
class RowPartitioner {
|
||||||
public:
|
public:
|
||||||
using RowIndexT = bst_uint;
|
using RowIndexT = cuda_impl::RowIndexT;
|
||||||
static constexpr bst_node_t kIgnoredTreePosition = -1;
|
static constexpr bst_node_t kIgnoredTreePosition = -1;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DeviceOrd device_idx_;
|
/**
|
||||||
/*! \brief In here if you want to find the rows belong to a node nid, first you need to
|
* In here if you want to find the rows belong to a node nid, first you need to get the
|
||||||
* get the indices segment from ridx_segments[nid], then get the row index that
|
* indices segment from ridx_segments[nid], then get the row index that represents
|
||||||
* represents position of row in input data X. `RowPartitioner::GetRows` would be a
|
* position of row in input data X. `RowPartitioner::GetRows` would be a good starting
|
||||||
* good starting place to get a sense what are these vector storing.
|
* place to get a sense what are these vector storing.
|
||||||
*
|
*
|
||||||
* node id -> segment -> indices of rows belonging to node
|
* node id -> segment -> indices of rows belonging to node
|
||||||
*/
|
*/
|
||||||
/*! \brief Range of row index for each node, pointers into ridx below. */
|
|
||||||
|
|
||||||
|
/** @brief Range of row index for each node, pointers into ridx below. */
|
||||||
std::vector<NodePositionInfo> ridx_segments_;
|
std::vector<NodePositionInfo> ridx_segments_;
|
||||||
/*! \brief mapping for node id -> rows.
|
/**
|
||||||
|
* @brief mapping for node id -> rows.
|
||||||
|
*
|
||||||
* This looks like:
|
* This looks like:
|
||||||
* node id | 1 | 2 |
|
* node id | 1 | 2 |
|
||||||
* rows idx | 3, 5, 1 | 13, 31 |
|
* rows idx | 3, 5, 1 | 13, 31 |
|
||||||
*/
|
*/
|
||||||
dh::TemporaryArray<RowIndexT> ridx_;
|
dh::DeviceUVector<RowIndexT> ridx_;
|
||||||
// Staging area for sorting ridx
|
// Staging area for sorting ridx
|
||||||
dh::TemporaryArray<RowIndexT> ridx_tmp_;
|
dh::DeviceUVector<RowIndexT> ridx_tmp_;
|
||||||
dh::device_vector<int8_t> tmp_;
|
dh::device_vector<int8_t> tmp_;
|
||||||
dh::PinnedMemory pinned_;
|
dh::PinnedMemory pinned_;
|
||||||
dh::PinnedMemory pinned2_;
|
dh::PinnedMemory pinned2_;
|
||||||
@ -228,7 +239,9 @@ class RowPartitioner {
|
|||||||
* @param n_samples The number of samples in each batch.
|
* @param n_samples The number of samples in each batch.
|
||||||
* @param base_rowid The base row index for the current batch.
|
* @param base_rowid The base row index for the current batch.
|
||||||
*/
|
*/
|
||||||
RowPartitioner(Context const* ctx, bst_idx_t n_samples, bst_idx_t base_rowid);
|
RowPartitioner() = default;
|
||||||
|
void Reset(Context const* ctx, bst_idx_t n_samples, bst_idx_t base_rowid);
|
||||||
|
|
||||||
~RowPartitioner();
|
~RowPartitioner();
|
||||||
RowPartitioner(const RowPartitioner&) = delete;
|
RowPartitioner(const RowPartitioner&) = delete;
|
||||||
RowPartitioner& operator=(const RowPartitioner&) = delete;
|
RowPartitioner& operator=(const RowPartitioner&) = delete;
|
||||||
@ -285,8 +298,8 @@ class RowPartitioner {
|
|||||||
cudaMemcpyDefault));
|
cudaMemcpyDefault));
|
||||||
|
|
||||||
// Temporary arrays
|
// Temporary arrays
|
||||||
auto h_counts = pinned_.GetSpan<bst_uint>(nidx.size(), 0);
|
auto h_counts = pinned_.GetSpan<RowIndexT>(nidx.size(), 0);
|
||||||
dh::TemporaryArray<bst_uint> d_counts(nidx.size(), 0);
|
dh::TemporaryArray<RowIndexT> d_counts(nidx.size(), 0);
|
||||||
|
|
||||||
// Partition the rows according to the operator
|
// Partition the rows according to the operator
|
||||||
SortPositionBatch<RowIndexT, UpdatePositionOpT, OpDataT>(
|
SortPositionBatch<RowIndexT, UpdatePositionOpT, OpDataT>(
|
||||||
@ -299,7 +312,7 @@ class RowPartitioner {
|
|||||||
dh::DefaultStream().Sync();
|
dh::DefaultStream().Sync();
|
||||||
|
|
||||||
// Update segments
|
// Update segments
|
||||||
for (size_t i = 0; i < nidx.size(); i++) {
|
for (std::size_t i = 0; i < nidx.size(); i++) {
|
||||||
auto segment = ridx_segments_.at(nidx[i]).segment;
|
auto segment = ridx_segments_.at(nidx[i]).segment;
|
||||||
auto left_count = h_counts[i];
|
auto left_count = h_counts[i];
|
||||||
CHECK_LE(left_count, segment.Size());
|
CHECK_LE(left_count, segment.Size());
|
||||||
@ -336,11 +349,9 @@ class RowPartitioner {
|
|||||||
constexpr int kBlockSize = 512;
|
constexpr int kBlockSize = 512;
|
||||||
const int kItemsThread = 8;
|
const int kItemsThread = 8;
|
||||||
const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread);
|
const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread);
|
||||||
common::Span<const RowIndexT> d_ridx(ridx_.data().get(), ridx_.size());
|
common::Span<RowIndexT const> d_ridx{ridx_.data(), ridx_.size()};
|
||||||
FinalisePositionKernel<kBlockSize><<<grid_size, kBlockSize, 0>>>(
|
FinalisePositionKernel<kBlockSize>
|
||||||
dh::ToSpan(d_node_info_storage), d_ridx, d_out_position, op);
|
<<<grid_size, kBlockSize, 0>>>(dh::ToSpan(d_node_info_storage), d_ridx, d_out_position, op);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
}; // namespace xgboost::tree
|
||||||
}; // namespace tree
|
|
||||||
}; // namespace xgboost
|
|
||||||
|
|||||||
@ -145,9 +145,11 @@ struct GPUHistMakerDevice {
|
|||||||
|
|
||||||
quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, dmat->Info());
|
quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, dmat->Info());
|
||||||
|
|
||||||
row_partitioner.reset(); // Release the device memory first before reallocating
|
if (!row_partitioner) {
|
||||||
|
row_partitioner = std::make_unique<RowPartitioner>();
|
||||||
|
}
|
||||||
|
row_partitioner->Reset(ctx_, sample.sample_rows, page->base_rowid);
|
||||||
CHECK_EQ(page->base_rowid, 0);
|
CHECK_EQ(page->base_rowid, 0);
|
||||||
row_partitioner = std::make_unique<RowPartitioner>(ctx_, sample.sample_rows, page->base_rowid);
|
|
||||||
|
|
||||||
// Init histogram
|
// Init histogram
|
||||||
hist.Init(ctx_->Device(), page->Cuts().TotalBins());
|
hist.Init(ctx_->Device(), page->Cuts().TotalBins());
|
||||||
|
|||||||
@ -66,7 +66,8 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
|
|||||||
for (auto const& batch : matrix->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
for (auto const& batch : matrix->GetBatches<EllpackPage>(&ctx, batch_param)) {
|
||||||
auto* page = batch.Impl();
|
auto* page = batch.Impl();
|
||||||
|
|
||||||
tree::RowPartitioner row_partitioner{&ctx, kRows, page->base_rowid};
|
tree::RowPartitioner row_partitioner;
|
||||||
|
row_partitioner.Reset(&ctx, kRows, page->base_rowid);
|
||||||
auto ridx = row_partitioner.GetRows(0);
|
auto ridx = row_partitioner.GetRows(0);
|
||||||
|
|
||||||
bst_bin_t num_bins = kBins * kCols;
|
bst_bin_t num_bins = kBins * kCols;
|
||||||
@ -171,7 +172,8 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
|||||||
auto cat_m = GetDMatrixFromData(x, kRows, 1);
|
auto cat_m = GetDMatrixFromData(x, kRows, 1);
|
||||||
cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
||||||
auto batch_param = BatchParam{kBins, tree::TrainParam::DftSparseThreshold()};
|
auto batch_param = BatchParam{kBins, tree::TrainParam::DftSparseThreshold()};
|
||||||
tree::RowPartitioner row_partitioner{&ctx, kRows, 0};
|
tree::RowPartitioner row_partitioner;
|
||||||
|
row_partitioner.Reset(&ctx, kRows, 0);
|
||||||
auto ridx = row_partitioner.GetRows(0);
|
auto ridx = row_partitioner.GetRows(0);
|
||||||
dh::device_vector<GradientPairInt64> cat_hist(num_categories);
|
dh::device_vector<GradientPairInt64> cat_hist(num_categories);
|
||||||
auto gpair = GenerateRandomGradients(kRows, 0, 2);
|
auto gpair = GenerateRandomGradients(kRows, 0, 2);
|
||||||
@ -343,8 +345,8 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
|
|||||||
cuts = std::make_shared<common::HistogramCuts>(impl->Cuts());
|
cuts = std::make_shared<common::HistogramCuts>(impl->Cuts());
|
||||||
}
|
}
|
||||||
|
|
||||||
partitioners.emplace_back(
|
partitioners.emplace_back(std::make_unique<RowPartitioner>());
|
||||||
std::make_unique<RowPartitioner>(&ctx, impl->Size(), impl->base_rowid));
|
partitioners.back()->Reset(&ctx, impl->Size(), impl->base_rowid);
|
||||||
|
|
||||||
auto ridx = partitioners.at(k)->GetRows(0);
|
auto ridx = partitioners.at(k)->GetRows(0);
|
||||||
auto d_histogram = dh::ToSpan(multi_hist);
|
auto d_histogram = dh::ToSpan(multi_hist);
|
||||||
@ -362,7 +364,9 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
|
|||||||
/**
|
/**
|
||||||
* Single page.
|
* Single page.
|
||||||
*/
|
*/
|
||||||
RowPartitioner partitioner{&ctx, p_fmat->Info().num_row_, 0};
|
RowPartitioner partitioner;
|
||||||
|
partitioner.Reset(&ctx, p_fmat->Info().num_row_, 0);
|
||||||
|
|
||||||
SparsePage concat;
|
SparsePage concat;
|
||||||
std::vector<float> hess(p_fmat->Info().num_row_, 1.0f);
|
std::vector<float> hess(p_fmat->Info().num_row_, 1.0f);
|
||||||
for (auto const& page : p_fmat->GetBatches<SparsePage>()) {
|
for (auto const& page : p_fmat->GetBatches<SparsePage>()) {
|
||||||
|
|||||||
@ -16,7 +16,8 @@ namespace xgboost::tree {
|
|||||||
void TestUpdatePositionBatch() {
|
void TestUpdatePositionBatch() {
|
||||||
const int kNumRows = 10;
|
const int kNumRows = 10;
|
||||||
auto ctx = MakeCUDACtx(0);
|
auto ctx = MakeCUDACtx(0);
|
||||||
RowPartitioner rp{&ctx, kNumRows, 0};
|
RowPartitioner rp;
|
||||||
|
rp.Reset(&ctx, kNumRows, 0);
|
||||||
auto rows = rp.GetRowsHost(0);
|
auto rows = rp.GetRowsHost(0);
|
||||||
EXPECT_EQ(rows.size(), kNumRows);
|
EXPECT_EQ(rows.size(), kNumRows);
|
||||||
for (auto i = 0ull; i < kNumRows; i++) {
|
for (auto i = 0ull; i < kNumRows; i++) {
|
||||||
|
|||||||
@ -64,7 +64,8 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
|||||||
}
|
}
|
||||||
gpair.SetDevice(ctx.Device());
|
gpair.SetDevice(ctx.Device());
|
||||||
|
|
||||||
maker.row_partitioner = std::make_unique<RowPartitioner>(&ctx, kNRows, 0);
|
maker.row_partitioner = std::make_unique<RowPartitioner>();
|
||||||
|
maker.row_partitioner->Reset(&ctx, kNRows, 0);
|
||||||
|
|
||||||
maker.hist.Init(ctx.Device(), page->Cuts().TotalBins());
|
maker.hist.Init(ctx.Device(), page->Cuts().TotalBins());
|
||||||
maker.hist.AllocateHistograms(&ctx, {0});
|
maker.hist.AllocateHistograms(&ctx, {0});
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user