Allow insecure gRPC connections for federated learning (#8181)
* Allow insecure gRPC connections for federated learning * format
This commit is contained in:
parent
53d2a733b0
commit
ad3bc0edee
@ -42,9 +42,14 @@ class FederatedEngine : public IEngine {
|
|||||||
}
|
}
|
||||||
utils::Printf("Connecting to federated server %s, world size %d, rank %d",
|
utils::Printf("Connecting to federated server %s, world size %d, rank %d",
|
||||||
server_address_.c_str(), world_size_, rank_);
|
server_address_.c_str(), world_size_, rank_);
|
||||||
|
if (server_cert_.empty() || client_key_.empty() || client_cert_.empty()) {
|
||||||
|
utils::Printf("Certificates not specified, turning off SSL.");
|
||||||
|
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_));
|
||||||
|
} else {
|
||||||
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
|
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
|
||||||
client_key_, client_cert_));
|
client_key_, client_cert_));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Finalize() { client_.reset(); }
|
void Finalize() { client_.reset(); }
|
||||||
|
|
||||||
@ -84,13 +89,9 @@ class FederatedEngine : public IEngine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int LoadCheckPoint() override {
|
int LoadCheckPoint() override { return 0; }
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
void CheckPoint() override {
|
void CheckPoint() override { version_number_ += 1; }
|
||||||
version_number_ += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
int VersionNumber() const override { return version_number_; }
|
int VersionNumber() const override { return version_number_; }
|
||||||
|
|
||||||
|
|||||||
@ -33,7 +33,7 @@ class FederatedClient {
|
|||||||
}()},
|
}()},
|
||||||
rank_{rank} {}
|
rank_{rank} {}
|
||||||
|
|
||||||
/** @brief Insecure client for testing only. */
|
/** @brief Insecure client for connecting to localhost only. */
|
||||||
FederatedClient(std::string const &server_address, int rank)
|
FederatedClient(std::string const &server_address, int rank)
|
||||||
: stub_{Federated::NewStub(
|
: stub_{Federated::NewStub(
|
||||||
grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials()))},
|
grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials()))},
|
||||||
|
|||||||
@ -231,5 +231,20 @@ void RunServer(int port, int world_size, char const* server_key_file, char const
|
|||||||
server->Wait();
|
server->Wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RunInsecureServer(int port, int world_size) {
|
||||||
|
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||||
|
FederatedService service{world_size};
|
||||||
|
|
||||||
|
grpc::ServerBuilder builder;
|
||||||
|
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||||
|
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||||
|
builder.RegisterService(&service);
|
||||||
|
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
|
||||||
|
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
|
||||||
|
<< world_size;
|
||||||
|
|
||||||
|
server->Wait();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace federated
|
} // namespace federated
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -40,5 +40,7 @@ class FederatedService final : public Federated::Service {
|
|||||||
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
|
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
|
||||||
char const* client_cert_file);
|
char const* client_cert_file);
|
||||||
|
|
||||||
|
void RunInsecureServer(int port, int world_size);
|
||||||
|
|
||||||
} // namespace federated
|
} // namespace federated
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -6,9 +6,9 @@ from .core import _LIB, XGBoostError, _check_call, build_info, c_str
|
|||||||
def run_federated_server(
|
def run_federated_server(
|
||||||
port: int,
|
port: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
server_key_path: str,
|
server_key_path: str = "",
|
||||||
server_cert_path: str,
|
server_cert_path: str = "",
|
||||||
client_cert_path: str,
|
client_cert_path: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run the Federated Learning server.
|
"""Run the Federated Learning server.
|
||||||
|
|
||||||
@ -19,13 +19,16 @@ def run_federated_server(
|
|||||||
world_size: int
|
world_size: int
|
||||||
The number of federated workers.
|
The number of federated workers.
|
||||||
server_key_path: str
|
server_key_path: str
|
||||||
Path to the server private key file.
|
Path to the server private key file. SSL is turned off if empty.
|
||||||
server_cert_path: str
|
server_cert_path: str
|
||||||
Path to the server certificate file.
|
Path to the server certificate file. SSL is turned off if empty.
|
||||||
client_cert_path: str
|
client_cert_path: str
|
||||||
Path to the client certificate file.
|
Path to the client certificate file. SSL is turned off if empty.
|
||||||
"""
|
"""
|
||||||
if build_info()["USE_FEDERATED"]:
|
if build_info()["USE_FEDERATED"]:
|
||||||
|
if not server_key_path or not server_cert_path or not client_cert_path:
|
||||||
|
_check_call(_LIB.XGBRunInsecureFederatedServer(port, world_size))
|
||||||
|
else:
|
||||||
_check_call(
|
_check_call(
|
||||||
_LIB.XGBRunFederatedServer(
|
_LIB.XGBRunFederatedServer(
|
||||||
port,
|
port,
|
||||||
|
|||||||
@ -1377,6 +1377,13 @@ XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_k
|
|||||||
federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path);
|
federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run a server without SSL for local testing.
|
||||||
|
XGB_DLL int XGBRunInsecureFederatedServer(int port, int world_size) {
|
||||||
|
API_BEGIN();
|
||||||
|
federated::RunInsecureServer(port, world_size);
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// force link rabit
|
// force link rabit
|
||||||
|
|||||||
@ -12,21 +12,28 @@ CLIENT_KEY = 'client-key.pem'
|
|||||||
CLIENT_CERT = 'client-cert.pem'
|
CLIENT_CERT = 'client-cert.pem'
|
||||||
|
|
||||||
|
|
||||||
def run_server(port: int, world_size: int) -> None:
|
def run_server(port: int, world_size: int, with_ssl: bool) -> None:
|
||||||
|
if with_ssl:
|
||||||
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT,
|
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT,
|
||||||
CLIENT_CERT)
|
CLIENT_CERT)
|
||||||
|
else:
|
||||||
|
xgboost.federated.run_federated_server(port, world_size)
|
||||||
|
|
||||||
|
|
||||||
def run_worker(port: int, world_size: int, rank: int, with_gpu: bool) -> None:
|
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
|
||||||
# Always call this before using distributed module
|
|
||||||
rabit_env = [
|
rabit_env = [
|
||||||
f'federated_server_address=localhost:{port}',
|
f'federated_server_address=localhost:{port}',
|
||||||
f'federated_world_size={world_size}',
|
f'federated_world_size={world_size}',
|
||||||
f'federated_rank={rank}',
|
f'federated_rank={rank}'
|
||||||
|
]
|
||||||
|
if with_ssl:
|
||||||
|
rabit_env = rabit_env + [
|
||||||
f'federated_server_cert={SERVER_CERT}',
|
f'federated_server_cert={SERVER_CERT}',
|
||||||
f'federated_client_key={CLIENT_KEY}',
|
f'federated_client_key={CLIENT_KEY}',
|
||||||
f'federated_client_cert={CLIENT_CERT}'
|
f'federated_client_cert={CLIENT_CERT}'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Always call this before using distributed module
|
||||||
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
|
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
|
||||||
# Load file, file will not be sharded in federated mode.
|
# Load file, file will not be sharded in federated mode.
|
||||||
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
|
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
|
||||||
@ -52,11 +59,11 @@ def run_worker(port: int, world_size: int, rank: int, with_gpu: bool) -> None:
|
|||||||
xgb.rabit.tracker_print("Finished training\n")
|
xgb.rabit.tracker_print("Finished training\n")
|
||||||
|
|
||||||
|
|
||||||
def run_test(with_gpu: bool = False) -> None:
|
def run_test(with_ssl: bool = True, with_gpu: bool = False) -> None:
|
||||||
port = 9091
|
port = 9091
|
||||||
world_size = int(sys.argv[1])
|
world_size = int(sys.argv[1])
|
||||||
|
|
||||||
server = multiprocessing.Process(target=run_server, args=(port, world_size))
|
server = multiprocessing.Process(target=run_server, args=(port, world_size, with_ssl))
|
||||||
server.start()
|
server.start()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
if not server.is_alive():
|
if not server.is_alive():
|
||||||
@ -64,7 +71,8 @@ def run_test(with_gpu: bool = False) -> None:
|
|||||||
|
|
||||||
workers = []
|
workers = []
|
||||||
for rank in range(world_size):
|
for rank in range(world_size):
|
||||||
worker = multiprocessing.Process(target=run_worker, args=(port, world_size, rank, with_gpu))
|
worker = multiprocessing.Process(target=run_worker,
|
||||||
|
args=(port, world_size, rank, with_ssl, with_gpu))
|
||||||
workers.append(worker)
|
workers.append(worker)
|
||||||
worker.start()
|
worker.start()
|
||||||
for worker in workers:
|
for worker in workers:
|
||||||
@ -73,5 +81,7 @@ def run_test(with_gpu: bool = False) -> None:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_test()
|
run_test(with_ssl=True, with_gpu=False)
|
||||||
run_test(with_gpu=True)
|
run_test(with_ssl=False, with_gpu=False)
|
||||||
|
run_test(with_ssl=True, with_gpu=True)
|
||||||
|
run_test(with_ssl=False, with_gpu=True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user