diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index 494fb7d1c..136fbb743 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -29,7 +29,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo { : columns_(columns), num_rows_(num_rows) {} size_t Size() const { return num_rows_ * columns_.size(); } - __device__ COOTuple GetElement(size_t idx) const { + __device__ __forceinline__ COOTuple GetElement(size_t idx) const { size_t column_idx = idx % columns_.size(); size_t row_idx = idx / columns_.size(); auto const& column = columns_[column_idx]; @@ -221,13 +221,24 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span offset, * \brief Check there's no inf in data. */ template -bool HasInfInData(AdapterBatchT const& batch, IsValidFunctor is_valid) { +bool NoInfInData(AdapterBatchT const& batch, IsValidFunctor is_valid) { auto counting = thrust::make_counting_iterator(0llu); - auto value_iter = dh::MakeTransformIterator( - counting, [=] XGBOOST_DEVICE(std::size_t idx) { return batch.GetElement(idx).value; }); - auto valid = - thrust::none_of(value_iter, value_iter + batch.Size(), - [is_valid] XGBOOST_DEVICE(float v) { return is_valid(v) && std::isinf(v); }); + auto value_iter = dh::MakeTransformIterator(counting, [=] XGBOOST_DEVICE(std::size_t idx) { + auto v = batch.GetElement(idx).value; + if (!is_valid(v)) { + // discard the invalid elements. + return true; + } + // check that there's no inf in data. + return !std::isinf(v); + }); + dh::XGBCachingDeviceAllocator alloc; + // The default implementation in thrust optimizes any_of/none_of/all_of by using small + // intervals to early stop. But we expect all data to be valid here, using small + // intervals only decreases performance due to excessive kernel launch and stream + // synchronization. + auto valid = dh::Reduce(thrust::cuda::par(alloc), value_iter, value_iter + batch.Size(), true, + thrust::logical_and<>{}); return valid; } }; // namespace data diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 4409a7ebb..aa218fa31 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -200,7 +200,7 @@ void CopyDataToEllpack(const AdapterBatchT& batch, common::Span( diff --git a/src/data/simple_dmatrix.cuh b/src/data/simple_dmatrix.cuh index 63310a929..e2c0ae347 100644 --- a/src/data/simple_dmatrix.cuh +++ b/src/data/simple_dmatrix.cuh @@ -64,7 +64,7 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span offset, template size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing, SparsePage* page) { - bool valid = HasInfInData(batch, IsValidFunctor{missing}); + bool valid = NoInfInData(batch, IsValidFunctor{missing}); CHECK(valid) << error::InfInData(); page->offset.SetDevice(device);