Use double for GPU Hist node sum. (#7507)
This commit is contained in:
parent
eabec370e4
commit
7f399eac8b
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2020 by XGBoost Contributors
|
* Copyright 2020-2021 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include "evaluate_splits.cuh"
|
#include "evaluate_splits.cuh"
|
||||||
@ -9,15 +9,13 @@ namespace xgboost {
|
|||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
// With constraints
|
// With constraints
|
||||||
template <typename GradientPairT>
|
XGBOOST_DEVICE float LossChangeMissing(const GradientPairPrecise &scan,
|
||||||
XGBOOST_DEVICE float
|
const GradientPairPrecise &missing,
|
||||||
LossChangeMissing(const GradientPairT &scan, const GradientPairT &missing,
|
const GradientPairPrecise &parent_sum,
|
||||||
const GradientPairT &parent_sum,
|
const GPUTrainingParam ¶m, bst_node_t nidx,
|
||||||
const GPUTrainingParam ¶m,
|
bst_feature_t fidx,
|
||||||
bst_node_t nidx,
|
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||||
bst_feature_t fidx,
|
bool &missing_left_out) { // NOLINT
|
||||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
|
||||||
bool &missing_left_out) { // NOLINT
|
|
||||||
float parent_gain = CalcGain(param, parent_sum);
|
float parent_gain = CalcGain(param, parent_sum);
|
||||||
float missing_left_gain =
|
float missing_left_gain =
|
||||||
evaluator.CalcSplitGain(param, nidx, fidx, GradStats(scan + missing),
|
evaluator.CalcSplitGain(param, nidx, fidx, GradStats(scan + missing),
|
||||||
@ -72,32 +70,32 @@ ReduceFeature(common::Span<const GradientSumT> feature_histogram,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename GradientSumT, typename TempStorageT> struct OneHotBin {
|
template <typename GradientSumT, typename TempStorageT> struct OneHotBin {
|
||||||
GradientSumT __device__ operator()(
|
GradientSumT __device__ operator()(bool thread_active, uint32_t scan_begin,
|
||||||
bool thread_active, uint32_t scan_begin,
|
SumCallbackOp<GradientSumT> *,
|
||||||
SumCallbackOp<GradientSumT>*,
|
GradientPairPrecise const &missing,
|
||||||
GradientSumT const &missing,
|
EvaluateSplitInputs<GradientSumT> const &inputs,
|
||||||
EvaluateSplitInputs<GradientSumT> const &inputs, TempStorageT *) {
|
TempStorageT *) {
|
||||||
GradientSumT bin = thread_active
|
GradientSumT bin = thread_active
|
||||||
? inputs.gradient_histogram[scan_begin + threadIdx.x]
|
? inputs.gradient_histogram[scan_begin + threadIdx.x]
|
||||||
: GradientSumT();
|
: GradientSumT();
|
||||||
auto rest = inputs.parent_sum - bin - missing;
|
auto rest = inputs.parent_sum - GradientPairPrecise(bin) - missing;
|
||||||
return rest;
|
return GradientSumT{rest};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
struct UpdateOneHot {
|
struct UpdateOneHot {
|
||||||
void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain,
|
void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain,
|
||||||
bst_feature_t fidx, GradientSumT const &missing,
|
bst_feature_t fidx, GradientPairPrecise const &missing,
|
||||||
GradientSumT const &bin,
|
GradientSumT const &bin,
|
||||||
EvaluateSplitInputs<GradientSumT> const &inputs,
|
EvaluateSplitInputs<GradientSumT> const &inputs,
|
||||||
DeviceSplitCandidate *best_split) {
|
DeviceSplitCandidate *best_split) {
|
||||||
int split_gidx = (scan_begin + threadIdx.x);
|
int split_gidx = (scan_begin + threadIdx.x);
|
||||||
float fvalue = inputs.feature_values[split_gidx];
|
float fvalue = inputs.feature_values[split_gidx];
|
||||||
GradientSumT left = missing_left ? bin + missing : bin;
|
GradientPairPrecise left =
|
||||||
GradientSumT right = inputs.parent_sum - left;
|
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
|
||||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx,
|
GradientPairPrecise right = inputs.parent_sum - left;
|
||||||
GradientPair(left), GradientPair(right), true,
|
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right, true,
|
||||||
inputs.param);
|
inputs.param);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -105,8 +103,8 @@ struct UpdateOneHot {
|
|||||||
template <typename GradientSumT, typename TempStorageT, typename ScanT>
|
template <typename GradientSumT, typename TempStorageT, typename ScanT>
|
||||||
struct NumericBin {
|
struct NumericBin {
|
||||||
GradientSumT __device__ operator()(bool thread_active, uint32_t scan_begin,
|
GradientSumT __device__ operator()(bool thread_active, uint32_t scan_begin,
|
||||||
SumCallbackOp<GradientSumT>* prefix_callback,
|
SumCallbackOp<GradientSumT> *prefix_callback,
|
||||||
GradientSumT const &missing,
|
GradientPairPrecise const &missing,
|
||||||
EvaluateSplitInputs<GradientSumT> inputs,
|
EvaluateSplitInputs<GradientSumT> inputs,
|
||||||
TempStorageT *temp_storage) {
|
TempStorageT *temp_storage) {
|
||||||
GradientSumT bin = thread_active
|
GradientSumT bin = thread_active
|
||||||
@ -120,7 +118,7 @@ struct NumericBin {
|
|||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
struct UpdateNumeric {
|
struct UpdateNumeric {
|
||||||
void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain,
|
void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain,
|
||||||
bst_feature_t fidx, GradientSumT const &missing,
|
bst_feature_t fidx, GradientPairPrecise const &missing,
|
||||||
GradientSumT const &bin,
|
GradientSumT const &bin,
|
||||||
EvaluateSplitInputs<GradientSumT> const &inputs,
|
EvaluateSplitInputs<GradientSumT> const &inputs,
|
||||||
DeviceSplitCandidate *best_split) {
|
DeviceSplitCandidate *best_split) {
|
||||||
@ -133,11 +131,11 @@ struct UpdateNumeric {
|
|||||||
} else {
|
} else {
|
||||||
fvalue = inputs.feature_values[split_gidx];
|
fvalue = inputs.feature_values[split_gidx];
|
||||||
}
|
}
|
||||||
GradientSumT left = missing_left ? bin + missing : bin;
|
GradientPairPrecise left =
|
||||||
GradientSumT right = inputs.parent_sum - left;
|
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
|
||||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue,
|
GradientPairPrecise right = inputs.parent_sum - left;
|
||||||
fidx, GradientPair(left), GradientPair(right),
|
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right, false,
|
||||||
false, inputs.param);
|
inputs.param);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -164,7 +162,7 @@ __device__ void EvaluateFeature(
|
|||||||
ReduceFeature<BLOCK_THREADS, ReduceT, TempStorageT, GradientSumT>(
|
ReduceFeature<BLOCK_THREADS, ReduceT, TempStorageT, GradientSumT>(
|
||||||
feature_hist, temp_storage);
|
feature_hist, temp_storage);
|
||||||
|
|
||||||
GradientSumT const missing = inputs.parent_sum - feature_sum;
|
GradientPairPrecise const missing = inputs.parent_sum - GradientPairPrecise{feature_sum};
|
||||||
float const null_gain = -std::numeric_limits<bst_float>::infinity();
|
float const null_gain = -std::numeric_limits<bst_float>::infinity();
|
||||||
|
|
||||||
SumCallbackOp<GradientSumT> prefix_op = SumCallbackOp<GradientSumT>();
|
SumCallbackOp<GradientSumT> prefix_op = SumCallbackOp<GradientSumT>();
|
||||||
@ -177,11 +175,8 @@ __device__ void EvaluateFeature(
|
|||||||
bool missing_left = true;
|
bool missing_left = true;
|
||||||
float gain = null_gain;
|
float gain = null_gain;
|
||||||
if (thread_active) {
|
if (thread_active) {
|
||||||
gain = LossChangeMissing(bin, missing, inputs.parent_sum, inputs.param,
|
gain = LossChangeMissing(GradientPairPrecise{bin}, missing, inputs.parent_sum, inputs.param,
|
||||||
inputs.nidx,
|
inputs.nidx, fidx, evaluator, missing_left);
|
||||||
fidx,
|
|
||||||
evaluator,
|
|
||||||
missing_left);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|||||||
@ -15,7 +15,7 @@ namespace tree {
|
|||||||
template <typename GradientSumT>
|
template <typename GradientSumT>
|
||||||
struct EvaluateSplitInputs {
|
struct EvaluateSplitInputs {
|
||||||
int nidx;
|
int nidx;
|
||||||
GradientSumT parent_sum;
|
GradientPairPrecise parent_sum;
|
||||||
GPUTrainingParam param;
|
GPUTrainingParam param;
|
||||||
common::Span<const bst_feature_t> feature_set;
|
common::Span<const bst_feature_t> feature_set;
|
||||||
common::Span<FeatureType const> feature_types;
|
common::Span<FeatureType const> feature_types;
|
||||||
|
|||||||
@ -61,8 +61,8 @@ struct DeviceSplitCandidate {
|
|||||||
float fvalue {0};
|
float fvalue {0};
|
||||||
bool is_cat { false };
|
bool is_cat { false };
|
||||||
|
|
||||||
GradientPair left_sum;
|
GradientPairPrecise left_sum;
|
||||||
GradientPair right_sum;
|
GradientPairPrecise right_sum;
|
||||||
|
|
||||||
XGBOOST_DEVICE DeviceSplitCandidate() {} // NOLINT
|
XGBOOST_DEVICE DeviceSplitCandidate() {} // NOLINT
|
||||||
|
|
||||||
@ -78,8 +78,8 @@ struct DeviceSplitCandidate {
|
|||||||
|
|
||||||
XGBOOST_DEVICE void Update(float loss_chg_in, DefaultDirection dir_in,
|
XGBOOST_DEVICE void Update(float loss_chg_in, DefaultDirection dir_in,
|
||||||
float fvalue_in, int findex_in,
|
float fvalue_in, int findex_in,
|
||||||
GradientPair left_sum_in,
|
GradientPairPrecise left_sum_in,
|
||||||
GradientPair right_sum_in,
|
GradientPairPrecise right_sum_in,
|
||||||
bool cat,
|
bool cat,
|
||||||
const GPUTrainingParam& param) {
|
const GPUTrainingParam& param) {
|
||||||
if (loss_chg_in > loss_chg &&
|
if (loss_chg_in > loss_chg &&
|
||||||
|
|||||||
@ -173,7 +173,7 @@ struct GPUHistMakerDevice {
|
|||||||
dh::caching_device_vector<int> monotone_constraints;
|
dh::caching_device_vector<int> monotone_constraints;
|
||||||
|
|
||||||
/*! \brief Sum gradient for each node. */
|
/*! \brief Sum gradient for each node. */
|
||||||
std::vector<GradientPair> node_sum_gradients;
|
std::vector<GradientPairPrecise> node_sum_gradients;
|
||||||
|
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
|
|
||||||
@ -239,8 +239,7 @@ struct GPUHistMakerDevice {
|
|||||||
dh::safe_cuda(cudaSetDevice(device_id));
|
dh::safe_cuda(cudaSetDevice(device_id));
|
||||||
tree_evaluator = TreeEvaluator(param, dmat->Info().num_col_, device_id);
|
tree_evaluator = TreeEvaluator(param, dmat->Info().num_col_, device_id);
|
||||||
this->interaction_constraints.Reset();
|
this->interaction_constraints.Reset();
|
||||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
|
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPairPrecise{});
|
||||||
GradientPair());
|
|
||||||
|
|
||||||
if (d_gpair.size() != dh_gpair->Size()) {
|
if (d_gpair.size() != dh_gpair->Size()) {
|
||||||
d_gpair.resize(dh_gpair->Size());
|
d_gpair.resize(dh_gpair->Size());
|
||||||
@ -260,7 +259,7 @@ struct GPUHistMakerDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
DeviceSplitCandidate EvaluateRootSplit(GradientPair root_sum) {
|
DeviceSplitCandidate EvaluateRootSplit(GradientPairPrecise root_sum) {
|
||||||
int nidx = RegTree::kRoot;
|
int nidx = RegTree::kRoot;
|
||||||
dh::TemporaryArray<DeviceSplitCandidate> splits_out(1);
|
dh::TemporaryArray<DeviceSplitCandidate> splits_out(1);
|
||||||
GPUTrainingParam gpu_param(param);
|
GPUTrainingParam gpu_param(param);
|
||||||
@ -269,16 +268,15 @@ struct GPUHistMakerDevice {
|
|||||||
common::Span<bst_feature_t> feature_set =
|
common::Span<bst_feature_t> feature_set =
|
||||||
interaction_constraints.Query(sampled_features->DeviceSpan(), nidx);
|
interaction_constraints.Query(sampled_features->DeviceSpan(), nidx);
|
||||||
auto matrix = page->GetDeviceAccessor(device_id);
|
auto matrix = page->GetDeviceAccessor(device_id);
|
||||||
EvaluateSplitInputs<GradientSumT> inputs{
|
EvaluateSplitInputs<GradientSumT> inputs{nidx,
|
||||||
nidx,
|
root_sum,
|
||||||
{root_sum.GetGrad(), root_sum.GetHess()},
|
gpu_param,
|
||||||
gpu_param,
|
feature_set,
|
||||||
feature_set,
|
feature_types,
|
||||||
feature_types,
|
matrix.feature_segments,
|
||||||
matrix.feature_segments,
|
matrix.gidx_fvalue_map,
|
||||||
matrix.gidx_fvalue_map,
|
matrix.min_fvalue,
|
||||||
matrix.min_fvalue,
|
hist.GetNodeHistogram(nidx)};
|
||||||
hist.GetNodeHistogram(nidx)};
|
|
||||||
auto gain_calc = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
auto gain_calc = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||||
EvaluateSingleSplit(dh::ToSpan(splits_out), gain_calc, inputs);
|
EvaluateSingleSplit(dh::ToSpan(splits_out), gain_calc, inputs);
|
||||||
std::vector<DeviceSplitCandidate> result(1);
|
std::vector<DeviceSplitCandidate> result(1);
|
||||||
@ -307,28 +305,24 @@ struct GPUHistMakerDevice {
|
|||||||
left_nidx);
|
left_nidx);
|
||||||
auto matrix = page->GetDeviceAccessor(device_id);
|
auto matrix = page->GetDeviceAccessor(device_id);
|
||||||
|
|
||||||
EvaluateSplitInputs<GradientSumT> left{
|
EvaluateSplitInputs<GradientSumT> left{left_nidx,
|
||||||
left_nidx,
|
candidate.split.left_sum,
|
||||||
{candidate.split.left_sum.GetGrad(),
|
gpu_param,
|
||||||
candidate.split.left_sum.GetHess()},
|
left_feature_set,
|
||||||
gpu_param,
|
feature_types,
|
||||||
left_feature_set,
|
matrix.feature_segments,
|
||||||
feature_types,
|
matrix.gidx_fvalue_map,
|
||||||
matrix.feature_segments,
|
matrix.min_fvalue,
|
||||||
matrix.gidx_fvalue_map,
|
hist.GetNodeHistogram(left_nidx)};
|
||||||
matrix.min_fvalue,
|
EvaluateSplitInputs<GradientSumT> right{right_nidx,
|
||||||
hist.GetNodeHistogram(left_nidx)};
|
candidate.split.right_sum,
|
||||||
EvaluateSplitInputs<GradientSumT> right{
|
gpu_param,
|
||||||
right_nidx,
|
right_feature_set,
|
||||||
{candidate.split.right_sum.GetGrad(),
|
feature_types,
|
||||||
candidate.split.right_sum.GetHess()},
|
matrix.feature_segments,
|
||||||
gpu_param,
|
matrix.gidx_fvalue_map,
|
||||||
right_feature_set,
|
matrix.min_fvalue,
|
||||||
feature_types,
|
hist.GetNodeHistogram(right_nidx)};
|
||||||
matrix.feature_segments,
|
|
||||||
matrix.gidx_fvalue_map,
|
|
||||||
matrix.min_fvalue,
|
|
||||||
hist.GetNodeHistogram(right_nidx)};
|
|
||||||
auto d_splits_out = dh::ToSpan(splits_out);
|
auto d_splits_out = dh::ToSpan(splits_out);
|
||||||
EvaluateSplits(d_splits_out, tree_evaluator.GetEvaluator<GPUTrainingParam>(), left, right);
|
EvaluateSplits(d_splits_out, tree_evaluator.GetEvaluator<GPUTrainingParam>(), left, right);
|
||||||
dh::TemporaryArray<GPUExpandEntry> entries(2);
|
dh::TemporaryArray<GPUExpandEntry> entries(2);
|
||||||
@ -502,12 +496,11 @@ struct GPUHistMakerDevice {
|
|||||||
auto d_ridx = row_partitioner->GetRows();
|
auto d_ridx = row_partitioner->GetRows();
|
||||||
|
|
||||||
GPUTrainingParam param_d(param);
|
GPUTrainingParam param_d(param);
|
||||||
dh::TemporaryArray<GradientPair> device_node_sum_gradients(node_sum_gradients.size());
|
dh::TemporaryArray<GradientPairPrecise> device_node_sum_gradients(node_sum_gradients.size());
|
||||||
|
|
||||||
dh::safe_cuda(
|
dh::safe_cuda(cudaMemcpyAsync(device_node_sum_gradients.data().get(), node_sum_gradients.data(),
|
||||||
cudaMemcpyAsync(device_node_sum_gradients.data().get(), node_sum_gradients.data(),
|
sizeof(GradientPairPrecise) * node_sum_gradients.size(),
|
||||||
sizeof(GradientPair) * node_sum_gradients.size(),
|
cudaMemcpyHostToDevice));
|
||||||
cudaMemcpyHostToDevice));
|
|
||||||
auto d_position = row_partitioner->GetPosition();
|
auto d_position = row_partitioner->GetPosition();
|
||||||
auto d_node_sum_gradients = device_node_sum_gradients.data().get();
|
auto d_node_sum_gradients = device_node_sum_gradients.data().get();
|
||||||
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||||
@ -623,13 +616,12 @@ struct GPUHistMakerDevice {
|
|||||||
GPUExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
|
GPUExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
|
||||||
constexpr bst_node_t kRootNIdx = 0;
|
constexpr bst_node_t kRootNIdx = 0;
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
GradientPair root_sum = dh::Reduce(
|
auto gpair_it = dh::MakeTransformIterator<GradientPairPrecise>(
|
||||||
thrust::cuda::par(alloc),
|
dh::tbegin(gpair), [] __device__(auto const& gpair) { return GradientPairPrecise{gpair}; });
|
||||||
thrust::device_ptr<GradientPair const>(gpair.data()),
|
GradientPairPrecise root_sum =
|
||||||
thrust::device_ptr<GradientPair const>(gpair.data() + gpair.size()),
|
dh::Reduce(thrust::cuda::par(alloc), gpair_it, gpair_it + gpair.size(),
|
||||||
GradientPair{}, thrust::plus<GradientPair>{});
|
GradientPairPrecise{}, thrust::plus<GradientPairPrecise>{});
|
||||||
rabit::Allreduce<rabit::op::Sum, float>(reinterpret_cast<float*>(&root_sum),
|
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double*>(&root_sum), 2);
|
||||||
2);
|
|
||||||
|
|
||||||
this->BuildHist(kRootNIdx);
|
this->BuildHist(kRootNIdx);
|
||||||
this->AllReduceHist(kRootNIdx, reducer);
|
this->AllReduceHist(kRootNIdx, reducer);
|
||||||
|
|||||||
@ -17,7 +17,7 @@ auto ZeroParam() {
|
|||||||
|
|
||||||
void TestEvaluateSingleSplit(bool is_categorical) {
|
void TestEvaluateSingleSplit(bool is_categorical) {
|
||||||
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
||||||
GradientPair parent_sum(0.0, 1.0);
|
GradientPairPrecise parent_sum(0.0, 1.0);
|
||||||
TrainParam tparam = ZeroParam();
|
TrainParam tparam = ZeroParam();
|
||||||
GPUTrainingParam param{tparam};
|
GPUTrainingParam param{tparam};
|
||||||
|
|
||||||
@ -73,7 +73,7 @@ TEST(GpuHist, EvaluateCategoricalSplit) {
|
|||||||
|
|
||||||
TEST(GpuHist, EvaluateSingleSplitMissing) {
|
TEST(GpuHist, EvaluateSingleSplitMissing) {
|
||||||
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
||||||
GradientPair parent_sum(1.0, 1.5);
|
GradientPairPrecise parent_sum(1.0, 1.5);
|
||||||
TrainParam tparam = ZeroParam();
|
TrainParam tparam = ZeroParam();
|
||||||
GPUTrainingParam param{tparam};
|
GPUTrainingParam param{tparam};
|
||||||
|
|
||||||
@ -104,8 +104,8 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
|
|||||||
EXPECT_EQ(result.findex, 0);
|
EXPECT_EQ(result.findex, 0);
|
||||||
EXPECT_EQ(result.fvalue, 1.0);
|
EXPECT_EQ(result.fvalue, 1.0);
|
||||||
EXPECT_EQ(result.dir, kRightDir);
|
EXPECT_EQ(result.dir, kRightDir);
|
||||||
EXPECT_EQ(result.left_sum, GradientPair(-0.5, 0.5));
|
EXPECT_EQ(result.left_sum, GradientPairPrecise(-0.5, 0.5));
|
||||||
EXPECT_EQ(result.right_sum, GradientPair(1.5, 1.0));
|
EXPECT_EQ(result.right_sum, GradientPairPrecise(1.5, 1.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GpuHist, EvaluateSingleSplitEmpty) {
|
TEST(GpuHist, EvaluateSingleSplitEmpty) {
|
||||||
@ -130,7 +130,7 @@ TEST(GpuHist, EvaluateSingleSplitEmpty) {
|
|||||||
// Feature 0 has a better split, but the algorithm must select feature 1
|
// Feature 0 has a better split, but the algorithm must select feature 1
|
||||||
TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
||||||
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
||||||
GradientPair parent_sum(0.0, 1.0);
|
GradientPairPrecise parent_sum(0.0, 1.0);
|
||||||
TrainParam tparam = ZeroParam();
|
TrainParam tparam = ZeroParam();
|
||||||
tparam.UpdateAllowUnknown(Args{});
|
tparam.UpdateAllowUnknown(Args{});
|
||||||
GPUTrainingParam param{tparam};
|
GPUTrainingParam param{tparam};
|
||||||
@ -164,14 +164,14 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
|||||||
DeviceSplitCandidate result = out_splits[0];
|
DeviceSplitCandidate result = out_splits[0];
|
||||||
EXPECT_EQ(result.findex, 1);
|
EXPECT_EQ(result.findex, 1);
|
||||||
EXPECT_EQ(result.fvalue, 11.0);
|
EXPECT_EQ(result.fvalue, 11.0);
|
||||||
EXPECT_EQ(result.left_sum, GradientPair(-0.5, 0.5));
|
EXPECT_EQ(result.left_sum, GradientPairPrecise(-0.5, 0.5));
|
||||||
EXPECT_EQ(result.right_sum, GradientPair(0.5, 0.5));
|
EXPECT_EQ(result.right_sum, GradientPairPrecise(0.5, 0.5));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Features 0 and 1 have identical gain, the algorithm must select 0
|
// Features 0 and 1 have identical gain, the algorithm must select 0
|
||||||
TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
||||||
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
||||||
GradientPair parent_sum(0.0, 1.0);
|
GradientPairPrecise parent_sum(0.0, 1.0);
|
||||||
TrainParam tparam = ZeroParam();
|
TrainParam tparam = ZeroParam();
|
||||||
tparam.UpdateAllowUnknown(Args{});
|
tparam.UpdateAllowUnknown(Args{});
|
||||||
GPUTrainingParam param{tparam};
|
GPUTrainingParam param{tparam};
|
||||||
@ -209,7 +209,7 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
|||||||
|
|
||||||
TEST(GpuHist, EvaluateSplits) {
|
TEST(GpuHist, EvaluateSplits) {
|
||||||
thrust::device_vector<DeviceSplitCandidate> out_splits(2);
|
thrust::device_vector<DeviceSplitCandidate> out_splits(2);
|
||||||
GradientPair parent_sum(0.0, 1.0);
|
GradientPairPrecise parent_sum(0.0, 1.0);
|
||||||
TrainParam tparam = ZeroParam();
|
TrainParam tparam = ZeroParam();
|
||||||
tparam.UpdateAllowUnknown(Args{});
|
tparam.UpdateAllowUnknown(Args{});
|
||||||
GPUTrainingParam param{tparam};
|
GPUTrainingParam param{tparam};
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user