Fix LTR with weighted Quantile DMatrix. (#7975)
* Fix LTR with weighted Quantile DMatrix. * Better tests.
This commit is contained in:
parent
1a33b50a0d
commit
8f8bd8147a
@ -184,8 +184,6 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
|||||||
dh::safe_cuda(cudaSetDevice(device));
|
dh::safe_cuda(cudaSetDevice(device));
|
||||||
info.weights_.SetDevice(device);
|
info.weights_.SetDevice(device);
|
||||||
auto weights = info.weights_.ConstDeviceSpan();
|
auto weights = info.weights_.ConstDeviceSpan();
|
||||||
dh::caching_device_vector<bst_group_t> group_ptr(info.group_ptr_);
|
|
||||||
auto d_group_ptr = dh::ToSpan(group_ptr);
|
|
||||||
|
|
||||||
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
||||||
thrust::make_counting_iterator(0llu),
|
thrust::make_counting_iterator(0llu),
|
||||||
@ -205,9 +203,13 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
|||||||
auto d_temp_weights = dh::ToSpan(temp_weights);
|
auto d_temp_weights = dh::ToSpan(temp_weights);
|
||||||
|
|
||||||
if (is_ranking) {
|
if (is_ranking) {
|
||||||
|
if (!weights.empty()) {
|
||||||
|
CHECK_EQ(weights.size(), info.group_ptr_.size() - 1);
|
||||||
|
}
|
||||||
|
dh::caching_device_vector<bst_group_t> group_ptr(info.group_ptr_);
|
||||||
|
auto d_group_ptr = dh::ToSpan(group_ptr);
|
||||||
auto const weight_iter = dh::MakeTransformIterator<float>(
|
auto const weight_iter = dh::MakeTransformIterator<float>(
|
||||||
thrust::make_constant_iterator(0lu),
|
thrust::make_counting_iterator(0lu), [=] __device__(size_t idx) -> float {
|
||||||
[=]__device__(size_t idx) -> float {
|
|
||||||
auto ridx = batch.GetElement(idx).row_idx;
|
auto ridx = batch.GetElement(idx).row_idx;
|
||||||
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx);
|
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx);
|
||||||
return weights[group_idx];
|
return weights[group_idx];
|
||||||
@ -272,7 +274,7 @@ void AdapterDeviceSketch(Batch batch, int num_bins,
|
|||||||
size_t num_cols = batch.NumCols();
|
size_t num_cols = batch.NumCols();
|
||||||
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
|
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
|
||||||
int32_t device = sketch_container->DeviceIdx();
|
int32_t device = sketch_container->DeviceIdx();
|
||||||
bool weighted = info.weights_.Size() != 0;
|
bool weighted = !info.weights_.Empty();
|
||||||
|
|
||||||
if (weighted) {
|
if (weighted) {
|
||||||
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
||||||
|
|||||||
@ -122,27 +122,6 @@ std::vector<float> MergeWeights(MetaInfo const &info, Span<float const> hessian,
|
|||||||
}
|
}
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
|
|
||||||
std::vector<float> const &group_weights = info.weights_.HostVector();
|
|
||||||
if (group_weights.empty()) {
|
|
||||||
return group_weights;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t n_samples = info.num_row_;
|
|
||||||
auto const &group_ptr = info.group_ptr_;
|
|
||||||
std::vector<float> results(n_samples);
|
|
||||||
CHECK_GE(group_ptr.size(), 2);
|
|
||||||
CHECK_EQ(group_ptr.back(), n_samples);
|
|
||||||
size_t cur_group = 0;
|
|
||||||
for (size_t i = 0; i < n_samples; ++i) {
|
|
||||||
results[i] = group_weights[cur_group];
|
|
||||||
if (i == group_ptr[cur_group + 1]) {
|
|
||||||
cur_group++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return results;
|
|
||||||
}
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
template <typename WQSketch>
|
template <typename WQSketch>
|
||||||
@ -156,12 +135,10 @@ void SketchContainerImpl<WQSketch>::PushRowPage(SparsePage const &page, MetaInfo
|
|||||||
|
|
||||||
// glue these conditions using ternary operator to avoid making data copies.
|
// glue these conditions using ternary operator to avoid making data copies.
|
||||||
auto const &weights =
|
auto const &weights =
|
||||||
hessian.empty()
|
hessian.empty() ? (use_group_ind_ ? detail::UnrollGroupWeights(info) // use group weight
|
||||||
? (use_group_ind_ ? UnrollGroupWeights(info) // use group weight
|
: info.weights_.HostVector()) // use sample weight
|
||||||
: info.weights_.HostVector()) // use sample weight
|
: MergeWeights(info, hessian, use_group_ind_,
|
||||||
: MergeWeights(
|
n_threads_); // use hessian merged with group/sample weights
|
||||||
info, hessian, use_group_ind_,
|
|
||||||
n_threads_); // use hessian merged with group/sample weights
|
|
||||||
if (!weights.empty()) {
|
if (!weights.empty()) {
|
||||||
CHECK_EQ(weights.size(), info.num_row_);
|
CHECK_EQ(weights.size(), info.num_row_);
|
||||||
}
|
}
|
||||||
@ -563,8 +540,8 @@ void SortedSketchContainer::PushColPage(SparsePage const &page, MetaInfo const &
|
|||||||
monitor_.Start(__func__);
|
monitor_.Start(__func__);
|
||||||
// glue these conditions using ternary operator to avoid making data copies.
|
// glue these conditions using ternary operator to avoid making data copies.
|
||||||
auto const &weights =
|
auto const &weights =
|
||||||
hessian.empty() ? (use_group_ind_ ? UnrollGroupWeights(info) // use group weight
|
hessian.empty() ? (use_group_ind_ ? detail::UnrollGroupWeights(info) // use group weight
|
||||||
: info.weights_.HostVector()) // use sample weight
|
: info.weights_.HostVector()) // use sample weight
|
||||||
: MergeWeights(info, hessian, use_group_ind_,
|
: MergeWeights(info, hessian, use_group_ind_,
|
||||||
n_threads_); // use hessian merged with group/sample weights
|
n_threads_); // use hessian merged with group/sample weights
|
||||||
CHECK_EQ(weights.size(), info.num_row_);
|
CHECK_EQ(weights.size(), info.num_row_);
|
||||||
|
|||||||
@ -697,6 +697,29 @@ class WXQuantileSketch :
|
|||||||
public QuantileSketchTemplate<DType, RType, WXQSummary<DType, RType> > {
|
public QuantileSketchTemplate<DType, RType, WXQSummary<DType, RType> > {
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
inline std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
|
||||||
|
std::vector<float> const &group_weights = info.weights_.HostVector();
|
||||||
|
if (group_weights.empty()) {
|
||||||
|
return group_weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t n_samples = info.num_row_;
|
||||||
|
auto const &group_ptr = info.group_ptr_;
|
||||||
|
std::vector<float> results(n_samples);
|
||||||
|
CHECK_GE(group_ptr.size(), 2);
|
||||||
|
CHECK_EQ(group_ptr.back(), n_samples);
|
||||||
|
size_t cur_group = 0;
|
||||||
|
for (size_t i = 0; i < n_samples; ++i) {
|
||||||
|
results[i] = group_weights[cur_group];
|
||||||
|
if (i == group_ptr[cur_group + 1]) {
|
||||||
|
cur_group++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
class HistogramCuts;
|
class HistogramCuts;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -366,6 +366,7 @@ void TestSketchFromWeights(bool with_group) {
|
|||||||
ValidateCuts(cuts, m.get(), kBins);
|
ValidateCuts(cuts, m.get(), kBins);
|
||||||
|
|
||||||
if (with_group) {
|
if (with_group) {
|
||||||
|
m->Info().weights_ = decltype(m->Info().weights_)(); // remove weight
|
||||||
HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
|
HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
|
||||||
for (size_t i = 0; i < cuts.Values().size(); ++i) {
|
for (size_t i = 0; i < cuts.Values().size(); ++i) {
|
||||||
EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
|
EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
|
||||||
@ -377,6 +378,17 @@ void TestSketchFromWeights(bool with_group) {
|
|||||||
ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i));
|
ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (with_group) {
|
||||||
|
auto& h_weights = info.weights_.HostVector();
|
||||||
|
h_weights.resize(kGroups);
|
||||||
|
// Generate different weight.
|
||||||
|
for (size_t i = 0; i < h_weights.size(); ++i) {
|
||||||
|
h_weights[i] = static_cast<float>(i + 1) / static_cast<float>(kGroups);
|
||||||
|
}
|
||||||
|
HistogramCuts weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
|
||||||
|
ValidateCuts(weighted, m.get(), kBins);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HistUtil, SketchFromWeights) {
|
TEST(HistUtil, SketchFromWeights) {
|
||||||
|
|||||||
@ -593,9 +593,10 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
|||||||
ValidateCuts(cuts, dmat.get(), kBins);
|
ValidateCuts(cuts, dmat.get(), kBins);
|
||||||
|
|
||||||
if (with_group) {
|
if (with_group) {
|
||||||
|
dmat->Info().weights_ = decltype(dmat->Info().weights_)(); // remove weight
|
||||||
HistogramCuts non_weighted = DeviceSketch(0, dmat.get(), kBins, 0);
|
HistogramCuts non_weighted = DeviceSketch(0, dmat.get(), kBins, 0);
|
||||||
for (size_t i = 0; i < cuts.Values().size(); ++i) {
|
for (size_t i = 0; i < cuts.Values().size(); ++i) {
|
||||||
EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
|
ASSERT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < cuts.MinValues().size(); ++i) {
|
for (size_t i = 0; i < cuts.MinValues().size(); ++i) {
|
||||||
ASSERT_EQ(cuts.MinValues()[i], non_weighted.MinValues()[i]);
|
ASSERT_EQ(cuts.MinValues()[i], non_weighted.MinValues()[i]);
|
||||||
@ -604,6 +605,24 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
|||||||
ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i));
|
ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (with_group) {
|
||||||
|
common::HistogramCuts weighted;
|
||||||
|
auto& h_weights = info.weights_.HostVector();
|
||||||
|
h_weights.resize(kGroups);
|
||||||
|
// Generate different weight.
|
||||||
|
for (size_t i = 0; i < h_weights.size(); ++i) {
|
||||||
|
// FIXME(jiamingy): Some entries generated GPU test cannot pass the validate cuts if
|
||||||
|
// we use more diverse weights, partially caused by
|
||||||
|
// https://github.com/dmlc/xgboost/issues/7946
|
||||||
|
h_weights[i] = (i % 2 == 0 ? 1 : 2) / static_cast<float>(kGroups);
|
||||||
|
}
|
||||||
|
SketchContainer sketch_container(ft, kBins, kCols, kRows, 0);
|
||||||
|
AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
|
||||||
|
&sketch_container);
|
||||||
|
sketch_container.MakeCuts(&weighted);
|
||||||
|
ValidateCuts(weighted, dmat.get(), kBins);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HistUtil, AdapterSketchFromWeights) {
|
TEST(HistUtil, AdapterSketchFromWeights) {
|
||||||
|
|||||||
@ -98,7 +98,11 @@ inline void TestBinDistribution(const HistogramCuts &cuts, int column_idx,
|
|||||||
int num_bins) {
|
int num_bins) {
|
||||||
std::map<int, int> bin_weights;
|
std::map<int, int> bin_weights;
|
||||||
for (auto i = 0ull; i < sorted_column.size(); i++) {
|
for (auto i = 0ull; i < sorted_column.size(); i++) {
|
||||||
bin_weights[cuts.SearchBin(sorted_column[i], column_idx)] += sorted_weights[i];
|
auto bin_idx = cuts.SearchBin(sorted_column[i], column_idx);
|
||||||
|
if (bin_weights.find(bin_idx) == bin_weights.cend()) {
|
||||||
|
bin_weights[bin_idx] = 0;
|
||||||
|
}
|
||||||
|
bin_weights.at(bin_idx) += sorted_weights[i];
|
||||||
}
|
}
|
||||||
int local_num_bins = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx];
|
int local_num_bins = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx];
|
||||||
auto total_weight = std::accumulate(sorted_weights.begin(), sorted_weights.end(),0);
|
auto total_weight = std::accumulate(sorted_weights.begin(), sorted_weights.end(),0);
|
||||||
@ -176,8 +180,7 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat,
|
inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, int num_bins) {
|
||||||
int num_bins) {
|
|
||||||
// Collect data into columns
|
// Collect data into columns
|
||||||
std::vector<std::vector<float>> columns(dmat->Info().num_col_);
|
std::vector<std::vector<float>> columns(dmat->Info().num_col_);
|
||||||
for (auto& batch : dmat->GetBatches<SparsePage>()) {
|
for (auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||||
@ -189,17 +192,22 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// construct weights.
|
||||||
|
std::vector<float> w = dmat->Info().group_ptr_.empty() ? dmat->Info().weights_.HostVector()
|
||||||
|
: detail::UnrollGroupWeights(dmat->Info());
|
||||||
|
|
||||||
// Sort
|
// Sort
|
||||||
for (auto i = 0ull; i < columns.size(); i++) {
|
for (auto i = 0ull; i < columns.size(); i++) {
|
||||||
auto& col = columns.at(i);
|
auto& col = columns.at(i);
|
||||||
const auto& w = dmat->Info().weights_.HostVector();
|
std::vector<size_t> index(col.size());
|
||||||
std::vector<size_t > index(col.size());
|
|
||||||
std::iota(index.begin(), index.end(), 0);
|
std::iota(index.begin(), index.end(), 0);
|
||||||
std::sort(index.begin(), index.end(),
|
std::sort(index.begin(), index.end(), [=](size_t a, size_t b) { return col[a] < col[b]; });
|
||||||
[=](size_t a, size_t b) { return col[a] < col[b]; });
|
|
||||||
|
|
||||||
std::vector<float> sorted_column(col.size());
|
std::vector<float> sorted_column(col.size());
|
||||||
std::vector<float> sorted_weights(col.size(), 1.0);
|
std::vector<float> sorted_weights(col.size(), 1.0);
|
||||||
|
const auto& w = dmat->Info().weights_.HostVector();
|
||||||
|
|
||||||
for (auto j = 0ull; j < col.size(); j++) {
|
for (auto j = 0ull; j < col.size(); j++) {
|
||||||
sorted_column[j] = col[index[j]];
|
sorted_column[j] = col[index[j]];
|
||||||
if (w.size() == col.size()) {
|
if (w.size() == col.size()) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user