From a185ddfe03f881735e2e59dc5c1e268500309bbd Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Fri, 20 Apr 2018 14:56:35 +1200 Subject: [PATCH] Implement GPU accelerated coordinate descent algorithm (#3178) * Implement GPU accelerated coordinate descent algorithm. * Exclude external memory tests for GPU --- include/xgboost/linear_updater.h | 3 +- python-package/xgboost/core.py | 5 +- rabit | 2 +- src/common/device_helpers.cuh | 107 ++++++--- src/gbm/gblinear.cc | 4 +- src/linear/linear_updater.cc | 3 + src/linear/updater_coordinate.cc | 22 +- src/linear/updater_gpu_coordinate.cu | 346 +++++++++++++++++++++++++++ src/linear/updater_shotgun.cc | 15 +- tests/cpp/linear/test_linear.cc | 12 +- tests/python-gpu/test_gpu_linear.py | 14 ++ tests/python/test_linear.py | 3 +- 12 files changed, 473 insertions(+), 63 deletions(-) create mode 100644 src/linear/updater_gpu_coordinate.cu create mode 100644 tests/python-gpu/test_gpu_linear.py diff --git a/include/xgboost/linear_updater.h b/include/xgboost/linear_updater.h index 3d5d75f13..f083f8fa9 100644 --- a/include/xgboost/linear_updater.h +++ b/include/xgboost/linear_updater.h @@ -11,6 +11,7 @@ #include #include #include "../../src/gbm/gblinear_model.h" +#include "../../src/common/host_device_vector.h" namespace xgboost { /*! @@ -36,7 +37,7 @@ class LinearUpdater { * \param sum_instance_weight The sum instance weights, used to normalise l1/l2 penalty. */ - virtual void Update(std::vector* in_gpair, DMatrix* data, + virtual void Update(HostDeviceVector* in_gpair, DMatrix* data, gbm::GBLinearModel* model, double sum_instance_weight) = 0; diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 9b29df695..8b5ec7033 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -626,9 +626,8 @@ class DMatrix(object): feature_names : list or None """ if self._feature_names is None: - return ['f{0}'.format(i) for i in range(self.num_col())] - else: - return self._feature_names + self._feature_names = ['f{0}'.format(i) for i in range(self.num_col())] + return self._feature_names @property def feature_types(self): diff --git a/rabit b/rabit index 7bc46b8c7..a764d45cf 160000 --- a/rabit +++ b/rabit @@ -1 +1 @@ -Subproject commit 7bc46b8c75a6d530b2ad4efcf407b6aeab71e44f +Subproject commit a764d45cfb438cc9f15cf47ce586c02ff2c65d0f diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 2413e065a..ce2119ed7 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -374,10 +374,10 @@ class DVec { safe_cuda(cudaSetDevice(this->DeviceIdx())); if (end - begin != Size()) { throw std::runtime_error( - "Cannot copy assign vector to DVec, sizes are different"); + "Cannot copy assign vector to dvec, sizes are different"); } - safe_cuda(cudaMemcpy(this->Data(), begin.get(), - Size() * sizeof(T), cudaMemcpyDefault)); + safe_cuda(cudaMemcpy(this->Data(), begin.get(), Size() * sizeof(T), + cudaMemcpyDefault)); } }; @@ -544,7 +544,7 @@ struct CubMemory { size_t temp_storage_bytes; // Thrust - using ValueT = char; + using value_type = char; // NOLINT CubMemory() : d_temp_storage(nullptr), temp_storage_bytes(0) {} @@ -807,18 +807,20 @@ void SumReduction(dh::CubMemory &tmp_mem, dh::DVec &in, dh::DVec &out, * @param nVals number of elements in the input array */ template -T SumReduction(dh::CubMemory &tmp_mem, T *in, int nVals) { +typename std::iterator_traits::value_type SumReduction(dh::CubMemory &tmp_mem, T in, int nVals) { + using ValueT = typename std::iterator_traits::value_type; size_t tmpSize; dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, in, nVals)); // Allocate small extra memory for the return value - tmp_mem.LazyAllocate(tmpSize + sizeof(T)); - auto ptr = reinterpret_cast(tmp_mem.d_temp_storage) + 1; + tmp_mem.LazyAllocate(tmpSize + sizeof(ValueT)); + auto ptr = reinterpret_cast(tmp_mem.d_temp_storage) + 1; dh::safe_cuda(cub::DeviceReduce::Sum( - reinterpret_cast(ptr), tmpSize, in, - reinterpret_cast(tmp_mem.d_temp_storage), nVals)); - T sum; - dh::safe_cuda(cudaMemcpy(&sum, tmp_mem.d_temp_storage, sizeof(T), - cudaMemcpyDeviceToHost)); + reinterpret_cast(ptr), tmpSize, in, + reinterpret_cast(tmp_mem.d_temp_storage), + nVals)); + ValueT sum; + dh::safe_cuda(cudaMemcpy(&sum, tmp_mem.d_temp_storage, sizeof(ValueT), + cudaMemcpyDeviceToHost)); return sum; } @@ -876,7 +878,8 @@ void Gather(int device_idx, T *out, const T *in, const int *instId, int nVals) { * \class AllReducer * * \brief All reducer class that manages its own communication group and - * streams. Must be initialised before use. If XGBoost is compiled without NCCL this is a dummy class that will error if used with more than one GPU. + * streams. Must be initialised before use. If XGBoost is compiled without NCCL + * this is a dummy class that will error if used with more than one GPU. */ class AllReducer { @@ -912,7 +915,8 @@ class AllReducer { } initialised = true; #else - CHECK_EQ(device_ordinals.size(), 1) << "XGBoost must be compiled with NCCL to use more than one GPU."; + CHECK_EQ(device_ordinals.size(), 1) + << "XGBoost must be compiled with NCCL to use more than one GPU."; #endif } ~AllReducer() { @@ -929,16 +933,13 @@ class AllReducer { } /** - * \fn void AllReduceSum(int communication_group_idx, const double *sendbuff, - * double *recvbuff, int count) - * * \brief Allreduce. Use in exactly the same way as NCCL but without needing * streams or comms. * - * \param communication_group_idx Zero-based index of the - * communication group. \param sendbuff The sendbuff. \param - * sendbuff The sendbuff. \param [in,out] recvbuff - * The recvbuff. \param count Number of. + * \param communication_group_idx Zero-based index of the communication group. + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of elements. */ void AllReduceSum(int communication_group_idx, const double *sendbuff, @@ -954,17 +955,14 @@ class AllReducer { } /** - * \fn void AllReduceSum(int communication_group_idx, const int64_t *sendbuff, int64_t *recvbuff, int count) - * * \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms. * - * \param communication_group_idx Zero-based index of the communication group. \param - * sendbuff The sendbuff. \param sendbuff - * The sendbuff. \param [in,out] recvbuff The recvbuff. - * \param count Number of. - * \param sendbuff The sendbuff. - * \param [in,out] recvbuff If non-null, the recvbuff. - * \param count Number of. + * \param count Number of. + * + * \param communication_group_idx Zero-based index of the communication group. \param sendbuff. + * \param sendbuff The sendbuff. + * \param recvbuff The recvbuff. + * \param count Number of. */ void AllReduceSum(int communication_group_idx, const int64_t *sendbuff, @@ -993,4 +991,53 @@ class AllReducer { #endif } }; + +/** + * \brief Executes some operation on each element of the input vector, using a + * single controlling thread for each element. + * + * \tparam T Generic type parameter. + * \tparam FunctionT Type of the function t. + * \param shards The shards. + * \param f The func_t to process. + */ + +template +void ExecuteShards(std::vector *shards, FunctionT f) { + auto previous_num_threads = omp_get_max_threads(); + omp_set_num_threads(shards->size()); +#pragma omp parallel + { + auto cpu_thread_id = omp_get_thread_num(); + f(shards->at(cpu_thread_id)); + } + omp_set_num_threads(previous_num_threads); +} + +/** + * \brief Executes some operation on each element of the input vector, using a single controlling + * thread for each element, returns the sum of the results. + * + * \tparam ReduceT Type of the reduce t. + * \tparam T Generic type parameter. + * \tparam FunctionT Type of the function t. + * \param shards The shards. + * \param f The func_t to process. + * + * \return A reduce_t. + */ + +template +ReduceT ReduceShards(std::vector *shards, FunctionT f) { + auto previous_num_threads = omp_get_max_threads(); + omp_set_num_threads(shards->size()); + std::vector sums(shards->size()); +#pragma omp parallel + { + auto cpu_thread_id = omp_get_thread_num(); + sums[cpu_thread_id] = f(shards->at(cpu_thread_id)); + } + omp_set_num_threads(previous_num_threads); + return std::accumulate(sums.begin(), sums.end(), ReduceT()); +} } // namespace dh diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index d1ea3a306..9de055b79 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -84,15 +84,17 @@ class GBLinear : public GradientBooster { monitor_.Start("DoBoost"); if (!p_fmat->HaveColAccess(false)) { + monitor_.Start("InitColAccess"); std::vector enabled(p_fmat->Info().num_col_, true); p_fmat->InitColAccess(enabled, 1.0f, param_.max_row_perbatch, false); + monitor_.Stop("InitColAccess"); } model_.LazyInitModel(); this->LazySumWeights(p_fmat); if (!this->CheckConvergence()) { - updater_->Update(&in_gpair->HostVector(), p_fmat, &model_, sum_instance_weight_); + updater_->Update(in_gpair, p_fmat, &model_, sum_instance_weight_); } this->UpdatePredictionCache(); diff --git a/src/linear/linear_updater.cc b/src/linear/linear_updater.cc index 9041a57f3..4e12cd865 100644 --- a/src/linear/linear_updater.cc +++ b/src/linear/linear_updater.cc @@ -25,5 +25,8 @@ namespace linear { // List of files that will be force linked in static links. DMLC_REGISTRY_LINK_TAG(updater_shotgun); DMLC_REGISTRY_LINK_TAG(updater_coordinate); +#ifdef XGBOOST_USE_CUDA +DMLC_REGISTRY_LINK_TAG(updater_gpu_coordinate); +#endif } // namespace linear } // namespace xgboost diff --git a/src/linear/updater_coordinate.cc b/src/linear/updater_coordinate.cc index 8de2b6d97..d4c294c9a 100644 --- a/src/linear/updater_coordinate.cc +++ b/src/linear/updater_coordinate.cc @@ -84,48 +84,46 @@ class CoordinateUpdater : public LinearUpdater { selector.reset(FeatureSelector::Create(param.feature_selector)); monitor.Init("CoordinateUpdater", param.debug_verbose); } - - void Update(std::vector *in_gpair, DMatrix *p_fmat, + void Update(HostDeviceVector *in_gpair, DMatrix *p_fmat, gbm::GBLinearModel *model, double sum_instance_weight) override { param.DenormalizePenalties(sum_instance_weight); const int ngroup = model->param.num_output_group; // update bias for (int group_idx = 0; group_idx < ngroup; ++group_idx) { - auto grad = GetBiasGradientParallel(group_idx, ngroup, *in_gpair, p_fmat); + auto grad = GetBiasGradientParallel(group_idx, ngroup, in_gpair->HostVector(), p_fmat); auto dbias = static_cast(param.learning_rate * CoordinateDeltaBias(grad.first, grad.second)); model->bias()[group_idx] += dbias; - UpdateBiasResidualParallel(group_idx, ngroup, dbias, in_gpair, p_fmat); + UpdateBiasResidualParallel(group_idx, ngroup, + dbias, &in_gpair->HostVector(), p_fmat); } // prepare for updating the weights - selector->Setup(*model, *in_gpair, p_fmat, param.reg_alpha_denorm, + selector->Setup(*model, in_gpair->HostVector(), p_fmat, param.reg_alpha_denorm, param.reg_lambda_denorm, param.top_k); // update weights for (int group_idx = 0; group_idx < ngroup; ++group_idx) { for (unsigned i = 0U; i < model->param.num_feature; i++) { - int fidx = selector->NextFeature(i, *model, group_idx, *in_gpair, p_fmat, + int fidx = selector->NextFeature(i, *model, group_idx, in_gpair->HostVector(), p_fmat, param.reg_alpha_denorm, param.reg_lambda_denorm); if (fidx < 0) break; - this->UpdateFeature(fidx, group_idx, in_gpair, p_fmat, model); + this->UpdateFeature(fidx, group_idx, &in_gpair->HostVector(), p_fmat, model); } } + monitor.Stop("UpdateFeature"); } inline void UpdateFeature(int fidx, int group_idx, std::vector *in_gpair, DMatrix *p_fmat, gbm::GBLinearModel *model) { const int ngroup = model->param.num_output_group; bst_float &w = (*model)[fidx][group_idx]; - monitor.Start("GetGradientParallel"); - auto gradient = GetGradientParallel(group_idx, ngroup, fidx, *in_gpair, p_fmat); - monitor.Stop("GetGradientParallel"); + auto gradient = + GetGradientParallel(group_idx, ngroup, fidx, *in_gpair, p_fmat); auto dw = static_cast( param.learning_rate * CoordinateDelta(gradient.first, gradient.second, w, param.reg_alpha_denorm, param.reg_lambda_denorm)); w += dw; - monitor.Start("UpdateResidualParallel"); UpdateResidualParallel(fidx, group_idx, ngroup, dw, in_gpair, p_fmat); - monitor.Stop("UpdateResidualParallel"); } // training parameter diff --git a/src/linear/updater_gpu_coordinate.cu b/src/linear/updater_gpu_coordinate.cu new file mode 100644 index 000000000..3f4dea972 --- /dev/null +++ b/src/linear/updater_gpu_coordinate.cu @@ -0,0 +1,346 @@ +/*! + * Copyright 2018 by Contributors + * \author Rory Mitchell + */ + +#include +#include +#include +#include "../common/device_helpers.cuh" +#include "../common/timer.h" +#include "coordinate_common.h" + +namespace xgboost { +namespace linear { + +DMLC_REGISTRY_FILE_TAG(updater_gpu_coordinate); + +// training parameter +struct GPUCoordinateTrainParam + : public dmlc::Parameter { + /*! \brief learning_rate */ + float learning_rate; + /*! \brief regularization weight for L2 norm */ + float reg_lambda; + /*! \brief regularization weight for L1 norm */ + float reg_alpha; + int feature_selector; + int top_k; + int debug_verbose; + int n_gpus; + int gpu_id; + bool silent; + // declare parameters + DMLC_DECLARE_PARAMETER(GPUCoordinateTrainParam) { + DMLC_DECLARE_FIELD(learning_rate) + .set_lower_bound(0.0f) + .set_default(1.0f) + .describe("Learning rate of each update."); + DMLC_DECLARE_FIELD(reg_lambda) + .set_lower_bound(0.0f) + .set_default(0.0f) + .describe("L2 regularization on weights."); + DMLC_DECLARE_FIELD(reg_alpha) + .set_lower_bound(0.0f) + .set_default(0.0f) + .describe("L1 regularization on weights."); + DMLC_DECLARE_FIELD(feature_selector) + .set_default(kCyclic) + .add_enum("cyclic", kCyclic) + .add_enum("shuffle", kShuffle) + .add_enum("thrifty", kThrifty) + .add_enum("greedy", kGreedy) + .add_enum("random", kRandom) + .describe("Feature selection or ordering method."); + DMLC_DECLARE_FIELD(top_k).set_lower_bound(0).set_default(0).describe( + "The number of top features to select in 'thrifty' feature_selector. " + "The value of zero means using all the features."); + DMLC_DECLARE_FIELD(debug_verbose) + .set_lower_bound(0) + .set_default(0) + .describe("flag to print out detailed breakdown of runtime"); + DMLC_DECLARE_FIELD(n_gpus).set_default(1).describe( + "Number of devices to use."); + DMLC_DECLARE_FIELD(gpu_id).set_default(0).describe( + "Primary device ordinal."); + DMLC_DECLARE_FIELD(silent).set_default(false).describe( + "Do not print information during trainig."); + // alias of parameters + DMLC_DECLARE_ALIAS(learning_rate, eta); + DMLC_DECLARE_ALIAS(reg_lambda, lambda); + DMLC_DECLARE_ALIAS(reg_alpha, alpha); + } + /*! \brief Denormalizes the regularization penalties - to be called at each + * update */ + void DenormalizePenalties(double sum_instance_weight) { + reg_lambda_denorm = reg_lambda * sum_instance_weight; + reg_alpha_denorm = reg_alpha * sum_instance_weight; + } + // denormalizated regularization penalties + float reg_lambda_denorm; + float reg_alpha_denorm; +}; + +void RescaleIndices(size_t ridx_begin, dh::DVec *data) { + auto d_data = data->Data(); + dh::LaunchN(data->DeviceIdx(), data->Size(), + [=] __device__(size_t idx) { d_data[idx].index -= ridx_begin; }); +} + +class DeviceShard { + int device_idx_; + int normalised_device_idx_; // Device index counting from param.gpu_id + dh::BulkAllocator ba_; + std::vector row_ptr_; + dh::DVec data_; + dh::DVec gpair_; + dh::CubMemory temp_; + size_t ridx_begin_; + size_t ridx_end_; + + public: + DeviceShard(int device_idx, int normalised_device_idx, const ColBatch &batch, + bst_uint row_begin, bst_uint row_end, + const GPUCoordinateTrainParam ¶m, + const gbm::GBLinearModelParam &model_param) + : device_idx_(device_idx), + normalised_device_idx_(normalised_device_idx), + ridx_begin_(row_begin), + ridx_end_(row_end) { + dh::safe_cuda(cudaSetDevice(device_idx)); + // The begin and end indices for the section of each column associated with + // this shard + std::vector> column_segments; + row_ptr_ = {0}; + for (auto fidx = 0; fidx < batch.size; fidx++) { + auto col = batch[fidx]; + auto cmp = [](SparseBatch::Entry e1, SparseBatch::Entry e2) { + return e1.index < e2.index; + }; + auto column_begin = + std::lower_bound(col.data, col.data + col.length, + SparseBatch::Entry(row_begin, 0.0f), cmp); + auto column_end = + std::upper_bound(col.data, col.data + col.length, + SparseBatch::Entry(row_end, 0.0f), cmp); + column_segments.push_back( + std::make_pair(column_begin - col.data, column_end - col.data)); + row_ptr_.push_back(row_ptr_.back() + column_end - column_begin); + } + ba_.Allocate(device_idx, param.silent, &data_, row_ptr_.back(), &gpair_, + (row_end - row_begin) * model_param.num_output_group); + + for (int fidx = 0; fidx < batch.size; fidx++) { + ColBatch::Inst col = batch[fidx]; + thrust::copy(col.data + column_segments[fidx].first, + col.data + column_segments[fidx].second, + data_.tbegin() + row_ptr_[fidx]); + } + // Rescale indices with respect to current shard + RescaleIndices(ridx_begin_, &data_); + } + void UpdateGpair(const std::vector &host_gpair, + const gbm::GBLinearModelParam &model_param) { + gpair_.copy(host_gpair.begin() + ridx_begin_ * model_param.num_output_group, + host_gpair.begin() + ridx_end_ * model_param.num_output_group); + } + + GradientPair GetBiasGradient(int group_idx, int num_group) { + auto counting = thrust::make_counting_iterator(0ull); + auto f = [=] __device__(size_t idx) { + return idx * num_group + group_idx; + }; // NOLINT + thrust::transform_iterator skip( + counting, f); + auto perm = thrust::make_permutation_iterator(gpair_.tbegin(), skip); + + return dh::SumReduction(temp_, perm, ridx_end_ - ridx_begin_); + } + + void UpdateBiasResidual(float dbias, int group_idx, int num_groups) { + if (dbias == 0.0f) return; + auto d_gpair = gpair_.Data(); + dh::LaunchN(device_idx_, ridx_end_ - ridx_begin_, [=] __device__(size_t idx) { + auto &g = d_gpair[idx * num_groups + group_idx]; + g += GradientPair(g.GetHess() * dbias, 0); + }); + } + + GradientPair GetGradient(int group_idx, int num_group, int fidx) { + auto d_col = data_.Data() + row_ptr_[fidx]; + size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx]; + auto d_gpair = gpair_.Data(); + auto counting = thrust::make_counting_iterator(0ull); + auto f = [=] __device__(size_t idx) { + auto entry = d_col[idx]; + auto g = d_gpair[entry.index * num_group + group_idx]; + return GradientPair(g.GetGrad() * entry.fvalue, + g.GetHess() * entry.fvalue * entry.fvalue); + }; // NOLINT + thrust::transform_iterator + multiply_iterator(counting, f); + return dh::SumReduction(temp_, multiply_iterator, col_size); + } + + void UpdateResidual(float dw, int group_idx, int num_groups, int fidx) { + auto d_gpair = gpair_.Data(); + auto d_col = data_.Data() + row_ptr_[fidx]; + size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx]; + dh::LaunchN(device_idx_, col_size, [=] __device__(size_t idx) { + auto entry = d_col[idx]; + auto &g = d_gpair[entry.index * num_groups + group_idx]; + g += GradientPair(g.GetHess() * dw * entry.fvalue, 0); + }); + } +}; + +/** + * \class GPUCoordinateUpdater + * + * \brief Coordinate descent algorithm that updates one feature per iteration + */ + +class GPUCoordinateUpdater : public LinearUpdater { + public: + // set training parameter + void Init( + const std::vector> &args) override { + param.InitAllowUnknown(args); + selector.reset(FeatureSelector::Create(param.feature_selector)); + monitor.Init("GPUCoordinateUpdater", param.debug_verbose); + } + + void LazyInitShards(DMatrix *p_fmat, + const gbm::GBLinearModelParam &model_param) { + if (!shards.empty()) return; + int n_devices = dh::NDevices(param.n_gpus, p_fmat->Info().num_row_); + bst_uint row_begin = 0; + bst_uint shard_size = + std::ceil(static_cast(p_fmat->Info().num_row_) / n_devices); + + device_list.resize(n_devices); + for (int d_idx = 0; d_idx < n_devices; ++d_idx) { + int device_idx = (param.gpu_id + d_idx) % dh::NVisibleDevices(); + device_list[d_idx] = device_idx; + } + // Partition input matrix into row segments + std::vector row_segments; + row_segments.push_back(0); + for (int d_idx = 0; d_idx < n_devices; ++d_idx) { + bst_uint row_end = std::min(static_cast(row_begin + shard_size), + p_fmat->Info().num_row_); + row_segments.push_back(row_end); + row_begin = row_end; + } + + dmlc::DataIter *iter = p_fmat->ColIterator(); + CHECK(p_fmat->SingleColBlock()); + iter->Next(); + auto batch = iter->Value(); + + shards.resize(n_devices); + // Create device shards + dh::ExecuteShards(&shards, [&](std::unique_ptr &shard) { + auto idx = &shard - &shards[0]; + shard = std::unique_ptr( + new DeviceShard(device_list[idx], idx, batch, row_segments[idx], + row_segments[idx + 1], param, model_param)); + }); + } + void Update(HostDeviceVector *in_gpair, DMatrix *p_fmat, + gbm::GBLinearModel *model, double sum_instance_weight) override { + param.DenormalizePenalties(sum_instance_weight); + monitor.Start("LazyInitShards"); + this->LazyInitShards(p_fmat, model->param); + monitor.Stop("LazyInitShards"); + + monitor.Start("UpdateGpair"); + // Update gpair + dh::ExecuteShards(&shards, [&](std::unique_ptr &shard) { + shard->UpdateGpair(in_gpair->HostVector(), model->param); + }); + monitor.Stop("UpdateGpair"); + + monitor.Start("UpdateBias"); + this->UpdateBias(p_fmat, model); + monitor.Stop("UpdateBias"); + // prepare for updating the weights + selector->Setup(*model, in_gpair->HostVector(), p_fmat, + param.reg_alpha_denorm, param.reg_lambda_denorm, + param.top_k); + monitor.Start("UpdateFeature"); + for (auto group_idx = 0; group_idx < model->param.num_output_group; + ++group_idx) { + for (auto i = 0U; i < model->param.num_feature; i++) { + auto fidx = selector->NextFeature( + i, *model, group_idx, in_gpair->HostVector(), p_fmat, + param.reg_alpha_denorm, param.reg_lambda_denorm); + if (fidx < 0) break; + this->UpdateFeature(fidx, group_idx, &in_gpair->HostVector(), model); + } + } + monitor.Stop("UpdateFeature"); + } + + void UpdateBias(DMatrix *p_fmat, gbm::GBLinearModel *model) { + for (int group_idx = 0; group_idx < model->param.num_output_group; + ++group_idx) { + // Get gradient + auto grad = dh::ReduceShards( + &shards, [&](std::unique_ptr &shard) { + return shard->GetBiasGradient(group_idx, + model->param.num_output_group); + }); + + auto dbias = static_cast( + param.learning_rate * + CoordinateDeltaBias(grad.GetGrad(), grad.GetHess())); + model->bias()[group_idx] += dbias; + + // Update residual + dh::ExecuteShards(&shards, [&](std::unique_ptr &shard) { + shard->UpdateBiasResidual(dbias, group_idx, + model->param.num_output_group); + }); + } + } + + void UpdateFeature(int fidx, int group_idx, + std::vector *in_gpair, + gbm::GBLinearModel *model) { + bst_float &w = (*model)[fidx][group_idx]; + // Get gradient + auto grad = dh::ReduceShards( + &shards, [&](std::unique_ptr &shard) { + return shard->GetGradient(group_idx, model->param.num_output_group, + fidx); + }); + + auto dw = static_cast(param.learning_rate * + CoordinateDelta(grad.GetGrad(), grad.GetHess(), + w, param.reg_alpha_denorm, + param.reg_lambda_denorm)); + w += dw; + + dh::ExecuteShards(&shards, [&](std::unique_ptr &shard) { + shard->UpdateResidual(dw, group_idx, model->param.num_output_group, fidx); + }); + } + + // training parameter + GPUCoordinateTrainParam param; + std::unique_ptr selector; + common::Monitor monitor; + + std::vector> shards; + std::vector device_list; +}; + +DMLC_REGISTER_PARAMETER(GPUCoordinateTrainParam); +XGBOOST_REGISTER_LINEAR_UPDATER(GPUCoordinateUpdater, "gpu_coord_descent") + .describe( + "Update linear model according to coordinate descent algorithm. GPU " + "accelerated.") + .set_body([]() { return new GPUCoordinateUpdater(); }); +} // namespace linear +} // namespace xgboost diff --git a/src/linear/updater_shotgun.cc b/src/linear/updater_shotgun.cc index 4cd52d36e..6f42391c8 100644 --- a/src/linear/updater_shotgun.cc +++ b/src/linear/updater_shotgun.cc @@ -61,32 +61,31 @@ class ShotgunUpdater : public LinearUpdater { param_.InitAllowUnknown(args); selector_.reset(FeatureSelector::Create(param_.feature_selector)); } - - void Update(std::vector *in_gpair, DMatrix *p_fmat, + void Update(HostDeviceVector *in_gpair, DMatrix *p_fmat, gbm::GBLinearModel *model, double sum_instance_weight) override { + std::vector &gpair = in_gpair->HostVector(); param_.DenormalizePenalties(sum_instance_weight); - std::vector &gpair = *in_gpair; const int ngroup = model->param.num_output_group; // update bias for (int gid = 0; gid < ngroup; ++gid) { - auto grad = GetBiasGradientParallel(gid, ngroup, *in_gpair, p_fmat); + auto grad = GetBiasGradientParallel(gid, ngroup, in_gpair->HostVector(), p_fmat); auto dbias = static_cast(param_.learning_rate * CoordinateDeltaBias(grad.first, grad.second)); model->bias()[gid] += dbias; - UpdateBiasResidualParallel(gid, ngroup, dbias, in_gpair, p_fmat); + UpdateBiasResidualParallel(gid, ngroup, dbias, &in_gpair->HostVector(), p_fmat); } // lock-free parallel updates of weights - selector_->Setup(*model, *in_gpair, p_fmat, param_.reg_alpha_denorm, - param_.reg_lambda_denorm, 0); + selector_->Setup(*model, in_gpair->HostVector(), p_fmat, + param_.reg_alpha_denorm, param_.reg_lambda_denorm, 0); dmlc::DataIter *iter = p_fmat->ColIterator(); while (iter->Next()) { const ColBatch &batch = iter->Value(); const auto nfeat = static_cast(batch.size); #pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < nfeat; ++i) { - int ii = selector_->NextFeature(i, *model, 0, *in_gpair, p_fmat, + int ii = selector_->NextFeature(i, *model, 0, in_gpair->HostVector(), p_fmat, param_.reg_alpha_denorm, param_.reg_lambda_denorm); if (ii < 0) continue; const bst_uint fid = batch.col_index[ii]; diff --git a/tests/cpp/linear/test_linear.cc b/tests/cpp/linear/test_linear.cc index 7f58c8be1..ac6a01e9b 100644 --- a/tests/cpp/linear/test_linear.cc +++ b/tests/cpp/linear/test_linear.cc @@ -13,13 +13,13 @@ TEST(Linear, shotgun) { auto updater = std::unique_ptr( xgboost::LinearUpdater::Create("shotgun")); updater->Init({{"eta", "1."}}); - std::vector gpair(mat->Info().num_row_, - xgboost::GradientPair(-5, 1.0)); + xgboost::HostDeviceVector gpair( + mat->Info().num_row_, xgboost::GradientPair(-5, 1.0)); xgboost::gbm::GBLinearModel model; model.param.num_feature = mat->Info().num_col_; model.param.num_output_group = 1; model.LazyInitModel(); - updater->Update(&gpair, mat.get(), &model, gpair.size()); + updater->Update(&gpair, mat.get(), &model, gpair.Size()); ASSERT_EQ(model.bias()[0], 5.0f); } @@ -32,13 +32,13 @@ TEST(Linear, coordinate) { auto updater = std::unique_ptr( xgboost::LinearUpdater::Create("coord_descent")); updater->Init({}); - std::vector gpair(mat->Info().num_row_, - xgboost::GradientPair(-5, 1.0)); + xgboost::HostDeviceVector gpair( + mat->Info().num_row_, xgboost::GradientPair(-5, 1.0)); xgboost::gbm::GBLinearModel model; model.param.num_feature = mat->Info().num_col_; model.param.num_output_group = 1; model.LazyInitModel(); - updater->Update(&gpair, mat.get(), &model, gpair.size()); + updater->Update(&gpair, mat.get(), &model, gpair.Size()); ASSERT_EQ(model.bias()[0], 5.0f); } \ No newline at end of file diff --git a/tests/python-gpu/test_gpu_linear.py b/tests/python-gpu/test_gpu_linear.py new file mode 100644 index 000000000..ad727afeb --- /dev/null +++ b/tests/python-gpu/test_gpu_linear.py @@ -0,0 +1,14 @@ +import sys + +sys.path.append('tests/python/') +import test_linear +import testing as tm +import unittest + + +class TestGPULinear(unittest.TestCase): + def test_gpu_coordinate(self): + tm._skip_if_no_sklearn() + variable_param = {'alpha': [.005, .1], 'lambda': [0.005], + 'coordinate_selection': ['cyclic', 'random', 'greedy'], 'n_gpus': [-1, 1]} + test_linear.assert_updater_accuracy('gpu_coord_descent', variable_param) diff --git a/tests/python/test_linear.py b/tests/python/test_linear.py index 26e91ec93..05d3fcdca 100644 --- a/tests/python/test_linear.py +++ b/tests/python/test_linear.py @@ -144,7 +144,8 @@ def assert_updater_accuracy(linear_updater, variable_param): train_classification(param_tmp) train_classification_multi(param_tmp) train_breast_cancer(param_tmp) - train_external_mem(param_tmp) + if 'gpu' not in linear_updater: + train_external_mem(param_tmp) class TestLinear(unittest.TestCase):