Support hessian in host sketch container. (#7081)
Prepare for migrating approx onto hist's codebase.
This commit is contained in:
@@ -226,6 +226,39 @@ TEST(HistUtil, DenseCutsAccuracyTestWeights) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUtil, QuantileWithHessian) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
auto w = GenerateRandomWeights(num_rows);
|
||||
auto hessian = GenerateRandomWeights(num_rows);
|
||||
std::mt19937 rng(0);
|
||||
std::shuffle(hessian.begin(), hessian.end(), rng);
|
||||
dmat->Info().weights_.HostVector() = w;
|
||||
|
||||
for (auto num_bins : bin_sizes) {
|
||||
HistogramCuts cuts_hess = SketchOnDMatrix(dmat.get(), num_bins, hessian);
|
||||
for (size_t i = 0; i < w.size(); ++i) {
|
||||
dmat->Info().weights_.HostVector()[i] = w[i] * hessian[i];
|
||||
}
|
||||
ValidateCuts(cuts_hess, dmat.get(), num_bins);
|
||||
|
||||
HistogramCuts cuts_wh = SketchOnDMatrix(dmat.get(), num_bins);
|
||||
ValidateCuts(cuts_wh, dmat.get(), num_bins);
|
||||
|
||||
ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size());
|
||||
for (size_t i = 0; i < cuts_hess.Values().size(); ++i) {
|
||||
ASSERT_NEAR(cuts_wh.Values()[i], cuts_hess.Values()[i], kRtEps);
|
||||
}
|
||||
|
||||
dmat->Info().weights_.HostVector() = w;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUtil, DenseCutsExternalMemory) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
|
||||
@@ -43,7 +43,7 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
// Generate cuts for distributed environment.
|
||||
auto sparsity = 0.5f;
|
||||
auto rank = rabit::GetRank();
|
||||
HostSketchContainer sketch_distributed(column_size, n_bins, false);
|
||||
HostSketchContainer sketch_distributed(column_size, n_bins, false, OmpGetNumThreads(0));
|
||||
auto m = RandomDataGenerator{rows, cols, sparsity}
|
||||
.Seed(rank)
|
||||
.Lower(.0f)
|
||||
@@ -59,7 +59,7 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
rabit::Finalize();
|
||||
CHECK_EQ(rabit::GetWorldSize(), 1);
|
||||
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
|
||||
HostSketchContainer sketch_on_single_node(column_size, n_bins, false);
|
||||
HostSketchContainer sketch_on_single_node(column_size, n_bins, false, OmpGetNumThreads(0));
|
||||
for (auto rank = 0; rank < world; ++rank) {
|
||||
auto m = RandomDataGenerator{rows, cols, sparsity}
|
||||
.Seed(rank)
|
||||
|
||||
Reference in New Issue
Block a user