diff --git a/tracker/rabit_hadoop.py b/tracker/rabit_hadoop.py index 682ec69a1..4c87460d8 100755 --- a/tracker/rabit_hadoop.py +++ b/tracker/rabit_hadoop.py @@ -37,6 +37,8 @@ parser = argparse.ArgumentParser(description='Rabit script to submit rabit jobs 'This script support both Hadoop 1.0 and Yarn(MRv2), Yarn is recommended') parser.add_argument('-n', '--nworker', required=True, type=int, help = 'number of worker proccess to be launched') +parser.add_argument('-hip', '--host_ip', default='auto', type=str, + help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine') parser.add_argument('-nt', '--nthread', default = -1, type=int, help = 'number of thread in each mapper to be launched, set it if each rabit job is multi-threaded') parser.add_argument('-i', '--input', required=True, @@ -149,4 +151,4 @@ def hadoop_streaming(nworker, worker_args, use_yarn): subprocess.check_call(cmd, shell = True) fun_submit = lambda nworker, worker_args: hadoop_streaming(nworker, worker_args, int(hadoop_version[0]) >= 2) -tracker.submit(args.nworker, [], fun_submit = fun_submit, verbose = args.verbose) +tracker.submit(args.nworker, [], fun_submit = fun_submit, verbose = args.verbose, hostIP = args.host_ip) diff --git a/tracker/rabit_tracker.py b/tracker/rabit_tracker.py index aa21973e5..e324f7a9b 100644 --- a/tracker/rabit_tracker.py +++ b/tracker/rabit_tracker.py @@ -122,7 +122,7 @@ class SlaveEntry: return rmset class Tracker: - def __init__(self, port = 9091, port_end = 9999, verbose = True): + def __init__(self, port = 9091, port_end = 9999, verbose = True, hostIP = 'auto'): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) for port in range(port, port_end): try: @@ -134,11 +134,18 @@ class Tracker: sock.listen(16) self.sock = sock self.verbose = verbose + self.hostIP = hostIP self.log_print('start listen on %s:%d' % (socket.gethostname(), self.port), 1) def __del__(self): self.sock.close() def slave_args(self): - return ['rabit_tracker_uri=%s' % socket.gethostname(), + if self.hostIP == 'auto': + host = socket.gethostname() + elif self.hostIP = 'ip': + host = socket.gethostbyname(socket.getfqdn()) + else: + host = hostIP + return ['rabit_tracker_uri=%s' % hostIP, 'rabit_tracker_port=%s' % self.port] def get_neighbor(self, rank, nslave): rank = rank + 1