remove device shards (#4867)
This commit is contained in:
parent
0b89cd1dfa
commit
562bb0ae31
@ -27,9 +27,11 @@ namespace common {
|
||||
|
||||
using WXQSketch = DenseCuts::WXQSketch;
|
||||
|
||||
__global__ void FindCutsK
|
||||
(WXQSketch::Entry* __restrict__ cuts, const bst_float* __restrict__ data,
|
||||
const float* __restrict__ cum_weights, int nsamples, int ncuts) {
|
||||
__global__ void FindCutsK(WXQSketch::Entry* __restrict__ cuts,
|
||||
const bst_float* __restrict__ data,
|
||||
const float* __restrict__ cum_weights,
|
||||
int nsamples,
|
||||
int ncuts) {
|
||||
// ncuts < nsamples
|
||||
int icut = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (icut >= ncuts) {
|
||||
@ -42,7 +44,7 @@ __global__ void FindCutsK
|
||||
isample = nsamples - 1;
|
||||
} else {
|
||||
bst_float rank = cum_weights[nsamples - 1] / static_cast<float>(ncuts - 1)
|
||||
* static_cast<float>(icut);
|
||||
* static_cast<float>(icut);
|
||||
// -1 is used because cum_weights is an inclusive sum
|
||||
isample = dh::UpperBound(cum_weights, nsamples, rank);
|
||||
isample = max(0, min(isample, nsamples - 1));
|
||||
@ -99,9 +101,8 @@ struct SketchContainer {
|
||||
std::vector<std::mutex> col_locks_; // NOLINT
|
||||
static constexpr int kOmpNumColsParallelizeLimit = 1000;
|
||||
|
||||
SketchContainer(int max_bin, DMatrix *dmat) :
|
||||
col_locks_(dmat->Info().num_col_) {
|
||||
const MetaInfo &info = dmat->Info();
|
||||
SketchContainer(int max_bin, DMatrix* dmat) : col_locks_(dmat->Info().num_col_) {
|
||||
const MetaInfo& info = dmat->Info();
|
||||
// Initialize Sketches for this dmatrix
|
||||
sketches_.resize(info.num_col_);
|
||||
#pragma omp parallel for default(none) shared(info, max_bin) schedule(static) \
|
||||
@ -119,328 +120,339 @@ if (info.num_col_ > kOmpNumColsParallelizeLimit) // NOLINT
|
||||
};
|
||||
|
||||
// finds quantiles on the GPU
|
||||
struct GPUSketcher {
|
||||
// manage memory for a single GPU
|
||||
class DeviceShard {
|
||||
int device_;
|
||||
bst_uint n_rows_;
|
||||
int num_cols_{0};
|
||||
size_t n_cuts_{0};
|
||||
size_t gpu_batch_nrows_{0};
|
||||
bool has_weights_{false};
|
||||
size_t row_stride_{0};
|
||||
|
||||
const int max_bin_;
|
||||
SketchContainer *sketch_container_;
|
||||
dh::device_vector<size_t> row_ptrs_{};
|
||||
dh::device_vector<Entry> entries_{};
|
||||
dh::device_vector<bst_float> fvalues_{};
|
||||
dh::device_vector<bst_float> feature_weights_{};
|
||||
dh::device_vector<bst_float> fvalues_cur_{};
|
||||
dh::device_vector<WXQSketch::Entry> cuts_d_{};
|
||||
thrust::host_vector<WXQSketch::Entry> cuts_h_{};
|
||||
dh::device_vector<bst_float> weights_{};
|
||||
dh::device_vector<bst_float> weights2_{};
|
||||
std::vector<size_t> n_cuts_cur_{};
|
||||
dh::device_vector<size_t> num_elements_{};
|
||||
dh::device_vector<char> tmp_storage_{};
|
||||
|
||||
public:
|
||||
DeviceShard(int device,
|
||||
bst_uint n_rows,
|
||||
int max_bin,
|
||||
SketchContainer* sketch_container) :
|
||||
device_(device),
|
||||
n_rows_(n_rows),
|
||||
max_bin_(max_bin),
|
||||
sketch_container_(sketch_container) {
|
||||
}
|
||||
|
||||
~DeviceShard() { // NOLINT
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
}
|
||||
|
||||
inline size_t GetRowStride() const {
|
||||
return row_stride_;
|
||||
}
|
||||
|
||||
void Init(const SparsePage& row_batch, const MetaInfo& info, int gpu_batch_nrows) {
|
||||
num_cols_ = info.num_col_;
|
||||
has_weights_ = info.weights_.Size() > 0;
|
||||
|
||||
// find the batch size
|
||||
if (gpu_batch_nrows == 0) {
|
||||
// By default, use no more than 1/16th of GPU memory
|
||||
gpu_batch_nrows_ = dh::TotalMemory(device_) /
|
||||
(16 * num_cols_ * sizeof(Entry));
|
||||
} else if (gpu_batch_nrows == -1) {
|
||||
gpu_batch_nrows_ = n_rows_;
|
||||
} else {
|
||||
gpu_batch_nrows_ = gpu_batch_nrows;
|
||||
}
|
||||
if (gpu_batch_nrows_ > n_rows_) {
|
||||
gpu_batch_nrows_ = n_rows_;
|
||||
}
|
||||
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * max_bin_);
|
||||
size_t dummy_nlevel;
|
||||
WXQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_);
|
||||
|
||||
// allocate necessary GPU buffers
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
|
||||
entries_.resize(gpu_batch_nrows_ * num_cols_);
|
||||
fvalues_.resize(gpu_batch_nrows_ * num_cols_);
|
||||
fvalues_cur_.resize(gpu_batch_nrows_);
|
||||
cuts_d_.resize(n_cuts_ * num_cols_);
|
||||
cuts_h_.resize(n_cuts_ * num_cols_);
|
||||
weights_.resize(gpu_batch_nrows_);
|
||||
weights2_.resize(gpu_batch_nrows_);
|
||||
num_elements_.resize(1);
|
||||
|
||||
if (has_weights_) {
|
||||
feature_weights_.resize(gpu_batch_nrows_ * num_cols_);
|
||||
}
|
||||
n_cuts_cur_.resize(num_cols_);
|
||||
|
||||
// allocate storage for CUB algorithms; the size is the maximum of the sizes
|
||||
// required for various algorithm
|
||||
size_t tmp_size = 0, cur_tmp_size = 0;
|
||||
// size for sorting
|
||||
if (has_weights_) {
|
||||
cub::DeviceRadixSort::SortPairs
|
||||
(nullptr, cur_tmp_size, fvalues_cur_.data().get(),
|
||||
fvalues_.data().get(), weights_.data().get(), weights2_.data().get(),
|
||||
gpu_batch_nrows_);
|
||||
} else {
|
||||
cub::DeviceRadixSort::SortKeys
|
||||
(nullptr, cur_tmp_size, fvalues_cur_.data().get(), fvalues_.data().get(),
|
||||
gpu_batch_nrows_);
|
||||
}
|
||||
tmp_size = std::max(tmp_size, cur_tmp_size);
|
||||
// size for inclusive scan
|
||||
if (has_weights_) {
|
||||
cub::DeviceScan::InclusiveSum
|
||||
(nullptr, cur_tmp_size, weights2_.begin(), weights_.begin(), gpu_batch_nrows_);
|
||||
tmp_size = std::max(tmp_size, cur_tmp_size);
|
||||
}
|
||||
// size for reduction by key
|
||||
cub::DeviceReduce::ReduceByKey
|
||||
(nullptr, cur_tmp_size, fvalues_.begin(),
|
||||
fvalues_cur_.begin(), weights_.begin(), weights2_.begin(),
|
||||
num_elements_.begin(), thrust::maximum<bst_float>(), gpu_batch_nrows_);
|
||||
tmp_size = std::max(tmp_size, cur_tmp_size);
|
||||
// size for filtering
|
||||
cub::DeviceSelect::If
|
||||
(nullptr, cur_tmp_size, fvalues_.begin(), fvalues_cur_.begin(),
|
||||
num_elements_.begin(), gpu_batch_nrows_, IsNotNaN());
|
||||
tmp_size = std::max(tmp_size, cur_tmp_size);
|
||||
|
||||
tmp_storage_.resize(tmp_size);
|
||||
}
|
||||
|
||||
void FindColumnCuts(size_t batch_nrows, size_t icol) {
|
||||
size_t tmp_size = tmp_storage_.size();
|
||||
// filter out NaNs in feature values
|
||||
auto fvalues_begin = fvalues_.data() + icol * gpu_batch_nrows_;
|
||||
cub::DeviceSelect::If
|
||||
(tmp_storage_.data().get(), tmp_size, fvalues_begin,
|
||||
fvalues_cur_.data(), num_elements_.begin(), batch_nrows, IsNotNaN());
|
||||
size_t nfvalues_cur = 0;
|
||||
thrust::copy_n(num_elements_.begin(), 1, &nfvalues_cur);
|
||||
|
||||
// compute cumulative weights using a prefix scan
|
||||
if (has_weights_) {
|
||||
// filter out NaNs in weights;
|
||||
// since cub::DeviceSelect::If performs stable filtering,
|
||||
// the weights are stored in the correct positions
|
||||
auto feature_weights_begin = feature_weights_.data() +
|
||||
icol * gpu_batch_nrows_;
|
||||
cub::DeviceSelect::If
|
||||
(tmp_storage_.data().get(), tmp_size, feature_weights_begin,
|
||||
weights_.data().get(), num_elements_.begin(), batch_nrows, IsNotNaN());
|
||||
|
||||
// sort the values and weights
|
||||
cub::DeviceRadixSort::SortPairs
|
||||
(tmp_storage_.data().get(), tmp_size, fvalues_cur_.data().get(),
|
||||
fvalues_begin.get(), weights_.data().get(), weights2_.data().get(),
|
||||
nfvalues_cur);
|
||||
|
||||
// sum the weights to get cumulative weight values
|
||||
cub::DeviceScan::InclusiveSum
|
||||
(tmp_storage_.data().get(), tmp_size, weights2_.begin(),
|
||||
weights_.begin(), nfvalues_cur);
|
||||
} else {
|
||||
// sort the batch values
|
||||
cub::DeviceRadixSort::SortKeys
|
||||
(tmp_storage_.data().get(), tmp_size,
|
||||
fvalues_cur_.data().get(), fvalues_begin.get(), nfvalues_cur);
|
||||
|
||||
// fill in cumulative weights with counting iterator
|
||||
thrust::copy_n(thrust::make_counting_iterator(1), nfvalues_cur,
|
||||
weights_.begin());
|
||||
}
|
||||
|
||||
// remove repeated items and sum the weights across them;
|
||||
// non-negative weights are assumed
|
||||
cub::DeviceReduce::ReduceByKey
|
||||
(tmp_storage_.data().get(), tmp_size, fvalues_begin,
|
||||
fvalues_cur_.begin(), weights_.begin(), weights2_.begin(),
|
||||
num_elements_.begin(), thrust::maximum<bst_float>(), nfvalues_cur);
|
||||
size_t n_unique = 0;
|
||||
thrust::copy_n(num_elements_.begin(), 1, &n_unique);
|
||||
|
||||
// extract cuts
|
||||
n_cuts_cur_[icol] = std::min(n_cuts_, n_unique);
|
||||
// if less elements than cuts: copy all elements with their weights
|
||||
if (n_cuts_ > n_unique) {
|
||||
float* weights2_ptr = weights2_.data().get();
|
||||
float* fvalues_ptr = fvalues_cur_.data().get();
|
||||
WXQSketch::Entry* cuts_ptr = cuts_d_.data().get() + icol * n_cuts_;
|
||||
dh::LaunchN(device_, n_unique, [=]__device__(size_t i) {
|
||||
bst_float rmax = weights2_ptr[i];
|
||||
bst_float rmin = i > 0 ? weights2_ptr[i - 1] : 0;
|
||||
cuts_ptr[i] = WXQSketch::Entry(rmin, rmax, rmax - rmin, fvalues_ptr[i]);
|
||||
});
|
||||
} else if (n_cuts_cur_[icol] > 0) {
|
||||
// if more elements than cuts: use binary search on cumulative weights
|
||||
int block = 256;
|
||||
FindCutsK<<<common::DivRoundUp(n_cuts_cur_[icol], block), block>>>
|
||||
(cuts_d_.data().get() + icol * n_cuts_, fvalues_cur_.data().get(),
|
||||
weights2_.data().get(), n_unique, n_cuts_cur_[icol]);
|
||||
dh::safe_cuda(cudaGetLastError()); // NOLINT
|
||||
}
|
||||
}
|
||||
|
||||
void SketchBatch(const SparsePage& row_batch, const MetaInfo& info,
|
||||
size_t gpu_batch) {
|
||||
// compute start and end indices
|
||||
size_t batch_row_begin = gpu_batch * gpu_batch_nrows_;
|
||||
size_t batch_row_end = std::min((gpu_batch + 1) * gpu_batch_nrows_,
|
||||
static_cast<size_t>(n_rows_));
|
||||
size_t batch_nrows = batch_row_end - batch_row_begin;
|
||||
|
||||
const auto& offset_vec = row_batch.offset.HostVector();
|
||||
const auto& data_vec = row_batch.data.HostVector();
|
||||
|
||||
size_t n_entries = offset_vec[batch_row_end] - offset_vec[batch_row_begin];
|
||||
// copy the batch to the GPU
|
||||
dh::safe_cuda
|
||||
(cudaMemcpyAsync(entries_.data().get(),
|
||||
data_vec.data() + offset_vec[batch_row_begin],
|
||||
n_entries * sizeof(Entry), cudaMemcpyDefault));
|
||||
// copy the weights if necessary
|
||||
if (has_weights_) {
|
||||
const auto& weights_vec = info.weights_.HostVector();
|
||||
dh::safe_cuda
|
||||
(cudaMemcpyAsync(weights_.data().get(),
|
||||
weights_vec.data() + batch_row_begin,
|
||||
batch_nrows * sizeof(bst_float), cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
// unpack the features; also unpack weights if present
|
||||
thrust::fill(fvalues_.begin(), fvalues_.end(), NAN);
|
||||
if (has_weights_) {
|
||||
thrust::fill(feature_weights_.begin(), feature_weights_.end(), NAN);
|
||||
}
|
||||
|
||||
dim3 block3(16, 64, 1);
|
||||
// NOTE: This will typically support ~ 4M features - 64K*64
|
||||
dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
|
||||
common::DivRoundUp(num_cols_, block3.y), 1);
|
||||
UnpackFeaturesK<<<grid3, block3>>>
|
||||
(fvalues_.data().get(), has_weights_ ? feature_weights_.data().get() : nullptr,
|
||||
row_ptrs_.data().get() + batch_row_begin,
|
||||
has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(),
|
||||
gpu_batch_nrows_, offset_vec[batch_row_begin], batch_nrows);
|
||||
|
||||
for (int icol = 0; icol < num_cols_; ++icol) {
|
||||
FindColumnCuts(batch_nrows, icol);
|
||||
}
|
||||
|
||||
// add cuts into sketches
|
||||
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
|
||||
#pragma omp parallel for default(none) schedule(static) \
|
||||
if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
|
||||
for (int icol = 0; icol < num_cols_; ++icol) {
|
||||
WXQSketch::SummaryContainer summary;
|
||||
summary.Reserve(n_cuts_);
|
||||
summary.MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]);
|
||||
|
||||
std::lock_guard<std::mutex> lock(sketch_container_->col_locks_[icol]);
|
||||
sketch_container_->sketches_[icol].PushSummary(summary);
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeRowStride() {
|
||||
// Find the row stride for this batch
|
||||
auto row_iter = row_ptrs_.begin();
|
||||
// Functor for finding the maximum row size for this batch
|
||||
auto get_size = [=] __device__(size_t row) {
|
||||
return row_iter[row + 1] - row_iter[row];
|
||||
}; // NOLINT
|
||||
|
||||
auto counting = thrust::make_counting_iterator(size_t(0));
|
||||
using TransformT = thrust::transform_iterator<decltype(get_size),
|
||||
decltype(counting), size_t>;
|
||||
TransformT row_size_iter = TransformT(counting, get_size);
|
||||
row_stride_ = thrust::reduce(row_size_iter, row_size_iter + n_rows_, 0,
|
||||
thrust::maximum<size_t>());
|
||||
}
|
||||
|
||||
void Sketch(const SparsePage& row_batch, const MetaInfo& info) {
|
||||
// copy rows to the device
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
const auto& offset_vec = row_batch.offset.HostVector();
|
||||
row_ptrs_.resize(n_rows_ + 1);
|
||||
thrust::copy(offset_vec.data(), offset_vec.data() + n_rows_ + 1, row_ptrs_.begin());
|
||||
size_t gpu_nbatches = common::DivRoundUp(n_rows_, gpu_batch_nrows_);
|
||||
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
|
||||
SketchBatch(row_batch, info, gpu_batch);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void SketchBatch(const SparsePage &batch, const MetaInfo &info) {
|
||||
// create device shard
|
||||
shard_.reset(new DeviceShard(device_, batch.Size(), max_bin_, sketch_container_.get()));
|
||||
|
||||
// compute sketches for the shard
|
||||
shard_->Init(batch, info, gpu_batch_nrows_);
|
||||
shard_->Sketch(batch, info);
|
||||
shard_->ComputeRowStride();
|
||||
|
||||
// compute row stride
|
||||
row_stride_ = shard_->GetRowStride();
|
||||
}
|
||||
|
||||
class GPUSketcher {
|
||||
public:
|
||||
GPUSketcher(int device, int max_bin, int gpu_nrows)
|
||||
: device_(device), max_bin_(max_bin), gpu_batch_nrows_(gpu_nrows), row_stride_(0) {}
|
||||
|
||||
~GPUSketcher() { // NOLINT
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
}
|
||||
|
||||
void SketchBatch(const SparsePage &batch, const MetaInfo &info) {
|
||||
n_rows_ = batch.Size();
|
||||
|
||||
Init(batch, info, gpu_batch_nrows_);
|
||||
Sketch(batch, info);
|
||||
ComputeRowStride();
|
||||
}
|
||||
|
||||
/* Builds the sketches on the GPU for the dmatrix and returns the row stride
|
||||
* for the entire dataset */
|
||||
size_t Sketch(DMatrix *dmat, DenseCuts *hmat) {
|
||||
const MetaInfo &info = dmat->Info();
|
||||
const MetaInfo& info = dmat->Info();
|
||||
|
||||
row_stride_ = 0;
|
||||
sketch_container_.reset(new SketchContainer(max_bin_, dmat));
|
||||
for (const auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
this->SketchBatch(batch, info);
|
||||
}
|
||||
|
||||
hmat->Init(&sketch_container_->sketches_, max_bin_);
|
||||
|
||||
return row_stride_;
|
||||
}
|
||||
|
||||
// This needs to be public because of the __device__ lambda.
|
||||
void ComputeRowStride() {
|
||||
// Find the row stride for this batch
|
||||
auto row_iter = row_ptrs_.begin();
|
||||
// Functor for finding the maximum row size for this batch
|
||||
auto get_size = [=] __device__(size_t row) {
|
||||
return row_iter[row + 1] - row_iter[row];
|
||||
}; // NOLINT
|
||||
|
||||
auto counting = thrust::make_counting_iterator(size_t(0));
|
||||
using TransformT = thrust::transform_iterator<decltype(get_size), decltype(counting), size_t>;
|
||||
TransformT row_size_iter = TransformT(counting, get_size);
|
||||
row_stride_ =
|
||||
thrust::reduce(row_size_iter, row_size_iter + n_rows_, 0, thrust::maximum<size_t>());
|
||||
}
|
||||
|
||||
// This needs to be public because of the __device__ lambda.
|
||||
void FindColumnCuts(size_t batch_nrows, size_t icol) {
|
||||
size_t tmp_size = tmp_storage_.size();
|
||||
// filter out NaNs in feature values
|
||||
auto fvalues_begin = fvalues_.data() + icol * gpu_batch_nrows_;
|
||||
cub::DeviceSelect::If(tmp_storage_.data().get(),
|
||||
tmp_size,
|
||||
fvalues_begin,
|
||||
fvalues_cur_.data(),
|
||||
num_elements_.begin(),
|
||||
batch_nrows,
|
||||
IsNotNaN());
|
||||
size_t nfvalues_cur = 0;
|
||||
thrust::copy_n(num_elements_.begin(), 1, &nfvalues_cur);
|
||||
|
||||
// compute cumulative weights using a prefix scan
|
||||
if (has_weights_) {
|
||||
// filter out NaNs in weights;
|
||||
// since cub::DeviceSelect::If performs stable filtering,
|
||||
// the weights are stored in the correct positions
|
||||
auto feature_weights_begin = feature_weights_.data() + icol * gpu_batch_nrows_;
|
||||
cub::DeviceSelect::If(tmp_storage_.data().get(),
|
||||
tmp_size,
|
||||
feature_weights_begin,
|
||||
weights_.data().get(),
|
||||
num_elements_.begin(),
|
||||
batch_nrows,
|
||||
IsNotNaN());
|
||||
|
||||
// sort the values and weights
|
||||
cub::DeviceRadixSort::SortPairs(tmp_storage_.data().get(),
|
||||
tmp_size,
|
||||
fvalues_cur_.data().get(),
|
||||
fvalues_begin.get(),
|
||||
weights_.data().get(),
|
||||
weights2_.data().get(),
|
||||
nfvalues_cur);
|
||||
|
||||
// sum the weights to get cumulative weight values
|
||||
cub::DeviceScan::InclusiveSum(tmp_storage_.data().get(),
|
||||
tmp_size,
|
||||
weights2_.begin(),
|
||||
weights_.begin(),
|
||||
nfvalues_cur);
|
||||
} else {
|
||||
// sort the batch values
|
||||
cub::DeviceRadixSort::SortKeys(tmp_storage_.data().get(),
|
||||
tmp_size,
|
||||
fvalues_cur_.data().get(),
|
||||
fvalues_begin.get(),
|
||||
nfvalues_cur);
|
||||
|
||||
// fill in cumulative weights with counting iterator
|
||||
thrust::copy_n(thrust::make_counting_iterator(1), nfvalues_cur, weights_.begin());
|
||||
}
|
||||
|
||||
// remove repeated items and sum the weights across them;
|
||||
// non-negative weights are assumed
|
||||
cub::DeviceReduce::ReduceByKey(tmp_storage_.data().get(),
|
||||
tmp_size,
|
||||
fvalues_begin,
|
||||
fvalues_cur_.begin(),
|
||||
weights_.begin(),
|
||||
weights2_.begin(),
|
||||
num_elements_.begin(),
|
||||
thrust::maximum<bst_float>(),
|
||||
nfvalues_cur);
|
||||
size_t n_unique = 0;
|
||||
thrust::copy_n(num_elements_.begin(), 1, &n_unique);
|
||||
|
||||
// extract cuts
|
||||
n_cuts_cur_[icol] = std::min(n_cuts_, n_unique);
|
||||
// if less elements than cuts: copy all elements with their weights
|
||||
if (n_cuts_ > n_unique) {
|
||||
float* weights2_ptr = weights2_.data().get();
|
||||
float* fvalues_ptr = fvalues_cur_.data().get();
|
||||
WXQSketch::Entry* cuts_ptr = cuts_d_.data().get() + icol * n_cuts_;
|
||||
dh::LaunchN(device_, n_unique, [=]__device__(size_t i) {
|
||||
bst_float rmax = weights2_ptr[i];
|
||||
bst_float rmin = i > 0 ? weights2_ptr[i - 1] : 0;
|
||||
cuts_ptr[i] = WXQSketch::Entry(rmin, rmax, rmax - rmin, fvalues_ptr[i]);
|
||||
});
|
||||
} else if (n_cuts_cur_[icol] > 0) {
|
||||
// if more elements than cuts: use binary search on cumulative weights
|
||||
int block = 256;
|
||||
FindCutsK<<<common::DivRoundUp(n_cuts_cur_[icol], block), block>>>(
|
||||
cuts_d_.data().get() + icol * n_cuts_,
|
||||
fvalues_cur_.data().get(),
|
||||
weights2_.data().get(),
|
||||
n_unique,
|
||||
n_cuts_cur_[icol]);
|
||||
dh::safe_cuda(cudaGetLastError()); // NOLINT
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<DeviceShard> shard_;
|
||||
void Init(const SparsePage& row_batch, const MetaInfo& info, int gpu_batch_nrows) {
|
||||
num_cols_ = info.num_col_;
|
||||
has_weights_ = info.weights_.Size() > 0;
|
||||
|
||||
// find the batch size
|
||||
if (gpu_batch_nrows == 0) {
|
||||
// By default, use no more than 1/16th of GPU memory
|
||||
gpu_batch_nrows_ = dh::TotalMemory(device_) / (16 * num_cols_ * sizeof(Entry));
|
||||
} else if (gpu_batch_nrows == -1) {
|
||||
gpu_batch_nrows_ = n_rows_;
|
||||
} else {
|
||||
gpu_batch_nrows_ = gpu_batch_nrows;
|
||||
}
|
||||
if (gpu_batch_nrows_ > n_rows_) {
|
||||
gpu_batch_nrows_ = n_rows_;
|
||||
}
|
||||
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * max_bin_);
|
||||
size_t dummy_nlevel;
|
||||
WXQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_);
|
||||
|
||||
// allocate necessary GPU buffers
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
|
||||
entries_.resize(gpu_batch_nrows_ * num_cols_);
|
||||
fvalues_.resize(gpu_batch_nrows_ * num_cols_);
|
||||
fvalues_cur_.resize(gpu_batch_nrows_);
|
||||
cuts_d_.resize(n_cuts_ * num_cols_);
|
||||
cuts_h_.resize(n_cuts_ * num_cols_);
|
||||
weights_.resize(gpu_batch_nrows_);
|
||||
weights2_.resize(gpu_batch_nrows_);
|
||||
num_elements_.resize(1);
|
||||
|
||||
if (has_weights_) {
|
||||
feature_weights_.resize(gpu_batch_nrows_ * num_cols_);
|
||||
}
|
||||
n_cuts_cur_.resize(num_cols_);
|
||||
|
||||
// allocate storage for CUB algorithms; the size is the maximum of the sizes
|
||||
// required for various algorithm
|
||||
size_t tmp_size = 0, cur_tmp_size = 0;
|
||||
// size for sorting
|
||||
if (has_weights_) {
|
||||
cub::DeviceRadixSort::SortPairs(nullptr,
|
||||
cur_tmp_size,
|
||||
fvalues_cur_.data().get(),
|
||||
fvalues_.data().get(),
|
||||
weights_.data().get(),
|
||||
weights2_.data().get(),
|
||||
gpu_batch_nrows_);
|
||||
} else {
|
||||
cub::DeviceRadixSort::SortKeys(nullptr,
|
||||
cur_tmp_size,
|
||||
fvalues_cur_.data().get(),
|
||||
fvalues_.data().get(),
|
||||
gpu_batch_nrows_);
|
||||
}
|
||||
tmp_size = std::max(tmp_size, cur_tmp_size);
|
||||
// size for inclusive scan
|
||||
if (has_weights_) {
|
||||
cub::DeviceScan::InclusiveSum(nullptr,
|
||||
cur_tmp_size,
|
||||
weights2_.begin(),
|
||||
weights_.begin(),
|
||||
gpu_batch_nrows_);
|
||||
tmp_size = std::max(tmp_size, cur_tmp_size);
|
||||
}
|
||||
// size for reduction by key
|
||||
cub::DeviceReduce::ReduceByKey(nullptr,
|
||||
cur_tmp_size,
|
||||
fvalues_.begin(),
|
||||
fvalues_cur_.begin(),
|
||||
weights_.begin(),
|
||||
weights2_.begin(),
|
||||
num_elements_.begin(),
|
||||
thrust::maximum<bst_float>(),
|
||||
gpu_batch_nrows_);
|
||||
tmp_size = std::max(tmp_size, cur_tmp_size);
|
||||
// size for filtering
|
||||
cub::DeviceSelect::If(nullptr,
|
||||
cur_tmp_size,
|
||||
fvalues_.begin(),
|
||||
fvalues_cur_.begin(),
|
||||
num_elements_.begin(),
|
||||
gpu_batch_nrows_,
|
||||
IsNotNaN());
|
||||
tmp_size = std::max(tmp_size, cur_tmp_size);
|
||||
|
||||
tmp_storage_.resize(tmp_size);
|
||||
}
|
||||
|
||||
void Sketch(const SparsePage& row_batch, const MetaInfo& info) {
|
||||
// copy rows to the device
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
const auto& offset_vec = row_batch.offset.HostVector();
|
||||
row_ptrs_.resize(n_rows_ + 1);
|
||||
thrust::copy(offset_vec.data(), offset_vec.data() + n_rows_ + 1, row_ptrs_.begin());
|
||||
size_t gpu_nbatches = common::DivRoundUp(n_rows_, gpu_batch_nrows_);
|
||||
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
|
||||
SketchBatch(row_batch, info, gpu_batch);
|
||||
}
|
||||
}
|
||||
|
||||
void SketchBatch(const SparsePage& row_batch, const MetaInfo& info, size_t gpu_batch) {
|
||||
// compute start and end indices
|
||||
size_t batch_row_begin = gpu_batch * gpu_batch_nrows_;
|
||||
size_t batch_row_end = std::min((gpu_batch + 1) * gpu_batch_nrows_,
|
||||
static_cast<size_t>(n_rows_));
|
||||
size_t batch_nrows = batch_row_end - batch_row_begin;
|
||||
|
||||
const auto& offset_vec = row_batch.offset.HostVector();
|
||||
const auto& data_vec = row_batch.data.HostVector();
|
||||
|
||||
size_t n_entries = offset_vec[batch_row_end] - offset_vec[batch_row_begin];
|
||||
// copy the batch to the GPU
|
||||
dh::safe_cuda(cudaMemcpyAsync(entries_.data().get(),
|
||||
data_vec.data() + offset_vec[batch_row_begin],
|
||||
n_entries * sizeof(Entry),
|
||||
cudaMemcpyDefault));
|
||||
// copy the weights if necessary
|
||||
if (has_weights_) {
|
||||
const auto& weights_vec = info.weights_.HostVector();
|
||||
dh::safe_cuda(cudaMemcpyAsync(weights_.data().get(),
|
||||
weights_vec.data() + batch_row_begin,
|
||||
batch_nrows * sizeof(bst_float),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
// unpack the features; also unpack weights if present
|
||||
thrust::fill(fvalues_.begin(), fvalues_.end(), NAN);
|
||||
if (has_weights_) {
|
||||
thrust::fill(feature_weights_.begin(), feature_weights_.end(), NAN);
|
||||
}
|
||||
|
||||
dim3 block3(16, 64, 1);
|
||||
// NOTE: This will typically support ~ 4M features - 64K*64
|
||||
dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
|
||||
common::DivRoundUp(num_cols_, block3.y), 1);
|
||||
UnpackFeaturesK<<<grid3, block3>>>(
|
||||
fvalues_.data().get(),
|
||||
has_weights_ ? feature_weights_.data().get() : nullptr,
|
||||
row_ptrs_.data().get() + batch_row_begin,
|
||||
has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(),
|
||||
gpu_batch_nrows_,
|
||||
offset_vec[batch_row_begin],
|
||||
batch_nrows);
|
||||
|
||||
for (int icol = 0; icol < num_cols_; ++icol) {
|
||||
FindColumnCuts(batch_nrows, icol);
|
||||
}
|
||||
|
||||
// add cuts into sketches
|
||||
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
|
||||
#pragma omp parallel for default(none) schedule(static) \
|
||||
if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
|
||||
for (int icol = 0; icol < num_cols_; ++icol) {
|
||||
WXQSketch::SummaryContainer summary;
|
||||
summary.Reserve(n_cuts_);
|
||||
summary.MakeFromSorted(&cuts_h_[n_cuts_ * icol], n_cuts_cur_[icol]);
|
||||
|
||||
std::lock_guard<std::mutex> lock(sketch_container_->col_locks_[icol]);
|
||||
sketch_container_->sketches_[icol].PushSummary(summary);
|
||||
}
|
||||
}
|
||||
|
||||
const int device_;
|
||||
const int max_bin_;
|
||||
int gpu_batch_nrows_;
|
||||
size_t row_stride_;
|
||||
std::unique_ptr<SketchContainer> sketch_container_;
|
||||
|
||||
bst_uint n_rows_{};
|
||||
int num_cols_{0};
|
||||
size_t n_cuts_{0};
|
||||
bool has_weights_{false};
|
||||
|
||||
dh::device_vector<size_t> row_ptrs_{};
|
||||
dh::device_vector<Entry> entries_{};
|
||||
dh::device_vector<bst_float> fvalues_{};
|
||||
dh::device_vector<bst_float> feature_weights_{};
|
||||
dh::device_vector<bst_float> fvalues_cur_{};
|
||||
dh::device_vector<WXQSketch::Entry> cuts_d_{};
|
||||
thrust::host_vector<WXQSketch::Entry> cuts_h_{};
|
||||
dh::device_vector<bst_float> weights_{};
|
||||
dh::device_vector<bst_float> weights2_{};
|
||||
std::vector<size_t> n_cuts_cur_{};
|
||||
dh::device_vector<size_t> num_elements_{};
|
||||
dh::device_vector<char> tmp_storage_{};
|
||||
};
|
||||
|
||||
size_t DeviceSketch(int device,
|
||||
|
||||
@ -14,8 +14,8 @@
|
||||
* Initialization/Allocation:<br/>
|
||||
* One can choose to initialize the vector on CPU or GPU during constructor.
|
||||
* (use the 'devices' argument) Or, can choose to use the 'Resize' method to
|
||||
* allocate/resize memory explicitly, and use the 'Shard' method
|
||||
* to specify the devices.
|
||||
* allocate/resize memory explicitly, and use the 'SetDevice' method
|
||||
* to specify the device.
|
||||
*
|
||||
* Accessing underlying data:<br/>
|
||||
* Use 'HostVector' method to explicitly query for the underlying std::vector.
|
||||
|
||||
@ -73,7 +73,7 @@ void EllpackPageImpl::Init(int device, int max_bin, int gpu_batch_nrows) {
|
||||
const auto& info = dmat_->Info();
|
||||
auto is_dense = info.num_nonzero_ == info.num_row_ * info.num_col_;
|
||||
|
||||
// Init global data for each shard
|
||||
// Init global data
|
||||
monitor_.StartCuda("InitCompressedData");
|
||||
InitCompressedData(device, hmat, row_stride, is_dense);
|
||||
monitor_.StopCuda("InitCompressedData");
|
||||
|
||||
@ -19,27 +19,39 @@ namespace linear {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(updater_gpu_coordinate);
|
||||
|
||||
class DeviceShard {
|
||||
int device_id_;
|
||||
dh::BulkAllocator ba_;
|
||||
std::vector<size_t> row_ptr_;
|
||||
common::Span<xgboost::Entry> data_;
|
||||
common::Span<GradientPair> gpair_;
|
||||
dh::CubMemory temp_;
|
||||
size_t shard_size_;
|
||||
/**
|
||||
* \class GPUCoordinateUpdater
|
||||
*
|
||||
* \brief Coordinate descent algorithm that updates one feature per iteration
|
||||
*/
|
||||
|
||||
class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||
public:
|
||||
DeviceShard(int device_id,
|
||||
const SparsePage &batch, // column batch
|
||||
bst_uint shard_size,
|
||||
const LinearTrainParam ¶m,
|
||||
const gbm::GBLinearModelParam &model_param)
|
||||
: device_id_(device_id),
|
||||
shard_size_(shard_size) {
|
||||
~GPUCoordinateUpdater() { // NOLINT
|
||||
if (learner_param_->gpu_id >= 0) {
|
||||
dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id));
|
||||
}
|
||||
}
|
||||
|
||||
// set training parameter
|
||||
void Configure(Args const& args) override {
|
||||
tparam_.InitAllowUnknown(args);
|
||||
selector_.reset(FeatureSelector::Create(tparam_.feature_selector));
|
||||
monitor_.Init("GPUCoordinateUpdater");
|
||||
}
|
||||
|
||||
void LazyInitDevice(DMatrix *p_fmat, const gbm::GBLinearModelParam &model_param) {
|
||||
if (learner_param_->gpu_id < 0) return;
|
||||
|
||||
num_row_ = static_cast<size_t>(p_fmat->Info().num_row_);
|
||||
|
||||
CHECK(p_fmat->SingleColBlock());
|
||||
SparsePage const& batch = *(p_fmat->GetBatches<CSCPage>().begin());
|
||||
|
||||
if ( IsEmpty() ) { return; }
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id));
|
||||
// The begin and end indices for the section of each column associated with
|
||||
// this shard
|
||||
// this device
|
||||
std::vector<std::pair<bst_uint, bst_uint>> column_segments;
|
||||
row_ptr_ = {0};
|
||||
// iterate through columns
|
||||
@ -53,13 +65,13 @@ class DeviceShard {
|
||||
xgboost::Entry(0, 0.0f), cmp);
|
||||
auto column_end =
|
||||
std::lower_bound(col.cbegin(), col.cend(),
|
||||
xgboost::Entry(shard_size_, 0.0f), cmp);
|
||||
xgboost::Entry(num_row_, 0.0f), cmp);
|
||||
column_segments.emplace_back(
|
||||
std::make_pair(column_begin - col.cbegin(), column_end - col.cbegin()));
|
||||
row_ptr_.push_back(row_ptr_.back() + (column_end - column_begin));
|
||||
}
|
||||
ba_.Allocate(device_id_, &data_, row_ptr_.back(), &gpair_,
|
||||
shard_size_ * model_param.num_output_group);
|
||||
ba_.Allocate(learner_param_->gpu_id, &data_, row_ptr_.back(), &gpair_,
|
||||
num_row_ * model_param.num_output_group);
|
||||
|
||||
for (size_t fidx = 0; fidx < batch.Size(); fidx++) {
|
||||
auto col = batch[fidx];
|
||||
@ -71,121 +83,18 @@ class DeviceShard {
|
||||
}
|
||||
}
|
||||
|
||||
~DeviceShard() { // NOLINT
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
}
|
||||
|
||||
bool IsEmpty() {
|
||||
return shard_size_ == 0;
|
||||
}
|
||||
|
||||
void UpdateGpair(const std::vector<GradientPair> &host_gpair,
|
||||
const gbm::GBLinearModelParam &model_param) {
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
gpair_.data(),
|
||||
host_gpair.data(),
|
||||
gpair_.size() * sizeof(GradientPair), cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
GradientPair GetBiasGradient(int group_idx, int num_group) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
auto counting = thrust::make_counting_iterator(0ull);
|
||||
auto f = [=] __device__(size_t idx) {
|
||||
return idx * num_group + group_idx;
|
||||
}; // NOLINT
|
||||
thrust::transform_iterator<decltype(f), decltype(counting), size_t> skip(
|
||||
counting, f);
|
||||
auto perm = thrust::make_permutation_iterator(gpair_.data(), skip);
|
||||
|
||||
return dh::SumReduction(temp_, perm, shard_size_);
|
||||
}
|
||||
|
||||
void UpdateBiasResidual(float dbias, int group_idx, int num_groups) {
|
||||
if (dbias == 0.0f) return;
|
||||
auto d_gpair = gpair_;
|
||||
dh::LaunchN(device_id_, shard_size_, [=] __device__(size_t idx) {
|
||||
auto &g = d_gpair[idx * num_groups + group_idx];
|
||||
g += GradientPair(g.GetHess() * dbias, 0);
|
||||
});
|
||||
}
|
||||
|
||||
GradientPair GetGradient(int group_idx, int num_group, int fidx) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
common::Span<xgboost::Entry> d_col = data_.subspan(row_ptr_[fidx]);
|
||||
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
||||
common::Span<GradientPair> d_gpair = gpair_;
|
||||
auto counting = thrust::make_counting_iterator(0ull);
|
||||
auto f = [=] __device__(size_t idx) {
|
||||
auto entry = d_col[idx];
|
||||
auto g = d_gpair[entry.index * num_group + group_idx];
|
||||
return GradientPair(g.GetGrad() * entry.fvalue,
|
||||
g.GetHess() * entry.fvalue * entry.fvalue);
|
||||
}; // NOLINT
|
||||
thrust::transform_iterator<decltype(f), decltype(counting), GradientPair>
|
||||
multiply_iterator(counting, f);
|
||||
return dh::SumReduction(temp_, multiply_iterator, col_size);
|
||||
}
|
||||
|
||||
void UpdateResidual(float dw, int group_idx, int num_groups, int fidx) {
|
||||
common::Span<GradientPair> d_gpair = gpair_;
|
||||
common::Span<Entry> d_col = data_.subspan(row_ptr_[fidx]);
|
||||
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
||||
dh::LaunchN(device_id_, col_size, [=] __device__(size_t idx) {
|
||||
auto entry = d_col[idx];
|
||||
auto &g = d_gpair[entry.index * num_groups + group_idx];
|
||||
g += GradientPair(g.GetHess() * dw * entry.fvalue, 0);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \class GPUCoordinateUpdater
|
||||
*
|
||||
* \brief Coordinate descent algorithm that updates one feature per iteration
|
||||
*/
|
||||
|
||||
class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||
public:
|
||||
// set training parameter
|
||||
void Configure(Args const& args) override {
|
||||
tparam_.InitAllowUnknown(args);
|
||||
selector_.reset(FeatureSelector::Create(tparam_.feature_selector));
|
||||
monitor_.Init("GPUCoordinateUpdater");
|
||||
}
|
||||
|
||||
void LazyInitShards(DMatrix *p_fmat,
|
||||
const gbm::GBLinearModelParam &model_param) {
|
||||
if (shard_) return;
|
||||
|
||||
device_ = learner_param_->gpu_id;
|
||||
|
||||
auto num_row = static_cast<size_t>(p_fmat->Info().num_row_);
|
||||
|
||||
// Partition input matrix into row segments
|
||||
std::vector<size_t> row_segments;
|
||||
row_segments.push_back(0);
|
||||
size_t shard_size = num_row;
|
||||
row_segments.push_back(shard_size);
|
||||
|
||||
CHECK(p_fmat->SingleColBlock());
|
||||
SparsePage const& batch = *(p_fmat->GetBatches<CSCPage>().begin());
|
||||
|
||||
// Create device shard
|
||||
shard_.reset(new DeviceShard(device_, batch, shard_size, tparam_, model_param));
|
||||
}
|
||||
|
||||
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
||||
gbm::GBLinearModel *model, double sum_instance_weight) override {
|
||||
tparam_.DenormalizePenalties(sum_instance_weight);
|
||||
monitor_.Start("LazyInitShards");
|
||||
this->LazyInitShards(p_fmat, model->param);
|
||||
monitor_.Stop("LazyInitShards");
|
||||
monitor_.Start("LazyInitDevice");
|
||||
this->LazyInitDevice(p_fmat, model->param);
|
||||
monitor_.Stop("LazyInitDevice");
|
||||
|
||||
monitor_.Start("UpdateGpair");
|
||||
auto &in_gpair_host = in_gpair->ConstHostVector();
|
||||
// Update gpair
|
||||
if (shard_) {
|
||||
shard_->UpdateGpair(in_gpair_host, model->param);
|
||||
if (learner_param_->gpu_id >= 0) {
|
||||
this->UpdateGpair(in_gpair_host, model->param);
|
||||
}
|
||||
monitor_.Stop("UpdateGpair");
|
||||
|
||||
@ -197,8 +106,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||
tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm,
|
||||
coord_param_.top_k);
|
||||
monitor_.Start("UpdateFeature");
|
||||
for (auto group_idx = 0; group_idx < model->param.num_output_group;
|
||||
++group_idx) {
|
||||
for (auto group_idx = 0; group_idx < model->param.num_output_group; ++group_idx) {
|
||||
for (auto i = 0U; i < model->param.num_feature; i++) {
|
||||
auto fidx = selector_->NextFeature(
|
||||
i, *model, group_idx, in_gpair->ConstHostVector(), p_fmat,
|
||||
@ -214,8 +122,8 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||
for (int group_idx = 0; group_idx < model->param.num_output_group; ++group_idx) {
|
||||
// Get gradient
|
||||
auto grad = GradientPair(0, 0);
|
||||
if (shard_) {
|
||||
grad = shard_->GetBiasGradient(group_idx, model->param.num_output_group);
|
||||
if (learner_param_->gpu_id >= 0) {
|
||||
grad = GetBiasGradient(group_idx, model->param.num_output_group);
|
||||
}
|
||||
auto dbias = static_cast<float>(
|
||||
tparam_.learning_rate *
|
||||
@ -223,8 +131,8 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||
model->bias()[group_idx] += dbias;
|
||||
|
||||
// Update residual
|
||||
if (shard_) {
|
||||
shard_->UpdateBiasResidual(dbias, group_idx, model->param.num_output_group);
|
||||
if (learner_param_->gpu_id >= 0) {
|
||||
UpdateBiasResidual(dbias, group_idx, model->param.num_output_group);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -235,8 +143,8 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||
bst_float &w = (*model)[fidx][group_idx];
|
||||
// Get gradient
|
||||
auto grad = GradientPair(0, 0);
|
||||
if (shard_) {
|
||||
grad = shard_->GetGradient(group_idx, model->param.num_output_group, fidx);
|
||||
if (learner_param_->gpu_id >= 0) {
|
||||
grad = GetGradient(group_idx, model->param.num_output_group, fidx);
|
||||
}
|
||||
auto dw = static_cast<float>(tparam_.learning_rate *
|
||||
CoordinateDelta(grad.GetGrad(), grad.GetHess(),
|
||||
@ -244,20 +152,90 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||
tparam_.reg_lambda_denorm));
|
||||
w += dw;
|
||||
|
||||
if (shard_) {
|
||||
shard_->UpdateResidual(dw, group_idx, model->param.num_output_group, fidx);
|
||||
if (learner_param_->gpu_id >= 0) {
|
||||
UpdateResidual(dw, group_idx, model->param.num_output_group, fidx);
|
||||
}
|
||||
}
|
||||
|
||||
// This needs to be public because of the __device__ lambda.
|
||||
GradientPair GetBiasGradient(int group_idx, int num_group) {
|
||||
dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id));
|
||||
auto counting = thrust::make_counting_iterator(0ull);
|
||||
auto f = [=] __device__(size_t idx) {
|
||||
return idx * num_group + group_idx;
|
||||
}; // NOLINT
|
||||
thrust::transform_iterator<decltype(f), decltype(counting), size_t> skip(
|
||||
counting, f);
|
||||
auto perm = thrust::make_permutation_iterator(gpair_.data(), skip);
|
||||
|
||||
return dh::SumReduction(temp_, perm, num_row_);
|
||||
}
|
||||
|
||||
// This needs to be public because of the __device__ lambda.
|
||||
void UpdateBiasResidual(float dbias, int group_idx, int num_groups) {
|
||||
if (dbias == 0.0f) return;
|
||||
auto d_gpair = gpair_;
|
||||
dh::LaunchN(learner_param_->gpu_id, num_row_, [=] __device__(size_t idx) {
|
||||
auto &g = d_gpair[idx * num_groups + group_idx];
|
||||
g += GradientPair(g.GetHess() * dbias, 0);
|
||||
});
|
||||
}
|
||||
|
||||
// This needs to be public because of the __device__ lambda.
|
||||
GradientPair GetGradient(int group_idx, int num_group, int fidx) {
|
||||
dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id));
|
||||
common::Span<xgboost::Entry> d_col = data_.subspan(row_ptr_[fidx]);
|
||||
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
||||
common::Span<GradientPair> d_gpair = gpair_;
|
||||
auto counting = thrust::make_counting_iterator(0ull);
|
||||
auto f = [=] __device__(size_t idx) {
|
||||
auto entry = d_col[idx];
|
||||
auto g = d_gpair[entry.index * num_group + group_idx];
|
||||
return GradientPair(g.GetGrad() * entry.fvalue,
|
||||
g.GetHess() * entry.fvalue * entry.fvalue);
|
||||
}; // NOLINT
|
||||
thrust::transform_iterator<decltype(f), decltype(counting), GradientPair>
|
||||
multiply_iterator(counting, f);
|
||||
return dh::SumReduction(temp_, multiply_iterator, col_size);
|
||||
}
|
||||
|
||||
// This needs to be public because of the __device__ lambda.
|
||||
void UpdateResidual(float dw, int group_idx, int num_groups, int fidx) {
|
||||
common::Span<GradientPair> d_gpair = gpair_;
|
||||
common::Span<Entry> d_col = data_.subspan(row_ptr_[fidx]);
|
||||
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
||||
dh::LaunchN(learner_param_->gpu_id, col_size, [=] __device__(size_t idx) {
|
||||
auto entry = d_col[idx];
|
||||
auto &g = d_gpair[entry.index * num_groups + group_idx];
|
||||
g += GradientPair(g.GetHess() * dw * entry.fvalue, 0);
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
bool IsEmpty() {
|
||||
return num_row_ == 0;
|
||||
}
|
||||
|
||||
void UpdateGpair(const std::vector<GradientPair> &host_gpair,
|
||||
const gbm::GBLinearModelParam &model_param) {
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
gpair_.data(),
|
||||
host_gpair.data(),
|
||||
gpair_.size() * sizeof(GradientPair), cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
// training parameter
|
||||
LinearTrainParam tparam_;
|
||||
CoordinateParam coord_param_;
|
||||
int device_{};
|
||||
std::unique_ptr<FeatureSelector> selector_;
|
||||
common::Monitor monitor_;
|
||||
|
||||
std::unique_ptr<DeviceShard> shard_{nullptr};
|
||||
dh::BulkAllocator ba_;
|
||||
std::vector<size_t> row_ptr_;
|
||||
common::Span<xgboost::Entry> data_;
|
||||
common::Span<GradientPair> gpair_;
|
||||
dh::CubMemory temp_;
|
||||
size_t num_row_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_LINEAR_UPDATER(GPUCoordinateUpdater, "gpu_coord_descent")
|
||||
|
||||
@ -33,9 +33,7 @@ struct SoftmaxMultiClassParam : public dmlc::Parameter<SoftmaxMultiClassParam> {
|
||||
.describe("Number of output class in the multi-class classification.");
|
||||
}
|
||||
};
|
||||
// TODO(trivialfis): Currently the sharding in softmax is less than ideal
|
||||
// due to repeated copying data between CPU and GPUs. Maybe we just use single
|
||||
// GPU?
|
||||
|
||||
class SoftmaxMultiClassObj : public ObjFunction {
|
||||
public:
|
||||
explicit SoftmaxMultiClassObj(bool output_prob)
|
||||
|
||||
@ -195,77 +195,52 @@ __global__ void PredictKernel(common::Span<const DevicePredictionNode> d_nodes,
|
||||
|
||||
class GPUPredictor : public xgboost::Predictor {
|
||||
private:
|
||||
struct DeviceShard {
|
||||
DeviceShard() : device_{-1} {}
|
||||
void InitModel(const gbm::GBTreeModel& model,
|
||||
const thrust::host_vector<size_t>& h_tree_segments,
|
||||
const thrust::host_vector<DevicePredictionNode>& h_nodes,
|
||||
size_t tree_begin, size_t tree_end) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
nodes_.resize(h_nodes.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(nodes_.data().get(), h_nodes.data(),
|
||||
sizeof(DevicePredictionNode) * h_nodes.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_segments_.resize(h_tree_segments.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(tree_segments_.data().get(), h_tree_segments.data(),
|
||||
sizeof(size_t) * h_tree_segments.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_group_.resize(model.tree_info.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(tree_group_.data().get(), model.tree_info.data(),
|
||||
sizeof(int) * model.tree_info.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
this->tree_begin_ = tree_begin;
|
||||
this->tree_end_ = tree_end;
|
||||
this->num_group_ = model.param.num_output_group;
|
||||
}
|
||||
|
||||
~DeviceShard() {
|
||||
if (device_ >= 0) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
}
|
||||
void PredictInternal(const SparsePage& batch,
|
||||
size_t num_features,
|
||||
HostDeviceVector<bst_float>* predictions,
|
||||
size_t batch_offset) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
const int BLOCK_THREADS = 128;
|
||||
size_t num_rows = batch.Size();
|
||||
const int GRID_SIZE = static_cast<int>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||
|
||||
int shared_memory_bytes = static_cast<int>
|
||||
(sizeof(float) * num_features * BLOCK_THREADS);
|
||||
bool use_shared = true;
|
||||
if (shared_memory_bytes > max_shared_memory_bytes_) {
|
||||
shared_memory_bytes = 0;
|
||||
use_shared = false;
|
||||
}
|
||||
size_t entry_start = 0;
|
||||
|
||||
void Init(int device) {
|
||||
this->device_ = device;
|
||||
max_shared_memory_bytes_ = dh::MaxSharedMemory(this->device_);
|
||||
}
|
||||
|
||||
void InitModel(const gbm::GBTreeModel& model,
|
||||
const thrust::host_vector<size_t>& h_tree_segments,
|
||||
const thrust::host_vector<DevicePredictionNode>& h_nodes,
|
||||
size_t tree_begin, size_t tree_end) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
nodes_.resize(h_nodes.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(nodes_.data().get(), h_nodes.data(),
|
||||
sizeof(DevicePredictionNode) * h_nodes.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_segments_.resize(h_tree_segments.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(tree_segments_.data().get(), h_tree_segments.data(),
|
||||
sizeof(size_t) * h_tree_segments.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_group_.resize(model.tree_info.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(tree_group_.data().get(), model.tree_info.data(),
|
||||
sizeof(int) * model.tree_info.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
this->tree_begin_ = tree_begin;
|
||||
this->tree_end_ = tree_end;
|
||||
this->num_group_ = model.param.num_output_group;
|
||||
}
|
||||
|
||||
void PredictInternal(const SparsePage& batch,
|
||||
size_t num_features,
|
||||
HostDeviceVector<bst_float>* predictions,
|
||||
size_t batch_offset) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
const int BLOCK_THREADS = 128;
|
||||
size_t num_rows = batch.Size();
|
||||
const int GRID_SIZE = static_cast<int>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||
|
||||
int shared_memory_bytes = static_cast<int>
|
||||
(sizeof(float) * num_features * BLOCK_THREADS);
|
||||
bool use_shared = true;
|
||||
if (shared_memory_bytes > max_shared_memory_bytes_) {
|
||||
shared_memory_bytes = 0;
|
||||
use_shared = false;
|
||||
}
|
||||
size_t entry_start = 0;
|
||||
|
||||
PredictKernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>>
|
||||
(dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset),
|
||||
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(),
|
||||
batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared, this->num_group_);
|
||||
}
|
||||
|
||||
private:
|
||||
int device_;
|
||||
dh::device_vector<DevicePredictionNode> nodes_;
|
||||
dh::device_vector<size_t> tree_segments_;
|
||||
dh::device_vector<int> tree_group_;
|
||||
size_t max_shared_memory_bytes_;
|
||||
size_t tree_begin_;
|
||||
size_t tree_end_;
|
||||
int num_group_;
|
||||
};
|
||||
PredictKernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>>
|
||||
(dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset),
|
||||
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(),
|
||||
batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared, this->num_group_);
|
||||
}
|
||||
|
||||
void InitModel(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) {
|
||||
CHECK_EQ(model.param.size_leaf_vector, 0);
|
||||
@ -285,7 +260,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
std::copy(src_nodes.begin(), src_nodes.end(),
|
||||
h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]);
|
||||
}
|
||||
shard_.InitModel(model, h_tree_segments, h_nodes, tree_begin, tree_end);
|
||||
InitModel(model, h_tree_segments, h_nodes, tree_begin, tree_end);
|
||||
}
|
||||
|
||||
void DevicePredictInternal(DMatrix* dmat,
|
||||
@ -301,7 +276,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||
batch.offset.SetDevice(device_);
|
||||
batch.data.SetDevice(device_);
|
||||
shard_.PredictInternal(batch, model.param.num_feature, out_preds, batch_offset);
|
||||
PredictInternal(batch, model.param.num_feature, out_preds, batch_offset);
|
||||
batch_offset += batch.Size() * model.param.num_output_group;
|
||||
}
|
||||
|
||||
@ -309,14 +284,20 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
|
||||
public:
|
||||
GPUPredictor() : device_{-1} {};
|
||||
GPUPredictor() : device_{-1} {}
|
||||
|
||||
~GPUPredictor() override {
|
||||
if (device_ >= 0) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
}
|
||||
}
|
||||
|
||||
void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
unsigned ntree_limit = 0) override {
|
||||
int device = learner_param_->gpu_id;
|
||||
CHECK_GE(device, 0);
|
||||
ConfigureShard(device);
|
||||
ConfigureDevice(device);
|
||||
|
||||
if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) {
|
||||
return;
|
||||
@ -433,22 +414,29 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
|
||||
int device = learner_param_->gpu_id;
|
||||
if (device >= 0) {
|
||||
ConfigureShard(device);
|
||||
ConfigureDevice(device);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief Reconfigure the shard when GPU is changed. */
|
||||
void ConfigureShard(int device) {
|
||||
/*! \brief Reconfigure the device when GPU is changed. */
|
||||
void ConfigureDevice(int device) {
|
||||
if (device_ == device) return;
|
||||
|
||||
device_ = device;
|
||||
shard_.Init(device_);
|
||||
if (device_ >= 0) {
|
||||
max_shared_memory_bytes_ = dh::MaxSharedMemory(device_);
|
||||
}
|
||||
}
|
||||
|
||||
DeviceShard shard_;
|
||||
int device_;
|
||||
common::Monitor monitor_;
|
||||
dh::device_vector<DevicePredictionNode> nodes_;
|
||||
dh::device_vector<size_t> tree_segments_;
|
||||
dh::device_vector<int> tree_group_;
|
||||
size_t max_shared_memory_bytes_;
|
||||
size_t tree_begin_;
|
||||
size_t tree_end_;
|
||||
int num_group_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
|
||||
|
||||
@ -435,7 +435,7 @@ __global__ void SharedMemHistKernel(xgboost::ELLPackMatrix matrix,
|
||||
|
||||
// Manage memory for a single GPU
|
||||
template <typename GradientSumT>
|
||||
struct DeviceShard {
|
||||
struct GPUHistMakerDevice {
|
||||
int device_id;
|
||||
EllpackPageImpl* page;
|
||||
|
||||
@ -474,12 +474,12 @@ struct DeviceShard {
|
||||
std::function<bool(ExpandEntry, ExpandEntry)>>;
|
||||
std::unique_ptr<ExpandQueue> qexpand;
|
||||
|
||||
DeviceShard(int _device_id,
|
||||
EllpackPageImpl* _page,
|
||||
bst_uint _n_rows,
|
||||
TrainParam _param,
|
||||
uint32_t column_sampler_seed,
|
||||
uint32_t n_features)
|
||||
GPUHistMakerDevice(int _device_id,
|
||||
EllpackPageImpl* _page,
|
||||
bst_uint _n_rows,
|
||||
TrainParam _param,
|
||||
uint32_t column_sampler_seed,
|
||||
uint32_t n_features)
|
||||
: device_id(_device_id),
|
||||
page(_page),
|
||||
n_rows(_n_rows),
|
||||
@ -487,12 +487,12 @@ struct DeviceShard {
|
||||
prediction_cache_initialised(false),
|
||||
column_sampler(column_sampler_seed),
|
||||
interaction_constraints(param, n_features) {
|
||||
monitor.Init(std::string("DeviceShard") + std::to_string(device_id));
|
||||
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id));
|
||||
}
|
||||
|
||||
void InitHistogram();
|
||||
|
||||
~DeviceShard() { // NOLINT
|
||||
~GPUHistMakerDevice() { // NOLINT
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
for (auto& stream : streams) {
|
||||
dh::safe_cuda(cudaStreamDestroy(stream));
|
||||
@ -781,7 +781,7 @@ struct DeviceShard {
|
||||
auto left_node_rows = row_partitioner->GetRows(nidx_left).size();
|
||||
auto right_node_rows = row_partitioner->GetRows(nidx_right).size();
|
||||
// Decide whether to build the left histogram or right histogram
|
||||
// Find the largest number of training instances on any given Shard
|
||||
// Find the largest number of training instances on any given device
|
||||
// Assume this will be the bottleneck and avoid building this node if
|
||||
// possible
|
||||
std::vector<size_t> max_reduce;
|
||||
@ -939,7 +939,7 @@ struct DeviceShard {
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
inline void DeviceShard<GradientSumT>::InitHistogram() {
|
||||
inline void GPUHistMakerDevice<GradientSumT>::InitHistogram() {
|
||||
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
|
||||
<< "Max leaves and max depth cannot both be unconstrained for "
|
||||
"gpu_hist.";
|
||||
@ -1026,19 +1026,17 @@ class GPUHistMakerSpecialised {
|
||||
page->Init(device_, param_.max_bin, hist_maker_param_.gpu_batch_nrows);
|
||||
}
|
||||
|
||||
// Create device shard
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
shard_.reset(new DeviceShard<GradientSumT>(device_,
|
||||
page,
|
||||
info_->num_row_,
|
||||
param_,
|
||||
column_sampling_seed,
|
||||
info_->num_col_));
|
||||
maker_.reset(new GPUHistMakerDevice<GradientSumT>(device_,
|
||||
page,
|
||||
info_->num_row_,
|
||||
param_,
|
||||
column_sampling_seed,
|
||||
info_->num_col_));
|
||||
|
||||
// Init global data for each shard
|
||||
monitor_.StartCuda("InitHistogram");
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
shard_->InitHistogram();
|
||||
maker_->InitHistogram();
|
||||
monitor_.StopCuda("InitHistogram");
|
||||
|
||||
p_last_fmat_ = dmat;
|
||||
@ -1077,18 +1075,17 @@ class GPUHistMakerSpecialised {
|
||||
monitor_.StopCuda("InitData");
|
||||
|
||||
gpair->SetDevice(device_);
|
||||
shard_->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
|
||||
maker_->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(
|
||||
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
|
||||
if (shard_ == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
||||
if (maker_ == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
||||
return false;
|
||||
}
|
||||
monitor_.StartCuda("UpdatePredictionCache");
|
||||
p_out_preds->SetDevice(device_);
|
||||
dh::safe_cuda(cudaSetDevice(shard_->device_id));
|
||||
shard_->UpdatePredictionCache(p_out_preds->DevicePointer());
|
||||
maker_->UpdatePredictionCache(p_out_preds->DevicePointer());
|
||||
monitor_.StopCuda("UpdatePredictionCache");
|
||||
return true;
|
||||
}
|
||||
@ -1096,7 +1093,7 @@ class GPUHistMakerSpecialised {
|
||||
TrainParam param_; // NOLINT
|
||||
MetaInfo* info_{}; // NOLINT
|
||||
|
||||
std::unique_ptr<DeviceShard<GradientSumT>> shard_; // NOLINT
|
||||
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker_; // NOLINT
|
||||
|
||||
private:
|
||||
bool initialised_;
|
||||
|
||||
@ -71,40 +71,6 @@ class HistogramCutsWrapper : public common::HistogramCuts {
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
|
||||
template <typename GradientSumT>
|
||||
void BuildGidx(DeviceShard<GradientSumT>* shard, int n_rows, int n_cols,
|
||||
bst_float sparsity=0) {
|
||||
auto dmat = CreateDMatrix(n_rows, n_cols, sparsity, 3);
|
||||
const SparsePage& batch = *(*dmat)->GetBatches<xgboost::SparsePage>().begin();
|
||||
|
||||
HistogramCutsWrapper cmat;
|
||||
cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
|
||||
// 24 cut fields, 3 cut fields for each feature (column).
|
||||
cmat.SetValues({0.30f, 0.67f, 1.64f,
|
||||
0.32f, 0.77f, 1.95f,
|
||||
0.29f, 0.70f, 1.80f,
|
||||
0.32f, 0.75f, 1.85f,
|
||||
0.18f, 0.59f, 1.69f,
|
||||
0.25f, 0.74f, 2.00f,
|
||||
0.26f, 0.74f, 1.98f,
|
||||
0.26f, 0.71f, 1.83f});
|
||||
cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f});
|
||||
|
||||
auto is_dense = (*dmat)->Info().num_nonzero_ ==
|
||||
(*dmat)->Info().num_row_ * (*dmat)->Info().num_col_;
|
||||
size_t row_stride = 0;
|
||||
const auto &offset_vec = batch.offset.ConstHostVector();
|
||||
for (size_t i = 1; i < offset_vec.size(); ++i) {
|
||||
row_stride = std::max(row_stride, offset_vec[i] - offset_vec[i-1]);
|
||||
}
|
||||
shard->InitHistogram(cmat, row_stride, is_dense);
|
||||
shard->CreateHistIndices(
|
||||
batch, cmat, RowStateOnDevice(batch.Size(), batch.Size()), -1);
|
||||
|
||||
delete dmat;
|
||||
}
|
||||
|
||||
std::vector<GradientPairPrecise> GetHostHistGpair() {
|
||||
// 24 bins, 3 bins for each feature (column).
|
||||
std::vector<GradientPairPrecise> hist_gpair = {
|
||||
@ -131,8 +97,8 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
};
|
||||
param.Init(args);
|
||||
auto page = BuildEllpackPage(kNRows, kNCols);
|
||||
DeviceShard<GradientSumT> shard(0, page.get(), kNRows, param, kNCols, kNCols);
|
||||
shard.InitHistogram();
|
||||
GPUHistMakerDevice<GradientSumT> maker(0, page.get(), kNRows, param, kNCols, kNCols);
|
||||
maker.InitHistogram();
|
||||
|
||||
xgboost::SimpleLCG gen;
|
||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
||||
@ -150,13 +116,13 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
sizeof(common::CompressedByteT) * page->gidx_buffer.size(),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
shard.row_partitioner.reset(new RowPartitioner(0, kNRows));
|
||||
shard.hist.AllocateHistogram(0);
|
||||
dh::CopyVectorToDeviceSpan(shard.gpair, h_gpair);
|
||||
maker.row_partitioner.reset(new RowPartitioner(0, kNRows));
|
||||
maker.hist.AllocateHistogram(0);
|
||||
dh::CopyVectorToDeviceSpan(maker.gpair, h_gpair);
|
||||
|
||||
shard.use_shared_memory_histograms = use_shared_memory_histograms;
|
||||
shard.BuildHist(0);
|
||||
DeviceHistogram<GradientSumT> d_hist = shard.hist;
|
||||
maker.use_shared_memory_histograms = use_shared_memory_histograms;
|
||||
maker.BuildHist(0);
|
||||
DeviceHistogram<GradientSumT> d_hist = maker.hist;
|
||||
|
||||
auto node_histogram = d_hist.GetNodeHistogram(0);
|
||||
// d_hist.data stored in float, not gradient pair
|
||||
@ -230,30 +196,29 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
|
||||
int max_bins = 4;
|
||||
|
||||
// Initialize DeviceShard
|
||||
// Initialize GPUHistMakerDevice
|
||||
auto page = BuildEllpackPage(kNRows, kNCols);
|
||||
std::unique_ptr<DeviceShard<GradientPairPrecise>> shard{
|
||||
new DeviceShard<GradientPairPrecise>(0, page.get(), kNRows, param, kNCols, kNCols)};
|
||||
// Initialize DeviceShard::node_sum_gradients
|
||||
shard->node_sum_gradients = {{6.4f, 12.8f}};
|
||||
GPUHistMakerDevice<GradientPairPrecise> maker(0, page.get(), kNRows, param, kNCols, kNCols);
|
||||
// Initialize GPUHistMakerDevice::node_sum_gradients
|
||||
maker.node_sum_gradients = {{6.4f, 12.8f}};
|
||||
|
||||
// Initialize DeviceShard::cut
|
||||
// Initialize GPUHistMakerDevice::cut
|
||||
auto cmat = GetHostCutMatrix();
|
||||
|
||||
// Copy cut matrix to device.
|
||||
shard->ba.Allocate(0,
|
||||
&(page->ellpack_matrix.feature_segments), cmat.Ptrs().size(),
|
||||
&(page->ellpack_matrix.min_fvalue), cmat.MinValues().size(),
|
||||
&(page->ellpack_matrix.gidx_fvalue_map), 24,
|
||||
&(shard->monotone_constraints), kNCols);
|
||||
maker.ba.Allocate(0,
|
||||
&(page->ellpack_matrix.feature_segments), cmat.Ptrs().size(),
|
||||
&(page->ellpack_matrix.min_fvalue), cmat.MinValues().size(),
|
||||
&(page->ellpack_matrix.gidx_fvalue_map), 24,
|
||||
&(maker.monotone_constraints), kNCols);
|
||||
dh::CopyVectorToDeviceSpan(page->ellpack_matrix.feature_segments, cmat.Ptrs());
|
||||
dh::CopyVectorToDeviceSpan(page->ellpack_matrix.gidx_fvalue_map, cmat.Values());
|
||||
dh::CopyVectorToDeviceSpan(shard->monotone_constraints, param.monotone_constraints);
|
||||
dh::CopyVectorToDeviceSpan(maker.monotone_constraints, param.monotone_constraints);
|
||||
dh::CopyVectorToDeviceSpan(page->ellpack_matrix.min_fvalue, cmat.MinValues());
|
||||
|
||||
// Initialize DeviceShard::hist
|
||||
shard->hist.Init(0, (max_bins - 1) * kNCols);
|
||||
shard->hist.AllocateHistogram(0);
|
||||
// Initialize GPUHistMakerDevice::hist
|
||||
maker.hist.Init(0, (max_bins - 1) * kNCols);
|
||||
maker.hist.AllocateHistogram(0);
|
||||
// Each row of hist_gpair represents gpairs for one feature.
|
||||
// Each entry represents a bin.
|
||||
std::vector<GradientPairPrecise> hist_gpair = GetHostHistGpair();
|
||||
@ -263,27 +228,26 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
hist.push_back(pair.GetHess());
|
||||
}
|
||||
|
||||
ASSERT_EQ(shard->hist.Data().size(), hist.size());
|
||||
ASSERT_EQ(maker.hist.Data().size(), hist.size());
|
||||
thrust::copy(hist.begin(), hist.end(),
|
||||
shard->hist.Data().begin());
|
||||
maker.hist.Data().begin());
|
||||
|
||||
shard->column_sampler.Init(kNCols,
|
||||
param.colsample_bynode,
|
||||
param.colsample_bylevel,
|
||||
param.colsample_bytree,
|
||||
false);
|
||||
maker.column_sampler.Init(kNCols,
|
||||
param.colsample_bynode,
|
||||
param.colsample_bylevel,
|
||||
param.colsample_bytree,
|
||||
false);
|
||||
|
||||
RegTree tree;
|
||||
MetaInfo info;
|
||||
info.num_row_ = kNRows;
|
||||
info.num_col_ = kNCols;
|
||||
|
||||
shard->node_value_constraints.resize(1);
|
||||
shard->node_value_constraints[0].lower_bound = -1.0;
|
||||
shard->node_value_constraints[0].upper_bound = 1.0;
|
||||
maker.node_value_constraints.resize(1);
|
||||
maker.node_value_constraints[0].lower_bound = -1.0;
|
||||
maker.node_value_constraints[0].upper_bound = 1.0;
|
||||
|
||||
std::vector<DeviceSplitCandidate> res =
|
||||
shard->EvaluateSplits({ 0,0 }, tree, kNCols);
|
||||
std::vector<DeviceSplitCandidate> res = maker.EvaluateSplits({0, 0 }, tree, kNCols);
|
||||
|
||||
ASSERT_EQ(res[0].findex, 7);
|
||||
ASSERT_EQ(res[1].findex, 7);
|
||||
@ -316,18 +280,18 @@ void TestHistogramIndexImpl() {
|
||||
hist_maker_ext.Configure(training_params, &generic_param);
|
||||
hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get());
|
||||
|
||||
// Extract the device shard from the histogram makers and from that its compressed
|
||||
// Extract the device maker from the histogram makers and from that its compressed
|
||||
// histogram index
|
||||
const auto &dev_shard = hist_maker.shard_;
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer(dev_shard->page->gidx_buffer.size());
|
||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer, dev_shard->page->gidx_buffer);
|
||||
const auto &maker = hist_maker.maker_;
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer(maker->page->gidx_buffer.size());
|
||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer, maker->page->gidx_buffer);
|
||||
|
||||
const auto &dev_shard_ext = hist_maker_ext.shard_;
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer_ext(dev_shard_ext->page->gidx_buffer.size());
|
||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, dev_shard_ext->page->gidx_buffer);
|
||||
const auto &maker_ext = hist_maker_ext.maker_;
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer_ext(maker_ext->page->gidx_buffer.size());
|
||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, maker_ext->page->gidx_buffer);
|
||||
|
||||
ASSERT_EQ(dev_shard->page->n_bins, dev_shard_ext->page->n_bins);
|
||||
ASSERT_EQ(dev_shard->page->gidx_buffer.size(), dev_shard_ext->page->gidx_buffer.size());
|
||||
ASSERT_EQ(maker->page->n_bins, maker_ext->page->n_bins);
|
||||
ASSERT_EQ(maker->page->gidx_buffer.size(), maker_ext->page->gidx_buffer.size());
|
||||
|
||||
ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user