allow setup from env variables

This commit is contained in:
tqchen 2015-03-07 16:45:31 -08:00
parent 9b6bf57e79
commit 67ebf81e7a
8 changed files with 52 additions and 24 deletions

View File

@ -2,7 +2,7 @@ ifndef CXX
export CXX = g++
endif
export MPICXX = mpicxx
export LDFLAGS= -Llib
export LDFLAGS= -Llib -lrt
export WARNFLAGS= -Wall -Wextra -Wno-unused-parameter -Wno-unknown-pragmas -pedantic
export CFLAGS = -O3 -msse2 -fPIC $(WARNFLAGS)
@ -50,7 +50,7 @@ $(ALIB):
ar cr $@ $+
$(SLIB) :
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS)
clean:
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~ wrapper/*~

View File

@ -29,11 +29,24 @@ AllreduceBase::AllreduceBase(void) {
task_id = "NULL";
err_link = NULL;
this->SetParam("rabit_reduce_buffer", "256MB");
// setup possible enviroment variable of intrest
env_vars.push_back("rabit_task_id");
env_vars.push_back("rabit_num_trial");
env_vars.push_back("rabit_reduce_buffer");
env_vars.push_back("rabit_tracker_uri");
env_vars.push_back("rabit_tracker_port");
}
// initialization function
void AllreduceBase::Init(void) {
// setup from enviroment variables
// handler to get variables from env
for (size_t i = 0; i < env_vars.size(); ++i) {
const char *value = getenv(env_vars[i].c_str());
if (value != NULL) {
this->SetParam(env_vars[i].c_str(), value);
}
}
{
// handling for hadoop
const char *task_id = getenv("mapred_tip_id");

View File

@ -413,6 +413,8 @@ class AllreduceBase : public IEngine {
// pointer to links in the ring
LinkRecord *ring_prev, *ring_next;
//----- meta information-----
// list of enviroment variables that are of possible interest
std::vector<std::string> env_vars;
// unique identifier of the possible job this process is doing
// used to assign ranks, optional, default to NULL
std::string task_id;

View File

@ -27,7 +27,9 @@ AllreduceRobust::AllreduceRobust(void) {
result_buffer_round = 1;
global_lazycheck = NULL;
use_local_model = -1;
recover_counter = 0;
recover_counter = 0;
env_vars.push_back("rabit_global_replica");
env_vars.push_back("rabit_local_replica");
}
void AllreduceRobust::Init(void) {
AllreduceBase::Init();

View File

@ -31,35 +31,38 @@ nrep=0
rc=254
while [ $rc -eq 254 ];
do
export rabit_num_trial=$nrep
%s
%s %s rabit_num_trial=$nrep
%s
rc=$?;
nrep=$((nrep+1));
done
"""
def exec_cmd(cmd, taskid):
def exec_cmd(cmd, taskid, worker_env):
if cmd[0].find('/') == -1 and os.path.exists(cmd[0]) and os.name != 'nt':
cmd[0] = './' + cmd[0]
cmd = ' '.join(cmd)
arg = ' rabit_task_id=%d' % (taskid)
cmd = cmd + arg
env = {}
for k, v in worker_env.items():
env[k] = str(v)
env['rabit_task_id'] = str(taskid)
env['PYTHONPATH'] = WRAPPER_PATH
ntrial = 0
while True:
if os.name == 'nt':
prep = 'SET PYTHONPATH=\"%s\"\n' % WRAPPER_PATH
ret = subprocess.call(prep + cmd + ('rabit_num_trial=%d' % ntrial), shell=True)
env['rabit_num_trial'] = str(ntrial)
ret = subprocess.call(cmd, shell=True, env = env)
if ret == 254:
ntrial += 1
continue
else:
prep = 'PYTHONPATH=\"%s\" ' % WRAPPER_PATH
if args.verbose != 0:
bash = keepalive % (echo % cmd, prep, cmd)
if args.verbose != 0:
bash = keepalive % (echo % cmd, cmd)
else:
bash = keepalive % ('', prep, cmd)
ret = subprocess.call(bash, shell=True, executable='bash')
bash = keepalive % ('', cmd)
ret = subprocess.call(bash, shell=True, executable='bash', env = env)
if ret == 0:
if args.verbose != 0:
print 'Thread %d exit with 0' % taskid
@ -73,7 +76,7 @@ def exec_cmd(cmd, taskid):
# Note: this submit script is only used for demo purpose
# submission script using pyhton multi-threading
#
def mthread_submit(nslave, worker_args):
def mthread_submit(nslave, worker_args, worker_envs):
"""
customized submit script, that submit nslave jobs, each must contain args as parameter
note this can be a lambda function containing additional parameters in input
@ -84,7 +87,7 @@ def mthread_submit(nslave, worker_args):
"""
procs = {}
for i in range(nslave):
procs[i] = Thread(target = exec_cmd, args = (args.command + worker_args, i))
procs[i] = Thread(target = exec_cmd, args = (args.command + worker_args, i, worker_envs))
procs[i].daemon = True
procs[i].start()
for i in range(nslave):

View File

@ -94,8 +94,8 @@ use_yarn = int(hadoop_version[0]) >= 2
print 'Current Hadoop Version is %s' % out[1]
def hadoop_streaming(nworker, worker_args, use_yarn):
fset = set()
def hadoop_streaming(nworker, worker_args, worker_envs, use_yarn):
fset = set()
if args.auto_file_cache:
for i in range(len(args.command)):
f = args.command[i]
@ -113,6 +113,7 @@ def hadoop_streaming(nworker, worker_args, use_yarn):
if os.path.exists(f):
fset.add(f)
kmap = {}
kmap['env'] = 'mapred.child.env'
# setup keymaps
if use_yarn:
kmap['nworker'] = 'mapreduce.job.maps'
@ -129,6 +130,8 @@ def hadoop_streaming(nworker, worker_args, use_yarn):
cmd = '%s jar %s' % (args.hadoop_binary, args.hadoop_streaming_jar)
cmd += ' -D%s=%d' % (kmap['nworker'], nworker)
cmd += ' -D%s=%s' % (kmap['jobname'], args.jobname)
envstr = ','.join('%s=%s' % (k, str(v)) for k, v in worker_envs.items())
cmd += ' -D%s=\"%s\"' % (kmap['env'], envstr)
if args.nthread != -1:
if kmap['nthread'] is None:
warnings.warn('nthread can only be set in Yarn(Hadoop version greater than 2.0),'\

View File

@ -22,7 +22,7 @@ args = parser.parse_args()
#
# submission script using MPI
#
def mpi_submit(nslave, worker_args):
def mpi_submit(nslave, worker_args, worker_envs):
"""
customized submit script, that submit nslave jobs, each must contain args as parameter
note this can be a lambda function containing additional parameters in input
@ -31,6 +31,7 @@ def mpi_submit(nslave, worker_args):
args arguments to launch each job
this usually includes the parameters of master_uri and parameters passed into submit
"""
worker_args += ['%s=%s' % (k, str(v)) for k, v in worker_envs.items()]
sargs = ' '.join(args.command + worker_args)
if args.hostfile is None:
cmd = ' '.join(['mpirun -n %d' % (nslave)] + args.command + worker_args)

View File

@ -140,15 +140,19 @@ class Tracker:
self.log_print('start listen on %s:%d' % (socket.gethostname(), self.port), 1)
def __del__(self):
self.sock.close()
def slave_args(self):
def slave_envs(self):
"""
get enviroment variables for slaves
can be passed in as args or envs
"""
if self.hostIP == 'dns':
host = socket.gethostname()
elif self.hostIP == 'ip':
host = socket.gethostbyname(socket.getfqdn())
else:
host = self.hostIP
return ['rabit_tracker_uri=%s' % host,
'rabit_tracker_port=%s' % self.port]
return {'rabit_tracker_uri': host,
'rabit_tracker_port': self.port}
def get_neighbor(self, rank, nslave):
rank = rank + 1
ret = []
@ -265,7 +269,7 @@ class Tracker:
def submit(nslave, args, fun_submit, verbose, hostIP = 'auto'):
master = Tracker(verbose = verbose, hostIP = hostIP)
submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args()))
submit_thread = Thread(target = fun_submit, args = (nslave, args, master.slave_envs()))
submit_thread.daemon = True
submit_thread.start()
master.accept_slaves(nslave)