add auto caching of python in hadoop script, mock test module to python, with checkpt
This commit is contained in:
parent
877fc42e40
commit
3419cf9aa7
5
Makefile
5
Makefile
@ -10,14 +10,14 @@ BPATH=.
|
||||
MPIOBJ= $(BPATH)/engine_mpi.o
|
||||
OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/engine_empty.o $(BPATH)/engine_mock.o\
|
||||
$(BPATH)/rabit_wrapper.o
|
||||
SLIB= wrapper/librabit_wrapper.so
|
||||
SLIB= wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so
|
||||
ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a lib/librabit_mock.a
|
||||
HEADERS=src/*.h include/*.h include/rabit/*.h
|
||||
.PHONY: clean all install mpi python
|
||||
|
||||
all: lib/librabit.a lib/librabit_mock.a $(SLIB)
|
||||
mpi: lib/librabit_mpi.a
|
||||
python: wrapper/librabit_wrapper.so
|
||||
python: wrapper/librabit_wrapper.so wrpper/librabit_wrapper_mock.so
|
||||
|
||||
$(BPATH)/allreduce_base.o: src/allreduce_base.cc $(HEADERS)
|
||||
$(BPATH)/engine.o: src/engine.cc $(HEADERS)
|
||||
@ -33,6 +33,7 @@ lib/librabit_mpi.a: $(MPIOBJ)
|
||||
# wrapper code
|
||||
$(BPATH)/rabit_wrapper.o: wrapper/rabit_wrapper.cc
|
||||
wrapper/librabit_wrapper.so: $(BPATH)/rabit_wrapper.o lib/librabit.a
|
||||
wrapper/librabit_wrapper_mock.so: $(BPATH)/rabit_wrapper.o lib/librabit_mock.a
|
||||
|
||||
$(OBJ) :
|
||||
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
||||
|
||||
@ -5,8 +5,9 @@ demo python script of rabit
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
# add path to wrapper
|
||||
sys.path.append(os.path.dirname(__file__) + '/../wrapper')
|
||||
# import rabit, the tracker script will setup the lib path correctly
|
||||
# for normal run without tracker script, add following line
|
||||
# sys.path.append(os.path.dirname(__file__) + '/../wrapper')
|
||||
import rabit
|
||||
|
||||
rabit.init()
|
||||
@ -15,7 +16,7 @@ rank = rabit.get_rank()
|
||||
a = np.zeros(n)
|
||||
for i in xrange(n):
|
||||
a[i] = rank + i
|
||||
|
||||
|
||||
print '@node[%d] before-allreduce: a=%s' % (rank, str(a))
|
||||
a = rabit.allreduce(a, rabit.MAX)
|
||||
print '@node[%d] after-allreduce: a=%s' % (rank, str(a))
|
||||
|
||||
@ -5,7 +5,8 @@ demo python script of rabit
|
||||
import os
|
||||
import sys
|
||||
# add path to wrapper
|
||||
sys.path.append(os.path.dirname(__file__) + '/../wrapper')
|
||||
# for normal run without tracker script, add following line
|
||||
# sys.path.append(os.path.dirname(__file__) + '/../wrapper')
|
||||
import rabit
|
||||
|
||||
rabit.init()
|
||||
|
||||
@ -8,13 +8,15 @@ export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../include -std=
|
||||
BIN = speed_test test_model_recover test_local_recover
|
||||
OBJ = $(RABIT_OBJ) speed_test.o test_model_recover.o test_local_recover.o
|
||||
MPIBIN = speed_test.mpi
|
||||
.PHONY: clean all lib
|
||||
.PHONY: clean all lib mpi
|
||||
|
||||
all: $(BIN) $(MPIBIN)
|
||||
lib:
|
||||
cd ..;make;cd -
|
||||
mpi:
|
||||
cd ..;make mpi;cd -
|
||||
# programs
|
||||
speed_test.o: speed_test.cpp ../include/*.h lib
|
||||
speed_test.o: speed_test.cpp ../include/*.h lib mpi
|
||||
test_model_recover.o: test_model_recover.cpp ../include/*.h lib
|
||||
test_local_recover.o: test_local_recover.cpp ../include/*.h lib
|
||||
|
||||
|
||||
@ -9,12 +9,6 @@ endif
|
||||
.PHONY: model_recover local_recover speed
|
||||
|
||||
|
||||
local_recover:
|
||||
../tracker/rabit_demo.py -n $(nslave) test_local_recover $(ndata) rabit_local_replica=1
|
||||
|
||||
local_recover_10_10k:
|
||||
../tracker/rabit_demo.py -n 10 test_local_recover 10000 rabit_local_replica=1
|
||||
|
||||
# this experiment test recovery with actually process exit, use keepalive to keep program alive
|
||||
model_recover_10_10k:
|
||||
../tracker/rabit_demo.py -n 10 test_model_recover 10000 mock=0,0,1,0 mock=1,1,1,0
|
||||
|
||||
@ -5,28 +5,8 @@
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cmath>
|
||||
#include "./mock.h"
|
||||
|
||||
using namespace rabit;
|
||||
namespace rabit {
|
||||
namespace test {
|
||||
inline void CallBegin(const char *fun, int ntrial, int iter) {
|
||||
int rank = rabit::GetRank();
|
||||
if (!strcmp(fun, "Allreduce::Sum")) {
|
||||
if (ntrial == iter && rank == 0) throw MockException();
|
||||
}
|
||||
if (!strcmp(fun, "Allreduce::Max")) {
|
||||
if (ntrial == iter && rank == 3) throw MockException();
|
||||
}
|
||||
}
|
||||
inline void CallEnd(const char *fun, int ntrial, int iter) {
|
||||
int rank = rabit::GetRank();
|
||||
if (!strcmp(fun, "Allreduce::Bcast")) {
|
||||
if (ntrial == iter && rand() % 10 == rank) throw MockException();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dummy model
|
||||
class Model : public rabit::ISerializable {
|
||||
@ -52,8 +32,6 @@ inline void TestMax(Model *model, Model *local, int ntrial, int iter) {
|
||||
int nproc = rabit::GetWorldSize();
|
||||
const int z = iter + 111;
|
||||
std::vector<float> ndata(model->data.size());
|
||||
|
||||
test::CallBegin("Allreduce::Max", ntrial, iter);
|
||||
rabit::Allreduce<op::Max>(&ndata[0], ndata.size(),
|
||||
[&]() {
|
||||
// use lambda expression to prepare the data
|
||||
@ -61,7 +39,6 @@ inline void TestMax(Model *model, Model *local, int ntrial, int iter) {
|
||||
ndata[i] = (i * (rank+1)) % z + local->data[i];
|
||||
}
|
||||
});
|
||||
test::CallEnd("Allreduce::Max", ntrial, iter);
|
||||
|
||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||
float rmax = (i * 1) % z + model->data[i];
|
||||
@ -86,9 +63,7 @@ inline void TestSum(Model *model, Model *local, int ntrial, int iter) {
|
||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||
ndata[i] = (i * (rank+1)) % z + local->data[i];
|
||||
}
|
||||
test::CallBegin("Allreduce::Sum", ntrial, iter);
|
||||
Allreduce<op::Sum>(&ndata[0], ndata.size());
|
||||
test::CallEnd("Allreduce::Sum", ntrial, iter);
|
||||
|
||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||
float rsum = 0.0f;
|
||||
@ -113,13 +88,9 @@ inline void TestBcast(size_t n, int root, int ntrial, int iter) {
|
||||
std::string res;
|
||||
if (root == rank) {
|
||||
res = s;
|
||||
test::CallBegin("Broadcast", ntrial, iter);
|
||||
rabit::Broadcast(&res, root);
|
||||
test::CallBegin("Broadcast", ntrial, iter);
|
||||
} else {
|
||||
test::CallBegin("Broadcast", ntrial, iter);
|
||||
rabit::Broadcast(&res, root);
|
||||
test::CallEnd("Broadcast", ntrial, iter);
|
||||
}
|
||||
utils::Check(res == s, "[%d] TestBcast fail", rank);
|
||||
}
|
||||
@ -141,34 +112,26 @@ int main(int argc, char *argv[]) {
|
||||
int n;
|
||||
if (sscanf(argv[i], "repeat=%d", &n) == 1) ntrial = n;
|
||||
}
|
||||
while (true) {
|
||||
try {
|
||||
int iter = rabit::LoadCheckPoint(&model, &local);
|
||||
if (iter == 0) {
|
||||
model.InitModel(n, 1.0f);
|
||||
local.InitModel(n, 1.0f + rank);
|
||||
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||
} else {
|
||||
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||
}
|
||||
for (int r = iter; r < 3; ++r) {
|
||||
TestMax(&model, &local, ntrial, r);
|
||||
printf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
|
||||
int step = std::max(nproc / 3, 1);
|
||||
for (int i = 0; i < nproc; i += step) {
|
||||
TestBcast(n, i, ntrial, r);
|
||||
}
|
||||
printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
|
||||
TestSum(&model, &local, ntrial, r);
|
||||
printf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
|
||||
rabit::CheckPoint(&model, &local);
|
||||
printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
|
||||
}
|
||||
break;
|
||||
} catch (MockException &e) {
|
||||
rabit::engine::GetEngine()->InitAfterException();
|
||||
++ntrial;
|
||||
int iter = rabit::LoadCheckPoint(&model, &local);
|
||||
if (iter == 0) {
|
||||
model.InitModel(n, 1.0f);
|
||||
local.InitModel(n, 1.0f + rank);
|
||||
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||
} else {
|
||||
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||
}
|
||||
for (int r = iter; r < 3; ++r) {
|
||||
TestMax(&model, &local, ntrial, r);
|
||||
printf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
|
||||
int step = std::max(nproc / 3, 1);
|
||||
for (int i = 0; i < nproc; i += step) {
|
||||
TestBcast(n, i, ntrial, r);
|
||||
}
|
||||
printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
|
||||
TestSum(&model, &local, ntrial, r);
|
||||
printf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
|
||||
rabit::CheckPoint(&model, &local);
|
||||
printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
|
||||
}
|
||||
rabit::Finalize();
|
||||
return 0;
|
||||
|
||||
@ -5,30 +5,7 @@
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cmath>
|
||||
#include "./mock.h"
|
||||
|
||||
using namespace rabit;
|
||||
namespace rabit {
|
||||
namespace test {
|
||||
inline void CallBegin(const char *fun, int ntrial, int iter) {
|
||||
return;
|
||||
int rank = rabit::GetRank();
|
||||
if (!strcmp(fun, "Allreduce::Sum")) {
|
||||
if (ntrial == iter && rank == 0) exit(-1);
|
||||
}
|
||||
if (!strcmp(fun, "Allreduce::Max")) {
|
||||
if (ntrial == iter && rank == 3) exit(-1);
|
||||
}
|
||||
}
|
||||
inline void CallEnd(const char *fun, int ntrial, int iter) {
|
||||
return;
|
||||
int rank = rabit::GetRank();
|
||||
if (!strcmp(fun, "Allreduce::Bcast")) {
|
||||
if (ntrial == iter && rand() % 10 == rank) exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dummy model
|
||||
class Model : public rabit::ISerializable {
|
||||
@ -58,9 +35,7 @@ inline void TestMax(Model *model, int ntrial, int iter) {
|
||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||
ndata[i] = (i * (rank+1)) % z + model->data[i];
|
||||
}
|
||||
test::CallBegin("Allreduce::Max", ntrial, iter);
|
||||
rabit::Allreduce<op::Max>(&ndata[0], ndata.size());
|
||||
test::CallEnd("Allreduce::Max", ntrial, iter);
|
||||
|
||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||
float rmax = (i * 1) % z + model->data[i];
|
||||
@ -81,9 +56,7 @@ inline void TestSum(Model *model, int ntrial, int iter) {
|
||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||
ndata[i] = (i * (rank+1)) % z + model->data[i];
|
||||
}
|
||||
test::CallBegin("Allreduce::Sum", ntrial, iter);
|
||||
Allreduce<op::Sum>(&ndata[0], ndata.size());
|
||||
test::CallEnd("Allreduce::Sum", ntrial, iter);
|
||||
|
||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||
float rsum = model->data[i] * nproc;
|
||||
@ -105,13 +78,9 @@ inline void TestBcast(size_t n, int root, int ntrial, int iter) {
|
||||
std::string res;
|
||||
if (root == rank) {
|
||||
res = s;
|
||||
test::CallBegin("Broadcast", ntrial, iter);
|
||||
rabit::Broadcast(&res, root);
|
||||
test::CallBegin("Broadcast", ntrial, iter);
|
||||
} else {
|
||||
test::CallBegin("Broadcast", ntrial, iter);
|
||||
rabit::Broadcast(&res, root);
|
||||
test::CallEnd("Broadcast", ntrial, iter);
|
||||
}
|
||||
utils::Check(res == s, "[%d] TestBcast fail", rank);
|
||||
}
|
||||
@ -133,33 +102,25 @@ int main(int argc, char *argv[]) {
|
||||
int n;
|
||||
if (sscanf(argv[i], "rabit_num_trial=%d", &n) == 1) ntrial = n;
|
||||
}
|
||||
while (true) {
|
||||
try {
|
||||
int iter = rabit::LoadCheckPoint(&model);
|
||||
if (iter == 0) {
|
||||
model.InitModel(n);
|
||||
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||
} else {
|
||||
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||
}
|
||||
for (int r = iter; r < 3; ++r) {
|
||||
TestMax(&model, ntrial, r);
|
||||
printf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
|
||||
int step = std::max(nproc / 3, 1);
|
||||
for (int i = 0; i < nproc; i += step) {
|
||||
TestBcast(n, i, ntrial, r);
|
||||
}
|
||||
printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
|
||||
TestSum(&model, ntrial, r);
|
||||
printf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
|
||||
rabit::CheckPoint(&model);
|
||||
printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
|
||||
}
|
||||
break;
|
||||
} catch (MockException &e) {
|
||||
rabit::engine::GetEngine()->InitAfterException();
|
||||
++ntrial;
|
||||
int iter = rabit::LoadCheckPoint(&model);
|
||||
if (iter == 0) {
|
||||
model.InitModel(n);
|
||||
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||
} else {
|
||||
printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter);
|
||||
}
|
||||
for (int r = iter; r < 3; ++r) {
|
||||
TestMax(&model, ntrial, r);
|
||||
printf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
|
||||
int step = std::max(nproc / 3, 1);
|
||||
for (int i = 0; i < nproc; i += step) {
|
||||
TestBcast(n, i, ntrial, r);
|
||||
}
|
||||
printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
|
||||
TestSum(&model, ntrial, r);
|
||||
printf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
|
||||
rabit::CheckPoint(&model);
|
||||
printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r);
|
||||
}
|
||||
rabit::Finalize();
|
||||
return 0;
|
||||
|
||||
@ -9,6 +9,7 @@ import os
|
||||
import subprocess
|
||||
from threading import Thread
|
||||
import rabit_tracker as tracker
|
||||
WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper'
|
||||
|
||||
parser = argparse.ArgumentParser(description='Rabit script to submit rabit job locally using python subprocess')
|
||||
parser.add_argument('-n', '--nworker', required=True, type=int,
|
||||
@ -25,8 +26,9 @@ def exec_cmd(cmd, taskid):
|
||||
cmd = ' '.join(cmd)
|
||||
ntrial = 0
|
||||
while True:
|
||||
prep = 'PYTHONPATH=\"%s\" ' % WRAPPER_PATH
|
||||
arg = ' rabit_task_id=%d rabit_num_trial=%d' % (taskid, ntrial)
|
||||
ret = subprocess.call(cmd + arg, shell = True)
|
||||
ret = subprocess.call(prep + cmd + arg, shell = True)
|
||||
if ret == 254 or ret == -2:
|
||||
ntrial += 1
|
||||
continue
|
||||
|
||||
@ -11,6 +11,7 @@ import subprocess
|
||||
import warnings
|
||||
import rabit_tracker as tracker
|
||||
|
||||
WRAPPER_PATH = os.path.dirname(__file__) + '/../wrapper'
|
||||
|
||||
#!!! Set path to hadoop and hadoop streaming jar here
|
||||
hadoop_binary = 'hadoop'
|
||||
@ -102,6 +103,13 @@ def hadoop_streaming(nworker, worker_args, use_yarn):
|
||||
args.command[i] = './' + args.command[i].split('/')[-1]
|
||||
else:
|
||||
args.command[i] = args.command[i].split('/')[-1]
|
||||
if args.commands[0].endswith('.py'):
|
||||
flst = [WRAPPER_PATH + '/rabit.py',
|
||||
WRAPPER_PATH + '/librabit_wrapper.so',
|
||||
WRAPPER_PATH + '/librabit_wrapper_mock.so']
|
||||
for f in flst:
|
||||
if os.path.exists(f):
|
||||
fset.add(f)
|
||||
kmap = {}
|
||||
# setup keymaps
|
||||
if use_yarn:
|
||||
|
||||
@ -7,18 +7,34 @@ import cPickle as pickle
|
||||
import ctypes
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
if os.name == 'nt':
|
||||
assert False, "Rabit windows is not yet compiled"
|
||||
else:
|
||||
RABIT_PATH = os.path.dirname(__file__)+'/librabit_wrapper.so'
|
||||
WRAPPER_PATH = os.path.dirname(__file__)
|
||||
|
||||
rbtlib = None
|
||||
|
||||
# load in xgboost library
|
||||
rbtlib = ctypes.cdll.LoadLibrary(RABIT_PATH)
|
||||
rbtlib.RabitGetRank.restype = ctypes.c_int
|
||||
rbtlib.RabitGetWorldSize.restype = ctypes.c_int
|
||||
rbtlib.RabitVersionNumber.restype = ctypes.c_int
|
||||
def loadlib__(with_mock = False):
|
||||
global rbtlib
|
||||
if rbtlib != None:
|
||||
warnings.Warn('rabit.int call was ignored because it has already been initialized', level = 2)
|
||||
return
|
||||
if with_mock:
|
||||
rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH + '/librabit_wrapper_mock.so')
|
||||
else:
|
||||
rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH + '/librabit_wrapper.so')
|
||||
rbtlib.RabitGetRank.restype = ctypes.c_int
|
||||
rbtlib.RabitGetWorldSize.restype = ctypes.c_int
|
||||
rbtlib.RabitVersionNumber.restype = ctypes.c_int
|
||||
|
||||
def unloadlib__():
|
||||
global rbtlib
|
||||
del rbtlib
|
||||
rbtlib = None
|
||||
|
||||
# reduction operators
|
||||
MAX = 0
|
||||
@ -32,14 +48,17 @@ def check_err__():
|
||||
"""
|
||||
return
|
||||
|
||||
def init(args = sys.argv):
|
||||
def init(args = sys.argv, with_mock = False):
|
||||
"""
|
||||
intialize the rabit module, call this once before using anything
|
||||
Arguments:
|
||||
args: list(string)
|
||||
the list of arguments used to initialized the rabit
|
||||
usually you need to pass in sys.argv
|
||||
args: list(string) [default=sys.argv]
|
||||
the list of arguments used to initialized the rabit
|
||||
usually you need to pass in sys.argv
|
||||
with_mock: boolean [default=False]
|
||||
Whether initialize the mock test module
|
||||
"""
|
||||
loadlib__(with_mock)
|
||||
arr = (ctypes.c_char_p * len(args))()
|
||||
arr[:] = args
|
||||
rbtlib.RabitInit(len(args), arr)
|
||||
@ -51,6 +70,7 @@ def finalize():
|
||||
"""
|
||||
rbtlib.RabitFinalize()
|
||||
check_err__()
|
||||
unloadlib__()
|
||||
|
||||
def get_rank():
|
||||
"""
|
||||
@ -156,7 +176,7 @@ DTYPE_ENUM__ = {
|
||||
np.dtype('float64') : 7
|
||||
}
|
||||
|
||||
def allreduce(data, op):
|
||||
def allreduce(data, op, prepare_fun = None):
|
||||
"""
|
||||
perform allreduce, return the result, this function is not thread-safe
|
||||
Arguments:
|
||||
@ -164,8 +184,8 @@ def allreduce(data, op):
|
||||
input data
|
||||
op: int
|
||||
reduction operators, can be MIN, MAX, SUM, BITOR
|
||||
prepare_fun: lambda
|
||||
Lazy preprocessing function, if it is not None, prepare_fun()
|
||||
prepare_fun: lambda data
|
||||
Lazy preprocessing function, if it is not None, prepare_fun(data)
|
||||
will be called by the function before performing allreduce, to intialize the data
|
||||
If the result of Allreduce can be recovered directly, then prepare_fun will NOT be called
|
||||
Returns:
|
||||
@ -178,9 +198,17 @@ def allreduce(data, op):
|
||||
buf = buf.copy()
|
||||
if buf.dtype not in DTYPE_ENUM__:
|
||||
raise Exception('data type %s not supported' % str(buf.dtype))
|
||||
rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size, DTYPE_ENUM__[dtype],
|
||||
op, None, None);
|
||||
if prepare_fun is None:
|
||||
rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||
op, None, None)
|
||||
else:
|
||||
PFUNC = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
||||
def pfunc(args):
|
||||
prepare_fun(data)
|
||||
rbtlib.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||
op, PFUNC(pfunc), None)
|
||||
check_err__()
|
||||
return buf
|
||||
|
||||
@ -273,4 +301,3 @@ def version_number():
|
||||
ret = rbtlib.RabitVersionNumber()
|
||||
check_err__()
|
||||
return ret
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user