Check inf in data for all types of DMatrix. (#8911)

This commit is contained in:
Jiaming Yuan
2023-03-15 11:24:35 +08:00
committed by GitHub
parent 72e8331eab
commit f186c87cf9
11 changed files with 118 additions and 45 deletions

View File

@@ -10,13 +10,16 @@
#include <cstring>
#include "../collective/communicator-inl.h"
#include "../common/algorithm.h" // StableSort
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
#include "../collective/communicator.h"
#include "../common/common.h"
#include "../common/algorithm.h" // for StableSort
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/error_msg.h" // for InfInData
#include "../common/group_data.h"
#include "../common/io.h"
#include "../common/linalg_op.h"
#include "../common/math.h"
#include "../common/numeric.h" // Iota
#include "../common/numeric.h" // for Iota
#include "../common/threading_utils.h"
#include "../common/version.h"
#include "../data/adapter.h"
@@ -1144,7 +1147,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
});
}
exec.Rethrow();
CHECK(valid) << "Input data contains `inf` or `nan`";
CHECK(valid) << error::InfInData();
for (const auto & max : max_columns_vector) {
max_columns = std::max(max_columns, max[0]);
}

View File

@@ -4,7 +4,10 @@
*/
#ifndef XGBOOST_DATA_DEVICE_ADAPTER_H_
#define XGBOOST_DATA_DEVICE_ADAPTER_H_
#include <cstddef> // for size_t
#include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
#include <thrust/logical.h> // for none_of
#include <cstddef> // for size_t
#include <limits>
#include <memory>
#include <string>
@@ -213,6 +216,20 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
static_cast<std::size_t>(0), thrust::maximum<size_t>());
return row_stride;
}
/**
* \brief Check there's no inf in data.
*/
template <typename AdapterBatchT>
bool HasInfInData(AdapterBatchT const& batch, IsValidFunctor is_valid) {
auto counting = thrust::make_counting_iterator(0llu);
auto value_iter = dh::MakeTransformIterator<float>(
counting, [=] XGBOOST_DEVICE(std::size_t idx) { return batch.GetElement(idx).value; });
auto valid =
thrust::none_of(value_iter, value_iter + batch.Size(),
[is_valid] XGBOOST_DEVICE(float v) { return is_valid(v) && std::isinf(v); });
return valid;
}
}; // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_DEVICE_ADAPTER_H_

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2019-2022 XGBoost contributors
/**
* Copyright 2019-2023 by XGBoost contributors
*/
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
@@ -9,7 +9,7 @@
#include "../common/random.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "./ellpack_page.cuh"
#include "device_adapter.cuh"
#include "device_adapter.cuh" // for HasInfInData
#include "gradient_index.h"
#include "xgboost/data.h"
@@ -189,9 +189,8 @@ struct TupleScanOp {
// Here the data is already correctly ordered and simply needs to be compacted
// to remove missing data
template <typename AdapterBatchT>
void CopyDataToEllpack(const AdapterBatchT &batch,
common::Span<FeatureType const> feature_types,
EllpackPageImpl *dst, int device_idx, float missing) {
void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType const> feature_types,
EllpackPageImpl* dst, int device_idx, float missing) {
// Some witchcraft happens here
// The goal is to copy valid elements out of the input to an ELLPACK matrix
// with a given row stride, using no extra working memory Standard stream
@@ -201,6 +200,9 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
// correct output position
auto counting = thrust::make_counting_iterator(0llu);
data::IsValidFunctor is_valid(missing);
bool valid = data::HasInfInData(batch, is_valid);
CHECK(valid) << error::InfInData();
auto key_iter = dh::MakeTransformIterator<size_t>(
counting,
[=] __device__(size_t idx) {
@@ -239,9 +241,9 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
TupleScanOp<Tuple>, cub::NullType, int64_t>;
#if THRUST_MAJOR_VERSION >= 2
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
nullptr);
dh::safe_cuda(DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
nullptr));
#else
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
@@ -249,9 +251,9 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
#endif
dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
#if THRUST_MAJOR_VERSION >= 2
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(),
cub::NullType(), batch.Size(), nullptr);
dh::safe_cuda(DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(),
cub::NullType(), batch.Size(), nullptr));
#else
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(),

View File

@@ -1,21 +1,23 @@
/*!
* Copyright 2017-2022 by XGBoost Contributors
/**
* Copyright 2017-2023 by XGBoost Contributors
* \brief Data type for fast histogram aggregation.
*/
#ifndef XGBOOST_DATA_GRADIENT_INDEX_H_
#define XGBOOST_DATA_GRADIENT_INDEX_H_
#include <algorithm> // std::min
#include <cinttypes> // std::uint32_t
#include <cstddef> // std::size_t
#include <algorithm> // for min
#include <atomic> // for atomic
#include <cinttypes> // for uint32_t
#include <cstddef> // for size_t
#include <memory>
#include <vector>
#include "../common/categorical.h"
#include "../common/error_msg.h" // for InfInData
#include "../common/hist_util.h"
#include "../common/numeric.h"
#include "../common/threading_utils.h"
#include "../common/transform_iterator.h" // common::MakeIndexTransformIter
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
#include "adapter.h"
#include "proxy_dmatrix.h"
#include "xgboost/base.h"
@@ -62,6 +64,7 @@ class GHistIndexMatrix {
BinIdxType* index_data = index_data_span.data();
auto const& ptrs = cut.Ptrs();
auto const& values = cut.Values();
std::atomic<bool> valid{true};
common::ParallelFor(batch_size, batch_threads, [&](size_t i) {
auto line = batch.GetLine(i);
size_t ibegin = row_ptr[rbegin + i]; // index of first entry for current block
@@ -70,6 +73,9 @@ class GHistIndexMatrix {
for (size_t j = 0; j < line.Size(); ++j) {
data::COOTuple elem = line.GetElement(j);
if (is_valid(elem)) {
if (XGBOOST_EXPECT((std::isinf(elem.value)), false)) {
valid = false;
}
bst_bin_t bin_idx{-1};
if (common::IsCat(ft, elem.column_idx)) {
bin_idx = cut.SearchCatBin(elem.value, elem.column_idx, ptrs, values);
@@ -82,6 +88,8 @@ class GHistIndexMatrix {
}
}
});
CHECK(valid) << error::InfInData();
}
// Gather hit_count from all threads

View File

@@ -1,18 +1,19 @@
/*!
* Copyright 2019-2021 by XGBoost Contributors
/**
* Copyright 2019-2023 by XGBoost Contributors
* \file simple_dmatrix.cuh
*/
#ifndef XGBOOST_DATA_SIMPLE_DMATRIX_CUH_
#define XGBOOST_DATA_SIMPLE_DMATRIX_CUH_
#include <thrust/copy.h>
#include <thrust/scan.h>
#include <thrust/execution_policy.h>
#include "device_adapter.cuh"
#include "../common/device_helpers.cuh"
#include <thrust/scan.h>
namespace xgboost {
namespace data {
#include "../common/device_helpers.cuh"
#include "../common/error_msg.h" // for InfInData
#include "device_adapter.cuh" // for HasInfInData
namespace xgboost::data {
template <typename AdapterBatchT>
struct COOToEntryOp {
@@ -61,7 +62,11 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
}
template <typename AdapterBatchT>
size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing, SparsePage* page) {
size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing,
SparsePage* page) {
bool valid = HasInfInData(batch, IsValidFunctor{missing});
CHECK(valid) << error::InfInData();
page->offset.SetDevice(device);
page->data.SetDevice(device);
page->offset.Resize(batch.NumRows() + 1);
@@ -73,6 +78,5 @@ size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missin
return num_nonzero_;
}
} // namespace data
} // namespace xgboost
} // namespace xgboost::data
#endif // XGBOOST_DATA_SIMPLE_DMATRIX_CUH_