[JVM-Packages] Use Python tracker in XGBoost for JVM package. (#7132)

This commit is contained in:
Jiaming Yuan 2021-07-27 16:20:42 +08:00 committed by GitHub
parent 48d5de80a2
commit 7017dd5a26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 101 additions and 28 deletions

View File

@ -148,8 +148,7 @@ if __name__ == "__main__":
cp("../lib/" + library_name, output_folder)
print("copying pure-Python tracker")
cp("../dmlc-core/tracker/dmlc_tracker/tracker.py",
"{}/src/main/resources".format(xgboost4j))
cp("../python-package/xgboost/tracker.py", "{}/src/main/resources".format(xgboost4j))
print("copying train/test files")
maybe_makedirs("{}/src/test/resources".format(xgboost4j_spark))

View File

@ -127,7 +127,7 @@ def _start_tracker(n_workers: int) -> Dict[str, Any]:
"""Start Rabit tracker """
env = {'DMLC_NUM_WORKER': n_workers}
host = get_host_ip('auto')
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
rabit_context = RabitTracker(hostIP=host, nslave=n_workers, use_logger=False)
env.update(rabit_context.slave_envs())
rabit_context.start(n_workers)

View File

@ -10,6 +10,8 @@ import struct
import time
import logging
from threading import Thread
import argparse
import sys
class ExSocket(object):
@ -52,29 +54,6 @@ def get_some_ip(host):
return socket.getaddrinfo(host, None)[0][4][0]
def get_host_ip(hostIP=None):
if hostIP is None or hostIP == 'auto':
hostIP = 'ip'
if hostIP == 'dns':
hostIP = socket.getfqdn()
elif hostIP == 'ip':
from socket import gaierror
try:
hostIP = socket.gethostbyname(socket.getfqdn())
except gaierror:
logging.debug(
'gethostbyname(socket.getfqdn()) failed... trying on hostname()'
)
hostIP = socket.gethostbyname(socket.gethostname())
if hostIP.startswith("127."):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# doesn't have to be reachable
s.connect(('10.255.255.255', 1))
hostIP = s.getsockname()[0]
return hostIP
def get_family(addr):
return socket.getaddrinfo(addr, None)[0][0]
@ -164,7 +143,18 @@ class RabitTracker(object):
tracker for rabit
"""
def __init__(self, hostIP, nslave, port=9091, port_end=9999):
def __init__(
self, hostIP, nslave, port=9091, port_end=9999, use_logger: bool = True
) -> None:
"""A Python implementation of RABIT tracker.
Parameters
..........
use_logger:
Use logging.info for tracker print command. When set to False, Python print
function is used instead.
"""
sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM)
for _port in range(port, port_end):
try:
@ -182,6 +172,7 @@ class RabitTracker(object):
self.start_time = None
self.end_time = None
self.nslave = nslave
self._use_logger = use_logger
logging.info('start listen on %s:%d', hostIP, self.port)
def __del__(self):
@ -293,7 +284,11 @@ class RabitTracker(object):
s = SlaveEntry(fd, s_addr)
if s.cmd == 'print':
msg = s.sock.recvstr()
print(msg.strip(), flush=True)
# On dask we use print to avoid setting global verbosity.
if self._use_logger:
logging.info(msg.strip())
else:
print(msg.strip(), flush=True)
continue
if s.cmd == 'shutdown':
assert s.rank >= 0 and s.rank not in shutdown
@ -357,3 +352,82 @@ class RabitTracker(object):
def alive(self):
return self.thread.is_alive()
def get_host_ip(hostIP=None):
if hostIP is None or hostIP == 'auto':
hostIP = 'ip'
if hostIP == 'dns':
hostIP = socket.getfqdn()
elif hostIP == 'ip':
from socket import gaierror
try:
hostIP = socket.gethostbyname(socket.getfqdn())
except gaierror:
logging.debug(
'gethostbyname(socket.getfqdn()) failed... trying on hostname()'
)
hostIP = socket.gethostbyname(socket.gethostname())
if hostIP.startswith("127."):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# doesn't have to be reachable
s.connect(('10.255.255.255', 1))
hostIP = s.getsockname()[0]
return hostIP
def start_rabit_tracker(args):
"""Standalone function to start rabit tracker.
Parameters
----------
args: arguments to start the rabit tracker.
"""
envs = {'DMLC_NUM_WORKER': args.num_workers,
'DMLC_NUM_SERVER': args.num_servers}
rabit = RabitTracker(hostIP=get_host_ip(args.host_ip), nslave=args.num_workers)
envs.update(rabit.slave_envs())
rabit.start(args.num_workers)
sys.stdout.write('DMLC_TRACKER_ENV_START\n')
# simply write configuration to stdout
for k, v in envs.items():
sys.stdout.write('%s=%s\n' % (k, str(v)))
sys.stdout.write('DMLC_TRACKER_ENV_END\n')
sys.stdout.flush()
rabit.join()
def main():
"""Main function if tracker is executed in standalone mode."""
parser = argparse.ArgumentParser(description='Rabit Tracker start.')
parser.add_argument('--num-workers', required=True, type=int,
help='Number of worker proccess to be launched.')
parser.add_argument('--num-servers', default=0, type=int,
help='Number of server process to be launched. Only used in PS jobs.')
parser.add_argument('--host-ip', default=None, type=str,
help=('Host IP addressed, this is only needed ' +
'if the host IP cannot be automatically guessed.'))
parser.add_argument('--log-level', default='INFO', type=str,
choices=['INFO', 'DEBUG'],
help='Logging level of the logger.')
args = parser.parse_args()
fmt = '%(asctime)s %(levelname)s %(message)s'
if args.log_level == 'INFO':
level = logging.INFO
elif args.log_level == 'DEBUG':
level = logging.DEBUG
else:
raise RuntimeError("Unknown logging level %s" % args.log_level)
logging.basicConfig(format=fmt, level=level)
if args.num_servers == 0:
start_rabit_tracker(args)
else:
raise RuntimeError("Do not yet support start ps tracker in standalone mode.")
if __name__ == "__main__":
main()