More refactoring to take advantage of collective aggregators (#9081)
This commit is contained in:
@@ -31,18 +31,16 @@ namespace collective {
|
||||
* @param buffer The buffer storing the results.
|
||||
* @param size The size of the buffer.
|
||||
* @param function The function used to calculate the results.
|
||||
* @param args Arguments to the function.
|
||||
*/
|
||||
template <typename Function, typename T, typename... Args>
|
||||
void ApplyWithLabels(MetaInfo const& info, T* buffer, size_t size, Function&& function,
|
||||
Args&&... args) {
|
||||
template <typename Function>
|
||||
void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& function) {
|
||||
if (info.IsVerticalFederated()) {
|
||||
// We assume labels are only available on worker 0, so the calculation is done there and result
|
||||
// broadcast to other workers.
|
||||
std::string message;
|
||||
if (collective::GetRank() == 0) {
|
||||
try {
|
||||
std::forward<Function>(function)(std::forward<Args>(args)...);
|
||||
std::forward<Function>(function)();
|
||||
} catch (dmlc::Error& e) {
|
||||
message = e.what();
|
||||
}
|
||||
@@ -55,7 +53,7 @@ void ApplyWithLabels(MetaInfo const& info, T* buffer, size_t size, Function&& fu
|
||||
LOG(FATAL) << &message[0];
|
||||
}
|
||||
} else {
|
||||
std::forward<Function>(function)(std::forward<Args>(args)...);
|
||||
std::forward<Function>(function)();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user