From 7017dd5a261a8546d5f7aaaf17a5a5df09a4a733 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 27 Jul 2021 16:20:42 +0800 Subject: [PATCH] [JVM-Packages] Use Python tracker in XGBoost for JVM package. (#7132) --- jvm-packages/create_jni.py | 3 +- python-package/xgboost/dask.py | 2 +- python-package/xgboost/tracker.py | 124 ++++++++++++++++++++++++------ 3 files changed, 101 insertions(+), 28 deletions(-) diff --git a/jvm-packages/create_jni.py b/jvm-packages/create_jni.py index b8f6c1480..03937b78a 100755 --- a/jvm-packages/create_jni.py +++ b/jvm-packages/create_jni.py @@ -148,8 +148,7 @@ if __name__ == "__main__": cp("../lib/" + library_name, output_folder) print("copying pure-Python tracker") - cp("../dmlc-core/tracker/dmlc_tracker/tracker.py", - "{}/src/main/resources".format(xgboost4j)) + cp("../python-package/xgboost/tracker.py", "{}/src/main/resources".format(xgboost4j)) print("copying train/test files") maybe_makedirs("{}/src/test/resources".format(xgboost4j_spark)) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index f203f88b2..a73dc2638 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -127,7 +127,7 @@ def _start_tracker(n_workers: int) -> Dict[str, Any]: """Start Rabit tracker """ env = {'DMLC_NUM_WORKER': n_workers} host = get_host_ip('auto') - rabit_context = RabitTracker(hostIP=host, nslave=n_workers) + rabit_context = RabitTracker(hostIP=host, nslave=n_workers, use_logger=False) env.update(rabit_context.slave_envs()) rabit_context.start(n_workers) diff --git a/python-package/xgboost/tracker.py b/python-package/xgboost/tracker.py index 02cba4aa0..f69273e1b 100644 --- a/python-package/xgboost/tracker.py +++ b/python-package/xgboost/tracker.py @@ -10,6 +10,8 @@ import struct import time import logging from threading import Thread +import argparse +import sys class ExSocket(object): @@ -52,29 +54,6 @@ def get_some_ip(host): return socket.getaddrinfo(host, None)[0][4][0] -def get_host_ip(hostIP=None): - if hostIP is None or hostIP == 'auto': - hostIP = 'ip' - - if hostIP == 'dns': - hostIP = socket.getfqdn() - elif hostIP == 'ip': - from socket import gaierror - try: - hostIP = socket.gethostbyname(socket.getfqdn()) - except gaierror: - logging.debug( - 'gethostbyname(socket.getfqdn()) failed... trying on hostname()' - ) - hostIP = socket.gethostbyname(socket.gethostname()) - if hostIP.startswith("127."): - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - # doesn't have to be reachable - s.connect(('10.255.255.255', 1)) - hostIP = s.getsockname()[0] - return hostIP - - def get_family(addr): return socket.getaddrinfo(addr, None)[0][0] @@ -164,7 +143,18 @@ class RabitTracker(object): tracker for rabit """ - def __init__(self, hostIP, nslave, port=9091, port_end=9999): + def __init__( + self, hostIP, nslave, port=9091, port_end=9999, use_logger: bool = True + ) -> None: + """A Python implementation of RABIT tracker. + + Parameters + .......... + use_logger: + Use logging.info for tracker print command. When set to False, Python print + function is used instead. + + """ sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM) for _port in range(port, port_end): try: @@ -182,6 +172,7 @@ class RabitTracker(object): self.start_time = None self.end_time = None self.nslave = nslave + self._use_logger = use_logger logging.info('start listen on %s:%d', hostIP, self.port) def __del__(self): @@ -293,7 +284,11 @@ class RabitTracker(object): s = SlaveEntry(fd, s_addr) if s.cmd == 'print': msg = s.sock.recvstr() - print(msg.strip(), flush=True) + # On dask we use print to avoid setting global verbosity. + if self._use_logger: + logging.info(msg.strip()) + else: + print(msg.strip(), flush=True) continue if s.cmd == 'shutdown': assert s.rank >= 0 and s.rank not in shutdown @@ -357,3 +352,82 @@ class RabitTracker(object): def alive(self): return self.thread.is_alive() + + +def get_host_ip(hostIP=None): + if hostIP is None or hostIP == 'auto': + hostIP = 'ip' + + if hostIP == 'dns': + hostIP = socket.getfqdn() + elif hostIP == 'ip': + from socket import gaierror + try: + hostIP = socket.gethostbyname(socket.getfqdn()) + except gaierror: + logging.debug( + 'gethostbyname(socket.getfqdn()) failed... trying on hostname()' + ) + hostIP = socket.gethostbyname(socket.gethostname()) + if hostIP.startswith("127."): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + # doesn't have to be reachable + s.connect(('10.255.255.255', 1)) + hostIP = s.getsockname()[0] + return hostIP + + +def start_rabit_tracker(args): + """Standalone function to start rabit tracker. + + Parameters + ---------- + args: arguments to start the rabit tracker. + """ + envs = {'DMLC_NUM_WORKER': args.num_workers, + 'DMLC_NUM_SERVER': args.num_servers} + rabit = RabitTracker(hostIP=get_host_ip(args.host_ip), nslave=args.num_workers) + envs.update(rabit.slave_envs()) + rabit.start(args.num_workers) + sys.stdout.write('DMLC_TRACKER_ENV_START\n') + # simply write configuration to stdout + for k, v in envs.items(): + sys.stdout.write('%s=%s\n' % (k, str(v))) + sys.stdout.write('DMLC_TRACKER_ENV_END\n') + sys.stdout.flush() + rabit.join() + + +def main(): + """Main function if tracker is executed in standalone mode.""" + parser = argparse.ArgumentParser(description='Rabit Tracker start.') + parser.add_argument('--num-workers', required=True, type=int, + help='Number of worker proccess to be launched.') + parser.add_argument('--num-servers', default=0, type=int, + help='Number of server process to be launched. Only used in PS jobs.') + parser.add_argument('--host-ip', default=None, type=str, + help=('Host IP addressed, this is only needed ' + + 'if the host IP cannot be automatically guessed.')) + parser.add_argument('--log-level', default='INFO', type=str, + choices=['INFO', 'DEBUG'], + help='Logging level of the logger.') + args = parser.parse_args() + + fmt = '%(asctime)s %(levelname)s %(message)s' + if args.log_level == 'INFO': + level = logging.INFO + elif args.log_level == 'DEBUG': + level = logging.DEBUG + else: + raise RuntimeError("Unknown logging level %s" % args.log_level) + + logging.basicConfig(format=fmt, level=level) + + if args.num_servers == 0: + start_rabit_tracker(args) + else: + raise RuntimeError("Do not yet support start ps tracker in standalone mode.") + + +if __name__ == "__main__": + main()