[coll] Pass context to various functions. (#9772)
* [coll] Pass context to various functions. In the future, the `Context` object would be required for collective operations, this PR passes the context object to some required functions to prepare for swapping out the implementation.
This commit is contained in:
@@ -360,25 +360,27 @@ TEST(HistUtil, DeviceSketchExternalMemoryWithWeights) {
|
||||
}
|
||||
|
||||
template <typename Adapter>
|
||||
auto MakeUnweightedCutsForTest(Adapter adapter, int32_t num_bins, float missing, size_t batch_size = 0) {
|
||||
auto MakeUnweightedCutsForTest(Context const* ctx, Adapter adapter, int32_t num_bins, float missing,
|
||||
size_t batch_size = 0) {
|
||||
common::HistogramCuts batched_cuts;
|
||||
HostDeviceVector<FeatureType> ft;
|
||||
SketchContainer sketch_container(ft, num_bins, adapter.NumColumns(), adapter.NumRows(),
|
||||
DeviceOrd::CUDA(0));
|
||||
MetaInfo info;
|
||||
AdapterDeviceSketch(adapter.Value(), num_bins, info, missing, &sketch_container, batch_size);
|
||||
sketch_container.MakeCuts(&batched_cuts, info.IsColumnSplit());
|
||||
sketch_container.MakeCuts(ctx, &batched_cuts, info.IsColumnSplit());
|
||||
return batched_cuts;
|
||||
}
|
||||
|
||||
template <typename Adapter>
|
||||
void ValidateBatchedCuts(Adapter adapter, int num_bins, DMatrix* dmat, size_t batch_size = 0) {
|
||||
void ValidateBatchedCuts(Context const* ctx, Adapter adapter, int num_bins, DMatrix* dmat, size_t batch_size = 0) {
|
||||
common::HistogramCuts batched_cuts = MakeUnweightedCutsForTest(
|
||||
adapter, num_bins, std::numeric_limits<float>::quiet_NaN(), batch_size);
|
||||
ctx, adapter, num_bins, std::numeric_limits<float>::quiet_NaN(), batch_size);
|
||||
ValidateCuts(batched_cuts, dmat, num_bins);
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterDeviceSketch) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
int rows = 5;
|
||||
int cols = 1;
|
||||
int num_bins = 4;
|
||||
@@ -391,8 +393,8 @@ TEST(HistUtil, AdapterDeviceSketch) {
|
||||
|
||||
data::CupyAdapter adapter(str);
|
||||
|
||||
auto device_cuts = MakeUnweightedCutsForTest(adapter, num_bins, missing);
|
||||
Context ctx;
|
||||
auto device_cuts = MakeUnweightedCutsForTest(&ctx, adapter, num_bins, missing);
|
||||
ctx = ctx.MakeCPU();
|
||||
auto host_cuts = GetHostCuts(&ctx, &adapter, num_bins, missing);
|
||||
|
||||
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
|
||||
@@ -401,6 +403,7 @@ TEST(HistUtil, AdapterDeviceSketch) {
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterDeviceSketchMemory) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
int num_columns = 100;
|
||||
int num_rows = 1000;
|
||||
int num_bins = 256;
|
||||
@@ -410,7 +413,8 @@ TEST(HistUtil, AdapterDeviceSketchMemory) {
|
||||
|
||||
dh::GlobalMemoryLogger().Clear();
|
||||
ConsoleLogger::Configure({{"verbosity", "3"}});
|
||||
auto cuts = MakeUnweightedCutsForTest(adapter, num_bins, std::numeric_limits<float>::quiet_NaN());
|
||||
auto cuts =
|
||||
MakeUnweightedCutsForTest(&ctx, adapter, num_bins, std::numeric_limits<float>::quiet_NaN());
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
size_t bytes_required = detail::RequiredMemory(
|
||||
num_rows, num_columns, num_rows * num_columns, num_bins, false);
|
||||
@@ -419,6 +423,7 @@ TEST(HistUtil, AdapterDeviceSketchMemory) {
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterSketchSlidingWindowMemory) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
int num_columns = 100;
|
||||
int num_rows = 1000;
|
||||
int num_bins = 256;
|
||||
@@ -435,7 +440,7 @@ TEST(HistUtil, AdapterSketchSlidingWindowMemory) {
|
||||
AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits<float>::quiet_NaN(),
|
||||
&sketch_container);
|
||||
HistogramCuts cuts;
|
||||
sketch_container.MakeCuts(&cuts, info.IsColumnSplit());
|
||||
sketch_container.MakeCuts(&ctx, &cuts, info.IsColumnSplit());
|
||||
size_t bytes_required = detail::RequiredMemory(
|
||||
num_rows, num_columns, num_rows * num_columns, num_bins, false);
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
|
||||
@@ -444,6 +449,7 @@ TEST(HistUtil, AdapterSketchSlidingWindowMemory) {
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
int num_columns = 100;
|
||||
int num_rows = 1000;
|
||||
int num_bins = 256;
|
||||
@@ -465,7 +471,7 @@ TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) {
|
||||
&sketch_container);
|
||||
|
||||
HistogramCuts cuts;
|
||||
sketch_container.MakeCuts(&cuts, info.IsColumnSplit());
|
||||
sketch_container.MakeCuts(&ctx, &cuts, info.IsColumnSplit());
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
size_t bytes_required = detail::RequiredMemory(
|
||||
num_rows, num_columns, num_rows * num_columns, num_bins, true);
|
||||
@@ -475,6 +481,7 @@ TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) {
|
||||
|
||||
void TestCategoricalSketchAdapter(size_t n, size_t num_categories,
|
||||
int32_t num_bins, bool weighted) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
auto h_x = GenerateRandomCategoricalSingleColumn(n, num_categories);
|
||||
thrust::device_vector<float> x(h_x);
|
||||
auto adapter = AdapterFromData(x, n, 1);
|
||||
@@ -498,7 +505,7 @@ void TestCategoricalSketchAdapter(size_t n, size_t num_categories,
|
||||
AdapterDeviceSketch(adapter.Value(), num_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(), &container);
|
||||
HistogramCuts cuts;
|
||||
container.MakeCuts(&cuts, info.IsColumnSplit());
|
||||
container.MakeCuts(&ctx, &cuts, info.IsColumnSplit());
|
||||
|
||||
thrust::sort(x.begin(), x.end());
|
||||
auto n_uniques = thrust::unique(x.begin(), x.end()) - x.begin();
|
||||
@@ -522,6 +529,7 @@ void TestCategoricalSketchAdapter(size_t n, size_t num_categories,
|
||||
TEST(HistUtil, AdapterDeviceSketchCategorical) {
|
||||
auto categorical_sizes = {2, 6, 8, 12};
|
||||
int num_bins = 256;
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
auto sizes = {25, 100, 1000};
|
||||
for (auto n : sizes) {
|
||||
for (auto num_categories : categorical_sizes) {
|
||||
@@ -529,7 +537,7 @@ TEST(HistUtil, AdapterDeviceSketchCategorical) {
|
||||
auto dmat = GetDMatrixFromData(x, n, 1);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = AdapterFromData(x_device, n, 1);
|
||||
ValidateBatchedCuts(adapter, num_bins, dmat.get());
|
||||
ValidateBatchedCuts(&ctx, adapter, num_bins, dmat.get());
|
||||
TestCategoricalSketchAdapter(n, num_categories, num_bins, true);
|
||||
TestCategoricalSketchAdapter(n, num_categories, num_bins, false);
|
||||
}
|
||||
@@ -540,13 +548,14 @@ TEST(HistUtil, AdapterDeviceSketchMultipleColumns) {
|
||||
auto bin_sizes = {2, 16, 256, 512};
|
||||
auto sizes = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||
ValidateBatchedCuts(adapter, num_bins, dmat.get());
|
||||
ValidateBatchedCuts(&ctx, adapter, num_bins, dmat.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -556,12 +565,13 @@ TEST(HistUtil, AdapterDeviceSketchBatches) {
|
||||
int num_rows = 5000;
|
||||
auto batch_sizes = {0, 100, 1500, 6000};
|
||||
int num_columns = 5;
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
for (auto batch_size : batch_sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||
ValidateBatchedCuts(adapter, num_bins, dmat.get(), batch_size);
|
||||
ValidateBatchedCuts(&ctx, adapter, num_bins, dmat.get(), batch_size);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -647,12 +657,12 @@ TEST(HistUtil, SketchingEquivalent) {
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||
common::HistogramCuts adapter_cuts = MakeUnweightedCutsForTest(
|
||||
adapter, num_bins, std::numeric_limits<float>::quiet_NaN());
|
||||
&ctx, adapter, num_bins, std::numeric_limits<float>::quiet_NaN());
|
||||
EXPECT_EQ(dmat_cuts.Values(), adapter_cuts.Values());
|
||||
EXPECT_EQ(dmat_cuts.Ptrs(), adapter_cuts.Ptrs());
|
||||
EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues());
|
||||
|
||||
ValidateBatchedCuts(adapter, num_bins, dmat.get());
|
||||
ValidateBatchedCuts(&ctx, adapter, num_bins, dmat.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -702,7 +712,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
.Device(DeviceOrd::CUDA(0))
|
||||
.GenerateArrayInterface(&storage);
|
||||
MetaInfo info;
|
||||
Context ctx;
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
auto& h_weights = info.weights_.HostVector();
|
||||
if (with_group) {
|
||||
h_weights.resize(kGroups);
|
||||
@@ -731,7 +741,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
&sketch_container);
|
||||
|
||||
common::HistogramCuts cuts;
|
||||
sketch_container.MakeCuts(&cuts, info.IsColumnSplit());
|
||||
sketch_container.MakeCuts(&ctx, &cuts, info.IsColumnSplit());
|
||||
|
||||
auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols);
|
||||
if (with_group) {
|
||||
@@ -744,10 +754,9 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
ASSERT_EQ(cuts.Ptrs().size(), kCols + 1);
|
||||
ValidateCuts(cuts, dmat.get(), kBins);
|
||||
|
||||
auto cuda_ctx = MakeCUDACtx(0);
|
||||
if (with_group) {
|
||||
dmat->Info().weights_ = decltype(dmat->Info().weights_)(); // remove weight
|
||||
HistogramCuts non_weighted = DeviceSketch(&cuda_ctx, dmat.get(), kBins, 0);
|
||||
HistogramCuts non_weighted = DeviceSketch(&ctx, dmat.get(), kBins, 0);
|
||||
for (size_t i = 0; i < cuts.Values().size(); ++i) {
|
||||
ASSERT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
|
||||
}
|
||||
@@ -773,7 +782,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
SketchContainer sketch_container{ft, kBins, kCols, kRows, DeviceOrd::CUDA(0)};
|
||||
AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
|
||||
&sketch_container);
|
||||
sketch_container.MakeCuts(&weighted, info.IsColumnSplit());
|
||||
sketch_container.MakeCuts(&ctx, &weighted, info.IsColumnSplit());
|
||||
ValidateCuts(weighted, dmat.get(), kBins);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user