Check inf in data for all types of DMatrix. (#8911)

This commit is contained in:
Jiaming Yuan 2023-03-15 11:24:35 +08:00 committed by GitHub
parent 72e8331eab
commit f186c87cf9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 118 additions and 45 deletions

View File

@ -84,9 +84,10 @@ public class BoosterTest {
}; };
try (Table tmpTable = Table.readCSV(schema, opts, new File(trainingDataPath))) { try (Table tmpTable = Table.readCSV(schema, opts, new File(trainingDataPath))) {
ColumnVector[] df = new ColumnVector[12]; ColumnVector[] df = new ColumnVector[10];
for (int i = 0; i < 12; ++i) { // exclude the first two columns, they are label bounds and contain inf.
df[i] = tmpTable.getColumn(i); for (int i = 2; i < 12; ++i) {
df[i - 2] = tmpTable.getColumn(i);
} }
try (Table X = new Table(df);) { try (Table X = new Table(df);) {
ColumnVector[] labels = new ColumnVector[1]; ColumnVector[] labels = new ColumnVector[1];

View File

@ -21,7 +21,7 @@ import java.io.File
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier} import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier}
import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.functions.{col, udf, when}
import org.apache.spark.sql.types.{FloatType, StructField, StructType} import org.apache.spark.sql.types.{FloatType, StructField, StructType}
class GpuXGBoostClassifierSuite extends GpuTestSuite { class GpuXGBoostClassifierSuite extends GpuTestSuite {
@ -47,7 +47,8 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist", "num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
"features_cols" -> featureNames, "label_col" -> labelName) "features_cols" -> featureNames, "label_col" -> labelName)
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema) val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1) .csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
.randomSplit(Array(0.7, 0.3), seed = 1)
// Get a model // Get a model
val model = new XGBoostClassifier(xgbParam) val model = new XGBoostClassifier(xgbParam)
.fit(originalDf) .fit(originalDf)
@ -64,7 +65,8 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist", "num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
"features_cols" -> featureNames, "label_col" -> labelName) "features_cols" -> featureNames, "label_col" -> labelName)
val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema) val Array(originalDf, testDf) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1) .csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
.randomSplit(Array(0.7, 0.3), seed = 1)
val getWeightFromF1 = udf({ f1: Float => if (f1.toInt % 2 == 0) 1.0f else 0.001f }) val getWeightFromF1 = udf({ f1: Float => if (f1.toInt % 2 == 0) 1.0f else 0.001f })
val dfWithWeight = originalDf.withColumn("weight", getWeightFromF1(col("f1"))) val dfWithWeight = originalDf.withColumn("weight", getWeightFromF1(col("f1")))
@ -87,7 +89,8 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic", val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
"num_round" -> 10, "num_workers" -> 1) "num_round" -> 10, "num_workers" -> 1)
val Array(rawInput, testDf) = spark.read.option("header", "true").schema(schema) val Array(rawInput, testDf) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1) .csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
.randomSplit(Array(0.7, 0.3), seed = 1)
val classifier = new XGBoostClassifier(xgbParam) val classifier = new XGBoostClassifier(xgbParam)
.setFeaturesCol(featureNames) .setFeaturesCol(featureNames)
@ -122,7 +125,8 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic", val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
"num_round" -> 10, "num_workers" -> 1) "num_round" -> 10, "num_workers" -> 1)
val Array(rawInput, _) = spark.read.option("header", "true").schema(schema) val Array(rawInput, _) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1) .csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
.randomSplit(Array(0.7, 0.3), seed = 1)
val vectorAssembler = new VectorAssembler() val vectorAssembler = new VectorAssembler()
.setHandleInvalid("keep") .setHandleInvalid("keep")
@ -144,7 +148,8 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
// transform on GPU // transform on GPU
withGpuSparkSession() { spark => withGpuSparkSession() { spark =>
val Array(_, testDf) = spark.read.option("header", "true").schema(schema) val Array(_, testDf) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1) .csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
.randomSplit(Array(0.7, 0.3), seed = 1)
// Since CPU model does not know the information about the features cols that GPU transform // Since CPU model does not know the information about the features cols that GPU transform
// pipeline requires. End user needs to setFeaturesCol(features: Array[String]) in the model // pipeline requires. End user needs to setFeaturesCol(features: Array[String]) in the model
@ -174,7 +179,8 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic", val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "binary:logistic",
"num_round" -> 10, "num_workers" -> 1) "num_round" -> 10, "num_workers" -> 1)
val Array(rawInput, _) = spark.read.option("header", "true").schema(schema) val Array(rawInput, _) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1) .csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
.randomSplit(Array(0.7, 0.3), seed = 1)
val classifier = new XGBoostClassifier(xgbParam) val classifier = new XGBoostClassifier(xgbParam)
.setFeaturesCol(featureNames) .setFeaturesCol(featureNames)
@ -190,7 +196,8 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite {
// transform on CPU // transform on CPU
withCpuSparkSession() { spark => withCpuSparkSession() { spark =>
val Array(_, rawInput) = spark.read.option("header", "true").schema(schema) val Array(_, rawInput) = spark.read.option("header", "true").schema(schema)
.csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1) .csv(dataPath).withColumn("f2", when(col("f2").isin(Float.PositiveInfinity), 0))
.randomSplit(Array(0.7, 0.3), seed = 1)
val featureColName = "feature_col" val featureColName = "feature_col"
val vectorAssembler = new VectorAssembler() val vectorAssembler = new VectorAssembler()

View File

@ -2,7 +2,10 @@
from typing import Any, Generator, Tuple, Union from typing import Any, Generator, Tuple, Union
import numpy as np import numpy as np
import pytest
from numpy.random import Generator as RNG
import xgboost
from xgboost.data import pandas_pyarrow_mapper from xgboost.data import pandas_pyarrow_mapper
@ -179,3 +182,16 @@ def pd_arrow_dtypes() -> Generator:
dtype=pd.ArrowDtype(pa.bool_()), dtype=pd.ArrowDtype(pa.bool_()),
) )
yield orig, df yield orig, df
def check_inf(rng: RNG) -> None:
"""Validate there's no inf in X."""
X = rng.random(size=32).reshape(8, 4)
y = rng.random(size=8)
X[5, 2] = np.inf
with pytest.raises(ValueError, match="Input data contains `inf`"):
xgboost.QuantileDMatrix(X, y)
with pytest.raises(ValueError, match="Input data contains `inf`"):
xgboost.DMatrix(X, y)

View File

@ -20,5 +20,9 @@ constexpr StringView GroupSize() {
constexpr StringView LabelScoreSize() { constexpr StringView LabelScoreSize() {
return "The size of label doesn't match the size of prediction."; return "The size of label doesn't match the size of prediction.";
} }
constexpr StringView InfInData() {
return "Input data contains `inf` while `missing` is not set to `inf`";
}
} // namespace xgboost::error } // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_ #endif // XGBOOST_COMMON_ERROR_MSG_H_

View File

@ -10,13 +10,16 @@
#include <cstring> #include <cstring>
#include "../collective/communicator-inl.h" #include "../collective/communicator-inl.h"
#include "../common/algorithm.h" // StableSort #include "../collective/communicator.h"
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry #include "../common/common.h"
#include "../common/algorithm.h" // for StableSort
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/error_msg.h" // for InfInData
#include "../common/group_data.h" #include "../common/group_data.h"
#include "../common/io.h" #include "../common/io.h"
#include "../common/linalg_op.h" #include "../common/linalg_op.h"
#include "../common/math.h" #include "../common/math.h"
#include "../common/numeric.h" // Iota #include "../common/numeric.h" // for Iota
#include "../common/threading_utils.h" #include "../common/threading_utils.h"
#include "../common/version.h" #include "../common/version.h"
#include "../data/adapter.h" #include "../data/adapter.h"
@ -1144,7 +1147,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
}); });
} }
exec.Rethrow(); exec.Rethrow();
CHECK(valid) << "Input data contains `inf` or `nan`"; CHECK(valid) << error::InfInData();
for (const auto & max : max_columns_vector) { for (const auto & max : max_columns_vector) {
max_columns = std::max(max_columns, max[0]); max_columns = std::max(max_columns, max[0]);
} }

View File

@ -4,7 +4,10 @@
*/ */
#ifndef XGBOOST_DATA_DEVICE_ADAPTER_H_ #ifndef XGBOOST_DATA_DEVICE_ADAPTER_H_
#define XGBOOST_DATA_DEVICE_ADAPTER_H_ #define XGBOOST_DATA_DEVICE_ADAPTER_H_
#include <cstddef> // for size_t #include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
#include <thrust/logical.h> // for none_of
#include <cstddef> // for size_t
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <string> #include <string>
@ -213,6 +216,20 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
static_cast<std::size_t>(0), thrust::maximum<size_t>()); static_cast<std::size_t>(0), thrust::maximum<size_t>());
return row_stride; return row_stride;
} }
/**
* \brief Check there's no inf in data.
*/
template <typename AdapterBatchT>
bool HasInfInData(AdapterBatchT const& batch, IsValidFunctor is_valid) {
auto counting = thrust::make_counting_iterator(0llu);
auto value_iter = dh::MakeTransformIterator<float>(
counting, [=] XGBOOST_DEVICE(std::size_t idx) { return batch.GetElement(idx).value; });
auto valid =
thrust::none_of(value_iter, value_iter + batch.Size(),
[is_valid] XGBOOST_DEVICE(float v) { return is_valid(v) && std::isinf(v); });
return valid;
}
}; // namespace data }; // namespace data
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_DATA_DEVICE_ADAPTER_H_ #endif // XGBOOST_DATA_DEVICE_ADAPTER_H_

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2019-2022 XGBoost contributors * Copyright 2019-2023 by XGBoost contributors
*/ */
#include <thrust/iterator/discard_iterator.h> #include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h> #include <thrust/iterator/transform_output_iterator.h>
@ -9,7 +9,7 @@
#include "../common/random.h" #include "../common/random.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter #include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "./ellpack_page.cuh" #include "./ellpack_page.cuh"
#include "device_adapter.cuh" #include "device_adapter.cuh" // for HasInfInData
#include "gradient_index.h" #include "gradient_index.h"
#include "xgboost/data.h" #include "xgboost/data.h"
@ -189,9 +189,8 @@ struct TupleScanOp {
// Here the data is already correctly ordered and simply needs to be compacted // Here the data is already correctly ordered and simply needs to be compacted
// to remove missing data // to remove missing data
template <typename AdapterBatchT> template <typename AdapterBatchT>
void CopyDataToEllpack(const AdapterBatchT &batch, void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType const> feature_types,
common::Span<FeatureType const> feature_types, EllpackPageImpl* dst, int device_idx, float missing) {
EllpackPageImpl *dst, int device_idx, float missing) {
// Some witchcraft happens here // Some witchcraft happens here
// The goal is to copy valid elements out of the input to an ELLPACK matrix // The goal is to copy valid elements out of the input to an ELLPACK matrix
// with a given row stride, using no extra working memory Standard stream // with a given row stride, using no extra working memory Standard stream
@ -201,6 +200,9 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
// correct output position // correct output position
auto counting = thrust::make_counting_iterator(0llu); auto counting = thrust::make_counting_iterator(0llu);
data::IsValidFunctor is_valid(missing); data::IsValidFunctor is_valid(missing);
bool valid = data::HasInfInData(batch, is_valid);
CHECK(valid) << error::InfInData();
auto key_iter = dh::MakeTransformIterator<size_t>( auto key_iter = dh::MakeTransformIterator<size_t>(
counting, counting,
[=] __device__(size_t idx) { [=] __device__(size_t idx) {
@ -239,9 +241,9 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
cub::DispatchScan<decltype(key_value_index_iter), decltype(out), cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
TupleScanOp<Tuple>, cub::NullType, int64_t>; TupleScanOp<Tuple>, cub::NullType, int64_t>;
#if THRUST_MAJOR_VERSION >= 2 #if THRUST_MAJOR_VERSION >= 2
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out, dh::safe_cuda(DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(), TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
nullptr); nullptr));
#else #else
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out, DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(), TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
@ -249,9 +251,9 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
#endif #endif
dh::TemporaryArray<char> temp_storage(temp_storage_bytes); dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
#if THRUST_MAJOR_VERSION >= 2 #if THRUST_MAJOR_VERSION >= 2
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes, dh::safe_cuda(DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(), key_value_index_iter, out, TupleScanOp<Tuple>(),
cub::NullType(), batch.Size(), nullptr); cub::NullType(), batch.Size(), nullptr));
#else #else
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes, DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(), key_value_index_iter, out, TupleScanOp<Tuple>(),

View File

@ -1,21 +1,23 @@
/*! /**
* Copyright 2017-2022 by XGBoost Contributors * Copyright 2017-2023 by XGBoost Contributors
* \brief Data type for fast histogram aggregation. * \brief Data type for fast histogram aggregation.
*/ */
#ifndef XGBOOST_DATA_GRADIENT_INDEX_H_ #ifndef XGBOOST_DATA_GRADIENT_INDEX_H_
#define XGBOOST_DATA_GRADIENT_INDEX_H_ #define XGBOOST_DATA_GRADIENT_INDEX_H_
#include <algorithm> // std::min #include <algorithm> // for min
#include <cinttypes> // std::uint32_t #include <atomic> // for atomic
#include <cstddef> // std::size_t #include <cinttypes> // for uint32_t
#include <cstddef> // for size_t
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "../common/categorical.h" #include "../common/categorical.h"
#include "../common/error_msg.h" // for InfInData
#include "../common/hist_util.h" #include "../common/hist_util.h"
#include "../common/numeric.h" #include "../common/numeric.h"
#include "../common/threading_utils.h" #include "../common/threading_utils.h"
#include "../common/transform_iterator.h" // common::MakeIndexTransformIter #include "../common/transform_iterator.h" // for MakeIndexTransformIter
#include "adapter.h" #include "adapter.h"
#include "proxy_dmatrix.h" #include "proxy_dmatrix.h"
#include "xgboost/base.h" #include "xgboost/base.h"
@ -62,6 +64,7 @@ class GHistIndexMatrix {
BinIdxType* index_data = index_data_span.data(); BinIdxType* index_data = index_data_span.data();
auto const& ptrs = cut.Ptrs(); auto const& ptrs = cut.Ptrs();
auto const& values = cut.Values(); auto const& values = cut.Values();
std::atomic<bool> valid{true};
common::ParallelFor(batch_size, batch_threads, [&](size_t i) { common::ParallelFor(batch_size, batch_threads, [&](size_t i) {
auto line = batch.GetLine(i); auto line = batch.GetLine(i);
size_t ibegin = row_ptr[rbegin + i]; // index of first entry for current block size_t ibegin = row_ptr[rbegin + i]; // index of first entry for current block
@ -70,6 +73,9 @@ class GHistIndexMatrix {
for (size_t j = 0; j < line.Size(); ++j) { for (size_t j = 0; j < line.Size(); ++j) {
data::COOTuple elem = line.GetElement(j); data::COOTuple elem = line.GetElement(j);
if (is_valid(elem)) { if (is_valid(elem)) {
if (XGBOOST_EXPECT((std::isinf(elem.value)), false)) {
valid = false;
}
bst_bin_t bin_idx{-1}; bst_bin_t bin_idx{-1};
if (common::IsCat(ft, elem.column_idx)) { if (common::IsCat(ft, elem.column_idx)) {
bin_idx = cut.SearchCatBin(elem.value, elem.column_idx, ptrs, values); bin_idx = cut.SearchCatBin(elem.value, elem.column_idx, ptrs, values);
@ -82,6 +88,8 @@ class GHistIndexMatrix {
} }
} }
}); });
CHECK(valid) << error::InfInData();
} }
// Gather hit_count from all threads // Gather hit_count from all threads

View File

@ -1,18 +1,19 @@
/*! /**
* Copyright 2019-2021 by XGBoost Contributors * Copyright 2019-2023 by XGBoost Contributors
* \file simple_dmatrix.cuh * \file simple_dmatrix.cuh
*/ */
#ifndef XGBOOST_DATA_SIMPLE_DMATRIX_CUH_ #ifndef XGBOOST_DATA_SIMPLE_DMATRIX_CUH_
#define XGBOOST_DATA_SIMPLE_DMATRIX_CUH_ #define XGBOOST_DATA_SIMPLE_DMATRIX_CUH_
#include <thrust/copy.h> #include <thrust/copy.h>
#include <thrust/scan.h>
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include "device_adapter.cuh" #include <thrust/scan.h>
#include "../common/device_helpers.cuh"
namespace xgboost { #include "../common/device_helpers.cuh"
namespace data { #include "../common/error_msg.h" // for InfInData
#include "device_adapter.cuh" // for HasInfInData
namespace xgboost::data {
template <typename AdapterBatchT> template <typename AdapterBatchT>
struct COOToEntryOp { struct COOToEntryOp {
@ -61,7 +62,11 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
} }
template <typename AdapterBatchT> template <typename AdapterBatchT>
size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing, SparsePage* page) { size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing,
SparsePage* page) {
bool valid = HasInfInData(batch, IsValidFunctor{missing});
CHECK(valid) << error::InfInData();
page->offset.SetDevice(device); page->offset.SetDevice(device);
page->data.SetDevice(device); page->data.SetDevice(device);
page->offset.Resize(batch.NumRows() + 1); page->offset.Resize(batch.NumRows() + 1);
@ -73,6 +78,5 @@ size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missin
return num_nonzero_; return num_nonzero_;
} }
} // namespace data } // namespace xgboost::data
} // namespace xgboost
#endif // XGBOOST_DATA_SIMPLE_DMATRIX_CUH_ #endif // XGBOOST_DATA_SIMPLE_DMATRIX_CUH_

View File

@ -6,6 +6,7 @@ from hypothesis import given, settings, strategies
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
from xgboost.testing.data import check_inf
sys.path.append("tests/python") sys.path.append("tests/python")
import test_quantile_dmatrix as tqd import test_quantile_dmatrix as tqd
@ -153,3 +154,9 @@ class TestQuantileDMatrix:
from_qdm = xgb.QuantileDMatrix(X, weight=w, ref=Xy_qdm) from_qdm = xgb.QuantileDMatrix(X, weight=w, ref=Xy_qdm)
assert tm.predictor_equal(from_qdm, from_dm) assert tm.predictor_equal(from_qdm, from_dm)
@pytest.mark.skipif(**tm.no_cupy())
def test_check_inf(self) -> None:
import cupy as cp
rng = cp.random.default_rng(1994)
check_inf(rng)

View File

@ -15,7 +15,7 @@ from xgboost.testing import (
make_sparse_regression, make_sparse_regression,
predictor_equal, predictor_equal,
) )
from xgboost.testing.data import np_dtypes from xgboost.testing.data import check_inf, np_dtypes
class TestQuantileDMatrix: class TestQuantileDMatrix:
@ -244,6 +244,10 @@ class TestQuantileDMatrix:
from_dm = xgb.QuantileDMatrix(X, weight=w, ref=Xy) from_dm = xgb.QuantileDMatrix(X, weight=w, ref=Xy)
assert predictor_equal(from_qdm, from_dm) assert predictor_equal(from_qdm, from_dm)
def test_check_inf(self) -> None:
rng = np.random.default_rng(1994)
check_inf(rng)
# we don't test empty Quantile DMatrix in single node construction. # we don't test empty Quantile DMatrix in single node construction.
@given( @given(
strategies.integers(1, 1000), strategies.integers(1, 1000),