Implement GPU accelerated coordinate descent algorithm (#3178)
* Implement GPU accelerated coordinate descent algorithm. * Exclude external memory tests for GPU
This commit is contained in:
parent
ccf80703ef
commit
a185ddfe03
@ -11,6 +11,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "../../src/gbm/gblinear_model.h"
|
||||
#include "../../src/common/host_device_vector.h"
|
||||
|
||||
namespace xgboost {
|
||||
/*!
|
||||
@ -36,7 +37,7 @@ class LinearUpdater {
|
||||
* \param sum_instance_weight The sum instance weights, used to normalise l1/l2 penalty.
|
||||
*/
|
||||
|
||||
virtual void Update(std::vector<GradientPair>* in_gpair, DMatrix* data,
|
||||
virtual void Update(HostDeviceVector<GradientPair>* in_gpair, DMatrix* data,
|
||||
gbm::GBLinearModel* model,
|
||||
double sum_instance_weight) = 0;
|
||||
|
||||
|
||||
@ -626,9 +626,8 @@ class DMatrix(object):
|
||||
feature_names : list or None
|
||||
"""
|
||||
if self._feature_names is None:
|
||||
return ['f{0}'.format(i) for i in range(self.num_col())]
|
||||
else:
|
||||
return self._feature_names
|
||||
self._feature_names = ['f{0}'.format(i) for i in range(self.num_col())]
|
||||
return self._feature_names
|
||||
|
||||
@property
|
||||
def feature_types(self):
|
||||
|
||||
2
rabit
2
rabit
@ -1 +1 @@
|
||||
Subproject commit 7bc46b8c75a6d530b2ad4efcf407b6aeab71e44f
|
||||
Subproject commit a764d45cfb438cc9f15cf47ce586c02ff2c65d0f
|
||||
@ -374,10 +374,10 @@ class DVec {
|
||||
safe_cuda(cudaSetDevice(this->DeviceIdx()));
|
||||
if (end - begin != Size()) {
|
||||
throw std::runtime_error(
|
||||
"Cannot copy assign vector to DVec, sizes are different");
|
||||
"Cannot copy assign vector to dvec, sizes are different");
|
||||
}
|
||||
safe_cuda(cudaMemcpy(this->Data(), begin.get(),
|
||||
Size() * sizeof(T), cudaMemcpyDefault));
|
||||
safe_cuda(cudaMemcpy(this->Data(), begin.get(), Size() * sizeof(T),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
};
|
||||
|
||||
@ -544,7 +544,7 @@ struct CubMemory {
|
||||
size_t temp_storage_bytes;
|
||||
|
||||
// Thrust
|
||||
using ValueT = char;
|
||||
using value_type = char; // NOLINT
|
||||
|
||||
CubMemory() : d_temp_storage(nullptr), temp_storage_bytes(0) {}
|
||||
|
||||
@ -807,18 +807,20 @@ void SumReduction(dh::CubMemory &tmp_mem, dh::DVec<T> &in, dh::DVec<T> &out,
|
||||
* @param nVals number of elements in the input array
|
||||
*/
|
||||
template <typename T>
|
||||
T SumReduction(dh::CubMemory &tmp_mem, T *in, int nVals) {
|
||||
typename std::iterator_traits<T>::value_type SumReduction(dh::CubMemory &tmp_mem, T in, int nVals) {
|
||||
using ValueT = typename std::iterator_traits<T>::value_type;
|
||||
size_t tmpSize;
|
||||
dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, in, nVals));
|
||||
// Allocate small extra memory for the return value
|
||||
tmp_mem.LazyAllocate(tmpSize + sizeof(T));
|
||||
auto ptr = reinterpret_cast<T *>(tmp_mem.d_temp_storage) + 1;
|
||||
tmp_mem.LazyAllocate(tmpSize + sizeof(ValueT));
|
||||
auto ptr = reinterpret_cast<ValueT *>(tmp_mem.d_temp_storage) + 1;
|
||||
dh::safe_cuda(cub::DeviceReduce::Sum(
|
||||
reinterpret_cast<void *>(ptr), tmpSize, in,
|
||||
reinterpret_cast<T *>(tmp_mem.d_temp_storage), nVals));
|
||||
T sum;
|
||||
dh::safe_cuda(cudaMemcpy(&sum, tmp_mem.d_temp_storage, sizeof(T),
|
||||
cudaMemcpyDeviceToHost));
|
||||
reinterpret_cast<void *>(ptr), tmpSize, in,
|
||||
reinterpret_cast<ValueT *>(tmp_mem.d_temp_storage),
|
||||
nVals));
|
||||
ValueT sum;
|
||||
dh::safe_cuda(cudaMemcpy(&sum, tmp_mem.d_temp_storage, sizeof(ValueT),
|
||||
cudaMemcpyDeviceToHost));
|
||||
return sum;
|
||||
}
|
||||
|
||||
@ -876,7 +878,8 @@ void Gather(int device_idx, T *out, const T *in, const int *instId, int nVals) {
|
||||
* \class AllReducer
|
||||
*
|
||||
* \brief All reducer class that manages its own communication group and
|
||||
* streams. Must be initialised before use. If XGBoost is compiled without NCCL this is a dummy class that will error if used with more than one GPU.
|
||||
* streams. Must be initialised before use. If XGBoost is compiled without NCCL
|
||||
* this is a dummy class that will error if used with more than one GPU.
|
||||
*/
|
||||
|
||||
class AllReducer {
|
||||
@ -912,7 +915,8 @@ class AllReducer {
|
||||
}
|
||||
initialised = true;
|
||||
#else
|
||||
CHECK_EQ(device_ordinals.size(), 1) << "XGBoost must be compiled with NCCL to use more than one GPU.";
|
||||
CHECK_EQ(device_ordinals.size(), 1)
|
||||
<< "XGBoost must be compiled with NCCL to use more than one GPU.";
|
||||
#endif
|
||||
}
|
||||
~AllReducer() {
|
||||
@ -929,16 +933,13 @@ class AllReducer {
|
||||
}
|
||||
|
||||
/**
|
||||
* \fn void AllReduceSum(int communication_group_idx, const double *sendbuff,
|
||||
* double *recvbuff, int count)
|
||||
*
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
||||
* streams or comms.
|
||||
*
|
||||
* \param communication_group_idx Zero-based index of the
|
||||
* communication group. \param sendbuff The sendbuff. \param
|
||||
* sendbuff The sendbuff. \param [in,out] recvbuff
|
||||
* The recvbuff. \param count Number of.
|
||||
* \param communication_group_idx Zero-based index of the communication group.
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
|
||||
void AllReduceSum(int communication_group_idx, const double *sendbuff,
|
||||
@ -954,17 +955,14 @@ class AllReducer {
|
||||
}
|
||||
|
||||
/**
|
||||
* \fn void AllReduceSum(int communication_group_idx, const int64_t *sendbuff, int64_t *recvbuff, int count)
|
||||
*
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms.
|
||||
*
|
||||
* \param communication_group_idx Zero-based index of the communication group. \param
|
||||
* sendbuff The sendbuff. \param sendbuff
|
||||
* The sendbuff. \param [in,out] recvbuff The recvbuff.
|
||||
* \param count Number of.
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param [in,out] recvbuff If non-null, the recvbuff.
|
||||
* \param count Number of.
|
||||
* \param count Number of.
|
||||
*
|
||||
* \param communication_group_idx Zero-based index of the communication group. \param sendbuff.
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of.
|
||||
*/
|
||||
|
||||
void AllReduceSum(int communication_group_idx, const int64_t *sendbuff,
|
||||
@ -993,4 +991,53 @@ class AllReducer {
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Executes some operation on each element of the input vector, using a
|
||||
* single controlling thread for each element.
|
||||
*
|
||||
* \tparam T Generic type parameter.
|
||||
* \tparam FunctionT Type of the function t.
|
||||
* \param shards The shards.
|
||||
* \param f The func_t to process.
|
||||
*/
|
||||
|
||||
template <typename T, typename FunctionT>
|
||||
void ExecuteShards(std::vector<T> *shards, FunctionT f) {
|
||||
auto previous_num_threads = omp_get_max_threads();
|
||||
omp_set_num_threads(shards->size());
|
||||
#pragma omp parallel
|
||||
{
|
||||
auto cpu_thread_id = omp_get_thread_num();
|
||||
f(shards->at(cpu_thread_id));
|
||||
}
|
||||
omp_set_num_threads(previous_num_threads);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Executes some operation on each element of the input vector, using a single controlling
|
||||
* thread for each element, returns the sum of the results.
|
||||
*
|
||||
* \tparam ReduceT Type of the reduce t.
|
||||
* \tparam T Generic type parameter.
|
||||
* \tparam FunctionT Type of the function t.
|
||||
* \param shards The shards.
|
||||
* \param f The func_t to process.
|
||||
*
|
||||
* \return A reduce_t.
|
||||
*/
|
||||
|
||||
template <typename ReduceT,typename T, typename FunctionT>
|
||||
ReduceT ReduceShards(std::vector<T> *shards, FunctionT f) {
|
||||
auto previous_num_threads = omp_get_max_threads();
|
||||
omp_set_num_threads(shards->size());
|
||||
std::vector<ReduceT> sums(shards->size());
|
||||
#pragma omp parallel
|
||||
{
|
||||
auto cpu_thread_id = omp_get_thread_num();
|
||||
sums[cpu_thread_id] = f(shards->at(cpu_thread_id));
|
||||
}
|
||||
omp_set_num_threads(previous_num_threads);
|
||||
return std::accumulate(sums.begin(), sums.end(), ReduceT());
|
||||
}
|
||||
} // namespace dh
|
||||
|
||||
@ -84,15 +84,17 @@ class GBLinear : public GradientBooster {
|
||||
monitor_.Start("DoBoost");
|
||||
|
||||
if (!p_fmat->HaveColAccess(false)) {
|
||||
monitor_.Start("InitColAccess");
|
||||
std::vector<bool> enabled(p_fmat->Info().num_col_, true);
|
||||
p_fmat->InitColAccess(enabled, 1.0f, param_.max_row_perbatch, false);
|
||||
monitor_.Stop("InitColAccess");
|
||||
}
|
||||
|
||||
model_.LazyInitModel();
|
||||
this->LazySumWeights(p_fmat);
|
||||
|
||||
if (!this->CheckConvergence()) {
|
||||
updater_->Update(&in_gpair->HostVector(), p_fmat, &model_, sum_instance_weight_);
|
||||
updater_->Update(in_gpair, p_fmat, &model_, sum_instance_weight_);
|
||||
}
|
||||
this->UpdatePredictionCache();
|
||||
|
||||
|
||||
@ -25,5 +25,8 @@ namespace linear {
|
||||
// List of files that will be force linked in static links.
|
||||
DMLC_REGISTRY_LINK_TAG(updater_shotgun);
|
||||
DMLC_REGISTRY_LINK_TAG(updater_coordinate);
|
||||
#ifdef XGBOOST_USE_CUDA
|
||||
DMLC_REGISTRY_LINK_TAG(updater_gpu_coordinate);
|
||||
#endif
|
||||
} // namespace linear
|
||||
} // namespace xgboost
|
||||
|
||||
@ -84,48 +84,46 @@ class CoordinateUpdater : public LinearUpdater {
|
||||
selector.reset(FeatureSelector::Create(param.feature_selector));
|
||||
monitor.Init("CoordinateUpdater", param.debug_verbose);
|
||||
}
|
||||
|
||||
void Update(std::vector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
||||
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
||||
gbm::GBLinearModel *model, double sum_instance_weight) override {
|
||||
param.DenormalizePenalties(sum_instance_weight);
|
||||
const int ngroup = model->param.num_output_group;
|
||||
// update bias
|
||||
for (int group_idx = 0; group_idx < ngroup; ++group_idx) {
|
||||
auto grad = GetBiasGradientParallel(group_idx, ngroup, *in_gpair, p_fmat);
|
||||
auto grad = GetBiasGradientParallel(group_idx, ngroup, in_gpair->HostVector(), p_fmat);
|
||||
auto dbias = static_cast<float>(param.learning_rate *
|
||||
CoordinateDeltaBias(grad.first, grad.second));
|
||||
model->bias()[group_idx] += dbias;
|
||||
UpdateBiasResidualParallel(group_idx, ngroup, dbias, in_gpair, p_fmat);
|
||||
UpdateBiasResidualParallel(group_idx, ngroup,
|
||||
dbias, &in_gpair->HostVector(), p_fmat);
|
||||
}
|
||||
// prepare for updating the weights
|
||||
selector->Setup(*model, *in_gpair, p_fmat, param.reg_alpha_denorm,
|
||||
selector->Setup(*model, in_gpair->HostVector(), p_fmat, param.reg_alpha_denorm,
|
||||
param.reg_lambda_denorm, param.top_k);
|
||||
// update weights
|
||||
for (int group_idx = 0; group_idx < ngroup; ++group_idx) {
|
||||
for (unsigned i = 0U; i < model->param.num_feature; i++) {
|
||||
int fidx = selector->NextFeature(i, *model, group_idx, *in_gpair, p_fmat,
|
||||
int fidx = selector->NextFeature(i, *model, group_idx, in_gpair->HostVector(), p_fmat,
|
||||
param.reg_alpha_denorm, param.reg_lambda_denorm);
|
||||
if (fidx < 0) break;
|
||||
this->UpdateFeature(fidx, group_idx, in_gpair, p_fmat, model);
|
||||
this->UpdateFeature(fidx, group_idx, &in_gpair->HostVector(), p_fmat, model);
|
||||
}
|
||||
}
|
||||
monitor.Stop("UpdateFeature");
|
||||
}
|
||||
|
||||
inline void UpdateFeature(int fidx, int group_idx, std::vector<GradientPair> *in_gpair,
|
||||
DMatrix *p_fmat, gbm::GBLinearModel *model) {
|
||||
const int ngroup = model->param.num_output_group;
|
||||
bst_float &w = (*model)[fidx][group_idx];
|
||||
monitor.Start("GetGradientParallel");
|
||||
auto gradient = GetGradientParallel(group_idx, ngroup, fidx, *in_gpair, p_fmat);
|
||||
monitor.Stop("GetGradientParallel");
|
||||
auto gradient =
|
||||
GetGradientParallel(group_idx, ngroup, fidx, *in_gpair, p_fmat);
|
||||
auto dw = static_cast<float>(
|
||||
param.learning_rate *
|
||||
CoordinateDelta(gradient.first, gradient.second, w, param.reg_alpha_denorm,
|
||||
param.reg_lambda_denorm));
|
||||
w += dw;
|
||||
monitor.Start("UpdateResidualParallel");
|
||||
UpdateResidualParallel(fidx, group_idx, ngroup, dw, in_gpair, p_fmat);
|
||||
monitor.Stop("UpdateResidualParallel");
|
||||
}
|
||||
|
||||
// training parameter
|
||||
|
||||
346
src/linear/updater_gpu_coordinate.cu
Normal file
346
src/linear/updater_gpu_coordinate.cu
Normal file
@ -0,0 +1,346 @@
|
||||
/*!
|
||||
* Copyright 2018 by Contributors
|
||||
* \author Rory Mitchell
|
||||
*/
|
||||
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/inner_product.h>
|
||||
#include <xgboost/linear_updater.h>
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/timer.h"
|
||||
#include "coordinate_common.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace linear {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(updater_gpu_coordinate);
|
||||
|
||||
// training parameter
|
||||
struct GPUCoordinateTrainParam
|
||||
: public dmlc::Parameter<GPUCoordinateTrainParam> {
|
||||
/*! \brief learning_rate */
|
||||
float learning_rate;
|
||||
/*! \brief regularization weight for L2 norm */
|
||||
float reg_lambda;
|
||||
/*! \brief regularization weight for L1 norm */
|
||||
float reg_alpha;
|
||||
int feature_selector;
|
||||
int top_k;
|
||||
int debug_verbose;
|
||||
int n_gpus;
|
||||
int gpu_id;
|
||||
bool silent;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(GPUCoordinateTrainParam) {
|
||||
DMLC_DECLARE_FIELD(learning_rate)
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(1.0f)
|
||||
.describe("Learning rate of each update.");
|
||||
DMLC_DECLARE_FIELD(reg_lambda)
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(0.0f)
|
||||
.describe("L2 regularization on weights.");
|
||||
DMLC_DECLARE_FIELD(reg_alpha)
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(0.0f)
|
||||
.describe("L1 regularization on weights.");
|
||||
DMLC_DECLARE_FIELD(feature_selector)
|
||||
.set_default(kCyclic)
|
||||
.add_enum("cyclic", kCyclic)
|
||||
.add_enum("shuffle", kShuffle)
|
||||
.add_enum("thrifty", kThrifty)
|
||||
.add_enum("greedy", kGreedy)
|
||||
.add_enum("random", kRandom)
|
||||
.describe("Feature selection or ordering method.");
|
||||
DMLC_DECLARE_FIELD(top_k).set_lower_bound(0).set_default(0).describe(
|
||||
"The number of top features to select in 'thrifty' feature_selector. "
|
||||
"The value of zero means using all the features.");
|
||||
DMLC_DECLARE_FIELD(debug_verbose)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("flag to print out detailed breakdown of runtime");
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(1).describe(
|
||||
"Number of devices to use.");
|
||||
DMLC_DECLARE_FIELD(gpu_id).set_default(0).describe(
|
||||
"Primary device ordinal.");
|
||||
DMLC_DECLARE_FIELD(silent).set_default(false).describe(
|
||||
"Do not print information during trainig.");
|
||||
// alias of parameters
|
||||
DMLC_DECLARE_ALIAS(learning_rate, eta);
|
||||
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
||||
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
|
||||
}
|
||||
/*! \brief Denormalizes the regularization penalties - to be called at each
|
||||
* update */
|
||||
void DenormalizePenalties(double sum_instance_weight) {
|
||||
reg_lambda_denorm = reg_lambda * sum_instance_weight;
|
||||
reg_alpha_denorm = reg_alpha * sum_instance_weight;
|
||||
}
|
||||
// denormalizated regularization penalties
|
||||
float reg_lambda_denorm;
|
||||
float reg_alpha_denorm;
|
||||
};
|
||||
|
||||
void RescaleIndices(size_t ridx_begin, dh::DVec<SparseBatch::Entry> *data) {
|
||||
auto d_data = data->Data();
|
||||
dh::LaunchN(data->DeviceIdx(), data->Size(),
|
||||
[=] __device__(size_t idx) { d_data[idx].index -= ridx_begin; });
|
||||
}
|
||||
|
||||
class DeviceShard {
|
||||
int device_idx_;
|
||||
int normalised_device_idx_; // Device index counting from param.gpu_id
|
||||
dh::BulkAllocator<dh::MemoryType::kDevice> ba_;
|
||||
std::vector<size_t> row_ptr_;
|
||||
dh::DVec<SparseBatch::Entry> data_;
|
||||
dh::DVec<GradientPair> gpair_;
|
||||
dh::CubMemory temp_;
|
||||
size_t ridx_begin_;
|
||||
size_t ridx_end_;
|
||||
|
||||
public:
|
||||
DeviceShard(int device_idx, int normalised_device_idx, const ColBatch &batch,
|
||||
bst_uint row_begin, bst_uint row_end,
|
||||
const GPUCoordinateTrainParam ¶m,
|
||||
const gbm::GBLinearModelParam &model_param)
|
||||
: device_idx_(device_idx),
|
||||
normalised_device_idx_(normalised_device_idx),
|
||||
ridx_begin_(row_begin),
|
||||
ridx_end_(row_end) {
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
// The begin and end indices for the section of each column associated with
|
||||
// this shard
|
||||
std::vector<std::pair<bst_uint, bst_uint>> column_segments;
|
||||
row_ptr_ = {0};
|
||||
for (auto fidx = 0; fidx < batch.size; fidx++) {
|
||||
auto col = batch[fidx];
|
||||
auto cmp = [](SparseBatch::Entry e1, SparseBatch::Entry e2) {
|
||||
return e1.index < e2.index;
|
||||
};
|
||||
auto column_begin =
|
||||
std::lower_bound(col.data, col.data + col.length,
|
||||
SparseBatch::Entry(row_begin, 0.0f), cmp);
|
||||
auto column_end =
|
||||
std::upper_bound(col.data, col.data + col.length,
|
||||
SparseBatch::Entry(row_end, 0.0f), cmp);
|
||||
column_segments.push_back(
|
||||
std::make_pair(column_begin - col.data, column_end - col.data));
|
||||
row_ptr_.push_back(row_ptr_.back() + column_end - column_begin);
|
||||
}
|
||||
ba_.Allocate(device_idx, param.silent, &data_, row_ptr_.back(), &gpair_,
|
||||
(row_end - row_begin) * model_param.num_output_group);
|
||||
|
||||
for (int fidx = 0; fidx < batch.size; fidx++) {
|
||||
ColBatch::Inst col = batch[fidx];
|
||||
thrust::copy(col.data + column_segments[fidx].first,
|
||||
col.data + column_segments[fidx].second,
|
||||
data_.tbegin() + row_ptr_[fidx]);
|
||||
}
|
||||
// Rescale indices with respect to current shard
|
||||
RescaleIndices(ridx_begin_, &data_);
|
||||
}
|
||||
void UpdateGpair(const std::vector<GradientPair> &host_gpair,
|
||||
const gbm::GBLinearModelParam &model_param) {
|
||||
gpair_.copy(host_gpair.begin() + ridx_begin_ * model_param.num_output_group,
|
||||
host_gpair.begin() + ridx_end_ * model_param.num_output_group);
|
||||
}
|
||||
|
||||
GradientPair GetBiasGradient(int group_idx, int num_group) {
|
||||
auto counting = thrust::make_counting_iterator(0ull);
|
||||
auto f = [=] __device__(size_t idx) {
|
||||
return idx * num_group + group_idx;
|
||||
}; // NOLINT
|
||||
thrust::transform_iterator<decltype(f), decltype(counting), size_t> skip(
|
||||
counting, f);
|
||||
auto perm = thrust::make_permutation_iterator(gpair_.tbegin(), skip);
|
||||
|
||||
return dh::SumReduction(temp_, perm, ridx_end_ - ridx_begin_);
|
||||
}
|
||||
|
||||
void UpdateBiasResidual(float dbias, int group_idx, int num_groups) {
|
||||
if (dbias == 0.0f) return;
|
||||
auto d_gpair = gpair_.Data();
|
||||
dh::LaunchN(device_idx_, ridx_end_ - ridx_begin_, [=] __device__(size_t idx) {
|
||||
auto &g = d_gpair[idx * num_groups + group_idx];
|
||||
g += GradientPair(g.GetHess() * dbias, 0);
|
||||
});
|
||||
}
|
||||
|
||||
GradientPair GetGradient(int group_idx, int num_group, int fidx) {
|
||||
auto d_col = data_.Data() + row_ptr_[fidx];
|
||||
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
||||
auto d_gpair = gpair_.Data();
|
||||
auto counting = thrust::make_counting_iterator(0ull);
|
||||
auto f = [=] __device__(size_t idx) {
|
||||
auto entry = d_col[idx];
|
||||
auto g = d_gpair[entry.index * num_group + group_idx];
|
||||
return GradientPair(g.GetGrad() * entry.fvalue,
|
||||
g.GetHess() * entry.fvalue * entry.fvalue);
|
||||
}; // NOLINT
|
||||
thrust::transform_iterator<decltype(f), decltype(counting), GradientPair>
|
||||
multiply_iterator(counting, f);
|
||||
return dh::SumReduction(temp_, multiply_iterator, col_size);
|
||||
}
|
||||
|
||||
void UpdateResidual(float dw, int group_idx, int num_groups, int fidx) {
|
||||
auto d_gpair = gpair_.Data();
|
||||
auto d_col = data_.Data() + row_ptr_[fidx];
|
||||
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
||||
dh::LaunchN(device_idx_, col_size, [=] __device__(size_t idx) {
|
||||
auto entry = d_col[idx];
|
||||
auto &g = d_gpair[entry.index * num_groups + group_idx];
|
||||
g += GradientPair(g.GetHess() * dw * entry.fvalue, 0);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \class GPUCoordinateUpdater
|
||||
*
|
||||
* \brief Coordinate descent algorithm that updates one feature per iteration
|
||||
*/
|
||||
|
||||
class GPUCoordinateUpdater : public LinearUpdater {
|
||||
public:
|
||||
// set training parameter
|
||||
void Init(
|
||||
const std::vector<std::pair<std::string, std::string>> &args) override {
|
||||
param.InitAllowUnknown(args);
|
||||
selector.reset(FeatureSelector::Create(param.feature_selector));
|
||||
monitor.Init("GPUCoordinateUpdater", param.debug_verbose);
|
||||
}
|
||||
|
||||
void LazyInitShards(DMatrix *p_fmat,
|
||||
const gbm::GBLinearModelParam &model_param) {
|
||||
if (!shards.empty()) return;
|
||||
int n_devices = dh::NDevices(param.n_gpus, p_fmat->Info().num_row_);
|
||||
bst_uint row_begin = 0;
|
||||
bst_uint shard_size =
|
||||
std::ceil(static_cast<double>(p_fmat->Info().num_row_) / n_devices);
|
||||
|
||||
device_list.resize(n_devices);
|
||||
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
|
||||
int device_idx = (param.gpu_id + d_idx) % dh::NVisibleDevices();
|
||||
device_list[d_idx] = device_idx;
|
||||
}
|
||||
// Partition input matrix into row segments
|
||||
std::vector<size_t> row_segments;
|
||||
row_segments.push_back(0);
|
||||
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
|
||||
bst_uint row_end = std::min(static_cast<size_t>(row_begin + shard_size),
|
||||
p_fmat->Info().num_row_);
|
||||
row_segments.push_back(row_end);
|
||||
row_begin = row_end;
|
||||
}
|
||||
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
|
||||
CHECK(p_fmat->SingleColBlock());
|
||||
iter->Next();
|
||||
auto batch = iter->Value();
|
||||
|
||||
shards.resize(n_devices);
|
||||
// Create device shards
|
||||
dh::ExecuteShards(&shards, [&](std::unique_ptr<DeviceShard> &shard) {
|
||||
auto idx = &shard - &shards[0];
|
||||
shard = std::unique_ptr<DeviceShard>(
|
||||
new DeviceShard(device_list[idx], idx, batch, row_segments[idx],
|
||||
row_segments[idx + 1], param, model_param));
|
||||
});
|
||||
}
|
||||
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
||||
gbm::GBLinearModel *model, double sum_instance_weight) override {
|
||||
param.DenormalizePenalties(sum_instance_weight);
|
||||
monitor.Start("LazyInitShards");
|
||||
this->LazyInitShards(p_fmat, model->param);
|
||||
monitor.Stop("LazyInitShards");
|
||||
|
||||
monitor.Start("UpdateGpair");
|
||||
// Update gpair
|
||||
dh::ExecuteShards(&shards, [&](std::unique_ptr<DeviceShard> &shard) {
|
||||
shard->UpdateGpair(in_gpair->HostVector(), model->param);
|
||||
});
|
||||
monitor.Stop("UpdateGpair");
|
||||
|
||||
monitor.Start("UpdateBias");
|
||||
this->UpdateBias(p_fmat, model);
|
||||
monitor.Stop("UpdateBias");
|
||||
// prepare for updating the weights
|
||||
selector->Setup(*model, in_gpair->HostVector(), p_fmat,
|
||||
param.reg_alpha_denorm, param.reg_lambda_denorm,
|
||||
param.top_k);
|
||||
monitor.Start("UpdateFeature");
|
||||
for (auto group_idx = 0; group_idx < model->param.num_output_group;
|
||||
++group_idx) {
|
||||
for (auto i = 0U; i < model->param.num_feature; i++) {
|
||||
auto fidx = selector->NextFeature(
|
||||
i, *model, group_idx, in_gpair->HostVector(), p_fmat,
|
||||
param.reg_alpha_denorm, param.reg_lambda_denorm);
|
||||
if (fidx < 0) break;
|
||||
this->UpdateFeature(fidx, group_idx, &in_gpair->HostVector(), model);
|
||||
}
|
||||
}
|
||||
monitor.Stop("UpdateFeature");
|
||||
}
|
||||
|
||||
void UpdateBias(DMatrix *p_fmat, gbm::GBLinearModel *model) {
|
||||
for (int group_idx = 0; group_idx < model->param.num_output_group;
|
||||
++group_idx) {
|
||||
// Get gradient
|
||||
auto grad = dh::ReduceShards<GradientPair>(
|
||||
&shards, [&](std::unique_ptr<DeviceShard> &shard) {
|
||||
return shard->GetBiasGradient(group_idx,
|
||||
model->param.num_output_group);
|
||||
});
|
||||
|
||||
auto dbias = static_cast<float>(
|
||||
param.learning_rate *
|
||||
CoordinateDeltaBias(grad.GetGrad(), grad.GetHess()));
|
||||
model->bias()[group_idx] += dbias;
|
||||
|
||||
// Update residual
|
||||
dh::ExecuteShards(&shards, [&](std::unique_ptr<DeviceShard> &shard) {
|
||||
shard->UpdateBiasResidual(dbias, group_idx,
|
||||
model->param.num_output_group);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateFeature(int fidx, int group_idx,
|
||||
std::vector<GradientPair> *in_gpair,
|
||||
gbm::GBLinearModel *model) {
|
||||
bst_float &w = (*model)[fidx][group_idx];
|
||||
// Get gradient
|
||||
auto grad = dh::ReduceShards<GradientPair>(
|
||||
&shards, [&](std::unique_ptr<DeviceShard> &shard) {
|
||||
return shard->GetGradient(group_idx, model->param.num_output_group,
|
||||
fidx);
|
||||
});
|
||||
|
||||
auto dw = static_cast<float>(param.learning_rate *
|
||||
CoordinateDelta(grad.GetGrad(), grad.GetHess(),
|
||||
w, param.reg_alpha_denorm,
|
||||
param.reg_lambda_denorm));
|
||||
w += dw;
|
||||
|
||||
dh::ExecuteShards(&shards, [&](std::unique_ptr<DeviceShard> &shard) {
|
||||
shard->UpdateResidual(dw, group_idx, model->param.num_output_group, fidx);
|
||||
});
|
||||
}
|
||||
|
||||
// training parameter
|
||||
GPUCoordinateTrainParam param;
|
||||
std::unique_ptr<FeatureSelector> selector;
|
||||
common::Monitor monitor;
|
||||
|
||||
std::vector<std::unique_ptr<DeviceShard>> shards;
|
||||
std::vector<int> device_list;
|
||||
};
|
||||
|
||||
DMLC_REGISTER_PARAMETER(GPUCoordinateTrainParam);
|
||||
XGBOOST_REGISTER_LINEAR_UPDATER(GPUCoordinateUpdater, "gpu_coord_descent")
|
||||
.describe(
|
||||
"Update linear model according to coordinate descent algorithm. GPU "
|
||||
"accelerated.")
|
||||
.set_body([]() { return new GPUCoordinateUpdater(); });
|
||||
} // namespace linear
|
||||
} // namespace xgboost
|
||||
@ -61,32 +61,31 @@ class ShotgunUpdater : public LinearUpdater {
|
||||
param_.InitAllowUnknown(args);
|
||||
selector_.reset(FeatureSelector::Create(param_.feature_selector));
|
||||
}
|
||||
|
||||
void Update(std::vector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
||||
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
||||
gbm::GBLinearModel *model, double sum_instance_weight) override {
|
||||
std::vector<GradientPair> &gpair = in_gpair->HostVector();
|
||||
param_.DenormalizePenalties(sum_instance_weight);
|
||||
std::vector<GradientPair> &gpair = *in_gpair;
|
||||
const int ngroup = model->param.num_output_group;
|
||||
|
||||
// update bias
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
auto grad = GetBiasGradientParallel(gid, ngroup, *in_gpair, p_fmat);
|
||||
auto grad = GetBiasGradientParallel(gid, ngroup, in_gpair->HostVector(), p_fmat);
|
||||
auto dbias = static_cast<bst_float>(param_.learning_rate *
|
||||
CoordinateDeltaBias(grad.first, grad.second));
|
||||
model->bias()[gid] += dbias;
|
||||
UpdateBiasResidualParallel(gid, ngroup, dbias, in_gpair, p_fmat);
|
||||
UpdateBiasResidualParallel(gid, ngroup, dbias, &in_gpair->HostVector(), p_fmat);
|
||||
}
|
||||
|
||||
// lock-free parallel updates of weights
|
||||
selector_->Setup(*model, *in_gpair, p_fmat, param_.reg_alpha_denorm,
|
||||
param_.reg_lambda_denorm, 0);
|
||||
selector_->Setup(*model, in_gpair->HostVector(), p_fmat,
|
||||
param_.reg_alpha_denorm, param_.reg_lambda_denorm, 0);
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
|
||||
while (iter->Next()) {
|
||||
const ColBatch &batch = iter->Value();
|
||||
const auto nfeat = static_cast<bst_omp_uint>(batch.size);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
||||
int ii = selector_->NextFeature(i, *model, 0, *in_gpair, p_fmat,
|
||||
int ii = selector_->NextFeature(i, *model, 0, in_gpair->HostVector(), p_fmat,
|
||||
param_.reg_alpha_denorm, param_.reg_lambda_denorm);
|
||||
if (ii < 0) continue;
|
||||
const bst_uint fid = batch.col_index[ii];
|
||||
|
||||
@ -13,13 +13,13 @@ TEST(Linear, shotgun) {
|
||||
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
|
||||
xgboost::LinearUpdater::Create("shotgun"));
|
||||
updater->Init({{"eta", "1."}});
|
||||
std::vector<xgboost::GradientPair> gpair(mat->Info().num_row_,
|
||||
xgboost::GradientPair(-5, 1.0));
|
||||
xgboost::HostDeviceVector<xgboost::GradientPair> gpair(
|
||||
mat->Info().num_row_, xgboost::GradientPair(-5, 1.0));
|
||||
xgboost::gbm::GBLinearModel model;
|
||||
model.param.num_feature = mat->Info().num_col_;
|
||||
model.param.num_output_group = 1;
|
||||
model.LazyInitModel();
|
||||
updater->Update(&gpair, mat.get(), &model, gpair.size());
|
||||
updater->Update(&gpair, mat.get(), &model, gpair.Size());
|
||||
|
||||
ASSERT_EQ(model.bias()[0], 5.0f);
|
||||
}
|
||||
@ -32,13 +32,13 @@ TEST(Linear, coordinate) {
|
||||
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
|
||||
xgboost::LinearUpdater::Create("coord_descent"));
|
||||
updater->Init({});
|
||||
std::vector<xgboost::GradientPair> gpair(mat->Info().num_row_,
|
||||
xgboost::GradientPair(-5, 1.0));
|
||||
xgboost::HostDeviceVector<xgboost::GradientPair> gpair(
|
||||
mat->Info().num_row_, xgboost::GradientPair(-5, 1.0));
|
||||
xgboost::gbm::GBLinearModel model;
|
||||
model.param.num_feature = mat->Info().num_col_;
|
||||
model.param.num_output_group = 1;
|
||||
model.LazyInitModel();
|
||||
updater->Update(&gpair, mat.get(), &model, gpair.size());
|
||||
updater->Update(&gpair, mat.get(), &model, gpair.Size());
|
||||
|
||||
ASSERT_EQ(model.bias()[0], 5.0f);
|
||||
}
|
||||
14
tests/python-gpu/test_gpu_linear.py
Normal file
14
tests/python-gpu/test_gpu_linear.py
Normal file
@ -0,0 +1,14 @@
|
||||
import sys
|
||||
|
||||
sys.path.append('tests/python/')
|
||||
import test_linear
|
||||
import testing as tm
|
||||
import unittest
|
||||
|
||||
|
||||
class TestGPULinear(unittest.TestCase):
|
||||
def test_gpu_coordinate(self):
|
||||
tm._skip_if_no_sklearn()
|
||||
variable_param = {'alpha': [.005, .1], 'lambda': [0.005],
|
||||
'coordinate_selection': ['cyclic', 'random', 'greedy'], 'n_gpus': [-1, 1]}
|
||||
test_linear.assert_updater_accuracy('gpu_coord_descent', variable_param)
|
||||
@ -144,7 +144,8 @@ def assert_updater_accuracy(linear_updater, variable_param):
|
||||
train_classification(param_tmp)
|
||||
train_classification_multi(param_tmp)
|
||||
train_breast_cancer(param_tmp)
|
||||
train_external_mem(param_tmp)
|
||||
if 'gpu' not in linear_updater:
|
||||
train_external_mem(param_tmp)
|
||||
|
||||
|
||||
class TestLinear(unittest.TestCase):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user