Use simple print in tracker print function. (#6609)
This commit is contained in:
parent
26982f9fce
commit
7bc56fa0ed
@ -1,7 +1,6 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
# pylint: disable= invalid-name
|
# pylint: disable= invalid-name
|
||||||
"""Distributed XGBoost Rabit related API."""
|
"""Distributed XGBoost Rabit related API."""
|
||||||
import sys
|
|
||||||
import ctypes
|
import ctypes
|
||||||
import pickle
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -79,8 +78,7 @@ def tracker_print(msg):
|
|||||||
if is_dist != 0:
|
if is_dist != 0:
|
||||||
_check_call(_LIB.RabitTrackerPrint(c_str(msg)))
|
_check_call(_LIB.RabitTrackerPrint(c_str(msg)))
|
||||||
else:
|
else:
|
||||||
sys.stdout.write(msg)
|
print(msg.strip(), flush=True)
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
|
|
||||||
def get_processor_name():
|
def get_processor_name():
|
||||||
|
|||||||
@ -293,7 +293,7 @@ class RabitTracker(object):
|
|||||||
s = SlaveEntry(fd, s_addr)
|
s = SlaveEntry(fd, s_addr)
|
||||||
if s.cmd == 'print':
|
if s.cmd == 'print':
|
||||||
msg = s.sock.recvstr()
|
msg = s.sock.recvstr()
|
||||||
logging.info(msg.strip())
|
print(msg.strip(), flush=True)
|
||||||
continue
|
continue
|
||||||
if s.cmd == 'shutdown':
|
if s.cmd == 'shutdown':
|
||||||
assert s.rank >= 0 and s.rank not in shutdown
|
assert s.rank >= 0 and s.rank not in shutdown
|
||||||
|
|||||||
@ -867,7 +867,9 @@ class TestWithDask:
|
|||||||
|
|
||||||
test = "--gtest_filter=Quantile." + name
|
test = "--gtest_filter=Quantile." + name
|
||||||
|
|
||||||
def runit(worker_addr: str, rabit_args: List[bytes]) -> subprocess.CompletedProcess:
|
def runit(
|
||||||
|
worker_addr: str, rabit_args: List[bytes]
|
||||||
|
) -> subprocess.CompletedProcess:
|
||||||
port_env = ''
|
port_env = ''
|
||||||
# setup environment for running the c++ part.
|
# setup environment for running the c++ part.
|
||||||
for arg in rabit_args:
|
for arg in rabit_args:
|
||||||
@ -965,7 +967,9 @@ class TestWithDask:
|
|||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
path = os.path.join(tmpdir, 'log')
|
path = os.path.join(tmpdir, 'log')
|
||||||
|
|
||||||
def sqr(labels: np.ndarray, predts: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
def sqr(
|
||||||
|
labels: np.ndarray, predts: np.ndarray
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
with open(path, 'a') as fd:
|
with open(path, 'a') as fd:
|
||||||
print('Running sqr', file=fd)
|
print('Running sqr', file=fd)
|
||||||
grad = predts - labels
|
grad = predts - labels
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user