Check for invalid data. (#6742)

This commit is contained in:
Jiaming Yuan 2021-03-04 14:37:20 +08:00 committed by GitHub
parent a9b4a95225
commit f20074e826
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 6 deletions

View File

@ -898,11 +898,12 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
return max_columns;
}
std::vector<std::vector<uint64_t>> max_columns_vector(nthread);
dmlc::OMPException exc;
dmlc::OMPException exec;
std::atomic<bool> valid{true};
// First-pass over the batch counting valid elements
#pragma omp parallel num_threads(nthread)
{
exc.Run([&]() {
exec.Run([&]() {
int tid = omp_get_thread_num();
size_t begin = tid*thread_size;
size_t end = tid != (nthread-1) ? (tid+1)*thread_size : batch_size;
@ -912,7 +913,10 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
for (size_t i = begin; i < end; ++i) {
auto line = batch.GetLine(i);
for (auto j = 0ull; j < line.Size(); j++) {
auto element = line.GetElement(j);
data::COOTuple const& element = line.GetElement(j);
if (!std::isinf(missing) && std::isinf(element.value)) {
valid = false;
}
const size_t key = element.row_idx - base_rowid;
CHECK_GE(key, builder_base_row_offset);
max_columns_local =
@ -927,7 +931,8 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
}
});
}
exc.Rethrow();
exec.Rethrow();
CHECK(valid) << "Input data contains `inf` or `nan`";
for (const auto & max : max_columns_vector) {
max_columns = std::max(max_columns, max[0]);
}
@ -938,7 +943,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
#pragma omp parallel num_threads(nthread)
{
exc.Run([&]() {
exec.Run([&]() {
int tid = omp_get_thread_num();
size_t begin = tid*thread_size;
size_t end = tid != (nthread-1) ? (tid+1)*thread_size : batch_size;
@ -954,7 +959,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
}
});
}
exc.Rethrow();
exec.Rethrow();
omp_set_num_threads(nthread_original);
return max_columns;

View File

@ -116,6 +116,14 @@ TEST(SimpleDMatrix, MissingData) {
CHECK_EQ(dmat->Info().num_nonzero_, 2);
dmat.reset(new data::SimpleDMatrix(&adapter, 1.0, 1));
CHECK_EQ(dmat->Info().num_nonzero_, 1);
{
data[1] = std::numeric_limits<float>::infinity();
data::DenseAdapter adapter(data.data(), data.size(), 1);
EXPECT_THROW(data::SimpleDMatrix dmat(
&adapter, std::numeric_limits<float>::quiet_NaN(), -1),
dmlc::Error);
}
}
TEST(SimpleDMatrix, EmptyRow) {