Support column split with GPU quantile (#9370)

This commit is contained in:
Rong Ou
2023-07-10 21:15:56 -07:00
committed by GitHub
parent 97ed944209
commit 3632242e0b
6 changed files with 68 additions and 16 deletions

View File

@@ -352,7 +352,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
}
}
}
sketch_container.MakeCuts(&cuts);
sketch_container.MakeCuts(&cuts, dmat->Info().IsColumnSplit());
return cuts;
}
} // namespace common

View File

@@ -501,10 +501,10 @@ void SketchContainer::FixError() {
});
}
void SketchContainer::AllReduce() {
void SketchContainer::AllReduce(bool is_column_split) {
dh::safe_cuda(cudaSetDevice(device_));
auto world = collective::GetWorldSize();
if (world == 1) {
if (world == 1 || is_column_split) {
return;
}
@@ -582,13 +582,13 @@ struct InvalidCatOp {
};
} // anonymous namespace
void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
p_cuts->min_vals_.Resize(num_columns_);
// Sync between workers.
this->AllReduce();
this->AllReduce(is_column_split);
// Prune to final number of bins.
this->Prune(num_bins_ + 1);

View File

@@ -154,9 +154,9 @@ class SketchContainer {
Span<SketchEntry const> that);
/* \brief Merge quantiles from other GPU workers. */
void AllReduce();
void AllReduce(bool is_column_split);
/* \brief Create the final histogram cut values. */
void MakeCuts(HistogramCuts* cuts);
void MakeCuts(HistogramCuts* cuts, bool is_column_split);
Span<SketchEntry const> Data() const {
return {this->Current().data().get(), this->Current().size()};