Combine thread launches into single launch per tree for gpu_hist (#4343)
* Combine thread launches into single launch per tree for gpu_hist algorithm. * Address deprecation warning * Add manual column sampler constructor * Turn off omp dynamic to get a guaranteed number of threads * Enable openmp in cuda code
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
#include <valarray>
|
||||
#include "../../../src/common/random.h"
|
||||
#include "../helpers.h"
|
||||
#include "gtest/gtest.h"
|
||||
@@ -33,7 +34,8 @@ TEST(ColumnSampler, Test) {
|
||||
|
||||
// No level or node sampling, should be the same at different depth
|
||||
cs.Init(n, 1.0f, 1.0f, 0.5f);
|
||||
ASSERT_EQ(cs.GetFeatureSet(0)->HostVector(), cs.GetFeatureSet(1)->HostVector());
|
||||
ASSERT_EQ(cs.GetFeatureSet(0)->HostVector(),
|
||||
cs.GetFeatureSet(1)->HostVector());
|
||||
|
||||
cs.Init(n, 1.0f, 1.0f, 1.0f);
|
||||
auto set5 = *cs.GetFeatureSet(0);
|
||||
@@ -45,7 +47,34 @@ TEST(ColumnSampler, Test) {
|
||||
// Should always be a minimum of one feature
|
||||
cs.Init(n, 1e-16f, 1e-16f, 1e-16f);
|
||||
ASSERT_EQ(cs.GetFeatureSet(0)->Size(), 1);
|
||||
}
|
||||
|
||||
// Test if different threads using the same seed produce the same result
|
||||
TEST(ColumnSampler, ThreadSynchronisation) {
|
||||
const int64_t num_threads = 100;
|
||||
int n = 128;
|
||||
int iterations = 10;
|
||||
int levels = 5;
|
||||
std::vector<int> reference_result;
|
||||
bool success =
|
||||
true; // Cannot use google test asserts in multithreaded region
|
||||
#pragma omp parallel num_threads(num_threads)
|
||||
{
|
||||
for (auto j = 0ull; j < iterations; j++) {
|
||||
ColumnSampler cs(j);
|
||||
cs.Init(n, 0.5f, 0.5f, 0.5f);
|
||||
for (auto level = 0ull; level < levels; level++) {
|
||||
auto result = cs.GetFeatureSet(level)->ConstHostVector();
|
||||
#pragma omp single
|
||||
{ reference_result = result; }
|
||||
if (result != reference_result) {
|
||||
success = false;
|
||||
}
|
||||
#pragma omp barrier
|
||||
}
|
||||
}
|
||||
}
|
||||
ASSERT_TRUE(success);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user