change default behavior to behave normal
This commit is contained in:
parent
478d250818
commit
348a1e7619
7
Makefile
7
Makefile
@ -10,14 +10,14 @@ BPATH=.
|
|||||||
MPIOBJ= $(BPATH)/engine_mpi.o
|
MPIOBJ= $(BPATH)/engine_mpi.o
|
||||||
OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/engine_empty.o $(BPATH)/engine_mock.o\
|
OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/engine_empty.o $(BPATH)/engine_mock.o\
|
||||||
$(BPATH)/rabit_wrapper.o
|
$(BPATH)/rabit_wrapper.o
|
||||||
SLIB= wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so
|
SLIB= wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so wrapper/librabit_wrapper_mpi.so
|
||||||
ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a lib/librabit_mock.a
|
ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a lib/librabit_mock.a
|
||||||
HEADERS=src/*.h include/*.h include/rabit/*.h
|
HEADERS=src/*.h include/*.h include/rabit/*.h
|
||||||
.PHONY: clean all install mpi python
|
.PHONY: clean all install mpi python
|
||||||
|
|
||||||
all: lib/librabit.a lib/librabit_mock.a $(SLIB)
|
all: lib/librabit.a lib/librabit_mock.a $(SLIB)
|
||||||
mpi: lib/librabit_mpi.a
|
mpi: lib/librabit_mpi.a wrapper/librabit_wrapper_mpi.so
|
||||||
python: wrapper/librabit_wrapper.so wrpper/librabit_wrapper_mock.so
|
python: wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so
|
||||||
|
|
||||||
$(BPATH)/allreduce_base.o: src/allreduce_base.cc $(HEADERS)
|
$(BPATH)/allreduce_base.o: src/allreduce_base.cc $(HEADERS)
|
||||||
$(BPATH)/engine.o: src/engine.cc $(HEADERS)
|
$(BPATH)/engine.o: src/engine.cc $(HEADERS)
|
||||||
@ -34,6 +34,7 @@ lib/librabit_mpi.a: $(MPIOBJ)
|
|||||||
$(BPATH)/rabit_wrapper.o: wrapper/rabit_wrapper.cc
|
$(BPATH)/rabit_wrapper.o: wrapper/rabit_wrapper.cc
|
||||||
wrapper/librabit_wrapper.so: $(BPATH)/rabit_wrapper.o lib/librabit.a
|
wrapper/librabit_wrapper.so: $(BPATH)/rabit_wrapper.o lib/librabit.a
|
||||||
wrapper/librabit_wrapper_mock.so: $(BPATH)/rabit_wrapper.o lib/librabit_mock.a
|
wrapper/librabit_wrapper_mock.so: $(BPATH)/rabit_wrapper.o lib/librabit_mock.a
|
||||||
|
wrapper/librabit_wrapper_mpi.so: $(BPATH)/rabit_wrapper.o lib/librabit_mpi.a
|
||||||
|
|
||||||
$(OBJ) :
|
$(OBJ) :
|
||||||
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
|
||||||
|
|||||||
@ -38,7 +38,7 @@ class AllreduceBase : public IEngine {
|
|||||||
AllreduceBase(void);
|
AllreduceBase(void);
|
||||||
virtual ~AllreduceBase(void) {}
|
virtual ~AllreduceBase(void) {}
|
||||||
// initialize the manager
|
// initialize the manager
|
||||||
void Init(void);
|
virtual void Init(void);
|
||||||
// shutdown the engine
|
// shutdown the engine
|
||||||
virtual void Shutdown(void);
|
virtual void Shutdown(void);
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -20,10 +20,16 @@ namespace rabit {
|
|||||||
namespace engine {
|
namespace engine {
|
||||||
AllreduceRobust::AllreduceRobust(void) {
|
AllreduceRobust::AllreduceRobust(void) {
|
||||||
num_local_replica = 0;
|
num_local_replica = 0;
|
||||||
|
num_global_replica = 5;
|
||||||
|
default_local_replica = 2;
|
||||||
seq_counter = 0;
|
seq_counter = 0;
|
||||||
local_chkpt_version = 0;
|
local_chkpt_version = 0;
|
||||||
result_buffer_round = 1;
|
result_buffer_round = 1;
|
||||||
}
|
}
|
||||||
|
void AllreduceRobust::Init(void) {
|
||||||
|
AllreduceBase::Init();
|
||||||
|
result_buffer_round = std::max(world_size / num_global_replica, 1);
|
||||||
|
}
|
||||||
/*! \brief shutdown the engine */
|
/*! \brief shutdown the engine */
|
||||||
void AllreduceRobust::Shutdown(void) {
|
void AllreduceRobust::Shutdown(void) {
|
||||||
// need to sync the exec before we shutdown, do a pesudo check point
|
// need to sync the exec before we shutdown, do a pesudo check point
|
||||||
@ -44,10 +50,7 @@ void AllreduceRobust::Shutdown(void) {
|
|||||||
*/
|
*/
|
||||||
void AllreduceRobust::SetParam(const char *name, const char *val) {
|
void AllreduceRobust::SetParam(const char *name, const char *val) {
|
||||||
AllreduceBase::SetParam(name, val);
|
AllreduceBase::SetParam(name, val);
|
||||||
if (!strcmp(name, "rabit_buffer_round")) result_buffer_round = atoi(val);
|
if (!strcmp(name, "rabit_global_replica")) num_global_replica = atoi(val);
|
||||||
if (!strcmp(name, "rabit_global_replica")) {
|
|
||||||
result_buffer_round = std::max(world_size / atoi(val), 1);
|
|
||||||
}
|
|
||||||
if (!strcmp(name, "rabit_local_replica")) {
|
if (!strcmp(name, "rabit_local_replica")) {
|
||||||
num_local_replica = atoi(val);
|
num_local_replica = atoi(val);
|
||||||
}
|
}
|
||||||
@ -151,9 +154,12 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
|
|||||||
ISerializable *local_model) {
|
ISerializable *local_model) {
|
||||||
// skip action in single node
|
// skip action in single node
|
||||||
if (world_size == 1) return 0;
|
if (world_size == 1) return 0;
|
||||||
|
if (local_model != NULL && num_local_replica == 0) {
|
||||||
|
num_local_replica = default_local_replica;
|
||||||
|
}
|
||||||
if (num_local_replica == 0) {
|
if (num_local_replica == 0) {
|
||||||
utils::Check(local_model == NULL,
|
utils::Check(local_model == NULL,
|
||||||
"need to set num_local_replica larger than 1 to checkpoint local_model");
|
"need to set rabit_local_replica larger than 1 to checkpoint local_model");
|
||||||
}
|
}
|
||||||
// check if we succesful
|
// check if we succesful
|
||||||
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp)) {
|
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp)) {
|
||||||
@ -214,9 +220,12 @@ void AllreduceRobust::CheckPoint(const ISerializable *global_model,
|
|||||||
if (world_size == 1) {
|
if (world_size == 1) {
|
||||||
version_number += 1; return;
|
version_number += 1; return;
|
||||||
}
|
}
|
||||||
|
if (local_model != NULL && num_local_replica == 0) {
|
||||||
|
num_local_replica = default_local_replica;
|
||||||
|
}
|
||||||
if (num_local_replica == 0) {
|
if (num_local_replica == 0) {
|
||||||
utils::Check(local_model == NULL,
|
utils::Check(local_model == NULL,
|
||||||
"need to set num_local_replica larger than 1 to checkpoint local_model");
|
"need to set rabit_local_replica larger than 1 to checkpoint local_model");
|
||||||
}
|
}
|
||||||
if (num_local_replica != 0) {
|
if (num_local_replica != 0) {
|
||||||
while (true) {
|
while (true) {
|
||||||
|
|||||||
@ -23,6 +23,8 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
public:
|
public:
|
||||||
AllreduceRobust(void);
|
AllreduceRobust(void);
|
||||||
virtual ~AllreduceRobust(void) {}
|
virtual ~AllreduceRobust(void) {}
|
||||||
|
// initialize the manager
|
||||||
|
virtual void Init(void);
|
||||||
/*! \brief shutdown the engine */
|
/*! \brief shutdown the engine */
|
||||||
virtual void Shutdown(void);
|
virtual void Shutdown(void);
|
||||||
/*!
|
/*!
|
||||||
@ -468,6 +470,10 @@ o * the input state must exactly one saved state(local state of current node)
|
|||||||
std::string global_checkpoint;
|
std::string global_checkpoint;
|
||||||
// number of replica for local state/model
|
// number of replica for local state/model
|
||||||
int num_local_replica;
|
int num_local_replica;
|
||||||
|
// number of default local replica
|
||||||
|
int default_local_replica;
|
||||||
|
// number of replica for global state/model
|
||||||
|
int num_global_replica;
|
||||||
// --- recovery data structure for local checkpoint
|
// --- recovery data structure for local checkpoint
|
||||||
// there is two version of the data structure,
|
// there is two version of the data structure,
|
||||||
// at one time one version is valid and another is used as temp memory
|
// at one time one version is valid and another is used as temp memory
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
import rabit
|
import rabit
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
rabit.init(with_mock = True)
|
rabit.init(lib='mock')
|
||||||
rank = rabit.get_rank()
|
rank = rabit.get_rank()
|
||||||
n = 10
|
n = 10
|
||||||
nround = 3
|
nround = 3
|
||||||
|
|||||||
@ -18,15 +18,19 @@ else:
|
|||||||
rbtlib = None
|
rbtlib = None
|
||||||
|
|
||||||
# load in xgboost library
|
# load in xgboost library
|
||||||
def loadlib__(with_mock = False):
|
def loadlib__(lib = 'standard'):
|
||||||
global rbtlib
|
global rbtlib
|
||||||
if rbtlib != None:
|
if rbtlib != None:
|
||||||
warnings.Warn('rabit.int call was ignored because it has already been initialized', level = 2)
|
warnings.Warn('rabit.int call was ignored because it has already been initialized', level = 2)
|
||||||
return
|
return
|
||||||
if with_mock:
|
if lib == 'standard':
|
||||||
|
rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH + '/librabit_wrapper.so')
|
||||||
|
elif lib == 'mock':
|
||||||
rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH + '/librabit_wrapper_mock.so')
|
rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH + '/librabit_wrapper_mock.so')
|
||||||
|
elif lib == 'mpi':
|
||||||
|
rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH + '/librabit_wrapper_mpi.so')
|
||||||
else:
|
else:
|
||||||
rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH + '/librabit_wrapper.so')
|
raise Exception('unknown rabit lib %s, can be standard, mock, mpi' % lib)
|
||||||
rbtlib.RabitGetRank.restype = ctypes.c_int
|
rbtlib.RabitGetRank.restype = ctypes.c_int
|
||||||
rbtlib.RabitGetWorldSize.restype = ctypes.c_int
|
rbtlib.RabitGetWorldSize.restype = ctypes.c_int
|
||||||
rbtlib.RabitVersionNumber.restype = ctypes.c_int
|
rbtlib.RabitVersionNumber.restype = ctypes.c_int
|
||||||
@ -48,7 +52,7 @@ def check_err__():
|
|||||||
"""
|
"""
|
||||||
return
|
return
|
||||||
|
|
||||||
def init(args = sys.argv, with_mock = False):
|
def init(args = sys.argv, lib = 'standard'):
|
||||||
"""
|
"""
|
||||||
intialize the rabit module, call this once before using anything
|
intialize the rabit module, call this once before using anything
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -58,7 +62,7 @@ def init(args = sys.argv, with_mock = False):
|
|||||||
with_mock: boolean [default=False]
|
with_mock: boolean [default=False]
|
||||||
Whether initialize the mock test module
|
Whether initialize the mock test module
|
||||||
"""
|
"""
|
||||||
loadlib__(with_mock)
|
loadlib__(lib)
|
||||||
arr = (ctypes.c_char_p * len(args))()
|
arr = (ctypes.c_char_p * len(args))()
|
||||||
arr[:] = args
|
arr[:] = args
|
||||||
rbtlib.RabitInit(len(args), arr)
|
rbtlib.RabitInit(len(args), arr)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user