Test CPU histogram with cat data. (#7465)
This commit is contained in:
parent
24be04e848
commit
bf7bb575b4
@ -2,7 +2,11 @@
|
||||
* Copyright 2018-2021 by Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../../helpers.h"
|
||||
#include "../../categorical_helpers.h"
|
||||
|
||||
#include "../../../../src/common/categorical.h"
|
||||
#include "../../../../src/tree/hist/histogram.h"
|
||||
#include "../../../../src/tree/updater_quantile_hist.h"
|
||||
|
||||
@ -311,9 +315,72 @@ TEST(CPUHistogram, BuildHist) {
|
||||
TestBuildHistogram<double>(false);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void TestHistogramCategorical(size_t n_categories) {
|
||||
size_t constexpr kRows = 340;
|
||||
int32_t constexpr kBins = 256;
|
||||
auto x = GenerateRandomCategoricalSingleColumn(kRows, n_categories);
|
||||
auto cat_m = GetDMatrixFromData(x, kRows, 1);
|
||||
cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
||||
BatchParam batch_param{0, static_cast<int32_t>(kBins)};
|
||||
|
||||
RegTree tree;
|
||||
CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f);
|
||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build;
|
||||
nodes_for_explicit_hist_build.push_back(node);
|
||||
|
||||
auto gpair = GenerateRandomGradients(kRows, 0, 2);
|
||||
|
||||
RowSetCollection row_set_collection;
|
||||
row_set_collection.Clear();
|
||||
std::vector<size_t> &row_indices = *row_set_collection.Data();
|
||||
row_indices.resize(kRows);
|
||||
std::iota(row_indices.begin(), row_indices.end(), 0);
|
||||
row_set_collection.Init();
|
||||
|
||||
/**
|
||||
* Generate hist with cat data.
|
||||
*/
|
||||
HistogramBuilder<double, CPUExpandEntry> cat_hist;
|
||||
for (auto const &gidx : cat_m->GetBatches<GHistIndexMatrix>(
|
||||
{GenericParameter::kCpuId, kBins})) {
|
||||
auto total_bins = gidx.cut.TotalBins();
|
||||
cat_hist.Reset(total_bins, {GenericParameter::kCpuId, kBins},
|
||||
omp_get_max_threads(), 1, false);
|
||||
cat_hist.BuildHist(0, gidx, &tree, row_set_collection,
|
||||
nodes_for_explicit_hist_build, {}, gpair.HostVector());
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate hist with one hot encoded data.
|
||||
*/
|
||||
auto x_encoded = OneHotEncodeFeature(x, n_categories);
|
||||
auto encode_m = GetDMatrixFromData(x_encoded, kRows, n_categories);
|
||||
HistogramBuilder<double, CPUExpandEntry> onehot_hist;
|
||||
for (auto const &gidx : encode_m->GetBatches<GHistIndexMatrix>(
|
||||
{GenericParameter::kCpuId, kBins})) {
|
||||
auto total_bins = gidx.cut.TotalBins();
|
||||
onehot_hist.Reset(total_bins, {GenericParameter::kCpuId, kBins},
|
||||
omp_get_max_threads(), 1, false);
|
||||
onehot_hist.BuildHist(0, gidx, &tree, row_set_collection,
|
||||
nodes_for_explicit_hist_build, {},
|
||||
gpair.HostVector());
|
||||
}
|
||||
|
||||
auto cat = cat_hist.Histogram()[0];
|
||||
auto onehot = onehot_hist.Histogram()[0];
|
||||
ValidateCategoricalHistogram(n_categories, onehot, cat);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(CPUHistogram, Categorical) {
|
||||
for (size_t n_categories = 2; n_categories < 8; ++n_categories) {
|
||||
TestHistogramCategorical(n_categories);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CPUHistogram, ExternalMemory) {
|
||||
size_t constexpr kEntries = 1 << 16;
|
||||
|
||||
int32_t constexpr kBins = 32;
|
||||
auto m = CreateSparsePageDMatrix(kEntries, "cache");
|
||||
std::vector<size_t> partition_size(1, 0);
|
||||
@ -358,8 +425,8 @@ TEST(CPUHistogram, ExternalMemory) {
|
||||
size_t page_idx{0};
|
||||
for (auto const &page : m->GetBatches<GHistIndexMatrix>(
|
||||
{GenericParameter::kCpuId, kBins, hess})) {
|
||||
multi_build.BuildHist(page_idx, space, page, &tree,
|
||||
rows_set.at(page_idx), nodes, {}, h_gpair);
|
||||
multi_build.BuildHist(page_idx, space, page, &tree, rows_set.at(page_idx), nodes, {},
|
||||
h_gpair);
|
||||
++page_idx;
|
||||
}
|
||||
ASSERT_EQ(page_idx, 2);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user