@@ -201,7 +201,7 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
|
||||
|
||||
// Returns maximum row length
|
||||
template <typename AdapterBatchT>
|
||||
size_t GetRowCounts(const AdapterBatchT& batch, common::Span<size_t> offset,
|
||||
size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
|
||||
int device_idx, float missing) {
|
||||
IsValidFunctor is_valid(missing);
|
||||
// Count elements per row
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2020 by Contributors
|
||||
* \file device_dmatrix.cu
|
||||
* \brief Device-memory version of DMatrix.
|
||||
*/
|
||||
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
#include <thrust/iterator/transform_output_iterator.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "../common/hist_util.h"
|
||||
#include "adapter.h"
|
||||
#include "device_adapter.cuh"
|
||||
#include "ellpack_page.cuh"
|
||||
#include "device_dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
// Does not currently support metainfo as no on-device data source contains this
|
||||
// Current implementation assumes a single batch. More batches can
|
||||
// be supported in future. Does not currently support inferring row/column size
|
||||
template <typename AdapterT>
|
||||
DeviceDMatrix::DeviceDMatrix(AdapterT* adapter, float missing, int nthread, int max_bin) {
|
||||
dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx()));
|
||||
auto& batch = adapter->Value();
|
||||
// Work out how many valid entries we have in each row
|
||||
dh::caching_device_vector<size_t> row_counts(adapter->NumRows() + 1, 0);
|
||||
common::Span<size_t> row_counts_span(row_counts.data().get(),
|
||||
row_counts.size());
|
||||
size_t row_stride =
|
||||
GetRowCounts(batch, row_counts_span, adapter->DeviceIdx(), missing);
|
||||
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
info_.num_nonzero_ = thrust::reduce(thrust::cuda::par(alloc),
|
||||
row_counts.begin(), row_counts.end());
|
||||
info_.num_col_ = adapter->NumColumns();
|
||||
info_.num_row_ = adapter->NumRows();
|
||||
|
||||
ellpack_page_.reset(new EllpackPage());
|
||||
*ellpack_page_->Impl() =
|
||||
EllpackPageImpl(adapter, missing, this->IsDense(), nthread, max_bin,
|
||||
row_counts_span, row_stride);
|
||||
|
||||
// Synchronise worker columns
|
||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
||||
}
|
||||
|
||||
#define DEVICE_DMARIX_SPECIALIZATION(__ADAPTER_T) \
|
||||
template DeviceDMatrix::DeviceDMatrix(__ADAPTER_T* adapter, float missing, \
|
||||
int nthread, int max_bin);
|
||||
|
||||
DEVICE_DMARIX_SPECIALIZATION(CudfAdapter);
|
||||
DEVICE_DMARIX_SPECIALIZATION(CupyAdapter);
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
@@ -1,64 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2020 by Contributors
|
||||
* \file device_dmatrix.h
|
||||
* \brief Device-memory version of DMatrix.
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_DEVICE_DMATRIX_H_
|
||||
#define XGBOOST_DATA_DEVICE_DMATRIX_H_
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "adapter.h"
|
||||
#include "simple_batch_iterator.h"
|
||||
#include "simple_dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
class DeviceDMatrix : public DMatrix {
|
||||
public:
|
||||
template <typename AdapterT>
|
||||
explicit DeviceDMatrix(AdapterT* adapter, float missing, int nthread, int max_bin);
|
||||
|
||||
MetaInfo& Info() override { return info_; }
|
||||
|
||||
const MetaInfo& Info() const override { return info_; }
|
||||
|
||||
bool SingleColBlock() const override { return true; }
|
||||
|
||||
bool EllpackExists() const override { return true; }
|
||||
bool SparsePageExists() const override { return false; }
|
||||
DMatrix *Slice(common::Span<int32_t const> ridxs) override {
|
||||
LOG(FATAL) << "Slicing DMatrix is not supported for Device DMatrix.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
BatchSet<SparsePage> GetRowBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
|
||||
}
|
||||
BatchSet<CSCPage> GetColumnBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<CSCPage>(BatchIterator<CSCPage>(nullptr));
|
||||
}
|
||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(nullptr));
|
||||
}
|
||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override {
|
||||
auto begin_iter = BatchIterator<EllpackPage>(
|
||||
new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
|
||||
return BatchSet<EllpackPage>(begin_iter);
|
||||
}
|
||||
|
||||
MetaInfo info_;
|
||||
// source data pointer.
|
||||
std::unique_ptr<EllpackPage> ellpack_page_;
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_DEVICE_DMATRIX_H_
|
||||
@@ -93,6 +93,11 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
batches++;
|
||||
}
|
||||
|
||||
if (device < 0) { // error or empty
|
||||
this->page_.reset(new EllpackPage);
|
||||
return;
|
||||
}
|
||||
|
||||
common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, device);
|
||||
for (auto const& sketch : sketch_containers) {
|
||||
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
|
||||
@@ -108,14 +113,23 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
this->info_.num_row_ = accumulated_rows;
|
||||
this->info_.num_nonzero_ = nnz;
|
||||
|
||||
// Construct the final ellpack page.
|
||||
page_.reset(new EllpackPage);
|
||||
*(page_->Impl()) = EllpackPageImpl(proxy->DeviceIdx(), cuts, this->IsDense(),
|
||||
row_stride, accumulated_rows);
|
||||
auto init_page = [this, &proxy, &cuts, row_stride, accumulated_rows]() {
|
||||
if (!page_) {
|
||||
// Should be put inside the while loop to protect against empty batch. In
|
||||
// that case device id is invalid.
|
||||
page_.reset(new EllpackPage);
|
||||
*(page_->Impl()) =
|
||||
EllpackPageImpl(proxy->DeviceIdx(), cuts, this->IsDense(), row_stride,
|
||||
accumulated_rows);
|
||||
}
|
||||
};
|
||||
|
||||
// Construct the final ellpack page.
|
||||
size_t offset = 0;
|
||||
iter.Reset();
|
||||
size_t n_batches_for_verification = 0;
|
||||
while (iter.Next()) {
|
||||
init_page();
|
||||
auto device = proxy->DeviceIdx();
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
auto rows = num_rows();
|
||||
@@ -138,7 +152,10 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
if (batches != 1) {
|
||||
this->info_.Extend(std::move(proxy->Info()), false);
|
||||
}
|
||||
n_batches_for_verification++;
|
||||
}
|
||||
CHECK_EQ(batches, n_batches_for_verification)
|
||||
<< "Different number of batches returned between 2 iterations";
|
||||
|
||||
if (batches == 1) {
|
||||
this->info_ = std::move(proxy->Info());
|
||||
|
||||
Reference in New Issue
Block a user