Improve test coverage with predictor configuration. (#9354)
* Improve test coverage with predictor configuration. - Test with ext memory. - Test with QDM. - Test with dart.
This commit is contained in:
parent
6c9c8a9001
commit
645037e376
@ -9,9 +9,10 @@
|
||||
#include <xgboost/logging.h> // for CHECK_GE
|
||||
#include <xgboost/parameter.h> // for XGBoostParameter
|
||||
|
||||
#include <cstdint> // for int16_t, int32_t, int64_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string, to_string
|
||||
#include <cstdint> // for int16_t, int32_t, int64_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string, to_string
|
||||
#include <type_traits> // for invoke_result_t, is_same_v
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@ -152,6 +153,25 @@ struct Context : public XGBoostParameter<Context> {
|
||||
ctx.gpu_id = kCpuId;
|
||||
return ctx;
|
||||
}
|
||||
/**
|
||||
* @brief Call function based on the current device.
|
||||
*/
|
||||
template <typename CPUFn, typename CUDAFn>
|
||||
decltype(auto) DispatchDevice(CPUFn&& cpu_fn, CUDAFn&& cuda_fn) const {
|
||||
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<CUDAFn>>);
|
||||
switch (this->Device().device) {
|
||||
case DeviceOrd::kCPU:
|
||||
return cpu_fn();
|
||||
case DeviceOrd::kCUDA:
|
||||
return cuda_fn();
|
||||
default:
|
||||
// Do not use the device name as this is likely an internal error, the name
|
||||
// wouldn't be valid.
|
||||
LOG(FATAL) << "Unknown device type:" << static_cast<std::int16_t>(this->Device().device);
|
||||
break;
|
||||
}
|
||||
return std::invoke_result_t<CPUFn>();
|
||||
}
|
||||
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(Context) {
|
||||
|
||||
@ -6,24 +6,22 @@
|
||||
*/
|
||||
#pragma once
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/cache.h> // DMatrixCache
|
||||
#include <xgboost/cache.h> // for DMatrixCache
|
||||
#include <xgboost/context.h> // for Context
|
||||
#include <xgboost/context.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
|
||||
#include <functional> // std::function
|
||||
#include <memory>
|
||||
#include <functional> // for function
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string>
|
||||
#include <thread> // for get_id
|
||||
#include <utility> // for make_pair
|
||||
#include <vector>
|
||||
|
||||
// Forward declarations
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
namespace xgboost::gbm {
|
||||
struct GBTreeModel;
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::gbm
|
||||
|
||||
namespace xgboost {
|
||||
/**
|
||||
|
||||
@ -47,5 +47,9 @@ inline void MaxFeatureSize(std::uint64_t n_features) {
|
||||
<< "Unfortunately, XGBoost does not support data matrices with "
|
||||
<< std::numeric_limits<bst_feature_t>::max() << " features or greater";
|
||||
}
|
||||
|
||||
constexpr StringView InplacePredictProxy() {
|
||||
return "Inplace predict accepts only DMatrixProxy as input.";
|
||||
}
|
||||
} // namespace xgboost::error
|
||||
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
||||
|
||||
@ -68,6 +68,7 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
|
||||
}
|
||||
|
||||
std::size_t Write(GHistIndexMatrix const& page, common::AlignedFileWriteStream* fo) override {
|
||||
CHECK_NE(page.index.Size(), 0) << "Empty page is not supported.";
|
||||
std::size_t bytes = 0;
|
||||
bytes += WriteHistogramCuts(page.cut, fo);
|
||||
// indptr
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
/*!
|
||||
* Copyright 2021-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2021-2023, XGBoost Contributors
|
||||
*/
|
||||
#include "gradient_index_page_source.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
namespace xgboost::data {
|
||||
void GradientIndexPageSource::Fetch() {
|
||||
if (!this->ReadCache()) {
|
||||
if (count_ != 0 && !sync_) {
|
||||
@ -21,5 +20,4 @@ void GradientIndexPageSource::Fetch() {
|
||||
this->WriteCache();
|
||||
}
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::data
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "../common/common.h"
|
||||
#include "../common/error_msg.h" // for UnknownDevice
|
||||
#include "../common/error_msg.h" // for UnknownDevice, InplacePredictProxy
|
||||
#include "../common/random.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/timer.h"
|
||||
@ -542,6 +542,18 @@ void GBTree::PredictBatchImpl(DMatrix* p_fmat, PredictionCacheEntry* out_preds,
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
inline void MismatchedDevices(Context const* booster, Context const* data) {
|
||||
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
|
||||
<< "is running on: " << booster->DeviceName()
|
||||
<< ", while the input data is on: " << data->DeviceName() << ".\n"
|
||||
<< R"(Potential solutions:
|
||||
- Use a data structure that matches the device ordinal in the booster.
|
||||
- Set the device for booster before call to inplace_predict.
|
||||
)";
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
|
||||
bst_layer_t layer_begin, bst_layer_t layer_end) {
|
||||
// dispatch to const function.
|
||||
@ -555,24 +567,26 @@ void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
|
||||
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
|
||||
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
|
||||
if (p_m->Ctx()->Device() != this->ctx_->Device()) {
|
||||
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
|
||||
<< "is running on: " << this->ctx_->DeviceName()
|
||||
<< ", while the input data is on: " << p_m->Ctx()->DeviceName() << ".";
|
||||
MismatchedDevices(this->ctx_, p_m->Ctx());
|
||||
CHECK_EQ(out_preds->version, 0);
|
||||
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
|
||||
auto any_adapter = proxy->Adapter();
|
||||
CHECK(proxy) << error::InplacePredictProxy();
|
||||
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);
|
||||
this->PredictBatchImpl(p_fmat.get(), out_preds, false, layer_begin, layer_end);
|
||||
return;
|
||||
}
|
||||
|
||||
if (this->ctx_->IsCPU()) {
|
||||
this->cpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (p_m->Ctx()->IsCUDA()) {
|
||||
CHECK(this->gpu_predictor_);
|
||||
this->gpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end);
|
||||
} else {
|
||||
LOG(FATAL) << error::UnknownDevice();
|
||||
bool known_type = this->ctx_->DispatchDevice(
|
||||
[&, begin = tree_begin, end = tree_end] {
|
||||
return this->cpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, begin, end);
|
||||
},
|
||||
[&, begin = tree_begin, end = tree_end] {
|
||||
return this->gpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, begin, end);
|
||||
});
|
||||
if (!known_type) {
|
||||
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
|
||||
CHECK(proxy) << error::InplacePredictProxy();
|
||||
LOG(FATAL) << "Unknown data type for inplace prediction:" << proxy->Adapter().type().name();
|
||||
}
|
||||
}
|
||||
|
||||
@ -808,11 +822,9 @@ class Dart : public GBTree {
|
||||
auto n_groups = model_.learner_model_param->num_output_group;
|
||||
|
||||
if (ctx_->Device() != p_fmat->Ctx()->Device()) {
|
||||
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
|
||||
<< "is running on: " << this->ctx_->DeviceName()
|
||||
<< ", while the input data is on: " << p_fmat->Ctx()->DeviceName() << ".";
|
||||
MismatchedDevices(ctx_, p_fmat->Ctx());
|
||||
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_fmat);
|
||||
auto any_adapter = proxy->Adapter();
|
||||
CHECK(proxy) << error::InplacePredictProxy();
|
||||
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);
|
||||
this->PredictBatchImpl(p_fmat.get(), p_out_preds, false, layer_begin, layer_end);
|
||||
return;
|
||||
@ -825,20 +837,15 @@ class Dart : public GBTree {
|
||||
}
|
||||
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
|
||||
|
||||
auto get_predictor = [&]() -> Predictor const* {
|
||||
if (ctx_->IsCPU()) {
|
||||
return cpu_predictor_.get();
|
||||
} else if (ctx_->IsCUDA()) {
|
||||
CHECK(this->gpu_predictor_);
|
||||
return gpu_predictor_.get();
|
||||
} else {
|
||||
LOG(FATAL) << error::UnknownDevice();
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
auto predict_impl = [&](size_t i) {
|
||||
predts.predictions.Fill(0);
|
||||
bool success{get_predictor()->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1)};
|
||||
bool success = this->ctx_->DispatchDevice(
|
||||
[&] {
|
||||
return cpu_predictor_->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
|
||||
},
|
||||
[&] {
|
||||
return gpu_predictor_->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
|
||||
});
|
||||
CHECK(success) << msg;
|
||||
};
|
||||
|
||||
@ -846,7 +853,15 @@ class Dart : public GBTree {
|
||||
for (bst_tree_t i = tree_begin; i < tree_end; ++i) {
|
||||
predict_impl(i);
|
||||
if (i == tree_begin) {
|
||||
get_predictor()->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, model_);
|
||||
this->ctx_->DispatchDevice(
|
||||
[&] {
|
||||
this->cpu_predictor_->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
|
||||
model_);
|
||||
},
|
||||
[&] {
|
||||
this->gpu_predictor_->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
|
||||
model_);
|
||||
});
|
||||
}
|
||||
// Multiple the tree weight
|
||||
auto w = this->weight_drop_.at(i);
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
#include "../common/bitfield.h" // for RBitField8
|
||||
#include "../common/categorical.h" // for IsCat, Decision
|
||||
#include "../common/common.h" // for DivRoundUp
|
||||
#include "../common/error_msg.h" // for InplacePredictProxy
|
||||
#include "../common/math.h" // for CheckNAN
|
||||
#include "../common/threading_utils.h" // for ParallelFor
|
||||
#include "../data/adapter.h" // for ArrayAdapter, CSRAdapter, CSRArrayAdapter
|
||||
@ -741,7 +742,7 @@ class CPUPredictor : public Predictor {
|
||||
PredictionCacheEntry *out_preds, uint32_t tree_begin,
|
||||
unsigned tree_end) const override {
|
||||
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
|
||||
CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input.";
|
||||
CHECK(proxy)<< error::InplacePredictProxy();
|
||||
CHECK(!p_m->Info().IsColumnSplit())
|
||||
<< "Inplace predict support for column-wise data split is not yet implemented.";
|
||||
auto x = proxy->Adapter();
|
||||
|
||||
@ -15,8 +15,9 @@
|
||||
#include "../common/bitfield.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/cuda_context.cuh"
|
||||
#include "../common/cuda_context.cuh" // for CUDAContext
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/error_msg.h" // for InplacePredictProxy
|
||||
#include "../data/device_adapter.cuh"
|
||||
#include "../data/ellpack_page.cuh"
|
||||
#include "../data/proxy_dmatrix.h"
|
||||
@ -989,7 +990,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
PredictionCacheEntry* out_preds, uint32_t tree_begin,
|
||||
unsigned tree_end) const override {
|
||||
auto proxy = dynamic_cast<data::DMatrixProxy*>(p_m.get());
|
||||
CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input.";
|
||||
CHECK(proxy) << error::InplacePredictProxy();
|
||||
auto x = proxy->Adapter();
|
||||
if (x.type() == typeid(std::shared_ptr<data::CupyAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::CupyAdapter,
|
||||
|
||||
@ -27,26 +27,31 @@
|
||||
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||
|
||||
namespace xgboost::data {
|
||||
TEST(GradientIndex, ExternalMemory) {
|
||||
TEST(GradientIndex, ExternalMemoryBaseRowID) {
|
||||
Context ctx;
|
||||
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(10000);
|
||||
auto p_fmat = RandomDataGenerator{4096, 256, 0.5}
|
||||
.Device(ctx.gpu_id)
|
||||
.Batches(8)
|
||||
.GenerateSparsePageDMatrix("cache", true);
|
||||
|
||||
std::vector<size_t> base_rowids;
|
||||
std::vector<float> hessian(dmat->Info().num_row_, 1);
|
||||
for (auto const &page : dmat->GetBatches<GHistIndexMatrix>(&ctx, {64, hessian, true})) {
|
||||
std::vector<float> hessian(p_fmat->Info().num_row_, 1);
|
||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, {64, hessian, true})) {
|
||||
base_rowids.push_back(page.base_rowid);
|
||||
}
|
||||
size_t i = 0;
|
||||
for (auto const &page : dmat->GetBatches<SparsePage>()) {
|
||||
|
||||
std::size_t i = 0;
|
||||
for (auto const &page : p_fmat->GetBatches<SparsePage>()) {
|
||||
ASSERT_EQ(base_rowids[i], page.base_rowid);
|
||||
++i;
|
||||
}
|
||||
|
||||
base_rowids.clear();
|
||||
for (auto const &page : dmat->GetBatches<GHistIndexMatrix>(&ctx, {64, hessian, false})) {
|
||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, {64, hessian, false})) {
|
||||
base_rowids.push_back(page.base_rowid);
|
||||
}
|
||||
i = 0;
|
||||
for (auto const &page : dmat->GetBatches<SparsePage>()) {
|
||||
for (auto const &page : p_fmat->GetBatches<SparsePage>()) {
|
||||
ASSERT_EQ(base_rowids[i], page.base_rowid);
|
||||
++i;
|
||||
}
|
||||
|
||||
@ -76,9 +76,11 @@ TEST(SparsePageDMatrix, LoadFile) {
|
||||
// allow caller to retain pages so they can process multiple pages at the same time.
|
||||
template <typename Page>
|
||||
void TestRetainPage() {
|
||||
auto m = CreateSparsePageDMatrix(10000);
|
||||
std::size_t n_batches = 4;
|
||||
auto p_fmat = RandomDataGenerator{1024, 128, 0.5f}.Batches(n_batches).GenerateSparsePageDMatrix(
|
||||
"cache", true);
|
||||
Context ctx;
|
||||
auto batches = m->GetBatches<Page>(&ctx);
|
||||
auto batches = p_fmat->GetBatches<Page>(&ctx);
|
||||
auto begin = batches.begin();
|
||||
auto end = batches.end();
|
||||
|
||||
@ -94,7 +96,7 @@ void TestRetainPage() {
|
||||
}
|
||||
ASSERT_EQ(pages.back().Size(), (*it).Size());
|
||||
}
|
||||
ASSERT_GE(iterators.size(), 2);
|
||||
ASSERT_GE(iterators.size(), n_batches);
|
||||
|
||||
for (size_t i = 0; i < iterators.size(); ++i) {
|
||||
ASSERT_EQ((*iterators[i]).Size(), pages.at(i).Size());
|
||||
@ -102,7 +104,7 @@ void TestRetainPage() {
|
||||
}
|
||||
|
||||
// make sure it's const and the caller can not modify the content of page.
|
||||
for (auto &page : m->GetBatches<Page>({&ctx})) {
|
||||
for (auto &page : p_fmat->GetBatches<Page>({&ctx})) {
|
||||
static_assert(std::is_const<std::remove_reference_t<decltype(page)>>::value);
|
||||
}
|
||||
}
|
||||
|
||||
@ -514,4 +514,86 @@ TEST(GBTree, PredictRange) {
|
||||
dmlc::Error);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GBTree, InplacePredictionError) {
|
||||
std::size_t n_samples{2048}, n_features{32};
|
||||
|
||||
auto test_ext_err = [&](std::string booster, Context const* ctx) {
|
||||
std::shared_ptr<DMatrix> p_fmat =
|
||||
RandomDataGenerator{n_samples, n_features, 0.5f}.Batches(2).GenerateSparsePageDMatrix(
|
||||
"cache", true);
|
||||
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
|
||||
learner->SetParam("booster", booster);
|
||||
ConfigLearnerByCtx(ctx, learner.get());
|
||||
learner->Configure();
|
||||
for (std::int32_t i = 0; i < 3; ++i) {
|
||||
learner->UpdateOneIter(i, p_fmat);
|
||||
}
|
||||
HostDeviceVector<float>* out_predt;
|
||||
ASSERT_THROW(
|
||||
{
|
||||
learner->InplacePredict(p_fmat, PredictionType::kValue,
|
||||
std::numeric_limits<float>::quiet_NaN(), &out_predt, 0, 0);
|
||||
},
|
||||
dmlc::Error);
|
||||
};
|
||||
|
||||
{
|
||||
Context ctx;
|
||||
test_ext_err("gbtree", &ctx);
|
||||
test_ext_err("dart", &ctx);
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
{
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
test_ext_err("gbtree", &ctx);
|
||||
test_ext_err("dart", &ctx);
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
auto test_qdm_err = [&](std::string booster, Context const* ctx) {
|
||||
std::shared_ptr<DMatrix> p_fmat;
|
||||
bst_bin_t max_bins = 16;
|
||||
auto rng = RandomDataGenerator{n_samples, n_features, 0.5f}.Device(ctx->gpu_id).Bins(max_bins);
|
||||
if (ctx->IsCPU()) {
|
||||
p_fmat = rng.GenerateQuantileDMatrix(true);
|
||||
} else {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
p_fmat = rng.GenerateDeviceDMatrix(true);
|
||||
#else
|
||||
CHECK(p_fmat);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
};
|
||||
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
|
||||
learner->SetParam("booster", booster);
|
||||
learner->SetParam("max_bin", std::to_string(max_bins));
|
||||
ConfigLearnerByCtx(ctx, learner.get());
|
||||
learner->Configure();
|
||||
for (std::int32_t i = 0; i < 3; ++i) {
|
||||
learner->UpdateOneIter(i, p_fmat);
|
||||
}
|
||||
HostDeviceVector<float>* out_predt;
|
||||
ASSERT_THROW(
|
||||
{
|
||||
learner->InplacePredict(p_fmat, PredictionType::kValue,
|
||||
std::numeric_limits<float>::quiet_NaN(), &out_predt, 0, 0);
|
||||
},
|
||||
dmlc::Error);
|
||||
};
|
||||
|
||||
{
|
||||
Context ctx;
|
||||
test_qdm_err("gbtree", &ctx);
|
||||
test_qdm_err("dart", &ctx);
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
{
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
test_qdm_err("gbtree", &ctx);
|
||||
test_qdm_err("dart", &ctx);
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -61,7 +61,6 @@ void TestInplaceFallback(Context const* ctx) {
|
||||
learner->InplacePredict(p_m, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),
|
||||
&out_predt, 0, 0);
|
||||
auto output = testing::internal::GetCapturedStderr();
|
||||
std::cout << "output:" << output << std::endl;
|
||||
ASSERT_NE(output.find("Falling back"), std::string::npos);
|
||||
|
||||
// test when the contexts match
|
||||
|
||||
@ -210,6 +210,16 @@ SimpleLCG::StateType SimpleLCG::Max() const { return max(); }
|
||||
// Make sure it's compile time constant.
|
||||
static_assert(SimpleLCG::max() - SimpleLCG::min());
|
||||
|
||||
void RandomDataGenerator::GenerateLabels(std::shared_ptr<DMatrix> p_fmat) const {
|
||||
RandomDataGenerator{p_fmat->Info().num_row_, this->n_targets_, 0.0f}.GenerateDense(
|
||||
p_fmat->Info().labels.Data());
|
||||
CHECK_EQ(p_fmat->Info().labels.Size(), this->rows_ * this->n_targets_);
|
||||
p_fmat->Info().labels.Reshape(this->rows_, this->n_targets_);
|
||||
if (device_ != Context::kCpuId) {
|
||||
p_fmat->Info().labels.SetDevice(device_);
|
||||
}
|
||||
}
|
||||
|
||||
void RandomDataGenerator::GenerateDense(HostDeviceVector<float> *out) const {
|
||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(lower_, upper_);
|
||||
CHECK(out);
|
||||
@ -363,8 +373,9 @@ void RandomDataGenerator::GenerateCSR(
|
||||
CHECK_EQ(columns->Size(), value->Size());
|
||||
}
|
||||
|
||||
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label,
|
||||
size_t classes) const {
|
||||
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label,
|
||||
bool float_label,
|
||||
size_t classes) const {
|
||||
HostDeviceVector<float> data;
|
||||
HostDeviceVector<bst_row_t> rptrs;
|
||||
HostDeviceVector<bst_feature_t> columns;
|
||||
@ -406,10 +417,58 @@ std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label, b
|
||||
return out;
|
||||
}
|
||||
|
||||
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateQuantileDMatrix() {
|
||||
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateSparsePageDMatrix(
|
||||
std::string prefix, bool with_label) const {
|
||||
CHECK_GE(this->rows_, this->n_batches_);
|
||||
CHECK_GE(this->n_batches_, 1)
|
||||
<< "Must set the n_batches before generating an external memory DMatrix.";
|
||||
std::unique_ptr<ArrayIterForTest> iter;
|
||||
if (device_ == Context::kCpuId) {
|
||||
iter = std::make_unique<NumpyArrayIterForTest>(this->sparsity_, rows_, cols_, n_batches_);
|
||||
} else {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
iter = std::make_unique<CudaArrayIterForTest>(this->sparsity_, rows_, cols_, n_batches_);
|
||||
#else
|
||||
CHECK(iter);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
std::unique_ptr<DMatrix> dmat{
|
||||
DMatrix::Create(static_cast<DataIterHandle>(iter.get()), iter->Proxy(), Reset, Next,
|
||||
std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(), prefix)};
|
||||
|
||||
auto row_page_path =
|
||||
data::MakeId(prefix, dynamic_cast<data::SparsePageDMatrix*>(dmat.get())) + ".row.page";
|
||||
EXPECT_TRUE(FileExists(row_page_path)) << row_page_path;
|
||||
|
||||
// Loop over the batches and count the number of pages
|
||||
std::size_t batch_count = 0;
|
||||
bst_row_t row_count = 0;
|
||||
for (const auto& batch : dmat->GetBatches<xgboost::SparsePage>()) {
|
||||
batch_count++;
|
||||
row_count += batch.Size();
|
||||
CHECK_NE(batch.data.Size(), 0);
|
||||
}
|
||||
|
||||
EXPECT_EQ(batch_count, n_batches_);
|
||||
EXPECT_EQ(row_count, dmat->Info().num_row_);
|
||||
|
||||
if (with_label) {
|
||||
RandomDataGenerator{dmat->Info().num_row_, this->n_targets_, 0.0f}.GenerateDense(
|
||||
dmat->Info().labels.Data());
|
||||
CHECK_EQ(dmat->Info().labels.Size(), this->rows_ * this->n_targets_);
|
||||
dmat->Info().labels.Reshape(this->rows_, this->n_targets_);
|
||||
}
|
||||
return dmat;
|
||||
}
|
||||
|
||||
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateQuantileDMatrix(bool with_label) {
|
||||
NumpyArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1};
|
||||
auto m = std::make_shared<data::IterativeDMatrix>(
|
||||
&iter, iter.Proxy(), nullptr, Reset, Next, std::numeric_limits<float>::quiet_NaN(), 0, bins_);
|
||||
if (with_label) {
|
||||
this->GenerateLabels(m);
|
||||
}
|
||||
return m;
|
||||
}
|
||||
|
||||
|
||||
@ -24,10 +24,13 @@ int CudaArrayIterForTest::Next() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDeviceDMatrix() {
|
||||
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDeviceDMatrix(bool with_label) {
|
||||
CudaArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1};
|
||||
auto m = std::make_shared<data::IterativeDMatrix>(
|
||||
&iter, iter.Proxy(), nullptr, Reset, Next, std::numeric_limits<float>::quiet_NaN(), 0, bins_);
|
||||
if (with_label) {
|
||||
this->GenerateLabels(m);
|
||||
}
|
||||
return m;
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -238,15 +238,18 @@ class RandomDataGenerator {
|
||||
bst_target_t n_targets_{1};
|
||||
|
||||
std::int32_t device_{Context::kCpuId};
|
||||
std::size_t n_batches_{0};
|
||||
std::uint64_t seed_{0};
|
||||
SimpleLCG lcg_;
|
||||
|
||||
std::size_t bins_{0};
|
||||
bst_bin_t bins_{0};
|
||||
std::vector<FeatureType> ft_;
|
||||
bst_cat_t max_cat_;
|
||||
|
||||
Json ArrayInterfaceImpl(HostDeviceVector<float>* storage, size_t rows, size_t cols) const;
|
||||
|
||||
void GenerateLabels(std::shared_ptr<DMatrix> p_fmat) const;
|
||||
|
||||
public:
|
||||
RandomDataGenerator(bst_row_t rows, size_t cols, float sparsity)
|
||||
: rows_{rows}, cols_{cols}, sparsity_{sparsity}, lcg_{seed_} {}
|
||||
@ -263,12 +266,16 @@ class RandomDataGenerator {
|
||||
device_ = d;
|
||||
return *this;
|
||||
}
|
||||
RandomDataGenerator& Batches(std::size_t n_batches) {
|
||||
n_batches_ = n_batches;
|
||||
return *this;
|
||||
}
|
||||
RandomDataGenerator& Seed(uint64_t s) {
|
||||
seed_ = s;
|
||||
lcg_.Seed(seed_);
|
||||
return *this;
|
||||
}
|
||||
RandomDataGenerator& Bins(size_t b) {
|
||||
RandomDataGenerator& Bins(bst_bin_t b) {
|
||||
bins_ = b;
|
||||
return *this;
|
||||
}
|
||||
@ -309,12 +316,17 @@ class RandomDataGenerator {
|
||||
void GenerateCSR(HostDeviceVector<float>* value, HostDeviceVector<bst_row_t>* row_ptr,
|
||||
HostDeviceVector<bst_feature_t>* columns) const;
|
||||
|
||||
std::shared_ptr<DMatrix> GenerateDMatrix(bool with_label = false, bool float_label = true,
|
||||
size_t classes = 1) const;
|
||||
[[nodiscard]] std::shared_ptr<DMatrix> GenerateDMatrix(bool with_label = false,
|
||||
bool float_label = true,
|
||||
size_t classes = 1) const;
|
||||
|
||||
[[nodiscard]] std::shared_ptr<DMatrix> GenerateSparsePageDMatrix(std::string prefix,
|
||||
bool with_label) const;
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
std::shared_ptr<DMatrix> GenerateDeviceDMatrix();
|
||||
std::shared_ptr<DMatrix> GenerateDeviceDMatrix(bool with_label);
|
||||
#endif
|
||||
std::shared_ptr<DMatrix> GenerateQuantileDMatrix();
|
||||
std::shared_ptr<DMatrix> GenerateQuantileDMatrix(bool with_label);
|
||||
};
|
||||
|
||||
// Generate an empty DMatrix, mostly for its meta info.
|
||||
@ -443,11 +455,11 @@ class ArrayIterForTest {
|
||||
size_t static constexpr Cols() { return 13; }
|
||||
|
||||
public:
|
||||
std::string AsArray() const { return interface_; }
|
||||
[[nodiscard]] std::string AsArray() const { return interface_; }
|
||||
|
||||
virtual int Next() = 0;
|
||||
virtual void Reset() { iter_ = 0; }
|
||||
size_t Iter() const { return iter_; }
|
||||
[[nodiscard]] std::size_t Iter() const { return iter_; }
|
||||
auto Proxy() -> decltype(proxy_) { return proxy_; }
|
||||
|
||||
explicit ArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches);
|
||||
|
||||
@ -216,7 +216,7 @@ void TestUpdatePredictionCache(bool use_subsampling) {
|
||||
|
||||
TEST(CPUPredictor, GHistIndex) {
|
||||
size_t constexpr kRows{128}, kCols{16}, kBins{64};
|
||||
auto p_hist = RandomDataGenerator{kRows, kCols, 0.0}.Bins(kBins).GenerateQuantileDMatrix();
|
||||
auto p_hist = RandomDataGenerator{kRows, kCols, 0.0}.Bins(kBins).GenerateQuantileDMatrix(false);
|
||||
HostDeviceVector<float> storage(kRows * kCols);
|
||||
auto columnar = RandomDataGenerator{kRows, kCols, 0.0}.GenerateArrayInterface(&storage);
|
||||
auto adapter = data::ArrayAdapter(columnar.c_str());
|
||||
|
||||
@ -123,7 +123,8 @@ TEST(GPUPredictor, EllpackBasic) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
for (size_t bins = 2; bins < 258; bins += 16) {
|
||||
size_t rows = bins * 16;
|
||||
auto p_m = RandomDataGenerator{rows, kCols, 0.0}.Bins(bins).Device(0).GenerateDeviceDMatrix();
|
||||
auto p_m =
|
||||
RandomDataGenerator{rows, kCols, 0.0}.Bins(bins).Device(0).GenerateDeviceDMatrix(false);
|
||||
ASSERT_FALSE(p_m->PageExists<SparsePage>());
|
||||
TestPredictionFromGradientIndex<EllpackPage>(&ctx, rows, kCols, p_m);
|
||||
TestPredictionFromGradientIndex<EllpackPage>(&ctx, bins, kCols, p_m);
|
||||
@ -133,7 +134,7 @@ TEST(GPUPredictor, EllpackBasic) {
|
||||
TEST(GPUPredictor, EllpackTraining) {
|
||||
size_t constexpr kRows { 128 }, kCols { 16 }, kBins { 64 };
|
||||
auto p_ellpack =
|
||||
RandomDataGenerator{kRows, kCols, 0.0}.Bins(kBins).Device(0).GenerateDeviceDMatrix();
|
||||
RandomDataGenerator{kRows, kCols, 0.0}.Bins(kBins).Device(0).GenerateDeviceDMatrix(false);
|
||||
HostDeviceVector<float> storage(kRows * kCols);
|
||||
auto columnar = RandomDataGenerator{kRows, kCols, 0.0}
|
||||
.Device(0)
|
||||
@ -219,7 +220,7 @@ TEST(GPUPredictor, ShapStump) {
|
||||
gbm::GBTreeModel model(&mparam, &ctx);
|
||||
|
||||
std::vector<std::unique_ptr<RegTree>> trees;
|
||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
|
||||
trees.push_back(std::make_unique<RegTree>());
|
||||
model.CommitModelGroup(std::move(trees), 0);
|
||||
|
||||
auto gpu_lparam = MakeCUDACtx(0);
|
||||
@ -246,7 +247,7 @@ TEST(GPUPredictor, Shap) {
|
||||
gbm::GBTreeModel model(&mparam, &ctx);
|
||||
|
||||
std::vector<std::unique_ptr<RegTree>> trees;
|
||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
|
||||
trees.push_back(std::make_unique<RegTree>());
|
||||
trees[0]->ExpandNode(0, 0, 0.5, true, 1.0, -1.0, 1.0, 0.0, 5.0, 2.0, 3.0);
|
||||
model.CommitModelGroup(std::move(trees), 0);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user