From 54fcff189f00ae7a86e8a289b810d1b8e6384815 Mon Sep 17 00:00:00 2001 From: nachocano Date: Wed, 26 Nov 2014 16:37:23 -0800 Subject: [PATCH] dummy mock for now --- src/mock.h | 96 +++++++++++++++++++++++++++++++++++++++++ test/Makefile | 4 ++ test/test_allreduce.cpp | 31 +++++++++++++ 3 files changed, 131 insertions(+) create mode 100644 src/mock.h diff --git a/src/mock.h b/src/mock.h new file mode 100644 index 000000000..71e65dc0a --- /dev/null +++ b/src/mock.h @@ -0,0 +1,96 @@ +#ifndef ALLREDUCE_MOCK_H +#define ALLREDUCE_MOCK_H +/*! + * \file mock.h + * \brief This file defines a mock object to test the system + * \author Tianqi Chen, Nacho, Tianyi + */ +#include "./engine.h" +#include "./utils.h" +#include + #include + + +/*! \brief namespace of mock */ +namespace test { + +class Mock { + + typedef std::map > Map; + +public: + + Mock() : record(true) {} + + inline void Replay() { + record = false; + } + + // record methods + + inline void OnAllReduce(int rank, int code) { + utils::Check(record, "Not in record state"); + Map::iterator it = allReduce.find(rank); + if (it == allReduce.end()) { + std::queue aQueue; + allReduce[rank] = aQueue; + } + allReduce[rank].push(code); + } + + inline void OnBroadcast() { + utils::Check(record, "Not in record state"); + } + + inline void OnLoadCheckpoint() { + utils::Check(record, "Not in record state"); + } + + inline void OnCheckpoint() { + utils::Check(record, "Not in record state"); + } + + + // replay methods + + inline int AllReduce(int rank) { + utils::Check(!record, "Not in replay state"); + utils::Check(allReduce.find(rank) != allReduce.end(), "Not recorded"); + int result = 0; + if (!allReduce[rank].empty()) { + result = allReduce[rank].front(); + allReduce[rank].pop(); + } + return result; + } + + inline int Broadcast(int rank) { + utils::Check(!record, "Not in replay state"); + return 0; + } + + inline int LoadCheckpoint(int rank) { + utils::Check(!record, "Not in replay state"); + return 0; + } + + inline int Checkpoint(int rank) { + utils::Check(!record, "Not in replay state"); + return 0; + } + + +private: + + // flag to indicate if the mock is in record state + bool record; + + Map allReduce; + Map broadcast; + Map loadCheckpoint; + Map checkpoint; +}; + +} + +#endif // ALLREDUCE_MOCK_H diff --git a/test/Makefile b/test/Makefile index 18f8c7481..78c6095e4 100644 --- a/test/Makefile +++ b/test/Makefile @@ -9,6 +9,10 @@ ifeq ($(no_omp),1) else CFLAGS += -fopenmp endif +ifeq ($(test),1) + CFLAGS += -DTEST +endif + # specify tensor path BIN = test_allreduce diff --git a/test/test_allreduce.cpp b/test/test_allreduce.cpp index 407abf139..43876a4a9 100644 --- a/test/test_allreduce.cpp +++ b/test/test_allreduce.cpp @@ -3,6 +3,8 @@ #include #include #include +#include + using namespace sync; @@ -60,6 +62,27 @@ inline void TestBcast(size_t n, int root) { utils::Check(res == s, "[%d] TestBcast fail", rank); } +// ugly stuff, just to see if it works +inline void record(test::Mock& mock, int rank) { + switch(rank) { + case 0: + mock.OnAllReduce(0, -1); + break; + case 1: + mock.OnAllReduce(1, -1); + break; + case 2: + mock.OnAllReduce(2, 0); + break; + } +} + +// to be removed, should be added in engine tcp +inline void replay(test::Mock& mock, int rank) { + printf("[%d] All reduce %d\n", rank, mock.AllReduce(rank)); + printf("[%d] All reduce %d\n", rank, mock.AllReduce(rank)); +} + int main(int argc, char *argv[]) { if (argc < 2) { printf("Usage: \n"); @@ -69,6 +92,14 @@ int main(int argc, char *argv[]) { sync::Init(argc, argv); int rank = sync::GetRank(); std::string name = sync::GetProcessorName(); + + #ifdef TEST + test::Mock mock; + record(mock, rank); + mock.Replay(); + replay(mock, rank); + #endif + printf("[%d] start at %s\n", rank, name.c_str()); TestMax(n); printf("[%d] TestMax pass\n", rank);