[coll] Pass context to various functions. (#9772)
* [coll] Pass context to various functions. In the future, the `Context` object would be required for collective operations, this PR passes the context object to some required functions to prepare for swapping out the implementation.
This commit is contained in:
@@ -178,7 +178,7 @@ class MetaInfo {
|
||||
* in vertical federated learning, since each worker loads its own list of columns,
|
||||
* we need to sum them.
|
||||
*/
|
||||
void SynchronizeNumberOfColumns();
|
||||
void SynchronizeNumberOfColumns(Context const* ctx);
|
||||
|
||||
/*! \brief Whether the data is split row-wise. */
|
||||
bool IsRowSplit() const {
|
||||
|
||||
@@ -582,20 +582,20 @@ auto MakeTensorView(Context const *ctx, Container &data, S &&...shape) { // NOL
|
||||
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device()};
|
||||
}
|
||||
|
||||
template <typename T, typename... S>
|
||||
LINALG_HD auto MakeTensorView(DeviceOrd device, common::Span<T> data, S &&...shape) {
|
||||
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
|
||||
LINALG_HD auto MakeTensorView(DeviceOrd device, common::Span<T, ext> data, S &&...shape) {
|
||||
std::size_t in_shape[sizeof...(S)];
|
||||
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
|
||||
return TensorView<T, sizeof...(S)>{data, in_shape, device};
|
||||
}
|
||||
|
||||
template <typename T, typename... S>
|
||||
auto MakeTensorView(Context const *ctx, common::Span<T> data, S &&...shape) {
|
||||
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
|
||||
auto MakeTensorView(Context const *ctx, common::Span<T, ext> data, S &&...shape) {
|
||||
return MakeTensorView(ctx->Device(), data, std::forward<S>(shape)...);
|
||||
}
|
||||
|
||||
template <typename T, typename... S>
|
||||
auto MakeTensorView(Context const *ctx, Order order, common::Span<T> data, S &&...shape) {
|
||||
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
|
||||
auto MakeTensorView(Context const *ctx, Order order, common::Span<T, ext> data, S &&...shape) {
|
||||
std::size_t in_shape[sizeof...(S)];
|
||||
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
|
||||
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device(), order};
|
||||
|
||||
Reference in New Issue
Block a user