[coll] Pass context to various functions. (#9772)
* [coll] Pass context to various functions. In the future, the `Context` object would be required for collective operations, this PR passes the context object to some required functions to prepare for swapping out the implementation.
This commit is contained in:
@@ -181,7 +181,7 @@ void TestSyncHist(bool is_distributed) {
|
||||
|
||||
histogram.Buffer().Reset(1, n_nodes, space, target_hists);
|
||||
// sync hist
|
||||
histogram.SyncHistogram(&tree, nodes_for_explicit_hist_build, nodes_for_subtraction_trick);
|
||||
histogram.SyncHistogram(&ctx, &tree, nodes_for_explicit_hist_build, nodes_for_subtraction_trick);
|
||||
|
||||
using GHistRowT = common::GHistRow;
|
||||
auto check_hist = [](const GHistRowT parent, const GHistRowT left, const GHistRowT right,
|
||||
@@ -266,7 +266,7 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
|
||||
histogram.BuildHist(0, space, gidx, row_set_collection, nodes_to_build,
|
||||
linalg::MakeTensorView(&ctx, gpair, gpair.size()), force_read_by_column);
|
||||
}
|
||||
histogram.SyncHistogram(&tree, nodes_to_build, {});
|
||||
histogram.SyncHistogram(&ctx, &tree, nodes_to_build, {});
|
||||
|
||||
// Check if number of histogram bins is correct
|
||||
ASSERT_EQ(histogram.Histogram()[nid].size(), gmat.cut.Ptrs().back());
|
||||
@@ -366,7 +366,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
||||
linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size()),
|
||||
force_read_by_column);
|
||||
}
|
||||
cat_hist.SyncHistogram(&tree, nodes_to_build, {});
|
||||
cat_hist.SyncHistogram(&ctx, &tree, nodes_to_build, {});
|
||||
|
||||
/**
|
||||
* Generate hist with one hot encoded data.
|
||||
@@ -382,7 +382,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
||||
linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size()),
|
||||
force_read_by_column);
|
||||
}
|
||||
onehot_hist.SyncHistogram(&tree, nodes_to_build, {});
|
||||
onehot_hist.SyncHistogram(&ctx, &tree, nodes_to_build, {});
|
||||
|
||||
auto cat = cat_hist.Histogram()[0];
|
||||
auto onehot = onehot_hist.Histogram()[0];
|
||||
@@ -451,7 +451,7 @@ void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, boo
|
||||
force_read_by_column);
|
||||
++page_idx;
|
||||
}
|
||||
multi_build.SyncHistogram(&tree, nodes, {});
|
||||
multi_build.SyncHistogram(ctx, &tree, nodes, {});
|
||||
|
||||
multi_page = multi_build.Histogram()[RegTree::kRoot];
|
||||
}
|
||||
@@ -480,7 +480,7 @@ void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, boo
|
||||
single_build.BuildHist(0, space, gmat, row_set_collection, nodes,
|
||||
linalg::MakeTensorView(ctx, h_gpair, h_gpair.size()),
|
||||
force_read_by_column);
|
||||
single_build.SyncHistogram(&tree, nodes, {});
|
||||
single_build.SyncHistogram(ctx, &tree, nodes, {});
|
||||
|
||||
single_page = single_build.Histogram()[RegTree::kRoot];
|
||||
}
|
||||
@@ -570,7 +570,7 @@ class OverflowTest : public ::testing::TestWithParam<std::tuple<bool, bool>> {
|
||||
CHECK_NE(partitioners.front()[tree.RightChild(best.nid)].Size(), 0);
|
||||
|
||||
hist_builder.BuildHistLeftRight(
|
||||
Xy.get(), &tree, partitioners, valid_candidates,
|
||||
&ctx, Xy.get(), &tree, partitioners, valid_candidates,
|
||||
linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), gpair.Size(), 1), batch);
|
||||
|
||||
if (limit) {
|
||||
|
||||
Reference in New Issue
Block a user