mgpu predictor using explicit offsets (#4438)
* mgpu prediction using explicit sharding
This commit is contained in:
parent
d16d9a9988
commit
be0f346ec9
@ -110,6 +110,9 @@ class GPUDistribution {
|
|||||||
return GPUDistribution(devices, granularity, 0, std::vector<size_t>());
|
return GPUDistribution(devices, granularity, 0, std::vector<size_t>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE(rongou): Explicit offsets don't necessarily cover the whole vector. Sections before the
|
||||||
|
// first shard or after the last shard may be on host only. This windowing is done in the GPU
|
||||||
|
// predictor for external memory support.
|
||||||
static GPUDistribution Explicit(GPUSet devices, std::vector<size_t> offsets) {
|
static GPUDistribution Explicit(GPUSet devices, std::vector<size_t> offsets) {
|
||||||
return GPUDistribution(devices, 1, 0, std::move(offsets));
|
return GPUDistribution(devices, 1, 0, std::move(offsets));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -221,7 +221,9 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void DeviceOffsets(const HostDeviceVector<size_t>& data, std::vector<size_t>* out_offsets) {
|
void DeviceOffsets(const HostDeviceVector<size_t>& data,
|
||||||
|
size_t total_size,
|
||||||
|
std::vector<size_t>* out_offsets) {
|
||||||
auto& offsets = *out_offsets;
|
auto& offsets = *out_offsets;
|
||||||
offsets.resize(devices_.Size() + 1);
|
offsets.resize(devices_.Size() + 1);
|
||||||
offsets[0] = 0;
|
offsets[0] = 0;
|
||||||
@ -230,13 +232,35 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
int device = devices_.DeviceId(shard);
|
int device = devices_.DeviceId(shard);
|
||||||
auto data_span = data.DeviceSpan(device);
|
auto data_span = data.DeviceSpan(device);
|
||||||
dh::safe_cuda(cudaSetDevice(device));
|
dh::safe_cuda(cudaSetDevice(device));
|
||||||
// copy the last element from every shard
|
if (data_span.size() == 0) {
|
||||||
dh::safe_cuda(cudaMemcpy(&offsets.at(shard + 1),
|
offsets[shard + 1] = total_size;
|
||||||
&data_span[data_span.size()-1],
|
} else {
|
||||||
sizeof(size_t), cudaMemcpyDeviceToHost));
|
// copy the last element from every shard
|
||||||
|
dh::safe_cuda(cudaMemcpy(&offsets.at(shard + 1),
|
||||||
|
&data_span[data_span.size()-1],
|
||||||
|
sizeof(size_t), cudaMemcpyDeviceToHost));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This function populates the explicit offsets that can be used to create a window into the
|
||||||
|
// underlying host vector. The window starts from the `batch_offset` and has a size of
|
||||||
|
// `batch_size`, and is sharded across all the devices. Each shard is granular depending on
|
||||||
|
// the number of output classes `n_classes`.
|
||||||
|
void PredictionDeviceOffsets(size_t total_size, size_t batch_offset, size_t batch_size,
|
||||||
|
int n_classes, std::vector<size_t>* out_offsets) {
|
||||||
|
auto& offsets = *out_offsets;
|
||||||
|
size_t n_shards = devices_.Size();
|
||||||
|
offsets.resize(n_shards + 2);
|
||||||
|
size_t rows_per_shard = dh::DivRoundUp(batch_size, n_shards);
|
||||||
|
for (size_t shard = 0; shard < devices_.Size(); ++shard) {
|
||||||
|
size_t n_rows = std::min(batch_size, shard * rows_per_shard);
|
||||||
|
offsets[shard] = batch_offset + n_rows * n_classes;
|
||||||
|
}
|
||||||
|
offsets[n_shards] = batch_offset + batch_size * n_classes;
|
||||||
|
offsets[n_shards + 1] = total_size;
|
||||||
|
}
|
||||||
|
|
||||||
struct DeviceShard {
|
struct DeviceShard {
|
||||||
DeviceShard() : device_{-1} {}
|
DeviceShard() : device_{-1} {}
|
||||||
void Init(int device) {
|
void Init(int device) {
|
||||||
@ -246,11 +270,11 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
void PredictInternal
|
void PredictInternal
|
||||||
(const SparsePage& batch, const MetaInfo& info,
|
(const SparsePage& batch, const MetaInfo& info,
|
||||||
HostDeviceVector<bst_float>* predictions,
|
HostDeviceVector<bst_float>* predictions,
|
||||||
const size_t batch_offset,
|
|
||||||
const gbm::GBTreeModel& model,
|
const gbm::GBTreeModel& model,
|
||||||
const thrust::host_vector<size_t>& h_tree_segments,
|
const thrust::host_vector<size_t>& h_tree_segments,
|
||||||
const thrust::host_vector<DevicePredictionNode>& h_nodes,
|
const thrust::host_vector<DevicePredictionNode>& h_nodes,
|
||||||
size_t tree_begin, size_t tree_end) {
|
size_t tree_begin, size_t tree_end) {
|
||||||
|
if (predictions->DeviceSize(device_) == 0) { return; }
|
||||||
dh::safe_cuda(cudaSetDevice(device_));
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
nodes_.resize(h_nodes.size());
|
nodes_.resize(h_nodes.size());
|
||||||
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(nodes_), h_nodes.data(),
|
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(nodes_), h_nodes.data(),
|
||||||
@ -269,8 +293,6 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
|
|
||||||
const int BLOCK_THREADS = 128;
|
const int BLOCK_THREADS = 128;
|
||||||
size_t num_rows = batch.offset.DeviceSize(device_) - 1;
|
size_t num_rows = batch.offset.DeviceSize(device_) - 1;
|
||||||
if (num_rows < 1) { return; }
|
|
||||||
|
|
||||||
const int GRID_SIZE = static_cast<int>(dh::DivRoundUp(num_rows, BLOCK_THREADS));
|
const int GRID_SIZE = static_cast<int>(dh::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||||
|
|
||||||
int shared_memory_bytes = static_cast<int>
|
int shared_memory_bytes = static_cast<int>
|
||||||
@ -285,8 +307,8 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
data_distr.Devices().Index(device_));
|
data_distr.Devices().Index(device_));
|
||||||
|
|
||||||
PredictKernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>>
|
PredictKernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>>
|
||||||
(dh::ToSpan(nodes_), predictions->DeviceSpan(device_).subspan(batch_offset),
|
(dh::ToSpan(nodes_), predictions->DeviceSpan(device_), dh::ToSpan(tree_segments_),
|
||||||
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(device_),
|
dh::ToSpan(tree_group_), batch.offset.DeviceSpan(device_),
|
||||||
batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_,
|
batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_,
|
||||||
num_rows, entry_start, use_shared, model.param.num_output_group);
|
num_rows, entry_start, use_shared, model.param.num_output_group);
|
||||||
}
|
}
|
||||||
@ -324,22 +346,30 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]);
|
h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t i_batch = 0;
|
|
||||||
size_t batch_offset = 0;
|
size_t batch_offset = 0;
|
||||||
for (auto &batch : dmat->GetRowBatches()) {
|
for (auto &batch : dmat->GetRowBatches()) {
|
||||||
CHECK(i_batch == 0 || devices_.Size() == 1) << "External memory not supported for multi-GPU";
|
bool is_external_memory = batch.Size() < dmat->Info().num_row_;
|
||||||
// out_preds have been sharded and resized in InitOutPredictions()
|
if (is_external_memory) {
|
||||||
|
std::vector<size_t> out_preds_offsets;
|
||||||
|
PredictionDeviceOffsets(out_preds->Size(), batch_offset, batch.Size(),
|
||||||
|
model.param.num_output_group, &out_preds_offsets);
|
||||||
|
out_preds->Reshard(GPUDistribution::Explicit(devices_, out_preds_offsets));
|
||||||
|
}
|
||||||
|
|
||||||
batch.offset.Shard(GPUDistribution::Overlap(devices_, 1));
|
batch.offset.Shard(GPUDistribution::Overlap(devices_, 1));
|
||||||
std::vector<size_t> device_offsets;
|
std::vector<size_t> device_offsets;
|
||||||
DeviceOffsets(batch.offset, &device_offsets);
|
DeviceOffsets(batch.offset, batch.data.Size(), &device_offsets);
|
||||||
batch.data.Reshard(GPUDistribution::Explicit(devices_, device_offsets));
|
batch.data.Reshard(GPUDistribution::Explicit(devices_, device_offsets));
|
||||||
|
|
||||||
|
// TODO(rongou): only copy the model once for all the batches.
|
||||||
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
|
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
|
||||||
shard.PredictInternal(batch, dmat->Info(), out_preds, batch_offset, model,
|
shard.PredictInternal(batch, dmat->Info(), out_preds, model,
|
||||||
h_tree_segments, h_nodes, tree_begin, tree_end);
|
h_tree_segments, h_nodes, tree_begin, tree_end);
|
||||||
});
|
});
|
||||||
batch_offset += batch.Size() * model.param.num_output_group;
|
batch_offset += batch.Size() * model.param.num_output_group;
|
||||||
i_batch++;
|
|
||||||
}
|
}
|
||||||
|
out_preds->Reshard(GPUDistribution::Granular(devices_, model.param.num_output_group));
|
||||||
|
|
||||||
monitor_.StopCuda("DevicePredictInternal");
|
monitor_.StopCuda("DevicePredictInternal");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -198,9 +198,7 @@ TEST(gpu_predictor, MGPU_PicklingTest) {
|
|||||||
|
|
||||||
CheckCAPICall(XGBoosterFree(bst2));
|
CheckCAPICall(XGBoosterFree(bst2));
|
||||||
}
|
}
|
||||||
#endif // defined(XGBOOST_USE_NCCL)
|
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_NCCL)
|
|
||||||
// multi-GPU predictor test
|
// multi-GPU predictor test
|
||||||
TEST(gpu_predictor, MGPU_Test) {
|
TEST(gpu_predictor, MGPU_Test) {
|
||||||
std::unique_ptr<Predictor> gpu_predictor =
|
std::unique_ptr<Predictor> gpu_predictor =
|
||||||
@ -233,6 +231,34 @@ TEST(gpu_predictor, MGPU_Test) {
|
|||||||
delete dmat;
|
delete dmat;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// multi-GPU predictor external memory test
|
||||||
|
TEST(gpu_predictor, MGPU_ExternalMemoryTest) {
|
||||||
|
std::unique_ptr<Predictor> gpu_predictor =
|
||||||
|
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor"));
|
||||||
|
gpu_predictor->Init({std::pair<std::string, std::string>("n_gpus", "-1")}, {});
|
||||||
|
|
||||||
|
gbm::GBTreeModel model = CreateTestModel();
|
||||||
|
const int n_classes = 3;
|
||||||
|
model.param.num_output_group = n_classes;
|
||||||
|
std::vector<std::unique_ptr<DMatrix>> dmats;
|
||||||
|
dmats.push_back(CreateSparsePageDMatrix(9, 64UL));
|
||||||
|
dmats.push_back(CreateSparsePageDMatrix(128, 128UL));
|
||||||
|
dmats.push_back(CreateSparsePageDMatrix(1024, 1024UL));
|
||||||
|
|
||||||
|
for (const auto& dmat: dmats) {
|
||||||
|
// Test predict batch
|
||||||
|
HostDeviceVector<float> out_predictions;
|
||||||
|
gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
|
||||||
|
EXPECT_EQ(out_predictions.Size(), dmat->Info().num_row_ * n_classes);
|
||||||
|
const std::vector<float> &host_vector = out_predictions.ConstHostVector();
|
||||||
|
for (int i = 0; i < host_vector.size() / n_classes; i++) {
|
||||||
|
ASSERT_EQ(host_vector[i * n_classes], 1.5);
|
||||||
|
ASSERT_EQ(host_vector[i * n_classes + 1], 0.);
|
||||||
|
ASSERT_EQ(host_vector[i * n_classes + 2], 0.);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
#endif // defined(XGBOOST_USE_NCCL)
|
#endif // defined(XGBOOST_USE_NCCL)
|
||||||
} // namespace predictor
|
} // namespace predictor
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user