diff --git a/Makefile b/Makefile index 64b58d76e..c5d052b7c 100644 --- a/Makefile +++ b/Makefile @@ -10,14 +10,14 @@ BPATH=. MPIOBJ= $(BPATH)/engine_mpi.o OBJ= $(BPATH)/allreduce_base.o $(BPATH)/allreduce_robust.o $(BPATH)/engine.o $(BPATH)/engine_empty.o $(BPATH)/engine_mock.o\ $(BPATH)/rabit_wrapper.o -SLIB= wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so +SLIB= wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so wrapper/librabit_wrapper_mpi.so ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a lib/librabit_mock.a HEADERS=src/*.h include/*.h include/rabit/*.h .PHONY: clean all install mpi python all: lib/librabit.a lib/librabit_mock.a $(SLIB) -mpi: lib/librabit_mpi.a -python: wrapper/librabit_wrapper.so wrpper/librabit_wrapper_mock.so +mpi: lib/librabit_mpi.a wrapper/librabit_wrapper_mpi.so +python: wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so $(BPATH)/allreduce_base.o: src/allreduce_base.cc $(HEADERS) $(BPATH)/engine.o: src/engine.cc $(HEADERS) @@ -34,6 +34,7 @@ lib/librabit_mpi.a: $(MPIOBJ) $(BPATH)/rabit_wrapper.o: wrapper/rabit_wrapper.cc wrapper/librabit_wrapper.so: $(BPATH)/rabit_wrapper.o lib/librabit.a wrapper/librabit_wrapper_mock.so: $(BPATH)/rabit_wrapper.o lib/librabit_mock.a +wrapper/librabit_wrapper_mpi.so: $(BPATH)/rabit_wrapper.o lib/librabit_mpi.a $(OBJ) : $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) ) diff --git a/src/allreduce_base.h b/src/allreduce_base.h index 14a8cf339..da57c34f6 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -38,7 +38,7 @@ class AllreduceBase : public IEngine { AllreduceBase(void); virtual ~AllreduceBase(void) {} // initialize the manager - void Init(void); + virtual void Init(void); // shutdown the engine virtual void Shutdown(void); /*! diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index e25a9c85f..7cb2b3611 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -20,10 +20,16 @@ namespace rabit { namespace engine { AllreduceRobust::AllreduceRobust(void) { num_local_replica = 0; + num_global_replica = 5; + default_local_replica = 2; seq_counter = 0; local_chkpt_version = 0; result_buffer_round = 1; } +void AllreduceRobust::Init(void) { + AllreduceBase::Init(); + result_buffer_round = std::max(world_size / num_global_replica, 1); +} /*! \brief shutdown the engine */ void AllreduceRobust::Shutdown(void) { // need to sync the exec before we shutdown, do a pesudo check point @@ -44,10 +50,7 @@ void AllreduceRobust::Shutdown(void) { */ void AllreduceRobust::SetParam(const char *name, const char *val) { AllreduceBase::SetParam(name, val); - if (!strcmp(name, "rabit_buffer_round")) result_buffer_round = atoi(val); - if (!strcmp(name, "rabit_global_replica")) { - result_buffer_round = std::max(world_size / atoi(val), 1); - } + if (!strcmp(name, "rabit_global_replica")) num_global_replica = atoi(val); if (!strcmp(name, "rabit_local_replica")) { num_local_replica = atoi(val); } @@ -151,9 +154,12 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model, ISerializable *local_model) { // skip action in single node if (world_size == 1) return 0; + if (local_model != NULL && num_local_replica == 0) { + num_local_replica = default_local_replica; + } if (num_local_replica == 0) { utils::Check(local_model == NULL, - "need to set num_local_replica larger than 1 to checkpoint local_model"); + "need to set rabit_local_replica larger than 1 to checkpoint local_model"); } // check if we succesful if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp)) { @@ -214,9 +220,12 @@ void AllreduceRobust::CheckPoint(const ISerializable *global_model, if (world_size == 1) { version_number += 1; return; } + if (local_model != NULL && num_local_replica == 0) { + num_local_replica = default_local_replica; + } if (num_local_replica == 0) { utils::Check(local_model == NULL, - "need to set num_local_replica larger than 1 to checkpoint local_model"); + "need to set rabit_local_replica larger than 1 to checkpoint local_model"); } if (num_local_replica != 0) { while (true) { diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index f2a804e95..921f18319 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -23,6 +23,8 @@ class AllreduceRobust : public AllreduceBase { public: AllreduceRobust(void); virtual ~AllreduceRobust(void) {} + // initialize the manager + virtual void Init(void); /*! \brief shutdown the engine */ virtual void Shutdown(void); /*! @@ -468,6 +470,10 @@ o * the input state must exactly one saved state(local state of current node) std::string global_checkpoint; // number of replica for local state/model int num_local_replica; + // number of default local replica + int default_local_replica; + // number of replica for global state/model + int num_global_replica; // --- recovery data structure for local checkpoint // there is two version of the data structure, // at one time one version is valid and another is used as temp memory diff --git a/test/test_local_recover.py b/test/test_local_recover.py index 02dcc3e7f..e35bd3177 100755 --- a/test/test_local_recover.py +++ b/test/test_local_recover.py @@ -2,7 +2,7 @@ import rabit import numpy as np -rabit.init(with_mock = True) +rabit.init(lib='mock') rank = rabit.get_rank() n = 10 nround = 3 diff --git a/wrapper/rabit.py b/wrapper/rabit.py index a4932a0a4..7fc913084 100644 --- a/wrapper/rabit.py +++ b/wrapper/rabit.py @@ -18,15 +18,19 @@ else: rbtlib = None # load in xgboost library -def loadlib__(with_mock = False): +def loadlib__(lib = 'standard'): global rbtlib if rbtlib != None: warnings.Warn('rabit.int call was ignored because it has already been initialized', level = 2) return - if with_mock: + if lib == 'standard': + rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH + '/librabit_wrapper.so') + elif lib == 'mock': rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH + '/librabit_wrapper_mock.so') + elif lib == 'mpi': + rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH + '/librabit_wrapper_mpi.so') else: - rbtlib = ctypes.cdll.LoadLibrary(WRAPPER_PATH + '/librabit_wrapper.so') + raise Exception('unknown rabit lib %s, can be standard, mock, mpi' % lib) rbtlib.RabitGetRank.restype = ctypes.c_int rbtlib.RabitGetWorldSize.restype = ctypes.c_int rbtlib.RabitVersionNumber.restype = ctypes.c_int @@ -48,7 +52,7 @@ def check_err__(): """ return -def init(args = sys.argv, with_mock = False): +def init(args = sys.argv, lib = 'standard'): """ intialize the rabit module, call this once before using anything Arguments: @@ -58,7 +62,7 @@ def init(args = sys.argv, with_mock = False): with_mock: boolean [default=False] Whether initialize the mock test module """ - loadlib__(with_mock) + loadlib__(lib) arr = (ctypes.c_char_p * len(args))() arr[:] = args rbtlib.RabitInit(len(args), arr)