[coll] Pass context to various functions. (#9772)
* [coll] Pass context to various functions. In the future, the `Context` object would be required for collective operations, this PR passes the context object to some required functions to prepare for swapping out the implementation.
This commit is contained in:
parent
6c0a190f6d
commit
06bdc15e9b
@ -178,7 +178,7 @@ class MetaInfo {
|
|||||||
* in vertical federated learning, since each worker loads its own list of columns,
|
* in vertical federated learning, since each worker loads its own list of columns,
|
||||||
* we need to sum them.
|
* we need to sum them.
|
||||||
*/
|
*/
|
||||||
void SynchronizeNumberOfColumns();
|
void SynchronizeNumberOfColumns(Context const* ctx);
|
||||||
|
|
||||||
/*! \brief Whether the data is split row-wise. */
|
/*! \brief Whether the data is split row-wise. */
|
||||||
bool IsRowSplit() const {
|
bool IsRowSplit() const {
|
||||||
|
|||||||
@ -582,20 +582,20 @@ auto MakeTensorView(Context const *ctx, Container &data, S &&...shape) { // NOL
|
|||||||
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device()};
|
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device()};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename... S>
|
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
|
||||||
LINALG_HD auto MakeTensorView(DeviceOrd device, common::Span<T> data, S &&...shape) {
|
LINALG_HD auto MakeTensorView(DeviceOrd device, common::Span<T, ext> data, S &&...shape) {
|
||||||
std::size_t in_shape[sizeof...(S)];
|
std::size_t in_shape[sizeof...(S)];
|
||||||
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
|
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
|
||||||
return TensorView<T, sizeof...(S)>{data, in_shape, device};
|
return TensorView<T, sizeof...(S)>{data, in_shape, device};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename... S>
|
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
|
||||||
auto MakeTensorView(Context const *ctx, common::Span<T> data, S &&...shape) {
|
auto MakeTensorView(Context const *ctx, common::Span<T, ext> data, S &&...shape) {
|
||||||
return MakeTensorView(ctx->Device(), data, std::forward<S>(shape)...);
|
return MakeTensorView(ctx->Device(), data, std::forward<S>(shape)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename... S>
|
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
|
||||||
auto MakeTensorView(Context const *ctx, Order order, common::Span<T> data, S &&...shape) {
|
auto MakeTensorView(Context const *ctx, Order order, common::Span<T, ext> data, S &&...shape) {
|
||||||
std::size_t in_shape[sizeof...(S)];
|
std::size_t in_shape[sizeof...(S)];
|
||||||
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
|
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
|
||||||
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device(), order};
|
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device(), order};
|
||||||
|
|||||||
@ -29,7 +29,7 @@ namespace {
|
|||||||
auto stub = fed->Handle();
|
auto stub = fed->Handle();
|
||||||
|
|
||||||
BroadcastRequest request;
|
BroadcastRequest request;
|
||||||
request.set_sequence_number(*sequence_number++);
|
request.set_sequence_number((*sequence_number)++);
|
||||||
request.set_rank(comm.Rank());
|
request.set_rank(comm.Rank());
|
||||||
if (comm.Rank() != root) {
|
if (comm.Rank() != root) {
|
||||||
request.set_send_buffer(nullptr, 0);
|
request.set_send_buffer(nullptr, 0);
|
||||||
@ -90,9 +90,9 @@ Coll *FederatedColl::MakeCUDAVar() {
|
|||||||
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||||
std::int32_t root) {
|
std::int32_t root) {
|
||||||
if (comm.Rank() == root) {
|
if (comm.Rank() == root) {
|
||||||
return BroadcastImpl(comm, &sequence_number_, data, root);
|
return BroadcastImpl(comm, &this->sequence_number_, data, root);
|
||||||
} else {
|
} else {
|
||||||
return BroadcastImpl(comm, &sequence_number_, data, root);
|
return BroadcastImpl(comm, &this->sequence_number_, data, root);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -62,6 +62,9 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
|||||||
|
|
||||||
Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
|
Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
|
||||||
ArrayInterfaceHandler::Type type) {
|
ArrayInterfaceHandler::Type type) {
|
||||||
|
if (comm.World() == 1) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
return DispatchDType(type, [&](auto t) {
|
return DispatchDType(type, [&](auto t) {
|
||||||
using T = decltype(t);
|
using T = decltype(t);
|
||||||
// Divide the data into segments according to the number of workers.
|
// Divide the data into segments according to the number of workers.
|
||||||
|
|||||||
@ -10,6 +10,7 @@
|
|||||||
#include <sstream> // for stringstream
|
#include <sstream> // for stringstream
|
||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#include "../common/cuda_context.cuh" // for CUDAContext
|
||||||
#include "../common/device_helpers.cuh" // for DefaultStream
|
#include "../common/device_helpers.cuh" // for DefaultStream
|
||||||
#include "../common/type.h" // for EraseType
|
#include "../common/type.h" // for EraseType
|
||||||
#include "broadcast.h" // for Broadcast
|
#include "broadcast.h" // for Broadcast
|
||||||
@ -60,7 +61,7 @@ Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const {
|
|||||||
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl)
|
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl)
|
||||||
: Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(),
|
: Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(),
|
||||||
root.TaskID()},
|
root.TaskID()},
|
||||||
stream_{dh::DefaultStream()} {
|
stream_{ctx->CUDACtx()->Stream()} {
|
||||||
this->world_ = root.World();
|
this->world_ = root.World();
|
||||||
this->rank_ = root.Rank();
|
this->rank_ = root.Rank();
|
||||||
this->domain_ = root.Domain();
|
this->domain_ = root.Domain();
|
||||||
|
|||||||
@ -105,7 +105,7 @@ CommGroup::CommGroup()
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<collective::CommGroup>& GlobalCommGroup() {
|
std::unique_ptr<collective::CommGroup>& GlobalCommGroup() {
|
||||||
static std::unique_ptr<collective::CommGroup> sptr;
|
static thread_local std::unique_ptr<collective::CommGroup> sptr;
|
||||||
if (!sptr) {
|
if (!sptr) {
|
||||||
Json config{Null{}};
|
Json config{Null{}};
|
||||||
sptr.reset(CommGroup::Create(config));
|
sptr.reset(CommGroup::Create(config));
|
||||||
|
|||||||
@ -480,7 +480,8 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
|
|||||||
cub::CachingDeviceAllocator& GetGlobalCachingAllocator() {
|
cub::CachingDeviceAllocator& GetGlobalCachingAllocator() {
|
||||||
// Configure allocator with maximum cached bin size of ~1GB and no limit on
|
// Configure allocator with maximum cached bin size of ~1GB and no limit on
|
||||||
// maximum cached bytes
|
// maximum cached bytes
|
||||||
thread_local cub::CachingDeviceAllocator *allocator = new cub::CachingDeviceAllocator(2, 9, 29);
|
thread_local std::unique_ptr<cub::CachingDeviceAllocator> allocator{
|
||||||
|
std::make_unique<cub::CachingDeviceAllocator>(2, 9, 29)};
|
||||||
return *allocator;
|
return *allocator;
|
||||||
}
|
}
|
||||||
pointer allocate(size_t n) { // NOLINT
|
pointer allocate(size_t n) { // NOLINT
|
||||||
|
|||||||
@ -51,7 +51,7 @@ HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins
|
|||||||
for (auto const &page : m->GetBatches<SparsePage>()) {
|
for (auto const &page : m->GetBatches<SparsePage>()) {
|
||||||
container.PushRowPage(page, info, hessian);
|
container.PushRowPage(page, info, hessian);
|
||||||
}
|
}
|
||||||
container.MakeCuts(m->Info(), &out);
|
container.MakeCuts(ctx, m->Info(), &out);
|
||||||
} else {
|
} else {
|
||||||
SortedSketchContainer container{ctx,
|
SortedSketchContainer container{ctx,
|
||||||
max_bins,
|
max_bins,
|
||||||
@ -61,7 +61,7 @@ HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins
|
|||||||
for (auto const &page : m->GetBatches<SortedCSCPage>(ctx)) {
|
for (auto const &page : m->GetBatches<SortedCSCPage>(ctx)) {
|
||||||
container.PushColPage(page, info, hessian);
|
container.PushColPage(page, info, hessian);
|
||||||
}
|
}
|
||||||
container.MakeCuts(m->Info(), &out);
|
container.MakeCuts(ctx, m->Info(), &out);
|
||||||
}
|
}
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
|
|||||||
@ -359,7 +359,7 @@ HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_b
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sketch_container.MakeCuts(&cuts, p_fmat->Info().IsColumnSplit());
|
sketch_container.MakeCuts(ctx, &cuts, p_fmat->Info().IsColumnSplit());
|
||||||
return cuts;
|
return cuts;
|
||||||
}
|
}
|
||||||
} // namespace xgboost::common
|
} // namespace xgboost::common
|
||||||
|
|||||||
@ -11,9 +11,7 @@
|
|||||||
#include "categorical.h"
|
#include "categorical.h"
|
||||||
#include "hist_util.h"
|
#include "hist_util.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::common {
|
||||||
namespace common {
|
|
||||||
|
|
||||||
template <typename WQSketch>
|
template <typename WQSketch>
|
||||||
SketchContainerImpl<WQSketch>::SketchContainerImpl(Context const *ctx,
|
SketchContainerImpl<WQSketch>::SketchContainerImpl(Context const *ctx,
|
||||||
std::vector<bst_row_t> columns_size,
|
std::vector<bst_row_t> columns_size,
|
||||||
@ -129,7 +127,7 @@ struct QuantileAllreduce {
|
|||||||
* \param rank rank of target worker
|
* \param rank rank of target worker
|
||||||
* \param fidx feature idx
|
* \param fidx feature idx
|
||||||
*/
|
*/
|
||||||
auto Values(int32_t rank, bst_feature_t fidx) const {
|
[[nodiscard]] auto Values(int32_t rank, bst_feature_t fidx) const {
|
||||||
// get span for worker
|
// get span for worker
|
||||||
auto wsize = worker_indptr[rank + 1] - worker_indptr[rank];
|
auto wsize = worker_indptr[rank + 1] - worker_indptr[rank];
|
||||||
auto worker_values = global_values.subspan(worker_indptr[rank], wsize);
|
auto worker_values = global_values.subspan(worker_indptr[rank], wsize);
|
||||||
@ -145,7 +143,7 @@ struct QuantileAllreduce {
|
|||||||
|
|
||||||
template <typename WQSketch>
|
template <typename WQSketch>
|
||||||
void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||||
MetaInfo const& info,
|
Context const *, MetaInfo const &info,
|
||||||
std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
||||||
std::vector<size_t> *p_worker_segments, std::vector<bst_row_t> *p_sketches_scan,
|
std::vector<size_t> *p_worker_segments, std::vector<bst_row_t> *p_sketches_scan,
|
||||||
std::vector<typename WQSketch::Entry> *p_global_sketches) {
|
std::vector<typename WQSketch::Entry> *p_global_sketches) {
|
||||||
@ -206,7 +204,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename WQSketch>
|
template <typename WQSketch>
|
||||||
void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
|
void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const*, MetaInfo const& info) {
|
||||||
auto world_size = collective::GetWorldSize();
|
auto world_size = collective::GetWorldSize();
|
||||||
auto rank = collective::GetRank();
|
auto rank = collective::GetRank();
|
||||||
if (world_size == 1 || info.IsColumnSplit()) {
|
if (world_size == 1 || info.IsColumnSplit()) {
|
||||||
@ -274,16 +272,15 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
|
|||||||
|
|
||||||
template <typename WQSketch>
|
template <typename WQSketch>
|
||||||
void SketchContainerImpl<WQSketch>::AllReduce(
|
void SketchContainerImpl<WQSketch>::AllReduce(
|
||||||
MetaInfo const& info,
|
Context const *ctx, MetaInfo const &info,
|
||||||
std::vector<typename WQSketch::SummaryContainer> *p_reduced,
|
std::vector<typename WQSketch::SummaryContainer> *p_reduced, std::vector<int32_t> *p_num_cuts) {
|
||||||
std::vector<int32_t>* p_num_cuts) {
|
|
||||||
monitor_.Start(__func__);
|
monitor_.Start(__func__);
|
||||||
|
|
||||||
size_t n_columns = sketches_.size();
|
size_t n_columns = sketches_.size();
|
||||||
collective::Allreduce<collective::Operation::kMax>(&n_columns, 1);
|
collective::Allreduce<collective::Operation::kMax>(&n_columns, 1);
|
||||||
CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers";
|
CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers";
|
||||||
|
|
||||||
AllreduceCategories(info);
|
AllreduceCategories(ctx, info);
|
||||||
|
|
||||||
auto& num_cuts = *p_num_cuts;
|
auto& num_cuts = *p_num_cuts;
|
||||||
CHECK_EQ(num_cuts.size(), 0);
|
CHECK_EQ(num_cuts.size(), 0);
|
||||||
@ -324,7 +321,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
|
|||||||
std::vector<bst_row_t> sketches_scan((n_columns + 1) * world, 0);
|
std::vector<bst_row_t> sketches_scan((n_columns + 1) * world, 0);
|
||||||
|
|
||||||
std::vector<typename WQSketch::Entry> global_sketches;
|
std::vector<typename WQSketch::Entry> global_sketches;
|
||||||
this->GatherSketchInfo(info, reduced, &worker_segments, &sketches_scan, &global_sketches);
|
this->GatherSketchInfo(ctx, info, reduced, &worker_segments, &sketches_scan, &global_sketches);
|
||||||
|
|
||||||
std::vector<typename WQSketch::SummaryContainer> final_sketches(n_columns);
|
std::vector<typename WQSketch::SummaryContainer> final_sketches(n_columns);
|
||||||
|
|
||||||
@ -383,11 +380,12 @@ auto AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename WQSketch>
|
template <typename WQSketch>
|
||||||
void SketchContainerImpl<WQSketch>::MakeCuts(MetaInfo const &info, HistogramCuts *p_cuts) {
|
void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const &info,
|
||||||
|
HistogramCuts *p_cuts) {
|
||||||
monitor_.Start(__func__);
|
monitor_.Start(__func__);
|
||||||
std::vector<typename WQSketch::SummaryContainer> reduced;
|
std::vector<typename WQSketch::SummaryContainer> reduced;
|
||||||
std::vector<int32_t> num_cuts;
|
std::vector<int32_t> num_cuts;
|
||||||
this->AllReduce(info, &reduced, &num_cuts);
|
this->AllReduce(ctx, info, &reduced, &num_cuts);
|
||||||
|
|
||||||
p_cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f);
|
p_cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f);
|
||||||
std::vector<typename WQSketch::SummaryContainer> final_summaries(reduced.size());
|
std::vector<typename WQSketch::SummaryContainer> final_summaries(reduced.size());
|
||||||
@ -496,5 +494,4 @@ void SortedSketchContainer::PushColPage(SparsePage const &page, MetaInfo const &
|
|||||||
});
|
});
|
||||||
monitor_.Stop(__func__);
|
monitor_.Stop(__func__);
|
||||||
}
|
}
|
||||||
} // namespace common
|
} // namespace xgboost::common
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -22,9 +22,7 @@
|
|||||||
#include "transform_iterator.h" // MakeIndexTransformIter
|
#include "transform_iterator.h" // MakeIndexTransformIter
|
||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::common {
|
||||||
namespace common {
|
|
||||||
|
|
||||||
using WQSketch = HostSketchContainer::WQSketch;
|
using WQSketch = HostSketchContainer::WQSketch;
|
||||||
using SketchEntry = WQSketch::Entry;
|
using SketchEntry = WQSketch::Entry;
|
||||||
|
|
||||||
@ -501,7 +499,7 @@ void SketchContainer::FixError() {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void SketchContainer::AllReduce(bool is_column_split) {
|
void SketchContainer::AllReduce(Context const*, bool is_column_split) {
|
||||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
||||||
auto world = collective::GetWorldSize();
|
auto world = collective::GetWorldSize();
|
||||||
if (world == 1 || is_column_split) {
|
if (world == 1 || is_column_split) {
|
||||||
@ -582,13 +580,13 @@ struct InvalidCatOp {
|
|||||||
};
|
};
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) {
|
void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool is_column_split) {
|
||||||
timer_.Start(__func__);
|
timer_.Start(__func__);
|
||||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
||||||
p_cuts->min_vals_.Resize(num_columns_);
|
p_cuts->min_vals_.Resize(num_columns_);
|
||||||
|
|
||||||
// Sync between workers.
|
// Sync between workers.
|
||||||
this->AllReduce(is_column_split);
|
this->AllReduce(ctx, is_column_split);
|
||||||
|
|
||||||
// Prune to final number of bins.
|
// Prune to final number of bins.
|
||||||
this->Prune(num_bins_ + 1);
|
this->Prune(num_bins_ + 1);
|
||||||
@ -731,5 +729,4 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) {
|
|||||||
p_cuts->SetCategorical(this->has_categorical_, max_cat);
|
p_cuts->SetCategorical(this->has_categorical_, max_cat);
|
||||||
timer_.Stop(__func__);
|
timer_.Stop(__func__);
|
||||||
}
|
}
|
||||||
} // namespace common
|
} // namespace xgboost::common
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -151,9 +151,9 @@ class SketchContainer {
|
|||||||
Span<SketchEntry const> that);
|
Span<SketchEntry const> that);
|
||||||
|
|
||||||
/* \brief Merge quantiles from other GPU workers. */
|
/* \brief Merge quantiles from other GPU workers. */
|
||||||
void AllReduce(bool is_column_split);
|
void AllReduce(Context const* ctx, bool is_column_split);
|
||||||
/* \brief Create the final histogram cut values. */
|
/* \brief Create the final histogram cut values. */
|
||||||
void MakeCuts(HistogramCuts* cuts, bool is_column_split);
|
void MakeCuts(Context const* ctx, HistogramCuts* cuts, bool is_column_split);
|
||||||
|
|
||||||
Span<SketchEntry const> Data() const {
|
Span<SketchEntry const> Data() const {
|
||||||
return {this->Current().data().get(), this->Current().size()};
|
return {this->Current().data().get(), this->Current().size()};
|
||||||
|
|||||||
@ -827,13 +827,14 @@ class SketchContainerImpl {
|
|||||||
return group_ind;
|
return group_ind;
|
||||||
}
|
}
|
||||||
// Gather sketches from all workers.
|
// Gather sketches from all workers.
|
||||||
void GatherSketchInfo(MetaInfo const& info,
|
void GatherSketchInfo(Context const *ctx, MetaInfo const &info,
|
||||||
std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
||||||
std::vector<bst_row_t> *p_worker_segments,
|
std::vector<bst_row_t> *p_worker_segments,
|
||||||
std::vector<bst_row_t> *p_sketches_scan,
|
std::vector<bst_row_t> *p_sketches_scan,
|
||||||
std::vector<typename WQSketch::Entry> *p_global_sketches);
|
std::vector<typename WQSketch::Entry> *p_global_sketches);
|
||||||
// Merge sketches from all workers.
|
// Merge sketches from all workers.
|
||||||
void AllReduce(MetaInfo const& info, std::vector<typename WQSketch::SummaryContainer> *p_reduced,
|
void AllReduce(Context const *ctx, MetaInfo const &info,
|
||||||
|
std::vector<typename WQSketch::SummaryContainer> *p_reduced,
|
||||||
std::vector<int32_t> *p_num_cuts);
|
std::vector<int32_t> *p_num_cuts);
|
||||||
|
|
||||||
template <typename Batch, typename IsValid>
|
template <typename Batch, typename IsValid>
|
||||||
@ -887,11 +888,11 @@ class SketchContainerImpl {
|
|||||||
/* \brief Push a CSR matrix. */
|
/* \brief Push a CSR matrix. */
|
||||||
void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});
|
void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});
|
||||||
|
|
||||||
void MakeCuts(MetaInfo const& info, HistogramCuts* cuts);
|
void MakeCuts(Context const *ctx, MetaInfo const &info, HistogramCuts *cuts);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Merge all categories from other workers.
|
// Merge all categories from other workers.
|
||||||
void AllreduceCategories(MetaInfo const& info);
|
void AllreduceCategories(Context const* ctx, MetaInfo const& info);
|
||||||
};
|
};
|
||||||
|
|
||||||
class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, float>> {
|
class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, float>> {
|
||||||
|
|||||||
@ -745,7 +745,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MetaInfo::SynchronizeNumberOfColumns() {
|
void MetaInfo::SynchronizeNumberOfColumns(Context const*) {
|
||||||
if (IsColumnSplit()) {
|
if (IsColumnSplit()) {
|
||||||
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
|
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -95,7 +95,7 @@ void GetCutsFromRef(Context const* ctx, std::shared_ptr<DMatrix> ref, bst_featur
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Synchronize feature type in case of empty DMatrix
|
// Synchronize feature type in case of empty DMatrix
|
||||||
void SyncFeatureType(std::vector<FeatureType>* p_h_ft) {
|
void SyncFeatureType(Context const*, std::vector<FeatureType>* p_h_ft) {
|
||||||
if (!collective::IsDistributed()) {
|
if (!collective::IsDistributed()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -193,7 +193,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
|
|||||||
// From here on Info() has the correct data shape
|
// From here on Info() has the correct data shape
|
||||||
Info().num_row_ = accumulated_rows;
|
Info().num_row_ = accumulated_rows;
|
||||||
Info().num_nonzero_ = nnz;
|
Info().num_nonzero_ = nnz;
|
||||||
Info().SynchronizeNumberOfColumns();
|
Info().SynchronizeNumberOfColumns(ctx);
|
||||||
CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) {
|
CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) {
|
||||||
return f > accumulated_rows;
|
return f > accumulated_rows;
|
||||||
})) << "Something went wrong during iteration.";
|
})) << "Something went wrong during iteration.";
|
||||||
@ -213,9 +213,9 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
|
|||||||
while (iter.Next()) {
|
while (iter.Next()) {
|
||||||
if (!p_sketch) {
|
if (!p_sketch) {
|
||||||
h_ft = proxy->Info().feature_types.ConstHostVector();
|
h_ft = proxy->Info().feature_types.ConstHostVector();
|
||||||
SyncFeatureType(&h_ft);
|
SyncFeatureType(ctx, &h_ft);
|
||||||
p_sketch.reset(new common::HostSketchContainer{ctx, p.max_bin, h_ft, column_sizes,
|
p_sketch = std::make_unique<common::HostSketchContainer>(ctx, p.max_bin, h_ft, column_sizes,
|
||||||
!proxy->Info().group_ptr_.empty()});
|
!proxy->Info().group_ptr_.empty());
|
||||||
}
|
}
|
||||||
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
||||||
proxy->Info().num_nonzero_ = batch_nnz[i];
|
proxy->Info().num_nonzero_ = batch_nnz[i];
|
||||||
@ -230,7 +230,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
|
|||||||
CHECK_EQ(accumulated_rows, Info().num_row_);
|
CHECK_EQ(accumulated_rows, Info().num_row_);
|
||||||
|
|
||||||
CHECK(p_sketch);
|
CHECK(p_sketch);
|
||||||
p_sketch->MakeCuts(Info(), &cuts);
|
p_sketch->MakeCuts(ctx, Info(), &cuts);
|
||||||
}
|
}
|
||||||
if (!h_ft.empty()) {
|
if (!h_ft.empty()) {
|
||||||
CHECK_EQ(h_ft.size(), n_features);
|
CHECK_EQ(h_ft.size(), n_features);
|
||||||
|
|||||||
@ -105,7 +105,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
|||||||
sketch_containers.clear();
|
sketch_containers.clear();
|
||||||
sketch_containers.shrink_to_fit();
|
sketch_containers.shrink_to_fit();
|
||||||
|
|
||||||
final_sketch.MakeCuts(&cuts, this->info_.IsColumnSplit());
|
final_sketch.MakeCuts(ctx, &cuts, this->info_.IsColumnSplit());
|
||||||
} else {
|
} else {
|
||||||
GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts);
|
GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts);
|
||||||
}
|
}
|
||||||
@ -167,7 +167,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
|||||||
|
|
||||||
iter.Reset();
|
iter.Reset();
|
||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
info_.SynchronizeNumberOfColumns();
|
info_.SynchronizeNumberOfColumns(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
|
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
|
||||||
|
|||||||
@ -283,7 +283,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
|
|||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
info_.data_split_mode = data_split_mode;
|
info_.data_split_mode = data_split_mode;
|
||||||
ReindexFeatures(&ctx);
|
ReindexFeatures(&ctx);
|
||||||
info_.SynchronizeNumberOfColumns();
|
info_.SynchronizeNumberOfColumns(&ctx);
|
||||||
|
|
||||||
if (adapter->NumRows() == kAdapterUnknownSize) {
|
if (adapter->NumRows() == kAdapterUnknownSize) {
|
||||||
using IteratorAdapterT =
|
using IteratorAdapterT =
|
||||||
|
|||||||
@ -42,7 +42,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, std::int32_t nthr
|
|||||||
info_.num_row_ = adapter->NumRows();
|
info_.num_row_ = adapter->NumRows();
|
||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
info_.data_split_mode = data_split_mode;
|
info_.data_split_mode = data_split_mode;
|
||||||
info_.SynchronizeNumberOfColumns();
|
info_.SynchronizeNumberOfColumns(&ctx);
|
||||||
|
|
||||||
this->fmat_ctx_ = ctx;
|
this->fmat_ctx_ = ctx;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -97,7 +97,7 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
|
|||||||
this->info_.num_col_ = n_features;
|
this->info_.num_col_ = n_features;
|
||||||
this->info_.num_nonzero_ = nnz;
|
this->info_.num_nonzero_ = nnz;
|
||||||
|
|
||||||
info_.SynchronizeNumberOfColumns();
|
info_.SynchronizeNumberOfColumns(&ctx);
|
||||||
CHECK_NE(info_.num_col_, 0);
|
CHECK_NE(info_.num_col_, 0);
|
||||||
|
|
||||||
fmat_ctx_ = ctx;
|
fmat_ctx_ = ctx;
|
||||||
|
|||||||
@ -209,7 +209,7 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
|
|||||||
return dmlc::Parameter<LearnerModelParamLegacy>::UpdateAllowUnknown(kwargs);
|
return dmlc::Parameter<LearnerModelParamLegacy>::UpdateAllowUnknown(kwargs);
|
||||||
}
|
}
|
||||||
// sanity check
|
// sanity check
|
||||||
void Validate() {
|
void Validate(Context const*) {
|
||||||
if (!collective::IsDistributed()) {
|
if (!collective::IsDistributed()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -434,7 +434,7 @@ class LearnerConfiguration : public Learner {
|
|||||||
}
|
}
|
||||||
// Update the shared model parameter
|
// Update the shared model parameter
|
||||||
this->ConfigureModelParamWithoutBaseScore();
|
this->ConfigureModelParamWithoutBaseScore();
|
||||||
mparam_.Validate();
|
mparam_.Validate(&ctx_);
|
||||||
}
|
}
|
||||||
CHECK(!std::isnan(mparam_.base_score));
|
CHECK(!std::isnan(mparam_.base_score));
|
||||||
CHECK(!std::isinf(mparam_.base_score));
|
CHECK(!std::isinf(mparam_.base_score));
|
||||||
|
|||||||
@ -199,9 +199,9 @@ void Transpose(common::Span<float const> in, common::Span<float> out, size_t m,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
double ScaleClasses(common::Span<double> results, common::Span<double> local_area,
|
double ScaleClasses(Context const *ctx, common::Span<double> results,
|
||||||
common::Span<double> tp, common::Span<double> auc, size_t n_classes) {
|
common::Span<double> local_area, common::Span<double> tp,
|
||||||
dh::XGBDeviceAllocator<char> alloc;
|
common::Span<double> auc, size_t n_classes) {
|
||||||
if (collective::IsDistributed()) {
|
if (collective::IsDistributed()) {
|
||||||
int32_t device = dh::CurrentDevice();
|
int32_t device = dh::CurrentDevice();
|
||||||
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
|
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
|
||||||
@ -218,8 +218,8 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
|
|||||||
double tp_sum;
|
double tp_sum;
|
||||||
double auc_sum;
|
double auc_sum;
|
||||||
thrust::tie(auc_sum, tp_sum) =
|
thrust::tie(auc_sum, tp_sum) =
|
||||||
thrust::reduce(thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes,
|
thrust::reduce(ctx->CUDACtx()->CTP(), reduce_in, reduce_in + n_classes, Pair{0.0, 0.0},
|
||||||
Pair{0.0, 0.0}, PairPlus<double, double>{});
|
PairPlus<double, double>{});
|
||||||
if (tp_sum != 0 && !std::isnan(auc_sum)) {
|
if (tp_sum != 0 && !std::isnan(auc_sum)) {
|
||||||
auc_sum /= tp_sum;
|
auc_sum /= tp_sum;
|
||||||
} else {
|
} else {
|
||||||
@ -309,10 +309,10 @@ void SegmentedReduceAUC(common::Span<size_t const> d_unique_idx,
|
|||||||
* up each class in all kernels.
|
* up each class in all kernels.
|
||||||
*/
|
*/
|
||||||
template <bool scale, typename Fn>
|
template <bool scale, typename Fn>
|
||||||
double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
|
double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info,
|
||||||
common::Span<uint32_t> d_class_ptr, size_t n_classes,
|
common::Span<uint32_t> d_class_ptr, size_t n_classes,
|
||||||
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
|
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
|
||||||
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
|
||||||
/**
|
/**
|
||||||
* Sorted idx
|
* Sorted idx
|
||||||
*/
|
*/
|
||||||
@ -320,7 +320,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
|
|||||||
// Index is sorted within class.
|
// Index is sorted within class.
|
||||||
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
|
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
|
||||||
|
|
||||||
auto labels = info.labels.View(device);
|
auto labels = info.labels.View(ctx->Device());
|
||||||
auto weights = info.weights_.ConstDeviceSpan();
|
auto weights = info.weights_.ConstDeviceSpan();
|
||||||
|
|
||||||
size_t n_samples = labels.Shape(0);
|
size_t n_samples = labels.Shape(0);
|
||||||
@ -328,12 +328,11 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
|
|||||||
if (n_samples == 0) {
|
if (n_samples == 0) {
|
||||||
dh::TemporaryArray<double> resutls(n_classes * 4, 0.0f);
|
dh::TemporaryArray<double> resutls(n_classes * 4, 0.0f);
|
||||||
auto d_results = dh::ToSpan(resutls);
|
auto d_results = dh::ToSpan(resutls);
|
||||||
dh::LaunchN(n_classes * 4,
|
dh::LaunchN(n_classes * 4, [=] XGBOOST_DEVICE(size_t i) { d_results[i] = 0.0f; });
|
||||||
[=] XGBOOST_DEVICE(size_t i) { d_results[i] = 0.0f; });
|
|
||||||
auto local_area = d_results.subspan(0, n_classes);
|
auto local_area = d_results.subspan(0, n_classes);
|
||||||
auto tp = d_results.subspan(2 * n_classes, n_classes);
|
auto tp = d_results.subspan(2 * n_classes, n_classes);
|
||||||
auto auc = d_results.subspan(3 * n_classes, n_classes);
|
auto auc = d_results.subspan(3 * n_classes, n_classes);
|
||||||
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
|
return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -437,7 +436,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
|
|||||||
tp[c] = 1.0f;
|
tp[c] = 1.0f;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
|
return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MultiClassSortedIdx(Context const *ctx, common::Span<float const> predts,
|
void MultiClassSortedIdx(Context const *ctx, common::Span<float const> predts,
|
||||||
@ -472,8 +471,7 @@ double GPUMultiClassROCAUC(Context const *ctx, common::Span<float const> predts,
|
|||||||
size_t /*class_id*/) {
|
size_t /*class_id*/) {
|
||||||
return TrapezoidArea(fp_prev, fp, tp_prev, tp);
|
return TrapezoidArea(fp_prev, fp, tp_prev, tp);
|
||||||
};
|
};
|
||||||
return GPUMultiClassAUCOVR<true>(info, ctx->Device(), dh::ToSpan(class_ptr), n_classes, cache,
|
return GPUMultiClassAUCOVR<true>(ctx, info, dh::ToSpan(class_ptr), n_classes, cache, fn);
|
||||||
fn);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -697,7 +695,7 @@ double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
|
|||||||
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
|
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
|
||||||
d_totals[class_id].first);
|
d_totals[class_id].first);
|
||||||
};
|
};
|
||||||
return GPUMultiClassAUCOVR<false>(info, ctx->Device(), d_class_ptr, n_classes, cache, fn);
|
return GPUMultiClassAUCOVR<false>(ctx, info, d_class_ptr, n_classes, cache, fn);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Fn>
|
template <typename Fn>
|
||||||
|
|||||||
@ -215,7 +215,7 @@ struct EvalError {
|
|||||||
has_param_ = false;
|
has_param_ = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const char *Name() const {
|
[[nodiscard]] const char *Name() const {
|
||||||
static thread_local std::string name;
|
static thread_local std::string name;
|
||||||
if (has_param_) {
|
if (has_param_) {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
@ -228,7 +228,7 @@ struct EvalError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
|
[[nodiscard]] XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
|
||||||
// assume label is in [0,1]
|
// assume label is in [0,1]
|
||||||
return pred > threshold_ ? 1.0f - label : label;
|
return pred > threshold_ ? 1.0f - label : label;
|
||||||
}
|
}
|
||||||
@ -370,7 +370,7 @@ struct EvalEWiseBase : public MetricNoCache {
|
|||||||
return Policy::GetFinal(dat[0], dat[1]);
|
return Policy::GetFinal(dat[0], dat[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* Name() const override { return policy_.Name(); }
|
[[nodiscard]] const char* Name() const override { return policy_.Name(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Policy policy_;
|
Policy policy_;
|
||||||
|
|||||||
@ -162,7 +162,7 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig {
|
|||||||
return collective::GlobalRatio(info, sum_metric, static_cast<double>(ngroups));
|
return collective::GlobalRatio(info, sum_metric, static_cast<double>(ngroups));
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* Name() const override {
|
[[nodiscard]] const char* Name() const override {
|
||||||
return name.c_str();
|
return name.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -294,7 +294,7 @@ class EvalRankWithCache : public Metric {
|
|||||||
};
|
};
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
double Finalize(MetaInfo const& info, double score, double sw) {
|
double Finalize(Context const*, MetaInfo const& info, double score, double sw) {
|
||||||
std::array<double, 2> dat{score, sw};
|
std::array<double, 2> dat{score, sw};
|
||||||
collective::GlobalSum(info, &dat);
|
collective::GlobalSum(info, &dat);
|
||||||
std::tie(score, sw) = std::tuple_cat(dat);
|
std::tie(score, sw) = std::tuple_cat(dat);
|
||||||
@ -323,7 +323,7 @@ class EvalPrecision : public EvalRankWithCache<ltr::PreCache> {
|
|||||||
|
|
||||||
if (ctx_->IsCUDA()) {
|
if (ctx_->IsCUDA()) {
|
||||||
auto pre = cuda_impl::PreScore(ctx_, info, predt, p_cache);
|
auto pre = cuda_impl::PreScore(ctx_, info, predt, p_cache);
|
||||||
return Finalize(info, pre.Residue(), pre.Weights());
|
return Finalize(ctx_, info, pre.Residue(), pre.Weights());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto gptr = p_cache->DataGroupPtr(ctx_);
|
auto gptr = p_cache->DataGroupPtr(ctx_);
|
||||||
@ -352,7 +352,7 @@ class EvalPrecision : public EvalRankWithCache<ltr::PreCache> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto sum = std::accumulate(pre.cbegin(), pre.cend(), 0.0);
|
auto sum = std::accumulate(pre.cbegin(), pre.cend(), 0.0);
|
||||||
return Finalize(info, sum, sw);
|
return Finalize(ctx_, info, sum, sw);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -369,7 +369,7 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
|
|||||||
std::shared_ptr<ltr::NDCGCache> p_cache) override {
|
std::shared_ptr<ltr::NDCGCache> p_cache) override {
|
||||||
if (ctx_->IsCUDA()) {
|
if (ctx_->IsCUDA()) {
|
||||||
auto ndcg = cuda_impl::NDCGScore(ctx_, info, preds, minus_, p_cache);
|
auto ndcg = cuda_impl::NDCGScore(ctx_, info, preds, minus_, p_cache);
|
||||||
return Finalize(info, ndcg.Residue(), ndcg.Weights());
|
return Finalize(ctx_, info, ndcg.Residue(), ndcg.Weights());
|
||||||
}
|
}
|
||||||
|
|
||||||
// group local ndcg
|
// group local ndcg
|
||||||
@ -415,7 +415,7 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
|
|||||||
sum_w = std::accumulate(weights.weights.cbegin(), weights.weights.cend(), 0.0);
|
sum_w = std::accumulate(weights.weights.cbegin(), weights.weights.cend(), 0.0);
|
||||||
}
|
}
|
||||||
auto ndcg = std::accumulate(linalg::cbegin(ndcg_gloc), linalg::cend(ndcg_gloc), 0.0);
|
auto ndcg = std::accumulate(linalg::cbegin(ndcg_gloc), linalg::cend(ndcg_gloc), 0.0);
|
||||||
return Finalize(info, ndcg, sum_w);
|
return Finalize(ctx_, info, ndcg, sum_w);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -427,7 +427,7 @@ class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
|
|||||||
std::shared_ptr<ltr::MAPCache> p_cache) override {
|
std::shared_ptr<ltr::MAPCache> p_cache) override {
|
||||||
if (ctx_->IsCUDA()) {
|
if (ctx_->IsCUDA()) {
|
||||||
auto map = cuda_impl::MAPScore(ctx_, info, predt, minus_, p_cache);
|
auto map = cuda_impl::MAPScore(ctx_, info, predt, minus_, p_cache);
|
||||||
return Finalize(info, map.Residue(), map.Weights());
|
return Finalize(ctx_, info, map.Residue(), map.Weights());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto gptr = p_cache->DataGroupPtr(ctx_);
|
auto gptr = p_cache->DataGroupPtr(ctx_);
|
||||||
@ -469,7 +469,7 @@ class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
|
|||||||
sw += weight[i];
|
sw += weight[i];
|
||||||
}
|
}
|
||||||
auto sum = std::accumulate(map_gloc.cbegin(), map_gloc.cend(), 0.0);
|
auto sum = std::accumulate(map_gloc.cbegin(), map_gloc.cend(), 0.0);
|
||||||
return Finalize(info, sum, sw);
|
return Finalize(ctx_, info, sum, sw);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -217,7 +217,7 @@ struct EvalEWiseSurvivalBase : public MetricNoCache {
|
|||||||
return Policy::GetFinal(dat[0], dat[1]);
|
return Policy::GetFinal(dat[0], dat[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* Name() const override {
|
[[nodiscard]] const char* Name() const override {
|
||||||
return policy_.Name();
|
return policy_.Name();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -189,7 +189,7 @@ struct SparsePageView {
|
|||||||
|
|
||||||
explicit SparsePageView(SparsePage const *p) : base_rowid{p->base_rowid} { view = p->GetView(); }
|
explicit SparsePageView(SparsePage const *p) : base_rowid{p->base_rowid} { view = p->GetView(); }
|
||||||
SparsePage::Inst operator[](size_t i) { return view[i]; }
|
SparsePage::Inst operator[](size_t i) { return view[i]; }
|
||||||
size_t Size() const { return view.Size(); }
|
[[nodiscard]] size_t Size() const { return view.Size(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SingleInstanceView {
|
struct SingleInstanceView {
|
||||||
@ -250,7 +250,7 @@ struct GHistIndexMatrixView {
|
|||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
size_t Size() const { return page_.Size(); }
|
[[nodiscard]] size_t Size() const { return page_.Size(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Adapter>
|
template <typename Adapter>
|
||||||
@ -290,7 +290,7 @@ class AdapterView {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Size() const { return adapter_->NumRows(); }
|
[[nodiscard]] size_t Size() const { return adapter_->NumRows(); }
|
||||||
|
|
||||||
bst_row_t const static base_rowid = 0; // NOLINT
|
bst_row_t const static base_rowid = 0; // NOLINT
|
||||||
};
|
};
|
||||||
@ -408,31 +408,33 @@ class ColumnSplitHelper {
|
|||||||
ColumnSplitHelper(ColumnSplitHelper &&) noexcept = delete;
|
ColumnSplitHelper(ColumnSplitHelper &&) noexcept = delete;
|
||||||
ColumnSplitHelper &operator=(ColumnSplitHelper &&) noexcept = delete;
|
ColumnSplitHelper &operator=(ColumnSplitHelper &&) noexcept = delete;
|
||||||
|
|
||||||
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
|
void PredictDMatrix(Context const *ctx, DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
|
||||||
CHECK(xgboost::collective::IsDistributed())
|
CHECK(xgboost::collective::IsDistributed())
|
||||||
<< "column-split prediction is only supported for distributed training";
|
<< "column-split prediction is only supported for distributed training";
|
||||||
|
|
||||||
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
|
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
CHECK_EQ(out_preds->size(),
|
CHECK_EQ(out_preds->size(),
|
||||||
p_fmat->Info().num_row_ * model_.learner_model_param->num_output_group);
|
p_fmat->Info().num_row_ * model_.learner_model_param->num_output_group);
|
||||||
PredictBatchKernel<SparsePageView, kBlockOfRowsSize>(SparsePageView{&batch}, out_preds);
|
PredictBatchKernel<SparsePageView, kBlockOfRowsSize>(ctx, SparsePageView{&batch}, out_preds);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredictInstance(SparsePage::Inst const &inst, std::vector<bst_float> *out_preds) {
|
void PredictInstance(Context const *ctx, SparsePage::Inst const &inst,
|
||||||
|
std::vector<bst_float> *out_preds) {
|
||||||
CHECK(xgboost::collective::IsDistributed())
|
CHECK(xgboost::collective::IsDistributed())
|
||||||
<< "column-split prediction is only supported for distributed training";
|
<< "column-split prediction is only supported for distributed training";
|
||||||
|
|
||||||
PredictBatchKernel<SingleInstanceView, 1>(SingleInstanceView{inst}, out_preds);
|
PredictBatchKernel<SingleInstanceView, 1>(ctx, SingleInstanceView{inst}, out_preds);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredictLeaf(DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
|
void PredictLeaf(Context const* ctx, DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
|
||||||
CHECK(xgboost::collective::IsDistributed())
|
CHECK(xgboost::collective::IsDistributed())
|
||||||
<< "column-split prediction is only supported for distributed training";
|
<< "column-split prediction is only supported for distributed training";
|
||||||
|
|
||||||
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
|
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
CHECK_EQ(out_preds->size(), p_fmat->Info().num_row_ * (tree_end_ - tree_begin_));
|
CHECK_EQ(out_preds->size(), p_fmat->Info().num_row_ * (tree_end_ - tree_begin_));
|
||||||
PredictBatchKernel<SparsePageView, kBlockOfRowsSize, true>(SparsePageView{&batch}, out_preds);
|
PredictBatchKernel<SparsePageView, kBlockOfRowsSize, true>(ctx, SparsePageView{&batch},
|
||||||
|
out_preds);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -453,12 +455,13 @@ class ColumnSplitHelper {
|
|||||||
std::fill(missing_storage_.begin(), missing_storage_.end(), 0);
|
std::fill(missing_storage_.begin(), missing_storage_.end(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::size_t BitIndex(std::size_t tree_id, std::size_t row_id, std::size_t node_id) const {
|
[[nodiscard]] std::size_t BitIndex(std::size_t tree_id, std::size_t row_id,
|
||||||
|
std::size_t node_id) const {
|
||||||
size_t tree_index = tree_id - tree_begin_;
|
size_t tree_index = tree_id - tree_begin_;
|
||||||
return tree_offsets_[tree_index] * n_rows_ + row_id * tree_sizes_[tree_index] + node_id;
|
return tree_offsets_[tree_index] * n_rows_ + row_id * tree_sizes_[tree_index] + node_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllreduceBitVectors() {
|
void AllreduceBitVectors(Context const*) {
|
||||||
collective::Allreduce<collective::Operation::kBitwiseOR>(decision_storage_.data(),
|
collective::Allreduce<collective::Operation::kBitwiseOR>(decision_storage_.data(),
|
||||||
decision_storage_.size());
|
decision_storage_.size());
|
||||||
collective::Allreduce<collective::Operation::kBitwiseAND>(missing_storage_.data(),
|
collective::Allreduce<collective::Operation::kBitwiseAND>(missing_storage_.data(),
|
||||||
@ -547,7 +550,7 @@ class ColumnSplitHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename DataView, size_t block_of_rows_size, bool predict_leaf = false>
|
template <typename DataView, size_t block_of_rows_size, bool predict_leaf = false>
|
||||||
void PredictBatchKernel(DataView batch, std::vector<bst_float> *out_preds) {
|
void PredictBatchKernel(Context const* ctx, DataView batch, std::vector<bst_float> *out_preds) {
|
||||||
auto const num_group = model_.learner_model_param->num_output_group;
|
auto const num_group = model_.learner_model_param->num_output_group;
|
||||||
|
|
||||||
// parallel over local batch
|
// parallel over local batch
|
||||||
@ -568,7 +571,7 @@ class ColumnSplitHelper {
|
|||||||
FVecDrop(block_size, fvec_offset, &feat_vecs_);
|
FVecDrop(block_size, fvec_offset, &feat_vecs_);
|
||||||
});
|
});
|
||||||
|
|
||||||
AllreduceBitVectors();
|
AllreduceBitVectors(ctx);
|
||||||
|
|
||||||
// auto block_id has the same type as `n_blocks`.
|
// auto block_id has the same type as `n_blocks`.
|
||||||
common::ParallelFor(n_blocks, n_threads_, [&](auto block_id) {
|
common::ParallelFor(n_blocks, n_threads_, [&](auto block_id) {
|
||||||
@ -646,7 +649,7 @@ class CPUPredictor : public Predictor {
|
|||||||
<< "Predict DMatrix with column split" << MTNotImplemented();
|
<< "Predict DMatrix with column split" << MTNotImplemented();
|
||||||
|
|
||||||
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
|
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
|
||||||
helper.PredictDMatrix(p_fmat, out_preds);
|
helper.PredictDMatrix(ctx_, p_fmat, out_preds);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -779,7 +782,7 @@ class CPUPredictor : public Predictor {
|
|||||||
<< "Predict instance with column split" << MTNotImplemented();
|
<< "Predict instance with column split" << MTNotImplemented();
|
||||||
|
|
||||||
ColumnSplitHelper helper(this->ctx_->Threads(), model, 0, ntree_limit);
|
ColumnSplitHelper helper(this->ctx_->Threads(), model, 0, ntree_limit);
|
||||||
helper.PredictInstance(inst, out_preds);
|
helper.PredictInstance(ctx_, inst, out_preds);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -811,7 +814,7 @@ class CPUPredictor : public Predictor {
|
|||||||
<< "Predict leaf with column split" << MTNotImplemented();
|
<< "Predict leaf with column split" << MTNotImplemented();
|
||||||
|
|
||||||
ColumnSplitHelper helper(n_threads, model, 0, ntree_limit);
|
ColumnSplitHelper helper(n_threads, model, 0, ntree_limit);
|
||||||
helper.PredictLeaf(p_fmat, &preds);
|
helper.PredictLeaf(ctx_, p_fmat, &preds);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -62,9 +62,7 @@ struct TreeView {
|
|||||||
cats.node_ptr = tree_cat_ptrs;
|
cats.node_ptr = tree_cat_ptrs;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ bool HasCategoricalSplit() const {
|
[[nodiscard]] __device__ bool HasCategoricalSplit() const { return !cats.categories.empty(); }
|
||||||
return !cats.categories.empty();
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SparsePageView {
|
struct SparsePageView {
|
||||||
@ -77,7 +75,7 @@ struct SparsePageView {
|
|||||||
common::Span<const bst_row_t> row_ptr,
|
common::Span<const bst_row_t> row_ptr,
|
||||||
bst_feature_t num_features)
|
bst_feature_t num_features)
|
||||||
: d_data{data}, d_row_ptr{row_ptr}, num_features(num_features) {}
|
: d_data{data}, d_row_ptr{row_ptr}, num_features(num_features) {}
|
||||||
__device__ float GetElement(size_t ridx, size_t fidx) const {
|
[[nodiscard]] __device__ float GetElement(size_t ridx, size_t fidx) const {
|
||||||
// Binary search
|
// Binary search
|
||||||
auto begin_ptr = d_data.begin() + d_row_ptr[ridx];
|
auto begin_ptr = d_data.begin() + d_row_ptr[ridx];
|
||||||
auto end_ptr = d_data.begin() + d_row_ptr[ridx + 1];
|
auto end_ptr = d_data.begin() + d_row_ptr[ridx + 1];
|
||||||
@ -105,8 +103,8 @@ struct SparsePageView {
|
|||||||
// Value is missing
|
// Value is missing
|
||||||
return nanf("");
|
return nanf("");
|
||||||
}
|
}
|
||||||
XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; }
|
[[nodiscard]] XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; }
|
||||||
XGBOOST_DEVICE size_t NumCols() const { return num_features; }
|
[[nodiscard]] XGBOOST_DEVICE size_t NumCols() const { return num_features; }
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SparsePageLoader {
|
struct SparsePageLoader {
|
||||||
@ -137,7 +135,7 @@ struct SparsePageLoader {
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__device__ float GetElement(size_t ridx, size_t fidx) const {
|
[[nodiscard]] __device__ float GetElement(size_t ridx, size_t fidx) const {
|
||||||
if (use_shared) {
|
if (use_shared) {
|
||||||
return smem[threadIdx.x * data.num_features + fidx];
|
return smem[threadIdx.x * data.num_features + fidx];
|
||||||
} else {
|
} else {
|
||||||
@ -151,7 +149,7 @@ struct EllpackLoader {
|
|||||||
XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor const& m, bool, bst_feature_t, bst_row_t,
|
XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor const& m, bool, bst_feature_t, bst_row_t,
|
||||||
size_t, float)
|
size_t, float)
|
||||||
: matrix{m} {}
|
: matrix{m} {}
|
||||||
__device__ __forceinline__ float GetElement(size_t ridx, size_t fidx) const {
|
[[nodiscard]] __device__ __forceinline__ float GetElement(size_t ridx, size_t fidx) const {
|
||||||
auto gidx = matrix.GetBinIndex(ridx, fidx);
|
auto gidx = matrix.GetBinIndex(ridx, fidx);
|
||||||
if (gidx == -1) {
|
if (gidx == -1) {
|
||||||
return nan("");
|
return nan("");
|
||||||
|
|||||||
@ -395,11 +395,11 @@ void GPUHistEvaluator::CopyToHost(const std::vector<bst_node_t> &nidx) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUHistEvaluator::EvaluateSplits(
|
void GPUHistEvaluator::EvaluateSplits(Context const *ctx, const std::vector<bst_node_t> &nidx,
|
||||||
const std::vector<bst_node_t> &nidx, bst_feature_t max_active_features,
|
bst_feature_t max_active_features,
|
||||||
common::Span<const EvaluateSplitInputs> d_inputs,
|
common::Span<const EvaluateSplitInputs> d_inputs,
|
||||||
EvaluateSplitSharedInputs shared_inputs,
|
EvaluateSplitSharedInputs shared_inputs,
|
||||||
common::Span<GPUExpandEntry> out_entries) {
|
common::Span<GPUExpandEntry> out_entries) {
|
||||||
auto evaluator = this->tree_evaluator_.template GetEvaluator<GPUTrainingParam>();
|
auto evaluator = this->tree_evaluator_.template GetEvaluator<GPUTrainingParam>();
|
||||||
|
|
||||||
dh::TemporaryArray<DeviceSplitCandidate> splits_out_storage(d_inputs.size());
|
dh::TemporaryArray<DeviceSplitCandidate> splits_out_storage(d_inputs.size());
|
||||||
@ -417,19 +417,20 @@ void GPUHistEvaluator::EvaluateSplits(
|
|||||||
out_splits.size() * sizeof(DeviceSplitCandidate));
|
out_splits.size() * sizeof(DeviceSplitCandidate));
|
||||||
|
|
||||||
// Reduce to get the best candidate from all workers.
|
// Reduce to get the best candidate from all workers.
|
||||||
dh::LaunchN(out_splits.size(), [world_size, all_candidates, out_splits] __device__(size_t i) {
|
dh::LaunchN(out_splits.size(), ctx->CUDACtx()->Stream(),
|
||||||
out_splits[i] = all_candidates[i];
|
[world_size, all_candidates, out_splits] __device__(size_t i) {
|
||||||
for (auto rank = 1; rank < world_size; rank++) {
|
out_splits[i] = all_candidates[i];
|
||||||
out_splits[i] = out_splits[i] + all_candidates[rank * out_splits.size() + i];
|
for (auto rank = 1; rank < world_size; rank++) {
|
||||||
}
|
out_splits[i] = out_splits[i] + all_candidates[rank * out_splits.size() + i];
|
||||||
});
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
auto d_sorted_idx = this->SortedIdx(d_inputs.size(), shared_inputs.feature_values.size());
|
auto d_sorted_idx = this->SortedIdx(d_inputs.size(), shared_inputs.feature_values.size());
|
||||||
auto d_entries = out_entries;
|
auto d_entries = out_entries;
|
||||||
auto device_cats_accessor = this->DeviceCatStorage(nidx);
|
auto device_cats_accessor = this->DeviceCatStorage(nidx);
|
||||||
// turn candidate into entry, along with handling sort based split.
|
// turn candidate into entry, along with handling sort based split.
|
||||||
dh::LaunchN(d_inputs.size(), [=] __device__(size_t i) mutable {
|
dh::LaunchN(d_inputs.size(), ctx->CUDACtx()->Stream(), [=] __device__(size_t i) mutable {
|
||||||
auto const input = d_inputs[i];
|
auto const input = d_inputs[i];
|
||||||
auto &split = out_splits[i];
|
auto &split = out_splits[i];
|
||||||
// Subtract parent gain here
|
// Subtract parent gain here
|
||||||
@ -464,12 +465,12 @@ void GPUHistEvaluator::EvaluateSplits(
|
|||||||
this->CopyToHost(nidx);
|
this->CopyToHost(nidx);
|
||||||
}
|
}
|
||||||
|
|
||||||
GPUExpandEntry GPUHistEvaluator::EvaluateSingleSplit(
|
GPUExpandEntry GPUHistEvaluator::EvaluateSingleSplit(Context const *ctx, EvaluateSplitInputs input,
|
||||||
EvaluateSplitInputs input, EvaluateSplitSharedInputs shared_inputs) {
|
EvaluateSplitSharedInputs shared_inputs) {
|
||||||
dh::device_vector<EvaluateSplitInputs> inputs = std::vector<EvaluateSplitInputs>{input};
|
dh::device_vector<EvaluateSplitInputs> inputs = std::vector<EvaluateSplitInputs>{input};
|
||||||
dh::TemporaryArray<GPUExpandEntry> out_entries(1);
|
dh::TemporaryArray<GPUExpandEntry> out_entries(1);
|
||||||
this->EvaluateSplits({input.nidx}, input.feature_set.size(), dh::ToSpan(inputs), shared_inputs,
|
this->EvaluateSplits(ctx, {input.nidx}, input.feature_set.size(), dh::ToSpan(inputs),
|
||||||
dh::ToSpan(out_entries));
|
shared_inputs, dh::ToSpan(out_entries));
|
||||||
GPUExpandEntry root_entry;
|
GPUExpandEntry root_entry;
|
||||||
dh::safe_cuda(cudaMemcpyAsync(&root_entry, out_entries.data().get(), sizeof(GPUExpandEntry),
|
dh::safe_cuda(cudaMemcpyAsync(&root_entry, out_entries.data().get(), sizeof(GPUExpandEntry),
|
||||||
cudaMemcpyDeviceToHost));
|
cudaMemcpyDeviceToHost));
|
||||||
|
|||||||
@ -193,7 +193,7 @@ class GPUHistEvaluator {
|
|||||||
/**
|
/**
|
||||||
* \brief Evaluate splits for left and right nodes.
|
* \brief Evaluate splits for left and right nodes.
|
||||||
*/
|
*/
|
||||||
void EvaluateSplits(const std::vector<bst_node_t> &nidx,
|
void EvaluateSplits(Context const* ctx, const std::vector<bst_node_t> &nidx,
|
||||||
bst_feature_t max_active_features,
|
bst_feature_t max_active_features,
|
||||||
common::Span<const EvaluateSplitInputs> d_inputs,
|
common::Span<const EvaluateSplitInputs> d_inputs,
|
||||||
EvaluateSplitSharedInputs shared_inputs,
|
EvaluateSplitSharedInputs shared_inputs,
|
||||||
@ -201,7 +201,7 @@ class GPUHistEvaluator {
|
|||||||
/**
|
/**
|
||||||
* \brief Evaluate splits for root node.
|
* \brief Evaluate splits for root node.
|
||||||
*/
|
*/
|
||||||
GPUExpandEntry EvaluateSingleSplit(EvaluateSplitInputs input,
|
GPUExpandEntry EvaluateSingleSplit(Context const *ctx, EvaluateSplitInputs input,
|
||||||
EvaluateSplitSharedInputs shared_inputs);
|
EvaluateSplitSharedInputs shared_inputs);
|
||||||
};
|
};
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
|
|||||||
@ -16,8 +16,7 @@
|
|||||||
#include "row_partitioner.cuh"
|
#include "row_partitioner.cuh"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::tree {
|
||||||
namespace tree {
|
|
||||||
namespace {
|
namespace {
|
||||||
struct Pair {
|
struct Pair {
|
||||||
GradientPair first;
|
GradientPair first;
|
||||||
@ -53,7 +52,8 @@ struct Clip : public thrust::unary_function<GradientPair, Pair> {
|
|||||||
*
|
*
|
||||||
* to avoid outliers, as the full reduction is reproducible on GPU with reduction tree.
|
* to avoid outliers, as the full reduction is reproducible on GPU with reduction tree.
|
||||||
*/
|
*/
|
||||||
GradientQuantiser::GradientQuantiser(common::Span<GradientPair const> gpair, MetaInfo const& info) {
|
GradientQuantiser::GradientQuantiser(Context const*, common::Span<GradientPair const> gpair,
|
||||||
|
MetaInfo const& info) {
|
||||||
using GradientSumT = GradientPairPrecise;
|
using GradientSumT = GradientPairPrecise;
|
||||||
using T = typename GradientSumT::ValueT;
|
using T = typename GradientSumT::ValueT;
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
@ -99,7 +99,6 @@ GradientQuantiser::GradientQuantiser(common::Span<GradientPair const> gpair, Met
|
|||||||
static_cast<T>(1) / to_floating_point_.GetHess());
|
static_cast<T>(1) / to_floating_point_.GetHess());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
XGBOOST_DEV_INLINE void
|
XGBOOST_DEV_INLINE void
|
||||||
AtomicAddGpairShared(xgboost::GradientPairInt64 *dest,
|
AtomicAddGpairShared(xgboost::GradientPairInt64 *dest,
|
||||||
xgboost::GradientPairInt64 const &gpair) {
|
xgboost::GradientPairInt64 const &gpair) {
|
||||||
@ -314,6 +313,4 @@ void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const&
|
|||||||
|
|
||||||
dh::safe_cuda(cudaGetLastError());
|
dh::safe_cuda(cudaGetLastError());
|
||||||
}
|
}
|
||||||
|
} // namespace xgboost::tree
|
||||||
} // namespace tree
|
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -39,18 +39,20 @@ private:
|
|||||||
GradientPairPrecise to_floating_point_;
|
GradientPairPrecise to_floating_point_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
GradientQuantiser(common::Span<GradientPair const> gpair, MetaInfo const& info);
|
GradientQuantiser(Context const* ctx, common::Span<GradientPair const> gpair, MetaInfo const& info);
|
||||||
XGBOOST_DEVICE GradientPairInt64 ToFixedPoint(GradientPair const& gpair) const {
|
[[nodiscard]] XGBOOST_DEVICE GradientPairInt64 ToFixedPoint(GradientPair const& gpair) const {
|
||||||
auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(),
|
auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(),
|
||||||
gpair.GetHess() * to_fixed_point_.GetHess());
|
gpair.GetHess() * to_fixed_point_.GetHess());
|
||||||
return adjusted;
|
return adjusted;
|
||||||
}
|
}
|
||||||
XGBOOST_DEVICE GradientPairInt64 ToFixedPoint(GradientPairPrecise const& gpair) const {
|
[[nodiscard]] XGBOOST_DEVICE GradientPairInt64
|
||||||
|
ToFixedPoint(GradientPairPrecise const& gpair) const {
|
||||||
auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(),
|
auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(),
|
||||||
gpair.GetHess() * to_fixed_point_.GetHess());
|
gpair.GetHess() * to_fixed_point_.GetHess());
|
||||||
return adjusted;
|
return adjusted;
|
||||||
}
|
}
|
||||||
XGBOOST_DEVICE GradientPairPrecise ToFloatingPoint(const GradientPairInt64&gpair) const {
|
[[nodiscard]] XGBOOST_DEVICE GradientPairPrecise
|
||||||
|
ToFloatingPoint(const GradientPairInt64& gpair) const {
|
||||||
auto g = gpair.GetQuantisedGrad() * to_floating_point_.GetGrad();
|
auto g = gpair.GetQuantisedGrad() * to_floating_point_.GetGrad();
|
||||||
auto h = gpair.GetQuantisedHess() * to_floating_point_.GetHess();
|
auto h = gpair.GetQuantisedHess() * to_floating_point_.GetHess();
|
||||||
return {g,h};
|
return {g,h};
|
||||||
|
|||||||
@ -171,7 +171,8 @@ class HistogramBuilder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SyncHistogram(RegTree const *p_tree, std::vector<bst_node_t> const &nodes_to_build,
|
void SyncHistogram(Context const *, RegTree const *p_tree,
|
||||||
|
std::vector<bst_node_t> const &nodes_to_build,
|
||||||
std::vector<bst_node_t> const &nodes_to_trick) {
|
std::vector<bst_node_t> const &nodes_to_trick) {
|
||||||
auto n_total_bins = buffer_.TotalBins();
|
auto n_total_bins = buffer_.TotalBins();
|
||||||
common::BlockedSpace2d space(
|
common::BlockedSpace2d space(
|
||||||
@ -277,14 +278,14 @@ class MultiHistogramBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (bst_target_t t = 0; t < p_tree->NumTargets(); ++t) {
|
for (bst_target_t t = 0; t < p_tree->NumTargets(); ++t) {
|
||||||
this->target_builders_[t].SyncHistogram(p_tree, nodes, dummy_sub);
|
this->target_builders_[t].SyncHistogram(ctx_, p_tree, nodes, dummy_sub);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* @brief Build histogram for left and right child of valid candidates
|
* @brief Build histogram for left and right child of valid candidates
|
||||||
*/
|
*/
|
||||||
template <typename Partitioner, typename ExpandEntry>
|
template <typename Partitioner, typename ExpandEntry>
|
||||||
void BuildHistLeftRight(DMatrix *p_fmat, RegTree const *p_tree,
|
void BuildHistLeftRight(Context const *ctx, DMatrix *p_fmat, RegTree const *p_tree,
|
||||||
std::vector<Partitioner> const &partitioners,
|
std::vector<Partitioner> const &partitioners,
|
||||||
std::vector<ExpandEntry> const &valid_candidates,
|
std::vector<ExpandEntry> const &valid_candidates,
|
||||||
linalg::MatrixView<GradientPair const> gpair, BatchParam const ¶m,
|
linalg::MatrixView<GradientPair const> gpair, BatchParam const ¶m,
|
||||||
@ -318,7 +319,7 @@ class MultiHistogramBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (bst_target_t t = 0; t < p_tree->NumTargets(); ++t) {
|
for (bst_target_t t = 0; t < p_tree->NumTargets(); ++t) {
|
||||||
this->target_builders_[t].SyncHistogram(p_tree, nodes_to_build, nodes_to_sub);
|
this->target_builders_[t].SyncHistogram(ctx, p_tree, nodes_to_build, nodes_to_sub);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@
|
|||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
DMLC_REGISTER_PARAMETER(HistMakerTrainParam);
|
DMLC_REGISTER_PARAMETER(HistMakerTrainParam);
|
||||||
|
|
||||||
void HistMakerTrainParam::CheckTreesSynchronized(RegTree const* local_tree) const {
|
void HistMakerTrainParam::CheckTreesSynchronized(Context const*, RegTree const* local_tree) const {
|
||||||
if (!this->debug_synchronize) {
|
if (!this->debug_synchronize) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -15,7 +15,7 @@ struct HistMakerTrainParam : public XGBoostParameter<HistMakerTrainParam> {
|
|||||||
bool debug_synchronize{false};
|
bool debug_synchronize{false};
|
||||||
std::size_t max_cached_hist_node{DefaultNodes()};
|
std::size_t max_cached_hist_node{DefaultNodes()};
|
||||||
|
|
||||||
void CheckTreesSynchronized(RegTree const* local_tree) const;
|
void CheckTreesSynchronized(Context const* ctx, RegTree const* local_tree) const;
|
||||||
|
|
||||||
// declare parameters
|
// declare parameters
|
||||||
DMLC_DECLARE_PARAMETER(HistMakerTrainParam) {
|
DMLC_DECLARE_PARAMETER(HistMakerTrainParam) {
|
||||||
|
|||||||
@ -140,7 +140,7 @@ class GloablApproxBuilder {
|
|||||||
std::vector<GradientPair> const &gpair, common::Span<float> hess) {
|
std::vector<GradientPair> const &gpair, common::Span<float> hess) {
|
||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
this->histogram_builder_.BuildHistLeftRight(
|
this->histogram_builder_.BuildHistLeftRight(
|
||||||
p_fmat, p_tree, partitioner_, valid_candidates,
|
ctx_, p_fmat, p_tree, partitioner_, valid_candidates,
|
||||||
linalg::MakeTensorView(ctx_, gpair, gpair.size(), 1), BatchSpec(*param_, hess));
|
linalg::MakeTensorView(ctx_, gpair, gpair.size(), 1), BatchSpec(*param_, hess));
|
||||||
monitor_->Stop(__func__);
|
monitor_->Stop(__func__);
|
||||||
}
|
}
|
||||||
@ -300,7 +300,7 @@ class GlobalApproxUpdater : public TreeUpdater {
|
|||||||
std::size_t t_idx = 0;
|
std::size_t t_idx = 0;
|
||||||
for (auto p_tree : trees) {
|
for (auto p_tree : trees) {
|
||||||
this->pimpl_->UpdateTree(m, s_gpair, hess, p_tree, &out_position[t_idx]);
|
this->pimpl_->UpdateTree(m, s_gpair, hess, p_tree, &out_position[t_idx]);
|
||||||
hist_param_.CheckTreesSynchronized(p_tree);
|
hist_param_.CheckTreesSynchronized(ctx_, p_tree);
|
||||||
++t_idx;
|
++t_idx;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -246,7 +246,7 @@ struct GPUHistMakerDevice {
|
|||||||
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param,
|
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param,
|
||||||
dmat->Info().IsColumnSplit(), ctx_->Device());
|
dmat->Info().IsColumnSplit(), ctx_->Device());
|
||||||
|
|
||||||
quantiser = std::make_unique<GradientQuantiser>(this->gpair, dmat->Info());
|
quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, dmat->Info());
|
||||||
|
|
||||||
row_partitioner.reset(); // Release the device memory first before reallocating
|
row_partitioner.reset(); // Release the device memory first before reallocating
|
||||||
row_partitioner = std::make_unique<RowPartitioner>(ctx_->Device(), sample.sample_rows);
|
row_partitioner = std::make_unique<RowPartitioner>(ctx_->Device(), sample.sample_rows);
|
||||||
@ -276,7 +276,7 @@ struct GPUHistMakerDevice {
|
|||||||
matrix.min_fvalue,
|
matrix.min_fvalue,
|
||||||
matrix.is_dense && !collective::IsDistributed()
|
matrix.is_dense && !collective::IsDistributed()
|
||||||
};
|
};
|
||||||
auto split = this->evaluator_.EvaluateSingleSplit(inputs, shared_inputs);
|
auto split = this->evaluator_.EvaluateSingleSplit(ctx_, inputs, shared_inputs);
|
||||||
return split;
|
return split;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -329,7 +329,7 @@ struct GPUHistMakerDevice {
|
|||||||
d_node_inputs.data().get(), h_node_inputs.data(),
|
d_node_inputs.data().get(), h_node_inputs.data(),
|
||||||
h_node_inputs.size() * sizeof(EvaluateSplitInputs), cudaMemcpyDefault));
|
h_node_inputs.size() * sizeof(EvaluateSplitInputs), cudaMemcpyDefault));
|
||||||
|
|
||||||
this->evaluator_.EvaluateSplits(nidx, max_active_features, dh::ToSpan(d_node_inputs),
|
this->evaluator_.EvaluateSplits(ctx_, nidx, max_active_features, dh::ToSpan(d_node_inputs),
|
||||||
shared_inputs, dh::ToSpan(entries));
|
shared_inputs, dh::ToSpan(entries));
|
||||||
dh::safe_cuda(cudaMemcpyAsync(pinned_candidates_out.data(),
|
dh::safe_cuda(cudaMemcpyAsync(pinned_candidates_out.data(),
|
||||||
entries.data().get(), sizeof(GPUExpandEntry) * entries.size(),
|
entries.data().get(), sizeof(GPUExpandEntry) * entries.size(),
|
||||||
@ -842,7 +842,7 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
std::size_t t_idx{0};
|
std::size_t t_idx{0};
|
||||||
for (xgboost::RegTree* tree : trees) {
|
for (xgboost::RegTree* tree : trees) {
|
||||||
this->UpdateTree(param, gpair_hdv, dmat, tree, &out_position[t_idx]);
|
this->UpdateTree(param, gpair_hdv, dmat, tree, &out_position[t_idx]);
|
||||||
this->hist_maker_param_.CheckTreesSynchronized(tree);
|
this->hist_maker_param_.CheckTreesSynchronized(ctx_, tree);
|
||||||
++t_idx;
|
++t_idx;
|
||||||
}
|
}
|
||||||
dh::safe_cuda(cudaGetLastError());
|
dh::safe_cuda(cudaGetLastError());
|
||||||
@ -985,7 +985,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
|
|||||||
std::size_t t_idx{0};
|
std::size_t t_idx{0};
|
||||||
for (xgboost::RegTree* tree : trees) {
|
for (xgboost::RegTree* tree : trees) {
|
||||||
this->UpdateTree(gpair->Data(), p_fmat, tree, &out_position[t_idx]);
|
this->UpdateTree(gpair->Data(), p_fmat, tree, &out_position[t_idx]);
|
||||||
this->hist_maker_param_.CheckTreesSynchronized(tree);
|
this->hist_maker_param_.CheckTreesSynchronized(ctx_, tree);
|
||||||
++t_idx;
|
++t_idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -228,8 +228,8 @@ class MultiTargetHistBuilder {
|
|||||||
std::vector<MultiExpandEntry> const &valid_candidates,
|
std::vector<MultiExpandEntry> const &valid_candidates,
|
||||||
linalg::MatrixView<GradientPair const> gpair) {
|
linalg::MatrixView<GradientPair const> gpair) {
|
||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
histogram_builder_->BuildHistLeftRight(p_fmat, p_tree, partitioner_, valid_candidates, gpair,
|
histogram_builder_->BuildHistLeftRight(ctx_, p_fmat, p_tree, partitioner_, valid_candidates,
|
||||||
HistBatch(param_));
|
gpair, HistBatch(param_));
|
||||||
monitor_->Stop(__func__);
|
monitor_->Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -436,8 +436,8 @@ class HistUpdater {
|
|||||||
std::vector<CPUExpandEntry> const &valid_candidates,
|
std::vector<CPUExpandEntry> const &valid_candidates,
|
||||||
linalg::MatrixView<GradientPair const> gpair) {
|
linalg::MatrixView<GradientPair const> gpair) {
|
||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
this->histogram_builder_->BuildHistLeftRight(p_fmat, p_tree, partitioner_, valid_candidates,
|
this->histogram_builder_->BuildHistLeftRight(ctx_, p_fmat, p_tree, partitioner_,
|
||||||
gpair, HistBatch(param_));
|
valid_candidates, gpair, HistBatch(param_));
|
||||||
monitor_->Stop(__func__);
|
monitor_->Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -537,7 +537,7 @@ class QuantileHistMaker : public TreeUpdater {
|
|||||||
h_out_position, *tree_it);
|
h_out_position, *tree_it);
|
||||||
}
|
}
|
||||||
|
|
||||||
hist_param_.CheckTreesSynchronized(*tree_it);
|
hist_param_.CheckTreesSynchronized(ctx_, *tree_it);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -360,25 +360,27 @@ TEST(HistUtil, DeviceSketchExternalMemoryWithWeights) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename Adapter>
|
template <typename Adapter>
|
||||||
auto MakeUnweightedCutsForTest(Adapter adapter, int32_t num_bins, float missing, size_t batch_size = 0) {
|
auto MakeUnweightedCutsForTest(Context const* ctx, Adapter adapter, int32_t num_bins, float missing,
|
||||||
|
size_t batch_size = 0) {
|
||||||
common::HistogramCuts batched_cuts;
|
common::HistogramCuts batched_cuts;
|
||||||
HostDeviceVector<FeatureType> ft;
|
HostDeviceVector<FeatureType> ft;
|
||||||
SketchContainer sketch_container(ft, num_bins, adapter.NumColumns(), adapter.NumRows(),
|
SketchContainer sketch_container(ft, num_bins, adapter.NumColumns(), adapter.NumRows(),
|
||||||
DeviceOrd::CUDA(0));
|
DeviceOrd::CUDA(0));
|
||||||
MetaInfo info;
|
MetaInfo info;
|
||||||
AdapterDeviceSketch(adapter.Value(), num_bins, info, missing, &sketch_container, batch_size);
|
AdapterDeviceSketch(adapter.Value(), num_bins, info, missing, &sketch_container, batch_size);
|
||||||
sketch_container.MakeCuts(&batched_cuts, info.IsColumnSplit());
|
sketch_container.MakeCuts(ctx, &batched_cuts, info.IsColumnSplit());
|
||||||
return batched_cuts;
|
return batched_cuts;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Adapter>
|
template <typename Adapter>
|
||||||
void ValidateBatchedCuts(Adapter adapter, int num_bins, DMatrix* dmat, size_t batch_size = 0) {
|
void ValidateBatchedCuts(Context const* ctx, Adapter adapter, int num_bins, DMatrix* dmat, size_t batch_size = 0) {
|
||||||
common::HistogramCuts batched_cuts = MakeUnweightedCutsForTest(
|
common::HistogramCuts batched_cuts = MakeUnweightedCutsForTest(
|
||||||
adapter, num_bins, std::numeric_limits<float>::quiet_NaN(), batch_size);
|
ctx, adapter, num_bins, std::numeric_limits<float>::quiet_NaN(), batch_size);
|
||||||
ValidateCuts(batched_cuts, dmat, num_bins);
|
ValidateCuts(batched_cuts, dmat, num_bins);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HistUtil, AdapterDeviceSketch) {
|
TEST(HistUtil, AdapterDeviceSketch) {
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
int rows = 5;
|
int rows = 5;
|
||||||
int cols = 1;
|
int cols = 1;
|
||||||
int num_bins = 4;
|
int num_bins = 4;
|
||||||
@ -391,8 +393,8 @@ TEST(HistUtil, AdapterDeviceSketch) {
|
|||||||
|
|
||||||
data::CupyAdapter adapter(str);
|
data::CupyAdapter adapter(str);
|
||||||
|
|
||||||
auto device_cuts = MakeUnweightedCutsForTest(adapter, num_bins, missing);
|
auto device_cuts = MakeUnweightedCutsForTest(&ctx, adapter, num_bins, missing);
|
||||||
Context ctx;
|
ctx = ctx.MakeCPU();
|
||||||
auto host_cuts = GetHostCuts(&ctx, &adapter, num_bins, missing);
|
auto host_cuts = GetHostCuts(&ctx, &adapter, num_bins, missing);
|
||||||
|
|
||||||
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
|
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
|
||||||
@ -401,6 +403,7 @@ TEST(HistUtil, AdapterDeviceSketch) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(HistUtil, AdapterDeviceSketchMemory) {
|
TEST(HistUtil, AdapterDeviceSketchMemory) {
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
int num_columns = 100;
|
int num_columns = 100;
|
||||||
int num_rows = 1000;
|
int num_rows = 1000;
|
||||||
int num_bins = 256;
|
int num_bins = 256;
|
||||||
@ -410,7 +413,8 @@ TEST(HistUtil, AdapterDeviceSketchMemory) {
|
|||||||
|
|
||||||
dh::GlobalMemoryLogger().Clear();
|
dh::GlobalMemoryLogger().Clear();
|
||||||
ConsoleLogger::Configure({{"verbosity", "3"}});
|
ConsoleLogger::Configure({{"verbosity", "3"}});
|
||||||
auto cuts = MakeUnweightedCutsForTest(adapter, num_bins, std::numeric_limits<float>::quiet_NaN());
|
auto cuts =
|
||||||
|
MakeUnweightedCutsForTest(&ctx, adapter, num_bins, std::numeric_limits<float>::quiet_NaN());
|
||||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||||
size_t bytes_required = detail::RequiredMemory(
|
size_t bytes_required = detail::RequiredMemory(
|
||||||
num_rows, num_columns, num_rows * num_columns, num_bins, false);
|
num_rows, num_columns, num_rows * num_columns, num_bins, false);
|
||||||
@ -419,6 +423,7 @@ TEST(HistUtil, AdapterDeviceSketchMemory) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(HistUtil, AdapterSketchSlidingWindowMemory) {
|
TEST(HistUtil, AdapterSketchSlidingWindowMemory) {
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
int num_columns = 100;
|
int num_columns = 100;
|
||||||
int num_rows = 1000;
|
int num_rows = 1000;
|
||||||
int num_bins = 256;
|
int num_bins = 256;
|
||||||
@ -435,7 +440,7 @@ TEST(HistUtil, AdapterSketchSlidingWindowMemory) {
|
|||||||
AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits<float>::quiet_NaN(),
|
AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits<float>::quiet_NaN(),
|
||||||
&sketch_container);
|
&sketch_container);
|
||||||
HistogramCuts cuts;
|
HistogramCuts cuts;
|
||||||
sketch_container.MakeCuts(&cuts, info.IsColumnSplit());
|
sketch_container.MakeCuts(&ctx, &cuts, info.IsColumnSplit());
|
||||||
size_t bytes_required = detail::RequiredMemory(
|
size_t bytes_required = detail::RequiredMemory(
|
||||||
num_rows, num_columns, num_rows * num_columns, num_bins, false);
|
num_rows, num_columns, num_rows * num_columns, num_bins, false);
|
||||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
|
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
|
||||||
@ -444,6 +449,7 @@ TEST(HistUtil, AdapterSketchSlidingWindowMemory) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) {
|
TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) {
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
int num_columns = 100;
|
int num_columns = 100;
|
||||||
int num_rows = 1000;
|
int num_rows = 1000;
|
||||||
int num_bins = 256;
|
int num_bins = 256;
|
||||||
@ -465,7 +471,7 @@ TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) {
|
|||||||
&sketch_container);
|
&sketch_container);
|
||||||
|
|
||||||
HistogramCuts cuts;
|
HistogramCuts cuts;
|
||||||
sketch_container.MakeCuts(&cuts, info.IsColumnSplit());
|
sketch_container.MakeCuts(&ctx, &cuts, info.IsColumnSplit());
|
||||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||||
size_t bytes_required = detail::RequiredMemory(
|
size_t bytes_required = detail::RequiredMemory(
|
||||||
num_rows, num_columns, num_rows * num_columns, num_bins, true);
|
num_rows, num_columns, num_rows * num_columns, num_bins, true);
|
||||||
@ -475,6 +481,7 @@ TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) {
|
|||||||
|
|
||||||
void TestCategoricalSketchAdapter(size_t n, size_t num_categories,
|
void TestCategoricalSketchAdapter(size_t n, size_t num_categories,
|
||||||
int32_t num_bins, bool weighted) {
|
int32_t num_bins, bool weighted) {
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
auto h_x = GenerateRandomCategoricalSingleColumn(n, num_categories);
|
auto h_x = GenerateRandomCategoricalSingleColumn(n, num_categories);
|
||||||
thrust::device_vector<float> x(h_x);
|
thrust::device_vector<float> x(h_x);
|
||||||
auto adapter = AdapterFromData(x, n, 1);
|
auto adapter = AdapterFromData(x, n, 1);
|
||||||
@ -498,7 +505,7 @@ void TestCategoricalSketchAdapter(size_t n, size_t num_categories,
|
|||||||
AdapterDeviceSketch(adapter.Value(), num_bins, info,
|
AdapterDeviceSketch(adapter.Value(), num_bins, info,
|
||||||
std::numeric_limits<float>::quiet_NaN(), &container);
|
std::numeric_limits<float>::quiet_NaN(), &container);
|
||||||
HistogramCuts cuts;
|
HistogramCuts cuts;
|
||||||
container.MakeCuts(&cuts, info.IsColumnSplit());
|
container.MakeCuts(&ctx, &cuts, info.IsColumnSplit());
|
||||||
|
|
||||||
thrust::sort(x.begin(), x.end());
|
thrust::sort(x.begin(), x.end());
|
||||||
auto n_uniques = thrust::unique(x.begin(), x.end()) - x.begin();
|
auto n_uniques = thrust::unique(x.begin(), x.end()) - x.begin();
|
||||||
@ -522,6 +529,7 @@ void TestCategoricalSketchAdapter(size_t n, size_t num_categories,
|
|||||||
TEST(HistUtil, AdapterDeviceSketchCategorical) {
|
TEST(HistUtil, AdapterDeviceSketchCategorical) {
|
||||||
auto categorical_sizes = {2, 6, 8, 12};
|
auto categorical_sizes = {2, 6, 8, 12};
|
||||||
int num_bins = 256;
|
int num_bins = 256;
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
auto sizes = {25, 100, 1000};
|
auto sizes = {25, 100, 1000};
|
||||||
for (auto n : sizes) {
|
for (auto n : sizes) {
|
||||||
for (auto num_categories : categorical_sizes) {
|
for (auto num_categories : categorical_sizes) {
|
||||||
@ -529,7 +537,7 @@ TEST(HistUtil, AdapterDeviceSketchCategorical) {
|
|||||||
auto dmat = GetDMatrixFromData(x, n, 1);
|
auto dmat = GetDMatrixFromData(x, n, 1);
|
||||||
auto x_device = thrust::device_vector<float>(x);
|
auto x_device = thrust::device_vector<float>(x);
|
||||||
auto adapter = AdapterFromData(x_device, n, 1);
|
auto adapter = AdapterFromData(x_device, n, 1);
|
||||||
ValidateBatchedCuts(adapter, num_bins, dmat.get());
|
ValidateBatchedCuts(&ctx, adapter, num_bins, dmat.get());
|
||||||
TestCategoricalSketchAdapter(n, num_categories, num_bins, true);
|
TestCategoricalSketchAdapter(n, num_categories, num_bins, true);
|
||||||
TestCategoricalSketchAdapter(n, num_categories, num_bins, false);
|
TestCategoricalSketchAdapter(n, num_categories, num_bins, false);
|
||||||
}
|
}
|
||||||
@ -540,13 +548,14 @@ TEST(HistUtil, AdapterDeviceSketchMultipleColumns) {
|
|||||||
auto bin_sizes = {2, 16, 256, 512};
|
auto bin_sizes = {2, 16, 256, 512};
|
||||||
auto sizes = {100, 1000, 1500};
|
auto sizes = {100, 1000, 1500};
|
||||||
int num_columns = 5;
|
int num_columns = 5;
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
for (auto num_rows : sizes) {
|
for (auto num_rows : sizes) {
|
||||||
auto x = GenerateRandom(num_rows, num_columns);
|
auto x = GenerateRandom(num_rows, num_columns);
|
||||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||||
auto x_device = thrust::device_vector<float>(x);
|
auto x_device = thrust::device_vector<float>(x);
|
||||||
for (auto num_bins : bin_sizes) {
|
for (auto num_bins : bin_sizes) {
|
||||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||||
ValidateBatchedCuts(adapter, num_bins, dmat.get());
|
ValidateBatchedCuts(&ctx, adapter, num_bins, dmat.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -556,12 +565,13 @@ TEST(HistUtil, AdapterDeviceSketchBatches) {
|
|||||||
int num_rows = 5000;
|
int num_rows = 5000;
|
||||||
auto batch_sizes = {0, 100, 1500, 6000};
|
auto batch_sizes = {0, 100, 1500, 6000};
|
||||||
int num_columns = 5;
|
int num_columns = 5;
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
for (auto batch_size : batch_sizes) {
|
for (auto batch_size : batch_sizes) {
|
||||||
auto x = GenerateRandom(num_rows, num_columns);
|
auto x = GenerateRandom(num_rows, num_columns);
|
||||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||||
auto x_device = thrust::device_vector<float>(x);
|
auto x_device = thrust::device_vector<float>(x);
|
||||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||||
ValidateBatchedCuts(adapter, num_bins, dmat.get(), batch_size);
|
ValidateBatchedCuts(&ctx, adapter, num_bins, dmat.get(), batch_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -647,12 +657,12 @@ TEST(HistUtil, SketchingEquivalent) {
|
|||||||
auto x_device = thrust::device_vector<float>(x);
|
auto x_device = thrust::device_vector<float>(x);
|
||||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||||
common::HistogramCuts adapter_cuts = MakeUnweightedCutsForTest(
|
common::HistogramCuts adapter_cuts = MakeUnweightedCutsForTest(
|
||||||
adapter, num_bins, std::numeric_limits<float>::quiet_NaN());
|
&ctx, adapter, num_bins, std::numeric_limits<float>::quiet_NaN());
|
||||||
EXPECT_EQ(dmat_cuts.Values(), adapter_cuts.Values());
|
EXPECT_EQ(dmat_cuts.Values(), adapter_cuts.Values());
|
||||||
EXPECT_EQ(dmat_cuts.Ptrs(), adapter_cuts.Ptrs());
|
EXPECT_EQ(dmat_cuts.Ptrs(), adapter_cuts.Ptrs());
|
||||||
EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues());
|
EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues());
|
||||||
|
|
||||||
ValidateBatchedCuts(adapter, num_bins, dmat.get());
|
ValidateBatchedCuts(&ctx, adapter, num_bins, dmat.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -702,7 +712,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
|||||||
.Device(DeviceOrd::CUDA(0))
|
.Device(DeviceOrd::CUDA(0))
|
||||||
.GenerateArrayInterface(&storage);
|
.GenerateArrayInterface(&storage);
|
||||||
MetaInfo info;
|
MetaInfo info;
|
||||||
Context ctx;
|
auto ctx = MakeCUDACtx(0);
|
||||||
auto& h_weights = info.weights_.HostVector();
|
auto& h_weights = info.weights_.HostVector();
|
||||||
if (with_group) {
|
if (with_group) {
|
||||||
h_weights.resize(kGroups);
|
h_weights.resize(kGroups);
|
||||||
@ -731,7 +741,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
|||||||
&sketch_container);
|
&sketch_container);
|
||||||
|
|
||||||
common::HistogramCuts cuts;
|
common::HistogramCuts cuts;
|
||||||
sketch_container.MakeCuts(&cuts, info.IsColumnSplit());
|
sketch_container.MakeCuts(&ctx, &cuts, info.IsColumnSplit());
|
||||||
|
|
||||||
auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols);
|
auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols);
|
||||||
if (with_group) {
|
if (with_group) {
|
||||||
@ -744,10 +754,9 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
|||||||
ASSERT_EQ(cuts.Ptrs().size(), kCols + 1);
|
ASSERT_EQ(cuts.Ptrs().size(), kCols + 1);
|
||||||
ValidateCuts(cuts, dmat.get(), kBins);
|
ValidateCuts(cuts, dmat.get(), kBins);
|
||||||
|
|
||||||
auto cuda_ctx = MakeCUDACtx(0);
|
|
||||||
if (with_group) {
|
if (with_group) {
|
||||||
dmat->Info().weights_ = decltype(dmat->Info().weights_)(); // remove weight
|
dmat->Info().weights_ = decltype(dmat->Info().weights_)(); // remove weight
|
||||||
HistogramCuts non_weighted = DeviceSketch(&cuda_ctx, dmat.get(), kBins, 0);
|
HistogramCuts non_weighted = DeviceSketch(&ctx, dmat.get(), kBins, 0);
|
||||||
for (size_t i = 0; i < cuts.Values().size(); ++i) {
|
for (size_t i = 0; i < cuts.Values().size(); ++i) {
|
||||||
ASSERT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
|
ASSERT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
|
||||||
}
|
}
|
||||||
@ -773,7 +782,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
|||||||
SketchContainer sketch_container{ft, kBins, kCols, kRows, DeviceOrd::CUDA(0)};
|
SketchContainer sketch_container{ft, kBins, kCols, kRows, DeviceOrd::CUDA(0)};
|
||||||
AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
|
AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
|
||||||
&sketch_container);
|
&sketch_container);
|
||||||
sketch_container.MakeCuts(&weighted, info.IsColumnSplit());
|
sketch_container.MakeCuts(&ctx, &weighted, info.IsColumnSplit());
|
||||||
ValidateCuts(weighted, dmat.get(), kBins);
|
ValidateCuts(weighted, dmat.get(), kBins);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -86,7 +86,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
HistogramCuts distributed_cuts;
|
HistogramCuts distributed_cuts;
|
||||||
sketch_distributed.MakeCuts(m->Info(), &distributed_cuts);
|
sketch_distributed.MakeCuts(&ctx, m->Info(), &distributed_cuts);
|
||||||
|
|
||||||
// Generate cuts for single node environment
|
// Generate cuts for single node environment
|
||||||
collective::Finalize();
|
collective::Finalize();
|
||||||
@ -117,7 +117,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
HistogramCuts single_node_cuts;
|
HistogramCuts single_node_cuts;
|
||||||
sketch_on_single_node.MakeCuts(m->Info(), &single_node_cuts);
|
sketch_on_single_node.MakeCuts(&ctx, m->Info(), &single_node_cuts);
|
||||||
|
|
||||||
auto const& sptrs = single_node_cuts.Ptrs();
|
auto const& sptrs = single_node_cuts.Ptrs();
|
||||||
auto const& dptrs = distributed_cuts.Ptrs();
|
auto const& dptrs = distributed_cuts.Ptrs();
|
||||||
@ -220,7 +220,7 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sketch_distributed.MakeCuts(m->Info(), &distributed_cuts);
|
sketch_distributed.MakeCuts(&ctx, m->Info(), &distributed_cuts);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate cuts for single node environment
|
// Generate cuts for single node environment
|
||||||
@ -243,7 +243,7 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sketch_on_single_node.MakeCuts(m->Info(), &single_node_cuts);
|
sketch_on_single_node.MakeCuts(&ctx, m->Info(), &single_node_cuts);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto const& sptrs = single_node_cuts.Ptrs();
|
auto const& sptrs = single_node_cuts.Ptrs();
|
||||||
|
|||||||
@ -370,6 +370,7 @@ void TestAllReduceBasic() {
|
|||||||
constexpr size_t kRows = 1000, kCols = 100;
|
constexpr size_t kRows = 1000, kCols = 100;
|
||||||
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
|
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
|
||||||
auto const device = DeviceOrd::CUDA(GPUIDX);
|
auto const device = DeviceOrd::CUDA(GPUIDX);
|
||||||
|
auto ctx = MakeCUDACtx(device.ordinal);
|
||||||
|
|
||||||
// Set up single node version;
|
// Set up single node version;
|
||||||
HostDeviceVector<FeatureType> ft({}, device);
|
HostDeviceVector<FeatureType> ft({}, device);
|
||||||
@ -413,7 +414,7 @@ void TestAllReduceBasic() {
|
|||||||
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
||||||
std::numeric_limits<float>::quiet_NaN(),
|
std::numeric_limits<float>::quiet_NaN(),
|
||||||
&sketch_distributed);
|
&sketch_distributed);
|
||||||
sketch_distributed.AllReduce(false);
|
sketch_distributed.AllReduce(&ctx, false);
|
||||||
sketch_distributed.Unique();
|
sketch_distributed.Unique();
|
||||||
|
|
||||||
ASSERT_EQ(sketch_distributed.ColumnsPtr().size(),
|
ASSERT_EQ(sketch_distributed.ColumnsPtr().size(),
|
||||||
@ -517,6 +518,7 @@ void TestSameOnAllWorkers() {
|
|||||||
MetaInfo const &info) {
|
MetaInfo const &info) {
|
||||||
auto const rank = collective::GetRank();
|
auto const rank = collective::GetRank();
|
||||||
auto const device = DeviceOrd::CUDA(GPUIDX);
|
auto const device = DeviceOrd::CUDA(GPUIDX);
|
||||||
|
Context ctx = MakeCUDACtx(device.ordinal);
|
||||||
HostDeviceVector<FeatureType> ft({}, device);
|
HostDeviceVector<FeatureType> ft({}, device);
|
||||||
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, device);
|
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, device);
|
||||||
HostDeviceVector<float> storage({}, device);
|
HostDeviceVector<float> storage({}, device);
|
||||||
@ -528,7 +530,7 @@ void TestSameOnAllWorkers() {
|
|||||||
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
||||||
std::numeric_limits<float>::quiet_NaN(),
|
std::numeric_limits<float>::quiet_NaN(),
|
||||||
&sketch_distributed);
|
&sketch_distributed);
|
||||||
sketch_distributed.AllReduce(false);
|
sketch_distributed.AllReduce(&ctx, false);
|
||||||
sketch_distributed.Unique();
|
sketch_distributed.Unique();
|
||||||
TestQuantileElemRank(device, sketch_distributed.Data(), sketch_distributed.ColumnsPtr(), true);
|
TestQuantileElemRank(device, sketch_distributed.Data(), sketch_distributed.ColumnsPtr(), true);
|
||||||
|
|
||||||
|
|||||||
@ -73,6 +73,7 @@ void RunWithFederatedCommunicator(int32_t world_size, std::string const& server_
|
|||||||
auto run = [&](auto rank) {
|
auto run = [&](auto rank) {
|
||||||
Json config{JsonObject()};
|
Json config{JsonObject()};
|
||||||
config["xgboost_communicator"] = String("federated");
|
config["xgboost_communicator"] = String("federated");
|
||||||
|
config["federated_secure"] = false;
|
||||||
config["federated_server_address"] = String(server_address);
|
config["federated_server_address"] = String(server_address);
|
||||||
config["federated_world_size"] = world_size;
|
config["federated_world_size"] = world_size;
|
||||||
config["federated_rank"] = rank;
|
config["federated_rank"] = rank;
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
* Copyright (c) 2017-2023, XGBoost contributors
|
* Copyright (c) 2017-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include <gmock/gmock.h>
|
||||||
#include <xgboost/learner.h> // for Learner
|
#include <xgboost/learner.h> // for Learner
|
||||||
#include <xgboost/logging.h> // for LogCheck_NE, CHECK_NE, LogCheck_EQ
|
#include <xgboost/logging.h> // for LogCheck_NE, CHECK_NE, LogCheck_EQ
|
||||||
#include <xgboost/objective.h> // for ObjFunction
|
#include <xgboost/objective.h> // for ObjFunction
|
||||||
@ -81,7 +82,9 @@ TEST(Learner, ParameterValidation) {
|
|||||||
|
|
||||||
// whitespace
|
// whitespace
|
||||||
learner->SetParam("tree method", "exact");
|
learner->SetParam("tree method", "exact");
|
||||||
EXPECT_THROW(learner->Configure(), dmlc::Error);
|
EXPECT_THAT([&] { learner->Configure(); },
|
||||||
|
::testing::ThrowsMessage<dmlc::Error>(
|
||||||
|
::testing::HasSubstr(R"("tree method" contains whitespace)")));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Learner, CheckGroup) {
|
TEST(Learner, CheckGroup) {
|
||||||
|
|||||||
@ -19,14 +19,15 @@ auto ZeroParam() {
|
|||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
inline GradientQuantiser DummyRoundingFactor() {
|
inline GradientQuantiser DummyRoundingFactor(Context const* ctx) {
|
||||||
thrust::device_vector<GradientPair> gpair(1);
|
thrust::device_vector<GradientPair> gpair(1);
|
||||||
gpair[0] = {1000.f, 1000.f}; // Tests should not exceed sum of 1000
|
gpair[0] = {1000.f, 1000.f}; // Tests should not exceed sum of 1000
|
||||||
return {dh::ToSpan(gpair), MetaInfo()};
|
return {ctx, dh::ToSpan(gpair), MetaInfo()};
|
||||||
}
|
}
|
||||||
|
|
||||||
thrust::device_vector<GradientPairInt64> ConvertToInteger(std::vector<GradientPairPrecise> x) {
|
thrust::device_vector<GradientPairInt64> ConvertToInteger(Context const* ctx,
|
||||||
auto r = DummyRoundingFactor();
|
std::vector<GradientPairPrecise> x) {
|
||||||
|
auto r = DummyRoundingFactor(ctx);
|
||||||
std::vector<GradientPairInt64> y(x.size());
|
std::vector<GradientPairInt64> y(x.size());
|
||||||
for (std::size_t i = 0; i < x.size(); i++) {
|
for (std::size_t i = 0; i < x.size(); i++) {
|
||||||
y[i] = r.ToFixedPoint(GradientPair(x[i]));
|
y[i] = r.ToFixedPoint(GradientPair(x[i]));
|
||||||
@ -41,11 +42,12 @@ TEST_F(TestCategoricalSplitWithMissing, GPUHistEvaluator) {
|
|||||||
cuts_.cut_ptrs_.SetDevice(ctx.Device());
|
cuts_.cut_ptrs_.SetDevice(ctx.Device());
|
||||||
cuts_.cut_values_.SetDevice(ctx.Device());
|
cuts_.cut_values_.SetDevice(ctx.Device());
|
||||||
cuts_.min_vals_.SetDevice(ctx.Device());
|
cuts_.min_vals_.SetDevice(ctx.Device());
|
||||||
thrust::device_vector<GradientPairInt64> feature_histogram{ConvertToInteger(feature_histogram_)};
|
thrust::device_vector<GradientPairInt64> feature_histogram{
|
||||||
|
ConvertToInteger(&ctx, feature_histogram_)};
|
||||||
|
|
||||||
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
|
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
|
||||||
auto d_feature_types = dh::ToSpan(feature_types);
|
auto d_feature_types = dh::ToSpan(feature_types);
|
||||||
auto quantiser = DummyRoundingFactor();
|
auto quantiser = DummyRoundingFactor(&ctx);
|
||||||
EvaluateSplitInputs input{1, 0, quantiser.ToFixedPoint(parent_sum_), dh::ToSpan(feature_set),
|
EvaluateSplitInputs input{1, 0, quantiser.ToFixedPoint(parent_sum_), dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram)};
|
dh::ToSpan(feature_histogram)};
|
||||||
EvaluateSplitSharedInputs shared_inputs{param,
|
EvaluateSplitSharedInputs shared_inputs{param,
|
||||||
@ -60,7 +62,7 @@ TEST_F(TestCategoricalSplitWithMissing, GPUHistEvaluator) {
|
|||||||
|
|
||||||
evaluator.Reset(cuts_, dh::ToSpan(feature_types), feature_set.size(), param_, false,
|
evaluator.Reset(cuts_, dh::ToSpan(feature_types), feature_set.size(), param_, false,
|
||||||
ctx.Device());
|
ctx.Device());
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||||
|
|
||||||
ASSERT_EQ(result.thresh, 1);
|
ASSERT_EQ(result.thresh, 1);
|
||||||
this->CheckResult(result.loss_chg, result.findex, result.fvalue, result.is_cat,
|
this->CheckResult(result.loss_chg, result.findex, result.fvalue, result.is_cat,
|
||||||
@ -90,7 +92,7 @@ TEST(GpuHist, PartitionBasic) {
|
|||||||
*std::max_element(cuts.cut_values_.HostVector().begin(), cuts.cut_values_.HostVector().end());
|
*std::max_element(cuts.cut_values_.HostVector().begin(), cuts.cut_values_.HostVector().end());
|
||||||
cuts.SetCategorical(true, max_cat);
|
cuts.SetCategorical(true, max_cat);
|
||||||
d_feature_types = dh::ToSpan(feature_types);
|
d_feature_types = dh::ToSpan(feature_types);
|
||||||
auto quantiser = DummyRoundingFactor();
|
auto quantiser = DummyRoundingFactor(&ctx);
|
||||||
EvaluateSplitSharedInputs shared_inputs{
|
EvaluateSplitSharedInputs shared_inputs{
|
||||||
param,
|
param,
|
||||||
quantiser,
|
quantiser,
|
||||||
@ -108,10 +110,10 @@ TEST(GpuHist, PartitionBasic) {
|
|||||||
// -1.0s go right
|
// -1.0s go right
|
||||||
// -3.0s go left
|
// -3.0s go left
|
||||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-5.0, 3.0});
|
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-5.0, 3.0});
|
||||||
auto feature_histogram = ConvertToInteger({{-1.0, 1.0}, {-1.0, 1.0}, {-3.0, 1.0}});
|
auto feature_histogram = ConvertToInteger(&ctx, {{-1.0, 1.0}, {-1.0, 1.0}, {-3.0, 1.0}});
|
||||||
EvaluateSplitInputs input{0, 0, parent_sum, dh::ToSpan(feature_set),
|
EvaluateSplitInputs input{0, 0, parent_sum, dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram)};
|
dh::ToSpan(feature_histogram)};
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||||
EXPECT_EQ(result.dir, kLeftDir);
|
EXPECT_EQ(result.dir, kLeftDir);
|
||||||
EXPECT_EQ(cats, std::bitset<32>("11000000000000000000000000000000"));
|
EXPECT_EQ(cats, std::bitset<32>("11000000000000000000000000000000"));
|
||||||
@ -122,10 +124,10 @@ TEST(GpuHist, PartitionBasic) {
|
|||||||
// -1.0s go right
|
// -1.0s go right
|
||||||
// -3.0s go left
|
// -3.0s go left
|
||||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-7.0, 3.0});
|
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-7.0, 3.0});
|
||||||
auto feature_histogram = ConvertToInteger({{-1.0, 1.0}, {-3.0, 1.0}, {-3.0, 1.0}});
|
auto feature_histogram = ConvertToInteger(&ctx, {{-1.0, 1.0}, {-3.0, 1.0}, {-3.0, 1.0}});
|
||||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram)};
|
dh::ToSpan(feature_histogram)};
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||||
EXPECT_EQ(result.dir, kLeftDir);
|
EXPECT_EQ(result.dir, kLeftDir);
|
||||||
EXPECT_EQ(cats, std::bitset<32>("10000000000000000000000000000000"));
|
EXPECT_EQ(cats, std::bitset<32>("10000000000000000000000000000000"));
|
||||||
@ -134,10 +136,10 @@ TEST(GpuHist, PartitionBasic) {
|
|||||||
{
|
{
|
||||||
// All -1.0, gain from splitting should be 0.0
|
// All -1.0, gain from splitting should be 0.0
|
||||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-3.0, 3.0});
|
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-3.0, 3.0});
|
||||||
auto feature_histogram = ConvertToInteger({{-1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}});
|
auto feature_histogram = ConvertToInteger(&ctx, {{-1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}});
|
||||||
EvaluateSplitInputs input{2, 0, parent_sum, dh::ToSpan(feature_set),
|
EvaluateSplitInputs input{2, 0, parent_sum, dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram)};
|
dh::ToSpan(feature_histogram)};
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||||
EXPECT_EQ(result.dir, kLeftDir);
|
EXPECT_EQ(result.dir, kLeftDir);
|
||||||
EXPECT_FLOAT_EQ(result.loss_chg, 0.0f);
|
EXPECT_FLOAT_EQ(result.loss_chg, 0.0f);
|
||||||
EXPECT_EQ(result.left_sum + result.right_sum, parent_sum);
|
EXPECT_EQ(result.left_sum + result.right_sum, parent_sum);
|
||||||
@ -147,10 +149,10 @@ TEST(GpuHist, PartitionBasic) {
|
|||||||
// value
|
// value
|
||||||
{
|
{
|
||||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 6.0});
|
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 6.0});
|
||||||
auto feature_histogram = ConvertToInteger({{-1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}});
|
auto feature_histogram = ConvertToInteger(&ctx, {{-1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}});
|
||||||
EvaluateSplitInputs input{3, 0, parent_sum, dh::ToSpan(feature_set),
|
EvaluateSplitInputs input{3, 0, parent_sum, dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram)};
|
dh::ToSpan(feature_histogram)};
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||||
EXPECT_EQ(cats, std::bitset<32>("11000000000000000000000000000000"));
|
EXPECT_EQ(cats, std::bitset<32>("11000000000000000000000000000000"));
|
||||||
EXPECT_EQ(result.dir, kLeftDir);
|
EXPECT_EQ(result.dir, kLeftDir);
|
||||||
@ -160,10 +162,10 @@ TEST(GpuHist, PartitionBasic) {
|
|||||||
// -1.0s go right
|
// -1.0s go right
|
||||||
// -3.0s go left
|
// -3.0s go left
|
||||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-5.0, 3.0});
|
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-5.0, 3.0});
|
||||||
auto feature_histogram = ConvertToInteger({{-1.0, 1.0}, {-3.0, 1.0}, {-1.0, 1.0}});
|
auto feature_histogram = ConvertToInteger(&ctx, {{-1.0, 1.0}, {-3.0, 1.0}, {-1.0, 1.0}});
|
||||||
EvaluateSplitInputs input{4, 0, parent_sum, dh::ToSpan(feature_set),
|
EvaluateSplitInputs input{4, 0, parent_sum, dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram)};
|
dh::ToSpan(feature_histogram)};
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||||
EXPECT_EQ(result.dir, kLeftDir);
|
EXPECT_EQ(result.dir, kLeftDir);
|
||||||
EXPECT_EQ(cats, std::bitset<32>("10100000000000000000000000000000"));
|
EXPECT_EQ(cats, std::bitset<32>("10100000000000000000000000000000"));
|
||||||
@ -173,10 +175,10 @@ TEST(GpuHist, PartitionBasic) {
|
|||||||
// -1.0s go right
|
// -1.0s go right
|
||||||
// -3.0s go left
|
// -3.0s go left
|
||||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-5.0, 3.0});
|
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-5.0, 3.0});
|
||||||
auto feature_histogram = ConvertToInteger({{-3.0, 1.0}, {-1.0, 1.0}, {-3.0, 1.0}});
|
auto feature_histogram = ConvertToInteger(&ctx, {{-3.0, 1.0}, {-1.0, 1.0}, {-3.0, 1.0}});
|
||||||
EvaluateSplitInputs input{5, 0, parent_sum, dh::ToSpan(feature_set),
|
EvaluateSplitInputs input{5, 0, parent_sum, dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram)};
|
dh::ToSpan(feature_histogram)};
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||||
EXPECT_EQ(cats, std::bitset<32>("01000000000000000000000000000000"));
|
EXPECT_EQ(cats, std::bitset<32>("01000000000000000000000000000000"));
|
||||||
EXPECT_EQ(result.left_sum + result.right_sum, parent_sum);
|
EXPECT_EQ(result.left_sum + result.right_sum, parent_sum);
|
||||||
@ -205,7 +207,7 @@ TEST(GpuHist, PartitionTwoFeatures) {
|
|||||||
*std::max_element(cuts.cut_values_.HostVector().begin(), cuts.cut_values_.HostVector().end());
|
*std::max_element(cuts.cut_values_.HostVector().begin(), cuts.cut_values_.HostVector().end());
|
||||||
cuts.SetCategorical(true, max_cat);
|
cuts.SetCategorical(true, max_cat);
|
||||||
|
|
||||||
auto quantiser = DummyRoundingFactor();
|
auto quantiser = DummyRoundingFactor(&ctx);
|
||||||
EvaluateSplitSharedInputs shared_inputs{param,
|
EvaluateSplitSharedInputs shared_inputs{param,
|
||||||
quantiser,
|
quantiser,
|
||||||
d_feature_types,
|
d_feature_types,
|
||||||
@ -220,10 +222,10 @@ TEST(GpuHist, PartitionTwoFeatures) {
|
|||||||
{
|
{
|
||||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0});
|
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0});
|
||||||
auto feature_histogram = ConvertToInteger(
|
auto feature_histogram = ConvertToInteger(
|
||||||
{{-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}});
|
&ctx, {{-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}});
|
||||||
EvaluateSplitInputs input{0, 0, parent_sum, dh::ToSpan(feature_set),
|
EvaluateSplitInputs input{0, 0, parent_sum, dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram)};
|
dh::ToSpan(feature_histogram)};
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||||
EXPECT_EQ(result.findex, 1);
|
EXPECT_EQ(result.findex, 1);
|
||||||
EXPECT_EQ(cats, std::bitset<32>("11000000000000000000000000000000"));
|
EXPECT_EQ(cats, std::bitset<32>("11000000000000000000000000000000"));
|
||||||
@ -233,10 +235,10 @@ TEST(GpuHist, PartitionTwoFeatures) {
|
|||||||
{
|
{
|
||||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0});
|
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0});
|
||||||
auto feature_histogram = ConvertToInteger(
|
auto feature_histogram = ConvertToInteger(
|
||||||
{{-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0}});
|
&ctx, {{-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0}});
|
||||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram)};
|
dh::ToSpan(feature_histogram)};
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||||
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
auto cats = std::bitset<32>(evaluator.GetHostNodeCats(input.nidx)[0]);
|
||||||
EXPECT_EQ(result.findex, 1);
|
EXPECT_EQ(result.findex, 1);
|
||||||
EXPECT_EQ(cats, std::bitset<32>("10000000000000000000000000000000"));
|
EXPECT_EQ(cats, std::bitset<32>("10000000000000000000000000000000"));
|
||||||
@ -266,7 +268,7 @@ TEST(GpuHist, PartitionTwoNodes) {
|
|||||||
*std::max_element(cuts.cut_values_.HostVector().begin(), cuts.cut_values_.HostVector().end());
|
*std::max_element(cuts.cut_values_.HostVector().begin(), cuts.cut_values_.HostVector().end());
|
||||||
cuts.SetCategorical(true, max_cat);
|
cuts.SetCategorical(true, max_cat);
|
||||||
|
|
||||||
auto quantiser = DummyRoundingFactor();
|
auto quantiser = DummyRoundingFactor(&ctx);
|
||||||
EvaluateSplitSharedInputs shared_inputs{param,
|
EvaluateSplitSharedInputs shared_inputs{param,
|
||||||
quantiser,
|
quantiser,
|
||||||
d_feature_types,
|
d_feature_types,
|
||||||
@ -283,15 +285,16 @@ TEST(GpuHist, PartitionTwoNodes) {
|
|||||||
{
|
{
|
||||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0});
|
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0});
|
||||||
auto feature_histogram_a = ConvertToInteger(
|
auto feature_histogram_a = ConvertToInteger(
|
||||||
{{-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}});
|
&ctx, {{-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}});
|
||||||
thrust::device_vector<EvaluateSplitInputs> inputs(2);
|
thrust::device_vector<EvaluateSplitInputs> inputs(2);
|
||||||
inputs[0] = EvaluateSplitInputs{0, 0, parent_sum, dh::ToSpan(feature_set),
|
inputs[0] = EvaluateSplitInputs{0, 0, parent_sum, dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram_a)};
|
dh::ToSpan(feature_histogram_a)};
|
||||||
auto feature_histogram_b = ConvertToInteger({{-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}});
|
auto feature_histogram_b = ConvertToInteger(&ctx, {{-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}});
|
||||||
inputs[1] = EvaluateSplitInputs{1, 0, parent_sum, dh::ToSpan(feature_set),
|
inputs[1] = EvaluateSplitInputs{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram_b)};
|
dh::ToSpan(feature_histogram_b)};
|
||||||
thrust::device_vector<GPUExpandEntry> results(2);
|
thrust::device_vector<GPUExpandEntry> results(2);
|
||||||
evaluator.EvaluateSplits({0, 1}, 1, dh::ToSpan(inputs), shared_inputs, dh::ToSpan(results));
|
evaluator.EvaluateSplits(&ctx, {0, 1}, 1, dh::ToSpan(inputs), shared_inputs,
|
||||||
|
dh::ToSpan(results));
|
||||||
EXPECT_EQ(std::bitset<32>(evaluator.GetHostNodeCats(0)[0]),
|
EXPECT_EQ(std::bitset<32>(evaluator.GetHostNodeCats(0)[0]),
|
||||||
std::bitset<32>("10000000000000000000000000000000"));
|
std::bitset<32>("10000000000000000000000000000000"));
|
||||||
EXPECT_EQ(std::bitset<32>(evaluator.GetHostNodeCats(1)[0]),
|
EXPECT_EQ(std::bitset<32>(evaluator.GetHostNodeCats(1)[0]),
|
||||||
@ -301,7 +304,7 @@ TEST(GpuHist, PartitionTwoNodes) {
|
|||||||
|
|
||||||
void TestEvaluateSingleSplit(bool is_categorical) {
|
void TestEvaluateSingleSplit(bool is_categorical) {
|
||||||
auto ctx = MakeCUDACtx(0);
|
auto ctx = MakeCUDACtx(0);
|
||||||
auto quantiser = DummyRoundingFactor();
|
auto quantiser = DummyRoundingFactor(&ctx);
|
||||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
||||||
TrainParam tparam = ZeroParam();
|
TrainParam tparam = ZeroParam();
|
||||||
GPUTrainingParam param{tparam};
|
GPUTrainingParam param{tparam};
|
||||||
@ -311,7 +314,8 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
|||||||
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0, 1};
|
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0, 1};
|
||||||
|
|
||||||
// Setup gradients so that second feature gets higher gain
|
// Setup gradients so that second feature gets higher gain
|
||||||
auto feature_histogram = ConvertToInteger({{-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}});
|
auto feature_histogram =
|
||||||
|
ConvertToInteger(&ctx, {{-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}});
|
||||||
|
|
||||||
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
|
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
|
||||||
common::Span<FeatureType> d_feature_types;
|
common::Span<FeatureType> d_feature_types;
|
||||||
@ -336,7 +340,7 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
|||||||
ctx.Device()};
|
ctx.Device()};
|
||||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false,
|
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false,
|
||||||
ctx.Device());
|
ctx.Device());
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||||
|
|
||||||
EXPECT_EQ(result.findex, 1);
|
EXPECT_EQ(result.findex, 1);
|
||||||
if (is_categorical) {
|
if (is_categorical) {
|
||||||
@ -352,7 +356,8 @@ TEST(GpuHist, EvaluateSingleSplit) { TestEvaluateSingleSplit(false); }
|
|||||||
TEST(GpuHist, EvaluateSingleCategoricalSplit) { TestEvaluateSingleSplit(true); }
|
TEST(GpuHist, EvaluateSingleCategoricalSplit) { TestEvaluateSingleSplit(true); }
|
||||||
|
|
||||||
TEST(GpuHist, EvaluateSingleSplitMissing) {
|
TEST(GpuHist, EvaluateSingleSplitMissing) {
|
||||||
auto quantiser = DummyRoundingFactor();
|
auto ctx = MakeCUDACtx(0);
|
||||||
|
auto quantiser = DummyRoundingFactor(&ctx);
|
||||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{1.0, 1.5});
|
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{1.0, 1.5});
|
||||||
TrainParam tparam = ZeroParam();
|
TrainParam tparam = ZeroParam();
|
||||||
GPUTrainingParam param{tparam};
|
GPUTrainingParam param{tparam};
|
||||||
@ -361,7 +366,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
|
|||||||
thrust::device_vector<uint32_t> feature_segments = std::vector<bst_row_t>{0, 2};
|
thrust::device_vector<uint32_t> feature_segments = std::vector<bst_row_t>{0, 2};
|
||||||
thrust::device_vector<float> feature_values = std::vector<float>{1.0, 2.0};
|
thrust::device_vector<float> feature_values = std::vector<float>{1.0, 2.0};
|
||||||
thrust::device_vector<float> feature_min_values = std::vector<float>{0.0};
|
thrust::device_vector<float> feature_min_values = std::vector<float>{0.0};
|
||||||
auto feature_histogram = ConvertToInteger({{-0.5, 0.5}, {0.5, 0.5}});
|
auto feature_histogram = ConvertToInteger(&ctx, {{-0.5, 0.5}, {0.5, 0.5}});
|
||||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram)};
|
dh::ToSpan(feature_histogram)};
|
||||||
EvaluateSplitSharedInputs shared_inputs{param,
|
EvaluateSplitSharedInputs shared_inputs{param,
|
||||||
@ -373,7 +378,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
|
|||||||
false};
|
false};
|
||||||
|
|
||||||
GPUHistEvaluator evaluator(tparam, feature_set.size(), FstCU());
|
GPUHistEvaluator evaluator(tparam, feature_set.size(), FstCU());
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||||
|
|
||||||
EXPECT_EQ(result.findex, 0);
|
EXPECT_EQ(result.findex, 0);
|
||||||
EXPECT_EQ(result.fvalue, 1.0);
|
EXPECT_EQ(result.fvalue, 1.0);
|
||||||
@ -383,14 +388,15 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(GpuHist, EvaluateSingleSplitEmpty) {
|
TEST(GpuHist, EvaluateSingleSplitEmpty) {
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
TrainParam tparam = ZeroParam();
|
TrainParam tparam = ZeroParam();
|
||||||
GPUHistEvaluator evaluator(tparam, 1, FstCU());
|
GPUHistEvaluator evaluator(tparam, 1, FstCU());
|
||||||
DeviceSplitCandidate result =
|
DeviceSplitCandidate result =
|
||||||
evaluator
|
evaluator
|
||||||
.EvaluateSingleSplit(
|
.EvaluateSingleSplit(
|
||||||
EvaluateSplitInputs{},
|
&ctx, EvaluateSplitInputs{},
|
||||||
EvaluateSplitSharedInputs{
|
EvaluateSplitSharedInputs{
|
||||||
GPUTrainingParam(tparam), DummyRoundingFactor(), {}, {}, {}, {}, false})
|
GPUTrainingParam(tparam), DummyRoundingFactor(&ctx), {}, {}, {}, {}, false})
|
||||||
.split;
|
.split;
|
||||||
EXPECT_EQ(result.findex, -1);
|
EXPECT_EQ(result.findex, -1);
|
||||||
EXPECT_LT(result.loss_chg, 0.0f);
|
EXPECT_LT(result.loss_chg, 0.0f);
|
||||||
@ -398,7 +404,8 @@ TEST(GpuHist, EvaluateSingleSplitEmpty) {
|
|||||||
|
|
||||||
// Feature 0 has a better split, but the algorithm must select feature 1
|
// Feature 0 has a better split, but the algorithm must select feature 1
|
||||||
TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
||||||
auto quantiser = DummyRoundingFactor();
|
auto ctx = MakeCUDACtx(0);
|
||||||
|
auto quantiser = DummyRoundingFactor(&ctx);
|
||||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
||||||
TrainParam tparam = ZeroParam();
|
TrainParam tparam = ZeroParam();
|
||||||
tparam.UpdateAllowUnknown(Args{});
|
tparam.UpdateAllowUnknown(Args{});
|
||||||
@ -408,7 +415,8 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
|||||||
thrust::device_vector<uint32_t> feature_segments = std::vector<bst_row_t>{0, 2, 4};
|
thrust::device_vector<uint32_t> feature_segments = std::vector<bst_row_t>{0, 2, 4};
|
||||||
thrust::device_vector<float> feature_values = std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
thrust::device_vector<float> feature_values = std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||||
thrust::device_vector<float> feature_min_values = std::vector<float>{0.0, 10.0};
|
thrust::device_vector<float> feature_min_values = std::vector<float>{0.0, 10.0};
|
||||||
auto feature_histogram = ConvertToInteger({{-10.0, 0.5}, {10.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}});
|
auto feature_histogram =
|
||||||
|
ConvertToInteger(&ctx, {{-10.0, 0.5}, {10.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}});
|
||||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram)};
|
dh::ToSpan(feature_histogram)};
|
||||||
EvaluateSplitSharedInputs shared_inputs{param,
|
EvaluateSplitSharedInputs shared_inputs{param,
|
||||||
@ -420,7 +428,7 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
|||||||
false};
|
false};
|
||||||
|
|
||||||
GPUHistEvaluator evaluator(tparam, feature_min_values.size(), FstCU());
|
GPUHistEvaluator evaluator(tparam, feature_min_values.size(), FstCU());
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||||
|
|
||||||
EXPECT_EQ(result.findex, 1);
|
EXPECT_EQ(result.findex, 1);
|
||||||
EXPECT_EQ(result.fvalue, 11.0);
|
EXPECT_EQ(result.fvalue, 11.0);
|
||||||
@ -430,7 +438,8 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
|||||||
|
|
||||||
// Features 0 and 1 have identical gain, the algorithm must select 0
|
// Features 0 and 1 have identical gain, the algorithm must select 0
|
||||||
TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
||||||
auto quantiser = DummyRoundingFactor();
|
auto ctx = MakeCUDACtx(0);
|
||||||
|
auto quantiser = DummyRoundingFactor(&ctx);
|
||||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
||||||
TrainParam tparam = ZeroParam();
|
TrainParam tparam = ZeroParam();
|
||||||
tparam.UpdateAllowUnknown(Args{});
|
tparam.UpdateAllowUnknown(Args{});
|
||||||
@ -440,7 +449,8 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
|||||||
thrust::device_vector<uint32_t> feature_segments = std::vector<bst_row_t>{0, 2, 4};
|
thrust::device_vector<uint32_t> feature_segments = std::vector<bst_row_t>{0, 2, 4};
|
||||||
thrust::device_vector<float> feature_values = std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
thrust::device_vector<float> feature_values = std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||||
thrust::device_vector<float> feature_min_values = std::vector<float>{0.0, 10.0};
|
thrust::device_vector<float> feature_min_values = std::vector<float>{0.0, 10.0};
|
||||||
auto feature_histogram = ConvertToInteger({{-0.5, 0.5}, {0.5, 0.5}, {-0.5, 0.5}, {0.5, 0.5}});
|
auto feature_histogram =
|
||||||
|
ConvertToInteger(&ctx, {{-0.5, 0.5}, {0.5, 0.5}, {-0.5, 0.5}, {0.5, 0.5}});
|
||||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram)};
|
dh::ToSpan(feature_histogram)};
|
||||||
EvaluateSplitSharedInputs shared_inputs{param,
|
EvaluateSplitSharedInputs shared_inputs{param,
|
||||||
@ -452,15 +462,16 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
|||||||
false};
|
false};
|
||||||
|
|
||||||
GPUHistEvaluator evaluator(tparam, feature_min_values.size(), FstCU());
|
GPUHistEvaluator evaluator(tparam, feature_min_values.size(), FstCU());
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||||
|
|
||||||
EXPECT_EQ(result.findex, 0);
|
EXPECT_EQ(result.findex, 0);
|
||||||
EXPECT_EQ(result.fvalue, 1.0);
|
EXPECT_EQ(result.fvalue, 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GpuHist, EvaluateSplits) {
|
TEST(GpuHist, EvaluateSplits) {
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
thrust::device_vector<DeviceSplitCandidate> out_splits(2);
|
thrust::device_vector<DeviceSplitCandidate> out_splits(2);
|
||||||
auto quantiser = DummyRoundingFactor();
|
auto quantiser = DummyRoundingFactor(&ctx);
|
||||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
||||||
TrainParam tparam = ZeroParam();
|
TrainParam tparam = ZeroParam();
|
||||||
tparam.UpdateAllowUnknown(Args{});
|
tparam.UpdateAllowUnknown(Args{});
|
||||||
@ -471,9 +482,9 @@ TEST(GpuHist, EvaluateSplits) {
|
|||||||
thrust::device_vector<float> feature_values = std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
thrust::device_vector<float> feature_values = std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||||
thrust::device_vector<float> feature_min_values = std::vector<float>{0.0, 0.0};
|
thrust::device_vector<float> feature_min_values = std::vector<float>{0.0, 0.0};
|
||||||
auto feature_histogram_left =
|
auto feature_histogram_left =
|
||||||
ConvertToInteger({{-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}});
|
ConvertToInteger(&ctx, {{-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}});
|
||||||
auto feature_histogram_right =
|
auto feature_histogram_right =
|
||||||
ConvertToInteger({{-1.0, 0.5}, {1.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}});
|
ConvertToInteger(&ctx, {{-1.0, 0.5}, {1.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}});
|
||||||
EvaluateSplitInputs input_left{1, 0, parent_sum, dh::ToSpan(feature_set),
|
EvaluateSplitInputs input_left{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||||
dh::ToSpan(feature_histogram_left)};
|
dh::ToSpan(feature_histogram_left)};
|
||||||
EvaluateSplitInputs input_right{2, 0, parent_sum, dh::ToSpan(feature_set),
|
EvaluateSplitInputs input_right{2, 0, parent_sum, dh::ToSpan(feature_set),
|
||||||
@ -514,7 +525,7 @@ TEST_F(TestPartitionBasedSplit, GpuHist) {
|
|||||||
evaluator.Reset(cuts_, dh::ToSpan(ft), info_.num_col_, param_, false, ctx.Device());
|
evaluator.Reset(cuts_, dh::ToSpan(ft), info_.num_col_, param_, false, ctx.Device());
|
||||||
|
|
||||||
// Convert the sample histogram to fixed point
|
// Convert the sample histogram to fixed point
|
||||||
auto quantiser = DummyRoundingFactor();
|
auto quantiser = DummyRoundingFactor(&ctx);
|
||||||
thrust::host_vector<GradientPairInt64> h_hist;
|
thrust::host_vector<GradientPairInt64> h_hist;
|
||||||
for (auto e : hist_[0]) {
|
for (auto e : hist_[0]) {
|
||||||
h_hist.push_back(quantiser.ToFixedPoint(e));
|
h_hist.push_back(quantiser.ToFixedPoint(e));
|
||||||
@ -531,7 +542,7 @@ TEST_F(TestPartitionBasedSplit, GpuHist) {
|
|||||||
cuts_.cut_values_.ConstDeviceSpan(),
|
cuts_.cut_values_.ConstDeviceSpan(),
|
||||||
cuts_.min_vals_.ConstDeviceSpan(),
|
cuts_.min_vals_.ConstDeviceSpan(),
|
||||||
false};
|
false};
|
||||||
auto split = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
auto split = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||||
ASSERT_NEAR(split.loss_chg, best_score_, 1e-2);
|
ASSERT_NEAR(split.loss_chg, best_score_, 1e-2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -541,7 +552,7 @@ namespace {
|
|||||||
void VerifyColumnSplitEvaluateSingleSplit(bool is_categorical) {
|
void VerifyColumnSplitEvaluateSingleSplit(bool is_categorical) {
|
||||||
auto ctx = MakeCUDACtx(GPUIDX);
|
auto ctx = MakeCUDACtx(GPUIDX);
|
||||||
auto rank = collective::GetRank();
|
auto rank = collective::GetRank();
|
||||||
auto quantiser = DummyRoundingFactor();
|
auto quantiser = DummyRoundingFactor(&ctx);
|
||||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
||||||
TrainParam tparam = ZeroParam();
|
TrainParam tparam = ZeroParam();
|
||||||
GPUTrainingParam param{tparam};
|
GPUTrainingParam param{tparam};
|
||||||
@ -552,8 +563,8 @@ void VerifyColumnSplitEvaluateSingleSplit(bool is_categorical) {
|
|||||||
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0, 1};
|
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0, 1};
|
||||||
|
|
||||||
// Setup gradients so that second feature gets higher gain
|
// Setup gradients so that second feature gets higher gain
|
||||||
auto feature_histogram = rank == 0 ? ConvertToInteger({{-0.5, 0.5}, {0.5, 0.5}})
|
auto feature_histogram = rank == 0 ? ConvertToInteger(&ctx, {{-0.5, 0.5}, {0.5, 0.5}})
|
||||||
: ConvertToInteger({{-1.0, 0.5}, {1.0, 0.5}});
|
: ConvertToInteger(&ctx, {{-1.0, 0.5}, {1.0, 0.5}});
|
||||||
|
|
||||||
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
|
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
|
||||||
common::Span<FeatureType> d_feature_types;
|
common::Span<FeatureType> d_feature_types;
|
||||||
@ -576,7 +587,7 @@ void VerifyColumnSplitEvaluateSingleSplit(bool is_categorical) {
|
|||||||
|
|
||||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()};
|
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), ctx.Device()};
|
||||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, true, ctx.Device());
|
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, true, ctx.Device());
|
||||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split;
|
||||||
|
|
||||||
EXPECT_EQ(result.findex, 1) << "rank: " << rank;
|
EXPECT_EQ(result.findex, 1) << "rank: " << rank;
|
||||||
if (is_categorical) {
|
if (is_categorical) {
|
||||||
|
|||||||
@ -37,7 +37,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
|||||||
FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size,
|
FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size,
|
||||||
sizeof(GradientPairInt64));
|
sizeof(GradientPairInt64));
|
||||||
|
|
||||||
auto quantiser = GradientQuantiser(gpair.DeviceSpan(), MetaInfo());
|
auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo());
|
||||||
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(FstCU()),
|
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(FstCU()),
|
||||||
feature_groups.DeviceAccessor(FstCU()), gpair.DeviceSpan(), ridx,
|
feature_groups.DeviceAccessor(FstCU()), gpair.DeviceSpan(), ridx,
|
||||||
d_histogram, quantiser);
|
d_histogram, quantiser);
|
||||||
@ -51,7 +51,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
|||||||
dh::device_vector<GradientPairInt64> new_histogram(num_bins);
|
dh::device_vector<GradientPairInt64> new_histogram(num_bins);
|
||||||
auto d_new_histogram = dh::ToSpan(new_histogram);
|
auto d_new_histogram = dh::ToSpan(new_histogram);
|
||||||
|
|
||||||
auto quantiser = GradientQuantiser(gpair.DeviceSpan(), MetaInfo());
|
auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo());
|
||||||
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(FstCU()),
|
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(FstCU()),
|
||||||
feature_groups.DeviceAccessor(FstCU()), gpair.DeviceSpan(), ridx,
|
feature_groups.DeviceAccessor(FstCU()), gpair.DeviceSpan(), ridx,
|
||||||
d_new_histogram, quantiser);
|
d_new_histogram, quantiser);
|
||||||
@ -129,7 +129,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
|||||||
dh::device_vector<GradientPairInt64> cat_hist(num_categories);
|
dh::device_vector<GradientPairInt64> cat_hist(num_categories);
|
||||||
auto gpair = GenerateRandomGradients(kRows, 0, 2);
|
auto gpair = GenerateRandomGradients(kRows, 0, 2);
|
||||||
gpair.SetDevice(DeviceOrd::CUDA(0));
|
gpair.SetDevice(DeviceOrd::CUDA(0));
|
||||||
auto quantiser = GradientQuantiser(gpair.DeviceSpan(), MetaInfo());
|
auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo());
|
||||||
/**
|
/**
|
||||||
* Generate hist with cat data.
|
* Generate hist with cat data.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -181,7 +181,7 @@ void TestSyncHist(bool is_distributed) {
|
|||||||
|
|
||||||
histogram.Buffer().Reset(1, n_nodes, space, target_hists);
|
histogram.Buffer().Reset(1, n_nodes, space, target_hists);
|
||||||
// sync hist
|
// sync hist
|
||||||
histogram.SyncHistogram(&tree, nodes_for_explicit_hist_build, nodes_for_subtraction_trick);
|
histogram.SyncHistogram(&ctx, &tree, nodes_for_explicit_hist_build, nodes_for_subtraction_trick);
|
||||||
|
|
||||||
using GHistRowT = common::GHistRow;
|
using GHistRowT = common::GHistRow;
|
||||||
auto check_hist = [](const GHistRowT parent, const GHistRowT left, const GHistRowT right,
|
auto check_hist = [](const GHistRowT parent, const GHistRowT left, const GHistRowT right,
|
||||||
@ -266,7 +266,7 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
|
|||||||
histogram.BuildHist(0, space, gidx, row_set_collection, nodes_to_build,
|
histogram.BuildHist(0, space, gidx, row_set_collection, nodes_to_build,
|
||||||
linalg::MakeTensorView(&ctx, gpair, gpair.size()), force_read_by_column);
|
linalg::MakeTensorView(&ctx, gpair, gpair.size()), force_read_by_column);
|
||||||
}
|
}
|
||||||
histogram.SyncHistogram(&tree, nodes_to_build, {});
|
histogram.SyncHistogram(&ctx, &tree, nodes_to_build, {});
|
||||||
|
|
||||||
// Check if number of histogram bins is correct
|
// Check if number of histogram bins is correct
|
||||||
ASSERT_EQ(histogram.Histogram()[nid].size(), gmat.cut.Ptrs().back());
|
ASSERT_EQ(histogram.Histogram()[nid].size(), gmat.cut.Ptrs().back());
|
||||||
@ -366,7 +366,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
|||||||
linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size()),
|
linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size()),
|
||||||
force_read_by_column);
|
force_read_by_column);
|
||||||
}
|
}
|
||||||
cat_hist.SyncHistogram(&tree, nodes_to_build, {});
|
cat_hist.SyncHistogram(&ctx, &tree, nodes_to_build, {});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate hist with one hot encoded data.
|
* Generate hist with one hot encoded data.
|
||||||
@ -382,7 +382,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
|||||||
linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size()),
|
linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size()),
|
||||||
force_read_by_column);
|
force_read_by_column);
|
||||||
}
|
}
|
||||||
onehot_hist.SyncHistogram(&tree, nodes_to_build, {});
|
onehot_hist.SyncHistogram(&ctx, &tree, nodes_to_build, {});
|
||||||
|
|
||||||
auto cat = cat_hist.Histogram()[0];
|
auto cat = cat_hist.Histogram()[0];
|
||||||
auto onehot = onehot_hist.Histogram()[0];
|
auto onehot = onehot_hist.Histogram()[0];
|
||||||
@ -451,7 +451,7 @@ void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, boo
|
|||||||
force_read_by_column);
|
force_read_by_column);
|
||||||
++page_idx;
|
++page_idx;
|
||||||
}
|
}
|
||||||
multi_build.SyncHistogram(&tree, nodes, {});
|
multi_build.SyncHistogram(ctx, &tree, nodes, {});
|
||||||
|
|
||||||
multi_page = multi_build.Histogram()[RegTree::kRoot];
|
multi_page = multi_build.Histogram()[RegTree::kRoot];
|
||||||
}
|
}
|
||||||
@ -480,7 +480,7 @@ void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, boo
|
|||||||
single_build.BuildHist(0, space, gmat, row_set_collection, nodes,
|
single_build.BuildHist(0, space, gmat, row_set_collection, nodes,
|
||||||
linalg::MakeTensorView(ctx, h_gpair, h_gpair.size()),
|
linalg::MakeTensorView(ctx, h_gpair, h_gpair.size()),
|
||||||
force_read_by_column);
|
force_read_by_column);
|
||||||
single_build.SyncHistogram(&tree, nodes, {});
|
single_build.SyncHistogram(ctx, &tree, nodes, {});
|
||||||
|
|
||||||
single_page = single_build.Histogram()[RegTree::kRoot];
|
single_page = single_build.Histogram()[RegTree::kRoot];
|
||||||
}
|
}
|
||||||
@ -570,7 +570,7 @@ class OverflowTest : public ::testing::TestWithParam<std::tuple<bool, bool>> {
|
|||||||
CHECK_NE(partitioners.front()[tree.RightChild(best.nid)].Size(), 0);
|
CHECK_NE(partitioners.front()[tree.RightChild(best.nid)].Size(), 0);
|
||||||
|
|
||||||
hist_builder.BuildHistLeftRight(
|
hist_builder.BuildHistLeftRight(
|
||||||
Xy.get(), &tree, partitioners, valid_candidates,
|
&ctx, Xy.get(), &tree, partitioners, valid_candidates,
|
||||||
linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size(), 1), batch);
|
linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size(), 1), batch);
|
||||||
|
|
||||||
if (limit) {
|
if (limit) {
|
||||||
|
|||||||
@ -111,7 +111,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
|||||||
maker.hist.AllocateHistograms({0});
|
maker.hist.AllocateHistograms({0});
|
||||||
|
|
||||||
maker.gpair = gpair.DeviceSpan();
|
maker.gpair = gpair.DeviceSpan();
|
||||||
maker.quantiser = std::make_unique<GradientQuantiser>(maker.gpair, MetaInfo());
|
maker.quantiser = std::make_unique<GradientQuantiser>(&ctx, maker.gpair, MetaInfo());
|
||||||
maker.page = page.get();
|
maker.page = page.get();
|
||||||
|
|
||||||
maker.InitFeatureGroupsOnce();
|
maker.InitFeatureGroupsOnce();
|
||||||
@ -162,12 +162,6 @@ HistogramCutsWrapper GetHostCutMatrix () {
|
|||||||
return cmat;
|
return cmat;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline GradientQuantiser DummyRoundingFactor() {
|
|
||||||
thrust::device_vector<GradientPair> gpair(1);
|
|
||||||
gpair[0] = {1000.f, 1000.f}; // Tests should not exceed sum of 1000
|
|
||||||
return {dh::ToSpan(gpair), MetaInfo()};
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestHistogramIndexImpl() {
|
void TestHistogramIndexImpl() {
|
||||||
// Test if the compressed histogram index matches when using a sparse
|
// Test if the compressed histogram index matches when using a sparse
|
||||||
// dmatrix with and without using external memory
|
// dmatrix with and without using external memory
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user