dummy mock for now
This commit is contained in:
parent
d37f38c455
commit
54fcff189f
96
src/mock.h
Normal file
96
src/mock.h
Normal file
@ -0,0 +1,96 @@
|
||||
#ifndef ALLREDUCE_MOCK_H
|
||||
#define ALLREDUCE_MOCK_H
|
||||
/*!
|
||||
* \file mock.h
|
||||
* \brief This file defines a mock object to test the system
|
||||
* \author Tianqi Chen, Nacho, Tianyi
|
||||
*/
|
||||
#include "./engine.h"
|
||||
#include "./utils.h"
|
||||
#include <queue>
|
||||
#include <map>
|
||||
|
||||
|
||||
/*! \brief namespace of mock */
|
||||
namespace test {
|
||||
|
||||
class Mock {
|
||||
|
||||
typedef std::map<int,std::queue<int> > Map;
|
||||
|
||||
public:
|
||||
|
||||
Mock() : record(true) {}
|
||||
|
||||
inline void Replay() {
|
||||
record = false;
|
||||
}
|
||||
|
||||
// record methods
|
||||
|
||||
inline void OnAllReduce(int rank, int code) {
|
||||
utils::Check(record, "Not in record state");
|
||||
Map::iterator it = allReduce.find(rank);
|
||||
if (it == allReduce.end()) {
|
||||
std::queue<int> aQueue;
|
||||
allReduce[rank] = aQueue;
|
||||
}
|
||||
allReduce[rank].push(code);
|
||||
}
|
||||
|
||||
inline void OnBroadcast() {
|
||||
utils::Check(record, "Not in record state");
|
||||
}
|
||||
|
||||
inline void OnLoadCheckpoint() {
|
||||
utils::Check(record, "Not in record state");
|
||||
}
|
||||
|
||||
inline void OnCheckpoint() {
|
||||
utils::Check(record, "Not in record state");
|
||||
}
|
||||
|
||||
|
||||
// replay methods
|
||||
|
||||
inline int AllReduce(int rank) {
|
||||
utils::Check(!record, "Not in replay state");
|
||||
utils::Check(allReduce.find(rank) != allReduce.end(), "Not recorded");
|
||||
int result = 0;
|
||||
if (!allReduce[rank].empty()) {
|
||||
result = allReduce[rank].front();
|
||||
allReduce[rank].pop();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
inline int Broadcast(int rank) {
|
||||
utils::Check(!record, "Not in replay state");
|
||||
return 0;
|
||||
}
|
||||
|
||||
inline int LoadCheckpoint(int rank) {
|
||||
utils::Check(!record, "Not in replay state");
|
||||
return 0;
|
||||
}
|
||||
|
||||
inline int Checkpoint(int rank) {
|
||||
utils::Check(!record, "Not in replay state");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
|
||||
// flag to indicate if the mock is in record state
|
||||
bool record;
|
||||
|
||||
Map allReduce;
|
||||
Map broadcast;
|
||||
Map loadCheckpoint;
|
||||
Map checkpoint;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif // ALLREDUCE_MOCK_H
|
||||
@ -9,6 +9,10 @@ ifeq ($(no_omp),1)
|
||||
else
|
||||
CFLAGS += -fopenmp
|
||||
endif
|
||||
ifeq ($(test),1)
|
||||
CFLAGS += -DTEST
|
||||
endif
|
||||
|
||||
|
||||
# specify tensor path
|
||||
BIN = test_allreduce
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cmath>
|
||||
#include <mock.h>
|
||||
|
||||
|
||||
using namespace sync;
|
||||
|
||||
@ -60,6 +62,27 @@ inline void TestBcast(size_t n, int root) {
|
||||
utils::Check(res == s, "[%d] TestBcast fail", rank);
|
||||
}
|
||||
|
||||
// ugly stuff, just to see if it works
|
||||
inline void record(test::Mock& mock, int rank) {
|
||||
switch(rank) {
|
||||
case 0:
|
||||
mock.OnAllReduce(0, -1);
|
||||
break;
|
||||
case 1:
|
||||
mock.OnAllReduce(1, -1);
|
||||
break;
|
||||
case 2:
|
||||
mock.OnAllReduce(2, 0);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// to be removed, should be added in engine tcp
|
||||
inline void replay(test::Mock& mock, int rank) {
|
||||
printf("[%d] All reduce %d\n", rank, mock.AllReduce(rank));
|
||||
printf("[%d] All reduce %d\n", rank, mock.AllReduce(rank));
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
if (argc < 2) {
|
||||
printf("Usage: <ndata>\n");
|
||||
@ -69,6 +92,14 @@ int main(int argc, char *argv[]) {
|
||||
sync::Init(argc, argv);
|
||||
int rank = sync::GetRank();
|
||||
std::string name = sync::GetProcessorName();
|
||||
|
||||
#ifdef TEST
|
||||
test::Mock mock;
|
||||
record(mock, rank);
|
||||
mock.Replay();
|
||||
replay(mock, rank);
|
||||
#endif
|
||||
|
||||
printf("[%d] start at %s\n", rank, name.c_str());
|
||||
TestMax(n);
|
||||
printf("[%d] TestMax pass\n", rank);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user