Implement iterative DMatrix. (#5837)
This commit is contained in:
@@ -19,6 +19,7 @@
|
||||
#include "../common/version.h"
|
||||
#include "../common/group_data.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/iterative_device_dmatrix.h"
|
||||
|
||||
#if DMLC_ENABLE_STD_THREAD
|
||||
#include "./sparse_page_source.h"
|
||||
@@ -569,6 +570,26 @@ DMatrix* DMatrix::Load(const std::string& uri,
|
||||
}
|
||||
return dmat;
|
||||
}
|
||||
template <typename DataIterHandle, typename DMatrixHandle,
|
||||
typename DataIterResetCallback, typename XGDMatrixCallbackNext>
|
||||
DMatrix *DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, float missing,
|
||||
int nthread,
|
||||
int max_bin) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
return new data::IterativeDeviceDMatrix(iter, proxy, reset, next, missing, nthread, max_bin);
|
||||
#else
|
||||
common::AssertGPUSupport();
|
||||
return nullptr;
|
||||
#endif
|
||||
}
|
||||
|
||||
template DMatrix *DMatrix::Create<DataIterHandle, DMatrixHandle,
|
||||
DataIterResetCallback, XGDMatrixCallbackNext>(
|
||||
DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, float missing, int nthread,
|
||||
int max_bin);
|
||||
|
||||
template <typename AdapterT>
|
||||
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
|
||||
|
||||
@@ -252,6 +252,31 @@ EllpackPageImpl::EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense
|
||||
ELLPACK_SPECIALIZATION(data::CudfAdapter)
|
||||
ELLPACK_SPECIALIZATION(data::CupyAdapter)
|
||||
|
||||
|
||||
template <typename AdapterBatch>
|
||||
EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device,
|
||||
bool is_dense, int nthread,
|
||||
common::Span<size_t> row_counts_span,
|
||||
size_t row_stride, size_t n_rows, size_t n_cols,
|
||||
common::HistogramCuts const& cuts) {
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
|
||||
*this = EllpackPageImpl(device, cuts, is_dense, row_stride, n_rows);
|
||||
CopyDataToEllpack(batch, this, device, missing);
|
||||
WriteNullValues(this, device, row_counts_span);
|
||||
}
|
||||
|
||||
#define ELLPACK_BATCH_SPECIALIZE(__BATCH_T) \
|
||||
template EllpackPageImpl::EllpackPageImpl( \
|
||||
__BATCH_T batch, float missing, int device, \
|
||||
bool is_dense, int nthread, \
|
||||
common::Span<size_t> row_counts_span, \
|
||||
size_t row_stride, size_t n_rows, size_t n_cols, \
|
||||
common::HistogramCuts const& cuts);
|
||||
|
||||
ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch)
|
||||
ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch)
|
||||
|
||||
// A functor that copies the data from one EllpackPage to another.
|
||||
struct CopyPage {
|
||||
common::CompressedBufferWriter cbw;
|
||||
@@ -279,6 +304,10 @@ size_t EllpackPageImpl::Copy(int device, EllpackPageImpl* page, size_t offset) {
|
||||
CHECK_EQ(row_stride, page->row_stride);
|
||||
CHECK_EQ(NumSymbols(), page->NumSymbols());
|
||||
CHECK_GE(n_rows * row_stride, offset + num_elements);
|
||||
if (page == this) {
|
||||
LOG(FATAL) << "Concatenating the same Ellpack.";
|
||||
return this->n_rows * this->row_stride;
|
||||
}
|
||||
gidx_buffer.SetDevice(device);
|
||||
page->gidx_buffer.SetDevice(device);
|
||||
dh::LaunchN(device, num_elements, CopyPage(this, page, offset));
|
||||
|
||||
@@ -149,7 +149,7 @@ class EllpackPageImpl {
|
||||
|
||||
EllpackPageImpl(int device, common::HistogramCuts cuts,
|
||||
const SparsePage& page,
|
||||
bool is_dense,size_t row_stride);
|
||||
bool is_dense, size_t row_stride);
|
||||
|
||||
/*!
|
||||
* \brief Constructor from an existing DMatrix.
|
||||
@@ -161,8 +161,16 @@ class EllpackPageImpl {
|
||||
|
||||
template <typename AdapterT>
|
||||
explicit EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense, int nthread,
|
||||
int max_bin, common::Span<size_t> row_counts_span,
|
||||
int max_bin,
|
||||
common::Span<size_t> row_counts_span,
|
||||
size_t row_stride);
|
||||
|
||||
template <typename AdapterBatch>
|
||||
explicit EllpackPageImpl(AdapterBatch batch, float missing, int device, bool is_dense, int nthread,
|
||||
common::Span<size_t> row_counts_span,
|
||||
size_t row_stride, size_t n_rows, size_t n_cols,
|
||||
common::HistogramCuts const& cuts);
|
||||
|
||||
/*! \brief Copy the elements of the given ELLPACK page into this page.
|
||||
*
|
||||
* @param device The GPU device to use.
|
||||
|
||||
188
src/data/iterative_device_dmatrix.cu
Normal file
188
src/data/iterative_device_dmatrix.cu
Normal file
@@ -0,0 +1,188 @@
|
||||
/*!
|
||||
* Copyright 2020 XGBoost contributors
|
||||
*/
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <algorithm>
|
||||
|
||||
#include "../common/hist_util.cuh"
|
||||
#include "simple_batch_iterator.h"
|
||||
#include "iterative_device_dmatrix.h"
|
||||
#include "sparse_page_source.h"
|
||||
#include "ellpack_page.cuh"
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "device_adapter.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
template <typename Fn>
|
||||
decltype(auto) Dispatch(DMatrixProxy const* proxy, Fn fn) {
|
||||
if (proxy->Adapter().type() == typeid(std::shared_ptr<CupyAdapter>)) {
|
||||
auto value = dmlc::get<std::shared_ptr<CupyAdapter>>(
|
||||
proxy->Adapter())->Value();
|
||||
return fn(value);
|
||||
} else if (proxy->Adapter().type() == typeid(std::shared_ptr<CudfAdapter>)) {
|
||||
auto value = dmlc::get<std::shared_ptr<CudfAdapter>>(
|
||||
proxy->Adapter())->Value();
|
||||
return fn(value);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
|
||||
auto value = dmlc::get<std::shared_ptr<CudfAdapter>>(
|
||||
proxy->Adapter())->Value();
|
||||
return fn(value);
|
||||
}
|
||||
}
|
||||
|
||||
void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missing, int nthread) {
|
||||
// A handle passed to external iterator.
|
||||
auto handle = static_cast<std::shared_ptr<DMatrix>*>(proxy_);
|
||||
CHECK(handle);
|
||||
DMatrixProxy* proxy = static_cast<DMatrixProxy*>(handle->get());
|
||||
CHECK(proxy);
|
||||
// The external iterator
|
||||
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
|
||||
iter_handle, reset_, next_};
|
||||
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
auto num_rows = [&]() {
|
||||
return Dispatch(proxy, [](auto const &value) { return value.NumRows(); });
|
||||
};
|
||||
auto num_cols = [&]() {
|
||||
return Dispatch(proxy, [](auto const &value) { return value.NumCols(); });
|
||||
};
|
||||
|
||||
size_t row_stride = 0;
|
||||
size_t nnz = 0;
|
||||
// Sketch for all batches.
|
||||
iter.Reset();
|
||||
common::HistogramCuts cuts;
|
||||
common::DenseCuts dense_cuts(&cuts);
|
||||
|
||||
std::vector<common::SketchContainer> sketch_containers;
|
||||
size_t batches = 0;
|
||||
size_t accumulated_rows = 0;
|
||||
bst_feature_t cols = 0;
|
||||
while (iter.Next()) {
|
||||
auto device = proxy->DeviceIdx();
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
if (cols == 0) {
|
||||
cols = num_cols();
|
||||
} else {
|
||||
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
|
||||
}
|
||||
sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows());
|
||||
auto* p_sketch = &sketch_containers.back();
|
||||
if (proxy->Info().weights_.Size() != 0) {
|
||||
proxy->Info().weights_.SetDevice(device);
|
||||
Dispatch(proxy, [&](auto const &value) {
|
||||
common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin,
|
||||
proxy->Info(),
|
||||
missing, device, p_sketch);
|
||||
});
|
||||
} else {
|
||||
Dispatch(proxy, [&](auto const &value) {
|
||||
common::AdapterDeviceSketch(value, batch_param_.max_bin, missing,
|
||||
device, p_sketch);
|
||||
});
|
||||
}
|
||||
|
||||
auto batch_rows = num_rows();
|
||||
accumulated_rows += batch_rows;
|
||||
dh::caching_device_vector<size_t> row_counts(batch_rows + 1, 0);
|
||||
common::Span<size_t> row_counts_span(row_counts.data().get(),
|
||||
row_counts.size());
|
||||
row_stride =
|
||||
std::max(row_stride, Dispatch(proxy, [=](auto const& value) {
|
||||
return GetRowCounts(value, row_counts_span, device, missing);
|
||||
}));
|
||||
nnz += thrust::reduce(thrust::cuda::par(alloc),
|
||||
row_counts.begin(), row_counts.end());
|
||||
batches++;
|
||||
}
|
||||
|
||||
// Merging multiple batches for each column
|
||||
std::vector<common::WQSketch::SummaryContainer> summary_array(cols);
|
||||
size_t intermediate_num_cuts = std::min(
|
||||
accumulated_rows, static_cast<size_t>(batch_param_.max_bin *
|
||||
common::SketchContainer::kFactor));
|
||||
size_t nbytes =
|
||||
common::WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts);
|
||||
#pragma omp parallel for num_threads(nthread) if (nthread > 0)
|
||||
for (omp_ulong c = 0; c < cols; ++c) {
|
||||
for (auto& sketch_batch : sketch_containers) {
|
||||
common::WQSketch::SummaryContainer summary;
|
||||
sketch_batch.sketches_.at(c).GetSummary(&summary);
|
||||
sketch_batch.sketches_.at(c).Init(0, 1);
|
||||
summary_array.at(c).Reduce(summary, nbytes);
|
||||
}
|
||||
}
|
||||
sketch_containers.clear();
|
||||
|
||||
// Build the final summary.
|
||||
std::vector<common::WQSketch> sketches(cols);
|
||||
#pragma omp parallel for num_threads(nthread) if (nthread > 0)
|
||||
for (omp_ulong c = 0; c < cols; ++c) {
|
||||
sketches.at(c).Init(
|
||||
accumulated_rows,
|
||||
1.0 / (common::SketchContainer::kFactor * batch_param_.max_bin));
|
||||
sketches.at(c).PushSummary(summary_array.at(c));
|
||||
}
|
||||
dense_cuts.Init(&sketches, batch_param_.max_bin, accumulated_rows);
|
||||
summary_array.clear();
|
||||
|
||||
this->info_.num_col_ = cols;
|
||||
this->info_.num_row_ = accumulated_rows;
|
||||
this->info_.num_nonzero_ = nnz;
|
||||
|
||||
// Construct the final ellpack page.
|
||||
page_.reset(new EllpackPage);
|
||||
*(page_->Impl()) = EllpackPageImpl(proxy->DeviceIdx(), cuts, this->IsDense(),
|
||||
row_stride, accumulated_rows);
|
||||
|
||||
size_t offset = 0;
|
||||
iter.Reset();
|
||||
while (iter.Next()) {
|
||||
auto device = proxy->DeviceIdx();
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
auto rows = num_rows();
|
||||
dh::caching_device_vector<size_t> row_counts(rows + 1, 0);
|
||||
common::Span<size_t> row_counts_span(row_counts.data().get(),
|
||||
row_counts.size());
|
||||
Dispatch(proxy, [=](auto const& value) {
|
||||
return GetRowCounts(value, row_counts_span, device, missing);
|
||||
});
|
||||
auto is_dense = this->IsDense();
|
||||
auto new_impl = Dispatch(proxy, [&](auto const &value) {
|
||||
return EllpackPageImpl(value, missing, device, is_dense, nthread,
|
||||
row_counts_span, row_stride, rows, cols, cuts);
|
||||
});
|
||||
size_t num_elements = page_->Impl()->Copy(device, &new_impl, offset);
|
||||
offset += num_elements;
|
||||
|
||||
proxy->Info().num_row_ = num_rows();
|
||||
proxy->Info().num_col_ = cols;
|
||||
if (batches != 1) {
|
||||
this->info_.Extend(std::move(proxy->Info()), false);
|
||||
}
|
||||
}
|
||||
|
||||
if (batches == 1) {
|
||||
this->info_ = std::move(proxy->Info());
|
||||
CHECK_EQ(proxy->Info().labels_.Size(), 0);
|
||||
}
|
||||
|
||||
iter.Reset();
|
||||
// Synchronise worker columns
|
||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
||||
}
|
||||
|
||||
BatchSet<EllpackPage> IterativeDeviceDMatrix::GetEllpackBatches(const BatchParam& param) {
|
||||
CHECK(page_);
|
||||
auto begin_iter =
|
||||
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(page_.get()));
|
||||
return BatchSet<EllpackPage>(begin_iter);
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
76
src/data/iterative_device_dmatrix.h
Normal file
76
src/data/iterative_device_dmatrix.h
Normal file
@@ -0,0 +1,76 @@
|
||||
/*!
|
||||
* Copyright 2020 by Contributors
|
||||
* \file iterative_device_dmatrix.h
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_ITERATIVE_DEVICE_DMATRIX_H_
|
||||
#define XGBOOST_DATA_ITERATIVE_DEVICE_DMATRIX_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "proxy_dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
class IterativeDeviceDMatrix : public DMatrix {
|
||||
MetaInfo info_;
|
||||
BatchParam batch_param_;
|
||||
std::shared_ptr<EllpackPage> page_;
|
||||
|
||||
DMatrixHandle proxy_;
|
||||
DataIterResetCallback *reset_;
|
||||
XGDMatrixCallbackNext *next_;
|
||||
|
||||
public:
|
||||
void Initialize(DataIterHandle iter, float missing, int nthread);
|
||||
|
||||
public:
|
||||
explicit IterativeDeviceDMatrix(DataIterHandle iter, DMatrixHandle proxy,
|
||||
DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, float missing,
|
||||
int nthread, int max_bin)
|
||||
: proxy_{proxy}, reset_{reset}, next_{next} {
|
||||
batch_param_ = BatchParam{0, max_bin, 0};
|
||||
this->Initialize(iter, missing, nthread);
|
||||
}
|
||||
|
||||
bool EllpackExists() const override { return true; }
|
||||
bool SparsePageExists() const override { return false; }
|
||||
DMatrix *Slice(common::Span<int32_t const> ridxs) override {
|
||||
LOG(FATAL) << "Slicing DMatrix is not supported for Device DMatrix.";
|
||||
return nullptr;
|
||||
}
|
||||
BatchSet<SparsePage> GetRowBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
|
||||
}
|
||||
BatchSet<CSCPage> GetColumnBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<CSCPage>(BatchIterator<CSCPage>(nullptr));
|
||||
}
|
||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(nullptr));
|
||||
}
|
||||
|
||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
|
||||
|
||||
bool SingleColBlock() const override { return false; }
|
||||
|
||||
MetaInfo& Info() override {
|
||||
return info_;
|
||||
}
|
||||
MetaInfo const& Info() const override {
|
||||
return info_;
|
||||
}
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_DATA_ITERATIVE_DEVICE_DMATRIX_H_
|
||||
Reference in New Issue
Block a user