Improved gpu_hist_experimental algorithm (#2866)

- Implement colsampling, subsampling for gpu_hist_experimental

 - Optimised multi-GPU implementation for gpu_hist_experimental

 - Make nccl optional

 - Add Volta architecture flag

 - Optimise RegLossObj

 - Add timing utilities for debug verbose mode

 - Bump required cuda version to 8.0
This commit is contained in:
Rory Mitchell 2017-11-11 13:58:40 +13:00 committed by GitHub
parent 16c63f30d0
commit 40c6e2f0c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 855 additions and 473 deletions

View File

@ -9,6 +9,7 @@ msvc_use_static_runtime()
# Options
option(USE_CUDA "Build with GPU acceleration")
option(USE_NCCL "Build using NCCL for multi-GPU. Also requires USE_CUDA")
option(JVM_BINDINGS "Build JVM bindings" OFF)
option(GOOGLE_TEST "Build google tests" OFF)
option(R_LIB "Build shared library for R package" OFF)
@ -97,26 +98,39 @@ set(LINK_LIBRARIES dmlccore rabit)
if(USE_CUDA)
find_package(CUDA 7.5 REQUIRED)
find_package(CUDA 8.0 REQUIRED)
cmake_minimum_required(VERSION 3.5)
add_definitions(-DXGBOOST_USE_CUDA)
include_directories(
nccl/src
cub
)
include_directories(cub)
if(USE_NCCL)
include_directories(nccl/src)
add_definitions(-DXGBOOST_USE_NCCL)
endif()
if((CUDA_VERSION_MAJOR EQUAL 9) OR (CUDA_VERSION_MAJOR GREATER 9))
message("CUDA 9.0 detected, adding Volta compute capability (7.0).")
set(GPU_COMPUTE_VER "${GPU_COMPUTE_VER};70")
endif()
set(GENCODE_FLAGS "")
format_gencode_flags("${GPU_COMPUTE_VER}" GENCODE_FLAGS)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;${GENCODE_FLAGS};-lineinfo;")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;--expt-relaxed-constexpr;${GENCODE_FLAGS};-lineinfo;")
if(NOT MSVC)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-Xcompiler -fPIC; -std=c++11")
endif()
add_subdirectory(nccl)
if(USE_NCCL)
add_subdirectory(nccl)
endif()
cuda_add_library(gpuxgboost ${CUDA_SOURCES} STATIC)
target_link_libraries(gpuxgboost nccl)
if(USE_NCCL)
target_link_libraries(gpuxgboost nccl)
endif()
list(APPEND LINK_LIBRARIES gpuxgboost)
endif()

View File

@ -15,7 +15,10 @@
#include <sstream>
#include <string>
#include <vector>
#ifdef XGBOOST_USE_NCCL
#include "nccl.h"
#endif
// Uncomment to enable
#define TIMERS
@ -44,6 +47,7 @@ inline cudaError_t throw_on_cuda_error(cudaError_t code, const char *file,
return code;
}
#ifdef XGBOOST_USE_NCCL
#define safe_nccl(ans) throw_on_nccl_error((ans), __FILE__, __LINE__)
inline ncclResult_t throw_on_nccl_error(ncclResult_t code, const char *file,
@ -57,6 +61,7 @@ inline ncclResult_t throw_on_nccl_error(ncclResult_t code, const char *file,
return code;
}
#endif
template <typename T>
T *raw(thrust::device_vector<T> &v) { // NOLINT
@ -137,6 +142,19 @@ inline int get_device_idx(int gpu_id) {
return (std::abs(gpu_id) + 0) % dh::n_visible_devices();
}
inline void check_compute_capability() {
int n_devices = n_visible_devices();
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
cudaDeviceProp prop;
safe_cuda(cudaGetDeviceProperties(&prop, d_idx));
std::ostringstream oss;
oss << "CUDA Capability Major/Minor version number: " << prop.major << "."
<< prop.minor << " is insufficient. Need >=3.5";
int failed = prop.major < 3 || prop.major == 3 && prop.minor < 5;
if (failed) LOG(WARNING) << oss.str() << " for device: " << d_idx;
}
}
/*
* Range iterator
*/
@ -241,37 +259,12 @@ inline void launch_n(int device_idx, size_t n, L lambda) {
}
safe_cuda(cudaSetDevice(device_idx));
const int GRID_SIZE = static_cast<int>(div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS));
const int GRID_SIZE =
static_cast<int>(div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS));
launch_n_kernel<<<GRID_SIZE, BLOCK_THREADS>>>(static_cast<size_t>(0), n,
lambda);
}
/*
* Timers
*/
struct Timer {
typedef std::chrono::high_resolution_clock ClockT;
typedef std::chrono::high_resolution_clock::time_point TimePointT;
typedef std::chrono::high_resolution_clock::duration DurationT;
typedef std::chrono::duration<double> SecondsT;
TimePointT start;
DurationT elapsed;
Timer() { Reset(); }
void Reset() {
elapsed = DurationT::zero();
Start();
}
void Start() { start = ClockT::now(); }
void Stop() { elapsed += ClockT::now() - start; }
double ElapsedSeconds() const { return SecondsT(elapsed).count(); }
void PrintElapsed(std::string label) {
printf("%s:\t %fs\n", label.c_str(), SecondsT(elapsed).count());
Reset();
}
};
/*
* Memory
*/
@ -444,7 +437,7 @@ class bulk_allocator {
template <typename T>
void allocate_dvec(int device_idx, char *ptr, dvec<T> *first_vec,
size_t first_size) {
size_t first_size) {
first_vec->external_allocate(device_idx, static_cast<void *>(ptr),
first_size);
}
@ -470,8 +463,7 @@ class bulk_allocator {
template <typename T, typename... Args>
size_t get_size_bytes(dvec2<T> *first_vec, size_t first_size, Args... args) {
return get_size_bytes<T>(first_vec, first_size) +
get_size_bytes(args...);
return get_size_bytes<T>(first_vec, first_size) + get_size_bytes(args...);
}
template <typename T>
@ -497,6 +489,7 @@ class bulk_allocator {
if (!(d_ptr[i] == nullptr)) {
safe_cuda(cudaSetDevice(_device_idx[i]));
safe_cuda(cudaFree(d_ptr[i]));
d_ptr[i] = nullptr;
}
}
}
@ -642,7 +635,8 @@ __global__ void LbsKernel(coordinate_t *d_coordinates,
for (auto item : dh::block_stride_range(int(0), int(tile_num_rows + 1))) {
temp_storage.tile_segment_end_offsets[item] =
segment_end_offsets[min(tile_start_coord.x + item, num_segments - 1)];
segment_end_offsets[min(static_cast<size_t>(tile_start_coord.x + item),
static_cast<size_t>(num_segments - 1))];
}
__syncthreads();
@ -693,8 +687,8 @@ void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory,
BLOCK_THREADS, segments, num_segments, count);
LbsKernel<TILE_SIZE, ITEMS_PER_THREAD, BLOCK_THREADS, offset_t>
<<<uint32_t(num_tiles), BLOCK_THREADS>>>(tmp_tile_coordinates, segments + 1, f,
num_segments);
<<<uint32_t(num_tiles), BLOCK_THREADS>>>(tmp_tile_coordinates,
segments + 1, f, num_segments);
}
template <typename func_t, typename offset_t>
@ -836,4 +830,125 @@ void gather(int device_idx, T *out, const T *in, const int *instId, int nVals) {
});
}
/**
* \class AllReducer
*
* \brief All reducer class that manages its own communication group and
* streams. Must be initialised before use. If XGBoost is compiled without NCCL this is a dummy class that will error if used with more than one GPU.
*/
class AllReducer {
bool initialised;
#ifdef XGBOOST_USE_NCCL
std::vector<ncclComm_t> comms;
std::vector<cudaStream_t> streams;
std::vector<int> device_ordinals;
#endif
public:
AllReducer() : initialised(false) {}
/**
* \fn void Init(const std::vector<int> &device_ordinals)
*
* \brief Initialise with the desired device ordinals for this communication
* group.
*
* \param device_ordinals The device ordinals.
*/
void Init(const std::vector<int> &device_ordinals) {
#ifdef XGBOOST_USE_NCCL
this->device_ordinals = device_ordinals;
comms.resize(device_ordinals.size());
dh::safe_nccl(ncclCommInitAll(comms.data(),
static_cast<int>(device_ordinals.size()),
device_ordinals.data()));
streams.resize(device_ordinals.size());
for (size_t i = 0; i < device_ordinals.size(); i++) {
safe_cuda(cudaSetDevice(device_ordinals[i]));
safe_cuda(cudaStreamCreate(&streams[i]));
}
initialised = true;
#else
CHECK_EQ(device_ordinals.size(), 1) << "XGBoost must be compiled with NCCL to use more than one GPU.";
#endif
}
~AllReducer() {
#ifdef XGBOOST_USE_NCCL
if (initialised) {
for (auto &stream : streams) {
dh::safe_cuda(cudaStreamDestroy(stream));
}
for (auto &comm : comms) {
ncclCommDestroy(comm);
}
}
#endif
}
/**
* \fn void AllReduceSum(int communication_group_idx, const double *sendbuff,
* double *recvbuff, int count)
*
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param communication_group_idx Zero-based index of the
* communication group. \param sendbuff The sendbuff. \param
* sendbuff The sendbuff. \param [in,out] recvbuff
* The recvbuff. \param count Number of.
*/
void AllReduceSum(int communication_group_idx, const double *sendbuff,
double *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised);
dh::safe_cuda(cudaSetDevice(device_ordinals[communication_group_idx]));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum,
comms[communication_group_idx],
streams[communication_group_idx]));
#endif
}
/**
* \fn void AllReduceSum(int communication_group_idx, const int64_t *sendbuff, int64_t *recvbuff, int count)
*
* \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms.
*
* \param communication_group_idx Zero-based index of the communication group. \param
* sendbuff The sendbuff. \param sendbuff
* The sendbuff. \param [in,out] recvbuff The recvbuff.
* \param count Number of.
* \param sendbuff The sendbuff.
* \param [in,out] recvbuff If non-null, the recvbuff.
* \param count Number of.
*/
void AllReduceSum(int communication_group_idx, const int64_t *sendbuff,
int64_t *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised);
dh::safe_cuda(cudaSetDevice(device_ordinals[communication_group_idx]));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum,
comms[communication_group_idx],
streams[communication_group_idx]));
#endif
}
/**
* \fn void Synchronize()
*
* \brief Synchronizes the entire communication group.
*/
void Synchronize() {
#ifdef XGBOOST_USE_NCCL
for (int i = 0; i < device_ordinals.size(); i++) {
dh::safe_cuda(cudaSetDevice(device_ordinals[i]));
dh::safe_cuda(cudaStreamSynchronize(streams[i]));
}
#endif
}
};
} // namespace dh

77
src/common/timer.h Normal file
View File

@ -0,0 +1,77 @@
/*!
* Copyright by Contributors 2017
*/
#pragma once
#include <chrono>
#include <iostream>
#include <map>
#include <string>
namespace xgboost {
namespace common {
struct Timer {
typedef std::chrono::high_resolution_clock ClockT;
typedef std::chrono::high_resolution_clock::time_point TimePointT;
typedef std::chrono::high_resolution_clock::duration DurationT;
typedef std::chrono::duration<double> SecondsT;
TimePointT start;
DurationT elapsed;
Timer() { Reset(); }
void Reset() {
elapsed = DurationT::zero();
Start();
}
void Start() { start = ClockT::now(); }
void Stop() { elapsed += ClockT::now() - start; }
double ElapsedSeconds() const { return SecondsT(elapsed).count(); }
void PrintElapsed(std::string label) {
printf("%s:\t %fs\n", label.c_str(), SecondsT(elapsed).count());
Reset();
}
};
/**
* \struct Monitor
*
* \brief Timing utility used to measure total method execution time over the
* lifetime of the containing object.
*/
struct Monitor {
bool debug_verbose = false;
std::string label = "";
std::map<std::string, Timer> timer_map;
Timer self_timer;
Monitor() { self_timer.Start(); }
~Monitor() {
if (!debug_verbose) return;
std::cout << "========\n";
std::cout << "Monitor: " << label << "\n";
std::cout << "========\n";
for (auto &kv : timer_map) {
kv.second.PrintElapsed(kv.first);
}
self_timer.Stop();
self_timer.PrintElapsed(label + " Lifetime");
}
void Init(std::string label, bool debug_verbose) {
this->debug_verbose = debug_verbose;
this->label = label;
}
void Start(const std::string &name) { timer_map[name].Start(); }
void Stop(const std::string &name) {
if (debug_verbose) {
#ifdef __CUDACC__
#include "device_helpers.cuh"
dh::synchronize_all();
#endif
}
timer_map[name].Stop();
}
};
} // namespace common
} // namespace xgboost

View File

@ -20,6 +20,7 @@
#include "../common/common.h"
#include "../common/random.h"
#include "gbtree_model.h"
#include "../common/timer.h"
namespace xgboost {
namespace gbm {
@ -158,6 +159,7 @@ class GBTree : public GradientBooster {
// configure predictor
predictor = std::unique_ptr<Predictor>(Predictor::Create(tparam.predictor));
predictor->Init(cfg, cache_);
monitor.Init("GBTree", tparam.debug_verbose);
}
void Load(dmlc::Stream* fi) override {
@ -183,6 +185,7 @@ class GBTree : public GradientBooster {
const std::vector<bst_gpair>& gpair = *in_gpair;
std::vector<std::vector<std::unique_ptr<RegTree> > > new_trees;
const int ngroup = model_.param.num_output_group;
monitor.Start("BoostNewTrees");
if (ngroup == 1) {
std::vector<std::unique_ptr<RegTree> > ret;
BoostNewTrees(gpair, p_fmat, 0, &ret);
@ -202,13 +205,12 @@ class GBTree : public GradientBooster {
new_trees.push_back(std::move(ret));
}
}
double tstart = dmlc::GetTime();
monitor.Stop("BoostNewTrees");
monitor.Start("CommitModel");
for (int gid = 0; gid < ngroup; ++gid) {
this->CommitModel(std::move(new_trees[gid]), gid);
}
if (tparam.debug_verbose > 0) {
LOG(INFO) << "CommitModel(): " << dmlc::GetTime() - tstart << " sec";
}
monitor.Stop("CommitModel");
}
void PredictBatch(DMatrix* p_fmat,
@ -308,6 +310,7 @@ class GBTree : public GradientBooster {
// Cached matrices
std::vector<std::shared_ptr<DMatrix>> cache_;
std::unique_ptr<Predictor> predictor;
common::Monitor monitor;
};
// dart

View File

@ -18,6 +18,7 @@
#include "./common/common.h"
#include "./common/io.h"
#include "./common/random.h"
#include "common/timer.h"
namespace xgboost {
// implementation of base learner.
@ -202,6 +203,7 @@ class LearnerImpl : public Learner {
const std::vector<std::pair<std::string, std::string> >& args) override {
// add to configurations
tparam.InitAllowUnknown(args);
monitor.Init("Learner", tparam.debug_verbose);
cfg_.clear();
for (const auto& kv : args) {
if (kv.first == "eval_metric") {
@ -359,29 +361,37 @@ class LearnerImpl : public Learner {
}
void UpdateOneIter(int iter, DMatrix* train) override {
monitor.Start("UpdateOneIter");
CHECK(ModelInitialized())
<< "Always call InitModel or LoadModel before update";
if (tparam.seed_per_iteration || rabit::IsDistributed()) {
common::GlobalRandom().seed(tparam.seed * kRandSeedMagic + iter);
}
this->LazyInitDMatrix(train);
monitor.Start("PredictRaw");
this->PredictRaw(train, &preds_);
monitor.Stop("PredictRaw");
monitor.Start("GetGradient");
obj_->GetGradient(preds_, train->info(), iter, &gpair_);
monitor.Stop("GetGradient");
gbm_->DoBoost(train, &gpair_, obj_.get());
monitor.Stop("UpdateOneIter");
}
void BoostOneIter(int iter, DMatrix* train,
std::vector<bst_gpair>* in_gpair) override {
monitor.Start("BoostOneIter");
if (tparam.seed_per_iteration || rabit::IsDistributed()) {
common::GlobalRandom().seed(tparam.seed * kRandSeedMagic + iter);
}
this->LazyInitDMatrix(train);
gbm_->DoBoost(train, in_gpair);
monitor.Stop("BoostOneIter");
}
std::string EvalOneIter(int iter, const std::vector<DMatrix*>& data_sets,
const std::vector<std::string>& data_names) override {
double tstart = dmlc::GetTime();
monitor.Start("EvalOneIter");
std::ostringstream os;
os << '[' << iter << ']' << std::setiosflags(std::ios::fixed);
if (metrics_.size() == 0) {
@ -396,9 +406,7 @@ class LearnerImpl : public Learner {
}
}
if (tparam.debug_verbose > 0) {
LOG(INFO) << "EvalOneIter(): " << dmlc::GetTime() - tstart << " sec";
}
monitor.Stop("EvalOneIter");
return os.str();
}
@ -460,10 +468,11 @@ class LearnerImpl : public Learner {
// if not, initialize the column access.
inline void LazyInitDMatrix(DMatrix* p_train) {
if (tparam.tree_method == 3 || tparam.tree_method == 4 ||
tparam.tree_method == 5) {
tparam.tree_method == 5 || tparam.tree_method == 6) {
return;
}
monitor.Start("LazyInitDMatrix");
if (!p_train->HaveColAccess()) {
int ncol = static_cast<int>(p_train->info().num_col);
std::vector<bool> enabled(ncol, true);
@ -504,6 +513,7 @@ class LearnerImpl : public Learner {
gbm_->Configure(cfg_.begin(), cfg_.end());
}
}
monitor.Stop("LazyInitDMatrix");
}
// return whether model is already initialized.
@ -568,6 +578,8 @@ class LearnerImpl : public Learner {
static const int kRandSeedMagic = 127;
// internal cached dmatrix
std::vector<std::shared_ptr<DMatrix> > cache_;
common::Monitor monitor;
};
Learner* Learner::Create(

View File

@ -65,9 +65,9 @@ class SoftmaxMultiClassObj : public ObjFunction {
bst_float p = rec[k];
const bst_float h = 2.0f * p * (1.0f - p) * wt;
if (label == k) {
out_gpair->at(i * nclass + k) = bst_gpair((p - 1.0f) * wt, h);
(*out_gpair)[i * nclass + k] = bst_gpair((p - 1.0f) * wt, h);
} else {
out_gpair->at(i * nclass + k) = bst_gpair(p* wt, h);
(*out_gpair)[i * nclass + k] = bst_gpair(p* wt, h);
}
}
}

View File

@ -1,6 +1,6 @@
/*!
* Copyright 2015 by Contributors
* \file regression.cc
* \file regression_obj.cc
* \brief Definition of single-value regression and classification objectives.
* \author Tianqi Chen, Kailong Chen
*/
@ -68,53 +68,61 @@ struct LogisticRaw : public LogisticRegression {
struct RegLossParam : public dmlc::Parameter<RegLossParam> {
float scale_pos_weight;
int nthread;
// declare parameters
DMLC_DECLARE_PARAMETER(RegLossParam) {
DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f)
.describe("Scale the weight of positive examples by this factor");
DMLC_DECLARE_FIELD(nthread).set_default(0).describe(
"Number of threads to use.");
}
};
// regression loss function
template<typename Loss>
template <typename Loss>
class RegLossObj : public ObjFunction {
public:
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
RegLossObj() : labels_checked(false) {}
void Configure(
const std::vector<std::pair<std::string, std::string> > &args) override {
param_.InitAllowUnknown(args);
}
void GetGradient(const std::vector<bst_float> &preds,
const MetaInfo &info,
int iter,
std::vector<bst_gpair> *out_gpair) override {
void GetGradient(const std::vector<bst_float> &preds, const MetaInfo &info,
int iter, std::vector<bst_gpair> *out_gpair) override {
CHECK_NE(info.labels.size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.size(), info.labels.size())
<< "labels are not correctly provided"
<< "preds.size=" << preds.size() << ", label.size=" << info.labels.size();
<< "preds.size=" << preds.size()
<< ", label.size=" << info.labels.size();
this->LazyCheckLabels(info.labels);
out_gpair->resize(preds.size());
// check if label in range
bool label_correct = true;
// start calculating gradient
const omp_ulong ndata = static_cast<omp_ulong>(preds.size());
#pragma omp parallel for schedule(static)
int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread;
#pragma omp parallel for schedule(static) num_threads(nthread)
for (omp_ulong i = 0; i < ndata; ++i) {
auto y = info.labels[i];
bst_float p = Loss::PredTransform(preds[i]);
bst_float w = info.GetWeight(i);
if (info.labels[i] == 1.0f) w *= param_.scale_pos_weight;
if (!Loss::CheckLabel(info.labels[i])) label_correct = false;
out_gpair->at(i) = bst_gpair(Loss::FirstOrderGradient(p, info.labels[i]) * w,
Loss::SecondOrderGradient(p, info.labels[i]) * w);
}
if (!label_correct) {
LOG(FATAL) << Loss::LabelErrorMsg();
// Branchless version of the below function
// The branch is particularly slow as the cpu cannot predict the label
// with any accuracy resulting in frequent pipeline stalls
// if (info.labels[i] == 1.0f) w *= param_.scale_pos_weight;
w += y * ((param_.scale_pos_weight * w) - w);
(*out_gpair)[i] = bst_gpair(Loss::FirstOrderGradient(p, y) * w,
Loss::SecondOrderGradient(p, y) * w);
}
}
const char* DefaultEvalMetric() const override {
const char *DefaultEvalMetric() const override {
return Loss::DefaultEvalMetric();
}
void PredTransform(std::vector<bst_float> *io_preds) override {
std::vector<bst_float> &preds = *io_preds;
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size());
#pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {
preds[j] = Loss::PredTransform(preds[j]);
}
@ -124,7 +132,15 @@ class RegLossObj : public ObjFunction {
}
protected:
void LazyCheckLabels(const std::vector<float> &labels) {
if (labels_checked) return;
for (auto &y : labels) {
CHECK(Loss::CheckLabel(y)) << Loss::LabelErrorMsg();
}
labels_checked = true;
}
RegLossParam param_;
bool labels_checked;
};
// register the objective functions
@ -149,10 +165,13 @@ XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, "binary:logitraw")
// declare parameter
struct PoissonRegressionParam : public dmlc::Parameter<PoissonRegressionParam> {
float max_delta_step;
int nthread;
DMLC_DECLARE_PARAMETER(PoissonRegressionParam) {
DMLC_DECLARE_FIELD(max_delta_step).set_lower_bound(0.0f).set_default(0.7f)
.describe("Maximum delta step we allow each weight estimation to be." \
" This parameter is required for possion regression.");
DMLC_DECLARE_FIELD(nthread).set_default(0).describe(
"Number of threads to use.");
}
};
@ -175,13 +194,14 @@ class PoissonRegression : public ObjFunction {
bool label_correct = true;
// start calculating gradient
const omp_ulong ndata = static_cast<omp_ulong>(preds.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread;
#pragma omp parallel for schedule(static) num_threads(nthread)
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
bst_float p = preds[i];
bst_float w = info.GetWeight(i);
bst_float y = info.labels[i];
if (y >= 0.0f) {
out_gpair->at(i) = bst_gpair((std::exp(p) - y) * w,
(*out_gpair)[i] = bst_gpair((std::exp(p) - y) * w,
std::exp(p + param_.max_delta_step) * w);
} else {
label_correct = false;
@ -192,7 +212,8 @@ class PoissonRegression : public ObjFunction {
void PredTransform(std::vector<bst_float> *io_preds) override {
std::vector<bst_float> &preds = *io_preds;
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread;
#pragma omp parallel for schedule(static) num_threads(nthread)
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
preds[j] = std::exp(preds[j]);
}
@ -242,7 +263,7 @@ class GammaRegression : public ObjFunction {
bst_float w = info.GetWeight(i);
bst_float y = info.labels[i];
if (y >= 0.0f) {
out_gpair->at(i) = bst_gpair((1 - y / std::exp(p)) * w, y / std::exp(p) * w);
(*out_gpair)[i] = bst_gpair((1 - y / std::exp(p)) * w, y / std::exp(p) * w);
} else {
label_correct = false;
}
@ -276,9 +297,12 @@ XGBOOST_REGISTER_OBJECTIVE(GammaRegression, "reg:gamma")
// declare parameter
struct TweedieRegressionParam : public dmlc::Parameter<TweedieRegressionParam> {
float tweedie_variance_power;
int nthread;
DMLC_DECLARE_PARAMETER(TweedieRegressionParam) {
DMLC_DECLARE_FIELD(tweedie_variance_power).set_range(1.0f, 2.0f).set_default(1.5f)
.describe("Tweedie variance power. Must be between in range [1, 2).");
DMLC_DECLARE_FIELD(nthread).set_default(0).describe(
"Number of threads to use.");
}
};
@ -301,7 +325,8 @@ class TweedieRegression : public ObjFunction {
bool label_correct = true;
// start calculating gradient
const omp_ulong ndata = static_cast<omp_ulong>(preds.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread;
#pragma omp parallel for schedule(static) num_threads(nthread)
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
bst_float p = preds[i];
bst_float w = info.GetWeight(i);
@ -311,7 +336,7 @@ class TweedieRegression : public ObjFunction {
bst_float grad = -y * std::exp((1 - rho) * p) + std::exp((2 - rho) * p);
bst_float hess = -y * (1 - rho) * \
std::exp((1 - rho) * p) + (2 - rho) * std::exp((2 - rho) * p);
out_gpair->at(i) = bst_gpair(grad * w, hess * w);
(*out_gpair)[i] = bst_gpair(grad * w, hess * w);
} else {
label_correct = false;
}
@ -321,7 +346,8 @@ class TweedieRegression : public ObjFunction {
void PredTransform(std::vector<bst_float> *io_preds) override {
std::vector<bst_float> &preds = *io_preds;
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread;
#pragma omp parallel for schedule(static) num_threads(nthread)
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
preds[j] = std::exp(preds[j]);
}

View File

@ -12,9 +12,51 @@
#include "../common/random.h"
#include "param.h"
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600
#else
__device__ __forceinline__ double atomicAdd(double* address, double val) {
unsigned long long int* address_as_ull = (unsigned long long int*)address; // NOLINT
unsigned long long int old = *address_as_ull, assumed; // NOLINT
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
// Note: uses integer comparison to avoid hang in case of NaN (since NaN !=
// NaN)
} while (assumed != old);
return __longlong_as_double(old);
}
#endif
namespace xgboost {
namespace tree {
// Atomic add function for double precision gradients
__device__ __forceinline__ void AtomicAddGpair(bst_gpair_precise* dest,
const bst_gpair& gpair) {
auto dst_ptr = reinterpret_cast<double*>(dest);
atomicAdd(dst_ptr, static_cast<double>(gpair.GetGrad()));
atomicAdd(dst_ptr + 1, static_cast<double>(gpair.GetHess()));
}
// For integer gradients
__device__ __forceinline__ void AtomicAddGpair(bst_gpair_integer* dest,
const bst_gpair& gpair) {
auto dst_ptr = reinterpret_cast<unsigned long long int*>(dest); // NOLINT
bst_gpair_integer tmp(gpair.GetGrad(), gpair.GetHess());
auto src_ptr = reinterpret_cast<bst_gpair_integer::value_t*>(&tmp);
atomicAdd(dst_ptr,
static_cast<unsigned long long int>(*src_ptr)); // NOLINT
atomicAdd(dst_ptr + 1,
static_cast<unsigned long long int>(*(src_ptr + 1))); // NOLINT
}
/**
* \fn void CheckGradientMax(const dh::dvec<bst_gpair>& gpair)
*
@ -22,15 +64,11 @@ namespace tree {
* overflow when using integer gradient summation.
*/
inline void CheckGradientMax(const dh::dvec<bst_gpair>& gpair) {
auto dptr = thrust::device_ptr<const float>(
reinterpret_cast<const float*>(gpair.data()));
float abs_max = thrust::reduce(dptr, dptr + (gpair.size() * 2), 0.f,
[=] __device__(float a, float b) {
a = abs(a);
b = abs(b);
return max(a, b);
});
inline void CheckGradientMax(const std::vector<bst_gpair>& gpair) {
auto* ptr = reinterpret_cast<const float*>(gpair.data());
float abs_max =
std::accumulate(ptr, ptr + (gpair.size() * 2), 0.f,
[=](float a, float b) { return max(abs(a), abs(b)); });
CHECK_LT(abs_max, std::pow(2.0f, 16.0f))
<< "Labels are too large for this algorithm. Rescale to less than 2^16.";
@ -321,8 +359,74 @@ inline std::vector<int> col_sample(std::vector<int> features, float colsample) {
std::shuffle(features.begin(), features.end(), common::GlobalRandom());
features.resize(n);
std::sort(features.begin(), features.end());
return features;
}
/**
* \class ColumnSampler
*
* \brief Handles selection of columns due to colsample_bytree and
* colsample_bylevel parameters. Should be initialised the before tree
* construction and to reset When tree construction is completed.
*/
class ColumnSampler {
std::vector<int> feature_set_tree;
std::map<int, std::vector<int>> feature_set_level;
TrainParam param;
public:
/**
* \fn void Init(int64_t num_col, const TrainParam& param)
*
* \brief Initialise this object before use.
*
* \param num_col Number of cols.
* \param param The parameter.
*/
void Init(int64_t num_col, const TrainParam& param) {
this->Reset();
this->param = param;
feature_set_tree.resize(num_col);
std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0);
feature_set_tree = col_sample(feature_set_tree, param.colsample_bytree);
}
/**
* \fn void Reset()
*
* \brief Resets this object.
*/
void Reset() {
feature_set_tree.clear();
feature_set_level.clear();
}
/**
* \fn bool ColumnUsed(int column, int depth)
*
* \brief Whether the current column should be considered as a split.
*
* \param column The column index.
* \param depth The current tree depth.
*
* \return True if it should be used, false if it should not be used.
*/
bool ColumnUsed(int column, int depth) {
if (feature_set_level.count(depth) == 0) {
feature_set_level[depth] =
col_sample(feature_set_tree, param.colsample_bylevel);
}
return std::binary_search(feature_set_level[depth].begin(),
feature_set_level[depth].end(), column);
}
};
} // namespace tree
} // namespace xgboost

View File

@ -8,6 +8,7 @@
#include "../common/compressed_iterator.h"
#include "../common/device_helpers.cuh"
#include "../common/hist_util.h"
#include "../common/timer.h"
#include "param.h"
#include "updater_gpu_common.cuh"
@ -17,7 +18,6 @@ namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
typedef bst_gpair_integer gpair_sum_t;
static const ncclDataType_t nccl_sum_t = ncclInt64;
// Helper for explicit template specialisation
template <int N>
@ -63,15 +63,7 @@ struct HistHelper {
__device__ void Add(bst_gpair gpair, int gidx, int nidx) const {
int hist_idx = nidx * n_bins + gidx;
auto dst_ptr =
reinterpret_cast<unsigned long long int*>(&d_hist[hist_idx]); // NOLINT
gpair_sum_t tmp(gpair.GetGrad(), gpair.GetHess());
auto src_ptr = reinterpret_cast<gpair_sum_t::value_t*>(&tmp);
atomicAdd(dst_ptr,
static_cast<unsigned long long int>(*src_ptr)); // NOLINT
atomicAdd(dst_ptr + 1,
static_cast<unsigned long long int>(*(src_ptr + 1))); // NOLINT
AtomicAddGpair(d_hist + hist_idx, gpair);
}
__device__ gpair_sum_t Get(int gidx, int nidx) const {
return d_hist[nidx * n_bins + gidx];
@ -244,22 +236,7 @@ class GPUHistMaker : public TreeUpdater {
is_dense(false),
p_last_fmat_(nullptr),
prediction_cache_initialised(false) {}
~GPUHistMaker() {
if (initialised) {
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
ncclCommDestroy(comms[d_idx]);
dh::safe_cuda(cudaSetDevice(dList[d_idx]));
dh::safe_cuda(cudaStreamDestroy(*(streams[d_idx])));
}
for (int num_d = 1; num_d <= n_devices;
++num_d) { // loop over number of devices used
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
ncclCommDestroy(find_split_comms[num_d - 1][d_idx]);
}
}
}
}
~GPUHistMaker() {}
void Init(
const std::vector<std::pair<std::string, std::string>>& args) override {
param.InitAllowUnknown(args);
@ -290,7 +267,7 @@ class GPUHistMaker : public TreeUpdater {
void InitData(const std::vector<bst_gpair>& gpair, DMatrix& fmat, // NOLINT
const RegTree& tree) {
dh::Timer time1;
common::Timer time1;
// set member num_rows and n_devices for rest of GPUHistBuilder members
info = &fmat.info();
CHECK(info->num_row < std::numeric_limits<bst_uint>::max());
@ -298,6 +275,12 @@ class GPUHistMaker : public TreeUpdater {
n_devices = dh::n_devices(param.n_gpus, num_rows);
if (!initialised) {
// Check gradients are within acceptable size range
CheckGradientMax(gpair);
// Check compute capability is high enough
dh::check_compute_capability();
// reset static timers used across iterations
cpu_init_time = 0;
gpu_init_time = 0;
@ -312,57 +295,10 @@ class GPUHistMaker : public TreeUpdater {
}
// initialize nccl
comms.resize(n_devices);
streams.resize(n_devices);
dh::safe_nccl(ncclCommInitAll(comms.data(), n_devices,
dList.data())); // initialize communicator
// (One communicator per
// process)
// printf("# NCCL: Using devices\n");
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
streams[d_idx] =
reinterpret_cast<cudaStream_t*>(malloc(sizeof(cudaStream_t)));
dh::safe_cuda(cudaSetDevice(dList[d_idx]));
dh::safe_cuda(cudaStreamCreate(streams[d_idx]));
int cudaDev;
int rank;
cudaDeviceProp prop;
dh::safe_nccl(ncclCommCuDevice(comms[d_idx], &cudaDev));
dh::safe_nccl(ncclCommUserRank(comms[d_idx], &rank));
dh::safe_cuda(cudaGetDeviceProperties(&prop, cudaDev));
// printf("# Rank %2d uses device %2d [0x%02x] %s\n", rank, cudaDev,
// prop.pciBusID, prop.name);
// cudaDriverGetVersion(&driverVersion);
// cudaRuntimeGetVersion(&runtimeVersion);
std::ostringstream oss;
oss << "CUDA Capability Major/Minor version number: " << prop.major
<< "." << prop.minor << " is insufficient. Need >=3.5.";
int failed = prop.major < 3 || prop.major == 3 && prop.minor < 5;
CHECK(failed == 0) << oss.str();
}
// local find_split group of comms for each case of reduced number of
// GPUs to use
find_split_comms.resize(
n_devices,
std::vector<ncclComm_t>(n_devices)); // TODO(JCM): Excessive, but
// ok, and best to do
// here instead of
// repeatedly
for (int num_d = 1; num_d <= n_devices;
++num_d) { // loop over number of devices used
dh::safe_nccl(
ncclCommInitAll(find_split_comms[num_d - 1].data(), num_d,
dList.data())); // initialize communicator
// (One communicator per
// process)
}
reducer.Init(dList);
is_dense = info->num_nonzero == info->num_col * info->num_row;
dh::Timer time0;
common::Timer time0;
hmat_.Init(&fmat, param.max_bin);
cpu_init_time += time0.ElapsedSeconds();
if (param.debug_verbose) { // Only done once for each training session
@ -397,8 +333,8 @@ class GPUHistMaker : public TreeUpdater {
fflush(stdout);
}
int n_bins = static_cast<int >(hmat_.row_ptr.back());
int n_features = static_cast<int >(hmat_.row_ptr.size() - 1);
int n_bins = static_cast<int>(hmat_.row_ptr.back());
int n_features = static_cast<int>(hmat_.row_ptr.size() - 1);
// deliniate data onto multiple gpus
device_row_segments.push_back(0);
@ -442,10 +378,7 @@ class GPUHistMaker : public TreeUpdater {
temp_memory.resize(n_devices);
hist_vec.resize(n_devices);
nodes.resize(n_devices);
nodes_temp.resize(n_devices);
nodes_child_temp.resize(n_devices);
left_child_smallest.resize(n_devices);
left_child_smallest_temp.resize(n_devices);
feature_flags.resize(n_devices);
fidx_min_map.resize(n_devices);
feature_segments.resize(n_devices);
@ -457,12 +390,6 @@ class GPUHistMaker : public TreeUpdater {
gidx_feature_map.resize(n_devices);
gidx_fvalue_map.resize(n_devices);
int find_split_n_devices = static_cast<int >(std::pow(2, std::floor(std::log2(n_devices))));
find_split_n_devices =
std::min(n_nodes_level(param.max_depth), find_split_n_devices);
int max_num_nodes_device =
n_nodes_level(param.max_depth) / find_split_n_devices;
// num_rows_segment: for sharding rows onto gpus for splitting data
// num_elements_segment: for sharding rows (of elements) onto gpus for
// splitting data
@ -476,26 +403,31 @@ class GPUHistMaker : public TreeUpdater {
device_row_segments[d_idx + 1] - device_row_segments[d_idx];
bst_ulong num_elements_segment =
device_element_segments[d_idx + 1] - device_element_segments[d_idx];
// ensure allocation doesn't overflow
size_t hist_size = static_cast<size_t>(n_nodes(param.max_depth - 1)) *
static_cast<size_t>(n_bins);
size_t nodes_size = static_cast<size_t>(n_nodes(param.max_depth));
size_t hmat_size = static_cast<size_t>(hmat_.min_val.size());
size_t buffer_size = static_cast<size_t>(
common::CompressedBufferWriter::CalculateBufferSize(
static_cast<size_t>(num_elements_segment),
static_cast<size_t>(n_bins)));
ba.allocate(
device_idx, param.silent, &(hist_vec[d_idx].data),
n_nodes(param.max_depth - 1) * n_bins, &nodes[d_idx],
n_nodes(param.max_depth), &nodes_temp[d_idx], max_num_nodes_device,
&nodes_child_temp[d_idx], max_num_nodes_device,
&left_child_smallest[d_idx], n_nodes(param.max_depth),
&left_child_smallest_temp[d_idx], max_num_nodes_device,
&feature_flags[d_idx],
device_idx, param.silent, &(hist_vec[d_idx].data), hist_size,
&nodes[d_idx], n_nodes(param.max_depth),
&left_child_smallest[d_idx], nodes_size, &feature_flags[d_idx],
n_features, // may change but same on all devices
&fidx_min_map[d_idx],
hmat_.min_val.size(), // constant and same on all devices
hmat_size, // constant and same on all devices
&feature_segments[d_idx],
h_feature_segments.size(), // constant and same on all devices
&prediction_cache[d_idx], num_rows_segment, &position[d_idx],
num_rows_segment, &position_tmp[d_idx], num_rows_segment,
&device_gpair[d_idx], num_rows_segment,
&device_matrix[d_idx].gidx_buffer,
common::CompressedBufferWriter::CalculateBufferSize(
num_elements_segment,
n_bins), // constant and same on all devices
buffer_size, // constant and same on all devices
&device_matrix[d_idx].row_ptr, num_rows_segment + 1,
&gidx_feature_map[d_idx],
n_bins, // constant and same on all devices
@ -529,17 +461,12 @@ class GPUHistMaker : public TreeUpdater {
dh::safe_cuda(cudaSetDevice(device_idx));
nodes[d_idx].fill(DeviceNodeStats());
nodes_temp[d_idx].fill(DeviceNodeStats());
nodes_child_temp[d_idx].fill(DeviceNodeStats());
position[d_idx].fill(0);
device_gpair[d_idx].copy(gpair.begin() + device_row_segments[d_idx],
gpair.begin() + device_row_segments[d_idx + 1]);
// Check gradients are within acceptable size range
CheckGradientMax(device_gpair[d_idx]);
subsample_gpair(&device_gpair[d_idx], param.subsample,
device_row_segments[d_idx]);
@ -618,21 +545,16 @@ class GPUHistMaker : public TreeUpdater {
// fprintf(stderr,"sizeof(bst_gpair)/sizeof(float)=%d\n",sizeof(bst_gpair)/sizeof(float));
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
int device_idx = dList[d_idx];
dh::safe_cuda(cudaSetDevice(device_idx));
dh::safe_nccl(ncclAllReduce(
reinterpret_cast<const void*>(hist_vec[d_idx].GetLevelPtr(depth)),
reinterpret_cast<void*>(hist_vec[d_idx].GetLevelPtr(depth)),
hist_vec[d_idx].LevelSize(depth) * sizeof(gpair_sum_t) /
sizeof(gpair_sum_t::value_t),
nccl_sum_t, ncclSum, comms[d_idx], *(streams[d_idx])));
reducer.AllReduceSum(device_idx,
reinterpret_cast<gpair_sum_t::value_t*>(
hist_vec[d_idx].GetLevelPtr(depth)),
reinterpret_cast<gpair_sum_t::value_t*>(
hist_vec[d_idx].GetLevelPtr(depth)),
hist_vec[d_idx].LevelSize(depth) *
sizeof(gpair_sum_t) /
sizeof(gpair_sum_t::value_t));
}
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
int device_idx = dList[d_idx];
dh::safe_cuda(cudaSetDevice(device_idx));
dh::safe_cuda(cudaStreamSynchronize(*(streams[d_idx])));
}
// if no NCCL, then presume only 1 GPU, then already correct
reducer.Synchronize();
// time.printElapsed("Reduce-Add Time");
@ -955,7 +877,7 @@ class GPUHistMaker : public TreeUpdater {
}
void UpdateTree(const std::vector<bst_gpair>& gpair, DMatrix* p_fmat,
RegTree* p_tree) {
dh::Timer time0;
common::Timer time0;
this->InitData(gpair, *p_fmat, *p_tree);
this->InitFirstNode(gpair);
@ -1019,10 +941,7 @@ class GPUHistMaker : public TreeUpdater {
std::vector<dh::CubMemory> temp_memory;
std::vector<DeviceHist> hist_vec;
std::vector<dh::dvec<DeviceNodeStats>> nodes;
std::vector<dh::dvec<DeviceNodeStats>> nodes_temp;
std::vector<dh::dvec<DeviceNodeStats>> nodes_child_temp;
std::vector<dh::dvec<bool>> left_child_smallest;
std::vector<dh::dvec<bool>> left_child_smallest_temp;
std::vector<dh::dvec<int>> feature_flags;
std::vector<dh::dvec<float>> fidx_min_map;
std::vector<dh::dvec<int>> feature_segments;
@ -1034,13 +953,11 @@ class GPUHistMaker : public TreeUpdater {
std::vector<dh::dvec<int>> gidx_feature_map;
std::vector<dh::dvec<float>> gidx_fvalue_map;
std::vector<cudaStream_t*> streams;
std::vector<ncclComm_t> comms;
std::vector<std::vector<ncclComm_t>> find_split_comms;
dh::AllReducer reducer;
double cpu_init_time;
double gpu_init_time;
dh::Timer cpu_time;
common::Timer cpu_time;
double gpu_time;
};

View File

@ -1,8 +1,9 @@
/*!
* Copyright 2017 XGBoost contributors
*/
#include <thrust/count.h>
#include <thrust/sort.h>
#include <thrust/reduce.h>
#include <thrust/execution_policy.h>
#include <thrust/sequence.h>
#include <xgboost/tree_updater.h>
#include <algorithm>
#include <memory>
@ -12,6 +13,7 @@
#include "../common/compressed_iterator.h"
#include "../common/device_helpers.cuh"
#include "../common/hist_util.h"
#include "../common/timer.h"
#include "param.h"
#include "updater_gpu_common.cuh"
@ -20,19 +22,20 @@ namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist_experimental);
template <int BLOCK_THREADS, typename reduce_t, typename temp_storage_t>
__device__ bst_gpair_integer ReduceFeature(const bst_gpair_integer* begin,
const bst_gpair_integer* end,
temp_storage_t* temp_storage) {
__shared__ cub::Uninitialized<bst_gpair_integer> uninitialized_sum;
bst_gpair_integer& shared_sum = uninitialized_sum.Alias();
typedef bst_gpair_precise gpair_sum_t;
bst_gpair_integer local_sum = bst_gpair_integer();
template <int BLOCK_THREADS, typename reduce_t, typename temp_storage_t>
__device__ gpair_sum_t ReduceFeature(const gpair_sum_t* begin,
const gpair_sum_t* end,
temp_storage_t* temp_storage) {
__shared__ cub::Uninitialized<gpair_sum_t> uninitialized_sum;
gpair_sum_t& shared_sum = uninitialized_sum.Alias();
gpair_sum_t local_sum = gpair_sum_t();
for (auto itr = begin; itr < end; itr += BLOCK_THREADS) {
bool thread_active = itr + threadIdx.x < end;
// Scan histogram
bst_gpair_integer bin =
thread_active ? *(itr + threadIdx.x) : bst_gpair_integer();
gpair_sum_t bin = thread_active ? *(itr + threadIdx.x) : gpair_sum_t();
local_sum += reduce_t(temp_storage->sum_reduce).Reduce(bin, cub::Sum());
}
@ -47,7 +50,7 @@ __device__ bst_gpair_integer ReduceFeature(const bst_gpair_integer* begin,
template <int BLOCK_THREADS, typename reduce_t, typename scan_t,
typename max_reduce_t, typename temp_storage_t>
__device__ void EvaluateFeature(int fidx, const bst_gpair_integer* hist,
__device__ void EvaluateFeature(int fidx, const gpair_sum_t* hist,
const int* feature_segments, float min_fvalue,
const float* gidx_fvalue_map,
DeviceSplitCandidate* best_split,
@ -57,22 +60,22 @@ __device__ void EvaluateFeature(int fidx, const bst_gpair_integer* hist,
int gidx_begin = feature_segments[fidx];
int gidx_end = feature_segments[fidx + 1];
bst_gpair_integer feature_sum = ReduceFeature<BLOCK_THREADS, reduce_t>(
gpair_sum_t feature_sum = ReduceFeature<BLOCK_THREADS, reduce_t>(
hist + gidx_begin, hist + gidx_end, temp_storage);
auto prefix_op = SumCallbackOp<bst_gpair_integer>();
auto prefix_op = SumCallbackOp<gpair_sum_t>();
for (int scan_begin = gidx_begin; scan_begin < gidx_end;
scan_begin += BLOCK_THREADS) {
bool thread_active = scan_begin + threadIdx.x < gidx_end;
bst_gpair_integer bin =
thread_active ? hist[scan_begin + threadIdx.x] : bst_gpair_integer();
gpair_sum_t bin =
thread_active ? hist[scan_begin + threadIdx.x] : gpair_sum_t();
scan_t(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
// Calculate gain
bst_gpair_integer parent_sum = bst_gpair_integer(node.sum_gradients);
// Calculate gain
gpair_sum_t parent_sum = gpair_sum_t(node.sum_gradients);
bst_gpair_integer missing = parent_sum - feature_sum;
gpair_sum_t missing = parent_sum - feature_sum;
bool missing_left = true;
const float null_gain = -FLT_MAX;
@ -102,8 +105,8 @@ __device__ void EvaluateFeature(int fidx, const bst_gpair_integer* hist,
float fvalue =
gidx == gidx_begin ? min_fvalue : gidx_fvalue_map[gidx - 1];
bst_gpair_integer left = missing_left ? bin + missing : bin;
bst_gpair_integer right = parent_sum - left;
gpair_sum_t left = missing_left ? bin + missing : bin;
gpair_sum_t right = parent_sum - left;
best_split->Update(gain, missing_left ? LeftDir : RightDir, fvalue, fidx,
left, right, param);
@ -114,17 +117,16 @@ __device__ void EvaluateFeature(int fidx, const bst_gpair_integer* hist,
template <int BLOCK_THREADS>
__global__ void evaluate_split_kernel(
const bst_gpair_integer* d_hist, int nidx, uint64_t n_features,
const gpair_sum_t* d_hist, int nidx, uint64_t n_features,
DeviceNodeStats nodes, const int* d_feature_segments,
const float* d_fidx_min_map, const float* d_gidx_fvalue_map,
GPUTrainingParam gpu_param, DeviceSplitCandidate* d_split) {
typedef cub::KeyValuePair<int, float> ArgMaxT;
typedef cub::BlockScan<bst_gpair_integer, BLOCK_THREADS,
cub::BLOCK_SCAN_WARP_SCANS>
typedef cub::BlockScan<gpair_sum_t, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>
BlockScanT;
typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT;
typedef cub::BlockReduce<bst_gpair_integer, BLOCK_THREADS> SumReduceT;
typedef cub::BlockReduce<gpair_sum_t, BLOCK_THREADS> SumReduceT;
union TempStorage {
typename BlockScanT::TempStorage scan;
@ -159,13 +161,6 @@ __global__ void evaluate_split_kernel(
template <typename gidx_iter_t>
__device__ int BinarySearchRow(bst_uint begin, bst_uint end, gidx_iter_t data,
int fidx_begin, int fidx_end) {
// for(auto i = begin; i < end; i++)
//{
// auto gidx = data[i];
// if (gidx >= fidx_begin&&gidx < fidx_end) return gidx;
//}
// return -1;
bst_uint previous_middle = UINT32_MAX;
while (end != begin) {
auto middle = begin + (end - begin) / 2;
@ -190,48 +185,63 @@ __device__ int BinarySearchRow(bst_uint begin, bst_uint end, gidx_iter_t data,
struct DeviceHistogram {
dh::bulk_allocator<dh::memory_type::DEVICE> ba;
dh::dvec<bst_gpair_integer> data;
std::map<int, bst_gpair_integer*> node_map;
dh::dvec<gpair_sum_t> data;
int n_bins;
void Init(int device_idx, int max_nodes, int n_bins, bool silent) {
this->n_bins = n_bins;
ba.allocate(device_idx, silent, &data, max_nodes * n_bins);
ba.allocate(device_idx, silent, &data, size_t(max_nodes) * size_t(n_bins));
}
void Reset() {
data.fill(bst_gpair_integer());
node_map.clear();
}
void Reset() { data.fill(gpair_sum_t()); }
gpair_sum_t* GetHistPtr(int nidx) { return data.data() + nidx * n_bins; }
void AddNode(int nidx) {
CHECK_EQ(node_map.count(nidx), 0)
<< nidx << " already exists in the histogram.";
node_map[nidx] = data.data() + n_bins * node_map.size();
void PrintNidx(int nidx) const {
auto h_data = data.as_vector();
std::cout << "nidx " << nidx << ":\n";
for (int i = n_bins * nidx; i < n_bins * (nidx + 1); i++) {
std::cout << h_data[i] << " ";
}
std::cout << "\n";
}
};
// Manage memory for a single GPU
struct DeviceShard {
struct Segment {
size_t begin;
size_t end;
Segment() : begin(0), end(0) {}
Segment(size_t begin, size_t end) : begin(begin), end(end) {
CHECK_GE(end, begin);
}
size_t Size() const { return end - begin; }
};
int device_idx;
int normalised_device_idx; // Device index counting from param.gpu_id
dh::bulk_allocator<dh::memory_type::DEVICE> ba;
dh::dvec<common::compressed_byte_t> gidx_buffer;
dh::dvec<bst_gpair> gpair;
dh::dvec2<bst_uint> ridx;
dh::dvec2<bst_uint> ridx; // Row index relative to this shard
dh::dvec2<int> position;
std::vector<std::pair<int64_t, int64_t>> ridx_segments;
std::vector<Segment> ridx_segments;
dh::dvec<int> feature_segments;
dh::dvec<float> gidx_fvalue_map;
dh::dvec<float> min_fvalue;
std::vector<bst_gpair> node_sum_gradients;
common::CompressedIterator<uint32_t> gidx;
int row_stride;
bst_uint row_start_idx;
bst_uint row_begin_idx; // The row offset for this shard
bst_uint row_end_idx;
bst_uint n_rows;
int n_bins;
int null_gidx_value;
DeviceHistogram hist;
TrainParam param;
int64_t* tmp_pinned; // Small amount of staging memory
std::vector<cudaStream_t> streams;
@ -242,11 +252,12 @@ struct DeviceShard {
bst_uint row_end, int n_bins, TrainParam param)
: device_idx(device_idx),
normalised_device_idx(normalised_device_idx),
row_start_idx(row_begin),
row_begin_idx(row_begin),
row_end_idx(row_end),
n_rows(row_end - row_begin),
n_bins(n_bins),
null_gidx_value(n_bins) {
null_gidx_value(n_bins),
param(param) {
// Convert to ELLPACK matrix representation
int max_elements_row = 0;
for (auto i = row_begin; i < row_end; i++) {
@ -260,7 +271,8 @@ struct DeviceShard {
for (auto i = row_begin; i < row_end; i++) {
int row_count = 0;
for (auto j = gmat.row_ptr[i]; j < gmat.row_ptr[i + 1]; j++) {
ellpack_matrix[i * row_stride + row_count] = gmat.index[j];
ellpack_matrix[(i - row_begin) * row_stride + row_count] =
gmat.index[j];
row_count++;
}
}
@ -296,12 +308,15 @@ struct DeviceShard {
// Init histogram
hist.Init(device_idx, max_nodes, gmat.cut->row_ptr.back(), param.silent);
dh::safe_cuda(cudaMallocHost(&tmp_pinned, sizeof(int64_t)));
}
~DeviceShard() {
for (auto& stream : streams) {
dh::safe_cuda(cudaStreamDestroy(stream));
}
dh::safe_cuda(cudaFreeHost(tmp_pinned));
}
// Get vector of at least n initialised streams
@ -324,81 +339,135 @@ struct DeviceShard {
// Reset values for each update iteration
void Reset(const std::vector<bst_gpair>& host_gpair) {
dh::safe_cuda(cudaSetDevice(device_idx));
position.current_dvec().fill(0);
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
bst_gpair());
// TODO(rory): support subsampling
thrust::sequence(ridx.current_dvec().tbegin(), ridx.current_dvec().tend(),
row_start_idx);
std::fill(ridx_segments.begin(), ridx_segments.end(), std::make_pair(0, 0));
ridx_segments.front() = std::make_pair(0, ridx.size());
this->gpair.copy(host_gpair.begin() + row_start_idx,
thrust::sequence(ridx.current_dvec().tbegin(), ridx.current_dvec().tend());
std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0));
ridx_segments.front() = Segment(0, ridx.size());
this->gpair.copy(host_gpair.begin() + row_begin_idx,
host_gpair.begin() + row_end_idx);
// Check gradients are within acceptable size range
CheckGradientMax(gpair);
subsample_gpair(&gpair, param.subsample, row_begin_idx);
hist.Reset();
}
__device__ void IncrementHist(bst_gpair gpair, int gidx,
bst_gpair_integer* node_hist) const {
auto dst_ptr =
reinterpret_cast<unsigned long long int*>(&node_hist[gidx]); // NOLINT
bst_gpair_integer tmp(gpair.GetGrad(), gpair.GetHess());
auto src_ptr = reinterpret_cast<bst_gpair_integer::value_t*>(&tmp);
atomicAdd(dst_ptr,
static_cast<unsigned long long int>(*src_ptr)); // NOLINT
atomicAdd(dst_ptr + 1,
static_cast<unsigned long long int>(*(src_ptr + 1))); // NOLINT
}
void BuildHist(int nidx) {
hist.AddNode(nidx);
auto d_node_hist = hist.node_map[nidx];
auto segment = ridx_segments[nidx];
auto d_node_hist = hist.GetHistPtr(nidx);
auto d_gidx = gidx;
auto d_ridx = ridx.current();
auto d_gpair = gpair.data();
auto row_stride = this->row_stride;
auto null_gidx_value = this->null_gidx_value;
auto segment = ridx_segments[nidx];
auto n_elements = (segment.second - segment.first) * row_stride;
auto n_elements = segment.Size() * row_stride;
dh::launch_n(device_idx, n_elements, [=] __device__(size_t idx) {
int relative_ridx = d_ridx[(idx / row_stride) + segment.first];
int gidx = d_gidx[relative_ridx * row_stride + idx % row_stride];
int ridx = d_ridx[(idx / row_stride) + segment.begin];
int gidx = d_gidx[ridx * row_stride + idx % row_stride];
if (gidx != null_gidx_value) {
bst_gpair gpair = d_gpair[relative_ridx];
IncrementHist(gpair, gidx, d_node_hist);
AtomicAddGpair(d_node_hist + gidx, d_gpair[ridx]);
}
});
}
void SortPosition(const std::pair<bst_uint, bst_uint>& segment, int left_nidx,
int right_nidx) {
auto n = segment.second - segment.first;
void SubtractionTrick(int nidx_parent, int nidx_histogram,
int nidx_subtraction) {
auto d_node_hist_parent = hist.GetHistPtr(nidx_parent);
auto d_node_hist_histogram = hist.GetHistPtr(nidx_histogram);
auto d_node_hist_subtraction = hist.GetHistPtr(nidx_subtraction);
dh::launch_n(device_idx, hist.n_bins, [=] __device__(size_t idx) {
d_node_hist_subtraction[idx] =
d_node_hist_parent[idx] - d_node_hist_histogram[idx];
});
}
__device__ void CountLeft(int64_t* d_count, int val, int left_nidx) {
unsigned ballot = __ballot(val == left_nidx);
if (threadIdx.x % 32 == 0) {
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
}
}
void UpdatePosition(int nidx, int left_nidx, int right_nidx, int fidx,
int split_gidx, bool default_dir_left, bool is_dense,
int fidx_begin, int fidx_end) {
dh::safe_cuda(cudaSetDevice(device_idx));
temp_memory.LazyAllocate(sizeof(int64_t));
auto d_left_count = temp_memory.Pointer<int64_t>();
dh::safe_cuda(cudaMemset(d_left_count, 0, sizeof(int64_t)));
auto segment = ridx_segments[nidx];
auto d_ridx = ridx.current();
auto d_position = position.current();
auto d_gidx = gidx;
auto row_stride = this->row_stride;
dh::launch_n<1, 512>(
device_idx, segment.Size(), [=] __device__(bst_uint idx) {
idx += segment.begin;
auto ridx = d_ridx[idx];
auto row_begin = row_stride * ridx;
auto row_end = row_begin + row_stride;
auto gidx = -1;
if (is_dense) {
gidx = d_gidx[row_begin + fidx];
} else {
gidx = BinarySearchRow(row_begin, row_end, d_gidx, fidx_begin,
fidx_end);
}
int position;
if (gidx >= 0) {
// Feature is found
position = gidx <= split_gidx ? left_nidx : right_nidx;
} else {
// Feature is missing
position = default_dir_left ? left_nidx : right_nidx;
}
CountLeft(d_left_count, position, left_nidx);
d_position[idx] = position;
});
dh::safe_cuda(cudaMemcpy(tmp_pinned, d_left_count, sizeof(int64_t),
cudaMemcpyDeviceToHost));
auto left_count = *tmp_pinned;
SortPosition(segment, left_nidx, right_nidx);
// dh::safe_cuda(cudaStreamSynchronize(stream));
ridx_segments[left_nidx] =
Segment(segment.begin, segment.begin + left_count);
ridx_segments[right_nidx] =
Segment(segment.begin + left_count, segment.end);
}
void SortPosition(const Segment& segment, int left_nidx, int right_nidx) {
int min_bits = 0;
int max_bits = static_cast<int>(
std::ceil(std::log2((std::max)(left_nidx, right_nidx) + 1)));
size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairs(
nullptr, temp_storage_bytes, position.current() + segment.first,
position.other() + segment.first, ridx.current() + segment.first,
ridx.other() + segment.first, n, min_bits, max_bits);
nullptr, temp_storage_bytes, position.current() + segment.begin,
position.other() + segment.begin, ridx.current() + segment.begin,
ridx.other() + segment.begin, segment.Size(), min_bits, max_bits);
temp_memory.LazyAllocate(temp_storage_bytes);
cub::DeviceRadixSort::SortPairs(
temp_memory.d_temp_storage, temp_memory.temp_storage_bytes,
position.current() + segment.first, position.other() + segment.first,
ridx.current() + segment.first, ridx.other() + segment.first, n,
min_bits, max_bits);
dh::safe_cuda(cudaMemcpy(position.current() + segment.first,
position.other() + segment.first, n * sizeof(int),
cudaMemcpyDeviceToDevice));
dh::safe_cuda(cudaMemcpy(ridx.current() + segment.first,
ridx.other() + segment.first, n * sizeof(bst_uint),
cudaMemcpyDeviceToDevice));
//}
position.current() + segment.begin, position.other() + segment.begin,
ridx.current() + segment.begin, ridx.other() + segment.begin,
segment.Size(), min_bits, max_bits);
dh::safe_cuda(cudaMemcpy(
position.current() + segment.begin, position.other() + segment.begin,
segment.Size() * sizeof(int), cudaMemcpyDeviceToDevice));
dh::safe_cuda(cudaMemcpy(
ridx.current() + segment.begin, ridx.other() + segment.begin,
segment.Size() * sizeof(bst_uint), cudaMemcpyDeviceToDevice));
}
};
@ -412,10 +481,10 @@ class GPUHistMakerExperimental : public TreeUpdater {
const std::vector<std::pair<std::string, std::string>>& args) override {
param.InitAllowUnknown(args);
CHECK(param.n_gpus != 0) << "Must have at least one device";
CHECK(param.n_gpus <= 1 && param.n_gpus != -1)
<< "Only one GPU currently supported";
n_devices = param.n_gpus;
dh::check_compute_capability();
if (param.grow_policy == TrainParam::kLossGuide) {
qexpand_.reset(new ExpandQueue(loss_guide));
} else {
@ -426,6 +495,7 @@ class GPUHistMakerExperimental : public TreeUpdater {
}
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
monitor.Start("Update");
GradStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees
float lr = param.learning_rate;
@ -439,36 +509,119 @@ class GPUHistMakerExperimental : public TreeUpdater {
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
}
param.learning_rate = lr;
monitor.Stop("Update");
}
void InitDataOnce(DMatrix* dmat) {
info = &dmat->info();
monitor.Start("Quantiles");
hmat_.Init(dmat, param.max_bin);
gmat_.cut = &hmat_;
gmat_.Init(dmat);
monitor.Stop("Quantiles");
n_bins = hmat_.row_ptr.back();
shards.emplace_back(param.gpu_id, 0, gmat_, 0, info->num_row, n_bins,
param);
int n_devices = dh::n_devices(param.n_gpus, info->num_row);
bst_uint row_begin = 0;
bst_uint shard_size =
std::ceil(static_cast<double>(info->num_row) / n_devices);
std::vector<int> dList(n_devices);
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
int device_idx = (param.gpu_id + d_idx) % dh::n_visible_devices();
dList[d_idx] = device_idx;
}
reducer.Init(dList);
// Partition input matrix into row segments
std::vector<size_t> row_segments;
shards.resize(n_devices);
row_segments.push_back(0);
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
bst_uint row_end =
std::min(static_cast<size_t>(row_begin + shard_size), info->num_row);
row_segments.push_back(row_end);
row_begin = row_end;
}
// Create device shards
omp_set_num_threads(shards.size());
#pragma omp parallel
{
auto cpu_thread_id = omp_get_thread_num();
shards[cpu_thread_id] = std::unique_ptr<DeviceShard>(
new DeviceShard(dList[cpu_thread_id], cpu_thread_id, gmat_,
row_segments[cpu_thread_id],
row_segments[cpu_thread_id + 1], n_bins, param));
}
initialised = true;
}
void InitData(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const RegTree& tree) {
monitor.Start("InitDataOnce");
if (!initialised) {
CheckGradientMax(gpair);
this->InitDataOnce(dmat);
}
monitor.Stop("InitDataOnce");
this->ColSampleTree();
column_sampler.Init(info->num_col, param);
// Copy gpair & reset memory
for (auto& shard : shards) {
shard.Reset(gpair);
monitor.Start("InitDataReset");
omp_set_num_threads(shards.size());
#pragma omp parallel
{
auto cpu_thread_id = omp_get_thread_num();
shards[cpu_thread_id]->Reset(gpair);
}
monitor.Stop("InitDataReset");
}
void BuildHist(int nidx) {
void AllReduceHist(int nidx) {
for (auto& shard : shards) {
shard.BuildHist(nidx);
auto d_node_hist = shard->hist.GetHistPtr(nidx);
reducer.AllReduceSum(
shard->normalised_device_idx,
reinterpret_cast<gpair_sum_t::value_t*>(d_node_hist),
reinterpret_cast<gpair_sum_t::value_t*>(d_node_hist),
n_bins * (sizeof(gpair_sum_t) / sizeof(gpair_sum_t::value_t)));
}
reducer.Synchronize();
}
void BuildHistLeftRight(int nidx_parent, int nidx_left, int nidx_right) {
size_t left_node_max_elements = 0;
size_t right_node_max_elements = 0;
for (auto& shard : shards) {
left_node_max_elements = (std::max)(
left_node_max_elements, shard->ridx_segments[nidx_left].Size());
right_node_max_elements = (std::max)(
right_node_max_elements, shard->ridx_segments[nidx_right].Size());
}
auto build_hist_nidx = nidx_left;
auto subtraction_trick_nidx = nidx_right;
if (right_node_max_elements < left_node_max_elements) {
build_hist_nidx = nidx_right;
subtraction_trick_nidx = nidx_left;
}
for (auto& shard : shards) {
shard->BuildHist(build_hist_nidx);
}
this->AllReduceHist(build_hist_nidx);
for (auto& shard : shards) {
shard->SubtractionTrick(nidx_parent, build_hist_nidx,
subtraction_trick_nidx);
}
}
@ -481,36 +634,41 @@ class GPUHistMakerExperimental : public TreeUpdater {
columns);
// Use first device
auto& shard = shards.front();
dh::safe_cuda(cudaSetDevice(shard.device_idx));
shard.temp_memory.LazyAllocate(sizeof(DeviceSplitCandidate) * columns *
nidx_set.size());
auto d_split = shard.temp_memory.Pointer<DeviceSplitCandidate>();
dh::safe_cuda(cudaSetDevice(shard->device_idx));
shard->temp_memory.LazyAllocate(sizeof(DeviceSplitCandidate) * columns *
nidx_set.size());
auto d_split = shard->temp_memory.Pointer<DeviceSplitCandidate>();
auto& streams = shard.GetStreams(static_cast<int>(nidx_set.size()));
auto& streams = shard->GetStreams(static_cast<int>(nidx_set.size()));
// Use streams to process nodes concurrently
for (auto i = 0; i < nidx_set.size(); i++) {
auto nidx = nidx_set[i];
DeviceNodeStats node(shard.node_sum_gradients[nidx], nidx, param);
DeviceNodeStats node(shard->node_sum_gradients[nidx], nidx, param);
const int BLOCK_THREADS = 256;
evaluate_split_kernel<BLOCK_THREADS>
<<<uint32_t(columns), BLOCK_THREADS, 0, streams[i]>>>(
shard.hist.node_map[nidx], nidx, info->num_col, node,
shard.feature_segments.data(), shard.min_fvalue.data(),
shard.gidx_fvalue_map.data(), GPUTrainingParam(param),
shard->hist.GetHistPtr(nidx), nidx, info->num_col, node,
shard->feature_segments.data(), shard->min_fvalue.data(),
shard->gidx_fvalue_map.data(), GPUTrainingParam(param),
d_split + i * columns);
}
dh::safe_cuda(
cudaMemcpy(candidate_splits.data(), shard.temp_memory.d_temp_storage,
cudaMemcpy(candidate_splits.data(), shard->temp_memory.d_temp_storage,
sizeof(DeviceSplitCandidate) * columns * nidx_set.size(),
cudaMemcpyDeviceToHost));
for (auto i = 0; i < nidx_set.size(); i++) {
auto nidx = nidx_set[i];
DeviceSplitCandidate nidx_best;
for (auto fidx = 0; fidx < columns; fidx++) {
nidx_best.Update(candidate_splits[i * columns + fidx], param);
auto& candidate = candidate_splits[i * columns + fidx];
if (column_sampler.ColumnUsed(candidate.findex,
p_tree->GetDepth(nidx))) {
nidx_best.Update(candidate_splits[i * columns + fidx], param);
}
}
best_splits[i] = nidx_best;
}
@ -518,15 +676,28 @@ class GPUHistMakerExperimental : public TreeUpdater {
}
void InitRoot(const std::vector<bst_gpair>& gpair, RegTree* p_tree) {
int root_nidx = 0;
BuildHist(root_nidx);
// TODO(rory): support sub sampling
// TODO(rory): not asynchronous
bst_gpair sum_gradient;
for (auto& shard : shards) {
sum_gradient += thrust::reduce(shard.gpair.tbegin(), shard.gpair.tend());
auto root_nidx = 0;
// Sum gradients
std::vector<bst_gpair> tmp_sums(shards.size());
omp_set_num_threads(shards.size());
#pragma omp parallel
{
auto cpu_thread_id = omp_get_thread_num();
dh::safe_cuda(cudaSetDevice(shards[cpu_thread_id]->device_idx));
tmp_sums[cpu_thread_id] =
thrust::reduce(thrust::cuda::par(shards[cpu_thread_id]->temp_memory),
shards[cpu_thread_id]->gpair.tbegin(),
shards[cpu_thread_id]->gpair.tend());
}
auto sum_gradient =
std::accumulate(tmp_sums.begin(), tmp_sums.end(), bst_gpair());
// Generate root histogram
for (auto& shard : shards) {
shard->BuildHist(root_nidx);
}
this->AllReduceHist(root_nidx);
// Remember root stats
p_tree->stat(root_nidx).sum_hess = sum_gradient.GetHess();
@ -534,33 +705,17 @@ class GPUHistMakerExperimental : public TreeUpdater {
// Store sum gradients
for (auto& shard : shards) {
shard.node_sum_gradients[root_nidx] = sum_gradient;
shard->node_sum_gradients[root_nidx] = sum_gradient;
}
// Generate first split
auto splits = this->EvaluateSplits({root_nidx}, p_tree);
// Generate candidate
qexpand_->push(
ExpandEntry(root_nidx, p_tree->GetDepth(root_nidx), splits.front(), 0));
}
struct MatchingFunctor : public thrust::unary_function<int, int> {
int val;
__host__ __device__ MatchingFunctor(int val) : val(val) {}
__host__ __device__ int operator()(int x) const { return x == val; }
};
__device__ void CountLeft(int64_t* d_count, int val, int left_nidx) {
unsigned ballot = __ballot(val == left_nidx);
if (threadIdx.x % 32 == 0) {
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
}
}
void UpdatePosition(const ExpandEntry& candidate, RegTree* p_tree) {
auto nidx = candidate.nid;
auto is_dense = info->num_nonzero == info->num_row * info->num_col;
auto left_nidx = (*p_tree)[nidx].cleft();
auto right_nidx = (*p_tree)[nidx].cright();
@ -577,58 +732,15 @@ class GPUHistMakerExperimental : public TreeUpdater {
}
}
for (auto& shard : shards) {
monitor.Start("update position kernel");
shard.temp_memory.LazyAllocate(sizeof(int64_t));
auto d_left_count = shard.temp_memory.Pointer<int64_t>();
dh::safe_cuda(cudaMemset(d_left_count, 0, sizeof(int64_t)));
dh::safe_cuda(cudaSetDevice(shard.device_idx));
auto segment = shard.ridx_segments[nidx];
CHECK_GT(segment.second - segment.first, 0);
auto d_ridx = shard.ridx.current();
auto d_position = shard.position.current();
auto d_gidx = shard.gidx;
auto row_stride = shard.row_stride;
dh::launch_n<1, 512>(
shard.device_idx, segment.second - segment.first,
[=] __device__(bst_uint idx) {
idx += segment.first;
auto ridx = d_ridx[idx];
auto row_begin = row_stride * ridx;
auto row_end = row_begin + row_stride;
auto gidx = -1;
if (is_dense) {
gidx = d_gidx[row_begin + fidx];
} else {
gidx = BinarySearchRow(row_begin, row_end, d_gidx, fidx_begin,
fidx_end);
}
auto is_dense = info->num_nonzero == info->num_row * info->num_col;
int position;
if (gidx >= 0) {
// Feature is found
position = gidx <= split_gidx ? left_nidx : right_nidx;
} else {
// Feature is missing
position = default_dir_left ? left_nidx : right_nidx;
}
CountLeft(d_left_count, position, left_nidx);
d_position[idx] = position;
});
int64_t left_count;
dh::safe_cuda(cudaMemcpy(&left_count, d_left_count, sizeof(int64_t),
cudaMemcpyDeviceToHost));
monitor.Stop("update position kernel");
monitor.Start("sort");
shard.SortPosition(segment, left_nidx, right_nidx);
monitor.Stop("sort");
shard.ridx_segments[left_nidx] =
std::make_pair(segment.first, segment.first + left_count);
shard.ridx_segments[right_nidx] =
std::make_pair(segment.first + left_count, segment.second);
omp_set_num_threads(shards.size());
#pragma omp parallel
{
auto cpu_thread_id = omp_get_thread_num();
shards[cpu_thread_id]->UpdatePosition(nidx, left_nidx, right_nidx, fidx,
split_gidx, default_dir_left,
is_dense, fidx_begin, fidx_end);
}
}
@ -654,41 +766,12 @@ class GPUHistMakerExperimental : public TreeUpdater {
tree.stat(parent.cright()).sum_hess = candidate.split.right_sum.GetHess();
// Store sum gradients
for (auto& shard : shards) {
shard.node_sum_gradients[parent.cleft()] = candidate.split.left_sum;
shard.node_sum_gradients[parent.cright()] = candidate.split.right_sum;
shard->node_sum_gradients[parent.cleft()] = candidate.split.left_sum;
shard->node_sum_gradients[parent.cright()] = candidate.split.right_sum;
}
this->UpdatePosition(candidate, p_tree);
}
void ColSampleTree() {
if (param.colsample_bylevel == 1.0 && param.colsample_bytree == 1.0) return;
feature_set_tree.resize(info->num_col);
std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0);
feature_set_tree = col_sample(feature_set_tree, param.colsample_bytree);
}
struct Monitor {
bool debug_verbose = false;
std::string label = "";
std::map<std::string, dh::Timer> timer_map;
~Monitor() {
if (!debug_verbose) return;
std::cout << "Monitor: " << label << "\n";
for (auto& kv : timer_map) {
kv.second.PrintElapsed(kv.first);
}
}
void Init(std::string label, bool debug_verbose) {
this->debug_verbose = debug_verbose;
this->label = label;
}
void Start(const std::string& name) { timer_map[name].Start(); }
void Stop(const std::string& name) { timer_map[name].Stop(); }
};
void UpdateTree(const std::vector<bst_gpair>& gpair, DMatrix* p_fmat,
RegTree* p_tree) {
auto& tree = *p_tree;
@ -720,8 +803,8 @@ class GPUHistMakerExperimental : public TreeUpdater {
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
num_leaves)) {
monitor.Start("BuildHist");
this->BuildHist(left_child_nidx);
this->BuildHist(right_child_nidx);
this->BuildHistLeftRight(candidate.nid, left_child_nidx,
right_child_nidx);
monitor.Stop("BuildHist");
monitor.Start("EvaluateSplits");
@ -793,14 +876,14 @@ class GPUHistMakerExperimental : public TreeUpdater {
int n_devices;
int n_bins;
std::vector<DeviceShard> shards;
std::vector<int> feature_set_tree;
std::vector<int> feature_set_level;
std::vector<std::unique_ptr<DeviceShard>> shards;
ColumnSampler column_sampler;
typedef std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
std::function<bool(ExpandEntry, ExpandEntry)>>
ExpandQueue;
std::unique_ptr<ExpandQueue> qexpand_;
Monitor monitor;
common::Monitor monitor;
dh::AllReducer reducer;
};
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMakerExperimental,

View File

@ -11,20 +11,36 @@ rng = np.random.RandomState(1994)
def run_benchmark(args):
print("Generating dataset: {} rows * {} columns".format(args.rows, args.columns))
print("{}/{} test/train split".format(args.test_size, 1.0 - args.test_size))
tmp = time.time()
X, y = make_classification(args.rows, n_features=args.columns, random_state=7)
if args.sparsity < 1.0:
X = np.array([[np.nan if rng.uniform(0, 1) < args.sparsity else x for x in x_row] for x_row in X])
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=args.test_size, random_state=7)
print ("Generate Time: %s seconds" % (str(time.time() - tmp)))
tmp = time.time()
print ("DMatrix Start")
dtrain = xgb.DMatrix(X_train, y_train, nthread=-1)
dtest = xgb.DMatrix(X_test, y_test, nthread=-1)
print ("DMatrix Time: %s seconds" % (str(time.time() - tmp)))
try:
dtest = xgb.DMatrix('dtest.dm')
dtrain = xgb.DMatrix('dtrain.dm')
if not (dtest.num_col() == args.columns \
and dtrain.num_col() == args.columns):
raise ValueError("Wrong cols")
if not (dtest.num_row() == args.rows * args.test_size \
and dtrain.num_row() == args.rows * (1-args.test_size)):
raise ValueError("Wrong rows")
except:
print("Generating dataset: {} rows * {} columns".format(args.rows, args.columns))
print("{}/{} test/train split".format(args.test_size, 1.0 - args.test_size))
tmp = time.time()
X, y = make_classification(args.rows, n_features=args.columns, n_redundant=0, n_informative=args.columns, n_repeated=0, random_state=7)
if args.sparsity < 1.0:
X = np.array([[np.nan if rng.uniform(0, 1) < args.sparsity else x for x in x_row] for x_row in X])
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=args.test_size, random_state=7)
print ("Generate Time: %s seconds" % (str(time.time() - tmp)))
tmp = time.time()
print ("DMatrix Start")
dtrain = xgb.DMatrix(X_train, y_train)
dtest = xgb.DMatrix(X_test, y_test, nthread=-1)
print ("DMatrix Time: %s seconds" % (str(time.time() - tmp)))
dtest.save_binary('dtest.dm')
dtrain.save_binary('dtrain.dm')
param = {'objective': 'binary:logistic'}
if args.params is not '':

View File

@ -6,6 +6,7 @@
#include <xgboost/base.h>
#include "../../../src/common/device_helpers.cuh"
#include "gtest/gtest.h"
#include "../../../src/common/timer.h"
void CreateTestData(xgboost::bst_uint num_rows, int max_row_size,
thrust::host_vector<int> *row_ptr,
@ -35,7 +36,7 @@ void SpeedTest() {
thrust::device_vector<int> output_row(h_rows.size());
auto d_output_row = output_row.data();
dh::Timer t;
xgboost::common::Timer t;
dh::TransformLbs(
0, &temp_memory, h_rows.size(), dh::raw(row_ptr), row_ptr.size() - 1, false,
[=] __device__(size_t idx, size_t ridx) { d_output_row[idx] = ridx; });

View File

@ -7,8 +7,8 @@
#include "../helpers.h"
#include "gtest/gtest.h"
#include "../../../src/tree/updater_gpu_hist_experimental.cu"
#include "../../../src/gbm/gbtree_model.h"
#include "../../../src/tree/updater_gpu_hist_experimental.cu"
namespace xgboost {
namespace tree {
@ -22,7 +22,9 @@ TEST(gpu_hist_experimental, TestSparseShard) {
hmat.Init(dmat.get(), max_bins);
gmat.cut = &hmat;
gmat.Init(dmat.get());
DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(), TrainParam());
ncclComm_t comm;
DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(),
TrainParam());
ASSERT_LT(shard.row_stride, columns);
@ -54,7 +56,9 @@ TEST(gpu_hist_experimental, TestDenseShard) {
hmat.Init(dmat.get(), max_bins);
gmat.cut = &hmat;
gmat.Init(dmat.get());
DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(), TrainParam());
ncclComm_t comm;
DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(),
TrainParam());
ASSERT_EQ(shard.row_stride, columns);

View File

@ -8,6 +8,7 @@ import numpy as np
import unittest
from nose.plugins.attrib import attr
from sklearn.datasets import load_digits, load_boston, load_breast_cancer, make_regression
import itertools as it
rng = np.random.RandomState(1994)
@ -15,8 +16,9 @@ rng = np.random.RandomState(1994)
def non_increasing(L, tolerance):
return all((y - x) < tolerance for x, y in zip(L, L[1:]))
#Check result is always decreasing and final accuracy is within tolerance
def assert_accuracy(res, tree_method, comparison_tree_method, tolerance):
# Check result is always decreasing and final accuracy is within tolerance
def assert_accuracy(res, tree_method, comparison_tree_method, tolerance, param):
assert non_increasing(res[tree_method], tolerance)
assert np.allclose(res[tree_method][-1], res[comparison_tree_method][-1], 1e-3, 1e-2)
@ -26,13 +28,14 @@ def train_boston(param_in, comparison_tree_method):
dtrain = xgb.DMatrix(data.data, label=data.target)
param = {}
param.update(param_in)
param['max_depth'] = 2
res_tmp = {}
res = {}
num_rounds = 10
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
bst = xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[param['tree_method']] = res_tmp['train']['rmse']
param["tree_method"] = comparison_tree_method
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
bst = xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[comparison_tree_method] = res_tmp['train']['rmse']
return res
@ -92,17 +95,24 @@ def train_sparse(param_in, comparison_tree_method):
return res
# Enumerates all permutations of variable parameters
def assert_updater_accuracy(tree_method, comparison_tree_method, variable_param, tolerance):
param = {'tree_method': tree_method}
for k, set in variable_param.items():
for val in set:
param_tmp = param.copy()
param_tmp[k] = val
print(param_tmp, file=sys.stderr)
assert_accuracy(train_boston(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance)
assert_accuracy(train_digits(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance)
assert_accuracy(train_cancer(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance)
assert_accuracy(train_sparse(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance)
param = {'tree_method': tree_method }
names = sorted(variable_param)
combinations = it.product(*(variable_param[Name] for Name in names))
for set in combinations:
print(names, file=sys.stderr)
print(set, file=sys.stderr)
param_tmp = param.copy()
for i, name in enumerate(names):
param_tmp[name] = set[i]
print(param_tmp, file=sys.stderr)
assert_accuracy(train_boston(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance, param_tmp)
assert_accuracy(train_digits(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance, param_tmp)
assert_accuracy(train_cancer(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance, param_tmp)
assert_accuracy(train_sparse(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance, param_tmp)
@attr('gpu')
@ -116,5 +126,5 @@ class TestGPU(unittest.TestCase):
assert_updater_accuracy('gpu_exact', 'exact', variable_param, 0.02)
def test_gpu_hist_experimental(self):
variable_param = {'max_depth': [2, 6], 'max_leaves': [255, 4], 'max_bin': [2, 16, 1024]}
variable_param = {'n_gpus': [1, -1], 'max_depth': [2, 6], 'max_leaves': [255, 4], 'max_bin': [2, 16, 1024]}
assert_updater_accuracy('gpu_hist_experimental', 'hist', variable_param, 0.01)