Improve test coverage with predictor configuration. (#9354)
* Improve test coverage with predictor configuration. - Test with ext memory. - Test with QDM. - Test with dart.
This commit is contained in:
@@ -210,6 +210,16 @@ SimpleLCG::StateType SimpleLCG::Max() const { return max(); }
|
||||
// Make sure it's compile time constant.
|
||||
static_assert(SimpleLCG::max() - SimpleLCG::min());
|
||||
|
||||
void RandomDataGenerator::GenerateLabels(std::shared_ptr<DMatrix> p_fmat) const {
|
||||
RandomDataGenerator{p_fmat->Info().num_row_, this->n_targets_, 0.0f}.GenerateDense(
|
||||
p_fmat->Info().labels.Data());
|
||||
CHECK_EQ(p_fmat->Info().labels.Size(), this->rows_ * this->n_targets_);
|
||||
p_fmat->Info().labels.Reshape(this->rows_, this->n_targets_);
|
||||
if (device_ != Context::kCpuId) {
|
||||
p_fmat->Info().labels.SetDevice(device_);
|
||||
}
|
||||
}
|
||||
|
||||
void RandomDataGenerator::GenerateDense(HostDeviceVector<float> *out) const {
|
||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(lower_, upper_);
|
||||
CHECK(out);
|
||||
@@ -363,8 +373,9 @@ void RandomDataGenerator::GenerateCSR(
|
||||
CHECK_EQ(columns->Size(), value->Size());
|
||||
}
|
||||
|
||||
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label,
|
||||
size_t classes) const {
|
||||
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label,
|
||||
bool float_label,
|
||||
size_t classes) const {
|
||||
HostDeviceVector<float> data;
|
||||
HostDeviceVector<bst_row_t> rptrs;
|
||||
HostDeviceVector<bst_feature_t> columns;
|
||||
@@ -406,10 +417,58 @@ std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label, b
|
||||
return out;
|
||||
}
|
||||
|
||||
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateQuantileDMatrix() {
|
||||
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateSparsePageDMatrix(
|
||||
std::string prefix, bool with_label) const {
|
||||
CHECK_GE(this->rows_, this->n_batches_);
|
||||
CHECK_GE(this->n_batches_, 1)
|
||||
<< "Must set the n_batches before generating an external memory DMatrix.";
|
||||
std::unique_ptr<ArrayIterForTest> iter;
|
||||
if (device_ == Context::kCpuId) {
|
||||
iter = std::make_unique<NumpyArrayIterForTest>(this->sparsity_, rows_, cols_, n_batches_);
|
||||
} else {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
iter = std::make_unique<CudaArrayIterForTest>(this->sparsity_, rows_, cols_, n_batches_);
|
||||
#else
|
||||
CHECK(iter);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
std::unique_ptr<DMatrix> dmat{
|
||||
DMatrix::Create(static_cast<DataIterHandle>(iter.get()), iter->Proxy(), Reset, Next,
|
||||
std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(), prefix)};
|
||||
|
||||
auto row_page_path =
|
||||
data::MakeId(prefix, dynamic_cast<data::SparsePageDMatrix*>(dmat.get())) + ".row.page";
|
||||
EXPECT_TRUE(FileExists(row_page_path)) << row_page_path;
|
||||
|
||||
// Loop over the batches and count the number of pages
|
||||
std::size_t batch_count = 0;
|
||||
bst_row_t row_count = 0;
|
||||
for (const auto& batch : dmat->GetBatches<xgboost::SparsePage>()) {
|
||||
batch_count++;
|
||||
row_count += batch.Size();
|
||||
CHECK_NE(batch.data.Size(), 0);
|
||||
}
|
||||
|
||||
EXPECT_EQ(batch_count, n_batches_);
|
||||
EXPECT_EQ(row_count, dmat->Info().num_row_);
|
||||
|
||||
if (with_label) {
|
||||
RandomDataGenerator{dmat->Info().num_row_, this->n_targets_, 0.0f}.GenerateDense(
|
||||
dmat->Info().labels.Data());
|
||||
CHECK_EQ(dmat->Info().labels.Size(), this->rows_ * this->n_targets_);
|
||||
dmat->Info().labels.Reshape(this->rows_, this->n_targets_);
|
||||
}
|
||||
return dmat;
|
||||
}
|
||||
|
||||
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateQuantileDMatrix(bool with_label) {
|
||||
NumpyArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1};
|
||||
auto m = std::make_shared<data::IterativeDMatrix>(
|
||||
&iter, iter.Proxy(), nullptr, Reset, Next, std::numeric_limits<float>::quiet_NaN(), 0, bins_);
|
||||
if (with_label) {
|
||||
this->GenerateLabels(m);
|
||||
}
|
||||
return m;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user