Obtain CSR matrix from DMatrix. (#8269)

This commit is contained in:
Jiaming Yuan 2022-09-29 20:41:43 +08:00 committed by GitHub
parent b14c44ee5e
commit 55cf24cc32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 400 additions and 74 deletions

View File

@ -761,6 +761,39 @@ XGB_DLL int XGDMatrixNumRow(DMatrixHandle handle,
*/ */
XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle, XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle,
bst_ulong *out); bst_ulong *out);
/*!
* \brief Get number of valid values from DMatrix.
*
* \param handle the handle to the DMatrix
* \param out The output of number of non-missing values
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out);
/*!
* \brief Get the predictors from DMatrix as CSR matrix for testing. If this is a
* quantized DMatrix, quantized values are returned instead.
*
* Unlike most of XGBoost C functions, caller of `XGDMatrixGetDataAsCSR` is required to
* allocate the memory for return buffer instead of using thread local memory from
* XGBoost. This is to avoid allocating a huge memory buffer that can not be freed until
* exiting the thread.
*
* \param handle the handle to the DMatrix
* \param config Json configuration string. At the moment it should be an empty document,
* preserved for future use.
* \param out_indptr indptr of output CSR matrix.
* \param out_indices Column index of output CSR matrix.
* \param out_data Data value of CSR matrix.
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
bst_ulong *out_indptr, unsigned *out_indices, float *out_data);
// --- start XGBoost class // --- start XGBoost class
/*! /*!
* \brief create xgboost learner * \brief create xgboost learner

View File

@ -284,12 +284,17 @@ class SparsePage {
return {offset.ConstHostSpan(), data.ConstHostSpan()}; return {offset.ConstHostSpan(), data.ConstHostSpan()};
} }
/*! \brief constructor */ /*! \brief constructor */
SparsePage() { SparsePage() {
this->Clear(); this->Clear();
} }
SparsePage(SparsePage const& that) = delete;
SparsePage(SparsePage&& that) = default;
SparsePage& operator=(SparsePage const& that) = delete;
SparsePage& operator=(SparsePage&& that) = default;
virtual ~SparsePage() = default;
/*! \return Number of instances in the page. */ /*! \return Number of instances in the page. */
inline size_t Size() const { inline size_t Size() const {
return offset.Size() == 0 ? 0 : offset.Size() - 1; return offset.Size() == 0 ? 0 : offset.Size() - 1;
@ -358,6 +363,16 @@ class CSCPage: public SparsePage {
explicit CSCPage(SparsePage page) : SparsePage(std::move(page)) {} explicit CSCPage(SparsePage page) : SparsePage(std::move(page)) {}
}; };
/**
* \brief Sparse page for exporting DMatrix. Same as SparsePage, just a different type to
* prevent being used internally.
*/
class ExtSparsePage {
public:
std::shared_ptr<SparsePage const> page;
explicit ExtSparsePage(std::shared_ptr<SparsePage const> p) : page{std::move(p)} {}
};
class SortedCSCPage : public SparsePage { class SortedCSCPage : public SparsePage {
public: public:
SortedCSCPage() : SparsePage() {} SortedCSCPage() : SparsePage() {}
@ -610,6 +625,7 @@ class DMatrix {
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0; virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0; virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) = 0; virtual BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) = 0;
virtual BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) = 0;
virtual bool EllpackExists() const = 0; virtual bool EllpackExists() const = 0;
virtual bool GHistIndexExists() const = 0; virtual bool GHistIndexExists() const = 0;
@ -651,10 +667,15 @@ inline BatchSet<EllpackPage> DMatrix::GetBatches(const BatchParam& param) {
return GetEllpackBatches(param); return GetEllpackBatches(param);
} }
template<> template <>
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(const BatchParam& param) { inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(const BatchParam& param) {
return GetGradientIndex(param); return GetGradientIndex(param);
} }
template <>
inline BatchSet<ExtSparsePage> DMatrix::GetBatches() {
return GetExtBatches(BatchParam{});
}
} // namespace xgboost } // namespace xgboost
namespace dmlc { namespace dmlc {

View File

@ -609,7 +609,7 @@ def _deprecate_positional_args(f: Callable[..., _T]) -> Callable[..., _T]:
return inner_f return inner_f
class DMatrix: # pylint: disable=too-many-instance-attributes class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods
"""Data Matrix used in XGBoost. """Data Matrix used in XGBoost.
DMatrix is an internal data structure that is used by XGBoost, DMatrix is an internal data structure that is used by XGBoost,
@ -1015,29 +1015,49 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
group_ptr = self.get_uint_info("group_ptr") group_ptr = self.get_uint_info("group_ptr")
return np.diff(group_ptr) return np.diff(group_ptr)
def num_row(self) -> int: def get_data(self) -> scipy.sparse.csr_matrix:
"""Get the number of rows in the DMatrix. """Get the predictors from DMatrix as a CSR matrix. This getter is mostly for
testing purposes. If this is a quantized DMatrix then quantized values are
returned instead of input values.
.. versionadded:: 2.0.0
Returns
-------
number of rows : int
""" """
indptr = np.empty(self.num_row() + 1, dtype=np.uint64)
indices = np.empty(self.num_nonmissing(), dtype=np.uint32)
data = np.empty(self.num_nonmissing(), dtype=np.float32)
c_indptr = indptr.ctypes.data_as(ctypes.POINTER(c_bst_ulong))
c_indices = indices.ctypes.data_as(ctypes.POINTER(ctypes.c_uint32))
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
config = from_pystr_to_cstr(json.dumps({}))
_check_call(
_LIB.XGDMatrixGetDataAsCSR(self.handle, config, c_indptr, c_indices, c_data)
)
ret = scipy.sparse.csr_matrix(
(data, indices, indptr), shape=(self.num_row(), self.num_col())
)
return ret
def num_row(self) -> int:
"""Get the number of rows in the DMatrix."""
ret = c_bst_ulong() ret = c_bst_ulong()
_check_call(_LIB.XGDMatrixNumRow(self.handle, _check_call(_LIB.XGDMatrixNumRow(self.handle, ctypes.byref(ret)))
ctypes.byref(ret)))
return ret.value return ret.value
def num_col(self) -> int: def num_col(self) -> int:
"""Get the number of columns (features) in the DMatrix. """Get the number of columns (features) in the DMatrix."""
Returns
-------
number of columns
"""
ret = c_bst_ulong() ret = c_bst_ulong()
_check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret))) _check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret)))
return ret.value return ret.value
def num_nonmissing(self) -> int:
"""Get the number of non-missing values in the DMatrix."""
ret = c_bst_ulong()
_check_call(_LIB.XGDMatrixNumNonMissing(self.handle, ctypes.byref(ret)))
return ret.value
def slice( def slice(
self, rindex: Union[List[int], np.ndarray], allow_groups: bool = False self, rindex: Union[List[int], np.ndarray], allow_groups: bool = False
) -> "DMatrix": ) -> "DMatrix":

View File

@ -684,9 +684,9 @@ XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
xgboost::bst_ulong *out) { xgboost::bst_ulong *out) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
auto p_m = CastDMatrixHandle(handle);
xgboost_CHECK_C_ARG_PTR(out); xgboost_CHECK_C_ARG_PTR(out);
*out = static_cast<xgboost::bst_ulong>( *out = static_cast<xgboost::bst_ulong>(p_m->Info().num_row_);
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_row_);
API_END(); API_END();
} }
@ -694,9 +694,52 @@ XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
xgboost::bst_ulong *out) { xgboost::bst_ulong *out) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
auto p_m = CastDMatrixHandle(handle);
xgboost_CHECK_C_ARG_PTR(out); xgboost_CHECK_C_ARG_PTR(out);
*out = static_cast<xgboost::bst_ulong>( *out = static_cast<xgboost::bst_ulong>(p_m->Info().num_col_);
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_col_); API_END();
}
// We name the function non-missing instead of non-zero since zero is perfectly valid for XGBoost.
XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle const handle, xgboost::bst_ulong *out) {
API_BEGIN();
CHECK_HANDLE();
auto p_m = CastDMatrixHandle(handle);
xgboost_CHECK_C_ARG_PTR(out);
*out = static_cast<xgboost::bst_ulong>(p_m->Info().num_nonzero_);
API_END();
}
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
xgboost::bst_ulong *out_indptr, unsigned *out_indices,
float *out_data) {
API_BEGIN();
CHECK_HANDLE();
xgboost_CHECK_C_ARG_PTR(config);
auto jconfig = Json::Load(StringView{config});
auto p_m = CastDMatrixHandle(handle);
xgboost_CHECK_C_ARG_PTR(out_indptr);
xgboost_CHECK_C_ARG_PTR(out_indices);
xgboost_CHECK_C_ARG_PTR(out_data);
CHECK_LE(p_m->Info().num_col_, std::numeric_limits<unsigned>::max());
for (auto const &page : p_m->GetBatches<ExtSparsePage>()) {
CHECK(page.page);
auto const &h_offset = page.page->offset.ConstHostVector();
std::copy(h_offset.cbegin(), h_offset.cend(), out_indptr);
auto pv = page.page->GetView();
common::ParallelFor(page.page->data.Size(), p_m->Ctx()->Threads(), [&](std::size_t i) {
auto fvalue = pv.data[i].fvalue;
auto findex = pv.data[i].index;
out_data[i] = fvalue;
out_indices[i] = findex;
});
}
API_END(); API_END();
} }

View File

@ -6,14 +6,16 @@
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <vector> #include <memory> // std::shared_ptr
#include <memory>
#include <string> #include <string>
#include <vector>
#include "xgboost/logging.h" #include "xgboost/c_api.h"
#include "xgboost/data.h" // DMatrix
#include "xgboost/json.h" #include "xgboost/json.h"
#include "xgboost/learner.h" #include "xgboost/learner.h"
#include "xgboost/c_api.h" #include "xgboost/logging.h"
#include "xgboost/string_view.h" // StringView
namespace xgboost { namespace xgboost {
/* \brief Determine the output shape of prediction. /* \brief Determine the output shape of prediction.
@ -259,5 +261,17 @@ auto const &OptionalArg(Json const &in, std::string const &key, T const &dft) {
} }
return dft; return dft;
} }
/**
* \brief Get shared ptr from DMatrix C handle with additional checks.
*/
inline std::shared_ptr<DMatrix> CastDMatrixHandle(DMatrixHandle const handle) {
auto pp_m = static_cast<std::shared_ptr<DMatrix> *>(handle);
StringView msg{"Invalid DMatrix handle"};
CHECK(pp_m) << msg;
auto p_m = *pp_m;
CHECK(p_m) << msg;
return p_m;
}
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_C_API_C_API_UTILS_H_ #endif // XGBOOST_C_API_C_API_UTILS_H_

View File

@ -316,30 +316,16 @@ class ColumnMatrix {
void SetIndexMixedColumns(const GHistIndexMatrix& gmat) { void SetIndexMixedColumns(const GHistIndexMatrix& gmat) {
auto n_features = gmat.Features(); auto n_features = gmat.Features();
missing_flags_.resize(feature_offsets_[n_features], true); missing_flags_.resize(feature_offsets_[n_features], true);
auto const* row_index = gmat.index.data<uint32_t>() + gmat.row_ptr[gmat.base_rowid];
num_nonzeros_.resize(n_features, 0); num_nonzeros_.resize(n_features, 0);
auto const& ptrs = gmat.cut.Ptrs();
DispatchBinType(bins_type_size_, [&](auto t) { DispatchBinType(bins_type_size_, [&](auto t) {
using ColumnBinT = decltype(t); using ColumnBinT = decltype(t);
ColumnBinT* local_index = reinterpret_cast<ColumnBinT*>(index_.data()); ColumnBinT* local_index = reinterpret_cast<ColumnBinT*>(index_.data());
auto const batch_size = gmat.Size(); CHECK(this->any_missing_);
size_t k{0}; AssignColumnBinIndex(gmat,
[&](auto bin_idx, std::size_t, std::size_t ridx, bst_feature_t fidx) {
for (size_t ridx = 0; ridx < batch_size; ++ridx) { SetBinSparse(bin_idx, ridx, fidx, local_index);
auto r_beg = gmat.row_ptr[ridx]; });
auto r_end = gmat.row_ptr[ridx + 1];
bst_feature_t fidx{0};
for (size_t j = r_beg; j < r_end; ++j) {
const uint32_t bin_idx = row_index[k];
// find the feature index for current bin.
while (bin_idx >= ptrs[fidx + 1]) {
fidx++;
}
SetBinSparse(bin_idx, ridx, fidx, local_index);
++k;
}
}
}); });
} }

View File

@ -149,6 +149,19 @@ class HistogramCuts {
return this->SearchCatBin(value, fidx, ptrs, vals); return this->SearchCatBin(value, fidx, ptrs, vals);
} }
bst_bin_t SearchCatBin(Entry const& e) const { return SearchCatBin(e.fvalue, e.index); } bst_bin_t SearchCatBin(Entry const& e) const { return SearchCatBin(e.fvalue, e.index); }
/**
* \brief Return numerical bin value given bin index.
*/
static float NumericBinValue(std::vector<std::uint32_t> const& ptrs,
std::vector<float> const& vals, std::vector<float> const& mins,
bst_feature_t fidx, bst_bin_t bin_idx) {
auto lower = static_cast<bst_bin_t>(ptrs[fidx]);
if (bin_idx == lower) {
return mins[fidx];
}
return vals[bin_idx - 1];
}
}; };
/** /**

View File

@ -164,34 +164,30 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
return values[gidx]; return values[gidx];
} }
auto lower = static_cast<bst_bin_t>(cut.Ptrs()[fidx]); auto get_bin_val = [&](auto &column) {
auto get_bin_idx = [&](auto &column) {
auto bin_idx = column[ridx]; auto bin_idx = column[ridx];
if (bin_idx == common::DenseColumnIter<uint8_t, true>::kMissingId) { if (bin_idx == common::DenseColumnIter<uint8_t, true>::kMissingId) {
return std::numeric_limits<float>::quiet_NaN(); return std::numeric_limits<float>::quiet_NaN();
} }
if (bin_idx == lower) { return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx);
return mins[fidx];
}
return values[bin_idx - 1];
}; };
if (columns_->GetColumnType(fidx) == common::kDenseColumn) { if (columns_->GetColumnType(fidx) == common::kDenseColumn) {
if (columns_->AnyMissing()) { if (columns_->AnyMissing()) {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) { return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->DenseColumn<decltype(dtype), true>(fidx); auto column = columns_->DenseColumn<decltype(dtype), true>(fidx);
return get_bin_idx(column); return get_bin_val(column);
}); });
} else { } else {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) { return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->DenseColumn<decltype(dtype), false>(fidx); auto column = columns_->DenseColumn<decltype(dtype), false>(fidx);
return get_bin_idx(column); return get_bin_val(column);
}); });
} }
} else { } else {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) { return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->SparseColumn<decltype(dtype)>(fidx, 0); auto column = columns_->SparseColumn<decltype(dtype)>(fidx, 0);
return get_bin_idx(column); return get_bin_val(column);
}); });
} }

View File

@ -6,6 +6,8 @@
#define XGBOOST_DATA_GRADIENT_INDEX_H_ #define XGBOOST_DATA_GRADIENT_INDEX_H_
#include <algorithm> // std::min #include <algorithm> // std::min
#include <cinttypes> // std::uint32_t
#include <cstddef> // std::size_t
#include <memory> #include <memory>
#include <vector> #include <vector>
@ -229,6 +231,53 @@ class GHistIndexMatrix {
bool isDense_; bool isDense_;
}; };
/**
* \brief Helper for recovering feature index from row-based storage of histogram
* bin. (`GHistIndexMatrix`).
*
* \param assign A callback function that takes bin index, index into the whole batch, row
* index and feature index
*/
template <typename Fn>
void AssignColumnBinIndex(GHistIndexMatrix const& page, Fn&& assign) {
auto const batch_size = page.Size();
auto const& ptrs = page.cut.Ptrs();
std::size_t k{0};
auto dense = page.IsDense();
common::DispatchBinType(page.index.GetBinTypeSize(), [&](auto t) {
using BinT = decltype(t);
auto const& index = page.index;
for (std::size_t ridx = 0; ridx < batch_size; ++ridx) {
auto r_beg = page.row_ptr[ridx];
auto r_end = page.row_ptr[ridx + 1];
bst_feature_t fidx{0};
if (dense) {
// compressed, use the operator to obtain the true value.
for (std::size_t j = r_beg; j < r_end; ++j) {
bst_feature_t fidx = j - r_beg;
std::uint32_t bin_idx = index[k];
assign(bin_idx, k, ridx, fidx);
++k;
}
} else {
// not compressed
auto const* row_index = index.data<BinT>() + page.row_ptr[page.base_rowid];
for (std::size_t j = r_beg; j < r_end; ++j) {
std::uint32_t bin_idx = row_index[k];
// find the feature index for current bin.
while (bin_idx >= ptrs[fidx + 1]) {
fidx++;
}
assign(bin_idx, k, ridx, fidx);
++k;
}
}
}
});
}
/** /**
* \brief Should we regenerate the gradient index? * \brief Should we regenerate the gradient index?
* *

View File

@ -5,16 +5,18 @@
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <algorithm> // std::copy
#include "../common/categorical.h" // common::IsCat
#include "../common/column_matrix.h" #include "../common/column_matrix.h"
#include "../common/hist_util.h" #include "../common/hist_util.h" // common::HistogramCuts
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter. #include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
#include "gradient_index.h" #include "gradient_index.h"
#include "proxy_dmatrix.h" #include "proxy_dmatrix.h"
#include "simple_batch_iterator.h" #include "simple_batch_iterator.h"
namespace xgboost { namespace xgboost {
namespace data { namespace data {
IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy, IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
std::shared_ptr<DMatrix> ref, DataIterResetCallback* reset, std::shared_ptr<DMatrix> ref, DataIterResetCallback* reset,
XGDMatrixCallbackNext* next, float missing, int nthread, XGDMatrixCallbackNext* next, float missing, int nthread,
@ -144,7 +146,6 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
} else { } else {
CHECK_EQ(n_features, num_cols()) << "Inconsistent number of columns."; CHECK_EQ(n_features, num_cols()) << "Inconsistent number of columns.";
} }
size_t batch_size = num_rows(); size_t batch_size = num_rows();
batch_nnz.push_back(nnz_cnt()); batch_nnz.push_back(nnz_cnt());
nnz += batch_nnz.back(); nnz += batch_nnz.back();
@ -161,6 +162,8 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
return f > accumulated_rows; return f > accumulated_rows;
})) << "Something went wrong during iteration."; })) << "Something went wrong during iteration.";
CHECK_GE(n_features, 1) << "Data must has at least 1 column.";
/** /**
* Generate quantiles * Generate quantiles
*/ */
@ -249,9 +252,47 @@ BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const&
LOG(WARNING) << "`sparse_threshold` can not be changed when `QuantileDMatrix` is used instead " LOG(WARNING) << "`sparse_threshold` can not be changed when `QuantileDMatrix` is used instead "
"of `DMatrix`."; "of `DMatrix`.";
} }
auto begin_iter = auto begin_iter =
BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_)); BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_));
return BatchSet<GHistIndexMatrix>(begin_iter); return BatchSet<GHistIndexMatrix>(begin_iter);
} }
BatchSet<ExtSparsePage> IterativeDMatrix::GetExtBatches(BatchParam const& param) {
for (auto const& page : this->GetGradientIndex(param)) {
auto p_out = std::make_shared<SparsePage>();
p_out->data.Resize(this->Info().num_nonzero_);
p_out->offset.Resize(this->Info().num_row_ + 1);
auto& h_offset = p_out->offset.HostVector();
CHECK_EQ(page.row_ptr.size(), h_offset.size());
std::copy(page.row_ptr.cbegin(), page.row_ptr.cend(), h_offset.begin());
auto& h_data = p_out->data.HostVector();
auto const& vals = page.cut.Values();
auto const& mins = page.cut.MinValues();
auto const& ptrs = page.cut.Ptrs();
auto ft = Info().feature_types.ConstHostSpan();
AssignColumnBinIndex(page, [&](auto bin_idx, std::size_t idx, std::size_t, bst_feature_t fidx) {
float v;
if (common::IsCat(ft, fidx)) {
v = vals[bin_idx];
} else {
v = common::HistogramCuts::NumericBinValue(ptrs, vals, mins, fidx, bin_idx);
}
h_data[idx] = Entry{fidx, v};
});
auto p_ext_out = std::make_shared<ExtSparsePage>(p_out);
auto begin_iter =
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(p_ext_out));
return BatchSet<ExtSparsePage>(begin_iter);
}
LOG(FATAL) << "Unreachable";
auto begin_iter =
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(nullptr));
return BatchSet<ExtSparsePage>(begin_iter);
}
} // namespace data } // namespace data
} // namespace xgboost } // namespace xgboost

View File

@ -88,6 +88,9 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
} while (iter.Next()); } while (iter.Next());
iter.Reset(); iter.Reset();
auto n_features = cols;
CHECK_GE(n_features, 1) << "Data must has at least 1 column.";
dh::safe_cuda(cudaSetDevice(get_device())); dh::safe_cuda(cudaSetDevice(get_device()));
if (!ref) { if (!ref) {
HostDeviceVector<FeatureType> ft; HostDeviceVector<FeatureType> ft;

View File

@ -97,6 +97,7 @@ class IterativeDMatrix : public DMatrix {
BatchSet<GHistIndexMatrix> GetGradientIndex(BatchParam const &param) override; BatchSet<GHistIndexMatrix> GetGradientIndex(BatchParam const &param) override;
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam &param) override; BatchSet<EllpackPage> GetEllpackBatches(const BatchParam &param) override;
BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) override;
bool SingleColBlock() const override { return true; } bool SingleColBlock() const override { return true; }
@ -117,15 +118,14 @@ void GetCutsFromRef(std::shared_ptr<DMatrix> ref_, bst_feature_t n_features, Bat
void GetCutsFromEllpack(EllpackPage const &page, common::HistogramCuts *cuts); void GetCutsFromEllpack(EllpackPage const &page, common::HistogramCuts *cuts);
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)
inline void IterativeDMatrix::InitFromCUDA(DataIterHandle iter, float missing, inline void IterativeDMatrix::InitFromCUDA(DataIterHandle, float, std::shared_ptr<DMatrix>) {
std::shared_ptr<DMatrix> ref) {
// silent the warning about unused variables. // silent the warning about unused variables.
(void)(proxy_); (void)(proxy_);
(void)(reset_); (void)(reset_);
(void)(next_); (void)(next_);
common::AssertGPUSupport(); common::AssertGPUSupport();
} }
inline BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(const BatchParam &param) { inline BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(const BatchParam &) {
common::AssertGPUSupport(); common::AssertGPUSupport();
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_)); auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter)); return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));

View File

@ -57,6 +57,7 @@ class DMatrixProxy : public DMatrix {
void SetCUDAArray(char const* c_interface) { void SetCUDAArray(char const* c_interface) {
common::AssertGPUSupport(); common::AssertGPUSupport();
CHECK(c_interface);
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
StringView interface_str{c_interface}; StringView interface_str{c_interface};
Json json_array_interface = Json::Load(interface_str); Json json_array_interface = Json::Load(interface_str);
@ -106,7 +107,10 @@ class DMatrixProxy : public DMatrix {
LOG(FATAL) << "Not implemented."; LOG(FATAL) << "Not implemented.";
return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(nullptr)); return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(nullptr));
} }
BatchSet<ExtSparsePage> GetExtBatches(BatchParam const&) override {
LOG(FATAL) << "Not implemented.";
return BatchSet<ExtSparsePage>(BatchIterator<ExtSparsePage>(nullptr));
}
dmlc::any Adapter() const { dmlc::any Adapter() const {
return batch_; return batch_;
} }

View File

@ -114,6 +114,14 @@ BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& par
return BatchSet<GHistIndexMatrix>(begin_iter); return BatchSet<GHistIndexMatrix>(begin_iter);
} }
BatchSet<ExtSparsePage> SimpleDMatrix::GetExtBatches(BatchParam const&) {
auto casted = std::make_shared<ExtSparsePage>(sparse_page_);
CHECK(casted);
auto begin_iter =
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(casted));
return BatchSet<ExtSparsePage>(begin_iter);
}
template <typename AdapterT> template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
this->ctx_.nthread = nthread; this->ctx_.nthread = nthread;

View File

@ -45,6 +45,7 @@ class SimpleDMatrix : public DMatrix {
BatchSet<SortedCSCPage> GetSortedColumnBatches() override; BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override; BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) override; BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) override;
BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) override;
MetaInfo info_; MetaInfo info_;
// Primary storage type // Primary storage type

View File

@ -114,6 +114,10 @@ class SparsePageDMatrix : public DMatrix {
BatchSet<SortedCSCPage> GetSortedColumnBatches() override; BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override; BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam&) override; BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam&) override;
BatchSet<ExtSparsePage> GetExtBatches(BatchParam const &) override {
LOG(FATAL) << "Can not obtain a single CSR page for external memory DMatrix";
return BatchSet<ExtSparsePage>(BatchIterator<ExtSparsePage>(nullptr));
}
// source data pointers. // source data pointers.
std::shared_ptr<SparsePageSource> sparse_page_source_; std::shared_ptr<SparsePageSource> sparse_page_source_;

View File

@ -210,9 +210,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
void DevicePush(DMatrixProxy* proxy, float missing, SparsePage* page); void DevicePush(DMatrixProxy* proxy, float missing, SparsePage* page);
#else #else
inline void DevicePush(DMatrixProxy* proxy, float missing, SparsePage* page) { inline void DevicePush(DMatrixProxy*, float, SparsePage*) { common::AssertGPUSupport(); }
common::AssertGPUSupport();
}
#endif #endif
class SparsePageSource : public SparsePageSourceImpl<SparsePage> { class SparsePageSource : public SparsePageSourceImpl<SparsePage> {

View File

@ -125,6 +125,7 @@ if __name__ == "__main__":
# tests # tests
"tests/python/test_config.py", "tests/python/test_config.py",
"tests/python/test_spark/", "tests/python/test_spark/",
"tests/python/test_quantile_dmatrix.py",
"tests/python-gpu/test_gpu_spark/", "tests/python-gpu/test_gpu_spark/",
"tests/ci_build/lint_python.py", "tests/ci_build/lint_python.py",
# demo # demo

View File

@ -2,6 +2,8 @@ import numpy as np
import xgboost as xgb import xgboost as xgb
import pytest import pytest
import sys import sys
from hypothesis import given, strategies, settings
from scipy import sparse
sys.path.append("tests/python") sys.path.append("tests/python")
import testing as tm import testing as tm
@ -96,3 +98,42 @@ class TestDeviceQuantileDMatrix:
import cupy as cp import cupy as cp
rng = cp.random.RandomState(1994) rng = cp.random.RandomState(1994)
self.cputest.run_ref_dmatrix(rng, "gpu_hist", False) self.cputest.run_ref_dmatrix(rng, "gpu_hist", False)
@given(
strategies.integers(1, 1000),
strategies.integers(1, 100),
strategies.fractions(0, 0.99),
)
@settings(print_blob=True, deadline=None)
def test_to_csr(self, n_samples, n_features, sparsity) -> None:
import cupy as cp
X, y = tm.make_sparse_regression(
n_samples, n_features, sparsity, False
)
h_X = X.astype(np.float32)
csr = h_X
h_X = X.toarray()
h_X[h_X == 0] = np.nan
h_m = xgb.QuantileDMatrix(data=h_X)
h_ret = h_m.get_data()
d_X = cp.array(h_X)
d_m = xgb.QuantileDMatrix(data=d_X, label=y)
d_ret = d_m.get_data()
np.testing.assert_equal(csr.indptr, d_ret.indptr)
np.testing.assert_equal(csr.indices, d_ret.indices)
np.testing.assert_equal(h_ret.indptr, d_ret.indptr)
np.testing.assert_equal(h_ret.indices, d_ret.indices)
booster = xgb.train({"tree_method": "gpu_hist"}, dtrain=d_m)
np.testing.assert_allclose(
booster.predict(d_m),
booster.predict(xgb.DMatrix(d_m.get_data())),
atol=1e-6
)

View File

@ -1,13 +1,14 @@
# -*- coding: utf-8 -*-
import os import os
import tempfile import tempfile
import numpy as np
import xgboost as xgb
import scipy.sparse
import pytest
from scipy.sparse import rand, csr_matrix
import numpy as np
import pytest
import scipy.sparse
import testing as tm import testing as tm
from hypothesis import given, settings, strategies
from scipy.sparse import csr_matrix, rand
import xgboost as xgb
rng = np.random.RandomState(1) rng = np.random.RandomState(1)
@ -433,3 +434,22 @@ class TestDMatrix:
def test_base_margin(self): def test_base_margin(self):
set_base_margin_info(np.asarray, xgb.DMatrix, "hist") set_base_margin_info(np.asarray, xgb.DMatrix, "hist")
@given(
strategies.integers(0, 1000),
strategies.integers(0, 100),
strategies.fractions(0, 1),
)
@settings(deadline=None, print_blob=True)
def test_to_csr(self, n_samples, n_features, sparsity) -> None:
if n_samples == 0 or n_features == 0 or sparsity == 1.0:
csr = scipy.sparse.csr_matrix(np.empty((0, 0)))
else:
csr = tm.make_sparse_regression(n_samples, n_features, sparsity, False)[
0
].astype(np.float32)
m = xgb.DMatrix(data=csr)
ret = m.get_data()
np.testing.assert_equal(csr.indptr, ret.indptr)
np.testing.assert_equal(csr.data, ret.data)
np.testing.assert_equal(csr.indices, ret.indices)

View File

@ -1,9 +1,16 @@
from typing import Dict, List, Any from typing import Any, Dict, List
import numpy as np import numpy as np
import pytest import pytest
from hypothesis import given, settings, strategies
from scipy import sparse from scipy import sparse
from testing import IteratorForTest, make_batches, make_batches_sparse, make_categorical from testing import (
IteratorForTest,
make_batches,
make_batches_sparse,
make_categorical,
make_sparse_regression,
)
import xgboost as xgb import xgboost as xgb
@ -102,6 +109,7 @@ class TestQuantileDMatrix:
) )
if tree_method == "gpu_hist": if tree_method == "gpu_hist":
import cudf import cudf
X = cudf.from_pandas(X) X = cudf.from_pandas(X)
y = cudf.from_pandas(y) y = cudf.from_pandas(y)
else: else:
@ -154,6 +162,7 @@ class TestQuantileDMatrix:
X, y = make_categorical(n_samples, n_features, 13, onehot=False) X, y = make_categorical(n_samples, n_features, 13, onehot=False)
if tree_method == "gpu_hist": if tree_method == "gpu_hist":
import cudf import cudf
X = cudf.from_pandas(X) X = cudf.from_pandas(X)
y = cudf.from_pandas(y) y = cudf.from_pandas(y)
else: else:
@ -198,9 +207,7 @@ class TestQuantileDMatrix:
def test_predict(self) -> None: def test_predict(self) -> None:
n_samples, n_features = 16, 2 n_samples, n_features = 16, 2
X, y = make_categorical( X, y = make_categorical(n_samples, n_features, n_categories=13, onehot=False)
n_samples, n_features, n_categories=13, onehot=False
)
Xy = xgb.DMatrix(X, y, enable_categorical=True) Xy = xgb.DMatrix(X, y, enable_categorical=True)
booster = xgb.train({"tree_method": "hist"}, Xy) booster = xgb.train({"tree_method": "hist"}, Xy)
@ -210,3 +217,24 @@ class TestQuantileDMatrix:
qXy = xgb.QuantileDMatrix(X, y, enable_categorical=True) qXy = xgb.QuantileDMatrix(X, y, enable_categorical=True)
b = booster.predict(qXy) b = booster.predict(qXy)
np.testing.assert_allclose(a, b) np.testing.assert_allclose(a, b)
# we don't test empty Quantile DMatrix in single node construction.
@given(
strategies.integers(1, 1000),
strategies.integers(1, 100),
strategies.fractions(0, 0.99),
)
@settings(deadline=None, print_blob=True)
def test_to_csr(self, n_samples: int, n_features: int, sparsity: float) -> None:
csr, y = make_sparse_regression(n_samples, n_features, sparsity, False)
csr = csr.astype(np.float32)
qdm = xgb.QuantileDMatrix(data=csr, label=y)
ret = qdm.get_data()
np.testing.assert_equal(csr.indptr, ret.indptr)
np.testing.assert_equal(csr.indices, ret.indices)
booster = xgb.train({"tree_method": "hist"}, dtrain=qdm)
np.testing.assert_allclose(
booster.predict(qdm), booster.predict(xgb.DMatrix(qdm.get_data()))
)

View File

@ -577,6 +577,8 @@ def make_sparse_regression(
if as_dense: if as_dense:
arr = csr.toarray() arr = csr.toarray()
assert arr.shape[0] == n_samples
assert arr.shape[1] == n_features
arr[arr == 0] = np.nan arr[arr == 0] = np.nan
return arr, y return arr, y