update tracker for host IP

This commit is contained in:
tqchen 2015-03-01 23:27:59 -08:00
parent e4ce8efab5
commit 75c647cd84
2 changed files with 12 additions and 3 deletions

View File

@ -37,6 +37,8 @@ parser = argparse.ArgumentParser(description='Rabit script to submit rabit jobs
'This script support both Hadoop 1.0 and Yarn(MRv2), Yarn is recommended')
parser.add_argument('-n', '--nworker', required=True, type=int,
help = 'number of worker proccess to be launched')
parser.add_argument('-hip', '--host_ip', default='auto', type=str,
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
parser.add_argument('-nt', '--nthread', default = -1, type=int,
help = 'number of thread in each mapper to be launched, set it if each rabit job is multi-threaded')
parser.add_argument('-i', '--input', required=True,
@ -149,4 +151,4 @@ def hadoop_streaming(nworker, worker_args, use_yarn):
subprocess.check_call(cmd, shell = True)
fun_submit = lambda nworker, worker_args: hadoop_streaming(nworker, worker_args, int(hadoop_version[0]) >= 2)
tracker.submit(args.nworker, [], fun_submit = fun_submit, verbose = args.verbose)
tracker.submit(args.nworker, [], fun_submit = fun_submit, verbose = args.verbose, hostIP = args.host_ip)

View File

@ -122,7 +122,7 @@ class SlaveEntry:
return rmset
class Tracker:
def __init__(self, port = 9091, port_end = 9999, verbose = True):
def __init__(self, port = 9091, port_end = 9999, verbose = True, hostIP = 'auto'):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
for port in range(port, port_end):
try:
@ -134,11 +134,18 @@ class Tracker:
sock.listen(16)
self.sock = sock
self.verbose = verbose
self.hostIP = hostIP
self.log_print('start listen on %s:%d' % (socket.gethostname(), self.port), 1)
def __del__(self):
self.sock.close()
def slave_args(self):
return ['rabit_tracker_uri=%s' % socket.gethostname(),
if self.hostIP == 'auto':
host = socket.gethostname()
elif self.hostIP = 'ip':
host = socket.gethostbyname(socket.getfqdn())
else:
host = hostIP
return ['rabit_tracker_uri=%s' % hostIP,
'rabit_tracker_port=%s' % self.port]
def get_neighbor(self, rank, nslave):
rank = rank + 1