From c565104491ec998304c171729a9e8675bcaea335 Mon Sep 17 00:00:00 2001 From: nachocano Date: Wed, 26 Nov 2014 17:24:29 -0800 Subject: [PATCH] adding some references to mock inside TEST preprocessor directive. It shouldn't be an assert because it shutdowns the process. Instead should check on the value and return some sort of error, so that we can recover. The mock contains queues, indexed by the rank of the process. For each node, you can configure the behavior you expect (success or failure for now) when you call any of the methods (AllReduce, Broadcast, LoadCheckPoint and CheckPoint)... If you call several times AllReduce, the outputs will pop from the queue, i.e., first you can retrieve a success, then a failure and so on. Pretty basic for now, need to tune it better --- src/allreduce.h | 8 +++++ src/engine.h | 9 +++++ src/engine_tcp.cpp | 36 ++++++++++++++++++++ src/mock.h | 73 +++++++++++++++++++++-------------------- test/test_allreduce.cpp | 11 ++++--- 5 files changed, 98 insertions(+), 39 deletions(-) diff --git a/src/allreduce.h b/src/allreduce.h index c9bd0e579..9f150dcf4 100644 --- a/src/allreduce.h +++ b/src/allreduce.h @@ -4,6 +4,9 @@ * \author Tianqi Chen, Nacho, Tianyi */ #include "./engine.h" +#ifdef TEST + #include "./mock.h" +#endif /*! \brief namespace of all reduce */ namespace sync { @@ -43,6 +46,11 @@ void Init(int argc, char *argv[]) { void Finalize(void) { engine::Finalize(); } +#ifdef TEST +void SetMock(test::Mock& mock) { + engine::SetMock(mock); +} +#endif /*! \brief get rank of current process */ inline int GetRank(void) { return engine::GetEngine()->GetRank(); diff --git a/src/engine.h b/src/engine.h index ca928b22a..852e9187a 100644 --- a/src/engine.h +++ b/src/engine.h @@ -6,6 +6,10 @@ * \author Tianqi Chen, Nacho, Tianyi */ #include "./io.h" +#ifdef TEST + #include "./mock.h" +#endif + namespace MPI { /*! \brief MPI data type just to be compatible with MPI reduce function*/ @@ -76,5 +80,10 @@ void Init(int argc, char *argv[]); void Finalize(void); /*! \brief singleton method to get engine */ IEngine *GetEngine(void); + +#ifdef TEST +void SetMock(test::Mock& mock); +#endif + } // namespace engine #endif // ALLREDUCE_ENGINE_H diff --git a/src/engine_tcp.cpp b/src/engine_tcp.cpp index a0506129d..957319db4 100644 --- a/src/engine_tcp.cpp +++ b/src/engine_tcp.cpp @@ -12,6 +12,9 @@ #include #include "./engine.h" #include "./socket.h" +#ifdef TEST + #include "./mock.h" +#endif namespace MPI { class Datatype { @@ -37,6 +40,13 @@ class SyncManager : public IEngine { } ~SyncManager(void) { } + + #ifdef TEST + inline void SetMock(test::Mock& mock) { + this->mock = mock; + } + #endif + inline void Shutdown(void) { for (size_t i = 0; i < links.size(); ++i) { links[i].sock.Close(); @@ -168,6 +178,9 @@ class SyncManager : public IEngine { size_t type_nbytes, size_t count, ReduceFunction reducer) { + #ifdef TEST + utils::Assert(mock.AllReduce(this->rank), "Error returned by mock when reducing"); + #endif if (links.size() == 0) return; // total size of message const size_t total_size = type_nbytes * count; @@ -275,6 +288,10 @@ class SyncManager : public IEngine { * \param root the root worker id to broadcast the data */ virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) { + #ifdef TEST + utils::Assert(mock.Broadcast(this->rank), "Error returned by mock when broadcasting"); + #endif + if (links.size() == 0) return; // number of links const int nlink = static_cast(links.size()); @@ -329,9 +346,15 @@ class SyncManager : public IEngine { } } virtual bool LoadCheckPoint(utils::ISerializable *p_model) { + #ifdef TEST + utils::Assert(mock.LoadCheckPoint(this->rank), "Error returned by mock when loading checkpoint"); + #endif return false; } virtual void CheckPoint(const utils::ISerializable &model) { + #ifdef TEST + utils::Assert(mock.CheckPoint(this->rank), "Error returned by mock when checkpointing"); + #endif } private: @@ -421,6 +444,11 @@ class SyncManager : public IEngine { std::vector links; // select helper utils::SelectHelper selecter; + + #ifdef TEST + // mock to test + test::Mock mock; + #endif }; // singleton sync manager @@ -436,6 +464,14 @@ void Init(int argc, char *argv[]) { } manager.Init(); } + +#ifdef TEST +/*! \brief sets a mock to the manager for testing purposes */ +void SetMock(test::Mock& mock) { + manager.SetMock(mock); +} +#endif + /*! \brief finalize syncrhonization module */ void Finalize(void) { manager.Shutdown(); diff --git a/src/mock.h b/src/mock.h index 71e65dc0a..c6bfd89fd 100644 --- a/src/mock.h +++ b/src/mock.h @@ -16,7 +16,7 @@ namespace test { class Mock { - typedef std::map > Map; + typedef std::map > Map; public: @@ -27,61 +27,64 @@ public: } // record methods - - inline void OnAllReduce(int rank, int code) { - utils::Check(record, "Not in record state"); - Map::iterator it = allReduce.find(rank); - if (it == allReduce.end()) { - std::queue aQueue; - allReduce[rank] = aQueue; - } - allReduce[rank].push(code); + inline void OnAllReduce(int rank, bool success) { + onRecord(allReduce, rank, success); } - inline void OnBroadcast() { - utils::Check(record, "Not in record state"); + inline void OnBroadcast(int rank, bool success) { + onRecord(broadcast, rank, success); } - inline void OnLoadCheckpoint() { - utils::Check(record, "Not in record state"); + inline void OnLoadCheckPoint(int rank, bool success) { + onRecord(loadCheckpoint, rank, success); } - inline void OnCheckpoint() { - utils::Check(record, "Not in record state"); + inline void OnCheckPoint(int rank, bool success) { + onRecord(checkpoint, rank, success); } // replay methods - - inline int AllReduce(int rank) { - utils::Check(!record, "Not in replay state"); - utils::Check(allReduce.find(rank) != allReduce.end(), "Not recorded"); - int result = 0; - if (!allReduce[rank].empty()) { - result = allReduce[rank].front(); - allReduce[rank].pop(); - } - return result; + inline bool AllReduce(int rank) { + return onReplay(allReduce, rank); } - inline int Broadcast(int rank) { - utils::Check(!record, "Not in replay state"); - return 0; + inline bool Broadcast(int rank) { + return onReplay(broadcast, rank); } - inline int LoadCheckpoint(int rank) { - utils::Check(!record, "Not in replay state"); - return 0; + inline bool LoadCheckPoint(int rank) { + return onReplay(loadCheckpoint, rank); } - inline int Checkpoint(int rank) { - utils::Check(!record, "Not in replay state"); - return 0; + inline bool CheckPoint(int rank) { + return onReplay(checkpoint, rank); } private: + inline void onRecord(Map& m, int rank, bool success) { + utils::Check(record, "Not in record state"); + Map::iterator it = m.find(rank); + if (it == m.end()) { + std::queue aQueue; + m[rank] = aQueue; + } + m[rank].push(success); + } + + inline bool onReplay(Map& m, int rank) { + utils::Check(!record, "Not in replay state"); + utils::Check(m.find(rank) != m.end(), "Not recorded"); + bool result = true; + if (!m[rank].empty()) { + result = m[rank].front(); + m[rank].pop(); + } + return result; + } + // flag to indicate if the mock is in record state bool record; diff --git a/test/test_allreduce.cpp b/test/test_allreduce.cpp index 43876a4a9..6da800f90 100644 --- a/test/test_allreduce.cpp +++ b/test/test_allreduce.cpp @@ -66,13 +66,13 @@ inline void TestBcast(size_t n, int root) { inline void record(test::Mock& mock, int rank) { switch(rank) { case 0: - mock.OnAllReduce(0, -1); + mock.OnAllReduce(0, false); break; case 1: - mock.OnAllReduce(1, -1); + mock.OnAllReduce(1, false); break; case 2: - mock.OnAllReduce(2, 0); + mock.OnAllReduce(2, true); break; } } @@ -97,9 +97,12 @@ int main(int argc, char *argv[]) { test::Mock mock; record(mock, rank); mock.Replay(); - replay(mock, rank); + //replay(mock, rank); + sync::SetMock(mock); #endif + + printf("[%d] start at %s\n", rank, name.c_str()); TestMax(n); printf("[%d] TestMax pass\n", rank);