Improve update position function for gpu_hist (#3895)
This commit is contained in:
parent
143475b27b
commit
7af0946ac1
@ -766,7 +766,8 @@ typename std::iterator_traits<T>::value_type SumReduction(
|
||||
dh::CubMemory &tmp_mem, T in, int nVals) {
|
||||
using ValueT = typename std::iterator_traits<T>::value_type;
|
||||
size_t tmpSize;
|
||||
dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, in, nVals));
|
||||
ValueT *dummy_out = nullptr;
|
||||
dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, dummy_out, nVals));
|
||||
// Allocate small extra memory for the return value
|
||||
tmp_mem.LazyAllocate(tmpSize + sizeof(ValueT));
|
||||
auto ptr = reinterpret_cast<ValueT *>(tmp_mem.d_temp_storage) + 1;
|
||||
@ -1074,4 +1075,71 @@ xgboost::common::Span<T> ToSpan(thrust::device_vector<T>& vec,
|
||||
using IndexT = typename xgboost::common::Span<T>::index_type;
|
||||
return ToSpan(vec, static_cast<IndexT>(offset), static_cast<IndexT>(size));
|
||||
}
|
||||
|
||||
template <typename FunctionT>
|
||||
class LauncherItr {
|
||||
public:
|
||||
int idx;
|
||||
FunctionT f;
|
||||
XGBOOST_DEVICE LauncherItr() : idx(0) {}
|
||||
XGBOOST_DEVICE LauncherItr(int idx, FunctionT f) : idx(idx), f(f) {}
|
||||
XGBOOST_DEVICE LauncherItr &operator=(int output) {
|
||||
f(idx, output);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Thrust compatible iterator type - discards algorithm output and launches device lambda
|
||||
* with the index of the output and the algorithm output as arguments.
|
||||
*
|
||||
* \author Rory
|
||||
* \date 7/9/2017
|
||||
*
|
||||
* \tparam FunctionT Type of the function t.
|
||||
*/
|
||||
template <typename FunctionT>
|
||||
class DiscardLambdaItr {
|
||||
public:
|
||||
// Required iterator traits
|
||||
using self_type = DiscardLambdaItr; // NOLINT
|
||||
using difference_type = ptrdiff_t; // NOLINT
|
||||
using value_type = void; // NOLINT
|
||||
using pointer = value_type *; // NOLINT
|
||||
using reference = LauncherItr<FunctionT>; // NOLINT
|
||||
using iterator_category = typename thrust::detail::iterator_facade_category<
|
||||
thrust::any_system_tag, thrust::random_access_traversal_tag, value_type,
|
||||
reference>::type; // NOLINT
|
||||
private:
|
||||
difference_type offset_;
|
||||
FunctionT f_;
|
||||
public:
|
||||
XGBOOST_DEVICE explicit DiscardLambdaItr(FunctionT f) : offset_(0), f_(f) {}
|
||||
XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, FunctionT f)
|
||||
: offset_(offset), f_(f) {}
|
||||
XGBOOST_DEVICE self_type operator+(const int &b) const {
|
||||
return DiscardLambdaItr(offset_ + b, f_);
|
||||
}
|
||||
XGBOOST_DEVICE self_type operator++() {
|
||||
offset_++;
|
||||
return *this;
|
||||
}
|
||||
XGBOOST_DEVICE self_type operator++(int) {
|
||||
self_type retval = *this;
|
||||
offset_++;
|
||||
return retval;
|
||||
}
|
||||
XGBOOST_DEVICE self_type &operator+=(const int &b) {
|
||||
offset_ += b;
|
||||
return *this;
|
||||
}
|
||||
XGBOOST_DEVICE reference operator*() const {
|
||||
return LauncherItr<FunctionT>(offset_, f_);
|
||||
}
|
||||
XGBOOST_DEVICE reference operator[](int idx) {
|
||||
self_type offset = (*this) + idx;
|
||||
return *offset;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace dh
|
||||
|
||||
@ -380,6 +380,53 @@ struct Segment {
|
||||
size_t Size() const { return end - begin; }
|
||||
};
|
||||
|
||||
/** \brief Returns a one if the left node index is encountered, otherwise return
|
||||
* zero. */
|
||||
struct IndicateLeftTransform {
|
||||
int left_nidx;
|
||||
explicit IndicateLeftTransform(int left_nidx) : left_nidx(left_nidx) {}
|
||||
__host__ __device__ __forceinline__ int operator()(const int& x) const {
|
||||
return x == left_nidx ? 1 : 0;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Optimised routine for sorting key value pairs into left and right
|
||||
* segments. Based on a single pass of exclusive scan, uses iterators to
|
||||
* redirect inputs and outputs.
|
||||
*/
|
||||
void SortPosition(dh::CubMemory* temp_memory, common::Span<int> position,
|
||||
common::Span<int> position_out, common::Span<bst_uint> ridx,
|
||||
common::Span<bst_uint> ridx_out, int left_nidx,
|
||||
int right_nidx, int64_t left_count) {
|
||||
auto d_position_out = position_out.data();
|
||||
auto d_position_in = position.data();
|
||||
auto d_ridx_out = ridx_out.data();
|
||||
auto d_ridx_in = ridx.data();
|
||||
auto write_results = [=] __device__(size_t idx, int ex_scan_result) {
|
||||
int scatter_address;
|
||||
if (d_position_in[idx] == left_nidx) {
|
||||
scatter_address = ex_scan_result;
|
||||
} else {
|
||||
scatter_address = (idx - ex_scan_result) + left_count;
|
||||
}
|
||||
d_position_out[scatter_address] = d_position_in[idx];
|
||||
d_ridx_out[scatter_address] = d_ridx_in[idx];
|
||||
}; // NOLINT
|
||||
|
||||
IndicateLeftTransform conversion_op(left_nidx);
|
||||
cub::TransformInputIterator<int, IndicateLeftTransform, int*> in_itr(
|
||||
d_position_in, conversion_op);
|
||||
dh::DiscardLambdaItr<decltype(write_results)> out_itr(write_results);
|
||||
size_t temp_storage_bytes = 0;
|
||||
cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr,
|
||||
position.size());
|
||||
temp_memory->LazyAllocate(temp_storage_bytes);
|
||||
cub::DeviceScan::ExclusiveSum(temp_memory->d_temp_storage,
|
||||
temp_memory->temp_storage_bytes, in_itr,
|
||||
out_itr, position.size());
|
||||
}
|
||||
|
||||
struct DeviceShard;
|
||||
|
||||
struct GPUHistBuilderBase {
|
||||
@ -440,26 +487,22 @@ struct DeviceShard {
|
||||
TrainParam param;
|
||||
bool prediction_cache_initialised;
|
||||
|
||||
int64_t* tmp_pinned; // Small amount of staging memory
|
||||
|
||||
dh::CubMemory temp_memory;
|
||||
|
||||
std::unique_ptr<GPUHistBuilderBase> hist_builder;
|
||||
|
||||
// TODO(canonizer): do add support multi-batch DMatrix here
|
||||
DeviceShard(int device_id,
|
||||
bst_uint row_begin, bst_uint row_end, TrainParam _param) :
|
||||
device_id_(device_id),
|
||||
row_begin_idx(row_begin),
|
||||
row_end_idx(row_end),
|
||||
row_stride(0),
|
||||
n_rows(row_end - row_begin),
|
||||
n_bins(0),
|
||||
null_gidx_value(0),
|
||||
param(_param),
|
||||
prediction_cache_initialised(false),
|
||||
tmp_pinned(nullptr)
|
||||
{}
|
||||
DeviceShard(int device_id, bst_uint row_begin, bst_uint row_end,
|
||||
TrainParam _param)
|
||||
: device_id_(device_id),
|
||||
row_begin_idx(row_begin),
|
||||
row_end_idx(row_end),
|
||||
row_stride(0),
|
||||
n_rows(row_end - row_begin),
|
||||
n_bins(0),
|
||||
null_gidx_value(0),
|
||||
param(_param),
|
||||
prediction_cache_initialised(false) {}
|
||||
|
||||
/* Init row_ptrs and row_stride */
|
||||
void InitRowPtrs(const SparsePage& row_batch) {
|
||||
@ -495,7 +538,6 @@ struct DeviceShard {
|
||||
void CreateHistIndices(const SparsePage& row_batch);
|
||||
|
||||
~DeviceShard() {
|
||||
dh::safe_cuda(cudaFreeHost(tmp_pinned));
|
||||
}
|
||||
|
||||
// Reset values for each update iteration
|
||||
@ -587,29 +629,18 @@ struct DeviceShard {
|
||||
hist.HistogramExists(nidx_parent);
|
||||
}
|
||||
|
||||
/*! \brief Count how many rows are assigned to left node. */
|
||||
__device__ void CountLeft(int64_t* d_count, int val, int left_nidx) {
|
||||
unsigned ballot = __ballot(val == left_nidx);
|
||||
if (threadIdx.x % 32 == 0) {
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
|
||||
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
|
||||
}
|
||||
}
|
||||
|
||||
void UpdatePosition(int nidx, int left_nidx, int right_nidx, int fidx,
|
||||
int64_t split_gidx, bool default_dir_left, bool is_dense,
|
||||
int fidx_begin, // cut.row_ptr[fidx]
|
||||
int fidx_end) { // cut.row_ptr[fidx + 1]
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
auto d_left_count = temp_memory.GetSpan<int64_t>(1);
|
||||
dh::safe_cuda(cudaMemset(d_left_count.data(), 0, sizeof(int64_t)));
|
||||
Segment segment = ridx_segments[nidx];
|
||||
bst_uint* d_ridx = ridx.Current();
|
||||
int* d_position = position.Current();
|
||||
common::CompressedIterator<uint32_t> d_gidx = gidx;
|
||||
size_t row_stride = this->row_stride;
|
||||
// Launch 1 thread for each row
|
||||
dh::LaunchN<1, 512>(
|
||||
dh::LaunchN<1, 128>(
|
||||
device_id_, segment.Size(), [=] __device__(bst_uint idx) {
|
||||
idx += segment.begin;
|
||||
bst_uint ridx = d_ridx[idx];
|
||||
@ -634,13 +665,16 @@ struct DeviceShard {
|
||||
position = default_dir_left ? left_nidx : right_nidx;
|
||||
}
|
||||
|
||||
CountLeft(d_left_count.data(), position, left_nidx);
|
||||
d_position[idx] = position;
|
||||
});
|
||||
dh::safe_cuda(cudaMemcpy(tmp_pinned, d_left_count.data(), sizeof(int64_t),
|
||||
cudaMemcpyDeviceToHost));
|
||||
auto left_count = *tmp_pinned;
|
||||
SortPosition(segment, left_nidx, right_nidx);
|
||||
IndicateLeftTransform conversion_op(left_nidx);
|
||||
cub::TransformInputIterator<int, IndicateLeftTransform, int*> left_itr(
|
||||
d_position + segment.begin, conversion_op);
|
||||
int left_count = dh::SumReduction(temp_memory, left_itr, segment.Size());
|
||||
CHECK_LE(left_count, segment.Size());
|
||||
CHECK_GE(left_count, 0);
|
||||
|
||||
SortPositionAndCopy(segment, left_nidx, right_nidx, left_count);
|
||||
|
||||
ridx_segments[left_nidx] =
|
||||
Segment(segment.begin, segment.begin + left_count);
|
||||
@ -649,25 +683,15 @@ struct DeviceShard {
|
||||
}
|
||||
|
||||
/*! \brief Sort row indices according to position. */
|
||||
void SortPosition(const Segment& segment, int left_nidx, int right_nidx) {
|
||||
int min_bits = 0;
|
||||
int max_bits = static_cast<int>(
|
||||
std::ceil(std::log2((std::max)(left_nidx, right_nidx) + 1)));
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
cub::DeviceRadixSort::SortPairs(
|
||||
nullptr, temp_storage_bytes,
|
||||
position.Current() + segment.begin, position.other() + segment.begin,
|
||||
ridx.Current() + segment.begin, ridx.other() + segment.begin,
|
||||
segment.Size(), min_bits, max_bits);
|
||||
|
||||
temp_memory.LazyAllocate(temp_storage_bytes);
|
||||
|
||||
cub::DeviceRadixSort::SortPairs(
|
||||
temp_memory.d_temp_storage, temp_memory.temp_storage_bytes,
|
||||
position.Current() + segment.begin, position.other() + segment.begin,
|
||||
ridx.Current() + segment.begin, ridx.other() + segment.begin,
|
||||
segment.Size(), min_bits, max_bits);
|
||||
void SortPositionAndCopy(const Segment& segment, int left_nidx, int right_nidx,
|
||||
size_t left_count) {
|
||||
SortPosition(
|
||||
&temp_memory,
|
||||
common::Span<int>(position.Current() + segment.begin, segment.Size()),
|
||||
common::Span<int>(position.other() + segment.begin, segment.Size()),
|
||||
common::Span<bst_uint>(ridx.Current() + segment.begin, segment.Size()),
|
||||
common::Span<bst_uint>(ridx.other() + segment.begin, segment.Size()),
|
||||
left_nidx, right_nidx, left_count);
|
||||
// Copy back key
|
||||
dh::safe_cuda(cudaMemcpy(
|
||||
position.Current() + segment.begin, position.other() + segment.begin,
|
||||
@ -823,8 +847,6 @@ inline void DeviceShard::InitCompressedData(
|
||||
|
||||
// Init histogram
|
||||
hist.Init(device_id_, hmat.row_ptr.back());
|
||||
|
||||
dh::safe_cuda(cudaMallocHost(&tmp_pinned, sizeof(int64_t)));
|
||||
}
|
||||
|
||||
inline void DeviceShard::CreateHistIndices(const SparsePage& row_batch) {
|
||||
|
||||
@ -327,8 +327,6 @@ TEST(GpuHist, ApplySplit) {
|
||||
shard->row_stride = n_cols;
|
||||
thrust::sequence(shard->ridx.CurrentDVec().tbegin(),
|
||||
shard->ridx.CurrentDVec().tend());
|
||||
// Free inside DeviceShard
|
||||
dh::safe_cuda(cudaMallocHost(&(shard->tmp_pinned), sizeof(int64_t)));
|
||||
// Initialize GPUHistMaker
|
||||
hist_maker.param_ = param;
|
||||
RegTree tree;
|
||||
@ -389,5 +387,44 @@ TEST(GpuHist, ApplySplit) {
|
||||
ASSERT_EQ(shard->ridx_segments[right_nidx].end, 16);
|
||||
}
|
||||
|
||||
void TestSortPosition(const std::vector<int>& position_in, int left_idx,
|
||||
int right_idx) {
|
||||
int left_count = std::count(position_in.begin(), position_in.end(), left_idx);
|
||||
thrust::device_vector<int> position = position_in;
|
||||
thrust::device_vector<int> position_out(position.size());
|
||||
|
||||
thrust::device_vector<bst_uint> ridx(position.size());
|
||||
thrust::sequence(ridx.begin(), ridx.end());
|
||||
thrust::device_vector<bst_uint> ridx_out(ridx.size());
|
||||
dh::CubMemory tmp;
|
||||
SortPosition(
|
||||
&tmp, common::Span<int>(position.data().get(), position.size()),
|
||||
common::Span<int>(position_out.data().get(), position_out.size()),
|
||||
common::Span<bst_uint>(ridx.data().get(), ridx.size()),
|
||||
common::Span<bst_uint>(ridx_out.data().get(), ridx_out.size()), left_idx,
|
||||
right_idx, left_count);
|
||||
thrust::host_vector<int> position_result = position_out;
|
||||
thrust::host_vector<int> ridx_result = ridx_out;
|
||||
|
||||
// Check position is sorted
|
||||
EXPECT_TRUE(std::is_sorted(position_result.begin(), position_result.end()));
|
||||
// Check row indices are sorted inside left and right segment
|
||||
EXPECT_TRUE(
|
||||
std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count));
|
||||
EXPECT_TRUE(
|
||||
std::is_sorted(ridx_result.begin() + left_count, ridx_result.end()));
|
||||
|
||||
// Check key value pairs are the same
|
||||
for (auto i = 0ull; i < ridx_result.size(); i++) {
|
||||
EXPECT_EQ(position_result[i], position_in[ridx_result[i]]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GpuHist, SortPosition) {
|
||||
TestSortPosition({1, 2, 1, 2, 1}, 1, 2);
|
||||
TestSortPosition({1, 1, 1, 1}, 1, 2);
|
||||
TestSortPosition({2, 2, 2, 2}, 1, 2);
|
||||
TestSortPosition({1, 2, 1, 2, 3}, 1, 2);
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user