[coll] Pass context to various functions. (#9772)

* [coll] Pass context to various functions.

In the future, the `Context` object would be required for collective operations, this PR
passes the context object to some required functions to prepare for swapping out the
implementation.
This commit is contained in:
Jiaming Yuan
2023-11-08 09:54:05 +08:00
committed by GitHub
parent 6c0a190f6d
commit 06bdc15e9b
45 changed files with 275 additions and 255 deletions

View File

@@ -181,7 +181,7 @@ void TestSyncHist(bool is_distributed) {
histogram.Buffer().Reset(1, n_nodes, space, target_hists);
// sync hist
histogram.SyncHistogram(&tree, nodes_for_explicit_hist_build, nodes_for_subtraction_trick);
histogram.SyncHistogram(&ctx, &tree, nodes_for_explicit_hist_build, nodes_for_subtraction_trick);
using GHistRowT = common::GHistRow;
auto check_hist = [](const GHistRowT parent, const GHistRowT left, const GHistRowT right,
@@ -266,7 +266,7 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
histogram.BuildHist(0, space, gidx, row_set_collection, nodes_to_build,
linalg::MakeTensorView(&ctx, gpair, gpair.size()), force_read_by_column);
}
histogram.SyncHistogram(&tree, nodes_to_build, {});
histogram.SyncHistogram(&ctx, &tree, nodes_to_build, {});
// Check if number of histogram bins is correct
ASSERT_EQ(histogram.Histogram()[nid].size(), gmat.cut.Ptrs().back());
@@ -366,7 +366,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size()),
force_read_by_column);
}
cat_hist.SyncHistogram(&tree, nodes_to_build, {});
cat_hist.SyncHistogram(&ctx, &tree, nodes_to_build, {});
/**
* Generate hist with one hot encoded data.
@@ -382,7 +382,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size()),
force_read_by_column);
}
onehot_hist.SyncHistogram(&tree, nodes_to_build, {});
onehot_hist.SyncHistogram(&ctx, &tree, nodes_to_build, {});
auto cat = cat_hist.Histogram()[0];
auto onehot = onehot_hist.Histogram()[0];
@@ -451,7 +451,7 @@ void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, boo
force_read_by_column);
++page_idx;
}
multi_build.SyncHistogram(&tree, nodes, {});
multi_build.SyncHistogram(ctx, &tree, nodes, {});
multi_page = multi_build.Histogram()[RegTree::kRoot];
}
@@ -480,7 +480,7 @@ void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, boo
single_build.BuildHist(0, space, gmat, row_set_collection, nodes,
linalg::MakeTensorView(ctx, h_gpair, h_gpair.size()),
force_read_by_column);
single_build.SyncHistogram(&tree, nodes, {});
single_build.SyncHistogram(ctx, &tree, nodes, {});
single_page = single_build.Histogram()[RegTree::kRoot];
}
@@ -570,7 +570,7 @@ class OverflowTest : public ::testing::TestWithParam<std::tuple<bool, bool>> {
CHECK_NE(partitioners.front()[tree.RightChild(best.nid)].Size(), 0);
hist_builder.BuildHistLeftRight(
Xy.get(), &tree, partitioners, valid_candidates,
&ctx, Xy.get(), &tree, partitioners, valid_candidates,
linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size(), 1), batch);
if (limit) {