unittests mock, cleanup (#111)
* cleanup, fix issue involved after remove is_bootstrap parameter * misc * clean * add unittests
This commit is contained in:
parent
ddcc2d85da
commit
af7281afe3
@ -1,6 +1,7 @@
|
|||||||
option(DMLC_ROOT "Specify root of external dmlc core.")
|
option(DMLC_ROOT "Specify root of external dmlc core.")
|
||||||
|
|
||||||
add_library(allreduce_base "")
|
add_library(allreduce_base "")
|
||||||
|
add_library(allreduce_mock "")
|
||||||
|
|
||||||
target_sources(
|
target_sources(
|
||||||
allreduce_base
|
allreduce_base
|
||||||
@ -9,9 +10,22 @@ target_sources(
|
|||||||
PUBLIC
|
PUBLIC
|
||||||
${CMAKE_CURRENT_LIST_DIR}/allreduce_base.h
|
${CMAKE_CURRENT_LIST_DIR}/allreduce_base.h
|
||||||
)
|
)
|
||||||
|
target_sources(
|
||||||
|
allreduce_mock
|
||||||
|
PRIVATE
|
||||||
|
allreduce_robust.cc
|
||||||
|
PUBLIC
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}/allreduce_mock.h
|
||||||
|
)
|
||||||
|
|
||||||
target_include_directories(
|
target_include_directories(
|
||||||
allreduce_base
|
allreduce_base
|
||||||
PUBLIC
|
PUBLIC
|
||||||
${DMLC_ROOT}/include
|
${DMLC_ROOT}/include
|
||||||
${CMAKE_CURRENT_LIST_DIR}/../../include)
|
${CMAKE_CURRENT_LIST_DIR}/../../include)
|
||||||
|
|
||||||
|
target_include_directories(
|
||||||
|
allreduce_mock
|
||||||
|
PUBLIC
|
||||||
|
${DMLC_ROOT}/include
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}/../../include)
|
||||||
|
|||||||
@ -48,16 +48,23 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer,
|
ReduceFunction reducer,
|
||||||
PreprocFunction prepare_fun,
|
PreprocFunction prepare_fun,
|
||||||
void *prepare_arg) {
|
void *prepare_arg,
|
||||||
|
const char* _file = _FILE,
|
||||||
|
const int _line = _LINE,
|
||||||
|
const char* _caller = _CALLER) {
|
||||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
|
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
|
||||||
double tstart = utils::GetTime();
|
double tstart = utils::GetTime();
|
||||||
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
|
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
|
||||||
count, reducer, prepare_fun, prepare_arg);
|
count, reducer, prepare_fun, prepare_arg,
|
||||||
|
_file, _line, _caller);
|
||||||
tsum_allreduce += utils::GetTime() - tstart;
|
tsum_allreduce += utils::GetTime() - tstart;
|
||||||
}
|
}
|
||||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||||
|
const char* _file = _FILE,
|
||||||
|
const int _line = _LINE,
|
||||||
|
const char* _caller = _CALLER) {
|
||||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
|
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
|
||||||
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
|
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller);
|
||||||
}
|
}
|
||||||
virtual int LoadCheckPoint(Serializable *global_model,
|
virtual int LoadCheckPoint(Serializable *global_model,
|
||||||
Serializable *local_model) {
|
Serializable *local_model) {
|
||||||
@ -168,8 +175,8 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
inline void Verify(const MockKey &key, const char *name) {
|
inline void Verify(const MockKey &key, const char *name) {
|
||||||
if (mock_map.count(key) != 0) {
|
if (mock_map.count(key) != 0) {
|
||||||
num_trial += 1;
|
num_trial += 1;
|
||||||
fprintf(stderr, "[%d]@@@Hit Mock Error:%s\n", rank, name);
|
// data processing frameworks runs on shared process
|
||||||
exit(-2);
|
utils::Error("[%d]@@@Hit Mock Error:%s\n", rank, name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -113,7 +113,7 @@ int AllreduceRobust::SetBootstrapCache(const std::string &key, const void *buf,
|
|||||||
}
|
}
|
||||||
|
|
||||||
int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
|
int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
|
||||||
const size_t type_nbytes, const size_t count, const bool byref) {
|
const size_t type_nbytes, const size_t count) {
|
||||||
// as requester sync with rest of nodes on latest cache content
|
// as requester sync with rest of nodes on latest cache content
|
||||||
if (!RecoverExec(NULL, 0, ActionSummary::kLoadBootstrapCache,
|
if (!RecoverExec(NULL, 0, ActionSummary::kLoadBootstrapCache,
|
||||||
seq_counter, cur_cache_seq)) return -1;
|
seq_counter, cur_cache_seq)) return -1;
|
||||||
@ -136,14 +136,7 @@ int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
|
|||||||
utils::Assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index");
|
utils::Assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index");
|
||||||
utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested");
|
utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested");
|
||||||
utils::Assert(siz > 0, "cache size should be greater than 0");
|
utils::Assert(siz > 0, "cache size should be greater than 0");
|
||||||
|
std::memcpy(buf, temp, type_nbytes*count);
|
||||||
// immutable cache, save copy time by pointer manipulation
|
|
||||||
if (byref) {
|
|
||||||
buf = temp;
|
|
||||||
} else {
|
|
||||||
std::memcpy(buf, temp, type_nbytes*count);
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -184,7 +177,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_,
|
|||||||
|
|
||||||
// try fetch bootstrap allreduce results from cache
|
// try fetch bootstrap allreduce results from cache
|
||||||
if (!checkpoint_loaded && rabit_bootstrap_cache &&
|
if (!checkpoint_loaded && rabit_bootstrap_cache &&
|
||||||
GetBootstrapCache(key, sendrecvbuf_, type_nbytes, count, true) != -1) return;
|
GetBootstrapCache(key, sendrecvbuf_, type_nbytes, count) != -1) return;
|
||||||
|
|
||||||
double start = utils::GetTime();
|
double start = utils::GetTime();
|
||||||
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter, cur_cache_seq);
|
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter, cur_cache_seq);
|
||||||
@ -244,8 +237,7 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
|||||||
+ std::string(_caller) + "#" +std::to_string(total_size) + "@" + std::to_string(root);
|
+ std::string(_caller) + "#" +std::to_string(total_size) + "@" + std::to_string(root);
|
||||||
// try fetch bootstrap allreduce results from cache
|
// try fetch bootstrap allreduce results from cache
|
||||||
if (!checkpoint_loaded && rabit_bootstrap_cache &&
|
if (!checkpoint_loaded && rabit_bootstrap_cache &&
|
||||||
GetBootstrapCache(key, sendrecvbuf_, total_size, 1, true) != -1) return;
|
GetBootstrapCache(key, sendrecvbuf_, total_size, 1) != -1) return;
|
||||||
|
|
||||||
double start = utils::GetTime();
|
double start = utils::GetTime();
|
||||||
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter, cur_cache_seq);
|
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter, cur_cache_seq);
|
||||||
// now we are free to remove the last result, if any
|
// now we are free to remove the last result, if any
|
||||||
@ -1171,9 +1163,10 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
|
|||||||
// if all nodes are requester in load cache, skip
|
// if all nodes are requester in load cache, skip
|
||||||
if (act.load_cache(SeqType::kCache)) return false;
|
if (act.load_cache(SeqType::kCache)) return false;
|
||||||
|
|
||||||
// only restore when at least one pair of max_seq are different
|
// bootstrap cache always restore before loadcheckpoint
|
||||||
if (act.diff_seq(SeqType::kCache)) {
|
// requester always have seq diff with non requester
|
||||||
// if restore cache failed, retry from what's left
|
if (act.diff_seq()) {
|
||||||
|
// restore cache failed, retry from what's left
|
||||||
if (TryRestoreCache(req.load_cache(), act.seqno(), act.seqno(SeqType::kCache))
|
if (TryRestoreCache(req.load_cache(), act.seqno(), act.seqno(SeqType::kCache))
|
||||||
!= kSuccess) continue;
|
!= kSuccess) continue;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -49,7 +49,7 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
* \param buflen total number of bytes
|
* \param buflen total number of bytes
|
||||||
*/
|
*/
|
||||||
int GetBootstrapCache(const std::string &key, void *buf, const size_t type_nbytes,
|
int GetBootstrapCache(const std::string &key, void *buf, const size_t type_nbytes,
|
||||||
const size_t count, const bool byref = false);
|
const size_t count);
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
* this function is NOT thread-safe
|
* this function is NOT thread-safe
|
||||||
@ -255,9 +255,8 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
return (code & kCheckAck) != 0;
|
return (code & kCheckAck) != 0;
|
||||||
}
|
}
|
||||||
// whether the operation set contains different sequence number
|
// whether the operation set contains different sequence number
|
||||||
inline bool diff_seq(SeqType t = SeqType::kSeq) const {
|
inline bool diff_seq() const {
|
||||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
return (seqcode & kDiffSeq) != 0;
|
||||||
return (code & kDiffSeq) != 0;
|
|
||||||
}
|
}
|
||||||
// returns the operation flag of the result
|
// returns the operation flag of the result
|
||||||
inline int flag(SeqType t = SeqType::kSeq) const {
|
inline int flag(SeqType t = SeqType::kSeq) const {
|
||||||
@ -266,11 +265,10 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
}
|
}
|
||||||
// print flags in user friendly way
|
// print flags in user friendly way
|
||||||
inline void print_flags(int rank, std::string prefix ) {
|
inline void print_flags(int rank, std::string prefix ) {
|
||||||
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|%d|\n",
|
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n",
|
||||||
rank, prefix.c_str(),
|
rank, prefix.c_str(),
|
||||||
seqno(), check_point(), check_ack(), load_cache(),
|
seqno(), check_point(), check_ack(), load_cache(),
|
||||||
diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache),
|
diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache));
|
||||||
diff_seq(SeqType::kCache));
|
|
||||||
}
|
}
|
||||||
// reducer for Allreduce, get the result ActionSummary from all nodes
|
// reducer for Allreduce, get the result ActionSummary from all nodes
|
||||||
inline static void Reducer(const void *src_, void *dst_,
|
inline static void Reducer(const void *src_, void *dst_,
|
||||||
@ -286,12 +284,9 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
int role_flag = src[i].flag(SeqType::kCache) & dst[i].flag(SeqType::kCache);
|
int role_flag = src[i].flag(SeqType::kCache) & dst[i].flag(SeqType::kCache);
|
||||||
// if seqno is different in src and destination
|
// if seqno is different in src and destination
|
||||||
int seq_diff_flag = src[i].seqno() != dst[i].seqno() ? kDiffSeq : 0;
|
int seq_diff_flag = src[i].seqno() != dst[i].seqno() ? kDiffSeq : 0;
|
||||||
// if cache seqno is different in src and destination
|
|
||||||
int cache_diff_flag =
|
|
||||||
src[i].seqno(SeqType::kCache) != dst[i].seqno(SeqType::kCache) ? kDiffSeq : 0;
|
|
||||||
// apply or to both seq diff flag as well as cache seq diff flag
|
// apply or to both seq diff flag as well as cache seq diff flag
|
||||||
dst[i] = ActionSummary(action_flag | seq_diff_flag,
|
dst[i] = ActionSummary(action_flag | seq_diff_flag,
|
||||||
role_flag | cache_diff_flag, min_seqno, max_seqno);
|
role_flag, min_seqno, max_seqno);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -3,12 +3,13 @@ find_package(GTest REQUIRED)
|
|||||||
add_executable(
|
add_executable(
|
||||||
unit_tests
|
unit_tests
|
||||||
allreduce_base_test.cpp
|
allreduce_base_test.cpp
|
||||||
|
allreduce_mock_test.cpp
|
||||||
test_main.cpp)
|
test_main.cpp)
|
||||||
|
|
||||||
target_link_libraries(
|
target_link_libraries(
|
||||||
unit_tests
|
unit_tests
|
||||||
GTest::GTest GTest::Main
|
GTest::GTest GTest::Main
|
||||||
rabit_base)
|
rabit_base rabit_mock)
|
||||||
|
|
||||||
target_include_directories(unit_tests PUBLIC
|
target_include_directories(unit_tests PUBLIC
|
||||||
"$<BUILD_INTERFACE:${rabit_SOURCE_DIR}/include>"
|
"$<BUILD_INTERFACE:${rabit_SOURCE_DIR}/include>"
|
||||||
|
|||||||
36
test/cpp/allreduce_mock_test.cpp
Normal file
36
test/cpp/allreduce_mock_test.cpp
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
#define RABIT_CXXTESTDEFS_H
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
#include "../../src/allreduce_mock.h"
|
||||||
|
|
||||||
|
TEST(allreduce_mock, mock_allreduce)
|
||||||
|
{
|
||||||
|
rabit::engine::AllreduceMock m;
|
||||||
|
|
||||||
|
std::string mock_str = "mock=0,0,0,0";
|
||||||
|
char cmd[mock_str.size()+1];
|
||||||
|
std::copy(mock_str.begin(), mock_str.end(), cmd);
|
||||||
|
cmd[mock_str.size()] = '\0';
|
||||||
|
|
||||||
|
char* argv[] = {cmd};
|
||||||
|
m.Init(1, argv);
|
||||||
|
m.rank = 0;
|
||||||
|
EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), "");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(allreduce_mock, mock_broadcast)
|
||||||
|
{
|
||||||
|
rabit::engine::AllreduceMock m;
|
||||||
|
std::string mock_str = "mock=0,1,2,0";
|
||||||
|
char cmd[mock_str.size()+1];
|
||||||
|
std::copy(mock_str.begin(), mock_str.end(), cmd);
|
||||||
|
cmd[mock_str.size()] = '\0';
|
||||||
|
char* argv[] = {cmd};
|
||||||
|
m.Init(1, argv);
|
||||||
|
m.rank = 0;
|
||||||
|
m.version_number=1;
|
||||||
|
m.seq_counter=2;
|
||||||
|
EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), "");
|
||||||
|
}
|
||||||
@ -96,7 +96,7 @@ int main(int argc, char *argv[]) {
|
|||||||
std::string name = rabit::GetProcessorName();
|
std::string name = rabit::GetProcessorName();
|
||||||
|
|
||||||
int max_rank = rank;
|
int max_rank = rank;
|
||||||
rabit::Allreduce<op::Max>(&max_rank, sizeof(int));
|
rabit::Allreduce<op::Max>(&max_rank, 1);
|
||||||
utils::Check(max_rank == nproc - 1, "max rank is world size-1");
|
utils::Check(max_rank == nproc - 1, "max rank is world size-1");
|
||||||
|
|
||||||
Model model;
|
Model model;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user