Make HostDeviceVector single gpu only (#4773)

* Make HostDeviceVector single gpu only
This commit is contained in:
Rong Ou 2019-08-25 14:51:13 -07:00 committed by Rory Mitchell
parent 41227d1933
commit 38ab79f889
54 changed files with 641 additions and 1621 deletions

View File

@ -36,13 +36,12 @@ int main(int argc, char** argv) {
// https://xgboost.readthedocs.io/en/latest/parameter.html // https://xgboost.readthedocs.io/en/latest/parameter.html
safe_xgboost(XGBoosterSetParam(booster, "tree_method", use_gpu ? "gpu_hist" : "hist")); safe_xgboost(XGBoosterSetParam(booster, "tree_method", use_gpu ? "gpu_hist" : "hist"));
if (use_gpu) { if (use_gpu) {
// set the number of GPUs and the first GPU to use; // set the GPU to use;
// this is not necessary, but provided here as an illustration // this is not necessary, but provided here as an illustration
safe_xgboost(XGBoosterSetParam(booster, "n_gpus", "1"));
safe_xgboost(XGBoosterSetParam(booster, "gpu_id", "0")); safe_xgboost(XGBoosterSetParam(booster, "gpu_id", "0"));
} else { } else {
// avoid evaluating objective and metric on a GPU // avoid evaluating objective and metric on a GPU
safe_xgboost(XGBoosterSetParam(booster, "n_gpus", "0")); safe_xgboost(XGBoosterSetParam(booster, "gpu_id", "-1"));
} }
safe_xgboost(XGBoosterSetParam(booster, "objective", "binary:logistic")); safe_xgboost(XGBoosterSetParam(booster, "objective", "binary:logistic"));

View File

@ -19,10 +19,8 @@ struct GenericParameter : public dmlc::Parameter<GenericParameter> {
// number of threads to use if OpenMP is enabled // number of threads to use if OpenMP is enabled
// if equals 0, use system default // if equals 0, use system default
int nthread; int nthread;
// primary device. // primary device, -1 means no gpu.
int gpu_id; int gpu_id;
// number of devices to use, -1 implies using all available devices.
int n_gpus;
// declare parameters // declare parameters
DMLC_DECLARE_PARAMETER(GenericParameter) { DMLC_DECLARE_PARAMETER(GenericParameter) {
DMLC_DECLARE_FIELD(seed).set_default(0).describe( DMLC_DECLARE_FIELD(seed).set_default(0).describe(
@ -36,15 +34,20 @@ struct GenericParameter : public dmlc::Parameter<GenericParameter> {
DMLC_DECLARE_FIELD(nthread).set_default(0).describe( DMLC_DECLARE_FIELD(nthread).set_default(0).describe(
"Number of threads to use."); "Number of threads to use.");
DMLC_DECLARE_FIELD(gpu_id) DMLC_DECLARE_FIELD(gpu_id)
.set_default(0) .set_default(-1)
.set_lower_bound(-1)
.describe("The primary GPU device ordinal."); .describe("The primary GPU device ordinal.");
DMLC_DECLARE_FIELD(n_gpus) DMLC_DECLARE_FIELD(n_gpus)
.set_default(0) .set_default(0)
.set_range(0, 1) .set_range(0, 0)
.describe("Deprecated. Single process multi-GPU training is no longer supported. " .describe("Deprecated. Single process multi-GPU training is no longer supported. "
"Please switch to distributed training with one process per GPU. " "Please switch to distributed training with one process per GPU. "
"This can be done using Dask or Spark."); "This can be done using Dask or Spark.");
} }
private:
// number of devices to use (deprecated).
int n_gpus;
}; };
} // namespace xgboost } // namespace xgboost

View File

@ -60,8 +60,8 @@ class MyLogistic : public ObjFunction {
void PredTransform(HostDeviceVector<bst_float> *io_preds) override { void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
// transform margin value to probability. // transform margin value to probability.
std::vector<bst_float> &preds = io_preds->HostVector(); std::vector<bst_float> &preds = io_preds->HostVector();
for (size_t i = 0; i < preds.size(); ++i) { for (auto& pred : preds) {
preds[i] = 1.0f / (1.0f + std::exp(-preds[i])); pred = 1.0f / (1.0f + std::exp(-pred));
} }
} }
bst_float ProbToMargin(bst_float base_score) const override { bst_float ProbToMargin(bst_float base_score) const override {

View File

@ -22,48 +22,12 @@ using RandomThreadLocalStore = dmlc::ThreadLocalStore<RandomThreadLocalEntry>;
GlobalRandomEngine& GlobalRandom() { GlobalRandomEngine& GlobalRandom() {
return RandomThreadLocalStore::Get()->engine; return RandomThreadLocalStore::Get()->engine;
} }
} // namespace common
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)
int AllVisibleImpl::AllVisible() { int AllVisibleGPUs() {
return 0; return 0;
} }
#endif // !defined(XGBOOST_USE_CUDA) #endif // !defined(XGBOOST_USE_CUDA)
constexpr GPUSet::GpuIdType GPUSet::kAll; } // namespace common
GPUSet GPUSet::All(GpuIdType gpu_id, GpuIdType n_gpus, int32_t n_rows) {
CHECK_GE(gpu_id, 0) << "gpu_id must be >= 0.";
CHECK_GE(n_gpus, -1) << "n_gpus must be >= -1.";
GpuIdType const n_devices_visible = AllVisible().Size();
CHECK_LE(n_gpus, n_devices_visible);
if (n_devices_visible == 0 || n_gpus == 0 || n_rows == 0) {
LOG(DEBUG) << "Runing on CPU.";
return Empty();
}
GpuIdType const n_available_devices = n_devices_visible - gpu_id;
if (n_gpus == kAll) { // Use all devices starting from `gpu_id'.
CHECK(gpu_id < n_devices_visible)
<< "\ngpu_id should be less than number of visible devices.\ngpu_id: "
<< gpu_id
<< ", number of visible devices: "
<< n_devices_visible;
GpuIdType n_devices =
n_available_devices < n_rows ? n_available_devices : n_rows;
LOG(DEBUG) << "GPU ID: " << gpu_id << ", Number of GPUs: " << n_devices;
return Range(gpu_id, n_devices);
} else { // Use devices in ( gpu_id, gpu_id + n_gpus ).
CHECK_LE(n_gpus, n_available_devices)
<< "Starting from gpu id: " << gpu_id << ", there are only "
<< n_available_devices << " available devices, while n_gpus is set to: "
<< n_gpus;
GpuIdType n_devices = n_gpus < n_rows ? n_gpus : n_rows;
LOG(DEBUG) << "GPU ID: " << gpu_id << ", Number of GPUs: " << n_devices;
return Range(gpu_id, n_devices);
}
}
} // namespace xgboost } // namespace xgboost

View File

@ -4,8 +4,9 @@
#include "common.h" #include "common.h"
namespace xgboost { namespace xgboost {
namespace common {
int AllVisibleImpl::AllVisible() { int AllVisibleGPUs() {
int n_visgpus = 0; int n_visgpus = 0;
try { try {
// When compiled with CUDA but running on CPU only device, // When compiled with CUDA but running on CPU only device,
@ -17,4 +18,5 @@ int AllVisibleImpl::AllVisible() {
return n_visgpus; return n_visgpus;
} }
} // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -140,88 +140,8 @@ class Range {
Iterator begin_; Iterator begin_;
Iterator end_; Iterator end_;
}; };
int AllVisibleGPUs();
} // namespace common } // namespace common
struct AllVisibleImpl {
static int AllVisible();
};
/* \brief set of devices across which HostDeviceVector can be distributed.
*
* Currently implemented as a range, but can be changed later to something else,
* e.g. a bitset
*/
class GPUSet {
public:
using GpuIdType = int;
static constexpr GpuIdType kAll = -1;
explicit GPUSet(int start = 0, int ndevices = 0)
: devices_(start, start + ndevices) {}
static GPUSet Empty() { return GPUSet(); }
static GPUSet Range(GpuIdType start, GpuIdType n_gpus) {
return n_gpus <= 0 ? Empty() : GPUSet{start, n_gpus};
}
/*! \brief n_gpus and num_rows both are upper bounds. */
static GPUSet All(GpuIdType gpu_id, GpuIdType n_gpus,
GpuIdType num_rows = std::numeric_limits<GpuIdType>::max());
static GPUSet AllVisible() {
GpuIdType n = AllVisibleImpl::AllVisible();
return Range(0, n);
}
size_t Size() const {
GpuIdType size = *devices_.end() - *devices_.begin();
GpuIdType res = size < 0 ? 0 : size;
return static_cast<size_t>(res);
}
/*
* By default, we have two configurations of identifying device, one
* is the device id obtained from `cudaGetDevice'. But we sometimes
* store objects that allocated one for each device in a list, which
* requires a zero-based index.
*
* Hence, `DeviceId' converts a zero-based index to actual device id,
* `Index' converts a device id to a zero-based index.
*/
GpuIdType DeviceId(size_t index) const {
GpuIdType result = *devices_.begin() + static_cast<GpuIdType>(index);
CHECK(Contains(result)) << "\nDevice " << result << " is not in GPUSet."
<< "\nIndex: " << index
<< "\nGPUSet: (" << *begin() << ", " << *end() << ")"
<< std::endl;
return result;
}
size_t Index(GpuIdType device) const {
CHECK(Contains(device)) << "\nDevice " << device << " is not in GPUSet."
<< "\nGPUSet: (" << *begin() << ", " << *end() << ")"
<< std::endl;
size_t result = static_cast<size_t>(device - *devices_.begin());
return result;
}
bool IsEmpty() const { return Size() == 0; }
bool Contains(GpuIdType device) const {
return *devices_.begin() <= device && device < *devices_.end();
}
common::Range::Iterator begin() const { return devices_.begin(); } // NOLINT
common::Range::Iterator end() const { return devices_.end(); } // NOLINT
friend bool operator==(const GPUSet& lhs, const GPUSet& rhs) {
return lhs.devices_ == rhs.devices_;
}
friend bool operator!=(const GPUSet& lhs, const GPUSet& rhs) {
return !(lhs == rhs);
}
private:
common::Range devices_;
};
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_COMMON_H_ #endif // XGBOOST_COMMON_COMMON_H_

View File

@ -72,22 +72,6 @@ const T *Raw(const thrust::device_vector<T> &v) { // NOLINT
return raw_pointer_cast(v.data()); return raw_pointer_cast(v.data());
} }
// if n_devices=-1, then use all visible devices
inline void SynchronizeNDevices(xgboost::GPUSet devices) {
devices = devices.IsEmpty() ? xgboost::GPUSet::AllVisible() : devices;
for (auto const d : devices) {
safe_cuda(cudaSetDevice(d));
safe_cuda(cudaDeviceSynchronize());
}
}
inline void SynchronizeAll() {
for (int device_idx : xgboost::GPUSet::AllVisible()) {
safe_cuda(cudaSetDevice(device_idx));
safe_cuda(cudaDeviceSynchronize());
}
}
inline size_t AvailableMemory(int device_idx) { inline size_t AvailableMemory(int device_idx) {
size_t device_free = 0; size_t device_free = 0;
size_t device_total = 0; size_t device_total = 0;
@ -119,7 +103,7 @@ inline size_t MaxSharedMemory(int device_idx) {
} }
inline void CheckComputeCapability() { inline void CheckComputeCapability() {
for (int d_idx : xgboost::GPUSet::AllVisible()) { for (int d_idx = 0; d_idx < xgboost::common::AllVisibleGPUs(); ++d_idx) {
cudaDeviceProp prop; cudaDeviceProp prop;
safe_cuda(cudaGetDeviceProperties(&prop, d_idx)); safe_cuda(cudaGetDeviceProperties(&prop, d_idx));
std::ostringstream oss; std::ostringstream oss;

View File

@ -35,7 +35,6 @@ __global__ void FindCutsK
if (icut >= ncuts) { if (icut >= ncuts) {
return; return;
} }
WXQSketch::Entry v;
int isample = 0; int isample = 0;
if (icut == 0) { if (icut == 0) {
isample = 0; isample = 0;
@ -59,11 +58,14 @@ struct IsNotNaN {
__device__ bool operator()(float a) const { return !isnan(a); } __device__ bool operator()(float a) const { return !isnan(a); }
}; };
__global__ void UnpackFeaturesK __global__ void UnpackFeaturesK(float* __restrict__ fvalues,
(float* __restrict__ fvalues, float* __restrict__ feature_weights, float* __restrict__ feature_weights,
const size_t* __restrict__ row_ptrs, const float* __restrict__ weights, const size_t* __restrict__ row_ptrs,
Entry* entries, size_t nrows_array, int ncols, size_t row_begin_ptr, const float* __restrict__ weights,
size_t nrows) { Entry* entries,
size_t nrows_array,
size_t row_begin_ptr,
size_t nrows) {
size_t irow = threadIdx.x + size_t(blockIdx.x) * blockDim.x; size_t irow = threadIdx.x + size_t(blockIdx.x) * blockDim.x;
if (irow >= nrows) { if (irow >= nrows) {
return; return;
@ -102,8 +104,9 @@ struct SketchContainer {
const MetaInfo &info = dmat->Info(); const MetaInfo &info = dmat->Info();
// Initialize Sketches for this dmatrix // Initialize Sketches for this dmatrix
sketches_.resize(info.num_col_); sketches_.resize(info.num_col_);
#pragma omp parallel for schedule(static) if (info.num_col_ > kOmpNumColsParallelizeLimit) #pragma omp parallel for default(none) shared(info, param) schedule(static) \
for (int icol = 0; icol < info.num_col_; ++icol) { if (info.num_col_ > kOmpNumColsParallelizeLimit) // NOLINT
for (int icol = 0; icol < info.num_col_; ++icol) { // NOLINT
sketches_[icol].Init(info.num_row_, 1.0 / (8 * param.max_bin)); sketches_[icol].Init(info.num_row_, 1.0 / (8 * param.max_bin));
} }
} }
@ -120,8 +123,6 @@ struct GPUSketcher {
// manage memory for a single GPU // manage memory for a single GPU
class DeviceShard { class DeviceShard {
int device_; int device_;
bst_uint row_begin_; // The row offset for this shard
bst_uint row_end_;
bst_uint n_rows_; bst_uint n_rows_;
int num_cols_{0}; int num_cols_{0};
size_t n_cuts_{0}; size_t n_cuts_{0};
@ -131,27 +132,31 @@ struct GPUSketcher {
tree::TrainParam param_; tree::TrainParam param_;
SketchContainer *sketch_container_; SketchContainer *sketch_container_;
dh::device_vector<size_t> row_ptrs_; dh::device_vector<size_t> row_ptrs_{};
dh::device_vector<Entry> entries_; dh::device_vector<Entry> entries_{};
dh::device_vector<bst_float> fvalues_; dh::device_vector<bst_float> fvalues_{};
dh::device_vector<bst_float> feature_weights_; dh::device_vector<bst_float> feature_weights_{};
dh::device_vector<bst_float> fvalues_cur_; dh::device_vector<bst_float> fvalues_cur_{};
dh::device_vector<WXQSketch::Entry> cuts_d_; dh::device_vector<WXQSketch::Entry> cuts_d_{};
thrust::host_vector<WXQSketch::Entry> cuts_h_; thrust::host_vector<WXQSketch::Entry> cuts_h_{};
dh::device_vector<bst_float> weights_; dh::device_vector<bst_float> weights_{};
dh::device_vector<bst_float> weights2_; dh::device_vector<bst_float> weights2_{};
std::vector<size_t> n_cuts_cur_; std::vector<size_t> n_cuts_cur_{};
dh::device_vector<size_t> num_elements_; dh::device_vector<size_t> num_elements_{};
dh::device_vector<char> tmp_storage_; dh::device_vector<char> tmp_storage_{};
public: public:
DeviceShard(int device, bst_uint row_begin, bst_uint row_end, DeviceShard(int device,
tree::TrainParam param, SketchContainer *sketch_container) : bst_uint n_rows,
device_(device), row_begin_(row_begin), row_end_(row_end), tree::TrainParam param,
n_rows_(row_end - row_begin), param_(std::move(param)), sketch_container_(sketch_container) { SketchContainer* sketch_container) :
device_(device),
n_rows_(n_rows),
param_(std::move(param)),
sketch_container_(sketch_container) {
} }
~DeviceShard() { ~DeviceShard() { // NOLINT
dh::safe_cuda(cudaSetDevice(device_)); dh::safe_cuda(cudaSetDevice(device_));
} }
@ -319,19 +324,18 @@ struct GPUSketcher {
const auto& offset_vec = row_batch.offset.HostVector(); const auto& offset_vec = row_batch.offset.HostVector();
const auto& data_vec = row_batch.data.HostVector(); const auto& data_vec = row_batch.data.HostVector();
size_t n_entries = offset_vec[row_begin_ + batch_row_end] - size_t n_entries = offset_vec[batch_row_end] - offset_vec[batch_row_begin];
offset_vec[row_begin_ + batch_row_begin];
// copy the batch to the GPU // copy the batch to the GPU
dh::safe_cuda dh::safe_cuda
(cudaMemcpyAsync(entries_.data().get(), (cudaMemcpyAsync(entries_.data().get(),
data_vec.data() + offset_vec[row_begin_ + batch_row_begin], data_vec.data() + offset_vec[batch_row_begin],
n_entries * sizeof(Entry), cudaMemcpyDefault)); n_entries * sizeof(Entry), cudaMemcpyDefault));
// copy the weights if necessary // copy the weights if necessary
if (has_weights_) { if (has_weights_) {
const auto& weights_vec = info.weights_.HostVector(); const auto& weights_vec = info.weights_.HostVector();
dh::safe_cuda dh::safe_cuda
(cudaMemcpyAsync(weights_.data().get(), (cudaMemcpyAsync(weights_.data().get(),
weights_vec.data() + row_begin_ + batch_row_begin, weights_vec.data() + batch_row_begin,
batch_nrows * sizeof(bst_float), cudaMemcpyDefault)); batch_nrows * sizeof(bst_float), cudaMemcpyDefault));
} }
@ -349,8 +353,7 @@ struct GPUSketcher {
(fvalues_.data().get(), has_weights_ ? feature_weights_.data().get() : nullptr, (fvalues_.data().get(), has_weights_ ? feature_weights_.data().get() : nullptr,
row_ptrs_.data().get() + batch_row_begin, row_ptrs_.data().get() + batch_row_begin,
has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(), has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(),
gpu_batch_nrows_, num_cols_, gpu_batch_nrows_, offset_vec[batch_row_begin], batch_nrows);
offset_vec[row_begin_ + batch_row_begin], batch_nrows);
for (int icol = 0; icol < num_cols_; ++icol) { for (int icol = 0; icol < num_cols_; ++icol) {
FindColumnCuts(batch_nrows, icol); FindColumnCuts(batch_nrows, icol);
@ -358,7 +361,7 @@ struct GPUSketcher {
// add cuts into sketches // add cuts into sketches
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin()); thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
#pragma omp parallel for schedule(static) \ #pragma omp parallel for default(none) schedule(static) \
if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
for (int icol = 0; icol < num_cols_; ++icol) { for (int icol = 0; icol < num_cols_; ++icol) {
WXQSketch::SummaryContainer summary; WXQSketch::SummaryContainer summary;
@ -391,8 +394,7 @@ struct GPUSketcher {
dh::safe_cuda(cudaSetDevice(device_)); dh::safe_cuda(cudaSetDevice(device_));
const auto& offset_vec = row_batch.offset.HostVector(); const auto& offset_vec = row_batch.offset.HostVector();
row_ptrs_.resize(n_rows_ + 1); row_ptrs_.resize(n_rows_ + 1);
thrust::copy(offset_vec.data() + row_begin_, thrust::copy(offset_vec.data(), offset_vec.data() + n_rows_ + 1, row_ptrs_.begin());
offset_vec.data() + row_end_ + 1, row_ptrs_.begin());
size_t gpu_nbatches = common::DivRoundUp(n_rows_, gpu_batch_nrows_); size_t gpu_nbatches = common::DivRoundUp(n_rows_, gpu_batch_nrows_);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) { for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
SketchBatch(row_batch, info, gpu_batch); SketchBatch(row_batch, info, gpu_batch);
@ -401,32 +403,18 @@ struct GPUSketcher {
}; };
void SketchBatch(const SparsePage &batch, const MetaInfo &info) { void SketchBatch(const SparsePage &batch, const MetaInfo &info) {
GPUDistribution dist = auto device = generic_param_.gpu_id;
GPUDistribution::Block(GPUSet::All(generic_param_.gpu_id, generic_param_.n_gpus,
batch.Size()));
// create device shards // create device shard
shards_.resize(dist.Devices().Size()); shard_.reset(new DeviceShard(device, batch.Size(), param_, sketch_container_.get()));
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
size_t start = dist.ShardStart(batch.Size(), i);
size_t size = dist.ShardSize(batch.Size(), i);
shard = std::unique_ptr<DeviceShard>(
new DeviceShard(dist.Devices().DeviceId(i), start,
start + size, param_, sketch_container_.get()));
});
// compute sketches for each shard // compute sketches for the shard
dh::ExecuteIndexShards(&shards_, shard_->Init(batch, info, gpu_batch_nrows_);
[&](int idx, std::unique_ptr<DeviceShard>& shard) { shard_->Sketch(batch, info);
shard->Init(batch, info, gpu_batch_nrows_); shard_->ComputeRowStride();
shard->Sketch(batch, info);
shard->ComputeRowStride();
});
// compute row stride across all shards // compute row stride
for (const auto &shard : shards_) { row_stride_ = shard_->GetRowStride();
row_stride_ = std::max(row_stride_, shard->GetRowStride());
}
} }
GPUSketcher(const tree::TrainParam &param, const GenericParameter &generic_param, int gpu_nrows) GPUSketcher(const tree::TrainParam &param, const GenericParameter &generic_param, int gpu_nrows)
@ -444,13 +432,13 @@ struct GPUSketcher {
this->SketchBatch(batch, info); this->SketchBatch(batch, info);
} }
hmat->Init(&sketch_container_.get()->sketches_, param_.max_bin); hmat->Init(&sketch_container_->sketches_, param_.max_bin);
return row_stride_; return row_stride_;
} }
private: private:
std::vector<std::unique_ptr<DeviceShard>> shards_; std::unique_ptr<DeviceShard> shard_;
const tree::TrainParam &param_; const tree::TrainParam &param_;
const GenericParameter &generic_param_; const GenericParameter &generic_param_;
int gpu_batch_nrows_; int gpu_batch_nrows_;

View File

@ -30,19 +30,19 @@ struct HostDeviceVectorImpl {
}; };
template <typename T> template <typename T>
HostDeviceVector<T>::HostDeviceVector(size_t size, T v, const GPUDistribution &) HostDeviceVector<T>::HostDeviceVector(size_t size, T v, int device)
: impl_(nullptr) { : impl_(nullptr) {
impl_ = new HostDeviceVectorImpl<T>(size, v); impl_ = new HostDeviceVectorImpl<T>(size, v);
} }
template <typename T> template <typename T>
HostDeviceVector<T>::HostDeviceVector(std::initializer_list<T> init, const GPUDistribution &) HostDeviceVector<T>::HostDeviceVector(std::initializer_list<T> init, int device)
: impl_(nullptr) { : impl_(nullptr) {
impl_ = new HostDeviceVectorImpl<T>(init); impl_ = new HostDeviceVectorImpl<T>(init);
} }
template <typename T> template <typename T>
HostDeviceVector<T>::HostDeviceVector(const std::vector<T>& init, const GPUDistribution &) HostDeviceVector<T>::HostDeviceVector(const std::vector<T>& init, int device)
: impl_(nullptr) { : impl_(nullptr) {
impl_ = new HostDeviceVectorImpl<T>(init); impl_ = new HostDeviceVectorImpl<T>(init);
} }
@ -75,29 +75,23 @@ template <typename T>
size_t HostDeviceVector<T>::Size() const { return impl_->Vec().size(); } size_t HostDeviceVector<T>::Size() const { return impl_->Vec().size(); }
template <typename T> template <typename T>
GPUSet HostDeviceVector<T>::Devices() const { return GPUSet::Empty(); } int HostDeviceVector<T>::DeviceIdx() const { return -1; }
template <typename T> template <typename T>
const GPUDistribution& HostDeviceVector<T>::Distribution() const { T* HostDeviceVector<T>::DevicePointer() { return nullptr; }
static GPUDistribution dummyInstance;
return dummyInstance;
}
template <typename T> template <typename T>
T* HostDeviceVector<T>::DevicePointer(int device) { return nullptr; } const T* HostDeviceVector<T>::ConstDevicePointer() const {
template <typename T>
const T* HostDeviceVector<T>::ConstDevicePointer(int device) const {
return nullptr; return nullptr;
} }
template <typename T> template <typename T>
common::Span<T> HostDeviceVector<T>::DeviceSpan(int device) { common::Span<T> HostDeviceVector<T>::DeviceSpan() {
return common::Span<T>(); return common::Span<T>();
} }
template <typename T> template <typename T>
common::Span<const T> HostDeviceVector<T>::ConstDeviceSpan(int device) const { common::Span<const T> HostDeviceVector<T>::ConstDeviceSpan() const {
return common::Span<const T>(); return common::Span<const T>();
} }
@ -115,10 +109,7 @@ void HostDeviceVector<T>::Resize(size_t new_size, T v) {
} }
template <typename T> template <typename T>
size_t HostDeviceVector<T>::DeviceStart(int device) const { return 0; } size_t HostDeviceVector<T>::DeviceSize() const { return 0; }
template <typename T>
size_t HostDeviceVector<T>::DeviceSize(int device) const { return 0; }
template <typename T> template <typename T>
void HostDeviceVector<T>::Fill(T v) { void HostDeviceVector<T>::Fill(T v) {
@ -149,18 +140,12 @@ bool HostDeviceVector<T>::HostCanAccess(GPUAccess access) const {
} }
template <typename T> template <typename T>
bool HostDeviceVector<T>::DeviceCanAccess(int device, GPUAccess access) const { bool HostDeviceVector<T>::DeviceCanAccess(GPUAccess access) const {
return false; return false;
} }
template <typename T> template <typename T>
void HostDeviceVector<T>::Shard(const GPUDistribution& distribution) const { } void HostDeviceVector<T>::SetDevice(int device) const {}
template <typename T>
void HostDeviceVector<T>::Shard(GPUSet devices) const { }
template <typename T>
void Reshard(const GPUDistribution &distribution) { }
// explicit instantiations are required, as HostDeviceVector isn't header-only // explicit instantiations are required, as HostDeviceVector isn't header-only
template class HostDeviceVector<bst_float>; template class HostDeviceVector<bst_float>;

View File

@ -10,7 +10,6 @@
#include <mutex> #include <mutex>
#include "./device_helpers.cuh" #include "./device_helpers.cuh"
namespace xgboost { namespace xgboost {
// the handler to call instead of cudaSetDevice; only used for testing // the handler to call instead of cudaSetDevice; only used for testing
@ -43,144 +42,12 @@ class Permissions {
}; };
template <typename T> template <typename T>
struct HostDeviceVectorImpl { class HostDeviceVectorImpl {
struct DeviceShard { public:
DeviceShard() HostDeviceVectorImpl(size_t size, T v, int device) : device_(device), perm_h_(device < 0) {
: proper_size_{0}, device_{-1}, start_{0}, perm_d_{false}, if (device >= 0) {
cached_size_{static_cast<size_t>(~0)}, vec_{nullptr} {}
~DeviceShard() {
SetDevice(); SetDevice();
} data_d_.resize(size, v);
void Init(HostDeviceVectorImpl<T>* vec, int device) {
if (vec_ == nullptr) { vec_ = vec; }
CHECK_EQ(vec, vec_);
device_ = device;
LazyResize(vec_->Size());
perm_d_ = vec_->perm_h_.Complementary();
}
void Init(HostDeviceVectorImpl<T>* vec, const DeviceShard& other) {
if (vec_ == nullptr) { vec_ = vec; }
CHECK_EQ(vec, vec_);
device_ = other.device_;
cached_size_ = other.cached_size_;
start_ = other.start_;
proper_size_ = other.proper_size_;
SetDevice();
data_.resize(other.data_.size());
perm_d_ = other.perm_d_;
}
void ScatterFrom(const T* begin) {
// TODO(canonizer): avoid full copy of host data
LazySyncDevice(GPUAccess::kWrite);
SetDevice();
dh::safe_cuda(cudaMemcpyAsync(data_.data().get(), begin + start_,
data_.size() * sizeof(T), cudaMemcpyDefault));
}
void GatherTo(thrust::device_ptr<T> begin) {
LazySyncDevice(GPUAccess::kRead);
SetDevice();
dh::safe_cuda(cudaMemcpyAsync(begin.get() + start_, data_.data().get(),
proper_size_ * sizeof(T), cudaMemcpyDefault));
}
void Fill(T v) {
// TODO(canonizer): avoid full copy of host data
LazySyncDevice(GPUAccess::kWrite);
SetDevice();
thrust::fill(data_.begin(), data_.end(), v);
}
void Copy(DeviceShard* other) {
// TODO(canonizer): avoid full copy of host data for this (but not for other)
LazySyncDevice(GPUAccess::kWrite);
other->LazySyncDevice(GPUAccess::kRead);
SetDevice();
dh::safe_cuda(cudaMemcpyAsync(data_.data().get(), other->data_.data().get(),
data_.size() * sizeof(T), cudaMemcpyDefault));
}
void LazySyncHost(GPUAccess access) {
SetDevice();
dh::safe_cuda(cudaMemcpy(vec_->data_h_.data() + start_,
data_.data().get(), proper_size_ * sizeof(T),
cudaMemcpyDeviceToHost));
perm_d_.DenyComplementary(access);
}
void LazyResize(size_t new_size) {
if (new_size == cached_size_) { return; }
// resize is required
int ndevices = vec_->distribution_.devices_.Size();
int device_index = vec_->distribution_.devices_.Index(device_);
start_ = vec_->distribution_.ShardStart(new_size, device_index);
proper_size_ = vec_->distribution_.ShardProperSize(new_size, device_index);
// The size on this device.
size_t size_d = vec_->distribution_.ShardSize(new_size, device_index);
SetDevice();
data_.resize(size_d);
cached_size_ = new_size;
}
void LazySyncDevice(GPUAccess access) {
if (perm_d_.CanAccess(access)) { return; }
if (perm_d_.CanRead()) {
// deny read to the host
perm_d_.Grant(access);
std::lock_guard<std::mutex> lock(vec_->mutex_);
vec_->perm_h_.DenyComplementary(access);
return;
}
// data is on the host
size_t size_h = vec_->data_h_.size();
LazyResize(size_h);
SetDevice();
dh::safe_cuda(
cudaMemcpy(data_.data().get(), vec_->data_h_.data() + start_,
data_.size() * sizeof(T), cudaMemcpyHostToDevice));
perm_d_.Grant(access);
std::lock_guard<std::mutex> lock(vec_->mutex_);
vec_->perm_h_.DenyComplementary(access);
vec_->size_d_ = size_h;
}
void SetDevice() {
if (cudaSetDeviceHandler == nullptr) {
dh::safe_cuda(cudaSetDevice(device_));
} else {
(*cudaSetDeviceHandler)(device_);
}
}
T* Raw() { return data_.data().get(); }
size_t Start() const { return start_; }
size_t DataSize() const { return data_.size(); }
Permissions& Perm() { return perm_d_; }
Permissions const& Perm() const { return perm_d_; }
private:
int device_;
dh::device_vector<T> data_;
// cached vector size
size_t cached_size_;
size_t start_;
// size of the portion to copy back to the host
size_t proper_size_;
Permissions perm_d_;
HostDeviceVectorImpl<T>* vec_;
};
HostDeviceVectorImpl(size_t size, T v, const GPUDistribution &distribution)
: distribution_(distribution), perm_h_(distribution.IsEmpty()), size_d_(0) {
if (!distribution_.IsEmpty()) {
size_d_ = size;
InitShards();
Fill(v);
} else { } else {
data_h_.resize(size, v); data_h_.resize(size, v);
} }
@ -188,127 +55,81 @@ struct HostDeviceVectorImpl {
// required, as a new std::mutex has to be created // required, as a new std::mutex has to be created
HostDeviceVectorImpl(const HostDeviceVectorImpl<T>& other) HostDeviceVectorImpl(const HostDeviceVectorImpl<T>& other)
: data_h_(other.data_h_), perm_h_(other.perm_h_), size_d_(other.size_d_), : device_(other.device_), data_h_(other.data_h_), perm_h_(other.perm_h_), mutex_() {
distribution_(other.distribution_), mutex_() { if (device_ >= 0) {
shards_.resize(other.shards_.size()); SetDevice();
dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) { data_d_ = other.data_d_;
shard.Init(this, other.shards_.at(i)); }
});
} }
// Initializer can be std::vector<T> or std::initializer_list<T> // Initializer can be std::vector<T> or std::initializer_list<T>
template <class Initializer> template <class Initializer>
HostDeviceVectorImpl(const Initializer& init, const GPUDistribution &distribution) HostDeviceVectorImpl(const Initializer& init, int device) : device_(device), perm_h_(device < 0) {
: distribution_(distribution), perm_h_(distribution.IsEmpty()), size_d_(0) { if (device >= 0) {
if (!distribution_.IsEmpty()) { LazyResizeDevice(init.size());
size_d_ = init.size();
InitShards();
Copy(init); Copy(init);
} else { } else {
data_h_ = init; data_h_ = init;
} }
} }
void InitShards() { ~HostDeviceVectorImpl() {
int ndevices = distribution_.devices_.Size(); if (device_ >= 0) {
shards_.resize(ndevices); SetDevice();
dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) { }
shard.Init(this, distribution_.devices_.DeviceId(i));
});
} }
size_t Size() const { return perm_h_.CanRead() ? data_h_.size() : size_d_; } size_t Size() const { return perm_h_.CanRead() ? data_h_.size() : data_d_.size(); }
GPUSet Devices() const { return distribution_.devices_; } int DeviceIdx() const { return device_; }
const GPUDistribution& Distribution() const { return distribution_; } T* DevicePointer() {
LazySyncDevice(GPUAccess::kWrite);
T* DevicePointer(int device) { return data_d_.data().get();
CHECK(distribution_.devices_.Contains(device));
LazySyncDevice(device, GPUAccess::kWrite);
return shards_.at(distribution_.devices_.Index(device)).Raw();
} }
const T* ConstDevicePointer(int device) { const T* ConstDevicePointer() {
CHECK(distribution_.devices_.Contains(device)); LazySyncDevice(GPUAccess::kRead);
LazySyncDevice(device, GPUAccess::kRead); return data_d_.data().get();
return shards_.at(distribution_.devices_.Index(device)).Raw();
} }
common::Span<T> DeviceSpan(int device) { common::Span<T> DeviceSpan() {
GPUSet devices = distribution_.devices_; LazySyncDevice(GPUAccess::kWrite);
CHECK(devices.Contains(device)); return {data_d_.data().get(), static_cast<typename common::Span<T>::index_type>(DeviceSize())};
LazySyncDevice(device, GPUAccess::kWrite);
return {shards_.at(devices.Index(device)).Raw(),
static_cast<typename common::Span<T>::index_type>(DeviceSize(device))};
} }
common::Span<const T> ConstDeviceSpan(int device) { common::Span<const T> ConstDeviceSpan() {
GPUSet devices = distribution_.devices_; LazySyncDevice(GPUAccess::kRead);
CHECK(devices.Contains(device));
LazySyncDevice(device, GPUAccess::kRead);
using SpanInd = typename common::Span<const T>::index_type; using SpanInd = typename common::Span<const T>::index_type;
return {shards_.at(devices.Index(device)).Raw(), return {data_d_.data().get(), static_cast<SpanInd>(DeviceSize())};
static_cast<SpanInd>(DeviceSize(device))};
} }
size_t DeviceSize(int device) { size_t DeviceSize() {
CHECK(distribution_.devices_.Contains(device)); LazySyncDevice(GPUAccess::kRead);
LazySyncDevice(device, GPUAccess::kRead); return data_d_.size();
return shards_.at(distribution_.devices_.Index(device)).DataSize();
} }
size_t DeviceStart(int device) { thrust::device_ptr<T> tbegin() { // NOLINT
CHECK(distribution_.devices_.Contains(device)); return thrust::device_ptr<T>(DevicePointer());
LazySyncDevice(device, GPUAccess::kRead);
return shards_.at(distribution_.devices_.Index(device)).Start();
} }
thrust::device_ptr<T> tbegin(int device) { // NOLINT thrust::device_ptr<const T> tcbegin() { // NOLINT
return thrust::device_ptr<T>(DevicePointer(device)); return thrust::device_ptr<const T>(ConstDevicePointer());
} }
thrust::device_ptr<const T> tcbegin(int device) { // NOLINT thrust::device_ptr<T> tend() { // NOLINT
return thrust::device_ptr<const T>(ConstDevicePointer(device)); return tbegin() + DeviceSize();
} }
thrust::device_ptr<T> tend(int device) { // NOLINT thrust::device_ptr<const T> tcend() { // NOLINT
return tbegin(device) + DeviceSize(device); return tcbegin() + DeviceSize();
}
thrust::device_ptr<const T> tcend(int device) { // NOLINT
return tcbegin(device) + DeviceSize(device);
}
void ScatterFrom(thrust::device_ptr<const T> begin, thrust::device_ptr<const T> end) {
CHECK_EQ(end - begin, Size());
if (perm_h_.CanWrite()) {
dh::safe_cuda(cudaMemcpy(data_h_.data(), begin.get(),
(end - begin) * sizeof(T),
cudaMemcpyDeviceToHost));
} else {
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
shard.ScatterFrom(begin.get());
});
}
}
void GatherTo(thrust::device_ptr<T> begin, thrust::device_ptr<T> end) {
CHECK_EQ(end - begin, Size());
if (perm_h_.CanWrite()) {
dh::safe_cuda(cudaMemcpy(begin.get(), data_h_.data(),
data_h_.size() * sizeof(T),
cudaMemcpyHostToDevice));
} else {
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.GatherTo(begin); });
}
} }
void Fill(T v) { // NOLINT void Fill(T v) { // NOLINT
if (perm_h_.CanWrite()) { if (perm_h_.CanWrite()) {
std::fill(data_h_.begin(), data_h_.end(), v); std::fill(data_h_.begin(), data_h_.end(), v);
} else { } else {
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.Fill(v); }); DeviceFill(v);
} }
} }
@ -320,14 +141,10 @@ struct HostDeviceVectorImpl {
return; return;
} }
// Data is on device; // Data is on device;
if (distribution_ != other->distribution_) { if (device_ != other->device_) {
distribution_ = GPUDistribution(); SetDevice(other->device_);
Shard(other->Distribution());
size_d_ = other->size_d_;
} }
dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) { DeviceCopy(other);
shard.Copy(&other->shards_.at(i));
});
} }
void Copy(const std::vector<T>& other) { void Copy(const std::vector<T>& other) {
@ -335,9 +152,7 @@ struct HostDeviceVectorImpl {
if (perm_h_.CanWrite()) { if (perm_h_.CanWrite()) {
std::copy(other.begin(), other.end(), data_h_.begin()); std::copy(other.begin(), other.end(), data_h_.begin());
} else { } else {
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { DeviceCopy(other.data());
shard.ScatterFrom(other.data());
});
} }
} }
@ -346,9 +161,7 @@ struct HostDeviceVectorImpl {
if (perm_h_.CanWrite()) { if (perm_h_.CanWrite()) {
std::copy(other.begin(), other.end(), data_h_.begin()); std::copy(other.begin(), other.end(), data_h_.begin());
} else { } else {
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { DeviceCopy(other.begin());
shard.ScatterFrom(other.begin());
});
} }
} }
@ -362,40 +175,23 @@ struct HostDeviceVectorImpl {
return data_h_; return data_h_;
} }
void Shard(const GPUDistribution& distribution) { void SetDevice(int device) {
if (distribution_ == distribution) { return; } if (device_ == device) { return; }
CHECK(distribution_.IsEmpty()) if (device_ >= 0) {
<< "Data resides on different GPUs: " << "ID: " LazySyncHost(GPUAccess::kWrite);
<< *(distribution_.Devices().begin()) << " and ID: " }
<< *(distribution.Devices().begin()); device_ = device;
distribution_ = distribution; if (device_ >= 0) {
InitShards(); LazyResizeDevice(data_h_.size());
} }
void Shard(GPUSet new_devices) {
if (distribution_.Devices() == new_devices) { return; }
Shard(GPUDistribution::Block(new_devices));
}
void Reshard(const GPUDistribution &distribution) {
if (distribution_ == distribution) { return; }
LazySyncHost(GPUAccess::kWrite);
distribution_ = distribution;
shards_.clear();
InitShards();
} }
void Resize(size_t new_size, T v) { void Resize(size_t new_size, T v) {
if (new_size == Size()) { return; } if (new_size == Size()) { return; }
if (distribution_.IsFixedSize()) { if (Size() == 0 && device_ >= 0) {
CHECK_EQ(new_size, distribution_.offsets_.back());
}
if (Size() == 0 && !distribution_.IsEmpty()) {
// fast on-device resize // fast on-device resize
perm_h_ = Permissions(false); perm_h_ = Permissions(false);
size_d_ = new_size; data_d_.resize(new_size, v);
InitShards();
Fill(v);
} else { } else {
// resize on host // resize on host
LazySyncHost(GPUAccess::kWrite); LazySyncHost(GPUAccess::kWrite);
@ -407,72 +203,110 @@ struct HostDeviceVectorImpl {
if (perm_h_.CanAccess(access)) { return; } if (perm_h_.CanAccess(access)) { return; }
if (perm_h_.CanRead()) { if (perm_h_.CanRead()) {
// data is present, just need to deny access to the device // data is present, just need to deny access to the device
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
shard.Perm().DenyComplementary(access);
});
perm_h_.Grant(access); perm_h_.Grant(access);
return; return;
} }
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (data_h_.size() != size_d_) { data_h_.resize(size_d_); } if (data_h_.size() != data_d_.size()) { data_h_.resize(data_d_.size()); }
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { SetDevice();
shard.LazySyncHost(access); dh::safe_cuda(cudaMemcpy(data_h_.data(),
}); data_d_.data().get(),
data_d_.size() * sizeof(T),
cudaMemcpyDeviceToHost));
perm_h_.Grant(access); perm_h_.Grant(access);
} }
void LazySyncDevice(int device, GPUAccess access) { void LazySyncDevice(GPUAccess access) {
GPUSet devices = distribution_.Devices(); if (DevicePerm().CanAccess(access)) { return; }
CHECK(devices.Contains(device)); if (DevicePerm().CanRead()) {
shards_.at(devices.Index(device)).LazySyncDevice(access); // deny read to the host
std::lock_guard<std::mutex> lock(mutex_);
perm_h_.DenyComplementary(access);
return;
}
// data is on the host
LazyResizeDevice(data_h_.size());
SetDevice();
dh::safe_cuda(cudaMemcpy(data_d_.data().get(),
data_h_.data(),
data_d_.size() * sizeof(T),
cudaMemcpyHostToDevice));
std::lock_guard<std::mutex> lock(mutex_);
perm_h_.DenyComplementary(access);
} }
bool HostCanAccess(GPUAccess access) { return perm_h_.CanAccess(access); } bool HostCanAccess(GPUAccess access) { return perm_h_.CanAccess(access); }
bool DeviceCanAccess(GPUAccess access) { return DevicePerm().CanAccess(access); }
bool DeviceCanAccess(int device, GPUAccess access) {
GPUSet devices = distribution_.Devices();
if (!devices.Contains(device)) { return false; }
return shards_.at(devices.Index(device)).Perm().CanAccess(access);
}
private: private:
std::vector<T> data_h_; int device_{-1};
Permissions perm_h_; std::vector<T> data_h_{};
// the total size of the data stored on the devices dh::device_vector<T> data_d_{};
size_t size_d_; Permissions perm_h_{false};
GPUDistribution distribution_;
// protects size_d_ and perm_h_ when updated from multiple threads // protects size_d_ and perm_h_ when updated from multiple threads
std::mutex mutex_; std::mutex mutex_{};
std::vector<DeviceShard> shards_;
void DeviceFill(T v) {
// TODO(canonizer): avoid full copy of host data
LazySyncDevice(GPUAccess::kWrite);
SetDevice();
thrust::fill(data_d_.begin(), data_d_.end(), v);
}
void DeviceCopy(HostDeviceVectorImpl* other) {
// TODO(canonizer): avoid full copy of host data for this (but not for other)
LazySyncDevice(GPUAccess::kWrite);
other->LazySyncDevice(GPUAccess::kRead);
SetDevice();
dh::safe_cuda(cudaMemcpyAsync(data_d_.data().get(), other->data_d_.data().get(),
data_d_.size() * sizeof(T), cudaMemcpyDefault));
}
void DeviceCopy(const T* begin) {
// TODO(canonizer): avoid full copy of host data
LazySyncDevice(GPUAccess::kWrite);
SetDevice();
dh::safe_cuda(cudaMemcpyAsync(data_d_.data().get(), begin,
data_d_.size() * sizeof(T), cudaMemcpyDefault));
}
void LazyResizeDevice(size_t new_size) {
if (new_size == data_d_.size()) { return; }
SetDevice();
data_d_.resize(new_size);
}
void SetDevice() {
CHECK_GE(device_, 0);
if (cudaSetDeviceHandler == nullptr) {
dh::safe_cuda(cudaSetDevice(device_));
} else {
(*cudaSetDeviceHandler)(device_);
}
}
Permissions DevicePerm() const { return perm_h_.Complementary(); }
}; };
template <typename T> template<typename T>
HostDeviceVector<T>::HostDeviceVector HostDeviceVector<T>::HostDeviceVector(size_t size, T v, int device)
(size_t size, T v, const GPUDistribution &distribution) : impl_(nullptr) { : impl_(new HostDeviceVectorImpl<T>(size, v, device)) {}
impl_ = new HostDeviceVectorImpl<T>(size, v, distribution);
}
template <typename T> template <typename T>
HostDeviceVector<T>::HostDeviceVector HostDeviceVector<T>::HostDeviceVector(std::initializer_list<T> init, int device)
(std::initializer_list<T> init, const GPUDistribution &distribution) : impl_(nullptr) { : impl_(new HostDeviceVectorImpl<T>(init, device)) {}
impl_ = new HostDeviceVectorImpl<T>(init, distribution);
}
template <typename T> template <typename T>
HostDeviceVector<T>::HostDeviceVector HostDeviceVector<T>::HostDeviceVector(const std::vector<T>& init, int device)
(const std::vector<T>& init, const GPUDistribution &distribution) : impl_(nullptr) { : impl_(new HostDeviceVectorImpl<T>(init, device)) {}
impl_ = new HostDeviceVectorImpl<T>(init, distribution);
}
template <typename T> template <typename T>
HostDeviceVector<T>::HostDeviceVector(const HostDeviceVector<T>& other) HostDeviceVector<T>::HostDeviceVector(const HostDeviceVector<T>& other)
: impl_(nullptr) { : impl_(new HostDeviceVectorImpl<T>(*other.impl_)) {}
impl_ = new HostDeviceVectorImpl<T>(*other.impl_);
}
template <typename T> template <typename T>
HostDeviceVector<T>& HostDeviceVector<T>::operator= HostDeviceVector<T>& HostDeviceVector<T>::operator=(const HostDeviceVector<T>& other) {
(const HostDeviceVector<T>& other) {
if (this == &other) { return *this; } if (this == &other) { return *this; }
std::unique_ptr<HostDeviceVectorImpl<T>> newImpl(new HostDeviceVectorImpl<T>(*other.impl_)); std::unique_ptr<HostDeviceVectorImpl<T>> newImpl(new HostDeviceVectorImpl<T>(*other.impl_));
@ -491,73 +325,51 @@ template <typename T>
size_t HostDeviceVector<T>::Size() const { return impl_->Size(); } size_t HostDeviceVector<T>::Size() const { return impl_->Size(); }
template <typename T> template <typename T>
GPUSet HostDeviceVector<T>::Devices() const { return impl_->Devices(); } int HostDeviceVector<T>::DeviceIdx() const { return impl_->DeviceIdx(); }
template <typename T> template <typename T>
const GPUDistribution& HostDeviceVector<T>::Distribution() const { T* HostDeviceVector<T>::DevicePointer() {
return impl_->Distribution(); return impl_->DevicePointer();
} }
template <typename T> template <typename T>
T* HostDeviceVector<T>::DevicePointer(int device) { const T* HostDeviceVector<T>::ConstDevicePointer() const {
return impl_->DevicePointer(device); return impl_->ConstDevicePointer();
} }
template <typename T> template <typename T>
const T* HostDeviceVector<T>::ConstDevicePointer(int device) const { common::Span<T> HostDeviceVector<T>::DeviceSpan() {
return impl_->ConstDevicePointer(device); return impl_->DeviceSpan();
} }
template <typename T> template <typename T>
common::Span<T> HostDeviceVector<T>::DeviceSpan(int device) { common::Span<const T> HostDeviceVector<T>::ConstDeviceSpan() const {
return impl_->DeviceSpan(device); return impl_->ConstDeviceSpan();
} }
template <typename T> template <typename T>
common::Span<const T> HostDeviceVector<T>::ConstDeviceSpan(int device) const { size_t HostDeviceVector<T>::DeviceSize() const {
return impl_->ConstDeviceSpan(device); return impl_->DeviceSize();
} }
template <typename T> template <typename T>
size_t HostDeviceVector<T>::DeviceStart(int device) const { thrust::device_ptr<T> HostDeviceVector<T>::tbegin() { // NOLINT
return impl_->DeviceStart(device); return impl_->tbegin();
} }
template <typename T> template <typename T>
size_t HostDeviceVector<T>::DeviceSize(int device) const { thrust::device_ptr<const T> HostDeviceVector<T>::tcbegin() const { // NOLINT
return impl_->DeviceSize(device); return impl_->tcbegin();
} }
template <typename T> template <typename T>
thrust::device_ptr<T> HostDeviceVector<T>::tbegin(int device) { // NOLINT thrust::device_ptr<T> HostDeviceVector<T>::tend() { // NOLINT
return impl_->tbegin(device); return impl_->tend();
} }
template <typename T> template <typename T>
thrust::device_ptr<const T> HostDeviceVector<T>::tcbegin(int device) const { // NOLINT thrust::device_ptr<const T> HostDeviceVector<T>::tcend() const { // NOLINT
return impl_->tcbegin(device); return impl_->tcend();
}
template <typename T>
thrust::device_ptr<T> HostDeviceVector<T>::tend(int device) { // NOLINT
return impl_->tend(device);
}
template <typename T>
thrust::device_ptr<const T> HostDeviceVector<T>::tcend(int device) const { // NOLINT
return impl_->tcend(device);
}
template <typename T>
void HostDeviceVector<T>::ScatterFrom
(thrust::device_ptr<const T> begin, thrust::device_ptr<const T> end) {
impl_->ScatterFrom(begin, end);
}
template <typename T>
void HostDeviceVector<T>::GatherTo
(thrust::device_ptr<T> begin, thrust::device_ptr<T> end) const {
impl_->GatherTo(begin, end);
} }
template <typename T> template <typename T>
@ -594,23 +406,13 @@ bool HostDeviceVector<T>::HostCanAccess(GPUAccess access) const {
} }
template <typename T> template <typename T>
bool HostDeviceVector<T>::DeviceCanAccess(int device, GPUAccess access) const { bool HostDeviceVector<T>::DeviceCanAccess(GPUAccess access) const {
return impl_->DeviceCanAccess(device, access); return impl_->DeviceCanAccess(access);
} }
template <typename T> template <typename T>
void HostDeviceVector<T>::Shard(GPUSet new_devices) const { void HostDeviceVector<T>::SetDevice(int device) const {
impl_->Shard(new_devices); impl_->SetDevice(device);
}
template <typename T>
void HostDeviceVector<T>::Shard(const GPUDistribution &distribution) const {
impl_->Shard(distribution);
}
template <typename T>
void HostDeviceVector<T>::Reshard(const GPUDistribution &distribution) {
impl_->Reshard(distribution);
} }
template <typename T> template <typename T>

View File

@ -79,113 +79,6 @@ void SetCudaSetDeviceHandler(void (*handler)(int));
template <typename T> struct HostDeviceVectorImpl; template <typename T> struct HostDeviceVectorImpl;
// Distribution for the HostDeviceVector; it specifies such aspects as the
// devices it is distributed on, whether there are copies of elements from
// other GPUs as well as the granularity of splitting. It may also specify
// explicit boundaries for devices, in which case the size of the array cannot
// be changed.
class GPUDistribution {
template<typename T> friend struct HostDeviceVectorImpl;
public:
explicit GPUDistribution(GPUSet devices = GPUSet::Empty())
: devices_(devices), granularity_(1), overlap_(0) {}
private:
GPUDistribution(GPUSet devices, int granularity, int overlap,
std::vector<size_t> &&offsets)
: devices_(devices), granularity_(granularity), overlap_(overlap),
offsets_(std::move(offsets)) {}
public:
static GPUDistribution Empty() { return GPUDistribution(); }
static GPUDistribution Block(GPUSet devices) { return GPUDistribution(devices); }
static GPUDistribution Overlap(GPUSet devices, int overlap) {
return GPUDistribution(devices, 1, overlap, std::vector<size_t>());
}
static GPUDistribution Granular(GPUSet devices, int granularity) {
return GPUDistribution(devices, granularity, 0, std::vector<size_t>());
}
// NOTE(rongou): Explicit offsets don't necessarily cover the whole vector. Sections before the
// first shard or after the last shard may be on host only. This windowing is done in the GPU
// predictor for external memory support.
static GPUDistribution Explicit(GPUSet devices, std::vector<size_t> offsets) {
return GPUDistribution(devices, 1, 0, std::move(offsets));
}
friend bool operator==(const GPUDistribution& a, const GPUDistribution& b) {
bool const res = a.devices_ == b.devices_ &&
a.granularity_ == b.granularity_ &&
a.overlap_ == b.overlap_ &&
a.offsets_ == b.offsets_;
return res;
}
friend bool operator!=(const GPUDistribution& a, const GPUDistribution& b) {
return !(a == b);
}
GPUSet Devices() const { return devices_; }
bool IsEmpty() const { return devices_.IsEmpty(); }
size_t ShardStart(size_t size, int index) const {
if (size == 0) { return 0; }
if (offsets_.size() > 0) {
// explicit offsets are provided
CHECK_EQ(offsets_.back(), size);
return offsets_.at(index);
}
// no explicit offsets
size_t begin = std::min(index * Portion(size), size);
begin = begin > size ? size : begin;
return begin;
}
size_t ShardSize(size_t size, size_t index) const {
if (size == 0) { return 0; }
if (offsets_.size() > 0) {
// explicit offsets are provided
CHECK_EQ(offsets_.back(), size);
return offsets_.at(index + 1) - offsets_.at(index) +
(index == devices_.Size() - 1 ? overlap_ : 0);
}
size_t portion = Portion(size);
size_t begin = std::min(index * portion, size);
size_t end = std::min((index + 1) * portion + overlap_ * granularity_, size);
return end - begin;
}
size_t ShardProperSize(size_t size, size_t index) const {
if (size == 0) { return 0; }
return ShardSize(size, index) - (devices_.Size() - 1 > index ? overlap_ : 0);
}
bool IsFixedSize() const { return !offsets_.empty(); }
private:
static size_t DivRoundUp(size_t a, size_t b) { return (a + b - 1) / b; }
static size_t RoundUp(size_t a, size_t b) { return DivRoundUp(a, b) * b; }
size_t Portion(size_t size) const {
return RoundUp
(DivRoundUp
(std::max(static_cast<int64_t>(size - overlap_ * granularity_),
static_cast<int64_t>(1)),
devices_.Size()), granularity_);
}
GPUSet devices_;
int granularity_;
int overlap_;
// explicit offsets for the GPU parts, if any
std::vector<size_t> offsets_;
};
enum GPUAccess { enum GPUAccess {
kNone, kRead, kNone, kRead,
// write implies read // write implies read
@ -199,46 +92,38 @@ inline GPUAccess operator-(GPUAccess a, GPUAccess b) {
template <typename T> template <typename T>
class HostDeviceVector { class HostDeviceVector {
public: public:
explicit HostDeviceVector(size_t size = 0, T v = T(), explicit HostDeviceVector(size_t size = 0, T v = T(), int device = -1);
const GPUDistribution &distribution = GPUDistribution()); HostDeviceVector(std::initializer_list<T> init, int device = -1);
HostDeviceVector(std::initializer_list<T> init, explicit HostDeviceVector(const std::vector<T>& init, int device = -1);
const GPUDistribution &distribution = GPUDistribution());
explicit HostDeviceVector(const std::vector<T>& init,
const GPUDistribution &distribution = GPUDistribution());
~HostDeviceVector(); ~HostDeviceVector();
HostDeviceVector(const HostDeviceVector<T>&); HostDeviceVector(const HostDeviceVector<T>&);
HostDeviceVector<T>& operator=(const HostDeviceVector<T>&); HostDeviceVector<T>& operator=(const HostDeviceVector<T>&);
size_t Size() const; size_t Size() const;
GPUSet Devices() const; int DeviceIdx() const;
const GPUDistribution& Distribution() const; common::Span<T> DeviceSpan();
common::Span<T> DeviceSpan(int device); common::Span<const T> ConstDeviceSpan() const;
common::Span<const T> ConstDeviceSpan(int device) const; common::Span<const T> DeviceSpan() const { return ConstDeviceSpan(); }
common::Span<const T> DeviceSpan(int device) const { return ConstDeviceSpan(device); } T* DevicePointer();
T* DevicePointer(int device); const T* ConstDevicePointer() const;
const T* ConstDevicePointer(int device) const; const T* DevicePointer() const { return ConstDevicePointer(); }
const T* DevicePointer(int device) const { return ConstDevicePointer(device); }
T* HostPointer() { return HostVector().data(); } T* HostPointer() { return HostVector().data(); }
const T* ConstHostPointer() const { return ConstHostVector().data(); } const T* ConstHostPointer() const { return ConstHostVector().data(); }
const T* HostPointer() const { return ConstHostPointer(); } const T* HostPointer() const { return ConstHostPointer(); }
size_t DeviceStart(int device) const; size_t DeviceSize() const;
size_t DeviceSize(int device) const;
// only define functions returning device_ptr // only define functions returning device_ptr
// if HostDeviceVector.h is included from a .cu file // if HostDeviceVector.h is included from a .cu file
#ifdef __CUDACC__ #ifdef __CUDACC__
thrust::device_ptr<T> tbegin(int device); // NOLINT thrust::device_ptr<T> tbegin(); // NOLINT
thrust::device_ptr<T> tend(int device); // NOLINT thrust::device_ptr<T> tend(); // NOLINT
thrust::device_ptr<const T> tcbegin(int device) const; // NOLINT thrust::device_ptr<const T> tcbegin() const; // NOLINT
thrust::device_ptr<const T> tcend(int device) const; // NOLINT thrust::device_ptr<const T> tcend() const; // NOLINT
thrust::device_ptr<const T> tbegin(int device) const { // NOLINT thrust::device_ptr<const T> tbegin() const { // NOLINT
return tcbegin(device); return tcbegin();
} }
thrust::device_ptr<const T> tend(int device) const { return tcend(device); } // NOLINT thrust::device_ptr<const T> tend() const { return tcend(); } // NOLINT
void ScatterFrom(thrust::device_ptr<const T> begin, thrust::device_ptr<const T> end);
void GatherTo(thrust::device_ptr<T> begin, thrust::device_ptr<T> end) const;
#endif // __CUDACC__ #endif // __CUDACC__
void Fill(T v); void Fill(T v);
@ -251,18 +136,9 @@ class HostDeviceVector {
const std::vector<T>& HostVector() const {return ConstHostVector(); } const std::vector<T>& HostVector() const {return ConstHostVector(); }
bool HostCanAccess(GPUAccess access) const; bool HostCanAccess(GPUAccess access) const;
bool DeviceCanAccess(int device, GPUAccess access) const; bool DeviceCanAccess(GPUAccess access) const;
/*! void SetDevice(int device) const;
* \brief Specify memory distribution.
*/
void Shard(const GPUDistribution &distribution) const;
void Shard(GPUSet devices) const;
/*!
* \brief Change memory distribution.
*/
void Reshard(const GPUDistribution &distribution);
void Resize(size_t new_size, T v = T()); void Resize(size_t new_size, T v = T());

View File

@ -57,14 +57,10 @@ class Transform {
template <typename Functor> template <typename Functor>
struct Evaluator { struct Evaluator {
public: public:
Evaluator(Functor func, Range range, GPUSet devices, bool shard) : Evaluator(Functor func, Range range, int device, bool shard) :
func_(func), range_{std::move(range)}, func_(func), range_{std::move(range)},
shard_{shard}, shard_{shard},
distribution_{GPUDistribution::Block(devices)} {} device_{device} {}
Evaluator(Functor func, Range range, GPUDistribution dist,
bool shard) :
func_(func), range_{std::move(range)}, shard_{shard},
distribution_{std::move(dist)} {}
/*! /*!
* \brief Evaluate the functor with input pointers to HostDeviceVector. * \brief Evaluate the functor with input pointers to HostDeviceVector.
@ -74,7 +70,7 @@ class Transform {
*/ */
template <typename... HDV> template <typename... HDV>
void Eval(HDV... vectors) const { void Eval(HDV... vectors) const {
bool on_device = !distribution_.IsEmpty(); bool on_device = device_ >= 0;
if (on_device) { if (on_device) {
LaunchCUDA(func_, vectors...); LaunchCUDA(func_, vectors...);
@ -86,13 +82,13 @@ class Transform {
private: private:
// CUDA UnpackHDV // CUDA UnpackHDV
template <typename T> template <typename T>
Span<T> UnpackHDV(HostDeviceVector<T>* _vec, int _device) const { Span<T> UnpackHDVOnDevice(HostDeviceVector<T>* _vec) const {
auto span = _vec->DeviceSpan(_device); auto span = _vec->DeviceSpan();
return span; return span;
} }
template <typename T> template <typename T>
Span<T const> UnpackHDV(const HostDeviceVector<T>* _vec, int _device) const { Span<T const> UnpackHDVOnDevice(const HostDeviceVector<T>* _vec) const {
auto span = _vec->ConstDeviceSpan(_device); auto span = _vec->ConstDeviceSpan();
return span; return span;
} }
// CPU UnpackHDV // CPU UnpackHDV
@ -108,15 +104,15 @@ class Transform {
} }
// Recursive unpack for Shard. // Recursive unpack for Shard.
template <typename T> template <typename T>
void UnpackShard(GPUDistribution dist, const HostDeviceVector<T> *vector) const { void UnpackShard(int device, const HostDeviceVector<T> *vector) const {
vector->Shard(dist); vector->SetDevice(device);
} }
template <typename Head, typename... Rest> template <typename Head, typename... Rest>
void UnpackShard(GPUDistribution dist, void UnpackShard(int device,
const HostDeviceVector<Head> *_vector, const HostDeviceVector<Head> *_vector,
const HostDeviceVector<Rest> *... _vectors) const { const HostDeviceVector<Rest> *... _vectors) const {
_vector->Shard(dist); _vector->SetDevice(device);
UnpackShard(dist, _vectors...); UnpackShard(device, _vectors...);
} }
#if defined(__CUDACC__) #if defined(__CUDACC__)
@ -124,28 +120,20 @@ class Transform {
typename... HDV> typename... HDV>
void LaunchCUDA(Functor _func, HDV*... _vectors) const { void LaunchCUDA(Functor _func, HDV*... _vectors) const {
if (shard_) if (shard_)
UnpackShard(distribution_, _vectors...); UnpackShard(device_, _vectors...);
GPUSet devices = distribution_.Devices();
size_t range_size = *range_.end() - *range_.begin(); size_t range_size = *range_.end() - *range_.begin();
// Extract index to deal with possible old OpenMP. // Extract index to deal with possible old OpenMP.
size_t device_beg = *(devices.begin()); // This deals with situation like multi-class setting where
size_t device_end = *(devices.end()); // granularity is used in data vector.
#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1) size_t shard_size = range_size;
for (omp_ulong device = device_beg; device < device_end; ++device) { // NOLINT Range shard_range {0, static_cast<Range::DifferenceType>(shard_size)};
// Ignore other attributes of GPUDistribution for spliting index. dh::safe_cuda(cudaSetDevice(device_));
// This deals with situation like multi-class setting where const int GRID_SIZE =
// granularity is used in data vector. static_cast<int>(DivRoundUp(*(range_.end()), kBlockThreads));
size_t shard_size = GPUDistribution::Block(devices).ShardSize( detail::LaunchCUDAKernel<<<GRID_SIZE, kBlockThreads>>>(
range_size, devices.Index(device)); _func, shard_range, UnpackHDVOnDevice(_vectors)...);
Range shard_range {0, static_cast<Range::DifferenceType>(shard_size)};
dh::safe_cuda(cudaSetDevice(device));
const int GRID_SIZE =
static_cast<int>(DivRoundUp(*(range_.end()), kBlockThreads));
detail::LaunchCUDAKernel<<<GRID_SIZE, kBlockThreads>>>(
_func, shard_range, UnpackHDV(_vectors, device)...);
}
} }
#else #else
/*! \brief Dummy funtion defined when compiling for CPU. */ /*! \brief Dummy funtion defined when compiling for CPU. */
@ -172,7 +160,7 @@ class Transform {
Range range_; Range range_;
/*! \brief Whether sharding for vectors is required. */ /*! \brief Whether sharding for vectors is required. */
bool shard_; bool shard_;
GPUDistribution distribution_; int device_;
}; };
public: public:
@ -191,15 +179,9 @@ class Transform {
*/ */
template <typename Functor> template <typename Functor>
static Evaluator<Functor> Init(Functor func, Range const range, static Evaluator<Functor> Init(Functor func, Range const range,
GPUSet const devices, int device,
bool const shard = true) { bool const shard = true) {
return Evaluator<Functor> {func, std::move(range), std::move(devices), shard}; return Evaluator<Functor> {func, std::move(range), device, shard};
}
template <typename Functor>
static Evaluator<Functor> Init(Functor func, Range const range,
GPUDistribution const dist,
bool const shard = true) {
return Evaluator<Functor> {func, std::move(range), std::move(dist), shard};
} }
}; };

View File

@ -78,9 +78,9 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
} else { } else {
LOG(FATAL) << "Unknown metainfo: " << key; LOG(FATAL) << "Unknown metainfo: " << key;
} }
dst->Reshard(GPUDistribution(GPUSet::Range(ptr_device, 1))); dst->SetDevice(ptr_device);
dst->Resize(length); dst->Resize(length);
auto p_dst = thrust::device_pointer_cast(dst->DevicePointer(0)); auto p_dst = thrust::device_pointer_cast(dst->DevicePointer());
thrust::copy(p_src, p_src + length, p_dst); thrust::copy(p_src, p_src + length, p_dst);
} }
} // namespace xgboost } // namespace xgboost

View File

@ -77,16 +77,14 @@ void SimpleCSRSource::FromDeviceColumnar(std::vector<Columnar> cols) {
dh::safe_cuda(cudaSetDevice(device)); dh::safe_cuda(cudaSetDevice(device));
GPUSet devices = GPUSet::Range(device, 1); page_.offset.SetDevice(device);
page_.offset.Reshard(GPUDistribution(devices));
page_.offset.Resize(info.num_row_ + 1); page_.offset.Resize(info.num_row_ + 1);
page_.data.Reshard(GPUDistribution(devices)); page_.data.SetDevice(device);
page_.data.Resize(info.num_nonzero_); page_.data.Resize(info.num_nonzero_);
auto s_data = page_.data.DeviceSpan(device); auto s_data = page_.data.DeviceSpan();
auto s_offsets = page_.offset.DeviceSpan(device); auto s_offsets = page_.offset.DeviceSpan();
CHECK_EQ(s_offsets.size(), n_rows + 1); CHECK_EQ(s_offsets.size(), n_rows + 1);
int32_t constexpr kThreads = 256; int32_t constexpr kThreads = 256;

View File

@ -182,9 +182,9 @@ void GBTree::DoBoost(DMatrix* p_fmat,
CHECK_EQ(in_gpair->Size() % ngroup, 0U) CHECK_EQ(in_gpair->Size() % ngroup, 0U)
<< "must have exactly ngroup*nrow gpairs"; << "must have exactly ngroup*nrow gpairs";
// TODO(canonizer): perform this on GPU if HostDeviceVector has device set. // TODO(canonizer): perform this on GPU if HostDeviceVector has device set.
HostDeviceVector<GradientPair> tmp HostDeviceVector<GradientPair> tmp(in_gpair->Size() / ngroup,
(in_gpair->Size() / ngroup, GradientPair(), GradientPair(),
GPUDistribution::Block(in_gpair->Distribution().Devices())); in_gpair->DeviceIdx());
const auto& gpair_h = in_gpair->ConstHostVector(); const auto& gpair_h = in_gpair->ConstHostVector();
auto nsize = static_cast<bst_omp_uint>(tmp.Size()); auto nsize = static_cast<bst_omp_uint>(tmp.Size());
for (int gid = 0; gid < ngroup; ++gid) { for (int gid = 0; gid < ngroup; ++gid) {

View File

@ -237,14 +237,13 @@ class LearnerImpl : public Learner {
std::vector<std::pair<std::string, std::string> > attr; std::vector<std::pair<std::string, std::string> > attr;
fi->Read(&attr); fi->Read(&attr);
for (auto& kv : attr) { for (auto& kv : attr) {
// Load `predictor`, `n_gpus`, `gpu_id` parameters from extra attributes // Load `predictor`, `gpu_id` parameters from extra attributes
const std::string prefix = "SAVED_PARAM_"; const std::string prefix = "SAVED_PARAM_";
if (kv.first.find(prefix) == 0) { if (kv.first.find(prefix) == 0) {
const std::string saved_param = kv.first.substr(prefix.length()); const std::string saved_param = kv.first.substr(prefix.length());
bool is_gpu_predictor = saved_param == "predictor" && kv.second == "gpu_predictor"; bool is_gpu_predictor = saved_param == "predictor" && kv.second == "gpu_predictor";
#ifdef XGBOOST_USE_CUDA #ifdef XGBOOST_USE_CUDA
if (saved_param == "predictor" || saved_param == "n_gpus" if (saved_param == "predictor" || saved_param == "gpu_id") {
|| saved_param == "gpu_id") {
cfg_[saved_param] = kv.second; cfg_[saved_param] = kv.second;
LOG(INFO) LOG(INFO)
<< "Parameter '" << saved_param << "' has been recovered from " << "Parameter '" << saved_param << "' has been recovered from "
@ -266,7 +265,7 @@ class LearnerImpl : public Learner {
} }
#endif // XGBOOST_USE_CUDA #endif // XGBOOST_USE_CUDA
// NO visible GPU in current environment // NO visible GPU in current environment
if (is_gpu_predictor && GPUSet::AllVisible().Size() == 0) { if (is_gpu_predictor && common::AllVisibleGPUs() == 0) {
cfg_["predictor"] = "cpu_predictor"; cfg_["predictor"] = "cpu_predictor";
kv.second = "cpu_predictor"; kv.second = "cpu_predictor";
LOG(INFO) << "Switch gpu_predictor to cpu_predictor."; LOG(INFO) << "Switch gpu_predictor to cpu_predictor.";
@ -294,7 +293,9 @@ class LearnerImpl : public Learner {
auto n = tparam_.__DICT__(); auto n = tparam_.__DICT__();
cfg_.insert(n.cbegin(), n.cend()); cfg_.insert(n.cbegin(), n.cend());
gbm_->Configure({cfg_.cbegin(), cfg_.cend()}); Args args = {cfg_.cbegin(), cfg_.cend()};
generic_param_.InitAllowUnknown(args);
gbm_->Configure(args);
obj_->Configure({cfg_.begin(), cfg_.end()}); obj_->Configure({cfg_.begin(), cfg_.end()});
for (auto& p_metric : metrics_) { for (auto& p_metric : metrics_) {
@ -331,9 +332,8 @@ class LearnerImpl : public Learner {
} }
} }
{ {
// Write `predictor`, `n_gpus`, `gpu_id` parameters as extra attributes // Write `predictor`, `gpu_id` parameters as extra attributes
for (const auto& key : std::vector<std::string>{ for (const auto& key : std::vector<std::string>{"predictor", "gpu_id"}) {
"predictor", "n_gpus", "gpu_id"}) {
auto it = cfg_.find(key); auto it = cfg_.find(key);
if (it != cfg_.end()) { if (it != cfg_.end()) {
mparam.contain_extra_attrs = 1; mparam.contain_extra_attrs = 1;
@ -581,13 +581,8 @@ class LearnerImpl : public Learner {
gbm_->Configure(args); gbm_->Configure(args);
if (this->gbm_->UseGPU()) { if (this->gbm_->UseGPU()) {
if (cfg_.find("n_gpus") == cfg_.cend()) { if (cfg_.find("gpu_id") == cfg_.cend()) {
generic_param_.n_gpus = 1; generic_param_.gpu_id = 0;
}
if (generic_param_.n_gpus != 1) {
LOG(FATAL) << "Single process multi-GPU training is no longer supported. "
"Please switch to distributed GPU training with one process per GPU. "
"This can be done using Dask or Spark.";
} }
} }
} }

View File

@ -19,12 +19,6 @@ namespace linear {
DMLC_REGISTRY_FILE_TAG(updater_gpu_coordinate); DMLC_REGISTRY_FILE_TAG(updater_gpu_coordinate);
void RescaleIndices(int device_idx, size_t ridx_begin,
common::Span<xgboost::Entry> data) {
dh::LaunchN(device_idx, data.size(),
[=] __device__(size_t idx) { data[idx].index -= ridx_begin; });
}
class DeviceShard { class DeviceShard {
int device_id_; int device_id_;
dh::BulkAllocator ba_; dh::BulkAllocator ba_;
@ -32,18 +26,16 @@ class DeviceShard {
common::Span<xgboost::Entry> data_; common::Span<xgboost::Entry> data_;
common::Span<GradientPair> gpair_; common::Span<GradientPair> gpair_;
dh::CubMemory temp_; dh::CubMemory temp_;
size_t ridx_begin_; size_t shard_size_;
size_t ridx_end_;
public: public:
DeviceShard(int device_id, DeviceShard(int device_id,
const SparsePage &batch, // column batch const SparsePage &batch, // column batch
bst_uint row_begin, bst_uint row_end, bst_uint shard_size,
const LinearTrainParam &param, const LinearTrainParam &param,
const gbm::GBLinearModelParam &model_param) const gbm::GBLinearModelParam &model_param)
: device_id_(device_id), : device_id_(device_id),
ridx_begin_(row_begin), shard_size_(shard_size) {
ridx_end_(row_end) {
if ( IsEmpty() ) { return; } if ( IsEmpty() ) { return; }
dh::safe_cuda(cudaSetDevice(device_id_)); dh::safe_cuda(cudaSetDevice(device_id_));
// The begin and end indices for the section of each column associated with // The begin and end indices for the section of each column associated with
@ -51,25 +43,25 @@ class DeviceShard {
std::vector<std::pair<bst_uint, bst_uint>> column_segments; std::vector<std::pair<bst_uint, bst_uint>> column_segments;
row_ptr_ = {0}; row_ptr_ = {0};
// iterate through columns // iterate through columns
for (auto fidx = 0; fidx < batch.Size(); fidx++) { for (size_t fidx = 0; fidx < batch.Size(); fidx++) {
common::Span<Entry const> col = batch[fidx]; common::Span<Entry const> col = batch[fidx];
auto cmp = [](Entry e1, Entry e2) { auto cmp = [](Entry e1, Entry e2) {
return e1.index < e2.index; return e1.index < e2.index;
}; };
auto column_begin = auto column_begin =
std::lower_bound(col.cbegin(), col.cend(), std::lower_bound(col.cbegin(), col.cend(),
xgboost::Entry(row_begin, 0.0f), cmp); xgboost::Entry(0, 0.0f), cmp);
auto column_end = auto column_end =
std::lower_bound(col.cbegin(), col.cend(), std::lower_bound(col.cbegin(), col.cend(),
xgboost::Entry(row_end, 0.0f), cmp); xgboost::Entry(shard_size_, 0.0f), cmp);
column_segments.emplace_back( column_segments.emplace_back(
std::make_pair(column_begin - col.cbegin(), column_end - col.cbegin())); std::make_pair(column_begin - col.cbegin(), column_end - col.cbegin()));
row_ptr_.push_back(row_ptr_.back() + (column_end - column_begin)); row_ptr_.push_back(row_ptr_.back() + (column_end - column_begin));
} }
ba_.Allocate(device_id_, &data_, row_ptr_.back(), &gpair_, ba_.Allocate(device_id_, &data_, row_ptr_.back(), &gpair_,
(row_end - row_begin) * model_param.num_output_group); shard_size_ * model_param.num_output_group);
for (int fidx = 0; fidx < batch.Size(); fidx++) { for (size_t fidx = 0; fidx < batch.Size(); fidx++) {
auto col = batch[fidx]; auto col = batch[fidx];
auto seg = column_segments[fidx]; auto seg = column_segments[fidx];
dh::safe_cuda(cudaMemcpy( dh::safe_cuda(cudaMemcpy(
@ -77,23 +69,21 @@ class DeviceShard {
col.data() + seg.first, col.data() + seg.first,
sizeof(Entry) * (seg.second - seg.first), cudaMemcpyHostToDevice)); sizeof(Entry) * (seg.second - seg.first), cudaMemcpyHostToDevice));
} }
// Rescale indices with respect to current shard
RescaleIndices(device_id_, ridx_begin_, data_);
} }
~DeviceShard() { ~DeviceShard() { // NOLINT
dh::safe_cuda(cudaSetDevice(device_id_)); dh::safe_cuda(cudaSetDevice(device_id_));
} }
bool IsEmpty() { bool IsEmpty() {
return (ridx_end_ - ridx_begin_) == 0; return shard_size_ == 0;
} }
void UpdateGpair(const std::vector<GradientPair> &host_gpair, void UpdateGpair(const std::vector<GradientPair> &host_gpair,
const gbm::GBLinearModelParam &model_param) { const gbm::GBLinearModelParam &model_param) {
dh::safe_cuda(cudaMemcpyAsync( dh::safe_cuda(cudaMemcpyAsync(
gpair_.data(), gpair_.data(),
host_gpair.data() + ridx_begin_ * model_param.num_output_group, host_gpair.data(),
gpair_.size() * sizeof(GradientPair), cudaMemcpyHostToDevice)); gpair_.size() * sizeof(GradientPair), cudaMemcpyHostToDevice));
} }
@ -107,13 +97,13 @@ class DeviceShard {
counting, f); counting, f);
auto perm = thrust::make_permutation_iterator(gpair_.data(), skip); auto perm = thrust::make_permutation_iterator(gpair_.data(), skip);
return dh::SumReduction(temp_, perm, ridx_end_ - ridx_begin_); return dh::SumReduction(temp_, perm, shard_size_);
} }
void UpdateBiasResidual(float dbias, int group_idx, int num_groups) { void UpdateBiasResidual(float dbias, int group_idx, int num_groups) {
if (dbias == 0.0f) return; if (dbias == 0.0f) return;
auto d_gpair = gpair_; auto d_gpair = gpair_;
dh::LaunchN(device_id_, ridx_end_ - ridx_begin_, [=] __device__(size_t idx) { dh::LaunchN(device_id_, shard_size_, [=] __device__(size_t idx) {
auto &g = d_gpair[idx * num_groups + group_idx]; auto &g = d_gpair[idx * num_groups + group_idx];
g += GradientPair(g.GetHess() * dbias, 0); g += GradientPair(g.GetHess() * dbias, 0);
}); });
@ -154,7 +144,7 @@ class DeviceShard {
* \brief Coordinate descent algorithm that updates one feature per iteration * \brief Coordinate descent algorithm that updates one feature per iteration
*/ */
class GPUCoordinateUpdater : public LinearUpdater { class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
public: public:
// set training parameter // set training parameter
void Configure(Args const& args) override { void Configure(Args const& args) override {
@ -165,37 +155,23 @@ class GPUCoordinateUpdater : public LinearUpdater {
void LazyInitShards(DMatrix *p_fmat, void LazyInitShards(DMatrix *p_fmat,
const gbm::GBLinearModelParam &model_param) { const gbm::GBLinearModelParam &model_param) {
if (!shards_.empty()) return; if (shard_) return;
dist_ = GPUDistribution::Block(GPUSet::All(learner_param_->gpu_id, learner_param_->n_gpus, device_ = learner_param_->gpu_id;
p_fmat->Info().num_row_));
auto devices = dist_.Devices();
size_t n_devices = static_cast<size_t>(devices.Size()); auto num_row = static_cast<size_t>(p_fmat->Info().num_row_);
size_t row_begin = 0;
size_t num_row = static_cast<size_t>(p_fmat->Info().num_row_);
// Partition input matrix into row segments // Partition input matrix into row segments
std::vector<size_t> row_segments; std::vector<size_t> row_segments;
row_segments.push_back(0); row_segments.push_back(0);
for (int d_idx = 0; d_idx < n_devices; ++d_idx) { size_t shard_size = num_row;
size_t shard_size = dist_.ShardSize(num_row, d_idx); row_segments.push_back(shard_size);
size_t row_end = row_begin + shard_size;
row_segments.push_back(row_end);
row_begin = row_end;
}
CHECK(p_fmat->SingleColBlock()); CHECK(p_fmat->SingleColBlock());
SparsePage const& batch = *(p_fmat->GetBatches<CSCPage>().begin()); SparsePage const& batch = *(p_fmat->GetBatches<CSCPage>().begin());
shards_.resize(n_devices); // Create device shard
// Create device shards shard_.reset(new DeviceShard(device_, batch, shard_size, tparam_, model_param));
dh::ExecuteIndexShards(&shards_,
[&](int i, std::unique_ptr<DeviceShard>& shard) {
shard = std::unique_ptr<DeviceShard>(
new DeviceShard(devices.DeviceId(i), batch, row_segments[i],
row_segments[i + 1], tparam_, model_param));
});
} }
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat, void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat,
@ -208,11 +184,9 @@ class GPUCoordinateUpdater : public LinearUpdater {
monitor_.Start("UpdateGpair"); monitor_.Start("UpdateGpair");
auto &in_gpair_host = in_gpair->ConstHostVector(); auto &in_gpair_host = in_gpair->ConstHostVector();
// Update gpair // Update gpair
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) { if (shard_) {
if (!shard->IsEmpty()) { shard_->UpdateGpair(in_gpair_host, model->param);
shard->UpdateGpair(in_gpair_host, model->param); }
}
});
monitor_.Stop("UpdateGpair"); monitor_.Stop("UpdateGpair");
monitor_.Start("UpdateBias"); monitor_.Start("UpdateBias");
@ -237,32 +211,21 @@ class GPUCoordinateUpdater : public LinearUpdater {
} }
void UpdateBias(DMatrix *p_fmat, gbm::GBLinearModel *model) { void UpdateBias(DMatrix *p_fmat, gbm::GBLinearModel *model) {
for (int group_idx = 0; group_idx < model->param.num_output_group; for (int group_idx = 0; group_idx < model->param.num_output_group; ++group_idx) {
++group_idx) {
// Get gradient // Get gradient
auto grad = dh::ReduceShards<GradientPair>( auto grad = GradientPair(0, 0);
&shards_, [&](std::unique_ptr<DeviceShard> &shard) { if (shard_) {
if (!shard->IsEmpty()) { grad = shard_->GetBiasGradient(group_idx, model->param.num_output_group);
GradientPair result = }
shard->GetBiasGradient(group_idx,
model->param.num_output_group);
return result;
}
return GradientPair(0, 0);
});
auto dbias = static_cast<float>( auto dbias = static_cast<float>(
tparam_.learning_rate * tparam_.learning_rate *
CoordinateDeltaBias(grad.GetGrad(), grad.GetHess())); CoordinateDeltaBias(grad.GetGrad(), grad.GetHess()));
model->bias()[group_idx] += dbias; model->bias()[group_idx] += dbias;
// Update residual // Update residual
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) { if (shard_) {
if (!shard->IsEmpty()) { shard_->UpdateBiasResidual(dbias, group_idx, model->param.num_output_group);
shard->UpdateBiasResidual(dbias, group_idx, }
model->param.num_output_group);
}
});
} }
} }
@ -271,38 +234,30 @@ class GPUCoordinateUpdater : public LinearUpdater {
gbm::GBLinearModel *model) { gbm::GBLinearModel *model) {
bst_float &w = (*model)[fidx][group_idx]; bst_float &w = (*model)[fidx][group_idx];
// Get gradient // Get gradient
auto grad = dh::ReduceShards<GradientPair>( auto grad = GradientPair(0, 0);
&shards_, [&](std::unique_ptr<DeviceShard> &shard) { if (shard_) {
if (!shard->IsEmpty()) { grad = shard_->GetGradient(group_idx, model->param.num_output_group, fidx);
return shard->GetGradient(group_idx, model->param.num_output_group, }
fidx);
}
return GradientPair(0, 0);
});
auto dw = static_cast<float>(tparam_.learning_rate * auto dw = static_cast<float>(tparam_.learning_rate *
CoordinateDelta(grad.GetGrad(), grad.GetHess(), CoordinateDelta(grad.GetGrad(), grad.GetHess(),
w, tparam_.reg_alpha_denorm, w, tparam_.reg_alpha_denorm,
tparam_.reg_lambda_denorm)); tparam_.reg_lambda_denorm));
w += dw; w += dw;
dh::ExecuteIndexShards(&shards_, [&](int idx, if (shard_) {
std::unique_ptr<DeviceShard> &shard) { shard_->UpdateResidual(dw, group_idx, model->param.num_output_group, fidx);
if (!shard->IsEmpty()) { }
shard->UpdateResidual(dw, group_idx, model->param.num_output_group, fidx);
}
});
} }
private: private:
// training parameter // training parameter
LinearTrainParam tparam_; LinearTrainParam tparam_;
CoordinateParam coord_param_; CoordinateParam coord_param_;
GPUDistribution dist_; int device_{};
std::unique_ptr<FeatureSelector> selector_; std::unique_ptr<FeatureSelector> selector_;
common::Monitor monitor_; common::Monitor monitor_;
std::vector<std::unique_ptr<DeviceShard>> shards_; std::unique_ptr<DeviceShard> shard_{nullptr};
}; };
XGBOOST_REGISTER_LINEAR_UPDATER(GPUCoordinateUpdater, "gpu_coord_descent") XGBOOST_REGISTER_LINEAR_UPDATER(GPUCoordinateUpdater, "gpu_coord_descent")

View File

@ -30,8 +30,7 @@ DMLC_REGISTRY_FILE_TAG(elementwise_metric);
template <typename EvalRow> template <typename EvalRow>
class ElementWiseMetricsReduction { class ElementWiseMetricsReduction {
public: public:
explicit ElementWiseMetricsReduction(EvalRow policy) : explicit ElementWiseMetricsReduction(EvalRow policy) : policy_(std::move(policy)) {}
policy_(std::move(policy)) {}
PackedReduceResult CpuReduceMetrics( PackedReduceResult CpuReduceMetrics(
const HostDeviceVector<bst_float>& weights, const HostDeviceVector<bst_float>& weights,
@ -59,34 +58,31 @@ class ElementWiseMetricsReduction {
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
~ElementWiseMetricsReduction() { ~ElementWiseMetricsReduction() {
for (GPUSet::GpuIdType id = *devices_.begin(); id < *devices_.end(); ++id) { if (device_ >= 0) {
dh::safe_cuda(cudaSetDevice(id)); dh::safe_cuda(cudaSetDevice(device_));
size_t index = devices_.Index(id); allocator_.Free();
allocators_.at(index).Free();
} }
} }
PackedReduceResult DeviceReduceMetrics( PackedReduceResult DeviceReduceMetrics(
GPUSet::GpuIdType device_id,
size_t device_index,
const HostDeviceVector<bst_float>& weights, const HostDeviceVector<bst_float>& weights,
const HostDeviceVector<bst_float>& labels, const HostDeviceVector<bst_float>& labels,
const HostDeviceVector<bst_float>& preds) { const HostDeviceVector<bst_float>& preds) {
size_t n_data = preds.DeviceSize(device_id); size_t n_data = preds.DeviceSize();
thrust::counting_iterator<size_t> begin(0); thrust::counting_iterator<size_t> begin(0);
thrust::counting_iterator<size_t> end = begin + n_data; thrust::counting_iterator<size_t> end = begin + n_data;
auto s_label = labels.DeviceSpan(device_id); auto s_label = labels.DeviceSpan();
auto s_preds = preds.DeviceSpan(device_id); auto s_preds = preds.DeviceSpan();
auto s_weights = weights.DeviceSpan(device_id); auto s_weights = weights.DeviceSpan();
bool const is_null_weight = weights.Size() == 0; bool const is_null_weight = weights.Size() == 0;
auto d_policy = policy_; auto d_policy = policy_;
PackedReduceResult result = thrust::transform_reduce( PackedReduceResult result = thrust::transform_reduce(
thrust::cuda::par(allocators_.at(device_index)), thrust::cuda::par(allocator_),
begin, end, begin, end,
[=] XGBOOST_DEVICE(size_t idx) { [=] XGBOOST_DEVICE(size_t idx) {
bst_float weight = is_null_weight ? 1.0f : s_weights[idx]; bst_float weight = is_null_weight ? 1.0f : s_weights[idx];
@ -105,37 +101,24 @@ class ElementWiseMetricsReduction {
PackedReduceResult Reduce( PackedReduceResult Reduce(
const GenericParameter &tparam, const GenericParameter &tparam,
GPUSet devices, int device,
const HostDeviceVector<bst_float>& weights, const HostDeviceVector<bst_float>& weights,
const HostDeviceVector<bst_float>& labels, const HostDeviceVector<bst_float>& labels,
const HostDeviceVector<bst_float>& preds) { const HostDeviceVector<bst_float>& preds) {
PackedReduceResult result; PackedReduceResult result;
if (devices.IsEmpty()) { if (device < 0) {
result = CpuReduceMetrics(weights, labels, preds); result = CpuReduceMetrics(weights, labels, preds);
} }
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
else { // NOLINT else { // NOLINT
if (allocators_.empty()) { device_ = device;
devices_ = GPUSet::All(tparam.gpu_id, tparam.n_gpus); preds.SetDevice(device_);
allocators_.resize(devices_.Size()); labels.SetDevice(device_);
} weights.SetDevice(device_);
preds.Shard(devices);
labels.Shard(devices);
weights.Shard(devices);
std::vector<PackedReduceResult> res_per_device(devices.Size());
#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1) dh::safe_cuda(cudaSetDevice(device_));
for (GPUSet::GpuIdType id = *devices.begin(); id < *devices.end(); ++id) { result = DeviceReduceMetrics(weights, labels, preds);
dh::safe_cuda(cudaSetDevice(id));
size_t index = devices.Index(id);
res_per_device.at(index) =
DeviceReduceMetrics(id, index, weights, labels, preds);
}
for (auto const& res : res_per_device) {
result += res;
}
} }
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
return result; return result;
@ -144,8 +127,8 @@ class ElementWiseMetricsReduction {
private: private:
EvalRow policy_; EvalRow policy_;
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
GPUSet devices_; int device_{-1};
std::vector<dh::CubMemory> allocators_; dh::CubMemory allocator_;
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
}; };
@ -345,11 +328,10 @@ struct EvalEWiseBase : public Metric {
<< "label and prediction size not match, " << "label and prediction size not match, "
<< "hint: use merror or mlogloss for multi-class classification"; << "hint: use merror or mlogloss for multi-class classification";
const auto ndata = static_cast<omp_ulong>(info.labels_.Size()); const auto ndata = static_cast<omp_ulong>(info.labels_.Size());
// Dealing with ndata < n_gpus. int device = tparam_->gpu_id;
GPUSet devices = GPUSet::All(tparam_->gpu_id, tparam_->n_gpus, ndata);
auto result = auto result =
reducer_.Reduce(*tparam_, devices, info.weights_, info.labels_, preds); reducer_.Reduce(*tparam_, device, info.weights_, info.labels_, preds);
double dat[2] { result.Residue(), result.Weights() }; double dat[2] { result.Residue(), result.Weights() };
if (distributed) { if (distributed) {

View File

@ -74,35 +74,32 @@ class MultiClassMetricsReduction {
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
~MultiClassMetricsReduction() { ~MultiClassMetricsReduction() {
for (GPUSet::GpuIdType id = *devices_.begin(); id < *devices_.end(); ++id) { if (device_ >= 0) {
dh::safe_cuda(cudaSetDevice(id)); dh::safe_cuda(cudaSetDevice(device_));
size_t index = devices_.Index(id); allocator_.Free();
allocators_.at(index).Free();
} }
} }
PackedReduceResult DeviceReduceMetrics( PackedReduceResult DeviceReduceMetrics(
GPUSet::GpuIdType device_id,
size_t device_index,
const HostDeviceVector<bst_float>& weights, const HostDeviceVector<bst_float>& weights,
const HostDeviceVector<bst_float>& labels, const HostDeviceVector<bst_float>& labels,
const HostDeviceVector<bst_float>& preds, const HostDeviceVector<bst_float>& preds,
const size_t n_class) { const size_t n_class) {
size_t n_data = labels.DeviceSize(device_id); size_t n_data = labels.DeviceSize();
thrust::counting_iterator<size_t> begin(0); thrust::counting_iterator<size_t> begin(0);
thrust::counting_iterator<size_t> end = begin + n_data; thrust::counting_iterator<size_t> end = begin + n_data;
auto s_labels = labels.DeviceSpan(device_id); auto s_labels = labels.DeviceSpan();
auto s_preds = preds.DeviceSpan(device_id); auto s_preds = preds.DeviceSpan();
auto s_weights = weights.DeviceSpan(device_id); auto s_weights = weights.DeviceSpan();
bool const is_null_weight = weights.Size() == 0; bool const is_null_weight = weights.Size() == 0;
auto s_label_error = label_error_.GetSpan<int32_t>(1); auto s_label_error = label_error_.GetSpan<int32_t>(1);
s_label_error[0] = 0; s_label_error[0] = 0;
PackedReduceResult result = thrust::transform_reduce( PackedReduceResult result = thrust::transform_reduce(
thrust::cuda::par(allocators_.at(device_index)), thrust::cuda::par(allocator_),
begin, end, begin, end,
[=] XGBOOST_DEVICE(size_t idx) { [=] XGBOOST_DEVICE(size_t idx) {
bst_float weight = is_null_weight ? 1.0f : s_weights[idx]; bst_float weight = is_null_weight ? 1.0f : s_weights[idx];
@ -127,38 +124,25 @@ class MultiClassMetricsReduction {
PackedReduceResult Reduce( PackedReduceResult Reduce(
const GenericParameter &tparam, const GenericParameter &tparam,
GPUSet devices, int device,
size_t n_class, size_t n_class,
const HostDeviceVector<bst_float>& weights, const HostDeviceVector<bst_float>& weights,
const HostDeviceVector<bst_float>& labels, const HostDeviceVector<bst_float>& labels,
const HostDeviceVector<bst_float>& preds) { const HostDeviceVector<bst_float>& preds) {
PackedReduceResult result; PackedReduceResult result;
if (devices.IsEmpty()) { if (device < 0) {
result = CpuReduceMetrics(weights, labels, preds, n_class); result = CpuReduceMetrics(weights, labels, preds, n_class);
} }
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
else { // NOLINT else { // NOLINT
if (allocators_.empty()) { device_ = tparam.gpu_id;
devices_ = GPUSet::All(tparam.gpu_id, tparam.n_gpus); preds.SetDevice(device_);
allocators_.resize(devices_.Size()); labels.SetDevice(device_);
} weights.SetDevice(device_);
preds.Shard(GPUDistribution::Granular(devices, n_class));
labels.Shard(devices);
weights.Shard(devices);
std::vector<PackedReduceResult> res_per_device(devices.Size());
#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1) dh::safe_cuda(cudaSetDevice(device_));
for (GPUSet::GpuIdType id = *devices.begin(); id < *devices.end(); ++id) { result = DeviceReduceMetrics(weights, labels, preds, n_class);
dh::safe_cuda(cudaSetDevice(id));
size_t index = devices.Index(id);
res_per_device.at(index) =
DeviceReduceMetrics(id, index, weights, labels, preds, n_class);
}
for (auto const& res : res_per_device) {
result += res;
}
} }
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
return result; return result;
@ -167,8 +151,8 @@ class MultiClassMetricsReduction {
private: private:
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
dh::PinnedMemory label_error_; dh::PinnedMemory label_error_;
GPUSet devices_; int device_{-1};
std::vector<dh::CubMemory> allocators_; dh::CubMemory allocator_;
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
}; };
@ -190,8 +174,8 @@ struct EvalMClassBase : public Metric {
<< " use logloss for binary classification"; << " use logloss for binary classification";
const auto ndata = static_cast<bst_omp_uint>(info.labels_.Size()); const auto ndata = static_cast<bst_omp_uint>(info.labels_.Size());
GPUSet devices = GPUSet::All(tparam_->gpu_id, tparam_->n_gpus, ndata); int device = tparam_->gpu_id;
auto result = reducer_.Reduce(*tparam_, devices, nclass, info.weights_, info.labels_, preds); auto result = reducer_.Reduce(*tparam_, device, nclass, info.weights_, info.labels_, preds);
double dat[2] { result.Residue(), result.Weights() }; double dat[2] { result.Residue(), result.Weights() };
if (distributed) { if (distributed) {

View File

@ -58,7 +58,7 @@ class HingeObj : public ObjFunction {
_out_gpair[_idx] = GradientPair(g, h); _out_gpair[_idx] = GradientPair(g, h);
}, },
common::Range{0, static_cast<int64_t>(ndata)}, common::Range{0, static_cast<int64_t>(ndata)},
GPUSet::All(tparam_->gpu_id, tparam_->n_gpus, ndata)).Eval( tparam_->gpu_id).Eval(
out_gpair, &preds, &info.labels_, &info.weights_); out_gpair, &preds, &info.labels_, &info.weights_);
} }
@ -68,7 +68,7 @@ class HingeObj : public ObjFunction {
_preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0; _preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0;
}, },
common::Range{0, static_cast<int64_t>(io_preds->Size()), 1}, common::Range{0, static_cast<int64_t>(io_preds->Size()), 1},
GPUSet::All(tparam_->gpu_id, tparam_->n_gpus, io_preds->Size())) tparam_->gpu_id)
.Eval(io_preds); .Eval(io_preds);
} }

View File

@ -59,14 +59,14 @@ class SoftmaxMultiClassObj : public ObjFunction {
const int nclass = param_.num_class; const int nclass = param_.num_class;
const auto ndata = static_cast<int64_t>(preds.Size() / nclass); const auto ndata = static_cast<int64_t>(preds.Size() / nclass);
auto devices = GPUSet::All(tparam_->gpu_id, tparam_->n_gpus, preds.Size()); auto device = tparam_->gpu_id;
out_gpair->Shard(GPUDistribution::Granular(devices, nclass)); out_gpair->SetDevice(device);
info.labels_.Shard(GPUDistribution::Block(devices)); info.labels_.SetDevice(device);
info.weights_.Shard(GPUDistribution::Block(devices)); info.weights_.SetDevice(device);
preds.Shard(GPUDistribution::Granular(devices, nclass)); preds.SetDevice(device);
label_correct_.Resize(devices.IsEmpty() ? 1 : devices.Size()); label_correct_.Resize(1);
label_correct_.Shard(GPUDistribution::Block(devices)); label_correct_.SetDevice(device);
out_gpair->Resize(preds.Size()); out_gpair->Resize(preds.Size());
label_correct_.Fill(1); label_correct_.Fill(1);
@ -100,7 +100,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
p = label == k ? p - 1.0f : p; p = label == k ? p - 1.0f : p;
gpair[idx * nclass + k] = GradientPair(p * wt, h); gpair[idx * nclass + k] = GradientPair(p * wt, h);
} }
}, common::Range{0, ndata}, devices, false) }, common::Range{0, ndata}, device, false)
.Eval(out_gpair, &info.labels_, &preds, &info.weights_, &label_correct_); .Eval(out_gpair, &info.labels_, &preds, &info.weights_, &label_correct_);
std::vector<int>& label_correct_h = label_correct_.HostVector(); std::vector<int>& label_correct_h = label_correct_.HostVector();
@ -125,7 +125,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
const auto ndata = static_cast<int64_t>(io_preds->Size() / nclass); const auto ndata = static_cast<int64_t>(io_preds->Size() / nclass);
max_preds_.Resize(ndata); max_preds_.Resize(ndata);
auto devices = GPUSet::All(tparam_->gpu_id, tparam_->n_gpus, io_preds->Size()); auto device = tparam_->gpu_id;
if (prob) { if (prob) {
common::Transform<>::Init( common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) { [=] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
@ -133,11 +133,11 @@ class SoftmaxMultiClassObj : public ObjFunction {
_preds.subspan(_idx * nclass, nclass); _preds.subspan(_idx * nclass, nclass);
common::Softmax(point.begin(), point.end()); common::Softmax(point.begin(), point.end());
}, },
common::Range{0, ndata}, GPUDistribution::Granular(devices, nclass)) common::Range{0, ndata}, device)
.Eval(io_preds); .Eval(io_preds);
} else { } else {
io_preds->Shard(GPUDistribution::Granular(devices, nclass)); io_preds->SetDevice(device);
max_preds_.Shard(GPUDistribution::Block(devices)); max_preds_.SetDevice(device);
common::Transform<>::Init( common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx, [=] XGBOOST_DEVICE(size_t _idx,
common::Span<const bst_float> _preds, common::Span<const bst_float> _preds,
@ -148,7 +148,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
common::FindMaxIndex(point.cbegin(), common::FindMaxIndex(point.cbegin(),
point.cend()) - point.cbegin(); point.cend()) - point.cbegin();
}, },
common::Range{0, ndata}, devices, false) common::Range{0, ndata}, device, false)
.Eval(io_preds, &max_preds_); .Eval(io_preds, &max_preds_);
} }
if (!prob) { if (!prob) {

View File

@ -57,8 +57,8 @@ class RegLossObj : public ObjFunction {
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size(); << "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size();
size_t ndata = preds.Size(); size_t ndata = preds.Size();
out_gpair->Resize(ndata); out_gpair->Resize(ndata);
auto devices = GPUSet::All(tparam_->gpu_id, tparam_->n_gpus, preds.Size()); auto device = tparam_->gpu_id;
label_correct_.Resize(devices.IsEmpty() ? 1 : devices.Size()); label_correct_.Resize(1);
label_correct_.Fill(1); label_correct_.Fill(1);
bool is_null_weight = info.weights_.Size() == 0; bool is_null_weight = info.weights_.Size() == 0;
@ -83,7 +83,7 @@ class RegLossObj : public ObjFunction {
_out_gpair[_idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w, _out_gpair[_idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w,
Loss::SecondOrderGradient(p, label) * w); Loss::SecondOrderGradient(p, label) * w);
}, },
common::Range{0, static_cast<int64_t>(ndata)}, devices).Eval( common::Range{0, static_cast<int64_t>(ndata)}, device).Eval(
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
// copy "label correct" flags back to host // copy "label correct" flags back to host
@ -105,7 +105,7 @@ class RegLossObj : public ObjFunction {
[] XGBOOST_DEVICE(size_t _idx, common::Span<float> _preds) { [] XGBOOST_DEVICE(size_t _idx, common::Span<float> _preds) {
_preds[_idx] = Loss::PredTransform(_preds[_idx]); _preds[_idx] = Loss::PredTransform(_preds[_idx]);
}, common::Range{0, static_cast<int64_t>(io_preds->Size())}, }, common::Range{0, static_cast<int64_t>(io_preds->Size())},
GPUSet::All(tparam_->gpu_id, tparam_->n_gpus, io_preds->Size())) tparam_->gpu_id)
.Eval(io_preds); .Eval(io_preds);
} }
@ -175,8 +175,8 @@ class PoissonRegression : public ObjFunction {
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
size_t ndata = preds.Size(); size_t ndata = preds.Size();
out_gpair->Resize(ndata); out_gpair->Resize(ndata);
auto devices = GPUSet::All(tparam_->gpu_id, tparam_->n_gpus, preds.Size()); auto device = tparam_->gpu_id;
label_correct_.Resize(devices.IsEmpty() ? 1 : devices.Size()); label_correct_.Resize(1);
label_correct_.Fill(1); label_correct_.Fill(1);
bool is_null_weight = info.weights_.Size() == 0; bool is_null_weight = info.weights_.Size() == 0;
@ -197,7 +197,7 @@ class PoissonRegression : public ObjFunction {
_out_gpair[_idx] = GradientPair{(expf(p) - y) * w, _out_gpair[_idx] = GradientPair{(expf(p) - y) * w,
expf(p + max_delta_step) * w}; expf(p + max_delta_step) * w};
}, },
common::Range{0, static_cast<int64_t>(ndata)}, devices).Eval( common::Range{0, static_cast<int64_t>(ndata)}, device).Eval(
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
// copy "label correct" flags back to host // copy "label correct" flags back to host
std::vector<int>& label_correct_h = label_correct_.HostVector(); std::vector<int>& label_correct_h = label_correct_.HostVector();
@ -213,7 +213,7 @@ class PoissonRegression : public ObjFunction {
_preds[_idx] = expf(_preds[_idx]); _preds[_idx] = expf(_preds[_idx]);
}, },
common::Range{0, static_cast<int64_t>(io_preds->Size())}, common::Range{0, static_cast<int64_t>(io_preds->Size())},
GPUSet::All(tparam_->gpu_id, tparam_->n_gpus, io_preds->Size())) tparam_->gpu_id)
.Eval(io_preds); .Eval(io_preds);
} }
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override { void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
@ -340,9 +340,9 @@ class GammaRegression : public ObjFunction {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
const size_t ndata = preds.Size(); const size_t ndata = preds.Size();
auto devices = GPUSet::All(tparam_->gpu_id, tparam_->n_gpus, ndata); auto device = tparam_->gpu_id;
out_gpair->Resize(ndata); out_gpair->Resize(ndata);
label_correct_.Resize(devices.IsEmpty() ? 1 : devices.Size()); label_correct_.Resize(1);
label_correct_.Fill(1); label_correct_.Fill(1);
const bool is_null_weight = info.weights_.Size() == 0; const bool is_null_weight = info.weights_.Size() == 0;
@ -361,7 +361,7 @@ class GammaRegression : public ObjFunction {
} }
_out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w); _out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w);
}, },
common::Range{0, static_cast<int64_t>(ndata)}, devices).Eval( common::Range{0, static_cast<int64_t>(ndata)}, device).Eval(
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
// copy "label correct" flags back to host // copy "label correct" flags back to host
@ -378,7 +378,7 @@ class GammaRegression : public ObjFunction {
_preds[_idx] = expf(_preds[_idx]); _preds[_idx] = expf(_preds[_idx]);
}, },
common::Range{0, static_cast<int64_t>(io_preds->Size())}, common::Range{0, static_cast<int64_t>(io_preds->Size())},
GPUSet::All(tparam_->gpu_id, tparam_->n_gpus, io_preds->Size())) tparam_->gpu_id)
.Eval(io_preds); .Eval(io_preds);
} }
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override { void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
@ -430,8 +430,8 @@ class TweedieRegression : public ObjFunction {
const size_t ndata = preds.Size(); const size_t ndata = preds.Size();
out_gpair->Resize(ndata); out_gpair->Resize(ndata);
auto devices = GPUSet::All(tparam_->gpu_id, tparam_->n_gpus, preds.Size()); auto device = tparam_->gpu_id;
label_correct_.Resize(devices.IsEmpty() ? 1 : devices.Size()); label_correct_.Resize(1);
label_correct_.Fill(1); label_correct_.Fill(1);
const bool is_null_weight = info.weights_.Size() == 0; const bool is_null_weight = info.weights_.Size() == 0;
@ -455,7 +455,7 @@ class TweedieRegression : public ObjFunction {
std::exp((1 - rho) * p) + (2 - rho) * expf((2 - rho) * p); std::exp((1 - rho) * p) + (2 - rho) * expf((2 - rho) * p);
_out_gpair[_idx] = GradientPair(grad * w, hess * w); _out_gpair[_idx] = GradientPair(grad * w, hess * w);
}, },
common::Range{0, static_cast<int64_t>(ndata), 1}, devices) common::Range{0, static_cast<int64_t>(ndata), 1}, device)
.Eval(&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); .Eval(&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
// copy "label correct" flags back to host // copy "label correct" flags back to host
@ -472,7 +472,7 @@ class TweedieRegression : public ObjFunction {
_preds[_idx] = expf(_preds[_idx]); _preds[_idx] = expf(_preds[_idx]);
}, },
common::Range{0, static_cast<int64_t>(io_preds->Size())}, common::Range{0, static_cast<int64_t>(io_preds->Size())},
GPUSet::All(tparam_->gpu_id, tparam_->n_gpus, io_preds->Size())) tparam_->gpu_id)
.Eval(io_preds); .Eval(io_preds);
} }

View File

@ -20,12 +20,6 @@ namespace predictor {
DMLC_REGISTRY_FILE_TAG(gpu_predictor); DMLC_REGISTRY_FILE_TAG(gpu_predictor);
template <typename IterT>
void IncrementOffset(IterT begin_itr, IterT end_itr, size_t amount) {
thrust::transform(begin_itr, end_itr, begin_itr,
[=] __device__(size_t elem) { return elem + amount; });
}
/** /**
* \struct DevicePredictionNode * \struct DevicePredictionNode
* *
@ -44,7 +38,7 @@ struct DevicePredictionNode {
int fidx; int fidx;
int left_child_idx; int left_child_idx;
int right_child_idx; int right_child_idx;
NodeValue val; NodeValue val{};
DevicePredictionNode(const RegTree::Node& n) { // NOLINT DevicePredictionNode(const RegTree::Node& n) { // NOLINT
static_assert(sizeof(DevicePredictionNode) == 16, "Size is not 16 bytes"); static_assert(sizeof(DevicePredictionNode) == 16, "Size is not 16 bytes");
@ -200,58 +194,14 @@ __global__ void PredictKernel(common::Span<const DevicePredictionNode> d_nodes,
} }
class GPUPredictor : public xgboost::Predictor { class GPUPredictor : public xgboost::Predictor {
protected:
struct DevicePredictionCacheEntry {
std::shared_ptr<DMatrix> data;
HostDeviceVector<bst_float> predictions;
};
private: private:
void DeviceOffsets(const HostDeviceVector<size_t>& data,
size_t total_size,
std::vector<size_t>* out_offsets) {
auto& offsets = *out_offsets;
offsets.resize(devices_.Size() + 1);
offsets[0] = 0;
#pragma omp parallel for schedule(static, 1) if (devices_.Size() > 1)
for (int shard = 0; shard < devices_.Size(); ++shard) {
int device = devices_.DeviceId(shard);
auto data_span = data.DeviceSpan(device);
dh::safe_cuda(cudaSetDevice(device));
if (data_span.size() == 0) {
offsets[shard + 1] = total_size;
} else {
// copy the last element from every shard
dh::safe_cuda(cudaMemcpy(&offsets.at(shard + 1),
&data_span[data_span.size()-1],
sizeof(size_t), cudaMemcpyDeviceToHost));
}
}
}
// This function populates the explicit offsets that can be used to create a window into the
// underlying host vector. The window starts from the `batch_offset` and has a size of
// `batch_size`, and is sharded across all the devices. Each shard is granular depending on
// the number of output classes `n_classes`.
void PredictionDeviceOffsets(size_t total_size, size_t batch_offset, size_t batch_size,
int n_classes, std::vector<size_t>* out_offsets) {
auto& offsets = *out_offsets;
size_t n_shards = devices_.Size();
offsets.resize(n_shards + 2);
size_t rows_per_shard = common::DivRoundUp(batch_size, n_shards);
for (size_t shard = 0; shard < devices_.Size(); ++shard) {
size_t n_rows = std::min(batch_size, shard * rows_per_shard);
offsets[shard] = batch_offset + n_rows * n_classes;
}
offsets[n_shards] = batch_offset + batch_size * n_classes;
offsets[n_shards + 1] = total_size;
}
struct DeviceShard { struct DeviceShard {
DeviceShard() : device_{-1} {} DeviceShard() : device_{-1} {}
~DeviceShard() { ~DeviceShard() {
dh::safe_cuda(cudaSetDevice(device_)); if (device_ >= 0) {
dh::safe_cuda(cudaSetDevice(device_));
}
} }
void Init(int device) { void Init(int device) {
@ -284,10 +234,9 @@ class GPUPredictor : public xgboost::Predictor {
void PredictInternal void PredictInternal
(const SparsePage& batch, size_t num_features, (const SparsePage& batch, size_t num_features,
HostDeviceVector<bst_float>* predictions) { HostDeviceVector<bst_float>* predictions) {
if (predictions->DeviceSize(device_) == 0) { return; }
dh::safe_cuda(cudaSetDevice(device_)); dh::safe_cuda(cudaSetDevice(device_));
const int BLOCK_THREADS = 128; const int BLOCK_THREADS = 128;
size_t num_rows = batch.offset.DeviceSize(device_) - 1; size_t num_rows = batch.offset.DeviceSize() - 1;
const int GRID_SIZE = static_cast<int>(common::DivRoundUp(num_rows, BLOCK_THREADS)); const int GRID_SIZE = static_cast<int>(common::DivRoundUp(num_rows, BLOCK_THREADS));
int shared_memory_bytes = static_cast<int> int shared_memory_bytes = static_cast<int>
@ -297,14 +246,12 @@ class GPUPredictor : public xgboost::Predictor {
shared_memory_bytes = 0; shared_memory_bytes = 0;
use_shared = false; use_shared = false;
} }
const auto& data_distr = batch.data.Distribution(); size_t entry_start = 0;
size_t entry_start = data_distr.ShardStart(batch.data.Size(),
data_distr.Devices().Index(device_));
PredictKernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>> PredictKernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>>
(dh::ToSpan(nodes_), predictions->DeviceSpan(device_), dh::ToSpan(tree_segments_), (dh::ToSpan(nodes_), predictions->DeviceSpan(), dh::ToSpan(tree_segments_),
dh::ToSpan(tree_group_), batch.offset.DeviceSpan(device_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(),
batch.data.DeviceSpan(device_), this->tree_begin_, this->tree_end_, num_features, batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features,
num_rows, entry_start, use_shared, this->num_group_); num_rows, entry_start, use_shared, this->num_group_);
} }
@ -322,7 +269,7 @@ class GPUPredictor : public xgboost::Predictor {
void InitModel(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) { void InitModel(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) {
CHECK_EQ(model.param.size_leaf_vector, 0); CHECK_EQ(model.param.size_leaf_vector, 0);
// Copy decision trees to device // Copy decision trees to device
thrust::host_vector<size_t> h_tree_segments; thrust::host_vector<size_t> h_tree_segments{};
h_tree_segments.reserve((tree_end - tree_begin) + 1); h_tree_segments.reserve((tree_end - tree_begin) + 1);
size_t sum = 0; size_t sum = 0;
h_tree_segments.push_back(sum); h_tree_segments.push_back(sum);
@ -337,9 +284,7 @@ class GPUPredictor : public xgboost::Predictor {
std::copy(src_nodes.begin(), src_nodes.end(), std::copy(src_nodes.begin(), src_nodes.end(),
h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]); h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]);
} }
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard &shard) { shard_.InitModel(model, h_tree_segments, h_nodes, tree_begin, tree_end);
shard.InitModel(model, h_tree_segments, h_nodes, tree_begin, tree_end);
});
} }
void DevicePredictInternal(DMatrix* dmat, void DevicePredictInternal(DMatrix* dmat,
@ -352,40 +297,43 @@ class GPUPredictor : public xgboost::Predictor {
InitModel(model, tree_begin, tree_end); InitModel(model, tree_begin, tree_end);
size_t batch_offset = 0; size_t batch_offset = 0;
auto* preds = out_preds;
std::unique_ptr<HostDeviceVector<bst_float>> batch_preds{nullptr};
for (auto &batch : dmat->GetBatches<SparsePage>()) { for (auto &batch : dmat->GetBatches<SparsePage>()) {
bool is_external_memory = batch.Size() < dmat->Info().num_row_; bool is_external_memory = batch.Size() < dmat->Info().num_row_;
if (is_external_memory) { if (is_external_memory) {
std::vector<size_t> out_preds_offsets; batch_preds.reset(new HostDeviceVector<bst_float>);
PredictionDeviceOffsets(out_preds->Size(), batch_offset, batch.Size(), batch_preds->Resize(batch.Size() * model.param.num_output_group);
model.param.num_output_group, &out_preds_offsets); std::copy(out_preds->ConstHostVector().begin() + batch_offset,
out_preds->Reshard(GPUDistribution::Explicit(devices_, out_preds_offsets)); out_preds->ConstHostVector().begin() + batch_offset + batch_preds->Size(),
batch_preds->HostVector().begin());
preds = batch_preds.get();
} }
batch.offset.Shard(GPUDistribution::Overlap(devices_, 1)); batch.offset.SetDevice(device_);
std::vector<size_t> device_offsets; batch.data.SetDevice(device_);
DeviceOffsets(batch.offset, batch.data.Size(), &device_offsets); preds->SetDevice(device_);
batch.data.Reshard(GPUDistribution::Explicit(devices_, device_offsets)); shard_.PredictInternal(batch, model.param.num_feature, preds);
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { if (is_external_memory) {
shard.PredictInternal(batch, model.param.num_feature, out_preds); auto h_preds = preds->ConstHostVector();
}); std::copy(h_preds.begin(), h_preds.end(), out_preds->HostVector().begin() + batch_offset);
}
batch_offset += batch.Size() * model.param.num_output_group; batch_offset += batch.Size() * model.param.num_output_group;
} }
out_preds->Reshard(GPUDistribution::Granular(devices_, model.param.num_output_group));
monitor_.StopCuda("DevicePredictInternal"); monitor_.StopCuda("DevicePredictInternal");
} }
public: public:
GPUPredictor() = default; GPUPredictor() : device_{-1} {};
void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds, void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model, int tree_begin, const gbm::GBTreeModel& model, int tree_begin,
unsigned ntree_limit = 0) override { unsigned ntree_limit = 0) override {
GPUSet devices = GPUSet::All(learner_param_->gpu_id, learner_param_->n_gpus, int device = learner_param_->gpu_id;
dmat->Info().num_row_); CHECK_GE(device, 0);
CHECK_NE(devices.Size(), 0); ConfigureShard(device);
ConfigureShards(devices);
if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) { if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) {
return; return;
@ -408,10 +356,9 @@ class GPUPredictor : public xgboost::Predictor {
size_t n_classes = model.param.num_output_group; size_t n_classes = model.param.num_output_group;
size_t n = n_classes * info.num_row_; size_t n = n_classes * info.num_row_;
const HostDeviceVector<bst_float>& base_margin = info.base_margin_; const HostDeviceVector<bst_float>& base_margin = info.base_margin_;
out_preds->Shard(GPUDistribution::Granular(devices_, n_classes));
out_preds->Resize(n); out_preds->Resize(n);
if (base_margin.Size() != 0) { if (base_margin.Size() != 0) {
CHECK_EQ(out_preds->Size(), n); CHECK_EQ(base_margin.Size(), n);
out_preds->Copy(base_margin); out_preds->Copy(base_margin);
} else { } else {
out_preds->Fill(model.base_margin); out_preds->Fill(model.base_margin);
@ -427,7 +374,7 @@ class GPUPredictor : public xgboost::Predictor {
const HostDeviceVector<bst_float>& y = it->second.predictions; const HostDeviceVector<bst_float>& y = it->second.predictions;
if (y.Size() != 0) { if (y.Size() != 0) {
monitor_.StartCuda("PredictFromCache"); monitor_.StartCuda("PredictFromCache");
out_preds->Shard(y.Distribution()); out_preds->SetDevice(y.DeviceIdx());
out_preds->Resize(y.Size()); out_preds->Resize(y.Size());
out_preds->Copy(y); out_preds->Copy(y);
monitor_.StopCuda("PredictFromCache"); monitor_.StopCuda("PredictFromCache");
@ -500,25 +447,23 @@ class GPUPredictor : public xgboost::Predictor {
const std::vector<std::shared_ptr<DMatrix>>& cache) override { const std::vector<std::shared_ptr<DMatrix>>& cache) override {
Predictor::Configure(cfg, cache); Predictor::Configure(cfg, cache);
GPUSet devices = GPUSet::All(learner_param_->gpu_id, learner_param_->n_gpus); int device = learner_param_->gpu_id;
ConfigureShards(devices); if (device >= 0) {
ConfigureShard(device);
}
} }
private: private:
/*! \brief Re configure shards when GPUSet is changed. */ /*! \brief Re configure shards when GPUSet is changed. */
void ConfigureShards(GPUSet devices) { void ConfigureShard(int device) {
if (devices_ == devices) return; if (device_ == device) return;
devices_ = devices; device_ = device;
shards_.clear(); shard_.Init(device_);
shards_.resize(devices_.Size());
dh::ExecuteIndexShards(&shards_, [=](size_t i, DeviceShard& shard){
shard.Init(devices_.DeviceId(i));
});
} }
std::vector<DeviceShard> shards_; DeviceShard shard_;
GPUSet devices_; int device_;
common::Monitor monitor_; common::Monitor monitor_;
}; };

View File

@ -702,7 +702,7 @@ struct DeviceShard {
row_partitioner.reset(new RowPartitioner(device_id, n_rows)); row_partitioner.reset(new RowPartitioner(device_id, n_rows));
dh::safe_cuda(cudaMemcpyAsync( dh::safe_cuda(cudaMemcpyAsync(
gpair.data(), dh_gpair->ConstDevicePointer(device_id), gpair.data(), dh_gpair->ConstDevicePointer(),
gpair.size() * sizeof(GradientPair), cudaMemcpyHostToHost)); gpair.size() * sizeof(GradientPair), cudaMemcpyHostToHost));
SubsampleGradientPair(device_id, gpair, param.subsample, row_begin_idx); SubsampleGradientPair(device_id, gpair, param.subsample, row_begin_idx);
hist.Reset(); hist.Reset();
@ -745,8 +745,8 @@ struct DeviceShard {
for (auto i = 0ull; i < nidxs.size(); i++) { for (auto i = 0ull; i < nidxs.size(); i++) {
auto nidx = nidxs[i]; auto nidx = nidxs[i];
auto p_feature_set = column_sampler.GetFeatureSet(tree.GetDepth(nidx)); auto p_feature_set = column_sampler.GetFeatureSet(tree.GetDepth(nidx));
p_feature_set->Shard(GPUSet(device_id, 1)); p_feature_set->SetDevice(device_id);
auto d_sampled_features = p_feature_set->DeviceSpan(device_id); auto d_sampled_features = p_feature_set->DeviceSpan();
common::Span<int32_t> d_feature_set = common::Span<int32_t> d_feature_set =
interaction_constraints.Query(d_sampled_features, nidx); interaction_constraints.Query(d_sampled_features, nidx);
auto d_split_candidates = auto d_split_candidates =
@ -1016,7 +1016,7 @@ struct DeviceShard {
dh::AllReducer* reducer, int64_t num_columns) { dh::AllReducer* reducer, int64_t num_columns) {
constexpr int kRootNIdx = 0; constexpr int kRootNIdx = 0;
const auto &gpair = gpair_all->DeviceSpan(device_id); const auto &gpair = gpair_all->DeviceSpan();
dh::SumReduction(temp_memory, gpair, node_sum_gradients_d, dh::SumReduction(temp_memory, gpair, node_sum_gradients_d,
gpair.size()); gpair.size());
@ -1294,11 +1294,8 @@ class GPUHistMakerSpecialised {
param_.InitAllowUnknown(args); param_.InitAllowUnknown(args);
generic_param_ = generic_param; generic_param_ = generic_param;
hist_maker_param_.InitAllowUnknown(args); hist_maker_param_.InitAllowUnknown(args);
auto devices = GPUSet::All(generic_param_->gpu_id, device_ = generic_param_->gpu_id;
generic_param_->n_gpus); CHECK_GE(device_, 0) << "Must have at least one device";
n_devices_ = devices.Size();
CHECK(n_devices_ != 0) << "Must have at least one device";
dist_ = GPUDistribution::Block(devices);
dh::CheckComputeCapability(); dh::CheckComputeCapability();
@ -1330,30 +1327,22 @@ class GPUHistMakerSpecialised {
void InitDataOnce(DMatrix* dmat) { void InitDataOnce(DMatrix* dmat) {
info_ = &dmat->Info(); info_ = &dmat->Info();
int n_devices = dist_.Devices().Size(); reducer_.Init({device_});
device_list_.resize(n_devices);
for (int index = 0; index < n_devices; ++index) {
int device_id = dist_.Devices().DeviceId(index);
device_list_[index] = device_id;
}
reducer_.Init(device_list_);
// Synchronise the column sampling seed // Synchronise the column sampling seed
uint32_t column_sampling_seed = common::GlobalRandom()(); uint32_t column_sampling_seed = common::GlobalRandom()();
rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
// Create device shards // Create device shards
shards_.resize(n_devices); shards_.resize(1);
dh::ExecuteIndexShards( dh::ExecuteIndexShards(
&shards_, &shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) { [&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(dist_.Devices().DeviceId(idx))); dh::safe_cuda(cudaSetDevice(device_));
size_t start = dist_.ShardStart(info_->num_row_, idx); size_t start = 0;
size_t size = dist_.ShardSize(info_->num_row_, idx); size_t size = info_->num_row_;
shard = std::unique_ptr<DeviceShard<GradientSumT>>( shard = std::unique_ptr<DeviceShard<GradientSumT>>(
new DeviceShard<GradientSumT>(dist_.Devices().DeviceId(idx), idx, new DeviceShard<GradientSumT>(device_, idx,
start, start + size, param_, start, start + size, param_,
column_sampling_seed, column_sampling_seed,
info_->num_col_)); info_->num_col_));
@ -1436,7 +1425,7 @@ class GPUHistMakerSpecialised {
for (auto& tree : trees) { for (auto& tree : trees) {
tree = *p_tree; tree = *p_tree;
} }
gpair->Reshard(dist_); gpair->SetDevice(device_);
// Launch one thread for each device "shard" containing a subset of rows. // Launch one thread for each device "shard" containing a subset of rows.
// Threads will cooperatively build the tree, synchronising over histograms. // Threads will cooperatively build the tree, synchronising over histograms.
@ -1462,13 +1451,13 @@ class GPUHistMakerSpecialised {
return false; return false;
} }
monitor_.StartCuda("UpdatePredictionCache"); monitor_.StartCuda("UpdatePredictionCache");
p_out_preds->Shard(dist_.Devices()); p_out_preds->SetDevice(device_);
dh::ExecuteIndexShards( dh::ExecuteIndexShards(
&shards_, &shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) { [&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(shard->device_id)); dh::safe_cuda(cudaSetDevice(shard->device_id));
shard->UpdatePredictionCache( shard->UpdatePredictionCache(
p_out_preds->DevicePointer(shard->device_id)); p_out_preds->DevicePointer());
}); });
monitor_.StopCuda("UpdatePredictionCache"); monitor_.StopCuda("UpdatePredictionCache");
return true; return true;
@ -1483,7 +1472,6 @@ class GPUHistMakerSpecialised {
private: private:
bool initialised_; bool initialised_;
int n_devices_;
int n_bins_; int n_bins_;
GPUHistMakerTrainParam hist_maker_param_; GPUHistMakerTrainParam hist_maker_param_;
@ -1492,11 +1480,9 @@ class GPUHistMakerSpecialised {
dh::AllReducer reducer_; dh::AllReducer reducer_;
DMatrix* p_last_fmat_; DMatrix* p_last_fmat_;
GPUDistribution dist_; int device_;
common::Monitor monitor_; common::Monitor monitor_;
/*! List storing device id. */
std::vector<int> device_list_;
}; };
class GPUHistMaker : public TreeUpdater { class GPUHistMaker : public TreeUpdater {

View File

@ -1,37 +0,0 @@
#include "../../../src/common/common.h"
#include <gtest/gtest.h>
namespace xgboost {
TEST(GPUSet, Basic) {
GPUSet devices = GPUSet::Empty();
ASSERT_TRUE(devices.IsEmpty());
devices = GPUSet{0, 1};
ASSERT_TRUE(devices != GPUSet::Empty());
EXPECT_EQ(devices.Size(), 1);
devices = GPUSet::Range(1, 0);
EXPECT_EQ(devices.Size(), 0);
EXPECT_TRUE(devices.IsEmpty());
EXPECT_FALSE(devices.Contains(1));
devices = GPUSet::Range(2, -1);
EXPECT_EQ(devices, GPUSet::Empty());
EXPECT_EQ(devices.Size(), 0);
EXPECT_TRUE(devices.IsEmpty());
devices = GPUSet::Range(2, 8); // 2 ~ 10
EXPECT_EQ(devices.Size(), 8);
EXPECT_ANY_THROW(devices.DeviceId(8));
auto device_id = devices.DeviceId(0);
EXPECT_EQ(device_id, 2);
auto device_index = devices.Index(2);
EXPECT_EQ(device_index, 0);
#ifndef XGBOOST_USE_CUDA
EXPECT_EQ(GPUSet::AllVisible(), GPUSet::Empty());
#endif
}
} // namespace xgboost

View File

@ -1,83 +0,0 @@
#include <gtest/gtest.h>
#include <xgboost/logging.h>
#include "../../../src/common/common.h"
#include "../helpers.h"
#include <string>
namespace xgboost {
TEST(GPUSet, GPUBasic) {
GPUSet devices = GPUSet::Empty();
ASSERT_TRUE(devices.IsEmpty());
devices = GPUSet{1, 1};
ASSERT_TRUE(devices != GPUSet::Empty());
EXPECT_EQ(devices.Size(), 1);
EXPECT_EQ(*(devices.begin()), 1);
devices = GPUSet::Range(1, 0);
EXPECT_EQ(devices, GPUSet::Empty());
EXPECT_EQ(devices.Size(), 0);
EXPECT_TRUE(devices.IsEmpty());
EXPECT_FALSE(devices.Contains(1));
devices = GPUSet::Range(2, -1);
EXPECT_EQ(devices, GPUSet::Empty());
devices = GPUSet::Range(2, 8);
EXPECT_EQ(devices.Size(), 8);
EXPECT_EQ(*devices.begin(), 2);
EXPECT_EQ(*devices.end(), 2 + devices.Size());
EXPECT_EQ(8, devices.Size());
ASSERT_NO_THROW(GPUSet::AllVisible());
devices = GPUSet::AllVisible();
if (devices.IsEmpty()) {
LOG(WARNING) << "Empty devices.";
}
}
TEST(GPUSet, Verbose) {
{
std::map<std::string, std::string> args {};
args["verbosity"] = "3"; // LOG INFO
testing::internal::CaptureStderr();
ConsoleLogger::Configure({args.cbegin(), args.cend()});
GPUSet::All(0, 1);
std::string output = testing::internal::GetCapturedStderr();
ASSERT_NE(output.find("GPU ID: 0"), std::string::npos);
ASSERT_NE(output.find("GPUs: 1"), std::string::npos);
args["verbosity"] = "1"; // restore
ConsoleLogger::Configure({args.cbegin(), args.cend()});
}
}
#if defined(XGBOOST_USE_NCCL)
TEST(GPUSet, MGPU_GPUBasic) {
{
GPUSet devices = GPUSet::All(1, 1);
ASSERT_EQ(*(devices.begin()), 1);
ASSERT_EQ(*(devices.end()), 2);
ASSERT_EQ(devices.Size(), 1);
ASSERT_TRUE(devices.Contains(1));
}
{
GPUSet devices = GPUSet::All(0, -1);
ASSERT_GE(devices.Size(), 2);
}
// Specify number of rows.
{
GPUSet devices = GPUSet::All(0, -1, 1);
ASSERT_EQ(devices.Size(), 1);
}
}
#endif
} // namespace xgboost

View File

@ -87,8 +87,8 @@ TEST(ConfigParser, ParseKeyValuePair) {
ASSERT_TRUE(parser.ParseKeyValuePair("booster = gbtree", &key, &value)); ASSERT_TRUE(parser.ParseKeyValuePair("booster = gbtree", &key, &value));
ASSERT_EQ(key, "booster"); ASSERT_EQ(key, "booster");
ASSERT_EQ(value, "gbtree"); ASSERT_EQ(value, "gbtree");
ASSERT_TRUE(parser.ParseKeyValuePair("n_gpus = 2", &key, &value)); ASSERT_TRUE(parser.ParseKeyValuePair("gpu_id = 2", &key, &value));
ASSERT_EQ(key, "n_gpus"); ASSERT_EQ(key, "gpu_id");
ASSERT_EQ(value, "2"); ASSERT_EQ(value, "2");
ASSERT_TRUE(parser.ParseKeyValuePair("monotone_constraints = (1,0,-1)", ASSERT_TRUE(parser.ParseKeyValuePair("monotone_constraints = (1,0,-1)",
&key, &value)); &key, &value));

View File

@ -18,7 +18,7 @@
namespace xgboost { namespace xgboost {
namespace common { namespace common {
void TestDeviceSketch(const GPUSet& devices, bool use_external_memory) { void TestDeviceSketch(bool use_external_memory) {
// create the data // create the data
int nrows = 10001; int nrows = 10001;
std::shared_ptr<xgboost::DMatrix> *dmat = nullptr; std::shared_ptr<xgboost::DMatrix> *dmat = nullptr;
@ -53,7 +53,7 @@ void TestDeviceSketch(const GPUSet& devices, bool use_external_memory) {
// find the cuts on the GPU // find the cuts on the GPU
HistogramCuts hmat_gpu; HistogramCuts hmat_gpu;
size_t row_stride = DeviceSketch(p, CreateEmptyGenericParam(0, devices.Size()), gpu_batch_nrows, size_t row_stride = DeviceSketch(p, CreateEmptyGenericParam(0), gpu_batch_nrows,
dmat->get(), &hmat_gpu); dmat->get(), &hmat_gpu);
// compare the row stride with the one obtained from the dmatrix // compare the row stride with the one obtained from the dmatrix
@ -81,11 +81,11 @@ void TestDeviceSketch(const GPUSet& devices, bool use_external_memory) {
} }
TEST(gpu_hist_util, DeviceSketch) { TEST(gpu_hist_util, DeviceSketch) {
TestDeviceSketch(GPUSet::Range(0, 1), false); TestDeviceSketch(false);
} }
TEST(gpu_hist_util, DeviceSketch_ExternalMemory) { TEST(gpu_hist_util, DeviceSketch_ExternalMemory) {
TestDeviceSketch(GPUSet::Range(0, 1), true); TestDeviceSketch(true);
} }
} // namespace common } // namespace common

View File

@ -30,45 +30,36 @@ struct HostDeviceVectorSetDeviceHandler {
} }
}; };
void InitHostDeviceVector(size_t n, const GPUDistribution& distribution, void InitHostDeviceVector(size_t n, int device, HostDeviceVector<int> *v) {
HostDeviceVector<int> *v) {
// create the vector // create the vector
GPUSet devices = distribution.Devices(); v->SetDevice(device);
v->Shard(distribution);
v->Resize(n); v->Resize(n);
ASSERT_EQ(v->Size(), n); ASSERT_EQ(v->Size(), n);
ASSERT_TRUE(v->Distribution() == distribution); ASSERT_EQ(v->DeviceIdx(), device);
ASSERT_TRUE(v->Devices() == devices); // ensure that the device have read-write access
// ensure that the devices have read-write access ASSERT_TRUE(v->DeviceCanAccess(GPUAccess::kRead));
for (int i = 0; i < devices.Size(); ++i) { ASSERT_TRUE(v->DeviceCanAccess(GPUAccess::kWrite));
ASSERT_TRUE(v->DeviceCanAccess(i, GPUAccess::kRead));
ASSERT_TRUE(v->DeviceCanAccess(i, GPUAccess::kWrite));
}
// ensure that the host has no access // ensure that the host has no access
ASSERT_FALSE(v->HostCanAccess(GPUAccess::kWrite)); ASSERT_FALSE(v->HostCanAccess(GPUAccess::kWrite));
ASSERT_FALSE(v->HostCanAccess(GPUAccess::kRead)); ASSERT_FALSE(v->HostCanAccess(GPUAccess::kRead));
// fill in the data on the host // fill in the data on the host
std::vector<int>& data_h = v->HostVector(); std::vector<int>& data_h = v->HostVector();
// ensure that the host has full access, while the devices have none // ensure that the host has full access, while the device have none
ASSERT_TRUE(v->HostCanAccess(GPUAccess::kRead)); ASSERT_TRUE(v->HostCanAccess(GPUAccess::kRead));
ASSERT_TRUE(v->HostCanAccess(GPUAccess::kWrite)); ASSERT_TRUE(v->HostCanAccess(GPUAccess::kWrite));
for (int i = 0; i < devices.Size(); ++i) { ASSERT_FALSE(v->DeviceCanAccess(GPUAccess::kRead));
ASSERT_FALSE(v->DeviceCanAccess(i, GPUAccess::kRead)); ASSERT_FALSE(v->DeviceCanAccess(GPUAccess::kWrite));
ASSERT_FALSE(v->DeviceCanAccess(i, GPUAccess::kWrite));
}
ASSERT_EQ(data_h.size(), n); ASSERT_EQ(data_h.size(), n);
std::copy_n(thrust::make_counting_iterator(0), n, data_h.begin()); std::copy_n(thrust::make_counting_iterator(0), n, data_h.begin());
} }
void PlusOne(HostDeviceVector<int> *v) { void PlusOne(HostDeviceVector<int> *v) {
int n_devices = v->Devices().Size(); int device = v->DeviceIdx();
for (int i = 0; i < n_devices; ++i) { SetDevice(device);
SetDevice(i); thrust::transform(v->tbegin(), v->tend(), v->tbegin(),
thrust::transform(v->tbegin(i), v->tend(i), v->tbegin(i), [=]__device__(unsigned int a){ return a + 1; });
[=]__device__(unsigned int a){ return a + 1; });
}
} }
void CheckDevice(HostDeviceVector<int> *v, void CheckDevice(HostDeviceVector<int> *v,
@ -76,24 +67,24 @@ void CheckDevice(HostDeviceVector<int> *v,
const std::vector<size_t>& sizes, const std::vector<size_t>& sizes,
unsigned int first, GPUAccess access) { unsigned int first, GPUAccess access) {
int n_devices = sizes.size(); int n_devices = sizes.size();
ASSERT_EQ(v->Devices().Size(), n_devices); ASSERT_EQ(n_devices, 1);
for (int i = 0; i < n_devices; ++i) { for (int i = 0; i < n_devices; ++i) {
ASSERT_EQ(v->DeviceSize(i), sizes.at(i)); ASSERT_EQ(v->DeviceSize(), sizes.at(i));
SetDevice(i); SetDevice(i);
ASSERT_TRUE(thrust::equal(v->tcbegin(i), v->tcend(i), ASSERT_TRUE(thrust::equal(v->tcbegin(), v->tcend(),
thrust::make_counting_iterator(first + starts[i]))); thrust::make_counting_iterator(first + starts[i])));
ASSERT_TRUE(v->DeviceCanAccess(i, GPUAccess::kRead)); ASSERT_TRUE(v->DeviceCanAccess(GPUAccess::kRead));
// ensure that the device has at most the access specified by access // ensure that the device has at most the access specified by access
ASSERT_EQ(v->DeviceCanAccess(i, GPUAccess::kWrite), access == GPUAccess::kWrite); ASSERT_EQ(v->DeviceCanAccess(GPUAccess::kWrite), access == GPUAccess::kWrite);
} }
ASSERT_EQ(v->HostCanAccess(GPUAccess::kRead), access == GPUAccess::kRead); ASSERT_EQ(v->HostCanAccess(GPUAccess::kRead), access == GPUAccess::kRead);
ASSERT_FALSE(v->HostCanAccess(GPUAccess::kWrite)); ASSERT_FALSE(v->HostCanAccess(GPUAccess::kWrite));
for (int i = 0; i < n_devices; ++i) { for (int i = 0; i < n_devices; ++i) {
SetDevice(i); SetDevice(i);
ASSERT_TRUE(thrust::equal(v->tbegin(i), v->tend(i), ASSERT_TRUE(thrust::equal(v->tbegin(), v->tend(),
thrust::make_counting_iterator(first + starts[i]))); thrust::make_counting_iterator(first + starts[i])));
ASSERT_TRUE(v->DeviceCanAccess(i, GPUAccess::kRead)); ASSERT_TRUE(v->DeviceCanAccess(GPUAccess::kRead));
ASSERT_TRUE(v->DeviceCanAccess(i, GPUAccess::kWrite)); ASSERT_TRUE(v->DeviceCanAccess(GPUAccess::kWrite));
} }
ASSERT_FALSE(v->HostCanAccess(GPUAccess::kRead)); ASSERT_FALSE(v->HostCanAccess(GPUAccess::kRead));
ASSERT_FALSE(v->HostCanAccess(GPUAccess::kWrite)); ASSERT_FALSE(v->HostCanAccess(GPUAccess::kWrite));
@ -107,20 +98,20 @@ void CheckHost(HostDeviceVector<int> *v, GPUAccess access) {
} }
ASSERT_TRUE(v->HostCanAccess(GPUAccess::kRead)); ASSERT_TRUE(v->HostCanAccess(GPUAccess::kRead));
ASSERT_EQ(v->HostCanAccess(GPUAccess::kWrite), access == GPUAccess::kWrite); ASSERT_EQ(v->HostCanAccess(GPUAccess::kWrite), access == GPUAccess::kWrite);
size_t n_devices = v->Devices().Size(); size_t n_devices = 1;
for (int i = 0; i < n_devices; ++i) { for (int i = 0; i < n_devices; ++i) {
ASSERT_EQ(v->DeviceCanAccess(i, GPUAccess::kRead), access == GPUAccess::kRead); ASSERT_EQ(v->DeviceCanAccess(GPUAccess::kRead), access == GPUAccess::kRead);
// the devices should have no write access // the devices should have no write access
ASSERT_FALSE(v->DeviceCanAccess(i, GPUAccess::kWrite)); ASSERT_FALSE(v->DeviceCanAccess(GPUAccess::kWrite));
} }
} }
void TestHostDeviceVector void TestHostDeviceVector
(size_t n, const GPUDistribution& distribution, (size_t n, int device,
const std::vector<size_t>& starts, const std::vector<size_t>& sizes) { const std::vector<size_t>& starts, const std::vector<size_t>& sizes) {
HostDeviceVectorSetDeviceHandler hdvec_dev_hndlr(SetDevice); HostDeviceVectorSetDeviceHandler hdvec_dev_hndlr(SetDevice);
HostDeviceVector<int> v; HostDeviceVector<int> v;
InitHostDeviceVector(n, distribution, &v); InitHostDeviceVector(n, device, &v);
CheckDevice(&v, starts, sizes, 0, GPUAccess::kRead); CheckDevice(&v, starts, sizes, 0, GPUAccess::kRead);
PlusOne(&v); PlusOne(&v);
CheckDevice(&v, starts, sizes, 1, GPUAccess::kWrite); CheckDevice(&v, starts, sizes, 1, GPUAccess::kWrite);
@ -130,54 +121,24 @@ void TestHostDeviceVector
TEST(HostDeviceVector, TestBlock) { TEST(HostDeviceVector, TestBlock) {
size_t n = 1001; size_t n = 1001;
int n_devices = 2; int device = 0;
auto distribution = GPUDistribution::Block(GPUSet::Range(0, n_devices)); std::vector<size_t> starts{0};
std::vector<size_t> starts{0, 501}; std::vector<size_t> sizes{1001};
std::vector<size_t> sizes{501, 500}; TestHostDeviceVector(n, device, starts, sizes);
TestHostDeviceVector(n, distribution, starts, sizes);
}
TEST(HostDeviceVector, TestGranular) {
size_t n = 3003;
int n_devices = 2;
auto distribution = GPUDistribution::Granular(GPUSet::Range(0, n_devices), 3);
std::vector<size_t> starts{0, 1503};
std::vector<size_t> sizes{1503, 1500};
TestHostDeviceVector(n, distribution, starts, sizes);
}
TEST(HostDeviceVector, TestOverlap) {
size_t n = 1001;
int n_devices = 2;
auto distribution = GPUDistribution::Overlap(GPUSet::Range(0, n_devices), 1);
std::vector<size_t> starts{0, 500};
std::vector<size_t> sizes{501, 501};
TestHostDeviceVector(n, distribution, starts, sizes);
}
TEST(HostDeviceVector, TestExplicit) {
size_t n = 1001;
int n_devices = 2;
std::vector<size_t> offsets{0, 550, 1001};
auto distribution = GPUDistribution::Explicit(GPUSet::Range(0, n_devices), offsets);
std::vector<size_t> starts{0, 550};
std::vector<size_t> sizes{550, 451};
TestHostDeviceVector(n, distribution, starts, sizes);
} }
TEST(HostDeviceVector, TestCopy) { TEST(HostDeviceVector, TestCopy) {
size_t n = 1001; size_t n = 1001;
int n_devices = 2; int device = 0;
auto distribution = GPUDistribution::Block(GPUSet::Range(0, n_devices)); std::vector<size_t> starts{0};
std::vector<size_t> starts{0, 501}; std::vector<size_t> sizes{1001};
std::vector<size_t> sizes{501, 500};
HostDeviceVectorSetDeviceHandler hdvec_dev_hndlr(SetDevice); HostDeviceVectorSetDeviceHandler hdvec_dev_hndlr(SetDevice);
HostDeviceVector<int> v; HostDeviceVector<int> v;
{ {
// a separate scope to ensure that v1 is gone before further checks // a separate scope to ensure that v1 is gone before further checks
HostDeviceVector<int> v1; HostDeviceVector<int> v1;
InitHostDeviceVector(n, distribution, &v1); InitHostDeviceVector(n, device, &v1);
v = v1; v = v1;
} }
CheckDevice(&v, starts, sizes, 0, GPUAccess::kRead); CheckDevice(&v, starts, sizes, 0, GPUAccess::kRead);
@ -193,16 +154,16 @@ TEST(HostDeviceVector, Shard) {
h_vec[i] = i; h_vec[i] = i;
} }
HostDeviceVector<int> vec (h_vec); HostDeviceVector<int> vec (h_vec);
auto devices = GPUSet::Range(0, 1); auto device = 0;
vec.Shard(devices); vec.SetDevice(device);
ASSERT_EQ(vec.DeviceSize(0), h_vec.size()); ASSERT_EQ(vec.DeviceSize(), h_vec.size());
ASSERT_EQ(vec.Size(), h_vec.size()); ASSERT_EQ(vec.Size(), h_vec.size());
auto span = vec.DeviceSpan(0); // sync to device auto span = vec.DeviceSpan(); // sync to device
vec.Reshard(GPUDistribution::Empty()); // pull back to cpu, empty devices. vec.SetDevice(-1); // pull back to cpu.
ASSERT_EQ(vec.Size(), h_vec.size()); ASSERT_EQ(vec.Size(), h_vec.size());
ASSERT_TRUE(vec.Devices().IsEmpty()); ASSERT_EQ(vec.DeviceIdx(), -1);
auto h_vec_1 = vec.HostVector(); auto h_vec_1 = vec.HostVector();
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin())); ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
@ -214,16 +175,16 @@ TEST(HostDeviceVector, Reshard) {
h_vec[i] = i; h_vec[i] = i;
} }
HostDeviceVector<int> vec (h_vec); HostDeviceVector<int> vec (h_vec);
auto devices = GPUSet::Range(0, 1); auto device = 0;
vec.Shard(devices); vec.SetDevice(device);
ASSERT_EQ(vec.DeviceSize(0), h_vec.size()); ASSERT_EQ(vec.DeviceSize(), h_vec.size());
ASSERT_EQ(vec.Size(), h_vec.size()); ASSERT_EQ(vec.Size(), h_vec.size());
PlusOne(&vec); PlusOne(&vec);
vec.Reshard(GPUDistribution::Empty()); vec.SetDevice(-1);
ASSERT_EQ(vec.Size(), h_vec.size()); ASSERT_EQ(vec.Size(), h_vec.size());
ASSERT_TRUE(vec.Devices().IsEmpty()); ASSERT_EQ(vec.DeviceIdx(), -1);
auto h_vec_1 = vec.HostVector(); auto h_vec_1 = vec.HostVector();
for (size_t i = 0; i < h_vec_1.size(); ++i) { for (size_t i = 0; i < h_vec_1.size(); ++i) {
@ -233,97 +194,14 @@ TEST(HostDeviceVector, Reshard) {
TEST(HostDeviceVector, Span) { TEST(HostDeviceVector, Span) {
HostDeviceVector<float> vec {1.0f, 2.0f, 3.0f, 4.0f}; HostDeviceVector<float> vec {1.0f, 2.0f, 3.0f, 4.0f};
vec.Shard(GPUSet{0, 1}); vec.SetDevice(0);
auto span = vec.DeviceSpan(0); auto span = vec.DeviceSpan();
ASSERT_EQ(vec.DeviceSize(0), span.size()); ASSERT_EQ(vec.DeviceSize(), span.size());
ASSERT_EQ(vec.DevicePointer(0), span.data()); ASSERT_EQ(vec.DevicePointer(), span.data());
auto const_span = vec.ConstDeviceSpan(0); auto const_span = vec.ConstDeviceSpan();
ASSERT_EQ(vec.DeviceSize(0), span.size()); ASSERT_EQ(vec.DeviceSize(), span.size());
ASSERT_EQ(vec.ConstDevicePointer(0), span.data()); ASSERT_EQ(vec.ConstDevicePointer(), span.data());
} }
// Multi-GPUs' test
#if defined(XGBOOST_USE_NCCL)
TEST(HostDeviceVector, MGPU_Shard) {
auto devices = GPUSet::AllVisible();
if (devices.Size() < 2) {
LOG(WARNING) << "Not testing in multi-gpu environment.";
return;
}
std::vector<int> h_vec (2345);
for (size_t i = 0; i < h_vec.size(); ++i) {
h_vec[i] = i;
}
HostDeviceVector<int> vec (h_vec);
// Data size for each device.
std::vector<size_t> devices_size (devices.Size());
// From CPU to GPUs.
vec.Shard(devices);
size_t total_size = 0;
for (size_t i = 0; i < devices.Size(); ++i) {
total_size += vec.DeviceSize(i);
devices_size[i] = vec.DeviceSize(i);
}
ASSERT_EQ(total_size, h_vec.size());
ASSERT_EQ(total_size, vec.Size());
// Shard from devices to devices with different distribution.
EXPECT_ANY_THROW(
vec.Shard(GPUDistribution::Granular(devices, 12)));
// All data is drawn back to CPU
vec.Reshard(GPUDistribution::Empty());
ASSERT_TRUE(vec.Devices().IsEmpty());
ASSERT_EQ(vec.Size(), h_vec.size());
vec.Shard(GPUDistribution::Granular(devices, 12));
total_size = 0;
for (size_t i = 0; i < devices.Size(); ++i) {
total_size += vec.DeviceSize(i);
devices_size[i] = vec.DeviceSize(i);
}
ASSERT_EQ(total_size, h_vec.size());
ASSERT_EQ(total_size, vec.Size());
}
TEST(HostDeviceVector, MGPU_Reshard) {
auto devices = GPUSet::AllVisible();
if (devices.Size() < 2) {
LOG(WARNING) << "Not testing in multi-gpu environment.";
return;
}
size_t n = 1001;
int n_devices = 2;
auto distribution = GPUDistribution::Block(GPUSet::Range(0, n_devices));
std::vector<size_t> starts{0, 501};
std::vector<size_t> sizes{501, 500};
HostDeviceVector<int> v;
InitHostDeviceVector(n, distribution, &v);
CheckDevice(&v, starts, sizes, 0, GPUAccess::kRead);
PlusOne(&v);
CheckDevice(&v, starts, sizes, 1, GPUAccess::kWrite);
CheckHost(&v, GPUAccess::kRead);
CheckHost(&v, GPUAccess::kWrite);
auto distribution1 = GPUDistribution::Overlap(GPUSet::Range(0, n_devices), 1);
v.Reshard(distribution1);
for (size_t i = 0; i < n_devices; ++i) {
auto span = v.DeviceSpan(i); // sync to device
}
std::vector<size_t> starts1{0, 500};
std::vector<size_t> sizes1{501, 501};
CheckDevice(&v, starts1, sizes1, 1, GPUAccess::kWrite);
CheckHost(&v, GPUAccess::kRead);
CheckHost(&v, GPUAccess::kWrite);
}
#endif
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -33,7 +33,7 @@ std::string GetModelStr() {
}, },
"configuration": { "configuration": {
"booster": "gbtree", "booster": "gbtree",
"n_gpus": "1", "gpu_id": "0",
"num_class": "0", "num_class": "0",
"num_feature": "10", "num_feature": "10",
"objective": "reg:linear", "objective": "reg:linear",

View File

@ -9,13 +9,11 @@
#if defined(__CUDACC__) #if defined(__CUDACC__)
#define TRANSFORM_GPU_RANGE GPUSet::Range(0, 1) #define TRANSFORM_GPU 0
#define TRANSFORM_GPU_DIST GPUDistribution::Block(GPUSet::Range(0, 1))
#else #else
#define TRANSFORM_GPU_RANGE GPUSet::Empty() #define TRANSFORM_GPU -1
#define TRANSFORM_GPU_DIST GPUDistribution::Block(GPUSet::Empty())
#endif #endif
@ -46,13 +44,13 @@ TEST(Transform, DeclareUnifiedTest(Basic)) {
std::vector<bst_float> h_sol(size); std::vector<bst_float> h_sol(size);
InitializeRange(h_sol.begin(), h_sol.end()); InitializeRange(h_sol.begin(), h_sol.end());
const HostDeviceVector<bst_float> in_vec{h_in, TRANSFORM_GPU_DIST}; const HostDeviceVector<bst_float> in_vec{h_in, TRANSFORM_GPU};
HostDeviceVector<bst_float> out_vec{h_out, TRANSFORM_GPU_DIST}; HostDeviceVector<bst_float> out_vec{h_out, TRANSFORM_GPU};
out_vec.Fill(0); out_vec.Fill(0);
Transform<>::Init(TestTransformRange<bst_float>{}, Transform<>::Init(TestTransformRange<bst_float>{},
Range{0, static_cast<Range::DifferenceType>(size)}, Range{0, static_cast<Range::DifferenceType>(size)},
TRANSFORM_GPU_RANGE) TRANSFORM_GPU)
.Eval(&out_vec, &in_vec); .Eval(&out_vec, &in_vec);
std::vector<bst_float> res = out_vec.HostVector(); std::vector<bst_float> res = out_vec.HostVector();

View File

@ -5,87 +5,13 @@
namespace xgboost { namespace xgboost {
namespace common { namespace common {
// Test here is multi gpu specific
TEST(Transform, MGPU_Basic) {
auto devices = GPUSet::AllVisible();
CHECK_GT(devices.Size(), 1);
const size_t size {256};
std::vector<bst_float> h_in(size);
std::vector<bst_float> h_out(size);
InitializeRange(h_in.begin(), h_in.end());
std::vector<bst_float> h_sol(size);
InitializeRange(h_sol.begin(), h_sol.end());
const HostDeviceVector<bst_float> in_vec {h_in,
GPUDistribution::Block(GPUSet::Empty())};
HostDeviceVector<bst_float> out_vec {h_out,
GPUDistribution::Block(GPUSet::Empty())};
out_vec.Fill(0);
in_vec.Shard(GPUDistribution::Granular(devices, 8));
out_vec.Shard(GPUDistribution::Block(devices));
// Granularity is different, sharding will throw.
EXPECT_ANY_THROW(
Transform<>::Init(TestTransformRange<bst_float>{}, Range{0, size}, devices)
.Eval(&out_vec, &in_vec));
Transform<>::Init(TestTransformRange<bst_float>{}, Range{0, size},
devices, false).Eval(&out_vec, &in_vec);
std::vector<bst_float> res = out_vec.HostVector();
ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin()));
}
// Test for multi-classes setting.
template <typename T>
struct TestTransformRangeGranular {
const size_t granularity = 8;
explicit TestTransformRangeGranular(const size_t granular) : granularity{granular} {}
void XGBOOST_DEVICE operator()(size_t _idx,
Span<bst_float> _out, Span<const bst_float> _in) {
auto in_sub = _in.subspan(_idx * granularity, granularity);
auto out_sub = _out.subspan(_idx * granularity, granularity);
for (size_t i = 0; i < granularity; ++i) {
out_sub[i] = in_sub[i];
}
}
};
TEST(Transform, MGPU_Granularity) {
GPUSet devices = GPUSet::All(0, -1);
const size_t size {8990};
const size_t granularity = 10;
GPUDistribution distribution =
GPUDistribution::Granular(devices, granularity);
std::vector<bst_float> h_in(size);
std::vector<bst_float> h_out(size);
InitializeRange(h_in.begin(), h_in.end());
std::vector<bst_float> h_sol(size);
InitializeRange(h_sol.begin(), h_sol.end());
const HostDeviceVector<bst_float> in_vec {h_in, distribution};
HostDeviceVector<bst_float> out_vec {h_out, distribution};
ASSERT_NO_THROW(
Transform<>::Init(
TestTransformRangeGranular<bst_float>{granularity},
Range{0, size / granularity},
distribution)
.Eval(&out_vec, &in_vec));
std::vector<bst_float> res = out_vec.HostVector();
ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin()));
}
TEST(Transform, MGPU_SpecifiedGpuId) { TEST(Transform, MGPU_SpecifiedGpuId) {
if (AllVisibleGPUs() < 2) {
LOG(WARNING) << "Not testing in multi-gpu environment.";
return;
}
// Use 1 GPU, Numbering of GPU starts from 1 // Use 1 GPU, Numbering of GPU starts from 1
auto devices = GPUSet::All(1, 1); auto device = 1;
const size_t size {256}; const size_t size {256};
std::vector<bst_float> h_in(size); std::vector<bst_float> h_in(size);
std::vector<bst_float> h_out(size); std::vector<bst_float> h_out(size);
@ -93,13 +19,11 @@ TEST(Transform, MGPU_SpecifiedGpuId) {
std::vector<bst_float> h_sol(size); std::vector<bst_float> h_sol(size);
InitializeRange(h_sol.begin(), h_sol.end()); InitializeRange(h_sol.begin(), h_sol.end());
const HostDeviceVector<bst_float> in_vec {h_in, const HostDeviceVector<bst_float> in_vec {h_in, device};
GPUDistribution::Block(devices)}; HostDeviceVector<bst_float> out_vec {h_out, device};
HostDeviceVector<bst_float> out_vec {h_out,
GPUDistribution::Block(devices)};
ASSERT_NO_THROW( ASSERT_NO_THROW(
Transform<>::Init(TestTransformRange<bst_float>{}, Range{0, size}, devices) Transform<>::Init(TestTransformRange<bst_float>{}, Range{0, size}, device)
.Eval(&out_vec, &in_vec)); .Eval(&out_vec, &in_vec));
std::vector<bst_float> res = out_vec.HostVector(); std::vector<bst_float> res = out_vec.HostVector();
ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin())); ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin()));

View File

@ -12,7 +12,7 @@ TEST(GBTree, SelectTreeMethod) {
auto p_dmat {(*p_shared_ptr_dmat).get()}; auto p_dmat {(*p_shared_ptr_dmat).get()};
GenericParameter generic_param; GenericParameter generic_param;
generic_param.InitAllowUnknown(std::vector<Arg>{Arg("n_gpus", "0")}); generic_param.InitAllowUnknown(std::vector<Arg>{});
std::unique_ptr<GradientBooster> p_gbm{ std::unique_ptr<GradientBooster> p_gbm{
GradientBooster::Create("gbtree", &generic_param, {}, 0)}; GradientBooster::Create("gbtree", &generic_param, {}, 0)};
auto& gbtree = dynamic_cast<gbm::GBTree&> (*p_gbm); auto& gbtree = dynamic_cast<gbm::GBTree&> (*p_gbm);
@ -35,7 +35,7 @@ TEST(GBTree, SelectTreeMethod) {
Arg{"num_feature", n_feat}}, p_dmat); Arg{"num_feature", n_feat}}, p_dmat);
ASSERT_EQ(tparam.updater_seq, "grow_quantile_histmaker"); ASSERT_EQ(tparam.updater_seq, "grow_quantile_histmaker");
#ifdef XGBOOST_USE_CUDA #ifdef XGBOOST_USE_CUDA
generic_param.InitAllowUnknown(std::vector<Arg>{Arg{"n_gpus", "1"}}); generic_param.InitAllowUnknown(std::vector<Arg>{Arg{"gpu_id", "0"}});
gbtree.ConfigureWithKnownData({Arg("tree_method", "gpu_hist"), Arg("num_feature", n_feat)}, gbtree.ConfigureWithKnownData({Arg("tree_method", "gpu_hist"), Arg("num_feature", n_feat)},
p_dmat); p_dmat);
ASSERT_EQ(tparam.updater_seq, "grow_gpu_hist"); ASSERT_EQ(tparam.updater_seq, "grow_gpu_hist");

View File

@ -29,9 +29,9 @@
#endif #endif
#if defined(__CUDACC__) #if defined(__CUDACC__)
#define NGPUS 1 #define GPUIDX 0
#else #else
#define NGPUS 0 #define GPUIDX -1
#endif #endif
bool FileExists(const std::string& filename); bool FileExists(const std::string& filename);
@ -189,11 +189,10 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_c
gbm::GBTreeModel CreateTestModel(); gbm::GBTreeModel CreateTestModel();
inline GenericParameter CreateEmptyGenericParam(int gpu_id, int n_gpus) { inline GenericParameter CreateEmptyGenericParam(int gpu_id) {
xgboost::GenericParameter tparam; xgboost::GenericParameter tparam;
std::vector<std::pair<std::string, std::string>> args { std::vector<std::pair<std::string, std::string>> args {
{"gpu_id", std::to_string(gpu_id)}, {"gpu_id", std::to_string(gpu_id)}};
{"n_gpus", std::to_string(n_gpus)}};
tparam.Init(args); tparam.Init(args);
return tparam; return tparam;
} }

View File

@ -7,7 +7,7 @@
TEST(Linear, shotgun) { TEST(Linear, shotgun) {
auto mat = xgboost::CreateDMatrix(10, 10, 0); auto mat = xgboost::CreateDMatrix(10, 10, 0);
auto lparam = xgboost::CreateEmptyGenericParam(0, 0); auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
{ {
auto updater = std::unique_ptr<xgboost::LinearUpdater>( auto updater = std::unique_ptr<xgboost::LinearUpdater>(
xgboost::LinearUpdater::Create("shotgun", &lparam)); xgboost::LinearUpdater::Create("shotgun", &lparam));
@ -33,7 +33,7 @@ TEST(Linear, shotgun) {
TEST(Linear, coordinate) { TEST(Linear, coordinate) {
auto mat = xgboost::CreateDMatrix(10, 10, 0); auto mat = xgboost::CreateDMatrix(10, 10, 0);
auto lparam = xgboost::CreateEmptyGenericParam(0, 0); auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
auto updater = std::unique_ptr<xgboost::LinearUpdater>( auto updater = std::unique_ptr<xgboost::LinearUpdater>(
xgboost::LinearUpdater::Create("coord_descent", &lparam)); xgboost::LinearUpdater::Create("coord_descent", &lparam));
updater->Configure({{"eta", "1."}}); updater->Configure({{"eta", "1."}});

View File

@ -7,8 +7,7 @@ namespace xgboost {
TEST(Linear, GPUCoordinate) { TEST(Linear, GPUCoordinate) {
auto mat = xgboost::CreateDMatrix(10, 10, 0); auto mat = xgboost::CreateDMatrix(10, 10, 0);
auto lparam = CreateEmptyGenericParam(0, 1); auto lparam = CreateEmptyGenericParam(GPUIDX);
lparam.n_gpus = 1;
auto updater = std::unique_ptr<xgboost::LinearUpdater>( auto updater = std::unique_ptr<xgboost::LinearUpdater>(
xgboost::LinearUpdater::Create("gpu_coord_descent", &lparam)); xgboost::LinearUpdater::Create("gpu_coord_descent", &lparam));
updater->Configure({{"eta", "1."}}); updater->Configure({{"eta", "1."}});

View File

@ -6,7 +6,7 @@
#include "../helpers.h" #include "../helpers.h"
TEST(Metric, DeclareUnifiedTest(RMSE)) { TEST(Metric, DeclareUnifiedTest(RMSE)) {
auto lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric * metric = xgboost::Metric::Create("rmse", &lparam); xgboost::Metric * metric = xgboost::Metric::Create("rmse", &lparam);
metric->Configure({}); metric->Configure({});
ASSERT_STREQ(metric->Name(), "rmse"); ASSERT_STREQ(metric->Name(), "rmse");
@ -20,7 +20,7 @@ TEST(Metric, DeclareUnifiedTest(RMSE)) {
} }
TEST(Metric, DeclareUnifiedTest(RMSLE)) { TEST(Metric, DeclareUnifiedTest(RMSLE)) {
auto lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric * metric = xgboost::Metric::Create("rmsle", &lparam); xgboost::Metric * metric = xgboost::Metric::Create("rmsle", &lparam);
metric->Configure({}); metric->Configure({});
ASSERT_STREQ(metric->Name(), "rmsle"); ASSERT_STREQ(metric->Name(), "rmsle");
@ -32,7 +32,7 @@ TEST(Metric, DeclareUnifiedTest(RMSLE)) {
} }
TEST(Metric, DeclareUnifiedTest(MAE)) { TEST(Metric, DeclareUnifiedTest(MAE)) {
auto lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric * metric = xgboost::Metric::Create("mae", &lparam); xgboost::Metric * metric = xgboost::Metric::Create("mae", &lparam);
metric->Configure({}); metric->Configure({});
ASSERT_STREQ(metric->Name(), "mae"); ASSERT_STREQ(metric->Name(), "mae");
@ -45,7 +45,7 @@ TEST(Metric, DeclareUnifiedTest(MAE)) {
} }
TEST(Metric, DeclareUnifiedTest(LogLoss)) { TEST(Metric, DeclareUnifiedTest(LogLoss)) {
auto lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric * metric = xgboost::Metric::Create("logloss", &lparam); xgboost::Metric * metric = xgboost::Metric::Create("logloss", &lparam);
metric->Configure({}); metric->Configure({});
ASSERT_STREQ(metric->Name(), "logloss"); ASSERT_STREQ(metric->Name(), "logloss");
@ -58,7 +58,7 @@ TEST(Metric, DeclareUnifiedTest(LogLoss)) {
} }
TEST(Metric, DeclareUnifiedTest(Error)) { TEST(Metric, DeclareUnifiedTest(Error)) {
auto lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric * metric = xgboost::Metric::Create("error", &lparam); xgboost::Metric * metric = xgboost::Metric::Create("error", &lparam);
metric->Configure({}); metric->Configure({});
ASSERT_STREQ(metric->Name(), "error"); ASSERT_STREQ(metric->Name(), "error");
@ -90,7 +90,7 @@ TEST(Metric, DeclareUnifiedTest(Error)) {
} }
TEST(Metric, DeclareUnifiedTest(PoissionNegLogLik)) { TEST(Metric, DeclareUnifiedTest(PoissionNegLogLik)) {
auto lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric * metric = xgboost::Metric::Create("poisson-nloglik", &lparam); xgboost::Metric * metric = xgboost::Metric::Create("poisson-nloglik", &lparam);
metric->Configure({}); metric->Configure({});
ASSERT_STREQ(metric->Name(), "poisson-nloglik"); ASSERT_STREQ(metric->Name(), "poisson-nloglik");

View File

@ -4,7 +4,7 @@
#include "../helpers.h" #include "../helpers.h"
TEST(Metric, UnknownMetric) { TEST(Metric, UnknownMetric) {
auto tparam = xgboost::CreateEmptyGenericParam(0, 0); auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric * metric = nullptr; xgboost::Metric * metric = nullptr;
EXPECT_ANY_THROW(metric = xgboost::Metric::Create("unknown_name", &tparam)); EXPECT_ANY_THROW(metric = xgboost::Metric::Create("unknown_name", &tparam));
EXPECT_NO_THROW(metric = xgboost::Metric::Create("rmse", &tparam)); EXPECT_NO_THROW(metric = xgboost::Metric::Create("rmse", &tparam));

View File

@ -4,10 +4,9 @@
#include "../helpers.h" #include "../helpers.h"
inline void TestMultiClassError(xgboost::GPUSet const& devices) { inline void TestMultiClassError(int device) {
auto lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); auto lparam = xgboost::CreateEmptyGenericParam(device);
lparam.gpu_id = *devices.begin(); lparam.gpu_id = device;
lparam.n_gpus = devices.Size();
xgboost::Metric * metric = xgboost::Metric::Create("merror", &lparam); xgboost::Metric * metric = xgboost::Metric::Create("merror", &lparam);
metric->Configure({}); metric->Configure({});
ASSERT_STREQ(metric->Name(), "merror"); ASSERT_STREQ(metric->Name(), "merror");
@ -23,14 +22,12 @@ inline void TestMultiClassError(xgboost::GPUSet const& devices) {
} }
TEST(Metric, DeclareUnifiedTest(MultiClassError)) { TEST(Metric, DeclareUnifiedTest(MultiClassError)) {
auto devices = xgboost::GPUSet::Range(0, NGPUS); TestMultiClassError(GPUIDX);
TestMultiClassError(devices);
} }
inline void TestMultiClassLogLoss(xgboost::GPUSet const& devices) { inline void TestMultiClassLogLoss(int device) {
auto lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); auto lparam = xgboost::CreateEmptyGenericParam(device);
lparam.gpu_id = *devices.begin(); lparam.gpu_id = device;
lparam.n_gpus = devices.Size();
xgboost::Metric * metric = xgboost::Metric::Create("mlogloss", &lparam); xgboost::Metric * metric = xgboost::Metric::Create("mlogloss", &lparam);
metric->Configure({}); metric->Configure({});
ASSERT_STREQ(metric->Name(), "mlogloss"); ASSERT_STREQ(metric->Name(), "mlogloss");
@ -46,27 +43,31 @@ inline void TestMultiClassLogLoss(xgboost::GPUSet const& devices) {
} }
TEST(Metric, DeclareUnifiedTest(MultiClassLogLoss)) { TEST(Metric, DeclareUnifiedTest(MultiClassLogLoss)) {
auto devices = xgboost::GPUSet::Range(0, NGPUS); TestMultiClassLogLoss(GPUIDX);
TestMultiClassLogLoss(devices);
} }
#if defined(XGBOOST_USE_NCCL) && defined(__CUDACC__) #if defined(XGBOOST_USE_NCCL) && defined(__CUDACC__)
namespace xgboost {
namespace common {
TEST(Metric, MGPU_MultiClassError) { TEST(Metric, MGPU_MultiClassError) {
if (AllVisibleGPUs() < 2) {
LOG(WARNING) << "Not testing in multi-gpu environment.";
return;
}
{ {
auto devices = xgboost::GPUSet::All(0, -1); TestMultiClassError(0);
TestMultiClassError(devices);
} }
{ {
auto devices = xgboost::GPUSet::All(1, -1); TestMultiClassError(1);
TestMultiClassError(devices);
} }
{ {
auto devices = xgboost::GPUSet::All(0, -1); TestMultiClassLogLoss(0);
TestMultiClassLogLoss(devices);
} }
{ {
auto devices = xgboost::GPUSet::All(1, -1); TestMultiClassLogLoss(1);
TestMultiClassLogLoss(devices);
} }
} }
} // namespace common
} // namespace xgboost
#endif // defined(XGBOOST_USE_NCCL) #endif // defined(XGBOOST_USE_NCCL)

View File

@ -4,7 +4,7 @@
#include "../helpers.h" #include "../helpers.h"
TEST(Metric, AMS) { TEST(Metric, AMS) {
auto tparam = xgboost::CreateEmptyGenericParam(0, 0); auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
EXPECT_ANY_THROW(xgboost::Metric::Create("ams", &tparam)); EXPECT_ANY_THROW(xgboost::Metric::Create("ams", &tparam));
xgboost::Metric * metric = xgboost::Metric::Create("ams@0.5f", &tparam); xgboost::Metric * metric = xgboost::Metric::Create("ams@0.5f", &tparam);
ASSERT_STREQ(metric->Name(), "ams@0.5"); ASSERT_STREQ(metric->Name(), "ams@0.5");
@ -23,7 +23,7 @@ TEST(Metric, AMS) {
} }
TEST(Metric, AUC) { TEST(Metric, AUC) {
auto tparam = xgboost::CreateEmptyGenericParam(0, 0); auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric * metric = xgboost::Metric::Create("auc", &tparam); xgboost::Metric * metric = xgboost::Metric::Create("auc", &tparam);
ASSERT_STREQ(metric->Name(), "auc"); ASSERT_STREQ(metric->Name(), "auc");
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10);
@ -38,7 +38,7 @@ TEST(Metric, AUC) {
} }
TEST(Metric, AUCPR) { TEST(Metric, AUCPR) {
auto tparam = xgboost::CreateEmptyGenericParam(0, 0); auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric *metric = xgboost::Metric::Create("aucpr", &tparam); xgboost::Metric *metric = xgboost::Metric::Create("aucpr", &tparam);
ASSERT_STREQ(metric->Name(), "aucpr"); ASSERT_STREQ(metric->Name(), "aucpr");
EXPECT_NEAR(GetMetricEval(metric, {0, 0, 1, 1}, {0, 0, 1, 1}), 1, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0, 0, 1, 1}, {0, 0, 1, 1}), 1, 1e-10);
@ -65,7 +65,7 @@ TEST(Metric, Precision) {
// When the limit for precision is not given, it takes the limit at // When the limit for precision is not given, it takes the limit at
// std::numeric_limits<unsigned>::max(); hence all values are very small // std::numeric_limits<unsigned>::max(); hence all values are very small
// NOTE(AbdealiJK): Maybe this should be fixed to be num_row by default. // NOTE(AbdealiJK): Maybe this should be fixed to be num_row by default.
auto tparam = xgboost::CreateEmptyGenericParam(0, 0); auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric * metric = xgboost::Metric::Create("pre", &tparam); xgboost::Metric * metric = xgboost::Metric::Create("pre", &tparam);
ASSERT_STREQ(metric->Name(), "pre"); ASSERT_STREQ(metric->Name(), "pre");
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-7); EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-7);
@ -89,7 +89,7 @@ TEST(Metric, Precision) {
} }
TEST(Metric, NDCG) { TEST(Metric, NDCG) {
auto tparam = xgboost::CreateEmptyGenericParam(0, 0); auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric * metric = xgboost::Metric::Create("ndcg", &tparam); xgboost::Metric * metric = xgboost::Metric::Create("ndcg", &tparam);
ASSERT_STREQ(metric->Name(), "ndcg"); ASSERT_STREQ(metric->Name(), "ndcg");
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {})); EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {}));
@ -147,7 +147,7 @@ TEST(Metric, NDCG) {
} }
TEST(Metric, MAP) { TEST(Metric, MAP) {
auto tparam = xgboost::CreateEmptyGenericParam(0, 0); auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric * metric = xgboost::Metric::Create("map", &tparam); xgboost::Metric * metric = xgboost::Metric::Create("map", &tparam);
ASSERT_STREQ(metric->Name(), "map"); ASSERT_STREQ(metric->Name(), "map");
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10);

View File

@ -6,7 +6,7 @@
#include "../helpers.h" #include "../helpers.h"
TEST(Objective, DeclareUnifiedTest(HingeObj)) { TEST(Objective, DeclareUnifiedTest(HingeObj)) {
xgboost::GenericParameter tparam = xgboost::CreateEmptyGenericParam(0, NGPUS); xgboost::GenericParameter tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("binary:hinge", &tparam); xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("binary:hinge", &tparam);
xgboost::bst_float eps = std::numeric_limits<xgboost::bst_float>::min(); xgboost::bst_float eps = std::numeric_limits<xgboost::bst_float>::min();

View File

@ -7,7 +7,7 @@
#include "../helpers.h" #include "../helpers.h"
TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassObjGPair)) { TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassObjGPair)) {
xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::vector<std::pair<std::string, std::string>> args {{"num_class", "3"}}; std::vector<std::pair<std::string, std::string>> args {{"num_class", "3"}};
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("multi:softmax", &lparam); xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("multi:softmax", &lparam);
@ -25,7 +25,7 @@ TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassObjGPair)) {
} }
TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassBasic)) { TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassBasic)) {
auto lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::vector<std::pair<std::string, std::string>> args{ std::vector<std::pair<std::string, std::string>> args{
std::pair<std::string, std::string>("num_class", "3")}; std::pair<std::string, std::string>("num_class", "3")};
@ -47,7 +47,7 @@ TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassBasic)) {
} }
TEST(Objective, DeclareUnifiedTest(SoftprobMultiClassBasic)) { TEST(Objective, DeclareUnifiedTest(SoftprobMultiClassBasic)) {
xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::vector<std::pair<std::string, std::string>> args { std::vector<std::pair<std::string, std::string>> args {
std::pair<std::string, std::string>("num_class", "3")}; std::pair<std::string, std::string>("num_class", "3")};

View File

@ -7,7 +7,7 @@
#include "../helpers.h" #include "../helpers.h"
TEST(Objective, DeclareUnifiedTest(LinearRegressionGPair)) { TEST(Objective, DeclareUnifiedTest(LinearRegressionGPair)) {
xgboost::GenericParameter tparam = xgboost::CreateEmptyGenericParam(0, NGPUS); xgboost::GenericParameter tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::vector<std::pair<std::string, std::string>> args; std::vector<std::pair<std::string, std::string>> args;
xgboost::ObjFunction * obj = xgboost::ObjFunction * obj =
@ -32,7 +32,7 @@ TEST(Objective, DeclareUnifiedTest(LinearRegressionGPair)) {
} }
TEST(Objective, DeclareUnifiedTest(SquaredLog)) { TEST(Objective, DeclareUnifiedTest(SquaredLog)) {
xgboost::GenericParameter tparam = xgboost::CreateEmptyGenericParam(0, NGPUS); xgboost::GenericParameter tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::vector<std::pair<std::string, std::string>> args; std::vector<std::pair<std::string, std::string>> args;
xgboost::ObjFunction * obj = xgboost::ObjFunction * obj =
@ -56,7 +56,7 @@ TEST(Objective, DeclareUnifiedTest(SquaredLog)) {
} }
TEST(Objective, DeclareUnifiedTest(LogisticRegressionGPair)) { TEST(Objective, DeclareUnifiedTest(LogisticRegressionGPair)) {
xgboost::GenericParameter tparam = xgboost::CreateEmptyGenericParam(0, NGPUS); xgboost::GenericParameter tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::vector<std::pair<std::string, std::string>> args; std::vector<std::pair<std::string, std::string>> args;
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:logistic", &tparam); xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:logistic", &tparam);
@ -72,7 +72,7 @@ TEST(Objective, DeclareUnifiedTest(LogisticRegressionGPair)) {
} }
TEST(Objective, DeclareUnifiedTest(LogisticRegressionBasic)) { TEST(Objective, DeclareUnifiedTest(LogisticRegressionBasic)) {
xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::vector<std::pair<std::string, std::string>> args; std::vector<std::pair<std::string, std::string>> args;
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:logistic", &lparam); xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:logistic", &lparam);
@ -102,7 +102,7 @@ TEST(Objective, DeclareUnifiedTest(LogisticRegressionBasic)) {
} }
TEST(Objective, DeclareUnifiedTest(LogisticRawGPair)) { TEST(Objective, DeclareUnifiedTest(LogisticRawGPair)) {
xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::vector<std::pair<std::string, std::string>> args; std::vector<std::pair<std::string, std::string>> args;
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("binary:logitraw", &lparam); xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("binary:logitraw", &lparam);
@ -118,7 +118,7 @@ TEST(Objective, DeclareUnifiedTest(LogisticRawGPair)) {
} }
TEST(Objective, DeclareUnifiedTest(PoissonRegressionGPair)) { TEST(Objective, DeclareUnifiedTest(PoissonRegressionGPair)) {
xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::vector<std::pair<std::string, std::string>> args; std::vector<std::pair<std::string, std::string>> args;
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("count:poisson", &lparam); xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("count:poisson", &lparam);
@ -140,7 +140,7 @@ TEST(Objective, DeclareUnifiedTest(PoissonRegressionGPair)) {
} }
TEST(Objective, DeclareUnifiedTest(PoissonRegressionBasic)) { TEST(Objective, DeclareUnifiedTest(PoissonRegressionBasic)) {
xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::vector<std::pair<std::string, std::string>> args; std::vector<std::pair<std::string, std::string>> args;
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("count:poisson", &lparam); xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("count:poisson", &lparam);
@ -168,7 +168,7 @@ TEST(Objective, DeclareUnifiedTest(PoissonRegressionBasic)) {
} }
TEST(Objective, DeclareUnifiedTest(GammaRegressionGPair)) { TEST(Objective, DeclareUnifiedTest(GammaRegressionGPair)) {
xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::vector<std::pair<std::string, std::string>> args; std::vector<std::pair<std::string, std::string>> args;
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:gamma", &lparam); xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:gamma", &lparam);
@ -189,7 +189,7 @@ TEST(Objective, DeclareUnifiedTest(GammaRegressionGPair)) {
} }
TEST(Objective, DeclareUnifiedTest(GammaRegressionBasic)) { TEST(Objective, DeclareUnifiedTest(GammaRegressionBasic)) {
xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::vector<std::pair<std::string, std::string>> args; std::vector<std::pair<std::string, std::string>> args;
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:gamma", &lparam); xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:gamma", &lparam);
@ -217,7 +217,7 @@ TEST(Objective, DeclareUnifiedTest(GammaRegressionBasic)) {
} }
TEST(Objective, DeclareUnifiedTest(TweedieRegressionGPair)) { TEST(Objective, DeclareUnifiedTest(TweedieRegressionGPair)) {
xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::vector<std::pair<std::string, std::string>> args; std::vector<std::pair<std::string, std::string>> args;
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:tweedie", &lparam); xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:tweedie", &lparam);
@ -241,7 +241,7 @@ TEST(Objective, DeclareUnifiedTest(TweedieRegressionGPair)) {
#if defined(__CUDACC__) #if defined(__CUDACC__)
TEST(Objective, CPU_vs_CUDA) { TEST(Objective, CPU_vs_CUDA) {
xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(0, 1); xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::ObjFunction * obj = xgboost::ObjFunction * obj =
xgboost::ObjFunction::Create("reg:squarederror", &lparam); xgboost::ObjFunction::Create("reg:squarederror", &lparam);
@ -267,12 +267,12 @@ TEST(Objective, CPU_vs_CUDA) {
{ {
// CPU // CPU
lparam.n_gpus = 0; lparam.gpu_id = -1;
obj->GetGradient(preds, info, 0, &cpu_out_preds); obj->GetGradient(preds, info, 0, &cpu_out_preds);
} }
{ {
// CUDA // CUDA
lparam.n_gpus = 1; lparam.gpu_id = 0;
obj->GetGradient(preds, info, 0, &cuda_out_preds); obj->GetGradient(preds, info, 0, &cuda_out_preds);
} }
@ -294,7 +294,7 @@ TEST(Objective, CPU_vs_CUDA) {
#endif #endif
TEST(Objective, DeclareUnifiedTest(TweedieRegressionBasic)) { TEST(Objective, DeclareUnifiedTest(TweedieRegressionBasic)) {
xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(0, NGPUS); xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::vector<std::pair<std::string, std::string>> args; std::vector<std::pair<std::string, std::string>> args;
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:tweedie", &lparam); xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:tweedie", &lparam);
@ -325,7 +325,7 @@ TEST(Objective, DeclareUnifiedTest(TweedieRegressionBasic)) {
// CoxRegression not implemented in GPU code, no need for testing. // CoxRegression not implemented in GPU code, no need for testing.
#if !defined(__CUDACC__) #if !defined(__CUDACC__)
TEST(Objective, CoxRegressionGPair) { TEST(Objective, CoxRegressionGPair) {
xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(0, 0); xgboost::GenericParameter lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::vector<std::pair<std::string, std::string>> args; std::vector<std::pair<std::string, std::string>> args;
xgboost::ObjFunction * obj = xgboost::ObjFunction * obj =
xgboost::ObjFunction::Create("survival:cox", &lparam); xgboost::ObjFunction::Create("survival:cox", &lparam);

View File

@ -6,7 +6,7 @@
namespace xgboost { namespace xgboost {
TEST(Plugin, ExampleObjective) { TEST(Plugin, ExampleObjective) {
xgboost::GenericParameter tparam = CreateEmptyGenericParam(0, 0); xgboost::GenericParameter tparam = CreateEmptyGenericParam(GPUIDX);
auto * obj = xgboost::ObjFunction::Create("mylogistic", &tparam); auto * obj = xgboost::ObjFunction::Create("mylogistic", &tparam);
ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"error"}); ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"error"});
delete obj; delete obj;

View File

@ -6,7 +6,7 @@
namespace xgboost { namespace xgboost {
TEST(cpu_predictor, Test) { TEST(cpu_predictor, Test) {
auto lparam = CreateEmptyGenericParam(0, 0); auto lparam = CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<Predictor> cpu_predictor = std::unique_ptr<Predictor> cpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam)); std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam));
@ -59,7 +59,7 @@ TEST(cpu_predictor, ExternalMemoryTest) {
dmlc::TemporaryDirectory tmpdir; dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm"; std::string filename = tmpdir.path + "/big.libsvm";
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(12, 64, filename); std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(12, 64, filename);
auto lparam = CreateEmptyGenericParam(0, 0); auto lparam = CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<Predictor> cpu_predictor = std::unique_ptr<Predictor> cpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam)); std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam));

View File

@ -33,8 +33,8 @@ namespace xgboost {
namespace predictor { namespace predictor {
TEST(gpu_predictor, Test) { TEST(gpu_predictor, Test) {
auto cpu_lparam = CreateEmptyGenericParam(0, 0); auto cpu_lparam = CreateEmptyGenericParam(-1);
auto gpu_lparam = CreateEmptyGenericParam(0, 1); auto gpu_lparam = CreateEmptyGenericParam(0);
std::unique_ptr<Predictor> gpu_predictor = std::unique_ptr<Predictor> gpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &gpu_lparam)); std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &gpu_lparam));
@ -69,7 +69,7 @@ TEST(gpu_predictor, Test) {
} }
TEST(gpu_predictor, ExternalMemoryTest) { TEST(gpu_predictor, ExternalMemoryTest) {
auto lparam = CreateEmptyGenericParam(0, 1); auto lparam = CreateEmptyGenericParam(0);
std::unique_ptr<Predictor> gpu_predictor = std::unique_ptr<Predictor> gpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam)); std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam));
gpu_predictor->Configure({}, {}); gpu_predictor->Configure({}, {});
@ -83,26 +83,26 @@ TEST(gpu_predictor, ExternalMemoryTest) {
std::string file1 = tmpdir.path + "/big_1.libsvm"; std::string file1 = tmpdir.path + "/big_1.libsvm";
std::string file2 = tmpdir.path + "/big_2.libsvm"; std::string file2 = tmpdir.path + "/big_2.libsvm";
dmats.push_back(CreateSparsePageDMatrix(9, 64UL, file0)); dmats.push_back(CreateSparsePageDMatrix(9, 64UL, file0));
dmats.push_back(CreateSparsePageDMatrix(128, 128UL, file1)); // dmats.push_back(CreateSparsePageDMatrix(128, 128UL, file1));
dmats.push_back(CreateSparsePageDMatrix(1024, 1024UL, file2)); // dmats.push_back(CreateSparsePageDMatrix(1024, 1024UL, file2));
for (const auto& dmat: dmats) { for (const auto& dmat: dmats) {
// Test predict batch dmat->Info().base_margin_.Resize(dmat->Info().num_row_ * n_classes, 0.5);
HostDeviceVector<float> out_predictions; HostDeviceVector<float> out_predictions;
gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
EXPECT_EQ(out_predictions.Size(), dmat->Info().num_row_ * n_classes); EXPECT_EQ(out_predictions.Size(), dmat->Info().num_row_ * n_classes);
const std::vector<float> &host_vector = out_predictions.ConstHostVector(); const std::vector<float> &host_vector = out_predictions.ConstHostVector();
for (int i = 0; i < host_vector.size() / n_classes; i++) { for (int i = 0; i < host_vector.size() / n_classes; i++) {
ASSERT_EQ(host_vector[i * n_classes], 1.5); ASSERT_EQ(host_vector[i * n_classes], 2.0);
ASSERT_EQ(host_vector[i * n_classes + 1], 0.); ASSERT_EQ(host_vector[i * n_classes + 1], 0.5);
ASSERT_EQ(host_vector[i * n_classes + 2], 0.); ASSERT_EQ(host_vector[i * n_classes + 2], 0.5);
} }
} }
} }
// Test whether pickling preserves predictor parameters // Test whether pickling preserves predictor parameters
TEST(gpu_predictor, PicklingTest) { TEST(gpu_predictor, PicklingTest) {
int const ngpu = 1; int const gpuid = 0;
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
@ -134,7 +134,7 @@ TEST(gpu_predictor, PicklingTest) {
ASSERT_EQ(XGBoosterSetParam( ASSERT_EQ(XGBoosterSetParam(
bst, "tree_method", "gpu_hist"), 0) << XGBGetLastError(); bst, "tree_method", "gpu_hist"), 0) << XGBGetLastError();
ASSERT_EQ(XGBoosterSetParam( ASSERT_EQ(XGBoosterSetParam(
bst, "n_gpus", std::to_string(ngpu).c_str()), 0) << XGBGetLastError(); bst, "gpu_id", std::to_string(gpuid).c_str()), 0) << XGBGetLastError();
ASSERT_EQ(XGBoosterSetParam(bst, "predictor", "gpu_predictor"), 0) << XGBGetLastError(); ASSERT_EQ(XGBoosterSetParam(bst, "predictor", "gpu_predictor"), 0) << XGBGetLastError();
// Run boosting iterations // Run boosting iterations
@ -160,7 +160,7 @@ TEST(gpu_predictor, PicklingTest) {
{ // Query predictor { // Query predictor
const auto& kwargs = QueryBoosterConfigurationArguments(bst2); const auto& kwargs = QueryBoosterConfigurationArguments(bst2);
ASSERT_EQ(kwargs.at("predictor"), "gpu_predictor"); ASSERT_EQ(kwargs.at("predictor"), "gpu_predictor");
ASSERT_EQ(kwargs.at("n_gpus"), std::to_string(ngpu).c_str()); ASSERT_EQ(kwargs.at("gpu_id"), std::to_string(gpuid).c_str());
} }
{ // Change predictor and query again { // Change predictor and query again

View File

@ -168,10 +168,9 @@ TEST(Learner, IO) {
std::unique_ptr<Learner> learner {Learner::Create(mat)}; std::unique_ptr<Learner> learner {Learner::Create(mat)};
learner->SetParams({Arg{"tree_method", "auto"}, learner->SetParams({Arg{"tree_method", "auto"},
Arg{"predictor", "gpu_predictor"}, Arg{"predictor", "gpu_predictor"},
Arg{"n_gpus", "1"}}); Arg{"gpu_id", "0"}});
learner->UpdateOneIter(0, p_dmat.get()); learner->UpdateOneIter(0, p_dmat.get());
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0); ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
ASSERT_EQ(learner->GetGenericParameter().n_gpus, 1);
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string fname = tempdir.path + "/model.bst"; const std::string fname = tempdir.path + "/model.bst";
@ -185,7 +184,6 @@ TEST(Learner, IO) {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r")); std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r"));
learner->Load(fi.get()); learner->Load(fi.get());
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0); ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
ASSERT_EQ(learner->GetGenericParameter().n_gpus, 0);
delete pp_dmat; delete pp_dmat;
} }
@ -208,31 +206,27 @@ TEST(Learner, GPUConfiguration) {
Arg{"updater", "gpu_coord_descent"}}); Arg{"updater", "gpu_coord_descent"}});
learner->UpdateOneIter(0, p_dmat.get()); learner->UpdateOneIter(0, p_dmat.get());
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0); ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
ASSERT_EQ(learner->GetGenericParameter().n_gpus, 1);
} }
{ {
std::unique_ptr<Learner> learner {Learner::Create(mat)}; std::unique_ptr<Learner> learner {Learner::Create(mat)};
learner->SetParams({Arg{"tree_method", "gpu_hist"}}); learner->SetParams({Arg{"tree_method", "gpu_hist"}});
learner->UpdateOneIter(0, p_dmat.get()); learner->UpdateOneIter(0, p_dmat.get());
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0); ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
ASSERT_EQ(learner->GetGenericParameter().n_gpus, 1);
} }
{ {
// with CPU algorithm // with CPU algorithm
std::unique_ptr<Learner> learner {Learner::Create(mat)}; std::unique_ptr<Learner> learner {Learner::Create(mat)};
learner->SetParams({Arg{"tree_method", "hist"}}); learner->SetParams({Arg{"tree_method", "hist"}});
learner->UpdateOneIter(0, p_dmat.get()); learner->UpdateOneIter(0, p_dmat.get());
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0); ASSERT_EQ(learner->GetGenericParameter().gpu_id, -1);
ASSERT_EQ(learner->GetGenericParameter().n_gpus, 0);
} }
{ {
// with CPU algorithm, but `n_gpus` takes priority // with CPU algorithm, but `gpu_id` takes priority
std::unique_ptr<Learner> learner {Learner::Create(mat)}; std::unique_ptr<Learner> learner {Learner::Create(mat)};
learner->SetParams({Arg{"tree_method", "hist"}, learner->SetParams({Arg{"tree_method", "hist"},
Arg{"n_gpus", "1"}}); Arg{"gpu_id", "0"}});
learner->UpdateOneIter(0, p_dmat.get()); learner->UpdateOneIter(0, p_dmat.get());
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0); ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
ASSERT_EQ(learner->GetGenericParameter().n_gpus, 1);
} }
{ {
// With CPU algorithm but GPU Predictor, this is to simulate when // With CPU algorithm but GPU Predictor, this is to simulate when
@ -243,7 +237,6 @@ TEST(Learner, GPUConfiguration) {
Arg{"predictor", "gpu_predictor"}}); Arg{"predictor", "gpu_predictor"}});
learner->UpdateOneIter(0, p_dmat.get()); learner->UpdateOneIter(0, p_dmat.get());
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0); ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
ASSERT_EQ(learner->GetGenericParameter().n_gpus, 1);
} }
delete pp_dmat; delete pp_dmat;

View File

@ -366,7 +366,7 @@ TEST(GpuHist, EvaluateSplits) {
ASSERT_NEAR(res[1].fvalue, 0.26, xgboost::kRtEps); ASSERT_NEAR(res[1].fvalue, 0.26, xgboost::kRtEps);
} }
void TestHistogramIndexImpl(int n_gpus) { void TestHistogramIndexImpl() {
// Test if the compressed histogram index matches when using a sparse // Test if the compressed histogram index matches when using a sparse
// dmatrix with and without using external memory // dmatrix with and without using external memory
@ -384,7 +384,7 @@ void TestHistogramIndexImpl(int n_gpus) {
{"max_leaves", "0"} {"max_leaves", "0"}
}; };
GenericParameter generic_param(CreateEmptyGenericParam(0, n_gpus)); GenericParameter generic_param(CreateEmptyGenericParam(0));
hist_maker.Configure(training_params, &generic_param); hist_maker.Configure(training_params, &generic_param);
hist_maker.InitDataOnce(hist_maker_dmat.get()); hist_maker.InitDataOnce(hist_maker_dmat.get());
@ -412,7 +412,7 @@ void TestHistogramIndexImpl(int n_gpus) {
} }
TEST(GpuHist, TestHistogramIndex) { TEST(GpuHist, TestHistogramIndex) {
TestHistogramIndexImpl(1); TestHistogramIndexImpl();
} }
} // namespace tree } // namespace tree

View File

@ -29,7 +29,7 @@ TEST(Updater, Prune) {
{0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f} }; {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f} };
auto dmat = CreateDMatrix(32, 16, 0.4, 3); auto dmat = CreateDMatrix(32, 16, 0.4, 3);
auto lparam = CreateEmptyGenericParam(0, 0); auto lparam = CreateEmptyGenericParam(GPUIDX);
// prepare tree // prepare tree
RegTree tree = RegTree(); RegTree tree = RegTree();

View File

@ -25,7 +25,7 @@ TEST(Updater, Refresh) {
{"reg_lambda", "1"}}; {"reg_lambda", "1"}};
RegTree tree = RegTree(); RegTree tree = RegTree();
auto lparam = CreateEmptyGenericParam(0, 0); auto lparam = CreateEmptyGenericParam(GPUIDX);
tree.param.InitAllowUnknown(cfg); tree.param.InitAllowUnknown(cfg);
std::vector<RegTree*> trees {&tree}; std::vector<RegTree*> trees {&tree};
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh", &lparam)); std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh", &lparam));

View File

@ -61,7 +61,6 @@ base_params = {
def params_basic_1x4(rank): def params_basic_1x4(rank):
return dict(base_params, **{ return dict(base_params, **{
'n_gpus': 1,
'gpu_id': rank, 'gpu_id': rank,
}), 20 }), 20

View File

@ -23,7 +23,7 @@ class TestGPULinear(unittest.TestCase):
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def test_gpu_coordinate(self): def test_gpu_coordinate(self):
parameters = self.common_param.copy() parameters = self.common_param.copy()
parameters['n_gpus'] = [1] parameters['gpu_id'] = [0]
for param in test_linear.parameter_combinations(parameters): for param in test_linear.parameter_combinations(parameters):
results = test_linear.run_suite( results = test_linear.run_suite(
param, 150, self.datasets, scale_features=True) param, 150, self.datasets, scale_features=True)

View File

@ -21,7 +21,7 @@ datasets = ["Boston", "Cancer", "Digits", "Sparse regression",
class TestGPU(unittest.TestCase): class TestGPU(unittest.TestCase):
def test_gpu_hist(self): def test_gpu_hist(self):
test_param = parameter_combinations({'n_gpus': [1], 'max_depth': [2, 8], test_param = parameter_combinations({'gpu_id': [0], 'max_depth': [2, 8],
'max_leaves': [255, 4], 'max_leaves': [255, 4],
'max_bin': [2, 256], 'max_bin': [2, 256],
'grow_policy': ['lossguide']}) 'grow_policy': ['lossguide']})
@ -38,8 +38,7 @@ class TestGPU(unittest.TestCase):
@pytest.mark.mgpu @pytest.mark.mgpu
def test_specified_gpu_id_gpu_update(self): def test_specified_gpu_id_gpu_update(self):
variable_param = {'n_gpus': [1], variable_param = {'gpu_id': [1],
'gpu_id': [1],
'max_depth': [8], 'max_depth': [8],
'max_leaves': [255, 4], 'max_leaves': [255, 4],
'max_bin': [2, 64], 'max_bin': [2, 64],

View File

@ -63,7 +63,7 @@ class TestGPU(unittest.TestCase):
'nthread': 0, 'nthread': 0,
'eta': 1, 'eta': 1,
'verbosity': 3, 'verbosity': 3,
'n_gpus': 1, 'gpu_id': 0,
'objective': 'binary:logistic', 'objective': 'binary:logistic',
'max_bin': max_bin, 'max_bin': max_bin,
'eval_metric': 'auc'} 'eval_metric': 'auc'}