cleanup testcases
This commit is contained in:
@@ -26,7 +26,7 @@ AllreduceBase::AllreduceBase(void) {
|
||||
hadoop_mode = 0;
|
||||
version_number = 0;
|
||||
task_id = "NULL";
|
||||
this->SetParam("reduce_buffer", "256MB");
|
||||
this->SetParam("rabit_reduce_buffer", "256MB");
|
||||
}
|
||||
|
||||
// initialization function
|
||||
@@ -38,8 +38,8 @@ void AllreduceBase::Init(void) {
|
||||
utils::Check(task_id != NULL, "hadoop_mode is set but cannot find mapred_task_id");
|
||||
}
|
||||
if (task_id != NULL) {
|
||||
this->SetParam("task_id", task_id);
|
||||
this->SetParam("hadoop_mode", "1");
|
||||
this->SetParam("rabit_task_id", task_id);
|
||||
this->SetParam("rabit_hadoop_mode", "1");
|
||||
}
|
||||
}
|
||||
// start socket
|
||||
@@ -83,9 +83,9 @@ void AllreduceBase::Shutdown(void) {
|
||||
void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val;
|
||||
if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val);
|
||||
if (!strcmp(name, "task_id")) task_id = val;
|
||||
if (!strcmp(name, "hadoop_mode")) hadoop_mode = atoi(val);
|
||||
if (!strcmp(name, "reduce_buffer")) {
|
||||
if (!strcmp(name, "rabit_task_id")) task_id = val;
|
||||
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
|
||||
if (!strcmp(name, "rabit_reduce_buffer")) {
|
||||
char unit;
|
||||
unsigned long amount;
|
||||
if (sscanf(val, "%lu%c", &amount, &unit) == 2) {
|
||||
|
||||
@@ -17,10 +17,10 @@
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
AllreduceRobust::AllreduceRobust(void) {
|
||||
result_buffer_round = 1;
|
||||
num_local_replica = 0;
|
||||
seq_counter = 0;
|
||||
local_chkpt_version = 0;
|
||||
result_buffer_round = 1;
|
||||
}
|
||||
/*! \brief shutdown the engine */
|
||||
void AllreduceRobust::Shutdown(void) {
|
||||
@@ -42,11 +42,11 @@ void AllreduceRobust::Shutdown(void) {
|
||||
*/
|
||||
void AllreduceRobust::SetParam(const char *name, const char *val) {
|
||||
AllreduceBase::SetParam(name, val);
|
||||
if (!strcmp(name, "result_buffer_round")) result_buffer_round = atoi(val);
|
||||
if (!strcmp(name, "result_replicate")) {
|
||||
if (!strcmp(name, "rabit_buffer_round")) result_buffer_round = atoi(val);
|
||||
if (!strcmp(name, "rabit_global_replica")) {
|
||||
result_buffer_round = std::max(world_size / atoi(val), 1);
|
||||
}
|
||||
if (!strcmp(name, "num_local_replica")) {
|
||||
if (!strcmp(name, "rabit_local_replica")) {
|
||||
num_local_replica = atoi(val);
|
||||
}
|
||||
}
|
||||
|
||||
196
src/config.h
196
src/config.h
@@ -1,196 +0,0 @@
|
||||
#ifndef RABIT_UTILS_CONFIG_H_
|
||||
#define RABIT_UTILS_CONFIG_H_
|
||||
/*!
|
||||
* \file config.h
|
||||
* \brief helper class to load in configures from file
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <istream>
|
||||
#include <fstream>
|
||||
#include "./utils.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace utils {
|
||||
/*!
|
||||
* \brief base implementation of config reader
|
||||
*/
|
||||
class ConfigReaderBase {
|
||||
public:
|
||||
/*!
|
||||
* \brief get current name, called after Next returns true
|
||||
* \return current parameter name
|
||||
*/
|
||||
inline const char *name(void) const {
|
||||
return s_name;
|
||||
}
|
||||
/*!
|
||||
* \brief get current value, called after Next returns true
|
||||
* \return current parameter value
|
||||
*/
|
||||
inline const char *val(void) const {
|
||||
return s_val;
|
||||
}
|
||||
/*!
|
||||
* \brief move iterator to next position
|
||||
* \return true if there is value in next position
|
||||
*/
|
||||
inline bool Next(void) {
|
||||
while (!this->IsEnd()) {
|
||||
GetNextToken(s_name);
|
||||
if (s_name[0] == '=') return false;
|
||||
if (GetNextToken( s_buf ) || s_buf[0] != '=') return false;
|
||||
if (GetNextToken( s_val ) || s_val[0] == '=') return false;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// called before usage
|
||||
inline void Init(void) {
|
||||
ch_buf = this->GetChar();
|
||||
}
|
||||
|
||||
protected:
|
||||
/*!
|
||||
* \brief to be implemented by subclass,
|
||||
* get next token, return EOF if end of file
|
||||
*/
|
||||
virtual char GetChar(void) = 0;
|
||||
/*! \brief to be implemented by child, check if end of stream */
|
||||
virtual bool IsEnd(void) = 0;
|
||||
|
||||
private:
|
||||
char ch_buf;
|
||||
char s_name[100000], s_val[100000], s_buf[100000];
|
||||
|
||||
inline void SkipLine(void) {
|
||||
do {
|
||||
ch_buf = this->GetChar();
|
||||
} while (ch_buf != EOF && ch_buf != '\n' && ch_buf != '\r');
|
||||
}
|
||||
|
||||
inline void ParseStr(char tok[]) {
|
||||
int i = 0;
|
||||
while ((ch_buf = this->GetChar()) != EOF) {
|
||||
switch (ch_buf) {
|
||||
case '\\': tok[i++] = this->GetChar(); break;
|
||||
case '\"': tok[i++] = '\0'; return;
|
||||
case '\r':
|
||||
case '\n': Error("ConfigReader: unterminated string");
|
||||
default: tok[i++] = ch_buf;
|
||||
}
|
||||
}
|
||||
Error("ConfigReader: unterminated string");
|
||||
}
|
||||
inline void ParseStrML(char tok[]) {
|
||||
int i = 0;
|
||||
while ((ch_buf = this->GetChar()) != EOF) {
|
||||
switch (ch_buf) {
|
||||
case '\\': tok[i++] = this->GetChar(); break;
|
||||
case '\'': tok[i++] = '\0'; return;
|
||||
default: tok[i++] = ch_buf;
|
||||
}
|
||||
}
|
||||
Error("unterminated string");
|
||||
}
|
||||
// return newline
|
||||
inline bool GetNextToken(char tok[]) {
|
||||
int i = 0;
|
||||
bool new_line = false;
|
||||
while (ch_buf != EOF) {
|
||||
switch (ch_buf) {
|
||||
case '#' : SkipLine(); new_line = true; break;
|
||||
case '\"':
|
||||
if (i == 0) {
|
||||
ParseStr(tok); ch_buf = this->GetChar(); return new_line;
|
||||
} else {
|
||||
Error("ConfigReader: token followed directly by string");
|
||||
}
|
||||
case '\'':
|
||||
if (i == 0) {
|
||||
ParseStrML( tok ); ch_buf = this->GetChar(); return new_line;
|
||||
} else {
|
||||
Error("ConfigReader: token followed directly by string");
|
||||
}
|
||||
case '=':
|
||||
if (i == 0) {
|
||||
ch_buf = this->GetChar();
|
||||
tok[0] = '=';
|
||||
tok[1] = '\0';
|
||||
} else {
|
||||
tok[i] = '\0';
|
||||
}
|
||||
return new_line;
|
||||
case '\r':
|
||||
case '\n':
|
||||
if (i == 0) new_line = true;
|
||||
case '\t':
|
||||
case ' ' :
|
||||
ch_buf = this->GetChar();
|
||||
if (i > 0) {
|
||||
tok[i] = '\0';
|
||||
return new_line;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
tok[i++] = ch_buf;
|
||||
ch_buf = this->GetChar();
|
||||
break;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
/*!
|
||||
* \brief an iterator use stream base, allows use all types of istream
|
||||
*/
|
||||
class ConfigStreamReader: public ConfigReaderBase {
|
||||
public:
|
||||
/*!
|
||||
* \brief constructor
|
||||
* \param istream input stream
|
||||
*/
|
||||
explicit ConfigStreamReader(std::istream &fin) : fin(fin) {}
|
||||
|
||||
protected:
|
||||
virtual char GetChar(void) {
|
||||
return fin.get();
|
||||
}
|
||||
/*! \brief to be implemented by child, check if end of stream */
|
||||
virtual bool IsEnd(void) {
|
||||
return fin.eof();
|
||||
}
|
||||
|
||||
private:
|
||||
std::istream &fin;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief an iterator that iterates over a configure file and gets the configures
|
||||
*/
|
||||
class ConfigIterator: public ConfigStreamReader {
|
||||
public:
|
||||
/*!
|
||||
* \brief constructor
|
||||
* \param fname name of configure file
|
||||
*/
|
||||
explicit ConfigIterator(const char *fname) : ConfigStreamReader(fi) {
|
||||
fi.open(fname);
|
||||
if (fi.fail()) {
|
||||
utils::Error("cannot open file %s", fname);
|
||||
}
|
||||
ConfigReaderBase::Init();
|
||||
}
|
||||
/*! \brief destructor */
|
||||
~ConfigIterator(void) {
|
||||
fi.close();
|
||||
}
|
||||
|
||||
private:
|
||||
std::ifstream fi;
|
||||
};
|
||||
} // namespace utils
|
||||
} // namespace rabit
|
||||
#endif // RABIT_UTILS_CONFIG_H_
|
||||
118
src/mock.h
118
src/mock.h
@@ -1,118 +0,0 @@
|
||||
#ifndef RABIT_MOCK_H
|
||||
#define RABIT_MOCK_H
|
||||
/*!
|
||||
* \file mock.h
|
||||
* \brief This file defines a mock object to test the system
|
||||
* \author Ignacio Cano
|
||||
*/
|
||||
#include "./rabit.h"
|
||||
#include "./config.h"
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
|
||||
namespace rabit {
|
||||
/*! \brief namespace of mock */
|
||||
namespace test {
|
||||
|
||||
class Mock {
|
||||
|
||||
|
||||
public:
|
||||
|
||||
explicit Mock(const int& rank, char *config, char* round_dir) : rank(rank) {
|
||||
Init(config, round_dir);
|
||||
}
|
||||
|
||||
template<typename OP>
|
||||
inline void Allreduce(float *sendrecvbuf, size_t count) {
|
||||
utils::Assert(verify(allReduce), "[%d] error when calling allReduce", rank);
|
||||
rabit::Allreduce<OP>(sendrecvbuf, count);
|
||||
}
|
||||
|
||||
inline int LoadCheckPoint(utils::ISerializable *global_model,
|
||||
utils::ISerializable *local_model) {
|
||||
utils::Assert(verify(loadCheckpoint), "[%d] error when loading checkpoint", rank);
|
||||
return rabit::LoadCheckPoint(global_model, local_model);
|
||||
}
|
||||
|
||||
inline void CheckPoint(const utils::ISerializable *global_model,
|
||||
const utils::ISerializable *local_model) {
|
||||
utils::Assert(verify(checkpoint), "[%d] error when checkpointing", rank);
|
||||
rabit::CheckPoint(global_model, local_model);
|
||||
}
|
||||
|
||||
inline void Broadcast(std::string *sendrecv_data, int root) {
|
||||
utils::Assert(verify(broadcast), "[%d] error when broadcasting", rank);
|
||||
rabit::Broadcast(sendrecv_data, root);
|
||||
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
inline void Init(char* config, char* round_dir) {
|
||||
std::stringstream ss;
|
||||
ss << round_dir << "node" << rank << ".round";
|
||||
const char* round_file = ss.str().c_str();
|
||||
std::ifstream ifs(round_file);
|
||||
int current_round = 1;
|
||||
if (!ifs.good()) {
|
||||
// file does not exists, it's the first time, so save the current round to 1
|
||||
std::ofstream ofs(round_file);
|
||||
ofs << current_round;
|
||||
ofs.close();
|
||||
} else {
|
||||
// file does exists, read the previous round, increment by one, and save it back
|
||||
ifs >> current_round;
|
||||
current_round++;
|
||||
ifs.close();
|
||||
std::ofstream ofs(round_file);
|
||||
ofs << current_round;
|
||||
ofs.close();
|
||||
}
|
||||
printf("[%d] in round %d\n", rank, current_round);
|
||||
utils::ConfigIterator itr(config);
|
||||
while (itr.Next()) {
|
||||
char round[4], node_rank[4];
|
||||
sscanf(itr.name(), "%[^_]_%s", round, node_rank);
|
||||
int i_node_rank = atoi(node_rank);
|
||||
// if it's something for me
|
||||
if (i_node_rank == rank) {
|
||||
int i_round = atoi(round);
|
||||
// in my current round
|
||||
if (i_round == current_round) {
|
||||
printf("[%d] round %d, value %s\n", rank, i_round, itr.val());
|
||||
if (strcmp("allreduce", itr.val())) record(allReduce);
|
||||
else if (strcmp("broadcast", itr.val())) record(broadcast);
|
||||
else if (strcmp("loadcheckpoint", itr.val())) record(loadCheckpoint);
|
||||
else if (strcmp("checkpoint", itr.val())) record(checkpoint);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void record(std::map<int,bool>& m) {
|
||||
m[rank] = false;
|
||||
}
|
||||
|
||||
inline bool verify(std::map<int,bool>& m) {
|
||||
bool result = true;
|
||||
if (m.find(rank) != m.end()) {
|
||||
result = m[rank];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
int rank;
|
||||
std::map<int,bool> allReduce;
|
||||
std::map<int,bool> broadcast;
|
||||
std::map<int,bool> loadCheckpoint;
|
||||
std::map<int,bool> checkpoint;
|
||||
|
||||
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace rabit
|
||||
|
||||
#endif // RABIT_MOCK_H
|
||||
Reference in New Issue
Block a user