From af7281afe3e3a3b934ab3091f5cb58622da203ad Mon Sep 17 00:00:00 2001 From: Chen Qin Date: Tue, 1 Oct 2019 13:36:11 -0700 Subject: [PATCH] unittests mock, cleanup (#111) * cleanup, fix issue involved after remove is_bootstrap parameter * misc * clean * add unittests --- src/CMakeLists.txt | 16 +++++++++++++- src/allreduce_mock.h | 19 +++++++++++------ src/allreduce_robust.cc | 23 +++++++------------- src/allreduce_robust.h | 17 ++++++--------- test/cpp/CMakeLists.txt | 3 ++- test/cpp/allreduce_mock_test.cpp | 36 ++++++++++++++++++++++++++++++++ test/model_recover.cc | 2 +- 7 files changed, 81 insertions(+), 35 deletions(-) create mode 100644 test/cpp/allreduce_mock_test.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cc120451c..a7aa5d03b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,6 +1,7 @@ option(DMLC_ROOT "Specify root of external dmlc core.") add_library(allreduce_base "") +add_library(allreduce_mock "") target_sources( allreduce_base @@ -9,9 +10,22 @@ target_sources( PUBLIC ${CMAKE_CURRENT_LIST_DIR}/allreduce_base.h ) +target_sources( + allreduce_mock + PRIVATE + allreduce_robust.cc + PUBLIC + ${CMAKE_CURRENT_LIST_DIR}/allreduce_mock.h +) target_include_directories( allreduce_base PUBLIC - ${DMLC_ROOT}/include + ${DMLC_ROOT}/include ${CMAKE_CURRENT_LIST_DIR}/../../include) + +target_include_directories( + allreduce_mock + PUBLIC + ${DMLC_ROOT}/include + ${CMAKE_CURRENT_LIST_DIR}/../../include) diff --git a/src/allreduce_mock.h b/src/allreduce_mock.h index 1ab9f8632..ec535483f 100644 --- a/src/allreduce_mock.h +++ b/src/allreduce_mock.h @@ -48,16 +48,23 @@ class AllreduceMock : public AllreduceRobust { size_t count, ReduceFunction reducer, PreprocFunction prepare_fun, - void *prepare_arg) { + void *prepare_arg, + const char* _file = _FILE, + const int _line = _LINE, + const char* _caller = _CALLER) { this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce"); double tstart = utils::GetTime(); AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes, - count, reducer, prepare_fun, prepare_arg); + count, reducer, prepare_fun, prepare_arg, + _file, _line, _caller); tsum_allreduce += utils::GetTime() - tstart; } - virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) { + virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root, + const char* _file = _FILE, + const int _line = _LINE, + const char* _caller = _CALLER) { this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast"); - AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root); + AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller); } virtual int LoadCheckPoint(Serializable *global_model, Serializable *local_model) { @@ -168,8 +175,8 @@ class AllreduceMock : public AllreduceRobust { inline void Verify(const MockKey &key, const char *name) { if (mock_map.count(key) != 0) { num_trial += 1; - fprintf(stderr, "[%d]@@@Hit Mock Error:%s\n", rank, name); - exit(-2); + // data processing frameworks runs on shared process + utils::Error("[%d]@@@Hit Mock Error:%s\n", rank, name); } } }; diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index 2ed004f42..b9ddf71b7 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -113,7 +113,7 @@ int AllreduceRobust::SetBootstrapCache(const std::string &key, const void *buf, } int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf, - const size_t type_nbytes, const size_t count, const bool byref) { + const size_t type_nbytes, const size_t count) { // as requester sync with rest of nodes on latest cache content if (!RecoverExec(NULL, 0, ActionSummary::kLoadBootstrapCache, seq_counter, cur_cache_seq)) return -1; @@ -136,14 +136,7 @@ int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf, utils::Assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index"); utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested"); utils::Assert(siz > 0, "cache size should be greater than 0"); - - // immutable cache, save copy time by pointer manipulation - if (byref) { - buf = temp; - } else { - std::memcpy(buf, temp, type_nbytes*count); - } - + std::memcpy(buf, temp, type_nbytes*count); return 0; } @@ -184,7 +177,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_, // try fetch bootstrap allreduce results from cache if (!checkpoint_loaded && rabit_bootstrap_cache && - GetBootstrapCache(key, sendrecvbuf_, type_nbytes, count, true) != -1) return; + GetBootstrapCache(key, sendrecvbuf_, type_nbytes, count) != -1) return; double start = utils::GetTime(); bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter, cur_cache_seq); @@ -244,8 +237,7 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root, + std::string(_caller) + "#" +std::to_string(total_size) + "@" + std::to_string(root); // try fetch bootstrap allreduce results from cache if (!checkpoint_loaded && rabit_bootstrap_cache && - GetBootstrapCache(key, sendrecvbuf_, total_size, 1, true) != -1) return; - + GetBootstrapCache(key, sendrecvbuf_, total_size, 1) != -1) return; double start = utils::GetTime(); bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter, cur_cache_seq); // now we are free to remove the last result, if any @@ -1171,9 +1163,10 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno, // if all nodes are requester in load cache, skip if (act.load_cache(SeqType::kCache)) return false; - // only restore when at least one pair of max_seq are different - if (act.diff_seq(SeqType::kCache)) { - // if restore cache failed, retry from what's left + // bootstrap cache always restore before loadcheckpoint + // requester always have seq diff with non requester + if (act.diff_seq()) { + // restore cache failed, retry from what's left if (TryRestoreCache(req.load_cache(), act.seqno(), act.seqno(SeqType::kCache)) != kSuccess) continue; } diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index 30c7d9866..7704a31c6 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -49,7 +49,7 @@ class AllreduceRobust : public AllreduceBase { * \param buflen total number of bytes */ int GetBootstrapCache(const std::string &key, void *buf, const size_t type_nbytes, - const size_t count, const bool byref = false); + const size_t count); /*! * \brief perform in-place allreduce, on sendrecvbuf * this function is NOT thread-safe @@ -255,9 +255,8 @@ class AllreduceRobust : public AllreduceBase { return (code & kCheckAck) != 0; } // whether the operation set contains different sequence number - inline bool diff_seq(SeqType t = SeqType::kSeq) const { - int code = t == SeqType::kSeq ? seqcode : maxseqcode; - return (code & kDiffSeq) != 0; + inline bool diff_seq() const { + return (seqcode & kDiffSeq) != 0; } // returns the operation flag of the result inline int flag(SeqType t = SeqType::kSeq) const { @@ -266,11 +265,10 @@ class AllreduceRobust : public AllreduceBase { } // print flags in user friendly way inline void print_flags(int rank, std::string prefix ) { - utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|%d|\n", + utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n", rank, prefix.c_str(), seqno(), check_point(), check_ack(), load_cache(), - diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache), - diff_seq(SeqType::kCache)); + diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache)); } // reducer for Allreduce, get the result ActionSummary from all nodes inline static void Reducer(const void *src_, void *dst_, @@ -286,12 +284,9 @@ class AllreduceRobust : public AllreduceBase { int role_flag = src[i].flag(SeqType::kCache) & dst[i].flag(SeqType::kCache); // if seqno is different in src and destination int seq_diff_flag = src[i].seqno() != dst[i].seqno() ? kDiffSeq : 0; - // if cache seqno is different in src and destination - int cache_diff_flag = - src[i].seqno(SeqType::kCache) != dst[i].seqno(SeqType::kCache) ? kDiffSeq : 0; // apply or to both seq diff flag as well as cache seq diff flag dst[i] = ActionSummary(action_flag | seq_diff_flag, - role_flag | cache_diff_flag, min_seqno, max_seqno); + role_flag, min_seqno, max_seqno); } } diff --git a/test/cpp/CMakeLists.txt b/test/cpp/CMakeLists.txt index 31fa764cb..987f4c01e 100644 --- a/test/cpp/CMakeLists.txt +++ b/test/cpp/CMakeLists.txt @@ -3,12 +3,13 @@ find_package(GTest REQUIRED) add_executable( unit_tests allreduce_base_test.cpp + allreduce_mock_test.cpp test_main.cpp) target_link_libraries( unit_tests GTest::GTest GTest::Main - rabit_base) + rabit_base rabit_mock) target_include_directories(unit_tests PUBLIC "$" diff --git a/test/cpp/allreduce_mock_test.cpp b/test/cpp/allreduce_mock_test.cpp new file mode 100644 index 000000000..e659d8ea8 --- /dev/null +++ b/test/cpp/allreduce_mock_test.cpp @@ -0,0 +1,36 @@ +#define RABIT_CXXTESTDEFS_H +#include + +#include +#include +#include "../../src/allreduce_mock.h" + +TEST(allreduce_mock, mock_allreduce) +{ + rabit::engine::AllreduceMock m; + + std::string mock_str = "mock=0,0,0,0"; + char cmd[mock_str.size()+1]; + std::copy(mock_str.begin(), mock_str.end(), cmd); + cmd[mock_str.size()] = '\0'; + + char* argv[] = {cmd}; + m.Init(1, argv); + m.rank = 0; + EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), ""); +} + +TEST(allreduce_mock, mock_broadcast) +{ + rabit::engine::AllreduceMock m; + std::string mock_str = "mock=0,1,2,0"; + char cmd[mock_str.size()+1]; + std::copy(mock_str.begin(), mock_str.end(), cmd); + cmd[mock_str.size()] = '\0'; + char* argv[] = {cmd}; + m.Init(1, argv); + m.rank = 0; + m.version_number=1; + m.seq_counter=2; + EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), ""); +} diff --git a/test/model_recover.cc b/test/model_recover.cc index 3745caf5a..48a4da99c 100644 --- a/test/model_recover.cc +++ b/test/model_recover.cc @@ -96,7 +96,7 @@ int main(int argc, char *argv[]) { std::string name = rabit::GetProcessorName(); int max_rank = rank; - rabit::Allreduce(&max_rank, sizeof(int)); + rabit::Allreduce(&max_rank, 1); utils::Check(max_rank == nproc - 1, "max rank is world size-1"); Model model;