Fix feature weights with multiple column sampling. (#8100)
This commit is contained in:
parent
4a4e5c7c18
commit
7785d65c8a
@ -7,8 +7,7 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColumnSampler::ColSample(
|
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColumnSampler::ColSample(
|
||||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features,
|
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features, float colsample) {
|
||||||
float colsample) {
|
|
||||||
if (colsample == 1.0f) {
|
if (colsample == 1.0f) {
|
||||||
return p_features;
|
return p_features;
|
||||||
}
|
}
|
||||||
@ -20,19 +19,21 @@ std::shared_ptr<HostDeviceVector<bst_feature_t>> ColumnSampler::ColSample(
|
|||||||
auto &new_features = *p_new_features;
|
auto &new_features = *p_new_features;
|
||||||
|
|
||||||
if (feature_weights_.size() != 0) {
|
if (feature_weights_.size() != 0) {
|
||||||
new_features.HostVector() = WeightedSamplingWithoutReplacement(
|
auto const &h_features = p_features->HostVector();
|
||||||
p_features->HostVector(), feature_weights_, n);
|
std::vector<float> weights(h_features.size());
|
||||||
|
for (size_t i = 0; i < h_features.size(); ++i) {
|
||||||
|
weights[i] = feature_weights_[h_features[i]];
|
||||||
|
}
|
||||||
|
new_features.HostVector() =
|
||||||
|
WeightedSamplingWithoutReplacement(p_features->HostVector(), weights, n);
|
||||||
} else {
|
} else {
|
||||||
new_features.Resize(features.size());
|
new_features.Resize(features.size());
|
||||||
std::copy(features.begin(), features.end(),
|
std::copy(features.begin(), features.end(), new_features.HostVector().begin());
|
||||||
new_features.HostVector().begin());
|
std::shuffle(new_features.HostVector().begin(), new_features.HostVector().end(), rng_);
|
||||||
std::shuffle(new_features.HostVector().begin(),
|
|
||||||
new_features.HostVector().end(), rng_);
|
|
||||||
new_features.Resize(n);
|
new_features.Resize(n);
|
||||||
}
|
}
|
||||||
std::sort(new_features.HostVector().begin(), new_features.HostVector().end());
|
std::sort(new_features.HostVector().begin(), new_features.HostVector().end());
|
||||||
return p_new_features;
|
return p_new_features;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -126,5 +126,21 @@ TEST(ColumnSampler, WeightedSampling) {
|
|||||||
EXPECT_NEAR(freq[i], feature_weights[i], 1e-2);
|
EXPECT_NEAR(freq[i], feature_weights[i], 1e-2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ColumnSampler, WeightedMultiSampling) {
|
||||||
|
size_t constexpr kCols = 32;
|
||||||
|
std::vector<float> feature_weights(kCols, 0);
|
||||||
|
for (size_t i = 0; i < feature_weights.size(); ++i) {
|
||||||
|
feature_weights[i] = i;
|
||||||
|
}
|
||||||
|
ColumnSampler cs{0};
|
||||||
|
float bytree{0.5}, bylevel{0.5}, bynode{0.5};
|
||||||
|
cs.Init(feature_weights.size(), feature_weights, bytree, bylevel, bynode);
|
||||||
|
auto feature_set = cs.GetFeatureSet(0);
|
||||||
|
size_t n_sampled = kCols * bytree * bylevel * bynode;
|
||||||
|
ASSERT_EQ(feature_set->Size(), n_sampled);
|
||||||
|
feature_set = cs.GetFeatureSet(1);
|
||||||
|
ASSERT_EQ(feature_set->Size(), n_sampled);
|
||||||
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user