Support vertical federated learning with gpu_hist (#9539)
This commit is contained in:
@@ -26,7 +26,6 @@ namespace collective {
|
||||
* applied there, with the results broadcast to other workers.
|
||||
*
|
||||
* @tparam Function The function used to calculate the results.
|
||||
* @tparam Args Arguments to the function.
|
||||
* @param info MetaInfo about the DMatrix.
|
||||
* @param buffer The buffer storing the results.
|
||||
* @param size The size of the buffer.
|
||||
@@ -57,6 +56,52 @@ void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&&
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Apply the given function where the labels are.
|
||||
*
|
||||
* Normally all the workers have access to the labels, so the function is just applied locally. In
|
||||
* vertical federated learning, we assume labels are only available on worker 0, so the function is
|
||||
* applied there, with the results broadcast to other workers.
|
||||
*
|
||||
* @tparam T Type of the HostDeviceVector storing the results.
|
||||
* @tparam Function The function used to calculate the results.
|
||||
* @param info MetaInfo about the DMatrix.
|
||||
* @param result The HostDeviceVector storing the results.
|
||||
* @param function The function used to calculate the results.
|
||||
*/
|
||||
template <typename T, typename Function>
|
||||
void ApplyWithLabels(MetaInfo const& info, HostDeviceVector<T>* result, Function&& function) {
|
||||
if (info.IsVerticalFederated()) {
|
||||
// We assume labels are only available on worker 0, so the calculation is done there and result
|
||||
// broadcast to other workers.
|
||||
std::string message;
|
||||
if (collective::GetRank() == 0) {
|
||||
try {
|
||||
std::forward<Function>(function)();
|
||||
} catch (dmlc::Error& e) {
|
||||
message = e.what();
|
||||
}
|
||||
}
|
||||
|
||||
collective::Broadcast(&message, 0);
|
||||
if (!message.empty()) {
|
||||
LOG(FATAL) << &message[0];
|
||||
return;
|
||||
}
|
||||
|
||||
std::size_t size{};
|
||||
if (collective::GetRank() == 0) {
|
||||
size = result->Size();
|
||||
}
|
||||
collective::Broadcast(&size, sizeof(std::size_t), 0);
|
||||
|
||||
result->Resize(size);
|
||||
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
|
||||
} else {
|
||||
std::forward<Function>(function)();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Find the global max of the given value across all workers.
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user