unittests mock, cleanup (#111)
* cleanup, fix issue involved after remove is_bootstrap parameter * misc * clean * add unittests
This commit is contained in:
parent
ddcc2d85da
commit
af7281afe3
@ -1,6 +1,7 @@
|
||||
option(DMLC_ROOT "Specify root of external dmlc core.")
|
||||
|
||||
add_library(allreduce_base "")
|
||||
add_library(allreduce_mock "")
|
||||
|
||||
target_sources(
|
||||
allreduce_base
|
||||
@ -9,9 +10,22 @@ target_sources(
|
||||
PUBLIC
|
||||
${CMAKE_CURRENT_LIST_DIR}/allreduce_base.h
|
||||
)
|
||||
target_sources(
|
||||
allreduce_mock
|
||||
PRIVATE
|
||||
allreduce_robust.cc
|
||||
PUBLIC
|
||||
${CMAKE_CURRENT_LIST_DIR}/allreduce_mock.h
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
allreduce_base
|
||||
PUBLIC
|
||||
${DMLC_ROOT}/include
|
||||
${DMLC_ROOT}/include
|
||||
${CMAKE_CURRENT_LIST_DIR}/../../include)
|
||||
|
||||
target_include_directories(
|
||||
allreduce_mock
|
||||
PUBLIC
|
||||
${DMLC_ROOT}/include
|
||||
${CMAKE_CURRENT_LIST_DIR}/../../include)
|
||||
|
||||
@ -48,16 +48,23 @@ class AllreduceMock : public AllreduceRobust {
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
void *prepare_arg,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
|
||||
double tstart = utils::GetTime();
|
||||
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
|
||||
count, reducer, prepare_fun, prepare_arg);
|
||||
count, reducer, prepare_fun, prepare_arg,
|
||||
_file, _line, _caller);
|
||||
tsum_allreduce += utils::GetTime() - tstart;
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
|
||||
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
|
||||
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller);
|
||||
}
|
||||
virtual int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model) {
|
||||
@ -168,8 +175,8 @@ class AllreduceMock : public AllreduceRobust {
|
||||
inline void Verify(const MockKey &key, const char *name) {
|
||||
if (mock_map.count(key) != 0) {
|
||||
num_trial += 1;
|
||||
fprintf(stderr, "[%d]@@@Hit Mock Error:%s\n", rank, name);
|
||||
exit(-2);
|
||||
// data processing frameworks runs on shared process
|
||||
utils::Error("[%d]@@@Hit Mock Error:%s\n", rank, name);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -113,7 +113,7 @@ int AllreduceRobust::SetBootstrapCache(const std::string &key, const void *buf,
|
||||
}
|
||||
|
||||
int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
|
||||
const size_t type_nbytes, const size_t count, const bool byref) {
|
||||
const size_t type_nbytes, const size_t count) {
|
||||
// as requester sync with rest of nodes on latest cache content
|
||||
if (!RecoverExec(NULL, 0, ActionSummary::kLoadBootstrapCache,
|
||||
seq_counter, cur_cache_seq)) return -1;
|
||||
@ -136,14 +136,7 @@ int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
|
||||
utils::Assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index");
|
||||
utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested");
|
||||
utils::Assert(siz > 0, "cache size should be greater than 0");
|
||||
|
||||
// immutable cache, save copy time by pointer manipulation
|
||||
if (byref) {
|
||||
buf = temp;
|
||||
} else {
|
||||
std::memcpy(buf, temp, type_nbytes*count);
|
||||
}
|
||||
|
||||
std::memcpy(buf, temp, type_nbytes*count);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -184,7 +177,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_,
|
||||
|
||||
// try fetch bootstrap allreduce results from cache
|
||||
if (!checkpoint_loaded && rabit_bootstrap_cache &&
|
||||
GetBootstrapCache(key, sendrecvbuf_, type_nbytes, count, true) != -1) return;
|
||||
GetBootstrapCache(key, sendrecvbuf_, type_nbytes, count) != -1) return;
|
||||
|
||||
double start = utils::GetTime();
|
||||
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter, cur_cache_seq);
|
||||
@ -244,8 +237,7 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
+ std::string(_caller) + "#" +std::to_string(total_size) + "@" + std::to_string(root);
|
||||
// try fetch bootstrap allreduce results from cache
|
||||
if (!checkpoint_loaded && rabit_bootstrap_cache &&
|
||||
GetBootstrapCache(key, sendrecvbuf_, total_size, 1, true) != -1) return;
|
||||
|
||||
GetBootstrapCache(key, sendrecvbuf_, total_size, 1) != -1) return;
|
||||
double start = utils::GetTime();
|
||||
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter, cur_cache_seq);
|
||||
// now we are free to remove the last result, if any
|
||||
@ -1171,9 +1163,10 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
|
||||
// if all nodes are requester in load cache, skip
|
||||
if (act.load_cache(SeqType::kCache)) return false;
|
||||
|
||||
// only restore when at least one pair of max_seq are different
|
||||
if (act.diff_seq(SeqType::kCache)) {
|
||||
// if restore cache failed, retry from what's left
|
||||
// bootstrap cache always restore before loadcheckpoint
|
||||
// requester always have seq diff with non requester
|
||||
if (act.diff_seq()) {
|
||||
// restore cache failed, retry from what's left
|
||||
if (TryRestoreCache(req.load_cache(), act.seqno(), act.seqno(SeqType::kCache))
|
||||
!= kSuccess) continue;
|
||||
}
|
||||
|
||||
@ -49,7 +49,7 @@ class AllreduceRobust : public AllreduceBase {
|
||||
* \param buflen total number of bytes
|
||||
*/
|
||||
int GetBootstrapCache(const std::string &key, void *buf, const size_t type_nbytes,
|
||||
const size_t count, const bool byref = false);
|
||||
const size_t count);
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
@ -255,9 +255,8 @@ class AllreduceRobust : public AllreduceBase {
|
||||
return (code & kCheckAck) != 0;
|
||||
}
|
||||
// whether the operation set contains different sequence number
|
||||
inline bool diff_seq(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
||||
return (code & kDiffSeq) != 0;
|
||||
inline bool diff_seq() const {
|
||||
return (seqcode & kDiffSeq) != 0;
|
||||
}
|
||||
// returns the operation flag of the result
|
||||
inline int flag(SeqType t = SeqType::kSeq) const {
|
||||
@ -266,11 +265,10 @@ class AllreduceRobust : public AllreduceBase {
|
||||
}
|
||||
// print flags in user friendly way
|
||||
inline void print_flags(int rank, std::string prefix ) {
|
||||
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|%d|\n",
|
||||
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n",
|
||||
rank, prefix.c_str(),
|
||||
seqno(), check_point(), check_ack(), load_cache(),
|
||||
diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache),
|
||||
diff_seq(SeqType::kCache));
|
||||
diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache));
|
||||
}
|
||||
// reducer for Allreduce, get the result ActionSummary from all nodes
|
||||
inline static void Reducer(const void *src_, void *dst_,
|
||||
@ -286,12 +284,9 @@ class AllreduceRobust : public AllreduceBase {
|
||||
int role_flag = src[i].flag(SeqType::kCache) & dst[i].flag(SeqType::kCache);
|
||||
// if seqno is different in src and destination
|
||||
int seq_diff_flag = src[i].seqno() != dst[i].seqno() ? kDiffSeq : 0;
|
||||
// if cache seqno is different in src and destination
|
||||
int cache_diff_flag =
|
||||
src[i].seqno(SeqType::kCache) != dst[i].seqno(SeqType::kCache) ? kDiffSeq : 0;
|
||||
// apply or to both seq diff flag as well as cache seq diff flag
|
||||
dst[i] = ActionSummary(action_flag | seq_diff_flag,
|
||||
role_flag | cache_diff_flag, min_seqno, max_seqno);
|
||||
role_flag, min_seqno, max_seqno);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -3,12 +3,13 @@ find_package(GTest REQUIRED)
|
||||
add_executable(
|
||||
unit_tests
|
||||
allreduce_base_test.cpp
|
||||
allreduce_mock_test.cpp
|
||||
test_main.cpp)
|
||||
|
||||
target_link_libraries(
|
||||
unit_tests
|
||||
GTest::GTest GTest::Main
|
||||
rabit_base)
|
||||
rabit_base rabit_mock)
|
||||
|
||||
target_include_directories(unit_tests PUBLIC
|
||||
"$<BUILD_INTERFACE:${rabit_SOURCE_DIR}/include>"
|
||||
|
||||
36
test/cpp/allreduce_mock_test.cpp
Normal file
36
test/cpp/allreduce_mock_test.cpp
Normal file
@ -0,0 +1,36 @@
|
||||
#define RABIT_CXXTESTDEFS_H
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include "../../src/allreduce_mock.h"
|
||||
|
||||
TEST(allreduce_mock, mock_allreduce)
|
||||
{
|
||||
rabit::engine::AllreduceMock m;
|
||||
|
||||
std::string mock_str = "mock=0,0,0,0";
|
||||
char cmd[mock_str.size()+1];
|
||||
std::copy(mock_str.begin(), mock_str.end(), cmd);
|
||||
cmd[mock_str.size()] = '\0';
|
||||
|
||||
char* argv[] = {cmd};
|
||||
m.Init(1, argv);
|
||||
m.rank = 0;
|
||||
EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), "");
|
||||
}
|
||||
|
||||
TEST(allreduce_mock, mock_broadcast)
|
||||
{
|
||||
rabit::engine::AllreduceMock m;
|
||||
std::string mock_str = "mock=0,1,2,0";
|
||||
char cmd[mock_str.size()+1];
|
||||
std::copy(mock_str.begin(), mock_str.end(), cmd);
|
||||
cmd[mock_str.size()] = '\0';
|
||||
char* argv[] = {cmd};
|
||||
m.Init(1, argv);
|
||||
m.rank = 0;
|
||||
m.version_number=1;
|
||||
m.seq_counter=2;
|
||||
EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), "");
|
||||
}
|
||||
@ -96,7 +96,7 @@ int main(int argc, char *argv[]) {
|
||||
std::string name = rabit::GetProcessorName();
|
||||
|
||||
int max_rank = rank;
|
||||
rabit::Allreduce<op::Max>(&max_rank, sizeof(int));
|
||||
rabit::Allreduce<op::Max>(&max_rank, 1);
|
||||
utils::Check(max_rank == nproc - 1, "max rank is world size-1");
|
||||
|
||||
Model model;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user