Check inf in data for all types of DMatrix. (#8911)
This commit is contained in:
@@ -1,21 +1,23 @@
|
||||
/*!
|
||||
* Copyright 2017-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2017-2023 by XGBoost Contributors
|
||||
* \brief Data type for fast histogram aggregation.
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_GRADIENT_INDEX_H_
|
||||
#define XGBOOST_DATA_GRADIENT_INDEX_H_
|
||||
|
||||
#include <algorithm> // std::min
|
||||
#include <cinttypes> // std::uint32_t
|
||||
#include <cstddef> // std::size_t
|
||||
#include <algorithm> // for min
|
||||
#include <atomic> // for atomic
|
||||
#include <cinttypes> // for uint32_t
|
||||
#include <cstddef> // for size_t
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/error_msg.h" // for InfInData
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/numeric.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/transform_iterator.h" // common::MakeIndexTransformIter
|
||||
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
|
||||
#include "adapter.h"
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "xgboost/base.h"
|
||||
@@ -62,6 +64,7 @@ class GHistIndexMatrix {
|
||||
BinIdxType* index_data = index_data_span.data();
|
||||
auto const& ptrs = cut.Ptrs();
|
||||
auto const& values = cut.Values();
|
||||
std::atomic<bool> valid{true};
|
||||
common::ParallelFor(batch_size, batch_threads, [&](size_t i) {
|
||||
auto line = batch.GetLine(i);
|
||||
size_t ibegin = row_ptr[rbegin + i]; // index of first entry for current block
|
||||
@@ -70,6 +73,9 @@ class GHistIndexMatrix {
|
||||
for (size_t j = 0; j < line.Size(); ++j) {
|
||||
data::COOTuple elem = line.GetElement(j);
|
||||
if (is_valid(elem)) {
|
||||
if (XGBOOST_EXPECT((std::isinf(elem.value)), false)) {
|
||||
valid = false;
|
||||
}
|
||||
bst_bin_t bin_idx{-1};
|
||||
if (common::IsCat(ft, elem.column_idx)) {
|
||||
bin_idx = cut.SearchCatBin(elem.value, elem.column_idx, ptrs, values);
|
||||
@@ -82,6 +88,8 @@ class GHistIndexMatrix {
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
CHECK(valid) << error::InfInData();
|
||||
}
|
||||
|
||||
// Gather hit_count from all threads
|
||||
|
||||
Reference in New Issue
Block a user