Improve update position function for gpu_hist (#3895)
This commit is contained in:
@@ -327,8 +327,6 @@ TEST(GpuHist, ApplySplit) {
|
||||
shard->row_stride = n_cols;
|
||||
thrust::sequence(shard->ridx.CurrentDVec().tbegin(),
|
||||
shard->ridx.CurrentDVec().tend());
|
||||
// Free inside DeviceShard
|
||||
dh::safe_cuda(cudaMallocHost(&(shard->tmp_pinned), sizeof(int64_t)));
|
||||
// Initialize GPUHistMaker
|
||||
hist_maker.param_ = param;
|
||||
RegTree tree;
|
||||
@@ -389,5 +387,44 @@ TEST(GpuHist, ApplySplit) {
|
||||
ASSERT_EQ(shard->ridx_segments[right_nidx].end, 16);
|
||||
}
|
||||
|
||||
void TestSortPosition(const std::vector<int>& position_in, int left_idx,
|
||||
int right_idx) {
|
||||
int left_count = std::count(position_in.begin(), position_in.end(), left_idx);
|
||||
thrust::device_vector<int> position = position_in;
|
||||
thrust::device_vector<int> position_out(position.size());
|
||||
|
||||
thrust::device_vector<bst_uint> ridx(position.size());
|
||||
thrust::sequence(ridx.begin(), ridx.end());
|
||||
thrust::device_vector<bst_uint> ridx_out(ridx.size());
|
||||
dh::CubMemory tmp;
|
||||
SortPosition(
|
||||
&tmp, common::Span<int>(position.data().get(), position.size()),
|
||||
common::Span<int>(position_out.data().get(), position_out.size()),
|
||||
common::Span<bst_uint>(ridx.data().get(), ridx.size()),
|
||||
common::Span<bst_uint>(ridx_out.data().get(), ridx_out.size()), left_idx,
|
||||
right_idx, left_count);
|
||||
thrust::host_vector<int> position_result = position_out;
|
||||
thrust::host_vector<int> ridx_result = ridx_out;
|
||||
|
||||
// Check position is sorted
|
||||
EXPECT_TRUE(std::is_sorted(position_result.begin(), position_result.end()));
|
||||
// Check row indices are sorted inside left and right segment
|
||||
EXPECT_TRUE(
|
||||
std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count));
|
||||
EXPECT_TRUE(
|
||||
std::is_sorted(ridx_result.begin() + left_count, ridx_result.end()));
|
||||
|
||||
// Check key value pairs are the same
|
||||
for (auto i = 0ull; i < ridx_result.size(); i++) {
|
||||
EXPECT_EQ(position_result[i], position_in[ridx_result[i]]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GpuHist, SortPosition) {
|
||||
TestSortPosition({1, 2, 1, 2, 1}, 1, 2);
|
||||
TestSortPosition({1, 1, 1, 1}, 1, 2);
|
||||
TestSortPosition({2, 2, 2, 2}, 1, 2);
|
||||
TestSortPosition({1, 2, 1, 2, 3}, 1, 2);
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user