Optimisations for gpu_hist. (#4248)
* Optimisations for gpu_hist. * Use streams to overlap operations. * ColumnSampler now uses HostDeviceVector to prevent repeatedly copying feature vectors to the device.
This commit is contained in:
@@ -11,38 +11,40 @@ TEST(ColumnSampler, Test) {
|
||||
// No node sampling
|
||||
cs.Init(n, 1.0f, 0.5f, 0.5f);
|
||||
auto set0 = *cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set0.size(), 32);
|
||||
ASSERT_EQ(set0.Size(), 32);
|
||||
|
||||
auto set1 = *cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set0, set1);
|
||||
|
||||
ASSERT_EQ(set0.HostVector(), set1.HostVector());
|
||||
|
||||
auto set2 = *cs.GetFeatureSet(1);
|
||||
ASSERT_NE(set1, set2);
|
||||
ASSERT_EQ(set2.size(), 32);
|
||||
ASSERT_NE(set1.HostVector(), set2.HostVector());
|
||||
ASSERT_EQ(set2.Size(), 32);
|
||||
|
||||
// Node sampling
|
||||
cs.Init(n, 0.5f, 1.0f, 0.5f);
|
||||
auto set3 = *cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set3.size(), 32);
|
||||
ASSERT_EQ(set3.Size(), 32);
|
||||
|
||||
auto set4 = *cs.GetFeatureSet(0);
|
||||
ASSERT_NE(set3, set4);
|
||||
ASSERT_EQ(set4.size(), 32);
|
||||
|
||||
ASSERT_NE(set3.HostVector(), set4.HostVector());
|
||||
ASSERT_EQ(set4.Size(), 32);
|
||||
|
||||
// No level or node sampling, should be the same at different depth
|
||||
cs.Init(n, 1.0f, 1.0f, 0.5f);
|
||||
ASSERT_EQ(*cs.GetFeatureSet(0), *cs.GetFeatureSet(1));
|
||||
ASSERT_EQ(cs.GetFeatureSet(0)->HostVector(), cs.GetFeatureSet(1)->HostVector());
|
||||
|
||||
cs.Init(n, 1.0f, 1.0f, 1.0f);
|
||||
auto set5 = *cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set5.size(), n);
|
||||
ASSERT_EQ(set5.Size(), n);
|
||||
cs.Init(n, 1.0f, 1.0f, 1.0f);
|
||||
auto set6 = *cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set5, set6);
|
||||
ASSERT_EQ(set5.HostVector(), set6.HostVector());
|
||||
|
||||
// Should always be a minimum of one feature
|
||||
cs.Init(n, 1e-16f, 1e-16f, 1e-16f);
|
||||
ASSERT_EQ(cs.GetFeatureSet(0)->size(), 1);
|
||||
ASSERT_EQ(cs.GetFeatureSet(0)->Size(), 1);
|
||||
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
@@ -304,11 +304,13 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
hist_maker.node_value_constraints_[0].lower_bound = -1.0;
|
||||
hist_maker.node_value_constraints_[0].upper_bound = 1.0;
|
||||
|
||||
DeviceSplitCandidate res =
|
||||
hist_maker.EvaluateSplit(0, &tree);
|
||||
std::vector<DeviceSplitCandidate> res =
|
||||
hist_maker.EvaluateSplits({ 0,0 }, &tree);
|
||||
|
||||
ASSERT_EQ(res.findex, 7);
|
||||
ASSERT_NEAR(res.fvalue, 0.26, xgboost::kRtEps);
|
||||
ASSERT_EQ(res[0].findex, 7);
|
||||
ASSERT_EQ(res[1].findex, 7);
|
||||
ASSERT_NEAR(res[0].fvalue, 0.26, xgboost::kRtEps);
|
||||
ASSERT_NEAR(res[1].fvalue, 0.26, xgboost::kRtEps);
|
||||
}
|
||||
|
||||
TEST(GpuHist, ApplySplit) {
|
||||
@@ -400,7 +402,9 @@ TEST(GpuHist, ApplySplit) {
|
||||
|
||||
void TestSortPosition(const std::vector<int>& position_in, int left_idx,
|
||||
int right_idx) {
|
||||
int left_count = std::count(position_in.begin(), position_in.end(), left_idx);
|
||||
std::vector<int64_t> left_count = {
|
||||
std::count(position_in.begin(), position_in.end(), left_idx)};
|
||||
thrust::device_vector<int64_t> d_left_count = left_count;
|
||||
thrust::device_vector<int> position = position_in;
|
||||
thrust::device_vector<int> position_out(position.size());
|
||||
|
||||
@@ -413,7 +417,7 @@ void TestSortPosition(const std::vector<int>& position_in, int left_idx,
|
||||
common::Span<int>(position_out.data().get(), position_out.size()),
|
||||
common::Span<bst_uint>(ridx.data().get(), ridx.size()),
|
||||
common::Span<bst_uint>(ridx_out.data().get(), ridx_out.size()), left_idx,
|
||||
right_idx, left_count);
|
||||
right_idx, d_left_count.data().get());
|
||||
thrust::host_vector<int> position_result = position_out;
|
||||
thrust::host_vector<int> ridx_result = ridx_out;
|
||||
|
||||
@@ -421,9 +425,9 @@ void TestSortPosition(const std::vector<int>& position_in, int left_idx,
|
||||
EXPECT_TRUE(std::is_sorted(position_result.begin(), position_result.end()));
|
||||
// Check row indices are sorted inside left and right segment
|
||||
EXPECT_TRUE(
|
||||
std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count));
|
||||
std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count[0]));
|
||||
EXPECT_TRUE(
|
||||
std::is_sorted(ridx_result.begin() + left_count, ridx_result.end()));
|
||||
std::is_sorted(ridx_result.begin() + left_count[0], ridx_result.end()));
|
||||
|
||||
// Check key value pairs are the same
|
||||
for (auto i = 0ull; i < ridx_result.size(); i++) {
|
||||
|
||||
Reference in New Issue
Block a user