xgboost/plugin/federated/federated_tracker.h
2023-10-31 02:39:55 +08:00

42 lines
1.3 KiB
C++

/**
* Copyright 2022-2023, XGBoost contributors
*/
#pragma once
#include <federated.grpc.pb.h> // for Server
#include <future> // for future
#include <memory> // for unique_ptr
#include <string> // for string
#include "../../src/collective/tracker.h" // for Tracker
#include "xgboost/collective/result.h" // for Result
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
class FederatedTracker : public collective::Tracker {
std::unique_ptr<grpc::Server> server_;
std::string server_key_path_;
std::string server_cert_file_;
std::string client_cert_file_;
public:
/**
* @brief CTOR
*
* @param config Configuration, other than the base configuration from Tracker, we have:
*
* - federated_secure: bool whether this is a secure server.
* - server_key_path: path to the key.
* - server_cert_path: certificate path.
* - client_cert_path: certificate path for client.
*/
explicit FederatedTracker(Json const& config);
~FederatedTracker() override;
std::future<Result> Run() override;
// federated tracker do not provide initialization parameters, users have to provide it
// themseleves.
[[nodiscard]] Json WorkerArgs() const override { return Json{Null{}}; }
[[nodiscard]] Result Shutdown();
};
} // namespace xgboost::collective