Implement training observer. (#5088)
This commit is contained in:
parent
f0ca53d9ec
commit
f5e13dcb9b
@ -30,6 +30,8 @@ option(USE_OPENMP "Build with OpenMP support." ON)
|
|||||||
option(JVM_BINDINGS "Build JVM bindings" OFF)
|
option(JVM_BINDINGS "Build JVM bindings" OFF)
|
||||||
option(R_LIB "Build shared library for R package" OFF)
|
option(R_LIB "Build shared library for R package" OFF)
|
||||||
## Dev
|
## Dev
|
||||||
|
option(USE_DEBUG_OUTPUT "Dump internal training results like gradients and predictions to stdout.
|
||||||
|
Should only be used for debugging." OFF)
|
||||||
option(GOOGLE_TEST "Build google tests" OFF)
|
option(GOOGLE_TEST "Build google tests" OFF)
|
||||||
option(USE_DMLC_GTEST "Use google tests bundled with dmlc-core submodule" OFF)
|
option(USE_DMLC_GTEST "Use google tests bundled with dmlc-core submodule" OFF)
|
||||||
option(USE_NVTX "Build with cuda profiling annotations. Developers only." OFF)
|
option(USE_NVTX "Build with cuda profiling annotations. Developers only." OFF)
|
||||||
@ -39,17 +41,12 @@ option(RABIT_MOCK "Build rabit with mock" OFF)
|
|||||||
option(USE_CUDA "Build with GPU acceleration" OFF)
|
option(USE_CUDA "Build with GPU acceleration" OFF)
|
||||||
option(USE_NCCL "Build with NCCL to enable distributed GPU support." OFF)
|
option(USE_NCCL "Build with NCCL to enable distributed GPU support." OFF)
|
||||||
option(BUILD_WITH_SHARED_NCCL "Build with shared NCCL library." OFF)
|
option(BUILD_WITH_SHARED_NCCL "Build with shared NCCL library." OFF)
|
||||||
|
set(GPU_COMPUTE_VER "" CACHE STRING
|
||||||
|
"Semicolon separated list of compute versions to be built against, e.g. '35;61'")
|
||||||
## Copied From dmlc
|
## Copied From dmlc
|
||||||
option(USE_HDFS "Build with HDFS support" OFF)
|
option(USE_HDFS "Build with HDFS support" OFF)
|
||||||
option(USE_AZURE "Build with AZURE support" OFF)
|
option(USE_AZURE "Build with AZURE support" OFF)
|
||||||
option(USE_S3 "Build with S3 support" OFF)
|
option(USE_S3 "Build with S3 support" OFF)
|
||||||
|
|
||||||
set(GPU_COMPUTE_VER "" CACHE STRING
|
|
||||||
"Semicolon separated list of compute versions to be built against, e.g. '35;61'")
|
|
||||||
if (BUILD_WITH_SHARED_NCCL AND (NOT USE_NCCL))
|
|
||||||
message(SEND_ERROR "Build XGBoost with -DUSE_NCCL=ON to enable BUILD_WITH_SHARED_NCCL.")
|
|
||||||
endif (BUILD_WITH_SHARED_NCCL AND (NOT USE_NCCL))
|
|
||||||
## Sanitizers
|
## Sanitizers
|
||||||
option(USE_SANITIZER "Use santizer flags" OFF)
|
option(USE_SANITIZER "Use santizer flags" OFF)
|
||||||
option(SANITIZER_PATH "Path to sanitizes.")
|
option(SANITIZER_PATH "Path to sanitizes.")
|
||||||
@ -60,12 +57,28 @@ address, leak and thread.")
|
|||||||
option(PLUGIN_LZ4 "Build lz4 plugin" OFF)
|
option(PLUGIN_LZ4 "Build lz4 plugin" OFF)
|
||||||
option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF)
|
option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF)
|
||||||
|
|
||||||
## Deprecation warning
|
#-- Checks for building XGBoost
|
||||||
|
if (USE_DEBUG_OUTPUT AND (NOT (CMAKE_BUILD_TYPE MATCHES Debug)))
|
||||||
|
message(SEND_ERROR "Do not enable `USE_DEBUG_OUTPUT' with release build.")
|
||||||
|
endif (USE_DEBUG_OUTPUT AND (NOT (CMAKE_BUILD_TYPE MATCHES Debug)))
|
||||||
|
if (USE_NCCL AND NOT (USE_CUDA))
|
||||||
|
message(SEND_ERROR "`USE_NCCL` must be enabled with `USE_CUDA` flag.")
|
||||||
|
endif (USE_NCCL AND NOT (USE_CUDA))
|
||||||
|
if (BUILD_WITH_SHARED_NCCL AND (NOT USE_NCCL))
|
||||||
|
message(SEND_ERROR "Build XGBoost with -DUSE_NCCL=ON to enable BUILD_WITH_SHARED_NCCL.")
|
||||||
|
endif (BUILD_WITH_SHARED_NCCL AND (NOT USE_NCCL))
|
||||||
|
if (JVM_BINDINGS AND R_LIB)
|
||||||
|
message(SEND_ERROR "`R_LIB' is not compatible with `JVM_BINDINGS' as they both have customized configurations.")
|
||||||
|
endif (JVM_BINDINGS AND R_LIB)
|
||||||
|
if (R_LIB AND GOOGLE_TEST)
|
||||||
|
message(WARNING "Some C++ unittests will fail with `R_LIB` enabled,
|
||||||
|
as R package redirects some functions to R runtime implementation.")
|
||||||
|
endif (R_LIB AND GOOGLE_TEST)
|
||||||
if (USE_AVX)
|
if (USE_AVX)
|
||||||
message(WARNING "The option 'USE_AVX' is deprecated as experimental AVX features have been removed from xgboost.")
|
message(SEND_ERROR "The option 'USE_AVX' is deprecated as experimental AVX features have been removed from XGBoost.")
|
||||||
endif (USE_AVX)
|
endif (USE_AVX)
|
||||||
|
|
||||||
# Sanitizer
|
#-- Sanitizer
|
||||||
if (USE_SANITIZER)
|
if (USE_SANITIZER)
|
||||||
# Older CMake versions have had troubles with Sanitizer
|
# Older CMake versions have had troubles with Sanitizer
|
||||||
cmake_minimum_required(VERSION 3.12)
|
cmake_minimum_required(VERSION 3.12)
|
||||||
|
|||||||
@ -63,6 +63,9 @@ target_compile_definitions(objxgboost
|
|||||||
-DDMLC_LOG_CUSTOMIZE=1 # enable custom logging
|
-DDMLC_LOG_CUSTOMIZE=1 # enable custom logging
|
||||||
$<$<NOT:$<CXX_COMPILER_ID:MSVC>>:_MWAITXINTRIN_H_INCLUDED>
|
$<$<NOT:$<CXX_COMPILER_ID:MSVC>>:_MWAITXINTRIN_H_INCLUDED>
|
||||||
${XGBOOST_DEFINITIONS})
|
${XGBOOST_DEFINITIONS})
|
||||||
|
if (USE_DEBUG_OUTPUT)
|
||||||
|
target_compile_definitions(objxgboost PRIVATE -DXGBOOST_USE_DEBUG_OUTPUT=1)
|
||||||
|
endif (USE_DEBUG_OUTPUT)
|
||||||
|
|
||||||
if (XGBOOST_MM_PREFETCH_PRESENT)
|
if (XGBOOST_MM_PREFETCH_PRESENT)
|
||||||
target_compile_definitions(objxgboost
|
target_compile_definitions(objxgboost
|
||||||
|
|||||||
101
src/common/observer.h
Normal file
101
src/common/observer.h
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2019 XGBoost contributors
|
||||||
|
* \file observer.h
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_COMMON_OBSERVER_H_
|
||||||
|
#define XGBOOST_COMMON_OBSERVER_H_
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "xgboost/host_device_vector.h"
|
||||||
|
#include "xgboost/parameter.h"
|
||||||
|
#include "xgboost/json.h"
|
||||||
|
#include "xgboost/base.h"
|
||||||
|
#include "xgboost/tree_model.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
/*\brief An observer for logging internal data structures.
|
||||||
|
*
|
||||||
|
* This class is designed to be `diff` tool friendly, which means it uses plain
|
||||||
|
* `std::cout` for printing to avoid the time information emitted by `LOG(DEBUG)` or
|
||||||
|
* similiar facilities.
|
||||||
|
*/
|
||||||
|
class TrainingObserver {
|
||||||
|
#if defined(XGBOOST_USE_DEBUG_OUTPUT)
|
||||||
|
bool constexpr static observe_ {true};
|
||||||
|
#else
|
||||||
|
bool constexpr static observe_ {false};
|
||||||
|
#endif // defined(XGBOOST_USE_DEBUG_OUTPUT)
|
||||||
|
|
||||||
|
public:
|
||||||
|
void Update(int32_t iter) const {
|
||||||
|
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||||
|
std::cout << "Iter: " << iter << std::endl;
|
||||||
|
}
|
||||||
|
/*\brief Observe tree. */
|
||||||
|
void Observe(RegTree const& tree) {
|
||||||
|
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||||
|
std::cout << "Tree:" << std::endl;
|
||||||
|
Json j_tree {Object()};
|
||||||
|
tree.SaveModel(&j_tree);
|
||||||
|
std::string str;
|
||||||
|
Json::Dump(j_tree, &str, true);
|
||||||
|
std::cout << str << std::endl;
|
||||||
|
}
|
||||||
|
/*\brief Observe tree. */
|
||||||
|
void Observe(RegTree const* p_tree) {
|
||||||
|
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||||
|
auto const& tree = *p_tree;
|
||||||
|
this->Observe(tree);
|
||||||
|
}
|
||||||
|
/*\brief Observe data hosted by `std::vector'. */
|
||||||
|
template <typename T>
|
||||||
|
void Observe(std::vector<T> const& h_vec, std::string name) const {
|
||||||
|
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||||
|
std::cout << "Procedure: " << name << std::endl;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < h_vec.size(); ++i) {
|
||||||
|
std::cout << h_vec[i] << ", ";
|
||||||
|
if (i % 8 == 0) {
|
||||||
|
std::cout << '\n';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
/*\brief Observe data hosted by `HostDeviceVector'. */
|
||||||
|
template <typename T>
|
||||||
|
void Observe(HostDeviceVector<T> const& vec, std::string name) const {
|
||||||
|
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||||
|
auto const& h_vec = vec.HostVector();
|
||||||
|
this->Observe(h_vec, name);
|
||||||
|
}
|
||||||
|
/*\brief Observe objects with `XGBoostParamer' type. */
|
||||||
|
template <typename Parameter,
|
||||||
|
typename std::enable_if<
|
||||||
|
std::is_base_of<XGBoostParameter<Parameter>, Parameter>::value>::type* = nullptr>
|
||||||
|
void Observe(const Parameter &p, std::string name) const {
|
||||||
|
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||||
|
|
||||||
|
Json obj {toJson(p)};
|
||||||
|
std::cout << "Parameter: " << name << ":\n" << obj << std::endl;
|
||||||
|
}
|
||||||
|
/*\brief Observe parameters provided by users. */
|
||||||
|
void Observe(Args const& args) const {
|
||||||
|
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||||
|
|
||||||
|
for (auto kv : args) {
|
||||||
|
std::cout << kv.first << ": " << kv.second << "\n";
|
||||||
|
}
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*\brief Get a global instance. */
|
||||||
|
static TrainingObserver& Instance() {
|
||||||
|
static TrainingObserver observer;
|
||||||
|
return observer;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_COMMON_OBSERVER_H_
|
||||||
Loading…
x
Reference in New Issue
Block a user