execute it like ./test.sh 4 4000 testcase0.conf to obtain a successful execution
updating mock. It now wraps the calls to sync and reads config from configuration file. I believe it's better not to use the preprocessor directive, i.e. not to put any test code in the engine_tcp. I just call the mock in the test_allreduce file. It's a file purely for testing purposes, so it's fine to use the mock there.
This commit is contained in:
parent
21f3f3eec4
commit
faed8285cd
2
.gitignore
vendored
2
.gitignore
vendored
@ -28,4 +28,4 @@
|
|||||||
*.app
|
*.app
|
||||||
*~
|
*~
|
||||||
*.pyc
|
*.pyc
|
||||||
test
|
test_allreduce
|
||||||
@ -1,12 +1,11 @@
|
|||||||
|
#ifndef ALLREDUCE_H
|
||||||
|
#define ALLREDUCE_H
|
||||||
/*!
|
/*!
|
||||||
* \file allreduce.h
|
* \file allreduce.h
|
||||||
* \brief This file defines a template wrapper of engine to ensure
|
* \brief This file defines a template wrapper of engine to ensure
|
||||||
* \author Tianqi Chen, Nacho, Tianyi
|
* \author Tianqi Chen, Nacho, Tianyi
|
||||||
*/
|
*/
|
||||||
#include "./engine.h"
|
#include "./engine.h"
|
||||||
#ifdef TEST
|
|
||||||
#include "./mock.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/*! \brief namespace of all reduce */
|
/*! \brief namespace of all reduce */
|
||||||
namespace sync {
|
namespace sync {
|
||||||
@ -46,11 +45,7 @@ void Init(int argc, char *argv[]) {
|
|||||||
void Finalize(void) {
|
void Finalize(void) {
|
||||||
engine::Finalize();
|
engine::Finalize();
|
||||||
}
|
}
|
||||||
#ifdef TEST
|
|
||||||
void SetMock(const test::Mock& mock) {
|
|
||||||
engine::SetMock(mock);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
/*! \brief get rank of current process */
|
/*! \brief get rank of current process */
|
||||||
inline int GetRank(void) {
|
inline int GetRank(void) {
|
||||||
return engine::GetEngine()->GetRank();
|
return engine::GetEngine()->GetRank();
|
||||||
@ -113,3 +108,4 @@ inline void CheckPoint(const utils::ISerializable &model) {
|
|||||||
engine::GetEngine()->CheckPoint(model);
|
engine::GetEngine()->CheckPoint(model);
|
||||||
}
|
}
|
||||||
} // namespace allreduce
|
} // namespace allreduce
|
||||||
|
#endif // ALLREDUCE_H
|
||||||
|
|||||||
195
src/config.h
Normal file
195
src/config.h
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
#ifndef ALLREDUCE_UTILS_CONFIG_H_
|
||||||
|
#define ALLREDUCE_UTILS_CONFIG_H_
|
||||||
|
/*!
|
||||||
|
* \file config.h
|
||||||
|
* \brief helper class to load in configures from file
|
||||||
|
* \author Tianqi Chen
|
||||||
|
*/
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstring>
|
||||||
|
#include <string>
|
||||||
|
#include <istream>
|
||||||
|
#include <fstream>
|
||||||
|
#include "./utils.h"
|
||||||
|
|
||||||
|
namespace utils {
|
||||||
|
/*!
|
||||||
|
* \brief base implementation of config reader
|
||||||
|
*/
|
||||||
|
class ConfigReaderBase {
|
||||||
|
public:
|
||||||
|
/*!
|
||||||
|
* \brief get current name, called after Next returns true
|
||||||
|
* \return current parameter name
|
||||||
|
*/
|
||||||
|
inline const char *name(void) const {
|
||||||
|
return s_name;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief get current value, called after Next returns true
|
||||||
|
* \return current parameter value
|
||||||
|
*/
|
||||||
|
inline const char *val(void) const {
|
||||||
|
return s_val;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief move iterator to next position
|
||||||
|
* \return true if there is value in next position
|
||||||
|
*/
|
||||||
|
inline bool Next(void) {
|
||||||
|
while (!this->IsEnd()) {
|
||||||
|
GetNextToken(s_name);
|
||||||
|
if (s_name[0] == '=') return false;
|
||||||
|
if (GetNextToken( s_buf ) || s_buf[0] != '=') return false;
|
||||||
|
if (GetNextToken( s_val ) || s_val[0] == '=') return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// called before usage
|
||||||
|
inline void Init(void) {
|
||||||
|
ch_buf = this->GetChar();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/*!
|
||||||
|
* \brief to be implemented by subclass,
|
||||||
|
* get next token, return EOF if end of file
|
||||||
|
*/
|
||||||
|
virtual char GetChar(void) = 0;
|
||||||
|
/*! \brief to be implemented by child, check if end of stream */
|
||||||
|
virtual bool IsEnd(void) = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
char ch_buf;
|
||||||
|
char s_name[100000], s_val[100000], s_buf[100000];
|
||||||
|
|
||||||
|
inline void SkipLine(void) {
|
||||||
|
do {
|
||||||
|
ch_buf = this->GetChar();
|
||||||
|
} while (ch_buf != EOF && ch_buf != '\n' && ch_buf != '\r');
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ParseStr(char tok[]) {
|
||||||
|
int i = 0;
|
||||||
|
while ((ch_buf = this->GetChar()) != EOF) {
|
||||||
|
switch (ch_buf) {
|
||||||
|
case '\\': tok[i++] = this->GetChar(); break;
|
||||||
|
case '\"': tok[i++] = '\0'; return;
|
||||||
|
case '\r':
|
||||||
|
case '\n': Error("ConfigReader: unterminated string");
|
||||||
|
default: tok[i++] = ch_buf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Error("ConfigReader: unterminated string");
|
||||||
|
}
|
||||||
|
inline void ParseStrML(char tok[]) {
|
||||||
|
int i = 0;
|
||||||
|
while ((ch_buf = this->GetChar()) != EOF) {
|
||||||
|
switch (ch_buf) {
|
||||||
|
case '\\': tok[i++] = this->GetChar(); break;
|
||||||
|
case '\'': tok[i++] = '\0'; return;
|
||||||
|
default: tok[i++] = ch_buf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Error("unterminated string");
|
||||||
|
}
|
||||||
|
// return newline
|
||||||
|
inline bool GetNextToken(char tok[]) {
|
||||||
|
int i = 0;
|
||||||
|
bool new_line = false;
|
||||||
|
while (ch_buf != EOF) {
|
||||||
|
switch (ch_buf) {
|
||||||
|
case '#' : SkipLine(); new_line = true; break;
|
||||||
|
case '\"':
|
||||||
|
if (i == 0) {
|
||||||
|
ParseStr(tok); ch_buf = this->GetChar(); return new_line;
|
||||||
|
} else {
|
||||||
|
Error("ConfigReader: token followed directly by string");
|
||||||
|
}
|
||||||
|
case '\'':
|
||||||
|
if (i == 0) {
|
||||||
|
ParseStrML( tok ); ch_buf = this->GetChar(); return new_line;
|
||||||
|
} else {
|
||||||
|
Error("ConfigReader: token followed directly by string");
|
||||||
|
}
|
||||||
|
case '=':
|
||||||
|
if (i == 0) {
|
||||||
|
ch_buf = this->GetChar();
|
||||||
|
tok[0] = '=';
|
||||||
|
tok[1] = '\0';
|
||||||
|
} else {
|
||||||
|
tok[i] = '\0';
|
||||||
|
}
|
||||||
|
return new_line;
|
||||||
|
case '\r':
|
||||||
|
case '\n':
|
||||||
|
if (i == 0) new_line = true;
|
||||||
|
case '\t':
|
||||||
|
case ' ' :
|
||||||
|
ch_buf = this->GetChar();
|
||||||
|
if (i > 0) {
|
||||||
|
tok[i] = '\0';
|
||||||
|
return new_line;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
tok[i++] = ch_buf;
|
||||||
|
ch_buf = this->GetChar();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
/*!
|
||||||
|
* \brief an iterator use stream base, allows use all types of istream
|
||||||
|
*/
|
||||||
|
class ConfigStreamReader: public ConfigReaderBase {
|
||||||
|
public:
|
||||||
|
/*!
|
||||||
|
* \brief constructor
|
||||||
|
* \param istream input stream
|
||||||
|
*/
|
||||||
|
explicit ConfigStreamReader(std::istream &fin) : fin(fin) {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual char GetChar(void) {
|
||||||
|
return fin.get();
|
||||||
|
}
|
||||||
|
/*! \brief to be implemented by child, check if end of stream */
|
||||||
|
virtual bool IsEnd(void) {
|
||||||
|
return fin.eof();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::istream &fin;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief an iterator that iterates over a configure file and gets the configures
|
||||||
|
*/
|
||||||
|
class ConfigIterator: public ConfigStreamReader {
|
||||||
|
public:
|
||||||
|
/*!
|
||||||
|
* \brief constructor
|
||||||
|
* \param fname name of configure file
|
||||||
|
*/
|
||||||
|
explicit ConfigIterator(const char *fname) : ConfigStreamReader(fi) {
|
||||||
|
fi.open(fname);
|
||||||
|
if (fi.fail()) {
|
||||||
|
utils::Error("cannot open file %s", fname);
|
||||||
|
}
|
||||||
|
ConfigReaderBase::Init();
|
||||||
|
}
|
||||||
|
/*! \brief destructor */
|
||||||
|
~ConfigIterator(void) {
|
||||||
|
fi.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::ifstream fi;
|
||||||
|
};
|
||||||
|
} // namespace utils
|
||||||
|
|
||||||
|
#endif // ALLREDUCE_UTILS_CONFIG_H_
|
||||||
@ -6,9 +6,6 @@
|
|||||||
* \author Tianqi Chen, Nacho, Tianyi
|
* \author Tianqi Chen, Nacho, Tianyi
|
||||||
*/
|
*/
|
||||||
#include "./io.h"
|
#include "./io.h"
|
||||||
#ifdef TEST
|
|
||||||
#include "./mock.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
|
||||||
namespace MPI {
|
namespace MPI {
|
||||||
@ -81,9 +78,5 @@ void Finalize(void);
|
|||||||
/*! \brief singleton method to get engine */
|
/*! \brief singleton method to get engine */
|
||||||
IEngine *GetEngine(void);
|
IEngine *GetEngine(void);
|
||||||
|
|
||||||
#ifdef TEST
|
|
||||||
void SetMock(const test::Mock& mock);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
#endif // ALLREDUCE_ENGINE_H
|
#endif // ALLREDUCE_ENGINE_H
|
||||||
|
|||||||
@ -12,9 +12,6 @@
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include "./engine.h"
|
#include "./engine.h"
|
||||||
#include "./socket.h"
|
#include "./socket.h"
|
||||||
#ifdef TEST
|
|
||||||
#include "./mock.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace MPI {
|
namespace MPI {
|
||||||
class Datatype {
|
class Datatype {
|
||||||
@ -41,12 +38,6 @@ class SyncManager : public IEngine {
|
|||||||
~SyncManager(void) {
|
~SyncManager(void) {
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef TEST
|
|
||||||
inline void SetMock(const test::Mock& mock) {
|
|
||||||
this->mock = mock;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
inline void Shutdown(void) {
|
inline void Shutdown(void) {
|
||||||
for (size_t i = 0; i < links.size(); ++i) {
|
for (size_t i = 0; i < links.size(); ++i) {
|
||||||
links[i].sock.Close();
|
links[i].sock.Close();
|
||||||
@ -180,9 +171,6 @@ class SyncManager : public IEngine {
|
|||||||
size_t type_nbytes,
|
size_t type_nbytes,
|
||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer) {
|
ReduceFunction reducer) {
|
||||||
#ifdef TEST
|
|
||||||
utils::Assert(mock.AllReduce(this->rank), "Error returned by mock when reducing");
|
|
||||||
#endif
|
|
||||||
if (links.size() == 0) return;
|
if (links.size() == 0) return;
|
||||||
// total size of message
|
// total size of message
|
||||||
const size_t total_size = type_nbytes * count;
|
const size_t total_size = type_nbytes * count;
|
||||||
@ -304,10 +292,6 @@ class SyncManager : public IEngine {
|
|||||||
* \param root the root worker id to broadcast the data
|
* \param root the root worker id to broadcast the data
|
||||||
*/
|
*/
|
||||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||||
#ifdef TEST
|
|
||||||
utils::Assert(mock.Broadcast(this->rank), "Error returned by mock when broadcasting");
|
|
||||||
#endif
|
|
||||||
|
|
||||||
if (links.size() == 0) return;
|
if (links.size() == 0) return;
|
||||||
// number of links
|
// number of links
|
||||||
const int nlink = static_cast<int>(links.size());
|
const int nlink = static_cast<int>(links.size());
|
||||||
@ -368,15 +352,9 @@ class SyncManager : public IEngine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
virtual bool LoadCheckPoint(utils::ISerializable *p_model) {
|
virtual bool LoadCheckPoint(utils::ISerializable *p_model) {
|
||||||
#ifdef TEST
|
|
||||||
utils::Assert(mock.LoadCheckPoint(this->rank), "Error returned by mock when loading checkpoint");
|
|
||||||
#endif
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
virtual void CheckPoint(const utils::ISerializable &model) {
|
virtual void CheckPoint(const utils::ISerializable &model) {
|
||||||
#ifdef TEST
|
|
||||||
utils::Assert(mock.CheckPoint(this->rank), "Error returned by mock when checkpointing");
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -479,10 +457,6 @@ class SyncManager : public IEngine {
|
|||||||
// select helper
|
// select helper
|
||||||
utils::SelectHelper selecter;
|
utils::SelectHelper selecter;
|
||||||
|
|
||||||
#ifdef TEST
|
|
||||||
// mock to test
|
|
||||||
test::Mock mock;
|
|
||||||
#endif
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// singleton sync manager
|
// singleton sync manager
|
||||||
@ -499,13 +473,6 @@ void Init(int argc, char *argv[]) {
|
|||||||
manager.Init();
|
manager.Init();
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef TEST
|
|
||||||
/*! \brief sets a mock to the manager for testing purposes */
|
|
||||||
void SetMock(const test::Mock& mock) {
|
|
||||||
manager.SetMock(mock);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/*! \brief finalize syncrhonization module */
|
/*! \brief finalize syncrhonization module */
|
||||||
void Finalize(void) {
|
void Finalize(void) {
|
||||||
manager.Shutdown();
|
manager.Shutdown();
|
||||||
|
|||||||
102
src/mock.h
102
src/mock.h
@ -5,9 +5,8 @@
|
|||||||
* \brief This file defines a mock object to test the system
|
* \brief This file defines a mock object to test the system
|
||||||
* \author Tianqi Chen, Nacho, Tianyi
|
* \author Tianqi Chen, Nacho, Tianyi
|
||||||
*/
|
*/
|
||||||
#include "./engine.h"
|
#include "./allreduce.h"
|
||||||
#include "./utils.h"
|
#include "./config.h"
|
||||||
#include <queue>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
|
|
||||||
@ -16,82 +15,73 @@ namespace test {
|
|||||||
|
|
||||||
class Mock {
|
class Mock {
|
||||||
|
|
||||||
typedef std::map<int,std::queue<bool> > Map;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
Mock() : record(true) {}
|
Mock(const int& rank, char *config) : rank(rank) {
|
||||||
|
Init(config);
|
||||||
inline void Replay() {
|
|
||||||
record = false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// record methods
|
template<typename OP>
|
||||||
inline void OnAllReduce(int rank, bool success) {
|
inline void AllReduce(float *sendrecvbuf, size_t count) {
|
||||||
onRecord(allReduce, rank, success);
|
utils::Assert(verify(allReduce), "[%d] error when calling allReduce", rank);
|
||||||
|
sync::AllReduce<OP>(sendrecvbuf, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void OnBroadcast(int rank, bool success) {
|
inline bool LoadCheckPoint(utils::ISerializable *p_model) {
|
||||||
onRecord(broadcast, rank, success);
|
utils::Assert(verify(loadCheckpoint), "[%d] error when loading checkpoint", rank);
|
||||||
|
return sync::LoadCheckPoint(p_model);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void OnLoadCheckPoint(int rank, bool success) {
|
inline void CheckPoint(const utils::ISerializable &model) {
|
||||||
onRecord(loadCheckpoint, rank, success);
|
utils::Assert(verify(checkpoint), "[%d] error when checkpointing", rank);
|
||||||
|
sync::CheckPoint(model);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void OnCheckPoint(int rank, bool success) {
|
inline void Broadcast(std::string *sendrecv_data, int root) {
|
||||||
onRecord(checkpoint, rank, success);
|
utils::Assert(verify(broadcast), "[%d] error when broadcasting", rank);
|
||||||
|
sync::Bcast(sendrecv_data, root);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// replay methods
|
|
||||||
inline bool AllReduce(int rank) {
|
|
||||||
return onReplay(allReduce, rank);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool Broadcast(int rank) {
|
|
||||||
return onReplay(broadcast, rank);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool LoadCheckPoint(int rank) {
|
|
||||||
return onReplay(loadCheckpoint, rank);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool CheckPoint(int rank) {
|
|
||||||
return onReplay(checkpoint, rank);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
inline void onRecord(Map& m, int rank, bool success) {
|
inline void Init(char* config) {
|
||||||
utils::Check(record, "Not in record state");
|
utils::ConfigIterator itr(config);
|
||||||
Map::iterator it = m.find(rank);
|
while (itr.Next()) {
|
||||||
if (it == m.end()) {
|
char round[4], node_rank[4];
|
||||||
std::queue<bool> aQueue;
|
sscanf(itr.name(), "%[^_]_%s", round, node_rank);
|
||||||
m[rank] = aQueue;
|
int i_round = atoi(round);
|
||||||
|
if (i_round == 1) {
|
||||||
|
int i_node_rank = atoi(node_rank);
|
||||||
|
if (i_node_rank == rank) {
|
||||||
|
printf("[%d] round %d, value %s\n", rank, i_round, itr.val());
|
||||||
|
if (strcmp("allreduce", itr.val())) record(allReduce);
|
||||||
|
else if (strcmp("broadcast", itr.val())) record(broadcast);
|
||||||
|
else if (strcmp("loadcheckpoint", itr.val())) record(loadCheckpoint);
|
||||||
|
else if (strcmp("checkpoint", itr.val())) record(checkpoint);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
m[rank].push(success);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool onReplay(Map& m, int rank) {
|
inline void record(std::map<int,bool>& m) {
|
||||||
utils::Check(!record, "Not in replay state");
|
m[rank] = false;
|
||||||
utils::Check(m.find(rank) != m.end(), "Not recorded");
|
}
|
||||||
|
|
||||||
|
inline bool verify(std::map<int,bool>& m) {
|
||||||
bool result = true;
|
bool result = true;
|
||||||
if (!m[rank].empty()) {
|
if (m.find(rank) != m.end()) {
|
||||||
result = m[rank].front();
|
result = m[rank];
|
||||||
m[rank].pop();
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// flag to indicate if the mock is in record state
|
int rank;
|
||||||
bool record;
|
std::map<int,bool> allReduce;
|
||||||
|
std::map<int,bool> broadcast;
|
||||||
Map allReduce;
|
std::map<int,bool> loadCheckpoint;
|
||||||
Map broadcast;
|
std::map<int,bool> checkpoint;
|
||||||
Map loadCheckpoint;
|
|
||||||
Map checkpoint;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,10 +9,6 @@ ifeq ($(no_omp),1)
|
|||||||
else
|
else
|
||||||
CFLAGS += -fopenmp
|
CFLAGS += -fopenmp
|
||||||
endif
|
endif
|
||||||
ifeq ($(test),1)
|
|
||||||
CFLAGS += -DTEST
|
|
||||||
endif
|
|
||||||
|
|
||||||
|
|
||||||
# specify tensor path
|
# specify tensor path
|
||||||
BIN = test_allreduce
|
BIN = test_allreduce
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
if [ "$#" -ne 2 ];
|
if [ "$#" -ne 3 ];
|
||||||
then
|
then
|
||||||
echo "Usage <nslave> <ndata>"
|
echo "Usage <nslave> <ndata> <config>"
|
||||||
exit -1
|
exit -1
|
||||||
fi
|
fi
|
||||||
../submit_job_tcp.py $1 test_allreduce $2
|
../submit_job_tcp.py $1 test_allreduce $2 $3
|
||||||
|
|||||||
@ -8,7 +8,7 @@
|
|||||||
|
|
||||||
using namespace sync;
|
using namespace sync;
|
||||||
|
|
||||||
inline void TestMax(size_t n) {
|
inline void TestMax(test::Mock &mock, size_t n) {
|
||||||
int rank = sync::GetRank();
|
int rank = sync::GetRank();
|
||||||
int nproc = sync::GetWorldSize();
|
int nproc = sync::GetWorldSize();
|
||||||
|
|
||||||
@ -16,7 +16,7 @@ inline void TestMax(size_t n) {
|
|||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
ndata[i] = (i * (rank+1)) % 111;
|
ndata[i] = (i * (rank+1)) % 111;
|
||||||
}
|
}
|
||||||
sync::AllReduce<op::Max>(&ndata[0], ndata.size());
|
mock.AllReduce<op::Max>(&ndata[0], ndata.size());
|
||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
float rmax = (i * 1) % 111;
|
float rmax = (i * 1) % 111;
|
||||||
for (int r = 0; r < nproc; ++r) {
|
for (int r = 0; r < nproc; ++r) {
|
||||||
@ -26,7 +26,7 @@ inline void TestMax(size_t n) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void TestSum(size_t n) {
|
inline void TestSum(test::Mock &mock, size_t n) {
|
||||||
int rank = sync::GetRank();
|
int rank = sync::GetRank();
|
||||||
int nproc = sync::GetWorldSize();
|
int nproc = sync::GetWorldSize();
|
||||||
const int z = 131;
|
const int z = 131;
|
||||||
@ -35,7 +35,7 @@ inline void TestSum(size_t n) {
|
|||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
ndata[i] = (i * (rank+1)) % z;
|
ndata[i] = (i * (rank+1)) % z;
|
||||||
}
|
}
|
||||||
sync::AllReduce<op::Sum>(&ndata[0], ndata.size());
|
mock.AllReduce<op::Sum>(&ndata[0], ndata.size());
|
||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
float rsum = 0.0f;
|
float rsum = 0.0f;
|
||||||
for (int r = 0; r < nproc; ++r) {
|
for (int r = 0; r < nproc; ++r) {
|
||||||
@ -46,7 +46,7 @@ inline void TestSum(size_t n) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void TestBcast(size_t n, int root) {
|
inline void TestBcast(test::Mock &mock, size_t n, int root) {
|
||||||
int rank = sync::GetRank();
|
int rank = sync::GetRank();
|
||||||
std::string s; s.resize(n);
|
std::string s; s.resize(n);
|
||||||
for (size_t i = 0; i < n; ++i) {
|
for (size_t i = 0; i < n; ++i) {
|
||||||
@ -55,31 +55,16 @@ inline void TestBcast(size_t n, int root) {
|
|||||||
std::string res;
|
std::string res;
|
||||||
if (root == rank) {
|
if (root == rank) {
|
||||||
res = s;
|
res = s;
|
||||||
sync::Bcast(&res, root);
|
mock.Broadcast(&res, root);
|
||||||
} else {
|
} else {
|
||||||
sync::Bcast(&res, root);
|
mock.Broadcast(&res, root);
|
||||||
}
|
}
|
||||||
utils::Check(res == s, "[%d] TestBcast fail", rank);
|
utils::Check(res == s, "[%d] TestBcast fail", rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ugly stuff, just to see if it works. To be removed
|
|
||||||
inline void Record(test::Mock& mock, const int rank) {
|
|
||||||
switch(rank) {
|
|
||||||
case 0:
|
|
||||||
mock.OnAllReduce(0, false);
|
|
||||||
break;
|
|
||||||
case 1:
|
|
||||||
mock.OnAllReduce(1, false);
|
|
||||||
break;
|
|
||||||
case 2:
|
|
||||||
mock.OnAllReduce(2, true);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
if (argc < 2) {
|
if (argc < 3) {
|
||||||
printf("Usage: <ndata>\n");
|
printf("Usage: <ndata> <config>\n");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
int n = atoi(argv[1]);
|
int n = atoi(argv[1]);
|
||||||
@ -87,17 +72,12 @@ int main(int argc, char *argv[]) {
|
|||||||
int rank = sync::GetRank();
|
int rank = sync::GetRank();
|
||||||
std::string name = sync::GetProcessorName();
|
std::string name = sync::GetProcessorName();
|
||||||
|
|
||||||
#ifdef TEST
|
test::Mock mock(rank, argv[2]);
|
||||||
test::Mock mock;
|
|
||||||
Record(mock, rank);
|
|
||||||
mock.Replay();
|
|
||||||
sync::SetMock(mock);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
printf("[%d] start at %s\n", rank, name.c_str());
|
printf("[%d] start at %s\n", rank, name.c_str());
|
||||||
TestMax(n);
|
TestMax(mock, n);
|
||||||
printf("[%d] TestMax pass\n", rank);
|
printf("[%d] TestMax pass\n", rank);
|
||||||
TestSum(n);
|
TestSum(mock, n);
|
||||||
printf("[%d] TestSum pass\n", rank);
|
printf("[%d] TestSum pass\n", rank);
|
||||||
sync::Finalize();
|
sync::Finalize();
|
||||||
printf("[%d] all check pass\n", rank);
|
printf("[%d] all check pass\n", rank);
|
||||||
|
|||||||
1
test/testcase0.conf
Normal file
1
test/testcase0.conf
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Test Case 0 -> nothing fails
|
||||||
12
test/testcase1.conf
Normal file
12
test/testcase1.conf
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# Test Case example config
|
||||||
|
# You configure which methods should fail
|
||||||
|
# Format <round>_<rank> = <operation>
|
||||||
|
# <operation> can be one of the following = allreduce, broadcast, loadcheckpoint, checkpoint
|
||||||
|
|
||||||
|
1_0 = allreduce
|
||||||
|
1_1 = broadcast
|
||||||
|
1_2 = loadcheckpoint
|
||||||
|
1_3 = checkpoint
|
||||||
|
|
||||||
|
2_0 = allreduce
|
||||||
|
2_2 = checkpoint
|
||||||
Loading…
x
Reference in New Issue
Block a user