adding some references to mock inside TEST preprocessor directive.
It shouldn't be an assert because it shutdowns the process. Instead should check on the value and return some sort of error, so that we can recover. The mock contains queues, indexed by the rank of the process. For each node, you can configure the behavior you expect (success or failure for now) when you call any of the methods (AllReduce, Broadcast, LoadCheckPoint and CheckPoint)... If you call several times AllReduce, the outputs will pop from the queue, i.e., first you can retrieve a success, then a failure and so on. Pretty basic for now, need to tune it better
This commit is contained in:
parent
54fcff189f
commit
c565104491
@ -4,6 +4,9 @@
|
|||||||
* \author Tianqi Chen, Nacho, Tianyi
|
* \author Tianqi Chen, Nacho, Tianyi
|
||||||
*/
|
*/
|
||||||
#include "./engine.h"
|
#include "./engine.h"
|
||||||
|
#ifdef TEST
|
||||||
|
#include "./mock.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
/*! \brief namespace of all reduce */
|
/*! \brief namespace of all reduce */
|
||||||
namespace sync {
|
namespace sync {
|
||||||
@ -43,6 +46,11 @@ void Init(int argc, char *argv[]) {
|
|||||||
void Finalize(void) {
|
void Finalize(void) {
|
||||||
engine::Finalize();
|
engine::Finalize();
|
||||||
}
|
}
|
||||||
|
#ifdef TEST
|
||||||
|
void SetMock(test::Mock& mock) {
|
||||||
|
engine::SetMock(mock);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
/*! \brief get rank of current process */
|
/*! \brief get rank of current process */
|
||||||
inline int GetRank(void) {
|
inline int GetRank(void) {
|
||||||
return engine::GetEngine()->GetRank();
|
return engine::GetEngine()->GetRank();
|
||||||
|
|||||||
@ -6,6 +6,10 @@
|
|||||||
* \author Tianqi Chen, Nacho, Tianyi
|
* \author Tianqi Chen, Nacho, Tianyi
|
||||||
*/
|
*/
|
||||||
#include "./io.h"
|
#include "./io.h"
|
||||||
|
#ifdef TEST
|
||||||
|
#include "./mock.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
namespace MPI {
|
namespace MPI {
|
||||||
/*! \brief MPI data type just to be compatible with MPI reduce function*/
|
/*! \brief MPI data type just to be compatible with MPI reduce function*/
|
||||||
@ -76,5 +80,10 @@ void Init(int argc, char *argv[]);
|
|||||||
void Finalize(void);
|
void Finalize(void);
|
||||||
/*! \brief singleton method to get engine */
|
/*! \brief singleton method to get engine */
|
||||||
IEngine *GetEngine(void);
|
IEngine *GetEngine(void);
|
||||||
|
|
||||||
|
#ifdef TEST
|
||||||
|
void SetMock(test::Mock& mock);
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
#endif // ALLREDUCE_ENGINE_H
|
#endif // ALLREDUCE_ENGINE_H
|
||||||
|
|||||||
@ -12,6 +12,9 @@
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include "./engine.h"
|
#include "./engine.h"
|
||||||
#include "./socket.h"
|
#include "./socket.h"
|
||||||
|
#ifdef TEST
|
||||||
|
#include "./mock.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace MPI {
|
namespace MPI {
|
||||||
class Datatype {
|
class Datatype {
|
||||||
@ -37,6 +40,13 @@ class SyncManager : public IEngine {
|
|||||||
}
|
}
|
||||||
~SyncManager(void) {
|
~SyncManager(void) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef TEST
|
||||||
|
inline void SetMock(test::Mock& mock) {
|
||||||
|
this->mock = mock;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
inline void Shutdown(void) {
|
inline void Shutdown(void) {
|
||||||
for (size_t i = 0; i < links.size(); ++i) {
|
for (size_t i = 0; i < links.size(); ++i) {
|
||||||
links[i].sock.Close();
|
links[i].sock.Close();
|
||||||
@ -168,6 +178,9 @@ class SyncManager : public IEngine {
|
|||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer) {
|
ReduceFunction reducer) {
|
||||||
|
#ifdef TEST
|
||||||
|
utils::Assert(mock.AllReduce(this->rank), "Error returned by mock when reducing");
|
||||||
|
#endif
|
||||||
if (links.size() == 0) return;
|
if (links.size() == 0) return;
|
||||||
// total size of message
|
// total size of message
|
||||||
const size_t total_size = type_nbytes * count;
|
const size_t total_size = type_nbytes * count;
|
||||||
@ -275,6 +288,10 @@ class SyncManager : public IEngine {
|
|||||||
* \param root the root worker id to broadcast the data
|
* \param root the root worker id to broadcast the data
|
||||||
*/
|
*/
|
||||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||||
|
#ifdef TEST
|
||||||
|
utils::Assert(mock.Broadcast(this->rank), "Error returned by mock when broadcasting");
|
||||||
|
#endif
|
||||||
|
|
||||||
if (links.size() == 0) return;
|
if (links.size() == 0) return;
|
||||||
// number of links
|
// number of links
|
||||||
const int nlink = static_cast<int>(links.size());
|
const int nlink = static_cast<int>(links.size());
|
||||||
@ -329,9 +346,15 @@ class SyncManager : public IEngine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
virtual bool LoadCheckPoint(utils::ISerializable *p_model) {
|
virtual bool LoadCheckPoint(utils::ISerializable *p_model) {
|
||||||
|
#ifdef TEST
|
||||||
|
utils::Assert(mock.LoadCheckPoint(this->rank), "Error returned by mock when loading checkpoint");
|
||||||
|
#endif
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
virtual void CheckPoint(const utils::ISerializable &model) {
|
virtual void CheckPoint(const utils::ISerializable &model) {
|
||||||
|
#ifdef TEST
|
||||||
|
utils::Assert(mock.CheckPoint(this->rank), "Error returned by mock when checkpointing");
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -421,6 +444,11 @@ class SyncManager : public IEngine {
|
|||||||
std::vector<LinkRecord> links;
|
std::vector<LinkRecord> links;
|
||||||
// select helper
|
// select helper
|
||||||
utils::SelectHelper selecter;
|
utils::SelectHelper selecter;
|
||||||
|
|
||||||
|
#ifdef TEST
|
||||||
|
// mock to test
|
||||||
|
test::Mock mock;
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
// singleton sync manager
|
// singleton sync manager
|
||||||
@ -436,6 +464,14 @@ void Init(int argc, char *argv[]) {
|
|||||||
}
|
}
|
||||||
manager.Init();
|
manager.Init();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef TEST
|
||||||
|
/*! \brief sets a mock to the manager for testing purposes */
|
||||||
|
void SetMock(test::Mock& mock) {
|
||||||
|
manager.SetMock(mock);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
/*! \brief finalize syncrhonization module */
|
/*! \brief finalize syncrhonization module */
|
||||||
void Finalize(void) {
|
void Finalize(void) {
|
||||||
manager.Shutdown();
|
manager.Shutdown();
|
||||||
|
|||||||
73
src/mock.h
73
src/mock.h
@ -16,7 +16,7 @@ namespace test {
|
|||||||
|
|
||||||
class Mock {
|
class Mock {
|
||||||
|
|
||||||
typedef std::map<int,std::queue<int> > Map;
|
typedef std::map<int,std::queue<bool> > Map;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
@ -27,61 +27,64 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// record methods
|
// record methods
|
||||||
|
inline void OnAllReduce(int rank, bool success) {
|
||||||
inline void OnAllReduce(int rank, int code) {
|
onRecord(allReduce, rank, success);
|
||||||
utils::Check(record, "Not in record state");
|
|
||||||
Map::iterator it = allReduce.find(rank);
|
|
||||||
if (it == allReduce.end()) {
|
|
||||||
std::queue<int> aQueue;
|
|
||||||
allReduce[rank] = aQueue;
|
|
||||||
}
|
|
||||||
allReduce[rank].push(code);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void OnBroadcast() {
|
inline void OnBroadcast(int rank, bool success) {
|
||||||
utils::Check(record, "Not in record state");
|
onRecord(broadcast, rank, success);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void OnLoadCheckpoint() {
|
inline void OnLoadCheckPoint(int rank, bool success) {
|
||||||
utils::Check(record, "Not in record state");
|
onRecord(loadCheckpoint, rank, success);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void OnCheckpoint() {
|
inline void OnCheckPoint(int rank, bool success) {
|
||||||
utils::Check(record, "Not in record state");
|
onRecord(checkpoint, rank, success);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// replay methods
|
// replay methods
|
||||||
|
inline bool AllReduce(int rank) {
|
||||||
inline int AllReduce(int rank) {
|
return onReplay(allReduce, rank);
|
||||||
utils::Check(!record, "Not in replay state");
|
|
||||||
utils::Check(allReduce.find(rank) != allReduce.end(), "Not recorded");
|
|
||||||
int result = 0;
|
|
||||||
if (!allReduce[rank].empty()) {
|
|
||||||
result = allReduce[rank].front();
|
|
||||||
allReduce[rank].pop();
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int Broadcast(int rank) {
|
inline bool Broadcast(int rank) {
|
||||||
utils::Check(!record, "Not in replay state");
|
return onReplay(broadcast, rank);
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int LoadCheckpoint(int rank) {
|
inline bool LoadCheckPoint(int rank) {
|
||||||
utils::Check(!record, "Not in replay state");
|
return onReplay(loadCheckpoint, rank);
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int Checkpoint(int rank) {
|
inline bool CheckPoint(int rank) {
|
||||||
utils::Check(!record, "Not in replay state");
|
return onReplay(checkpoint, rank);
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
|
inline void onRecord(Map& m, int rank, bool success) {
|
||||||
|
utils::Check(record, "Not in record state");
|
||||||
|
Map::iterator it = m.find(rank);
|
||||||
|
if (it == m.end()) {
|
||||||
|
std::queue<bool> aQueue;
|
||||||
|
m[rank] = aQueue;
|
||||||
|
}
|
||||||
|
m[rank].push(success);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool onReplay(Map& m, int rank) {
|
||||||
|
utils::Check(!record, "Not in replay state");
|
||||||
|
utils::Check(m.find(rank) != m.end(), "Not recorded");
|
||||||
|
bool result = true;
|
||||||
|
if (!m[rank].empty()) {
|
||||||
|
result = m[rank].front();
|
||||||
|
m[rank].pop();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// flag to indicate if the mock is in record state
|
// flag to indicate if the mock is in record state
|
||||||
bool record;
|
bool record;
|
||||||
|
|
||||||
|
|||||||
@ -66,13 +66,13 @@ inline void TestBcast(size_t n, int root) {
|
|||||||
inline void record(test::Mock& mock, int rank) {
|
inline void record(test::Mock& mock, int rank) {
|
||||||
switch(rank) {
|
switch(rank) {
|
||||||
case 0:
|
case 0:
|
||||||
mock.OnAllReduce(0, -1);
|
mock.OnAllReduce(0, false);
|
||||||
break;
|
break;
|
||||||
case 1:
|
case 1:
|
||||||
mock.OnAllReduce(1, -1);
|
mock.OnAllReduce(1, false);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
mock.OnAllReduce(2, 0);
|
mock.OnAllReduce(2, true);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -97,9 +97,12 @@ int main(int argc, char *argv[]) {
|
|||||||
test::Mock mock;
|
test::Mock mock;
|
||||||
record(mock, rank);
|
record(mock, rank);
|
||||||
mock.Replay();
|
mock.Replay();
|
||||||
replay(mock, rank);
|
//replay(mock, rank);
|
||||||
|
sync::SetMock(mock);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
printf("[%d] start at %s\n", rank, name.c_str());
|
printf("[%d] start at %s\n", rank, name.c_str());
|
||||||
TestMax(n);
|
TestMax(n);
|
||||||
printf("[%d] TestMax pass\n", rank);
|
printf("[%d] TestMax pass\n", rank);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user