Refactor gpu_hist split evaluation (#5610)
* Refactor * Rewrite evaluate splits * Add more tests
This commit is contained in:
parent
dfcdfabf1f
commit
b9649e7b8e
2
cub
2
cub
@ -1 +1 @@
|
||||
Subproject commit b20808b1b04ec3d6a625e51fbc1eb76f337754ad
|
||||
Subproject commit c3cceac115c072fb63df1836ff46d8c60d9eb304
|
||||
@ -28,8 +28,8 @@ struct ValueConstraint {
|
||||
inline static void Init(tree::TrainParam *param, unsigned num_feature) {
|
||||
param->monotone_constraints.resize(num_feature, 0);
|
||||
}
|
||||
template <typename ParamT>
|
||||
XGBOOST_DEVICE inline double CalcWeight(const ParamT ¶m, tree::GradStats stats) const {
|
||||
template <typename ParamT, typename GpairT>
|
||||
XGBOOST_DEVICE inline double CalcWeight(const ParamT ¶m, GpairT stats) const {
|
||||
double w = xgboost::tree::CalcWeight(param, stats);
|
||||
if (w < lower_bound) {
|
||||
return lower_bound;
|
||||
@ -63,9 +63,9 @@ struct ValueConstraint {
|
||||
return wleft >= wright ? gain : negative_infinity;
|
||||
}
|
||||
}
|
||||
|
||||
inline void SetChild(const tree::TrainParam ¶m, bst_uint split_index,
|
||||
tree::GradStats left, tree::GradStats right, ValueConstraint *cleft,
|
||||
template <typename GpairT>
|
||||
void SetChild(const tree::TrainParam ¶m, bst_uint split_index,
|
||||
GpairT left, GpairT right, ValueConstraint *cleft,
|
||||
ValueConstraint *cright) {
|
||||
int c = param.monotone_constraints.at(split_index);
|
||||
*cleft = *this;
|
||||
|
||||
261
src/tree/gpu_hist/evaluate_splits.cu
Normal file
261
src/tree/gpu_hist/evaluate_splits.cu
Normal file
@ -0,0 +1,261 @@
|
||||
/*!
|
||||
* Copyright 2020 by XGBoost Contributors
|
||||
*/
|
||||
#include "evaluate_splits.cuh"
|
||||
#include <limits>
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
// With constraints
|
||||
template <typename GradientPairT>
|
||||
XGBOOST_DEVICE float LossChangeMissing(const GradientPairT& scan,
|
||||
const GradientPairT& missing,
|
||||
const GradientPairT& parent_sum,
|
||||
const GPUTrainingParam& param,
|
||||
int constraint,
|
||||
const ValueConstraint& value_constraint,
|
||||
bool& missing_left_out) { // NOLINT
|
||||
float parent_gain = CalcGain(param, parent_sum);
|
||||
float missing_left_gain = value_constraint.CalcSplitGain(
|
||||
param, constraint, GradStats(scan + missing),
|
||||
GradStats(parent_sum - (scan + missing)));
|
||||
float missing_right_gain = value_constraint.CalcSplitGain(
|
||||
param, constraint, GradStats(scan), GradStats(parent_sum - scan));
|
||||
|
||||
if (missing_left_gain >= missing_right_gain) {
|
||||
missing_left_out = true;
|
||||
return missing_left_gain - parent_gain;
|
||||
} else {
|
||||
missing_left_out = false;
|
||||
return missing_right_gain - parent_gain;
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief
|
||||
*
|
||||
* \tparam ReduceT BlockReduce Type.
|
||||
* \tparam TempStorage Cub Shared memory
|
||||
*
|
||||
* \param begin
|
||||
* \param end
|
||||
* \param temp_storage Shared memory for intermediate result.
|
||||
*/
|
||||
template <int BLOCK_THREADS, typename ReduceT, typename TempStorageT,
|
||||
typename GradientSumT>
|
||||
__device__ GradientSumT
|
||||
ReduceFeature(common::Span<const GradientSumT> feature_histogram,
|
||||
TempStorageT* temp_storage) {
|
||||
__shared__ cub::Uninitialized<GradientSumT> uninitialized_sum;
|
||||
GradientSumT& shared_sum = uninitialized_sum.Alias();
|
||||
|
||||
GradientSumT local_sum = GradientSumT();
|
||||
// For loop sums features into one block size
|
||||
auto begin = feature_histogram.data();
|
||||
auto end = begin + feature_histogram.size();
|
||||
for (auto itr = begin; itr < end; itr += BLOCK_THREADS) {
|
||||
bool thread_active = itr + threadIdx.x < end;
|
||||
// Scan histogram
|
||||
GradientSumT bin = thread_active ? *(itr + threadIdx.x) : GradientSumT();
|
||||
local_sum += bin;
|
||||
}
|
||||
local_sum = ReduceT(temp_storage->sum_reduce).Reduce(local_sum, cub::Sum());
|
||||
// Reduction result is stored in thread 0.
|
||||
if (threadIdx.x == 0) {
|
||||
shared_sum = local_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
return shared_sum;
|
||||
}
|
||||
|
||||
/*! \brief Find the thread with best gain. */
|
||||
template <int BLOCK_THREADS, typename ReduceT, typename ScanT,
|
||||
typename MaxReduceT, typename TempStorageT, typename GradientSumT>
|
||||
__device__ void EvaluateFeature(
|
||||
int fidx, EvaluateSplitInputs<GradientSumT> inputs,
|
||||
DeviceSplitCandidate* best_split, // shared memory storing best split
|
||||
TempStorageT* temp_storage // temp memory for cub operations
|
||||
) {
|
||||
// Use pointer from cut to indicate begin and end of bins for each feature.
|
||||
uint32_t gidx_begin = inputs.feature_segments[fidx]; // begining bin
|
||||
uint32_t gidx_end =
|
||||
inputs.feature_segments[fidx + 1]; // end bin for i^th feature
|
||||
|
||||
// Sum histogram bins for current feature
|
||||
GradientSumT const feature_sum =
|
||||
ReduceFeature<BLOCK_THREADS, ReduceT, TempStorageT, GradientSumT>(
|
||||
inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin),
|
||||
temp_storage);
|
||||
|
||||
GradientSumT const missing = inputs.parent_sum - feature_sum;
|
||||
float const null_gain = -std::numeric_limits<bst_float>::infinity();
|
||||
|
||||
SumCallbackOp<GradientSumT> prefix_op = SumCallbackOp<GradientSumT>();
|
||||
for (int scan_begin = gidx_begin; scan_begin < gidx_end;
|
||||
scan_begin += BLOCK_THREADS) {
|
||||
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
|
||||
|
||||
// Gradient value for current bin.
|
||||
GradientSumT bin = thread_active
|
||||
? inputs.gradient_histogram[scan_begin + threadIdx.x]
|
||||
: GradientSumT();
|
||||
ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
||||
|
||||
// Whether the gradient of missing values is put to the left side.
|
||||
bool missing_left = true;
|
||||
float gain = null_gain;
|
||||
if (thread_active) {
|
||||
gain = LossChangeMissing(bin, missing, inputs.parent_sum, inputs.param,
|
||||
inputs.monotonic_constraints[fidx],
|
||||
inputs.value_constraint, missing_left);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Find thread with best gain
|
||||
cub::KeyValuePair<int, float> tuple(threadIdx.x, gain);
|
||||
cub::KeyValuePair<int, float> best =
|
||||
MaxReduceT(temp_storage->max_reduce).Reduce(tuple, cub::ArgMax());
|
||||
|
||||
__shared__ cub::KeyValuePair<int, float> block_max;
|
||||
if (threadIdx.x == 0) {
|
||||
block_max = best;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Best thread updates split
|
||||
if (threadIdx.x == block_max.key) {
|
||||
int split_gidx = (scan_begin + threadIdx.x) - 1;
|
||||
float fvalue;
|
||||
if (split_gidx < static_cast<int>(gidx_begin)) {
|
||||
fvalue = inputs.min_fvalue[fidx];
|
||||
} else {
|
||||
fvalue = inputs.feature_values[split_gidx];
|
||||
}
|
||||
GradientSumT left = missing_left ? bin + missing : bin;
|
||||
GradientSumT right = inputs.parent_sum - left;
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue,
|
||||
fidx, GradientPair(left), GradientPair(right),
|
||||
inputs.param);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <int BLOCK_THREADS, typename GradientSumT>
|
||||
__global__ void EvaluateSplitsKernel(
|
||||
EvaluateSplitInputs<GradientSumT> left,
|
||||
EvaluateSplitInputs<GradientSumT> right,
|
||||
common::Span<DeviceSplitCandidate> out_candidates) {
|
||||
// KeyValuePair here used as threadIdx.x -> gain_value
|
||||
using ArgMaxT = cub::KeyValuePair<int, float>;
|
||||
using BlockScanT =
|
||||
cub::BlockScan<GradientSumT, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>;
|
||||
using MaxReduceT = cub::BlockReduce<ArgMaxT, BLOCK_THREADS>;
|
||||
|
||||
using SumReduceT = cub::BlockReduce<GradientSumT, BLOCK_THREADS>;
|
||||
|
||||
union TempStorage {
|
||||
typename BlockScanT::TempStorage scan;
|
||||
typename MaxReduceT::TempStorage max_reduce;
|
||||
typename SumReduceT::TempStorage sum_reduce;
|
||||
};
|
||||
|
||||
// Aligned && shared storage for best_split
|
||||
__shared__ cub::Uninitialized<DeviceSplitCandidate> uninitialized_split;
|
||||
DeviceSplitCandidate& best_split = uninitialized_split.Alias();
|
||||
__shared__ TempStorage temp_storage;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
best_split = DeviceSplitCandidate();
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// If this block is working on the left or right node
|
||||
bool is_left = blockIdx.x < left.feature_set.size();
|
||||
EvaluateSplitInputs<GradientSumT>& inputs = is_left ? left : right;
|
||||
|
||||
// One block for each feature. Features are sampled, so fidx != blockIdx.x
|
||||
int fidx = inputs.feature_set[is_left ? blockIdx.x
|
||||
: blockIdx.x - left.feature_set.size()];
|
||||
|
||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT>(
|
||||
fidx, inputs, &best_split, &temp_storage);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// Record best loss for each feature
|
||||
out_candidates[blockIdx.x] = best_split;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ DeviceSplitCandidate operator+(const DeviceSplitCandidate& a,
|
||||
const DeviceSplitCandidate& b) {
|
||||
return b.loss_chg > a.loss_chg ? b : a;
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void EvaluateSplits(common::Span<DeviceSplitCandidate> out_splits,
|
||||
EvaluateSplitInputs<GradientSumT> left,
|
||||
EvaluateSplitInputs<GradientSumT> right) {
|
||||
size_t combined_num_features =
|
||||
left.feature_set.size() + right.feature_set.size();
|
||||
dh::TemporaryArray<DeviceSplitCandidate> feature_best_splits(
|
||||
combined_num_features);
|
||||
// One block for each feature
|
||||
uint32_t constexpr kBlockThreads = 256;
|
||||
dh::LaunchKernel {uint32_t(combined_num_features), kBlockThreads, 0}(
|
||||
EvaluateSplitsKernel<kBlockThreads, GradientSumT>, left, right,
|
||||
dh::ToSpan(feature_best_splits));
|
||||
|
||||
// Reduce to get best candidate for left and right child over all features
|
||||
auto reduce_offset =
|
||||
dh::MakeTransformIterator<size_t>(thrust::make_counting_iterator(0llu),
|
||||
[=] __device__(size_t idx) -> size_t {
|
||||
if (idx == 0) {
|
||||
return 0;
|
||||
}
|
||||
if (idx == 1) {
|
||||
return left.feature_set.size();
|
||||
}
|
||||
if (idx == 2) {
|
||||
return combined_num_features;
|
||||
}
|
||||
return 0;
|
||||
});
|
||||
size_t temp_storage_bytes = 0;
|
||||
cub::DeviceSegmentedReduce::Sum(nullptr, temp_storage_bytes,
|
||||
feature_best_splits.data(), out_splits.data(),
|
||||
2, reduce_offset, reduce_offset + 1);
|
||||
dh::TemporaryArray<int8_t> temp(temp_storage_bytes);
|
||||
cub::DeviceSegmentedReduce::Sum(temp.data().get(), temp_storage_bytes,
|
||||
feature_best_splits.data(), out_splits.data(),
|
||||
2, reduce_offset, reduce_offset + 1);
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void EvaluateSingleSplit(common::Span<DeviceSplitCandidate> out_split,
|
||||
EvaluateSplitInputs<GradientSumT> input) {
|
||||
EvaluateSplits(out_split, input, {});
|
||||
}
|
||||
|
||||
template void EvaluateSplits<GradientPair>(
|
||||
common::Span<DeviceSplitCandidate> out_splits,
|
||||
EvaluateSplitInputs<GradientPair> left,
|
||||
EvaluateSplitInputs<GradientPair> right);
|
||||
template void EvaluateSplits<GradientPairPrecise>(
|
||||
common::Span<DeviceSplitCandidate> out_splits,
|
||||
EvaluateSplitInputs<GradientPairPrecise> left,
|
||||
EvaluateSplitInputs<GradientPairPrecise> right);
|
||||
template void EvaluateSingleSplit<GradientPair>(
|
||||
common::Span<DeviceSplitCandidate> out_split,
|
||||
EvaluateSplitInputs<GradientPair> input);
|
||||
template void EvaluateSingleSplit<GradientPairPrecise>(
|
||||
common::Span<DeviceSplitCandidate> out_split,
|
||||
EvaluateSplitInputs<GradientPairPrecise> input);
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
37
src/tree/gpu_hist/evaluate_splits.cuh
Normal file
37
src/tree/gpu_hist/evaluate_splits.cuh
Normal file
@ -0,0 +1,37 @@
|
||||
/*!
|
||||
* Copyright 2020 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef EVALUATE_SPLITS_CUH_
|
||||
#define EVALUATE_SPLITS_CUH_
|
||||
#include <xgboost/span.h>
|
||||
#include "../../data/ellpack_page.cuh"
|
||||
#include "../constraints.cuh"
|
||||
#include "../updater_gpu_common.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
template <typename GradientSumT>
|
||||
struct EvaluateSplitInputs {
|
||||
int nidx;
|
||||
GradientSumT parent_sum;
|
||||
GPUTrainingParam param;
|
||||
common::Span<const bst_feature_t> feature_set;
|
||||
common::Span<const uint32_t> feature_segments;
|
||||
common::Span<const float> feature_values;
|
||||
common::Span<const float> min_fvalue;
|
||||
common::Span<const GradientSumT> gradient_histogram;
|
||||
ValueConstraint value_constraint;
|
||||
common::Span<const int> monotonic_constraints;
|
||||
};
|
||||
template <typename GradientSumT>
|
||||
void EvaluateSplits(common::Span<DeviceSplitCandidate> out_splits,
|
||||
EvaluateSplitInputs<GradientSumT> left,
|
||||
EvaluateSplitInputs<GradientSumT> right);
|
||||
template <typename GradientSumT>
|
||||
void EvaluateSingleSplit(common::Span<DeviceSplitCandidate> out_split,
|
||||
EvaluateSplitInputs<GradientSumT> input);
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // EVALUATE_SPLITS_CUH_
|
||||
@ -29,6 +29,7 @@
|
||||
#include "gpu_hist/gradient_based_sampler.cuh"
|
||||
#include "gpu_hist/row_partitioner.cuh"
|
||||
#include "gpu_hist/histogram.cuh"
|
||||
#include "gpu_hist/evaluate_splits.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@ -108,188 +109,6 @@ inline static bool LossGuide(const ExpandEntry& lhs, const ExpandEntry& rhs) {
|
||||
}
|
||||
}
|
||||
|
||||
// With constraints
|
||||
template <typename GradientPairT>
|
||||
XGBOOST_DEVICE float inline LossChangeMissing(
|
||||
const GradientPairT& scan, const GradientPairT& missing, const GradientPairT& parent_sum,
|
||||
const float& parent_gain, const GPUTrainingParam& param, int constraint,
|
||||
const ValueConstraint& value_constraint,
|
||||
bool& missing_left_out) { // NOLINT
|
||||
float missing_left_gain = value_constraint.CalcSplitGain(
|
||||
param, constraint, GradStats(scan + missing),
|
||||
GradStats(parent_sum - (scan + missing)));
|
||||
float missing_right_gain = value_constraint.CalcSplitGain(
|
||||
param, constraint, GradStats(scan), GradStats(parent_sum - scan));
|
||||
|
||||
if (missing_left_gain >= missing_right_gain) {
|
||||
missing_left_out = true;
|
||||
return missing_left_gain - parent_gain;
|
||||
} else {
|
||||
missing_left_out = false;
|
||||
return missing_right_gain - parent_gain;
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief
|
||||
*
|
||||
* \tparam ReduceT BlockReduce Type.
|
||||
* \tparam TempStorage Cub Shared memory
|
||||
*
|
||||
* \param begin
|
||||
* \param end
|
||||
* \param temp_storage Shared memory for intermediate result.
|
||||
*/
|
||||
template <int BLOCK_THREADS, typename ReduceT, typename TempStorageT, typename GradientSumT>
|
||||
__device__ GradientSumT ReduceFeature(common::Span<const GradientSumT> feature_histogram,
|
||||
TempStorageT* temp_storage) {
|
||||
__shared__ cub::Uninitialized<GradientSumT> uninitialized_sum;
|
||||
GradientSumT& shared_sum = uninitialized_sum.Alias();
|
||||
|
||||
GradientSumT local_sum = GradientSumT();
|
||||
// For loop sums features into one block size
|
||||
auto begin = feature_histogram.data();
|
||||
auto end = begin + feature_histogram.size();
|
||||
for (auto itr = begin; itr < end; itr += BLOCK_THREADS) {
|
||||
bool thread_active = itr + threadIdx.x < end;
|
||||
// Scan histogram
|
||||
GradientSumT bin = thread_active ? *(itr + threadIdx.x) : GradientSumT();
|
||||
local_sum += bin;
|
||||
}
|
||||
local_sum = ReduceT(temp_storage->sum_reduce).Reduce(local_sum, cub::Sum());
|
||||
// Reduction result is stored in thread 0.
|
||||
if (threadIdx.x == 0) {
|
||||
shared_sum = local_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
return shared_sum;
|
||||
}
|
||||
|
||||
/*! \brief Find the thread with best gain. */
|
||||
template <int BLOCK_THREADS, typename ReduceT, typename ScanT,
|
||||
typename MaxReduceT, typename TempStorageT, typename GradientSumT>
|
||||
__device__ void EvaluateFeature(
|
||||
int fidx, common::Span<const GradientSumT> node_histogram,
|
||||
const EllpackDeviceAccessor& matrix,
|
||||
DeviceSplitCandidate* best_split, // shared memory storing best split
|
||||
const DeviceNodeStats& node, const GPUTrainingParam& param,
|
||||
TempStorageT* temp_storage, // temp memory for cub operations
|
||||
int constraint, // monotonic_constraints
|
||||
const ValueConstraint& value_constraint) {
|
||||
// Use pointer from cut to indicate begin and end of bins for each feature.
|
||||
uint32_t gidx_begin = matrix.feature_segments[fidx]; // begining bin
|
||||
uint32_t gidx_end = matrix.feature_segments[fidx + 1]; // end bin for i^th feature
|
||||
|
||||
// Sum histogram bins for current feature
|
||||
GradientSumT const feature_sum = ReduceFeature<BLOCK_THREADS, ReduceT>(
|
||||
node_histogram.subspan(gidx_begin, gidx_end - gidx_begin), temp_storage);
|
||||
|
||||
GradientSumT const parent_sum = GradientSumT(node.sum_gradients);
|
||||
GradientSumT const missing = parent_sum - feature_sum;
|
||||
float const null_gain = -std::numeric_limits<bst_float>::infinity();
|
||||
|
||||
SumCallbackOp<GradientSumT> prefix_op =
|
||||
SumCallbackOp<GradientSumT>();
|
||||
for (int scan_begin = gidx_begin; scan_begin < gidx_end;
|
||||
scan_begin += BLOCK_THREADS) {
|
||||
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
|
||||
|
||||
// Gradient value for current bin.
|
||||
GradientSumT bin =
|
||||
thread_active ? node_histogram[scan_begin + threadIdx.x] : GradientSumT();
|
||||
ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
||||
|
||||
// Whether the gradient of missing values is put to the left side.
|
||||
bool missing_left = true;
|
||||
float gain = null_gain;
|
||||
if (thread_active) {
|
||||
gain = LossChangeMissing(bin, missing, parent_sum, node.root_gain, param,
|
||||
constraint, value_constraint, missing_left);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Find thread with best gain
|
||||
cub::KeyValuePair<int, float> tuple(threadIdx.x, gain);
|
||||
cub::KeyValuePair<int, float> best =
|
||||
MaxReduceT(temp_storage->max_reduce).Reduce(tuple, cub::ArgMax());
|
||||
|
||||
__shared__ cub::KeyValuePair<int, float> block_max;
|
||||
if (threadIdx.x == 0) {
|
||||
block_max = best;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Best thread updates split
|
||||
if (threadIdx.x == block_max.key) {
|
||||
int split_gidx = (scan_begin + threadIdx.x) - 1;
|
||||
float fvalue;
|
||||
if (split_gidx < static_cast<int>(gidx_begin)) {
|
||||
fvalue = matrix.min_fvalue[fidx];
|
||||
} else {
|
||||
fvalue = matrix.gidx_fvalue_map[split_gidx];
|
||||
}
|
||||
GradientSumT left = missing_left ? bin + missing : bin;
|
||||
GradientSumT right = parent_sum - left;
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue,
|
||||
fidx, GradientPair(left), GradientPair(right), param);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <int BLOCK_THREADS, typename GradientSumT>
|
||||
__global__ void EvaluateSplitKernel(
|
||||
common::Span<const GradientSumT> node_histogram, // histogram for gradients
|
||||
common::Span<const bst_feature_t> feature_set, // Selected features
|
||||
DeviceNodeStats node,
|
||||
xgboost::EllpackDeviceAccessor matrix,
|
||||
GPUTrainingParam gpu_param,
|
||||
common::Span<DeviceSplitCandidate> split_candidates, // resulting split
|
||||
ValueConstraint value_constraint,
|
||||
common::Span<int> d_monotonic_constraints) {
|
||||
// KeyValuePair here used as threadIdx.x -> gain_value
|
||||
using ArgMaxT = cub::KeyValuePair<int, float>;
|
||||
using BlockScanT =
|
||||
cub::BlockScan<GradientSumT, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>;
|
||||
using MaxReduceT = cub::BlockReduce<ArgMaxT, BLOCK_THREADS>;
|
||||
|
||||
using SumReduceT = cub::BlockReduce<GradientSumT, BLOCK_THREADS>;
|
||||
|
||||
union TempStorage {
|
||||
typename BlockScanT::TempStorage scan;
|
||||
typename MaxReduceT::TempStorage max_reduce;
|
||||
typename SumReduceT::TempStorage sum_reduce;
|
||||
};
|
||||
|
||||
// Aligned && shared storage for best_split
|
||||
__shared__ cub::Uninitialized<DeviceSplitCandidate> uninitialized_split;
|
||||
DeviceSplitCandidate& best_split = uninitialized_split.Alias();
|
||||
__shared__ TempStorage temp_storage;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
best_split = DeviceSplitCandidate();
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// One block for each feature. Features are sampled, so fidx != blockIdx.x
|
||||
int fidx = feature_set[blockIdx.x];
|
||||
|
||||
int constraint = d_monotonic_constraints[fidx];
|
||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT>(
|
||||
fidx, node_histogram, matrix, &best_split, node, gpu_param, &temp_storage,
|
||||
constraint, value_constraint);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// Record best loss for each feature
|
||||
split_candidates[blockIdx.x] = best_split;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \struct DeviceHistogram
|
||||
*
|
||||
@ -411,24 +230,19 @@ struct GPUHistMakerDevice {
|
||||
std::unique_ptr<RowPartitioner> row_partitioner;
|
||||
DeviceHistogram<GradientSumT> hist{};
|
||||
|
||||
/*! \brief Gradient pair for each row. */
|
||||
common::Span<GradientPair> gpair;
|
||||
|
||||
dh::caching_device_vector<int> monotone_constraints;
|
||||
dh::caching_device_vector<bst_float> prediction_cache;
|
||||
|
||||
/*! \brief Sum gradient for each node. */
|
||||
std::vector<GradientPair> host_node_sum_gradients;
|
||||
dh::caching_device_vector<GradientPair> node_sum_gradients;
|
||||
bst_uint n_rows;
|
||||
std::vector<GradientPair> node_sum_gradients;
|
||||
|
||||
TrainParam param;
|
||||
bool deterministic_histogram;
|
||||
|
||||
GradientSumT histogram_rounding;
|
||||
|
||||
dh::PinnedMemory pinned_memory;
|
||||
|
||||
std::vector<cudaStream_t> streams{};
|
||||
|
||||
common::Monitor monitor;
|
||||
@ -453,22 +267,24 @@ struct GPUHistMakerDevice {
|
||||
BatchParam _batch_param)
|
||||
: device_id(_device_id),
|
||||
page(_page),
|
||||
n_rows(_n_rows),
|
||||
param(std::move(_param)),
|
||||
column_sampler(column_sampler_seed),
|
||||
interaction_constraints(param, n_features),
|
||||
deterministic_histogram{deterministic_histogram},
|
||||
batch_param(_batch_param) {
|
||||
sampler.reset(new GradientBasedSampler(page,
|
||||
n_rows,
|
||||
batch_param,
|
||||
param.subsample,
|
||||
param.sampling_method));
|
||||
sampler.reset(new GradientBasedSampler(
|
||||
page, _n_rows, batch_param, param.subsample, param.sampling_method));
|
||||
if (!param.monotone_constraints.empty()) {
|
||||
// Copy assigning an empty vector causes an exception in MSVC debug builds
|
||||
monotone_constraints = param.monotone_constraints;
|
||||
}
|
||||
node_sum_gradients.resize(param.MaxNodes());
|
||||
|
||||
// Init histogram
|
||||
hist.Init(device_id, page->Cuts().TotalBins());
|
||||
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id));
|
||||
}
|
||||
|
||||
void InitHistogram();
|
||||
|
||||
~GPUHistMakerDevice() { // NOLINT
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
for (auto& stream : streams) {
|
||||
@ -507,11 +323,10 @@ struct GPUHistMakerDevice {
|
||||
param.colsample_bylevel, param.colsample_bytree);
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
this->interaction_constraints.Reset();
|
||||
std::fill(host_node_sum_gradients.begin(), host_node_sum_gradients.end(),
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
|
||||
GradientPair());
|
||||
|
||||
auto sample = sampler->Sample(dh_gpair->DeviceSpan(), dmat);
|
||||
n_rows = sample.sample_rows;
|
||||
page = sample.page;
|
||||
gpair = sample.gpair;
|
||||
|
||||
@ -522,74 +337,87 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
row_partitioner.reset(); // Release the device memory first before reallocating
|
||||
row_partitioner.reset(new RowPartitioner(device_id, n_rows));
|
||||
row_partitioner.reset(new RowPartitioner(device_id, sample.sample_rows));
|
||||
hist.Reset();
|
||||
}
|
||||
|
||||
std::vector<DeviceSplitCandidate> EvaluateSplits(
|
||||
std::vector<int> nidxs, const RegTree& tree,
|
||||
size_t num_columns) {
|
||||
auto result_all = pinned_memory.GetSpan<DeviceSplitCandidate>(nidxs.size());
|
||||
|
||||
// Work out cub temporary memory requirement
|
||||
DeviceSplitCandidate EvaluateRootSplit(GradientPair root_sum) {
|
||||
int nidx = 0;
|
||||
dh::TemporaryArray<DeviceSplitCandidate> splits_out(1);
|
||||
GPUTrainingParam gpu_param(param);
|
||||
DeviceSplitCandidateReduceOp op(gpu_param);
|
||||
|
||||
dh::TemporaryArray<DeviceSplitCandidate> d_result_all(nidxs.size());
|
||||
dh::TemporaryArray<DeviceSplitCandidate> split_candidates_all(nidxs.size()*num_columns);
|
||||
|
||||
auto& streams = this->GetStreams(nidxs.size());
|
||||
for (auto i = 0ull; i < nidxs.size(); i++) {
|
||||
auto nidx = nidxs[i];
|
||||
auto p_feature_set = column_sampler.GetFeatureSet(tree.GetDepth(nidx));
|
||||
p_feature_set->SetDevice(device_id);
|
||||
common::Span<bst_feature_t> d_sampled_features =
|
||||
p_feature_set->DeviceSpan();
|
||||
common::Span<bst_feature_t> d_feature_set =
|
||||
interaction_constraints.Query(d_sampled_features, nidx);
|
||||
common::Span<DeviceSplitCandidate> d_split_candidates(
|
||||
split_candidates_all.data().get() + i * num_columns,
|
||||
d_feature_set.size());
|
||||
|
||||
DeviceNodeStats node(host_node_sum_gradients[nidx], nidx, param);
|
||||
|
||||
common::Span<DeviceSplitCandidate> d_result(d_result_all.data().get() + i, 1);
|
||||
if (d_feature_set.empty()) {
|
||||
// Acting as a device side constructor for DeviceSplitCandidate.
|
||||
// DeviceSplitCandidate::IsValid is false so that ApplySplit can reject this
|
||||
// candidate.
|
||||
auto worst_candidate = DeviceSplitCandidate();
|
||||
dh::safe_cuda(cudaMemcpyAsync(d_result.data(), &worst_candidate,
|
||||
sizeof(DeviceSplitCandidate),
|
||||
cudaMemcpyHostToDevice));
|
||||
continue;
|
||||
}
|
||||
|
||||
// One block for each feature
|
||||
uint32_t constexpr kBlockThreads = 256;
|
||||
dh::LaunchKernel {uint32_t(d_feature_set.size()), kBlockThreads, 0, streams[i]} (
|
||||
EvaluateSplitKernel<kBlockThreads, GradientSumT>,
|
||||
hist.GetNodeHistogram(nidx), d_feature_set, node, page->GetDeviceAccessor(device_id),
|
||||
gpu_param, d_split_candidates, node_value_constraints[nidx],
|
||||
dh::ToSpan(monotone_constraints));
|
||||
|
||||
// Reduce over features to find best feature
|
||||
size_t cub_bytes = 0;
|
||||
cub::DeviceReduce::Reduce(nullptr,
|
||||
cub_bytes, d_split_candidates.data(),
|
||||
d_result.data(), d_split_candidates.size(), op,
|
||||
DeviceSplitCandidate(), streams[i]);
|
||||
dh::TemporaryArray<char> cub_temp(cub_bytes);
|
||||
cub::DeviceReduce::Reduce(reinterpret_cast<void*>(cub_temp.data().get()),
|
||||
cub_bytes, d_split_candidates.data(),
|
||||
d_result.data(), d_split_candidates.size(), op,
|
||||
DeviceSplitCandidate(), streams[i]);
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaMemcpy(result_all.data(), d_result_all.data().get(),
|
||||
sizeof(DeviceSplitCandidate) * d_result_all.size(),
|
||||
auto sampled_features = column_sampler.GetFeatureSet(0);
|
||||
sampled_features->SetDevice(device_id);
|
||||
common::Span<bst_feature_t> feature_set =
|
||||
interaction_constraints.Query(sampled_features->DeviceSpan(), nidx);
|
||||
auto matrix = page->GetDeviceAccessor(device_id);
|
||||
EvaluateSplitInputs<GradientSumT> inputs{
|
||||
nidx,
|
||||
{root_sum.GetGrad(), root_sum.GetHess()},
|
||||
gpu_param,
|
||||
feature_set,
|
||||
matrix.feature_segments,
|
||||
matrix.gidx_fvalue_map,
|
||||
matrix.min_fvalue,
|
||||
hist.GetNodeHistogram(nidx),
|
||||
node_value_constraints[nidx],
|
||||
dh::ToSpan(monotone_constraints)};
|
||||
EvaluateSingleSplit(dh::ToSpan(splits_out), inputs);
|
||||
std::vector<DeviceSplitCandidate> result(1);
|
||||
dh::safe_cuda(cudaMemcpy(result.data(), splits_out.data().get(),
|
||||
sizeof(DeviceSplitCandidate) * splits_out.size(),
|
||||
cudaMemcpyDeviceToHost));
|
||||
return std::vector<DeviceSplitCandidate>(result_all.begin(), result_all.end());
|
||||
return result.front();
|
||||
}
|
||||
|
||||
std::vector<DeviceSplitCandidate> EvaluateLeftRightSplits(
|
||||
ExpandEntry candidate, int left_nidx, int right_nidx,
|
||||
const RegTree& tree) {
|
||||
dh::TemporaryArray<DeviceSplitCandidate> splits_out(2);
|
||||
GPUTrainingParam gpu_param(param);
|
||||
auto left_sampled_features =
|
||||
column_sampler.GetFeatureSet(tree.GetDepth(left_nidx));
|
||||
left_sampled_features->SetDevice(device_id);
|
||||
common::Span<bst_feature_t> left_feature_set =
|
||||
interaction_constraints.Query(left_sampled_features->DeviceSpan(),
|
||||
left_nidx);
|
||||
auto right_sampled_features =
|
||||
column_sampler.GetFeatureSet(tree.GetDepth(right_nidx));
|
||||
right_sampled_features->SetDevice(device_id);
|
||||
common::Span<bst_feature_t> right_feature_set =
|
||||
interaction_constraints.Query(right_sampled_features->DeviceSpan(),
|
||||
left_nidx);
|
||||
auto matrix = page->GetDeviceAccessor(device_id);
|
||||
|
||||
EvaluateSplitInputs<GradientSumT> left{left_nidx,
|
||||
{candidate.split.left_sum.GetGrad(),
|
||||
candidate.split.left_sum.GetHess()},
|
||||
gpu_param,
|
||||
left_feature_set,
|
||||
matrix.feature_segments,
|
||||
matrix.gidx_fvalue_map,
|
||||
matrix.min_fvalue,
|
||||
hist.GetNodeHistogram(left_nidx),
|
||||
node_value_constraints[left_nidx],
|
||||
dh::ToSpan(monotone_constraints)};
|
||||
EvaluateSplitInputs<GradientSumT> right{
|
||||
right_nidx,
|
||||
{candidate.split.right_sum.GetGrad(),
|
||||
candidate.split.right_sum.GetHess()},
|
||||
gpu_param,
|
||||
right_feature_set,
|
||||
matrix.feature_segments,
|
||||
matrix.gidx_fvalue_map,
|
||||
matrix.min_fvalue,
|
||||
hist.GetNodeHistogram(right_nidx),
|
||||
node_value_constraints[right_nidx],
|
||||
dh::ToSpan(monotone_constraints)};
|
||||
EvaluateSplits(dh::ToSpan(splits_out), left, right);
|
||||
std::vector<DeviceSplitCandidate> result(2);
|
||||
dh::safe_cuda(cudaMemcpy(result.data(), splits_out.data().get(),
|
||||
sizeof(DeviceSplitCandidate) * splits_out.size(),
|
||||
cudaMemcpyDeviceToHost));
|
||||
return result;
|
||||
}
|
||||
|
||||
void BuildHist(int nidx) {
|
||||
@ -704,13 +532,14 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
CalcWeightTrainParam param_d(param);
|
||||
dh::TemporaryArray<GradientPair> device_node_sum_gradients(node_sum_gradients.size());
|
||||
|
||||
dh::safe_cuda(
|
||||
cudaMemcpyAsync(node_sum_gradients.data().get(), host_node_sum_gradients.data(),
|
||||
sizeof(GradientPair) * host_node_sum_gradients.size(),
|
||||
cudaMemcpyAsync(device_node_sum_gradients.data().get(), node_sum_gradients.data(),
|
||||
sizeof(GradientPair) * node_sum_gradients.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
auto d_position = row_partitioner->GetPosition();
|
||||
auto d_node_sum_gradients = node_sum_gradients.data().get();
|
||||
auto d_node_sum_gradients = device_node_sum_gradients.data().get();
|
||||
auto d_prediction_cache = prediction_cache.data().get();
|
||||
|
||||
dh::LaunchN(
|
||||
@ -775,63 +604,56 @@ struct GPUHistMakerDevice {
|
||||
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
|
||||
RegTree& tree = *p_tree;
|
||||
|
||||
GradStats left_stats{};
|
||||
left_stats.Add(candidate.split.left_sum);
|
||||
GradStats right_stats{};
|
||||
right_stats.Add(candidate.split.right_sum);
|
||||
GradStats parent_sum{};
|
||||
parent_sum.Add(left_stats);
|
||||
parent_sum.Add(right_stats);
|
||||
node_value_constraints.resize(tree.GetNodes().size());
|
||||
auto base_weight = node_value_constraints[candidate.nid].CalcWeight(param, parent_sum);
|
||||
auto left_weight =
|
||||
node_value_constraints[candidate.nid].CalcWeight(param, left_stats)*param.learning_rate;
|
||||
auto right_weight =
|
||||
node_value_constraints[candidate.nid].CalcWeight(param, right_stats)*param.learning_rate;
|
||||
auto parent_sum = candidate.split.left_sum + candidate.split.right_sum;
|
||||
auto base_weight = node_value_constraints[candidate.nid].CalcWeight(
|
||||
param, parent_sum);
|
||||
auto left_weight = node_value_constraints[candidate.nid].CalcWeight(
|
||||
param, candidate.split.left_sum) *
|
||||
param.learning_rate;
|
||||
auto right_weight = node_value_constraints[candidate.nid].CalcWeight(
|
||||
param, candidate.split.right_sum) *
|
||||
param.learning_rate;
|
||||
tree.ExpandNode(candidate.nid, candidate.split.findex,
|
||||
candidate.split.fvalue, candidate.split.dir == kLeftDir,
|
||||
base_weight, left_weight, right_weight,
|
||||
candidate.split.loss_chg, parent_sum.sum_hess,
|
||||
left_stats.GetHess(), right_stats.GetHess());
|
||||
candidate.split.loss_chg, parent_sum.GetHess(),
|
||||
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
|
||||
// Set up child constraints
|
||||
node_value_constraints.resize(tree.GetNodes().size());
|
||||
node_value_constraints[candidate.nid].SetChild(
|
||||
param, tree[candidate.nid].SplitIndex(), left_stats, right_stats,
|
||||
param, tree[candidate.nid].SplitIndex(), candidate.split.left_sum,
|
||||
candidate.split.right_sum,
|
||||
&node_value_constraints[tree[candidate.nid].LeftChild()],
|
||||
&node_value_constraints[tree[candidate.nid].RightChild()]);
|
||||
host_node_sum_gradients[tree[candidate.nid].LeftChild()] =
|
||||
node_sum_gradients[tree[candidate.nid].LeftChild()] =
|
||||
candidate.split.left_sum;
|
||||
host_node_sum_gradients[tree[candidate.nid].RightChild()] =
|
||||
node_sum_gradients[tree[candidate.nid].RightChild()] =
|
||||
candidate.split.right_sum;
|
||||
|
||||
interaction_constraints.Split(candidate.nid, tree[candidate.nid].SplitIndex(),
|
||||
interaction_constraints.Split(
|
||||
candidate.nid, tree[candidate.nid].SplitIndex(),
|
||||
tree[candidate.nid].LeftChild(),
|
||||
tree[candidate.nid].RightChild());
|
||||
}
|
||||
|
||||
void InitRoot(RegTree* p_tree, dh::AllReducer* reducer, int64_t num_columns) {
|
||||
void InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
|
||||
constexpr bst_node_t kRootNIdx = 0;
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
GradientPair root_sum = thrust::reduce(
|
||||
thrust::cuda::par(alloc),
|
||||
thrust::device_ptr<GradientPair const>(gpair.data()),
|
||||
thrust::device_ptr<GradientPair const>(gpair.data() + gpair.size()));
|
||||
dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients.data().get(), &root_sum, sizeof(root_sum),
|
||||
cudaMemcpyHostToDevice));
|
||||
reducer->AllReduceSum(
|
||||
reinterpret_cast<float*>(node_sum_gradients.data().get()),
|
||||
reinterpret_cast<float*>(node_sum_gradients.data().get()), 2);
|
||||
reducer->Synchronize();
|
||||
dh::safe_cuda(cudaMemcpyAsync(host_node_sum_gradients.data(),
|
||||
node_sum_gradients.data().get(), sizeof(GradientPair),
|
||||
cudaMemcpyDeviceToHost));
|
||||
rabit::Allreduce<rabit::op::Sum, float>(reinterpret_cast<float*>(&root_sum),
|
||||
2);
|
||||
|
||||
this->BuildHist(kRootNIdx);
|
||||
this->AllReduceHist(kRootNIdx, reducer);
|
||||
|
||||
// Remember root stats
|
||||
p_tree->Stat(kRootNIdx).sum_hess = host_node_sum_gradients[kRootNIdx].GetHess();
|
||||
auto weight = CalcWeight(param, host_node_sum_gradients[kRootNIdx]);
|
||||
node_sum_gradients[kRootNIdx] = root_sum;
|
||||
p_tree->Stat(kRootNIdx).sum_hess = root_sum.GetHess();
|
||||
auto weight = CalcWeight(param, root_sum);
|
||||
p_tree->Stat(kRootNIdx).base_weight = weight;
|
||||
(*p_tree)[kRootNIdx].SetLeaf(param.learning_rate * weight);
|
||||
|
||||
@ -839,9 +661,9 @@ struct GPUHistMakerDevice {
|
||||
node_value_constraints.resize(p_tree->GetNodes().size());
|
||||
|
||||
// Generate first split
|
||||
auto split = this->EvaluateSplits({kRootNIdx}, *p_tree, num_columns);
|
||||
auto split = this->EvaluateRootSplit(root_sum);
|
||||
qexpand->push(
|
||||
ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split.at(0), 0));
|
||||
ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split, 0));
|
||||
}
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
|
||||
@ -853,7 +675,7 @@ struct GPUHistMakerDevice {
|
||||
monitor.StopCuda("Reset");
|
||||
|
||||
monitor.StartCuda("InitRoot");
|
||||
this->InitRoot(p_tree, reducer, p_fmat->Info().num_col_);
|
||||
this->InitRoot(p_tree, reducer);
|
||||
monitor.StopCuda("InitRoot");
|
||||
|
||||
auto timestamp = qexpand->size();
|
||||
@ -883,8 +705,9 @@ struct GPUHistMakerDevice {
|
||||
monitor.StopCuda("BuildHist");
|
||||
|
||||
monitor.StartCuda("EvaluateSplits");
|
||||
auto splits = this->EvaluateSplits({left_child_nidx, right_child_nidx},
|
||||
*p_tree, p_fmat->Info().num_col_);
|
||||
auto splits = this->EvaluateLeftRightSplits(candidate, left_child_nidx,
|
||||
right_child_nidx,
|
||||
*p_tree);
|
||||
monitor.StopCuda("EvaluateSplits");
|
||||
|
||||
qexpand->push(ExpandEntry(left_child_nidx,
|
||||
@ -902,19 +725,6 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
inline void GPUHistMakerDevice<GradientSumT>::InitHistogram() {
|
||||
if (!param.monotone_constraints.empty()) {
|
||||
// Copy assigning an empty vector causes an exception in MSVC debug builds
|
||||
monotone_constraints = param.monotone_constraints;
|
||||
}
|
||||
host_node_sum_gradients.resize(param.MaxNodes());
|
||||
node_sum_gradients.resize(param.MaxNodes());
|
||||
|
||||
// Init histogram
|
||||
hist.Init(device_id, page->Cuts().TotalBins());
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
class GPUHistMakerSpecialised {
|
||||
public:
|
||||
@ -984,11 +794,6 @@ class GPUHistMakerSpecialised {
|
||||
hist_maker_param_.deterministic_histogram,
|
||||
batch_param));
|
||||
|
||||
monitor_.StartCuda("InitHistogram");
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
maker->InitHistogram();
|
||||
monitor_.StopCuda("InitHistogram");
|
||||
|
||||
p_last_fmat_ = dmat;
|
||||
initialised_ = true;
|
||||
}
|
||||
|
||||
222
tests/cpp/tree/gpu_hist/test_evaluate_splits.cu
Normal file
222
tests/cpp/tree/gpu_hist/test_evaluate_splits.cu
Normal file
@ -0,0 +1,222 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include "../../../../src/tree/gpu_hist/evaluate_splits.cuh"
|
||||
#include "../../helpers.h"
|
||||
#include "../../histogram_helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
TEST(GpuHist, EvaluateSingleSplit) {
|
||||
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
||||
GradientPair parent_sum(0.0, 1.0);
|
||||
GPUTrainingParam param{};
|
||||
|
||||
thrust::device_vector<bst_feature_t> feature_set =
|
||||
std::vector<bst_feature_t>{0, 1};
|
||||
thrust::device_vector<uint32_t> feature_segments =
|
||||
std::vector<bst_row_t>{0, 2, 4};
|
||||
thrust::device_vector<float> feature_values =
|
||||
std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
thrust::device_vector<float> feature_min_values =
|
||||
std::vector<float>{0.0, 0.0};
|
||||
// Setup gradients so that second feature gets higher gain
|
||||
thrust::device_vector<GradientPair> feature_histogram =
|
||||
std::vector<GradientPair>{
|
||||
{-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}};
|
||||
thrust::device_vector<int> monotonic_constraints(feature_set.size(), 0);
|
||||
EvaluateSplitInputs<GradientPair> input{1,
|
||||
parent_sum,
|
||||
param,
|
||||
dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
dh::ToSpan(feature_min_values),
|
||||
dh::ToSpan(feature_histogram),
|
||||
ValueConstraint(),
|
||||
dh::ToSpan(monotonic_constraints)};
|
||||
EvaluateSingleSplit(dh::ToSpan(out_splits), input);
|
||||
|
||||
DeviceSplitCandidate result = out_splits[0];
|
||||
EXPECT_EQ(result.findex, 1);
|
||||
EXPECT_EQ(result.fvalue, 11.0);
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(),
|
||||
parent_sum.GetGrad());
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(),
|
||||
parent_sum.GetHess());
|
||||
}
|
||||
|
||||
TEST(GpuHist, EvaluateSingleSplitMissing) {
|
||||
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
||||
GradientPair parent_sum(1.0, 1.5);
|
||||
GPUTrainingParam param{};
|
||||
|
||||
thrust::device_vector<bst_feature_t> feature_set =
|
||||
std::vector<bst_feature_t>{0};
|
||||
thrust::device_vector<uint32_t> feature_segments =
|
||||
std::vector<bst_row_t>{0, 2};
|
||||
thrust::device_vector<float> feature_values = std::vector<float>{1.0, 2.0};
|
||||
thrust::device_vector<float> feature_min_values = std::vector<float>{0.0};
|
||||
thrust::device_vector<GradientPair> feature_histogram =
|
||||
std::vector<GradientPair>{{-0.5, 0.5}, {0.5, 0.5}};
|
||||
thrust::device_vector<int> monotonic_constraints(feature_set.size(), 0);
|
||||
EvaluateSplitInputs<GradientPair> input{1,
|
||||
parent_sum,
|
||||
param,
|
||||
dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
dh::ToSpan(feature_min_values),
|
||||
dh::ToSpan(feature_histogram),
|
||||
ValueConstraint(),
|
||||
dh::ToSpan(monotonic_constraints)};
|
||||
EvaluateSingleSplit(dh::ToSpan(out_splits), input);
|
||||
|
||||
DeviceSplitCandidate result = out_splits[0];
|
||||
EXPECT_EQ(result.findex, 0);
|
||||
EXPECT_EQ(result.fvalue, 1.0);
|
||||
EXPECT_EQ(result.dir, kRightDir);
|
||||
EXPECT_EQ(result.left_sum, GradientPair(-0.5, 0.5));
|
||||
EXPECT_EQ(result.right_sum, GradientPair(1.5, 1.0));
|
||||
}
|
||||
|
||||
TEST(GpuHist, EvaluateSingleSplitEmpty) {
|
||||
DeviceSplitCandidate nonzeroed;
|
||||
nonzeroed.findex = 1;
|
||||
nonzeroed.loss_chg = 1.0;
|
||||
|
||||
thrust::device_vector<DeviceSplitCandidate> out_split(1);
|
||||
out_split[0] = nonzeroed;
|
||||
EvaluateSingleSplit(dh::ToSpan(out_split),
|
||||
EvaluateSplitInputs<GradientPair>{});
|
||||
DeviceSplitCandidate result = out_split[0];
|
||||
EXPECT_EQ(result.findex, -1);
|
||||
EXPECT_LT(result.loss_chg, 0.0f);
|
||||
}
|
||||
|
||||
// Feature 0 has a better split, but the algorithm must select feature 1
|
||||
TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
||||
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
||||
GradientPair parent_sum(0.0, 1.0);
|
||||
GPUTrainingParam param{};
|
||||
|
||||
thrust::device_vector<bst_feature_t> feature_set =
|
||||
std::vector<bst_feature_t>{1};
|
||||
thrust::device_vector<uint32_t> feature_segments =
|
||||
std::vector<bst_row_t>{0, 2, 4};
|
||||
thrust::device_vector<float> feature_values =
|
||||
std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
thrust::device_vector<float> feature_min_values =
|
||||
std::vector<float>{0.0, 10.0};
|
||||
thrust::device_vector<GradientPair> feature_histogram =
|
||||
std::vector<GradientPair>{
|
||||
{-10.0, 0.5}, {10.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}};
|
||||
thrust::device_vector<int> monotonic_constraints(2, 0);
|
||||
EvaluateSplitInputs<GradientPair> input{1,
|
||||
parent_sum,
|
||||
param,
|
||||
dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
dh::ToSpan(feature_min_values),
|
||||
dh::ToSpan(feature_histogram),
|
||||
ValueConstraint(),
|
||||
dh::ToSpan(monotonic_constraints)};
|
||||
EvaluateSingleSplit(dh::ToSpan(out_splits), input);
|
||||
|
||||
DeviceSplitCandidate result = out_splits[0];
|
||||
EXPECT_EQ(result.findex, 1);
|
||||
EXPECT_EQ(result.fvalue, 11.0);
|
||||
EXPECT_EQ(result.left_sum, GradientPair(-0.5, 0.5));
|
||||
EXPECT_EQ(result.right_sum, GradientPair(0.5, 0.5));
|
||||
}
|
||||
|
||||
// Features 0 and 1 have identical gain, the algorithm must select 0
|
||||
TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
||||
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
||||
GradientPair parent_sum(0.0, 1.0);
|
||||
GPUTrainingParam param{};
|
||||
|
||||
thrust::device_vector<bst_feature_t> feature_set =
|
||||
std::vector<bst_feature_t>{0, 1};
|
||||
thrust::device_vector<uint32_t> feature_segments =
|
||||
std::vector<bst_row_t>{0, 2, 4};
|
||||
thrust::device_vector<float> feature_values =
|
||||
std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
thrust::device_vector<float> feature_min_values =
|
||||
std::vector<float>{0.0, 10.0};
|
||||
thrust::device_vector<GradientPair> feature_histogram =
|
||||
std::vector<GradientPair>{
|
||||
{-0.5, 0.5}, {0.5, 0.5}, {-0.5, 0.5}, {0.5, 0.5}};
|
||||
thrust::device_vector<int> monotonic_constraints(2, 0);
|
||||
EvaluateSplitInputs<GradientPair> input{1,
|
||||
parent_sum,
|
||||
param,
|
||||
dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
dh::ToSpan(feature_min_values),
|
||||
dh::ToSpan(feature_histogram),
|
||||
ValueConstraint(),
|
||||
dh::ToSpan(monotonic_constraints)};
|
||||
EvaluateSingleSplit(dh::ToSpan(out_splits), input);
|
||||
|
||||
DeviceSplitCandidate result = out_splits[0];
|
||||
EXPECT_EQ(result.findex, 0);
|
||||
EXPECT_EQ(result.fvalue, 1.0);
|
||||
}
|
||||
|
||||
TEST(GpuHist, EvaluateSplits) {
|
||||
thrust::device_vector<DeviceSplitCandidate> out_splits(2);
|
||||
GradientPair parent_sum(0.0, 1.0);
|
||||
GPUTrainingParam param{};
|
||||
|
||||
thrust::device_vector<bst_feature_t> feature_set =
|
||||
std::vector<bst_feature_t>{0, 1};
|
||||
thrust::device_vector<uint32_t> feature_segments =
|
||||
std::vector<bst_row_t>{0, 2, 4};
|
||||
thrust::device_vector<float> feature_values =
|
||||
std::vector<float>{1.0, 2.0, 11.0, 12.0};
|
||||
thrust::device_vector<float> feature_min_values =
|
||||
std::vector<float>{0.0, 0.0};
|
||||
thrust::device_vector<GradientPair> feature_histogram_left =
|
||||
std::vector<GradientPair>{
|
||||
{-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}};
|
||||
thrust::device_vector<GradientPair> feature_histogram_right =
|
||||
std::vector<GradientPair>{
|
||||
{-1.0, 0.5}, {1.0, 0.5}, {-0.5, 0.5}, {0.5, 0.5}};
|
||||
thrust::device_vector<int> monotonic_constraints(feature_set.size(), 0);
|
||||
EvaluateSplitInputs<GradientPair> input_left{
|
||||
1,
|
||||
parent_sum,
|
||||
param,
|
||||
dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
dh::ToSpan(feature_min_values),
|
||||
dh::ToSpan(feature_histogram_left),
|
||||
ValueConstraint(),
|
||||
dh::ToSpan(monotonic_constraints)};
|
||||
EvaluateSplitInputs<GradientPair> input_right{
|
||||
2,
|
||||
parent_sum,
|
||||
param,
|
||||
dh::ToSpan(feature_set),
|
||||
dh::ToSpan(feature_segments),
|
||||
dh::ToSpan(feature_values),
|
||||
dh::ToSpan(feature_min_values),
|
||||
dh::ToSpan(feature_histogram_right),
|
||||
ValueConstraint(),
|
||||
dh::ToSpan(monotonic_constraints)};
|
||||
EvaluateSplits(dh::ToSpan(out_splits), input_left, input_right);
|
||||
|
||||
DeviceSplitCandidate result_left = out_splits[0];
|
||||
EXPECT_EQ(result_left.findex, 1);
|
||||
EXPECT_EQ(result_left.fvalue, 11.0);
|
||||
|
||||
DeviceSplitCandidate result_right = out_splits[1];
|
||||
EXPECT_EQ(result_right.findex, 0);
|
||||
EXPECT_EQ(result_right.fvalue, 1.0);
|
||||
}
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
@ -41,9 +41,9 @@ void VerifySampling(size_t page_size,
|
||||
EXPECT_EQ(sample.page->n_rows, kRows);
|
||||
EXPECT_EQ(sample.gpair.size(), kRows);
|
||||
} else {
|
||||
EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.016);
|
||||
EXPECT_NEAR(sample.page->n_rows, sample_rows, kRows * 0.016f);
|
||||
EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.016f);
|
||||
EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.03);
|
||||
EXPECT_NEAR(sample.page->n_rows, sample_rows, kRows * 0.03f);
|
||||
EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.03f);
|
||||
}
|
||||
|
||||
GradientPair sum_sampled_gpair{};
|
||||
|
||||
@ -82,8 +82,6 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
BatchParam batch_param{};
|
||||
GPUHistMakerDevice<GradientSumT> maker(0, page.get(), kNRows, param, kNCols, kNCols,
|
||||
true, batch_param);
|
||||
maker.InitHistogram();
|
||||
|
||||
xgboost::SimpleLCG gen;
|
||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
||||
HostDeviceVector<GradientPair> gpair(kNRows);
|
||||
@ -150,7 +148,7 @@ HistogramCutsWrapper GetHostCutMatrix () {
|
||||
}
|
||||
|
||||
// TODO(trivialfis): This test is over simplified.
|
||||
TEST(GpuHist, EvaluateSplits) {
|
||||
TEST(GpuHist, EvaluateRootSplit) {
|
||||
constexpr int kNRows = 16;
|
||||
constexpr int kNCols = 8;
|
||||
|
||||
@ -182,7 +180,7 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
GPUHistMakerDevice<GradientPairPrecise>
|
||||
maker(0, page.get(), kNRows, param, kNCols, kNCols, true, batch_param);
|
||||
// Initialize GPUHistMakerDevice::node_sum_gradients
|
||||
maker.host_node_sum_gradients = {{6.4f, 12.8f}};
|
||||
maker.node_sum_gradients = {};
|
||||
|
||||
// Initialize GPUHistMakerDevice::cut
|
||||
auto cmat = GetHostCutMatrix();
|
||||
@ -222,12 +220,10 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
maker.node_value_constraints[0].lower_bound = -1.0;
|
||||
maker.node_value_constraints[0].upper_bound = 1.0;
|
||||
|
||||
std::vector<DeviceSplitCandidate> res = maker.EvaluateSplits({0, 0 }, tree, kNCols);
|
||||
DeviceSplitCandidate res = maker.EvaluateRootSplit({6.4f, 12.8f});
|
||||
|
||||
ASSERT_EQ(res[0].findex, 7);
|
||||
ASSERT_EQ(res[1].findex, 7);
|
||||
ASSERT_NEAR(res[0].fvalue, 0.26, xgboost::kRtEps);
|
||||
ASSERT_NEAR(res[1].fvalue, 0.26, xgboost::kRtEps);
|
||||
ASSERT_EQ(res.findex, 7);
|
||||
ASSERT_NEAR(res.fvalue, 0.26, xgboost::kRtEps);
|
||||
}
|
||||
|
||||
void TestHistogramIndexImpl() {
|
||||
|
||||
@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
from test_predict import run_threaded_predict # noqa
|
||||
@ -34,12 +35,13 @@ class TestGPUPredict(unittest.TestCase):
|
||||
param = {
|
||||
"objective": "binary:logistic",
|
||||
"predictor": "gpu_predictor",
|
||||
'eval_metric': 'auc',
|
||||
'tree_method': 'gpu_hist'
|
||||
'eval_metric': 'logloss',
|
||||
'tree_method': 'gpu_hist',
|
||||
'max_depth': 1
|
||||
}
|
||||
bst = xgb.train(param, dtrain, iterations, evals=watchlist,
|
||||
evals_result=res)
|
||||
assert self.non_decreasing(res["train"]["auc"])
|
||||
assert self.non_increasing(res["train"]["logloss"])
|
||||
gpu_pred_train = bst.predict(dtrain, output_margin=True)
|
||||
gpu_pred_test = bst.predict(dtest, output_margin=True)
|
||||
gpu_pred_val = bst.predict(dval, output_margin=True)
|
||||
@ -57,8 +59,8 @@ class TestGPUPredict(unittest.TestCase):
|
||||
np.testing.assert_allclose(cpu_pred_test, gpu_pred_test,
|
||||
rtol=1e-6)
|
||||
|
||||
def non_decreasing(self, L):
|
||||
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))
|
||||
def non_increasing(self, L):
|
||||
return all((y - x) < 0.001 for x, y in zip(L, L[1:]))
|
||||
|
||||
# Test case for a bug where multiple batch predictions made on a
|
||||
# test set produce incorrect results
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user