Check for invalid data. (#6742)
This commit is contained in:
parent
a9b4a95225
commit
f20074e826
@ -898,11 +898,12 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
||||
return max_columns;
|
||||
}
|
||||
std::vector<std::vector<uint64_t>> max_columns_vector(nthread);
|
||||
dmlc::OMPException exc;
|
||||
dmlc::OMPException exec;
|
||||
std::atomic<bool> valid{true};
|
||||
// First-pass over the batch counting valid elements
|
||||
#pragma omp parallel num_threads(nthread)
|
||||
{
|
||||
exc.Run([&]() {
|
||||
exec.Run([&]() {
|
||||
int tid = omp_get_thread_num();
|
||||
size_t begin = tid*thread_size;
|
||||
size_t end = tid != (nthread-1) ? (tid+1)*thread_size : batch_size;
|
||||
@ -912,7 +913,10 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
||||
for (size_t i = begin; i < end; ++i) {
|
||||
auto line = batch.GetLine(i);
|
||||
for (auto j = 0ull; j < line.Size(); j++) {
|
||||
auto element = line.GetElement(j);
|
||||
data::COOTuple const& element = line.GetElement(j);
|
||||
if (!std::isinf(missing) && std::isinf(element.value)) {
|
||||
valid = false;
|
||||
}
|
||||
const size_t key = element.row_idx - base_rowid;
|
||||
CHECK_GE(key, builder_base_row_offset);
|
||||
max_columns_local =
|
||||
@ -927,7 +931,8 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
||||
}
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
exec.Rethrow();
|
||||
CHECK(valid) << "Input data contains `inf` or `nan`";
|
||||
for (const auto & max : max_columns_vector) {
|
||||
max_columns = std::max(max_columns, max[0]);
|
||||
}
|
||||
@ -938,7 +943,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
||||
|
||||
#pragma omp parallel num_threads(nthread)
|
||||
{
|
||||
exc.Run([&]() {
|
||||
exec.Run([&]() {
|
||||
int tid = omp_get_thread_num();
|
||||
size_t begin = tid*thread_size;
|
||||
size_t end = tid != (nthread-1) ? (tid+1)*thread_size : batch_size;
|
||||
@ -954,7 +959,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
||||
}
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
exec.Rethrow();
|
||||
omp_set_num_threads(nthread_original);
|
||||
|
||||
return max_columns;
|
||||
|
||||
@ -116,6 +116,14 @@ TEST(SimpleDMatrix, MissingData) {
|
||||
CHECK_EQ(dmat->Info().num_nonzero_, 2);
|
||||
dmat.reset(new data::SimpleDMatrix(&adapter, 1.0, 1));
|
||||
CHECK_EQ(dmat->Info().num_nonzero_, 1);
|
||||
|
||||
{
|
||||
data[1] = std::numeric_limits<float>::infinity();
|
||||
data::DenseAdapter adapter(data.data(), data.size(), 1);
|
||||
EXPECT_THROW(data::SimpleDMatrix dmat(
|
||||
&adapter, std::numeric_limits<float>::quiet_NaN(), -1),
|
||||
dmlc::Error);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SimpleDMatrix, EmptyRow) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user