Avoid thrust logical operation. (#9199)
Thrust implementation of `thrust::all_of/any_of/none_of` adopts an early stopping strategy to bailout early by dividing the input into small batches. This is not ideal for data validation as we expect all data to be valid. The strategy leads to excessive kernel launches and stream synchronization. * Use reduce from dh instead.
This commit is contained in:
parent
614f47c477
commit
053aababd4
@ -29,7 +29,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
|
|||||||
: columns_(columns),
|
: columns_(columns),
|
||||||
num_rows_(num_rows) {}
|
num_rows_(num_rows) {}
|
||||||
size_t Size() const { return num_rows_ * columns_.size(); }
|
size_t Size() const { return num_rows_ * columns_.size(); }
|
||||||
__device__ COOTuple GetElement(size_t idx) const {
|
__device__ __forceinline__ COOTuple GetElement(size_t idx) const {
|
||||||
size_t column_idx = idx % columns_.size();
|
size_t column_idx = idx % columns_.size();
|
||||||
size_t row_idx = idx / columns_.size();
|
size_t row_idx = idx / columns_.size();
|
||||||
auto const& column = columns_[column_idx];
|
auto const& column = columns_[column_idx];
|
||||||
@ -221,13 +221,24 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
|
|||||||
* \brief Check there's no inf in data.
|
* \brief Check there's no inf in data.
|
||||||
*/
|
*/
|
||||||
template <typename AdapterBatchT>
|
template <typename AdapterBatchT>
|
||||||
bool HasInfInData(AdapterBatchT const& batch, IsValidFunctor is_valid) {
|
bool NoInfInData(AdapterBatchT const& batch, IsValidFunctor is_valid) {
|
||||||
auto counting = thrust::make_counting_iterator(0llu);
|
auto counting = thrust::make_counting_iterator(0llu);
|
||||||
auto value_iter = dh::MakeTransformIterator<float>(
|
auto value_iter = dh::MakeTransformIterator<bool>(counting, [=] XGBOOST_DEVICE(std::size_t idx) {
|
||||||
counting, [=] XGBOOST_DEVICE(std::size_t idx) { return batch.GetElement(idx).value; });
|
auto v = batch.GetElement(idx).value;
|
||||||
auto valid =
|
if (!is_valid(v)) {
|
||||||
thrust::none_of(value_iter, value_iter + batch.Size(),
|
// discard the invalid elements.
|
||||||
[is_valid] XGBOOST_DEVICE(float v) { return is_valid(v) && std::isinf(v); });
|
return true;
|
||||||
|
}
|
||||||
|
// check that there's no inf in data.
|
||||||
|
return !std::isinf(v);
|
||||||
|
});
|
||||||
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
|
// The default implementation in thrust optimizes any_of/none_of/all_of by using small
|
||||||
|
// intervals to early stop. But we expect all data to be valid here, using small
|
||||||
|
// intervals only decreases performance due to excessive kernel launch and stream
|
||||||
|
// synchronization.
|
||||||
|
auto valid = dh::Reduce(thrust::cuda::par(alloc), value_iter, value_iter + batch.Size(), true,
|
||||||
|
thrust::logical_and<>{});
|
||||||
return valid;
|
return valid;
|
||||||
}
|
}
|
||||||
}; // namespace data
|
}; // namespace data
|
||||||
|
|||||||
@ -200,7 +200,7 @@ void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType cons
|
|||||||
// correct output position
|
// correct output position
|
||||||
auto counting = thrust::make_counting_iterator(0llu);
|
auto counting = thrust::make_counting_iterator(0llu);
|
||||||
data::IsValidFunctor is_valid(missing);
|
data::IsValidFunctor is_valid(missing);
|
||||||
bool valid = data::HasInfInData(batch, is_valid);
|
bool valid = data::NoInfInData(batch, is_valid);
|
||||||
CHECK(valid) << error::InfInData();
|
CHECK(valid) << error::InfInData();
|
||||||
|
|
||||||
auto key_iter = dh::MakeTransformIterator<size_t>(
|
auto key_iter = dh::MakeTransformIterator<size_t>(
|
||||||
|
|||||||
@ -64,7 +64,7 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
|||||||
template <typename AdapterBatchT>
|
template <typename AdapterBatchT>
|
||||||
size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing,
|
size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing,
|
||||||
SparsePage* page) {
|
SparsePage* page) {
|
||||||
bool valid = HasInfInData(batch, IsValidFunctor{missing});
|
bool valid = NoInfInData(batch, IsValidFunctor{missing});
|
||||||
CHECK(valid) << error::InfInData();
|
CHECK(valid) << error::InfInData();
|
||||||
|
|
||||||
page->offset.SetDevice(device);
|
page->offset.SetDevice(device);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user