diff --git a/test/speed_test.cpp b/test/speed_test.cpp index 25033a54a..d171c7dc1 100644 --- a/test/speed_test.cpp +++ b/test/speed_test.cpp @@ -51,7 +51,7 @@ inline void TestBcast(size_t n, int root) { bcast_tdiff += utils::GetTime() - tstart; } -inline void PrintStats(const char *name, double tdiff) { +inline void PrintStats(const char *name, double tdiff, int n, int nrep, size_t size) { int nproc = rabit::GetWorldSize(); double tsum = tdiff; rabit::Allreduce(&tsum, 1); @@ -62,6 +62,11 @@ inline void PrintStats(const char *name, double tdiff) { double tstd = sqrt(tsqr / nproc); if (rabit::GetRank() == 0) { utils::LogPrintf("%s: mean=%g, std=%g sec\n", name, tavg, tstd); + double ndata = n; + ndata *= nrep * size; + if (n != 0) { + utils::LogPrintf("%s-speed: %g MB/sec\n", name, (ndata / tavg) / 1024 / 1024 ); + } } } @@ -87,10 +92,10 @@ int main(int argc, char *argv[]) { } tot_tdiff = utils::GetTime() - tstart; // use allreduce to get the sum and std of time - PrintStats("max_tdiff", max_tdiff); - PrintStats("sum_tdiff", sum_tdiff); - PrintStats("bcast_tdiff", bcast_tdiff); - PrintStats("tot_tdiff", tot_tdiff); + PrintStats("max_tdiff", max_tdiff, n, nrep, sizeof(float)); + PrintStats("sum_tdiff", sum_tdiff, n, nrep, sizeof(float)); + PrintStats("bcast_tdiff", bcast_tdiff, n, nrep, sizeof(char)); + PrintStats("tot_tdiff", tot_tdiff, 0, nrep, sizeof(float)); rabit::Finalize(); return 0; }