[JVM-Packages] Use Python tracker in XGBoost for JVM package. (#7132)
This commit is contained in:
parent
48d5de80a2
commit
7017dd5a26
@ -148,8 +148,7 @@ if __name__ == "__main__":
|
||||
cp("../lib/" + library_name, output_folder)
|
||||
|
||||
print("copying pure-Python tracker")
|
||||
cp("../dmlc-core/tracker/dmlc_tracker/tracker.py",
|
||||
"{}/src/main/resources".format(xgboost4j))
|
||||
cp("../python-package/xgboost/tracker.py", "{}/src/main/resources".format(xgboost4j))
|
||||
|
||||
print("copying train/test files")
|
||||
maybe_makedirs("{}/src/test/resources".format(xgboost4j_spark))
|
||||
|
||||
@ -127,7 +127,7 @@ def _start_tracker(n_workers: int) -> Dict[str, Any]:
|
||||
"""Start Rabit tracker """
|
||||
env = {'DMLC_NUM_WORKER': n_workers}
|
||||
host = get_host_ip('auto')
|
||||
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
|
||||
rabit_context = RabitTracker(hostIP=host, nslave=n_workers, use_logger=False)
|
||||
env.update(rabit_context.slave_envs())
|
||||
|
||||
rabit_context.start(n_workers)
|
||||
|
||||
@ -10,6 +10,8 @@ import struct
|
||||
import time
|
||||
import logging
|
||||
from threading import Thread
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
class ExSocket(object):
|
||||
@ -52,29 +54,6 @@ def get_some_ip(host):
|
||||
return socket.getaddrinfo(host, None)[0][4][0]
|
||||
|
||||
|
||||
def get_host_ip(hostIP=None):
|
||||
if hostIP is None or hostIP == 'auto':
|
||||
hostIP = 'ip'
|
||||
|
||||
if hostIP == 'dns':
|
||||
hostIP = socket.getfqdn()
|
||||
elif hostIP == 'ip':
|
||||
from socket import gaierror
|
||||
try:
|
||||
hostIP = socket.gethostbyname(socket.getfqdn())
|
||||
except gaierror:
|
||||
logging.debug(
|
||||
'gethostbyname(socket.getfqdn()) failed... trying on hostname()'
|
||||
)
|
||||
hostIP = socket.gethostbyname(socket.gethostname())
|
||||
if hostIP.startswith("127."):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
# doesn't have to be reachable
|
||||
s.connect(('10.255.255.255', 1))
|
||||
hostIP = s.getsockname()[0]
|
||||
return hostIP
|
||||
|
||||
|
||||
def get_family(addr):
|
||||
return socket.getaddrinfo(addr, None)[0][0]
|
||||
|
||||
@ -164,7 +143,18 @@ class RabitTracker(object):
|
||||
tracker for rabit
|
||||
"""
|
||||
|
||||
def __init__(self, hostIP, nslave, port=9091, port_end=9999):
|
||||
def __init__(
|
||||
self, hostIP, nslave, port=9091, port_end=9999, use_logger: bool = True
|
||||
) -> None:
|
||||
"""A Python implementation of RABIT tracker.
|
||||
|
||||
Parameters
|
||||
..........
|
||||
use_logger:
|
||||
Use logging.info for tracker print command. When set to False, Python print
|
||||
function is used instead.
|
||||
|
||||
"""
|
||||
sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM)
|
||||
for _port in range(port, port_end):
|
||||
try:
|
||||
@ -182,6 +172,7 @@ class RabitTracker(object):
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
self.nslave = nslave
|
||||
self._use_logger = use_logger
|
||||
logging.info('start listen on %s:%d', hostIP, self.port)
|
||||
|
||||
def __del__(self):
|
||||
@ -293,7 +284,11 @@ class RabitTracker(object):
|
||||
s = SlaveEntry(fd, s_addr)
|
||||
if s.cmd == 'print':
|
||||
msg = s.sock.recvstr()
|
||||
print(msg.strip(), flush=True)
|
||||
# On dask we use print to avoid setting global verbosity.
|
||||
if self._use_logger:
|
||||
logging.info(msg.strip())
|
||||
else:
|
||||
print(msg.strip(), flush=True)
|
||||
continue
|
||||
if s.cmd == 'shutdown':
|
||||
assert s.rank >= 0 and s.rank not in shutdown
|
||||
@ -357,3 +352,82 @@ class RabitTracker(object):
|
||||
|
||||
def alive(self):
|
||||
return self.thread.is_alive()
|
||||
|
||||
|
||||
def get_host_ip(hostIP=None):
|
||||
if hostIP is None or hostIP == 'auto':
|
||||
hostIP = 'ip'
|
||||
|
||||
if hostIP == 'dns':
|
||||
hostIP = socket.getfqdn()
|
||||
elif hostIP == 'ip':
|
||||
from socket import gaierror
|
||||
try:
|
||||
hostIP = socket.gethostbyname(socket.getfqdn())
|
||||
except gaierror:
|
||||
logging.debug(
|
||||
'gethostbyname(socket.getfqdn()) failed... trying on hostname()'
|
||||
)
|
||||
hostIP = socket.gethostbyname(socket.gethostname())
|
||||
if hostIP.startswith("127."):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
# doesn't have to be reachable
|
||||
s.connect(('10.255.255.255', 1))
|
||||
hostIP = s.getsockname()[0]
|
||||
return hostIP
|
||||
|
||||
|
||||
def start_rabit_tracker(args):
|
||||
"""Standalone function to start rabit tracker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args: arguments to start the rabit tracker.
|
||||
"""
|
||||
envs = {'DMLC_NUM_WORKER': args.num_workers,
|
||||
'DMLC_NUM_SERVER': args.num_servers}
|
||||
rabit = RabitTracker(hostIP=get_host_ip(args.host_ip), nslave=args.num_workers)
|
||||
envs.update(rabit.slave_envs())
|
||||
rabit.start(args.num_workers)
|
||||
sys.stdout.write('DMLC_TRACKER_ENV_START\n')
|
||||
# simply write configuration to stdout
|
||||
for k, v in envs.items():
|
||||
sys.stdout.write('%s=%s\n' % (k, str(v)))
|
||||
sys.stdout.write('DMLC_TRACKER_ENV_END\n')
|
||||
sys.stdout.flush()
|
||||
rabit.join()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function if tracker is executed in standalone mode."""
|
||||
parser = argparse.ArgumentParser(description='Rabit Tracker start.')
|
||||
parser.add_argument('--num-workers', required=True, type=int,
|
||||
help='Number of worker proccess to be launched.')
|
||||
parser.add_argument('--num-servers', default=0, type=int,
|
||||
help='Number of server process to be launched. Only used in PS jobs.')
|
||||
parser.add_argument('--host-ip', default=None, type=str,
|
||||
help=('Host IP addressed, this is only needed ' +
|
||||
'if the host IP cannot be automatically guessed.'))
|
||||
parser.add_argument('--log-level', default='INFO', type=str,
|
||||
choices=['INFO', 'DEBUG'],
|
||||
help='Logging level of the logger.')
|
||||
args = parser.parse_args()
|
||||
|
||||
fmt = '%(asctime)s %(levelname)s %(message)s'
|
||||
if args.log_level == 'INFO':
|
||||
level = logging.INFO
|
||||
elif args.log_level == 'DEBUG':
|
||||
level = logging.DEBUG
|
||||
else:
|
||||
raise RuntimeError("Unknown logging level %s" % args.log_level)
|
||||
|
||||
logging.basicConfig(format=fmt, level=level)
|
||||
|
||||
if args.num_servers == 0:
|
||||
start_rabit_tracker(args)
|
||||
else:
|
||||
raise RuntimeError("Do not yet support start ps tracker in standalone mode.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user