Support categorical data in ellpack. (#6140)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019 XGBoost contributors
|
||||
* Copyright 2019-2020 XGBoost contributors
|
||||
*/
|
||||
#include <xgboost/base.h>
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "../histogram_helpers.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "../../../src/common/categorical.h"
|
||||
#include "../../../src/common/hist_util.h"
|
||||
#include "../../../src/data/ellpack_page.cuh"
|
||||
|
||||
@@ -77,6 +78,45 @@ TEST(EllpackPage, BuildGidxSparse) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(EllpackPage, FromCategoricalBasic) {
|
||||
using common::AsCat;
|
||||
size_t constexpr kRows = 1000, kCats = 13, kCols = 1;
|
||||
size_t max_bins = 8;
|
||||
auto x = GenerateRandomCategoricalSingleColumn(kRows, kCats);
|
||||
auto m = GetDMatrixFromData(x, kRows, 1);
|
||||
auto& h_ft = m->Info().feature_types.HostVector();
|
||||
h_ft.resize(kCols, FeatureType::kCategorical);
|
||||
|
||||
BatchParam p(0, max_bins);
|
||||
auto ellpack = EllpackPage(m.get(), p);
|
||||
auto accessor = ellpack.Impl()->GetDeviceAccessor(0);
|
||||
ASSERT_EQ(kCats, accessor.NumBins());
|
||||
|
||||
auto x_copy = x;
|
||||
std::sort(x_copy.begin(), x_copy.end());
|
||||
auto n_uniques = std::unique(x_copy.begin(), x_copy.end()) - x_copy.begin();
|
||||
ASSERT_EQ(n_uniques, kCats);
|
||||
|
||||
std::vector<uint32_t> h_cuts_ptr(accessor.feature_segments.size());
|
||||
dh::CopyDeviceSpanToVector(&h_cuts_ptr, accessor.feature_segments);
|
||||
std::vector<float> h_cuts_values(accessor.gidx_fvalue_map.size());
|
||||
dh::CopyDeviceSpanToVector(&h_cuts_values, accessor.gidx_fvalue_map);
|
||||
|
||||
ASSERT_EQ(h_cuts_ptr.size(), 2);
|
||||
ASSERT_EQ(h_cuts_values.size(), kCats);
|
||||
|
||||
std::vector<common::CompressedByteT> const &h_gidx_buffer =
|
||||
ellpack.Impl()->gidx_buffer.HostVector();
|
||||
auto h_gidx_iter = common::CompressedIterator<uint32_t>(
|
||||
h_gidx_buffer.data(), accessor.NumSymbols());
|
||||
|
||||
for (size_t i = 0; i < x.size(); ++i) {
|
||||
auto bin = h_gidx_iter[i];
|
||||
auto bin_value = h_cuts_values.at(bin);
|
||||
ASSERT_EQ(AsCat(x[i]), AsCat(bin_value));
|
||||
}
|
||||
}
|
||||
|
||||
struct ReadRowFunction {
|
||||
EllpackDeviceAccessor matrix;
|
||||
int row;
|
||||
|
||||
Reference in New Issue
Block a user