unittests mock, cleanup (#111)

* cleanup, fix issue involved after remove is_bootstrap parameter

* misc

* clean

* add unittests
This commit is contained in:
Chen Qin 2019-10-01 13:36:11 -07:00 committed by Nan Zhu
parent ddcc2d85da
commit af7281afe3
7 changed files with 81 additions and 35 deletions

View File

@ -1,6 +1,7 @@
option(DMLC_ROOT "Specify root of external dmlc core.") option(DMLC_ROOT "Specify root of external dmlc core.")
add_library(allreduce_base "") add_library(allreduce_base "")
add_library(allreduce_mock "")
target_sources( target_sources(
allreduce_base allreduce_base
@ -9,9 +10,22 @@ target_sources(
PUBLIC PUBLIC
${CMAKE_CURRENT_LIST_DIR}/allreduce_base.h ${CMAKE_CURRENT_LIST_DIR}/allreduce_base.h
) )
target_sources(
allreduce_mock
PRIVATE
allreduce_robust.cc
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/allreduce_mock.h
)
target_include_directories( target_include_directories(
allreduce_base allreduce_base
PUBLIC PUBLIC
${DMLC_ROOT}/include ${DMLC_ROOT}/include
${CMAKE_CURRENT_LIST_DIR}/../../include) ${CMAKE_CURRENT_LIST_DIR}/../../include)
target_include_directories(
allreduce_mock
PUBLIC
${DMLC_ROOT}/include
${CMAKE_CURRENT_LIST_DIR}/../../include)

View File

@ -48,16 +48,23 @@ class AllreduceMock : public AllreduceRobust {
size_t count, size_t count,
ReduceFunction reducer, ReduceFunction reducer,
PreprocFunction prepare_fun, PreprocFunction prepare_fun,
void *prepare_arg) { void *prepare_arg,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce"); this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
double tstart = utils::GetTime(); double tstart = utils::GetTime();
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes, AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
count, reducer, prepare_fun, prepare_arg); count, reducer, prepare_fun, prepare_arg,
_file, _line, _caller);
tsum_allreduce += utils::GetTime() - tstart; tsum_allreduce += utils::GetTime() - tstart;
} }
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) { virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast"); this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root); AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller);
} }
virtual int LoadCheckPoint(Serializable *global_model, virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model) { Serializable *local_model) {
@ -168,8 +175,8 @@ class AllreduceMock : public AllreduceRobust {
inline void Verify(const MockKey &key, const char *name) { inline void Verify(const MockKey &key, const char *name) {
if (mock_map.count(key) != 0) { if (mock_map.count(key) != 0) {
num_trial += 1; num_trial += 1;
fprintf(stderr, "[%d]@@@Hit Mock Error:%s\n", rank, name); // data processing frameworks runs on shared process
exit(-2); utils::Error("[%d]@@@Hit Mock Error:%s\n", rank, name);
} }
} }
}; };

View File

@ -113,7 +113,7 @@ int AllreduceRobust::SetBootstrapCache(const std::string &key, const void *buf,
} }
int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf, int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
const size_t type_nbytes, const size_t count, const bool byref) { const size_t type_nbytes, const size_t count) {
// as requester sync with rest of nodes on latest cache content // as requester sync with rest of nodes on latest cache content
if (!RecoverExec(NULL, 0, ActionSummary::kLoadBootstrapCache, if (!RecoverExec(NULL, 0, ActionSummary::kLoadBootstrapCache,
seq_counter, cur_cache_seq)) return -1; seq_counter, cur_cache_seq)) return -1;
@ -136,14 +136,7 @@ int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
utils::Assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index"); utils::Assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index");
utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested"); utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested");
utils::Assert(siz > 0, "cache size should be greater than 0"); utils::Assert(siz > 0, "cache size should be greater than 0");
std::memcpy(buf, temp, type_nbytes*count);
// immutable cache, save copy time by pointer manipulation
if (byref) {
buf = temp;
} else {
std::memcpy(buf, temp, type_nbytes*count);
}
return 0; return 0;
} }
@ -184,7 +177,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_,
// try fetch bootstrap allreduce results from cache // try fetch bootstrap allreduce results from cache
if (!checkpoint_loaded && rabit_bootstrap_cache && if (!checkpoint_loaded && rabit_bootstrap_cache &&
GetBootstrapCache(key, sendrecvbuf_, type_nbytes, count, true) != -1) return; GetBootstrapCache(key, sendrecvbuf_, type_nbytes, count) != -1) return;
double start = utils::GetTime(); double start = utils::GetTime();
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter, cur_cache_seq); bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter, cur_cache_seq);
@ -244,8 +237,7 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root,
+ std::string(_caller) + "#" +std::to_string(total_size) + "@" + std::to_string(root); + std::string(_caller) + "#" +std::to_string(total_size) + "@" + std::to_string(root);
// try fetch bootstrap allreduce results from cache // try fetch bootstrap allreduce results from cache
if (!checkpoint_loaded && rabit_bootstrap_cache && if (!checkpoint_loaded && rabit_bootstrap_cache &&
GetBootstrapCache(key, sendrecvbuf_, total_size, 1, true) != -1) return; GetBootstrapCache(key, sendrecvbuf_, total_size, 1) != -1) return;
double start = utils::GetTime(); double start = utils::GetTime();
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter, cur_cache_seq); bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter, cur_cache_seq);
// now we are free to remove the last result, if any // now we are free to remove the last result, if any
@ -1171,9 +1163,10 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
// if all nodes are requester in load cache, skip // if all nodes are requester in load cache, skip
if (act.load_cache(SeqType::kCache)) return false; if (act.load_cache(SeqType::kCache)) return false;
// only restore when at least one pair of max_seq are different // bootstrap cache always restore before loadcheckpoint
if (act.diff_seq(SeqType::kCache)) { // requester always have seq diff with non requester
// if restore cache failed, retry from what's left if (act.diff_seq()) {
// restore cache failed, retry from what's left
if (TryRestoreCache(req.load_cache(), act.seqno(), act.seqno(SeqType::kCache)) if (TryRestoreCache(req.load_cache(), act.seqno(), act.seqno(SeqType::kCache))
!= kSuccess) continue; != kSuccess) continue;
} }

View File

@ -49,7 +49,7 @@ class AllreduceRobust : public AllreduceBase {
* \param buflen total number of bytes * \param buflen total number of bytes
*/ */
int GetBootstrapCache(const std::string &key, void *buf, const size_t type_nbytes, int GetBootstrapCache(const std::string &key, void *buf, const size_t type_nbytes,
const size_t count, const bool byref = false); const size_t count);
/*! /*!
* \brief perform in-place allreduce, on sendrecvbuf * \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe * this function is NOT thread-safe
@ -255,9 +255,8 @@ class AllreduceRobust : public AllreduceBase {
return (code & kCheckAck) != 0; return (code & kCheckAck) != 0;
} }
// whether the operation set contains different sequence number // whether the operation set contains different sequence number
inline bool diff_seq(SeqType t = SeqType::kSeq) const { inline bool diff_seq() const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode; return (seqcode & kDiffSeq) != 0;
return (code & kDiffSeq) != 0;
} }
// returns the operation flag of the result // returns the operation flag of the result
inline int flag(SeqType t = SeqType::kSeq) const { inline int flag(SeqType t = SeqType::kSeq) const {
@ -266,11 +265,10 @@ class AllreduceRobust : public AllreduceBase {
} }
// print flags in user friendly way // print flags in user friendly way
inline void print_flags(int rank, std::string prefix ) { inline void print_flags(int rank, std::string prefix ) {
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|%d|\n", utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n",
rank, prefix.c_str(), rank, prefix.c_str(),
seqno(), check_point(), check_ack(), load_cache(), seqno(), check_point(), check_ack(), load_cache(),
diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache), diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache));
diff_seq(SeqType::kCache));
} }
// reducer for Allreduce, get the result ActionSummary from all nodes // reducer for Allreduce, get the result ActionSummary from all nodes
inline static void Reducer(const void *src_, void *dst_, inline static void Reducer(const void *src_, void *dst_,
@ -286,12 +284,9 @@ class AllreduceRobust : public AllreduceBase {
int role_flag = src[i].flag(SeqType::kCache) & dst[i].flag(SeqType::kCache); int role_flag = src[i].flag(SeqType::kCache) & dst[i].flag(SeqType::kCache);
// if seqno is different in src and destination // if seqno is different in src and destination
int seq_diff_flag = src[i].seqno() != dst[i].seqno() ? kDiffSeq : 0; int seq_diff_flag = src[i].seqno() != dst[i].seqno() ? kDiffSeq : 0;
// if cache seqno is different in src and destination
int cache_diff_flag =
src[i].seqno(SeqType::kCache) != dst[i].seqno(SeqType::kCache) ? kDiffSeq : 0;
// apply or to both seq diff flag as well as cache seq diff flag // apply or to both seq diff flag as well as cache seq diff flag
dst[i] = ActionSummary(action_flag | seq_diff_flag, dst[i] = ActionSummary(action_flag | seq_diff_flag,
role_flag | cache_diff_flag, min_seqno, max_seqno); role_flag, min_seqno, max_seqno);
} }
} }

View File

@ -3,12 +3,13 @@ find_package(GTest REQUIRED)
add_executable( add_executable(
unit_tests unit_tests
allreduce_base_test.cpp allreduce_base_test.cpp
allreduce_mock_test.cpp
test_main.cpp) test_main.cpp)
target_link_libraries( target_link_libraries(
unit_tests unit_tests
GTest::GTest GTest::Main GTest::GTest GTest::Main
rabit_base) rabit_base rabit_mock)
target_include_directories(unit_tests PUBLIC target_include_directories(unit_tests PUBLIC
"$<BUILD_INTERFACE:${rabit_SOURCE_DIR}/include>" "$<BUILD_INTERFACE:${rabit_SOURCE_DIR}/include>"

View File

@ -0,0 +1,36 @@
#define RABIT_CXXTESTDEFS_H
#include <gtest/gtest.h>
#include <string>
#include <iostream>
#include "../../src/allreduce_mock.h"
TEST(allreduce_mock, mock_allreduce)
{
rabit::engine::AllreduceMock m;
std::string mock_str = "mock=0,0,0,0";
char cmd[mock_str.size()+1];
std::copy(mock_str.begin(), mock_str.end(), cmd);
cmd[mock_str.size()] = '\0';
char* argv[] = {cmd};
m.Init(1, argv);
m.rank = 0;
EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), "");
}
TEST(allreduce_mock, mock_broadcast)
{
rabit::engine::AllreduceMock m;
std::string mock_str = "mock=0,1,2,0";
char cmd[mock_str.size()+1];
std::copy(mock_str.begin(), mock_str.end(), cmd);
cmd[mock_str.size()] = '\0';
char* argv[] = {cmd};
m.Init(1, argv);
m.rank = 0;
m.version_number=1;
m.seq_counter=2;
EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), "");
}

View File

@ -96,7 +96,7 @@ int main(int argc, char *argv[]) {
std::string name = rabit::GetProcessorName(); std::string name = rabit::GetProcessorName();
int max_rank = rank; int max_rank = rank;
rabit::Allreduce<op::Max>(&max_rank, sizeof(int)); rabit::Allreduce<op::Max>(&max_rank, 1);
utils::Check(max_rank == nproc - 1, "max rank is world size-1"); utils::Check(max_rank == nproc - 1, "max rank is world size-1");
Model model; Model model;