Support building gradient index with cat data. (#7371)
This commit is contained in:
@@ -23,5 +23,39 @@ TEST(GradientIndex, ExternalMemory) {
|
||||
++i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GradientIndex, FromCategoricalBasic) {
|
||||
size_t constexpr kRows = 1000, kCats = 13, kCols = 1;
|
||||
size_t max_bins = 8;
|
||||
auto x = GenerateRandomCategoricalSingleColumn(kRows, kCats);
|
||||
auto m = GetDMatrixFromData(x, kRows, 1);
|
||||
|
||||
auto &h_ft = m->Info().feature_types.HostVector();
|
||||
h_ft.resize(kCols, FeatureType::kCategorical);
|
||||
|
||||
BatchParam p(0, max_bins);
|
||||
GHistIndexMatrix gidx;
|
||||
|
||||
gidx.Init(m.get(), max_bins, {});
|
||||
|
||||
auto x_copy = x;
|
||||
std::sort(x_copy.begin(), x_copy.end());
|
||||
auto n_uniques = std::unique(x_copy.begin(), x_copy.end()) - x_copy.begin();
|
||||
ASSERT_EQ(n_uniques, kCats);
|
||||
|
||||
auto const &h_cut_ptr = gidx.cut.Ptrs();
|
||||
auto const &h_cut_values = gidx.cut.Values();
|
||||
|
||||
ASSERT_EQ(h_cut_ptr.size(), 2);
|
||||
ASSERT_EQ(h_cut_values.size(), kCats);
|
||||
|
||||
auto const &index = gidx.index;
|
||||
|
||||
for (size_t i = 0; i < x.size(); ++i) {
|
||||
auto bin = index[i];
|
||||
auto bin_value = h_cut_values.at(bin);
|
||||
ASSERT_EQ(common::AsCat(x[i]), common::AsCat(bin_value));
|
||||
}
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
/*!
|
||||
* Copyright 2020-2021 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../../../../src/data/ellpack_page.cuh"
|
||||
#include "../../../../src/tree/gpu_hist/gradient_based_sampler.cuh"
|
||||
#include "../../../../src/tree/param.h"
|
||||
#include "../../helpers.h"
|
||||
#include "dmlc/filesystem.h"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user