Remove omp_get_max_threads in objective. (#7589)

This commit is contained in:
Jiaming Yuan
2022-01-24 04:35:49 +08:00
committed by GitHub
parent 5817840858
commit 6967ef7267
11 changed files with 76 additions and 74 deletions

View File

@@ -1,3 +1,6 @@
/*!
* Copyright 2018-2022 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/base.h>
#include <xgboost/span.h>
@@ -42,7 +45,7 @@ TEST(Transform, DeclareUnifiedTest(Basic)) {
out_vec.Fill(0);
Transform<>::Init(TestTransformRange<bst_float>{},
Range{0, static_cast<Range::DifferenceType>(size)},
Range{0, static_cast<Range::DifferenceType>(size)}, common::OmpGetNumThreads(0),
TRANSFORM_GPU)
.Eval(&out_vec, &in_vec);
std::vector<bst_float> res = out_vec.HostVector();
@@ -55,11 +58,14 @@ TEST(TransformDeathTest, Exception) {
size_t const kSize {16};
std::vector<bst_float> h_in(kSize);
const HostDeviceVector<bst_float> in_vec{h_in, -1};
EXPECT_DEATH({
Transform<>::Init([](size_t idx, common::Span<float const> _in) { _in[idx + 1]; },
Range(0, static_cast<Range::DifferenceType>(kSize)), -1)
.Eval(&in_vec);
}, "");
EXPECT_DEATH(
{
Transform<>::Init([](size_t idx, common::Span<float const> _in) { _in[idx + 1]; },
Range(0, static_cast<Range::DifferenceType>(kSize)),
common::OmpGetNumThreads(0), -1)
.Eval(&in_vec);
},
"");
}
#endif

View File

@@ -1,4 +1,7 @@
// This converts all tests from CPU to GPU.
/*!
* Copyright 2018-2022 by XGBoost Contributors
* \brief This converts all tests from CPU to GPU.
*/
#include "test_transform_range.cc"
#if defined(XGBOOST_USE_NCCL)
@@ -22,13 +25,13 @@ TEST(Transform, MGPU_SpecifiedGpuId) { // NOLINT
const HostDeviceVector<bst_float> in_vec {h_in, device};
HostDeviceVector<bst_float> out_vec {h_out, device};
ASSERT_NO_THROW(
Transform<>::Init(TestTransformRange<bst_float>{}, Range{0, size}, device)
.Eval(&out_vec, &in_vec));
ASSERT_NO_THROW(Transform<>::Init(TestTransformRange<bst_float>{}, Range{0, size},
common::OmpGetNumThreads(0), device)
.Eval(&out_vec, &in_vec));
std::vector<bst_float> res = out_vec.HostVector();
ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin()));
}
} // namespace common
} // namespace xgboost
#endif
#endif