Support column split in GPU evaluate splits (#9511)
This commit is contained in:
parent
8c10af45a0
commit
6103dca0bb
@ -57,6 +57,20 @@ inline void AllReduce(int device, double *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gather values from all all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param send_buffer Buffer storing the data to be sent.
|
||||
* @param receive_buffer Buffer storing the gathered data.
|
||||
* @param send_size Size of the sent data in bytes.
|
||||
*/
|
||||
inline void AllGather(int device, void const *send_buffer, void *receive_buffer,
|
||||
std::size_t send_size) {
|
||||
Communicator::GetDevice(device)->AllGather(send_buffer, receive_buffer, send_size);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
* @param device ID of the device.
|
||||
|
||||
@ -27,6 +27,17 @@ class DeviceCommunicator {
|
||||
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gather values from all all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param send_buffer Buffer storing the data to be sent.
|
||||
* @param receive_buffer Buffer storing the gathered data.
|
||||
* @param send_size Size of the sent data in bytes.
|
||||
*/
|
||||
virtual void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
* @param send_buffer Buffer storing the input data.
|
||||
|
||||
@ -28,12 +28,26 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto size = count * GetTypeSize(data_type);
|
||||
host_buffer_.reserve(size);
|
||||
host_buffer_.resize(size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
Allreduce(host_buffer_.data(), count, data_type, op);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
host_buffer_.resize(send_size * world_size_);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + rank_ * send_size, send_buffer, send_size,
|
||||
cudaMemcpyDefault));
|
||||
Allgather(host_buffer_.data(), host_buffer_.size());
|
||||
dh::safe_cuda(
|
||||
cudaMemcpy(receive_buffer, host_buffer_.data(), host_buffer_.size(), cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override {
|
||||
if (world_size_ == 1) {
|
||||
@ -49,7 +63,7 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||
receive_buffer->resize(total_bytes);
|
||||
|
||||
host_buffer_.reserve(total_bytes);
|
||||
host_buffer_.resize(total_bytes);
|
||||
size_t offset = 0;
|
||||
for (int32_t i = 0; i < world_size_; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
|
||||
@ -178,6 +178,17 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_buffer,
|
||||
std::size_t send_size) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclAllGather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
|
||||
dh::DefaultStream()));
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
|
||||
std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) {
|
||||
|
||||
@ -29,6 +29,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
~NcclDeviceCommunicator() override;
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override;
|
||||
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override;
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override;
|
||||
void Synchronize() override;
|
||||
|
||||
@ -5,8 +5,8 @@
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
|
||||
#include "../../collective/communicator-inl.cuh"
|
||||
#include "../../common/categorical.h"
|
||||
#include "../../common/device_helpers.cuh"
|
||||
#include "../../data/ellpack_page.cuh"
|
||||
#include "evaluate_splits.cuh"
|
||||
#include "expand_entry.cuh"
|
||||
@ -409,6 +409,23 @@ void GPUHistEvaluator::EvaluateSplits(
|
||||
this->LaunchEvaluateSplits(max_active_features, d_inputs, shared_inputs,
|
||||
evaluator, out_splits);
|
||||
|
||||
if (is_column_split_) {
|
||||
// With column-wise data split, we gather the split candidates from all the workers and find the
|
||||
// global best candidates.
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
dh::TemporaryArray<DeviceSplitCandidate> all_candidate_storage(out_splits.size() * world_size);
|
||||
auto all_candidates = dh::ToSpan(all_candidate_storage);
|
||||
collective::AllGather(device_, out_splits.data(), all_candidates.data(),
|
||||
out_splits.size() * sizeof(DeviceSplitCandidate));
|
||||
|
||||
// Reduce to get the best candidate from all workers.
|
||||
dh::LaunchN(out_splits.size(), [world_size, all_candidates, out_splits] __device__(size_t i) {
|
||||
for (auto rank = 0; rank < world_size; rank++) {
|
||||
out_splits[i] = out_splits[i] + all_candidates[rank * out_splits.size() + i];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
auto d_sorted_idx = this->SortedIdx(d_inputs.size(), shared_inputs.feature_values.size());
|
||||
auto d_entries = out_entries;
|
||||
auto device_cats_accessor = this->DeviceCatStorage(nidx);
|
||||
|
||||
@ -83,6 +83,9 @@ class GPUHistEvaluator {
|
||||
// Number of elements of categorical storage type
|
||||
// needed to hold categoricals for a single mode
|
||||
std::size_t node_categorical_storage_size_ = 0;
|
||||
// Is the data split column-wise?
|
||||
bool is_column_split_ = false;
|
||||
int32_t device_;
|
||||
|
||||
// Copy the categories from device to host asynchronously.
|
||||
void CopyToHost( const std::vector<bst_node_t>& nidx);
|
||||
@ -136,7 +139,8 @@ class GPUHistEvaluator {
|
||||
* \brief Reset the evaluator, should be called before any use.
|
||||
*/
|
||||
void Reset(common::HistogramCuts const &cuts, common::Span<FeatureType const> ft,
|
||||
bst_feature_t n_features, TrainParam const ¶m, int32_t device);
|
||||
bst_feature_t n_features, TrainParam const ¶m, bool is_column_split,
|
||||
int32_t device);
|
||||
|
||||
/**
|
||||
* \brief Get host category storage for nidx. Different from the internal version, this
|
||||
|
||||
@ -14,10 +14,9 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts,
|
||||
common::Span<FeatureType const> ft,
|
||||
bst_feature_t n_features, TrainParam const ¶m,
|
||||
int32_t device) {
|
||||
void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts, common::Span<FeatureType const> ft,
|
||||
bst_feature_t n_features, TrainParam const ¶m,
|
||||
bool is_column_split, int32_t device) {
|
||||
param_ = param;
|
||||
tree_evaluator_ = TreeEvaluator{param, n_features, device};
|
||||
has_categoricals_ = cuts.HasCategorical();
|
||||
@ -65,6 +64,8 @@ void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts,
|
||||
return fidx;
|
||||
});
|
||||
}
|
||||
is_column_split_ = is_column_split;
|
||||
device_ = device;
|
||||
}
|
||||
|
||||
common::Span<bst_feature_t const> GPUHistEvaluator::SortHistogram(
|
||||
|
||||
@ -242,7 +242,8 @@ struct GPUHistMakerDevice {
|
||||
page = sample.page;
|
||||
gpair = sample.gpair;
|
||||
|
||||
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, ctx_->gpu_id);
|
||||
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param,
|
||||
dmat->Info().IsColumnSplit(), ctx_->gpu_id);
|
||||
|
||||
quantiser.reset(new GradientQuantiser(this->gpair));
|
||||
|
||||
|
||||
@ -11,8 +11,8 @@
|
||||
#include "../../../plugin/federated/federated_communicator.h"
|
||||
#include "../../../src/collective/communicator-inl.cuh"
|
||||
#include "../../../src/collective/device_communicator_adapter.cuh"
|
||||
#include "./helpers.h"
|
||||
#include "../helpers.h"
|
||||
#include "./helpers.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
@ -45,6 +45,28 @@ TEST_F(FederatedAdapterTest, MGPUAllReduceSum) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllReduceSum);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void VerifyAllGather() {
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
auto const device = GPUIDX;
|
||||
common::SetDevice(device);
|
||||
thrust::device_vector<double> send_buffer(1, rank);
|
||||
thrust::device_vector<double> receive_buffer(world_size, 0);
|
||||
collective::AllGather(device, send_buffer.data().get(), receive_buffer.data().get(),
|
||||
sizeof(double));
|
||||
thrust::host_vector<double> host_buffer = receive_buffer;
|
||||
EXPECT_EQ(host_buffer.size(), world_size);
|
||||
for (auto i = 0; i < world_size; i++) {
|
||||
EXPECT_EQ(host_buffer[i], i);
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST_F(FederatedAdapterTest, MGPUAllGather) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllGather);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void VerifyAllGatherV() {
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
|
||||
@ -2,24 +2,23 @@
|
||||
* Copyright 2020-2022 by XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <thrust/host_vector.h>
|
||||
|
||||
#include "../../../../src/tree/gpu_hist/evaluate_splits.cuh"
|
||||
#include "../../helpers.h"
|
||||
#include "../../histogram_helpers.h"
|
||||
#include "../test_evaluate_splits.h" // TestPartitionBasedSplit
|
||||
#include <thrust/host_vector.h>
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
namespace {
|
||||
auto ZeroParam() {
|
||||
auto args = Args{{"min_child_weight", "0"},
|
||||
{"lambda", "0"}};
|
||||
auto args = Args{{"min_child_weight", "0"}, {"lambda", "0"}};
|
||||
TrainParam tparam;
|
||||
tparam.UpdateAllowUnknown(args);
|
||||
return tparam;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
inline GradientQuantiser DummyRoundingFactor() {
|
||||
@ -37,7 +36,6 @@ thrust::device_vector<GradientPairInt64> ConvertToInteger(std::vector<GradientPa
|
||||
return y;
|
||||
}
|
||||
|
||||
|
||||
TEST_F(TestCategoricalSplitWithMissing, GPUHistEvaluator) {
|
||||
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0};
|
||||
GPUTrainingParam param{param_};
|
||||
@ -61,12 +59,13 @@ TEST_F(TestCategoricalSplitWithMissing, GPUHistEvaluator) {
|
||||
|
||||
GPUHistEvaluator evaluator{param_, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||
|
||||
evaluator.Reset(cuts_, dh::ToSpan(feature_types), feature_set.size(), param_, 0);
|
||||
evaluator.Reset(cuts_, dh::ToSpan(feature_types), feature_set.size(), param_, false, 0);
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
|
||||
ASSERT_EQ(result.thresh, 1);
|
||||
this->CheckResult(result.loss_chg, result.findex, result.fvalue, result.is_cat,
|
||||
result.dir == kLeftDir, quantiser.ToFloatingPoint(result.left_sum), quantiser.ToFloatingPoint(result.right_sum));
|
||||
result.dir == kLeftDir, quantiser.ToFloatingPoint(result.left_sum),
|
||||
quantiser.ToFloatingPoint(result.right_sum));
|
||||
}
|
||||
|
||||
TEST(GpuHist, PartitionBasic) {
|
||||
@ -102,7 +101,7 @@ TEST(GpuHist, PartitionBasic) {
|
||||
};
|
||||
|
||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false, 0);
|
||||
|
||||
{
|
||||
// -1.0s go right
|
||||
@ -143,7 +142,8 @@ TEST(GpuHist, PartitionBasic) {
|
||||
EXPECT_EQ(result.left_sum + result.right_sum, parent_sum);
|
||||
}
|
||||
// With 3.0/3.0 missing values
|
||||
// Forward, first 2 categories are selected, while the last one go to left along with missing value
|
||||
// Forward, first 2 categories are selected, while the last one go to left along with missing
|
||||
// value
|
||||
{
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 6.0});
|
||||
auto feature_histogram = ConvertToInteger({{-1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}});
|
||||
@ -213,11 +213,12 @@ TEST(GpuHist, PartitionTwoFeatures) {
|
||||
false};
|
||||
|
||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false, 0);
|
||||
|
||||
{
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0});
|
||||
auto feature_histogram = ConvertToInteger({ {-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}});
|
||||
auto feature_histogram = ConvertToInteger(
|
||||
{{-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}});
|
||||
EvaluateSplitInputs input{0, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
@ -229,7 +230,8 @@ TEST(GpuHist, PartitionTwoFeatures) {
|
||||
|
||||
{
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0});
|
||||
auto feature_histogram = ConvertToInteger({ {-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0}});
|
||||
auto feature_histogram = ConvertToInteger(
|
||||
{{-2.0, 1.0}, {-2.0, 1.0}, {-2.0, 1.0}, {-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0}});
|
||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
@ -271,12 +273,12 @@ TEST(GpuHist, PartitionTwoNodes) {
|
||||
false};
|
||||
|
||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false, 0);
|
||||
|
||||
{
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{-6.0, 3.0});
|
||||
auto feature_histogram_a = ConvertToInteger({{-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0},
|
||||
{-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}});
|
||||
auto feature_histogram_a = ConvertToInteger(
|
||||
{{-1.0, 1.0}, {-2.5, 1.0}, {-2.5, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {-4.0, 1.0}});
|
||||
thrust::device_vector<EvaluateSplitInputs> inputs(2);
|
||||
inputs[0] = EvaluateSplitInputs{0, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram_a)};
|
||||
@ -304,8 +306,7 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
// Setup gradients so that second feature gets higher gain
|
||||
auto feature_histogram = ConvertToInteger({{-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}});
|
||||
|
||||
dh::device_vector<FeatureType> feature_types(feature_set.size(),
|
||||
FeatureType::kCategorical);
|
||||
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
|
||||
common::Span<FeatureType> d_feature_types;
|
||||
if (is_categorical) {
|
||||
auto max_cat = *std::max_element(cuts.cut_values_.HostVector().begin(),
|
||||
@ -324,9 +325,8 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
cuts.min_vals_.ConstDeviceSpan(),
|
||||
false};
|
||||
|
||||
GPUHistEvaluator evaluator{
|
||||
tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
|
||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, false, 0);
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
|
||||
EXPECT_EQ(result.findex, 1);
|
||||
@ -338,31 +338,23 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
EXPECT_EQ(result.left_sum + result.right_sum, parent_sum);
|
||||
}
|
||||
|
||||
TEST(GpuHist, EvaluateSingleSplit) {
|
||||
TestEvaluateSingleSplit(false);
|
||||
}
|
||||
TEST(GpuHist, EvaluateSingleSplit) { TestEvaluateSingleSplit(false); }
|
||||
|
||||
TEST(GpuHist, EvaluateSingleCategoricalSplit) {
|
||||
TestEvaluateSingleSplit(true);
|
||||
}
|
||||
TEST(GpuHist, EvaluateSingleCategoricalSplit) { TestEvaluateSingleSplit(true); }
|
||||
|
||||
TEST(GpuHist, EvaluateSingleSplitMissing) {
|
||||
auto quantiser = DummyRoundingFactor();
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{1.0, 1.5});
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{1.0, 1.5});
|
||||
TrainParam tparam = ZeroParam();
|
||||
GPUTrainingParam param{tparam};
|
||||
|
||||
thrust::device_vector<bst_feature_t> feature_set =
|
||||
std::vector<bst_feature_t>{0};
|
||||
thrust::device_vector<uint32_t> feature_segments =
|
||||
std::vector<bst_row_t>{0, 2};
|
||||
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0};
|
||||
thrust::device_vector<uint32_t> feature_segments = std::vector<bst_row_t>{0, 2};
|
||||
thrust::device_vector<float> feature_values = std::vector<float>{1.0, 2.0};
|
||||
thrust::device_vector<float> feature_min_values = std::vector<float>{0.0};
|
||||
auto feature_histogram = ConvertToInteger({{-0.5, 0.5}, {0.5, 0.5}});
|
||||
EvaluateSplitInputs input{1,0,
|
||||
parent_sum,
|
||||
dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
EvaluateSplitSharedInputs shared_inputs{param,
|
||||
quantiser,
|
||||
{},
|
||||
@ -377,7 +369,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
|
||||
EXPECT_EQ(result.findex, 0);
|
||||
EXPECT_EQ(result.fvalue, 1.0);
|
||||
EXPECT_EQ(result.dir, kRightDir);
|
||||
EXPECT_EQ(result.left_sum,quantiser.ToFixedPoint(GradientPairPrecise(-0.5, 0.5)));
|
||||
EXPECT_EQ(result.left_sum, quantiser.ToFixedPoint(GradientPairPrecise(-0.5, 0.5)));
|
||||
EXPECT_EQ(result.right_sum, quantiser.ToFixedPoint(GradientPairPrecise(1.5, 1.0)));
|
||||
}
|
||||
|
||||
@ -398,24 +390,18 @@ TEST(GpuHist, EvaluateSingleSplitEmpty) {
|
||||
// Feature 0 has a better split, but the algorithm must select feature 1
|
||||
TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
||||
auto quantiser = DummyRoundingFactor();
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
||||
TrainParam tparam = ZeroParam();
|
||||
tparam.UpdateAllowUnknown(Args{});
|
||||
GPUTrainingParam param{tparam};
|
||||
|
||||
thrust::device_vector<bst_feature_t> feature_set =
|
||||
std::vector<bst_feature_t>{1};
|
||||
thrust::device_vector<uint32_t> feature_segments =
|
||||
std::vector<bst_row_t>{0, 2, 4};
|
||||
thrust::device_vector<float> feature_values =
|
||||
std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
thrust::device_vector<float> feature_min_values =
|
||||
std::vector<float>{0.0, 10.0};
|
||||
auto feature_histogram = ConvertToInteger({ {-10.0, 0.5}, {10.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}});
|
||||
EvaluateSplitInputs input{1,0,
|
||||
parent_sum,
|
||||
dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{1};
|
||||
thrust::device_vector<uint32_t> feature_segments = std::vector<bst_row_t>{0, 2, 4};
|
||||
thrust::device_vector<float> feature_values = std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
thrust::device_vector<float> feature_min_values = std::vector<float>{0.0, 10.0};
|
||||
auto feature_histogram = ConvertToInteger({{-10.0, 0.5}, {10.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}});
|
||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
EvaluateSplitSharedInputs shared_inputs{param,
|
||||
quantiser,
|
||||
{},
|
||||
@ -429,31 +415,25 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
||||
|
||||
EXPECT_EQ(result.findex, 1);
|
||||
EXPECT_EQ(result.fvalue, 11.0);
|
||||
EXPECT_EQ(result.left_sum,quantiser.ToFixedPoint(GradientPairPrecise(-0.5, 0.5)));
|
||||
EXPECT_EQ(result.left_sum, quantiser.ToFixedPoint(GradientPairPrecise(-0.5, 0.5)));
|
||||
EXPECT_EQ(result.right_sum, quantiser.ToFixedPoint(GradientPairPrecise(0.5, 0.5)));
|
||||
}
|
||||
|
||||
// Features 0 and 1 have identical gain, the algorithm must select 0
|
||||
TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
||||
auto quantiser = DummyRoundingFactor();
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
||||
TrainParam tparam = ZeroParam();
|
||||
tparam.UpdateAllowUnknown(Args{});
|
||||
GPUTrainingParam param{tparam};
|
||||
|
||||
thrust::device_vector<bst_feature_t> feature_set =
|
||||
std::vector<bst_feature_t>{0, 1};
|
||||
thrust::device_vector<uint32_t> feature_segments =
|
||||
std::vector<bst_row_t>{0, 2, 4};
|
||||
thrust::device_vector<float> feature_values =
|
||||
std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
thrust::device_vector<float> feature_min_values =
|
||||
std::vector<float>{0.0, 10.0};
|
||||
auto feature_histogram = ConvertToInteger({ {-0.5, 0.5}, {0.5, 0.5}, {-0.5, 0.5}, {0.5, 0.5}});
|
||||
EvaluateSplitInputs input{1,0,
|
||||
parent_sum,
|
||||
dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0, 1};
|
||||
thrust::device_vector<uint32_t> feature_segments = std::vector<bst_row_t>{0, 2, 4};
|
||||
thrust::device_vector<float> feature_values = std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
thrust::device_vector<float> feature_min_values = std::vector<float>{0.0, 10.0};
|
||||
auto feature_histogram = ConvertToInteger({{-0.5, 0.5}, {0.5, 0.5}, {-0.5, 0.5}, {0.5, 0.5}});
|
||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
EvaluateSplitSharedInputs shared_inputs{param,
|
||||
quantiser,
|
||||
{},
|
||||
@ -463,7 +443,7 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
||||
false};
|
||||
|
||||
GPUHistEvaluator evaluator(tparam, feature_min_values.size(), 0);
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input,shared_inputs).split;
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
|
||||
EXPECT_EQ(result.findex, 0);
|
||||
EXPECT_EQ(result.fvalue, 1.0);
|
||||
@ -477,41 +457,31 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
tparam.UpdateAllowUnknown(Args{});
|
||||
GPUTrainingParam param{tparam};
|
||||
|
||||
thrust::device_vector<bst_feature_t> feature_set =
|
||||
std::vector<bst_feature_t>{0, 1};
|
||||
thrust::device_vector<uint32_t> feature_segments =
|
||||
std::vector<bst_row_t>{0, 2, 4};
|
||||
thrust::device_vector<float> feature_values =
|
||||
std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
thrust::device_vector<float> feature_min_values =
|
||||
std::vector<float>{0.0, 0.0};
|
||||
auto feature_histogram_left = ConvertToInteger({ {-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}});
|
||||
auto feature_histogram_right = ConvertToInteger({ {-1.0, 0.5}, {1.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}});
|
||||
EvaluateSplitInputs input_left{
|
||||
1,0,
|
||||
parent_sum,
|
||||
dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram_left)};
|
||||
EvaluateSplitInputs input_right{
|
||||
2,0,
|
||||
parent_sum,
|
||||
dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram_right)};
|
||||
EvaluateSplitSharedInputs shared_inputs{
|
||||
param,
|
||||
quantiser,
|
||||
{},
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
dh::ToSpan(feature_min_values),
|
||||
false
|
||||
};
|
||||
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0, 1};
|
||||
thrust::device_vector<uint32_t> feature_segments = std::vector<bst_row_t>{0, 2, 4};
|
||||
thrust::device_vector<float> feature_values = std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
thrust::device_vector<float> feature_min_values = std::vector<float>{0.0, 0.0};
|
||||
auto feature_histogram_left =
|
||||
ConvertToInteger({{-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}});
|
||||
auto feature_histogram_right =
|
||||
ConvertToInteger({{-1.0, 0.5}, {1.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}});
|
||||
EvaluateSplitInputs input_left{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram_left)};
|
||||
EvaluateSplitInputs input_right{2, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram_right)};
|
||||
EvaluateSplitSharedInputs shared_inputs{param,
|
||||
quantiser,
|
||||
{},
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
dh::ToSpan(feature_min_values),
|
||||
false};
|
||||
|
||||
GPUHistEvaluator evaluator{
|
||||
tparam, static_cast<bst_feature_t>(feature_min_values.size()), 0};
|
||||
dh::device_vector<EvaluateSplitInputs> inputs = std::vector<EvaluateSplitInputs>{input_left,input_right};
|
||||
evaluator.LaunchEvaluateSplits(input_left.feature_set.size(),dh::ToSpan(inputs),shared_inputs, evaluator.GetEvaluator(),
|
||||
dh::ToSpan(out_splits));
|
||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_min_values.size()), 0};
|
||||
dh::device_vector<EvaluateSplitInputs> inputs =
|
||||
std::vector<EvaluateSplitInputs>{input_left, input_right};
|
||||
evaluator.LaunchEvaluateSplits(input_left.feature_set.size(), dh::ToSpan(inputs), shared_inputs,
|
||||
evaluator.GetEvaluator(), dh::ToSpan(out_splits));
|
||||
|
||||
DeviceSplitCandidate result_left = out_splits[0];
|
||||
EXPECT_EQ(result_left.findex, 1);
|
||||
@ -530,18 +500,19 @@ TEST_F(TestPartitionBasedSplit, GpuHist) {
|
||||
cuts_.cut_values_.SetDevice(0);
|
||||
cuts_.min_vals_.SetDevice(0);
|
||||
|
||||
evaluator.Reset(cuts_, dh::ToSpan(ft), info_.num_col_, param_, 0);
|
||||
evaluator.Reset(cuts_, dh::ToSpan(ft), info_.num_col_, param_, false, 0);
|
||||
|
||||
// Convert the sample histogram to fixed point
|
||||
auto quantiser = DummyRoundingFactor();
|
||||
thrust::host_vector<GradientPairInt64> h_hist;
|
||||
for(auto e: hist_[0]){
|
||||
for (auto e : hist_[0]) {
|
||||
h_hist.push_back(quantiser.ToFixedPoint(e));
|
||||
}
|
||||
dh::device_vector<GradientPairInt64> d_hist = h_hist;
|
||||
dh::device_vector<bst_feature_t> feature_set{std::vector<bst_feature_t>{0}};
|
||||
|
||||
EvaluateSplitInputs input{0, 0, quantiser.ToFixedPoint(total_gpair_), dh::ToSpan(feature_set), dh::ToSpan(d_hist)};
|
||||
EvaluateSplitInputs input{0, 0, quantiser.ToFixedPoint(total_gpair_), dh::ToSpan(feature_set),
|
||||
dh::ToSpan(d_hist)};
|
||||
EvaluateSplitSharedInputs shared_inputs{GPUTrainingParam{param_},
|
||||
quantiser,
|
||||
dh::ToSpan(ft),
|
||||
@ -552,5 +523,65 @@ TEST_F(TestPartitionBasedSplit, GpuHist) {
|
||||
auto split = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
ASSERT_NEAR(split.loss_chg, best_score_, 1e-2);
|
||||
}
|
||||
|
||||
class MGPUHistTest : public BaseMGPUTest {};
|
||||
|
||||
namespace {
|
||||
void VerifyColumnSplitEvaluateSingleSplit(bool is_categorical) {
|
||||
auto rank = collective::GetRank();
|
||||
auto quantiser = DummyRoundingFactor();
|
||||
auto parent_sum = quantiser.ToFixedPoint(GradientPairPrecise{0.0, 1.0});
|
||||
TrainParam tparam = ZeroParam();
|
||||
GPUTrainingParam param{tparam};
|
||||
|
||||
common::HistogramCuts cuts{rank == 0
|
||||
? MakeCutsForTest({1.0, 2.0}, {0, 2, 2}, {0.0, 0.0}, GPUIDX)
|
||||
: MakeCutsForTest({11.0, 12.0}, {0, 0, 2}, {0.0, 0.0}, GPUIDX)};
|
||||
thrust::device_vector<bst_feature_t> feature_set = std::vector<bst_feature_t>{0, 1};
|
||||
|
||||
// Setup gradients so that second feature gets higher gain
|
||||
auto feature_histogram = rank == 0 ? ConvertToInteger({{-0.5, 0.5}, {0.5, 0.5}})
|
||||
: ConvertToInteger({{-1.0, 0.5}, {1.0, 0.5}});
|
||||
|
||||
dh::device_vector<FeatureType> feature_types(feature_set.size(), FeatureType::kCategorical);
|
||||
common::Span<FeatureType> d_feature_types;
|
||||
if (is_categorical) {
|
||||
auto max_cat = *std::max_element(cuts.cut_values_.HostVector().begin(),
|
||||
cuts.cut_values_.HostVector().end());
|
||||
cuts.SetCategorical(true, max_cat);
|
||||
d_feature_types = dh::ToSpan(feature_types);
|
||||
}
|
||||
|
||||
EvaluateSplitInputs input{1, 0, parent_sum, dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
EvaluateSplitSharedInputs shared_inputs{param,
|
||||
quantiser,
|
||||
d_feature_types,
|
||||
cuts.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts.cut_values_.ConstDeviceSpan(),
|
||||
cuts.min_vals_.ConstDeviceSpan(),
|
||||
false};
|
||||
|
||||
GPUHistEvaluator evaluator{tparam, static_cast<bst_feature_t>(feature_set.size()), GPUIDX};
|
||||
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, true, GPUIDX);
|
||||
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
|
||||
|
||||
EXPECT_EQ(result.findex, 1) << "rank: " << rank;
|
||||
if (is_categorical) {
|
||||
ASSERT_TRUE(std::isnan(result.fvalue));
|
||||
} else {
|
||||
EXPECT_EQ(result.fvalue, 11.0) << "rank: " << rank;
|
||||
}
|
||||
EXPECT_EQ(result.left_sum + result.right_sum, parent_sum) << "rank: " << rank;
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST_F(MGPUHistTest, ColumnSplitEvaluateSingleSplit) {
|
||||
DoTest(VerifyColumnSplitEvaluateSingleSplit, false);
|
||||
}
|
||||
|
||||
TEST_F(MGPUHistTest, ColumnSplitEvaluateSingleCategoricalSplit) {
|
||||
DoTest(VerifyColumnSplitEvaluateSingleSplit, true);
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user