diff --git a/src/engine.h b/src/engine.h index aede4ac74..807f7c6ad 100644 --- a/src/engine.h +++ b/src/engine.h @@ -52,7 +52,13 @@ class IEngine { * \param root the root worker id to broadcast the data */ virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) = 0; - /*! + /*! + * \brief explicitly re-init everything before calling LoadCheckPoint + * call this function when IEngine throw an exception out, + * this function is only used for test purpose + */ + virtual void InitAfterException(void) = 0; + /*! * \brief load latest check point * \param p_model pointer to the model * \return true if there was stored checkpoint and load was successful @@ -63,7 +69,7 @@ class IEngine { * \brief checkpoint the model, meaning we finished a stage of execution * \param p_model pointer to the model */ - virtual void CheckPoint(const utils::ISerializable &model) = 0; + virtual void CheckPoint(const utils::ISerializable &model) = 0; /*! \brief get rank of current node */ virtual int GetRank(void) const = 0; /*! \brief get total number of */ diff --git a/src/engine_base.h b/src/engine_base.h index 2fd5a761b..0cc281cff 100644 --- a/src/engine_base.h +++ b/src/engine_base.h @@ -93,7 +93,15 @@ class AllReduceBase : public IEngine { */ virtual void CheckPoint(const utils::ISerializable &model) { } - + /*! + * \brief explicitly re-init everything before calling LoadCheckPoint + * call this function when IEngine throw an exception out, + * this function is only used for test purpose + */ + virtual void InitAfterException(void) { + utils::Error("InitAfterException: not implemented"); + } + protected: /*! \brief enumeration of possible returning results from Try functions */ enum ReturnType { diff --git a/src/engine_robust.h b/src/engine_robust.h index 703a54469..783b2deb7 100644 --- a/src/engine_robust.h +++ b/src/engine_robust.h @@ -51,6 +51,14 @@ class AllReduceRobust : public AllReduceBase { * \param p_model pointer to the model */ virtual void CheckPoint(const utils::ISerializable &model); + /*! + * \brief explicitly re-init everything before calling LoadCheckPoint + * call this function when IEngine throw an exception out, + * this function is only used for test purpose + */ + virtual void InitAfterException(void) { + this->CheckAndRecover(kGetExcept); + } private: // constant one byte out of band message to indicate error happening diff --git a/test/Makefile b/test/Makefile index 49aca06e1..a3f6b07c7 100644 --- a/test/Makefile +++ b/test/Makefile @@ -11,7 +11,7 @@ else endif # specify tensor path -BIN = test_allreduce +BIN = test_allreduce test_recover OBJ = engine_base.o engine_robust.o engine.o .PHONY: clean all @@ -22,6 +22,7 @@ engine_base.o: ../src/engine_base.cc ../src/*.h engine.o: ../src/engine.cc ../src/*.h engine_robust.o: ../src/engine_robust.cc ../src/*.h test_allreduce: test_allreduce.cpp ../src/*.h $(OBJ) +test_recover: test_recover.cpp ../src/*.h $(OBJ) $(BIN) : $(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) diff --git a/test/test.sh b/test/test.sh index 5c70404ac..30d6bbca7 100755 --- a/test/test.sh +++ b/test/test.sh @@ -5,4 +5,4 @@ then exit -1 fi -../submit_job_tcp.py $1 test_allreduce $2 $3 $4 \ No newline at end of file +../submit_job_tcp.py $1 test_recover $2 $3 $4 diff --git a/test/test_recover.cpp b/test/test_recover.cpp new file mode 100644 index 000000000..81e81c6fa --- /dev/null +++ b/test/test_recover.cpp @@ -0,0 +1,126 @@ +#include +#include +#include +#include +#include +#include + +using namespace rabit; + +struct MockException { +}; + +inline void TestMax(test::Mock &mock, size_t n, int ntrial) { + int rank = rabit::GetRank(); + int nproc = rabit::GetWorldSize(); + + std::vector ndata(n); + for (size_t i = 0; i < ndata.size(); ++i) { + ndata[i] = (i * (rank+1)) % 111; + } + mock.AllReduce(&ndata[0], ndata.size()); + for (size_t i = 0; i < ndata.size(); ++i) { + float rmax = (i * 1) % 111; + for (int r = 0; r < nproc; ++r) { + rmax = std::max(rmax, (float)((i * (r+1)) % 111)); + } + utils::Check(rmax == ndata[i], "[%d] TestMax check failure", rank); + } +} + +inline void TestSum(test::Mock &mock, size_t n, int ntrial) { + int rank = rabit::GetRank(); + int nproc = rabit::GetWorldSize(); + const int z = 131; + + std::vector ndata(n); + for (size_t i = 0; i < ndata.size(); ++i) { + ndata[i] = (i * (rank+1)) % z; + } + mock.AllReduce(&ndata[0], ndata.size()); + + if (ntrial == 0 && rank == 0) throw MockException(); + + for (size_t i = 0; i < ndata.size(); ++i) { + float rsum = 0.0f; + for (int r = 0; r < nproc; ++r) { + rsum += (float)((i * (r+1)) % z); + } + utils::Check(fabsf(rsum - ndata[i]) < 1e-5 , + "[%d] TestSum check failure, local=%g, allreduce=%g", rank, rsum, ndata[i]); + } +} + +inline void TestBcast(test::Mock &mock, size_t n, int root, int ntrial) { + int rank = rabit::GetRank(); + std::string s; s.resize(n); + for (size_t i = 0; i < n; ++i) { + s[i] = char(i % 126 + 1); + } + std::string res; + if (root == rank) { + res = s; + mock.Broadcast(&res, root); + } else { + mock.Broadcast(&res, root); + } + utils::Check(res == s, "[%d] TestBcast fail", rank); +} +// dummy model +class Model : public rabit::utils::ISerializable { + public: + // iterations + int iter; + // load from stream + virtual void Load(rabit::utils::IStream &fi) { + fi.Read(&iter, sizeof(iter)); + } + /*! \brief save the model to the stream */ + virtual void Save(rabit::utils::IStream &fo) const { + fo.Write(&iter, sizeof(iter)); + } + virtual void InitModel(void) { + iter = 0; + } +}; + +int main(int argc, char *argv[]) { + if (argc < 3) { + printf("Usage: \n"); + return 0; + } + int n = atoi(argv[1]); + rabit::Init(argc, argv); + int rank = rabit::GetRank(); + int nproc = rabit::GetWorldSize(); + std::string name = rabit::GetProcessorName(); + test::Mock mock(rank, argv[2], argv[3]); + Model model; + srand(0); + int ntrial = 0; + while (true) { + try { + if (!rabit::LoadCheckPoint(&model)) { + model.InitModel(); + } + utils::LogPrintf("[%d] start at %s\n", rank, name.c_str()); + TestMax(mock, n, ntrial); + utils::LogPrintf("[%d] !!!TestMax pass\n", rank); + TestSum(mock, n, ntrial); + utils::LogPrintf("[%d] !!!TestSum pass\n", rank); + + for (int i = 0; i < nproc; i += nproc / 3) { + TestBcast(mock, n, i, ntrial); + } + utils::LogPrintf("[%d] !!!TestBcast pass\n", rank); + // reach here + break; + } catch (MockException &e) { + rabit::engine::GetEngine()->InitAfterException(); + ++ntrial; + } + } + rabit::Finalize(); + printf("[%d] all check pass\n", rank); + return 0; +}