[coll] Pass context to various functions. (#9772)

* [coll] Pass context to various functions.

In the future, the `Context` object would be required for collective operations, this PR
passes the context object to some required functions to prepare for swapping out the
implementation.
This commit is contained in:
Jiaming Yuan
2023-11-08 09:54:05 +08:00
committed by GitHub
parent 6c0a190f6d
commit 06bdc15e9b
45 changed files with 275 additions and 255 deletions

View File

@@ -178,7 +178,7 @@ class MetaInfo {
* in vertical federated learning, since each worker loads its own list of columns,
* we need to sum them.
*/
void SynchronizeNumberOfColumns();
void SynchronizeNumberOfColumns(Context const* ctx);
/*! \brief Whether the data is split row-wise. */
bool IsRowSplit() const {

View File

@@ -582,20 +582,20 @@ auto MakeTensorView(Context const *ctx, Container &data, S &&...shape) { // NOL
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device()};
}
template <typename T, typename... S>
LINALG_HD auto MakeTensorView(DeviceOrd device, common::Span<T> data, S &&...shape) {
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
LINALG_HD auto MakeTensorView(DeviceOrd device, common::Span<T, ext> data, S &&...shape) {
std::size_t in_shape[sizeof...(S)];
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
return TensorView<T, sizeof...(S)>{data, in_shape, device};
}
template <typename T, typename... S>
auto MakeTensorView(Context const *ctx, common::Span<T> data, S &&...shape) {
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
auto MakeTensorView(Context const *ctx, common::Span<T, ext> data, S &&...shape) {
return MakeTensorView(ctx->Device(), data, std::forward<S>(shape)...);
}
template <typename T, typename... S>
auto MakeTensorView(Context const *ctx, Order order, common::Span<T> data, S &&...shape) {
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
auto MakeTensorView(Context const *ctx, Order order, common::Span<T, ext> data, S &&...shape) {
std::size_t in_shape[sizeof...(S)];
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device(), order};