Fix gpu coordinate running on multi-gpu. (#3893)
This commit is contained in:
parent
0ddb8a7661
commit
97984f4890
@ -89,7 +89,7 @@ void RescaleIndices(size_t ridx_begin, dh::DVec<Entry> *data) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
class DeviceShard {
|
class DeviceShard {
|
||||||
int device_idx_;
|
int device_id_;
|
||||||
dh::BulkAllocator<dh::MemoryType::kDevice> ba_;
|
dh::BulkAllocator<dh::MemoryType::kDevice> ba_;
|
||||||
std::vector<size_t> row_ptr_;
|
std::vector<size_t> row_ptr_;
|
||||||
dh::DVec<Entry> data_;
|
dh::DVec<Entry> data_;
|
||||||
@ -99,14 +99,14 @@ class DeviceShard {
|
|||||||
size_t ridx_end_;
|
size_t ridx_end_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
DeviceShard(int device_idx, const SparsePage &batch,
|
DeviceShard(int device_id, const SparsePage &batch,
|
||||||
bst_uint row_begin, bst_uint row_end,
|
bst_uint row_begin, bst_uint row_end,
|
||||||
const GPUCoordinateTrainParam ¶m,
|
const GPUCoordinateTrainParam ¶m,
|
||||||
const gbm::GBLinearModelParam &model_param)
|
const gbm::GBLinearModelParam &model_param)
|
||||||
: device_idx_(device_idx),
|
: device_id_(device_id),
|
||||||
ridx_begin_(row_begin),
|
ridx_begin_(row_begin),
|
||||||
ridx_end_(row_end) {
|
ridx_end_(row_end) {
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||||
// The begin and end indices for the section of each column associated with
|
// The begin and end indices for the section of each column associated with
|
||||||
// this shard
|
// this shard
|
||||||
std::vector<std::pair<bst_uint, bst_uint>> column_segments;
|
std::vector<std::pair<bst_uint, bst_uint>> column_segments;
|
||||||
@ -126,7 +126,7 @@ class DeviceShard {
|
|||||||
std::make_pair(column_begin - col.data(), column_end - col.data()));
|
std::make_pair(column_begin - col.data(), column_end - col.data()));
|
||||||
row_ptr_.push_back(row_ptr_.back() + column_end - column_begin);
|
row_ptr_.push_back(row_ptr_.back() + column_end - column_begin);
|
||||||
}
|
}
|
||||||
ba_.Allocate(device_idx, param.silent, &data_, row_ptr_.back(), &gpair_,
|
ba_.Allocate(device_id_, param.silent, &data_, row_ptr_.back(), &gpair_,
|
||||||
(row_end - row_begin) * model_param.num_output_group);
|
(row_end - row_begin) * model_param.num_output_group);
|
||||||
|
|
||||||
for (int fidx = 0; fidx < batch.Size(); fidx++) {
|
for (int fidx = 0; fidx < batch.Size(); fidx++) {
|
||||||
@ -146,6 +146,7 @@ class DeviceShard {
|
|||||||
}
|
}
|
||||||
|
|
||||||
GradientPair GetBiasGradient(int group_idx, int num_group) {
|
GradientPair GetBiasGradient(int group_idx, int num_group) {
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||||
auto counting = thrust::make_counting_iterator(0ull);
|
auto counting = thrust::make_counting_iterator(0ull);
|
||||||
auto f = [=] __device__(size_t idx) {
|
auto f = [=] __device__(size_t idx) {
|
||||||
return idx * num_group + group_idx;
|
return idx * num_group + group_idx;
|
||||||
@ -160,13 +161,14 @@ class DeviceShard {
|
|||||||
void UpdateBiasResidual(float dbias, int group_idx, int num_groups) {
|
void UpdateBiasResidual(float dbias, int group_idx, int num_groups) {
|
||||||
if (dbias == 0.0f) return;
|
if (dbias == 0.0f) return;
|
||||||
auto d_gpair = gpair_.Data();
|
auto d_gpair = gpair_.Data();
|
||||||
dh::LaunchN(device_idx_, ridx_end_ - ridx_begin_, [=] __device__(size_t idx) {
|
dh::LaunchN(device_id_, ridx_end_ - ridx_begin_, [=] __device__(size_t idx) {
|
||||||
auto &g = d_gpair[idx * num_groups + group_idx];
|
auto &g = d_gpair[idx * num_groups + group_idx];
|
||||||
g += GradientPair(g.GetHess() * dbias, 0);
|
g += GradientPair(g.GetHess() * dbias, 0);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
GradientPair GetGradient(int group_idx, int num_group, int fidx) {
|
GradientPair GetGradient(int group_idx, int num_group, int fidx) {
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||||
auto d_col = data_.Data() + row_ptr_[fidx];
|
auto d_col = data_.Data() + row_ptr_[fidx];
|
||||||
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
||||||
auto d_gpair = gpair_.Data();
|
auto d_gpair = gpair_.Data();
|
||||||
@ -186,7 +188,7 @@ class DeviceShard {
|
|||||||
auto d_gpair = gpair_.Data();
|
auto d_gpair = gpair_.Data();
|
||||||
auto d_col = data_.Data() + row_ptr_[fidx];
|
auto d_col = data_.Data() + row_ptr_[fidx];
|
||||||
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
||||||
dh::LaunchN(device_idx_, col_size, [=] __device__(size_t idx) {
|
dh::LaunchN(device_id_, col_size, [=] __device__(size_t idx) {
|
||||||
auto entry = d_col[idx];
|
auto entry = d_col[idx];
|
||||||
auto &g = d_gpair[entry.index * num_groups + group_idx];
|
auto &g = d_gpair[entry.index * num_groups + group_idx];
|
||||||
g += GradientPair(g.GetHess() * dw * entry.fvalue, 0);
|
g += GradientPair(g.GetHess() * dw * entry.fvalue, 0);
|
||||||
|
|||||||
70
tests/cpp/linear/test_linear.cu
Normal file
70
tests/cpp/linear/test_linear.cu
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
// Copyright by Contributors
|
||||||
|
#include <xgboost/linear_updater.h>
|
||||||
|
#include "../helpers.h"
|
||||||
|
#include "xgboost/gbm.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
|
||||||
|
TEST(Linear, GPUCoordinate) {
|
||||||
|
dh::safe_cuda(cudaSetDevice(0));
|
||||||
|
auto mat = xgboost::CreateDMatrix(10, 10, 0);
|
||||||
|
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
|
||||||
|
xgboost::LinearUpdater::Create("gpu_coord_descent"));
|
||||||
|
updater->Init({{"eta", "1."}, {"n_gpus", "1"}});
|
||||||
|
xgboost::HostDeviceVector<xgboost::GradientPair> gpair(
|
||||||
|
(*mat)->Info().num_row_, xgboost::GradientPair(-5, 1.0));
|
||||||
|
xgboost::gbm::GBLinearModel model;
|
||||||
|
model.param.num_feature = (*mat)->Info().num_col_;
|
||||||
|
model.param.num_output_group = 1;
|
||||||
|
model.LazyInitModel();
|
||||||
|
updater->Update(&gpair, (*mat).get(), &model, gpair.Size());
|
||||||
|
|
||||||
|
ASSERT_EQ(model.bias()[0], 5.0f);
|
||||||
|
|
||||||
|
delete mat;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_NCCL)
|
||||||
|
TEST(Linear, MGPU_GPUCoordinate) {
|
||||||
|
dh::safe_cuda(cudaSetDevice(0));
|
||||||
|
{
|
||||||
|
auto mat = xgboost::CreateDMatrix(10, 10, 0);
|
||||||
|
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
|
||||||
|
xgboost::LinearUpdater::Create("gpu_coord_descent"));
|
||||||
|
updater->Init({{"eta", "1."}, {"n_gpus", "-1"}});
|
||||||
|
xgboost::HostDeviceVector<xgboost::GradientPair> gpair(
|
||||||
|
(*mat)->Info().num_row_, xgboost::GradientPair(-5, 1.0));
|
||||||
|
xgboost::gbm::GBLinearModel model;
|
||||||
|
model.param.num_feature = (*mat)->Info().num_col_;
|
||||||
|
model.param.num_output_group = 1;
|
||||||
|
model.LazyInitModel();
|
||||||
|
updater->Update(&gpair, (*mat).get(), &model, gpair.Size());
|
||||||
|
|
||||||
|
ASSERT_EQ(model.bias()[0], 5.0f);
|
||||||
|
delete mat;
|
||||||
|
}
|
||||||
|
|
||||||
|
dh::safe_cuda(cudaSetDevice(0));
|
||||||
|
{
|
||||||
|
auto mat = xgboost::CreateDMatrix(10, 10, 0);
|
||||||
|
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
|
||||||
|
xgboost::LinearUpdater::Create("gpu_coord_descent"));
|
||||||
|
updater->Init({
|
||||||
|
{"eta", "1."},
|
||||||
|
{"n_gpus", "-1"},
|
||||||
|
{"gpu_id", "1"}});
|
||||||
|
xgboost::HostDeviceVector<xgboost::GradientPair> gpair(
|
||||||
|
(*mat)->Info().num_row_, xgboost::GradientPair(-5, 1.0));
|
||||||
|
xgboost::gbm::GBLinearModel model;
|
||||||
|
model.param.num_feature = (*mat)->Info().num_col_;
|
||||||
|
model.param.num_output_group = 1;
|
||||||
|
model.LazyInitModel();
|
||||||
|
updater->Update(&gpair, (*mat).get(), &model, gpair.Size());
|
||||||
|
|
||||||
|
ASSERT_EQ(model.bias()[0], 5.0f);
|
||||||
|
delete mat;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace xgboost
|
||||||
Loading…
x
Reference in New Issue
Block a user