Check inf in data for all types of DMatrix. (#8911)
This commit is contained in:
parent
72e8331eab
commit
f186c87cf9
@ -84,9 +84,10 @@ public class BoosterTest {
|
||||
};
|
||||
|
||||
try (Table tmpTable = Table.readCSV(schema, opts, new File(trainingDataPath))) {
|
||||
ColumnVector[] df = new ColumnVector[12];
|
||||
for (int i = 0; i < 12; ++i) {
|
||||
df[i] = tmpTable.getColumn(i);
|
||||
ColumnVector[] df = new ColumnVector[10];
|
||||
// exclude the first two columns, they are label bounds and contain inf.
|
||||
for (int i = 2; i < 12; ++i) {
|
||||
df[i - 2] = tmpTable.getColumn(i);
|
||||
}
|
||||
try (Table X = new Table(df);) {
|
||||
ColumnVector[] labels = new ColumnVector[1];
|
||||
|
||||
@ -21,7 +21,7 @@ import java.io.File
|
||||
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier}
|
||||
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.apache.spark.sql.functions.{col, udf}
|
||||
import org.apache.spark.sql.functions.{col, udf, when}
|
||||
import org.apache.spark.sql.types.{FloatType, StructField, StructType}
|
||||
|
||||
class GpuXGBoostClassifierSuite extends GpuTestSuite {
|
||||
@ -47,7 +47,8 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
|
||||
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
|
||||
"features_cols" -> featureNames, "label_col" -> labelName)
|
||||
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
||||
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
|
||||
.randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
// Get a model
|
||||
val model = new XGBoostClassifier(xgbParam)
|
||||
.fit(originalDf)
|
||||
@ -64,7 +65,8 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
|
||||
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
|
||||
"features_cols" -> featureNames, "label_col" -> labelName)
|
||||
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
|
||||
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
|
||||
.randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
val getWeightFromF1 = udf({ f1: Float => if (f1.toInt % 2 == 0) 1.0f else 0.001f })
|
||||
val dfWithWeight = originalDf.withColumn("weight", getWeightFromF1(col("f1")))
|
||||
|
||||
@ -87,7 +89,8 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
|
||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
|
||||
"num_round" -> 10, "num_workers" -> 1)
|
||||
val Array(rawInput, testDf) = spark.read.option("header", "true").schema(schema)
|
||||
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
|
||||
.randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
|
||||
val classifier = new XGBoostClassifier(xgbParam)
|
||||
.setFeaturesCol(featureNames)
|
||||
@ -122,7 +125,8 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
|
||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
|
||||
"num_round" -> 10, "num_workers" -> 1)
|
||||
val Array(rawInput, _) = spark.read.option("header", "true").schema(schema)
|
||||
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
|
||||
.randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
|
||||
val vectorAssembler = new VectorAssembler()
|
||||
.setHandleInvalid("keep")
|
||||
@ -144,7 +148,8 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
|
||||
// transform on GPU
|
||||
withGpuSparkSession() { spark =>
|
||||
val Array(_, testDf) = spark.read.option("header", "true").schema(schema)
|
||||
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
|
||||
.randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
|
||||
// Since CPU model does not know the information about the features cols that GPU transform
|
||||
// pipeline requires. End user needs to setFeaturesCol(features: Array[String]) in the model
|
||||
@ -174,7 +179,8 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
|
||||
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
|
||||
"num_round" -> 10, "num_workers" -> 1)
|
||||
val Array(rawInput, _) = spark.read.option("header", "true").schema(schema)
|
||||
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
|
||||
.randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
|
||||
val classifier = new XGBoostClassifier(xgbParam)
|
||||
.setFeaturesCol(featureNames)
|
||||
@ -190,7 +196,8 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
|
||||
// transform on CPU
|
||||
withCpuSparkSession() { spark =>
|
||||
val Array(_, rawInput) = spark.read.option("header", "true").schema(schema)
|
||||
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
.csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
|
||||
.randomSplit(Array(0.7, 0.3), seed = 1)
|
||||
|
||||
val featureColName = "feature_col"
|
||||
val vectorAssembler = new VectorAssembler()
|
||||
|
||||
@ -2,7 +2,10 @@
|
||||
from typing import Any, Generator, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.random import Generator as RNG
|
||||
|
||||
import xgboost
|
||||
from xgboost.data import pandas_pyarrow_mapper
|
||||
|
||||
|
||||
@ -179,3 +182,16 @@ def pd_arrow_dtypes() -> Generator:
|
||||
dtype=pd.ArrowDtype(pa.bool_()),
|
||||
)
|
||||
yield orig, df
|
||||
|
||||
|
||||
def check_inf(rng: RNG) -> None:
|
||||
"""Validate there's no inf in X."""
|
||||
X = rng.random(size=32).reshape(8, 4)
|
||||
y = rng.random(size=8)
|
||||
X[5, 2] = np.inf
|
||||
|
||||
with pytest.raises(ValueError, match="Input data contains `inf`"):
|
||||
xgboost.QuantileDMatrix(X, y)
|
||||
|
||||
with pytest.raises(ValueError, match="Input data contains `inf`"):
|
||||
xgboost.DMatrix(X, y)
|
||||
|
||||
@ -20,5 +20,9 @@ constexpr StringView GroupSize() {
|
||||
constexpr StringView LabelScoreSize() {
|
||||
return "The size of label doesn't match the size of prediction.";
|
||||
}
|
||||
|
||||
constexpr StringView InfInData() {
|
||||
return "Input data contains `inf` while `missing` is not set to `inf`";
|
||||
}
|
||||
} // namespace xgboost::error
|
||||
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
||||
|
||||
@ -10,13 +10,16 @@
|
||||
#include <cstring>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/algorithm.h" // StableSort
|
||||
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||
#include "../collective/communicator.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/algorithm.h" // for StableSort
|
||||
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||
#include "../common/error_msg.h" // for InfInData
|
||||
#include "../common/group_data.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/linalg_op.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/numeric.h" // Iota
|
||||
#include "../common/numeric.h" // for Iota
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/version.h"
|
||||
#include "../data/adapter.h"
|
||||
@ -1144,7 +1147,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
||||
});
|
||||
}
|
||||
exec.Rethrow();
|
||||
CHECK(valid) << "Input data contains `inf` or `nan`";
|
||||
CHECK(valid) << error::InfInData();
|
||||
for (const auto & max : max_columns_vector) {
|
||||
max_columns = std::max(max_columns, max[0]);
|
||||
}
|
||||
|
||||
@ -4,6 +4,9 @@
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_DEVICE_ADAPTER_H_
|
||||
#define XGBOOST_DATA_DEVICE_ADAPTER_H_
|
||||
#include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
|
||||
#include <thrust/logical.h> // for none_of
|
||||
|
||||
#include <cstddef> // for size_t
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
@ -213,6 +216,20 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
|
||||
static_cast<std::size_t>(0), thrust::maximum<size_t>());
|
||||
return row_stride;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Check there's no inf in data.
|
||||
*/
|
||||
template <typename AdapterBatchT>
|
||||
bool HasInfInData(AdapterBatchT const& batch, IsValidFunctor is_valid) {
|
||||
auto counting = thrust::make_counting_iterator(0llu);
|
||||
auto value_iter = dh::MakeTransformIterator<float>(
|
||||
counting, [=] XGBOOST_DEVICE(std::size_t idx) { return batch.GetElement(idx).value; });
|
||||
auto valid =
|
||||
thrust::none_of(value_iter, value_iter + batch.Size(),
|
||||
[is_valid] XGBOOST_DEVICE(float v) { return is_valid(v) && std::isinf(v); });
|
||||
return valid;
|
||||
}
|
||||
}; // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_DEVICE_ADAPTER_H_
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019-2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2019-2023 by XGBoost contributors
|
||||
*/
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
#include <thrust/iterator/transform_output_iterator.h>
|
||||
@ -9,7 +9,7 @@
|
||||
#include "../common/random.h"
|
||||
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "./ellpack_page.cuh"
|
||||
#include "device_adapter.cuh"
|
||||
#include "device_adapter.cuh" // for HasInfInData
|
||||
#include "gradient_index.h"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
@ -189,8 +189,7 @@ struct TupleScanOp {
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
template <typename AdapterBatchT>
|
||||
void CopyDataToEllpack(const AdapterBatchT &batch,
|
||||
common::Span<FeatureType const> feature_types,
|
||||
void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType const> feature_types,
|
||||
EllpackPageImpl* dst, int device_idx, float missing) {
|
||||
// Some witchcraft happens here
|
||||
// The goal is to copy valid elements out of the input to an ELLPACK matrix
|
||||
@ -201,6 +200,9 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
|
||||
// correct output position
|
||||
auto counting = thrust::make_counting_iterator(0llu);
|
||||
data::IsValidFunctor is_valid(missing);
|
||||
bool valid = data::HasInfInData(batch, is_valid);
|
||||
CHECK(valid) << error::InfInData();
|
||||
|
||||
auto key_iter = dh::MakeTransformIterator<size_t>(
|
||||
counting,
|
||||
[=] __device__(size_t idx) {
|
||||
@ -239,9 +241,9 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
|
||||
cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
|
||||
TupleScanOp<Tuple>, cub::NullType, int64_t>;
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
|
||||
dh::safe_cuda(DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
|
||||
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
|
||||
nullptr);
|
||||
nullptr));
|
||||
#else
|
||||
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
|
||||
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
|
||||
@ -249,9 +251,9 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
|
||||
#endif
|
||||
dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
|
||||
dh::safe_cuda(DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
|
||||
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
||||
cub::NullType(), batch.Size(), nullptr);
|
||||
cub::NullType(), batch.Size(), nullptr));
|
||||
#else
|
||||
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
|
||||
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
||||
|
||||
@ -1,21 +1,23 @@
|
||||
/*!
|
||||
* Copyright 2017-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2017-2023 by XGBoost Contributors
|
||||
* \brief Data type for fast histogram aggregation.
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_GRADIENT_INDEX_H_
|
||||
#define XGBOOST_DATA_GRADIENT_INDEX_H_
|
||||
|
||||
#include <algorithm> // std::min
|
||||
#include <cinttypes> // std::uint32_t
|
||||
#include <cstddef> // std::size_t
|
||||
#include <algorithm> // for min
|
||||
#include <atomic> // for atomic
|
||||
#include <cinttypes> // for uint32_t
|
||||
#include <cstddef> // for size_t
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/error_msg.h" // for InfInData
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/numeric.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/transform_iterator.h" // common::MakeIndexTransformIter
|
||||
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
|
||||
#include "adapter.h"
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "xgboost/base.h"
|
||||
@ -62,6 +64,7 @@ class GHistIndexMatrix {
|
||||
BinIdxType* index_data = index_data_span.data();
|
||||
auto const& ptrs = cut.Ptrs();
|
||||
auto const& values = cut.Values();
|
||||
std::atomic<bool> valid{true};
|
||||
common::ParallelFor(batch_size, batch_threads, [&](size_t i) {
|
||||
auto line = batch.GetLine(i);
|
||||
size_t ibegin = row_ptr[rbegin + i]; // index of first entry for current block
|
||||
@ -70,6 +73,9 @@ class GHistIndexMatrix {
|
||||
for (size_t j = 0; j < line.Size(); ++j) {
|
||||
data::COOTuple elem = line.GetElement(j);
|
||||
if (is_valid(elem)) {
|
||||
if (XGBOOST_EXPECT((std::isinf(elem.value)), false)) {
|
||||
valid = false;
|
||||
}
|
||||
bst_bin_t bin_idx{-1};
|
||||
if (common::IsCat(ft, elem.column_idx)) {
|
||||
bin_idx = cut.SearchCatBin(elem.value, elem.column_idx, ptrs, values);
|
||||
@ -82,6 +88,8 @@ class GHistIndexMatrix {
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
CHECK(valid) << error::InfInData();
|
||||
}
|
||||
|
||||
// Gather hit_count from all threads
|
||||
|
||||
@ -1,18 +1,19 @@
|
||||
/*!
|
||||
* Copyright 2019-2021 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2019-2023 by XGBoost Contributors
|
||||
* \file simple_dmatrix.cuh
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_SIMPLE_DMATRIX_CUH_
|
||||
#define XGBOOST_DATA_SIMPLE_DMATRIX_CUH_
|
||||
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/scan.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include "device_adapter.cuh"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include <thrust/scan.h>
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/error_msg.h" // for InfInData
|
||||
#include "device_adapter.cuh" // for HasInfInData
|
||||
|
||||
namespace xgboost::data {
|
||||
|
||||
template <typename AdapterBatchT>
|
||||
struct COOToEntryOp {
|
||||
@ -61,7 +62,11 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
||||
}
|
||||
|
||||
template <typename AdapterBatchT>
|
||||
size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing, SparsePage* page) {
|
||||
size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing,
|
||||
SparsePage* page) {
|
||||
bool valid = HasInfInData(batch, IsValidFunctor{missing});
|
||||
CHECK(valid) << error::InfInData();
|
||||
|
||||
page->offset.SetDevice(device);
|
||||
page->data.SetDevice(device);
|
||||
page->offset.Resize(batch.NumRows() + 1);
|
||||
@ -73,6 +78,5 @@ size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missin
|
||||
|
||||
return num_nonzero_;
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::data
|
||||
#endif // XGBOOST_DATA_SIMPLE_DMATRIX_CUH_
|
||||
|
||||
@ -6,6 +6,7 @@ from hypothesis import given, settings, strategies
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.testing.data import check_inf
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import test_quantile_dmatrix as tqd
|
||||
@ -153,3 +154,9 @@ class TestQuantileDMatrix:
|
||||
from_qdm = xgb.QuantileDMatrix(X, weight=w, ref=Xy_qdm)
|
||||
|
||||
assert tm.predictor_equal(from_qdm, from_dm)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_check_inf(self) -> None:
|
||||
import cupy as cp
|
||||
rng = cp.random.default_rng(1994)
|
||||
check_inf(rng)
|
||||
|
||||
@ -15,7 +15,7 @@ from xgboost.testing import (
|
||||
make_sparse_regression,
|
||||
predictor_equal,
|
||||
)
|
||||
from xgboost.testing.data import np_dtypes
|
||||
from xgboost.testing.data import check_inf, np_dtypes
|
||||
|
||||
|
||||
class TestQuantileDMatrix:
|
||||
@ -244,6 +244,10 @@ class TestQuantileDMatrix:
|
||||
from_dm = xgb.QuantileDMatrix(X, weight=w, ref=Xy)
|
||||
assert predictor_equal(from_qdm, from_dm)
|
||||
|
||||
def test_check_inf(self) -> None:
|
||||
rng = np.random.default_rng(1994)
|
||||
check_inf(rng)
|
||||
|
||||
# we don't test empty Quantile DMatrix in single node construction.
|
||||
@given(
|
||||
strategies.integers(1, 1000),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user