Check for invalid data. (#6742)

This commit is contained in:
Jiaming Yuan
2021-03-04 14:37:20 +08:00
committed by GitHub
parent a9b4a95225
commit f20074e826
2 changed files with 19 additions and 6 deletions

View File

@@ -898,11 +898,12 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
return max_columns;
}
std::vector<std::vector<uint64_t>> max_columns_vector(nthread);
dmlc::OMPException exc;
dmlc::OMPException exec;
std::atomic<bool> valid{true};
// First-pass over the batch counting valid elements
#pragma omp parallel num_threads(nthread)
{
exc.Run([&]() {
exec.Run([&]() {
int tid = omp_get_thread_num();
size_t begin = tid*thread_size;
size_t end = tid != (nthread-1) ? (tid+1)*thread_size : batch_size;
@@ -912,7 +913,10 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
for (size_t i = begin; i < end; ++i) {
auto line = batch.GetLine(i);
for (auto j = 0ull; j < line.Size(); j++) {
auto element = line.GetElement(j);
data::COOTuple const& element = line.GetElement(j);
if (!std::isinf(missing) && std::isinf(element.value)) {
valid = false;
}
const size_t key = element.row_idx - base_rowid;
CHECK_GE(key, builder_base_row_offset);
max_columns_local =
@@ -927,7 +931,8 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
}
});
}
exc.Rethrow();
exec.Rethrow();
CHECK(valid) << "Input data contains `inf` or `nan`";
for (const auto & max : max_columns_vector) {
max_columns = std::max(max_columns, max[0]);
}
@@ -938,7 +943,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
#pragma omp parallel num_threads(nthread)
{
exc.Run([&]() {
exec.Run([&]() {
int tid = omp_get_thread_num();
size_t begin = tid*thread_size;
size_t end = tid != (nthread-1) ? (tid+1)*thread_size : batch_size;
@@ -954,7 +959,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
}
});
}
exc.Rethrow();
exec.Rethrow();
omp_set_num_threads(nthread_original);
return max_columns;