monitor for distributed envorinment. (#4829)
* Collect statistics from other ranks in monitor. * Workaround old GCC bug.
This commit is contained in:
parent
c0fbeff0ab
commit
52d44e07fe
@ -63,6 +63,7 @@
|
|||||||
#include "../src/learner.cc"
|
#include "../src/learner.cc"
|
||||||
#include "../src/logging.cc"
|
#include "../src/logging.cc"
|
||||||
#include "../src/common/common.cc"
|
#include "../src/common/common.cc"
|
||||||
|
#include "../src/common/timer.cc"
|
||||||
#include "../src/common/host_device_vector.cc"
|
#include "../src/common/host_device_vector.cc"
|
||||||
#include "../src/common/hist_util.cc"
|
#include "../src/common/hist_util.cc"
|
||||||
#include "../src/common/json.cc"
|
#include "../src/common/json.cc"
|
||||||
|
|||||||
@ -162,6 +162,10 @@ class JsonNumber : public Value {
|
|||||||
JsonNumber(FloatT value) : Value(ValueKind::Number) { // NOLINT
|
JsonNumber(FloatT value) : Value(ValueKind::Number) { // NOLINT
|
||||||
number_ = value;
|
number_ = value;
|
||||||
}
|
}
|
||||||
|
template <typename FloatT,
|
||||||
|
typename std::enable_if<std::is_same<FloatT, double>::value>::type* = nullptr>
|
||||||
|
JsonNumber(FloatT value) : Value{ValueKind::Number}, // NOLINT
|
||||||
|
number_{static_cast<Float>(value)} {}
|
||||||
|
|
||||||
void Save(JsonWriter* writer) override;
|
void Save(JsonWriter* writer) override;
|
||||||
|
|
||||||
@ -193,6 +197,10 @@ class JsonInteger : public Value {
|
|||||||
template <typename IntT,
|
template <typename IntT,
|
||||||
typename std::enable_if<std::is_same<IntT, Int>::value>::type* = nullptr>
|
typename std::enable_if<std::is_same<IntT, Int>::value>::type* = nullptr>
|
||||||
JsonInteger(IntT value) : Value(ValueKind::Integer), integer_{value} {} // NOLINT
|
JsonInteger(IntT value) : Value(ValueKind::Integer), integer_{value} {} // NOLINT
|
||||||
|
template <typename IntT,
|
||||||
|
typename std::enable_if<std::is_same<IntT, size_t>::value>::type* = nullptr>
|
||||||
|
JsonInteger(IntT value) : Value(ValueKind::Integer), // NOLINT
|
||||||
|
integer_{static_cast<Int>(value)} {}
|
||||||
|
|
||||||
Json& operator[](std::string const & key) override;
|
Json& operator[](std::string const & key) override;
|
||||||
Json& operator[](int ind) override;
|
Json& operator[](int ind) override;
|
||||||
|
|||||||
115
src/common/timer.cc
Normal file
115
src/common/timer.cc
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright by Contributors 2019
|
||||||
|
*/
|
||||||
|
#include <rabit/rabit.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <type_traits>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include <sstream>
|
||||||
|
#include "timer.h"
|
||||||
|
#include "xgboost/json.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
|
||||||
|
std::vector<Monitor::StatMap> Monitor::CollectFromOtherRanks() const {
|
||||||
|
// Since other nodes might have started timers that this one haven't, so
|
||||||
|
// we can't simply call all reduce.
|
||||||
|
size_t const world_size = rabit::GetWorldSize();
|
||||||
|
size_t const rank = rabit::GetRank();
|
||||||
|
|
||||||
|
// It's much easier to work with rabit if we have a string serialization. So we go with
|
||||||
|
// json.
|
||||||
|
Json j_statistic { Object() };
|
||||||
|
j_statistic["rank"] = Integer(rank);
|
||||||
|
j_statistic["statistic"] = Object();
|
||||||
|
|
||||||
|
auto& statistic = j_statistic["statistic"];
|
||||||
|
for (auto const& kv : statistics_map) {
|
||||||
|
statistic[kv.first] = Object();
|
||||||
|
auto& j_pair = statistic[kv.first];
|
||||||
|
j_pair["count"] = Integer(kv.second.count);
|
||||||
|
j_pair["elapsed"] = Integer(std::chrono::duration_cast<std::chrono::microseconds>(
|
||||||
|
kv.second.timer.elapsed).count());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::stringstream ss;
|
||||||
|
Json::Dump(j_statistic, &ss);
|
||||||
|
std::string const str { ss.str() };
|
||||||
|
|
||||||
|
size_t str_size = str.size();
|
||||||
|
rabit::Allreduce<rabit::op::Max>(&str_size, 1);
|
||||||
|
std::string buffer;
|
||||||
|
buffer.resize(str_size);
|
||||||
|
|
||||||
|
// vector storing stat from all workers
|
||||||
|
std::vector<StatMap> world(world_size);
|
||||||
|
|
||||||
|
// Actually only rank 0 is printing.
|
||||||
|
for (size_t i = 0; i < world_size; ++i) {
|
||||||
|
std::copy(str.cbegin(), str.cend(), buffer.begin());
|
||||||
|
rabit::Broadcast(&buffer, i);
|
||||||
|
auto j_other = Json::Load(StringView{buffer.c_str(), buffer.size()});
|
||||||
|
auto& other = world[i];
|
||||||
|
|
||||||
|
auto const& j_statistic = get<Object>(j_other["statistic"]);
|
||||||
|
|
||||||
|
for (auto const& kv : j_statistic) {
|
||||||
|
std::string const& timer_name = kv.first;
|
||||||
|
auto const& pair = kv.second;
|
||||||
|
other[timer_name] = {get<Integer>(pair["count"]), get<Integer>(pair["elapsed"])};
|
||||||
|
}
|
||||||
|
|
||||||
|
// FIXME(trivialfis): How to ask rabit to block here?
|
||||||
|
}
|
||||||
|
|
||||||
|
return world;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Monitor::PrintStatistics(StatMap const& statistics) const {
|
||||||
|
for (auto &kv : statistics) {
|
||||||
|
if (kv.second.first == 0) {
|
||||||
|
LOG(WARNING) <<
|
||||||
|
"Timer for " << kv.first << " did not get stopped properly.";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::cout << kv.first << ": " << static_cast<double>(kv.second.second) / 1e+6
|
||||||
|
<< "s, " << kv.second.first << " calls @ "
|
||||||
|
<< kv.second.second
|
||||||
|
<< "us" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Monitor::Print() const {
|
||||||
|
if (!ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { return; }
|
||||||
|
|
||||||
|
bool is_distributed = rabit::IsDistributed();
|
||||||
|
|
||||||
|
if (is_distributed) {
|
||||||
|
auto world = this->CollectFromOtherRanks();
|
||||||
|
// rank zero is in charge of printing
|
||||||
|
if (rabit::GetRank() == 0) {
|
||||||
|
LOG(CONSOLE) << "======== Monitor: " << label << " ========";
|
||||||
|
for (size_t i = 0; i < world.size(); ++i) {
|
||||||
|
std::cout << "From rank: " << i << ": " << std::endl;
|
||||||
|
auto const& statistic = world[i];
|
||||||
|
this->PrintStatistics(statistic);
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
StatMap stat_map;
|
||||||
|
for (auto const& kv : statistics_map) {
|
||||||
|
stat_map[kv.first] = std::make_pair(
|
||||||
|
kv.second.count, std::chrono::duration_cast<std::chrono::microseconds>(
|
||||||
|
kv.second.timer.elapsed).count());
|
||||||
|
}
|
||||||
|
LOG(CONSOLE) << "======== Monitor: " << label << " ========";
|
||||||
|
this->PrintStatistics(stat_map);
|
||||||
|
}
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace common
|
||||||
|
} // namespace xgboost
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright by Contributors 2017
|
* Copyright by Contributors 2017-2019
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <xgboost/logging.h>
|
#include <xgboost/logging.h>
|
||||||
@ -7,6 +7,8 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
||||||
#include <nvToolsExt.h>
|
#include <nvToolsExt.h>
|
||||||
@ -14,6 +16,7 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
|
|
||||||
struct Timer {
|
struct Timer {
|
||||||
using ClockT = std::chrono::high_resolution_clock;
|
using ClockT = std::chrono::high_resolution_clock;
|
||||||
using TimePointT = std::chrono::high_resolution_clock::time_point;
|
using TimePointT = std::chrono::high_resolution_clock::time_point;
|
||||||
@ -45,7 +48,6 @@ struct Timer {
|
|||||||
* \brief Timing utility used to measure total method execution time over the
|
* \brief Timing utility used to measure total method execution time over the
|
||||||
* lifetime of the containing object.
|
* lifetime of the containing object.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
struct Monitor {
|
struct Monitor {
|
||||||
private:
|
private:
|
||||||
struct Statistics {
|
struct Statistics {
|
||||||
@ -53,32 +55,34 @@ struct Monitor {
|
|||||||
size_t count{0};
|
size_t count{0};
|
||||||
uint64_t nvtx_id;
|
uint64_t nvtx_id;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// from left to right, <name <count, elapsed>>
|
||||||
|
using StatMap = std::map<std::string, std::pair<size_t, size_t>>;
|
||||||
|
|
||||||
std::string label = "";
|
std::string label = "";
|
||||||
std::map<std::string, Statistics> statistics_map;
|
std::map<std::string, Statistics> statistics_map;
|
||||||
Timer self_timer;
|
Timer self_timer;
|
||||||
|
|
||||||
|
/*! \brief Collect time statistics across all workers. */
|
||||||
|
std::vector<StatMap> CollectFromOtherRanks() const;
|
||||||
|
void PrintStatistics(StatMap const& statistics) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Monitor() { self_timer.Start(); }
|
Monitor() { self_timer.Start(); }
|
||||||
|
/*\brief Print statistics info during destruction.
|
||||||
|
*
|
||||||
|
* Please note that this may not work, as with distributed frameworks like Dask, the
|
||||||
|
* model is pickled to other workers, and the global parameters like `global_verbosity_`
|
||||||
|
* are not included in the pickle.
|
||||||
|
*/
|
||||||
~Monitor() {
|
~Monitor() {
|
||||||
if (!ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) return;
|
this->Print();
|
||||||
|
|
||||||
LOG(CONSOLE) << "======== Monitor: " << label << " ========";
|
|
||||||
for (auto &kv : statistics_map) {
|
|
||||||
if (kv.second.count == 0) {
|
|
||||||
LOG(WARNING) <<
|
|
||||||
"Timer for " << kv.first << " did not get stopped properly.";
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
LOG(CONSOLE) << kv.first << ": " << kv.second.timer.ElapsedSeconds()
|
|
||||||
<< "s, " << kv.second.count << " calls @ "
|
|
||||||
<< std::chrono::duration_cast<std::chrono::microseconds>(
|
|
||||||
kv.second.timer.elapsed / kv.second.count)
|
|
||||||
.count()
|
|
||||||
<< "us";
|
|
||||||
}
|
|
||||||
self_timer.Stop();
|
self_timer.Stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*! \brief Print all the statistics. */
|
||||||
|
void Print() const;
|
||||||
|
|
||||||
void Init(std::string label) { this->label = label; }
|
void Init(std::string label) { this->label = label; }
|
||||||
void Start(const std::string &name) {
|
void Start(const std::string &name) {
|
||||||
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
||||||
|
|||||||
@ -142,11 +142,13 @@ DMLC_REGISTER_PARAMETER(GenericParameter);
|
|||||||
class LearnerImpl : public Learner {
|
class LearnerImpl : public Learner {
|
||||||
public:
|
public:
|
||||||
explicit LearnerImpl(std::vector<std::shared_ptr<DMatrix> > cache)
|
explicit LearnerImpl(std::vector<std::shared_ptr<DMatrix> > cache)
|
||||||
: configured_{false}, cache_(std::move(cache)) {}
|
: configured_{false}, cache_(std::move(cache)) {
|
||||||
|
monitor_.Init("Learner");
|
||||||
|
}
|
||||||
// Configuration before data is known.
|
// Configuration before data is known.
|
||||||
void Configure() override {
|
void Configure() override {
|
||||||
if (configured_) { return; }
|
if (configured_) { return; }
|
||||||
monitor_.Init("Learner");
|
|
||||||
monitor_.Start("Configure");
|
monitor_.Start("Configure");
|
||||||
auto old_tparam = tparam_;
|
auto old_tparam = tparam_;
|
||||||
Args args = {cfg_.cbegin(), cfg_.cend()};
|
Args args = {cfg_.cbegin(), cfg_.cend()};
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user