diff --git a/src/common/span.h b/src/common/span.h index a59c39e72..cb49b84cd 100644 --- a/src/common/span.h +++ b/src/common/span.h @@ -621,8 +621,8 @@ XGBOOST_DEVICE auto as_writable_bytes(Span s) __span_noexcept -> // NOLIN return {reinterpret_cast(s.data()), s.size_bytes()}; } -} // namespace common -} // namespace xgboost +} // namespace common NOLINT +} // namespace xgboost NOLINT #if defined(_MSC_VER) &&_MSC_VER < 1910 #undef constexpr diff --git a/src/linear/updater_gpu_coordinate.cu b/src/linear/updater_gpu_coordinate.cu index 47c58198f..0d4cbe824 100644 --- a/src/linear/updater_gpu_coordinate.cu +++ b/src/linear/updater_gpu_coordinate.cu @@ -5,8 +5,10 @@ #include #include +#include #include #include "../common/common.h" +#include "../common/span.h" #include "../common/device_helpers.cuh" #include "../common/timer.h" #include "./param.h" @@ -17,8 +19,8 @@ namespace linear { DMLC_REGISTRY_FILE_TAG(updater_gpu_coordinate); -void RescaleIndices(size_t ridx_begin, dh::DVec *data) { - auto d_data = data->Data(); +void RescaleIndices(size_t ridx_begin, dh::DVec *data) { + auto d_data = data->GetSpan(); dh::LaunchN(data->DeviceIdx(), data->Size(), [=] __device__(size_t idx) { d_data[idx].index -= ridx_begin; }); } @@ -27,57 +29,66 @@ class DeviceShard { int device_id_; dh::BulkAllocator ba_; std::vector row_ptr_; - dh::DVec data_; + dh::DVec data_; dh::DVec gpair_; dh::CubMemory temp_; size_t ridx_begin_; size_t ridx_end_; public: - DeviceShard(int device_id, const SparsePage &batch, + DeviceShard(int device_id, + const SparsePage &batch, // column batch bst_uint row_begin, bst_uint row_end, const LinearTrainParam ¶m, const gbm::GBLinearModelParam &model_param) : device_id_(device_id), ridx_begin_(row_begin), ridx_end_(row_end) { + if ( IsEmpty() ) { return; } dh::safe_cuda(cudaSetDevice(device_id_)); // The begin and end indices for the section of each column associated with // this shard std::vector> column_segments; row_ptr_ = {0}; + // iterate through columns for (auto fidx = 0; fidx < batch.Size(); fidx++) { - auto col = batch[fidx]; + common::Span col = batch[fidx]; auto cmp = [](Entry e1, Entry e2) { return e1.index < e2.index; }; auto column_begin = - std::lower_bound(col.data(), col.data() + col.size(), - Entry(row_begin, 0.0f), cmp); + std::lower_bound(col.cbegin(), col.cend(), + xgboost::Entry(row_begin, 0.0f), cmp); auto column_end = - std::upper_bound(col.data(), col.data() + col.size(), - Entry(row_end, 0.0f), cmp); + std::lower_bound(col.cbegin(), col.cend(), + xgboost::Entry(row_end, 0.0f), cmp); column_segments.push_back( - std::make_pair(column_begin - col.data(), column_end - col.data())); - row_ptr_.push_back(row_ptr_.back() + column_end - column_begin); + std::make_pair(column_begin - col.cbegin(), column_end - col.cbegin())); + row_ptr_.push_back(row_ptr_.back() + (column_end - column_begin)); } ba_.Allocate(device_id_, &data_, row_ptr_.back(), &gpair_, - (row_end - row_begin) * model_param.num_output_group); + (row_end - row_begin) * model_param.num_output_group); for (int fidx = 0; fidx < batch.Size(); fidx++) { auto col = batch[fidx]; auto seg = column_segments[fidx]; dh::safe_cuda(cudaMemcpy( - data_.Data() + row_ptr_[fidx], col.data() + seg.first, + data_.GetSpan().subspan(row_ptr_[fidx]).data(), + col.data() + seg.first, sizeof(Entry) * (seg.second - seg.first), cudaMemcpyHostToDevice)); } // Rescale indices with respect to current shard RescaleIndices(ridx_begin_, &data_); } + + bool IsEmpty() { + return (ridx_end_ - ridx_begin_) == 0; + } + void UpdateGpair(const std::vector &host_gpair, const gbm::GBLinearModelParam &model_param) { gpair_.copy(host_gpair.begin() + ridx_begin_ * model_param.num_output_group, - host_gpair.begin() + ridx_end_ * model_param.num_output_group); + host_gpair.begin() + ridx_end_ * model_param.num_output_group); } GradientPair GetBiasGradient(int group_idx, int num_group) { @@ -95,7 +106,7 @@ class DeviceShard { void UpdateBiasResidual(float dbias, int group_idx, int num_groups) { if (dbias == 0.0f) return; - auto d_gpair = gpair_.Data(); + auto d_gpair = gpair_.GetSpan(); dh::LaunchN(device_id_, ridx_end_ - ridx_begin_, [=] __device__(size_t idx) { auto &g = d_gpair[idx * num_groups + group_idx]; g += GradientPair(g.GetHess() * dbias, 0); @@ -104,9 +115,9 @@ class DeviceShard { GradientPair GetGradient(int group_idx, int num_group, int fidx) { dh::safe_cuda(cudaSetDevice(device_id_)); - auto d_col = data_.Data() + row_ptr_[fidx]; + common::Span d_col = data_.GetSpan().subspan(row_ptr_[fidx]); size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx]; - auto d_gpair = gpair_.Data(); + common::Span d_gpair = gpair_.GetSpan(); auto counting = thrust::make_counting_iterator(0ull); auto f = [=] __device__(size_t idx) { auto entry = d_col[idx]; @@ -120,8 +131,8 @@ class DeviceShard { } void UpdateResidual(float dw, int group_idx, int num_groups, int fidx) { - auto d_gpair = gpair_.Data(); - auto d_col = data_.Data() + row_ptr_[fidx]; + common::Span d_gpair = gpair_.GetSpan(); + common::Span d_col = data_.GetSpan().subspan(row_ptr_[fidx]); size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx]; dh::LaunchN(device_id_, col_size, [=] __device__(size_t idx) { auto entry = d_col[idx]; @@ -158,21 +169,19 @@ class GPUCoordinateUpdater : public LinearUpdater { size_t n_devices = static_cast(devices.Size()); size_t row_begin = 0; size_t num_row = static_cast(p_fmat->Info().num_row_); - // Use fast integer ceiling - // See https://stackoverflow.com/a/2745086 - size_t shard_size = (num_row + n_devices - 1) / n_devices; // Partition input matrix into row segments std::vector row_segments; row_segments.push_back(0); for (int d_idx = 0; d_idx < n_devices; ++d_idx) { - size_t row_end = std::min(row_begin + shard_size, num_row); + size_t shard_size = dist_.ShardSize(num_row, d_idx); + size_t row_end = row_begin + shard_size; row_segments.push_back(row_end); row_begin = row_end; } CHECK(p_fmat->SingleColBlock()); - const auto &batch = *p_fmat->GetColumnBatches().begin(); + SparsePage const& batch = *(p_fmat->GetColumnBatches().begin()); shards.resize(n_devices); // Create device shards @@ -194,7 +203,9 @@ class GPUCoordinateUpdater : public LinearUpdater { monitor.Start("UpdateGpair"); // Update gpair dh::ExecuteIndexShards(&shards, [&](int idx, std::unique_ptr& shard) { - shard->UpdateGpair(in_gpair->ConstHostVector(), model->param); + if (!shard->IsEmpty()) { + shard->UpdateGpair(in_gpair->ConstHostVector(), model->param); + } }); monitor.Stop("UpdateGpair"); @@ -225,8 +236,13 @@ class GPUCoordinateUpdater : public LinearUpdater { // Get gradient auto grad = dh::ReduceShards( &shards, [&](std::unique_ptr &shard) { - return shard->GetBiasGradient(group_idx, - model->param.num_output_group); + if (!shard->IsEmpty()) { + GradientPair result = + shard->GetBiasGradient(group_idx, + model->param.num_output_group); + return result; + } + return GradientPair(0, 0); }); auto dbias = static_cast( @@ -236,8 +252,10 @@ class GPUCoordinateUpdater : public LinearUpdater { // Update residual dh::ExecuteIndexShards(&shards, [&](int idx, std::unique_ptr& shard) { - shard->UpdateBiasResidual(dbias, group_idx, - model->param.num_output_group); + if (!shard->IsEmpty()) { + shard->UpdateBiasResidual(dbias, group_idx, + model->param.num_output_group); + } }); } } @@ -249,8 +267,11 @@ class GPUCoordinateUpdater : public LinearUpdater { // Get gradient auto grad = dh::ReduceShards( &shards, [&](std::unique_ptr &shard) { - return shard->GetGradient(group_idx, model->param.num_output_group, - fidx); + if (!shard->IsEmpty()) { + return shard->GetGradient(group_idx, model->param.num_output_group, + fidx); + } + return GradientPair(0, 0); }); auto dw = static_cast(tparam_.learning_rate * @@ -259,8 +280,11 @@ class GPUCoordinateUpdater : public LinearUpdater { tparam_.reg_lambda_denorm)); w += dw; - dh::ExecuteIndexShards(&shards, [&](int idx, std::unique_ptr& shard) { - shard->UpdateResidual(dw, group_idx, model->param.num_output_group, fidx); + dh::ExecuteIndexShards(&shards, [&](int idx, + std::unique_ptr &shard) { + if (!shard->IsEmpty()) { + shard->UpdateResidual(dw, group_idx, model->param.num_output_group, fidx); + } }); } diff --git a/tests/cpp/linear/test_linear.cc b/tests/cpp/linear/test_linear.cc index 2a479a313..4dd27e5de 100644 --- a/tests/cpp/linear/test_linear.cc +++ b/tests/cpp/linear/test_linear.cc @@ -1,4 +1,6 @@ -// Copyright by Contributors +/*! + * Copyright 2018 by Contributors + */ #include #include "../helpers.h" #include "xgboost/gbm.h"