Multi-GPU support in GPUPredictor. (#3738)
* Multi-GPU support in GPUPredictor. - GPUPredictor is multi-GPU - removed DeviceMatrix, as it has been made obsolete by using HostDeviceVector in DMatrix * Replaced pointers with spans in GPUPredictor. * Added a multi-GPU predictor test. * Fix multi-gpu test. * Fix n_rows < n_gpus. * Reinitialize shards when GPUSet is changed. * Tests range of data. * Remove commented code. * Remove commented code.
This commit is contained in:
parent
32de54fdee
commit
2a59ff2f9b
@ -120,7 +120,7 @@ class SpanIterator {
|
||||
|
||||
using reference = typename std::conditional< // NOLINT
|
||||
IsConst, const ElementType, ElementType>::type&;
|
||||
using pointer = typename std::add_pointer<reference>::type&; // NOLINT
|
||||
using pointer = typename std::add_pointer<reference>::type; // NOLINT
|
||||
|
||||
XGBOOST_DEVICE constexpr SpanIterator() : span_{nullptr}, index_{0} {}
|
||||
|
||||
|
||||
@ -194,8 +194,9 @@ class GBTree : public GradientBooster {
|
||||
CHECK_EQ(in_gpair->Size() % ngroup, 0U)
|
||||
<< "must have exactly ngroup*nrow gpairs";
|
||||
// TODO(canonizer): perform this on GPU if HostDeviceVector has device set.
|
||||
HostDeviceVector<GradientPair> tmp(in_gpair->Size() / ngroup,
|
||||
GradientPair(), in_gpair->Distribution());
|
||||
HostDeviceVector<GradientPair> tmp
|
||||
(in_gpair->Size() / ngroup, GradientPair(),
|
||||
GPUDistribution::Block(in_gpair->Distribution().Devices()));
|
||||
const auto& gpair_h = in_gpair->ConstHostVector();
|
||||
auto nsize = static_cast<bst_omp_uint>(tmp.Size());
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
|
||||
@ -22,7 +22,7 @@ struct HingeObjParam : public dmlc::Parameter<HingeObjParam> {
|
||||
int n_gpus;
|
||||
int gpu_id;
|
||||
DMLC_DECLARE_PARAMETER(HingeObjParam) {
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(0).set_lower_bound(0)
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(-1)
|
||||
.describe("Number of GPUs to use for multi-gpu algorithms.");
|
||||
DMLC_DECLARE_FIELD(gpu_id)
|
||||
.set_lower_bound(0)
|
||||
|
||||
@ -31,7 +31,7 @@ struct SoftmaxMultiClassParam : public dmlc::Parameter<SoftmaxMultiClassParam> {
|
||||
DMLC_DECLARE_PARAMETER(SoftmaxMultiClassParam) {
|
||||
DMLC_DECLARE_FIELD(num_class).set_lower_bound(1)
|
||||
.describe("Number of output class in the multi-class classification.");
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1)
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(-1)
|
||||
.describe("Number of GPUs to use for multi-gpu algorithms.");
|
||||
DMLC_DECLARE_FIELD(gpu_id)
|
||||
.set_lower_bound(0)
|
||||
@ -64,10 +64,6 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
const int nclass = param_.num_class;
|
||||
const auto ndata = static_cast<int64_t>(preds.Size() / nclass);
|
||||
|
||||
// clear out device memory;
|
||||
out_gpair->Reshard(GPUSet::Empty());
|
||||
preds.Reshard(GPUSet::Empty());
|
||||
|
||||
out_gpair->Reshard(GPUDistribution::Granular(devices_, nclass));
|
||||
info.labels_.Reshard(GPUDistribution::Block(devices_));
|
||||
info.weights_.Reshard(GPUDistribution::Block(devices_));
|
||||
@ -109,11 +105,6 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
}, common::Range{0, ndata}, devices_, false)
|
||||
.Eval(out_gpair, &info.labels_, &preds, &info.weights_, &label_correct_);
|
||||
|
||||
out_gpair->Reshard(GPUSet::Empty());
|
||||
out_gpair->Reshard(GPUDistribution::Block(devices_));
|
||||
preds.Reshard(GPUSet::Empty());
|
||||
preds.Reshard(GPUDistribution::Block(devices_));
|
||||
|
||||
std::vector<int>& label_correct_h = label_correct_.HostVector();
|
||||
for (auto const flag : label_correct_h) {
|
||||
if (flag != 1) {
|
||||
@ -136,7 +127,6 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
const auto ndata = static_cast<int64_t>(io_preds->Size() / nclass);
|
||||
max_preds_.Resize(ndata);
|
||||
|
||||
io_preds->Reshard(GPUSet::Empty()); // clear out device memory
|
||||
if (prob) {
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
||||
@ -166,8 +156,6 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
io_preds->Resize(max_preds_.Size());
|
||||
io_preds->Copy(max_preds_);
|
||||
}
|
||||
io_preds->Reshard(GPUSet::Empty()); // clear out device memory
|
||||
io_preds->Reshard(GPUDistribution::Block(devices_));
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@ -34,7 +34,7 @@ struct RegLossParam : public dmlc::Parameter<RegLossParam> {
|
||||
DMLC_DECLARE_PARAMETER(RegLossParam) {
|
||||
DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f)
|
||||
.describe("Scale the weight of positive examples by this factor");
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1)
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(-1)
|
||||
.describe("Number of GPUs to use for multi-gpu algorithms.");
|
||||
DMLC_DECLARE_FIELD(gpu_id)
|
||||
.set_lower_bound(0)
|
||||
|
||||
@ -27,10 +27,10 @@ struct GPUPredictionParam : public dmlc::Parameter<GPUPredictionParam> {
|
||||
bool silent;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(GPUPredictionParam) {
|
||||
DMLC_DECLARE_FIELD(gpu_id).set_default(0).describe(
|
||||
DMLC_DECLARE_FIELD(gpu_id).set_lower_bound(0).set_default(0).describe(
|
||||
"Device ordinal for GPU prediction.");
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(1).describe(
|
||||
"Number of devices to use for prediction (NOT IMPLEMENTED).");
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_lower_bound(-1).set_default(1).describe(
|
||||
"Number of devices to use for prediction.");
|
||||
DMLC_DECLARE_FIELD(silent).set_default(false).describe(
|
||||
"Do not print information during trainig.");
|
||||
}
|
||||
@ -43,53 +43,12 @@ void IncrementOffset(IterT begin_itr, IterT end_itr, size_t amount) {
|
||||
[=] __device__(size_t elem) { return elem + amount; });
|
||||
}
|
||||
|
||||
/**
|
||||
* \struct DeviceMatrix
|
||||
*
|
||||
* \brief A csr representation of the input matrix allocated on the device.
|
||||
*/
|
||||
|
||||
struct DeviceMatrix {
|
||||
DMatrix* p_mat; // Pointer to the original matrix on the host
|
||||
dh::BulkAllocator<dh::MemoryType::kDevice> ba;
|
||||
dh::DVec<size_t> row_ptr;
|
||||
dh::DVec<Entry> data;
|
||||
thrust::device_vector<float> predictions;
|
||||
|
||||
DeviceMatrix(DMatrix* dmat, int device_idx, bool silent) : p_mat(dmat) {
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
const auto& info = dmat->Info();
|
||||
ba.Allocate(device_idx, silent, &row_ptr, info.num_row_ + 1, &data,
|
||||
info.num_nonzero_);
|
||||
size_t data_offset = 0;
|
||||
for (const auto &batch : dmat->GetRowBatches()) {
|
||||
const auto& offset_vec = batch.offset.HostVector();
|
||||
const auto& data_vec = batch.data.HostVector();
|
||||
// Copy row ptr
|
||||
dh::safe_cuda(cudaMemcpy(
|
||||
row_ptr.Data() + batch.base_rowid, offset_vec.data(),
|
||||
sizeof(size_t) * offset_vec.size(), cudaMemcpyHostToDevice));
|
||||
if (batch.base_rowid > 0) {
|
||||
auto begin_itr = row_ptr.tbegin() + batch.base_rowid;
|
||||
auto end_itr = begin_itr + batch.Size() + 1;
|
||||
IncrementOffset(begin_itr, end_itr, batch.base_rowid);
|
||||
}
|
||||
dh::safe_cuda(cudaMemcpy(data.Data() + data_offset, data_vec.data(),
|
||||
sizeof(Entry) * data_vec.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
// Copy data
|
||||
data_offset += batch.data.Size();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \struct DevicePredictionNode
|
||||
*
|
||||
* \brief Packed 16 byte representation of a tree node for use in device
|
||||
* prediction
|
||||
*/
|
||||
|
||||
struct DevicePredictionNode {
|
||||
XGBOOST_DEVICE DevicePredictionNode()
|
||||
: fidx(-1), left_child_idx(-1), right_child_idx(-1) {}
|
||||
@ -105,6 +64,7 @@ struct DevicePredictionNode {
|
||||
NodeValue val;
|
||||
|
||||
DevicePredictionNode(const RegTree::Node& n) { // NOLINT
|
||||
static_assert(sizeof(DevicePredictionNode) == 16, "Size is not 16 bytes");
|
||||
this->left_child_idx = n.LeftChild();
|
||||
this->right_child_idx = n.RightChild();
|
||||
this->fidx = n.SplitIndex();
|
||||
@ -140,19 +100,21 @@ struct DevicePredictionNode {
|
||||
|
||||
struct ElementLoader {
|
||||
bool use_shared;
|
||||
size_t* d_row_ptr;
|
||||
Entry* d_data;
|
||||
common::Span<const size_t> d_row_ptr;
|
||||
common::Span<const Entry> d_data;
|
||||
int num_features;
|
||||
float* smem;
|
||||
size_t entry_start;
|
||||
|
||||
__device__ ElementLoader(bool use_shared, size_t* row_ptr,
|
||||
Entry* entry, int num_features,
|
||||
float* smem, int num_rows)
|
||||
__device__ ElementLoader(bool use_shared, common::Span<const size_t> row_ptr,
|
||||
common::Span<const Entry> entry, int num_features,
|
||||
float* smem, int num_rows, size_t entry_start)
|
||||
: use_shared(use_shared),
|
||||
d_row_ptr(row_ptr),
|
||||
d_data(entry),
|
||||
num_features(num_features),
|
||||
smem(smem) {
|
||||
smem(smem),
|
||||
entry_start(entry_start) {
|
||||
// Copy instances
|
||||
if (use_shared) {
|
||||
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
@ -163,7 +125,7 @@ struct ElementLoader {
|
||||
bst_uint elem_begin = d_row_ptr[global_idx];
|
||||
bst_uint elem_end = d_row_ptr[global_idx + 1];
|
||||
for (bst_uint elem_idx = elem_begin; elem_idx < elem_end; elem_idx++) {
|
||||
Entry elem = d_data[elem_idx];
|
||||
Entry elem = d_data[elem_idx - entry_start];
|
||||
smem[threadIdx.x * num_features + elem.index] = elem.fvalue;
|
||||
}
|
||||
}
|
||||
@ -175,9 +137,9 @@ struct ElementLoader {
|
||||
return smem[threadIdx.x * num_features + fidx];
|
||||
} else {
|
||||
// Binary search
|
||||
auto begin_ptr = d_data + d_row_ptr[ridx];
|
||||
auto end_ptr = d_data + d_row_ptr[ridx + 1];
|
||||
Entry* previous_middle = nullptr;
|
||||
auto begin_ptr = d_data.begin() + (d_row_ptr[ridx] - entry_start);
|
||||
auto end_ptr = d_data.begin() + (d_row_ptr[ridx + 1] - entry_start);
|
||||
common::Span<const Entry>::iterator previous_middle;
|
||||
while (end_ptr != begin_ptr) {
|
||||
auto middle = begin_ptr + (end_ptr - begin_ptr) / 2;
|
||||
if (middle == previous_middle) {
|
||||
@ -220,22 +182,25 @@ __device__ float GetLeafWeight(bst_uint ridx, const DevicePredictionNode* tree,
|
||||
}
|
||||
|
||||
template <int BLOCK_THREADS>
|
||||
__global__ void PredictKernel(const DevicePredictionNode* d_nodes,
|
||||
float* d_out_predictions, size_t* d_tree_segments,
|
||||
int* d_tree_group, size_t* d_row_ptr,
|
||||
Entry* d_data, size_t tree_begin,
|
||||
__global__ void PredictKernel(common::Span<const DevicePredictionNode> d_nodes,
|
||||
common::Span<float> d_out_predictions,
|
||||
common::Span<size_t> d_tree_segments,
|
||||
common::Span<int> d_tree_group,
|
||||
common::Span<const size_t> d_row_ptr,
|
||||
common::Span<const Entry> d_data, size_t tree_begin,
|
||||
size_t tree_end, size_t num_features,
|
||||
size_t num_rows, bool use_shared, int num_group) {
|
||||
size_t num_rows, size_t entry_start,
|
||||
bool use_shared, int num_group) {
|
||||
extern __shared__ float smem[];
|
||||
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
ElementLoader loader(use_shared, d_row_ptr, d_data, num_features, smem,
|
||||
num_rows);
|
||||
num_rows, entry_start);
|
||||
if (global_idx >= num_rows) return;
|
||||
if (num_group == 1) {
|
||||
float sum = 0;
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
const DevicePredictionNode* d_tree =
|
||||
d_nodes + d_tree_segments[tree_idx - tree_begin];
|
||||
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
||||
sum += GetLeafWeight(global_idx, d_tree, &loader);
|
||||
}
|
||||
d_out_predictions[global_idx] += sum;
|
||||
@ -243,7 +208,7 @@ __global__ void PredictKernel(const DevicePredictionNode* d_nodes,
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
int tree_group = d_tree_group[tree_idx];
|
||||
const DevicePredictionNode* d_tree =
|
||||
d_nodes + d_tree_segments[tree_idx - tree_begin];
|
||||
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
||||
bst_uint out_prediction_idx = global_idx * num_group + tree_group;
|
||||
d_out_predictions[out_prediction_idx] +=
|
||||
GetLeafWeight(global_idx, d_tree, &loader);
|
||||
@ -259,31 +224,89 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
};
|
||||
|
||||
private:
|
||||
void DeviceOffsets(const HostDeviceVector<size_t>& data, std::vector<size_t>* out_offsets) {
|
||||
auto& offsets = *out_offsets;
|
||||
offsets.resize(devices_.Size() + 1);
|
||||
offsets[0] = 0;
|
||||
#pragma omp parallel for schedule(static, 1) if (devices_.Size() > 1)
|
||||
for (int shard = 0; shard < devices_.Size(); ++shard) {
|
||||
int device = devices_[shard];
|
||||
auto data_span = data.DeviceSpan(device);
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
// copy the last element from every shard
|
||||
dh::safe_cuda(cudaMemcpy(&offsets.at(shard + 1),
|
||||
&data_span[data_span.size()-1],
|
||||
sizeof(size_t), cudaMemcpyDeviceToHost));
|
||||
}
|
||||
}
|
||||
|
||||
struct DeviceShard {
|
||||
DeviceShard() : device_(-1) {}
|
||||
void Init(int device) {
|
||||
this->device_ = device;
|
||||
max_shared_memory_bytes = dh::MaxSharedMemory(this->device_);
|
||||
}
|
||||
void PredictInternal
|
||||
(const SparsePage& batch, const MetaInfo& info,
|
||||
HostDeviceVector<bst_float>* predictions,
|
||||
const gbm::GBTreeModel& model,
|
||||
const thrust::host_vector<size_t>& h_tree_segments,
|
||||
const thrust::host_vector<DevicePredictionNode>& h_nodes,
|
||||
size_t tree_begin, size_t tree_end) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
nodes.resize(h_nodes.size());
|
||||
dh::safe_cuda(cudaMemcpy(dh::Raw(nodes), h_nodes.data(),
|
||||
sizeof(DevicePredictionNode) * h_nodes.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_segments.resize(h_tree_segments.size());
|
||||
|
||||
dh::safe_cuda(cudaMemcpy(dh::Raw(tree_segments), h_tree_segments.data(),
|
||||
sizeof(size_t) * h_tree_segments.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_group.resize(model.tree_info.size());
|
||||
|
||||
dh::safe_cuda(cudaMemcpy(dh::Raw(tree_group), model.tree_info.data(),
|
||||
sizeof(int) * model.tree_info.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
|
||||
const int BLOCK_THREADS = 128;
|
||||
size_t num_rows = batch.offset.DeviceSize(device_) - 1;
|
||||
|
||||
const int GRID_SIZE = static_cast<int>(dh::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||
|
||||
int shared_memory_bytes = static_cast<int>
|
||||
(sizeof(float) * info.num_col_ * BLOCK_THREADS);
|
||||
bool use_shared = true;
|
||||
if (shared_memory_bytes > max_shared_memory_bytes) {
|
||||
shared_memory_bytes = 0;
|
||||
use_shared = false;
|
||||
}
|
||||
const auto& data_distr = batch.data.Distribution();
|
||||
int index = data_distr.Devices().Index(device_);
|
||||
size_t entry_start = data_distr.ShardStart(batch.data.Size(), index);
|
||||
|
||||
PredictKernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>>
|
||||
(dh::ToSpan(nodes), predictions->DeviceSpan(device_), dh::ToSpan(tree_segments),
|
||||
dh::ToSpan(tree_group), batch.offset.DeviceSpan(device_),
|
||||
batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_,
|
||||
num_rows, entry_start, use_shared, model.param.num_output_group);
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
int device_;
|
||||
thrust::device_vector<DevicePredictionNode> nodes;
|
||||
thrust::device_vector<size_t> tree_segments;
|
||||
thrust::device_vector<int> tree_group;
|
||||
size_t max_shared_memory_bytes;
|
||||
};
|
||||
|
||||
void DevicePredictInternal(DMatrix* dmat,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, size_t tree_begin,
|
||||
size_t tree_end) {
|
||||
if (tree_end - tree_begin == 0) {
|
||||
return;
|
||||
}
|
||||
if (tree_end - tree_begin == 0) { return; }
|
||||
|
||||
std::shared_ptr<DeviceMatrix> device_matrix;
|
||||
// Matrix is not in host cache, create a temporary matrix
|
||||
if (this->cache_.find(dmat) == this->cache_.end()) {
|
||||
device_matrix = std::shared_ptr<DeviceMatrix>(
|
||||
new DeviceMatrix(dmat, param.gpu_id, param.silent));
|
||||
} else {
|
||||
// Create this matrix on device if doesn't exist
|
||||
if (this->device_matrix_cache_.find(dmat) ==
|
||||
this->device_matrix_cache_.end()) {
|
||||
this->device_matrix_cache_.emplace(
|
||||
dmat, std::shared_ptr<DeviceMatrix>(
|
||||
new DeviceMatrix(dmat, param.gpu_id, param.silent)));
|
||||
}
|
||||
device_matrix = device_matrix_cache_.find(dmat)->second;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(param.gpu_id));
|
||||
CHECK_EQ(model.param.size_leaf_vector, 0);
|
||||
// Copy decision trees to device
|
||||
thrust::host_vector<size_t> h_tree_segments;
|
||||
@ -291,61 +314,33 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
size_t sum = 0;
|
||||
h_tree_segments.push_back(sum);
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
sum += model.trees[tree_idx]->GetNodes().size();
|
||||
sum += model.trees.at(tree_idx)->GetNodes().size();
|
||||
h_tree_segments.push_back(sum);
|
||||
}
|
||||
|
||||
thrust::host_vector<DevicePredictionNode> h_nodes(h_tree_segments.back());
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
auto& src_nodes = model.trees[tree_idx]->GetNodes();
|
||||
auto& src_nodes = model.trees.at(tree_idx)->GetNodes();
|
||||
std::copy(src_nodes.begin(), src_nodes.end(),
|
||||
h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]);
|
||||
}
|
||||
|
||||
nodes.resize(h_nodes.size());
|
||||
dh::safe_cuda(cudaMemcpy(dh::Raw(nodes), h_nodes.data(),
|
||||
sizeof(DevicePredictionNode) * h_nodes.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_segments.resize(h_tree_segments.size());
|
||||
dh::safe_cuda(cudaMemcpy(dh::Raw(tree_segments), h_tree_segments.data(),
|
||||
sizeof(size_t) * h_tree_segments.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_group.resize(model.tree_info.size());
|
||||
dh::safe_cuda(cudaMemcpy(dh::Raw(tree_group), model.tree_info.data(),
|
||||
sizeof(int) * model.tree_info.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
size_t i_batch = 0;
|
||||
|
||||
device_matrix->predictions.resize(out_preds->Size());
|
||||
auto& predictions = device_matrix->predictions;
|
||||
out_preds->GatherTo(predictions.data(),
|
||||
predictions.data() + predictions.size());
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(param.gpu_id));
|
||||
|
||||
const int BLOCK_THREADS = 128;
|
||||
const int GRID_SIZE = static_cast<int>(
|
||||
dh::DivRoundUp(device_matrix->row_ptr.Size() - 1, BLOCK_THREADS));
|
||||
|
||||
int shared_memory_bytes = static_cast<int>(
|
||||
sizeof(float) * device_matrix->p_mat->Info().num_col_ * BLOCK_THREADS);
|
||||
bool use_shared = true;
|
||||
if (shared_memory_bytes > max_shared_memory_bytes) {
|
||||
shared_memory_bytes = 0;
|
||||
use_shared = false;
|
||||
for (const auto &batch : dmat->GetRowBatches()) {
|
||||
CHECK_EQ(i_batch, 0) << "External memory not supported";
|
||||
size_t n_rows = batch.offset.Size() - 1;
|
||||
// out_preds have been resharded and resized in InitOutPredictions()
|
||||
batch.offset.Reshard(GPUDistribution::Overlap(devices_, 1));
|
||||
std::vector<size_t> device_offsets;
|
||||
DeviceOffsets(batch.offset, &device_offsets);
|
||||
batch.data.Reshard(GPUDistribution::Explicit(devices_, device_offsets));
|
||||
dh::ExecuteShards(&shards, [&](DeviceShard& shard){
|
||||
shard.PredictInternal(batch, dmat->Info(), out_preds, model, h_tree_segments,
|
||||
h_nodes, tree_begin, tree_end);
|
||||
});
|
||||
i_batch++;
|
||||
}
|
||||
|
||||
PredictKernel<BLOCK_THREADS>
|
||||
<<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>>(
|
||||
dh::Raw(nodes), dh::Raw(device_matrix->predictions),
|
||||
dh::Raw(tree_segments), dh::Raw(tree_group),
|
||||
device_matrix->row_ptr.Data(), device_matrix->data.Data(),
|
||||
tree_begin, tree_end, device_matrix->p_mat->Info().num_col_,
|
||||
device_matrix->p_mat->Info().num_row_, use_shared,
|
||||
model.param.num_output_group);
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
out_preds->ScatterFrom(predictions.data(),
|
||||
predictions.data() + predictions.size());
|
||||
}
|
||||
|
||||
public:
|
||||
@ -354,6 +349,10 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
unsigned ntree_limit = 0) override {
|
||||
GPUSet devices = GPUSet::All(
|
||||
param.n_gpus, dmat->Info().num_row_).Normalised(param.gpu_id);
|
||||
ConfigureShards(devices);
|
||||
|
||||
if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) {
|
||||
return;
|
||||
}
|
||||
@ -372,9 +371,10 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
void InitOutPredictions(const MetaInfo& info,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model) const {
|
||||
size_t n = model.param.num_output_group * info.num_row_;
|
||||
size_t n_classes = model.param.num_output_group;
|
||||
size_t n = n_classes * info.num_row_;
|
||||
const HostDeviceVector<bst_float>& base_margin = info.base_margin_;
|
||||
out_preds->Reshard(devices);
|
||||
out_preds->Reshard(GPUDistribution::Granular(devices_, n_classes));
|
||||
out_preds->Resize(n);
|
||||
if (base_margin.Size() != 0) {
|
||||
CHECK_EQ(out_preds->Size(), n);
|
||||
@ -392,14 +392,13 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
if (it != cache_.end()) {
|
||||
const HostDeviceVector<bst_float>& y = it->second.predictions;
|
||||
if (y.Size() != 0) {
|
||||
out_preds->Reshard(devices);
|
||||
out_preds->Reshard(y.Distribution());
|
||||
out_preds->Resize(y.Size());
|
||||
out_preds->Copy(y);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -464,24 +463,33 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
Predictor::Init(cfg, cache);
|
||||
cpu_predictor->Init(cfg, cache);
|
||||
param.InitAllowUnknown(cfg);
|
||||
devices = GPUSet::All(param.n_gpus).Normalised(param.gpu_id);
|
||||
max_shared_memory_bytes = dh::MaxSharedMemory(param.gpu_id);
|
||||
|
||||
GPUSet devices = GPUSet::All(param.n_gpus).Normalised(param.gpu_id);
|
||||
ConfigureShards(devices);
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief Re configure shards when GPUSet is changed. */
|
||||
void ConfigureShards(GPUSet devices) {
|
||||
if (devices_ == devices) return;
|
||||
|
||||
devices_ = devices;
|
||||
shards.clear();
|
||||
shards.resize(devices_.Size());
|
||||
dh::ExecuteIndexShards(&shards, [=](size_t i, DeviceShard& shard){
|
||||
shard.Init(devices_[i]);
|
||||
});
|
||||
}
|
||||
|
||||
GPUPredictionParam param;
|
||||
std::unique_ptr<Predictor> cpu_predictor;
|
||||
std::unordered_map<DMatrix*, std::shared_ptr<DeviceMatrix>>
|
||||
device_matrix_cache_;
|
||||
thrust::device_vector<DevicePredictionNode> nodes;
|
||||
thrust::device_vector<size_t> tree_segments;
|
||||
thrust::device_vector<int> tree_group;
|
||||
thrust::device_vector<bst_float> preds;
|
||||
GPUSet devices;
|
||||
size_t max_shared_memory_bytes;
|
||||
std::vector<DeviceShard> shards;
|
||||
GPUSet devices_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
|
||||
.describe("Make predictions using GPU.")
|
||||
.set_body([]() { return new GPUPredictor(); });
|
||||
|
||||
} // namespace predictor
|
||||
} // namespace xgboost
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace predictor {
|
||||
|
||||
TEST(gpu_predictor, Test) {
|
||||
std::unique_ptr<Predictor> gpu_predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor"));
|
||||
@ -41,8 +42,7 @@ TEST(gpu_predictor, Test) {
|
||||
std::vector<float>& cpu_out_predictions_h = cpu_out_predictions.HostVector();
|
||||
float abs_tolerance = 0.001;
|
||||
for (int i = 0; i < gpu_out_predictions.Size(); i++) {
|
||||
ASSERT_LT(std::abs(gpu_out_predictions_h[i] - cpu_out_predictions_h[i]),
|
||||
abs_tolerance);
|
||||
ASSERT_NEAR(gpu_out_predictions_h[i], cpu_out_predictions_h[i], abs_tolerance);
|
||||
}
|
||||
// Test predict instance
|
||||
const auto &batch = *(*dmat)->GetRowBatches().begin();
|
||||
@ -76,5 +76,46 @@ TEST(gpu_predictor, Test) {
|
||||
|
||||
delete dmat;
|
||||
}
|
||||
|
||||
// multi-GPU predictor test
|
||||
TEST(gpu_predictor, MGPU_Test) {
|
||||
std::unique_ptr<Predictor> gpu_predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor"));
|
||||
std::unique_ptr<Predictor> cpu_predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor"));
|
||||
|
||||
gpu_predictor->Init({std::pair<std::string, std::string>("n_gpus", "-1")}, {});
|
||||
cpu_predictor->Init({}, {});
|
||||
|
||||
for (size_t i = 1; i < 33; i *= 2) {
|
||||
int n_row = i, n_col = i;
|
||||
auto dmat = CreateDMatrix(n_row, n_col, 0);
|
||||
|
||||
std::vector<std::unique_ptr<RegTree>> trees;
|
||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree()));
|
||||
trees.back()->InitModel();
|
||||
(*trees.back())[0].SetLeaf(1.5f);
|
||||
(*trees.back()).Stat(0).sum_hess = 1.0f;
|
||||
gbm::GBTreeModel model(0.5);
|
||||
model.CommitModel(std::move(trees), 0);
|
||||
model.param.num_output_group = 1;
|
||||
|
||||
// Test predict batch
|
||||
HostDeviceVector<float> gpu_out_predictions;
|
||||
HostDeviceVector<float> cpu_out_predictions;
|
||||
|
||||
gpu_predictor->PredictBatch((*dmat).get(), &gpu_out_predictions, model, 0);
|
||||
cpu_predictor->PredictBatch((*dmat).get(), &cpu_out_predictions, model, 0);
|
||||
|
||||
std::vector<float>& gpu_out_predictions_h = gpu_out_predictions.HostVector();
|
||||
std::vector<float>& cpu_out_predictions_h = cpu_out_predictions.HostVector();
|
||||
float abs_tolerance = 0.001;
|
||||
for (int i = 0; i < gpu_out_predictions.Size(); i++) {
|
||||
ASSERT_NEAR(gpu_out_predictions_h[i], cpu_out_predictions_h[i], abs_tolerance);
|
||||
}
|
||||
delete dmat;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace predictor
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user