Categorical data support in CPU sketching. (#7221)
This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
/*!
|
||||
* Copyright 2019-2021 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
@@ -388,5 +391,16 @@ TEST(HistUtil, SketchFromWeights) {
|
||||
TestSketchFromWeights(true);
|
||||
TestSketchFromWeights(false);
|
||||
}
|
||||
|
||||
TEST(HistUtil, SketchCategoricalFeatures) {
|
||||
TestCategoricalSketch(1000, 256, 32, false,
|
||||
[](DMatrix *p_fmat, int32_t num_bins) {
|
||||
return SketchOnDMatrix(p_fmat, num_bins);
|
||||
});
|
||||
TestCategoricalSketch(1000, 256, 32, true,
|
||||
[](DMatrix *p_fmat, int32_t num_bins) {
|
||||
return SketchOnDMatrix(p_fmat, num_bins);
|
||||
});
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
/*!
|
||||
* Copyright 2019-2021 by XGBoost Contributors
|
||||
*/
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
@@ -126,43 +129,15 @@ TEST(HistUtil, DeviceSketchCategoricalAsNumeric) {
|
||||
}
|
||||
}
|
||||
|
||||
void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins, bool weighted) {
|
||||
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
|
||||
auto dmat = GetDMatrixFromData(x, n, 1);
|
||||
dmat->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
||||
|
||||
if (weighted) {
|
||||
std::vector<float> weights(n, 0);
|
||||
SimpleLCG lcg;
|
||||
SimpleRealUniformDistribution<float> dist(0, 1);
|
||||
for (auto& v : weights) {
|
||||
v = dist(&lcg);
|
||||
}
|
||||
dmat->Info().weights_.HostVector() = weights;
|
||||
}
|
||||
|
||||
ASSERT_EQ(dmat->Info().feature_types.Size(), 1);
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
std::sort(x.begin(), x.end());
|
||||
auto n_uniques = std::unique(x.begin(), x.end()) - x.begin();
|
||||
ASSERT_NE(n_uniques, x.size());
|
||||
ASSERT_EQ(cuts.TotalBins(), n_uniques);
|
||||
ASSERT_EQ(n_uniques, num_categories);
|
||||
|
||||
auto& values = cuts.cut_values_.HostVector();
|
||||
ASSERT_TRUE(std::is_sorted(values.cbegin(), values.cend()));
|
||||
auto is_unique = (std::unique(values.begin(), values.end()) - values.begin()) == n_uniques;
|
||||
ASSERT_TRUE(is_unique);
|
||||
|
||||
x.resize(n_uniques);
|
||||
for (size_t i = 0; i < n_uniques; ++i) {
|
||||
ASSERT_EQ(x[i], values[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchCategoricalFeatures) {
|
||||
TestCategoricalSketch(1000, 256, 32, false);
|
||||
TestCategoricalSketch(1000, 256, 32, true);
|
||||
TestCategoricalSketch(1000, 256, 32, false,
|
||||
[](DMatrix *p_fmat, int32_t num_bins) {
|
||||
return DeviceSketch(0, p_fmat, num_bins);
|
||||
});
|
||||
TestCategoricalSketch(1000, 256, 32, true,
|
||||
[](DMatrix *p_fmat, int32_t num_bins) {
|
||||
return DeviceSketch(0, p_fmat, num_bins);
|
||||
});
|
||||
}
|
||||
|
||||
void TestMixedSketch() {
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
/*!
|
||||
* Copyright 2019-2021 by XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <gtest/gtest.h>
|
||||
#include <dmlc/filesystem.h>
|
||||
@@ -5,6 +8,8 @@
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/common/hist_util.h"
|
||||
#include "../../../src/data/simple_dmatrix.h"
|
||||
#include "../../../src/data/adapter.h"
|
||||
@@ -206,5 +211,45 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Test for sketching on categorical data.
|
||||
*
|
||||
* \param sketch Sketch function, can be on device or on host.
|
||||
*/
|
||||
template <typename Fn>
|
||||
void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins,
|
||||
bool weighted, Fn sketch) {
|
||||
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
|
||||
auto dmat = GetDMatrixFromData(x, n, 1);
|
||||
dmat->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
|
||||
|
||||
if (weighted) {
|
||||
std::vector<float> weights(n, 0);
|
||||
SimpleLCG lcg;
|
||||
SimpleRealUniformDistribution<float> dist(0, 1);
|
||||
for (auto& v : weights) {
|
||||
v = dist(&lcg);
|
||||
}
|
||||
dmat->Info().weights_.HostVector() = weights;
|
||||
}
|
||||
|
||||
ASSERT_EQ(dmat->Info().feature_types.Size(), 1);
|
||||
auto cuts = sketch(dmat.get(), num_bins);
|
||||
std::sort(x.begin(), x.end());
|
||||
auto n_uniques = std::unique(x.begin(), x.end()) - x.begin();
|
||||
ASSERT_NE(n_uniques, x.size());
|
||||
ASSERT_EQ(cuts.TotalBins(), n_uniques);
|
||||
ASSERT_EQ(n_uniques, num_categories);
|
||||
|
||||
auto& values = cuts.cut_values_.HostVector();
|
||||
ASSERT_TRUE(std::is_sorted(values.cbegin(), values.cend()));
|
||||
auto is_unique = (std::unique(values.begin(), values.end()) - values.begin()) == n_uniques;
|
||||
ASSERT_TRUE(is_unique);
|
||||
|
||||
x.resize(n_uniques);
|
||||
for (size_t i = 0; i < n_uniques; ++i) {
|
||||
ASSERT_EQ(x[i], values[i]);
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -43,12 +43,14 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
// Generate cuts for distributed environment.
|
||||
auto sparsity = 0.5f;
|
||||
auto rank = rabit::GetRank();
|
||||
HostSketchContainer sketch_distributed(column_size, n_bins, false, OmpGetNumThreads(0));
|
||||
auto m = RandomDataGenerator{rows, cols, sparsity}
|
||||
.Seed(rank)
|
||||
.Lower(.0f)
|
||||
.Upper(1.0f)
|
||||
.GenerateDMatrix();
|
||||
HostSketchContainer sketch_distributed(
|
||||
column_size, n_bins, m->Info().feature_types.ConstHostSpan(), false,
|
||||
OmpGetNumThreads(0));
|
||||
for (auto const &page : m->GetBatches<SparsePage>()) {
|
||||
sketch_distributed.PushRowPage(page, m->Info());
|
||||
}
|
||||
@@ -59,7 +61,9 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
rabit::Finalize();
|
||||
CHECK_EQ(rabit::GetWorldSize(), 1);
|
||||
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
|
||||
HostSketchContainer sketch_on_single_node(column_size, n_bins, false, OmpGetNumThreads(0));
|
||||
HostSketchContainer sketch_on_single_node(
|
||||
column_size, n_bins, m->Info().feature_types.ConstHostSpan(), false,
|
||||
OmpGetNumThreads(0));
|
||||
for (auto rank = 0; rank < world; ++rank) {
|
||||
auto m = RandomDataGenerator{rows, cols, sparsity}
|
||||
.Seed(rank)
|
||||
|
||||
Reference in New Issue
Block a user