Support hessian in host sketch container. (#7081)
Prepare for migrating approx onto hist's codebase.
This commit is contained in:
@@ -226,6 +226,39 @@ TEST(HistUtil, DenseCutsAccuracyTestWeights) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUtil, QuantileWithHessian) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
auto w = GenerateRandomWeights(num_rows);
|
||||
auto hessian = GenerateRandomWeights(num_rows);
|
||||
std::mt19937 rng(0);
|
||||
std::shuffle(hessian.begin(), hessian.end(), rng);
|
||||
dmat->Info().weights_.HostVector() = w;
|
||||
|
||||
for (auto num_bins : bin_sizes) {
|
||||
HistogramCuts cuts_hess = SketchOnDMatrix(dmat.get(), num_bins, hessian);
|
||||
for (size_t i = 0; i < w.size(); ++i) {
|
||||
dmat->Info().weights_.HostVector()[i] = w[i] * hessian[i];
|
||||
}
|
||||
ValidateCuts(cuts_hess, dmat.get(), num_bins);
|
||||
|
||||
HistogramCuts cuts_wh = SketchOnDMatrix(dmat.get(), num_bins);
|
||||
ValidateCuts(cuts_wh, dmat.get(), num_bins);
|
||||
|
||||
ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size());
|
||||
for (size_t i = 0; i < cuts_hess.Values().size(); ++i) {
|
||||
ASSERT_NEAR(cuts_wh.Values()[i], cuts_hess.Values()[i], kRtEps);
|
||||
}
|
||||
|
||||
dmat->Info().weights_.HostVector() = w;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUtil, DenseCutsExternalMemory) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
|
||||
Reference in New Issue
Block a user