More refactoring to take advantage of collective aggregators (#9081)

This commit is contained in:
Rong Ou
2023-04-25 12:36:09 -07:00
committed by GitHub
parent 49ccae7fb9
commit a320b402a5
10 changed files with 81 additions and 81 deletions

View File

@@ -31,18 +31,16 @@ namespace collective {
* @param buffer The buffer storing the results.
* @param size The size of the buffer.
* @param function The function used to calculate the results.
* @param args Arguments to the function.
*/
template <typename Function, typename T, typename... Args>
void ApplyWithLabels(MetaInfo const& info, T* buffer, size_t size, Function&& function,
Args&&... args) {
template <typename Function>
void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& function) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
std::string message;
if (collective::GetRank() == 0) {
try {
std::forward<Function>(function)(std::forward<Args>(args)...);
std::forward<Function>(function)();
} catch (dmlc::Error& e) {
message = e.what();
}
@@ -55,7 +53,7 @@ void ApplyWithLabels(MetaInfo const& info, T* buffer, size_t size, Function&& fu
LOG(FATAL) << &message[0];
}
} else {
std::forward<Function>(function)(std::forward<Args>(args)...);
std::forward<Function>(function)();
}
}