diff --git a/src/allreduce.h b/src/allreduce.h index c9bd0e579..9f150dcf4 100644 --- a/src/allreduce.h +++ b/src/allreduce.h @@ -4,6 +4,9 @@ * \author Tianqi Chen, Nacho, Tianyi */ #include "./engine.h" +#ifdef TEST + #include "./mock.h" +#endif /*! \brief namespace of all reduce */ namespace sync { @@ -43,6 +46,11 @@ void Init(int argc, char *argv[]) { void Finalize(void) { engine::Finalize(); } +#ifdef TEST +void SetMock(test::Mock& mock) { + engine::SetMock(mock); +} +#endif /*! \brief get rank of current process */ inline int GetRank(void) { return engine::GetEngine()->GetRank(); diff --git a/src/engine.h b/src/engine.h index ca928b22a..852e9187a 100644 --- a/src/engine.h +++ b/src/engine.h @@ -6,6 +6,10 @@ * \author Tianqi Chen, Nacho, Tianyi */ #include "./io.h" +#ifdef TEST + #include "./mock.h" +#endif + namespace MPI { /*! \brief MPI data type just to be compatible with MPI reduce function*/ @@ -76,5 +80,10 @@ void Init(int argc, char *argv[]); void Finalize(void); /*! \brief singleton method to get engine */ IEngine *GetEngine(void); + +#ifdef TEST +void SetMock(test::Mock& mock); +#endif + } // namespace engine #endif // ALLREDUCE_ENGINE_H diff --git a/src/engine_tcp.cpp b/src/engine_tcp.cpp index a0506129d..957319db4 100644 --- a/src/engine_tcp.cpp +++ b/src/engine_tcp.cpp @@ -12,6 +12,9 @@ #include #include "./engine.h" #include "./socket.h" +#ifdef TEST + #include "./mock.h" +#endif namespace MPI { class Datatype { @@ -37,6 +40,13 @@ class SyncManager : public IEngine { } ~SyncManager(void) { } + + #ifdef TEST + inline void SetMock(test::Mock& mock) { + this->mock = mock; + } + #endif + inline void Shutdown(void) { for (size_t i = 0; i < links.size(); ++i) { links[i].sock.Close(); @@ -168,6 +178,9 @@ class SyncManager : public IEngine { size_t type_nbytes, size_t count, ReduceFunction reducer) { + #ifdef TEST + utils::Assert(mock.AllReduce(this->rank), "Error returned by mock when reducing"); + #endif if (links.size() == 0) return; // total size of message const size_t total_size = type_nbytes * count; @@ -275,6 +288,10 @@ class SyncManager : public IEngine { * \param root the root worker id to broadcast the data */ virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) { + #ifdef TEST + utils::Assert(mock.Broadcast(this->rank), "Error returned by mock when broadcasting"); + #endif + if (links.size() == 0) return; // number of links const int nlink = static_cast(links.size()); @@ -329,9 +346,15 @@ class SyncManager : public IEngine { } } virtual bool LoadCheckPoint(utils::ISerializable *p_model) { + #ifdef TEST + utils::Assert(mock.LoadCheckPoint(this->rank), "Error returned by mock when loading checkpoint"); + #endif return false; } virtual void CheckPoint(const utils::ISerializable &model) { + #ifdef TEST + utils::Assert(mock.CheckPoint(this->rank), "Error returned by mock when checkpointing"); + #endif } private: @@ -421,6 +444,11 @@ class SyncManager : public IEngine { std::vector links; // select helper utils::SelectHelper selecter; + + #ifdef TEST + // mock to test + test::Mock mock; + #endif }; // singleton sync manager @@ -436,6 +464,14 @@ void Init(int argc, char *argv[]) { } manager.Init(); } + +#ifdef TEST +/*! \brief sets a mock to the manager for testing purposes */ +void SetMock(test::Mock& mock) { + manager.SetMock(mock); +} +#endif + /*! \brief finalize syncrhonization module */ void Finalize(void) { manager.Shutdown(); diff --git a/src/mock.h b/src/mock.h index 71e65dc0a..c6bfd89fd 100644 --- a/src/mock.h +++ b/src/mock.h @@ -16,7 +16,7 @@ namespace test { class Mock { - typedef std::map > Map; + typedef std::map > Map; public: @@ -27,61 +27,64 @@ public: } // record methods - - inline void OnAllReduce(int rank, int code) { - utils::Check(record, "Not in record state"); - Map::iterator it = allReduce.find(rank); - if (it == allReduce.end()) { - std::queue aQueue; - allReduce[rank] = aQueue; - } - allReduce[rank].push(code); + inline void OnAllReduce(int rank, bool success) { + onRecord(allReduce, rank, success); } - inline void OnBroadcast() { - utils::Check(record, "Not in record state"); + inline void OnBroadcast(int rank, bool success) { + onRecord(broadcast, rank, success); } - inline void OnLoadCheckpoint() { - utils::Check(record, "Not in record state"); + inline void OnLoadCheckPoint(int rank, bool success) { + onRecord(loadCheckpoint, rank, success); } - inline void OnCheckpoint() { - utils::Check(record, "Not in record state"); + inline void OnCheckPoint(int rank, bool success) { + onRecord(checkpoint, rank, success); } // replay methods - - inline int AllReduce(int rank) { - utils::Check(!record, "Not in replay state"); - utils::Check(allReduce.find(rank) != allReduce.end(), "Not recorded"); - int result = 0; - if (!allReduce[rank].empty()) { - result = allReduce[rank].front(); - allReduce[rank].pop(); - } - return result; + inline bool AllReduce(int rank) { + return onReplay(allReduce, rank); } - inline int Broadcast(int rank) { - utils::Check(!record, "Not in replay state"); - return 0; + inline bool Broadcast(int rank) { + return onReplay(broadcast, rank); } - inline int LoadCheckpoint(int rank) { - utils::Check(!record, "Not in replay state"); - return 0; + inline bool LoadCheckPoint(int rank) { + return onReplay(loadCheckpoint, rank); } - inline int Checkpoint(int rank) { - utils::Check(!record, "Not in replay state"); - return 0; + inline bool CheckPoint(int rank) { + return onReplay(checkpoint, rank); } private: + inline void onRecord(Map& m, int rank, bool success) { + utils::Check(record, "Not in record state"); + Map::iterator it = m.find(rank); + if (it == m.end()) { + std::queue aQueue; + m[rank] = aQueue; + } + m[rank].push(success); + } + + inline bool onReplay(Map& m, int rank) { + utils::Check(!record, "Not in replay state"); + utils::Check(m.find(rank) != m.end(), "Not recorded"); + bool result = true; + if (!m[rank].empty()) { + result = m[rank].front(); + m[rank].pop(); + } + return result; + } + // flag to indicate if the mock is in record state bool record; diff --git a/test/test_allreduce.cpp b/test/test_allreduce.cpp index 43876a4a9..6da800f90 100644 --- a/test/test_allreduce.cpp +++ b/test/test_allreduce.cpp @@ -66,13 +66,13 @@ inline void TestBcast(size_t n, int root) { inline void record(test::Mock& mock, int rank) { switch(rank) { case 0: - mock.OnAllReduce(0, -1); + mock.OnAllReduce(0, false); break; case 1: - mock.OnAllReduce(1, -1); + mock.OnAllReduce(1, false); break; case 2: - mock.OnAllReduce(2, 0); + mock.OnAllReduce(2, true); break; } } @@ -97,9 +97,12 @@ int main(int argc, char *argv[]) { test::Mock mock; record(mock, rank); mock.Replay(); - replay(mock, rank); + //replay(mock, rank); + sync::SetMock(mock); #endif + + printf("[%d] start at %s\n", rank, name.c_str()); TestMax(n); printf("[%d] TestMax pass\n", rank);