Allow insecure gRPC connections for federated learning (#8181)

* Allow insecure gRPC connections for federated learning

* format
This commit is contained in:
Rong Ou
2022-08-18 21:16:14 -07:00
committed by GitHub
parent 53d2a733b0
commit ad3bc0edee
7 changed files with 75 additions and 37 deletions

View File

@@ -42,8 +42,13 @@ class FederatedEngine : public IEngine {
}
utils::Printf("Connecting to federated server %s, world size %d, rank %d",
server_address_.c_str(), world_size_, rank_);
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
client_key_, client_cert_));
if (server_cert_.empty() || client_key_.empty() || client_cert_.empty()) {
utils::Printf("Certificates not specified, turning off SSL.");
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_));
} else {
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
client_key_, client_cert_));
}
}
void Finalize() { client_.reset(); }
@@ -84,13 +89,9 @@ class FederatedEngine : public IEngine {
}
}
int LoadCheckPoint() override {
return 0;
}
int LoadCheckPoint() override { return 0; }
void CheckPoint() override {
version_number_ += 1;
}
void CheckPoint() override { version_number_ += 1; }
int VersionNumber() const override { return version_number_; }

View File

@@ -33,7 +33,7 @@ class FederatedClient {
}()},
rank_{rank} {}
/** @brief Insecure client for testing only. */
/** @brief Insecure client for connecting to localhost only. */
FederatedClient(std::string const &server_address, int rank)
: stub_{Federated::NewStub(
grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials()))},

View File

@@ -231,5 +231,20 @@ void RunServer(int port, int world_size, char const* server_key_file, char const
server->Wait();
}
void RunInsecureServer(int port, int world_size) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
FederatedService service{world_size};
grpc::ServerBuilder builder;
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
<< world_size;
server->Wait();
}
} // namespace federated
} // namespace xgboost

View File

@@ -40,5 +40,7 @@ class FederatedService final : public Federated::Service {
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
char const* client_cert_file);
void RunInsecureServer(int port, int world_size);
} // namespace federated
} // namespace xgboost