Catch exception in transform function omp context. (#4960)

This commit is contained in:
Jiaming Yuan 2019-10-21 17:03:38 +08:00 committed by GitHub
parent 010b8f1428
commit 4771bb0d41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 3 deletions

View File

@ -5,6 +5,8 @@
#define XGBOOST_COMMON_TRANSFORM_H_
#include <dmlc/omp.h>
#include <dmlc/common.h>
#include <xgboost/data.h>
#include <utility>
#include <vector>
@ -148,10 +150,12 @@ class Transform {
template <typename... HDV>
void LaunchCPU(Functor func, HDV*... vectors) const {
omp_ulong end = static_cast<omp_ulong>(*(range_.end()));
dmlc::OMPException omp_exc;
#pragma omp parallel for schedule(static)
for (omp_ulong idx = 0; idx < end; ++idx) {
func(idx, UnpackHDV(vectors)...);
omp_exc.Run(func, idx, UnpackHDV(vectors)...);
}
omp_exc.Rethrow();
}
private:

View File

@ -50,13 +50,26 @@ TEST(Transform, DeclareUnifiedTest(Basic)) {
out_vec.Fill(0);
Transform<>::Init(TestTransformRange<bst_float>{},
Range{0, static_cast<Range::DifferenceType>(size)},
TRANSFORM_GPU)
Range{0, static_cast<Range::DifferenceType>(size)},
TRANSFORM_GPU)
.Eval(&out_vec, &in_vec);
std::vector<bst_float> res = out_vec.HostVector();
ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin()));
}
#if !defined(__CUDACC__)
TEST(Transform, Exception) {
size_t const kSize {16};
std::vector<bst_float> h_in(kSize);
const HostDeviceVector<bst_float> in_vec{h_in, -1};
EXPECT_ANY_THROW({
Transform<>::Init([](size_t idx, common::Span<float const> _in) { _in[idx + 1]; },
Range(0, static_cast<Range::DifferenceType>(kSize)), -1)
.Eval(&in_vec);
});
}
#endif
} // namespace common
} // namespace xgboost