Remove column major specialization. (#5755)
Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
bd9d57f579
commit
cacff9232a
@ -463,7 +463,8 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, float *values,
|
|||||||
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
||||||
auto *learner = static_cast<xgboost::Learner *>(handle);
|
auto *learner = static_cast<xgboost::Learner *>(handle);
|
||||||
|
|
||||||
auto x = xgboost::data::DenseAdapter(values, n_rows, n_cols);
|
std::shared_ptr<xgboost::data::DenseAdapter> x{
|
||||||
|
new xgboost::data::DenseAdapter(values, n_rows, n_cols)};
|
||||||
HostDeviceVector<float>* p_predt { nullptr };
|
HostDeviceVector<float>* p_predt { nullptr };
|
||||||
std::string type { c_type };
|
std::string type { c_type };
|
||||||
learner->InplacePredict(x, type, missing, &p_predt);
|
learner->InplacePredict(x, type, missing, &p_predt);
|
||||||
@ -494,7 +495,8 @@ XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle,
|
|||||||
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
||||||
auto *learner = static_cast<xgboost::Learner *>(handle);
|
auto *learner = static_cast<xgboost::Learner *>(handle);
|
||||||
|
|
||||||
auto x = data::CSRAdapter(indptr, indices, data, nindptr - 1, nelem, num_col);
|
std::shared_ptr<xgboost::data::CSRAdapter> x{
|
||||||
|
new xgboost::data::CSRAdapter(indptr, indices, data, nindptr - 1, nelem, num_col)};
|
||||||
HostDeviceVector<float>* p_predt { nullptr };
|
HostDeviceVector<float>* p_predt { nullptr };
|
||||||
std::string type { c_type };
|
std::string type { c_type };
|
||||||
learner->InplacePredict(x, type, missing, &p_predt);
|
learner->InplacePredict(x, type, missing, &p_predt);
|
||||||
|
|||||||
@ -69,7 +69,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle,
|
|||||||
auto *learner = static_cast<Learner*>(handle);
|
auto *learner = static_cast<Learner*>(handle);
|
||||||
|
|
||||||
std::string json_str{c_json_strs};
|
std::string json_str{c_json_strs};
|
||||||
auto x = data::CudfAdapter(json_str);
|
auto x = std::make_shared<data::CudfAdapter>(json_str);
|
||||||
HostDeviceVector<float>* p_predt { nullptr };
|
HostDeviceVector<float>* p_predt { nullptr };
|
||||||
std::string type { c_type };
|
std::string type { c_type };
|
||||||
learner->InplacePredict(x, type, missing, &p_predt);
|
learner->InplacePredict(x, type, missing, &p_predt);
|
||||||
@ -97,7 +97,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle,
|
|||||||
auto *learner = static_cast<Learner*>(handle);
|
auto *learner = static_cast<Learner*>(handle);
|
||||||
|
|
||||||
std::string json_str{c_json_strs};
|
std::string json_str{c_json_strs};
|
||||||
auto x = data::CupyAdapter(json_str);
|
auto x = std::make_shared<data::CupyAdapter>(json_str);
|
||||||
HostDeviceVector<float>* p_predt { nullptr };
|
HostDeviceVector<float>* p_predt { nullptr };
|
||||||
std::string type { c_type };
|
std::string type { c_type };
|
||||||
learner->InplacePredict(x, type, missing, &p_predt);
|
learner->InplacePredict(x, type, missing, &p_predt);
|
||||||
|
|||||||
@ -34,36 +34,27 @@ struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
class CudfAdapterBatch : public detail::NoMetaInfo {
|
class CudfAdapterBatch : public detail::NoMetaInfo {
|
||||||
|
friend class CudfAdapter;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
CudfAdapterBatch() = default;
|
CudfAdapterBatch() = default;
|
||||||
CudfAdapterBatch(common::Span<ArrayInterface> columns,
|
CudfAdapterBatch(common::Span<ArrayInterface> columns, size_t num_rows)
|
||||||
common::Span<size_t> column_ptr, size_t num_elements)
|
|
||||||
: columns_(columns),
|
: columns_(columns),
|
||||||
column_ptr_(column_ptr),
|
num_rows_(num_rows) {}
|
||||||
num_elements_(num_elements) {}
|
size_t Size() const { return num_rows_ * columns_.size(); }
|
||||||
size_t Size() const { return num_elements_; }
|
|
||||||
__device__ COOTuple GetElement(size_t idx) const {
|
__device__ COOTuple GetElement(size_t idx) const {
|
||||||
size_t column_idx =
|
size_t column_idx = idx % columns_.size();
|
||||||
thrust::upper_bound(thrust::seq,column_ptr_.begin(), column_ptr_.end(), idx) - column_ptr_.begin() - 1;
|
size_t row_idx = idx / columns_.size();
|
||||||
auto& column = columns_[column_idx];
|
auto const& column = columns_[column_idx];
|
||||||
size_t row_idx = idx - column_ptr_[column_idx];
|
|
||||||
float value = column.valid.Data() == nullptr || column.valid.Check(row_idx)
|
float value = column.valid.Data() == nullptr || column.valid.Check(row_idx)
|
||||||
? column.GetElement(row_idx)
|
? column.GetElement(row_idx)
|
||||||
: std::numeric_limits<float>::quiet_NaN();
|
: std::numeric_limits<float>::quiet_NaN();
|
||||||
return {row_idx, column_idx, value};
|
return {row_idx, column_idx, value};
|
||||||
}
|
}
|
||||||
__device__ float GetValue(size_t ridx, bst_feature_t fidx) const {
|
|
||||||
auto const& column = columns_[fidx];
|
|
||||||
float value = column.valid.Data() == nullptr || column.valid.Check(ridx)
|
|
||||||
? column.GetElement(ridx)
|
|
||||||
: std::numeric_limits<float>::quiet_NaN();
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
common::Span<ArrayInterface> columns_;
|
common::Span<ArrayInterface> columns_;
|
||||||
common::Span<size_t> column_ptr_;
|
size_t num_rows_;
|
||||||
size_t num_elements_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
@ -127,7 +118,6 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
|
|||||||
CHECK_EQ(typestr.size(), 3) << ArrayInterfaceErrors::TypestrFormat();
|
CHECK_EQ(typestr.size(), 3) << ArrayInterfaceErrors::TypestrFormat();
|
||||||
CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian();
|
CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian();
|
||||||
std::vector<ArrayInterface> columns;
|
std::vector<ArrayInterface> columns;
|
||||||
std::vector<size_t> column_ptr({0});
|
|
||||||
auto first_column = ArrayInterface(get<Object const>(json_columns[0]));
|
auto first_column = ArrayInterface(get<Object const>(json_columns[0]));
|
||||||
device_idx_ = dh::CudaGetPointerDevice(first_column.data);
|
device_idx_ = dh::CudaGetPointerDevice(first_column.data);
|
||||||
CHECK_NE(device_idx_, -1);
|
CHECK_NE(device_idx_, -1);
|
||||||
@ -137,7 +127,6 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
|
|||||||
auto column = ArrayInterface(get<Object const>(json_col));
|
auto column = ArrayInterface(get<Object const>(json_col));
|
||||||
columns.push_back(column);
|
columns.push_back(column);
|
||||||
CHECK_EQ(column.num_cols, 1);
|
CHECK_EQ(column.num_cols, 1);
|
||||||
column_ptr.emplace_back(column_ptr.back() + column.num_rows);
|
|
||||||
num_rows_ = std::max(num_rows_, size_t(column.num_rows));
|
num_rows_ = std::max(num_rows_, size_t(column.num_rows));
|
||||||
CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data))
|
CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data))
|
||||||
<< "All columns should use the same device.";
|
<< "All columns should use the same device.";
|
||||||
@ -145,23 +134,20 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
|
|||||||
<< "All columns should have same number of rows.";
|
<< "All columns should have same number of rows.";
|
||||||
}
|
}
|
||||||
columns_ = columns;
|
columns_ = columns;
|
||||||
column_ptr_ = column_ptr;
|
batch_ = CudfAdapterBatch(dh::ToSpan(columns_), num_rows_);
|
||||||
batch_ = CudfAdapterBatch(dh::ToSpan(columns_), dh::ToSpan(column_ptr_),
|
}
|
||||||
column_ptr.back());
|
const CudfAdapterBatch& Value() const override {
|
||||||
|
CHECK_EQ(batch_.columns_.data(), columns_.data().get());
|
||||||
|
return batch_;
|
||||||
}
|
}
|
||||||
const CudfAdapterBatch& Value() const override { return batch_; }
|
|
||||||
|
|
||||||
size_t NumRows() const { return num_rows_; }
|
size_t NumRows() const { return num_rows_; }
|
||||||
size_t NumColumns() const { return columns_.size(); }
|
size_t NumColumns() const { return columns_.size(); }
|
||||||
size_t DeviceIdx() const { return device_idx_; }
|
size_t DeviceIdx() const { return device_idx_; }
|
||||||
|
|
||||||
// Cudf is column major
|
|
||||||
bool IsRowMajor() { return false; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CudfAdapterBatch batch_;
|
CudfAdapterBatch batch_;
|
||||||
dh::device_vector<ArrayInterface> columns_;
|
dh::device_vector<ArrayInterface> columns_;
|
||||||
dh::device_vector<size_t> column_ptr_; // Exclusive scan of column sizes
|
|
||||||
size_t num_rows_{0};
|
size_t num_rows_{0};
|
||||||
int device_idx_;
|
int device_idx_;
|
||||||
};
|
};
|
||||||
@ -201,8 +187,6 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
|
|||||||
size_t NumColumns() const { return array_interface_.num_cols; }
|
size_t NumColumns() const { return array_interface_.num_cols; }
|
||||||
size_t DeviceIdx() const { return device_idx_; }
|
size_t DeviceIdx() const { return device_idx_; }
|
||||||
|
|
||||||
bool IsRowMajor() { return true; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ArrayInterface array_interface_;
|
ArrayInterface array_interface_;
|
||||||
CupyAdapterBatch batch_;
|
CupyAdapterBatch batch_;
|
||||||
|
|||||||
@ -154,8 +154,8 @@ struct WriteCompressedEllpackFunctor {
|
|||||||
// Here the data is already correctly ordered and simply needs to be compacted
|
// Here the data is already correctly ordered and simply needs to be compacted
|
||||||
// to remove missing data
|
// to remove missing data
|
||||||
template <typename AdapterBatchT>
|
template <typename AdapterBatchT>
|
||||||
void CopyDataRowMajor(const AdapterBatchT& batch, EllpackPageImpl* dst,
|
void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst,
|
||||||
int device_idx, float missing) {
|
int device_idx, float missing) {
|
||||||
// Some witchcraft happens here
|
// Some witchcraft happens here
|
||||||
// The goal is to copy valid elements out of the input to an ellpack matrix
|
// The goal is to copy valid elements out of the input to an ellpack matrix
|
||||||
// with a given row stride, using no extra working memory Standard stream
|
// with a given row stride, using no extra working memory Standard stream
|
||||||
@ -209,51 +209,6 @@ void CopyDataRowMajor(const AdapterBatchT& batch, EllpackPageImpl* dst,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename AdapterT, typename AdapterBatchT>
|
|
||||||
void CopyDataColumnMajor(AdapterT* adapter, const AdapterBatchT& batch,
|
|
||||||
EllpackPageImpl* dst, float missing) {
|
|
||||||
// Step 1: Get the sizes of the input columns
|
|
||||||
dh::caching_device_vector<size_t> column_sizes(adapter->NumColumns(), 0);
|
|
||||||
auto d_column_sizes = column_sizes.data().get();
|
|
||||||
// Populate column sizes
|
|
||||||
dh::LaunchN(adapter->DeviceIdx(), batch.Size(), [=] __device__(size_t idx) {
|
|
||||||
const auto& e = batch.GetElement(idx);
|
|
||||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
|
||||||
&d_column_sizes[e.column_idx]),
|
|
||||||
static_cast<unsigned long long>(1)); // NOLINT
|
|
||||||
});
|
|
||||||
|
|
||||||
thrust::host_vector<size_t> host_column_sizes = column_sizes;
|
|
||||||
|
|
||||||
// Step 2: Iterate over columns, place elements in correct row, increment
|
|
||||||
// temporary row pointers
|
|
||||||
dh::caching_device_vector<size_t> temp_row_ptr(adapter->NumRows(), 0);
|
|
||||||
auto d_temp_row_ptr = temp_row_ptr.data().get();
|
|
||||||
auto row_stride = dst->row_stride;
|
|
||||||
size_t begin = 0;
|
|
||||||
auto device_accessor = dst->GetDeviceAccessor(adapter->DeviceIdx());
|
|
||||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
|
||||||
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
|
|
||||||
data::IsValidFunctor is_valid(missing);
|
|
||||||
for (auto size : host_column_sizes) {
|
|
||||||
size_t end = begin + size;
|
|
||||||
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
|
|
||||||
auto writer_non_const =
|
|
||||||
writer; // For some reason this variable gets captured as const
|
|
||||||
const auto& e = batch.GetElement(idx + begin);
|
|
||||||
if (!is_valid(e)) return;
|
|
||||||
size_t output_position =
|
|
||||||
e.row_idx * row_stride + d_temp_row_ptr[e.row_idx];
|
|
||||||
auto bin_idx = device_accessor.SearchBin(e.value, e.column_idx);
|
|
||||||
writer_non_const.AtomicWriteSymbol(d_compressed_buffer, bin_idx,
|
|
||||||
output_position);
|
|
||||||
d_temp_row_ptr[e.row_idx] += 1;
|
|
||||||
});
|
|
||||||
|
|
||||||
begin = end;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
||||||
common::Span<size_t> row_counts) {
|
common::Span<size_t> row_counts) {
|
||||||
// Write the null values
|
// Write the null values
|
||||||
@ -284,12 +239,7 @@ EllpackPageImpl::EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense
|
|||||||
|
|
||||||
*this = EllpackPageImpl(adapter->DeviceIdx(), cuts, is_dense, row_stride,
|
*this = EllpackPageImpl(adapter->DeviceIdx(), cuts, is_dense, row_stride,
|
||||||
adapter->NumRows());
|
adapter->NumRows());
|
||||||
if (adapter->IsRowMajor()) {
|
CopyDataToEllpack(batch, this, adapter->DeviceIdx(), missing);
|
||||||
CopyDataRowMajor(batch, this, adapter->DeviceIdx(), missing);
|
|
||||||
} else {
|
|
||||||
CopyDataColumnMajor(adapter, batch, this, missing);
|
|
||||||
}
|
|
||||||
|
|
||||||
WriteNullValues(this, adapter->DeviceIdx(), row_counts_span);
|
WriteNullValues(this, adapter->DeviceIdx(), row_counts_span);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -35,51 +35,12 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
|||||||
thrust::device_pointer_cast(offset.data()));
|
thrust::device_pointer_cast(offset.data()));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename AdapterT>
|
|
||||||
void CopyDataColumnMajor(AdapterT* adapter, common::Span<Entry> data,
|
|
||||||
int device_idx, float missing,
|
|
||||||
common::Span<size_t> row_ptr) {
|
|
||||||
// Step 1: Get the sizes of the input columns
|
|
||||||
dh::device_vector<size_t> column_sizes(adapter->NumColumns());
|
|
||||||
auto d_column_sizes = column_sizes.data().get();
|
|
||||||
auto& batch = adapter->Value();
|
|
||||||
// Populate column sizes
|
|
||||||
dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) {
|
|
||||||
const auto& e = batch.GetElement(idx);
|
|
||||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
|
||||||
&d_column_sizes[e.column_idx]),
|
|
||||||
static_cast<unsigned long long>(1)); // NOLINT
|
|
||||||
});
|
|
||||||
|
|
||||||
thrust::host_vector<size_t> host_column_sizes = column_sizes;
|
|
||||||
|
|
||||||
// Step 2: Iterate over columns, place elements in correct row, increment
|
|
||||||
// temporary row pointers
|
|
||||||
dh::device_vector<size_t> temp_row_ptr(
|
|
||||||
thrust::device_pointer_cast(row_ptr.data()),
|
|
||||||
thrust::device_pointer_cast(row_ptr.data() + row_ptr.size()));
|
|
||||||
auto d_temp_row_ptr = temp_row_ptr.data().get();
|
|
||||||
size_t begin = 0;
|
|
||||||
IsValidFunctor is_valid(missing);
|
|
||||||
for (auto size : host_column_sizes) {
|
|
||||||
size_t end = begin + size;
|
|
||||||
dh::LaunchN(device_idx, end - begin, [=] __device__(size_t idx) {
|
|
||||||
const auto& e = batch.GetElement(idx + begin);
|
|
||||||
if (!is_valid(e)) return;
|
|
||||||
data[d_temp_row_ptr[e.row_idx]] = Entry(e.column_idx, e.value);
|
|
||||||
d_temp_row_ptr[e.row_idx] += 1;
|
|
||||||
});
|
|
||||||
|
|
||||||
begin = end;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Here the data is already correctly ordered and simply needs to be compacted
|
// Here the data is already correctly ordered and simply needs to be compacted
|
||||||
// to remove missing data
|
// to remove missing data
|
||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
void CopyDataRowMajor(AdapterT* adapter, common::Span<Entry> data,
|
void CopyDataToDMatrix(AdapterT* adapter, common::Span<Entry> data,
|
||||||
int device_idx, float missing,
|
int device_idx, float missing,
|
||||||
common::Span<size_t> row_ptr) {
|
common::Span<size_t> row_ptr) {
|
||||||
auto& batch = adapter->Value();
|
auto& batch = adapter->Value();
|
||||||
auto transform_f = [=] __device__(size_t idx) {
|
auto transform_f = [=] __device__(size_t idx) {
|
||||||
const auto& e = batch.GetElement(idx);
|
const auto& e = batch.GetElement(idx);
|
||||||
@ -116,13 +77,8 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
|||||||
CountRowOffsets(batch, s_offset, adapter->DeviceIdx(), missing);
|
CountRowOffsets(batch, s_offset, adapter->DeviceIdx(), missing);
|
||||||
info_.num_nonzero_ = sparse_page_.offset.HostVector().back();
|
info_.num_nonzero_ = sparse_page_.offset.HostVector().back();
|
||||||
sparse_page_.data.Resize(info_.num_nonzero_);
|
sparse_page_.data.Resize(info_.num_nonzero_);
|
||||||
if (adapter->IsRowMajor()) {
|
CopyDataToDMatrix(adapter, sparse_page_.data.DeviceSpan(),
|
||||||
CopyDataRowMajor(adapter, sparse_page_.data.DeviceSpan(),
|
adapter->DeviceIdx(), missing, s_offset);
|
||||||
adapter->DeviceIdx(), missing, s_offset);
|
|
||||||
} else {
|
|
||||||
CopyDataColumnMajor(adapter, sparse_page_.data.DeviceSpan(),
|
|
||||||
adapter->DeviceIdx(), missing, s_offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
info_.num_col_ = adapter->NumColumns();
|
info_.num_col_ = adapter->NumColumns();
|
||||||
info_.num_row_ = adapter->NumRows();
|
info_.num_row_ = adapter->NumRows();
|
||||||
|
|||||||
@ -271,12 +271,12 @@ class CPUPredictor : public Predictor {
|
|||||||
PredictionCacheEntry *out_preds,
|
PredictionCacheEntry *out_preds,
|
||||||
uint32_t tree_begin, uint32_t tree_end) const {
|
uint32_t tree_begin, uint32_t tree_end) const {
|
||||||
auto threads = omp_get_max_threads();
|
auto threads = omp_get_max_threads();
|
||||||
auto m = dmlc::get<Adapter>(x);
|
auto m = dmlc::get<std::shared_ptr<Adapter>>(x);
|
||||||
CHECK_EQ(m.NumColumns(), model.learner_model_param->num_feature)
|
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
|
||||||
<< "Number of columns in data must equal to trained model.";
|
<< "Number of columns in data must equal to trained model.";
|
||||||
MetaInfo info;
|
MetaInfo info;
|
||||||
info.num_col_ = m.NumColumns();
|
info.num_col_ = m->NumColumns();
|
||||||
info.num_row_ = m.NumRows();
|
info.num_row_ = m->NumRows();
|
||||||
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
||||||
std::vector<Entry> workspace(info.num_col_ * 8 * threads);
|
std::vector<Entry> workspace(info.num_col_ * 8 * threads);
|
||||||
auto &predictions = out_preds->predictions.HostVector();
|
auto &predictions = out_preds->predictions.HostVector();
|
||||||
@ -284,17 +284,17 @@ class CPUPredictor : public Predictor {
|
|||||||
InitThreadTemp(threads, model.learner_model_param->num_feature, &thread_temp);
|
InitThreadTemp(threads, model.learner_model_param->num_feature, &thread_temp);
|
||||||
size_t constexpr kUnroll = 8;
|
size_t constexpr kUnroll = 8;
|
||||||
PredictBatchKernel(AdapterView<Adapter, kUnroll>(
|
PredictBatchKernel(AdapterView<Adapter, kUnroll>(
|
||||||
&m, missing, common::Span<Entry>{workspace}),
|
m.get(), missing, common::Span<Entry>{workspace}),
|
||||||
&predictions, model, tree_begin, tree_end, &thread_temp);
|
&predictions, model, tree_begin, tree_end, &thread_temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
|
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
|
||||||
float missing, PredictionCacheEntry *out_preds,
|
float missing, PredictionCacheEntry *out_preds,
|
||||||
uint32_t tree_begin, unsigned tree_end) const override {
|
uint32_t tree_begin, unsigned tree_end) const override {
|
||||||
if (x.type() == typeid(data::DenseAdapter)) {
|
if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) {
|
||||||
this->DispatchedInplacePredict<data::DenseAdapter>(
|
this->DispatchedInplacePredict<data::DenseAdapter>(
|
||||||
x, model, missing, out_preds, tree_begin, tree_end);
|
x, model, missing, out_preds, tree_begin, tree_end);
|
||||||
} else if (x.type() == typeid(data::CSRAdapter)) {
|
} else if (x.type() == typeid(std::shared_ptr<data::CSRAdapter>)) {
|
||||||
this->DispatchedInplacePredict<data::CSRAdapter>(
|
this->DispatchedInplacePredict<data::CSRAdapter>(
|
||||||
x, model, missing, out_preds, tree_begin, tree_end);
|
x, model, missing, out_preds, tree_begin, tree_end);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -118,14 +118,18 @@ struct EllpackLoader {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CuPyAdapterLoader {
|
template <typename Batch>
|
||||||
data::CupyAdapterBatch batch;
|
struct DeviceAdapterLoader {
|
||||||
|
Batch batch;
|
||||||
bst_feature_t columns;
|
bst_feature_t columns;
|
||||||
float* smem;
|
float* smem;
|
||||||
bool use_shared;
|
bool use_shared;
|
||||||
|
|
||||||
DEV_INLINE CuPyAdapterLoader(data::CupyAdapterBatch const batch, bool use_shared,
|
using BatchT = Batch;
|
||||||
bst_feature_t num_features, bst_row_t num_rows, size_t entry_start) :
|
|
||||||
|
DEV_INLINE DeviceAdapterLoader(Batch const batch, bool use_shared,
|
||||||
|
bst_feature_t num_features, bst_row_t num_rows,
|
||||||
|
size_t entry_start) :
|
||||||
batch{batch},
|
batch{batch},
|
||||||
columns{num_features},
|
columns{num_features},
|
||||||
use_shared{use_shared} {
|
use_shared{use_shared} {
|
||||||
@ -155,39 +159,6 @@ struct CuPyAdapterLoader {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CuDFAdapterLoader {
|
|
||||||
data::CudfAdapterBatch batch;
|
|
||||||
bst_feature_t columns;
|
|
||||||
float* smem;
|
|
||||||
bool use_shared;
|
|
||||||
|
|
||||||
DEV_INLINE CuDFAdapterLoader(data::CudfAdapterBatch const batch, bool use_shared,
|
|
||||||
bst_feature_t num_features,
|
|
||||||
bst_row_t num_rows, size_t entry_start)
|
|
||||||
: batch{batch}, columns{num_features}, use_shared{use_shared} {
|
|
||||||
extern __shared__ float _smem[];
|
|
||||||
smem = _smem;
|
|
||||||
if (use_shared) {
|
|
||||||
uint32_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
size_t shared_elements = blockDim.x * num_features;
|
|
||||||
dh::BlockFill(smem, shared_elements, nanf(""));
|
|
||||||
__syncthreads();
|
|
||||||
if (global_idx < num_rows) {
|
|
||||||
for (size_t i = 0; i < columns; ++i) {
|
|
||||||
smem[threadIdx.x * columns + i] = batch.GetValue(global_idx, i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
DEV_INLINE float GetFvalue(bst_row_t ridx, bst_feature_t fidx) const {
|
|
||||||
if (use_shared) {
|
|
||||||
return smem[threadIdx.x * columns + fidx];
|
|
||||||
}
|
|
||||||
return batch.GetValue(ridx, fidx);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Loader>
|
template <typename Loader>
|
||||||
__device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
|
__device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
|
||||||
Loader* loader) {
|
Loader* loader) {
|
||||||
@ -429,7 +400,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
out_preds->Size() == dmat->Info().num_row_);
|
out_preds->Size() == dmat->Info().num_row_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Adapter, typename Loader, typename Batch>
|
template <typename Adapter, typename Loader>
|
||||||
void DispatchedInplacePredict(dmlc::any const &x,
|
void DispatchedInplacePredict(dmlc::any const &x,
|
||||||
const gbm::GBTreeModel &model, float missing,
|
const gbm::GBTreeModel &model, float missing,
|
||||||
PredictionCacheEntry *out_preds,
|
PredictionCacheEntry *out_preds,
|
||||||
@ -439,22 +410,22 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
DeviceModel d_model;
|
DeviceModel d_model;
|
||||||
d_model.Init(model, tree_begin, tree_end, this->generic_param_->gpu_id);
|
d_model.Init(model, tree_begin, tree_end, this->generic_param_->gpu_id);
|
||||||
|
|
||||||
auto m = dmlc::get<Adapter>(x);
|
auto m = dmlc::get<std::shared_ptr<Adapter>>(x);
|
||||||
CHECK_EQ(m.NumColumns(), model.learner_model_param->num_feature)
|
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
|
||||||
<< "Number of columns in data must equal to trained model.";
|
<< "Number of columns in data must equal to trained model.";
|
||||||
CHECK_EQ(this->generic_param_->gpu_id, m.DeviceIdx())
|
CHECK_EQ(this->generic_param_->gpu_id, m->DeviceIdx())
|
||||||
<< "XGBoost is running on device: " << this->generic_param_->gpu_id << ", "
|
<< "XGBoost is running on device: " << this->generic_param_->gpu_id << ", "
|
||||||
<< "but data is on: " << m.DeviceIdx();
|
<< "but data is on: " << m->DeviceIdx();
|
||||||
MetaInfo info;
|
MetaInfo info;
|
||||||
info.num_col_ = m.NumColumns();
|
info.num_col_ = m->NumColumns();
|
||||||
info.num_row_ = m.NumRows();
|
info.num_row_ = m->NumRows();
|
||||||
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
||||||
|
|
||||||
const uint32_t BLOCK_THREADS = 128;
|
const uint32_t BLOCK_THREADS = 128;
|
||||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(info.num_row_, BLOCK_THREADS));
|
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(info.num_row_, BLOCK_THREADS));
|
||||||
|
|
||||||
auto shared_memory_bytes =
|
auto shared_memory_bytes =
|
||||||
static_cast<size_t>(sizeof(float) * m.NumColumns() * BLOCK_THREADS);
|
static_cast<size_t>(sizeof(float) * m->NumColumns() * BLOCK_THREADS);
|
||||||
bool use_shared = true;
|
bool use_shared = true;
|
||||||
if (shared_memory_bytes > max_shared_memory_bytes) {
|
if (shared_memory_bytes > max_shared_memory_bytes) {
|
||||||
shared_memory_bytes = 0;
|
shared_memory_bytes = 0;
|
||||||
@ -463,22 +434,24 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
size_t entry_start = 0;
|
size_t entry_start = 0;
|
||||||
|
|
||||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||||
PredictKernel<Loader, Batch>,
|
PredictKernel<Loader, typename Loader::BatchT>,
|
||||||
m.Value(),
|
m->Value(),
|
||||||
dh::ToSpan(d_model.nodes), out_preds->predictions.DeviceSpan(),
|
dh::ToSpan(d_model.nodes), out_preds->predictions.DeviceSpan(),
|
||||||
dh::ToSpan(d_model.tree_segments), dh::ToSpan(d_model.tree_group),
|
dh::ToSpan(d_model.tree_segments), dh::ToSpan(d_model.tree_group),
|
||||||
tree_begin, tree_end, m.NumColumns(), info.num_row_,
|
tree_begin, tree_end, m->NumColumns(), info.num_row_,
|
||||||
entry_start, use_shared, output_groups);
|
entry_start, use_shared, output_groups);
|
||||||
}
|
}
|
||||||
|
|
||||||
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
|
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
|
||||||
float missing, PredictionCacheEntry *out_preds,
|
float missing, PredictionCacheEntry *out_preds,
|
||||||
uint32_t tree_begin, unsigned tree_end) const override {
|
uint32_t tree_begin, unsigned tree_end) const override {
|
||||||
if (x.type() == typeid(data::CupyAdapter)) {
|
if (x.type() == typeid(std::shared_ptr<data::CupyAdapter>)) {
|
||||||
this->DispatchedInplacePredict<data::CupyAdapter, CuPyAdapterLoader, data::CupyAdapterBatch>(
|
this->DispatchedInplacePredict<
|
||||||
|
data::CupyAdapter, DeviceAdapterLoader<data::CupyAdapterBatch>>(
|
||||||
x, model, missing, out_preds, tree_begin, tree_end);
|
x, model, missing, out_preds, tree_begin, tree_end);
|
||||||
} else if (x.type() == typeid(data::CudfAdapter)) {
|
} else if (x.type() == typeid(std::shared_ptr<data::CudfAdapter>)) {
|
||||||
this->DispatchedInplacePredict<data::CudfAdapter, CuDFAdapterLoader, data::CudfAdapterBatch>(
|
this->DispatchedInplacePredict<
|
||||||
|
data::CudfAdapter, DeviceAdapterLoader<data::CudfAdapterBatch>>(
|
||||||
x, model, missing, out_preds, tree_begin, tree_end);
|
x, model, missing, out_preds, tree_begin, tree_end);
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Only CuPy and CuDF are supported by GPU Predictor.";
|
LOG(FATAL) << "Only CuPy and CuDF are supported by GPU Predictor.";
|
||||||
|
|||||||
@ -36,13 +36,12 @@ void TestCudfAdapter()
|
|||||||
EXPECT_NO_THROW({
|
EXPECT_NO_THROW({
|
||||||
dh::LaunchN(0, batch.Size(), [=] __device__(size_t idx) {
|
dh::LaunchN(0, batch.Size(), [=] __device__(size_t idx) {
|
||||||
auto element = batch.GetElement(idx);
|
auto element = batch.GetElement(idx);
|
||||||
if (idx < kRowsA) {
|
KERNEL_CHECK(element.row_idx == idx / 2);
|
||||||
|
if (idx % 2 == 0) {
|
||||||
KERNEL_CHECK(element.column_idx == 0);
|
KERNEL_CHECK(element.column_idx == 0);
|
||||||
KERNEL_CHECK(element.row_idx == idx);
|
|
||||||
KERNEL_CHECK(element.value == element.row_idx * 2.0f);
|
KERNEL_CHECK(element.value == element.row_idx * 2.0f);
|
||||||
} else {
|
} else {
|
||||||
KERNEL_CHECK(element.column_idx == 1);
|
KERNEL_CHECK(element.column_idx == 1);
|
||||||
KERNEL_CHECK(element.row_idx == idx - kRowsA);
|
|
||||||
KERNEL_CHECK(element.value == element.row_idx * 2.0f);
|
KERNEL_CHECK(element.value == element.row_idx * 2.0f);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@ -149,7 +149,8 @@ TEST(CpuPredictor, InplacePredict) {
|
|||||||
HostDeviceVector<float> data;
|
HostDeviceVector<float> data;
|
||||||
gen.GenerateDense(&data);
|
gen.GenerateDense(&data);
|
||||||
ASSERT_EQ(data.Size(), kRows * kCols);
|
ASSERT_EQ(data.Size(), kRows * kCols);
|
||||||
data::DenseAdapter x{data.HostPointer(), kRows, kCols};
|
std::shared_ptr<data::DenseAdapter> x{
|
||||||
|
new data::DenseAdapter(data.HostPointer(), kRows, kCols)};
|
||||||
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1);
|
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -158,8 +159,9 @@ TEST(CpuPredictor, InplacePredict) {
|
|||||||
HostDeviceVector<bst_row_t> rptrs;
|
HostDeviceVector<bst_row_t> rptrs;
|
||||||
HostDeviceVector<bst_feature_t> columns;
|
HostDeviceVector<bst_feature_t> columns;
|
||||||
gen.GenerateCSR(&data, &rptrs, &columns);
|
gen.GenerateCSR(&data, &rptrs, &columns);
|
||||||
data::CSRAdapter x(rptrs.HostPointer(), columns.HostPointer(),
|
std::shared_ptr<data::CSRAdapter> x{new data::CSRAdapter(
|
||||||
data.HostPointer(), kRows, data.Size(), kCols);
|
rptrs.HostPointer(), columns.HostPointer(), data.HostPointer(), kRows,
|
||||||
|
data.Size(), kCols)};
|
||||||
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1);
|
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -129,7 +129,7 @@ TEST(GPUPredictor, InplacePredictCupy) {
|
|||||||
gen.Device(0);
|
gen.Device(0);
|
||||||
HostDeviceVector<float> data;
|
HostDeviceVector<float> data;
|
||||||
std::string interface_str = gen.GenerateArrayInterface(&data);
|
std::string interface_str = gen.GenerateArrayInterface(&data);
|
||||||
data::CupyAdapter x{interface_str};
|
auto x = std::make_shared<data::CupyAdapter>(interface_str);
|
||||||
TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0);
|
TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,7 +139,7 @@ TEST(GPUPredictor, InplacePredictCuDF) {
|
|||||||
gen.Device(0);
|
gen.Device(0);
|
||||||
std::vector<HostDeviceVector<float>> storage(kCols);
|
std::vector<HostDeviceVector<float>> storage(kCols);
|
||||||
auto interface_str = gen.GenerateColumnarArrayInterface(&storage);
|
auto interface_str = gen.GenerateColumnarArrayInterface(&storage);
|
||||||
data::CudfAdapter x {interface_str};
|
auto x = std::make_shared<data::CudfAdapter>(interface_str);
|
||||||
TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0);
|
TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -154,7 +154,7 @@ TEST(GPUPredictor, MGPU_InplacePredict) { // NOLINT
|
|||||||
gen.Device(1);
|
gen.Device(1);
|
||||||
HostDeviceVector<float> data;
|
HostDeviceVector<float> data;
|
||||||
std::string interface_str = gen.GenerateArrayInterface(&data);
|
std::string interface_str = gen.GenerateArrayInterface(&data);
|
||||||
data::CupyAdapter x{interface_str};
|
auto x = std::make_shared<data::CupyAdapter>(interface_str);
|
||||||
TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 1);
|
TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 1);
|
||||||
EXPECT_THROW(TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0),
|
EXPECT_THROW(TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0),
|
||||||
dmlc::Error);
|
dmlc::Error);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user