Make HostDeviceVector single gpu only (#4773)
* Make HostDeviceVector single gpu only
This commit is contained in:
@@ -19,12 +19,6 @@ namespace linear {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(updater_gpu_coordinate);
|
||||
|
||||
void RescaleIndices(int device_idx, size_t ridx_begin,
|
||||
common::Span<xgboost::Entry> data) {
|
||||
dh::LaunchN(device_idx, data.size(),
|
||||
[=] __device__(size_t idx) { data[idx].index -= ridx_begin; });
|
||||
}
|
||||
|
||||
class DeviceShard {
|
||||
int device_id_;
|
||||
dh::BulkAllocator ba_;
|
||||
@@ -32,18 +26,16 @@ class DeviceShard {
|
||||
common::Span<xgboost::Entry> data_;
|
||||
common::Span<GradientPair> gpair_;
|
||||
dh::CubMemory temp_;
|
||||
size_t ridx_begin_;
|
||||
size_t ridx_end_;
|
||||
size_t shard_size_;
|
||||
|
||||
public:
|
||||
DeviceShard(int device_id,
|
||||
const SparsePage &batch, // column batch
|
||||
bst_uint row_begin, bst_uint row_end,
|
||||
bst_uint shard_size,
|
||||
const LinearTrainParam ¶m,
|
||||
const gbm::GBLinearModelParam &model_param)
|
||||
: device_id_(device_id),
|
||||
ridx_begin_(row_begin),
|
||||
ridx_end_(row_end) {
|
||||
shard_size_(shard_size) {
|
||||
if ( IsEmpty() ) { return; }
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
// The begin and end indices for the section of each column associated with
|
||||
@@ -51,25 +43,25 @@ class DeviceShard {
|
||||
std::vector<std::pair<bst_uint, bst_uint>> column_segments;
|
||||
row_ptr_ = {0};
|
||||
// iterate through columns
|
||||
for (auto fidx = 0; fidx < batch.Size(); fidx++) {
|
||||
for (size_t fidx = 0; fidx < batch.Size(); fidx++) {
|
||||
common::Span<Entry const> col = batch[fidx];
|
||||
auto cmp = [](Entry e1, Entry e2) {
|
||||
return e1.index < e2.index;
|
||||
};
|
||||
auto column_begin =
|
||||
std::lower_bound(col.cbegin(), col.cend(),
|
||||
xgboost::Entry(row_begin, 0.0f), cmp);
|
||||
xgboost::Entry(0, 0.0f), cmp);
|
||||
auto column_end =
|
||||
std::lower_bound(col.cbegin(), col.cend(),
|
||||
xgboost::Entry(row_end, 0.0f), cmp);
|
||||
xgboost::Entry(shard_size_, 0.0f), cmp);
|
||||
column_segments.emplace_back(
|
||||
std::make_pair(column_begin - col.cbegin(), column_end - col.cbegin()));
|
||||
row_ptr_.push_back(row_ptr_.back() + (column_end - column_begin));
|
||||
}
|
||||
ba_.Allocate(device_id_, &data_, row_ptr_.back(), &gpair_,
|
||||
(row_end - row_begin) * model_param.num_output_group);
|
||||
shard_size_ * model_param.num_output_group);
|
||||
|
||||
for (int fidx = 0; fidx < batch.Size(); fidx++) {
|
||||
for (size_t fidx = 0; fidx < batch.Size(); fidx++) {
|
||||
auto col = batch[fidx];
|
||||
auto seg = column_segments[fidx];
|
||||
dh::safe_cuda(cudaMemcpy(
|
||||
@@ -77,23 +69,21 @@ class DeviceShard {
|
||||
col.data() + seg.first,
|
||||
sizeof(Entry) * (seg.second - seg.first), cudaMemcpyHostToDevice));
|
||||
}
|
||||
// Rescale indices with respect to current shard
|
||||
RescaleIndices(device_id_, ridx_begin_, data_);
|
||||
}
|
||||
|
||||
~DeviceShard() {
|
||||
~DeviceShard() { // NOLINT
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
}
|
||||
|
||||
bool IsEmpty() {
|
||||
return (ridx_end_ - ridx_begin_) == 0;
|
||||
return shard_size_ == 0;
|
||||
}
|
||||
|
||||
void UpdateGpair(const std::vector<GradientPair> &host_gpair,
|
||||
const gbm::GBLinearModelParam &model_param) {
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
gpair_.data(),
|
||||
host_gpair.data() + ridx_begin_ * model_param.num_output_group,
|
||||
host_gpair.data(),
|
||||
gpair_.size() * sizeof(GradientPair), cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
@@ -107,13 +97,13 @@ class DeviceShard {
|
||||
counting, f);
|
||||
auto perm = thrust::make_permutation_iterator(gpair_.data(), skip);
|
||||
|
||||
return dh::SumReduction(temp_, perm, ridx_end_ - ridx_begin_);
|
||||
return dh::SumReduction(temp_, perm, shard_size_);
|
||||
}
|
||||
|
||||
void UpdateBiasResidual(float dbias, int group_idx, int num_groups) {
|
||||
if (dbias == 0.0f) return;
|
||||
auto d_gpair = gpair_;
|
||||
dh::LaunchN(device_id_, ridx_end_ - ridx_begin_, [=] __device__(size_t idx) {
|
||||
dh::LaunchN(device_id_, shard_size_, [=] __device__(size_t idx) {
|
||||
auto &g = d_gpair[idx * num_groups + group_idx];
|
||||
g += GradientPair(g.GetHess() * dbias, 0);
|
||||
});
|
||||
@@ -154,7 +144,7 @@ class DeviceShard {
|
||||
* \brief Coordinate descent algorithm that updates one feature per iteration
|
||||
*/
|
||||
|
||||
class GPUCoordinateUpdater : public LinearUpdater {
|
||||
class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||
public:
|
||||
// set training parameter
|
||||
void Configure(Args const& args) override {
|
||||
@@ -165,37 +155,23 @@ class GPUCoordinateUpdater : public LinearUpdater {
|
||||
|
||||
void LazyInitShards(DMatrix *p_fmat,
|
||||
const gbm::GBLinearModelParam &model_param) {
|
||||
if (!shards_.empty()) return;
|
||||
if (shard_) return;
|
||||
|
||||
dist_ = GPUDistribution::Block(GPUSet::All(learner_param_->gpu_id, learner_param_->n_gpus,
|
||||
p_fmat->Info().num_row_));
|
||||
auto devices = dist_.Devices();
|
||||
device_ = learner_param_->gpu_id;
|
||||
|
||||
size_t n_devices = static_cast<size_t>(devices.Size());
|
||||
size_t row_begin = 0;
|
||||
size_t num_row = static_cast<size_t>(p_fmat->Info().num_row_);
|
||||
auto num_row = static_cast<size_t>(p_fmat->Info().num_row_);
|
||||
|
||||
// Partition input matrix into row segments
|
||||
std::vector<size_t> row_segments;
|
||||
row_segments.push_back(0);
|
||||
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
|
||||
size_t shard_size = dist_.ShardSize(num_row, d_idx);
|
||||
size_t row_end = row_begin + shard_size;
|
||||
row_segments.push_back(row_end);
|
||||
row_begin = row_end;
|
||||
}
|
||||
size_t shard_size = num_row;
|
||||
row_segments.push_back(shard_size);
|
||||
|
||||
CHECK(p_fmat->SingleColBlock());
|
||||
SparsePage const& batch = *(p_fmat->GetBatches<CSCPage>().begin());
|
||||
|
||||
shards_.resize(n_devices);
|
||||
// Create device shards
|
||||
dh::ExecuteIndexShards(&shards_,
|
||||
[&](int i, std::unique_ptr<DeviceShard>& shard) {
|
||||
shard = std::unique_ptr<DeviceShard>(
|
||||
new DeviceShard(devices.DeviceId(i), batch, row_segments[i],
|
||||
row_segments[i + 1], tparam_, model_param));
|
||||
});
|
||||
// Create device shard
|
||||
shard_.reset(new DeviceShard(device_, batch, shard_size, tparam_, model_param));
|
||||
}
|
||||
|
||||
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
||||
@@ -208,11 +184,9 @@ class GPUCoordinateUpdater : public LinearUpdater {
|
||||
monitor_.Start("UpdateGpair");
|
||||
auto &in_gpair_host = in_gpair->ConstHostVector();
|
||||
// Update gpair
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
|
||||
if (!shard->IsEmpty()) {
|
||||
shard->UpdateGpair(in_gpair_host, model->param);
|
||||
}
|
||||
});
|
||||
if (shard_) {
|
||||
shard_->UpdateGpair(in_gpair_host, model->param);
|
||||
}
|
||||
monitor_.Stop("UpdateGpair");
|
||||
|
||||
monitor_.Start("UpdateBias");
|
||||
@@ -237,32 +211,21 @@ class GPUCoordinateUpdater : public LinearUpdater {
|
||||
}
|
||||
|
||||
void UpdateBias(DMatrix *p_fmat, gbm::GBLinearModel *model) {
|
||||
for (int group_idx = 0; group_idx < model->param.num_output_group;
|
||||
++group_idx) {
|
||||
for (int group_idx = 0; group_idx < model->param.num_output_group; ++group_idx) {
|
||||
// Get gradient
|
||||
auto grad = dh::ReduceShards<GradientPair>(
|
||||
&shards_, [&](std::unique_ptr<DeviceShard> &shard) {
|
||||
if (!shard->IsEmpty()) {
|
||||
GradientPair result =
|
||||
shard->GetBiasGradient(group_idx,
|
||||
model->param.num_output_group);
|
||||
return result;
|
||||
}
|
||||
return GradientPair(0, 0);
|
||||
});
|
||||
|
||||
auto grad = GradientPair(0, 0);
|
||||
if (shard_) {
|
||||
grad = shard_->GetBiasGradient(group_idx, model->param.num_output_group);
|
||||
}
|
||||
auto dbias = static_cast<float>(
|
||||
tparam_.learning_rate *
|
||||
CoordinateDeltaBias(grad.GetGrad(), grad.GetHess()));
|
||||
CoordinateDeltaBias(grad.GetGrad(), grad.GetHess()));
|
||||
model->bias()[group_idx] += dbias;
|
||||
|
||||
// Update residual
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
|
||||
if (!shard->IsEmpty()) {
|
||||
shard->UpdateBiasResidual(dbias, group_idx,
|
||||
model->param.num_output_group);
|
||||
}
|
||||
});
|
||||
if (shard_) {
|
||||
shard_->UpdateBiasResidual(dbias, group_idx, model->param.num_output_group);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -271,38 +234,30 @@ class GPUCoordinateUpdater : public LinearUpdater {
|
||||
gbm::GBLinearModel *model) {
|
||||
bst_float &w = (*model)[fidx][group_idx];
|
||||
// Get gradient
|
||||
auto grad = dh::ReduceShards<GradientPair>(
|
||||
&shards_, [&](std::unique_ptr<DeviceShard> &shard) {
|
||||
if (!shard->IsEmpty()) {
|
||||
return shard->GetGradient(group_idx, model->param.num_output_group,
|
||||
fidx);
|
||||
}
|
||||
return GradientPair(0, 0);
|
||||
});
|
||||
|
||||
auto grad = GradientPair(0, 0);
|
||||
if (shard_) {
|
||||
grad = shard_->GetGradient(group_idx, model->param.num_output_group, fidx);
|
||||
}
|
||||
auto dw = static_cast<float>(tparam_.learning_rate *
|
||||
CoordinateDelta(grad.GetGrad(), grad.GetHess(),
|
||||
w, tparam_.reg_alpha_denorm,
|
||||
tparam_.reg_lambda_denorm));
|
||||
w += dw;
|
||||
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx,
|
||||
std::unique_ptr<DeviceShard> &shard) {
|
||||
if (!shard->IsEmpty()) {
|
||||
shard->UpdateResidual(dw, group_idx, model->param.num_output_group, fidx);
|
||||
}
|
||||
});
|
||||
if (shard_) {
|
||||
shard_->UpdateResidual(dw, group_idx, model->param.num_output_group, fidx);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// training parameter
|
||||
LinearTrainParam tparam_;
|
||||
CoordinateParam coord_param_;
|
||||
GPUDistribution dist_;
|
||||
int device_{};
|
||||
std::unique_ptr<FeatureSelector> selector_;
|
||||
common::Monitor monitor_;
|
||||
|
||||
std::vector<std::unique_ptr<DeviceShard>> shards_;
|
||||
std::unique_ptr<DeviceShard> shard_{nullptr};
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_LINEAR_UPDATER(GPUCoordinateUpdater, "gpu_coord_descent")
|
||||
|
||||
Reference in New Issue
Block a user