diff --git a/src/data/data.cc b/src/data/data.cc index 2d5996331..f99d3368e 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -898,11 +898,12 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread return max_columns; } std::vector> max_columns_vector(nthread); - dmlc::OMPException exc; + dmlc::OMPException exec; + std::atomic valid{true}; // First-pass over the batch counting valid elements #pragma omp parallel num_threads(nthread) { - exc.Run([&]() { + exec.Run([&]() { int tid = omp_get_thread_num(); size_t begin = tid*thread_size; size_t end = tid != (nthread-1) ? (tid+1)*thread_size : batch_size; @@ -912,7 +913,10 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread for (size_t i = begin; i < end; ++i) { auto line = batch.GetLine(i); for (auto j = 0ull; j < line.Size(); j++) { - auto element = line.GetElement(j); + data::COOTuple const& element = line.GetElement(j); + if (!std::isinf(missing) && std::isinf(element.value)) { + valid = false; + } const size_t key = element.row_idx - base_rowid; CHECK_GE(key, builder_base_row_offset); max_columns_local = @@ -927,7 +931,8 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread } }); } - exc.Rethrow(); + exec.Rethrow(); + CHECK(valid) << "Input data contains `inf` or `nan`"; for (const auto & max : max_columns_vector) { max_columns = std::max(max_columns, max[0]); } @@ -938,7 +943,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread #pragma omp parallel num_threads(nthread) { - exc.Run([&]() { + exec.Run([&]() { int tid = omp_get_thread_num(); size_t begin = tid*thread_size; size_t end = tid != (nthread-1) ? (tid+1)*thread_size : batch_size; @@ -954,7 +959,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread } }); } - exc.Rethrow(); + exec.Rethrow(); omp_set_num_threads(nthread_original); return max_columns; diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index 3147395a6..f777b00e2 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -116,6 +116,14 @@ TEST(SimpleDMatrix, MissingData) { CHECK_EQ(dmat->Info().num_nonzero_, 2); dmat.reset(new data::SimpleDMatrix(&adapter, 1.0, 1)); CHECK_EQ(dmat->Info().num_nonzero_, 1); + + { + data[1] = std::numeric_limits::infinity(); + data::DenseAdapter adapter(data.data(), data.size(), 1); + EXPECT_THROW(data::SimpleDMatrix dmat( + &adapter, std::numeric_limits::quiet_NaN(), -1), + dmlc::Error); + } } TEST(SimpleDMatrix, EmptyRow) {