[rabit] Small cleanup to tracker initialization. (#9524)

- Remove recover related code.
- Clean startup, no need to consider previously connected nodes.
This commit is contained in:
Jiaming Yuan 2023-08-27 05:10:59 +08:00 committed by GitHub
parent 209335b18c
commit 1b87a1d8f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 20 deletions

View File

@ -137,15 +137,9 @@ class WorkerEntry:
return self._get_remote(wait_conn, nnset) return self._get_remote(wait_conn, nnset)
def _get_remote( def _get_remote(
self, wait_conn: Dict[int, "WorkerEntry"], nnset: Set[int] self, wait_conn: Dict[int, "WorkerEntry"], badset: Set[int]
) -> List[int]: ) -> List[int]:
while True: while True:
ngood = self.sock.recvint()
goodset = set()
for _ in range(ngood):
goodset.add(self.sock.recvint())
assert goodset.issubset(nnset)
badset = nnset - goodset
conset = [] conset = []
for r in badset: for r in badset:
if r in wait_conn: if r in wait_conn:
@ -343,7 +337,7 @@ class RabitTracker:
shutdown[s.rank] = s shutdown[s.rank] = s
logging.debug("Received %s signal from %d", s.cmd, s.rank) logging.debug("Received %s signal from %d", s.cmd, s.rank)
continue continue
assert s.cmd in ("start", "recover") assert s.cmd == "start"
# lazily initialize the workers # lazily initialize the workers
if tree_map is None: if tree_map is None:
assert s.cmd == "start" assert s.cmd == "start"

View File

@ -318,21 +318,10 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
// get number of to connect and number of to accept nodes from tracker // get number of to connect and number of to accept nodes from tracker
int num_conn, num_accept, num_error = 1; int num_conn, num_accept, num_error = 1;
do { do {
// send over good links
std::vector<int> good_link;
for (auto & all_link : all_links) { for (auto & all_link : all_links) {
if (!all_link.sock.BadSocket()) { all_link.sock.Close();
good_link.push_back(static_cast<int>(all_link.rank));
} else {
if (!all_link.sock.IsClosed()) all_link.sock.Close();
} }
}
int ngood = static_cast<int>(good_link.size());
// tracker construct goodset // tracker construct goodset
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood), "ReConnectLink failure 5");
for (int &i : good_link) {
Assert(tracker.SendAll(&i, sizeof(i)) == sizeof(i), "ReConnectLink failure 6");
}
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn), Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
"ReConnectLink failure 7"); "ReConnectLink failure 7");
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept), Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept),