add auto caching of python in hadoop script, mock test module to python, with checkpt
This commit is contained in:
parent
877fc42e40
commit
3419cf9aa7
5
Makefile
5
Makefile
@ -10,14 +10,14 @@ BPATH=.
|
|||||||
MPIOBJ= $(BPATH)/engine_mpi.o
|
MPIOBJ= $(BPATH)/engine_mpi.o
|
||||||
OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/engine_empty.o $(BPATH)/engine_mock.o\
|
OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/engine_empty.o $(BPATH)/engine_mock.o\
|
||||||
$(BPATH)/rabit_wrapper.o
|
$(BPATH)/rabit_wrapper.o
|
||||||
SLIB= wrapper/librabit_wrapper.so
|
SLIB= wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so
|
||||||
ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a lib/librabit_mock.a
|
ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a lib/librabit_mock.a
|
||||||
HEADERS=src/*.h include/*.h include/rabit/*.h
|
HEADERS=src/*.h include/*.h include/rabit/*.h
|
||||||
.PHONY: clean all install mpi python
|
.PHONY: clean all install mpi python
|
||||||
|
|
||||||
all: lib/librabit.a lib/librabit_mock.a $(SLIB)
|
all: lib/librabit.a lib/librabit_mock.a $(SLIB)
|
||||||
mpi: lib/librabit_mpi.a
|
mpi: lib/librabit_mpi.a
|
||||||
python: wrapper/librabit_wrapper.so
|
python: wrapper/librabit_wrapper.so wrpper/librabit_wrapper_mock.so
|
||||||
|
|
||||||
$(BPATH)/allreduce_base.o: src/allreduce_base.cc $(HEADERS)
|
$(BPATH)/allreduce_base.o: src/allreduce_base.cc $(HEADERS)
|
||||||
$(BPATH)/engine.o: src/engine.cc $(HEADERS)
|
$(BPATH)/engine.o: src/engine.cc $(HEADERS)
|
||||||
@ -33,6 +33,7 @@ lib/librabit_mpi.a: $(MPIOBJ)
|
|||||||
# wrapper code
|
# wrapper code
|
||||||
$(BPATH)/rabit_wrapper.o: wrapper/rabit_wrapper.cc
|
$(BPATH)/rabit_wrapper.o: wrapper/rabit_wrapper.cc
|
||||||
wrapper/librabit_wrapper.so: $(BPATH)/rabit_wrapper.o lib/librabit.a
|
wrapper/librabit_wrapper.so: $(BPATH)/rabit_wrapper.o lib/librabit.a
|
||||||
|
wrapper/librabit_wrapper_mock.so: $(BPATH)/rabit_wrapper.o lib/librabit_mock.a
|
||||||
|
|
||||||
$(OBJ) :
|
$(OBJ) :
|
||||||
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
||||||
|
|||||||
@ -5,8 +5,9 @@ demo python script of rabit
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
# add path to wrapper
|
# import rabit, the tracker script will setup the lib path correctly
|
||||||
sys.path.append(os.path.dirname(__file__) + '/../wrapper')
|
# for normal run without tracker script, add following line
|
||||||
|
# sys.path.append(os.path.dirname(__file__) + '/../wrapper')
|
||||||
import rabit
|
import rabit
|
||||||
|
|
||||||
rabit.init()
|
rabit.init()
|
||||||
|
|||||||
@ -5,7 +5,8 @@ demo python script of rabit
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
# add path to wrapper
|
# add path to wrapper
|
||||||
sys.path.append(os.path.dirname(__file__) + '/../wrapper')
|
# for normal run without tracker script, add following line
|
||||||
|
# sys.path.append(os.path.dirname(__file__) + '/../wrapper')
|
||||||
import rabit
|
import rabit
|
||||||
|
|
||||||
rabit.init()
|
rabit.init()
|
||||||
|
|||||||
@ -8,13 +8,15 @@ export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include -std=
|
|||||||
BIN = speed_test test_model_recover test_local_recover
|
BIN = speed_test test_model_recover test_local_recover
|
||||||
OBJ = $(RABIT_OBJ) speed_test.o test_model_recover.o test_local_recover.o
|
OBJ = $(RABIT_OBJ) speed_test.o test_model_recover.o test_local_recover.o
|
||||||
MPIBIN = speed_test.mpi
|
MPIBIN = speed_test.mpi
|
||||||
.PHONY: clean all lib
|
.PHONY: clean all lib mpi
|
||||||
|
|
||||||
all: $(BIN) $(MPIBIN)
|
all: $(BIN) $(MPIBIN)
|
||||||
lib:
|
lib:
|
||||||
cd ..;make;cd -
|
cd ..;make;cd -
|
||||||
|
mpi:
|
||||||
|
cd ..;make mpi;cd -
|
||||||
# programs
|
# programs
|
||||||
speed_test.o: speed_test.cpp ../include/*.h lib
|
speed_test.o: speed_test.cpp ../include/*.h lib mpi
|
||||||
test_model_recover.o: test_model_recover.cpp ../include/*.h lib
|
test_model_recover.o: test_model_recover.cpp ../include/*.h lib
|
||||||
test_local_recover.o: test_local_recover.cpp ../include/*.h lib
|
test_local_recover.o: test_local_recover.cpp ../include/*.h lib
|
||||||
|
|
||||||
|
|||||||
@ -9,12 +9,6 @@ endif
|
|||||||
.PHONY: model_recover local_recover speed
|
.PHONY: model_recover local_recover speed
|
||||||
|
|
||||||
|
|
||||||
local_recover:
|
|
||||||
../tracker/rabit_demo.py -n $(nslave) test_local_recover $(ndata) rabit_local_replica=1
|
|
||||||
|
|
||||||
local_recover_10_10k:
|
|
||||||
../tracker/rabit_demo.py -n 10 test_local_recover 10000 rabit_local_replica=1
|
|
||||||
|
|
||||||
# this experiment test recovery with actually process exit, use keepalive to keep program alive
|
# this experiment test recovery with actually process exit, use keepalive to keep program alive
|
||||||
model_recover_10_10k:
|
model_recover_10_10k:
|
||||||
../tracker/rabit_demo.py -n 10 test_model_recover 10000 mock=0,0,1,0 mock=1,1,1,0
|
../tracker/rabit_demo.py -n 10 test_model_recover 10000 mock=0,0,1,0 mock=1,1,1,0
|
||||||
|
|||||||
@ -5,28 +5,8 @@
|
|||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include "./mock.h"
|
|
||||||
|
|
||||||
using namespace rabit;
|
using namespace rabit;
|
||||||
namespace rabit {
|
|
||||||
namespace test {
|
|
||||||
inline void CallBegin(const char *fun, int ntrial, int iter) {
|
|
||||||
int rank = rabit::GetRank();
|
|
||||||
if (!strcmp(fun, "Allreduce::Sum")) {
|
|
||||||
if (ntrial == iter && rank == 0) throw MockException();
|
|
||||||
}
|
|
||||||
if (!strcmp(fun, "Allreduce::Max")) {
|
|
||||||
if (ntrial == iter && rank == 3) throw MockException();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
inline void CallEnd(const char *fun, int ntrial, int iter) {
|
|
||||||
int rank = rabit::GetRank();
|
|
||||||
if (!strcmp(fun, "Allreduce::Bcast")) {
|
|
||||||
if (ntrial == iter && rand() % 10 == rank) throw MockException();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// dummy model
|
// dummy model
|
||||||
class Model : public rabit::ISerializable {
|
class Model : public rabit::ISerializable {
|
||||||
@ -52,8 +32,6 @@ inline void TestMax(Model *model, Model *local, int ntrial, int iter) {
|
|||||||
int nproc = rabit::GetWorldSize();
|
int nproc = rabit::GetWorldSize();
|
||||||
const int z = iter + 111;
|
const int z = iter + 111;
|
||||||
std::vector<float> ndata(model->data.size());
|
std::vector<float> ndata(model->data.size());
|
||||||
|
|
||||||
test::CallBegin("Allreduce::Max", ntrial, iter);
|
|
||||||
rabit::Allreduce<op::Max>(&ndata[0], ndata.size(),
|
rabit::Allreduce<op::Max>(&ndata[0], ndata.size(),
|
||||||
[&]() {
|
[&]() {
|
||||||
// use lambda expression to prepare the data
|
// use lambda expression to prepare the data
|
||||||
@ -61,7 +39,6 @@ inline void TestMax(Model *model, Model *local, int ntrial, int iter) {
|
|||||||
ndata[i] = (i * (rank+1)) % z + local->data[i];
|
ndata[i] = (i * (rank+1)) % z + local->data[i];
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
test::CallEnd("Allreduce::Max", ntrial, iter);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
float rmax = (i * 1) % z + model->data[i];
|
float rmax = (i * 1) % z + model->data[i];
|
||||||
@ -86,9 +63,7 @@ inline void TestSum(Model *model, Model *local, int ntrial, int iter) {
|
|||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
ndata[i] = (i * (rank+1)) % z + local->data[i];
|
ndata[i] = (i * (rank+1)) % z + local->data[i];
|
||||||
}
|
}
|
||||||
test::CallBegin("Allreduce::Sum", ntrial, iter);
|
|
||||||
Allreduce<op::Sum>(&ndata[0], ndata.size());
|
Allreduce<op::Sum>(&ndata[0], ndata.size());
|
||||||
test::CallEnd("Allreduce::Sum", ntrial, iter);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
float rsum = 0.0f;
|
float rsum = 0.0f;
|
||||||
@ -113,13 +88,9 @@ inline void TestBcast(size_t n, int root, int ntrial, int iter) {
|
|||||||
std::string res;
|
std::string res;
|
||||||
if (root == rank) {
|
if (root == rank) {
|
||||||
res = s;
|
res = s;
|
||||||
test::CallBegin("Broadcast", ntrial, iter);
|
|
||||||
rabit::Broadcast(&res, root);
|
rabit::Broadcast(&res, root);
|
||||||
test::CallBegin("Broadcast", ntrial, iter);
|
|
||||||
} else {
|
} else {
|
||||||
test::CallBegin("Broadcast", ntrial, iter);
|
|
||||||
rabit::Broadcast(&res, root);
|
rabit::Broadcast(&res, root);
|
||||||
test::CallEnd("Broadcast", ntrial, iter);
|
|
||||||
}
|
}
|
||||||
utils::Check(res == s, "[%d] TestBcast fail", rank);
|
utils::Check(res == s, "[%d] TestBcast fail", rank);
|
||||||
}
|
}
|
||||||
@ -141,8 +112,6 @@ int main(int argc, char *argv[]) {
|
|||||||
int n;
|
int n;
|
||||||
if (sscanf(argv[i], "repeat=%d", &n) == 1) ntrial = n;
|
if (sscanf(argv[i], "repeat=%d", &n) == 1) ntrial = n;
|
||||||
}
|
}
|
||||||
while (true) {
|
|
||||||
try {
|
|
||||||
int iter = rabit::LoadCheckPoint(&model, &local);
|
int iter = rabit::LoadCheckPoint(&model, &local);
|
||||||
if (iter == 0) {
|
if (iter == 0) {
|
||||||
model.InitModel(n, 1.0f);
|
model.InitModel(n, 1.0f);
|
||||||
@ -164,12 +133,6 @@ int main(int argc, char *argv[]) {
|
|||||||
rabit::CheckPoint(&model, &local);
|
rabit::CheckPoint(&model, &local);
|
||||||
printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
|
printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
} catch (MockException &e) {
|
|
||||||
rabit::engine::GetEngine()->InitAfterException();
|
|
||||||
++ntrial;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rabit::Finalize();
|
rabit::Finalize();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,30 +5,7 @@
|
|||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include "./mock.h"
|
|
||||||
|
|
||||||
using namespace rabit;
|
using namespace rabit;
|
||||||
namespace rabit {
|
|
||||||
namespace test {
|
|
||||||
inline void CallBegin(const char *fun, int ntrial, int iter) {
|
|
||||||
return;
|
|
||||||
int rank = rabit::GetRank();
|
|
||||||
if (!strcmp(fun, "Allreduce::Sum")) {
|
|
||||||
if (ntrial == iter && rank == 0) exit(-1);
|
|
||||||
}
|
|
||||||
if (!strcmp(fun, "Allreduce::Max")) {
|
|
||||||
if (ntrial == iter && rank == 3) exit(-1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
inline void CallEnd(const char *fun, int ntrial, int iter) {
|
|
||||||
return;
|
|
||||||
int rank = rabit::GetRank();
|
|
||||||
if (!strcmp(fun, "Allreduce::Bcast")) {
|
|
||||||
if (ntrial == iter && rand() % 10 == rank) exit(-1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// dummy model
|
// dummy model
|
||||||
class Model : public rabit::ISerializable {
|
class Model : public rabit::ISerializable {
|
||||||
@ -58,9 +35,7 @@ inline void TestMax(Model *model, int ntrial, int iter) {
|
|||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
ndata[i] = (i * (rank+1)) % z + model->data[i];
|
ndata[i] = (i * (rank+1)) % z + model->data[i];
|
||||||
}
|
}
|
||||||
test::CallBegin("Allreduce::Max", ntrial, iter);
|
|
||||||
rabit::Allreduce<op::Max>(&ndata[0], ndata.size());
|
rabit::Allreduce<op::Max>(&ndata[0], ndata.size());
|
||||||
test::CallEnd("Allreduce::Max", ntrial, iter);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
float rmax = (i * 1) % z + model->data[i];
|
float rmax = (i * 1) % z + model->data[i];
|
||||||
@ -81,9 +56,7 @@ inline void TestSum(Model *model, int ntrial, int iter) {
|
|||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
ndata[i] = (i * (rank+1)) % z + model->data[i];
|
ndata[i] = (i * (rank+1)) % z + model->data[i];
|
||||||
}
|
}
|
||||||
test::CallBegin("Allreduce::Sum", ntrial, iter);
|
|
||||||
Allreduce<op::Sum>(&ndata[0], ndata.size());
|
Allreduce<op::Sum>(&ndata[0], ndata.size());
|
||||||
test::CallEnd("Allreduce::Sum", ntrial, iter);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
float rsum = model->data[i] * nproc;
|
float rsum = model->data[i] * nproc;
|
||||||
@ -105,13 +78,9 @@ inline void TestBcast(size_t n, int root, int ntrial, int iter) {
|
|||||||
std::string res;
|
std::string res;
|
||||||
if (root == rank) {
|
if (root == rank) {
|
||||||
res = s;
|
res = s;
|
||||||
test::CallBegin("Broadcast", ntrial, iter);
|
|
||||||
rabit::Broadcast(&res, root);
|
rabit::Broadcast(&res, root);
|
||||||
test::CallBegin("Broadcast", ntrial, iter);
|
|
||||||
} else {
|
} else {
|
||||||
test::CallBegin("Broadcast", ntrial, iter);
|
|
||||||
rabit::Broadcast(&res, root);
|
rabit::Broadcast(&res, root);
|
||||||
test::CallEnd("Broadcast", ntrial, iter);
|
|
||||||
}
|
}
|
||||||
utils::Check(res == s, "[%d] TestBcast fail", rank);
|
utils::Check(res == s, "[%d] TestBcast fail", rank);
|
||||||
}
|
}
|
||||||
@ -133,8 +102,6 @@ int main(int argc, char *argv[]) {
|
|||||||
int n;
|
int n;
|
||||||
if (sscanf(argv[i], "rabit_num_trial=%d", &n) == 1) ntrial = n;
|
if (sscanf(argv[i], "rabit_num_trial=%d", &n) == 1) ntrial = n;
|
||||||
}
|
}
|
||||||
while (true) {
|
|
||||||
try {
|
|
||||||
int iter = rabit::LoadCheckPoint(&model);
|
int iter = rabit::LoadCheckPoint(&model);
|
||||||
if (iter == 0) {
|
if (iter == 0) {
|
||||||
model.InitModel(n);
|
model.InitModel(n);
|
||||||
@ -155,12 +122,6 @@ int main(int argc, char *argv[]) {
|
|||||||
rabit::CheckPoint(&model);
|
rabit::CheckPoint(&model);
|
||||||
printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
|
printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
} catch (MockException &e) {
|
|
||||||
rabit::engine::GetEngine()->InitAfterException();
|
|
||||||
++ntrial;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rabit::Finalize();
|
rabit::Finalize();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
import rabit_tracker as tracker
|
import rabit_tracker as tracker
|
||||||
|
WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper'
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Rabit script to submit rabit job locally using python subprocess')
|
parser = argparse.ArgumentParser(description='Rabit script to submit rabit job locally using python subprocess')
|
||||||
parser.add_argument('-n', '--nworker', required=True, type=int,
|
parser.add_argument('-n', '--nworker', required=True, type=int,
|
||||||
@ -25,8 +26,9 @@ def exec_cmd(cmd, taskid):
|
|||||||
cmd = ' '.join(cmd)
|
cmd = ' '.join(cmd)
|
||||||
ntrial = 0
|
ntrial = 0
|
||||||
while True:
|
while True:
|
||||||
|
prep = 'PYTHONPATH=\"%s\" ' % WRAPPER_PATH
|
||||||
arg = ' rabit_task_id=%d rabit_num_trial=%d' % (taskid, ntrial)
|
arg = ' rabit_task_id=%d rabit_num_trial=%d' % (taskid, ntrial)
|
||||||
ret = subprocess.call(cmd + arg, shell = True)
|
ret = subprocess.call(prep + cmd + arg, shell = True)
|
||||||
if ret == 254 or ret == -2:
|
if ret == 254 or ret == -2:
|
||||||
ntrial += 1
|
ntrial += 1
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import subprocess
|
|||||||
import warnings
|
import warnings
|
||||||
import rabit_tracker as tracker
|
import rabit_tracker as tracker
|
||||||
|
|
||||||
|
WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper'
|
||||||
|
|
||||||
#!!! Set path to hadoop and hadoop streaming jar here
|
#!!! Set path to hadoop and hadoop streaming jar here
|
||||||
hadoop_binary = 'hadoop'
|
hadoop_binary = 'hadoop'
|
||||||
@ -102,6 +103,13 @@ def hadoop_streaming(nworker, worker_args, use_yarn):
|
|||||||
args.command[i] = './' + args.command[i].split('/')[-1]
|
args.command[i] = './' + args.command[i].split('/')[-1]
|
||||||
else:
|
else:
|
||||||
args.command[i] = args.command[i].split('/')[-1]
|
args.command[i] = args.command[i].split('/')[-1]
|
||||||
|
if args.commands[0].endswith('.py'):
|
||||||
|
flst = [WRAPPER_PATH + '/rabit.py',
|
||||||
|
WRAPPER_PATH + '/librabit_wrapper.so',
|
||||||
|
WRAPPER_PATH + '/librabit_wrapper_mock.so']
|
||||||
|
for f in flst:
|
||||||
|
if os.path.exists(f):
|
||||||
|
fset.add(f)
|
||||||
kmap = {}
|
kmap = {}
|
||||||
# setup keymaps
|
# setup keymaps
|
||||||
if use_yarn:
|
if use_yarn:
|
||||||
|
|||||||
@ -7,19 +7,35 @@ import cPickle as pickle
|
|||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
assert False, "Rabit windows is not yet compiled"
|
assert False, "Rabit windows is not yet compiled"
|
||||||
else:
|
else:
|
||||||
RABIT_PATH = os.path.dirname(__file__)+'/librabit_wrapper.so'
|
WRAPPER_PATH = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
rbtlib = None
|
||||||
|
|
||||||
# load in xgboost library
|
# load in xgboost library
|
||||||
rbtlib = ctypes.cdll.LoadLibrary(RABIT_PATH)
|
def loadlib__(with_mock = False):
|
||||||
|
global rbtlib
|
||||||
|
if rbtlib != None:
|
||||||
|
warnings.Warn('rabit.int call was ignored because it has already been initialized', level = 2)
|
||||||
|
return
|
||||||
|
if with_mock:
|
||||||
|
rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH + '/librabit_wrapper_mock.so')
|
||||||
|
else:
|
||||||
|
rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH + '/librabit_wrapper.so')
|
||||||
rbtlib.RabitGetRank.restype = ctypes.c_int
|
rbtlib.RabitGetRank.restype = ctypes.c_int
|
||||||
rbtlib.RabitGetWorldSize.restype = ctypes.c_int
|
rbtlib.RabitGetWorldSize.restype = ctypes.c_int
|
||||||
rbtlib.RabitVersionNumber.restype = ctypes.c_int
|
rbtlib.RabitVersionNumber.restype = ctypes.c_int
|
||||||
|
|
||||||
|
def unloadlib__():
|
||||||
|
global rbtlib
|
||||||
|
del rbtlib
|
||||||
|
rbtlib = None
|
||||||
|
|
||||||
# reduction operators
|
# reduction operators
|
||||||
MAX = 0
|
MAX = 0
|
||||||
MIN = 1
|
MIN = 1
|
||||||
@ -32,14 +48,17 @@ def check_err__():
|
|||||||
"""
|
"""
|
||||||
return
|
return
|
||||||
|
|
||||||
def init(args = sys.argv):
|
def init(args = sys.argv, with_mock = False):
|
||||||
"""
|
"""
|
||||||
intialize the rabit module, call this once before using anything
|
intialize the rabit module, call this once before using anything
|
||||||
Arguments:
|
Arguments:
|
||||||
args: list(string)
|
args: list(string) [default=sys.argv]
|
||||||
the list of arguments used to initialized the rabit
|
the list of arguments used to initialized the rabit
|
||||||
usually you need to pass in sys.argv
|
usually you need to pass in sys.argv
|
||||||
|
with_mock: boolean [default=False]
|
||||||
|
Whether initialize the mock test module
|
||||||
"""
|
"""
|
||||||
|
loadlib__(with_mock)
|
||||||
arr = (ctypes.c_char_p * len(args))()
|
arr = (ctypes.c_char_p * len(args))()
|
||||||
arr[:] = args
|
arr[:] = args
|
||||||
rbtlib.RabitInit(len(args), arr)
|
rbtlib.RabitInit(len(args), arr)
|
||||||
@ -51,6 +70,7 @@ def finalize():
|
|||||||
"""
|
"""
|
||||||
rbtlib.RabitFinalize()
|
rbtlib.RabitFinalize()
|
||||||
check_err__()
|
check_err__()
|
||||||
|
unloadlib__()
|
||||||
|
|
||||||
def get_rank():
|
def get_rank():
|
||||||
"""
|
"""
|
||||||
@ -156,7 +176,7 @@ DTYPE_ENUM__ = {
|
|||||||
np.dtype('float64') : 7
|
np.dtype('float64') : 7
|
||||||
}
|
}
|
||||||
|
|
||||||
def allreduce(data, op):
|
def allreduce(data, op, prepare_fun = None):
|
||||||
"""
|
"""
|
||||||
perform allreduce, return the result, this function is not thread-safe
|
perform allreduce, return the result, this function is not thread-safe
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -164,8 +184,8 @@ def allreduce(data, op):
|
|||||||
input data
|
input data
|
||||||
op: int
|
op: int
|
||||||
reduction operators, can be MIN, MAX, SUM, BITOR
|
reduction operators, can be MIN, MAX, SUM, BITOR
|
||||||
prepare_fun: lambda
|
prepare_fun: lambda data
|
||||||
Lazy preprocessing function, if it is not None, prepare_fun()
|
Lazy preprocessing function, if it is not None, prepare_fun(data)
|
||||||
will be called by the function before performing allreduce, to intialize the data
|
will be called by the function before performing allreduce, to intialize the data
|
||||||
If the result of Allreduce can be recovered directly, then prepare_fun will NOT be called
|
If the result of Allreduce can be recovered directly, then prepare_fun will NOT be called
|
||||||
Returns:
|
Returns:
|
||||||
@ -178,9 +198,17 @@ def allreduce(data, op):
|
|||||||
buf = buf.copy()
|
buf = buf.copy()
|
||||||
if buf.dtype not in DTYPE_ENUM__:
|
if buf.dtype not in DTYPE_ENUM__:
|
||||||
raise Exception('data type %s not supported' % str(buf.dtype))
|
raise Exception('data type %s not supported' % str(buf.dtype))
|
||||||
|
if prepare_fun is None:
|
||||||
rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||||
buf.size, DTYPE_ENUM__[dtype],
|
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||||
op, None, None);
|
op, None, None)
|
||||||
|
else:
|
||||||
|
PFUNC = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
||||||
|
def pfunc(args):
|
||||||
|
prepare_fun(data)
|
||||||
|
rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||||
|
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||||
|
op, PFUNC(pfunc), None)
|
||||||
check_err__()
|
check_err__()
|
||||||
return buf
|
return buf
|
||||||
|
|
||||||
@ -273,4 +301,3 @@ def version_number():
|
|||||||
ret = rbtlib.RabitVersionNumber()
|
ret = rbtlib.RabitVersionNumber()
|
||||||
check_err__()
|
check_err__()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user