dummy mock for now
This commit is contained in:
parent
d37f38c455
commit
54fcff189f
96
src/mock.h
Normal file
96
src/mock.h
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
#ifndef ALLREDUCE_MOCK_H
|
||||||
|
#define ALLREDUCE_MOCK_H
|
||||||
|
/*!
|
||||||
|
* \file mock.h
|
||||||
|
* \brief This file defines a mock object to test the system
|
||||||
|
* \author Tianqi Chen, Nacho, Tianyi
|
||||||
|
*/
|
||||||
|
#include "./engine.h"
|
||||||
|
#include "./utils.h"
|
||||||
|
#include <queue>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
|
||||||
|
/*! \brief namespace of mock */
|
||||||
|
namespace test {
|
||||||
|
|
||||||
|
class Mock {
|
||||||
|
|
||||||
|
typedef std::map<int,std::queue<int> > Map;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
Mock() : record(true) {}
|
||||||
|
|
||||||
|
inline void Replay() {
|
||||||
|
record = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// record methods
|
||||||
|
|
||||||
|
inline void OnAllReduce(int rank, int code) {
|
||||||
|
utils::Check(record, "Not in record state");
|
||||||
|
Map::iterator it = allReduce.find(rank);
|
||||||
|
if (it == allReduce.end()) {
|
||||||
|
std::queue<int> aQueue;
|
||||||
|
allReduce[rank] = aQueue;
|
||||||
|
}
|
||||||
|
allReduce[rank].push(code);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void OnBroadcast() {
|
||||||
|
utils::Check(record, "Not in record state");
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void OnLoadCheckpoint() {
|
||||||
|
utils::Check(record, "Not in record state");
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void OnCheckpoint() {
|
||||||
|
utils::Check(record, "Not in record state");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// replay methods
|
||||||
|
|
||||||
|
inline int AllReduce(int rank) {
|
||||||
|
utils::Check(!record, "Not in replay state");
|
||||||
|
utils::Check(allReduce.find(rank) != allReduce.end(), "Not recorded");
|
||||||
|
int result = 0;
|
||||||
|
if (!allReduce[rank].empty()) {
|
||||||
|
result = allReduce[rank].front();
|
||||||
|
allReduce[rank].pop();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int Broadcast(int rank) {
|
||||||
|
utils::Check(!record, "Not in replay state");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int LoadCheckpoint(int rank) {
|
||||||
|
utils::Check(!record, "Not in replay state");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int Checkpoint(int rank) {
|
||||||
|
utils::Check(!record, "Not in replay state");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
// flag to indicate if the mock is in record state
|
||||||
|
bool record;
|
||||||
|
|
||||||
|
Map allReduce;
|
||||||
|
Map broadcast;
|
||||||
|
Map loadCheckpoint;
|
||||||
|
Map checkpoint;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // ALLREDUCE_MOCK_H
|
||||||
@ -9,6 +9,10 @@ ifeq ($(no_omp),1)
|
|||||||
else
|
else
|
||||||
CFLAGS += -fopenmp
|
CFLAGS += -fopenmp
|
||||||
endif
|
endif
|
||||||
|
ifeq ($(test),1)
|
||||||
|
CFLAGS += -DTEST
|
||||||
|
endif
|
||||||
|
|
||||||
|
|
||||||
# specify tensor path
|
# specify tensor path
|
||||||
BIN = test_allreduce
|
BIN = test_allreduce
|
||||||
|
|||||||
@ -3,6 +3,8 @@
|
|||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <mock.h>
|
||||||
|
|
||||||
|
|
||||||
using namespace sync;
|
using namespace sync;
|
||||||
|
|
||||||
@ -60,6 +62,27 @@ inline void TestBcast(size_t n, int root) {
|
|||||||
utils::Check(res == s, "[%d] TestBcast fail", rank);
|
utils::Check(res == s, "[%d] TestBcast fail", rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ugly stuff, just to see if it works
|
||||||
|
inline void record(test::Mock& mock, int rank) {
|
||||||
|
switch(rank) {
|
||||||
|
case 0:
|
||||||
|
mock.OnAllReduce(0, -1);
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
mock.OnAllReduce(1, -1);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
mock.OnAllReduce(2, 0);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// to be removed, should be added in engine tcp
|
||||||
|
inline void replay(test::Mock& mock, int rank) {
|
||||||
|
printf("[%d] All reduce %d\n", rank, mock.AllReduce(rank));
|
||||||
|
printf("[%d] All reduce %d\n", rank, mock.AllReduce(rank));
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
if (argc < 2) {
|
if (argc < 2) {
|
||||||
printf("Usage: <ndata>\n");
|
printf("Usage: <ndata>\n");
|
||||||
@ -69,6 +92,14 @@ int main(int argc, char *argv[]) {
|
|||||||
sync::Init(argc, argv);
|
sync::Init(argc, argv);
|
||||||
int rank = sync::GetRank();
|
int rank = sync::GetRank();
|
||||||
std::string name = sync::GetProcessorName();
|
std::string name = sync::GetProcessorName();
|
||||||
|
|
||||||
|
#ifdef TEST
|
||||||
|
test::Mock mock;
|
||||||
|
record(mock, rank);
|
||||||
|
mock.Replay();
|
||||||
|
replay(mock, rank);
|
||||||
|
#endif
|
||||||
|
|
||||||
printf("[%d] start at %s\n", rank, name.c_str());
|
printf("[%d] start at %s\n", rank, name.c_str());
|
||||||
TestMax(n);
|
TestMax(n);
|
||||||
printf("[%d] TestMax pass\n", rank);
|
printf("[%d] TestMax pass\n", rank);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user