execute it like ./test.sh 4 4000 testcase0.conf to obtain a successful execution

updating mock. It now wraps the calls to sync and reads config from configuration file.
I believe it's better not to use the preprocessor directive, i.e. not to put any test code in the engine_tcp. I just call the mock in the test_allreduce file. It's a file purely for testing purposes, so it's fine to use the mock there.
This commit is contained in:
nachocano
2014-11-28 00:16:35 -08:00
parent 21f3f3eec4
commit faed8285cd
11 changed files with 274 additions and 144 deletions

View File

@@ -1,12 +1,11 @@
#ifndef ALLREDUCE_H
#define ALLREDUCE_H
/*!
* \file allreduce.h
* \brief This file defines a template wrapper of engine to ensure
* \author Tianqi Chen, Nacho, Tianyi
*/
#include "./engine.h"
#ifdef TEST
#include "./mock.h"
#endif
/*! \brief namespace of all reduce */
namespace sync {
@@ -46,11 +45,7 @@ void Init(int argc, char *argv[]) {
void Finalize(void) {
engine::Finalize();
}
#ifdef TEST
void SetMock(const test::Mock& mock) {
engine::SetMock(mock);
}
#endif
/*! \brief get rank of current process */
inline int GetRank(void) {
return engine::GetEngine()->GetRank();
@@ -113,3 +108,4 @@ inline void CheckPoint(const utils::ISerializable &model) {
engine::GetEngine()->CheckPoint(model);
}
} // namespace allreduce
#endif // ALLREDUCE_H

195
src/config.h Normal file
View File

@@ -0,0 +1,195 @@
#ifndef ALLREDUCE_UTILS_CONFIG_H_
#define ALLREDUCE_UTILS_CONFIG_H_
/*!
* \file config.h
* \brief helper class to load in configures from file
* \author Tianqi Chen
*/
#include <cstdio>
#include <cstring>
#include <string>
#include <istream>
#include <fstream>
#include "./utils.h"
namespace utils {
/*!
* \brief base implementation of config reader
*/
class ConfigReaderBase {
public:
/*!
* \brief get current name, called after Next returns true
* \return current parameter name
*/
inline const char *name(void) const {
return s_name;
}
/*!
* \brief get current value, called after Next returns true
* \return current parameter value
*/
inline const char *val(void) const {
return s_val;
}
/*!
* \brief move iterator to next position
* \return true if there is value in next position
*/
inline bool Next(void) {
while (!this->IsEnd()) {
GetNextToken(s_name);
if (s_name[0] == '=') return false;
if (GetNextToken( s_buf ) || s_buf[0] != '=') return false;
if (GetNextToken( s_val ) || s_val[0] == '=') return false;
return true;
}
return false;
}
// called before usage
inline void Init(void) {
ch_buf = this->GetChar();
}
protected:
/*!
* \brief to be implemented by subclass,
* get next token, return EOF if end of file
*/
virtual char GetChar(void) = 0;
/*! \brief to be implemented by child, check if end of stream */
virtual bool IsEnd(void) = 0;
private:
char ch_buf;
char s_name[100000], s_val[100000], s_buf[100000];
inline void SkipLine(void) {
do {
ch_buf = this->GetChar();
} while (ch_buf != EOF && ch_buf != '\n' && ch_buf != '\r');
}
inline void ParseStr(char tok[]) {
int i = 0;
while ((ch_buf = this->GetChar()) != EOF) {
switch (ch_buf) {
case '\\': tok[i++] = this->GetChar(); break;
case '\"': tok[i++] = '\0'; return;
case '\r':
case '\n': Error("ConfigReader: unterminated string");
default: tok[i++] = ch_buf;
}
}
Error("ConfigReader: unterminated string");
}
inline void ParseStrML(char tok[]) {
int i = 0;
while ((ch_buf = this->GetChar()) != EOF) {
switch (ch_buf) {
case '\\': tok[i++] = this->GetChar(); break;
case '\'': tok[i++] = '\0'; return;
default: tok[i++] = ch_buf;
}
}
Error("unterminated string");
}
// return newline
inline bool GetNextToken(char tok[]) {
int i = 0;
bool new_line = false;
while (ch_buf != EOF) {
switch (ch_buf) {
case '#' : SkipLine(); new_line = true; break;
case '\"':
if (i == 0) {
ParseStr(tok); ch_buf = this->GetChar(); return new_line;
} else {
Error("ConfigReader: token followed directly by string");
}
case '\'':
if (i == 0) {
ParseStrML( tok ); ch_buf = this->GetChar(); return new_line;
} else {
Error("ConfigReader: token followed directly by string");
}
case '=':
if (i == 0) {
ch_buf = this->GetChar();
tok[0] = '=';
tok[1] = '\0';
} else {
tok[i] = '\0';
}
return new_line;
case '\r':
case '\n':
if (i == 0) new_line = true;
case '\t':
case ' ' :
ch_buf = this->GetChar();
if (i > 0) {
tok[i] = '\0';
return new_line;
}
break;
default:
tok[i++] = ch_buf;
ch_buf = this->GetChar();
break;
}
}
return true;
}
};
/*!
* \brief an iterator use stream base, allows use all types of istream
*/
class ConfigStreamReader: public ConfigReaderBase {
public:
/*!
* \brief constructor
* \param istream input stream
*/
explicit ConfigStreamReader(std::istream &fin) : fin(fin) {}
protected:
virtual char GetChar(void) {
return fin.get();
}
/*! \brief to be implemented by child, check if end of stream */
virtual bool IsEnd(void) {
return fin.eof();
}
private:
std::istream &fin;
};
/*!
* \brief an iterator that iterates over a configure file and gets the configures
*/
class ConfigIterator: public ConfigStreamReader {
public:
/*!
* \brief constructor
* \param fname name of configure file
*/
explicit ConfigIterator(const char *fname) : ConfigStreamReader(fi) {
fi.open(fname);
if (fi.fail()) {
utils::Error("cannot open file %s", fname);
}
ConfigReaderBase::Init();
}
/*! \brief destructor */
~ConfigIterator(void) {
fi.close();
}
private:
std::ifstream fi;
};
} // namespace utils
#endif // ALLREDUCE_UTILS_CONFIG_H_

View File

@@ -6,9 +6,6 @@
* \author Tianqi Chen, Nacho, Tianyi
*/
#include "./io.h"
#ifdef TEST
#include "./mock.h"
#endif
namespace MPI {
@@ -81,9 +78,5 @@ void Finalize(void);
/*! \brief singleton method to get engine */
IEngine *GetEngine(void);
#ifdef TEST
void SetMock(const test::Mock& mock);
#endif
} // namespace engine
#endif // ALLREDUCE_ENGINE_H

View File

@@ -12,9 +12,6 @@
#include <cstring>
#include "./engine.h"
#include "./socket.h"
#ifdef TEST
#include "./mock.h"
#endif
namespace MPI {
class Datatype {
@@ -41,12 +38,6 @@ class SyncManager : public IEngine {
~SyncManager(void) {
}
#ifdef TEST
inline void SetMock(const test::Mock& mock) {
this->mock = mock;
}
#endif
inline void Shutdown(void) {
for (size_t i = 0; i < links.size(); ++i) {
links[i].sock.Close();
@@ -180,9 +171,6 @@ class SyncManager : public IEngine {
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
#ifdef TEST
utils::Assert(mock.AllReduce(this->rank), "Error returned by mock when reducing");
#endif
if (links.size() == 0) return;
// total size of message
const size_t total_size = type_nbytes * count;
@@ -304,10 +292,6 @@ class SyncManager : public IEngine {
* \param root the root worker id to broadcast the data
*/
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
#ifdef TEST
utils::Assert(mock.Broadcast(this->rank), "Error returned by mock when broadcasting");
#endif
if (links.size() == 0) return;
// number of links
const int nlink = static_cast<int>(links.size());
@@ -368,15 +352,9 @@ class SyncManager : public IEngine {
}
}
virtual bool LoadCheckPoint(utils::ISerializable *p_model) {
#ifdef TEST
utils::Assert(mock.LoadCheckPoint(this->rank), "Error returned by mock when loading checkpoint");
#endif
return false;
}
virtual void CheckPoint(const utils::ISerializable &model) {
#ifdef TEST
utils::Assert(mock.CheckPoint(this->rank), "Error returned by mock when checkpointing");
#endif
}
private:
@@ -479,10 +457,6 @@ class SyncManager : public IEngine {
// select helper
utils::SelectHelper selecter;
#ifdef TEST
// mock to test
test::Mock mock;
#endif
};
// singleton sync manager
@@ -499,13 +473,6 @@ void Init(int argc, char *argv[]) {
manager.Init();
}
#ifdef TEST
/*! \brief sets a mock to the manager for testing purposes */
void SetMock(const test::Mock& mock) {
manager.SetMock(mock);
}
#endif
/*! \brief finalize syncrhonization module */
void Finalize(void) {
manager.Shutdown();

View File

@@ -5,9 +5,8 @@
* \brief This file defines a mock object to test the system
* \author Tianqi Chen, Nacho, Tianyi
*/
#include "./engine.h"
#include "./utils.h"
#include <queue>
#include "./allreduce.h"
#include "./config.h"
#include <map>
@@ -16,82 +15,73 @@ namespace test {
class Mock {
typedef std::map<int,std::queue<bool> > Map;
public:
Mock() : record(true) {}
inline void Replay() {
record = false;
Mock(const int& rank, char *config) : rank(rank) {
Init(config);
}
// record methods
inline void OnAllReduce(int rank, bool success) {
onRecord(allReduce, rank, success);
template<typename OP>
inline void AllReduce(float *sendrecvbuf, size_t count) {
utils::Assert(verify(allReduce), "[%d] error when calling allReduce", rank);
sync::AllReduce<OP>(sendrecvbuf, count);
}
inline void OnBroadcast(int rank, bool success) {
onRecord(broadcast, rank, success);
inline bool LoadCheckPoint(utils::ISerializable *p_model) {
utils::Assert(verify(loadCheckpoint), "[%d] error when loading checkpoint", rank);
return sync::LoadCheckPoint(p_model);
}
inline void OnLoadCheckPoint(int rank, bool success) {
onRecord(loadCheckpoint, rank, success);
inline void CheckPoint(const utils::ISerializable &model) {
utils::Assert(verify(checkpoint), "[%d] error when checkpointing", rank);
sync::CheckPoint(model);
}
inline void OnCheckPoint(int rank, bool success) {
onRecord(checkpoint, rank, success);
inline void Broadcast(std::string *sendrecv_data, int root) {
utils::Assert(verify(broadcast), "[%d] error when broadcasting", rank);
sync::Bcast(sendrecv_data, root);
}
// replay methods
inline bool AllReduce(int rank) {
return onReplay(allReduce, rank);
}
inline bool Broadcast(int rank) {
return onReplay(broadcast, rank);
}
inline bool LoadCheckPoint(int rank) {
return onReplay(loadCheckpoint, rank);
}
inline bool CheckPoint(int rank) {
return onReplay(checkpoint, rank);
}
private:
inline void onRecord(Map& m, int rank, bool success) {
utils::Check(record, "Not in record state");
Map::iterator it = m.find(rank);
if (it == m.end()) {
std::queue<bool> aQueue;
m[rank] = aQueue;
inline void Init(char* config) {
utils::ConfigIterator itr(config);
while (itr.Next()) {
char round[4], node_rank[4];
sscanf(itr.name(), "%[^_]_%s", round, node_rank);
int i_round = atoi(round);
if (i_round == 1) {
int i_node_rank = atoi(node_rank);
if (i_node_rank == rank) {
printf("[%d] round %d, value %s\n", rank, i_round, itr.val());
if (strcmp("allreduce", itr.val())) record(allReduce);
else if (strcmp("broadcast", itr.val())) record(broadcast);
else if (strcmp("loadcheckpoint", itr.val())) record(loadCheckpoint);
else if (strcmp("checkpoint", itr.val())) record(checkpoint);
}
}
}
m[rank].push(success);
}
inline bool onReplay(Map& m, int rank) {
utils::Check(!record, "Not in replay state");
utils::Check(m.find(rank) != m.end(), "Not recorded");
inline void record(std::map<int,bool>& m) {
m[rank] = false;
}
inline bool verify(std::map<int,bool>& m) {
bool result = true;
if (!m[rank].empty()) {
result = m[rank].front();
m[rank].pop();
if (m.find(rank) != m.end()) {
result = m[rank];
}
return result;
}
// flag to indicate if the mock is in record state
bool record;
Map allReduce;
Map broadcast;
Map loadCheckpoint;
Map checkpoint;
int rank;
std::map<int,bool> allReduce;
std::map<int,bool> broadcast;
std::map<int,bool> loadCheckpoint;
std::map<int,bool> checkpoint;
};
}