Test CPU histogram with cat data. (#7465)
This commit is contained in:
parent
24be04e848
commit
bf7bb575b4
@ -2,7 +2,11 @@
|
|||||||
* Copyright 2018-2021 by Contributors
|
* Copyright 2018-2021 by Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include "../../helpers.h"
|
#include "../../helpers.h"
|
||||||
|
#include "../../categorical_helpers.h"
|
||||||
|
|
||||||
|
#include "../../../../src/common/categorical.h"
|
||||||
#include "../../../../src/tree/hist/histogram.h"
|
#include "../../../../src/tree/hist/histogram.h"
|
||||||
#include "../../../../src/tree/updater_quantile_hist.h"
|
#include "../../../../src/tree/updater_quantile_hist.h"
|
||||||
|
|
||||||
@ -311,9 +315,72 @@ TEST(CPUHistogram, BuildHist) {
|
|||||||
TestBuildHistogram<double>(false);
|
TestBuildHistogram<double>(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void TestHistogramCategorical(size_t n_categories) {
|
||||||
|
size_t constexpr kRows = 340;
|
||||||
|
int32_t constexpr kBins = 256;
|
||||||
|
auto x = GenerateRandomCategoricalSingleColumn(kRows, n_categories);
|
||||||
|
auto cat_m = GetDMatrixFromData(x, kRows, 1);
|
||||||
|
cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
||||||
|
BatchParam batch_param{0, static_cast<int32_t>(kBins)};
|
||||||
|
|
||||||
|
RegTree tree;
|
||||||
|
CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f);
|
||||||
|
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build;
|
||||||
|
nodes_for_explicit_hist_build.push_back(node);
|
||||||
|
|
||||||
|
auto gpair = GenerateRandomGradients(kRows, 0, 2);
|
||||||
|
|
||||||
|
RowSetCollection row_set_collection;
|
||||||
|
row_set_collection.Clear();
|
||||||
|
std::vector<size_t> &row_indices = *row_set_collection.Data();
|
||||||
|
row_indices.resize(kRows);
|
||||||
|
std::iota(row_indices.begin(), row_indices.end(), 0);
|
||||||
|
row_set_collection.Init();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate hist with cat data.
|
||||||
|
*/
|
||||||
|
HistogramBuilder<double, CPUExpandEntry> cat_hist;
|
||||||
|
for (auto const &gidx : cat_m->GetBatches<GHistIndexMatrix>(
|
||||||
|
{GenericParameter::kCpuId, kBins})) {
|
||||||
|
auto total_bins = gidx.cut.TotalBins();
|
||||||
|
cat_hist.Reset(total_bins, {GenericParameter::kCpuId, kBins},
|
||||||
|
omp_get_max_threads(), 1, false);
|
||||||
|
cat_hist.BuildHist(0, gidx, &tree, row_set_collection,
|
||||||
|
nodes_for_explicit_hist_build, {}, gpair.HostVector());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate hist with one hot encoded data.
|
||||||
|
*/
|
||||||
|
auto x_encoded = OneHotEncodeFeature(x, n_categories);
|
||||||
|
auto encode_m = GetDMatrixFromData(x_encoded, kRows, n_categories);
|
||||||
|
HistogramBuilder<double, CPUExpandEntry> onehot_hist;
|
||||||
|
for (auto const &gidx : encode_m->GetBatches<GHistIndexMatrix>(
|
||||||
|
{GenericParameter::kCpuId, kBins})) {
|
||||||
|
auto total_bins = gidx.cut.TotalBins();
|
||||||
|
onehot_hist.Reset(total_bins, {GenericParameter::kCpuId, kBins},
|
||||||
|
omp_get_max_threads(), 1, false);
|
||||||
|
onehot_hist.BuildHist(0, gidx, &tree, row_set_collection,
|
||||||
|
nodes_for_explicit_hist_build, {},
|
||||||
|
gpair.HostVector());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cat = cat_hist.Histogram()[0];
|
||||||
|
auto onehot = onehot_hist.Histogram()[0];
|
||||||
|
ValidateCategoricalHistogram(n_categories, onehot, cat);
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
TEST(CPUHistogram, Categorical) {
|
||||||
|
for (size_t n_categories = 2; n_categories < 8; ++n_categories) {
|
||||||
|
TestHistogramCategorical(n_categories);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CPUHistogram, ExternalMemory) {
|
TEST(CPUHistogram, ExternalMemory) {
|
||||||
size_t constexpr kEntries = 1 << 16;
|
size_t constexpr kEntries = 1 << 16;
|
||||||
|
|
||||||
int32_t constexpr kBins = 32;
|
int32_t constexpr kBins = 32;
|
||||||
auto m = CreateSparsePageDMatrix(kEntries, "cache");
|
auto m = CreateSparsePageDMatrix(kEntries, "cache");
|
||||||
std::vector<size_t> partition_size(1, 0);
|
std::vector<size_t> partition_size(1, 0);
|
||||||
@ -358,8 +425,8 @@ TEST(CPUHistogram, ExternalMemory) {
|
|||||||
size_t page_idx{0};
|
size_t page_idx{0};
|
||||||
for (auto const &page : m->GetBatches<GHistIndexMatrix>(
|
for (auto const &page : m->GetBatches<GHistIndexMatrix>(
|
||||||
{GenericParameter::kCpuId, kBins, hess})) {
|
{GenericParameter::kCpuId, kBins, hess})) {
|
||||||
multi_build.BuildHist(page_idx, space, page, &tree,
|
multi_build.BuildHist(page_idx, space, page, &tree, rows_set.at(page_idx), nodes, {},
|
||||||
rows_set.at(page_idx), nodes, {}, h_gpair);
|
h_gpair);
|
||||||
++page_idx;
|
++page_idx;
|
||||||
}
|
}
|
||||||
ASSERT_EQ(page_idx, 2);
|
ASSERT_EQ(page_idx, 2);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user