Remove column major specialization. (#5755)

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2020-06-05 16:19:14 +08:00 committed by GitHub
parent bd9d57f579
commit cacff9232a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 70 additions and 204 deletions

View File

@ -463,7 +463,8 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, float *values,
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet"; CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
auto *learner = static_cast<xgboost::Learner *>(handle); auto *learner = static_cast<xgboost::Learner *>(handle);
auto x = xgboost::data::DenseAdapter(values, n_rows, n_cols); std::shared_ptr<xgboost::data::DenseAdapter> x{
new xgboost::data::DenseAdapter(values, n_rows, n_cols)};
HostDeviceVector<float>* p_predt { nullptr }; HostDeviceVector<float>* p_predt { nullptr };
std::string type { c_type }; std::string type { c_type };
learner->InplacePredict(x, type, missing, &p_predt); learner->InplacePredict(x, type, missing, &p_predt);
@ -494,7 +495,8 @@ XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle,
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet"; CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
auto *learner = static_cast<xgboost::Learner *>(handle); auto *learner = static_cast<xgboost::Learner *>(handle);
auto x = data::CSRAdapter(indptr, indices, data, nindptr - 1, nelem, num_col); std::shared_ptr<xgboost::data::CSRAdapter> x{
new xgboost::data::CSRAdapter(indptr, indices, data, nindptr - 1, nelem, num_col)};
HostDeviceVector<float>* p_predt { nullptr }; HostDeviceVector<float>* p_predt { nullptr };
std::string type { c_type }; std::string type { c_type };
learner->InplacePredict(x, type, missing, &p_predt); learner->InplacePredict(x, type, missing, &p_predt);

View File

@ -69,7 +69,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle,
auto *learner = static_cast<Learner*>(handle); auto *learner = static_cast<Learner*>(handle);
std::string json_str{c_json_strs}; std::string json_str{c_json_strs};
auto x = data::CudfAdapter(json_str); auto x = std::make_shared<data::CudfAdapter>(json_str);
HostDeviceVector<float>* p_predt { nullptr }; HostDeviceVector<float>* p_predt { nullptr };
std::string type { c_type }; std::string type { c_type };
learner->InplacePredict(x, type, missing, &p_predt); learner->InplacePredict(x, type, missing, &p_predt);
@ -97,7 +97,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle,
auto *learner = static_cast<Learner*>(handle); auto *learner = static_cast<Learner*>(handle);
std::string json_str{c_json_strs}; std::string json_str{c_json_strs};
auto x = data::CupyAdapter(json_str); auto x = std::make_shared<data::CupyAdapter>(json_str);
HostDeviceVector<float>* p_predt { nullptr }; HostDeviceVector<float>* p_predt { nullptr };
std::string type { c_type }; std::string type { c_type };
learner->InplacePredict(x, type, missing, &p_predt); learner->InplacePredict(x, type, missing, &p_predt);

View File

@ -34,36 +34,27 @@ struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
}; };
class CudfAdapterBatch : public detail::NoMetaInfo { class CudfAdapterBatch : public detail::NoMetaInfo {
friend class CudfAdapter;
public: public:
CudfAdapterBatch() = default; CudfAdapterBatch() = default;
CudfAdapterBatch(common::Span<ArrayInterface> columns, CudfAdapterBatch(common::Span<ArrayInterface> columns, size_t num_rows)
common::Span<size_t> column_ptr, size_t num_elements)
: columns_(columns), : columns_(columns),
column_ptr_(column_ptr), num_rows_(num_rows) {}
num_elements_(num_elements) {} size_t Size() const { return num_rows_ * columns_.size(); }
size_t Size() const { return num_elements_; }
__device__ COOTuple GetElement(size_t idx) const { __device__ COOTuple GetElement(size_t idx) const {
size_t column_idx = size_t column_idx = idx % columns_.size();
thrust::upper_bound(thrust::seq,column_ptr_.begin(), column_ptr_.end(), idx) - column_ptr_.begin() - 1; size_t row_idx = idx / columns_.size();
auto& column = columns_[column_idx]; auto const& column = columns_[column_idx];
size_t row_idx = idx - column_ptr_[column_idx];
float value = column.valid.Data() == nullptr || column.valid.Check(row_idx) float value = column.valid.Data() == nullptr || column.valid.Check(row_idx)
? column.GetElement(row_idx) ? column.GetElement(row_idx)
: std::numeric_limits<float>::quiet_NaN(); : std::numeric_limits<float>::quiet_NaN();
return {row_idx, column_idx, value}; return {row_idx, column_idx, value};
} }
__device__ float GetValue(size_t ridx, bst_feature_t fidx) const {
auto const& column = columns_[fidx];
float value = column.valid.Data() == nullptr || column.valid.Check(ridx)
? column.GetElement(ridx)
: std::numeric_limits<float>::quiet_NaN();
return value;
}
private: private:
common::Span<ArrayInterface> columns_; common::Span<ArrayInterface> columns_;
common::Span<size_t> column_ptr_; size_t num_rows_;
size_t num_elements_;
}; };
/*! /*!
@ -127,7 +118,6 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
CHECK_EQ(typestr.size(), 3) << ArrayInterfaceErrors::TypestrFormat(); CHECK_EQ(typestr.size(), 3) << ArrayInterfaceErrors::TypestrFormat();
CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian(); CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian();
std::vector<ArrayInterface> columns; std::vector<ArrayInterface> columns;
std::vector<size_t> column_ptr({0});
auto first_column = ArrayInterface(get<Object const>(json_columns[0])); auto first_column = ArrayInterface(get<Object const>(json_columns[0]));
device_idx_ = dh::CudaGetPointerDevice(first_column.data); device_idx_ = dh::CudaGetPointerDevice(first_column.data);
CHECK_NE(device_idx_, -1); CHECK_NE(device_idx_, -1);
@ -137,7 +127,6 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
auto column = ArrayInterface(get<Object const>(json_col)); auto column = ArrayInterface(get<Object const>(json_col));
columns.push_back(column); columns.push_back(column);
CHECK_EQ(column.num_cols, 1); CHECK_EQ(column.num_cols, 1);
column_ptr.emplace_back(column_ptr.back() + column.num_rows);
num_rows_ = std::max(num_rows_, size_t(column.num_rows)); num_rows_ = std::max(num_rows_, size_t(column.num_rows));
CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data)) CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data))
<< "All columns should use the same device."; << "All columns should use the same device.";
@ -145,23 +134,20 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
<< "All columns should have same number of rows."; << "All columns should have same number of rows.";
} }
columns_ = columns; columns_ = columns;
column_ptr_ = column_ptr; batch_ = CudfAdapterBatch(dh::ToSpan(columns_), num_rows_);
batch_ = CudfAdapterBatch(dh::ToSpan(columns_), dh::ToSpan(column_ptr_), }
column_ptr.back()); const CudfAdapterBatch& Value() const override {
CHECK_EQ(batch_.columns_.data(), columns_.data().get());
return batch_;
} }
const CudfAdapterBatch& Value() const override { return batch_; }
size_t NumRows() const { return num_rows_; } size_t NumRows() const { return num_rows_; }
size_t NumColumns() const { return columns_.size(); } size_t NumColumns() const { return columns_.size(); }
size_t DeviceIdx() const { return device_idx_; } size_t DeviceIdx() const { return device_idx_; }
// Cudf is column major
bool IsRowMajor() { return false; }
private: private:
CudfAdapterBatch batch_; CudfAdapterBatch batch_;
dh::device_vector<ArrayInterface> columns_; dh::device_vector<ArrayInterface> columns_;
dh::device_vector<size_t> column_ptr_; // Exclusive scan of column sizes
size_t num_rows_{0}; size_t num_rows_{0};
int device_idx_; int device_idx_;
}; };
@ -201,8 +187,6 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
size_t NumColumns() const { return array_interface_.num_cols; } size_t NumColumns() const { return array_interface_.num_cols; }
size_t DeviceIdx() const { return device_idx_; } size_t DeviceIdx() const { return device_idx_; }
bool IsRowMajor() { return true; }
private: private:
ArrayInterface array_interface_; ArrayInterface array_interface_;
CupyAdapterBatch batch_; CupyAdapterBatch batch_;

View File

@ -154,8 +154,8 @@ struct WriteCompressedEllpackFunctor {
// Here the data is already correctly ordered and simply needs to be compacted // Here the data is already correctly ordered and simply needs to be compacted
// to remove missing data // to remove missing data
template <typename AdapterBatchT> template <typename AdapterBatchT>
void CopyDataRowMajor(const AdapterBatchT& batch, EllpackPageImpl* dst, void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst,
int device_idx, float missing) { int device_idx, float missing) {
// Some witchcraft happens here // Some witchcraft happens here
// The goal is to copy valid elements out of the input to an ellpack matrix // The goal is to copy valid elements out of the input to an ellpack matrix
// with a given row stride, using no extra working memory Standard stream // with a given row stride, using no extra working memory Standard stream
@ -209,51 +209,6 @@ void CopyDataRowMajor(const AdapterBatchT& batch, EllpackPageImpl* dst,
}); });
} }
template <typename AdapterT, typename AdapterBatchT>
void CopyDataColumnMajor(AdapterT* adapter, const AdapterBatchT& batch,
EllpackPageImpl* dst, float missing) {
// Step 1: Get the sizes of the input columns
dh::caching_device_vector<size_t> column_sizes(adapter->NumColumns(), 0);
auto d_column_sizes = column_sizes.data().get();
// Populate column sizes
dh::LaunchN(adapter->DeviceIdx(), batch.Size(), [=] __device__(size_t idx) {
const auto& e = batch.GetElement(idx);
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
&d_column_sizes[e.column_idx]),
static_cast<unsigned long long>(1)); // NOLINT
});
thrust::host_vector<size_t> host_column_sizes = column_sizes;
// Step 2: Iterate over columns, place elements in correct row, increment
// temporary row pointers
dh::caching_device_vector<size_t> temp_row_ptr(adapter->NumRows(), 0);
auto d_temp_row_ptr = temp_row_ptr.data().get();
auto row_stride = dst->row_stride;
size_t begin = 0;
auto device_accessor = dst->GetDeviceAccessor(adapter->DeviceIdx());
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
data::IsValidFunctor is_valid(missing);
for (auto size : host_column_sizes) {
size_t end = begin + size;
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
auto writer_non_const =
writer; // For some reason this variable gets captured as const
const auto& e = batch.GetElement(idx + begin);
if (!is_valid(e)) return;
size_t output_position =
e.row_idx * row_stride + d_temp_row_ptr[e.row_idx];
auto bin_idx = device_accessor.SearchBin(e.value, e.column_idx);
writer_non_const.AtomicWriteSymbol(d_compressed_buffer, bin_idx,
output_position);
d_temp_row_ptr[e.row_idx] += 1;
});
begin = end;
}
}
void WriteNullValues(EllpackPageImpl* dst, int device_idx, void WriteNullValues(EllpackPageImpl* dst, int device_idx,
common::Span<size_t> row_counts) { common::Span<size_t> row_counts) {
// Write the null values // Write the null values
@ -284,12 +239,7 @@ EllpackPageImpl::EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense
*this = EllpackPageImpl(adapter->DeviceIdx(), cuts, is_dense, row_stride, *this = EllpackPageImpl(adapter->DeviceIdx(), cuts, is_dense, row_stride,
adapter->NumRows()); adapter->NumRows());
if (adapter->IsRowMajor()) { CopyDataToEllpack(batch, this, adapter->DeviceIdx(), missing);
CopyDataRowMajor(batch, this, adapter->DeviceIdx(), missing);
} else {
CopyDataColumnMajor(adapter, batch, this, missing);
}
WriteNullValues(this, adapter->DeviceIdx(), row_counts_span); WriteNullValues(this, adapter->DeviceIdx(), row_counts_span);
} }

View File

@ -35,51 +35,12 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
thrust::device_pointer_cast(offset.data())); thrust::device_pointer_cast(offset.data()));
} }
template <typename AdapterT>
void CopyDataColumnMajor(AdapterT* adapter, common::Span<Entry> data,
int device_idx, float missing,
common::Span<size_t> row_ptr) {
// Step 1: Get the sizes of the input columns
dh::device_vector<size_t> column_sizes(adapter->NumColumns());
auto d_column_sizes = column_sizes.data().get();
auto& batch = adapter->Value();
// Populate column sizes
dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) {
const auto& e = batch.GetElement(idx);
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
&d_column_sizes[e.column_idx]),
static_cast<unsigned long long>(1)); // NOLINT
});
thrust::host_vector<size_t> host_column_sizes = column_sizes;
// Step 2: Iterate over columns, place elements in correct row, increment
// temporary row pointers
dh::device_vector<size_t> temp_row_ptr(
thrust::device_pointer_cast(row_ptr.data()),
thrust::device_pointer_cast(row_ptr.data() + row_ptr.size()));
auto d_temp_row_ptr = temp_row_ptr.data().get();
size_t begin = 0;
IsValidFunctor is_valid(missing);
for (auto size : host_column_sizes) {
size_t end = begin + size;
dh::LaunchN(device_idx, end - begin, [=] __device__(size_t idx) {
const auto& e = batch.GetElement(idx + begin);
if (!is_valid(e)) return;
data[d_temp_row_ptr[e.row_idx]] = Entry(e.column_idx, e.value);
d_temp_row_ptr[e.row_idx] += 1;
});
begin = end;
}
}
// Here the data is already correctly ordered and simply needs to be compacted // Here the data is already correctly ordered and simply needs to be compacted
// to remove missing data // to remove missing data
template <typename AdapterT> template <typename AdapterT>
void CopyDataRowMajor(AdapterT* adapter, common::Span<Entry> data, void CopyDataToDMatrix(AdapterT* adapter, common::Span<Entry> data,
int device_idx, float missing, int device_idx, float missing,
common::Span<size_t> row_ptr) { common::Span<size_t> row_ptr) {
auto& batch = adapter->Value(); auto& batch = adapter->Value();
auto transform_f = [=] __device__(size_t idx) { auto transform_f = [=] __device__(size_t idx) {
const auto& e = batch.GetElement(idx); const auto& e = batch.GetElement(idx);
@ -116,13 +77,8 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
CountRowOffsets(batch, s_offset, adapter->DeviceIdx(), missing); CountRowOffsets(batch, s_offset, adapter->DeviceIdx(), missing);
info_.num_nonzero_ = sparse_page_.offset.HostVector().back(); info_.num_nonzero_ = sparse_page_.offset.HostVector().back();
sparse_page_.data.Resize(info_.num_nonzero_); sparse_page_.data.Resize(info_.num_nonzero_);
if (adapter->IsRowMajor()) { CopyDataToDMatrix(adapter, sparse_page_.data.DeviceSpan(),
CopyDataRowMajor(adapter, sparse_page_.data.DeviceSpan(), adapter->DeviceIdx(), missing, s_offset);
adapter->DeviceIdx(), missing, s_offset);
} else {
CopyDataColumnMajor(adapter, sparse_page_.data.DeviceSpan(),
adapter->DeviceIdx(), missing, s_offset);
}
info_.num_col_ = adapter->NumColumns(); info_.num_col_ = adapter->NumColumns();
info_.num_row_ = adapter->NumRows(); info_.num_row_ = adapter->NumRows();

View File

@ -271,12 +271,12 @@ class CPUPredictor : public Predictor {
PredictionCacheEntry *out_preds, PredictionCacheEntry *out_preds,
uint32_t tree_begin, uint32_t tree_end) const { uint32_t tree_begin, uint32_t tree_end) const {
auto threads = omp_get_max_threads(); auto threads = omp_get_max_threads();
auto m = dmlc::get<Adapter>(x); auto m = dmlc::get<std::shared_ptr<Adapter>>(x);
CHECK_EQ(m.NumColumns(), model.learner_model_param->num_feature) CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
<< "Number of columns in data must equal to trained model."; << "Number of columns in data must equal to trained model.";
MetaInfo info; MetaInfo info;
info.num_col_ = m.NumColumns(); info.num_col_ = m->NumColumns();
info.num_row_ = m.NumRows(); info.num_row_ = m->NumRows();
this->InitOutPredictions(info, &(out_preds->predictions), model); this->InitOutPredictions(info, &(out_preds->predictions), model);
std::vector<Entry> workspace(info.num_col_ * 8 * threads); std::vector<Entry> workspace(info.num_col_ * 8 * threads);
auto &predictions = out_preds->predictions.HostVector(); auto &predictions = out_preds->predictions.HostVector();
@ -284,17 +284,17 @@ class CPUPredictor : public Predictor {
InitThreadTemp(threads, model.learner_model_param->num_feature, &thread_temp); InitThreadTemp(threads, model.learner_model_param->num_feature, &thread_temp);
size_t constexpr kUnroll = 8; size_t constexpr kUnroll = 8;
PredictBatchKernel(AdapterView<Adapter, kUnroll>( PredictBatchKernel(AdapterView<Adapter, kUnroll>(
&m, missing, common::Span<Entry>{workspace}), m.get(), missing, common::Span<Entry>{workspace}),
&predictions, model, tree_begin, tree_end, &thread_temp); &predictions, model, tree_begin, tree_end, &thread_temp);
} }
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
float missing, PredictionCacheEntry *out_preds, float missing, PredictionCacheEntry *out_preds,
uint32_t tree_begin, unsigned tree_end) const override { uint32_t tree_begin, unsigned tree_end) const override {
if (x.type() == typeid(data::DenseAdapter)) { if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) {
this->DispatchedInplacePredict<data::DenseAdapter>( this->DispatchedInplacePredict<data::DenseAdapter>(
x, model, missing, out_preds, tree_begin, tree_end); x, model, missing, out_preds, tree_begin, tree_end);
} else if (x.type() == typeid(data::CSRAdapter)) { } else if (x.type() == typeid(std::shared_ptr<data::CSRAdapter>)) {
this->DispatchedInplacePredict<data::CSRAdapter>( this->DispatchedInplacePredict<data::CSRAdapter>(
x, model, missing, out_preds, tree_begin, tree_end); x, model, missing, out_preds, tree_begin, tree_end);
} else { } else {

View File

@ -118,14 +118,18 @@ struct EllpackLoader {
} }
}; };
struct CuPyAdapterLoader { template <typename Batch>
data::CupyAdapterBatch batch; struct DeviceAdapterLoader {
Batch batch;
bst_feature_t columns; bst_feature_t columns;
float* smem; float* smem;
bool use_shared; bool use_shared;
DEV_INLINE CuPyAdapterLoader(data::CupyAdapterBatch const batch, bool use_shared, using BatchT = Batch;
bst_feature_t num_features, bst_row_t num_rows, size_t entry_start) :
DEV_INLINE DeviceAdapterLoader(Batch const batch, bool use_shared,
bst_feature_t num_features, bst_row_t num_rows,
size_t entry_start) :
batch{batch}, batch{batch},
columns{num_features}, columns{num_features},
use_shared{use_shared} { use_shared{use_shared} {
@ -155,39 +159,6 @@ struct CuPyAdapterLoader {
} }
}; };
struct CuDFAdapterLoader {
data::CudfAdapterBatch batch;
bst_feature_t columns;
float* smem;
bool use_shared;
DEV_INLINE CuDFAdapterLoader(data::CudfAdapterBatch const batch, bool use_shared,
bst_feature_t num_features,
bst_row_t num_rows, size_t entry_start)
: batch{batch}, columns{num_features}, use_shared{use_shared} {
extern __shared__ float _smem[];
smem = _smem;
if (use_shared) {
uint32_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
size_t shared_elements = blockDim.x * num_features;
dh::BlockFill(smem, shared_elements, nanf(""));
__syncthreads();
if (global_idx < num_rows) {
for (size_t i = 0; i < columns; ++i) {
smem[threadIdx.x * columns + i] = batch.GetValue(global_idx, i);
}
}
}
__syncthreads();
}
DEV_INLINE float GetFvalue(bst_row_t ridx, bst_feature_t fidx) const {
if (use_shared) {
return smem[threadIdx.x * columns + fidx];
}
return batch.GetValue(ridx, fidx);
}
};
template <typename Loader> template <typename Loader>
__device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree, __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
Loader* loader) { Loader* loader) {
@ -429,7 +400,7 @@ class GPUPredictor : public xgboost::Predictor {
out_preds->Size() == dmat->Info().num_row_); out_preds->Size() == dmat->Info().num_row_);
} }
template <typename Adapter, typename Loader, typename Batch> template <typename Adapter, typename Loader>
void DispatchedInplacePredict(dmlc::any const &x, void DispatchedInplacePredict(dmlc::any const &x,
const gbm::GBTreeModel &model, float missing, const gbm::GBTreeModel &model, float missing,
PredictionCacheEntry *out_preds, PredictionCacheEntry *out_preds,
@ -439,22 +410,22 @@ class GPUPredictor : public xgboost::Predictor {
DeviceModel d_model; DeviceModel d_model;
d_model.Init(model, tree_begin, tree_end, this->generic_param_->gpu_id); d_model.Init(model, tree_begin, tree_end, this->generic_param_->gpu_id);
auto m = dmlc::get<Adapter>(x); auto m = dmlc::get<std::shared_ptr<Adapter>>(x);
CHECK_EQ(m.NumColumns(), model.learner_model_param->num_feature) CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
<< "Number of columns in data must equal to trained model."; << "Number of columns in data must equal to trained model.";
CHECK_EQ(this->generic_param_->gpu_id, m.DeviceIdx()) CHECK_EQ(this->generic_param_->gpu_id, m->DeviceIdx())
<< "XGBoost is running on device: " << this->generic_param_->gpu_id << ", " << "XGBoost is running on device: " << this->generic_param_->gpu_id << ", "
<< "but data is on: " << m.DeviceIdx(); << "but data is on: " << m->DeviceIdx();
MetaInfo info; MetaInfo info;
info.num_col_ = m.NumColumns(); info.num_col_ = m->NumColumns();
info.num_row_ = m.NumRows(); info.num_row_ = m->NumRows();
this->InitOutPredictions(info, &(out_preds->predictions), model); this->InitOutPredictions(info, &(out_preds->predictions), model);
const uint32_t BLOCK_THREADS = 128; const uint32_t BLOCK_THREADS = 128;
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(info.num_row_, BLOCK_THREADS)); auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(info.num_row_, BLOCK_THREADS));
auto shared_memory_bytes = auto shared_memory_bytes =
static_cast<size_t>(sizeof(float) * m.NumColumns() * BLOCK_THREADS); static_cast<size_t>(sizeof(float) * m->NumColumns() * BLOCK_THREADS);
bool use_shared = true; bool use_shared = true;
if (shared_memory_bytes > max_shared_memory_bytes) { if (shared_memory_bytes > max_shared_memory_bytes) {
shared_memory_bytes = 0; shared_memory_bytes = 0;
@ -463,22 +434,24 @@ class GPUPredictor : public xgboost::Predictor {
size_t entry_start = 0; size_t entry_start = 0;
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
PredictKernel<Loader, Batch>, PredictKernel<Loader, typename Loader::BatchT>,
m.Value(), m->Value(),
dh::ToSpan(d_model.nodes), out_preds->predictions.DeviceSpan(), dh::ToSpan(d_model.nodes), out_preds->predictions.DeviceSpan(),
dh::ToSpan(d_model.tree_segments), dh::ToSpan(d_model.tree_group), dh::ToSpan(d_model.tree_segments), dh::ToSpan(d_model.tree_group),
tree_begin, tree_end, m.NumColumns(), info.num_row_, tree_begin, tree_end, m->NumColumns(), info.num_row_,
entry_start, use_shared, output_groups); entry_start, use_shared, output_groups);
} }
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
float missing, PredictionCacheEntry *out_preds, float missing, PredictionCacheEntry *out_preds,
uint32_t tree_begin, unsigned tree_end) const override { uint32_t tree_begin, unsigned tree_end) const override {
if (x.type() == typeid(data::CupyAdapter)) { if (x.type() == typeid(std::shared_ptr<data::CupyAdapter>)) {
this->DispatchedInplacePredict<data::CupyAdapter, CuPyAdapterLoader, data::CupyAdapterBatch>( this->DispatchedInplacePredict<
data::CupyAdapter, DeviceAdapterLoader<data::CupyAdapterBatch>>(
x, model, missing, out_preds, tree_begin, tree_end); x, model, missing, out_preds, tree_begin, tree_end);
} else if (x.type() == typeid(data::CudfAdapter)) { } else if (x.type() == typeid(std::shared_ptr<data::CudfAdapter>)) {
this->DispatchedInplacePredict<data::CudfAdapter, CuDFAdapterLoader, data::CudfAdapterBatch>( this->DispatchedInplacePredict<
data::CudfAdapter, DeviceAdapterLoader<data::CudfAdapterBatch>>(
x, model, missing, out_preds, tree_begin, tree_end); x, model, missing, out_preds, tree_begin, tree_end);
} else { } else {
LOG(FATAL) << "Only CuPy and CuDF are supported by GPU Predictor."; LOG(FATAL) << "Only CuPy and CuDF are supported by GPU Predictor.";

View File

@ -36,13 +36,12 @@ void TestCudfAdapter()
EXPECT_NO_THROW({ EXPECT_NO_THROW({
dh::LaunchN(0, batch.Size(), [=] __device__(size_t idx) { dh::LaunchN(0, batch.Size(), [=] __device__(size_t idx) {
auto element = batch.GetElement(idx); auto element = batch.GetElement(idx);
if (idx < kRowsA) { KERNEL_CHECK(element.row_idx == idx / 2);
if (idx % 2 == 0) {
KERNEL_CHECK(element.column_idx == 0); KERNEL_CHECK(element.column_idx == 0);
KERNEL_CHECK(element.row_idx == idx);
KERNEL_CHECK(element.value == element.row_idx * 2.0f); KERNEL_CHECK(element.value == element.row_idx * 2.0f);
} else { } else {
KERNEL_CHECK(element.column_idx == 1); KERNEL_CHECK(element.column_idx == 1);
KERNEL_CHECK(element.row_idx == idx - kRowsA);
KERNEL_CHECK(element.value == element.row_idx * 2.0f); KERNEL_CHECK(element.value == element.row_idx * 2.0f);
} }
}); });

View File

@ -149,7 +149,8 @@ TEST(CpuPredictor, InplacePredict) {
HostDeviceVector<float> data; HostDeviceVector<float> data;
gen.GenerateDense(&data); gen.GenerateDense(&data);
ASSERT_EQ(data.Size(), kRows * kCols); ASSERT_EQ(data.Size(), kRows * kCols);
data::DenseAdapter x{data.HostPointer(), kRows, kCols}; std::shared_ptr<data::DenseAdapter> x{
new data::DenseAdapter(data.HostPointer(), kRows, kCols)};
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1); TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1);
} }
@ -158,8 +159,9 @@ TEST(CpuPredictor, InplacePredict) {
HostDeviceVector<bst_row_t> rptrs; HostDeviceVector<bst_row_t> rptrs;
HostDeviceVector<bst_feature_t> columns; HostDeviceVector<bst_feature_t> columns;
gen.GenerateCSR(&data, &rptrs, &columns); gen.GenerateCSR(&data, &rptrs, &columns);
data::CSRAdapter x(rptrs.HostPointer(), columns.HostPointer(), std::shared_ptr<data::CSRAdapter> x{new data::CSRAdapter(
data.HostPointer(), kRows, data.Size(), kCols); rptrs.HostPointer(), columns.HostPointer(), data.HostPointer(), kRows,
data.Size(), kCols)};
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1); TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1);
} }
} }

View File

@ -129,7 +129,7 @@ TEST(GPUPredictor, InplacePredictCupy) {
gen.Device(0); gen.Device(0);
HostDeviceVector<float> data; HostDeviceVector<float> data;
std::string interface_str = gen.GenerateArrayInterface(&data); std::string interface_str = gen.GenerateArrayInterface(&data);
data::CupyAdapter x{interface_str}; auto x = std::make_shared<data::CupyAdapter>(interface_str);
TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0); TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0);
} }
@ -139,7 +139,7 @@ TEST(GPUPredictor, InplacePredictCuDF) {
gen.Device(0); gen.Device(0);
std::vector<HostDeviceVector<float>> storage(kCols); std::vector<HostDeviceVector<float>> storage(kCols);
auto interface_str = gen.GenerateColumnarArrayInterface(&storage); auto interface_str = gen.GenerateColumnarArrayInterface(&storage);
data::CudfAdapter x {interface_str}; auto x = std::make_shared<data::CudfAdapter>(interface_str);
TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0); TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0);
} }
@ -154,7 +154,7 @@ TEST(GPUPredictor, MGPU_InplacePredict) { // NOLINT
gen.Device(1); gen.Device(1);
HostDeviceVector<float> data; HostDeviceVector<float> data;
std::string interface_str = gen.GenerateArrayInterface(&data); std::string interface_str = gen.GenerateArrayInterface(&data);
data::CupyAdapter x{interface_str}; auto x = std::make_shared<data::CupyAdapter>(interface_str);
TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 1); TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 1);
EXPECT_THROW(TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0), EXPECT_THROW(TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0),
dmlc::Error); dmlc::Error);