Improve update position function for gpu_hist (#3895)

This commit is contained in:
Rory Mitchell
2018-11-14 19:33:29 +13:00
committed by GitHub
parent 143475b27b
commit 7af0946ac1
3 changed files with 184 additions and 57 deletions

View File

@@ -327,8 +327,6 @@ TEST(GpuHist, ApplySplit) {
shard->row_stride = n_cols;
thrust::sequence(shard->ridx.CurrentDVec().tbegin(),
shard->ridx.CurrentDVec().tend());
// Free inside DeviceShard
dh::safe_cuda(cudaMallocHost(&(shard->tmp_pinned), sizeof(int64_t)));
// Initialize GPUHistMaker
hist_maker.param_ = param;
RegTree tree;
@@ -389,5 +387,44 @@ TEST(GpuHist, ApplySplit) {
ASSERT_EQ(shard->ridx_segments[right_nidx].end, 16);
}
void TestSortPosition(const std::vector<int>& position_in, int left_idx,
int right_idx) {
int left_count = std::count(position_in.begin(), position_in.end(), left_idx);
thrust::device_vector<int> position = position_in;
thrust::device_vector<int> position_out(position.size());
thrust::device_vector<bst_uint> ridx(position.size());
thrust::sequence(ridx.begin(), ridx.end());
thrust::device_vector<bst_uint> ridx_out(ridx.size());
dh::CubMemory tmp;
SortPosition(
&tmp, common::Span<int>(position.data().get(), position.size()),
common::Span<int>(position_out.data().get(), position_out.size()),
common::Span<bst_uint>(ridx.data().get(), ridx.size()),
common::Span<bst_uint>(ridx_out.data().get(), ridx_out.size()), left_idx,
right_idx, left_count);
thrust::host_vector<int> position_result = position_out;
thrust::host_vector<int> ridx_result = ridx_out;
// Check position is sorted
EXPECT_TRUE(std::is_sorted(position_result.begin(), position_result.end()));
// Check row indices are sorted inside left and right segment
EXPECT_TRUE(
std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count));
EXPECT_TRUE(
std::is_sorted(ridx_result.begin() + left_count, ridx_result.end()));
// Check key value pairs are the same
for (auto i = 0ull; i < ridx_result.size(); i++) {
EXPECT_EQ(position_result[i], position_in[ridx_result[i]]);
}
}
TEST(GpuHist, SortPosition) {
TestSortPosition({1, 2, 1, 2, 1}, 1, 2);
TestSortPosition({1, 1, 1, 1}, 1, 2);
TestSortPosition({2, 2, 2, 2}, 1, 2);
TestSortPosition({1, 2, 1, 2, 3}, 1, 2);
}
} // namespace tree
} // namespace xgboost