add speed test

This commit is contained in:
tqchen 2014-12-06 11:05:24 -08:00
parent 19631ecef6
commit 0e012cb05e
6 changed files with 234 additions and 36 deletions

View File

@ -24,12 +24,12 @@ AllreduceRobust::AllreduceRobust(void) {
void AllreduceRobust::Shutdown(void) { void AllreduceRobust::Shutdown(void) {
// need to sync the exec before we shutdown, do a pesudo check point // need to sync the exec before we shutdown, do a pesudo check point
// execute checkpoint, note: when checkpoint existing, load will not happen // execute checkpoint, note: when checkpoint existing, load will not happen
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq), utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
"check point must return true"); "check point must return true");
// reset result buffer // reset result buffer
resbuf.Clear(); seq_counter = 0; resbuf.Clear(); seq_counter = 0;
// execute check ack step, load happens here // execute check ack step, load happens here
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq), utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
"check ack must return true"); "check ack must return true");
AllreduceBase::Shutdown(); AllreduceBase::Shutdown();
} }
@ -133,7 +133,7 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
utils::ISerializable *local_model) { utils::ISerializable *local_model) {
utils::Check(local_model == NULL, "CheckPoint local_model is not yet supported"); utils::Check(local_model == NULL, "CheckPoint local_model is not yet supported");
// check if we succesfll // check if we succesfll
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) { if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp)) {
// reset result buffer // reset result buffer
resbuf.Clear(); seq_counter = 0; resbuf.Clear(); seq_counter = 0;
// load from buffer // load from buffer
@ -142,7 +142,7 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
if (version_number == 0) return version_number; if (version_number == 0) return version_number;
global_model->Load(fs); global_model->Load(fs);
// run another phase of check ack, if recovered from data // run another phase of check ack, if recovered from data
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq), utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
"check ack must return true"); "check ack must return true");
return version_number; return version_number;
} else { } else {
@ -172,7 +172,7 @@ void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
const utils::ISerializable *local_model) { const utils::ISerializable *local_model) {
utils::Assert(local_model == NULL, "CheckPoint local model is not supported yet"); utils::Assert(local_model == NULL, "CheckPoint local model is not supported yet");
// execute checkpoint, note: when checkpoint existing, load will not happen // execute checkpoint, note: when checkpoint existing, load will not happen
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq), utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
"check point must return true"); "check point must return true");
// this is the critical region where we will change all the stored models // this is the critical region where we will change all the stored models
// increase version number // increase version number
@ -185,7 +185,7 @@ void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
// reset result buffer // reset result buffer
resbuf.Clear(); seq_counter = 0; resbuf.Clear(); seq_counter = 0;
// execute check ack step, load happens here // execute check ack step, load happens here
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq), utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
"check ack must return true"); "check ack must return true");
} }
/*! /*!
@ -608,6 +608,10 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
*/ */
AllreduceRobust::ReturnType AllreduceRobust::ReturnType
AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) { RecoverType role; AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) { RecoverType role;
// if minimum sequence requested is local check point ack,
// this means all nodes have finished local check point, directly return
if (seqno == ActionSummary::kLocalCheckAck) return kSuccess;
if (!requester) { if (!requester) {
sendrecvbuf = resbuf.Query(seqno, &size); sendrecvbuf = resbuf.Query(seqno, &size);
role = sendrecvbuf != NULL ? kHaveData : kPassData; role = sendrecvbuf != NULL ? kHaveData : kPassData;
@ -631,7 +635,7 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
* \param size the total size of the buffer * \param size the total size of the buffer
* \param flag flag information about the action \sa ActionSummary * \param flag flag information about the action \sa ActionSummary
* \param seqno sequence number of the action, if it is special action with flag set, * \param seqno sequence number of the action, if it is special action with flag set,
* seqno needs to be set to ActionSummary::kMaxSeq * seqno needs to be set to ActionSummary::kSpecialOp
* *
* \return if this function can return true or false * \return if this function can return true or false
* - true means buf already set to the * - true means buf already set to the
@ -640,7 +644,7 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
*/ */
bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) { bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
if (flag != 0) { if (flag != 0) {
utils::Assert(seqno == ActionSummary::kMaxSeq, "must only set seqno for normal operations"); utils::Assert(seqno == ActionSummary::kSpecialOp, "must only set seqno for normal operations");
} }
// request // request
ActionSummary req(flag, seqno); ActionSummary req(flag, seqno);
@ -672,7 +676,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
} else { } else {
if (act.check_point()) { if (act.check_point()) {
if (act.diff_seq()) { if (act.diff_seq()) {
utils::Assert(act.min_seqno() != ActionSummary::kMaxSeq, "min seq bug"); utils::Assert(act.min_seqno() != ActionSummary::kSpecialOp, "min seq bug");
bool requester = req.min_seqno() == act.min_seqno(); bool requester = req.min_seqno() == act.min_seqno();
if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue; if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue;
if (requester) return true; if (requester) return true;
@ -691,7 +695,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
if (req.load_check()) return true; if (req.load_check()) return true;
} else { } else {
// no special flags, no checkpoint, check ack, load_check // no special flags, no checkpoint, check ack, load_check
utils::Assert(act.min_seqno() != ActionSummary::kMaxSeq, "min seq bug"); utils::Assert(act.min_seqno() != ActionSummary::kSpecialOp, "min seq bug");
if (act.diff_seq()) { if (act.diff_seq()) {
bool requester = req.min_seqno() == act.min_seqno(); bool requester = req.min_seqno() == act.min_seqno();
if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue; if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue;
@ -708,7 +712,54 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
utils::Assert(false, "RecoverExec: should not reach here"); utils::Assert(false, "RecoverExec: should not reach here");
return true; return true;
} }
/*!
* \brief try to recover the local state, making each local state to be the result of itself
* plus replication of states in previous num_local_replica hops in the ring
*
* The input parameters must contain the valid local states available in current nodes,
* This function try ist best to "complete" the missing parts of local_rptr and local_chkpt
* If there is sufficient information in the ring, when the function returns, local_chkpt will
* contain num_local_replica + 1 checkpoints (including the chkpt of this node)
* If there is no sufficient information in the ring, this function the number of checkpoints
* will be less than the specified value
*
* \param p_local_rptr the pointer to the segment pointers in the states array
* \param p_local_chkpt the pointer to the storage of local check points
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllreduceRobust::ReturnType
AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
std::string *p_local_chkpt) {
// if there is no local replica, we can do nothing
if (num_local_replica == 0) return kSuccess;
std::vector<size_t> &rptr = *p_local_rptr;
std::string &chkpt = *p_local_chkpt;
if (rptr.size() == 0) {
rptr.push_back(0);
utils::Assert(chkpt.length() == 0, "local chkpt space inconsistent");
}
const int n = num_local_replica;
// message send to previous link
{
int msg_forward[2];
int nlocal = static_cast<int>(rptr.size() - 1);
msg_forward[0] = nlocal;
utils::Assert(msg_forward[0] <= n, "invalid local replica");
// backward passing one hop the request
ReturnType succ = RingPassing(msg_forward,
1 * sizeof(int), 2 * sizeof(int),
0 * sizeof(int), 1 * sizeof(int),
ring_prev, ring_next);
if (succ != kSuccess) return succ;
// check how much current node can help with the request
// if (nlocal > ) {
//}
}
return kSuccess;
}
/*! /*!
* \brief perform a ring passing to receive data from prev link, and sent data to next link * \brief perform a ring passing to receive data from prev link, and sent data to next link
* this allows data to stream over a ring structure * this allows data to stream over a ring structure
@ -723,8 +774,8 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
* \param read_end the ending position to read * \param read_end the ending position to read
* \param write_ptr the initial write pointer * \param write_ptr the initial write pointer
* \param write_end the ending position to write * \param write_end the ending position to write
* \param prev pointer to link to previous position in ring * \param read_link pointer to link to previous position in ring
* \param prev pointer to link of next position in ring * \param write_link pointer to link of next position in ring
*/ */
AllreduceRobust::ReturnType AllreduceRobust::ReturnType
AllreduceRobust::RingPassing(void *sendrecvbuf_, AllreduceRobust::RingPassing(void *sendrecvbuf_,
@ -732,14 +783,14 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
size_t read_end, size_t read_end,
size_t write_ptr, size_t write_ptr,
size_t write_end, size_t write_end,
LinkRecord *prev_link, LinkRecord *read_link,
LinkRecord *next_link) { LinkRecord *write_link) {
if (links.size() == 0 || read_end == 0) return kSuccess; if (links.size() == 0 || read_end == 0) return kSuccess;
utils::Assert(read_end <= write_end, "boundary check"); utils::Assert(read_end <= write_end, "boundary check");
utils::Assert(read_ptr <= read_end, "boundary check"); utils::Assert(read_ptr <= read_end, "boundary check");
utils::Assert(write_ptr <= write_end, "boundary check"); utils::Assert(write_ptr <= write_end, "boundary check");
// take reference // take reference
LinkRecord &prev = *prev_link, &next = *next_link; LinkRecord &prev = *read_link, &next = *write_link;
// send recv buffer // send recv buffer
char *buf = reinterpret_cast<char*>(sendrecvbuf_); char *buf = reinterpret_cast<char*>(sendrecvbuf_);
while (true) { while (true) {

View File

@ -122,7 +122,11 @@ class AllreduceRobust : public AllreduceBase {
*/ */
struct ActionSummary { struct ActionSummary {
// maximumly allowed sequence id // maximumly allowed sequence id
const static int kMaxSeq = 1 << 26; const static int kSpecialOp = (1 << 26);
// special sequence number for local state checkpoint
const static int kLocalCheckPoint = (1 << 26) - 2;
// special sequnce number for local state checkpoint ack signal
const static int kLocalCheckAck = (1 << 26) - 1;
//--------------------------------------------- //---------------------------------------------
// The following are bit mask of flag used in // The following are bit mask of flag used in
//---------------------------------------------- //----------------------------------------------
@ -140,7 +144,7 @@ class AllreduceRobust : public AllreduceBase {
// constructor // constructor
ActionSummary(void) {} ActionSummary(void) {}
// constructor of action // constructor of action
ActionSummary(int flag, int minseqno = kMaxSeq) { ActionSummary(int flag, int minseqno = kSpecialOp) {
seqcode = (minseqno << 4) | flag; seqcode = (minseqno << 4) | flag;
} }
// minimum number of all operations // minimum number of all operations
@ -277,14 +281,14 @@ class AllreduceRobust : public AllreduceBase {
* \param buf the buffer to store the result * \param buf the buffer to store the result
* \param size the total size of the buffer * \param size the total size of the buffer
* \param flag flag information about the action \sa ActionSummary * \param flag flag information about the action \sa ActionSummary
* \param seqno sequence number of the action, if it is special action with flag set, seqno needs to be set to ActionSummary::kMaxSeq * \param seqno sequence number of the action, if it is special action with flag set, seqno needs to be set to ActionSummary::kSpecialOp
* *
* \return if this function can return true or false * \return if this function can return true or false
* - true means buf already set to the * - true means buf already set to the
* result by recovering procedure, the action is complete, no further action is needed * result by recovering procedure, the action is complete, no further action is needed
* - false means this is the lastest action that has not yet been executed, need to execute the action * - false means this is the lastest action that has not yet been executed, need to execute the action
*/ */
bool RecoverExec(void *buf, size_t size, int flag, int seqno = ActionSummary::kMaxSeq); bool RecoverExec(void *buf, size_t size, int flag, int seqno = ActionSummary::kSpecialOp);
/*! /*!
* \brief try to load check point * \brief try to load check point
* *
@ -344,12 +348,30 @@ class AllreduceRobust : public AllreduceBase {
* *
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType, TryDecideRouting * \sa ReturnType, TryDecideRouting
*/ */
ReturnType TryRecoverData(RecoverType role, ReturnType TryRecoverData(RecoverType role,
void *sendrecvbuf_, void *sendrecvbuf_,
size_t size, size_t size,
int recv_link, int recv_link,
const std::vector<bool> &req_in); const std::vector<bool> &req_in);
/*!
* \brief try to recover the local state, making each local state to be the result of itself
* plus replication of states in previous num_local_replica hops in the ring
*
* The input parameters must contain the valid local states available in current nodes,
* This function try ist best to "complete" the missing parts of local_rptr and local_chkpt
* If there is sufficient information in the ring, when the function returns, local_chkpt will
* contain num_local_replica + 1 checkpoints (including the chkpt of this node)
* If there is no sufficient information in the ring, this function the number of checkpoints
* will be less than the specified value
*
* \param p_local_rptr the pointer to the segment pointers in the states array
* \param p_local_chkpt the pointer to the storage of local check points
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
std::string *p_local_chkpt);
/*! /*!
* \brief perform a ring passing to receive data from prev link, and sent data to next link * \brief perform a ring passing to receive data from prev link, and sent data to next link
* this allows data to stream over a ring structure * this allows data to stream over a ring structure
@ -364,16 +386,16 @@ class AllreduceRobust : public AllreduceBase {
* \param read_end the ending position to read * \param read_end the ending position to read
* \param write_ptr the initial write pointer * \param write_ptr the initial write pointer
* \param write_end the ending position to write * \param write_end the ending position to write
* \param prev pointer to link to previous position in ring * \param read_link pointer to link to previous position in ring
* \param prev pointer to link of next position in ring * \param write_link pointer to link of next position in ring
*/ */
ReturnType RingPassing(void *senrecvbuf_, ReturnType RingPassing(void *senrecvbuf_,
size_t read_ptr, size_t read_ptr,
size_t read_end, size_t read_end,
size_t write_ptr, size_t write_ptr,
size_t write_end, size_t write_end,
LinkRecord *prev_link, LinkRecord *read_link,
LinkRecord *next_link); LinkRecord *write_link);
/*! /*!
* \brief run message passing algorithm on the allreduce tree * \brief run message passing algorithm on the allreduce tree
* the result is edge message stored in p_edge_in and p_edge_out * the result is edge message stored in p_edge_in and p_edge_out
@ -410,15 +432,18 @@ class AllreduceRobust : public AllreduceBase {
std::string global_checkpoint; std::string global_checkpoint;
// number of replica for local state/model // number of replica for local state/model
int num_local_replica; int num_local_replica;
// --- recovery data structure for local checkpoint
// there is two version of the data structure,
// at one time one version is valid and another is used as temp memory
// pointer to memory position in the local model // pointer to memory position in the local model
// local model is stored in CSR format(like a sparse matrices) // local model is stored in CSR format(like a sparse matrices)
// local_model[rptr[0]:rptr[1]] stores the model of current node // local_model[rptr[0]:rptr[1]] stores the model of current node
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring // local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring
std::vector<size_t> local_rptr; std::vector<size_t> local_rptr[2];
// storage for local model replicas // storage for local model replicas
std::string local_checkpoint; std::string local_checkpoint[2];
// temporal storage for doing local checkpointing // version of local checkpoint can be 1 or 0
std::string tmp_local_check; int local_chkpt_version;
}; };
} // namespace engine } // namespace engine
} // namespace rabit } // namespace rabit

View File

@ -166,7 +166,7 @@ class Tracker:
s.assign_rank(rank, wait_conn, nslave) s.assign_rank(rank, wait_conn, nslave)
if s.wait_accept > 0: if s.wait_accept > 0:
wait_conn[rank] = s wait_conn[rank] = s
print 'all slaves setup complete' print 'All nodes finishes job'
def mpi_submit(nslave, args): def mpi_submit(nslave, args):
cmd = ' '.join(['mpirun -n %d' % nslave] + args) cmd = ' '.join(['mpirun -n %d' % nslave] + args)

23
src/timer.h Normal file
View File

@ -0,0 +1,23 @@
/*!
* \file timer.h
* \brief This file defines the utils for timing
* \author Tianqi Chen, Nacho, Tianyi
*/
#ifndef RABIT_TIMER_H
#define RABIT_TIMER_H
#include <time.h>
#include "./utils.h"
namespace rabit {
namespace utils {
/*!
* \brief return time in seconds
*/
inline double GetTime(void) {
timespec ts;
utils::Check(clock_gettime(CLOCK_REALTIME, &ts) == 0, "failed to get time");
return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9;
}
}
}
#endif

View File

@ -1,17 +1,17 @@
export CC = gcc export CC = gcc
export CXX = g++ export CXX = g++
export MPICXX = mpicxx export MPICXX = mpicxx
export LDFLAGS= -pthread -lm export LDFLAGS= -pthread -lm -lrt
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src
# specify tensor path # specify tensor path
BIN = test_allreduce test_recover test_model_recover BIN = test_allreduce test_recover test_model_recover speed_test
# objectives that makes up rabit library # objectives that makes up rabit library
RABIT_OBJ = allreduce_base.o allreduce_robust.o engine.o RABIT_OBJ = allreduce_base.o allreduce_robust.o engine.o
MPIOBJ = engine_mpi.o MPIOBJ = engine_mpi.o
OBJ = $(RABIT_OBJ) test_allreduce.o test_recover.o test_model_recover.o OBJ = $(RABIT_OBJ) test_allreduce.o test_recover.o test_model_recover.o speed_test.o
MPIBIN = test_allreduce.mpi MPIBIN = test_allreduce.mpi speed_test.mpi
.PHONY: clean all .PHONY: clean all
all: $(BIN) $(MPIBIN) all: $(BIN) $(MPIBIN)
@ -21,23 +21,26 @@ engine.o: ../src/engine.cc ../src/*.h
allreduce_robust.o: ../src/allreduce_robust.cc ../src/*.h allreduce_robust.o: ../src/allreduce_robust.cc ../src/*.h
engine_mpi.o: ../src/engine_mpi.cc engine_mpi.o: ../src/engine_mpi.cc
test_allreduce.o: test_allreduce.cpp ../src/*.h test_allreduce.o: test_allreduce.cpp ../src/*.h
speed_test.o: speed_test.cpp ../src/*.h
test_recover.o: test_recover.cpp ../src/*.h test_recover.o: test_recover.cpp ../src/*.h
test_model_recover.o: test_model_recover.cpp ../src/*.h test_model_recover.o: test_model_recover.cpp ../src/*.h
# we can link against MPI version to get use MPI # we can link against MPI version to get use MPI
test_allreduce: test_allreduce.o $(RABIT_OBJ) test_allreduce: test_allreduce.o $(RABIT_OBJ)
test_allreduce.mpi: test_allreduce.o $(MPIOBJ) test_allreduce.mpi: test_allreduce.o $(MPIOBJ)
speed_test: speed_test.o $(RABIT_OBJ)
speed_test.mpi: speed_test.o $(MPIOBJ)
test_recover: test_recover.o $(RABIT_OBJ) test_recover: test_recover.o $(RABIT_OBJ)
test_model_recover: test_model_recover.o $(RABIT_OBJ) test_model_recover: test_model_recover.o $(RABIT_OBJ)
$(BIN) : $(BIN) :
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS)
$(OBJ) : $(OBJ) :
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) ) $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )
$(MPIBIN) : $(MPIBIN) :
$(MPICXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(MPICXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(LDFLAGS)
$(MPIOBJ) : $(MPIOBJ) :
$(MPICXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) ) $(MPICXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )

96
test/speed_test.cpp Normal file
View File

@ -0,0 +1,96 @@
#include <rabit.h>
#include <utils.h>
#include <timer.h>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <time.h>
using namespace rabit;
double max_tdiff, sum_tdiff, bcast_tdiff, tot_tdiff;
inline void TestMax(size_t n) {
int rank = rabit::GetRank();
//int nproc = rabit::GetWorldSize();
std::vector<float> ndata(n);
for (size_t i = 0; i < ndata.size(); ++i) {
ndata[i] = (i * (rank+1)) % 111;
}
double tstart = utils::GetTime();
rabit::Allreduce<op::Max>(&ndata[0], ndata.size());
max_tdiff += utils::GetTime() - tstart;
}
inline void TestSum(size_t n) {
int rank = rabit::GetRank();
//int nproc = rabit::GetWorldSize();
const int z = 131;
std::vector<float> ndata(n);
for (size_t i = 0; i < ndata.size(); ++i) {
ndata[i] = (i * (rank+1)) % z;
}
double tstart = utils::GetTime();
rabit::Allreduce<op::Sum>(&ndata[0], ndata.size());
sum_tdiff += utils::GetTime() - tstart;
}
inline void TestBcast(size_t n, int root) {
int rank = rabit::GetRank();
std::string s; s.resize(n);
for (size_t i = 0; i < n; ++i) {
s[i] = char(i % 126 + 1);
}
std::string res;
if (root == rank) {
res = s;
}
double tstart = utils::GetTime();
rabit::Broadcast(&res, root);
bcast_tdiff += utils::GetTime() - tstart;
}
inline void PrintStats(const char *name, double tdiff) {
int nproc = rabit::GetWorldSize();
double tsum = tdiff;
rabit::Allreduce<op::Sum>(&tsum, 1);
double tavg = tsum / nproc;
double tsqr = tdiff - tavg;
tsqr *= tsqr;
rabit::Allreduce<op::Sum>(&tsqr, 1);
double tstd = sqrt(tsqr / nproc);
if (rabit::GetRank() == 0) {
utils::LogPrintf("%s: mean=%g, std=%g sec\n", name, tavg, tstd);
}
}
int main(int argc, char *argv[]) {
if (argc < 3) {
printf("Usage: <ndata> <nrepeat>\n");
return 0;
}
srand(0);
int n = atoi(argv[1]);
int nrep = atoi(argv[2]);
utils::Check(nrep >= 1, "need to at least repeat running once");
rabit::Init(argc, argv);
//int rank = rabit::GetRank();
int nproc = rabit::GetWorldSize();
std::string name = rabit::GetProcessorName();
max_tdiff = sum_tdiff = bcast_tdiff = 0;
double tstart = utils::GetTime();
for (int i = 0; i < nrep; ++i) {
TestMax(n);
TestSum(n);
TestBcast(n, rand() % nproc);
}
tot_tdiff = utils::GetTime() - tstart;
// use allreduce to get the sum and std of time
PrintStats("max_tdiff", max_tdiff);
PrintStats("sum_tdiff", sum_tdiff);
PrintStats("bcast_tdiff", bcast_tdiff);
PrintStats("tot_tdiff", tot_tdiff);
rabit::Finalize();
return 0;
}