diff --git a/CMakeLists.txt b/CMakeLists.txt index ab9f634f7..18a2ec04f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,6 +9,7 @@ msvc_use_static_runtime() # Options option(USE_CUDA "Build with GPU acceleration") +option(USE_NCCL "Build using NCCL for multi-GPU. Also requires USE_CUDA") option(JVM_BINDINGS "Build JVM bindings" OFF) option(GOOGLE_TEST "Build google tests" OFF) option(R_LIB "Build shared library for R package" OFF) @@ -97,26 +98,39 @@ set(LINK_LIBRARIES dmlccore rabit) if(USE_CUDA) - find_package(CUDA 7.5 REQUIRED) + find_package(CUDA 8.0 REQUIRED) cmake_minimum_required(VERSION 3.5) add_definitions(-DXGBOOST_USE_CUDA) - include_directories( - nccl/src - cub - ) + include_directories(cub) + if(USE_NCCL) + include_directories(nccl/src) + add_definitions(-DXGBOOST_USE_NCCL) + endif() + + if((CUDA_VERSION_MAJOR EQUAL 9) OR (CUDA_VERSION_MAJOR GREATER 9)) + message("CUDA 9.0 detected, adding Volta compute capability (7.0).") + set(GPU_COMPUTE_VER "${GPU_COMPUTE_VER};70") + endif() + set(GENCODE_FLAGS "") format_gencode_flags("${GPU_COMPUTE_VER}" GENCODE_FLAGS) - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;${GENCODE_FLAGS};-lineinfo;") + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;--expt-relaxed-constexpr;${GENCODE_FLAGS};-lineinfo;") if(NOT MSVC) set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-Xcompiler -fPIC; -std=c++11") endif() - add_subdirectory(nccl) + if(USE_NCCL) + add_subdirectory(nccl) + endif() + cuda_add_library(gpuxgboost ${CUDA_SOURCES} STATIC) - target_link_libraries(gpuxgboost nccl) + + if(USE_NCCL) + target_link_libraries(gpuxgboost nccl) + endif() list(APPEND LINK_LIBRARIES gpuxgboost) endif() diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index dce0e3be2..ace609b6e 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -15,7 +15,10 @@ #include #include #include + +#ifdef XGBOOST_USE_NCCL #include "nccl.h" +#endif // Uncomment to enable #define TIMERS @@ -44,6 +47,7 @@ inline cudaError_t throw_on_cuda_error(cudaError_t code, const char *file, return code; } +#ifdef XGBOOST_USE_NCCL #define safe_nccl(ans) throw_on_nccl_error((ans), __FILE__, __LINE__) inline ncclResult_t throw_on_nccl_error(ncclResult_t code, const char *file, @@ -57,6 +61,7 @@ inline ncclResult_t throw_on_nccl_error(ncclResult_t code, const char *file, return code; } +#endif template T *raw(thrust::device_vector &v) { // NOLINT @@ -137,6 +142,19 @@ inline int get_device_idx(int gpu_id) { return (std::abs(gpu_id) + 0) % dh::n_visible_devices(); } +inline void check_compute_capability() { + int n_devices = n_visible_devices(); + for (int d_idx = 0; d_idx < n_devices; ++d_idx) { + cudaDeviceProp prop; + safe_cuda(cudaGetDeviceProperties(&prop, d_idx)); + std::ostringstream oss; + oss << "CUDA Capability Major/Minor version number: " << prop.major << "." + << prop.minor << " is insufficient. Need >=3.5"; + int failed = prop.major < 3 || prop.major == 3 && prop.minor < 5; + if (failed) LOG(WARNING) << oss.str() << " for device: " << d_idx; + } +} + /* * Range iterator */ @@ -241,37 +259,12 @@ inline void launch_n(int device_idx, size_t n, L lambda) { } safe_cuda(cudaSetDevice(device_idx)); - const int GRID_SIZE = static_cast(div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS)); + const int GRID_SIZE = + static_cast(div_round_up(n, ITEMS_PER_THREAD * BLOCK_THREADS)); launch_n_kernel<<>>(static_cast(0), n, lambda); } -/* - * Timers - */ - -struct Timer { - typedef std::chrono::high_resolution_clock ClockT; - typedef std::chrono::high_resolution_clock::time_point TimePointT; - typedef std::chrono::high_resolution_clock::duration DurationT; - typedef std::chrono::duration SecondsT; - - TimePointT start; - DurationT elapsed; - Timer() { Reset(); } - void Reset() { - elapsed = DurationT::zero(); - Start(); - } - void Start() { start = ClockT::now(); } - void Stop() { elapsed += ClockT::now() - start; } - double ElapsedSeconds() const { return SecondsT(elapsed).count(); } - void PrintElapsed(std::string label) { - printf("%s:\t %fs\n", label.c_str(), SecondsT(elapsed).count()); - Reset(); - } -}; - /* * Memory */ @@ -444,7 +437,7 @@ class bulk_allocator { template void allocate_dvec(int device_idx, char *ptr, dvec *first_vec, - size_t first_size) { + size_t first_size) { first_vec->external_allocate(device_idx, static_cast(ptr), first_size); } @@ -470,8 +463,7 @@ class bulk_allocator { template size_t get_size_bytes(dvec2 *first_vec, size_t first_size, Args... args) { - return get_size_bytes(first_vec, first_size) + - get_size_bytes(args...); + return get_size_bytes(first_vec, first_size) + get_size_bytes(args...); } template @@ -497,6 +489,7 @@ class bulk_allocator { if (!(d_ptr[i] == nullptr)) { safe_cuda(cudaSetDevice(_device_idx[i])); safe_cuda(cudaFree(d_ptr[i])); + d_ptr[i] = nullptr; } } } @@ -642,7 +635,8 @@ __global__ void LbsKernel(coordinate_t *d_coordinates, for (auto item : dh::block_stride_range(int(0), int(tile_num_rows + 1))) { temp_storage.tile_segment_end_offsets[item] = - segment_end_offsets[min(tile_start_coord.x + item, num_segments - 1)]; + segment_end_offsets[min(static_cast(tile_start_coord.x + item), + static_cast(num_segments - 1))]; } __syncthreads(); @@ -693,8 +687,8 @@ void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory, BLOCK_THREADS, segments, num_segments, count); LbsKernel - <<>>(tmp_tile_coordinates, segments + 1, f, - num_segments); + <<>>(tmp_tile_coordinates, + segments + 1, f, num_segments); } template @@ -836,4 +830,125 @@ void gather(int device_idx, T *out, const T *in, const int *instId, int nVals) { }); } +/** + * \class AllReducer + * + * \brief All reducer class that manages its own communication group and + * streams. Must be initialised before use. If XGBoost is compiled without NCCL this is a dummy class that will error if used with more than one GPU. + */ + +class AllReducer { + bool initialised; +#ifdef XGBOOST_USE_NCCL + std::vector comms; + std::vector streams; + std::vector device_ordinals; +#endif + public: + AllReducer() : initialised(false) {} + + /** + * \fn void Init(const std::vector &device_ordinals) + * + * \brief Initialise with the desired device ordinals for this communication + * group. + * + * \param device_ordinals The device ordinals. + */ + + void Init(const std::vector &device_ordinals) { +#ifdef XGBOOST_USE_NCCL + this->device_ordinals = device_ordinals; + comms.resize(device_ordinals.size()); + dh::safe_nccl(ncclCommInitAll(comms.data(), + static_cast(device_ordinals.size()), + device_ordinals.data())); + streams.resize(device_ordinals.size()); + for (size_t i = 0; i < device_ordinals.size(); i++) { + safe_cuda(cudaSetDevice(device_ordinals[i])); + safe_cuda(cudaStreamCreate(&streams[i])); + } + initialised = true; +#else + CHECK_EQ(device_ordinals.size(), 1) << "XGBoost must be compiled with NCCL to use more than one GPU."; +#endif + } + ~AllReducer() { +#ifdef XGBOOST_USE_NCCL + if (initialised) { + for (auto &stream : streams) { + dh::safe_cuda(cudaStreamDestroy(stream)); + } + for (auto &comm : comms) { + ncclCommDestroy(comm); + } + } +#endif + } + + /** + * \fn void AllReduceSum(int communication_group_idx, const double *sendbuff, + * double *recvbuff, int count) + * + * \brief Allreduce. Use in exactly the same way as NCCL but without needing + * streams or comms. + * + * \param communication_group_idx Zero-based index of the + * communication group. \param sendbuff The sendbuff. \param + * sendbuff The sendbuff. \param [in,out] recvbuff + * The recvbuff. \param count Number of. + */ + + void AllReduceSum(int communication_group_idx, const double *sendbuff, + double *recvbuff, int count) { +#ifdef XGBOOST_USE_NCCL + CHECK(initialised); + + dh::safe_cuda(cudaSetDevice(device_ordinals[communication_group_idx])); + dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum, + comms[communication_group_idx], + streams[communication_group_idx])); +#endif + } + + /** + * \fn void AllReduceSum(int communication_group_idx, const int64_t *sendbuff, int64_t *recvbuff, int count) + * + * \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms. + * + * \param communication_group_idx Zero-based index of the communication group. \param + * sendbuff The sendbuff. \param sendbuff + * The sendbuff. \param [in,out] recvbuff The recvbuff. + * \param count Number of. + * \param sendbuff The sendbuff. + * \param [in,out] recvbuff If non-null, the recvbuff. + * \param count Number of. + */ + + void AllReduceSum(int communication_group_idx, const int64_t *sendbuff, + int64_t *recvbuff, int count) { +#ifdef XGBOOST_USE_NCCL + CHECK(initialised); + + dh::safe_cuda(cudaSetDevice(device_ordinals[communication_group_idx])); + dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum, + comms[communication_group_idx], + streams[communication_group_idx])); +#endif + } + + /** + * \fn void Synchronize() + * + * \brief Synchronizes the entire communication group. + */ + void Synchronize() { +#ifdef XGBOOST_USE_NCCL + for (int i = 0; i < device_ordinals.size(); i++) { + dh::safe_cuda(cudaSetDevice(device_ordinals[i])); + dh::safe_cuda(cudaStreamSynchronize(streams[i])); + } +#endif + } +}; } // namespace dh diff --git a/src/common/timer.h b/src/common/timer.h new file mode 100644 index 000000000..81eec7ab0 --- /dev/null +++ b/src/common/timer.h @@ -0,0 +1,77 @@ +/*! + * Copyright by Contributors 2017 + */ +#pragma once +#include +#include +#include +#include + +namespace xgboost { +namespace common { +struct Timer { + typedef std::chrono::high_resolution_clock ClockT; + typedef std::chrono::high_resolution_clock::time_point TimePointT; + typedef std::chrono::high_resolution_clock::duration DurationT; + typedef std::chrono::duration SecondsT; + + TimePointT start; + DurationT elapsed; + Timer() { Reset(); } + void Reset() { + elapsed = DurationT::zero(); + Start(); + } + void Start() { start = ClockT::now(); } + void Stop() { elapsed += ClockT::now() - start; } + double ElapsedSeconds() const { return SecondsT(elapsed).count(); } + void PrintElapsed(std::string label) { + printf("%s:\t %fs\n", label.c_str(), SecondsT(elapsed).count()); + Reset(); + } +}; + +/** + * \struct Monitor + * + * \brief Timing utility used to measure total method execution time over the + * lifetime of the containing object. + */ + +struct Monitor { + bool debug_verbose = false; + std::string label = ""; + std::map timer_map; + Timer self_timer; + + Monitor() { self_timer.Start(); } + + ~Monitor() { + if (!debug_verbose) return; + + std::cout << "========\n"; + std::cout << "Monitor: " << label << "\n"; + std::cout << "========\n"; + for (auto &kv : timer_map) { + kv.second.PrintElapsed(kv.first); + } + self_timer.Stop(); + self_timer.PrintElapsed(label + " Lifetime"); + } + void Init(std::string label, bool debug_verbose) { + this->debug_verbose = debug_verbose; + this->label = label; + } + void Start(const std::string &name) { timer_map[name].Start(); } + void Stop(const std::string &name) { + if (debug_verbose) { +#ifdef __CUDACC__ +#include "device_helpers.cuh" + dh::synchronize_all(); +#endif + } + timer_map[name].Stop(); + } +}; +} // namespace common +} // namespace xgboost diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 990327e5f..c9055d7c0 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -20,6 +20,7 @@ #include "../common/common.h" #include "../common/random.h" #include "gbtree_model.h" +#include "../common/timer.h" namespace xgboost { namespace gbm { @@ -158,6 +159,7 @@ class GBTree : public GradientBooster { // configure predictor predictor = std::unique_ptr(Predictor::Create(tparam.predictor)); predictor->Init(cfg, cache_); + monitor.Init("GBTree", tparam.debug_verbose); } void Load(dmlc::Stream* fi) override { @@ -183,6 +185,7 @@ class GBTree : public GradientBooster { const std::vector& gpair = *in_gpair; std::vector > > new_trees; const int ngroup = model_.param.num_output_group; + monitor.Start("BoostNewTrees"); if (ngroup == 1) { std::vector > ret; BoostNewTrees(gpair, p_fmat, 0, &ret); @@ -202,13 +205,12 @@ class GBTree : public GradientBooster { new_trees.push_back(std::move(ret)); } } - double tstart = dmlc::GetTime(); + monitor.Stop("BoostNewTrees"); + monitor.Start("CommitModel"); for (int gid = 0; gid < ngroup; ++gid) { this->CommitModel(std::move(new_trees[gid]), gid); } - if (tparam.debug_verbose > 0) { - LOG(INFO) << "CommitModel(): " << dmlc::GetTime() - tstart << " sec"; - } + monitor.Stop("CommitModel"); } void PredictBatch(DMatrix* p_fmat, @@ -308,6 +310,7 @@ class GBTree : public GradientBooster { // Cached matrices std::vector> cache_; std::unique_ptr predictor; + common::Monitor monitor; }; // dart diff --git a/src/learner.cc b/src/learner.cc index 6979225e1..32e807137 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -18,6 +18,7 @@ #include "./common/common.h" #include "./common/io.h" #include "./common/random.h" +#include "common/timer.h" namespace xgboost { // implementation of base learner. @@ -202,6 +203,7 @@ class LearnerImpl : public Learner { const std::vector >& args) override { // add to configurations tparam.InitAllowUnknown(args); + monitor.Init("Learner", tparam.debug_verbose); cfg_.clear(); for (const auto& kv : args) { if (kv.first == "eval_metric") { @@ -359,29 +361,37 @@ class LearnerImpl : public Learner { } void UpdateOneIter(int iter, DMatrix* train) override { + monitor.Start("UpdateOneIter"); CHECK(ModelInitialized()) << "Always call InitModel or LoadModel before update"; if (tparam.seed_per_iteration || rabit::IsDistributed()) { common::GlobalRandom().seed(tparam.seed * kRandSeedMagic + iter); } this->LazyInitDMatrix(train); + monitor.Start("PredictRaw"); this->PredictRaw(train, &preds_); + monitor.Stop("PredictRaw"); + monitor.Start("GetGradient"); obj_->GetGradient(preds_, train->info(), iter, &gpair_); + monitor.Stop("GetGradient"); gbm_->DoBoost(train, &gpair_, obj_.get()); + monitor.Stop("UpdateOneIter"); } void BoostOneIter(int iter, DMatrix* train, std::vector* in_gpair) override { + monitor.Start("BoostOneIter"); if (tparam.seed_per_iteration || rabit::IsDistributed()) { common::GlobalRandom().seed(tparam.seed * kRandSeedMagic + iter); } this->LazyInitDMatrix(train); gbm_->DoBoost(train, in_gpair); + monitor.Stop("BoostOneIter"); } std::string EvalOneIter(int iter, const std::vector& data_sets, const std::vector& data_names) override { - double tstart = dmlc::GetTime(); + monitor.Start("EvalOneIter"); std::ostringstream os; os << '[' << iter << ']' << std::setiosflags(std::ios::fixed); if (metrics_.size() == 0) { @@ -396,9 +406,7 @@ class LearnerImpl : public Learner { } } - if (tparam.debug_verbose > 0) { - LOG(INFO) << "EvalOneIter(): " << dmlc::GetTime() - tstart << " sec"; - } + monitor.Stop("EvalOneIter"); return os.str(); } @@ -460,10 +468,11 @@ class LearnerImpl : public Learner { // if not, initialize the column access. inline void LazyInitDMatrix(DMatrix* p_train) { if (tparam.tree_method == 3 || tparam.tree_method == 4 || - tparam.tree_method == 5) { + tparam.tree_method == 5 || tparam.tree_method == 6) { return; } + monitor.Start("LazyInitDMatrix"); if (!p_train->HaveColAccess()) { int ncol = static_cast(p_train->info().num_col); std::vector enabled(ncol, true); @@ -504,6 +513,7 @@ class LearnerImpl : public Learner { gbm_->Configure(cfg_.begin(), cfg_.end()); } } + monitor.Stop("LazyInitDMatrix"); } // return whether model is already initialized. @@ -568,6 +578,8 @@ class LearnerImpl : public Learner { static const int kRandSeedMagic = 127; // internal cached dmatrix std::vector > cache_; + + common::Monitor monitor; }; Learner* Learner::Create( diff --git a/src/objective/multiclass_obj.cc b/src/objective/multiclass_obj.cc index 51925c8d1..dad4a3d60 100644 --- a/src/objective/multiclass_obj.cc +++ b/src/objective/multiclass_obj.cc @@ -65,9 +65,9 @@ class SoftmaxMultiClassObj : public ObjFunction { bst_float p = rec[k]; const bst_float h = 2.0f * p * (1.0f - p) * wt; if (label == k) { - out_gpair->at(i * nclass + k) = bst_gpair((p - 1.0f) * wt, h); + (*out_gpair)[i * nclass + k] = bst_gpair((p - 1.0f) * wt, h); } else { - out_gpair->at(i * nclass + k) = bst_gpair(p* wt, h); + (*out_gpair)[i * nclass + k] = bst_gpair(p* wt, h); } } } diff --git a/src/objective/regression_obj.cc b/src/objective/regression_obj.cc index ac7e5a4fe..2597d2de4 100644 --- a/src/objective/regression_obj.cc +++ b/src/objective/regression_obj.cc @@ -1,6 +1,6 @@ /*! * Copyright 2015 by Contributors - * \file regression.cc + * \file regression_obj.cc * \brief Definition of single-value regression and classification objectives. * \author Tianqi Chen, Kailong Chen */ @@ -68,53 +68,61 @@ struct LogisticRaw : public LogisticRegression { struct RegLossParam : public dmlc::Parameter { float scale_pos_weight; + int nthread; // declare parameters DMLC_DECLARE_PARAMETER(RegLossParam) { DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f) .describe("Scale the weight of positive examples by this factor"); + DMLC_DECLARE_FIELD(nthread).set_default(0).describe( + "Number of threads to use."); } }; // regression loss function -template +template class RegLossObj : public ObjFunction { public: - void Configure(const std::vector >& args) override { + RegLossObj() : labels_checked(false) {} + + void Configure( + const std::vector > &args) override { param_.InitAllowUnknown(args); } - void GetGradient(const std::vector &preds, - const MetaInfo &info, - int iter, - std::vector *out_gpair) override { + void GetGradient(const std::vector &preds, const MetaInfo &info, + int iter, std::vector *out_gpair) override { CHECK_NE(info.labels.size(), 0U) << "label set cannot be empty"; CHECK_EQ(preds.size(), info.labels.size()) << "labels are not correctly provided" - << "preds.size=" << preds.size() << ", label.size=" << info.labels.size(); + << "preds.size=" << preds.size() + << ", label.size=" << info.labels.size(); + + this->LazyCheckLabels(info.labels); out_gpair->resize(preds.size()); - // check if label in range - bool label_correct = true; + // start calculating gradient const omp_ulong ndata = static_cast(preds.size()); - #pragma omp parallel for schedule(static) + int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread; +#pragma omp parallel for schedule(static) num_threads(nthread) for (omp_ulong i = 0; i < ndata; ++i) { + auto y = info.labels[i]; bst_float p = Loss::PredTransform(preds[i]); bst_float w = info.GetWeight(i); - if (info.labels[i] == 1.0f) w *= param_.scale_pos_weight; - if (!Loss::CheckLabel(info.labels[i])) label_correct = false; - out_gpair->at(i) = bst_gpair(Loss::FirstOrderGradient(p, info.labels[i]) * w, - Loss::SecondOrderGradient(p, info.labels[i]) * w); - } - if (!label_correct) { - LOG(FATAL) << Loss::LabelErrorMsg(); + // Branchless version of the below function + // The branch is particularly slow as the cpu cannot predict the label + // with any accuracy resulting in frequent pipeline stalls + // if (info.labels[i] == 1.0f) w *= param_.scale_pos_weight; + w += y * ((param_.scale_pos_weight * w) - w); + (*out_gpair)[i] = bst_gpair(Loss::FirstOrderGradient(p, y) * w, + Loss::SecondOrderGradient(p, y) * w); } } - const char* DefaultEvalMetric() const override { + const char *DefaultEvalMetric() const override { return Loss::DefaultEvalMetric(); } void PredTransform(std::vector *io_preds) override { std::vector &preds = *io_preds; const bst_omp_uint ndata = static_cast(preds.size()); - #pragma omp parallel for schedule(static) +#pragma omp parallel for schedule(static) for (bst_omp_uint j = 0; j < ndata; ++j) { preds[j] = Loss::PredTransform(preds[j]); } @@ -124,7 +132,15 @@ class RegLossObj : public ObjFunction { } protected: + void LazyCheckLabels(const std::vector &labels) { + if (labels_checked) return; + for (auto &y : labels) { + CHECK(Loss::CheckLabel(y)) << Loss::LabelErrorMsg(); + } + labels_checked = true; + } RegLossParam param_; + bool labels_checked; }; // register the objective functions @@ -149,10 +165,13 @@ XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, "binary:logitraw") // declare parameter struct PoissonRegressionParam : public dmlc::Parameter { float max_delta_step; + int nthread; DMLC_DECLARE_PARAMETER(PoissonRegressionParam) { DMLC_DECLARE_FIELD(max_delta_step).set_lower_bound(0.0f).set_default(0.7f) .describe("Maximum delta step we allow each weight estimation to be." \ " This parameter is required for possion regression."); + DMLC_DECLARE_FIELD(nthread).set_default(0).describe( + "Number of threads to use."); } }; @@ -175,13 +194,14 @@ class PoissonRegression : public ObjFunction { bool label_correct = true; // start calculating gradient const omp_ulong ndata = static_cast(preds.size()); // NOLINT(*) - #pragma omp parallel for schedule(static) + int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread; + #pragma omp parallel for schedule(static) num_threads(nthread) for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*) bst_float p = preds[i]; bst_float w = info.GetWeight(i); bst_float y = info.labels[i]; if (y >= 0.0f) { - out_gpair->at(i) = bst_gpair((std::exp(p) - y) * w, + (*out_gpair)[i] = bst_gpair((std::exp(p) - y) * w, std::exp(p + param_.max_delta_step) * w); } else { label_correct = false; @@ -192,7 +212,8 @@ class PoissonRegression : public ObjFunction { void PredTransform(std::vector *io_preds) override { std::vector &preds = *io_preds; const long ndata = static_cast(preds.size()); // NOLINT(*) - #pragma omp parallel for schedule(static) + int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread; + #pragma omp parallel for schedule(static) num_threads(nthread) for (long j = 0; j < ndata; ++j) { // NOLINT(*) preds[j] = std::exp(preds[j]); } @@ -242,7 +263,7 @@ class GammaRegression : public ObjFunction { bst_float w = info.GetWeight(i); bst_float y = info.labels[i]; if (y >= 0.0f) { - out_gpair->at(i) = bst_gpair((1 - y / std::exp(p)) * w, y / std::exp(p) * w); + (*out_gpair)[i] = bst_gpair((1 - y / std::exp(p)) * w, y / std::exp(p) * w); } else { label_correct = false; } @@ -276,9 +297,12 @@ XGBOOST_REGISTER_OBJECTIVE(GammaRegression, "reg:gamma") // declare parameter struct TweedieRegressionParam : public dmlc::Parameter { float tweedie_variance_power; + int nthread; DMLC_DECLARE_PARAMETER(TweedieRegressionParam) { DMLC_DECLARE_FIELD(tweedie_variance_power).set_range(1.0f, 2.0f).set_default(1.5f) .describe("Tweedie variance power. Must be between in range [1, 2)."); + DMLC_DECLARE_FIELD(nthread).set_default(0).describe( + "Number of threads to use."); } }; @@ -301,7 +325,8 @@ class TweedieRegression : public ObjFunction { bool label_correct = true; // start calculating gradient const omp_ulong ndata = static_cast(preds.size()); // NOLINT(*) - #pragma omp parallel for schedule(static) + int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread; + #pragma omp parallel for schedule(static) num_threads(nthread) for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*) bst_float p = preds[i]; bst_float w = info.GetWeight(i); @@ -311,7 +336,7 @@ class TweedieRegression : public ObjFunction { bst_float grad = -y * std::exp((1 - rho) * p) + std::exp((2 - rho) * p); bst_float hess = -y * (1 - rho) * \ std::exp((1 - rho) * p) + (2 - rho) * std::exp((2 - rho) * p); - out_gpair->at(i) = bst_gpair(grad * w, hess * w); + (*out_gpair)[i] = bst_gpair(grad * w, hess * w); } else { label_correct = false; } @@ -321,7 +346,8 @@ class TweedieRegression : public ObjFunction { void PredTransform(std::vector *io_preds) override { std::vector &preds = *io_preds; const long ndata = static_cast(preds.size()); // NOLINT(*) - #pragma omp parallel for schedule(static) + int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread; + #pragma omp parallel for schedule(static) num_threads(nthread) for (long j = 0; j < ndata; ++j) { // NOLINT(*) preds[j] = std::exp(preds[j]); } diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index 8edaf8578..a15c0242c 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -12,9 +12,51 @@ #include "../common/random.h" #include "param.h" +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 + +#else +__device__ __forceinline__ double atomicAdd(double* address, double val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; // NOLINT + unsigned long long int old = *address_as_ull, assumed; // NOLINT + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + __longlong_as_double(assumed))); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != + // NaN) + } while (assumed != old); + + return __longlong_as_double(old); +} +#endif + namespace xgboost { namespace tree { +// Atomic add function for double precision gradients +__device__ __forceinline__ void AtomicAddGpair(bst_gpair_precise* dest, + const bst_gpair& gpair) { + auto dst_ptr = reinterpret_cast(dest); + + atomicAdd(dst_ptr, static_cast(gpair.GetGrad())); + atomicAdd(dst_ptr + 1, static_cast(gpair.GetHess())); +} + +// For integer gradients +__device__ __forceinline__ void AtomicAddGpair(bst_gpair_integer* dest, + const bst_gpair& gpair) { + auto dst_ptr = reinterpret_cast(dest); // NOLINT + bst_gpair_integer tmp(gpair.GetGrad(), gpair.GetHess()); + auto src_ptr = reinterpret_cast(&tmp); + + atomicAdd(dst_ptr, + static_cast(*src_ptr)); // NOLINT + atomicAdd(dst_ptr + 1, + static_cast(*(src_ptr + 1))); // NOLINT +} + /** * \fn void CheckGradientMax(const dh::dvec& gpair) * @@ -22,15 +64,11 @@ namespace tree { * overflow when using integer gradient summation. */ -inline void CheckGradientMax(const dh::dvec& gpair) { - auto dptr = thrust::device_ptr( - reinterpret_cast(gpair.data())); - float abs_max = thrust::reduce(dptr, dptr + (gpair.size() * 2), 0.f, - [=] __device__(float a, float b) { - a = abs(a); - b = abs(b); - return max(a, b); - }); +inline void CheckGradientMax(const std::vector& gpair) { + auto* ptr = reinterpret_cast(gpair.data()); + float abs_max = + std::accumulate(ptr, ptr + (gpair.size() * 2), 0.f, + [=](float a, float b) { return max(abs(a), abs(b)); }); CHECK_LT(abs_max, std::pow(2.0f, 16.0f)) << "Labels are too large for this algorithm. Rescale to less than 2^16."; @@ -321,8 +359,74 @@ inline std::vector col_sample(std::vector features, float colsample) { std::shuffle(features.begin(), features.end(), common::GlobalRandom()); features.resize(n); + std::sort(features.begin(), features.end()); return features; } + +/** + * \class ColumnSampler + * + * \brief Handles selection of columns due to colsample_bytree and + * colsample_bylevel parameters. Should be initialised the before tree + * construction and to reset When tree construction is completed. + */ + +class ColumnSampler { + std::vector feature_set_tree; + std::map> feature_set_level; + TrainParam param; + + public: + /** + * \fn void Init(int64_t num_col, const TrainParam& param) + * + * \brief Initialise this object before use. + * + * \param num_col Number of cols. + * \param param The parameter. + */ + + void Init(int64_t num_col, const TrainParam& param) { + this->Reset(); + this->param = param; + feature_set_tree.resize(num_col); + std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0); + feature_set_tree = col_sample(feature_set_tree, param.colsample_bytree); + } + + /** + * \fn void Reset() + * + * \brief Resets this object. + */ + + void Reset() { + feature_set_tree.clear(); + feature_set_level.clear(); + } + + /** + * \fn bool ColumnUsed(int column, int depth) + * + * \brief Whether the current column should be considered as a split. + * + * \param column The column index. + * \param depth The current tree depth. + * + * \return True if it should be used, false if it should not be used. + */ + + bool ColumnUsed(int column, int depth) { + if (feature_set_level.count(depth) == 0) { + feature_set_level[depth] = + col_sample(feature_set_tree, param.colsample_bylevel); + } + + return std::binary_search(feature_set_level[depth].begin(), + feature_set_level[depth].end(), column); + } +}; + } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index dd3f1a8e9..77b48b157 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -8,6 +8,7 @@ #include "../common/compressed_iterator.h" #include "../common/device_helpers.cuh" #include "../common/hist_util.h" +#include "../common/timer.h" #include "param.h" #include "updater_gpu_common.cuh" @@ -17,7 +18,6 @@ namespace tree { DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); typedef bst_gpair_integer gpair_sum_t; -static const ncclDataType_t nccl_sum_t = ncclInt64; // Helper for explicit template specialisation template @@ -63,15 +63,7 @@ struct HistHelper { __device__ void Add(bst_gpair gpair, int gidx, int nidx) const { int hist_idx = nidx * n_bins + gidx; - auto dst_ptr = - reinterpret_cast(&d_hist[hist_idx]); // NOLINT - gpair_sum_t tmp(gpair.GetGrad(), gpair.GetHess()); - auto src_ptr = reinterpret_cast(&tmp); - - atomicAdd(dst_ptr, - static_cast(*src_ptr)); // NOLINT - atomicAdd(dst_ptr + 1, - static_cast(*(src_ptr + 1))); // NOLINT + AtomicAddGpair(d_hist + hist_idx, gpair); } __device__ gpair_sum_t Get(int gidx, int nidx) const { return d_hist[nidx * n_bins + gidx]; @@ -244,22 +236,7 @@ class GPUHistMaker : public TreeUpdater { is_dense(false), p_last_fmat_(nullptr), prediction_cache_initialised(false) {} - ~GPUHistMaker() { - if (initialised) { - for (int d_idx = 0; d_idx < n_devices; ++d_idx) { - ncclCommDestroy(comms[d_idx]); - - dh::safe_cuda(cudaSetDevice(dList[d_idx])); - dh::safe_cuda(cudaStreamDestroy(*(streams[d_idx]))); - } - for (int num_d = 1; num_d <= n_devices; - ++num_d) { // loop over number of devices used - for (int d_idx = 0; d_idx < n_devices; ++d_idx) { - ncclCommDestroy(find_split_comms[num_d - 1][d_idx]); - } - } - } - } + ~GPUHistMaker() {} void Init( const std::vector>& args) override { param.InitAllowUnknown(args); @@ -290,7 +267,7 @@ class GPUHistMaker : public TreeUpdater { void InitData(const std::vector& gpair, DMatrix& fmat, // NOLINT const RegTree& tree) { - dh::Timer time1; + common::Timer time1; // set member num_rows and n_devices for rest of GPUHistBuilder members info = &fmat.info(); CHECK(info->num_row < std::numeric_limits::max()); @@ -298,6 +275,12 @@ class GPUHistMaker : public TreeUpdater { n_devices = dh::n_devices(param.n_gpus, num_rows); if (!initialised) { + // Check gradients are within acceptable size range + CheckGradientMax(gpair); + + // Check compute capability is high enough + dh::check_compute_capability(); + // reset static timers used across iterations cpu_init_time = 0; gpu_init_time = 0; @@ -312,57 +295,10 @@ class GPUHistMaker : public TreeUpdater { } // initialize nccl - - comms.resize(n_devices); - streams.resize(n_devices); - dh::safe_nccl(ncclCommInitAll(comms.data(), n_devices, - dList.data())); // initialize communicator - // (One communicator per - // process) - - // printf("# NCCL: Using devices\n"); - for (int d_idx = 0; d_idx < n_devices; ++d_idx) { - streams[d_idx] = - reinterpret_cast(malloc(sizeof(cudaStream_t))); - dh::safe_cuda(cudaSetDevice(dList[d_idx])); - dh::safe_cuda(cudaStreamCreate(streams[d_idx])); - - int cudaDev; - int rank; - cudaDeviceProp prop; - dh::safe_nccl(ncclCommCuDevice(comms[d_idx], &cudaDev)); - dh::safe_nccl(ncclCommUserRank(comms[d_idx], &rank)); - dh::safe_cuda(cudaGetDeviceProperties(&prop, cudaDev)); - // printf("# Rank %2d uses device %2d [0x%02x] %s\n", rank, cudaDev, - // prop.pciBusID, prop.name); - // cudaDriverGetVersion(&driverVersion); - // cudaRuntimeGetVersion(&runtimeVersion); - std::ostringstream oss; - oss << "CUDA Capability Major/Minor version number: " << prop.major - << "." << prop.minor << " is insufficient. Need >=3.5."; - int failed = prop.major < 3 || prop.major == 3 && prop.minor < 5; - CHECK(failed == 0) << oss.str(); - } - - // local find_split group of comms for each case of reduced number of - // GPUs to use - find_split_comms.resize( - n_devices, - std::vector(n_devices)); // TODO(JCM): Excessive, but - // ok, and best to do - // here instead of - // repeatedly - for (int num_d = 1; num_d <= n_devices; - ++num_d) { // loop over number of devices used - dh::safe_nccl( - ncclCommInitAll(find_split_comms[num_d - 1].data(), num_d, - dList.data())); // initialize communicator - // (One communicator per - // process) - } + reducer.Init(dList); is_dense = info->num_nonzero == info->num_col * info->num_row; - dh::Timer time0; + common::Timer time0; hmat_.Init(&fmat, param.max_bin); cpu_init_time += time0.ElapsedSeconds(); if (param.debug_verbose) { // Only done once for each training session @@ -397,8 +333,8 @@ class GPUHistMaker : public TreeUpdater { fflush(stdout); } - int n_bins = static_cast(hmat_.row_ptr.back()); - int n_features = static_cast(hmat_.row_ptr.size() - 1); + int n_bins = static_cast(hmat_.row_ptr.back()); + int n_features = static_cast(hmat_.row_ptr.size() - 1); // deliniate data onto multiple gpus device_row_segments.push_back(0); @@ -442,10 +378,7 @@ class GPUHistMaker : public TreeUpdater { temp_memory.resize(n_devices); hist_vec.resize(n_devices); nodes.resize(n_devices); - nodes_temp.resize(n_devices); - nodes_child_temp.resize(n_devices); left_child_smallest.resize(n_devices); - left_child_smallest_temp.resize(n_devices); feature_flags.resize(n_devices); fidx_min_map.resize(n_devices); feature_segments.resize(n_devices); @@ -457,12 +390,6 @@ class GPUHistMaker : public TreeUpdater { gidx_feature_map.resize(n_devices); gidx_fvalue_map.resize(n_devices); - int find_split_n_devices = static_cast(std::pow(2, std::floor(std::log2(n_devices)))); - find_split_n_devices = - std::min(n_nodes_level(param.max_depth), find_split_n_devices); - int max_num_nodes_device = - n_nodes_level(param.max_depth) / find_split_n_devices; - // num_rows_segment: for sharding rows onto gpus for splitting data // num_elements_segment: for sharding rows (of elements) onto gpus for // splitting data @@ -476,26 +403,31 @@ class GPUHistMaker : public TreeUpdater { device_row_segments[d_idx + 1] - device_row_segments[d_idx]; bst_ulong num_elements_segment = device_element_segments[d_idx + 1] - device_element_segments[d_idx]; + + // ensure allocation doesn't overflow + size_t hist_size = static_cast(n_nodes(param.max_depth - 1)) * + static_cast(n_bins); + size_t nodes_size = static_cast(n_nodes(param.max_depth)); + size_t hmat_size = static_cast(hmat_.min_val.size()); + size_t buffer_size = static_cast( + common::CompressedBufferWriter::CalculateBufferSize( + static_cast(num_elements_segment), + static_cast(n_bins))); + ba.allocate( - device_idx, param.silent, &(hist_vec[d_idx].data), - n_nodes(param.max_depth - 1) * n_bins, &nodes[d_idx], - n_nodes(param.max_depth), &nodes_temp[d_idx], max_num_nodes_device, - &nodes_child_temp[d_idx], max_num_nodes_device, - &left_child_smallest[d_idx], n_nodes(param.max_depth), - &left_child_smallest_temp[d_idx], max_num_nodes_device, - &feature_flags[d_idx], + device_idx, param.silent, &(hist_vec[d_idx].data), hist_size, + &nodes[d_idx], n_nodes(param.max_depth), + &left_child_smallest[d_idx], nodes_size, &feature_flags[d_idx], n_features, // may change but same on all devices &fidx_min_map[d_idx], - hmat_.min_val.size(), // constant and same on all devices + hmat_size, // constant and same on all devices &feature_segments[d_idx], h_feature_segments.size(), // constant and same on all devices &prediction_cache[d_idx], num_rows_segment, &position[d_idx], num_rows_segment, &position_tmp[d_idx], num_rows_segment, &device_gpair[d_idx], num_rows_segment, &device_matrix[d_idx].gidx_buffer, - common::CompressedBufferWriter::CalculateBufferSize( - num_elements_segment, - n_bins), // constant and same on all devices + buffer_size, // constant and same on all devices &device_matrix[d_idx].row_ptr, num_rows_segment + 1, &gidx_feature_map[d_idx], n_bins, // constant and same on all devices @@ -529,17 +461,12 @@ class GPUHistMaker : public TreeUpdater { dh::safe_cuda(cudaSetDevice(device_idx)); nodes[d_idx].fill(DeviceNodeStats()); - nodes_temp[d_idx].fill(DeviceNodeStats()); - nodes_child_temp[d_idx].fill(DeviceNodeStats()); position[d_idx].fill(0); device_gpair[d_idx].copy(gpair.begin() + device_row_segments[d_idx], gpair.begin() + device_row_segments[d_idx + 1]); - // Check gradients are within acceptable size range - CheckGradientMax(device_gpair[d_idx]); - subsample_gpair(&device_gpair[d_idx], param.subsample, device_row_segments[d_idx]); @@ -618,21 +545,16 @@ class GPUHistMaker : public TreeUpdater { // fprintf(stderr,"sizeof(bst_gpair)/sizeof(float)=%d\n",sizeof(bst_gpair)/sizeof(float)); for (int d_idx = 0; d_idx < n_devices; d_idx++) { int device_idx = dList[d_idx]; - dh::safe_cuda(cudaSetDevice(device_idx)); - dh::safe_nccl(ncclAllReduce( - reinterpret_cast(hist_vec[d_idx].GetLevelPtr(depth)), - reinterpret_cast(hist_vec[d_idx].GetLevelPtr(depth)), - hist_vec[d_idx].LevelSize(depth) * sizeof(gpair_sum_t) / - sizeof(gpair_sum_t::value_t), - nccl_sum_t, ncclSum, comms[d_idx], *(streams[d_idx]))); + reducer.AllReduceSum(device_idx, + reinterpret_cast( + hist_vec[d_idx].GetLevelPtr(depth)), + reinterpret_cast( + hist_vec[d_idx].GetLevelPtr(depth)), + hist_vec[d_idx].LevelSize(depth) * + sizeof(gpair_sum_t) / + sizeof(gpair_sum_t::value_t)); } - - for (int d_idx = 0; d_idx < n_devices; d_idx++) { - int device_idx = dList[d_idx]; - dh::safe_cuda(cudaSetDevice(device_idx)); - dh::safe_cuda(cudaStreamSynchronize(*(streams[d_idx]))); - } - // if no NCCL, then presume only 1 GPU, then already correct + reducer.Synchronize(); // time.printElapsed("Reduce-Add Time"); @@ -955,7 +877,7 @@ class GPUHistMaker : public TreeUpdater { } void UpdateTree(const std::vector& gpair, DMatrix* p_fmat, RegTree* p_tree) { - dh::Timer time0; + common::Timer time0; this->InitData(gpair, *p_fmat, *p_tree); this->InitFirstNode(gpair); @@ -1019,10 +941,7 @@ class GPUHistMaker : public TreeUpdater { std::vector temp_memory; std::vector hist_vec; std::vector> nodes; - std::vector> nodes_temp; - std::vector> nodes_child_temp; std::vector> left_child_smallest; - std::vector> left_child_smallest_temp; std::vector> feature_flags; std::vector> fidx_min_map; std::vector> feature_segments; @@ -1034,13 +953,11 @@ class GPUHistMaker : public TreeUpdater { std::vector> gidx_feature_map; std::vector> gidx_fvalue_map; - std::vector streams; - std::vector comms; - std::vector> find_split_comms; + dh::AllReducer reducer; double cpu_init_time; double gpu_init_time; - dh::Timer cpu_time; + common::Timer cpu_time; double gpu_time; }; diff --git a/src/tree/updater_gpu_hist_experimental.cu b/src/tree/updater_gpu_hist_experimental.cu index 977e3c637..d351fa183 100644 --- a/src/tree/updater_gpu_hist_experimental.cu +++ b/src/tree/updater_gpu_hist_experimental.cu @@ -1,8 +1,9 @@ /*! * Copyright 2017 XGBoost contributors */ -#include -#include +#include +#include +#include #include #include #include @@ -12,6 +13,7 @@ #include "../common/compressed_iterator.h" #include "../common/device_helpers.cuh" #include "../common/hist_util.h" +#include "../common/timer.h" #include "param.h" #include "updater_gpu_common.cuh" @@ -20,19 +22,20 @@ namespace tree { DMLC_REGISTRY_FILE_TAG(updater_gpu_hist_experimental); -template -__device__ bst_gpair_integer ReduceFeature(const bst_gpair_integer* begin, - const bst_gpair_integer* end, - temp_storage_t* temp_storage) { - __shared__ cub::Uninitialized uninitialized_sum; - bst_gpair_integer& shared_sum = uninitialized_sum.Alias(); +typedef bst_gpair_precise gpair_sum_t; - bst_gpair_integer local_sum = bst_gpair_integer(); +template +__device__ gpair_sum_t ReduceFeature(const gpair_sum_t* begin, + const gpair_sum_t* end, + temp_storage_t* temp_storage) { + __shared__ cub::Uninitialized uninitialized_sum; + gpair_sum_t& shared_sum = uninitialized_sum.Alias(); + + gpair_sum_t local_sum = gpair_sum_t(); for (auto itr = begin; itr < end; itr += BLOCK_THREADS) { bool thread_active = itr + threadIdx.x < end; // Scan histogram - bst_gpair_integer bin = - thread_active ? *(itr + threadIdx.x) : bst_gpair_integer(); + gpair_sum_t bin = thread_active ? *(itr + threadIdx.x) : gpair_sum_t(); local_sum += reduce_t(temp_storage->sum_reduce).Reduce(bin, cub::Sum()); } @@ -47,7 +50,7 @@ __device__ bst_gpair_integer ReduceFeature(const bst_gpair_integer* begin, template -__device__ void EvaluateFeature(int fidx, const bst_gpair_integer* hist, +__device__ void EvaluateFeature(int fidx, const gpair_sum_t* hist, const int* feature_segments, float min_fvalue, const float* gidx_fvalue_map, DeviceSplitCandidate* best_split, @@ -57,22 +60,22 @@ __device__ void EvaluateFeature(int fidx, const bst_gpair_integer* hist, int gidx_begin = feature_segments[fidx]; int gidx_end = feature_segments[fidx + 1]; - bst_gpair_integer feature_sum = ReduceFeature( + gpair_sum_t feature_sum = ReduceFeature( hist + gidx_begin, hist + gidx_end, temp_storage); - auto prefix_op = SumCallbackOp(); + auto prefix_op = SumCallbackOp(); for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += BLOCK_THREADS) { bool thread_active = scan_begin + threadIdx.x < gidx_end; - bst_gpair_integer bin = - thread_active ? hist[scan_begin + threadIdx.x] : bst_gpair_integer(); + gpair_sum_t bin = + thread_active ? hist[scan_begin + threadIdx.x] : gpair_sum_t(); scan_t(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op); - // Calculate gain - bst_gpair_integer parent_sum = bst_gpair_integer(node.sum_gradients); + // Calculate gain + gpair_sum_t parent_sum = gpair_sum_t(node.sum_gradients); - bst_gpair_integer missing = parent_sum - feature_sum; + gpair_sum_t missing = parent_sum - feature_sum; bool missing_left = true; const float null_gain = -FLT_MAX; @@ -102,8 +105,8 @@ __device__ void EvaluateFeature(int fidx, const bst_gpair_integer* hist, float fvalue = gidx == gidx_begin ? min_fvalue : gidx_fvalue_map[gidx - 1]; - bst_gpair_integer left = missing_left ? bin + missing : bin; - bst_gpair_integer right = parent_sum - left; + gpair_sum_t left = missing_left ? bin + missing : bin; + gpair_sum_t right = parent_sum - left; best_split->Update(gain, missing_left ? LeftDir : RightDir, fvalue, fidx, left, right, param); @@ -114,17 +117,16 @@ __device__ void EvaluateFeature(int fidx, const bst_gpair_integer* hist, template __global__ void evaluate_split_kernel( - const bst_gpair_integer* d_hist, int nidx, uint64_t n_features, + const gpair_sum_t* d_hist, int nidx, uint64_t n_features, DeviceNodeStats nodes, const int* d_feature_segments, const float* d_fidx_min_map, const float* d_gidx_fvalue_map, GPUTrainingParam gpu_param, DeviceSplitCandidate* d_split) { typedef cub::KeyValuePair ArgMaxT; - typedef cub::BlockScan + typedef cub::BlockScan BlockScanT; typedef cub::BlockReduce MaxReduceT; - typedef cub::BlockReduce SumReduceT; + typedef cub::BlockReduce SumReduceT; union TempStorage { typename BlockScanT::TempStorage scan; @@ -159,13 +161,6 @@ __global__ void evaluate_split_kernel( template __device__ int BinarySearchRow(bst_uint begin, bst_uint end, gidx_iter_t data, int fidx_begin, int fidx_end) { - // for(auto i = begin; i < end; i++) - //{ - // auto gidx = data[i]; - // if (gidx >= fidx_begin&&gidx < fidx_end) return gidx; - //} - // return -1; - bst_uint previous_middle = UINT32_MAX; while (end != begin) { auto middle = begin + (end - begin) / 2; @@ -190,48 +185,63 @@ __device__ int BinarySearchRow(bst_uint begin, bst_uint end, gidx_iter_t data, struct DeviceHistogram { dh::bulk_allocator ba; - dh::dvec data; - std::map node_map; + dh::dvec data; int n_bins; void Init(int device_idx, int max_nodes, int n_bins, bool silent) { this->n_bins = n_bins; - ba.allocate(device_idx, silent, &data, max_nodes * n_bins); + ba.allocate(device_idx, silent, &data, size_t(max_nodes) * size_t(n_bins)); } - void Reset() { - data.fill(bst_gpair_integer()); - node_map.clear(); - } + void Reset() { data.fill(gpair_sum_t()); } + gpair_sum_t* GetHistPtr(int nidx) { return data.data() + nidx * n_bins; } - void AddNode(int nidx) { - CHECK_EQ(node_map.count(nidx), 0) - << nidx << " already exists in the histogram."; - node_map[nidx] = data.data() + n_bins * node_map.size(); + void PrintNidx(int nidx) const { + auto h_data = data.as_vector(); + std::cout << "nidx " << nidx << ":\n"; + for (int i = n_bins * nidx; i < n_bins * (nidx + 1); i++) { + std::cout << h_data[i] << " "; + } + std::cout << "\n"; } }; // Manage memory for a single GPU struct DeviceShard { + struct Segment { + size_t begin; + size_t end; + + Segment() : begin(0), end(0) {} + + Segment(size_t begin, size_t end) : begin(begin), end(end) { + CHECK_GE(end, begin); + } + size_t Size() const { return end - begin; } + }; + int device_idx; int normalised_device_idx; // Device index counting from param.gpu_id dh::bulk_allocator ba; dh::dvec gidx_buffer; dh::dvec gpair; - dh::dvec2 ridx; + dh::dvec2 ridx; // Row index relative to this shard dh::dvec2 position; - std::vector> ridx_segments; + std::vector ridx_segments; dh::dvec feature_segments; dh::dvec gidx_fvalue_map; dh::dvec min_fvalue; std::vector node_sum_gradients; common::CompressedIterator gidx; int row_stride; - bst_uint row_start_idx; + bst_uint row_begin_idx; // The row offset for this shard bst_uint row_end_idx; bst_uint n_rows; int n_bins; int null_gidx_value; DeviceHistogram hist; + TrainParam param; + + int64_t* tmp_pinned; // Small amount of staging memory std::vector streams; @@ -242,11 +252,12 @@ struct DeviceShard { bst_uint row_end, int n_bins, TrainParam param) : device_idx(device_idx), normalised_device_idx(normalised_device_idx), - row_start_idx(row_begin), + row_begin_idx(row_begin), row_end_idx(row_end), n_rows(row_end - row_begin), n_bins(n_bins), - null_gidx_value(n_bins) { + null_gidx_value(n_bins), + param(param) { // Convert to ELLPACK matrix representation int max_elements_row = 0; for (auto i = row_begin; i < row_end; i++) { @@ -260,7 +271,8 @@ struct DeviceShard { for (auto i = row_begin; i < row_end; i++) { int row_count = 0; for (auto j = gmat.row_ptr[i]; j < gmat.row_ptr[i + 1]; j++) { - ellpack_matrix[i * row_stride + row_count] = gmat.index[j]; + ellpack_matrix[(i - row_begin) * row_stride + row_count] = + gmat.index[j]; row_count++; } } @@ -296,12 +308,15 @@ struct DeviceShard { // Init histogram hist.Init(device_idx, max_nodes, gmat.cut->row_ptr.back(), param.silent); + + dh::safe_cuda(cudaMallocHost(&tmp_pinned, sizeof(int64_t))); } ~DeviceShard() { for (auto& stream : streams) { dh::safe_cuda(cudaStreamDestroy(stream)); } + dh::safe_cuda(cudaFreeHost(tmp_pinned)); } // Get vector of at least n initialised streams @@ -324,81 +339,135 @@ struct DeviceShard { // Reset values for each update iteration void Reset(const std::vector& host_gpair) { + dh::safe_cuda(cudaSetDevice(device_idx)); position.current_dvec().fill(0); std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), bst_gpair()); - // TODO(rory): support subsampling - thrust::sequence(ridx.current_dvec().tbegin(), ridx.current_dvec().tend(), - row_start_idx); - std::fill(ridx_segments.begin(), ridx_segments.end(), std::make_pair(0, 0)); - ridx_segments.front() = std::make_pair(0, ridx.size()); - this->gpair.copy(host_gpair.begin() + row_start_idx, + + thrust::sequence(ridx.current_dvec().tbegin(), ridx.current_dvec().tend()); + + std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0)); + ridx_segments.front() = Segment(0, ridx.size()); + this->gpair.copy(host_gpair.begin() + row_begin_idx, host_gpair.begin() + row_end_idx); - // Check gradients are within acceptable size range - CheckGradientMax(gpair); + subsample_gpair(&gpair, param.subsample, row_begin_idx); hist.Reset(); } - __device__ void IncrementHist(bst_gpair gpair, int gidx, - bst_gpair_integer* node_hist) const { - auto dst_ptr = - reinterpret_cast(&node_hist[gidx]); // NOLINT - bst_gpair_integer tmp(gpair.GetGrad(), gpair.GetHess()); - auto src_ptr = reinterpret_cast(&tmp); - - atomicAdd(dst_ptr, - static_cast(*src_ptr)); // NOLINT - atomicAdd(dst_ptr + 1, - static_cast(*(src_ptr + 1))); // NOLINT - } - void BuildHist(int nidx) { - hist.AddNode(nidx); - auto d_node_hist = hist.node_map[nidx]; + auto segment = ridx_segments[nidx]; + auto d_node_hist = hist.GetHistPtr(nidx); auto d_gidx = gidx; auto d_ridx = ridx.current(); auto d_gpair = gpair.data(); auto row_stride = this->row_stride; auto null_gidx_value = this->null_gidx_value; - auto segment = ridx_segments[nidx]; - auto n_elements = (segment.second - segment.first) * row_stride; + auto n_elements = segment.Size() * row_stride; dh::launch_n(device_idx, n_elements, [=] __device__(size_t idx) { - int relative_ridx = d_ridx[(idx / row_stride) + segment.first]; - int gidx = d_gidx[relative_ridx * row_stride + idx % row_stride]; + int ridx = d_ridx[(idx / row_stride) + segment.begin]; + int gidx = d_gidx[ridx * row_stride + idx % row_stride]; + if (gidx != null_gidx_value) { - bst_gpair gpair = d_gpair[relative_ridx]; - IncrementHist(gpair, gidx, d_node_hist); + AtomicAddGpair(d_node_hist + gidx, d_gpair[ridx]); } }); } - void SortPosition(const std::pair& segment, int left_nidx, - int right_nidx) { - auto n = segment.second - segment.first; + void SubtractionTrick(int nidx_parent, int nidx_histogram, + int nidx_subtraction) { + auto d_node_hist_parent = hist.GetHistPtr(nidx_parent); + auto d_node_hist_histogram = hist.GetHistPtr(nidx_histogram); + auto d_node_hist_subtraction = hist.GetHistPtr(nidx_subtraction); + + dh::launch_n(device_idx, hist.n_bins, [=] __device__(size_t idx) { + d_node_hist_subtraction[idx] = + d_node_hist_parent[idx] - d_node_hist_histogram[idx]; + }); + } + + __device__ void CountLeft(int64_t* d_count, int val, int left_nidx) { + unsigned ballot = __ballot(val == left_nidx); + if (threadIdx.x % 32 == 0) { + atomicAdd(reinterpret_cast(d_count), // NOLINT + static_cast(__popc(ballot))); // NOLINT + } + } + + void UpdatePosition(int nidx, int left_nidx, int right_nidx, int fidx, + int split_gidx, bool default_dir_left, bool is_dense, + int fidx_begin, int fidx_end) { + dh::safe_cuda(cudaSetDevice(device_idx)); + temp_memory.LazyAllocate(sizeof(int64_t)); + auto d_left_count = temp_memory.Pointer(); + dh::safe_cuda(cudaMemset(d_left_count, 0, sizeof(int64_t))); + auto segment = ridx_segments[nidx]; + auto d_ridx = ridx.current(); + auto d_position = position.current(); + auto d_gidx = gidx; + auto row_stride = this->row_stride; + dh::launch_n<1, 512>( + device_idx, segment.Size(), [=] __device__(bst_uint idx) { + idx += segment.begin; + auto ridx = d_ridx[idx]; + auto row_begin = row_stride * ridx; + auto row_end = row_begin + row_stride; + auto gidx = -1; + if (is_dense) { + gidx = d_gidx[row_begin + fidx]; + } else { + gidx = BinarySearchRow(row_begin, row_end, d_gidx, fidx_begin, + fidx_end); + } + + int position; + if (gidx >= 0) { + // Feature is found + position = gidx <= split_gidx ? left_nidx : right_nidx; + } else { + // Feature is missing + position = default_dir_left ? left_nidx : right_nidx; + } + + CountLeft(d_left_count, position, left_nidx); + d_position[idx] = position; + }); + + dh::safe_cuda(cudaMemcpy(tmp_pinned, d_left_count, sizeof(int64_t), + cudaMemcpyDeviceToHost)); + auto left_count = *tmp_pinned; + + SortPosition(segment, left_nidx, right_nidx); + // dh::safe_cuda(cudaStreamSynchronize(stream)); + ridx_segments[left_nidx] = + Segment(segment.begin, segment.begin + left_count); + ridx_segments[right_nidx] = + Segment(segment.begin + left_count, segment.end); + } + + void SortPosition(const Segment& segment, int left_nidx, int right_nidx) { int min_bits = 0; int max_bits = static_cast( std::ceil(std::log2((std::max)(left_nidx, right_nidx) + 1))); size_t temp_storage_bytes = 0; cub::DeviceRadixSort::SortPairs( - nullptr, temp_storage_bytes, position.current() + segment.first, - position.other() + segment.first, ridx.current() + segment.first, - ridx.other() + segment.first, n, min_bits, max_bits); + nullptr, temp_storage_bytes, position.current() + segment.begin, + position.other() + segment.begin, ridx.current() + segment.begin, + ridx.other() + segment.begin, segment.Size(), min_bits, max_bits); temp_memory.LazyAllocate(temp_storage_bytes); cub::DeviceRadixSort::SortPairs( temp_memory.d_temp_storage, temp_memory.temp_storage_bytes, - position.current() + segment.first, position.other() + segment.first, - ridx.current() + segment.first, ridx.other() + segment.first, n, - min_bits, max_bits); - dh::safe_cuda(cudaMemcpy(position.current() + segment.first, - position.other() + segment.first, n * sizeof(int), - cudaMemcpyDeviceToDevice)); - dh::safe_cuda(cudaMemcpy(ridx.current() + segment.first, - ridx.other() + segment.first, n * sizeof(bst_uint), - cudaMemcpyDeviceToDevice)); - //} + position.current() + segment.begin, position.other() + segment.begin, + ridx.current() + segment.begin, ridx.other() + segment.begin, + segment.Size(), min_bits, max_bits); + dh::safe_cuda(cudaMemcpy( + position.current() + segment.begin, position.other() + segment.begin, + segment.Size() * sizeof(int), cudaMemcpyDeviceToDevice)); + dh::safe_cuda(cudaMemcpy( + ridx.current() + segment.begin, ridx.other() + segment.begin, + segment.Size() * sizeof(bst_uint), cudaMemcpyDeviceToDevice)); } }; @@ -412,10 +481,10 @@ class GPUHistMakerExperimental : public TreeUpdater { const std::vector>& args) override { param.InitAllowUnknown(args); CHECK(param.n_gpus != 0) << "Must have at least one device"; - CHECK(param.n_gpus <= 1 && param.n_gpus != -1) - << "Only one GPU currently supported"; n_devices = param.n_gpus; + dh::check_compute_capability(); + if (param.grow_policy == TrainParam::kLossGuide) { qexpand_.reset(new ExpandQueue(loss_guide)); } else { @@ -426,6 +495,7 @@ class GPUHistMakerExperimental : public TreeUpdater { } void Update(const std::vector& gpair, DMatrix* dmat, const std::vector& trees) override { + monitor.Start("Update"); GradStats::CheckInfo(dmat->info()); // rescale learning rate according to size of trees float lr = param.learning_rate; @@ -439,36 +509,119 @@ class GPUHistMakerExperimental : public TreeUpdater { LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl; } param.learning_rate = lr; + monitor.Stop("Update"); } void InitDataOnce(DMatrix* dmat) { info = &dmat->info(); + monitor.Start("Quantiles"); hmat_.Init(dmat, param.max_bin); gmat_.cut = &hmat_; gmat_.Init(dmat); + monitor.Stop("Quantiles"); n_bins = hmat_.row_ptr.back(); - shards.emplace_back(param.gpu_id, 0, gmat_, 0, info->num_row, n_bins, - param); + + int n_devices = dh::n_devices(param.n_gpus, info->num_row); + + bst_uint row_begin = 0; + bst_uint shard_size = + std::ceil(static_cast(info->num_row) / n_devices); + + std::vector dList(n_devices); + for (int d_idx = 0; d_idx < n_devices; ++d_idx) { + int device_idx = (param.gpu_id + d_idx) % dh::n_visible_devices(); + dList[d_idx] = device_idx; + } + + reducer.Init(dList); + + // Partition input matrix into row segments + std::vector row_segments; + shards.resize(n_devices); + row_segments.push_back(0); + for (int d_idx = 0; d_idx < n_devices; ++d_idx) { + bst_uint row_end = + std::min(static_cast(row_begin + shard_size), info->num_row); + row_segments.push_back(row_end); + row_begin = row_end; + } + + // Create device shards + omp_set_num_threads(shards.size()); +#pragma omp parallel + { + auto cpu_thread_id = omp_get_thread_num(); + shards[cpu_thread_id] = std::unique_ptr( + new DeviceShard(dList[cpu_thread_id], cpu_thread_id, gmat_, + row_segments[cpu_thread_id], + row_segments[cpu_thread_id + 1], n_bins, param)); + } + initialised = true; } void InitData(const std::vector& gpair, DMatrix* dmat, const RegTree& tree) { + monitor.Start("InitDataOnce"); if (!initialised) { + CheckGradientMax(gpair); this->InitDataOnce(dmat); } + monitor.Stop("InitDataOnce"); - this->ColSampleTree(); + column_sampler.Init(info->num_col, param); // Copy gpair & reset memory - for (auto& shard : shards) { - shard.Reset(gpair); + monitor.Start("InitDataReset"); + omp_set_num_threads(shards.size()); +#pragma omp parallel + { + auto cpu_thread_id = omp_get_thread_num(); + shards[cpu_thread_id]->Reset(gpair); } + monitor.Stop("InitDataReset"); } - void BuildHist(int nidx) { + void AllReduceHist(int nidx) { for (auto& shard : shards) { - shard.BuildHist(nidx); + auto d_node_hist = shard->hist.GetHistPtr(nidx); + reducer.AllReduceSum( + shard->normalised_device_idx, + reinterpret_cast(d_node_hist), + reinterpret_cast(d_node_hist), + n_bins * (sizeof(gpair_sum_t) / sizeof(gpair_sum_t::value_t))); + } + + reducer.Synchronize(); + } + + void BuildHistLeftRight(int nidx_parent, int nidx_left, int nidx_right) { + size_t left_node_max_elements = 0; + size_t right_node_max_elements = 0; + for (auto& shard : shards) { + left_node_max_elements = (std::max)( + left_node_max_elements, shard->ridx_segments[nidx_left].Size()); + right_node_max_elements = (std::max)( + right_node_max_elements, shard->ridx_segments[nidx_right].Size()); + } + + auto build_hist_nidx = nidx_left; + auto subtraction_trick_nidx = nidx_right; + + if (right_node_max_elements < left_node_max_elements) { + build_hist_nidx = nidx_right; + subtraction_trick_nidx = nidx_left; + } + + for (auto& shard : shards) { + shard->BuildHist(build_hist_nidx); + } + + this->AllReduceHist(build_hist_nidx); + + for (auto& shard : shards) { + shard->SubtractionTrick(nidx_parent, build_hist_nidx, + subtraction_trick_nidx); } } @@ -481,36 +634,41 @@ class GPUHistMakerExperimental : public TreeUpdater { columns); // Use first device auto& shard = shards.front(); - dh::safe_cuda(cudaSetDevice(shard.device_idx)); - shard.temp_memory.LazyAllocate(sizeof(DeviceSplitCandidate) * columns * - nidx_set.size()); - auto d_split = shard.temp_memory.Pointer(); + dh::safe_cuda(cudaSetDevice(shard->device_idx)); + shard->temp_memory.LazyAllocate(sizeof(DeviceSplitCandidate) * columns * + nidx_set.size()); + auto d_split = shard->temp_memory.Pointer(); - auto& streams = shard.GetStreams(static_cast(nidx_set.size())); + auto& streams = shard->GetStreams(static_cast(nidx_set.size())); // Use streams to process nodes concurrently for (auto i = 0; i < nidx_set.size(); i++) { auto nidx = nidx_set[i]; - DeviceNodeStats node(shard.node_sum_gradients[nidx], nidx, param); + DeviceNodeStats node(shard->node_sum_gradients[nidx], nidx, param); const int BLOCK_THREADS = 256; evaluate_split_kernel <<>>( - shard.hist.node_map[nidx], nidx, info->num_col, node, - shard.feature_segments.data(), shard.min_fvalue.data(), - shard.gidx_fvalue_map.data(), GPUTrainingParam(param), + shard->hist.GetHistPtr(nidx), nidx, info->num_col, node, + shard->feature_segments.data(), shard->min_fvalue.data(), + shard->gidx_fvalue_map.data(), GPUTrainingParam(param), d_split + i * columns); } dh::safe_cuda( - cudaMemcpy(candidate_splits.data(), shard.temp_memory.d_temp_storage, + cudaMemcpy(candidate_splits.data(), shard->temp_memory.d_temp_storage, sizeof(DeviceSplitCandidate) * columns * nidx_set.size(), cudaMemcpyDeviceToHost)); for (auto i = 0; i < nidx_set.size(); i++) { + auto nidx = nidx_set[i]; DeviceSplitCandidate nidx_best; for (auto fidx = 0; fidx < columns; fidx++) { - nidx_best.Update(candidate_splits[i * columns + fidx], param); + auto& candidate = candidate_splits[i * columns + fidx]; + if (column_sampler.ColumnUsed(candidate.findex, + p_tree->GetDepth(nidx))) { + nidx_best.Update(candidate_splits[i * columns + fidx], param); + } } best_splits[i] = nidx_best; } @@ -518,15 +676,28 @@ class GPUHistMakerExperimental : public TreeUpdater { } void InitRoot(const std::vector& gpair, RegTree* p_tree) { - int root_nidx = 0; - BuildHist(root_nidx); - - // TODO(rory): support sub sampling - // TODO(rory): not asynchronous - bst_gpair sum_gradient; - for (auto& shard : shards) { - sum_gradient += thrust::reduce(shard.gpair.tbegin(), shard.gpair.tend()); + auto root_nidx = 0; + // Sum gradients + std::vector tmp_sums(shards.size()); + omp_set_num_threads(shards.size()); +#pragma omp parallel + { + auto cpu_thread_id = omp_get_thread_num(); + dh::safe_cuda(cudaSetDevice(shards[cpu_thread_id]->device_idx)); + tmp_sums[cpu_thread_id] = + thrust::reduce(thrust::cuda::par(shards[cpu_thread_id]->temp_memory), + shards[cpu_thread_id]->gpair.tbegin(), + shards[cpu_thread_id]->gpair.tend()); } + auto sum_gradient = + std::accumulate(tmp_sums.begin(), tmp_sums.end(), bst_gpair()); + + // Generate root histogram + for (auto& shard : shards) { + shard->BuildHist(root_nidx); + } + + this->AllReduceHist(root_nidx); // Remember root stats p_tree->stat(root_nidx).sum_hess = sum_gradient.GetHess(); @@ -534,33 +705,17 @@ class GPUHistMakerExperimental : public TreeUpdater { // Store sum gradients for (auto& shard : shards) { - shard.node_sum_gradients[root_nidx] = sum_gradient; + shard->node_sum_gradients[root_nidx] = sum_gradient; } + // Generate first split auto splits = this->EvaluateSplits({root_nidx}, p_tree); - - // Generate candidate qexpand_->push( ExpandEntry(root_nidx, p_tree->GetDepth(root_nidx), splits.front(), 0)); } - struct MatchingFunctor : public thrust::unary_function { - int val; - __host__ __device__ MatchingFunctor(int val) : val(val) {} - __host__ __device__ int operator()(int x) const { return x == val; } - }; - - __device__ void CountLeft(int64_t* d_count, int val, int left_nidx) { - unsigned ballot = __ballot(val == left_nidx); - if (threadIdx.x % 32 == 0) { - atomicAdd(reinterpret_cast(d_count), // NOLINT - static_cast(__popc(ballot))); // NOLINT - } - } - void UpdatePosition(const ExpandEntry& candidate, RegTree* p_tree) { auto nidx = candidate.nid; - auto is_dense = info->num_nonzero == info->num_row * info->num_col; auto left_nidx = (*p_tree)[nidx].cleft(); auto right_nidx = (*p_tree)[nidx].cright(); @@ -577,58 +732,15 @@ class GPUHistMakerExperimental : public TreeUpdater { } } - for (auto& shard : shards) { - monitor.Start("update position kernel"); - shard.temp_memory.LazyAllocate(sizeof(int64_t)); - auto d_left_count = shard.temp_memory.Pointer(); - dh::safe_cuda(cudaMemset(d_left_count, 0, sizeof(int64_t))); - dh::safe_cuda(cudaSetDevice(shard.device_idx)); - auto segment = shard.ridx_segments[nidx]; - CHECK_GT(segment.second - segment.first, 0); - auto d_ridx = shard.ridx.current(); - auto d_position = shard.position.current(); - auto d_gidx = shard.gidx; - auto row_stride = shard.row_stride; - dh::launch_n<1, 512>( - shard.device_idx, segment.second - segment.first, - [=] __device__(bst_uint idx) { - idx += segment.first; - auto ridx = d_ridx[idx]; - auto row_begin = row_stride * ridx; - auto row_end = row_begin + row_stride; - auto gidx = -1; - if (is_dense) { - gidx = d_gidx[row_begin + fidx]; - } else { - gidx = BinarySearchRow(row_begin, row_end, d_gidx, fidx_begin, - fidx_end); - } + auto is_dense = info->num_nonzero == info->num_row * info->num_col; - int position; - if (gidx >= 0) { - // Feature is found - position = gidx <= split_gidx ? left_nidx : right_nidx; - } else { - // Feature is missing - position = default_dir_left ? left_nidx : right_nidx; - } - - CountLeft(d_left_count, position, left_nidx); - d_position[idx] = position; - }); - - int64_t left_count; - dh::safe_cuda(cudaMemcpy(&left_count, d_left_count, sizeof(int64_t), - cudaMemcpyDeviceToHost)); - monitor.Stop("update position kernel"); - - monitor.Start("sort"); - shard.SortPosition(segment, left_nidx, right_nidx); - monitor.Stop("sort"); - shard.ridx_segments[left_nidx] = - std::make_pair(segment.first, segment.first + left_count); - shard.ridx_segments[right_nidx] = - std::make_pair(segment.first + left_count, segment.second); + omp_set_num_threads(shards.size()); +#pragma omp parallel + { + auto cpu_thread_id = omp_get_thread_num(); + shards[cpu_thread_id]->UpdatePosition(nidx, left_nidx, right_nidx, fidx, + split_gidx, default_dir_left, + is_dense, fidx_begin, fidx_end); } } @@ -654,41 +766,12 @@ class GPUHistMakerExperimental : public TreeUpdater { tree.stat(parent.cright()).sum_hess = candidate.split.right_sum.GetHess(); // Store sum gradients for (auto& shard : shards) { - shard.node_sum_gradients[parent.cleft()] = candidate.split.left_sum; - shard.node_sum_gradients[parent.cright()] = candidate.split.right_sum; + shard->node_sum_gradients[parent.cleft()] = candidate.split.left_sum; + shard->node_sum_gradients[parent.cright()] = candidate.split.right_sum; } this->UpdatePosition(candidate, p_tree); } - void ColSampleTree() { - if (param.colsample_bylevel == 1.0 && param.colsample_bytree == 1.0) return; - - feature_set_tree.resize(info->num_col); - std::iota(feature_set_tree.begin(), feature_set_tree.end(), 0); - feature_set_tree = col_sample(feature_set_tree, param.colsample_bytree); - } - - struct Monitor { - bool debug_verbose = false; - std::string label = ""; - std::map timer_map; - - ~Monitor() { - if (!debug_verbose) return; - - std::cout << "Monitor: " << label << "\n"; - for (auto& kv : timer_map) { - kv.second.PrintElapsed(kv.first); - } - } - void Init(std::string label, bool debug_verbose) { - this->debug_verbose = debug_verbose; - this->label = label; - } - void Start(const std::string& name) { timer_map[name].Start(); } - void Stop(const std::string& name) { timer_map[name].Stop(); } - }; - void UpdateTree(const std::vector& gpair, DMatrix* p_fmat, RegTree* p_tree) { auto& tree = *p_tree; @@ -720,8 +803,8 @@ class GPUHistMakerExperimental : public TreeUpdater { if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), num_leaves)) { monitor.Start("BuildHist"); - this->BuildHist(left_child_nidx); - this->BuildHist(right_child_nidx); + this->BuildHistLeftRight(candidate.nid, left_child_nidx, + right_child_nidx); monitor.Stop("BuildHist"); monitor.Start("EvaluateSplits"); @@ -793,14 +876,14 @@ class GPUHistMakerExperimental : public TreeUpdater { int n_devices; int n_bins; - std::vector shards; - std::vector feature_set_tree; - std::vector feature_set_level; + std::vector> shards; + ColumnSampler column_sampler; typedef std::priority_queue, std::function> ExpandQueue; std::unique_ptr qexpand_; - Monitor monitor; + common::Monitor monitor; + dh::AllReducer reducer; }; XGBOOST_REGISTER_TREE_UPDATER(GPUHistMakerExperimental, diff --git a/tests/benchmark/benchmark.py b/tests/benchmark/benchmark.py index 2ee17ffb8..d913ee769 100644 --- a/tests/benchmark/benchmark.py +++ b/tests/benchmark/benchmark.py @@ -11,20 +11,36 @@ rng = np.random.RandomState(1994) def run_benchmark(args): - print("Generating dataset: {} rows * {} columns".format(args.rows, args.columns)) - print("{}/{} test/train split".format(args.test_size, 1.0 - args.test_size)) - tmp = time.time() - X, y = make_classification(args.rows, n_features=args.columns, random_state=7) - if args.sparsity < 1.0: - X = np.array([[np.nan if rng.uniform(0, 1) < args.sparsity else x for x in x_row] for x_row in X]) - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=args.test_size, random_state=7) - print ("Generate Time: %s seconds" % (str(time.time() - tmp))) - tmp = time.time() - print ("DMatrix Start") - dtrain = xgb.DMatrix(X_train, y_train, nthread=-1) - dtest = xgb.DMatrix(X_test, y_test, nthread=-1) - print ("DMatrix Time: %s seconds" % (str(time.time() - tmp))) + try: + dtest = xgb.DMatrix('dtest.dm') + dtrain = xgb.DMatrix('dtrain.dm') + + if not (dtest.num_col() == args.columns \ + and dtrain.num_col() == args.columns): + raise ValueError("Wrong cols") + if not (dtest.num_row() == args.rows * args.test_size \ + and dtrain.num_row() == args.rows * (1-args.test_size)): + raise ValueError("Wrong rows") + except: + + print("Generating dataset: {} rows * {} columns".format(args.rows, args.columns)) + print("{}/{} test/train split".format(args.test_size, 1.0 - args.test_size)) + tmp = time.time() + X, y = make_classification(args.rows, n_features=args.columns, n_redundant=0, n_informative=args.columns, n_repeated=0, random_state=7) + if args.sparsity < 1.0: + X = np.array([[np.nan if rng.uniform(0, 1) < args.sparsity else x for x in x_row] for x_row in X]) + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=args.test_size, random_state=7) + print ("Generate Time: %s seconds" % (str(time.time() - tmp))) + tmp = time.time() + print ("DMatrix Start") + dtrain = xgb.DMatrix(X_train, y_train) + dtest = xgb.DMatrix(X_test, y_test, nthread=-1) + print ("DMatrix Time: %s seconds" % (str(time.time() - tmp))) + + dtest.save_binary('dtest.dm') + dtrain.save_binary('dtrain.dm') param = {'objective': 'binary:logistic'} if args.params is not '': diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index 56b98a1cc..c51afb3f8 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -6,6 +6,7 @@ #include #include "../../../src/common/device_helpers.cuh" #include "gtest/gtest.h" +#include "../../../src/common/timer.h" void CreateTestData(xgboost::bst_uint num_rows, int max_row_size, thrust::host_vector *row_ptr, @@ -35,7 +36,7 @@ void SpeedTest() { thrust::device_vector output_row(h_rows.size()); auto d_output_row = output_row.data(); - dh::Timer t; + xgboost::common::Timer t; dh::TransformLbs( 0, &temp_memory, h_rows.size(), dh::raw(row_ptr), row_ptr.size() - 1, false, [=] __device__(size_t idx, size_t ridx) { d_output_row[idx] = ridx; }); diff --git a/tests/cpp/tree/test_gpu_hist_experimental.cu b/tests/cpp/tree/test_gpu_hist_experimental.cu index 46fd99d5f..481a1b254 100644 --- a/tests/cpp/tree/test_gpu_hist_experimental.cu +++ b/tests/cpp/tree/test_gpu_hist_experimental.cu @@ -7,8 +7,8 @@ #include "../helpers.h" #include "gtest/gtest.h" -#include "../../../src/tree/updater_gpu_hist_experimental.cu" #include "../../../src/gbm/gbtree_model.h" +#include "../../../src/tree/updater_gpu_hist_experimental.cu" namespace xgboost { namespace tree { @@ -22,7 +22,9 @@ TEST(gpu_hist_experimental, TestSparseShard) { hmat.Init(dmat.get(), max_bins); gmat.cut = &hmat; gmat.Init(dmat.get()); - DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(), TrainParam()); + ncclComm_t comm; + DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(), + TrainParam()); ASSERT_LT(shard.row_stride, columns); @@ -54,7 +56,9 @@ TEST(gpu_hist_experimental, TestDenseShard) { hmat.Init(dmat.get(), max_bins); gmat.cut = &hmat; gmat.Init(dmat.get()); - DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(), TrainParam()); + ncclComm_t comm; + DeviceShard shard(0, 0, gmat, 0, rows, hmat.row_ptr.back(), + TrainParam()); ASSERT_EQ(shard.row_stride, columns); diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index cb430ed55..f7dbbb489 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -8,6 +8,7 @@ import numpy as np import unittest from nose.plugins.attrib import attr from sklearn.datasets import load_digits, load_boston, load_breast_cancer, make_regression +import itertools as it rng = np.random.RandomState(1994) @@ -15,8 +16,9 @@ rng = np.random.RandomState(1994) def non_increasing(L, tolerance): return all((y - x) < tolerance for x, y in zip(L, L[1:])) -#Check result is always decreasing and final accuracy is within tolerance -def assert_accuracy(res, tree_method, comparison_tree_method, tolerance): + +# Check result is always decreasing and final accuracy is within tolerance +def assert_accuracy(res, tree_method, comparison_tree_method, tolerance, param): assert non_increasing(res[tree_method], tolerance) assert np.allclose(res[tree_method][-1], res[comparison_tree_method][-1], 1e-3, 1e-2) @@ -26,13 +28,14 @@ def train_boston(param_in, comparison_tree_method): dtrain = xgb.DMatrix(data.data, label=data.target) param = {} param.update(param_in) + param['max_depth'] = 2 res_tmp = {} res = {} num_rounds = 10 - xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp) + bst = xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp) res[param['tree_method']] = res_tmp['train']['rmse'] param["tree_method"] = comparison_tree_method - xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp) + bst = xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp) res[comparison_tree_method] = res_tmp['train']['rmse'] return res @@ -92,17 +95,24 @@ def train_sparse(param_in, comparison_tree_method): return res +# Enumerates all permutations of variable parameters def assert_updater_accuracy(tree_method, comparison_tree_method, variable_param, tolerance): - param = {'tree_method': tree_method} - for k, set in variable_param.items(): - for val in set: - param_tmp = param.copy() - param_tmp[k] = val - print(param_tmp, file=sys.stderr) - assert_accuracy(train_boston(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance) - assert_accuracy(train_digits(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance) - assert_accuracy(train_cancer(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance) - assert_accuracy(train_sparse(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance) + param = {'tree_method': tree_method } + names = sorted(variable_param) + combinations = it.product(*(variable_param[Name] for Name in names)) + + for set in combinations: + print(names, file=sys.stderr) + print(set, file=sys.stderr) + param_tmp = param.copy() + for i, name in enumerate(names): + param_tmp[name] = set[i] + + print(param_tmp, file=sys.stderr) + assert_accuracy(train_boston(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance, param_tmp) + assert_accuracy(train_digits(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance, param_tmp) + assert_accuracy(train_cancer(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance, param_tmp) + assert_accuracy(train_sparse(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance, param_tmp) @attr('gpu') @@ -116,5 +126,5 @@ class TestGPU(unittest.TestCase): assert_updater_accuracy('gpu_exact', 'exact', variable_param, 0.02) def test_gpu_hist_experimental(self): - variable_param = {'max_depth': [2, 6], 'max_leaves': [255, 4], 'max_bin': [2, 16, 1024]} + variable_param = {'n_gpus': [1, -1], 'max_depth': [2, 6], 'max_leaves': [255, 4], 'max_bin': [2, 16, 1024]} assert_updater_accuracy('gpu_hist_experimental', 'hist', variable_param, 0.01)