Support column split with GPU quantile (#9370)
This commit is contained in:
@@ -352,7 +352,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
}
|
||||
}
|
||||
}
|
||||
sketch_container.MakeCuts(&cuts);
|
||||
sketch_container.MakeCuts(&cuts, dmat->Info().IsColumnSplit());
|
||||
return cuts;
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
@@ -501,10 +501,10 @@ void SketchContainer::FixError() {
|
||||
});
|
||||
}
|
||||
|
||||
void SketchContainer::AllReduce() {
|
||||
void SketchContainer::AllReduce(bool is_column_split) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world == 1) {
|
||||
if (world == 1 || is_column_split) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -582,13 +582,13 @@ struct InvalidCatOp {
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
|
||||
void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) {
|
||||
timer_.Start(__func__);
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
p_cuts->min_vals_.Resize(num_columns_);
|
||||
|
||||
// Sync between workers.
|
||||
this->AllReduce();
|
||||
this->AllReduce(is_column_split);
|
||||
|
||||
// Prune to final number of bins.
|
||||
this->Prune(num_bins_ + 1);
|
||||
|
||||
@@ -154,9 +154,9 @@ class SketchContainer {
|
||||
Span<SketchEntry const> that);
|
||||
|
||||
/* \brief Merge quantiles from other GPU workers. */
|
||||
void AllReduce();
|
||||
void AllReduce(bool is_column_split);
|
||||
/* \brief Create the final histogram cut values. */
|
||||
void MakeCuts(HistogramCuts* cuts);
|
||||
void MakeCuts(HistogramCuts* cuts, bool is_column_split);
|
||||
|
||||
Span<SketchEntry const> Data() const {
|
||||
return {this->Current().data().get(), this->Current().size()};
|
||||
|
||||
@@ -106,7 +106,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
||||
sketch_containers.clear();
|
||||
sketch_containers.shrink_to_fit();
|
||||
|
||||
final_sketch.MakeCuts(&cuts);
|
||||
final_sketch.MakeCuts(&cuts, this->info_.IsColumnSplit());
|
||||
} else {
|
||||
GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user