Allow insecure gRPC connections for federated learning (#8181)
* Allow insecure gRPC connections for federated learning * format
This commit is contained in:
@@ -42,8 +42,13 @@ class FederatedEngine : public IEngine {
|
||||
}
|
||||
utils::Printf("Connecting to federated server %s, world size %d, rank %d",
|
||||
server_address_.c_str(), world_size_, rank_);
|
||||
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
|
||||
client_key_, client_cert_));
|
||||
if (server_cert_.empty() || client_key_.empty() || client_cert_.empty()) {
|
||||
utils::Printf("Certificates not specified, turning off SSL.");
|
||||
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_));
|
||||
} else {
|
||||
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
|
||||
client_key_, client_cert_));
|
||||
}
|
||||
}
|
||||
|
||||
void Finalize() { client_.reset(); }
|
||||
@@ -84,13 +89,9 @@ class FederatedEngine : public IEngine {
|
||||
}
|
||||
}
|
||||
|
||||
int LoadCheckPoint() override {
|
||||
return 0;
|
||||
}
|
||||
int LoadCheckPoint() override { return 0; }
|
||||
|
||||
void CheckPoint() override {
|
||||
version_number_ += 1;
|
||||
}
|
||||
void CheckPoint() override { version_number_ += 1; }
|
||||
|
||||
int VersionNumber() const override { return version_number_; }
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class FederatedClient {
|
||||
}()},
|
||||
rank_{rank} {}
|
||||
|
||||
/** @brief Insecure client for testing only. */
|
||||
/** @brief Insecure client for connecting to localhost only. */
|
||||
FederatedClient(std::string const &server_address, int rank)
|
||||
: stub_{Federated::NewStub(
|
||||
grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials()))},
|
||||
|
||||
@@ -231,5 +231,20 @@ void RunServer(int port, int world_size, char const* server_key_file, char const
|
||||
server->Wait();
|
||||
}
|
||||
|
||||
void RunInsecureServer(int port, int world_size) {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedService service{world_size};
|
||||
|
||||
grpc::ServerBuilder builder;
|
||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
builder.RegisterService(&service);
|
||||
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
|
||||
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
|
||||
<< world_size;
|
||||
|
||||
server->Wait();
|
||||
}
|
||||
|
||||
} // namespace federated
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -40,5 +40,7 @@ class FederatedService final : public Federated::Service {
|
||||
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
|
||||
char const* client_cert_file);
|
||||
|
||||
void RunInsecureServer(int port, int world_size);
|
||||
|
||||
} // namespace federated
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user