Pass shared pointer instead of raw pointer to Learner. (#5302)

Extracted from https://github.com/dmlc/xgboost/pull/5220 .
This commit is contained in:
Jiaming Yuan
2020-02-11 14:16:38 +08:00
committed by GitHub
parent 2e0067e790
commit 29eeea709a
12 changed files with 97 additions and 73 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2017-2019 XGBoost contributors
* Copyright 2017-2020 XGBoost contributors
*/
#pragma once
#include <thrust/device_ptr.h>
@@ -9,7 +9,6 @@
#include <thrust/system_error.h>
#include <thrust/logical.h>
#include <omp.h>
#include <rabit/rabit.h>
#include <cub/cub.cuh>
#include <cub/util_allocator.cuh>

View File

@@ -1,11 +1,12 @@
/*!
* Copyright 2019 XGBoost contributors
* Copyright 2019-2020 XGBoost contributors
* \file observer.h
*/
#ifndef XGBOOST_COMMON_OBSERVER_H_
#define XGBOOST_COMMON_OBSERVER_H_
#include <iostream>
#include <limits>
#include <string>
#include <vector>
@@ -63,7 +64,8 @@ class TrainingObserver {
}
/*\brief Observe data hosted by `std::vector'. */
template <typename T>
void Observe(std::vector<T> const& h_vec, std::string name) const {
void Observe(std::vector<T> const& h_vec, std::string name,
size_t n = std::numeric_limits<std::size_t>::max()) const {
if (XGBOOST_EXPECT(!observe_, true)) { return; }
OBSERVER_PRINT << "Procedure: " << name << OBSERVER_ENDL;
@@ -72,20 +74,25 @@ class TrainingObserver {
if (i % 8 == 0) {
OBSERVER_PRINT << OBSERVER_NEWLINE;
}
if ((i + 1) == n) {
break;
}
}
OBSERVER_PRINT << OBSERVER_ENDL;
}
/*\brief Observe data hosted by `HostDeviceVector'. */
template <typename T>
void Observe(HostDeviceVector<T> const& vec, std::string name) const {
void Observe(HostDeviceVector<T> const& vec, std::string name,
size_t n = std::numeric_limits<std::size_t>::max()) const {
if (XGBOOST_EXPECT(!observe_, true)) { return; }
auto const& h_vec = vec.HostVector();
this->Observe(h_vec, name);
this->Observe(h_vec, name, n);
}
template <typename T>
void Observe(HostDeviceVector<T>* vec, std::string name) const {
void Observe(HostDeviceVector<T>* vec, std::string name,
size_t n = std::numeric_limits<std::size_t>::max()) const {
if (XGBOOST_EXPECT(!observe_, true)) { return; }
this->Observe(*vec, name);
this->Observe(*vec, name, n);
}
/*\brief Observe objects with `XGBoostParamer' type. */