[TRACKER] remove tracker in rabit, use DMLC
This commit is contained in:
parent
112d866dc9
commit
73b6e9bbd0
43
doc/guide.md
43
doc/guide.md
@ -13,9 +13,6 @@ To run the examples locally, you will need to build them with ```make```.
|
||||
- [Checkpoint and LazyCheckpoint](#checkpoint-and-lazycheckpoint)
|
||||
* [Compile Programs with Rabit](#compile-programs-with-rabit)
|
||||
* [Running Rabit Jobs](#running-rabit-jobs)
|
||||
- [Running Rabit on Hadoop](#running-rabit-on-hadoop)
|
||||
- [Running Rabit using MPI](#running-rabit-using-mpi)
|
||||
- [Customize Tracker Script](#customize-tracker-script)
|
||||
* [Fault Tolerance](#fault-tolerance)
|
||||
|
||||
What is Allreduce
|
||||
@ -334,45 +331,7 @@ For example, consider the following script in the test case
|
||||
Running Rabit Jobs
|
||||
------------------
|
||||
Rabit is a portable library that can run on multiple platforms.
|
||||
|
||||
#### Running Rabit Locally
|
||||
* You can use [../tracker/rabit_demo.py](https://github.com/dmlc/rabit/blob/master/tracker/rabit_demo.py) to start n processes locally
|
||||
* This script will restart the program when it exits with -2, so it can be used for [mock test](#link-against-mock-test-library)
|
||||
|
||||
#### Running Rabit on Hadoop
|
||||
* You can use [../tracker/rabit_yarn.py](https://github.com/dmlc/rabit/blob/master/tracker/rabit_yarn.py) to run rabit programs as Yarn application
|
||||
* This will start rabit programs as yarn applications
|
||||
- This allows multi-threading programs in each node, which can be more efficient
|
||||
- An easy multi-threading solution could be to use OpenMP with rabit code
|
||||
* It is also possible to run rabit program via hadoop streaming, however, YARN is highly recommended.
|
||||
|
||||
#### Running Rabit using MPI
|
||||
* You can submit rabit programs to an MPI cluster using [../tracker/rabit_mpi.py](https://github.com/dmlc/rabit/blob/master/tracker/rabit_mpi.py).
|
||||
* If you linked your code against librabit_mpi.a, then you can directly use mpirun to submit the job
|
||||
|
||||
#### Customize Tracker Script
|
||||
You can also modify the tracker script to allow rabit to run on other platforms. To do so, refer to existing
|
||||
tracker scripts, such as [../tracker/rabit_yarn.py](../tracker/rabit_yarn.py) and [../tracker/rabit_mpi.py](https://github.com/dmlc/rabit/blob/master/tracker/rabit_mpi.py) to get a sense of how it is done.
|
||||
|
||||
You will need to implement a platform dependent submission function with the following definition
|
||||
```python
|
||||
def fun_submit(nworkers, worker_args, worker_envs):
|
||||
"""
|
||||
customized submit script, that submits nslave jobs,
|
||||
each must contain args as parameter
|
||||
note this can be a lambda closure
|
||||
Parameters
|
||||
nworkers number of worker processes to start
|
||||
worker_args addtiional arguments that needs to be passed to worker
|
||||
worker_envs enviroment variables that need to be set to the worker
|
||||
"""
|
||||
```
|
||||
The submission function should start nworkers processes in the platform, and append worker_args to the end of the other arguments.
|
||||
Then you can simply call ```tracker.submit``` with fun_submit to submit jobs to the target platform
|
||||
|
||||
Note that the current rabit tracker does not restart a worker when it dies, the restart of a node is done by the platform, otherwise we should write the fail-restart logic in the custom script.
|
||||
* Fail-restart is usually provided by most platforms.
|
||||
- rabit-yarn provides such functionality in YARN
|
||||
All the rabit jobs can be submitted using [dmlc-tracker](https://github.com/dmlc/dmlc-core/tree/master/tracker)
|
||||
|
||||
Fault Tolerance
|
||||
---------------
|
||||
|
||||
@ -1,12 +0,0 @@
|
||||
Trackers
|
||||
=====
|
||||
This folder contains tracker scripts that can be used to submit yarn jobs to different platforms,
|
||||
the example guidelines are in the script themselfs
|
||||
|
||||
***Supported Platforms***
|
||||
* Local demo: [rabit_demo.py](rabit_demo.py)
|
||||
* MPI: [rabit_mpi.py](rabit_mpi.py)
|
||||
* Yarn (Hadoop): [rabit_yarn.py](rabit_yarn.py)
|
||||
- It is also possible to submit via hadoop streaming with rabit_hadoop_streaming.py
|
||||
- However, it is higly recommended to use rabit_yarn.py because this will allocate resources more precisely and fits machine learning scenarios
|
||||
* Sun Grid engine: [rabit_sge.py](rabit_sge.py)
|
||||
@ -1,96 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
This is the demo submission script of rabit for submitting jobs in local machine
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
from threading import Thread
|
||||
import rabit_tracker as tracker
|
||||
if os.name == 'nt':
|
||||
WRAPPER_PATH = os.path.dirname(__file__) + '\\..\\wrapper'
|
||||
else:
|
||||
WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper'
|
||||
|
||||
parser = argparse.ArgumentParser(description='Rabit script to submit rabit job locally using python subprocess')
|
||||
parser.add_argument('-n', '--nworker', required=True, type=int,
|
||||
help = 'number of worker proccess to be launched')
|
||||
parser.add_argument('-v', '--verbose', default=0, choices=[0, 1], type=int,
|
||||
help = 'print more messages into the console')
|
||||
parser.add_argument('command', nargs='+',
|
||||
help = 'command for rabit program')
|
||||
args = parser.parse_args()
|
||||
|
||||
# bash script for keepalive
|
||||
# use it so that python do not need to communicate with subprocess
|
||||
echo="echo %s rabit_num_trial=$nrep;"
|
||||
keepalive = """
|
||||
nrep=0
|
||||
rc=254
|
||||
while [ $rc -eq 254 ];
|
||||
do
|
||||
export rabit_num_trial=$nrep
|
||||
%s
|
||||
%s
|
||||
rc=$?;
|
||||
nrep=$((nrep+1));
|
||||
done
|
||||
"""
|
||||
|
||||
def exec_cmd(cmd, taskid, worker_env):
|
||||
if cmd[0].find('/') == -1 and os.path.exists(cmd[0]) and os.name != 'nt':
|
||||
cmd[0] = './' + cmd[0]
|
||||
cmd = ' '.join(cmd)
|
||||
env = os.environ.copy()
|
||||
for k, v in worker_env.items():
|
||||
env[k] = str(v)
|
||||
env['rabit_task_id'] = str(taskid)
|
||||
env['PYTHONPATH'] = WRAPPER_PATH
|
||||
|
||||
ntrial = 0
|
||||
while True:
|
||||
if os.name == 'nt':
|
||||
env['rabit_num_trial'] = str(ntrial)
|
||||
ret = subprocess.call(cmd, shell=True, env = env)
|
||||
if ret == 254:
|
||||
ntrial += 1
|
||||
continue
|
||||
else:
|
||||
if args.verbose != 0:
|
||||
bash = keepalive % (echo % cmd, cmd)
|
||||
else:
|
||||
bash = keepalive % ('', cmd)
|
||||
ret = subprocess.call(bash, shell=True, executable='bash', env = env)
|
||||
if ret == 0:
|
||||
if args.verbose != 0:
|
||||
print 'Thread %d exit with 0' % taskid
|
||||
return
|
||||
else:
|
||||
if os.name == 'nt':
|
||||
os.exit(-1)
|
||||
else:
|
||||
raise Exception('Get nonzero return code=%d' % ret)
|
||||
#
|
||||
# Note: this submit script is only used for demo purpose
|
||||
# submission script using pyhton multi-threading
|
||||
#
|
||||
def mthread_submit(nslave, worker_args, worker_envs):
|
||||
"""
|
||||
customized submit script, that submit nslave jobs, each must contain args as parameter
|
||||
note this can be a lambda function containing additional parameters in input
|
||||
Parameters
|
||||
nslave number of slave process to start up
|
||||
args arguments to launch each job
|
||||
this usually includes the parameters of master_uri and parameters passed into submit
|
||||
"""
|
||||
procs = {}
|
||||
for i in range(nslave):
|
||||
procs[i] = Thread(target = exec_cmd, args = (args.command + worker_args, i, worker_envs))
|
||||
procs[i].daemon = True
|
||||
procs[i].start()
|
||||
for i in range(nslave):
|
||||
procs[i].join()
|
||||
|
||||
# call submit, with nslave, the commands to run each job and submit function
|
||||
tracker.submit(args.nworker, [], fun_submit = mthread_submit, verbose = args.verbose)
|
||||
@ -1,165 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Deprecated
|
||||
|
||||
This is a script to submit rabit job using hadoop streaming.
|
||||
It will submit the rabit process as mappers of MapReduce.
|
||||
|
||||
This script is deprecated, it is highly recommended to use rabit_yarn.py instead
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import subprocess
|
||||
import warnings
|
||||
import rabit_tracker as tracker
|
||||
|
||||
WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper'
|
||||
|
||||
#!!! Set path to hadoop and hadoop streaming jar here
|
||||
hadoop_binary = 'hadoop'
|
||||
hadoop_streaming_jar = None
|
||||
|
||||
# code
|
||||
hadoop_home = os.getenv('HADOOP_HOME')
|
||||
if hadoop_home != None:
|
||||
if hadoop_binary == None:
|
||||
hadoop_binary = hadoop_home + '/bin/hadoop'
|
||||
assert os.path.exists(hadoop_binary), "HADOOP_HOME does not contain the hadoop binary"
|
||||
if hadoop_streaming_jar == None:
|
||||
hadoop_streaming_jar = hadoop_home + '/lib/hadoop-streaming.jar'
|
||||
assert os.path.exists(hadoop_streaming_jar), "HADOOP_HOME does not contain the hadoop streaming jar"
|
||||
|
||||
if hadoop_binary == None or hadoop_streaming_jar == None:
|
||||
warnings.warn('Warning: Cannot auto-detect path to hadoop or hadoop-streaming jar\n'\
|
||||
'\tneed to set them via arguments -hs and -hb\n'\
|
||||
'\tTo enable auto-detection, you can set enviroment variable HADOOP_HOME'\
|
||||
', or modify rabit_hadoop.py line 16', stacklevel = 2)
|
||||
|
||||
parser = argparse.ArgumentParser(description='Rabit script to submit rabit jobs using Hadoop Streaming.'\
|
||||
'It is Highly recommended to use rabit_yarn.py instead')
|
||||
parser.add_argument('-n', '--nworker', required=True, type=int,
|
||||
help = 'number of worker proccess to be launched')
|
||||
parser.add_argument('-hip', '--host_ip', default='auto', type=str,
|
||||
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
|
||||
parser.add_argument('-i', '--input', required=True,
|
||||
help = 'input path in HDFS')
|
||||
parser.add_argument('-o', '--output', required=True,
|
||||
help = 'output path in HDFS')
|
||||
parser.add_argument('-v', '--verbose', default=0, choices=[0, 1], type=int,
|
||||
help = 'print more messages into the console')
|
||||
parser.add_argument('-ac', '--auto_file_cache', default=1, choices=[0, 1], type=int,
|
||||
help = 'whether automatically cache the files in the command to hadoop localfile, this is on by default')
|
||||
parser.add_argument('-f', '--files', default = [], action='append',
|
||||
help = 'the cached file list in mapreduce,'\
|
||||
' the submission script will automatically cache all the files which appears in command'\
|
||||
' This will also cause rewritten of all the file names in the command to current path,'\
|
||||
' for example `../../kmeans ../kmeans.conf` will be rewritten to `./kmeans kmeans.conf`'\
|
||||
' because the two files are cached to running folder.'\
|
||||
' You may need this option to cache additional files.'\
|
||||
' You can also use it to manually cache files when auto_file_cache is off')
|
||||
parser.add_argument('--jobname', default='auto', help = 'customize jobname in tracker')
|
||||
parser.add_argument('--timeout', default=600000000, type=int,
|
||||
help = 'timeout (in million seconds) of each mapper job, automatically set to a very long time,'\
|
||||
'normally you do not need to set this ')
|
||||
parser.add_argument('--vcores', default = -1, type=int,
|
||||
help = 'number of vcpores to request in each mapper, set it if each rabit job is multi-threaded')
|
||||
parser.add_argument('-mem', '--memory_mb', default=-1, type=int,
|
||||
help = 'maximum memory used by the process. Guide: set it large (near mapred.cluster.max.map.memory.mb)'\
|
||||
'if you are running multi-threading rabit,'\
|
||||
'so that each node can occupy all the mapper slots in a machine for maximum performance')
|
||||
if hadoop_binary == None:
|
||||
parser.add_argument('-hb', '--hadoop_binary', required = True,
|
||||
help="path to hadoop binary file")
|
||||
else:
|
||||
parser.add_argument('-hb', '--hadoop_binary', default = hadoop_binary,
|
||||
help="path to hadoop binary file")
|
||||
|
||||
if hadoop_streaming_jar == None:
|
||||
parser.add_argument('-hs', '--hadoop_streaming_jar', required = True,
|
||||
help='path to hadoop streamimg jar file')
|
||||
else:
|
||||
parser.add_argument('-hs', '--hadoop_streaming_jar', default = hadoop_streaming_jar,
|
||||
help='path to hadoop streamimg jar file')
|
||||
parser.add_argument('command', nargs='+',
|
||||
help = 'command for rabit program')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.jobname == 'auto':
|
||||
args.jobname = ('Rabit[nworker=%d]:' % args.nworker) + args.command[0].split('/')[-1];
|
||||
|
||||
# detech hadoop version
|
||||
(out, err) = subprocess.Popen('%s version' % args.hadoop_binary, shell = True, stdout=subprocess.PIPE).communicate()
|
||||
out = out.split('\n')[0].split()
|
||||
assert out[0] == 'Hadoop', 'cannot parse hadoop version string'
|
||||
hadoop_version = out[1].split('.')
|
||||
use_yarn = int(hadoop_version[0]) >= 2
|
||||
if use_yarn:
|
||||
warnings.warn('It is highly recommended to use rabit_yarn.py to submit jobs to yarn instead', stacklevel = 2)
|
||||
|
||||
print 'Current Hadoop Version is %s' % out[1]
|
||||
|
||||
def hadoop_streaming(nworker, worker_args, worker_envs, use_yarn):
|
||||
worker_envs['CLASSPATH'] = '`$HADOOP_HOME/bin/hadoop classpath --glob` '
|
||||
worker_envs['LD_LIBRARY_PATH'] = '{LD_LIBRARY_PATH}:$HADOOP_HDFS_HOME/lib/native:$JAVA_HOME/jre/lib/amd64/server'
|
||||
fset = set()
|
||||
if args.auto_file_cache:
|
||||
for i in range(len(args.command)):
|
||||
f = args.command[i]
|
||||
if os.path.exists(f):
|
||||
fset.add(f)
|
||||
if i == 0:
|
||||
args.command[i] = './' + args.command[i].split('/')[-1]
|
||||
else:
|
||||
args.command[i] = args.command[i].split('/')[-1]
|
||||
if args.command[0].endswith('.py'):
|
||||
flst = [WRAPPER_PATH + '/rabit.py',
|
||||
WRAPPER_PATH + '/librabit_wrapper.so',
|
||||
WRAPPER_PATH + '/librabit_wrapper_mock.so']
|
||||
for f in flst:
|
||||
if os.path.exists(f):
|
||||
fset.add(f)
|
||||
kmap = {}
|
||||
kmap['env'] = 'mapred.child.env'
|
||||
# setup keymaps
|
||||
if use_yarn:
|
||||
kmap['nworker'] = 'mapreduce.job.maps'
|
||||
kmap['jobname'] = 'mapreduce.job.name'
|
||||
kmap['nthread'] = 'mapreduce.map.cpu.vcores'
|
||||
kmap['timeout'] = 'mapreduce.task.timeout'
|
||||
kmap['memory_mb'] = 'mapreduce.map.memory.mb'
|
||||
else:
|
||||
kmap['nworker'] = 'mapred.map.tasks'
|
||||
kmap['jobname'] = 'mapred.job.name'
|
||||
kmap['nthread'] = None
|
||||
kmap['timeout'] = 'mapred.task.timeout'
|
||||
kmap['memory_mb'] = 'mapred.job.map.memory.mb'
|
||||
cmd = '%s jar %s' % (args.hadoop_binary, args.hadoop_streaming_jar)
|
||||
cmd += ' -D%s=%d' % (kmap['nworker'], nworker)
|
||||
cmd += ' -D%s=%s' % (kmap['jobname'], args.jobname)
|
||||
envstr = ','.join('%s=%s' % (k, str(v)) for k, v in worker_envs.items())
|
||||
cmd += ' -D%s=\"%s\"' % (kmap['env'], envstr)
|
||||
if args.vcores != -1:
|
||||
if kmap['nthread'] is None:
|
||||
warnings.warn('nthread can only be set in Yarn(Hadoop version greater than 2.0),'\
|
||||
'it is recommended to use Yarn to submit rabit jobs', stacklevel = 2)
|
||||
else:
|
||||
cmd += ' -D%s=%d' % (kmap['nthread'], args.vcores)
|
||||
cmd += ' -D%s=%d' % (kmap['timeout'], args.timeout)
|
||||
if args.memory_mb != -1:
|
||||
cmd += ' -D%s=%d' % (kmap['timeout'], args.timeout)
|
||||
|
||||
cmd += ' -input %s -output %s' % (args.input, args.output)
|
||||
cmd += ' -mapper \"%s\" -reducer \"/bin/cat\" ' % (' '.join(args.command + worker_args))
|
||||
if args.files != None:
|
||||
for flst in args.files:
|
||||
for f in flst.split('#'):
|
||||
fset.add(f)
|
||||
for f in fset:
|
||||
cmd += ' -file %s' % f
|
||||
print cmd
|
||||
subprocess.check_call(cmd, shell = True)
|
||||
|
||||
fun_submit = lambda nworker, worker_args, worker_envs: hadoop_streaming(nworker, worker_args, worker_envs, int(hadoop_version[0]) >= 2)
|
||||
tracker.submit(args.nworker, [], fun_submit = fun_submit, verbose = args.verbose, hostIP = args.host_ip)
|
||||
@ -1,43 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Submission script to submit rabit jobs using MPI
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
import rabit_tracker as tracker
|
||||
|
||||
parser = argparse.ArgumentParser(description='Rabit script to submit rabit job using MPI')
|
||||
parser.add_argument('-n', '--nworker', required=True, type=int,
|
||||
help = 'number of worker proccess to be launched')
|
||||
parser.add_argument('-v', '--verbose', default=0, choices=[0, 1], type=int,
|
||||
help = 'print more messages into the console')
|
||||
parser.add_argument('-H', '--hostfile', type=str,
|
||||
help = 'the hostfile of mpi server')
|
||||
parser.add_argument('command', nargs='+',
|
||||
help = 'command for rabit program')
|
||||
args = parser.parse_args()
|
||||
#
|
||||
# submission script using MPI
|
||||
#
|
||||
def mpi_submit(nslave, worker_args, worker_envs):
|
||||
"""
|
||||
customized submit script, that submit nslave jobs, each must contain args as parameter
|
||||
note this can be a lambda function containing additional parameters in input
|
||||
Parameters
|
||||
nslave number of slave process to start up
|
||||
args arguments to launch each job
|
||||
this usually includes the parameters of master_uri and parameters passed into submit
|
||||
"""
|
||||
worker_args += ['%s=%s' % (k, str(v)) for k, v in worker_envs.items()]
|
||||
sargs = ' '.join(args.command + worker_args)
|
||||
if args.hostfile is None:
|
||||
cmd = ' '.join(['mpirun -n %d' % (nslave)] + args.command + worker_args)
|
||||
else:
|
||||
cmd = ' '.join(['mpirun -n %d --hostfile %s' % (nslave, args.hostfile)] + args.command + worker_args)
|
||||
print cmd
|
||||
subprocess.check_call(cmd, shell = True)
|
||||
|
||||
# call submit, with nslave, the commands to run each job and submit function
|
||||
tracker.submit(args.nworker, [], fun_submit = mpi_submit, verbose = args.verbose)
|
||||
@ -1,70 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Submit rabit jobs to Sun Grid Engine
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
import rabit_tracker as tracker
|
||||
|
||||
parser = argparse.ArgumentParser(description='Rabit script to submit rabit job using MPI')
|
||||
parser.add_argument('-n', '--nworker', required=True, type=int,
|
||||
help = 'number of worker proccess to be launched')
|
||||
parser.add_argument('-q', '--queue', default='default', type=str,
|
||||
help = 'the queue we want to submit the job to')
|
||||
parser.add_argument('-hip', '--host_ip', default='auto', type=str,
|
||||
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
|
||||
parser.add_argument('--vcores', default = 1, type=int,
|
||||
help = 'number of vcpores to request in each mapper, set it if each rabit job is multi-threaded')
|
||||
parser.add_argument('--jobname', default='auto', help = 'customize jobname in tracker')
|
||||
parser.add_argument('--logdir', default='auto', help = 'customize the directory to place the logs')
|
||||
parser.add_argument('-v', '--verbose', default=0, choices=[0, 1], type=int,
|
||||
help = 'print more messages into the console')
|
||||
parser.add_argument('command', nargs='+',
|
||||
help = 'command for rabit program')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.jobname == 'auto':
|
||||
args.jobname = ('rabit%d.' % args.nworker) + args.command[0].split('/')[-1];
|
||||
if args.logdir == 'auto':
|
||||
args.logdir = args.jobname + '.log'
|
||||
|
||||
if os.path.exists(args.logdir):
|
||||
if not os.path.isdir(args.logdir):
|
||||
raise RuntimeError('specified logdir %s is a file instead of directory' % args.logdir)
|
||||
else:
|
||||
os.mkdir(args.logdir)
|
||||
|
||||
runscript = '%s/runrabit.sh' % args.logdir
|
||||
fo = open(runscript, 'w')
|
||||
fo.write('source ~/.bashrc\n')
|
||||
fo.write('\"$@\"\n')
|
||||
fo.close()
|
||||
#
|
||||
# submission script using MPI
|
||||
#
|
||||
def sge_submit(nslave, worker_args, worker_envs):
|
||||
"""
|
||||
customized submit script, that submit nslave jobs, each must contain args as parameter
|
||||
note this can be a lambda function containing additional parameters in input
|
||||
Parameters
|
||||
nslave number of slave process to start up
|
||||
args arguments to launch each job
|
||||
this usually includes the parameters of master_uri and parameters passed into submit
|
||||
"""
|
||||
env_arg = ','.join(['%s=\"%s\"' % (k, str(v)) for k, v in worker_envs.items()])
|
||||
cmd = 'qsub -cwd -t 1-%d -S /bin/bash' % nslave
|
||||
if args.queue != 'default':
|
||||
cmd += '-q %s' % args.queue
|
||||
cmd += ' -N %s ' % args.jobname
|
||||
cmd += ' -e %s -o %s' % (args.logdir, args.logdir)
|
||||
cmd += ' -pe orte %d' % (args.vcores)
|
||||
cmd += ' -v %s,PATH=${PATH}:.' % env_arg
|
||||
cmd += ' %s %s' % (runscript, ' '.join(args.command + worker_args))
|
||||
print cmd
|
||||
subprocess.check_call(cmd, shell = True)
|
||||
print 'Waiting for the jobs to get up...'
|
||||
|
||||
# call submit, with nslave, the commands to run each job and submit function
|
||||
tracker.submit(args.nworker, [], fun_submit = sge_submit, verbose = args.verbose)
|
||||
@ -1,317 +0,0 @@
|
||||
"""
|
||||
Tracker script for rabit
|
||||
Implements the tracker control protocol
|
||||
- start rabit jobs
|
||||
- help nodes to establish links with each other
|
||||
|
||||
Tianqi Chen
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import socket
|
||||
import struct
|
||||
import subprocess
|
||||
import random
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
"""
|
||||
Extension of socket to handle recv and send of special data
|
||||
"""
|
||||
class ExSocket:
|
||||
def __init__(self, sock):
|
||||
self.sock = sock
|
||||
def recvall(self, nbytes):
|
||||
res = []
|
||||
sock = self.sock
|
||||
nread = 0
|
||||
while nread < nbytes:
|
||||
chunk = self.sock.recv(min(nbytes - nread, 1024))
|
||||
nread += len(chunk)
|
||||
res.append(chunk)
|
||||
return ''.join(res)
|
||||
def recvint(self):
|
||||
return struct.unpack('@i', self.recvall(4))[0]
|
||||
def sendint(self, n):
|
||||
self.sock.sendall(struct.pack('@i', n))
|
||||
def sendstr(self, s):
|
||||
self.sendint(len(s))
|
||||
self.sock.sendall(s)
|
||||
def recvstr(self):
|
||||
slen = self.recvint()
|
||||
return self.recvall(slen)
|
||||
|
||||
# magic number used to verify existence of data
|
||||
kMagic = 0xff99
|
||||
|
||||
class SlaveEntry:
|
||||
def __init__(self, sock, s_addr):
|
||||
slave = ExSocket(sock)
|
||||
self.sock = slave
|
||||
self.host = socket.gethostbyname(s_addr[0])
|
||||
magic = slave.recvint()
|
||||
assert magic == kMagic, 'invalid magic number=%d from %s' % (magic, self.host)
|
||||
slave.sendint(kMagic)
|
||||
self.rank = slave.recvint()
|
||||
self.world_size = slave.recvint()
|
||||
self.jobid = slave.recvstr()
|
||||
self.cmd = slave.recvstr()
|
||||
|
||||
def decide_rank(self, job_map):
|
||||
if self.rank >= 0:
|
||||
return self.rank
|
||||
if self.jobid != 'NULL' and self.jobid in job_map:
|
||||
return job_map[self.jobid]
|
||||
return -1
|
||||
|
||||
def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map):
|
||||
self.rank = rank
|
||||
nnset = set(tree_map[rank])
|
||||
rprev, rnext = ring_map[rank]
|
||||
self.sock.sendint(rank)
|
||||
# send parent rank
|
||||
self.sock.sendint(parent_map[rank])
|
||||
# send world size
|
||||
self.sock.sendint(len(tree_map))
|
||||
self.sock.sendint(len(nnset))
|
||||
# send the rprev and next link
|
||||
for r in nnset:
|
||||
self.sock.sendint(r)
|
||||
# send prev link
|
||||
if rprev != -1 and rprev != rank:
|
||||
nnset.add(rprev)
|
||||
self.sock.sendint(rprev)
|
||||
else:
|
||||
self.sock.sendint(-1)
|
||||
# send next link
|
||||
if rnext != -1 and rnext != rank:
|
||||
nnset.add(rnext)
|
||||
self.sock.sendint(rnext)
|
||||
else:
|
||||
self.sock.sendint(-1)
|
||||
while True:
|
||||
ngood = self.sock.recvint()
|
||||
goodset = set([])
|
||||
for i in xrange(ngood):
|
||||
goodset.add(self.sock.recvint())
|
||||
assert goodset.issubset(nnset)
|
||||
badset = nnset - goodset
|
||||
conset = []
|
||||
for r in badset:
|
||||
if r in wait_conn:
|
||||
conset.append(r)
|
||||
self.sock.sendint(len(conset))
|
||||
self.sock.sendint(len(badset) - len(conset))
|
||||
for r in conset:
|
||||
self.sock.sendstr(wait_conn[r].host)
|
||||
self.sock.sendint(wait_conn[r].port)
|
||||
self.sock.sendint(r)
|
||||
nerr = self.sock.recvint()
|
||||
if nerr != 0:
|
||||
continue
|
||||
self.port = self.sock.recvint()
|
||||
rmset = []
|
||||
# all connection was successuly setup
|
||||
for r in conset:
|
||||
wait_conn[r].wait_accept -= 1
|
||||
if wait_conn[r].wait_accept == 0:
|
||||
rmset.append(r)
|
||||
for r in rmset:
|
||||
wait_conn.pop(r, None)
|
||||
self.wait_accept = len(badset) - len(conset)
|
||||
return rmset
|
||||
|
||||
class Tracker:
|
||||
def __init__(self, port = 9091, port_end = 9999, verbose = True, hostIP = 'auto'):
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
for port in range(port, port_end):
|
||||
try:
|
||||
sock.bind(('', port))
|
||||
self.port = port
|
||||
break
|
||||
except socket.error:
|
||||
continue
|
||||
sock.listen(128)
|
||||
self.sock = sock
|
||||
self.verbose = verbose
|
||||
if hostIP == 'auto':
|
||||
hostIP = 'ip'
|
||||
self.hostIP = hostIP
|
||||
self.log_print('start listen on %s:%d' % (socket.gethostname(), self.port), 1)
|
||||
def __del__(self):
|
||||
self.sock.close()
|
||||
def slave_envs(self):
|
||||
"""
|
||||
get enviroment variables for slaves
|
||||
can be passed in as args or envs
|
||||
"""
|
||||
if self.hostIP == 'dns':
|
||||
host = socket.gethostname()
|
||||
elif self.hostIP == 'ip':
|
||||
host = socket.gethostbyname(socket.getfqdn())
|
||||
else:
|
||||
host = self.hostIP
|
||||
return {'rabit_tracker_uri': host,
|
||||
'rabit_tracker_port': self.port}
|
||||
def get_neighbor(self, rank, nslave):
|
||||
rank = rank + 1
|
||||
ret = []
|
||||
if rank > 1:
|
||||
ret.append(rank / 2 - 1)
|
||||
if rank * 2 - 1 < nslave:
|
||||
ret.append(rank * 2 - 1)
|
||||
if rank * 2 < nslave:
|
||||
ret.append(rank * 2)
|
||||
return ret
|
||||
def get_tree(self, nslave):
|
||||
tree_map = {}
|
||||
parent_map = {}
|
||||
for r in range(nslave):
|
||||
tree_map[r] = self.get_neighbor(r, nslave)
|
||||
parent_map[r] = (r + 1) / 2 - 1
|
||||
return tree_map, parent_map
|
||||
def find_share_ring(self, tree_map, parent_map, r):
|
||||
"""
|
||||
get a ring structure that tends to share nodes with the tree
|
||||
return a list starting from r
|
||||
"""
|
||||
nset = set(tree_map[r])
|
||||
cset = nset - set([parent_map[r]])
|
||||
if len(cset) == 0:
|
||||
return [r]
|
||||
rlst = [r]
|
||||
cnt = 0
|
||||
for v in cset:
|
||||
vlst = self.find_share_ring(tree_map, parent_map, v)
|
||||
cnt += 1
|
||||
if cnt == len(cset):
|
||||
vlst.reverse()
|
||||
rlst += vlst
|
||||
return rlst
|
||||
|
||||
def get_ring(self, tree_map, parent_map):
|
||||
"""
|
||||
get a ring connection used to recover local data
|
||||
"""
|
||||
assert parent_map[0] == -1
|
||||
rlst = self.find_share_ring(tree_map, parent_map, 0)
|
||||
assert len(rlst) == len(tree_map)
|
||||
ring_map = {}
|
||||
nslave = len(tree_map)
|
||||
for r in range(nslave):
|
||||
rprev = (r + nslave - 1) % nslave
|
||||
rnext = (r + 1) % nslave
|
||||
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
|
||||
return ring_map
|
||||
|
||||
def get_link_map(self, nslave):
|
||||
"""
|
||||
get the link map, this is a bit hacky, call for better algorithm
|
||||
to place similar nodes together
|
||||
"""
|
||||
tree_map, parent_map = self.get_tree(nslave)
|
||||
ring_map = self.get_ring(tree_map, parent_map)
|
||||
rmap = {0 : 0}
|
||||
k = 0
|
||||
for i in range(nslave - 1):
|
||||
k = ring_map[k][1]
|
||||
rmap[k] = i + 1
|
||||
|
||||
ring_map_ = {}
|
||||
tree_map_ = {}
|
||||
parent_map_ ={}
|
||||
for k, v in ring_map.items():
|
||||
ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
|
||||
for k, v in tree_map.items():
|
||||
tree_map_[rmap[k]] = [rmap[x] for x in v]
|
||||
for k, v in parent_map.items():
|
||||
if k != 0:
|
||||
parent_map_[rmap[k]] = rmap[v]
|
||||
else:
|
||||
parent_map_[rmap[k]] = -1
|
||||
return tree_map_, parent_map_, ring_map_
|
||||
|
||||
def handle_print(self,slave, msg):
|
||||
sys.stdout.write(msg)
|
||||
|
||||
def log_print(self, msg, level):
|
||||
if level == 1:
|
||||
if self.verbose:
|
||||
sys.stderr.write(msg + '\n')
|
||||
else:
|
||||
sys.stderr.write(msg + '\n')
|
||||
|
||||
def accept_slaves(self, nslave):
|
||||
# set of nodes that finishs the job
|
||||
shutdown = {}
|
||||
# set of nodes that is waiting for connections
|
||||
wait_conn = {}
|
||||
# maps job id to rank
|
||||
job_map = {}
|
||||
# list of workers that is pending to be assigned rank
|
||||
pending = []
|
||||
# lazy initialize tree_map
|
||||
tree_map = None
|
||||
|
||||
while len(shutdown) != nslave:
|
||||
fd, s_addr = self.sock.accept()
|
||||
s = SlaveEntry(fd, s_addr)
|
||||
if s.cmd == 'print':
|
||||
msg = s.sock.recvstr()
|
||||
self.handle_print(s, msg)
|
||||
continue
|
||||
if s.cmd == 'shutdown':
|
||||
assert s.rank >= 0 and s.rank not in shutdown
|
||||
assert s.rank not in wait_conn
|
||||
shutdown[s.rank] = s
|
||||
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1)
|
||||
continue
|
||||
assert s.cmd == 'start' or s.cmd == 'recover'
|
||||
# lazily initialize the slaves
|
||||
if tree_map == None:
|
||||
assert s.cmd == 'start'
|
||||
if s.world_size > 0:
|
||||
nslave = s.world_size
|
||||
tree_map, parent_map, ring_map = self.get_link_map(nslave)
|
||||
# set of nodes that is pending for getting up
|
||||
todo_nodes = range(nslave)
|
||||
else:
|
||||
assert s.world_size == -1 or s.world_size == nslave
|
||||
if s.cmd == 'recover':
|
||||
assert s.rank >= 0
|
||||
|
||||
rank = s.decide_rank(job_map)
|
||||
# batch assignment of ranks
|
||||
if rank == -1:
|
||||
assert len(todo_nodes) != 0
|
||||
pending.append(s)
|
||||
if len(pending) == len(todo_nodes):
|
||||
pending.sort(key = lambda x : x.host)
|
||||
for s in pending:
|
||||
rank = todo_nodes.pop(0)
|
||||
if s.jobid != 'NULL':
|
||||
job_map[s.jobid] = rank
|
||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||
if s.wait_accept > 0:
|
||||
wait_conn[rank] = s
|
||||
self.log_print('Recieve %s signal from %s; assign rank %d' % (s.cmd, s.host, s.rank), 1)
|
||||
if len(todo_nodes) == 0:
|
||||
self.log_print('@tracker All of %d nodes getting started' % nslave, 2)
|
||||
self.start_time = time.time()
|
||||
else:
|
||||
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
|
||||
self.log_print('Recieve %s signal from %d' % (s.cmd, s.rank), 1)
|
||||
if s.wait_accept > 0:
|
||||
wait_conn[rank] = s
|
||||
self.log_print('@tracker All nodes finishes job', 2)
|
||||
self.end_time = time.time()
|
||||
self.log_print('@tracker %s secs between node start and job finish' % str(self.end_time - self.start_time), 2)
|
||||
|
||||
def submit(nslave, args, fun_submit, verbose, hostIP = 'auto'):
|
||||
master = Tracker(verbose = verbose, hostIP = hostIP)
|
||||
submit_thread = Thread(target = fun_submit, args = (nslave, args, master.slave_envs()))
|
||||
submit_thread.daemon = True
|
||||
submit_thread.start()
|
||||
master.accept_slaves(nslave)
|
||||
submit_thread.join()
|
||||
@ -1,140 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
This is a script to submit rabit job via Yarn
|
||||
rabit will run as a Yarn application
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import subprocess
|
||||
import warnings
|
||||
import rabit_tracker as tracker
|
||||
|
||||
WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper'
|
||||
YARN_JAR_PATH = os.path.dirname(__file__) + '/../yarn/rabit-yarn.jar'
|
||||
YARN_BOOT_PY = os.path.dirname(__file__) + '/../yarn/run_hdfs_prog.py'
|
||||
|
||||
if not os.path.exists(YARN_JAR_PATH):
|
||||
warnings.warn("cannot find \"%s\", I will try to run build" % YARN_JAR_PATH)
|
||||
cmd = 'cd %s;./build.sh' % (os.path.dirname(__file__) + '/../yarn/')
|
||||
print cmd
|
||||
subprocess.check_call(cmd, shell = True, env = os.environ)
|
||||
assert os.path.exists(YARN_JAR_PATH), "failed to build rabit-yarn.jar, try it manually"
|
||||
|
||||
hadoop_binary = None
|
||||
# code
|
||||
hadoop_home = os.getenv('HADOOP_HOME')
|
||||
|
||||
if hadoop_home != None:
|
||||
if hadoop_binary == None:
|
||||
hadoop_binary = hadoop_home + '/bin/hadoop'
|
||||
assert os.path.exists(hadoop_binary), "HADOOP_HOME does not contain the hadoop binary"
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Rabit script to submit rabit jobs to Yarn.')
|
||||
parser.add_argument('-n', '--nworker', required=True, type=int,
|
||||
help = 'number of worker proccess to be launched')
|
||||
parser.add_argument('-hip', '--host_ip', default='auto', type=str,
|
||||
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
|
||||
parser.add_argument('-v', '--verbose', default=0, choices=[0, 1], type=int,
|
||||
help = 'print more messages into the console')
|
||||
parser.add_argument('-q', '--queue', default='default', type=str,
|
||||
help = 'the queue we want to submit the job to')
|
||||
parser.add_argument('-ac', '--auto_file_cache', default=1, choices=[0, 1], type=int,
|
||||
help = 'whether automatically cache the files in the command to hadoop localfile, this is on by default')
|
||||
parser.add_argument('-f', '--files', default = [], action='append',
|
||||
help = 'the cached file list in mapreduce,'\
|
||||
' the submission script will automatically cache all the files which appears in command'\
|
||||
' This will also cause rewritten of all the file names in the command to current path,'\
|
||||
' for example `../../kmeans ../kmeans.conf` will be rewritten to `./kmeans kmeans.conf`'\
|
||||
' because the two files are cached to running folder.'\
|
||||
' You may need this option to cache additional files.'\
|
||||
' You can also use it to manually cache files when auto_file_cache is off')
|
||||
parser.add_argument('--jobname', default='auto', help = 'customize jobname in tracker')
|
||||
parser.add_argument('--tempdir', default='/tmp', help = 'temporary directory in HDFS that can be used to store intermediate results')
|
||||
parser.add_argument('--vcores', default = 1, type=int,
|
||||
help = 'number of vcpores to request in each mapper, set it if each rabit job is multi-threaded')
|
||||
parser.add_argument('-mem', '--memory_mb', default=1024, type=int,
|
||||
help = 'maximum memory used by the process. Guide: set it large (near mapred.cluster.max.map.memory.mb)'\
|
||||
'if you are running multi-threading rabit,'\
|
||||
'so that each node can occupy all the mapper slots in a machine for maximum performance')
|
||||
parser.add_argument('--libhdfs-opts', default='-Xmx128m', type=str,
|
||||
help = 'setting to be passed to libhdfs')
|
||||
parser.add_argument('--name-node', default='default', type=str,
|
||||
help = 'the namenode address of hdfs, libhdfs should connect to, normally leave it as default')
|
||||
|
||||
parser.add_argument('command', nargs='+',
|
||||
help = 'command for rabit program')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.jobname == 'auto':
|
||||
args.jobname = ('Rabit[nworker=%d]:' % args.nworker) + args.command[0].split('/')[-1];
|
||||
|
||||
if hadoop_binary == None:
|
||||
parser.add_argument('-hb', '--hadoop_binary', required = True,
|
||||
help="path to hadoop binary file")
|
||||
else:
|
||||
parser.add_argument('-hb', '--hadoop_binary', default = hadoop_binary,
|
||||
help="path to hadoop binary file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.jobname == 'auto':
|
||||
args.jobname = ('Rabit[nworker=%d]:' % args.nworker) + args.command[0].split('/')[-1];
|
||||
|
||||
# detech hadoop version
|
||||
(out, err) = subprocess.Popen('%s version' % args.hadoop_binary, shell = True, stdout=subprocess.PIPE).communicate()
|
||||
out = out.split('\n')[0].split()
|
||||
assert out[0] == 'Hadoop', 'cannot parse hadoop version string'
|
||||
hadoop_version = out[1].split('.')
|
||||
|
||||
(classpath, err) = subprocess.Popen('%s classpath --glob' % args.hadoop_binary, shell = True, stdout=subprocess.PIPE).communicate()
|
||||
|
||||
if hadoop_version < 2:
|
||||
print 'Current Hadoop Version is %s, rabit_yarn will need Yarn(Hadoop 2.0)' % out[1]
|
||||
|
||||
def submit_yarn(nworker, worker_args, worker_env):
|
||||
fset = set([YARN_JAR_PATH, YARN_BOOT_PY])
|
||||
if args.auto_file_cache != 0:
|
||||
for i in range(len(args.command)):
|
||||
f = args.command[i]
|
||||
if os.path.exists(f):
|
||||
fset.add(f)
|
||||
if i == 0:
|
||||
args.command[i] = './' + args.command[i].split('/')[-1]
|
||||
else:
|
||||
args.command[i] = './' + args.command[i].split('/')[-1]
|
||||
if args.command[0].endswith('.py'):
|
||||
flst = [WRAPPER_PATH + '/rabit.py',
|
||||
WRAPPER_PATH + '/librabit_wrapper.so',
|
||||
WRAPPER_PATH + '/librabit_wrapper_mock.so']
|
||||
for f in flst:
|
||||
if os.path.exists(f):
|
||||
fset.add(f)
|
||||
|
||||
cmd = 'java -cp `%s classpath`:%s org.apache.hadoop.yarn.rabit.Client ' % (args.hadoop_binary, YARN_JAR_PATH)
|
||||
env = os.environ.copy()
|
||||
for k, v in worker_env.items():
|
||||
env[k] = str(v)
|
||||
env['rabit_cpu_vcores'] = str(args.vcores)
|
||||
env['rabit_memory_mb'] = str(args.memory_mb)
|
||||
env['rabit_world_size'] = str(args.nworker)
|
||||
env['rabit_hdfs_opts'] = str(args.libhdfs_opts)
|
||||
env['rabit_hdfs_namenode'] = str(args.name_node)
|
||||
|
||||
if args.files != None:
|
||||
for flst in args.files:
|
||||
for f in flst.split('#'):
|
||||
fset.add(f)
|
||||
for f in fset:
|
||||
cmd += ' -file %s' % f
|
||||
cmd += ' -jobname %s ' % args.jobname
|
||||
cmd += ' -tempdir %s ' % args.tempdir
|
||||
cmd += ' -queue %s ' % args.queue
|
||||
cmd += (' '.join(['./run_hdfs_prog.py'] + args.command + worker_args))
|
||||
if args.verbose != 0:
|
||||
print cmd
|
||||
subprocess.check_call(cmd, shell = True, env = env)
|
||||
|
||||
tracker.submit(args.nworker, [], fun_submit = submit_yarn, verbose = args.verbose, hostIP = args.host_ip)
|
||||
4
yarn/.gitignore
vendored
4
yarn/.gitignore
vendored
@ -1,4 +0,0 @@
|
||||
bin
|
||||
.classpath
|
||||
.project
|
||||
*.jar
|
||||
@ -1,5 +0,0 @@
|
||||
rabit-yarn
|
||||
=====
|
||||
* This folder contains Application code to allow rabit run on Yarn.
|
||||
* You can use [../tracker/rabit_yarn.py](../tracker/rabit_yarn.py) to submit the job
|
||||
- run ```./build.sh``` to build the jar, before using the script
|
||||
@ -1,8 +0,0 @@
|
||||
#!/bin/bash
|
||||
if [ ! -d bin ]; then
|
||||
mkdir bin
|
||||
fi
|
||||
|
||||
CPATH=`${HADOOP_HOME}/bin/hadoop classpath`
|
||||
javac -cp $CPATH -d bin src/org/apache/hadoop/yarn/rabit/*
|
||||
jar cf rabit-yarn.jar -C bin .
|
||||
@ -1,45 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
this script helps setup classpath env for HDFS, before running program
|
||||
that links with libhdfs
|
||||
"""
|
||||
import glob
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print 'Usage: the command you want to run'
|
||||
|
||||
hadoop_home = os.getenv('HADOOP_HOME')
|
||||
hdfs_home = os.getenv('HADOOP_HDFS_HOME')
|
||||
java_home = os.getenv('JAVA_HOME')
|
||||
if hadoop_home is None:
|
||||
hadoop_home = os.getenv('HADOOP_PREFIX')
|
||||
assert hadoop_home is not None, 'need to set HADOOP_HOME'
|
||||
assert hdfs_home is not None, 'need to set HADOOP_HDFS_HOME'
|
||||
assert java_home is not None, 'need to set JAVA_HOME'
|
||||
|
||||
(classpath, err) = subprocess.Popen('%s/bin/hadoop classpath' % hadoop_home,
|
||||
stdout=subprocess.PIPE, shell = True,
|
||||
env = os.environ).communicate()
|
||||
cpath = []
|
||||
for f in classpath.split(':'):
|
||||
cpath += glob.glob(f)
|
||||
|
||||
lpath = []
|
||||
lpath.append('%s/lib/native' % hdfs_home)
|
||||
lpath.append('%s/jre/lib/amd64/server' % java_home)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['CLASSPATH'] = '${CLASSPATH}:' + (':'.join(cpath))
|
||||
|
||||
# setup hdfs options
|
||||
if 'rabit_hdfs_opts' in env:
|
||||
env['LIBHDFS_OPTS'] = env['rabit_hdfs_opts']
|
||||
elif 'LIBHDFS_OPTS' not in env:
|
||||
env['LIBHDFS_OPTS'] = '--Xmx128m'
|
||||
|
||||
env['LD_LIBRARY_PATH'] = '${LD_LIBRARY_PATH}:' + (':'.join(lpath))
|
||||
ret = subprocess.call(args = sys.argv[1:], env = env)
|
||||
sys.exit(ret)
|
||||
@ -1,570 +0,0 @@
|
||||
package org.apache.hadoop.yarn.rabit;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Queue;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.FileStatus;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
import org.apache.hadoop.io.DataOutputBuffer;
|
||||
import org.apache.hadoop.yarn.util.ConverterUtils;
|
||||
import org.apache.hadoop.yarn.util.Records;
|
||||
import org.apache.hadoop.yarn.conf.YarnConfiguration;
|
||||
import org.apache.hadoop.yarn.api.ApplicationConstants;
|
||||
import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse;
|
||||
import org.apache.hadoop.yarn.api.records.Container;
|
||||
import org.apache.hadoop.yarn.api.records.ContainerExitStatus;
|
||||
import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
|
||||
import org.apache.hadoop.yarn.api.records.ContainerState;
|
||||
import org.apache.hadoop.yarn.api.records.FinalApplicationStatus;
|
||||
import org.apache.hadoop.yarn.api.records.LocalResource;
|
||||
import org.apache.hadoop.yarn.api.records.LocalResourceType;
|
||||
import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
|
||||
import org.apache.hadoop.yarn.api.records.Priority;
|
||||
import org.apache.hadoop.yarn.api.records.Resource;
|
||||
import org.apache.hadoop.yarn.api.records.ContainerId;
|
||||
import org.apache.hadoop.yarn.api.records.ContainerStatus;
|
||||
import org.apache.hadoop.yarn.api.records.NodeReport;
|
||||
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest;
|
||||
import org.apache.hadoop.yarn.client.api.async.NMClientAsync;
|
||||
import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
|
||||
import org.apache.hadoop.security.Credentials;
|
||||
import org.apache.hadoop.security.UserGroupInformation;
|
||||
|
||||
/**
|
||||
* application master for allocating resources of rabit client
|
||||
*
|
||||
* @author Tianqi Chen
|
||||
*/
|
||||
public class ApplicationMaster {
|
||||
// logger
|
||||
private static final Log LOG = LogFactory.getLog(ApplicationMaster.class);
|
||||
// configuration
|
||||
private Configuration conf = new YarnConfiguration();
|
||||
// hdfs handler
|
||||
private FileSystem dfs;
|
||||
|
||||
// number of cores allocated for each task
|
||||
private int numVCores = 1;
|
||||
// memory needed requested for the task
|
||||
private int numMemoryMB = 10;
|
||||
// priority of the app master
|
||||
private int appPriority = 0;
|
||||
// total number of tasks
|
||||
private int numTasks = 1;
|
||||
// maximum number of attempts to try in each task
|
||||
private int maxNumAttempt = 3;
|
||||
// command to launch
|
||||
private String command = "";
|
||||
|
||||
// username
|
||||
private String userName = "";
|
||||
// user credentials
|
||||
private Credentials credentials = null;
|
||||
// security tokens
|
||||
private ByteBuffer securityTokens = null;
|
||||
// application tracker hostname
|
||||
private String appHostName = "";
|
||||
// tracker URL to do
|
||||
private String appTrackerUrl = "";
|
||||
// tracker port
|
||||
private int appTrackerPort = 0;
|
||||
|
||||
// whether we start to abort the application, due to whatever fatal reasons
|
||||
private boolean startAbort = false;
|
||||
// worker resources
|
||||
private Map<String, LocalResource> workerResources = new java.util.HashMap<String, LocalResource>();
|
||||
// record the aborting reason
|
||||
private String abortDiagnosis = "";
|
||||
// resource manager
|
||||
private AMRMClientAsync<ContainerRequest> rmClient = null;
|
||||
// node manager
|
||||
private NMClientAsync nmClient = null;
|
||||
|
||||
// list of tasks that pending for resources to be allocated
|
||||
private final Queue<TaskRecord> pendingTasks = new java.util.LinkedList<TaskRecord>();
|
||||
// map containerId->task record of tasks that was running
|
||||
private final Map<ContainerId, TaskRecord> runningTasks = new java.util.HashMap<ContainerId, TaskRecord>();
|
||||
// collection of tasks
|
||||
private final Collection<TaskRecord> finishedTasks = new java.util.LinkedList<TaskRecord>();
|
||||
// collection of killed tasks
|
||||
private final Collection<TaskRecord> killedTasks = new java.util.LinkedList<TaskRecord>();
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
new ApplicationMaster().run(args);
|
||||
}
|
||||
|
||||
private ApplicationMaster() throws IOException {
|
||||
dfs = FileSystem.get(conf);
|
||||
userName = UserGroupInformation.getCurrentUser().getShortUserName();
|
||||
credentials = UserGroupInformation.getCurrentUser().getCredentials();
|
||||
DataOutputBuffer buffer = new DataOutputBuffer();
|
||||
this.credentials.writeTokenStorageToStream(buffer);
|
||||
this.securityTokens = ByteBuffer.wrap(buffer.getData());
|
||||
}
|
||||
/**
|
||||
* get integer argument from environment variable
|
||||
*
|
||||
* @param name
|
||||
* name of key
|
||||
* @param required
|
||||
* whether this is required
|
||||
* @param defv
|
||||
* default value
|
||||
* @return the requested result
|
||||
*/
|
||||
private int getEnvInteger(String name, boolean required, int defv)
|
||||
throws IOException {
|
||||
String value = System.getenv(name);
|
||||
if (value == null) {
|
||||
if (required) {
|
||||
throw new IOException("environment variable " + name
|
||||
+ " not set");
|
||||
} else {
|
||||
return defv;
|
||||
}
|
||||
}
|
||||
return Integer.valueOf(value);
|
||||
}
|
||||
|
||||
/**
|
||||
* initialize from arguments and command lines
|
||||
*
|
||||
* @param args
|
||||
*/
|
||||
private void initArgs(String args[]) throws IOException {
|
||||
LOG.info("Start AM as user=" + this.userName);
|
||||
// get user name
|
||||
userName = UserGroupInformation.getCurrentUser().getShortUserName();
|
||||
// cached maps
|
||||
Map<String, Path> cacheFiles = new java.util.HashMap<String, Path>();
|
||||
for (int i = 0; i < args.length; ++i) {
|
||||
if (args[i].equals("-file")) {
|
||||
String[] arr = args[++i].split("#");
|
||||
Path path = new Path(arr[0]);
|
||||
if (arr.length == 1) {
|
||||
cacheFiles.put(path.getName(), path);
|
||||
} else {
|
||||
cacheFiles.put(arr[1], path);
|
||||
}
|
||||
} else {
|
||||
this.command += args[i] + " ";
|
||||
}
|
||||
}
|
||||
for (Map.Entry<String, Path> e : cacheFiles.entrySet()) {
|
||||
LocalResource r = Records.newRecord(LocalResource.class);
|
||||
FileStatus status = dfs.getFileStatus(e.getValue());
|
||||
r.setResource(ConverterUtils.getYarnUrlFromPath(e.getValue()));
|
||||
r.setSize(status.getLen());
|
||||
r.setTimestamp(status.getModificationTime());
|
||||
r.setType(LocalResourceType.FILE);
|
||||
r.setVisibility(LocalResourceVisibility.APPLICATION);
|
||||
workerResources.put(e.getKey(), r);
|
||||
}
|
||||
numVCores = this.getEnvInteger("rabit_cpu_vcores", true, numVCores);
|
||||
numMemoryMB = this.getEnvInteger("rabit_memory_mb", true, numMemoryMB);
|
||||
numTasks = this.getEnvInteger("rabit_world_size", true, numTasks);
|
||||
maxNumAttempt = this.getEnvInteger("rabit_max_attempt", false,
|
||||
maxNumAttempt);
|
||||
}
|
||||
|
||||
/**
|
||||
* called to start the application
|
||||
*/
|
||||
private void run(String args[]) throws Exception {
|
||||
this.initArgs(args);
|
||||
this.rmClient = AMRMClientAsync.createAMRMClientAsync(1000,
|
||||
new RMCallbackHandler());
|
||||
this.nmClient = NMClientAsync
|
||||
.createNMClientAsync(new NMCallbackHandler());
|
||||
this.rmClient.init(conf);
|
||||
this.rmClient.start();
|
||||
this.nmClient.init(conf);
|
||||
this.nmClient.start();
|
||||
RegisterApplicationMasterResponse response = this.rmClient
|
||||
.registerApplicationMaster(this.appHostName,
|
||||
this.appTrackerPort, this.appTrackerUrl);
|
||||
|
||||
boolean success = false;
|
||||
String diagnostics = "";
|
||||
try {
|
||||
// list of tasks that waits to be submit
|
||||
java.util.Collection<TaskRecord> tasks = new java.util.LinkedList<TaskRecord>();
|
||||
// add waiting tasks
|
||||
for (int i = 0; i < this.numTasks; ++i) {
|
||||
tasks.add(new TaskRecord(i));
|
||||
}
|
||||
Resource maxResource = response.getMaximumResourceCapability();
|
||||
|
||||
if (maxResource.getMemory() < this.numMemoryMB) {
|
||||
LOG.warn("[Rabit] memory requested exceed bound "
|
||||
+ maxResource.getMemory());
|
||||
this.numMemoryMB = maxResource.getMemory();
|
||||
}
|
||||
if (maxResource.getVirtualCores() < this.numVCores) {
|
||||
LOG.warn("[Rabit] memory requested exceed bound "
|
||||
+ maxResource.getVirtualCores());
|
||||
this.numVCores = maxResource.getVirtualCores();
|
||||
}
|
||||
this.submitTasks(tasks);
|
||||
LOG.info("[Rabit] ApplicationMaster started");
|
||||
while (!this.doneAllJobs()) {
|
||||
try {
|
||||
Thread.sleep(100);
|
||||
} catch (InterruptedException e) {
|
||||
}
|
||||
}
|
||||
assert (killedTasks.size() + finishedTasks.size() == numTasks);
|
||||
success = finishedTasks.size() == numTasks;
|
||||
LOG.info("Application completed. Stopping running containers");
|
||||
diagnostics = "Diagnostics." + ", num_tasks" + this.numTasks
|
||||
+ ", finished=" + this.finishedTasks.size() + ", failed="
|
||||
+ this.killedTasks.size() + "\n" + this.abortDiagnosis;
|
||||
nmClient.stop();
|
||||
LOG.info(diagnostics);
|
||||
} catch (Exception e) {
|
||||
diagnostics = e.toString();
|
||||
}
|
||||
rmClient.unregisterApplicationMaster(
|
||||
success ? FinalApplicationStatus.SUCCEEDED
|
||||
: FinalApplicationStatus.FAILED, diagnostics,
|
||||
appTrackerUrl);
|
||||
if (!success)
|
||||
throw new Exception("Application not successful");
|
||||
}
|
||||
|
||||
/**
|
||||
* check if the job finishes
|
||||
*
|
||||
* @return whether we finished all the jobs
|
||||
*/
|
||||
private synchronized boolean doneAllJobs() {
|
||||
return pendingTasks.size() == 0 && runningTasks.size() == 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* submit tasks to request containers for the tasks
|
||||
*
|
||||
* @param tasks
|
||||
* a collection of tasks we want to ask container for
|
||||
*/
|
||||
private synchronized void submitTasks(Collection<TaskRecord> tasks) {
|
||||
for (TaskRecord r : tasks) {
|
||||
Resource resource = Records.newRecord(Resource.class);
|
||||
resource.setMemory(numMemoryMB);
|
||||
resource.setVirtualCores(numVCores);
|
||||
Priority priority = Records.newRecord(Priority.class);
|
||||
priority.setPriority(this.appPriority);
|
||||
r.containerRequest = new ContainerRequest(resource, null, null,
|
||||
priority);
|
||||
rmClient.addContainerRequest(r.containerRequest);
|
||||
pendingTasks.add(r);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* launch the task on container
|
||||
*
|
||||
* @param container
|
||||
* container to run the task
|
||||
* @param task
|
||||
* the task
|
||||
*/
|
||||
private void launchTask(Container container, TaskRecord task) {
|
||||
task.container = container;
|
||||
task.containerRequest = null;
|
||||
ContainerLaunchContext ctx = Records
|
||||
.newRecord(ContainerLaunchContext.class);
|
||||
String cmd =
|
||||
// use this to setup CLASSPATH correctly for libhdfs
|
||||
this.command + " 1>"
|
||||
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
|
||||
+ " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR
|
||||
+ "/stderr";
|
||||
ctx.setCommands(Collections.singletonList(cmd));
|
||||
ctx.setTokens(this.securityTokens);
|
||||
LOG.info(workerResources);
|
||||
ctx.setLocalResources(this.workerResources);
|
||||
// setup environment variables
|
||||
Map<String, String> env = new java.util.HashMap<String, String>();
|
||||
|
||||
// setup class path, this is kind of duplicated, ignoring
|
||||
StringBuilder cpath = new StringBuilder("${CLASSPATH}:./*");
|
||||
for (String c : conf.getStrings(
|
||||
YarnConfiguration.YARN_APPLICATION_CLASSPATH,
|
||||
YarnConfiguration.DEFAULT_YARN_APPLICATION_CLASSPATH)) {
|
||||
String[] arrPath = c.split(":");
|
||||
for (String ps : arrPath) {
|
||||
if (ps.endsWith("*.jar") || ps.endsWith("*")) {
|
||||
ps = ps.substring(0, ps.lastIndexOf('*'));
|
||||
String prefix = ps.substring(0, ps.lastIndexOf('/'));
|
||||
if (ps.startsWith("$")) {
|
||||
String[] arr =ps.split("/", 2);
|
||||
if (arr.length != 2) continue;
|
||||
try {
|
||||
ps = System.getenv(arr[0].substring(1)) + '/' + arr[1];
|
||||
} catch (Exception e){
|
||||
continue;
|
||||
}
|
||||
}
|
||||
File dir = new File(ps);
|
||||
if (dir.isDirectory()) {
|
||||
for (File f: dir.listFiles()) {
|
||||
if (f.isFile() && f.getPath().endsWith(".jar")) {
|
||||
cpath.append(":");
|
||||
cpath.append(prefix + '/' + f.getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cpath.append(':');
|
||||
cpath.append(ps.trim());
|
||||
}
|
||||
}
|
||||
}
|
||||
// already use hadoop command to get class path in worker, maybe a
|
||||
// better solution in future
|
||||
env.put("CLASSPATH", cpath.toString());
|
||||
//LOG.info("CLASSPATH =" + cpath.toString());
|
||||
// setup LD_LIBARY_PATH path for libhdfs
|
||||
env.put("LD_LIBRARY_PATH",
|
||||
"${LD_LIBRARY_PATH}:$HADOOP_HDFS_HOME/lib/native:$JAVA_HOME/jre/lib/amd64/server");
|
||||
env.put("PYTHONPATH", "${PYTHONPATH}:.");
|
||||
// inherit all rabit variables
|
||||
for (Map.Entry<String, String> e : System.getenv().entrySet()) {
|
||||
if (e.getKey().startsWith("rabit_")) {
|
||||
env.put(e.getKey(), e.getValue());
|
||||
}
|
||||
if (e.getKey() == "LIBHDFS_OPTS") {
|
||||
env.put(e.getKey(), e.getValue());
|
||||
}
|
||||
}
|
||||
env.put("rabit_task_id", String.valueOf(task.taskId));
|
||||
env.put("rabit_num_trial", String.valueOf(task.attemptCounter));
|
||||
// ctx.setUser(userName);
|
||||
ctx.setEnvironment(env);
|
||||
synchronized (this) {
|
||||
assert (!this.runningTasks.containsKey(container.getId()));
|
||||
this.runningTasks.put(container.getId(), task);
|
||||
this.nmClient.startContainerAsync(container, ctx);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* free the containers that have not yet been launched
|
||||
*
|
||||
* @param containers
|
||||
*/
|
||||
private synchronized void freeUnusedContainers(
|
||||
Collection<Container> containers) {
|
||||
}
|
||||
|
||||
/**
|
||||
* handle method for AMRMClientAsync.CallbackHandler container allocation
|
||||
*
|
||||
* @param containers
|
||||
*/
|
||||
private synchronized void onContainersAllocated(List<Container> containers) {
|
||||
if (this.startAbort) {
|
||||
this.freeUnusedContainers(containers);
|
||||
return;
|
||||
}
|
||||
Collection<Container> freelist = new java.util.LinkedList<Container>();
|
||||
for (Container c : containers) {
|
||||
TaskRecord task;
|
||||
task = pendingTasks.poll();
|
||||
if (task == null) {
|
||||
freelist.add(c);
|
||||
continue;
|
||||
}
|
||||
this.launchTask(c, task);
|
||||
}
|
||||
this.freeUnusedContainers(freelist);
|
||||
}
|
||||
|
||||
/**
|
||||
* start aborting the job
|
||||
*
|
||||
* @param msg
|
||||
* the fatal message
|
||||
*/
|
||||
private synchronized void abortJob(String msg) {
|
||||
if (!this.startAbort)
|
||||
this.abortDiagnosis = msg;
|
||||
this.startAbort = true;
|
||||
for (TaskRecord r : this.runningTasks.values()) {
|
||||
if (!r.abortRequested) {
|
||||
nmClient.stopContainerAsync(r.container.getId(),
|
||||
r.container.getNodeId());
|
||||
r.abortRequested = true;
|
||||
}
|
||||
}
|
||||
this.killedTasks.addAll(this.pendingTasks);
|
||||
for (TaskRecord r : this.pendingTasks) {
|
||||
rmClient.removeContainerRequest(r.containerRequest);
|
||||
}
|
||||
this.pendingTasks.clear();
|
||||
LOG.info(msg);
|
||||
}
|
||||
|
||||
/**
|
||||
* handle non fatal failures
|
||||
*
|
||||
* @param cid
|
||||
*/
|
||||
private synchronized void handleFailure(Collection<ContainerId> failed) {
|
||||
Collection<TaskRecord> tasks = new java.util.LinkedList<TaskRecord>();
|
||||
for (ContainerId cid : failed) {
|
||||
TaskRecord r = runningTasks.remove(cid);
|
||||
if (r == null) {
|
||||
continue;
|
||||
}
|
||||
LOG.info("Task "
|
||||
+ r.taskId
|
||||
+ "failed on "
|
||||
+ r.container.getId()
|
||||
+ ". See LOG at : "
|
||||
+ String.format("http://%s/node/containerlogs/%s/"
|
||||
+ userName, r.container.getNodeHttpAddress(),
|
||||
r.container.getId()));
|
||||
r.attemptCounter += 1;
|
||||
r.container = null;
|
||||
tasks.add(r);
|
||||
if (r.attemptCounter >= this.maxNumAttempt) {
|
||||
this.abortJob("[Rabit] Task " + r.taskId + " failed more than "
|
||||
+ r.attemptCounter + "times");
|
||||
}
|
||||
}
|
||||
if (this.startAbort) {
|
||||
this.killedTasks.addAll(tasks);
|
||||
} else {
|
||||
this.submitTasks(tasks);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* handle method for AMRMClientAsync.CallbackHandler container allocation
|
||||
*
|
||||
* @param status
|
||||
* list of status
|
||||
*/
|
||||
private synchronized void onContainersCompleted(List<ContainerStatus> status) {
|
||||
Collection<ContainerId> failed = new java.util.LinkedList<ContainerId>();
|
||||
for (ContainerStatus s : status) {
|
||||
assert (s.getState().equals(ContainerState.COMPLETE));
|
||||
int exstatus = s.getExitStatus();
|
||||
TaskRecord r = runningTasks.get(s.getContainerId());
|
||||
if (r == null)
|
||||
continue;
|
||||
if (exstatus == ContainerExitStatus.SUCCESS) {
|
||||
finishedTasks.add(r);
|
||||
runningTasks.remove(s.getContainerId());
|
||||
} else {
|
||||
try {
|
||||
if (exstatus == ContainerExitStatus.class.getField(
|
||||
"KILLED_EXCEEDED_PMEM").getInt(null)) {
|
||||
this.abortJob("[Rabit] Task "
|
||||
+ r.taskId
|
||||
+ " killed because of exceeding allocated physical memory");
|
||||
continue;
|
||||
}
|
||||
if (exstatus == ContainerExitStatus.class.getField(
|
||||
"KILLED_EXCEEDED_VMEM").getInt(null)) {
|
||||
this.abortJob("[Rabit] Task "
|
||||
+ r.taskId
|
||||
+ " killed because of exceeding allocated virtual memory");
|
||||
continue;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
}
|
||||
LOG.info("[Rabit] Task " + r.taskId + " exited with status "
|
||||
+ exstatus + " Diagnostics:"+ s.getDiagnostics());
|
||||
failed.add(s.getContainerId());
|
||||
}
|
||||
}
|
||||
this.handleFailure(failed);
|
||||
}
|
||||
|
||||
/**
|
||||
* callback handler for resource manager
|
||||
*/
|
||||
private class RMCallbackHandler implements AMRMClientAsync.CallbackHandler {
|
||||
@Override
|
||||
public float getProgress() {
|
||||
return 1.0f - (float) (pendingTasks.size()) / numTasks;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onContainersAllocated(List<Container> containers) {
|
||||
ApplicationMaster.this.onContainersAllocated(containers);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onContainersCompleted(List<ContainerStatus> status) {
|
||||
ApplicationMaster.this.onContainersCompleted(status);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable ex) {
|
||||
ApplicationMaster.this.abortJob("[Rabit] Resource manager Error "
|
||||
+ ex.toString());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNodesUpdated(List<NodeReport> nodereport) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onShutdownRequest() {
|
||||
ApplicationMaster.this
|
||||
.abortJob("[Rabit] Get shutdown request, start to shutdown...");
|
||||
}
|
||||
}
|
||||
|
||||
private class NMCallbackHandler implements NMClientAsync.CallbackHandler {
|
||||
@Override
|
||||
public void onContainerStarted(ContainerId cid,
|
||||
Map<String, ByteBuffer> services) {
|
||||
LOG.debug("onContainerStarted Invoked");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onContainerStatusReceived(ContainerId cid,
|
||||
ContainerStatus status) {
|
||||
LOG.debug("onContainerStatusReceived Invoked");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onContainerStopped(ContainerId cid) {
|
||||
LOG.debug("onContainerStopped Invoked");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onGetContainerStatusError(ContainerId cid, Throwable ex) {
|
||||
LOG.debug("onGetContainerStatusError Invoked: " + ex.toString());
|
||||
ApplicationMaster.this
|
||||
.handleFailure(Collections.singletonList(cid));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onStartContainerError(ContainerId cid, Throwable ex) {
|
||||
LOG.debug("onStartContainerError Invoked: " + ex.toString());
|
||||
ApplicationMaster.this
|
||||
.handleFailure(Collections.singletonList(cid));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onStopContainerError(ContainerId cid, Throwable ex) {
|
||||
LOG.info("onStopContainerError Invoked: " + ex.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,269 +0,0 @@
|
||||
package org.apache.hadoop.yarn.rabit;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
import org.apache.hadoop.fs.FileStatus;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.permission.FsPermission;
|
||||
import org.apache.hadoop.io.DataOutputBuffer;
|
||||
import org.apache.hadoop.security.UserGroupInformation;
|
||||
import org.apache.hadoop.security.Credentials;
|
||||
import org.apache.hadoop.yarn.api.ApplicationConstants;
|
||||
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
||||
import org.apache.hadoop.yarn.api.records.ApplicationReport;
|
||||
import org.apache.hadoop.yarn.api.records.ApplicationSubmissionContext;
|
||||
import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
|
||||
import org.apache.hadoop.yarn.api.records.FinalApplicationStatus;
|
||||
import org.apache.hadoop.yarn.api.records.LocalResource;
|
||||
import org.apache.hadoop.yarn.api.records.LocalResourceType;
|
||||
import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
|
||||
import org.apache.hadoop.yarn.api.records.Resource;
|
||||
import org.apache.hadoop.yarn.api.records.QueueInfo;
|
||||
import org.apache.hadoop.yarn.api.records.YarnApplicationState;
|
||||
import org.apache.hadoop.yarn.client.api.YarnClient;
|
||||
import org.apache.hadoop.yarn.client.api.YarnClientApplication;
|
||||
import org.apache.hadoop.yarn.conf.YarnConfiguration;
|
||||
import org.apache.hadoop.yarn.util.ConverterUtils;
|
||||
import org.apache.hadoop.yarn.util.Records;
|
||||
|
||||
public class Client {
|
||||
// logger
|
||||
private static final Log LOG = LogFactory.getLog(Client.class);
|
||||
// permission for temp file
|
||||
private static final FsPermission permTemp = new FsPermission("777");
|
||||
// configuration
|
||||
private YarnConfiguration conf = new YarnConfiguration();
|
||||
// hdfs handler
|
||||
private FileSystem dfs;
|
||||
// cached maps
|
||||
private Map<String, String> cacheFiles = new java.util.HashMap<String, String>();
|
||||
// enviroment variable to setup cachefiles
|
||||
private String cacheFileArg = "";
|
||||
// args to pass to application master
|
||||
private String appArgs = "";
|
||||
// HDFS Path to store temporal result
|
||||
private String tempdir = "/tmp";
|
||||
// user name
|
||||
private String userName = "";
|
||||
// user credentials
|
||||
private Credentials credentials = null;
|
||||
// job name
|
||||
private String jobName = "";
|
||||
// queue
|
||||
private String queue = "default";
|
||||
/**
|
||||
* constructor
|
||||
* @throws IOException
|
||||
*/
|
||||
private Client() throws IOException {
|
||||
conf.addResource(new Path(System.getenv("HADOOP_CONF_DIR") +"/core-site.xml"));
|
||||
conf.addResource(new Path(System.getenv("HADOOP_CONF_DIR") +"/hdfs-site.xml"));
|
||||
dfs = FileSystem.get(conf);
|
||||
userName = UserGroupInformation.getCurrentUser().getShortUserName();
|
||||
credentials = UserGroupInformation.getCurrentUser().getCredentials();
|
||||
}
|
||||
|
||||
/**
|
||||
* setup security token given current user
|
||||
* @return the ByeBuffer containing the security tokens
|
||||
* @throws IOException
|
||||
*/
|
||||
private ByteBuffer setupTokens() throws IOException {
|
||||
DataOutputBuffer buffer = new DataOutputBuffer();
|
||||
this.credentials.writeTokenStorageToStream(buffer);
|
||||
return ByteBuffer.wrap(buffer.getData());
|
||||
}
|
||||
|
||||
/**
|
||||
* setup all the cached files
|
||||
*
|
||||
* @param fmaps
|
||||
* the file maps
|
||||
* @return the resource map
|
||||
* @throws IOException
|
||||
*/
|
||||
private Map<String, LocalResource> setupCacheFiles(ApplicationId appId) throws IOException {
|
||||
// create temporary rabit directory
|
||||
Path tmpPath = new Path(this.tempdir);
|
||||
if (!dfs.exists(tmpPath)) {
|
||||
dfs.mkdirs(tmpPath, permTemp);
|
||||
LOG.info("HDFS temp directory do not exist, creating.. " + tmpPath);
|
||||
}
|
||||
tmpPath = new Path(tmpPath + "/temp-rabit-yarn-" + appId);
|
||||
if (dfs.exists(tmpPath)) {
|
||||
dfs.delete(tmpPath, true);
|
||||
}
|
||||
// create temporary directory
|
||||
FileSystem.mkdirs(dfs, tmpPath, permTemp);
|
||||
|
||||
StringBuilder cstr = new StringBuilder();
|
||||
Map<String, LocalResource> rmap = new java.util.HashMap<String, LocalResource>();
|
||||
for (Map.Entry<String, String> e : cacheFiles.entrySet()) {
|
||||
LocalResource r = Records.newRecord(LocalResource.class);
|
||||
Path path = new Path(e.getValue());
|
||||
// copy local data to temporary folder in HDFS
|
||||
if (!e.getValue().startsWith("hdfs://")) {
|
||||
Path dst = new Path("hdfs://" + tmpPath + "/"+ path.getName());
|
||||
dfs.copyFromLocalFile(false, true, path, dst);
|
||||
dfs.setPermission(dst, permTemp);
|
||||
dfs.deleteOnExit(dst);
|
||||
path = dst;
|
||||
}
|
||||
FileStatus status = dfs.getFileStatus(path);
|
||||
r.setResource(ConverterUtils.getYarnUrlFromPath(path));
|
||||
r.setSize(status.getLen());
|
||||
r.setTimestamp(status.getModificationTime());
|
||||
r.setType(LocalResourceType.FILE);
|
||||
r.setVisibility(LocalResourceVisibility.APPLICATION);
|
||||
rmap.put(e.getKey(), r);
|
||||
cstr.append(" -file \"");
|
||||
cstr.append(path.toString());
|
||||
cstr.append('#');
|
||||
cstr.append(e.getKey());
|
||||
cstr.append("\"");
|
||||
}
|
||||
|
||||
dfs.deleteOnExit(tmpPath);
|
||||
this.cacheFileArg = cstr.toString();
|
||||
return rmap;
|
||||
}
|
||||
|
||||
/**
|
||||
* get the environment variables for container
|
||||
*
|
||||
* @return the env variable for child class
|
||||
*/
|
||||
private Map<String, String> getEnvironment() {
|
||||
// Setup environment variables
|
||||
Map<String, String> env = new java.util.HashMap<String, String>();
|
||||
String cpath = "${CLASSPATH}:./*";
|
||||
for (String c : conf.getStrings(
|
||||
YarnConfiguration.YARN_APPLICATION_CLASSPATH,
|
||||
YarnConfiguration.DEFAULT_YARN_APPLICATION_CLASSPATH)) {
|
||||
cpath += ':';
|
||||
cpath += c.trim();
|
||||
}
|
||||
env.put("CLASSPATH", cpath);
|
||||
for (Map.Entry<String, String> e : System.getenv().entrySet()) {
|
||||
if (e.getKey().startsWith("rabit_")) {
|
||||
env.put(e.getKey(), e.getValue());
|
||||
}
|
||||
if (e.getKey() == "LIBHDFS_OPTS") {
|
||||
env.put(e.getKey(), e.getValue());
|
||||
}
|
||||
}
|
||||
LOG.debug(env);
|
||||
return env;
|
||||
}
|
||||
|
||||
/**
|
||||
* initialize the settings
|
||||
*
|
||||
* @param args
|
||||
*/
|
||||
private void initArgs(String[] args) {
|
||||
// directly pass all arguments except args0
|
||||
StringBuilder sargs = new StringBuilder("");
|
||||
for (int i = 0; i < args.length; ++i) {
|
||||
if (args[i].equals("-file")) {
|
||||
String[] arr = args[++i].split("#");
|
||||
if (arr.length == 1) {
|
||||
cacheFiles.put(new Path(arr[0]).getName(), arr[0]);
|
||||
} else {
|
||||
cacheFiles.put(arr[1], arr[0]);
|
||||
}
|
||||
} else if(args[i].equals("-jobname")) {
|
||||
this.jobName = args[++i];
|
||||
} else if(args[i].equals("-tempdir")) {
|
||||
this.tempdir = args[++i];
|
||||
} else if(args[i].equals("-queue")) {
|
||||
this.queue = args[++i];
|
||||
} else {
|
||||
sargs.append(" ");
|
||||
sargs.append(args[i]);
|
||||
}
|
||||
}
|
||||
this.appArgs = sargs.toString();
|
||||
}
|
||||
|
||||
private void run(String[] args) throws Exception {
|
||||
if (args.length == 0) {
|
||||
System.out.println("Usage: [options] [commands..]");
|
||||
System.out.println("options: [-file filename]");
|
||||
return;
|
||||
}
|
||||
this.initArgs(args);
|
||||
// Create yarnClient
|
||||
YarnClient yarnClient = YarnClient.createYarnClient();
|
||||
yarnClient.init(conf);
|
||||
yarnClient.start();
|
||||
|
||||
// Create application via yarnClient
|
||||
YarnClientApplication app = yarnClient.createApplication();
|
||||
|
||||
// Set up the container launch context for the application master
|
||||
ContainerLaunchContext amContainer = Records
|
||||
.newRecord(ContainerLaunchContext.class);
|
||||
ApplicationSubmissionContext appContext = app
|
||||
.getApplicationSubmissionContext();
|
||||
// Submit application
|
||||
ApplicationId appId = appContext.getApplicationId();
|
||||
// setup security token
|
||||
amContainer.setTokens(this.setupTokens());
|
||||
// setup cache-files and environment variables
|
||||
amContainer.setLocalResources(this.setupCacheFiles(appId));
|
||||
amContainer.setEnvironment(this.getEnvironment());
|
||||
String cmd = "$JAVA_HOME/bin/java"
|
||||
+ " -Xmx900M"
|
||||
+ " org.apache.hadoop.yarn.rabit.ApplicationMaster"
|
||||
+ this.cacheFileArg + ' ' + this.appArgs + " 1>"
|
||||
+ ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
|
||||
+ " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr";
|
||||
LOG.debug(cmd);
|
||||
amContainer.setCommands(Collections.singletonList(cmd));
|
||||
|
||||
// Set up resource type requirements for ApplicationMaster
|
||||
Resource capability = Records.newRecord(Resource.class);
|
||||
capability.setMemory(1024);
|
||||
capability.setVirtualCores(1);
|
||||
LOG.info("jobname=" + this.jobName + ",username=" + this.userName);
|
||||
|
||||
appContext.setApplicationName(jobName + ":RABIT-YARN");
|
||||
appContext.setAMContainerSpec(amContainer);
|
||||
appContext.setResource(capability);
|
||||
appContext.setQueue(queue);
|
||||
//appContext.setUser(userName);
|
||||
LOG.info("Submitting application " + appId);
|
||||
yarnClient.submitApplication(appContext);
|
||||
|
||||
ApplicationReport appReport = yarnClient.getApplicationReport(appId);
|
||||
YarnApplicationState appState = appReport.getYarnApplicationState();
|
||||
while (appState != YarnApplicationState.FINISHED
|
||||
&& appState != YarnApplicationState.KILLED
|
||||
&& appState != YarnApplicationState.FAILED) {
|
||||
Thread.sleep(100);
|
||||
appReport = yarnClient.getApplicationReport(appId);
|
||||
appState = appReport.getYarnApplicationState();
|
||||
}
|
||||
|
||||
System.out.println("Application " + appId + " finished with"
|
||||
+ " state " + appState + " at " + appReport.getFinishTime());
|
||||
if (!appReport.getFinalApplicationStatus().equals(
|
||||
FinalApplicationStatus.SUCCEEDED)) {
|
||||
System.err.println(appReport.getDiagnostics());
|
||||
System.out.println("Available queues:");
|
||||
for (QueueInfo q : yarnClient.getAllQueues()) {
|
||||
System.out.println(q.getQueueName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
new Client().run(args);
|
||||
}
|
||||
}
|
||||
@ -1,24 +0,0 @@
|
||||
package org.apache.hadoop.yarn.rabit;
|
||||
|
||||
import org.apache.hadoop.yarn.api.records.Container;
|
||||
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest;
|
||||
|
||||
/**
|
||||
* data structure to hold the task information
|
||||
*/
|
||||
public class TaskRecord {
|
||||
// task id of the task
|
||||
public int taskId = 0;
|
||||
// number of failed attempts to run the task
|
||||
public int attemptCounter = 0;
|
||||
// container request, can be null if task is already running
|
||||
public ContainerRequest containerRequest = null;
|
||||
// running container, can be null if the task is not launched
|
||||
public Container container = null;
|
||||
// whether we have requested abortion of this task
|
||||
public boolean abortRequested = false;
|
||||
|
||||
public TaskRecord(int taskId) {
|
||||
this.taskId = taskId;
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user