dummy mock for now

This commit is contained in:
nachocano 2014-11-26 16:37:23 -08:00
parent d37f38c455
commit 54fcff189f
3 changed files with 131 additions and 0 deletions

96
src/mock.h Normal file
View File

@ -0,0 +1,96 @@
#ifndef ALLREDUCE_MOCK_H
#define ALLREDUCE_MOCK_H
/*!
* \file mock.h
* \brief This file defines a mock object to test the system
* \author Tianqi Chen, Nacho, Tianyi
*/
#include "./engine.h"
#include "./utils.h"
#include <queue>
#include <map>
/*! \brief namespace of mock */
namespace test {
class Mock {
typedef std::map<int,std::queue<int> > Map;
public:
Mock() : record(true) {}
inline void Replay() {
record = false;
}
// record methods
inline void OnAllReduce(int rank, int code) {
utils::Check(record, "Not in record state");
Map::iterator it = allReduce.find(rank);
if (it == allReduce.end()) {
std::queue<int> aQueue;
allReduce[rank] = aQueue;
}
allReduce[rank].push(code);
}
inline void OnBroadcast() {
utils::Check(record, "Not in record state");
}
inline void OnLoadCheckpoint() {
utils::Check(record, "Not in record state");
}
inline void OnCheckpoint() {
utils::Check(record, "Not in record state");
}
// replay methods
inline int AllReduce(int rank) {
utils::Check(!record, "Not in replay state");
utils::Check(allReduce.find(rank) != allReduce.end(), "Not recorded");
int result = 0;
if (!allReduce[rank].empty()) {
result = allReduce[rank].front();
allReduce[rank].pop();
}
return result;
}
inline int Broadcast(int rank) {
utils::Check(!record, "Not in replay state");
return 0;
}
inline int LoadCheckpoint(int rank) {
utils::Check(!record, "Not in replay state");
return 0;
}
inline int Checkpoint(int rank) {
utils::Check(!record, "Not in replay state");
return 0;
}
private:
// flag to indicate if the mock is in record state
bool record;
Map allReduce;
Map broadcast;
Map loadCheckpoint;
Map checkpoint;
};
}
#endif // ALLREDUCE_MOCK_H

View File

@ -9,6 +9,10 @@ ifeq ($(no_omp),1)
else
CFLAGS += -fopenmp
endif
ifeq ($(test),1)
CFLAGS += -DTEST
endif
# specify tensor path
BIN = test_allreduce

View File

@ -3,6 +3,8 @@
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <mock.h>
using namespace sync;
@ -60,6 +62,27 @@ inline void TestBcast(size_t n, int root) {
utils::Check(res == s, "[%d] TestBcast fail", rank);
}
// ugly stuff, just to see if it works
inline void record(test::Mock& mock, int rank) {
switch(rank) {
case 0:
mock.OnAllReduce(0, -1);
break;
case 1:
mock.OnAllReduce(1, -1);
break;
case 2:
mock.OnAllReduce(2, 0);
break;
}
}
// to be removed, should be added in engine tcp
inline void replay(test::Mock& mock, int rank) {
printf("[%d] All reduce %d\n", rank, mock.AllReduce(rank));
printf("[%d] All reduce %d\n", rank, mock.AllReduce(rank));
}
int main(int argc, char *argv[]) {
if (argc < 2) {
printf("Usage: <ndata>\n");
@ -69,6 +92,14 @@ int main(int argc, char *argv[]) {
sync::Init(argc, argv);
int rank = sync::GetRank();
std::string name = sync::GetProcessorName();
#ifdef TEST
test::Mock mock;
record(mock, rank);
mock.Replay();
replay(mock, rank);
#endif
printf("[%d] start at %s\n", rank, name.c_str());
TestMax(n);
printf("[%d] TestMax pass\n", rank);