diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 963b9782d..62b2f9cd0 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -63,6 +63,7 @@ #include "../src/learner.cc" #include "../src/logging.cc" #include "../src/common/common.cc" +#include "../src/common/timer.cc" #include "../src/common/host_device_vector.cc" #include "../src/common/hist_util.cc" #include "../src/common/json.cc" diff --git a/include/xgboost/json.h b/include/xgboost/json.h index 477705e70..790eca8b8 100644 --- a/include/xgboost/json.h +++ b/include/xgboost/json.h @@ -162,6 +162,10 @@ class JsonNumber : public Value { JsonNumber(FloatT value) : Value(ValueKind::Number) { // NOLINT number_ = value; } + template ::value>::type* = nullptr> + JsonNumber(FloatT value) : Value{ValueKind::Number}, // NOLINT + number_{static_cast(value)} {} void Save(JsonWriter* writer) override; @@ -193,6 +197,10 @@ class JsonInteger : public Value { template ::value>::type* = nullptr> JsonInteger(IntT value) : Value(ValueKind::Integer), integer_{value} {} // NOLINT + template ::value>::type* = nullptr> + JsonInteger(IntT value) : Value(ValueKind::Integer), // NOLINT + integer_{static_cast(value)} {} Json& operator[](std::string const & key) override; Json& operator[](int ind) override; diff --git a/src/common/timer.cc b/src/common/timer.cc new file mode 100644 index 000000000..41aa6da26 --- /dev/null +++ b/src/common/timer.cc @@ -0,0 +1,115 @@ +/*! + * Copyright by Contributors 2019 + */ +#include +#include +#include +#include +#include +#include +#include "timer.h" +#include "xgboost/json.h" + +namespace xgboost { +namespace common { + +std::vector Monitor::CollectFromOtherRanks() const { + // Since other nodes might have started timers that this one haven't, so + // we can't simply call all reduce. + size_t const world_size = rabit::GetWorldSize(); + size_t const rank = rabit::GetRank(); + + // It's much easier to work with rabit if we have a string serialization. So we go with + // json. + Json j_statistic { Object() }; + j_statistic["rank"] = Integer(rank); + j_statistic["statistic"] = Object(); + + auto& statistic = j_statistic["statistic"]; + for (auto const& kv : statistics_map) { + statistic[kv.first] = Object(); + auto& j_pair = statistic[kv.first]; + j_pair["count"] = Integer(kv.second.count); + j_pair["elapsed"] = Integer(std::chrono::duration_cast( + kv.second.timer.elapsed).count()); + } + + std::stringstream ss; + Json::Dump(j_statistic, &ss); + std::string const str { ss.str() }; + + size_t str_size = str.size(); + rabit::Allreduce(&str_size, 1); + std::string buffer; + buffer.resize(str_size); + + // vector storing stat from all workers + std::vector world(world_size); + + // Actually only rank 0 is printing. + for (size_t i = 0; i < world_size; ++i) { + std::copy(str.cbegin(), str.cend(), buffer.begin()); + rabit::Broadcast(&buffer, i); + auto j_other = Json::Load(StringView{buffer.c_str(), buffer.size()}); + auto& other = world[i]; + + auto const& j_statistic = get(j_other["statistic"]); + + for (auto const& kv : j_statistic) { + std::string const& timer_name = kv.first; + auto const& pair = kv.second; + other[timer_name] = {get(pair["count"]), get(pair["elapsed"])}; + } + + // FIXME(trivialfis): How to ask rabit to block here? + } + + return world; +} + +void Monitor::PrintStatistics(StatMap const& statistics) const { + for (auto &kv : statistics) { + if (kv.second.first == 0) { + LOG(WARNING) << + "Timer for " << kv.first << " did not get stopped properly."; + continue; + } + std::cout << kv.first << ": " << static_cast(kv.second.second) / 1e+6 + << "s, " << kv.second.first << " calls @ " + << kv.second.second + << "us" << std::endl; + } +} + +void Monitor::Print() const { + if (!ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { return; } + + bool is_distributed = rabit::IsDistributed(); + + if (is_distributed) { + auto world = this->CollectFromOtherRanks(); + // rank zero is in charge of printing + if (rabit::GetRank() == 0) { + LOG(CONSOLE) << "======== Monitor: " << label << " ========"; + for (size_t i = 0; i < world.size(); ++i) { + std::cout << "From rank: " << i << ": " << std::endl; + auto const& statistic = world[i]; + this->PrintStatistics(statistic); + std::cout << std::endl; + } + } + } else { + StatMap stat_map; + for (auto const& kv : statistics_map) { + stat_map[kv.first] = std::make_pair( + kv.second.count, std::chrono::duration_cast( + kv.second.timer.elapsed).count()); + } + LOG(CONSOLE) << "======== Monitor: " << label << " ========"; + this->PrintStatistics(stat_map); + } + std::cout << std::endl; +} + +} // namespace common +} // namespace xgboost diff --git a/src/common/timer.h b/src/common/timer.h index 72db5d8fc..a899e8798 100644 --- a/src/common/timer.h +++ b/src/common/timer.h @@ -1,5 +1,5 @@ /*! - * Copyright by Contributors 2017 + * Copyright by Contributors 2017-2019 */ #pragma once #include @@ -7,6 +7,8 @@ #include #include #include +#include +#include #if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) #include @@ -14,6 +16,7 @@ namespace xgboost { namespace common { + struct Timer { using ClockT = std::chrono::high_resolution_clock; using TimePointT = std::chrono::high_resolution_clock::time_point; @@ -45,7 +48,6 @@ struct Timer { * \brief Timing utility used to measure total method execution time over the * lifetime of the containing object. */ - struct Monitor { private: struct Statistics { @@ -53,32 +55,34 @@ struct Monitor { size_t count{0}; uint64_t nvtx_id; }; + + // from left to right, > + using StatMap = std::map>; + std::string label = ""; std::map statistics_map; Timer self_timer; + /*! \brief Collect time statistics across all workers. */ + std::vector CollectFromOtherRanks() const; + void PrintStatistics(StatMap const& statistics) const; + public: Monitor() { self_timer.Start(); } - + /*\brief Print statistics info during destruction. + * + * Please note that this may not work, as with distributed frameworks like Dask, the + * model is pickled to other workers, and the global parameters like `global_verbosity_` + * are not included in the pickle. + */ ~Monitor() { - if (!ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) return; - - LOG(CONSOLE) << "======== Monitor: " << label << " ========"; - for (auto &kv : statistics_map) { - if (kv.second.count == 0) { - LOG(WARNING) << - "Timer for " << kv.first << " did not get stopped properly."; - continue; - } - LOG(CONSOLE) << kv.first << ": " << kv.second.timer.ElapsedSeconds() - << "s, " << kv.second.count << " calls @ " - << std::chrono::duration_cast( - kv.second.timer.elapsed / kv.second.count) - .count() - << "us"; - } + this->Print(); self_timer.Stop(); } + + /*! \brief Print all the statistics. */ + void Print() const; + void Init(std::string label) { this->label = label; } void Start(const std::string &name) { if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { diff --git a/src/learner.cc b/src/learner.cc index b6cf889c7..7967c081b 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -142,11 +142,13 @@ DMLC_REGISTER_PARAMETER(GenericParameter); class LearnerImpl : public Learner { public: explicit LearnerImpl(std::vector > cache) - : configured_{false}, cache_(std::move(cache)) {} + : configured_{false}, cache_(std::move(cache)) { + monitor_.Init("Learner"); + } // Configuration before data is known. void Configure() override { if (configured_) { return; } - monitor_.Init("Learner"); + monitor_.Start("Configure"); auto old_tparam = tparam_; Args args = {cfg_.cbegin(), cfg_.cend()};