[EM] Have one partitioner for each batch. (#10760)
- Initialize one partitioner for each batch. - Collect partition size during initialization. - Support base ridx in the finalization.
This commit is contained in:
parent
3043827efc
commit
4fe67f10b4
@ -387,11 +387,6 @@ void CopyTo(Src const &src, Dst *dst) {
|
|||||||
src.size() * sizeof(SVT), cudaMemcpyDefault));
|
src.size() * sizeof(SVT), cudaMemcpyDefault));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class HContainer, class DContainer>
|
|
||||||
void CopyToD(HContainer const &h, DContainer *d) {
|
|
||||||
CopyTo(h, d);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Keep track of pinned memory allocation
|
// Keep track of pinned memory allocation
|
||||||
struct PinnedMemory {
|
struct PinnedMemory {
|
||||||
void *temp_storage{nullptr};
|
void *temp_storage{nullptr};
|
||||||
|
|||||||
@ -124,7 +124,7 @@ void NameThread(std::thread* t, StringView name) {
|
|||||||
char old[16];
|
char old[16];
|
||||||
auto ret = pthread_getname_np(handle, old, 16);
|
auto ret = pthread_getname_np(handle, old, 16);
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
LOG(WARNING) << "Failed to get the name from thread";
|
LOG(DEBUG) << "Failed to get the name from thread";
|
||||||
}
|
}
|
||||||
auto new_name = std::string{old} + ">" + name.c_str(); // NOLINT
|
auto new_name = std::string{old} + ">" + name.c_str(); // NOLINT
|
||||||
if (new_name.size() > 15) {
|
if (new_name.size() > 15) {
|
||||||
@ -132,7 +132,7 @@ void NameThread(std::thread* t, StringView name) {
|
|||||||
}
|
}
|
||||||
ret = pthread_setname_np(handle, new_name.c_str());
|
ret = pthread_setname_np(handle, new_name.c_str());
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
LOG(WARNING) << "Failed to name thread:" << ret << " :" << new_name;
|
LOG(DEBUG) << "Failed to name thread:" << ret << " :" << new_name;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
(void)name;
|
(void)name;
|
||||||
|
|||||||
@ -152,7 +152,7 @@ NoSampling::NoSampling(BatchParam batch_param) : batch_param_(std::move(batch_pa
|
|||||||
|
|
||||||
GradientBasedSample NoSampling::Sample(Context const*, common::Span<GradientPair> gpair,
|
GradientBasedSample NoSampling::Sample(Context const*, common::Span<GradientPair> gpair,
|
||||||
DMatrix* dmat) {
|
DMatrix* dmat) {
|
||||||
return {dmat->Info().num_row_, dmat, gpair};
|
return {dmat, gpair};
|
||||||
}
|
}
|
||||||
|
|
||||||
ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param)
|
ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param)
|
||||||
@ -179,7 +179,7 @@ GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx,
|
|||||||
this->p_fmat_new_ =
|
this->p_fmat_new_ =
|
||||||
std::make_unique<data::IterativeDMatrix>(new_page, p_fmat->Info(), batch_param_);
|
std::make_unique<data::IterativeDMatrix>(new_page, p_fmat->Info(), batch_param_);
|
||||||
}
|
}
|
||||||
return {p_fmat->Info().num_row_, this->p_fmat_new_.get(), gpair};
|
return {this->p_fmat_new_.get(), gpair};
|
||||||
}
|
}
|
||||||
|
|
||||||
UniformSampling::UniformSampling(BatchParam batch_param, float subsample)
|
UniformSampling::UniformSampling(BatchParam batch_param, float subsample)
|
||||||
@ -192,7 +192,7 @@ GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span<Gra
|
|||||||
thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
||||||
thrust::counting_iterator<std::size_t>(0),
|
thrust::counting_iterator<std::size_t>(0),
|
||||||
BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair());
|
BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair());
|
||||||
return {p_fmat->Info().num_row_, p_fmat, gpair};
|
return {p_fmat, gpair};
|
||||||
}
|
}
|
||||||
|
|
||||||
ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows,
|
ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows,
|
||||||
@ -252,7 +252,8 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx,
|
|||||||
// Create the new DMatrix
|
// Create the new DMatrix
|
||||||
this->p_fmat_new_ = std::make_unique<data::IterativeDMatrix>(
|
this->p_fmat_new_ = std::make_unique<data::IterativeDMatrix>(
|
||||||
new_page, dmat->Info().Slice(ctx, dh::ToSpan(compact_row_index_), nnz), batch_param_);
|
new_page, dmat->Info().Slice(ctx, dh::ToSpan(compact_row_index_), nnz), batch_param_);
|
||||||
return {sample_rows, this->p_fmat_new_.get(), dh::ToSpan(gpair_)};
|
CHECK_EQ(sample_rows, this->p_fmat_new_->Info().num_row_);
|
||||||
|
return {this->p_fmat_new_.get(), dh::ToSpan(gpair_)};
|
||||||
}
|
}
|
||||||
|
|
||||||
GradientBasedSampling::GradientBasedSampling(std::size_t n_rows, BatchParam batch_param,
|
GradientBasedSampling::GradientBasedSampling(std::size_t n_rows, BatchParam batch_param,
|
||||||
@ -274,7 +275,7 @@ GradientBasedSample GradientBasedSampling::Sample(Context const* ctx,
|
|||||||
thrust::counting_iterator<size_t>(0), dh::tbegin(gpair),
|
thrust::counting_iterator<size_t>(0), dh::tbegin(gpair),
|
||||||
PoissonSampling(dh::ToSpan(threshold_), threshold_index,
|
PoissonSampling(dh::ToSpan(threshold_), threshold_index,
|
||||||
RandomWeight(common::GlobalRandom()())));
|
RandomWeight(common::GlobalRandom()())));
|
||||||
return {n_rows, dmat, gpair};
|
return {dmat, gpair};
|
||||||
}
|
}
|
||||||
|
|
||||||
ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(size_t n_rows,
|
ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(size_t n_rows,
|
||||||
@ -334,7 +335,8 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
|
|||||||
// Create the new DMatrix
|
// Create the new DMatrix
|
||||||
this->p_fmat_new_ = std::make_unique<data::IterativeDMatrix>(
|
this->p_fmat_new_ = std::make_unique<data::IterativeDMatrix>(
|
||||||
new_page, dmat->Info().Slice(ctx, dh::ToSpan(compact_row_index_), nnz), batch_param_);
|
new_page, dmat->Info().Slice(ctx, dh::ToSpan(compact_row_index_), nnz), batch_param_);
|
||||||
return {sample_rows, this->p_fmat_new_.get(), dh::ToSpan(gpair_)};
|
CHECK_EQ(sample_rows, this->p_fmat_new_->Info().num_row_);
|
||||||
|
return {this->p_fmat_new_.get(), dh::ToSpan(gpair_)};
|
||||||
}
|
}
|
||||||
|
|
||||||
GradientBasedSampler::GradientBasedSampler(Context const* /*ctx*/, size_t n_rows,
|
GradientBasedSampler::GradientBasedSampler(Context const* /*ctx*/, size_t n_rows,
|
||||||
|
|||||||
@ -12,11 +12,9 @@
|
|||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
struct GradientBasedSample {
|
struct GradientBasedSample {
|
||||||
/*!\brief Number of sampled rows. */
|
/** @brief Sampled rows in ELLPACK format. */
|
||||||
bst_idx_t sample_rows;
|
|
||||||
/*!\brief Sampled rows in ELLPACK format. */
|
|
||||||
DMatrix* p_fmat;
|
DMatrix* p_fmat;
|
||||||
/*!\brief Gradient pairs for the sampled rows. */
|
/** @brief Gradient pairs for the sampled rows. */
|
||||||
common::Span<GradientPair const> gpair;
|
common::Span<GradientPair const> gpair;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -31,7 +31,7 @@ common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(bst_node_t
|
|||||||
return dh::ToSpan(ridx_).subspan(segment.begin, segment.Size());
|
return dh::ToSpan(ridx_).subspan(segment.begin, segment.Size());
|
||||||
}
|
}
|
||||||
|
|
||||||
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows() {
|
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows() const {
|
||||||
return dh::ToSpan(ridx_);
|
return dh::ToSpan(ridx_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -200,11 +200,11 @@ XGBOOST_DEV_INLINE int GetPositionFromSegments(std::size_t idx,
|
|||||||
|
|
||||||
template <int kBlockSize, typename RowIndexT, typename OpT>
|
template <int kBlockSize, typename RowIndexT, typename OpT>
|
||||||
__global__ __launch_bounds__(kBlockSize) void FinalisePositionKernel(
|
__global__ __launch_bounds__(kBlockSize) void FinalisePositionKernel(
|
||||||
const common::Span<const NodePositionInfo> d_node_info,
|
const common::Span<const NodePositionInfo> d_node_info, bst_idx_t base_ridx,
|
||||||
const common::Span<const RowIndexT> d_ridx, common::Span<bst_node_t> d_out_position, OpT op) {
|
const common::Span<const RowIndexT> d_ridx, common::Span<bst_node_t> d_out_position, OpT op) {
|
||||||
for (auto idx : dh::GridStrideRange<std::size_t>(0, d_ridx.size())) {
|
for (auto idx : dh::GridStrideRange<std::size_t>(0, d_ridx.size())) {
|
||||||
auto position = GetPositionFromSegments(idx, d_node_info.data());
|
auto position = GetPositionFromSegments(idx, d_node_info.data());
|
||||||
RowIndexT ridx = d_ridx[idx];
|
RowIndexT ridx = d_ridx[idx] - base_ridx;
|
||||||
bst_node_t new_position = op(ridx, position);
|
bst_node_t new_position = op(ridx, position);
|
||||||
d_out_position[ridx] = new_position;
|
d_out_position[ridx] = new_position;
|
||||||
}
|
}
|
||||||
@ -264,7 +264,12 @@ class RowPartitioner {
|
|||||||
/**
|
/**
|
||||||
* \brief Gets all training rows in the set.
|
* \brief Gets all training rows in the set.
|
||||||
*/
|
*/
|
||||||
common::Span<const RowIndexT> GetRows();
|
common::Span<const RowIndexT> GetRows() const;
|
||||||
|
/**
|
||||||
|
* @brief Get the number of rows in this partitioner.
|
||||||
|
*/
|
||||||
|
std::size_t Size() const { return this->GetRows().size(); }
|
||||||
|
|
||||||
[[nodiscard]] bst_node_t GetNumNodes() const { return n_nodes_; }
|
[[nodiscard]] bst_node_t GetNumNodes() const { return n_nodes_; }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -351,7 +356,8 @@ class RowPartitioner {
|
|||||||
* argument and return the new position for this training instance.
|
* argument and return the new position for this training instance.
|
||||||
*/
|
*/
|
||||||
template <typename FinalisePositionOpT>
|
template <typename FinalisePositionOpT>
|
||||||
void FinalisePosition(common::Span<bst_node_t> d_out_position, FinalisePositionOpT op) const {
|
void FinalisePosition(common::Span<bst_node_t> d_out_position, bst_idx_t base_ridx,
|
||||||
|
FinalisePositionOpT op) const {
|
||||||
dh::TemporaryArray<NodePositionInfo> d_node_info_storage(ridx_segments_.size());
|
dh::TemporaryArray<NodePositionInfo> d_node_info_storage(ridx_segments_.size());
|
||||||
dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(),
|
dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(),
|
||||||
sizeof(NodePositionInfo) * ridx_segments_.size(),
|
sizeof(NodePositionInfo) * ridx_segments_.size(),
|
||||||
@ -361,8 +367,8 @@ class RowPartitioner {
|
|||||||
const int kItemsThread = 8;
|
const int kItemsThread = 8;
|
||||||
const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread);
|
const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread);
|
||||||
common::Span<RowIndexT const> d_ridx{ridx_.data(), ridx_.size()};
|
common::Span<RowIndexT const> d_ridx{ridx_.data(), ridx_.size()};
|
||||||
FinalisePositionKernel<kBlockSize>
|
FinalisePositionKernel<kBlockSize><<<grid_size, kBlockSize, 0>>>(
|
||||||
<<<grid_size, kBlockSize, 0>>>(dh::ToSpan(d_node_info_storage), d_ridx, d_out_position, op);
|
dh::ToSpan(d_node_info_storage), base_ridx, d_ridx, d_out_position, op);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}; // namespace xgboost::tree
|
}; // namespace xgboost::tree
|
||||||
|
|||||||
@ -1,25 +1,24 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2017-2024, XGBoost contributors
|
* Copyright 2017-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <thrust/copy.h>
|
#include <thrust/functional.h> // for plus
|
||||||
#include <thrust/reduce.h>
|
#include <thrust/transform.h> // for transform
|
||||||
#include <xgboost/tree_updater.h>
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm> // for max
|
||||||
#include <cmath>
|
#include <cmath> // for isnan
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <memory> // for unique_ptr, make_unique
|
#include <memory> // for unique_ptr, make_unique
|
||||||
#include <utility> // for move
|
#include <utility> // for move
|
||||||
#include <vector>
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../collective/aggregator.h"
|
#include "../collective/aggregator.h"
|
||||||
#include "../collective/broadcast.h"
|
#include "../collective/broadcast.h" // for Broadcast
|
||||||
#include "../common/bitfield.h"
|
#include "../common/categorical.h" // for KCatBitField
|
||||||
#include "../common/categorical.h"
|
|
||||||
#include "../common/cuda_context.cuh" // for CUDAContext
|
#include "../common/cuda_context.cuh" // for CUDAContext
|
||||||
#include "../common/cuda_rt_utils.h" // for CheckComputeCapability
|
#include "../common/cuda_rt_utils.h" // for CheckComputeCapability
|
||||||
#include "../common/device_helpers.cuh"
|
#include "../common/device_helpers.cuh"
|
||||||
#include "../common/hist_util.h"
|
#include "../common/device_vector.cuh" // for device_vector
|
||||||
|
#include "../common/hist_util.h" // for HistogramCuts
|
||||||
#include "../common/random.h" // for ColumnSampler, GlobalRandom
|
#include "../common/random.h" // for ColumnSampler, GlobalRandom
|
||||||
#include "../common/timer.h"
|
#include "../common/timer.h"
|
||||||
#include "../data/ellpack_page.cuh"
|
#include "../data/ellpack_page.cuh"
|
||||||
@ -31,20 +30,20 @@
|
|||||||
#include "gpu_hist/feature_groups.cuh"
|
#include "gpu_hist/feature_groups.cuh"
|
||||||
#include "gpu_hist/gradient_based_sampler.cuh"
|
#include "gpu_hist/gradient_based_sampler.cuh"
|
||||||
#include "gpu_hist/histogram.cuh"
|
#include "gpu_hist/histogram.cuh"
|
||||||
#include "gpu_hist/row_partitioner.cuh"
|
#include "gpu_hist/row_partitioner.cuh" // for RowPartitioner
|
||||||
#include "hist/param.h"
|
#include "hist/param.h" // for HistMakerTrainParam
|
||||||
#include "param.h"
|
#include "param.h" // for TrainParam
|
||||||
#include "sample_position.h" // for SamplePosition
|
#include "sample_position.h" // for SamplePosition
|
||||||
#include "updater_gpu_common.cuh" // for HistBatch
|
#include "updater_gpu_common.cuh" // for HistBatch
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h" // for bst_idx_t
|
||||||
#include "xgboost/context.h"
|
#include "xgboost/context.h" // for Context
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h" // for DMatrix
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h" // for Json
|
||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h" // for Span
|
||||||
#include "xgboost/task.h" // for ObjInfo
|
#include "xgboost/task.h" // for ObjInfo
|
||||||
#include "xgboost/tree_model.h"
|
#include "xgboost/tree_model.h" // for RegTree
|
||||||
#include "xgboost/tree_updater.h"
|
#include "xgboost/tree_updater.h" // for TreeUpdater
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
|
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
|
||||||
@ -57,21 +56,6 @@ using cuda_impl::HistBatch;
|
|||||||
// parameter to avoid any regen.
|
// parameter to avoid any regen.
|
||||||
using cuda_impl::StaticBatch;
|
using cuda_impl::StaticBatch;
|
||||||
|
|
||||||
// GPU tree updater implementation.
|
|
||||||
struct GPUHistMakerDevice {
|
|
||||||
private:
|
|
||||||
GPUHistEvaluator evaluator_;
|
|
||||||
Context const* ctx_;
|
|
||||||
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
|
||||||
MetaInfo const& info_;
|
|
||||||
|
|
||||||
DeviceHistogramBuilder histogram_;
|
|
||||||
// node idx for each sample
|
|
||||||
dh::device_vector<bst_node_t> positions_;
|
|
||||||
std::unique_ptr<RowPartitioner> row_partitioner_;
|
|
||||||
std::shared_ptr<common::HistogramCuts const> cuts_{nullptr};
|
|
||||||
|
|
||||||
public:
|
|
||||||
// Extra data for each node that is passed to the update position function
|
// Extra data for each node that is passed to the update position function
|
||||||
struct NodeSplitData {
|
struct NodeSplitData {
|
||||||
RegTree::Node split_node;
|
RegTree::Node split_node;
|
||||||
@ -80,9 +64,23 @@ struct GPUHistMakerDevice {
|
|||||||
};
|
};
|
||||||
static_assert(std::is_trivially_copyable_v<NodeSplitData>);
|
static_assert(std::is_trivially_copyable_v<NodeSplitData>);
|
||||||
|
|
||||||
public:
|
// GPU tree updater implementation.
|
||||||
common::Span<FeatureType const> feature_types;
|
struct GPUHistMakerDevice {
|
||||||
|
private:
|
||||||
|
GPUHistEvaluator evaluator_;
|
||||||
|
Context const* ctx_;
|
||||||
|
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
||||||
|
// Set of row partitioners, one for each batch (external memory). When the training is
|
||||||
|
// in-core, there's only one partitioner.
|
||||||
|
std::vector<std::unique_ptr<RowPartitioner>> partitioners_;
|
||||||
|
|
||||||
|
DeviceHistogramBuilder histogram_;
|
||||||
|
std::vector<bst_idx_t> batch_ptr_;
|
||||||
|
// node idx for each sample
|
||||||
|
dh::device_vector<bst_node_t> positions_;
|
||||||
|
std::shared_ptr<common::HistogramCuts const> cuts_{nullptr};
|
||||||
|
|
||||||
|
public:
|
||||||
DeviceHistogramStorage<> hist{};
|
DeviceHistogramStorage<> hist{};
|
||||||
|
|
||||||
dh::device_vector<GradientPair> d_gpair; // storage for gpair;
|
dh::device_vector<GradientPair> d_gpair; // storage for gpair;
|
||||||
@ -104,21 +102,20 @@ struct GPUHistMakerDevice {
|
|||||||
|
|
||||||
std::unique_ptr<FeatureGroups> feature_groups;
|
std::unique_ptr<FeatureGroups> feature_groups;
|
||||||
|
|
||||||
GPUHistMakerDevice(Context const* ctx, std::shared_ptr<common::HistogramCuts const> cuts,
|
GPUHistMakerDevice(Context const* ctx, TrainParam _param,
|
||||||
bool is_external_memory, common::Span<FeatureType const> _feature_types,
|
std::shared_ptr<common::ColumnSampler> column_sampler, BatchParam batch_param,
|
||||||
TrainParam _param, std::shared_ptr<common::ColumnSampler> column_sampler,
|
MetaInfo const& info, std::vector<bst_idx_t> batch_ptr,
|
||||||
BatchParam batch_param, MetaInfo const& info)
|
std::shared_ptr<common::HistogramCuts const> cuts)
|
||||||
: evaluator_{_param, static_cast<bst_feature_t>(info.num_col_), ctx->Device()},
|
: evaluator_{_param, static_cast<bst_feature_t>(info.num_col_), ctx->Device()},
|
||||||
ctx_(ctx),
|
ctx_(ctx),
|
||||||
feature_types{_feature_types},
|
|
||||||
param(std::move(_param)),
|
param(std::move(_param)),
|
||||||
column_sampler_(std::move(column_sampler)),
|
column_sampler_(std::move(column_sampler)),
|
||||||
interaction_constraints(param, info.num_col_),
|
interaction_constraints(param, static_cast<bst_feature_t>(info.num_col_)),
|
||||||
info_{info},
|
batch_ptr_{std::move(batch_ptr)},
|
||||||
cuts_{std::move(cuts)} {
|
cuts_{std::move(cuts)} {
|
||||||
sampler =
|
sampler =
|
||||||
std::make_unique<GradientBasedSampler>(ctx, info.num_row_, batch_param, param.subsample,
|
std::make_unique<GradientBasedSampler>(ctx, info.num_row_, batch_param, param.subsample,
|
||||||
param.sampling_method, is_external_memory);
|
param.sampling_method, batch_ptr_.size() > 2);
|
||||||
if (!param.monotone_constraints.empty()) {
|
if (!param.monotone_constraints.empty()) {
|
||||||
// Copy assigning an empty vector causes an exception in MSVC debug builds
|
// Copy assigning an empty vector causes an exception in MSVC debug builds
|
||||||
monotone_constraints = param.monotone_constraints;
|
monotone_constraints = param.monotone_constraints;
|
||||||
@ -149,27 +146,45 @@ struct GPUHistMakerDevice {
|
|||||||
|
|
||||||
this->interaction_constraints.Reset();
|
this->interaction_constraints.Reset();
|
||||||
|
|
||||||
if (d_gpair.size() != dh_gpair->Size()) {
|
// Sampling
|
||||||
d_gpair.resize(dh_gpair->Size());
|
dh::CopyTo(dh_gpair->ConstDeviceSpan(), &this->d_gpair); // backup the gradient
|
||||||
}
|
auto sample = this->sampler->Sample(ctx_, dh::ToSpan(d_gpair), p_fmat);
|
||||||
dh::safe_cuda(cudaMemcpyAsync(d_gpair.data().get(), dh_gpair->ConstDevicePointer(),
|
|
||||||
dh_gpair->Size() * sizeof(GradientPair),
|
|
||||||
cudaMemcpyDeviceToDevice));
|
|
||||||
auto sample = sampler->Sample(ctx_, dh::ToSpan(d_gpair), p_fmat);
|
|
||||||
this->gpair = sample.gpair;
|
this->gpair = sample.gpair;
|
||||||
p_fmat = sample.p_fmat;
|
p_fmat = sample.p_fmat; // Update p_fmat before allocating partitioners
|
||||||
CHECK(p_fmat->SingleColBlock());
|
p_fmat->Info().feature_types.SetDevice(ctx_->Device());
|
||||||
|
std::size_t n_batches = p_fmat->NumBatches();
|
||||||
|
bool is_concat = (n_batches + 1) != this->batch_ptr_.size();
|
||||||
|
std::vector<bst_idx_t> batch_ptr{batch_ptr_};
|
||||||
|
if (is_concat) {
|
||||||
|
// Concatenate the batch ptrs as well.
|
||||||
|
batch_ptr = {static_cast<bst_idx_t>(0), p_fmat->Info().num_row_};
|
||||||
|
}
|
||||||
|
// Initialize partitions
|
||||||
|
if (!partitioners_.empty()) {
|
||||||
|
CHECK_EQ(partitioners_.size(), n_batches);
|
||||||
|
}
|
||||||
|
for (std::size_t k = 0; k < n_batches; ++k) {
|
||||||
|
if (partitioners_.size() != n_batches) {
|
||||||
|
// First run.
|
||||||
|
partitioners_.emplace_back(std::make_unique<RowPartitioner>());
|
||||||
|
}
|
||||||
|
auto base_ridx = batch_ptr[k];
|
||||||
|
auto n_samples = batch_ptr.at(k + 1) - base_ridx;
|
||||||
|
partitioners_[k]->Reset(ctx_, n_samples, base_ridx);
|
||||||
|
}
|
||||||
|
CHECK_EQ(partitioners_.size(), n_batches);
|
||||||
|
if (is_concat) {
|
||||||
|
CHECK_EQ(partitioners_.size(), 1);
|
||||||
|
CHECK_EQ(partitioners_.front()->Size(), p_fmat->Info().num_row_);
|
||||||
|
}
|
||||||
|
|
||||||
this->evaluator_.Reset(*cuts_, feature_types, p_fmat->Info().num_col_, param,
|
// Other initializations
|
||||||
p_fmat->Info().IsColumnSplit(), ctx_->Device());
|
this->evaluator_.Reset(*cuts_, p_fmat->Info().feature_types.ConstDeviceSpan(),
|
||||||
|
p_fmat->Info().num_col_, this->param, p_fmat->Info().IsColumnSplit(),
|
||||||
|
this->ctx_->Device());
|
||||||
|
|
||||||
quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, p_fmat->Info());
|
quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, p_fmat->Info());
|
||||||
|
|
||||||
if (!row_partitioner_) {
|
|
||||||
row_partitioner_ = std::make_unique<RowPartitioner>();
|
|
||||||
}
|
|
||||||
row_partitioner_->Reset(ctx_, sample.sample_rows, 0);
|
|
||||||
|
|
||||||
// Init histogram
|
// Init histogram
|
||||||
hist.Init(ctx_->Device(), this->cuts_->TotalBins());
|
hist.Init(ctx_->Device(), this->cuts_->TotalBins());
|
||||||
hist.Reset(ctx_);
|
hist.Reset(ctx_);
|
||||||
@ -181,22 +196,20 @@ struct GPUHistMakerDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
GPUExpandEntry EvaluateRootSplit(DMatrix const* p_fmat, GradientPairInt64 root_sum) {
|
GPUExpandEntry EvaluateRootSplit(DMatrix const* p_fmat, GradientPairInt64 root_sum) {
|
||||||
int nidx = RegTree::kRoot;
|
bst_node_t nidx = RegTree::kRoot;
|
||||||
GPUTrainingParam gpu_param(param);
|
GPUTrainingParam gpu_param(param);
|
||||||
auto sampled_features = column_sampler_->GetFeatureSet(0);
|
auto sampled_features = column_sampler_->GetFeatureSet(0);
|
||||||
sampled_features->SetDevice(ctx_->Device());
|
sampled_features->SetDevice(ctx_->Device());
|
||||||
common::Span<bst_feature_t> feature_set =
|
common::Span<bst_feature_t> feature_set =
|
||||||
interaction_constraints.Query(sampled_features->DeviceSpan(), nidx);
|
interaction_constraints.Query(sampled_features->DeviceSpan(), nidx);
|
||||||
EvaluateSplitInputs inputs{nidx, 0, root_sum, feature_set, hist.GetNodeHistogram(nidx)};
|
EvaluateSplitInputs inputs{nidx, 0, root_sum, feature_set, hist.GetNodeHistogram(nidx)};
|
||||||
EvaluateSplitSharedInputs shared_inputs{
|
EvaluateSplitSharedInputs shared_inputs{gpu_param,
|
||||||
gpu_param,
|
|
||||||
*quantiser,
|
*quantiser,
|
||||||
feature_types,
|
p_fmat->Info().feature_types.ConstDeviceSpan(),
|
||||||
cuts_->cut_ptrs_.ConstDeviceSpan(),
|
cuts_->cut_ptrs_.ConstDeviceSpan(),
|
||||||
cuts_->cut_values_.ConstDeviceSpan(),
|
cuts_->cut_values_.ConstDeviceSpan(),
|
||||||
cuts_->min_vals_.ConstDeviceSpan(),
|
cuts_->min_vals_.ConstDeviceSpan(),
|
||||||
p_fmat->IsDense() && !collective::IsDistributed()
|
p_fmat->IsDense() && !collective::IsDistributed()};
|
||||||
};
|
|
||||||
auto split = this->evaluator_.EvaluateSingleSplit(ctx_, inputs, shared_inputs);
|
auto split = this->evaluator_.EvaluateSingleSplit(ctx_, inputs, shared_inputs);
|
||||||
return split;
|
return split;
|
||||||
}
|
}
|
||||||
@ -212,8 +225,9 @@ struct GPUHistMakerDevice {
|
|||||||
std::vector<bst_node_t> nidx(2 * candidates.size());
|
std::vector<bst_node_t> nidx(2 * candidates.size());
|
||||||
auto h_node_inputs = pinned2.GetSpan<EvaluateSplitInputs>(2 * candidates.size());
|
auto h_node_inputs = pinned2.GetSpan<EvaluateSplitInputs>(2 * candidates.size());
|
||||||
EvaluateSplitSharedInputs shared_inputs{
|
EvaluateSplitSharedInputs shared_inputs{
|
||||||
GPUTrainingParam{param}, *quantiser, feature_types, cuts_->cut_ptrs_.ConstDeviceSpan(),
|
GPUTrainingParam{param}, *quantiser, p_fmat->Info().feature_types.ConstDeviceSpan(),
|
||||||
cuts_->cut_values_.ConstDeviceSpan(), cuts_->min_vals_.ConstDeviceSpan(),
|
cuts_->cut_ptrs_.ConstDeviceSpan(), cuts_->cut_values_.ConstDeviceSpan(),
|
||||||
|
cuts_->min_vals_.ConstDeviceSpan(),
|
||||||
// is_dense represents the local data
|
// is_dense represents the local data
|
||||||
p_fmat->IsDense() && !collective::IsDistributed()};
|
p_fmat->IsDense() && !collective::IsDistributed()};
|
||||||
dh::TemporaryArray<GPUExpandEntry> entries(2 * candidates.size());
|
dh::TemporaryArray<GPUExpandEntry> entries(2 * candidates.size());
|
||||||
@ -262,7 +276,7 @@ struct GPUHistMakerDevice {
|
|||||||
|
|
||||||
void BuildHist(EllpackPageImpl const* page, int nidx) {
|
void BuildHist(EllpackPageImpl const* page, int nidx) {
|
||||||
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
||||||
auto d_ridx = row_partitioner_->GetRows(nidx);
|
auto d_ridx = partitioners_.front()->GetRows(nidx);
|
||||||
this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()),
|
this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()),
|
||||||
feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx,
|
feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx,
|
||||||
d_node_hist, *quantiser);
|
d_node_hist, *quantiser);
|
||||||
@ -335,7 +349,7 @@ struct GPUHistMakerDevice {
|
|||||||
};
|
};
|
||||||
collective::SafeColl(rc);
|
collective::SafeColl(rc);
|
||||||
|
|
||||||
row_partitioner_->UpdatePositionBatch(
|
partitioners_.front()->UpdatePositionBatch(
|
||||||
nidx, left_nidx, right_nidx, split_data,
|
nidx, left_nidx, right_nidx, split_data,
|
||||||
[=] __device__(bst_uint ridx, int nidx_in_batch, NodeSplitData const& data) {
|
[=] __device__(bst_uint ridx, int nidx_in_batch, NodeSplitData const& data) {
|
||||||
auto const index = ridx * num_candidates + nidx_in_batch;
|
auto const index = ridx * num_candidates + nidx_in_batch;
|
||||||
@ -396,16 +410,17 @@ struct GPUHistMakerDevice {
|
|||||||
CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat);
|
CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CHECK_EQ(p_fmat->NumBatches(), 1);
|
||||||
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
||||||
auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_->Device());
|
auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_->Device());
|
||||||
|
|
||||||
if (info_.IsColumnSplit()) {
|
if (p_fmat->Info().IsColumnSplit()) {
|
||||||
UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx);
|
UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx);
|
||||||
monitor.Stop(__func__);
|
monitor.Stop(__func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto go_left = GoLeftOp{d_matrix};
|
auto go_left = GoLeftOp{d_matrix};
|
||||||
row_partitioner_->UpdatePositionBatch(
|
partitioners_.front()->UpdatePositionBatch(
|
||||||
nidx, left_nidx, right_nidx, split_data,
|
nidx, left_nidx, right_nidx, split_data,
|
||||||
[=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/,
|
[=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/,
|
||||||
const NodeSplitData& data) { return go_left(ridx, data); });
|
const NodeSplitData& data) { return go_left(ridx, data); });
|
||||||
@ -423,25 +438,30 @@ struct GPUHistMakerDevice {
|
|||||||
LOG(FATAL) << "Current objective function can not be used with external memory.";
|
LOG(FATAL) << "Current objective function can not be used with external memory.";
|
||||||
}
|
}
|
||||||
if (p_fmat->Info().num_row_ != n_samples) {
|
if (p_fmat->Info().num_row_ != n_samples) {
|
||||||
// Subsampling with external memory. Not supported.
|
// External memory with concatenation. Not supported.
|
||||||
p_out_position->Resize(0);
|
p_out_position->Resize(0);
|
||||||
positions_.clear();
|
positions_.clear();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
p_out_position->SetDevice(ctx_->Device());
|
p_out_position->SetDevice(ctx_->Device());
|
||||||
p_out_position->Resize(row_partitioner_->GetRows().size());
|
p_out_position->Resize(p_fmat->Info().num_row_);
|
||||||
auto d_out_position = p_out_position->DeviceSpan();
|
auto d_out_position = p_out_position->DeviceSpan();
|
||||||
|
|
||||||
auto d_gpair = this->gpair;
|
auto d_gpair = this->gpair;
|
||||||
auto encode_op = [=] __device__(bst_idx_t row_id, bst_node_t nidx) {
|
auto encode_op = [=] __device__(bst_idx_t ridx, bst_node_t nidx) {
|
||||||
bool is_invalid = d_gpair[row_id].GetHess() - .0f == 0.f;
|
bool is_invalid = d_gpair[ridx].GetHess() - .0f == 0.f;
|
||||||
return SamplePosition::Encode(nidx, !is_invalid);
|
return SamplePosition::Encode(nidx, !is_invalid);
|
||||||
}; // NOLINT
|
}; // NOLINT
|
||||||
|
|
||||||
if (!p_fmat->SingleColBlock()) {
|
if (!p_fmat->SingleColBlock()) {
|
||||||
CHECK_EQ(row_partitioner_->GetNumNodes(), p_tree->NumNodes());
|
for (std::size_t k = 0; k < partitioners_.size(); ++k) {
|
||||||
row_partitioner_->FinalisePosition(d_out_position, encode_op);
|
auto& part = partitioners_.at(k);
|
||||||
|
CHECK_EQ(part->GetNumNodes(), p_tree->NumNodes());
|
||||||
|
auto base_ridx = batch_ptr_[k];
|
||||||
|
auto n_samples = batch_ptr_.at(k + 1) - base_ridx;
|
||||||
|
part->FinalisePosition(d_out_position.subspan(base_ridx, n_samples), base_ridx, encode_op);
|
||||||
|
}
|
||||||
dh::CopyTo(d_out_position, &positions_);
|
dh::CopyTo(d_out_position, &positions_);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -465,11 +485,11 @@ struct GPUHistMakerDevice {
|
|||||||
|
|
||||||
auto go_left_op = GoLeftOp{d_matrix};
|
auto go_left_op = GoLeftOp{d_matrix};
|
||||||
dh::caching_device_vector<NodeSplitData> d_split_data;
|
dh::caching_device_vector<NodeSplitData> d_split_data;
|
||||||
dh::CopyToD(split_data, &d_split_data);
|
dh::CopyTo(split_data, &d_split_data);
|
||||||
auto s_split_data = dh::ToSpan(d_split_data);
|
auto s_split_data = dh::ToSpan(d_split_data);
|
||||||
|
|
||||||
row_partitioner_->FinalisePosition(d_out_position,
|
partitioners_.front()->FinalisePosition(
|
||||||
[=] __device__(bst_idx_t row_id, bst_node_t nidx) {
|
d_out_position, page.BaseRowId(), [=] __device__(bst_idx_t row_id, bst_node_t nidx) {
|
||||||
auto split_data = s_split_data[nidx];
|
auto split_data = s_split_data[nidx];
|
||||||
auto node = split_data.split_node;
|
auto node = split_data.split_node;
|
||||||
while (!node.IsLeaf()) {
|
while (!node.IsLeaf()) {
|
||||||
@ -479,10 +499,9 @@ struct GPUHistMakerDevice {
|
|||||||
}
|
}
|
||||||
return encode_op(row_id, nidx);
|
return encode_op(row_id, nidx);
|
||||||
});
|
});
|
||||||
}
|
|
||||||
|
|
||||||
dh::CopyTo(d_out_position, &positions_);
|
dh::CopyTo(d_out_position, &positions_);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool UpdatePredictionCache(linalg::MatrixView<float> out_preds_d, RegTree const* p_tree) {
|
bool UpdatePredictionCache(linalg::MatrixView<float> out_preds_d, RegTree const* p_tree) {
|
||||||
if (positions_.empty()) {
|
if (positions_.empty()) {
|
||||||
@ -513,17 +532,16 @@ struct GPUHistMakerDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// num histograms is the number of contiguous histograms in memory to reduce over
|
// num histograms is the number of contiguous histograms in memory to reduce over
|
||||||
void AllReduceHist(bst_node_t nidx, int num_histograms) {
|
void AllReduceHist(MetaInfo const& info, bst_node_t nidx, int num_histograms) {
|
||||||
monitor.Start("AllReduce");
|
monitor.Start(__func__);
|
||||||
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
|
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
||||||
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
|
using ReduceT = typename std::remove_pointer<decltype(d_node_hist.data())>::type::ValueT;
|
||||||
auto rc = collective::GlobalSum(
|
auto rc = collective::GlobalSum(
|
||||||
ctx_, info_,
|
ctx_, info,
|
||||||
linalg::MakeVec(reinterpret_cast<ReduceT*>(d_node_hist),
|
linalg::MakeVec(reinterpret_cast<ReduceT*>(d_node_hist.data()),
|
||||||
cuts_->TotalBins() * 2 * num_histograms, ctx_->Device()));
|
d_node_hist.size() * 2 * num_histograms, ctx_->Device()));
|
||||||
SafeColl(rc);
|
SafeColl(rc);
|
||||||
|
monitor.Stop(__func__);
|
||||||
monitor.Stop("AllReduce");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -566,7 +584,7 @@ struct GPUHistMakerDevice {
|
|||||||
// Reduce all in one go
|
// Reduce all in one go
|
||||||
// This gives much better latency in a distributed setting
|
// This gives much better latency in a distributed setting
|
||||||
// when processing a large batch
|
// when processing a large batch
|
||||||
this->AllReduceHist(hist_nidx.at(0), hist_nidx.size());
|
this->AllReduceHist(p_fmat->Info(), hist_nidx.at(0), hist_nidx.size());
|
||||||
|
|
||||||
for (size_t i = 0; i < subtraction_nidx.size(); i++) {
|
for (size_t i = 0; i < subtraction_nidx.size(); i++) {
|
||||||
auto build_hist_nidx = hist_nidx.at(i);
|
auto build_hist_nidx = hist_nidx.at(i);
|
||||||
@ -578,7 +596,7 @@ struct GPUHistMakerDevice {
|
|||||||
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
||||||
this->BuildHist(page.Impl(), subtraction_trick_nidx);
|
this->BuildHist(page.Impl(), subtraction_trick_nidx);
|
||||||
}
|
}
|
||||||
this->AllReduceHist(subtraction_trick_nidx, 1);
|
this->AllReduceHist(p_fmat->Info(), subtraction_trick_nidx, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
this->monitor.Stop(__func__);
|
this->monitor.Stop(__func__);
|
||||||
@ -588,18 +606,16 @@ struct GPUHistMakerDevice {
|
|||||||
RegTree& tree = *p_tree;
|
RegTree& tree = *p_tree;
|
||||||
|
|
||||||
// Sanity check - have we created a leaf with no training instances?
|
// Sanity check - have we created a leaf with no training instances?
|
||||||
if (!collective::IsDistributed() && row_partitioner_) {
|
if (!collective::IsDistributed() && partitioners_.size() == 1) {
|
||||||
CHECK(row_partitioner_->GetRows(candidate.nid).size() > 0)
|
CHECK(partitioners_.front()->GetRows(candidate.nid).size() > 0)
|
||||||
<< "No training instances in this leaf!";
|
<< "No training instances in this leaf!";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto base_weight = candidate.base_weight;
|
auto base_weight = candidate.base_weight;
|
||||||
auto left_weight = candidate.left_weight * param.learning_rate;
|
auto left_weight = candidate.left_weight * param.learning_rate;
|
||||||
auto right_weight = candidate.right_weight * param.learning_rate;
|
auto right_weight = candidate.right_weight * param.learning_rate;
|
||||||
auto parent_hess = quantiser
|
auto parent_hess =
|
||||||
->ToFloatingPoint(candidate.split.left_sum +
|
quantiser->ToFloatingPoint(candidate.split.left_sum + candidate.split.right_sum).GetHess();
|
||||||
candidate.split.right_sum)
|
|
||||||
.GetHess();
|
|
||||||
auto left_hess =
|
auto left_hess =
|
||||||
quantiser->ToFloatingPoint(candidate.split.left_sum).GetHess();
|
quantiser->ToFloatingPoint(candidate.split.left_sum).GetHess();
|
||||||
auto right_hess =
|
auto right_hess =
|
||||||
@ -640,22 +656,21 @@ struct GPUHistMakerDevice {
|
|||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
auto quantiser = *this->quantiser;
|
auto quantiser = *this->quantiser;
|
||||||
auto gpair_it = dh::MakeTransformIterator<GradientPairInt64>(
|
auto gpair_it = dh::MakeTransformIterator<GradientPairInt64>(
|
||||||
dh::tbegin(gpair), [=] __device__(auto const &gpair) {
|
dh::tbegin(gpair),
|
||||||
return quantiser.ToFixedPoint(gpair);
|
[=] __device__(auto const& gpair) { return quantiser.ToFixedPoint(gpair); });
|
||||||
});
|
|
||||||
GradientPairInt64 root_sum_quantised =
|
GradientPairInt64 root_sum_quantised =
|
||||||
dh::Reduce(ctx_->CUDACtx()->CTP(), gpair_it, gpair_it + gpair.size(),
|
dh::Reduce(ctx_->CUDACtx()->CTP(), gpair_it, gpair_it + gpair.size(), GradientPairInt64{},
|
||||||
GradientPairInt64{}, thrust::plus<GradientPairInt64>{});
|
thrust::plus<GradientPairInt64>{});
|
||||||
using ReduceT = typename decltype(root_sum_quantised)::ValueT;
|
using ReduceT = typename decltype(root_sum_quantised)::ValueT;
|
||||||
auto rc = collective::GlobalSum(
|
auto rc = collective::GlobalSum(
|
||||||
ctx_, info_, linalg::MakeVec(reinterpret_cast<ReduceT*>(&root_sum_quantised), 2));
|
ctx_, p_fmat->Info(), linalg::MakeVec(reinterpret_cast<ReduceT*>(&root_sum_quantised), 2));
|
||||||
collective::SafeColl(rc);
|
collective::SafeColl(rc);
|
||||||
|
|
||||||
hist.AllocateHistograms(ctx_, {kRootNIdx});
|
hist.AllocateHistograms(ctx_, {kRootNIdx});
|
||||||
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
||||||
this->BuildHist(page.Impl(), kRootNIdx);
|
this->BuildHist(page.Impl(), kRootNIdx);
|
||||||
}
|
}
|
||||||
this->AllReduceHist(kRootNIdx, 1);
|
this->AllReduceHist(p_fmat->Info(), kRootNIdx, 1);
|
||||||
|
|
||||||
// Remember root stats
|
// Remember root stats
|
||||||
auto root_sum = quantiser.ToFloatingPoint(root_sum_quantised);
|
auto root_sum = quantiser.ToFloatingPoint(root_sum_quantised);
|
||||||
@ -719,12 +734,30 @@ struct GPUHistMakerDevice {
|
|||||||
// restrictions like min loss change after evalaution. Therefore, the check condition
|
// restrictions like min loss change after evalaution. Therefore, the check condition
|
||||||
// is greater than or equal to.
|
// is greater than or equal to.
|
||||||
if (is_single_block) {
|
if (is_single_block) {
|
||||||
CHECK_GE(p_tree->NumNodes(), this->row_partitioner_->GetNumNodes());
|
CHECK_GE(p_tree->NumNodes(), this->partitioners_.front()->GetNumNodes());
|
||||||
}
|
}
|
||||||
this->FinalisePosition(p_fmat, p_tree, *task, n_samples, p_out_position);
|
this->FinalisePosition(p_fmat, p_tree, *task, n_samples, p_out_position);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::shared_ptr<common::HistogramCuts const> InitBatchCuts(Context const* ctx, DMatrix* p_fmat,
|
||||||
|
BatchParam batch,
|
||||||
|
std::vector<bst_idx_t>* p_batch_ptr) {
|
||||||
|
std::vector<bst_idx_t>& batch_ptr = *p_batch_ptr;
|
||||||
|
batch_ptr = {0};
|
||||||
|
std::shared_ptr<common::HistogramCuts const> cuts;
|
||||||
|
|
||||||
|
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx, batch)) {
|
||||||
|
batch_ptr.push_back(page.Size());
|
||||||
|
cuts = page.Impl()->CutsShared();
|
||||||
|
CHECK(cuts->cut_values_.DeviceCanRead());
|
||||||
|
}
|
||||||
|
CHECK(cuts);
|
||||||
|
CHECK_EQ(p_fmat->NumBatches(), batch_ptr.size() - 1);
|
||||||
|
std::partial_sum(batch_ptr.cbegin(), batch_ptr.cend(), batch_ptr.begin());
|
||||||
|
return cuts;
|
||||||
|
}
|
||||||
|
|
||||||
class GPUHistMaker : public TreeUpdater {
|
class GPUHistMaker : public TreeUpdater {
|
||||||
using GradientSumT = GradientPairPrecise;
|
using GradientSumT = GradientPairPrecise;
|
||||||
|
|
||||||
@ -774,23 +807,20 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
CHECK_GE(ctx_->Ordinal(), 0) << "Must have at least one device";
|
CHECK_GE(ctx_->Ordinal(), 0) << "Must have at least one device";
|
||||||
|
|
||||||
// Synchronise the column sampling seed
|
// Synchronise the column sampling seed
|
||||||
uint32_t column_sampling_seed = common::GlobalRandom()();
|
std::uint32_t column_sampling_seed = common::GlobalRandom()();
|
||||||
auto rc = collective::Broadcast(
|
SafeColl(collective::Broadcast(
|
||||||
ctx_, linalg::MakeVec(&column_sampling_seed, sizeof(column_sampling_seed)), 0);
|
ctx_, linalg::MakeVec(&column_sampling_seed, sizeof(column_sampling_seed)), 0));
|
||||||
SafeColl(rc);
|
|
||||||
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
|
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
|
||||||
|
|
||||||
std::shared_ptr<common::HistogramCuts const> cuts;
|
|
||||||
auto batch = HistBatch(*param);
|
|
||||||
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, HistBatch(*param))) {
|
|
||||||
cuts = page.Impl()->CutsShared();
|
|
||||||
}
|
|
||||||
|
|
||||||
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
|
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
|
||||||
p_fmat->Info().feature_types.SetDevice(ctx_->Device());
|
p_fmat->Info().feature_types.SetDevice(ctx_->Device());
|
||||||
maker = std::make_unique<GPUHistMakerDevice>(ctx_, cuts, !p_fmat->SingleColBlock(),
|
|
||||||
p_fmat->Info().feature_types.ConstDeviceSpan(),
|
std::vector<bst_idx_t> batch_ptr;
|
||||||
*param, column_sampler_, batch, p_fmat->Info());
|
auto batch = HistBatch(*param);
|
||||||
|
auto cuts = InitBatchCuts(ctx_, p_fmat, batch, &batch_ptr);
|
||||||
|
|
||||||
|
this->maker = std::make_unique<GPUHistMakerDevice>(ctx_, *param, column_sampler_, batch,
|
||||||
|
p_fmat->Info(), batch_ptr, cuts);
|
||||||
|
|
||||||
p_last_fmat_ = p_fmat;
|
p_last_fmat_ = p_fmat;
|
||||||
initialised_ = true;
|
initialised_ = true;
|
||||||
@ -896,15 +926,14 @@ class GPUGlobalApproxMaker : public TreeUpdater {
|
|||||||
|
|
||||||
auto const& info = p_fmat->Info();
|
auto const& info = p_fmat->Info();
|
||||||
info.feature_types.SetDevice(ctx_->Device());
|
info.feature_types.SetDevice(ctx_->Device());
|
||||||
std::shared_ptr<common::HistogramCuts const> cuts;
|
|
||||||
|
std::vector<bst_idx_t> batch_ptr;
|
||||||
auto batch = ApproxBatch(*param, hess, *task_);
|
auto batch = ApproxBatch(*param, hess, *task_);
|
||||||
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, batch)) {
|
auto cuts = InitBatchCuts(ctx_, p_fmat, batch, &batch_ptr);
|
||||||
cuts = page.Impl()->CutsShared();
|
|
||||||
}
|
|
||||||
batch.regen = false; // Regen only at the beginning of the iteration.
|
batch.regen = false; // Regen only at the beginning of the iteration.
|
||||||
maker_ = std::make_unique<GPUHistMakerDevice>(ctx_, cuts, !p_fmat->SingleColBlock(),
|
|
||||||
info.feature_types.ConstDeviceSpan(), *param,
|
this->maker_ = std::make_unique<GPUHistMakerDevice>(ctx_, *param, column_sampler_, batch,
|
||||||
column_sampler_, batch, p_fmat->Info());
|
p_fmat->Info(), batch_ptr, cuts);
|
||||||
|
|
||||||
std::size_t t_idx{0};
|
std::size_t t_idx{0};
|
||||||
for (xgboost::RegTree* tree : trees) {
|
for (xgboost::RegTree* tree : trees) {
|
||||||
|
|||||||
@ -38,11 +38,9 @@ void VerifySampling(size_t page_size, float subsample, int sampling_method,
|
|||||||
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
|
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
|
||||||
|
|
||||||
if (fixed_size_sampling) {
|
if (fixed_size_sampling) {
|
||||||
EXPECT_EQ(sample.sample_rows, kRows);
|
|
||||||
EXPECT_EQ(sample.p_fmat->Info().num_row_, kRows);
|
EXPECT_EQ(sample.p_fmat->Info().num_row_, kRows);
|
||||||
EXPECT_EQ(sample.gpair.size(), kRows);
|
EXPECT_EQ(sample.gpair.size(), kRows);
|
||||||
} else {
|
} else {
|
||||||
EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.03);
|
|
||||||
EXPECT_NEAR(sample.p_fmat->Info().num_row_, sample_rows, kRows * 0.03f);
|
EXPECT_NEAR(sample.p_fmat->Info().num_row_, sample_rows, kRows * 0.03f);
|
||||||
EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.03f);
|
EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.03f);
|
||||||
}
|
}
|
||||||
@ -89,7 +87,7 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
|
|||||||
GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true);
|
GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true);
|
||||||
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
|
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
|
||||||
auto p_fmat = sample.p_fmat;
|
auto p_fmat = sample.p_fmat;
|
||||||
EXPECT_EQ(sample.sample_rows, kRows);
|
EXPECT_EQ(sample.p_fmat->Info().num_row_, kRows);
|
||||||
EXPECT_EQ(sample.gpair.size(), gpair.Size());
|
EXPECT_EQ(sample.gpair.size(), gpair.Size());
|
||||||
EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer());
|
EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer());
|
||||||
EXPECT_EQ(p_fmat->Info().num_row_, kRows);
|
EXPECT_EQ(p_fmat->Info().num_row_, kRows);
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from test_data_iterator import test_single_batch as cpu_single_batch
|
|||||||
|
|
||||||
|
|
||||||
def test_gpu_single_batch() -> None:
|
def test_gpu_single_batch() -> None:
|
||||||
cpu_single_batch("gpu_hist")
|
cpu_single_batch("hist", "cuda")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**no_cupy())
|
@pytest.mark.skipif(**no_cupy())
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from xgboost.testing.updater import check_quantile_loss_extmem
|
|||||||
pytestmark = tm.timeout(30)
|
pytestmark = tm.timeout(30)
|
||||||
|
|
||||||
|
|
||||||
def test_single_batch(tree_method: str = "approx") -> None:
|
def test_single_batch(tree_method: str = "approx", device: str = "cpu") -> None:
|
||||||
from sklearn.datasets import load_breast_cancer
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
|
||||||
n_rounds = 10
|
n_rounds = 10
|
||||||
@ -25,17 +25,19 @@ def test_single_batch(tree_method: str = "approx") -> None:
|
|||||||
X = X.astype(np.float32)
|
X = X.astype(np.float32)
|
||||||
y = y.astype(np.float32)
|
y = y.astype(np.float32)
|
||||||
|
|
||||||
|
params = {"tree_method": tree_method, "device": device}
|
||||||
|
|
||||||
Xy = xgb.DMatrix(SingleBatch(data=X, label=y))
|
Xy = xgb.DMatrix(SingleBatch(data=X, label=y))
|
||||||
from_it = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds)
|
from_it = xgb.train(params, Xy, num_boost_round=n_rounds)
|
||||||
|
|
||||||
Xy = xgb.DMatrix(X, y)
|
Xy = xgb.DMatrix(X, y)
|
||||||
from_dmat = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds)
|
from_dmat = xgb.train(params, Xy, num_boost_round=n_rounds)
|
||||||
assert from_it.get_dump() == from_dmat.get_dump()
|
assert from_it.get_dump() == from_dmat.get_dump()
|
||||||
|
|
||||||
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
|
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
|
||||||
X = X.astype(np.float32)
|
X = X.astype(np.float32)
|
||||||
Xy = xgb.DMatrix(SingleBatch(data=X, label=y))
|
Xy = xgb.DMatrix(SingleBatch(data=X, label=y))
|
||||||
from_pd = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds)
|
from_pd = xgb.train(params, Xy, num_boost_round=n_rounds)
|
||||||
# remove feature info to generate exact same text representation.
|
# remove feature info to generate exact same text representation.
|
||||||
from_pd.feature_names = None
|
from_pd.feature_names = None
|
||||||
from_pd.feature_types = None
|
from_pd.feature_types = None
|
||||||
@ -45,11 +47,11 @@ def test_single_batch(tree_method: str = "approx") -> None:
|
|||||||
X, y = load_breast_cancer(return_X_y=True)
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
X = csr_matrix(X)
|
X = csr_matrix(X)
|
||||||
Xy = xgb.DMatrix(SingleBatch(data=X, label=y))
|
Xy = xgb.DMatrix(SingleBatch(data=X, label=y))
|
||||||
from_it = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds)
|
from_it = xgb.train(params, Xy, num_boost_round=n_rounds)
|
||||||
|
|
||||||
X, y = load_breast_cancer(return_X_y=True)
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
Xy = xgb.DMatrix(SingleBatch(data=X, label=y), missing=0.0)
|
Xy = xgb.DMatrix(SingleBatch(data=X, label=y), missing=0.0)
|
||||||
from_np = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds)
|
from_np = xgb.train(params, Xy, num_boost_round=n_rounds)
|
||||||
assert from_np.get_dump() == from_it.get_dump()
|
assert from_np.get_dump() == from_it.get_dump()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user