Use simple print in tracker print function. (#6609)

This commit is contained in:
Jiaming Yuan 2021-01-21 21:15:43 +08:00 committed by GitHub
parent 26982f9fce
commit 7bc56fa0ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 6 deletions

View File

@ -1,7 +1,6 @@
# coding: utf-8
# pylint: disable= invalid-name
"""Distributed XGBoost Rabit related API."""
import sys
import ctypes
import pickle
import numpy as np
@ -79,8 +78,7 @@ def tracker_print(msg):
if is_dist != 0:
_check_call(_LIB.RabitTrackerPrint(c_str(msg)))
else:
sys.stdout.write(msg)
sys.stdout.flush()
print(msg.strip(), flush=True)
def get_processor_name():

View File

@ -293,7 +293,7 @@ class RabitTracker(object):
s = SlaveEntry(fd, s_addr)
if s.cmd == 'print':
msg = s.sock.recvstr()
logging.info(msg.strip())
print(msg.strip(), flush=True)
continue
if s.cmd == 'shutdown':
assert s.rank >= 0 and s.rank not in shutdown

View File

@ -867,7 +867,9 @@ class TestWithDask:
test = "--gtest_filter=Quantile." + name
def runit(worker_addr: str, rabit_args: List[bytes]) -> subprocess.CompletedProcess:
def runit(
worker_addr: str, rabit_args: List[bytes]
) -> subprocess.CompletedProcess:
port_env = ''
# setup environment for running the c++ part.
for arg in rabit_args:
@ -965,7 +967,9 @@ class TestWithDask:
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'log')
def sqr(labels: np.ndarray, predts: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
def sqr(
labels: np.ndarray, predts: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
with open(path, 'a') as fd:
print('Running sqr', file=fd)
grad = predts - labels