From 2750679270219102c51a2da8f22bbe171bf40d4a Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 7 Dec 2014 20:57:29 -0800 Subject: [PATCH] normal state running ok --- src/allreduce_robust.cc | 17 ++-- src/mock.h | 12 +-- src/rabit-inl.h | 10 ++- src/rabit.h | 26 ++++-- test/Makefile | 6 +- test/keepalive.sh | 4 +- test/test_local_recover.cpp | 154 ++++++++++++++++++++++++++++++++++++ test/test_model_recover.cpp | 2 +- toolkit/kmeans.cpp | 2 +- 9 files changed, 203 insertions(+), 30 deletions(-) create mode 100644 test/test_local_recover.cpp diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index 3d2f128f2..1de92b7d6 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -26,12 +26,12 @@ void AllreduceRobust::Shutdown(void) { // need to sync the exec before we shutdown, do a pesudo check point // execute checkpoint, note: when checkpoint existing, load will not happen utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp), - "check point must return true"); + "Shutdown: check point must return true"); // reset result buffer resbuf.Clear(); seq_counter = 0; // execute check ack step, load happens here utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp), - "check ack must return true"); + "Shutdown: check ack must return true"); AllreduceBase::Shutdown(); } /*! @@ -201,9 +201,8 @@ void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model, if (CheckAndRecover(TryCheckinLocalState(&local_rptr[new_version], &local_chkpt[new_version]))) break; } - // run the ack phase - utils::Assert(RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckAck), - "check point must return true"); + // run the ack phase, can be true or false + RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckAck); // switch pointer to new version local_chkpt_version = !local_chkpt_version; } @@ -678,7 +677,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) { // if minimum sequence requested is local check point ack, // this means all nodes have finished local check point, directly return - if (seqno == ActionSummary::kLocalCheckAck) return kSuccess; + if (seqno == ActionSummary::kLocalCheckAck) return kSuccess; if (seqno == ActionSummary::kLocalCheckPoint) { // new version of local model int new_version = !local_chkpt_version; @@ -972,8 +971,8 @@ AllreduceRobust::TryCheckinLocalState(std::vector *p_local_rptr, ring_prev, ring_next); if (succ != kSuccess) return succ; // update rptr - rptr.resize(n + 1); - for (int i = 1; i < n; ++i) { + rptr.resize(n + 2); + for (int i = 1; i <= n; ++i) { rptr[i + 1] = rptr[i] + sizes[i]; } chkpt.resize(rptr.back()); @@ -1013,7 +1012,7 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_, LinkRecord *read_link, LinkRecord *write_link) { if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess; - utils::Assert(write_end <= read_end, "RingPassing: boundary check1"); + utils::Assert(write_end <= read_end, "RingPassing: boundary check1, write_end=%lu, read_end=%lu", write_end, read_end); utils::Assert(read_ptr <= read_end, "RingPassing: boundary check2"); utils::Assert(write_ptr <= write_end, "RingPassing: boundary check3"); // take reference diff --git a/src/mock.h b/src/mock.h index 31c93d113..e5a4c283a 100644 --- a/src/mock.h +++ b/src/mock.h @@ -30,14 +30,16 @@ public: rabit::Allreduce(sendrecvbuf, count); } - inline bool LoadCheckPoint(utils::ISerializable *p_model) { +inline int LoadCheckPoint(utils::ISerializable *global_model, + utils::ISerializable *local_model) { utils::Assert(verify(loadCheckpoint), "[%d] error when loading checkpoint", rank); - return rabit::LoadCheckPoint(p_model); + return rabit::LoadCheckPoint(global_model, local_model); } - - inline void CheckPoint(const utils::ISerializable &model) { + + inline void CheckPoint(const utils::ISerializable *global_model, + const utils::ISerializable *local_model) { utils::Assert(verify(checkpoint), "[%d] error when checkpointing", rank); - rabit::CheckPoint(model); + rabit::CheckPoint(global_model, local_model); } inline void Broadcast(std::string *sendrecv_data, int root) { diff --git a/src/rabit-inl.h b/src/rabit-inl.h index b13ea88fc..95a2eb8fd 100644 --- a/src/rabit-inl.h +++ b/src/rabit-inl.h @@ -124,12 +124,14 @@ inline void Allreduce(DType *sendrecvbuf, size_t count) { engine::mpi::GetType(), OP::kType); } // load latest check point -inline int LoadCheckPoint(utils::ISerializable *p_model) { - return engine::GetEngine()->LoadCheckPoint(p_model); +inline int LoadCheckPoint(utils::ISerializable *global_model, + utils::ISerializable *local_model) { + return engine::GetEngine()->LoadCheckPoint(global_model, local_model); } // checkpoint the model, meaning we finished a stage of execution -inline void CheckPoint(const utils::ISerializable &model) { - engine::GetEngine()->CheckPoint(&model); +inline void CheckPoint(const utils::ISerializable *global_model, + const utils::ISerializable *local_model) { + engine::GetEngine()->CheckPoint(global_model, local_model); } // return the version number of currently stored model inline int VersionNumber(void) { diff --git a/src/rabit.h b/src/rabit.h index 68e39f3fa..f19792442 100644 --- a/src/rabit.h +++ b/src/rabit.h @@ -84,7 +84,12 @@ template inline void Allreduce(DType *sendrecvbuf, size_t count); /*! * \brief load latest check point - * \param p_model pointer to the model + * \param global_model pointer to the globally shared model/state + * when calling this function, the caller need to gauranttees that global_model + * is the same in all nodes + * \param local_model pointer to local model, that is specific to current node/rank + * this can be NULL when no local model is needed + * * \return the version number of check point loaded * if returned version == 0, this means no model has been CheckPointed * the p_model is not touched, user should do necessary initialization by themselves @@ -99,15 +104,24 @@ inline void Allreduce(DType *sendrecvbuf, size_t count); * * \sa CheckPoint, VersionNumber */ -inline int LoadCheckPoint(utils::ISerializable *p_model); +inline int LoadCheckPoint(utils::ISerializable *global_model, + utils::ISerializable *local_model = NULL); /*! * \brief checkpoint the model, meaning we finished a stage of execution * every time we call check point, there is a version number which will increase by one * - * \param p_model pointer to the model - * \sa LoadCheckPoint, VersionNumber - */ -inline void CheckPoint(const utils::ISerializable &model); + * \param global_model pointer to the globally shared model/state + * when calling this function, the caller need to gauranttees that global_model + * is the same in all nodes + * \param local_model pointer to local model, that is specific to current node/rank + * this can be NULL when no local state is needed + * NOTE: local_model requires explicit replication of the model for fault-tolerance, which will + * bring replication cost in CheckPoint function. global_model do not need explicit replication. + * So only CheckPoint with global_model if possible + * \sa LoadCheckPoint, VersionNumber + */ +inline void CheckPoint(const utils::ISerializable *global_model, + const utils::ISerializable *local_model = NULL); /*! * \return version number of current stored model, * which means how many calls to CheckPoint we made so far diff --git a/test/Makefile b/test/Makefile index f10229a1f..9f742be74 100644 --- a/test/Makefile +++ b/test/Makefile @@ -5,12 +5,12 @@ export LDFLAGS= -pthread -lm -lrt export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src # specify tensor path -BIN = test_allreduce test_recover test_model_recover speed_test +BIN = test_allreduce test_recover test_model_recover speed_test test_local_recover # objectives that makes up rabit library RABIT_OBJ = allreduce_base.o allreduce_robust.o engine.o MPIOBJ = engine_mpi.o -OBJ = $(RABIT_OBJ) test_allreduce.o test_recover.o test_model_recover.o speed_test.o +OBJ = $(RABIT_OBJ) test_allreduce.o test_recover.o test_model_recover.o speed_test.o test_local_recover.o MPIBIN = test_allreduce.mpi speed_test.mpi .PHONY: clean all @@ -24,6 +24,7 @@ test_allreduce.o: test_allreduce.cpp ../src/*.h speed_test.o: speed_test.cpp ../src/*.h test_recover.o: test_recover.cpp ../src/*.h test_model_recover.o: test_model_recover.cpp ../src/*.h +test_local_recover.o: test_local_recover.cpp ../src/*.h # we can link against MPI version to get use MPI test_allreduce: test_allreduce.o $(RABIT_OBJ) @@ -32,6 +33,7 @@ speed_test: speed_test.o $(RABIT_OBJ) speed_test.mpi: speed_test.o $(MPIOBJ) test_recover: test_recover.o $(RABIT_OBJ) test_model_recover: test_model_recover.o $(RABIT_OBJ) +test_local_recover: test_local_recover.o $(RABIT_OBJ) $(BIN) : $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS) diff --git a/test/keepalive.sh b/test/keepalive.sh index e72a2bba9..ddfc5d618 100755 --- a/test/keepalive.sh +++ b/test/keepalive.sh @@ -6,8 +6,8 @@ then exit -1 fi nrep=0 -echo ./$@ job_id=$OMPI_COMM_WORLD_RANK -until ./$@ job_id=$OMPI_COMM_WORLD_RANK repeat=$nrep; do +echo ./$@ task_id=$OMPI_COMM_WORLD_RANK +until ./$@ task_id=$OMPI_COMM_WORLD_RANK repeat=$nrep; do sleep 1 nrep=$((nrep+1)) echo ./$@ job_id=$OMPI_COMM_WORLD_RANK repeat=$nrep diff --git a/test/test_local_recover.cpp b/test/test_local_recover.cpp new file mode 100644 index 000000000..87262ba7b --- /dev/null +++ b/test/test_local_recover.cpp @@ -0,0 +1,154 @@ +// this is a test case to test whether rabit can recover model when +// facing an exception +#include +#include +#include +#include +#include +#include + +using namespace rabit; + +struct MockException { +}; + +// dummy model +class Model : public rabit::utils::ISerializable { + public: + // iterations + std::vector data; + // load from stream + virtual void Load(rabit::utils::IStream &fi) { + fi.Read(&data); + } + /*! \brief save the model to the stream */ + virtual void Save(rabit::utils::IStream &fo) const { + fo.Write(data); + } + virtual void InitModel(size_t n, float v) { + data.resize(n, v); + } +}; + +inline void TestMax(test::Mock &mock, Model *model, Model *local, int ntrial, int iter) { + int rank = rabit::GetRank(); + int nproc = rabit::GetWorldSize(); + const int z = iter + 111; + + std::vector ndata(model->data.size()); + for (size_t i = 0; i < ndata.size(); ++i) { + ndata[i] = (i * (rank+1)) % z + local->data[i]; + } + mock.Allreduce(&ndata[0], ndata.size()); + if (ntrial == iter && rank == 3) { + //exit(-1); + } + for (size_t i = 0; i < ndata.size(); ++i) { + float rmax = (i * 1) % z + model->data[i]; + for (int r = 0; r < nproc; ++r) { + rmax = std::max(rmax, (float)((i * (r+1)) % z) + model->data[i] + r); + } + utils::Check(rmax == ndata[i], "[%d] TestMax check failure", rank); + } + model->data = ndata; + local->data = ndata; + for (size_t i = 0; i < ndata.size(); ++i) { + local->data[i] = ndata[i] + rank; + } +} + +inline void TestSum(test::Mock &mock, Model *model, Model *local, int ntrial, int iter) { + int rank = rabit::GetRank(); + int nproc = rabit::GetWorldSize(); + const int z = 131 + iter; + + std::vector ndata(model->data.size()); + for (size_t i = 0; i < ndata.size(); ++i) { + ndata[i] = (i * (rank+1)) % z + local->data[i]; + } + mock.Allreduce(&ndata[0], ndata.size()); + + if (ntrial == iter && rank == 0) { + exit(-1); + } + + for (size_t i = 0; i < ndata.size(); ++i) { + float rsum = 0.0f; + for (int r = 0; r < nproc; ++r) { + rsum += (float)((i * (r+1)) % z) + model->data[i] + r; + } + utils::Check(fabsf(rsum - ndata[i]) < 1e-5 , + "[%d] TestSum check failure, local=%g, allreduce=%g", rank, rsum, ndata[i]); + } + model->data = ndata; + for (size_t i = 0; i < ndata.size(); ++i) { + local->data[i] = ndata[i] + rank; + } +} + +inline void TestBcast(test::Mock &mock, size_t n, int root, int ntrial) { + int rank = rabit::GetRank(); + std::string s; s.resize(n); + for (size_t i = 0; i < n; ++i) { + s[i] = char(i % 126 + 1); + } + std::string res; + if (root == rank) { + res = s; + mock.Broadcast(&res, root); + } else { + mock.Broadcast(&res, root); + } + utils::Check(res == s, "[%d] TestBcast fail", rank); +} + +int main(int argc, char *argv[]) { + if (argc < 3) { + printf("Usage: \n"); + return 0; + } + int n = atoi(argv[1]); + rabit::Init(argc, argv); + int rank = rabit::GetRank(); + int nproc = rabit::GetWorldSize(); + std::string name = rabit::GetProcessorName(); + test::Mock mock(rank, argv[2], argv[3]); + Model model, local; + srand(0); + int ntrial = 0; + for (int i = 1; i < argc; ++i) { + int n; + if (sscanf(argv[i], "repeat=%d", &n) == 1) ntrial = n; + } + while (true) { + try { + int iter = rabit::LoadCheckPoint(&model, &local); + if (iter == 0) { + model.InitModel(n, 1.0f); + local.InitModel(n, 1.0f + rank); + utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter); + } else { + utils::LogPrintf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter); + } + for (int r = iter; r < 3; ++r) { + TestMax(mock, &model, &local, ntrial, r); + utils::LogPrintf("[%d] !!!TestMax pass, iter=%d\n", rank, r); + int step = std::max(nproc / 3, 1); + for (int i = 0; i < nproc; i += step) { + TestBcast(mock, n, i, ntrial); + } + utils::LogPrintf("[%d] !!!TestBcast pass, iter=%d\n", rank, r); + TestSum(mock, &model, &local, ntrial, r); + utils::LogPrintf("[%d] !!!TestSum pass, iter=%d\n", rank, r); + rabit::CheckPoint(&model, &local); + utils::LogPrintf("[%d] !!!CheckPont pass, iter=%d\n", rank, r); + } + break; + } catch (MockException &e) { + rabit::engine::GetEngine()->InitAfterException(); + ++ntrial; + } + } + rabit::Finalize(); + return 0; +} diff --git a/test/test_model_recover.cpp b/test/test_model_recover.cpp index 86762c671..2b72cde75 100644 --- a/test/test_model_recover.cpp +++ b/test/test_model_recover.cpp @@ -132,7 +132,7 @@ int main(int argc, char *argv[]) { utils::LogPrintf("[%d] !!!TestBcast pass, iter=%d\n", rank, r); TestSum(mock, &model, ntrial, r); utils::LogPrintf("[%d] !!!TestSum pass, iter=%d\n", rank, r); - rabit::CheckPoint(model); + rabit::CheckPoint(&model); utils::LogPrintf("[%d] !!!CheckPont pass, iter=%d\n", rank, r); } break; diff --git a/toolkit/kmeans.cpp b/toolkit/kmeans.cpp index 674223cc6..e6dffd500 100644 --- a/toolkit/kmeans.cpp +++ b/toolkit/kmeans.cpp @@ -137,7 +137,7 @@ int main(int argc, char *argv[]) { } } model.Normalize(); - rabit::CheckPoint(model); + rabit::CheckPoint(&model); } // output the model file to somewhere if (rabit::GetRank() == 0) {