Improve test coverage with predictor configuration. (#9354)

* Improve test coverage with predictor configuration.

- Test with ext memory.
- Test with QDM.
- Test with dart.
This commit is contained in:
Jiaming Yuan 2023-07-05 15:17:22 +08:00 committed by GitHub
parent 6c9c8a9001
commit 645037e376
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 280 additions and 79 deletions

View File

@ -9,9 +9,10 @@
#include <xgboost/logging.h> // for CHECK_GE
#include <xgboost/parameter.h> // for XGBoostParameter
#include <cstdint> // for int16_t, int32_t, int64_t
#include <memory> // for shared_ptr
#include <string> // for string, to_string
#include <cstdint> // for int16_t, int32_t, int64_t
#include <memory> // for shared_ptr
#include <string> // for string, to_string
#include <type_traits> // for invoke_result_t, is_same_v
namespace xgboost {
@ -152,6 +153,25 @@ struct Context : public XGBoostParameter<Context> {
ctx.gpu_id = kCpuId;
return ctx;
}
/**
* @brief Call function based on the current device.
*/
template <typename CPUFn, typename CUDAFn>
decltype(auto) DispatchDevice(CPUFn&& cpu_fn, CUDAFn&& cuda_fn) const {
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<CUDAFn>>);
switch (this->Device().device) {
case DeviceOrd::kCPU:
return cpu_fn();
case DeviceOrd::kCUDA:
return cuda_fn();
default:
// Do not use the device name as this is likely an internal error, the name
// wouldn't be valid.
LOG(FATAL) << "Unknown device type:" << static_cast<std::int16_t>(this->Device().device);
break;
}
return std::invoke_result_t<CPUFn>();
}
// declare parameters
DMLC_DECLARE_PARAMETER(Context) {

View File

@ -6,24 +6,22 @@
*/
#pragma once
#include <xgboost/base.h>
#include <xgboost/cache.h> // DMatrixCache
#include <xgboost/cache.h> // for DMatrixCache
#include <xgboost/context.h> // for Context
#include <xgboost/context.h>
#include <xgboost/data.h>
#include <xgboost/host_device_vector.h>
#include <functional> // std::function
#include <memory>
#include <functional> // for function
#include <memory> // for shared_ptr
#include <string>
#include <thread> // for get_id
#include <utility> // for make_pair
#include <vector>
// Forward declarations
namespace xgboost {
namespace gbm {
namespace xgboost::gbm {
struct GBTreeModel;
} // namespace gbm
} // namespace xgboost
} // namespace xgboost::gbm
namespace xgboost {
/**

View File

@ -47,5 +47,9 @@ inline void MaxFeatureSize(std::uint64_t n_features) {
<< "Unfortunately, XGBoost does not support data matrices with "
<< std::numeric_limits<bst_feature_t>::max() << " features or greater";
}
constexpr StringView InplacePredictProxy() {
return "Inplace predict accepts only DMatrixProxy as input.";
}
} // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_

View File

@ -68,6 +68,7 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
}
std::size_t Write(GHistIndexMatrix const& page, common::AlignedFileWriteStream* fo) override {
CHECK_NE(page.index.Size(), 0) << "Empty page is not supported.";
std::size_t bytes = 0;
bytes += WriteHistogramCuts(page.cut, fo);
// indptr

View File

@ -1,10 +1,9 @@
/*!
* Copyright 2021-2022 by XGBoost Contributors
/**
* Copyright 2021-2023, XGBoost Contributors
*/
#include "gradient_index_page_source.h"
namespace xgboost {
namespace data {
namespace xgboost::data {
void GradientIndexPageSource::Fetch() {
if (!this->ReadCache()) {
if (count_ != 0 && !sync_) {
@ -21,5 +20,4 @@ void GradientIndexPageSource::Fetch() {
this->WriteCache();
}
}
} // namespace data
} // namespace xgboost
} // namespace xgboost::data

View File

@ -18,7 +18,7 @@
#include <vector>
#include "../common/common.h"
#include "../common/error_msg.h" // for UnknownDevice
#include "../common/error_msg.h" // for UnknownDevice, InplacePredictProxy
#include "../common/random.h"
#include "../common/threading_utils.h"
#include "../common/timer.h"
@ -542,6 +542,18 @@ void GBTree::PredictBatchImpl(DMatrix* p_fmat, PredictionCacheEntry* out_preds,
}
}
namespace {
inline void MismatchedDevices(Context const* booster, Context const* data) {
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
<< "is running on: " << booster->DeviceName()
<< ", while the input data is on: " << data->DeviceName() << ".\n"
<< R"(Potential solutions:
- Use a data structure that matches the device ordinal in the booster.
- Set the device for booster before call to inplace_predict.
)";
}
}; // namespace
void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
bst_layer_t layer_begin, bst_layer_t layer_end) {
// dispatch to const function.
@ -555,24 +567,26 @@ void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
if (p_m->Ctx()->Device() != this->ctx_->Device()) {
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
<< "is running on: " << this->ctx_->DeviceName()
<< ", while the input data is on: " << p_m->Ctx()->DeviceName() << ".";
MismatchedDevices(this->ctx_, p_m->Ctx());
CHECK_EQ(out_preds->version, 0);
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
auto any_adapter = proxy->Adapter();
CHECK(proxy) << error::InplacePredictProxy();
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);
this->PredictBatchImpl(p_fmat.get(), out_preds, false, layer_begin, layer_end);
return;
}
if (this->ctx_->IsCPU()) {
this->cpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end);
} else if (p_m->Ctx()->IsCUDA()) {
CHECK(this->gpu_predictor_);
this->gpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end);
} else {
LOG(FATAL) << error::UnknownDevice();
bool known_type = this->ctx_->DispatchDevice(
[&, begin = tree_begin, end = tree_end] {
return this->cpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, begin, end);
},
[&, begin = tree_begin, end = tree_end] {
return this->gpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, begin, end);
});
if (!known_type) {
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
CHECK(proxy) << error::InplacePredictProxy();
LOG(FATAL) << "Unknown data type for inplace prediction:" << proxy->Adapter().type().name();
}
}
@ -808,11 +822,9 @@ class Dart : public GBTree {
auto n_groups = model_.learner_model_param->num_output_group;
if (ctx_->Device() != p_fmat->Ctx()->Device()) {
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
<< "is running on: " << this->ctx_->DeviceName()
<< ", while the input data is on: " << p_fmat->Ctx()->DeviceName() << ".";
MismatchedDevices(ctx_, p_fmat->Ctx());
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_fmat);
auto any_adapter = proxy->Adapter();
CHECK(proxy) << error::InplacePredictProxy();
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);
this->PredictBatchImpl(p_fmat.get(), p_out_preds, false, layer_begin, layer_end);
return;
@ -825,20 +837,15 @@ class Dart : public GBTree {
}
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
auto get_predictor = [&]() -> Predictor const* {
if (ctx_->IsCPU()) {
return cpu_predictor_.get();
} else if (ctx_->IsCUDA()) {
CHECK(this->gpu_predictor_);
return gpu_predictor_.get();
} else {
LOG(FATAL) << error::UnknownDevice();
return nullptr;
}
};
auto predict_impl = [&](size_t i) {
predts.predictions.Fill(0);
bool success{get_predictor()->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1)};
bool success = this->ctx_->DispatchDevice(
[&] {
return cpu_predictor_->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
},
[&] {
return gpu_predictor_->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
});
CHECK(success) << msg;
};
@ -846,7 +853,15 @@ class Dart : public GBTree {
for (bst_tree_t i = tree_begin; i < tree_end; ++i) {
predict_impl(i);
if (i == tree_begin) {
get_predictor()->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, model_);
this->ctx_->DispatchDevice(
[&] {
this->cpu_predictor_->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
model_);
},
[&] {
this->gpu_predictor_->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
model_);
});
}
// Multiple the tree weight
auto w = this->weight_drop_.at(i);

View File

@ -16,6 +16,7 @@
#include "../common/bitfield.h" // for RBitField8
#include "../common/categorical.h" // for IsCat, Decision
#include "../common/common.h" // for DivRoundUp
#include "../common/error_msg.h" // for InplacePredictProxy
#include "../common/math.h" // for CheckNAN
#include "../common/threading_utils.h" // for ParallelFor
#include "../data/adapter.h" // for ArrayAdapter, CSRAdapter, CSRArrayAdapter
@ -741,7 +742,7 @@ class CPUPredictor : public Predictor {
PredictionCacheEntry *out_preds, uint32_t tree_begin,
unsigned tree_end) const override {
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input.";
CHECK(proxy)<< error::InplacePredictProxy();
CHECK(!p_m->Info().IsColumnSplit())
<< "Inplace predict support for column-wise data split is not yet implemented.";
auto x = proxy->Adapter();

View File

@ -15,8 +15,9 @@
#include "../common/bitfield.h"
#include "../common/categorical.h"
#include "../common/common.h"
#include "../common/cuda_context.cuh"
#include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/device_helpers.cuh"
#include "../common/error_msg.h" // for InplacePredictProxy
#include "../data/device_adapter.cuh"
#include "../data/ellpack_page.cuh"
#include "../data/proxy_dmatrix.h"
@ -989,7 +990,7 @@ class GPUPredictor : public xgboost::Predictor {
PredictionCacheEntry* out_preds, uint32_t tree_begin,
unsigned tree_end) const override {
auto proxy = dynamic_cast<data::DMatrixProxy*>(p_m.get());
CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input.";
CHECK(proxy) << error::InplacePredictProxy();
auto x = proxy->Adapter();
if (x.type() == typeid(std::shared_ptr<data::CupyAdapter>)) {
this->DispatchedInplacePredict<data::CupyAdapter,

View File

@ -27,26 +27,31 @@
#include "xgboost/host_device_vector.h" // for HostDeviceVector
namespace xgboost::data {
TEST(GradientIndex, ExternalMemory) {
TEST(GradientIndex, ExternalMemoryBaseRowID) {
Context ctx;
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(10000);
auto p_fmat = RandomDataGenerator{4096, 256, 0.5}
.Device(ctx.gpu_id)
.Batches(8)
.GenerateSparsePageDMatrix("cache", true);
std::vector<size_t> base_rowids;
std::vector<float> hessian(dmat->Info().num_row_, 1);
for (auto const &page : dmat->GetBatches<GHistIndexMatrix>(&ctx, {64, hessian, true})) {
std::vector<float> hessian(p_fmat->Info().num_row_, 1);
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, {64, hessian, true})) {
base_rowids.push_back(page.base_rowid);
}
size_t i = 0;
for (auto const &page : dmat->GetBatches<SparsePage>()) {
std::size_t i = 0;
for (auto const &page : p_fmat->GetBatches<SparsePage>()) {
ASSERT_EQ(base_rowids[i], page.base_rowid);
++i;
}
base_rowids.clear();
for (auto const &page : dmat->GetBatches<GHistIndexMatrix>(&ctx, {64, hessian, false})) {
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, {64, hessian, false})) {
base_rowids.push_back(page.base_rowid);
}
i = 0;
for (auto const &page : dmat->GetBatches<SparsePage>()) {
for (auto const &page : p_fmat->GetBatches<SparsePage>()) {
ASSERT_EQ(base_rowids[i], page.base_rowid);
++i;
}

View File

@ -76,9 +76,11 @@ TEST(SparsePageDMatrix, LoadFile) {
// allow caller to retain pages so they can process multiple pages at the same time.
template <typename Page>
void TestRetainPage() {
auto m = CreateSparsePageDMatrix(10000);
std::size_t n_batches = 4;
auto p_fmat = RandomDataGenerator{1024, 128, 0.5f}.Batches(n_batches).GenerateSparsePageDMatrix(
"cache", true);
Context ctx;
auto batches = m->GetBatches<Page>(&ctx);
auto batches = p_fmat->GetBatches<Page>(&ctx);
auto begin = batches.begin();
auto end = batches.end();
@ -94,7 +96,7 @@ void TestRetainPage() {
}
ASSERT_EQ(pages.back().Size(), (*it).Size());
}
ASSERT_GE(iterators.size(), 2);
ASSERT_GE(iterators.size(), n_batches);
for (size_t i = 0; i < iterators.size(); ++i) {
ASSERT_EQ((*iterators[i]).Size(), pages.at(i).Size());
@ -102,7 +104,7 @@ void TestRetainPage() {
}
// make sure it's const and the caller can not modify the content of page.
for (auto &page : m->GetBatches<Page>({&ctx})) {
for (auto &page : p_fmat->GetBatches<Page>({&ctx})) {
static_assert(std::is_const<std::remove_reference_t<decltype(page)>>::value);
}
}

View File

@ -514,4 +514,86 @@ TEST(GBTree, PredictRange) {
dmlc::Error);
}
}
TEST(GBTree, InplacePredictionError) {
std::size_t n_samples{2048}, n_features{32};
auto test_ext_err = [&](std::string booster, Context const* ctx) {
std::shared_ptr<DMatrix> p_fmat =
RandomDataGenerator{n_samples, n_features, 0.5f}.Batches(2).GenerateSparsePageDMatrix(
"cache", true);
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
learner->SetParam("booster", booster);
ConfigLearnerByCtx(ctx, learner.get());
learner->Configure();
for (std::int32_t i = 0; i < 3; ++i) {
learner->UpdateOneIter(i, p_fmat);
}
HostDeviceVector<float>* out_predt;
ASSERT_THROW(
{
learner->InplacePredict(p_fmat, PredictionType::kValue,
std::numeric_limits<float>::quiet_NaN(), &out_predt, 0, 0);
},
dmlc::Error);
};
{
Context ctx;
test_ext_err("gbtree", &ctx);
test_ext_err("dart", &ctx);
}
#if defined(XGBOOST_USE_CUDA)
{
auto ctx = MakeCUDACtx(0);
test_ext_err("gbtree", &ctx);
test_ext_err("dart", &ctx);
}
#endif // defined(XGBOOST_USE_CUDA)
auto test_qdm_err = [&](std::string booster, Context const* ctx) {
std::shared_ptr<DMatrix> p_fmat;
bst_bin_t max_bins = 16;
auto rng = RandomDataGenerator{n_samples, n_features, 0.5f}.Device(ctx->gpu_id).Bins(max_bins);
if (ctx->IsCPU()) {
p_fmat = rng.GenerateQuantileDMatrix(true);
} else {
#if defined(XGBOOST_USE_CUDA)
p_fmat = rng.GenerateDeviceDMatrix(true);
#else
CHECK(p_fmat);
#endif // defined(XGBOOST_USE_CUDA)
};
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
learner->SetParam("booster", booster);
learner->SetParam("max_bin", std::to_string(max_bins));
ConfigLearnerByCtx(ctx, learner.get());
learner->Configure();
for (std::int32_t i = 0; i < 3; ++i) {
learner->UpdateOneIter(i, p_fmat);
}
HostDeviceVector<float>* out_predt;
ASSERT_THROW(
{
learner->InplacePredict(p_fmat, PredictionType::kValue,
std::numeric_limits<float>::quiet_NaN(), &out_predt, 0, 0);
},
dmlc::Error);
};
{
Context ctx;
test_qdm_err("gbtree", &ctx);
test_qdm_err("dart", &ctx);
}
#if defined(XGBOOST_USE_CUDA)
{
auto ctx = MakeCUDACtx(0);
test_qdm_err("gbtree", &ctx);
test_qdm_err("dart", &ctx);
}
#endif // defined(XGBOOST_USE_CUDA)
}
} // namespace xgboost

View File

@ -61,7 +61,6 @@ void TestInplaceFallback(Context const* ctx) {
learner->InplacePredict(p_m, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),
&out_predt, 0, 0);
auto output = testing::internal::GetCapturedStderr();
std::cout << "output:" << output << std::endl;
ASSERT_NE(output.find("Falling back"), std::string::npos);
// test when the contexts match

View File

@ -210,6 +210,16 @@ SimpleLCG::StateType SimpleLCG::Max() const { return max(); }
// Make sure it's compile time constant.
static_assert(SimpleLCG::max() - SimpleLCG::min());
void RandomDataGenerator::GenerateLabels(std::shared_ptr<DMatrix> p_fmat) const {
RandomDataGenerator{p_fmat->Info().num_row_, this->n_targets_, 0.0f}.GenerateDense(
p_fmat->Info().labels.Data());
CHECK_EQ(p_fmat->Info().labels.Size(), this->rows_ * this->n_targets_);
p_fmat->Info().labels.Reshape(this->rows_, this->n_targets_);
if (device_ != Context::kCpuId) {
p_fmat->Info().labels.SetDevice(device_);
}
}
void RandomDataGenerator::GenerateDense(HostDeviceVector<float> *out) const {
xgboost::SimpleRealUniformDistribution<bst_float> dist(lower_, upper_);
CHECK(out);
@ -363,8 +373,9 @@ void RandomDataGenerator::GenerateCSR(
CHECK_EQ(columns->Size(), value->Size());
}
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label,
size_t classes) const {
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label,
bool float_label,
size_t classes) const {
HostDeviceVector<float> data;
HostDeviceVector<bst_row_t> rptrs;
HostDeviceVector<bst_feature_t> columns;
@ -406,10 +417,58 @@ std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label, b
return out;
}
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateQuantileDMatrix() {
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateSparsePageDMatrix(
std::string prefix, bool with_label) const {
CHECK_GE(this->rows_, this->n_batches_);
CHECK_GE(this->n_batches_, 1)
<< "Must set the n_batches before generating an external memory DMatrix.";
std::unique_ptr<ArrayIterForTest> iter;
if (device_ == Context::kCpuId) {
iter = std::make_unique<NumpyArrayIterForTest>(this->sparsity_, rows_, cols_, n_batches_);
} else {
#if defined(XGBOOST_USE_CUDA)
iter = std::make_unique<CudaArrayIterForTest>(this->sparsity_, rows_, cols_, n_batches_);
#else
CHECK(iter);
#endif // defined(XGBOOST_USE_CUDA)
}
std::unique_ptr<DMatrix> dmat{
DMatrix::Create(static_cast<DataIterHandle>(iter.get()), iter->Proxy(), Reset, Next,
std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(), prefix)};
auto row_page_path =
data::MakeId(prefix, dynamic_cast<data::SparsePageDMatrix*>(dmat.get())) + ".row.page";
EXPECT_TRUE(FileExists(row_page_path)) << row_page_path;
// Loop over the batches and count the number of pages
std::size_t batch_count = 0;
bst_row_t row_count = 0;
for (const auto& batch : dmat->GetBatches<xgboost::SparsePage>()) {
batch_count++;
row_count += batch.Size();
CHECK_NE(batch.data.Size(), 0);
}
EXPECT_EQ(batch_count, n_batches_);
EXPECT_EQ(row_count, dmat->Info().num_row_);
if (with_label) {
RandomDataGenerator{dmat->Info().num_row_, this->n_targets_, 0.0f}.GenerateDense(
dmat->Info().labels.Data());
CHECK_EQ(dmat->Info().labels.Size(), this->rows_ * this->n_targets_);
dmat->Info().labels.Reshape(this->rows_, this->n_targets_);
}
return dmat;
}
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateQuantileDMatrix(bool with_label) {
NumpyArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1};
auto m = std::make_shared<data::IterativeDMatrix>(
&iter, iter.Proxy(), nullptr, Reset, Next, std::numeric_limits<float>::quiet_NaN(), 0, bins_);
if (with_label) {
this->GenerateLabels(m);
}
return m;
}

View File

@ -24,10 +24,13 @@ int CudaArrayIterForTest::Next() {
return 1;
}
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDeviceDMatrix() {
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDeviceDMatrix(bool with_label) {
CudaArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1};
auto m = std::make_shared<data::IterativeDMatrix>(
&iter, iter.Proxy(), nullptr, Reset, Next, std::numeric_limits<float>::quiet_NaN(), 0, bins_);
if (with_label) {
this->GenerateLabels(m);
}
return m;
}
} // namespace xgboost

View File

@ -238,15 +238,18 @@ class RandomDataGenerator {
bst_target_t n_targets_{1};
std::int32_t device_{Context::kCpuId};
std::size_t n_batches_{0};
std::uint64_t seed_{0};
SimpleLCG lcg_;
std::size_t bins_{0};
bst_bin_t bins_{0};
std::vector<FeatureType> ft_;
bst_cat_t max_cat_;
Json ArrayInterfaceImpl(HostDeviceVector<float>* storage, size_t rows, size_t cols) const;
void GenerateLabels(std::shared_ptr<DMatrix> p_fmat) const;
public:
RandomDataGenerator(bst_row_t rows, size_t cols, float sparsity)
: rows_{rows}, cols_{cols}, sparsity_{sparsity}, lcg_{seed_} {}
@ -263,12 +266,16 @@ class RandomDataGenerator {
device_ = d;
return *this;
}
RandomDataGenerator& Batches(std::size_t n_batches) {
n_batches_ = n_batches;
return *this;
}
RandomDataGenerator& Seed(uint64_t s) {
seed_ = s;
lcg_.Seed(seed_);
return *this;
}
RandomDataGenerator& Bins(size_t b) {
RandomDataGenerator& Bins(bst_bin_t b) {
bins_ = b;
return *this;
}
@ -309,12 +316,17 @@ class RandomDataGenerator {
void GenerateCSR(HostDeviceVector<float>* value, HostDeviceVector<bst_row_t>* row_ptr,
HostDeviceVector<bst_feature_t>* columns) const;
std::shared_ptr<DMatrix> GenerateDMatrix(bool with_label = false, bool float_label = true,
size_t classes = 1) const;
[[nodiscard]] std::shared_ptr<DMatrix> GenerateDMatrix(bool with_label = false,
bool float_label = true,
size_t classes = 1) const;
[[nodiscard]] std::shared_ptr<DMatrix> GenerateSparsePageDMatrix(std::string prefix,
bool with_label) const;
#if defined(XGBOOST_USE_CUDA)
std::shared_ptr<DMatrix> GenerateDeviceDMatrix();
std::shared_ptr<DMatrix> GenerateDeviceDMatrix(bool with_label);
#endif
std::shared_ptr<DMatrix> GenerateQuantileDMatrix();
std::shared_ptr<DMatrix> GenerateQuantileDMatrix(bool with_label);
};
// Generate an empty DMatrix, mostly for its meta info.
@ -443,11 +455,11 @@ class ArrayIterForTest {
size_t static constexpr Cols() { return 13; }
public:
std::string AsArray() const { return interface_; }
[[nodiscard]] std::string AsArray() const { return interface_; }
virtual int Next() = 0;
virtual void Reset() { iter_ = 0; }
size_t Iter() const { return iter_; }
[[nodiscard]] std::size_t Iter() const { return iter_; }
auto Proxy() -> decltype(proxy_) { return proxy_; }
explicit ArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches);

View File

@ -216,7 +216,7 @@ void TestUpdatePredictionCache(bool use_subsampling) {
TEST(CPUPredictor, GHistIndex) {
size_t constexpr kRows{128}, kCols{16}, kBins{64};
auto p_hist = RandomDataGenerator{kRows, kCols, 0.0}.Bins(kBins).GenerateQuantileDMatrix();
auto p_hist = RandomDataGenerator{kRows, kCols, 0.0}.Bins(kBins).GenerateQuantileDMatrix(false);
HostDeviceVector<float> storage(kRows * kCols);
auto columnar = RandomDataGenerator{kRows, kCols, 0.0}.GenerateArrayInterface(&storage);
auto adapter = data::ArrayAdapter(columnar.c_str());

View File

@ -123,7 +123,8 @@ TEST(GPUPredictor, EllpackBasic) {
auto ctx = MakeCUDACtx(0);
for (size_t bins = 2; bins < 258; bins += 16) {
size_t rows = bins * 16;
auto p_m = RandomDataGenerator{rows, kCols, 0.0}.Bins(bins).Device(0).GenerateDeviceDMatrix();
auto p_m =
RandomDataGenerator{rows, kCols, 0.0}.Bins(bins).Device(0).GenerateDeviceDMatrix(false);
ASSERT_FALSE(p_m->PageExists<SparsePage>());
TestPredictionFromGradientIndex<EllpackPage>(&ctx, rows, kCols, p_m);
TestPredictionFromGradientIndex<EllpackPage>(&ctx, bins, kCols, p_m);
@ -133,7 +134,7 @@ TEST(GPUPredictor, EllpackBasic) {
TEST(GPUPredictor, EllpackTraining) {
size_t constexpr kRows { 128 }, kCols { 16 }, kBins { 64 };
auto p_ellpack =
RandomDataGenerator{kRows, kCols, 0.0}.Bins(kBins).Device(0).GenerateDeviceDMatrix();
RandomDataGenerator{kRows, kCols, 0.0}.Bins(kBins).Device(0).GenerateDeviceDMatrix(false);
HostDeviceVector<float> storage(kRows * kCols);
auto columnar = RandomDataGenerator{kRows, kCols, 0.0}
.Device(0)
@ -219,7 +220,7 @@ TEST(GPUPredictor, ShapStump) {
gbm::GBTreeModel model(&mparam, &ctx);
std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
trees.push_back(std::make_unique<RegTree>());
model.CommitModelGroup(std::move(trees), 0);
auto gpu_lparam = MakeCUDACtx(0);
@ -246,7 +247,7 @@ TEST(GPUPredictor, Shap) {
gbm::GBTreeModel model(&mparam, &ctx);
std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
trees.push_back(std::make_unique<RegTree>());
trees[0]->ExpandNode(0, 0, 0.5, true, 1.0, -1.0, 1.0, 0.0, 5.0, 2.0, 3.0);
model.CommitModelGroup(std::move(trees), 0);