Prepare gradient index for Quantile DMatrix. (#8103)
* Prepare gradient index for Quantile DMatrix. - Implement push batch with adapter batch. - Implement `GetFvalue` for prediction.
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include "../../../src/common/column_matrix.h"
|
||||
#include "../../../src/data/gradient_index.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
@@ -65,5 +66,46 @@ TEST(GradientIndex, FromCategoricalBasic) {
|
||||
ASSERT_EQ(common::AsCat(x[i]), common::AsCat(bin_value));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GradientIndex, PushBatch) {
|
||||
size_t constexpr kRows = 64, kCols = 4;
|
||||
bst_bin_t max_bins = 64;
|
||||
float st = 0.5;
|
||||
|
||||
auto test = [&](float sparisty) {
|
||||
auto m = RandomDataGenerator{kRows, kCols, sparisty}.GenerateDMatrix(true);
|
||||
auto cuts = common::SketchOnDMatrix(m.get(), max_bins, common::OmpGetNumThreads(0), false, {});
|
||||
common::HistogramCuts copy_cuts = cuts;
|
||||
|
||||
ASSERT_EQ(m->Info().num_row_, kRows);
|
||||
ASSERT_EQ(m->Info().num_col_, kCols);
|
||||
GHistIndexMatrix gmat{m->Info(), std::move(copy_cuts), max_bins};
|
||||
|
||||
for (auto const &page : m->GetBatches<SparsePage>()) {
|
||||
SparsePageAdapterBatch batch{page.GetView()};
|
||||
gmat.PushAdapterBatch(m->Ctx(), 0, 0, batch, std::numeric_limits<float>::quiet_NaN(), {}, st,
|
||||
m->Info().num_row_);
|
||||
gmat.PushAdapterBatchColumns(m->Ctx(), batch, std::numeric_limits<float>::quiet_NaN(), 0);
|
||||
}
|
||||
for (auto const &page : m->GetBatches<GHistIndexMatrix>(BatchParam{max_bins, st})) {
|
||||
for (size_t i = 0; i < kRows; ++i) {
|
||||
for (size_t j = 0; j < kCols; ++j) {
|
||||
auto v0 = gmat.GetFvalue(i, j, false);
|
||||
auto v1 = page.GetFvalue(i, j, false);
|
||||
if (sparisty == 0.0) {
|
||||
ASSERT_FALSE(std::isnan(v0));
|
||||
}
|
||||
if (!std::isnan(v0)) {
|
||||
ASSERT_EQ(v0, v1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
test(0.0f);
|
||||
test(0.5f);
|
||||
test(0.9f);
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user