Obtain CSR matrix from DMatrix. (#8269)

This commit is contained in:
Jiaming Yuan
2022-09-29 20:41:43 +08:00
committed by GitHub
parent b14c44ee5e
commit 55cf24cc32
22 changed files with 400 additions and 74 deletions

View File

@@ -164,34 +164,30 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
return values[gidx];
}
auto lower = static_cast<bst_bin_t>(cut.Ptrs()[fidx]);
auto get_bin_idx = [&](auto &column) {
auto get_bin_val = [&](auto &column) {
auto bin_idx = column[ridx];
if (bin_idx == common::DenseColumnIter<uint8_t, true>::kMissingId) {
return std::numeric_limits<float>::quiet_NaN();
}
if (bin_idx == lower) {
return mins[fidx];
}
return values[bin_idx - 1];
return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx);
};
if (columns_->GetColumnType(fidx) == common::kDenseColumn) {
if (columns_->AnyMissing()) {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->DenseColumn<decltype(dtype), true>(fidx);
return get_bin_idx(column);
return get_bin_val(column);
});
} else {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->DenseColumn<decltype(dtype), false>(fidx);
return get_bin_idx(column);
return get_bin_val(column);
});
}
} else {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->SparseColumn<decltype(dtype)>(fidx, 0);
return get_bin_idx(column);
return get_bin_val(column);
});
}

View File

@@ -6,6 +6,8 @@
#define XGBOOST_DATA_GRADIENT_INDEX_H_
#include <algorithm> // std::min
#include <cinttypes> // std::uint32_t
#include <cstddef> // std::size_t
#include <memory>
#include <vector>
@@ -229,6 +231,53 @@ class GHistIndexMatrix {
bool isDense_;
};
/**
* \brief Helper for recovering feature index from row-based storage of histogram
* bin. (`GHistIndexMatrix`).
*
* \param assign A callback function that takes bin index, index into the whole batch, row
* index and feature index
*/
template <typename Fn>
void AssignColumnBinIndex(GHistIndexMatrix const& page, Fn&& assign) {
auto const batch_size = page.Size();
auto const& ptrs = page.cut.Ptrs();
std::size_t k{0};
auto dense = page.IsDense();
common::DispatchBinType(page.index.GetBinTypeSize(), [&](auto t) {
using BinT = decltype(t);
auto const& index = page.index;
for (std::size_t ridx = 0; ridx < batch_size; ++ridx) {
auto r_beg = page.row_ptr[ridx];
auto r_end = page.row_ptr[ridx + 1];
bst_feature_t fidx{0};
if (dense) {
// compressed, use the operator to obtain the true value.
for (std::size_t j = r_beg; j < r_end; ++j) {
bst_feature_t fidx = j - r_beg;
std::uint32_t bin_idx = index[k];
assign(bin_idx, k, ridx, fidx);
++k;
}
} else {
// not compressed
auto const* row_index = index.data<BinT>() + page.row_ptr[page.base_rowid];
for (std::size_t j = r_beg; j < r_end; ++j) {
std::uint32_t bin_idx = row_index[k];
// find the feature index for current bin.
while (bin_idx >= ptrs[fidx + 1]) {
fidx++;
}
assign(bin_idx, k, ridx, fidx);
++k;
}
}
}
});
}
/**
* \brief Should we regenerate the gradient index?
*

View File

@@ -5,16 +5,18 @@
#include <rabit/rabit.h>
#include <algorithm> // std::copy
#include "../common/categorical.h" // common::IsCat
#include "../common/column_matrix.h"
#include "../common/hist_util.h"
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
#include "../common/hist_util.h" // common::HistogramCuts
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
#include "gradient_index.h"
#include "proxy_dmatrix.h"
#include "simple_batch_iterator.h"
namespace xgboost {
namespace data {
IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
std::shared_ptr<DMatrix> ref, DataIterResetCallback* reset,
XGDMatrixCallbackNext* next, float missing, int nthread,
@@ -144,7 +146,6 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
} else {
CHECK_EQ(n_features, num_cols()) << "Inconsistent number of columns.";
}
size_t batch_size = num_rows();
batch_nnz.push_back(nnz_cnt());
nnz += batch_nnz.back();
@@ -161,6 +162,8 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
return f > accumulated_rows;
})) << "Something went wrong during iteration.";
CHECK_GE(n_features, 1) << "Data must has at least 1 column.";
/**
* Generate quantiles
*/
@@ -249,9 +252,47 @@ BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const&
LOG(WARNING) << "`sparse_threshold` can not be changed when `QuantileDMatrix` is used instead "
"of `DMatrix`.";
}
auto begin_iter =
BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_));
return BatchSet<GHistIndexMatrix>(begin_iter);
}
BatchSet<ExtSparsePage> IterativeDMatrix::GetExtBatches(BatchParam const& param) {
for (auto const& page : this->GetGradientIndex(param)) {
auto p_out = std::make_shared<SparsePage>();
p_out->data.Resize(this->Info().num_nonzero_);
p_out->offset.Resize(this->Info().num_row_ + 1);
auto& h_offset = p_out->offset.HostVector();
CHECK_EQ(page.row_ptr.size(), h_offset.size());
std::copy(page.row_ptr.cbegin(), page.row_ptr.cend(), h_offset.begin());
auto& h_data = p_out->data.HostVector();
auto const& vals = page.cut.Values();
auto const& mins = page.cut.MinValues();
auto const& ptrs = page.cut.Ptrs();
auto ft = Info().feature_types.ConstHostSpan();
AssignColumnBinIndex(page, [&](auto bin_idx, std::size_t idx, std::size_t, bst_feature_t fidx) {
float v;
if (common::IsCat(ft, fidx)) {
v = vals[bin_idx];
} else {
v = common::HistogramCuts::NumericBinValue(ptrs, vals, mins, fidx, bin_idx);
}
h_data[idx] = Entry{fidx, v};
});
auto p_ext_out = std::make_shared<ExtSparsePage>(p_out);
auto begin_iter =
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(p_ext_out));
return BatchSet<ExtSparsePage>(begin_iter);
}
LOG(FATAL) << "Unreachable";
auto begin_iter =
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(nullptr));
return BatchSet<ExtSparsePage>(begin_iter);
}
} // namespace data
} // namespace xgboost

View File

@@ -88,6 +88,9 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
} while (iter.Next());
iter.Reset();
auto n_features = cols;
CHECK_GE(n_features, 1) << "Data must has at least 1 column.";
dh::safe_cuda(cudaSetDevice(get_device()));
if (!ref) {
HostDeviceVector<FeatureType> ft;

View File

@@ -97,6 +97,7 @@ class IterativeDMatrix : public DMatrix {
BatchSet<GHistIndexMatrix> GetGradientIndex(BatchParam const &param) override;
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam &param) override;
BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) override;
bool SingleColBlock() const override { return true; }
@@ -117,15 +118,14 @@ void GetCutsFromRef(std::shared_ptr<DMatrix> ref_, bst_feature_t n_features, Bat
void GetCutsFromEllpack(EllpackPage const &page, common::HistogramCuts *cuts);
#if !defined(XGBOOST_USE_CUDA)
inline void IterativeDMatrix::InitFromCUDA(DataIterHandle iter, float missing,
std::shared_ptr<DMatrix> ref) {
inline void IterativeDMatrix::InitFromCUDA(DataIterHandle, float, std::shared_ptr<DMatrix>) {
// silent the warning about unused variables.
(void)(proxy_);
(void)(reset_);
(void)(next_);
common::AssertGPUSupport();
}
inline BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(const BatchParam &param) {
inline BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(const BatchParam &) {
common::AssertGPUSupport();
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));

View File

@@ -57,6 +57,7 @@ class DMatrixProxy : public DMatrix {
void SetCUDAArray(char const* c_interface) {
common::AssertGPUSupport();
CHECK(c_interface);
#if defined(XGBOOST_USE_CUDA)
StringView interface_str{c_interface};
Json json_array_interface = Json::Load(interface_str);
@@ -106,7 +107,10 @@ class DMatrixProxy : public DMatrix {
LOG(FATAL) << "Not implemented.";
return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(nullptr));
}
BatchSet<ExtSparsePage> GetExtBatches(BatchParam const&) override {
LOG(FATAL) << "Not implemented.";
return BatchSet<ExtSparsePage>(BatchIterator<ExtSparsePage>(nullptr));
}
dmlc::any Adapter() const {
return batch_;
}

View File

@@ -114,6 +114,14 @@ BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& par
return BatchSet<GHistIndexMatrix>(begin_iter);
}
BatchSet<ExtSparsePage> SimpleDMatrix::GetExtBatches(BatchParam const&) {
auto casted = std::make_shared<ExtSparsePage>(sparse_page_);
CHECK(casted);
auto begin_iter =
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(casted));
return BatchSet<ExtSparsePage>(begin_iter);
}
template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
this->ctx_.nthread = nthread;

View File

@@ -45,6 +45,7 @@ class SimpleDMatrix : public DMatrix {
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) override;
BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) override;
MetaInfo info_;
// Primary storage type

View File

@@ -114,6 +114,10 @@ class SparsePageDMatrix : public DMatrix {
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam&) override;
BatchSet<ExtSparsePage> GetExtBatches(BatchParam const &) override {
LOG(FATAL) << "Can not obtain a single CSR page for external memory DMatrix";
return BatchSet<ExtSparsePage>(BatchIterator<ExtSparsePage>(nullptr));
}
// source data pointers.
std::shared_ptr<SparsePageSource> sparse_page_source_;

View File

@@ -210,9 +210,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
#if defined(XGBOOST_USE_CUDA)
void DevicePush(DMatrixProxy* proxy, float missing, SparsePage* page);
#else
inline void DevicePush(DMatrixProxy* proxy, float missing, SparsePage* page) {
common::AssertGPUSupport();
}
inline void DevicePush(DMatrixProxy*, float, SparsePage*) { common::AssertGPUSupport(); }
#endif
class SparsePageSource : public SparsePageSourceImpl<SparsePage> {