Avoid thrust logical operation. (#9199)
Thrust implementation of `thrust::all_of/any_of/none_of` adopts an early stopping strategy to bailout early by dividing the input into small batches. This is not ideal for data validation as we expect all data to be valid. The strategy leads to excessive kernel launches and stream synchronization. * Use reduce from dh instead.
This commit is contained in:
parent
614f47c477
commit
053aababd4
@ -29,7 +29,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
|
||||
: columns_(columns),
|
||||
num_rows_(num_rows) {}
|
||||
size_t Size() const { return num_rows_ * columns_.size(); }
|
||||
__device__ COOTuple GetElement(size_t idx) const {
|
||||
__device__ __forceinline__ COOTuple GetElement(size_t idx) const {
|
||||
size_t column_idx = idx % columns_.size();
|
||||
size_t row_idx = idx / columns_.size();
|
||||
auto const& column = columns_[column_idx];
|
||||
@ -221,13 +221,24 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
|
||||
* \brief Check there's no inf in data.
|
||||
*/
|
||||
template <typename AdapterBatchT>
|
||||
bool HasInfInData(AdapterBatchT const& batch, IsValidFunctor is_valid) {
|
||||
bool NoInfInData(AdapterBatchT const& batch, IsValidFunctor is_valid) {
|
||||
auto counting = thrust::make_counting_iterator(0llu);
|
||||
auto value_iter = dh::MakeTransformIterator<float>(
|
||||
counting, [=] XGBOOST_DEVICE(std::size_t idx) { return batch.GetElement(idx).value; });
|
||||
auto valid =
|
||||
thrust::none_of(value_iter, value_iter + batch.Size(),
|
||||
[is_valid] XGBOOST_DEVICE(float v) { return is_valid(v) && std::isinf(v); });
|
||||
auto value_iter = dh::MakeTransformIterator<bool>(counting, [=] XGBOOST_DEVICE(std::size_t idx) {
|
||||
auto v = batch.GetElement(idx).value;
|
||||
if (!is_valid(v)) {
|
||||
// discard the invalid elements.
|
||||
return true;
|
||||
}
|
||||
// check that there's no inf in data.
|
||||
return !std::isinf(v);
|
||||
});
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
// The default implementation in thrust optimizes any_of/none_of/all_of by using small
|
||||
// intervals to early stop. But we expect all data to be valid here, using small
|
||||
// intervals only decreases performance due to excessive kernel launch and stream
|
||||
// synchronization.
|
||||
auto valid = dh::Reduce(thrust::cuda::par(alloc), value_iter, value_iter + batch.Size(), true,
|
||||
thrust::logical_and<>{});
|
||||
return valid;
|
||||
}
|
||||
}; // namespace data
|
||||
|
||||
@ -200,7 +200,7 @@ void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType cons
|
||||
// correct output position
|
||||
auto counting = thrust::make_counting_iterator(0llu);
|
||||
data::IsValidFunctor is_valid(missing);
|
||||
bool valid = data::HasInfInData(batch, is_valid);
|
||||
bool valid = data::NoInfInData(batch, is_valid);
|
||||
CHECK(valid) << error::InfInData();
|
||||
|
||||
auto key_iter = dh::MakeTransformIterator<size_t>(
|
||||
|
||||
@ -64,7 +64,7 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
||||
template <typename AdapterBatchT>
|
||||
size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing,
|
||||
SparsePage* page) {
|
||||
bool valid = HasInfInData(batch, IsValidFunctor{missing});
|
||||
bool valid = NoInfInData(batch, IsValidFunctor{missing});
|
||||
CHECK(valid) << error::InfInData();
|
||||
|
||||
page->offset.SetDevice(device);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user