From 7785d65c8a086bc757a86472fb0ffc4d9c918216 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 22 Jul 2022 20:23:05 +0800 Subject: [PATCH] Fix feature weights with multiple column sampling. (#8100) --- src/common/random.cc | 19 ++++++++++--------- tests/cpp/common/test_random.cc | 16 ++++++++++++++++ 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/src/common/random.cc b/src/common/random.cc index f386cad91..f66b084cc 100644 --- a/src/common/random.cc +++ b/src/common/random.cc @@ -7,8 +7,7 @@ namespace xgboost { namespace common { std::shared_ptr> ColumnSampler::ColSample( - std::shared_ptr> p_features, - float colsample) { + std::shared_ptr> p_features, float colsample) { if (colsample == 1.0f) { return p_features; } @@ -20,19 +19,21 @@ std::shared_ptr> ColumnSampler::ColSample( auto &new_features = *p_new_features; if (feature_weights_.size() != 0) { - new_features.HostVector() = WeightedSamplingWithoutReplacement( - p_features->HostVector(), feature_weights_, n); + auto const &h_features = p_features->HostVector(); + std::vector weights(h_features.size()); + for (size_t i = 0; i < h_features.size(); ++i) { + weights[i] = feature_weights_[h_features[i]]; + } + new_features.HostVector() = + WeightedSamplingWithoutReplacement(p_features->HostVector(), weights, n); } else { new_features.Resize(features.size()); - std::copy(features.begin(), features.end(), - new_features.HostVector().begin()); - std::shuffle(new_features.HostVector().begin(), - new_features.HostVector().end(), rng_); + std::copy(features.begin(), features.end(), new_features.HostVector().begin()); + std::shuffle(new_features.HostVector().begin(), new_features.HostVector().end(), rng_); new_features.Resize(n); } std::sort(new_features.HostVector().begin(), new_features.HostVector().end()); return p_new_features; } - } // namespace common } // namespace xgboost diff --git a/tests/cpp/common/test_random.cc b/tests/cpp/common/test_random.cc index 9b2a15155..201f7b407 100644 --- a/tests/cpp/common/test_random.cc +++ b/tests/cpp/common/test_random.cc @@ -126,5 +126,21 @@ TEST(ColumnSampler, WeightedSampling) { EXPECT_NEAR(freq[i], feature_weights[i], 1e-2); } } + +TEST(ColumnSampler, WeightedMultiSampling) { + size_t constexpr kCols = 32; + std::vector feature_weights(kCols, 0); + for (size_t i = 0; i < feature_weights.size(); ++i) { + feature_weights[i] = i; + } + ColumnSampler cs{0}; + float bytree{0.5}, bylevel{0.5}, bynode{0.5}; + cs.Init(feature_weights.size(), feature_weights, bytree, bylevel, bynode); + auto feature_set = cs.GetFeatureSet(0); + size_t n_sampled = kCols * bytree * bylevel * bynode; + ASSERT_EQ(feature_set->Size(), n_sampled); + feature_set = cs.GetFeatureSet(1); + ASSERT_EQ(feature_set->Size(), n_sampled); +} } // namespace common } // namespace xgboost