From 053aababd480380ab1223ec19d4b9f28d2ce88d7 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 27 May 2023 01:36:58 +0800 Subject: [PATCH] Avoid thrust logical operation. (#9199) Thrust implementation of `thrust::all_of/any_of/none_of` adopts an early stopping strategy to bailout early by dividing the input into small batches. This is not ideal for data validation as we expect all data to be valid. The strategy leads to excessive kernel launches and stream synchronization. * Use reduce from dh instead. --- src/data/device_adapter.cuh | 25 ++++++++++++++++++------- src/data/ellpack_page.cu | 2 +- src/data/simple_dmatrix.cuh | 2 +- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index 494fb7d1c..136fbb743 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -29,7 +29,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo { : columns_(columns), num_rows_(num_rows) {} size_t Size() const { return num_rows_ * columns_.size(); } - __device__ COOTuple GetElement(size_t idx) const { + __device__ __forceinline__ COOTuple GetElement(size_t idx) const { size_t column_idx = idx % columns_.size(); size_t row_idx = idx / columns_.size(); auto const& column = columns_[column_idx]; @@ -221,13 +221,24 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span offset, * \brief Check there's no inf in data. */ template -bool HasInfInData(AdapterBatchT const& batch, IsValidFunctor is_valid) { +bool NoInfInData(AdapterBatchT const& batch, IsValidFunctor is_valid) { auto counting = thrust::make_counting_iterator(0llu); - auto value_iter = dh::MakeTransformIterator( - counting, [=] XGBOOST_DEVICE(std::size_t idx) { return batch.GetElement(idx).value; }); - auto valid = - thrust::none_of(value_iter, value_iter + batch.Size(), - [is_valid] XGBOOST_DEVICE(float v) { return is_valid(v) && std::isinf(v); }); + auto value_iter = dh::MakeTransformIterator(counting, [=] XGBOOST_DEVICE(std::size_t idx) { + auto v = batch.GetElement(idx).value; + if (!is_valid(v)) { + // discard the invalid elements. + return true; + } + // check that there's no inf in data. + return !std::isinf(v); + }); + dh::XGBCachingDeviceAllocator alloc; + // The default implementation in thrust optimizes any_of/none_of/all_of by using small + // intervals to early stop. But we expect all data to be valid here, using small + // intervals only decreases performance due to excessive kernel launch and stream + // synchronization. + auto valid = dh::Reduce(thrust::cuda::par(alloc), value_iter, value_iter + batch.Size(), true, + thrust::logical_and<>{}); return valid; } }; // namespace data diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 4409a7ebb..aa218fa31 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -200,7 +200,7 @@ void CopyDataToEllpack(const AdapterBatchT& batch, common::Span( diff --git a/src/data/simple_dmatrix.cuh b/src/data/simple_dmatrix.cuh index 63310a929..e2c0ae347 100644 --- a/src/data/simple_dmatrix.cuh +++ b/src/data/simple_dmatrix.cuh @@ -64,7 +64,7 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span offset, template size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing, SparsePage* page) { - bool valid = HasInfInData(batch, IsValidFunctor{missing}); + bool valid = NoInfInData(batch, IsValidFunctor{missing}); CHECK(valid) << error::InfInData(); page->offset.SetDevice(device);