Optimisations for gpu_hist. (#4248)

* Optimisations for gpu_hist.

* Use streams to overlap operations.

* ColumnSampler now uses HostDeviceVector to prevent repeatedly copying feature vectors to the device.
This commit is contained in:
Rory Mitchell
2019-03-20 13:30:06 +13:00
committed by GitHub
parent 7814183199
commit 00465d243d
8 changed files with 278 additions and 119 deletions

View File

@@ -304,11 +304,13 @@ TEST(GpuHist, EvaluateSplits) {
hist_maker.node_value_constraints_[0].lower_bound = -1.0;
hist_maker.node_value_constraints_[0].upper_bound = 1.0;
DeviceSplitCandidate res =
hist_maker.EvaluateSplit(0, &tree);
std::vector<DeviceSplitCandidate> res =
hist_maker.EvaluateSplits({ 0,0 }, &tree);
ASSERT_EQ(res.findex, 7);
ASSERT_NEAR(res.fvalue, 0.26, xgboost::kRtEps);
ASSERT_EQ(res[0].findex, 7);
ASSERT_EQ(res[1].findex, 7);
ASSERT_NEAR(res[0].fvalue, 0.26, xgboost::kRtEps);
ASSERT_NEAR(res[1].fvalue, 0.26, xgboost::kRtEps);
}
TEST(GpuHist, ApplySplit) {
@@ -400,7 +402,9 @@ TEST(GpuHist, ApplySplit) {
void TestSortPosition(const std::vector<int>& position_in, int left_idx,
int right_idx) {
int left_count = std::count(position_in.begin(), position_in.end(), left_idx);
std::vector<int64_t> left_count = {
std::count(position_in.begin(), position_in.end(), left_idx)};
thrust::device_vector<int64_t> d_left_count = left_count;
thrust::device_vector<int> position = position_in;
thrust::device_vector<int> position_out(position.size());
@@ -413,7 +417,7 @@ void TestSortPosition(const std::vector<int>& position_in, int left_idx,
common::Span<int>(position_out.data().get(), position_out.size()),
common::Span<bst_uint>(ridx.data().get(), ridx.size()),
common::Span<bst_uint>(ridx_out.data().get(), ridx_out.size()), left_idx,
right_idx, left_count);
right_idx, d_left_count.data().get());
thrust::host_vector<int> position_result = position_out;
thrust::host_vector<int> ridx_result = ridx_out;
@@ -421,9 +425,9 @@ void TestSortPosition(const std::vector<int>& position_in, int left_idx,
EXPECT_TRUE(std::is_sorted(position_result.begin(), position_result.end()));
// Check row indices are sorted inside left and right segment
EXPECT_TRUE(
std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count));
std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count[0]));
EXPECT_TRUE(
std::is_sorted(ridx_result.begin() + left_count, ridx_result.end()));
std::is_sorted(ridx_result.begin() + left_count[0], ridx_result.end()));
// Check key value pairs are the same
for (auto i = 0ull; i < ridx_result.size(); i++) {