[rabit] Small cleanup to tracker initialization. (#9524)
- Remove recover related code. - Clean startup, no need to consider previously connected nodes.
This commit is contained in:
parent
209335b18c
commit
1b87a1d8f8
@ -137,15 +137,9 @@ class WorkerEntry:
|
|||||||
return self._get_remote(wait_conn, nnset)
|
return self._get_remote(wait_conn, nnset)
|
||||||
|
|
||||||
def _get_remote(
|
def _get_remote(
|
||||||
self, wait_conn: Dict[int, "WorkerEntry"], nnset: Set[int]
|
self, wait_conn: Dict[int, "WorkerEntry"], badset: Set[int]
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
while True:
|
while True:
|
||||||
ngood = self.sock.recvint()
|
|
||||||
goodset = set()
|
|
||||||
for _ in range(ngood):
|
|
||||||
goodset.add(self.sock.recvint())
|
|
||||||
assert goodset.issubset(nnset)
|
|
||||||
badset = nnset - goodset
|
|
||||||
conset = []
|
conset = []
|
||||||
for r in badset:
|
for r in badset:
|
||||||
if r in wait_conn:
|
if r in wait_conn:
|
||||||
@ -343,7 +337,7 @@ class RabitTracker:
|
|||||||
shutdown[s.rank] = s
|
shutdown[s.rank] = s
|
||||||
logging.debug("Received %s signal from %d", s.cmd, s.rank)
|
logging.debug("Received %s signal from %d", s.cmd, s.rank)
|
||||||
continue
|
continue
|
||||||
assert s.cmd in ("start", "recover")
|
assert s.cmd == "start"
|
||||||
# lazily initialize the workers
|
# lazily initialize the workers
|
||||||
if tree_map is None:
|
if tree_map is None:
|
||||||
assert s.cmd == "start"
|
assert s.cmd == "start"
|
||||||
|
|||||||
@ -318,21 +318,10 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
// get number of to connect and number of to accept nodes from tracker
|
// get number of to connect and number of to accept nodes from tracker
|
||||||
int num_conn, num_accept, num_error = 1;
|
int num_conn, num_accept, num_error = 1;
|
||||||
do {
|
do {
|
||||||
// send over good links
|
|
||||||
std::vector<int> good_link;
|
|
||||||
for (auto & all_link : all_links) {
|
for (auto & all_link : all_links) {
|
||||||
if (!all_link.sock.BadSocket()) {
|
all_link.sock.Close();
|
||||||
good_link.push_back(static_cast<int>(all_link.rank));
|
|
||||||
} else {
|
|
||||||
if (!all_link.sock.IsClosed()) all_link.sock.Close();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
int ngood = static_cast<int>(good_link.size());
|
|
||||||
// tracker construct goodset
|
// tracker construct goodset
|
||||||
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood), "ReConnectLink failure 5");
|
|
||||||
for (int &i : good_link) {
|
|
||||||
Assert(tracker.SendAll(&i, sizeof(i)) == sizeof(i), "ReConnectLink failure 6");
|
|
||||||
}
|
|
||||||
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
||||||
"ReConnectLink failure 7");
|
"ReConnectLink failure 7");
|
||||||
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept),
|
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user