Catch exception in transform function omp context. (#4960)
This commit is contained in:
parent
010b8f1428
commit
4771bb0d41
@ -5,6 +5,8 @@
|
|||||||
#define XGBOOST_COMMON_TRANSFORM_H_
|
#define XGBOOST_COMMON_TRANSFORM_H_
|
||||||
|
|
||||||
#include <dmlc/omp.h>
|
#include <dmlc/omp.h>
|
||||||
|
#include <dmlc/common.h>
|
||||||
|
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -148,10 +150,12 @@ class Transform {
|
|||||||
template <typename... HDV>
|
template <typename... HDV>
|
||||||
void LaunchCPU(Functor func, HDV*... vectors) const {
|
void LaunchCPU(Functor func, HDV*... vectors) const {
|
||||||
omp_ulong end = static_cast<omp_ulong>(*(range_.end()));
|
omp_ulong end = static_cast<omp_ulong>(*(range_.end()));
|
||||||
|
dmlc::OMPException omp_exc;
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (omp_ulong idx = 0; idx < end; ++idx) {
|
for (omp_ulong idx = 0; idx < end; ++idx) {
|
||||||
func(idx, UnpackHDV(vectors)...);
|
omp_exc.Run(func, idx, UnpackHDV(vectors)...);
|
||||||
}
|
}
|
||||||
|
omp_exc.Rethrow();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@ -50,13 +50,26 @@ TEST(Transform, DeclareUnifiedTest(Basic)) {
|
|||||||
out_vec.Fill(0);
|
out_vec.Fill(0);
|
||||||
|
|
||||||
Transform<>::Init(TestTransformRange<bst_float>{},
|
Transform<>::Init(TestTransformRange<bst_float>{},
|
||||||
Range{0, static_cast<Range::DifferenceType>(size)},
|
Range{0, static_cast<Range::DifferenceType>(size)},
|
||||||
TRANSFORM_GPU)
|
TRANSFORM_GPU)
|
||||||
.Eval(&out_vec, &in_vec);
|
.Eval(&out_vec, &in_vec);
|
||||||
std::vector<bst_float> res = out_vec.HostVector();
|
std::vector<bst_float> res = out_vec.HostVector();
|
||||||
|
|
||||||
ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin()));
|
ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if !defined(__CUDACC__)
|
||||||
|
TEST(Transform, Exception) {
|
||||||
|
size_t const kSize {16};
|
||||||
|
std::vector<bst_float> h_in(kSize);
|
||||||
|
const HostDeviceVector<bst_float> in_vec{h_in, -1};
|
||||||
|
EXPECT_ANY_THROW({
|
||||||
|
Transform<>::Init([](size_t idx, common::Span<float const> _in) { _in[idx + 1]; },
|
||||||
|
Range(0, static_cast<Range::DifferenceType>(kSize)), -1)
|
||||||
|
.Eval(&in_vec);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user