adding some references to mock inside TEST preprocessor directive.
It shouldn't be an assert because it shutdowns the process. Instead should check on the value and return some sort of error, so that we can recover. The mock contains queues, indexed by the rank of the process. For each node, you can configure the behavior you expect (success or failure for now) when you call any of the methods (AllReduce, Broadcast, LoadCheckPoint and CheckPoint)... If you call several times AllReduce, the outputs will pop from the queue, i.e., first you can retrieve a success, then a failure and so on. Pretty basic for now, need to tune it better
This commit is contained in:
parent
54fcff189f
commit
c565104491
@ -4,6 +4,9 @@
|
||||
* \author Tianqi Chen, Nacho, Tianyi
|
||||
*/
|
||||
#include "./engine.h"
|
||||
#ifdef TEST
|
||||
#include "./mock.h"
|
||||
#endif
|
||||
|
||||
/*! \brief namespace of all reduce */
|
||||
namespace sync {
|
||||
@ -43,6 +46,11 @@ void Init(int argc, char *argv[]) {
|
||||
void Finalize(void) {
|
||||
engine::Finalize();
|
||||
}
|
||||
#ifdef TEST
|
||||
void SetMock(test::Mock& mock) {
|
||||
engine::SetMock(mock);
|
||||
}
|
||||
#endif
|
||||
/*! \brief get rank of current process */
|
||||
inline int GetRank(void) {
|
||||
return engine::GetEngine()->GetRank();
|
||||
|
||||
@ -6,6 +6,10 @@
|
||||
* \author Tianqi Chen, Nacho, Tianyi
|
||||
*/
|
||||
#include "./io.h"
|
||||
#ifdef TEST
|
||||
#include "./mock.h"
|
||||
#endif
|
||||
|
||||
|
||||
namespace MPI {
|
||||
/*! \brief MPI data type just to be compatible with MPI reduce function*/
|
||||
@ -76,5 +80,10 @@ void Init(int argc, char *argv[]);
|
||||
void Finalize(void);
|
||||
/*! \brief singleton method to get engine */
|
||||
IEngine *GetEngine(void);
|
||||
|
||||
#ifdef TEST
|
||||
void SetMock(test::Mock& mock);
|
||||
#endif
|
||||
|
||||
} // namespace engine
|
||||
#endif // ALLREDUCE_ENGINE_H
|
||||
|
||||
@ -12,6 +12,9 @@
|
||||
#include <cstring>
|
||||
#include "./engine.h"
|
||||
#include "./socket.h"
|
||||
#ifdef TEST
|
||||
#include "./mock.h"
|
||||
#endif
|
||||
|
||||
namespace MPI {
|
||||
class Datatype {
|
||||
@ -37,6 +40,13 @@ class SyncManager : public IEngine {
|
||||
}
|
||||
~SyncManager(void) {
|
||||
}
|
||||
|
||||
#ifdef TEST
|
||||
inline void SetMock(test::Mock& mock) {
|
||||
this->mock = mock;
|
||||
}
|
||||
#endif
|
||||
|
||||
inline void Shutdown(void) {
|
||||
for (size_t i = 0; i < links.size(); ++i) {
|
||||
links[i].sock.Close();
|
||||
@ -168,6 +178,9 @@ class SyncManager : public IEngine {
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer) {
|
||||
#ifdef TEST
|
||||
utils::Assert(mock.AllReduce(this->rank), "Error returned by mock when reducing");
|
||||
#endif
|
||||
if (links.size() == 0) return;
|
||||
// total size of message
|
||||
const size_t total_size = type_nbytes * count;
|
||||
@ -275,6 +288,10 @@ class SyncManager : public IEngine {
|
||||
* \param root the root worker id to broadcast the data
|
||||
*/
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
#ifdef TEST
|
||||
utils::Assert(mock.Broadcast(this->rank), "Error returned by mock when broadcasting");
|
||||
#endif
|
||||
|
||||
if (links.size() == 0) return;
|
||||
// number of links
|
||||
const int nlink = static_cast<int>(links.size());
|
||||
@ -329,9 +346,15 @@ class SyncManager : public IEngine {
|
||||
}
|
||||
}
|
||||
virtual bool LoadCheckPoint(utils::ISerializable *p_model) {
|
||||
#ifdef TEST
|
||||
utils::Assert(mock.LoadCheckPoint(this->rank), "Error returned by mock when loading checkpoint");
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
virtual void CheckPoint(const utils::ISerializable &model) {
|
||||
#ifdef TEST
|
||||
utils::Assert(mock.CheckPoint(this->rank), "Error returned by mock when checkpointing");
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
@ -421,6 +444,11 @@ class SyncManager : public IEngine {
|
||||
std::vector<LinkRecord> links;
|
||||
// select helper
|
||||
utils::SelectHelper selecter;
|
||||
|
||||
#ifdef TEST
|
||||
// mock to test
|
||||
test::Mock mock;
|
||||
#endif
|
||||
};
|
||||
|
||||
// singleton sync manager
|
||||
@ -436,6 +464,14 @@ void Init(int argc, char *argv[]) {
|
||||
}
|
||||
manager.Init();
|
||||
}
|
||||
|
||||
#ifdef TEST
|
||||
/*! \brief sets a mock to the manager for testing purposes */
|
||||
void SetMock(test::Mock& mock) {
|
||||
manager.SetMock(mock);
|
||||
}
|
||||
#endif
|
||||
|
||||
/*! \brief finalize syncrhonization module */
|
||||
void Finalize(void) {
|
||||
manager.Shutdown();
|
||||
|
||||
73
src/mock.h
73
src/mock.h
@ -16,7 +16,7 @@ namespace test {
|
||||
|
||||
class Mock {
|
||||
|
||||
typedef std::map<int,std::queue<int> > Map;
|
||||
typedef std::map<int,std::queue<bool> > Map;
|
||||
|
||||
public:
|
||||
|
||||
@ -27,61 +27,64 @@ public:
|
||||
}
|
||||
|
||||
// record methods
|
||||
|
||||
inline void OnAllReduce(int rank, int code) {
|
||||
utils::Check(record, "Not in record state");
|
||||
Map::iterator it = allReduce.find(rank);
|
||||
if (it == allReduce.end()) {
|
||||
std::queue<int> aQueue;
|
||||
allReduce[rank] = aQueue;
|
||||
}
|
||||
allReduce[rank].push(code);
|
||||
inline void OnAllReduce(int rank, bool success) {
|
||||
onRecord(allReduce, rank, success);
|
||||
}
|
||||
|
||||
inline void OnBroadcast() {
|
||||
utils::Check(record, "Not in record state");
|
||||
inline void OnBroadcast(int rank, bool success) {
|
||||
onRecord(broadcast, rank, success);
|
||||
}
|
||||
|
||||
inline void OnLoadCheckpoint() {
|
||||
utils::Check(record, "Not in record state");
|
||||
inline void OnLoadCheckPoint(int rank, bool success) {
|
||||
onRecord(loadCheckpoint, rank, success);
|
||||
}
|
||||
|
||||
inline void OnCheckpoint() {
|
||||
utils::Check(record, "Not in record state");
|
||||
inline void OnCheckPoint(int rank, bool success) {
|
||||
onRecord(checkpoint, rank, success);
|
||||
}
|
||||
|
||||
|
||||
// replay methods
|
||||
|
||||
inline int AllReduce(int rank) {
|
||||
utils::Check(!record, "Not in replay state");
|
||||
utils::Check(allReduce.find(rank) != allReduce.end(), "Not recorded");
|
||||
int result = 0;
|
||||
if (!allReduce[rank].empty()) {
|
||||
result = allReduce[rank].front();
|
||||
allReduce[rank].pop();
|
||||
}
|
||||
return result;
|
||||
inline bool AllReduce(int rank) {
|
||||
return onReplay(allReduce, rank);
|
||||
}
|
||||
|
||||
inline int Broadcast(int rank) {
|
||||
utils::Check(!record, "Not in replay state");
|
||||
return 0;
|
||||
inline bool Broadcast(int rank) {
|
||||
return onReplay(broadcast, rank);
|
||||
}
|
||||
|
||||
inline int LoadCheckpoint(int rank) {
|
||||
utils::Check(!record, "Not in replay state");
|
||||
return 0;
|
||||
inline bool LoadCheckPoint(int rank) {
|
||||
return onReplay(loadCheckpoint, rank);
|
||||
}
|
||||
|
||||
inline int Checkpoint(int rank) {
|
||||
utils::Check(!record, "Not in replay state");
|
||||
return 0;
|
||||
inline bool CheckPoint(int rank) {
|
||||
return onReplay(checkpoint, rank);
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
|
||||
inline void onRecord(Map& m, int rank, bool success) {
|
||||
utils::Check(record, "Not in record state");
|
||||
Map::iterator it = m.find(rank);
|
||||
if (it == m.end()) {
|
||||
std::queue<bool> aQueue;
|
||||
m[rank] = aQueue;
|
||||
}
|
||||
m[rank].push(success);
|
||||
}
|
||||
|
||||
inline bool onReplay(Map& m, int rank) {
|
||||
utils::Check(!record, "Not in replay state");
|
||||
utils::Check(m.find(rank) != m.end(), "Not recorded");
|
||||
bool result = true;
|
||||
if (!m[rank].empty()) {
|
||||
result = m[rank].front();
|
||||
m[rank].pop();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// flag to indicate if the mock is in record state
|
||||
bool record;
|
||||
|
||||
|
||||
@ -66,13 +66,13 @@ inline void TestBcast(size_t n, int root) {
|
||||
inline void record(test::Mock& mock, int rank) {
|
||||
switch(rank) {
|
||||
case 0:
|
||||
mock.OnAllReduce(0, -1);
|
||||
mock.OnAllReduce(0, false);
|
||||
break;
|
||||
case 1:
|
||||
mock.OnAllReduce(1, -1);
|
||||
mock.OnAllReduce(1, false);
|
||||
break;
|
||||
case 2:
|
||||
mock.OnAllReduce(2, 0);
|
||||
mock.OnAllReduce(2, true);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -97,9 +97,12 @@ int main(int argc, char *argv[]) {
|
||||
test::Mock mock;
|
||||
record(mock, rank);
|
||||
mock.Replay();
|
||||
replay(mock, rank);
|
||||
//replay(mock, rank);
|
||||
sync::SetMock(mock);
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
printf("[%d] start at %s\n", rank, name.c_str());
|
||||
TestMax(n);
|
||||
printf("[%d] TestMax pass\n", rank);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user