Make HostDeviceVector single gpu only (#4773)

* Make HostDeviceVector single gpu only
This commit is contained in:
Rong Ou
2019-08-25 14:51:13 -07:00
committed by Rory Mitchell
parent 41227d1933
commit 38ab79f889
54 changed files with 641 additions and 1621 deletions

View File

@@ -35,7 +35,6 @@ __global__ void FindCutsK
if (icut >= ncuts) {
return;
}
WXQSketch::Entry v;
int isample = 0;
if (icut == 0) {
isample = 0;
@@ -59,11 +58,14 @@ struct IsNotNaN {
__device__ bool operator()(float a) const { return !isnan(a); }
};
__global__ void UnpackFeaturesK
(float* __restrict__ fvalues, float* __restrict__ feature_weights,
const size_t* __restrict__ row_ptrs, const float* __restrict__ weights,
Entry* entries, size_t nrows_array, int ncols, size_t row_begin_ptr,
size_t nrows) {
__global__ void UnpackFeaturesK(float* __restrict__ fvalues,
float* __restrict__ feature_weights,
const size_t* __restrict__ row_ptrs,
const float* __restrict__ weights,
Entry* entries,
size_t nrows_array,
size_t row_begin_ptr,
size_t nrows) {
size_t irow = threadIdx.x + size_t(blockIdx.x) * blockDim.x;
if (irow >= nrows) {
return;
@@ -102,8 +104,9 @@ struct SketchContainer {
const MetaInfo &info = dmat->Info();
// Initialize Sketches for this dmatrix
sketches_.resize(info.num_col_);
#pragma omp parallel for schedule(static) if (info.num_col_ > kOmpNumColsParallelizeLimit)
for (int icol = 0; icol < info.num_col_; ++icol) {
#pragma omp parallel for default(none) shared(info, param) schedule(static) \
if (info.num_col_ > kOmpNumColsParallelizeLimit) // NOLINT
for (int icol = 0; icol < info.num_col_; ++icol) { // NOLINT
sketches_[icol].Init(info.num_row_, 1.0 / (8 * param.max_bin));
}
}
@@ -120,8 +123,6 @@ struct GPUSketcher {
// manage memory for a single GPU
class DeviceShard {
int device_;
bst_uint row_begin_; // The row offset for this shard
bst_uint row_end_;
bst_uint n_rows_;
int num_cols_{0};
size_t n_cuts_{0};
@@ -131,27 +132,31 @@ struct GPUSketcher {
tree::TrainParam param_;
SketchContainer *sketch_container_;
dh::device_vector<size_t> row_ptrs_;
dh::device_vector<Entry> entries_;
dh::device_vector<bst_float> fvalues_;
dh::device_vector<bst_float> feature_weights_;
dh::device_vector<bst_float> fvalues_cur_;
dh::device_vector<WXQSketch::Entry> cuts_d_;
thrust::host_vector<WXQSketch::Entry> cuts_h_;
dh::device_vector<bst_float> weights_;
dh::device_vector<bst_float> weights2_;
std::vector<size_t> n_cuts_cur_;
dh::device_vector<size_t> num_elements_;
dh::device_vector<char> tmp_storage_;
dh::device_vector<size_t> row_ptrs_{};
dh::device_vector<Entry> entries_{};
dh::device_vector<bst_float> fvalues_{};
dh::device_vector<bst_float> feature_weights_{};
dh::device_vector<bst_float> fvalues_cur_{};
dh::device_vector<WXQSketch::Entry> cuts_d_{};
thrust::host_vector<WXQSketch::Entry> cuts_h_{};
dh::device_vector<bst_float> weights_{};
dh::device_vector<bst_float> weights2_{};
std::vector<size_t> n_cuts_cur_{};
dh::device_vector<size_t> num_elements_{};
dh::device_vector<char> tmp_storage_{};
public:
DeviceShard(int device, bst_uint row_begin, bst_uint row_end,
tree::TrainParam param, SketchContainer *sketch_container) :
device_(device), row_begin_(row_begin), row_end_(row_end),
n_rows_(row_end - row_begin), param_(std::move(param)), sketch_container_(sketch_container) {
DeviceShard(int device,
bst_uint n_rows,
tree::TrainParam param,
SketchContainer* sketch_container) :
device_(device),
n_rows_(n_rows),
param_(std::move(param)),
sketch_container_(sketch_container) {
}
~DeviceShard() {
~DeviceShard() { // NOLINT
dh::safe_cuda(cudaSetDevice(device_));
}
@@ -319,19 +324,18 @@ struct GPUSketcher {
const auto& offset_vec = row_batch.offset.HostVector();
const auto& data_vec = row_batch.data.HostVector();
size_t n_entries = offset_vec[row_begin_ + batch_row_end] -
offset_vec[row_begin_ + batch_row_begin];
size_t n_entries = offset_vec[batch_row_end] - offset_vec[batch_row_begin];
// copy the batch to the GPU
dh::safe_cuda
(cudaMemcpyAsync(entries_.data().get(),
data_vec.data() + offset_vec[row_begin_ + batch_row_begin],
data_vec.data() + offset_vec[batch_row_begin],
n_entries * sizeof(Entry), cudaMemcpyDefault));
// copy the weights if necessary
if (has_weights_) {
const auto& weights_vec = info.weights_.HostVector();
dh::safe_cuda
(cudaMemcpyAsync(weights_.data().get(),
weights_vec.data() + row_begin_ + batch_row_begin,
weights_vec.data() + batch_row_begin,
batch_nrows * sizeof(bst_float), cudaMemcpyDefault));
}
@@ -349,8 +353,7 @@ struct GPUSketcher {
(fvalues_.data().get(), has_weights_ ? feature_weights_.data().get() : nullptr,
row_ptrs_.data().get() + batch_row_begin,
has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(),
gpu_batch_nrows_, num_cols_,
offset_vec[row_begin_ + batch_row_begin], batch_nrows);
gpu_batch_nrows_, offset_vec[batch_row_begin], batch_nrows);
for (int icol = 0; icol < num_cols_; ++icol) {
FindColumnCuts(batch_nrows, icol);
@@ -358,7 +361,7 @@ struct GPUSketcher {
// add cuts into sketches
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
#pragma omp parallel for schedule(static) \
#pragma omp parallel for default(none) schedule(static) \
if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
for (int icol = 0; icol < num_cols_; ++icol) {
WXQSketch::SummaryContainer summary;
@@ -391,8 +394,7 @@ struct GPUSketcher {
dh::safe_cuda(cudaSetDevice(device_));
const auto& offset_vec = row_batch.offset.HostVector();
row_ptrs_.resize(n_rows_ + 1);
thrust::copy(offset_vec.data() + row_begin_,
offset_vec.data() + row_end_ + 1, row_ptrs_.begin());
thrust::copy(offset_vec.data(), offset_vec.data() + n_rows_ + 1, row_ptrs_.begin());
size_t gpu_nbatches = common::DivRoundUp(n_rows_, gpu_batch_nrows_);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
SketchBatch(row_batch, info, gpu_batch);
@@ -401,32 +403,18 @@ struct GPUSketcher {
};
void SketchBatch(const SparsePage &batch, const MetaInfo &info) {
GPUDistribution dist =
GPUDistribution::Block(GPUSet::All(generic_param_.gpu_id, generic_param_.n_gpus,
batch.Size()));
auto device = generic_param_.gpu_id;
// create device shards
shards_.resize(dist.Devices().Size());
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
size_t start = dist.ShardStart(batch.Size(), i);
size_t size = dist.ShardSize(batch.Size(), i);
shard = std::unique_ptr<DeviceShard>(
new DeviceShard(dist.Devices().DeviceId(i), start,
start + size, param_, sketch_container_.get()));
});
// create device shard
shard_.reset(new DeviceShard(device, batch.Size(), param_, sketch_container_.get()));
// compute sketches for each shard
dh::ExecuteIndexShards(&shards_,
[&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->Init(batch, info, gpu_batch_nrows_);
shard->Sketch(batch, info);
shard->ComputeRowStride();
});
// compute sketches for the shard
shard_->Init(batch, info, gpu_batch_nrows_);
shard_->Sketch(batch, info);
shard_->ComputeRowStride();
// compute row stride across all shards
for (const auto &shard : shards_) {
row_stride_ = std::max(row_stride_, shard->GetRowStride());
}
// compute row stride
row_stride_ = shard_->GetRowStride();
}
GPUSketcher(const tree::TrainParam &param, const GenericParameter &generic_param, int gpu_nrows)
@@ -444,13 +432,13 @@ struct GPUSketcher {
this->SketchBatch(batch, info);
}
hmat->Init(&sketch_container_.get()->sketches_, param_.max_bin);
hmat->Init(&sketch_container_->sketches_, param_.max_bin);
return row_stride_;
}
private:
std::vector<std::unique_ptr<DeviceShard>> shards_;
std::unique_ptr<DeviceShard> shard_;
const tree::TrainParam &param_;
const GenericParameter &generic_param_;
int gpu_batch_nrows_;