Support column split with GPU quantile (#9370)
This commit is contained in:
parent
97ed944209
commit
3632242e0b
@ -352,7 +352,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sketch_container.MakeCuts(&cuts);
|
sketch_container.MakeCuts(&cuts, dmat->Info().IsColumnSplit());
|
||||||
return cuts;
|
return cuts;
|
||||||
}
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
|||||||
@ -501,10 +501,10 @@ void SketchContainer::FixError() {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void SketchContainer::AllReduce() {
|
void SketchContainer::AllReduce(bool is_column_split) {
|
||||||
dh::safe_cuda(cudaSetDevice(device_));
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
auto world = collective::GetWorldSize();
|
auto world = collective::GetWorldSize();
|
||||||
if (world == 1) {
|
if (world == 1 || is_column_split) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -582,13 +582,13 @@ struct InvalidCatOp {
|
|||||||
};
|
};
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
|
void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) {
|
||||||
timer_.Start(__func__);
|
timer_.Start(__func__);
|
||||||
dh::safe_cuda(cudaSetDevice(device_));
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
p_cuts->min_vals_.Resize(num_columns_);
|
p_cuts->min_vals_.Resize(num_columns_);
|
||||||
|
|
||||||
// Sync between workers.
|
// Sync between workers.
|
||||||
this->AllReduce();
|
this->AllReduce(is_column_split);
|
||||||
|
|
||||||
// Prune to final number of bins.
|
// Prune to final number of bins.
|
||||||
this->Prune(num_bins_ + 1);
|
this->Prune(num_bins_ + 1);
|
||||||
|
|||||||
@ -154,9 +154,9 @@ class SketchContainer {
|
|||||||
Span<SketchEntry const> that);
|
Span<SketchEntry const> that);
|
||||||
|
|
||||||
/* \brief Merge quantiles from other GPU workers. */
|
/* \brief Merge quantiles from other GPU workers. */
|
||||||
void AllReduce();
|
void AllReduce(bool is_column_split);
|
||||||
/* \brief Create the final histogram cut values. */
|
/* \brief Create the final histogram cut values. */
|
||||||
void MakeCuts(HistogramCuts* cuts);
|
void MakeCuts(HistogramCuts* cuts, bool is_column_split);
|
||||||
|
|
||||||
Span<SketchEntry const> Data() const {
|
Span<SketchEntry const> Data() const {
|
||||||
return {this->Current().data().get(), this->Current().size()};
|
return {this->Current().data().get(), this->Current().size()};
|
||||||
|
|||||||
@ -106,7 +106,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
|||||||
sketch_containers.clear();
|
sketch_containers.clear();
|
||||||
sketch_containers.shrink_to_fit();
|
sketch_containers.shrink_to_fit();
|
||||||
|
|
||||||
final_sketch.MakeCuts(&cuts);
|
final_sketch.MakeCuts(&cuts, this->info_.IsColumnSplit());
|
||||||
} else {
|
} else {
|
||||||
GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts);
|
GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -351,7 +351,7 @@ auto MakeUnweightedCutsForTest(Adapter adapter, int32_t num_bins, float missing,
|
|||||||
SketchContainer sketch_container(ft, num_bins, adapter.NumColumns(), adapter.NumRows(), 0);
|
SketchContainer sketch_container(ft, num_bins, adapter.NumColumns(), adapter.NumRows(), 0);
|
||||||
MetaInfo info;
|
MetaInfo info;
|
||||||
AdapterDeviceSketch(adapter.Value(), num_bins, info, missing, &sketch_container, batch_size);
|
AdapterDeviceSketch(adapter.Value(), num_bins, info, missing, &sketch_container, batch_size);
|
||||||
sketch_container.MakeCuts(&batched_cuts);
|
sketch_container.MakeCuts(&batched_cuts, info.IsColumnSplit());
|
||||||
return batched_cuts;
|
return batched_cuts;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -419,7 +419,7 @@ TEST(HistUtil, AdapterSketchSlidingWindowMemory) {
|
|||||||
AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits<float>::quiet_NaN(),
|
AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits<float>::quiet_NaN(),
|
||||||
&sketch_container);
|
&sketch_container);
|
||||||
HistogramCuts cuts;
|
HistogramCuts cuts;
|
||||||
sketch_container.MakeCuts(&cuts);
|
sketch_container.MakeCuts(&cuts, info.IsColumnSplit());
|
||||||
size_t bytes_required = detail::RequiredMemory(
|
size_t bytes_required = detail::RequiredMemory(
|
||||||
num_rows, num_columns, num_rows * num_columns, num_bins, false);
|
num_rows, num_columns, num_rows * num_columns, num_bins, false);
|
||||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
|
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
|
||||||
@ -449,7 +449,7 @@ TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) {
|
|||||||
&sketch_container);
|
&sketch_container);
|
||||||
|
|
||||||
HistogramCuts cuts;
|
HistogramCuts cuts;
|
||||||
sketch_container.MakeCuts(&cuts);
|
sketch_container.MakeCuts(&cuts, info.IsColumnSplit());
|
||||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||||
size_t bytes_required = detail::RequiredMemory(
|
size_t bytes_required = detail::RequiredMemory(
|
||||||
num_rows, num_columns, num_rows * num_columns, num_bins, true);
|
num_rows, num_columns, num_rows * num_columns, num_bins, true);
|
||||||
@ -482,7 +482,7 @@ void TestCategoricalSketchAdapter(size_t n, size_t num_categories,
|
|||||||
AdapterDeviceSketch(adapter.Value(), num_bins, info,
|
AdapterDeviceSketch(adapter.Value(), num_bins, info,
|
||||||
std::numeric_limits<float>::quiet_NaN(), &container);
|
std::numeric_limits<float>::quiet_NaN(), &container);
|
||||||
HistogramCuts cuts;
|
HistogramCuts cuts;
|
||||||
container.MakeCuts(&cuts);
|
container.MakeCuts(&cuts, info.IsColumnSplit());
|
||||||
|
|
||||||
thrust::sort(x.begin(), x.end());
|
thrust::sort(x.begin(), x.end());
|
||||||
auto n_uniques = thrust::unique(x.begin(), x.end()) - x.begin();
|
auto n_uniques = thrust::unique(x.begin(), x.end()) - x.begin();
|
||||||
@ -710,7 +710,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
|||||||
&sketch_container);
|
&sketch_container);
|
||||||
|
|
||||||
common::HistogramCuts cuts;
|
common::HistogramCuts cuts;
|
||||||
sketch_container.MakeCuts(&cuts);
|
sketch_container.MakeCuts(&cuts, info.IsColumnSplit());
|
||||||
|
|
||||||
auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols);
|
auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols);
|
||||||
if (with_group) {
|
if (with_group) {
|
||||||
@ -751,7 +751,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
|||||||
SketchContainer sketch_container(ft, kBins, kCols, kRows, 0);
|
SketchContainer sketch_container(ft, kBins, kCols, kRows, 0);
|
||||||
AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
|
AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
|
||||||
&sketch_container);
|
&sketch_container);
|
||||||
sketch_container.MakeCuts(&weighted);
|
sketch_container.MakeCuts(&weighted, info.IsColumnSplit());
|
||||||
ValidateCuts(weighted, dmat.get(), kBins);
|
ValidateCuts(weighted, dmat.get(), kBins);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -388,7 +388,7 @@ void TestAllReduceBasic(int32_t n_gpus) {
|
|||||||
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
||||||
std::numeric_limits<float>::quiet_NaN(),
|
std::numeric_limits<float>::quiet_NaN(),
|
||||||
&sketch_distributed);
|
&sketch_distributed);
|
||||||
sketch_distributed.AllReduce();
|
sketch_distributed.AllReduce(false);
|
||||||
sketch_distributed.Unique();
|
sketch_distributed.Unique();
|
||||||
|
|
||||||
ASSERT_EQ(sketch_distributed.ColumnsPtr().size(),
|
ASSERT_EQ(sketch_distributed.ColumnsPtr().size(),
|
||||||
@ -425,6 +425,58 @@ TEST(GPUQuantile, MGPUAllReduceBasic) {
|
|||||||
RunWithInMemoryCommunicator(n_gpus, TestAllReduceBasic, n_gpus);
|
RunWithInMemoryCommunicator(n_gpus, TestAllReduceBasic, n_gpus);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void TestColumnSplitBasic() {
|
||||||
|
auto const world = collective::GetWorldSize();
|
||||||
|
auto const rank = collective::GetRank();
|
||||||
|
std::size_t constexpr kRows = 1000, kCols = 100, kBins = 64;
|
||||||
|
|
||||||
|
auto m = std::unique_ptr<DMatrix>{[=]() {
|
||||||
|
auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
|
||||||
|
return dmat->SliceCol(world, rank);
|
||||||
|
}()};
|
||||||
|
|
||||||
|
// Generate cuts for distributed environment.
|
||||||
|
auto const device = rank;
|
||||||
|
HistogramCuts distributed_cuts = common::DeviceSketch(device, m.get(), kBins);
|
||||||
|
|
||||||
|
// Generate cuts for single node environment
|
||||||
|
collective::Finalize();
|
||||||
|
CHECK_EQ(collective::GetWorldSize(), 1);
|
||||||
|
HistogramCuts single_node_cuts = common::DeviceSketch(device, m.get(), kBins);
|
||||||
|
|
||||||
|
auto const& sptrs = single_node_cuts.Ptrs();
|
||||||
|
auto const& dptrs = distributed_cuts.Ptrs();
|
||||||
|
auto const& svals = single_node_cuts.Values();
|
||||||
|
auto const& dvals = distributed_cuts.Values();
|
||||||
|
auto const& smins = single_node_cuts.MinValues();
|
||||||
|
auto const& dmins = distributed_cuts.MinValues();
|
||||||
|
|
||||||
|
EXPECT_EQ(sptrs.size(), dptrs.size());
|
||||||
|
for (size_t i = 0; i < sptrs.size(); ++i) {
|
||||||
|
EXPECT_EQ(sptrs[i], dptrs[i]) << "rank: " << rank << ", i: " << i;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(svals.size(), dvals.size());
|
||||||
|
for (size_t i = 0; i < svals.size(); ++i) {
|
||||||
|
EXPECT_NEAR(svals[i], dvals[i], 2e-2f) << "rank: " << rank << ", i: " << i;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(smins.size(), dmins.size());
|
||||||
|
for (size_t i = 0; i < smins.size(); ++i) {
|
||||||
|
EXPECT_FLOAT_EQ(smins[i], dmins[i]) << "rank: " << rank << ", i: " << i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
TEST(GPUQuantile, MGPUColumnSplitBasic) {
|
||||||
|
auto const n_gpus = AllVisibleGPUs();
|
||||||
|
if (n_gpus <= 1) {
|
||||||
|
GTEST_SKIP() << "Skipping MGPUColumnSplitBasic test with # GPUs = " << n_gpus;
|
||||||
|
}
|
||||||
|
RunWithInMemoryCommunicator(n_gpus, TestColumnSplitBasic);
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void TestSameOnAllWorkers(std::int32_t n_gpus) {
|
void TestSameOnAllWorkers(std::int32_t n_gpus) {
|
||||||
auto world = collective::GetWorldSize();
|
auto world = collective::GetWorldSize();
|
||||||
@ -445,7 +497,7 @@ void TestSameOnAllWorkers(std::int32_t n_gpus) {
|
|||||||
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
||||||
std::numeric_limits<float>::quiet_NaN(),
|
std::numeric_limits<float>::quiet_NaN(),
|
||||||
&sketch_distributed);
|
&sketch_distributed);
|
||||||
sketch_distributed.AllReduce();
|
sketch_distributed.AllReduce(false);
|
||||||
sketch_distributed.Unique();
|
sketch_distributed.Unique();
|
||||||
TestQuantileElemRank(device, sketch_distributed.Data(), sketch_distributed.ColumnsPtr(), true);
|
TestQuantileElemRank(device, sketch_distributed.Data(), sketch_distributed.ColumnsPtr(), true);
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user