From faed8285cd4846afed63e21086577c76da97be3c Mon Sep 17 00:00:00 2001 From: nachocano Date: Fri, 28 Nov 2014 00:16:35 -0800 Subject: [PATCH] execute it like ./test.sh 4 4000 testcase0.conf to obtain a successful execution updating mock. It now wraps the calls to sync and reads config from configuration file. I believe it's better not to use the preprocessor directive, i.e. not to put any test code in the engine_tcp. I just call the mock in the test_allreduce file. It's a file purely for testing purposes, so it's fine to use the mock there. --- .gitignore | 2 +- src/allreduce.h | 12 +-- src/config.h | 195 ++++++++++++++++++++++++++++++++++++++++ src/engine.h | 7 -- src/engine_tcp.cpp | 33 ------- src/mock.h | 102 ++++++++++----------- test/Makefile | 4 - test/test.sh | 6 +- test/test_allreduce.cpp | 44 +++------ test/testcase0.conf | 1 + test/testcase1.conf | 12 +++ 11 files changed, 274 insertions(+), 144 deletions(-) create mode 100644 src/config.h create mode 100644 test/testcase0.conf create mode 100644 test/testcase1.conf diff --git a/.gitignore b/.gitignore index 2922a01e6..f087cd689 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,4 @@ *.app *~ *.pyc -test \ No newline at end of file +test_allreduce \ No newline at end of file diff --git a/src/allreduce.h b/src/allreduce.h index ef3fe589b..264541211 100644 --- a/src/allreduce.h +++ b/src/allreduce.h @@ -1,12 +1,11 @@ +#ifndef ALLREDUCE_H +#define ALLREDUCE_H /*! * \file allreduce.h * \brief This file defines a template wrapper of engine to ensure * \author Tianqi Chen, Nacho, Tianyi */ #include "./engine.h" -#ifdef TEST - #include "./mock.h" -#endif /*! \brief namespace of all reduce */ namespace sync { @@ -46,11 +45,7 @@ void Init(int argc, char *argv[]) { void Finalize(void) { engine::Finalize(); } -#ifdef TEST -void SetMock(const test::Mock& mock) { - engine::SetMock(mock); -} -#endif + /*! \brief get rank of current process */ inline int GetRank(void) { return engine::GetEngine()->GetRank(); @@ -113,3 +108,4 @@ inline void CheckPoint(const utils::ISerializable &model) { engine::GetEngine()->CheckPoint(model); } } // namespace allreduce +#endif // ALLREDUCE_H diff --git a/src/config.h b/src/config.h new file mode 100644 index 000000000..45da45bdb --- /dev/null +++ b/src/config.h @@ -0,0 +1,195 @@ +#ifndef ALLREDUCE_UTILS_CONFIG_H_ +#define ALLREDUCE_UTILS_CONFIG_H_ +/*! + * \file config.h + * \brief helper class to load in configures from file + * \author Tianqi Chen + */ +#include +#include +#include +#include +#include +#include "./utils.h" + +namespace utils { +/*! + * \brief base implementation of config reader + */ +class ConfigReaderBase { + public: + /*! + * \brief get current name, called after Next returns true + * \return current parameter name + */ + inline const char *name(void) const { + return s_name; + } + /*! + * \brief get current value, called after Next returns true + * \return current parameter value + */ + inline const char *val(void) const { + return s_val; + } + /*! + * \brief move iterator to next position + * \return true if there is value in next position + */ + inline bool Next(void) { + while (!this->IsEnd()) { + GetNextToken(s_name); + if (s_name[0] == '=') return false; + if (GetNextToken( s_buf ) || s_buf[0] != '=') return false; + if (GetNextToken( s_val ) || s_val[0] == '=') return false; + return true; + } + return false; + } + // called before usage + inline void Init(void) { + ch_buf = this->GetChar(); + } + + protected: + /*! + * \brief to be implemented by subclass, + * get next token, return EOF if end of file + */ + virtual char GetChar(void) = 0; + /*! \brief to be implemented by child, check if end of stream */ + virtual bool IsEnd(void) = 0; + + private: + char ch_buf; + char s_name[100000], s_val[100000], s_buf[100000]; + + inline void SkipLine(void) { + do { + ch_buf = this->GetChar(); + } while (ch_buf != EOF && ch_buf != '\n' && ch_buf != '\r'); + } + + inline void ParseStr(char tok[]) { + int i = 0; + while ((ch_buf = this->GetChar()) != EOF) { + switch (ch_buf) { + case '\\': tok[i++] = this->GetChar(); break; + case '\"': tok[i++] = '\0'; return; + case '\r': + case '\n': Error("ConfigReader: unterminated string"); + default: tok[i++] = ch_buf; + } + } + Error("ConfigReader: unterminated string"); + } + inline void ParseStrML(char tok[]) { + int i = 0; + while ((ch_buf = this->GetChar()) != EOF) { + switch (ch_buf) { + case '\\': tok[i++] = this->GetChar(); break; + case '\'': tok[i++] = '\0'; return; + default: tok[i++] = ch_buf; + } + } + Error("unterminated string"); + } + // return newline + inline bool GetNextToken(char tok[]) { + int i = 0; + bool new_line = false; + while (ch_buf != EOF) { + switch (ch_buf) { + case '#' : SkipLine(); new_line = true; break; + case '\"': + if (i == 0) { + ParseStr(tok); ch_buf = this->GetChar(); return new_line; + } else { + Error("ConfigReader: token followed directly by string"); + } + case '\'': + if (i == 0) { + ParseStrML( tok ); ch_buf = this->GetChar(); return new_line; + } else { + Error("ConfigReader: token followed directly by string"); + } + case '=': + if (i == 0) { + ch_buf = this->GetChar(); + tok[0] = '='; + tok[1] = '\0'; + } else { + tok[i] = '\0'; + } + return new_line; + case '\r': + case '\n': + if (i == 0) new_line = true; + case '\t': + case ' ' : + ch_buf = this->GetChar(); + if (i > 0) { + tok[i] = '\0'; + return new_line; + } + break; + default: + tok[i++] = ch_buf; + ch_buf = this->GetChar(); + break; + } + } + return true; + } +}; +/*! + * \brief an iterator use stream base, allows use all types of istream + */ +class ConfigStreamReader: public ConfigReaderBase { + public: + /*! + * \brief constructor + * \param istream input stream + */ + explicit ConfigStreamReader(std::istream &fin) : fin(fin) {} + + protected: + virtual char GetChar(void) { + return fin.get(); + } + /*! \brief to be implemented by child, check if end of stream */ + virtual bool IsEnd(void) { + return fin.eof(); + } + + private: + std::istream &fin; +}; + +/*! + * \brief an iterator that iterates over a configure file and gets the configures + */ +class ConfigIterator: public ConfigStreamReader { + public: + /*! + * \brief constructor + * \param fname name of configure file + */ + explicit ConfigIterator(const char *fname) : ConfigStreamReader(fi) { + fi.open(fname); + if (fi.fail()) { + utils::Error("cannot open file %s", fname); + } + ConfigReaderBase::Init(); + } + /*! \brief destructor */ + ~ConfigIterator(void) { + fi.close(); + } + + private: + std::ifstream fi; +}; +} // namespace utils + +#endif // ALLREDUCE_UTILS_CONFIG_H_ diff --git a/src/engine.h b/src/engine.h index dc3a14049..ce8603d6f 100644 --- a/src/engine.h +++ b/src/engine.h @@ -6,9 +6,6 @@ * \author Tianqi Chen, Nacho, Tianyi */ #include "./io.h" -#ifdef TEST - #include "./mock.h" -#endif namespace MPI { @@ -81,9 +78,5 @@ void Finalize(void); /*! \brief singleton method to get engine */ IEngine *GetEngine(void); -#ifdef TEST -void SetMock(const test::Mock& mock); -#endif - } // namespace engine #endif // ALLREDUCE_ENGINE_H diff --git a/src/engine_tcp.cpp b/src/engine_tcp.cpp index c34d7d86b..4cbbe384f 100644 --- a/src/engine_tcp.cpp +++ b/src/engine_tcp.cpp @@ -12,9 +12,6 @@ #include #include "./engine.h" #include "./socket.h" -#ifdef TEST - #include "./mock.h" -#endif namespace MPI { class Datatype { @@ -41,12 +38,6 @@ class SyncManager : public IEngine { ~SyncManager(void) { } - #ifdef TEST - inline void SetMock(const test::Mock& mock) { - this->mock = mock; - } - #endif - inline void Shutdown(void) { for (size_t i = 0; i < links.size(); ++i) { links[i].sock.Close(); @@ -180,9 +171,6 @@ class SyncManager : public IEngine { size_t type_nbytes, size_t count, ReduceFunction reducer) { - #ifdef TEST - utils::Assert(mock.AllReduce(this->rank), "Error returned by mock when reducing"); - #endif if (links.size() == 0) return; // total size of message const size_t total_size = type_nbytes * count; @@ -304,10 +292,6 @@ class SyncManager : public IEngine { * \param root the root worker id to broadcast the data */ virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) { - #ifdef TEST - utils::Assert(mock.Broadcast(this->rank), "Error returned by mock when broadcasting"); - #endif - if (links.size() == 0) return; // number of links const int nlink = static_cast(links.size()); @@ -368,15 +352,9 @@ class SyncManager : public IEngine { } } virtual bool LoadCheckPoint(utils::ISerializable *p_model) { - #ifdef TEST - utils::Assert(mock.LoadCheckPoint(this->rank), "Error returned by mock when loading checkpoint"); - #endif return false; } virtual void CheckPoint(const utils::ISerializable &model) { - #ifdef TEST - utils::Assert(mock.CheckPoint(this->rank), "Error returned by mock when checkpointing"); - #endif } private: @@ -479,10 +457,6 @@ class SyncManager : public IEngine { // select helper utils::SelectHelper selecter; - #ifdef TEST - // mock to test - test::Mock mock; - #endif }; // singleton sync manager @@ -499,13 +473,6 @@ void Init(int argc, char *argv[]) { manager.Init(); } -#ifdef TEST -/*! \brief sets a mock to the manager for testing purposes */ -void SetMock(const test::Mock& mock) { - manager.SetMock(mock); -} -#endif - /*! \brief finalize syncrhonization module */ void Finalize(void) { manager.Shutdown(); diff --git a/src/mock.h b/src/mock.h index c6bfd89fd..8eb5629e6 100644 --- a/src/mock.h +++ b/src/mock.h @@ -5,9 +5,8 @@ * \brief This file defines a mock object to test the system * \author Tianqi Chen, Nacho, Tianyi */ -#include "./engine.h" -#include "./utils.h" -#include +#include "./allreduce.h" +#include "./config.h" #include @@ -16,82 +15,73 @@ namespace test { class Mock { - typedef std::map > Map; public: - Mock() : record(true) {} - - inline void Replay() { - record = false; + Mock(const int& rank, char *config) : rank(rank) { + Init(config); } - // record methods - inline void OnAllReduce(int rank, bool success) { - onRecord(allReduce, rank, success); + template + inline void AllReduce(float *sendrecvbuf, size_t count) { + utils::Assert(verify(allReduce), "[%d] error when calling allReduce", rank); + sync::AllReduce(sendrecvbuf, count); } - inline void OnBroadcast(int rank, bool success) { - onRecord(broadcast, rank, success); + inline bool LoadCheckPoint(utils::ISerializable *p_model) { + utils::Assert(verify(loadCheckpoint), "[%d] error when loading checkpoint", rank); + return sync::LoadCheckPoint(p_model); } - inline void OnLoadCheckPoint(int rank, bool success) { - onRecord(loadCheckpoint, rank, success); + inline void CheckPoint(const utils::ISerializable &model) { + utils::Assert(verify(checkpoint), "[%d] error when checkpointing", rank); + sync::CheckPoint(model); } - inline void OnCheckPoint(int rank, bool success) { - onRecord(checkpoint, rank, success); + inline void Broadcast(std::string *sendrecv_data, int root) { + utils::Assert(verify(broadcast), "[%d] error when broadcasting", rank); + sync::Bcast(sendrecv_data, root); + } - - // replay methods - inline bool AllReduce(int rank) { - return onReplay(allReduce, rank); - } - - inline bool Broadcast(int rank) { - return onReplay(broadcast, rank); - } - - inline bool LoadCheckPoint(int rank) { - return onReplay(loadCheckpoint, rank); - } - - inline bool CheckPoint(int rank) { - return onReplay(checkpoint, rank); - } - - private: - inline void onRecord(Map& m, int rank, bool success) { - utils::Check(record, "Not in record state"); - Map::iterator it = m.find(rank); - if (it == m.end()) { - std::queue aQueue; - m[rank] = aQueue; + inline void Init(char* config) { + utils::ConfigIterator itr(config); + while (itr.Next()) { + char round[4], node_rank[4]; + sscanf(itr.name(), "%[^_]_%s", round, node_rank); + int i_round = atoi(round); + if (i_round == 1) { + int i_node_rank = atoi(node_rank); + if (i_node_rank == rank) { + printf("[%d] round %d, value %s\n", rank, i_round, itr.val()); + if (strcmp("allreduce", itr.val())) record(allReduce); + else if (strcmp("broadcast", itr.val())) record(broadcast); + else if (strcmp("loadcheckpoint", itr.val())) record(loadCheckpoint); + else if (strcmp("checkpoint", itr.val())) record(checkpoint); + } + } } - m[rank].push(success); } - inline bool onReplay(Map& m, int rank) { - utils::Check(!record, "Not in replay state"); - utils::Check(m.find(rank) != m.end(), "Not recorded"); + inline void record(std::map& m) { + m[rank] = false; + } + + inline bool verify(std::map& m) { bool result = true; - if (!m[rank].empty()) { - result = m[rank].front(); - m[rank].pop(); + if (m.find(rank) != m.end()) { + result = m[rank]; } return result; } - // flag to indicate if the mock is in record state - bool record; - - Map allReduce; - Map broadcast; - Map loadCheckpoint; - Map checkpoint; + int rank; + std::map allReduce; + std::map broadcast; + std::map loadCheckpoint; + std::map checkpoint; }; } diff --git a/test/Makefile b/test/Makefile index ef752de74..c71db5f86 100644 --- a/test/Makefile +++ b/test/Makefile @@ -9,10 +9,6 @@ ifeq ($(no_omp),1) else CFLAGS += -fopenmp endif -ifeq ($(test),1) - CFLAGS += -DTEST -endif - # specify tensor path BIN = test_allreduce diff --git a/test/test.sh b/test/test.sh index 5e5ef546d..753085724 100755 --- a/test/test.sh +++ b/test/test.sh @@ -1,7 +1,7 @@ #!/bin/bash -if [ "$#" -ne 2 ]; +if [ "$#" -ne 3 ]; then - echo "Usage " + echo "Usage " exit -1 fi -../submit_job_tcp.py $1 test_allreduce $2 +../submit_job_tcp.py $1 test_allreduce $2 $3 diff --git a/test/test_allreduce.cpp b/test/test_allreduce.cpp index 185d8ea6d..807b1a1bb 100644 --- a/test/test_allreduce.cpp +++ b/test/test_allreduce.cpp @@ -8,7 +8,7 @@ using namespace sync; -inline void TestMax(size_t n) { +inline void TestMax(test::Mock &mock, size_t n) { int rank = sync::GetRank(); int nproc = sync::GetWorldSize(); @@ -16,7 +16,7 @@ inline void TestMax(size_t n) { for (size_t i = 0; i < ndata.size(); ++i) { ndata[i] = (i * (rank+1)) % 111; } - sync::AllReduce(&ndata[0], ndata.size()); + mock.AllReduce(&ndata[0], ndata.size()); for (size_t i = 0; i < ndata.size(); ++i) { float rmax = (i * 1) % 111; for (int r = 0; r < nproc; ++r) { @@ -26,7 +26,7 @@ inline void TestMax(size_t n) { } } -inline void TestSum(size_t n) { +inline void TestSum(test::Mock &mock, size_t n) { int rank = sync::GetRank(); int nproc = sync::GetWorldSize(); const int z = 131; @@ -35,7 +35,7 @@ inline void TestSum(size_t n) { for (size_t i = 0; i < ndata.size(); ++i) { ndata[i] = (i * (rank+1)) % z; } - sync::AllReduce(&ndata[0], ndata.size()); + mock.AllReduce(&ndata[0], ndata.size()); for (size_t i = 0; i < ndata.size(); ++i) { float rsum = 0.0f; for (int r = 0; r < nproc; ++r) { @@ -46,7 +46,7 @@ inline void TestSum(size_t n) { } } -inline void TestBcast(size_t n, int root) { +inline void TestBcast(test::Mock &mock, size_t n, int root) { int rank = sync::GetRank(); std::string s; s.resize(n); for (size_t i = 0; i < n; ++i) { @@ -55,31 +55,16 @@ inline void TestBcast(size_t n, int root) { std::string res; if (root == rank) { res = s; - sync::Bcast(&res, root); + mock.Broadcast(&res, root); } else { - sync::Bcast(&res, root); + mock.Broadcast(&res, root); } utils::Check(res == s, "[%d] TestBcast fail", rank); } -// ugly stuff, just to see if it works. To be removed -inline void Record(test::Mock& mock, const int rank) { - switch(rank) { - case 0: - mock.OnAllReduce(0, false); - break; - case 1: - mock.OnAllReduce(1, false); - break; - case 2: - mock.OnAllReduce(2, true); - break; - } -} - int main(int argc, char *argv[]) { - if (argc < 2) { - printf("Usage: \n"); + if (argc < 3) { + printf("Usage: \n"); return 0; } int n = atoi(argv[1]); @@ -87,17 +72,12 @@ int main(int argc, char *argv[]) { int rank = sync::GetRank(); std::string name = sync::GetProcessorName(); - #ifdef TEST - test::Mock mock; - Record(mock, rank); - mock.Replay(); - sync::SetMock(mock); - #endif + test::Mock mock(rank, argv[2]); printf("[%d] start at %s\n", rank, name.c_str()); - TestMax(n); + TestMax(mock, n); printf("[%d] TestMax pass\n", rank); - TestSum(n); + TestSum(mock, n); printf("[%d] TestSum pass\n", rank); sync::Finalize(); printf("[%d] all check pass\n", rank); diff --git a/test/testcase0.conf b/test/testcase0.conf new file mode 100644 index 000000000..4c324d282 --- /dev/null +++ b/test/testcase0.conf @@ -0,0 +1 @@ +# Test Case 0 -> nothing fails \ No newline at end of file diff --git a/test/testcase1.conf b/test/testcase1.conf new file mode 100644 index 000000000..f5aa31892 --- /dev/null +++ b/test/testcase1.conf @@ -0,0 +1,12 @@ +# Test Case example config +# You configure which methods should fail +# Format _ = +# can be one of the following = allreduce, broadcast, loadcheckpoint, checkpoint + +1_0 = allreduce +1_1 = broadcast +1_2 = loadcheckpoint +1_3 = checkpoint + +2_0 = allreduce +2_2 = checkpoint