Enhance nvtx support. (#5636)
This commit is contained in:
@@ -10,12 +10,21 @@
|
||||
#include "timer.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
#if defined(XGBOOST_USE_NVTX)
|
||||
#include <nvToolsExt.h>
|
||||
#endif // defined(XGBOOST_USE_NVTX)
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
void Monitor::Start(std::string const &name) {
|
||||
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
||||
statistics_map_[name].timer.Start();
|
||||
auto &stats = statistics_map_[name];
|
||||
stats.timer.Start();
|
||||
#if defined(XGBOOST_USE_NVTX)
|
||||
std::string nvtx_name = label_ + "::" + name;
|
||||
stats.nvtx_id = nvtxRangeStartA(nvtx_name.c_str());
|
||||
#endif // defined(XGBOOST_USE_NVTX)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +33,9 @@ void Monitor::Stop(const std::string &name) {
|
||||
auto &stats = statistics_map_[name];
|
||||
stats.timer.Stop();
|
||||
stats.count++;
|
||||
#if defined(XGBOOST_USE_NVTX)
|
||||
nvtxRangeEnd(stats.nvtx_id);
|
||||
#endif // defined(XGBOOST_USE_NVTX)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
/*!
|
||||
* Copyright by Contributors 2019
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NVTX)
|
||||
#include <nvToolsExt.h>
|
||||
#endif // defined(XGBOOST_USE_NVTX)
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
#include "device_helpers.cuh"
|
||||
#include "timer.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
void Monitor::StartCuda(const std::string& name) {
|
||||
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
||||
auto &stats = statistics_map_[name];
|
||||
stats.timer.Start();
|
||||
#if defined(XGBOOST_USE_NVTX)
|
||||
stats.nvtx_id = nvtxRangeStartA(name.c_str());
|
||||
#endif // defined(XGBOOST_USE_NVTX)
|
||||
}
|
||||
}
|
||||
|
||||
void Monitor::StopCuda(const std::string& name) {
|
||||
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
|
||||
auto &stats = statistics_map_[name];
|
||||
stats.timer.Stop();
|
||||
stats.count++;
|
||||
#if defined(XGBOOST_USE_NVTX)
|
||||
nvtxRangeEnd(stats.nvtx_id);
|
||||
#endif // defined(XGBOOST_USE_NVTX)
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@@ -82,8 +82,6 @@ struct Monitor {
|
||||
void Init(std::string label) { this->label_ = label; }
|
||||
void Start(const std::string &name);
|
||||
void Stop(const std::string &name);
|
||||
void StartCuda(const std::string &name);
|
||||
void StopCuda(const std::string &name);
|
||||
};
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user