Obtain CSR matrix from DMatrix. (#8269)
This commit is contained in:
parent
b14c44ee5e
commit
55cf24cc32
@ -761,6 +761,39 @@ XGB_DLL int XGDMatrixNumRow(DMatrixHandle handle,
|
||||
*/
|
||||
XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle,
|
||||
bst_ulong *out);
|
||||
|
||||
|
||||
/*!
|
||||
* \brief Get number of valid values from DMatrix.
|
||||
*
|
||||
* \param handle the handle to the DMatrix
|
||||
* \param out The output of number of non-missing values
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out);
|
||||
|
||||
/*!
|
||||
* \brief Get the predictors from DMatrix as CSR matrix for testing. If this is a
|
||||
* quantized DMatrix, quantized values are returned instead.
|
||||
*
|
||||
* Unlike most of XGBoost C functions, caller of `XGDMatrixGetDataAsCSR` is required to
|
||||
* allocate the memory for return buffer instead of using thread local memory from
|
||||
* XGBoost. This is to avoid allocating a huge memory buffer that can not be freed until
|
||||
* exiting the thread.
|
||||
*
|
||||
* \param handle the handle to the DMatrix
|
||||
* \param config Json configuration string. At the moment it should be an empty document,
|
||||
* preserved for future use.
|
||||
* \param out_indptr indptr of output CSR matrix.
|
||||
* \param out_indices Column index of output CSR matrix.
|
||||
* \param out_data Data value of CSR matrix.
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
|
||||
bst_ulong *out_indptr, unsigned *out_indices, float *out_data);
|
||||
|
||||
// --- start XGBoost class
|
||||
/*!
|
||||
* \brief create xgboost learner
|
||||
|
||||
@ -284,12 +284,17 @@ class SparsePage {
|
||||
return {offset.ConstHostSpan(), data.ConstHostSpan()};
|
||||
}
|
||||
|
||||
|
||||
/*! \brief constructor */
|
||||
SparsePage() {
|
||||
this->Clear();
|
||||
}
|
||||
|
||||
SparsePage(SparsePage const& that) = delete;
|
||||
SparsePage(SparsePage&& that) = default;
|
||||
SparsePage& operator=(SparsePage const& that) = delete;
|
||||
SparsePage& operator=(SparsePage&& that) = default;
|
||||
virtual ~SparsePage() = default;
|
||||
|
||||
/*! \return Number of instances in the page. */
|
||||
inline size_t Size() const {
|
||||
return offset.Size() == 0 ? 0 : offset.Size() - 1;
|
||||
@ -358,6 +363,16 @@ class CSCPage: public SparsePage {
|
||||
explicit CSCPage(SparsePage page) : SparsePage(std::move(page)) {}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Sparse page for exporting DMatrix. Same as SparsePage, just a different type to
|
||||
* prevent being used internally.
|
||||
*/
|
||||
class ExtSparsePage {
|
||||
public:
|
||||
std::shared_ptr<SparsePage const> page;
|
||||
explicit ExtSparsePage(std::shared_ptr<SparsePage const> p) : page{std::move(p)} {}
|
||||
};
|
||||
|
||||
class SortedCSCPage : public SparsePage {
|
||||
public:
|
||||
SortedCSCPage() : SparsePage() {}
|
||||
@ -610,6 +625,7 @@ class DMatrix {
|
||||
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
|
||||
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
|
||||
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) = 0;
|
||||
virtual BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) = 0;
|
||||
|
||||
virtual bool EllpackExists() const = 0;
|
||||
virtual bool GHistIndexExists() const = 0;
|
||||
@ -651,10 +667,15 @@ inline BatchSet<EllpackPage> DMatrix::GetBatches(const BatchParam& param) {
|
||||
return GetEllpackBatches(param);
|
||||
}
|
||||
|
||||
template<>
|
||||
template <>
|
||||
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(const BatchParam& param) {
|
||||
return GetGradientIndex(param);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline BatchSet<ExtSparsePage> DMatrix::GetBatches() {
|
||||
return GetExtBatches(BatchParam{});
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
namespace dmlc {
|
||||
|
||||
@ -609,7 +609,7 @@ def _deprecate_positional_args(f: Callable[..., _T]) -> Callable[..., _T]:
|
||||
return inner_f
|
||||
|
||||
|
||||
class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods
|
||||
"""Data Matrix used in XGBoost.
|
||||
|
||||
DMatrix is an internal data structure that is used by XGBoost,
|
||||
@ -1015,29 +1015,49 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
group_ptr = self.get_uint_info("group_ptr")
|
||||
return np.diff(group_ptr)
|
||||
|
||||
def num_row(self) -> int:
|
||||
"""Get the number of rows in the DMatrix.
|
||||
def get_data(self) -> scipy.sparse.csr_matrix:
|
||||
"""Get the predictors from DMatrix as a CSR matrix. This getter is mostly for
|
||||
testing purposes. If this is a quantized DMatrix then quantized values are
|
||||
returned instead of input values.
|
||||
|
||||
.. versionadded:: 2.0.0
|
||||
|
||||
Returns
|
||||
-------
|
||||
number of rows : int
|
||||
"""
|
||||
indptr = np.empty(self.num_row() + 1, dtype=np.uint64)
|
||||
indices = np.empty(self.num_nonmissing(), dtype=np.uint32)
|
||||
data = np.empty(self.num_nonmissing(), dtype=np.float32)
|
||||
|
||||
c_indptr = indptr.ctypes.data_as(ctypes.POINTER(c_bst_ulong))
|
||||
c_indices = indices.ctypes.data_as(ctypes.POINTER(ctypes.c_uint32))
|
||||
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
||||
config = from_pystr_to_cstr(json.dumps({}))
|
||||
|
||||
_check_call(
|
||||
_LIB.XGDMatrixGetDataAsCSR(self.handle, config, c_indptr, c_indices, c_data)
|
||||
)
|
||||
ret = scipy.sparse.csr_matrix(
|
||||
(data, indices, indptr), shape=(self.num_row(), self.num_col())
|
||||
)
|
||||
return ret
|
||||
|
||||
def num_row(self) -> int:
|
||||
"""Get the number of rows in the DMatrix."""
|
||||
ret = c_bst_ulong()
|
||||
_check_call(_LIB.XGDMatrixNumRow(self.handle,
|
||||
ctypes.byref(ret)))
|
||||
_check_call(_LIB.XGDMatrixNumRow(self.handle, ctypes.byref(ret)))
|
||||
return ret.value
|
||||
|
||||
def num_col(self) -> int:
|
||||
"""Get the number of columns (features) in the DMatrix.
|
||||
|
||||
Returns
|
||||
-------
|
||||
number of columns
|
||||
"""
|
||||
"""Get the number of columns (features) in the DMatrix."""
|
||||
ret = c_bst_ulong()
|
||||
_check_call(_LIB.XGDMatrixNumCol(self.handle, ctypes.byref(ret)))
|
||||
return ret.value
|
||||
|
||||
def num_nonmissing(self) -> int:
|
||||
"""Get the number of non-missing values in the DMatrix."""
|
||||
ret = c_bst_ulong()
|
||||
_check_call(_LIB.XGDMatrixNumNonMissing(self.handle, ctypes.byref(ret)))
|
||||
return ret.value
|
||||
|
||||
def slice(
|
||||
self, rindex: Union[List[int], np.ndarray], allow_groups: bool = False
|
||||
) -> "DMatrix":
|
||||
|
||||
@ -684,9 +684,9 @@ XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
|
||||
xgboost::bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto p_m = CastDMatrixHandle(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = static_cast<xgboost::bst_ulong>(
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_row_);
|
||||
*out = static_cast<xgboost::bst_ulong>(p_m->Info().num_row_);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@ -694,9 +694,52 @@ XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
|
||||
xgboost::bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto p_m = CastDMatrixHandle(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = static_cast<xgboost::bst_ulong>(
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_col_);
|
||||
*out = static_cast<xgboost::bst_ulong>(p_m->Info().num_col_);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// We name the function non-missing instead of non-zero since zero is perfectly valid for XGBoost.
|
||||
XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle const handle, xgboost::bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto p_m = CastDMatrixHandle(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = static_cast<xgboost::bst_ulong>(p_m->Info().num_nonzero_);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
|
||||
xgboost::bst_ulong *out_indptr, unsigned *out_indices,
|
||||
float *out_data) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
|
||||
auto p_m = CastDMatrixHandle(handle);
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(out_indptr);
|
||||
xgboost_CHECK_C_ARG_PTR(out_indices);
|
||||
xgboost_CHECK_C_ARG_PTR(out_data);
|
||||
|
||||
CHECK_LE(p_m->Info().num_col_, std::numeric_limits<unsigned>::max());
|
||||
|
||||
for (auto const &page : p_m->GetBatches<ExtSparsePage>()) {
|
||||
CHECK(page.page);
|
||||
auto const &h_offset = page.page->offset.ConstHostVector();
|
||||
std::copy(h_offset.cbegin(), h_offset.cend(), out_indptr);
|
||||
auto pv = page.page->GetView();
|
||||
common::ParallelFor(page.page->data.Size(), p_m->Ctx()->Threads(), [&](std::size_t i) {
|
||||
auto fvalue = pv.data[i].fvalue;
|
||||
auto findex = pv.data[i].index;
|
||||
out_data[i] = fvalue;
|
||||
out_indices[i] = findex;
|
||||
});
|
||||
}
|
||||
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
@ -6,14 +6,16 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <memory> // std::shared_ptr
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/data.h" // DMatrix
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/string_view.h" // StringView
|
||||
|
||||
namespace xgboost {
|
||||
/* \brief Determine the output shape of prediction.
|
||||
@ -259,5 +261,17 @@ auto const &OptionalArg(Json const &in, std::string const &key, T const &dft) {
|
||||
}
|
||||
return dft;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get shared ptr from DMatrix C handle with additional checks.
|
||||
*/
|
||||
inline std::shared_ptr<DMatrix> CastDMatrixHandle(DMatrixHandle const handle) {
|
||||
auto pp_m = static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
StringView msg{"Invalid DMatrix handle"};
|
||||
CHECK(pp_m) << msg;
|
||||
auto p_m = *pp_m;
|
||||
CHECK(p_m) << msg;
|
||||
return p_m;
|
||||
}
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_C_API_C_API_UTILS_H_
|
||||
|
||||
@ -316,30 +316,16 @@ class ColumnMatrix {
|
||||
void SetIndexMixedColumns(const GHistIndexMatrix& gmat) {
|
||||
auto n_features = gmat.Features();
|
||||
missing_flags_.resize(feature_offsets_[n_features], true);
|
||||
auto const* row_index = gmat.index.data<uint32_t>() + gmat.row_ptr[gmat.base_rowid];
|
||||
num_nonzeros_.resize(n_features, 0);
|
||||
auto const& ptrs = gmat.cut.Ptrs();
|
||||
|
||||
DispatchBinType(bins_type_size_, [&](auto t) {
|
||||
using ColumnBinT = decltype(t);
|
||||
ColumnBinT* local_index = reinterpret_cast<ColumnBinT*>(index_.data());
|
||||
auto const batch_size = gmat.Size();
|
||||
size_t k{0};
|
||||
|
||||
for (size_t ridx = 0; ridx < batch_size; ++ridx) {
|
||||
auto r_beg = gmat.row_ptr[ridx];
|
||||
auto r_end = gmat.row_ptr[ridx + 1];
|
||||
bst_feature_t fidx{0};
|
||||
for (size_t j = r_beg; j < r_end; ++j) {
|
||||
const uint32_t bin_idx = row_index[k];
|
||||
// find the feature index for current bin.
|
||||
while (bin_idx >= ptrs[fidx + 1]) {
|
||||
fidx++;
|
||||
}
|
||||
SetBinSparse(bin_idx, ridx, fidx, local_index);
|
||||
++k;
|
||||
}
|
||||
}
|
||||
CHECK(this->any_missing_);
|
||||
AssignColumnBinIndex(gmat,
|
||||
[&](auto bin_idx, std::size_t, std::size_t ridx, bst_feature_t fidx) {
|
||||
SetBinSparse(bin_idx, ridx, fidx, local_index);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -149,6 +149,19 @@ class HistogramCuts {
|
||||
return this->SearchCatBin(value, fidx, ptrs, vals);
|
||||
}
|
||||
bst_bin_t SearchCatBin(Entry const& e) const { return SearchCatBin(e.fvalue, e.index); }
|
||||
|
||||
/**
|
||||
* \brief Return numerical bin value given bin index.
|
||||
*/
|
||||
static float NumericBinValue(std::vector<std::uint32_t> const& ptrs,
|
||||
std::vector<float> const& vals, std::vector<float> const& mins,
|
||||
bst_feature_t fidx, bst_bin_t bin_idx) {
|
||||
auto lower = static_cast<bst_bin_t>(ptrs[fidx]);
|
||||
if (bin_idx == lower) {
|
||||
return mins[fidx];
|
||||
}
|
||||
return vals[bin_idx - 1];
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@ -164,34 +164,30 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
|
||||
return values[gidx];
|
||||
}
|
||||
|
||||
auto lower = static_cast<bst_bin_t>(cut.Ptrs()[fidx]);
|
||||
auto get_bin_idx = [&](auto &column) {
|
||||
auto get_bin_val = [&](auto &column) {
|
||||
auto bin_idx = column[ridx];
|
||||
if (bin_idx == common::DenseColumnIter<uint8_t, true>::kMissingId) {
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
if (bin_idx == lower) {
|
||||
return mins[fidx];
|
||||
}
|
||||
return values[bin_idx - 1];
|
||||
return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx);
|
||||
};
|
||||
|
||||
if (columns_->GetColumnType(fidx) == common::kDenseColumn) {
|
||||
if (columns_->AnyMissing()) {
|
||||
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
||||
auto column = columns_->DenseColumn<decltype(dtype), true>(fidx);
|
||||
return get_bin_idx(column);
|
||||
return get_bin_val(column);
|
||||
});
|
||||
} else {
|
||||
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
||||
auto column = columns_->DenseColumn<decltype(dtype), false>(fidx);
|
||||
return get_bin_idx(column);
|
||||
return get_bin_val(column);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
||||
auto column = columns_->SparseColumn<decltype(dtype)>(fidx, 0);
|
||||
return get_bin_idx(column);
|
||||
return get_bin_val(column);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -6,6 +6,8 @@
|
||||
#define XGBOOST_DATA_GRADIENT_INDEX_H_
|
||||
|
||||
#include <algorithm> // std::min
|
||||
#include <cinttypes> // std::uint32_t
|
||||
#include <cstddef> // std::size_t
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
@ -229,6 +231,53 @@ class GHistIndexMatrix {
|
||||
bool isDense_;
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Helper for recovering feature index from row-based storage of histogram
|
||||
* bin. (`GHistIndexMatrix`).
|
||||
*
|
||||
* \param assign A callback function that takes bin index, index into the whole batch, row
|
||||
* index and feature index
|
||||
*/
|
||||
template <typename Fn>
|
||||
void AssignColumnBinIndex(GHistIndexMatrix const& page, Fn&& assign) {
|
||||
auto const batch_size = page.Size();
|
||||
auto const& ptrs = page.cut.Ptrs();
|
||||
std::size_t k{0};
|
||||
|
||||
auto dense = page.IsDense();
|
||||
|
||||
common::DispatchBinType(page.index.GetBinTypeSize(), [&](auto t) {
|
||||
using BinT = decltype(t);
|
||||
auto const& index = page.index;
|
||||
for (std::size_t ridx = 0; ridx < batch_size; ++ridx) {
|
||||
auto r_beg = page.row_ptr[ridx];
|
||||
auto r_end = page.row_ptr[ridx + 1];
|
||||
bst_feature_t fidx{0};
|
||||
if (dense) {
|
||||
// compressed, use the operator to obtain the true value.
|
||||
for (std::size_t j = r_beg; j < r_end; ++j) {
|
||||
bst_feature_t fidx = j - r_beg;
|
||||
std::uint32_t bin_idx = index[k];
|
||||
assign(bin_idx, k, ridx, fidx);
|
||||
++k;
|
||||
}
|
||||
} else {
|
||||
// not compressed
|
||||
auto const* row_index = index.data<BinT>() + page.row_ptr[page.base_rowid];
|
||||
for (std::size_t j = r_beg; j < r_end; ++j) {
|
||||
std::uint32_t bin_idx = row_index[k];
|
||||
// find the feature index for current bin.
|
||||
while (bin_idx >= ptrs[fidx + 1]) {
|
||||
fidx++;
|
||||
}
|
||||
assign(bin_idx, k, ridx, fidx);
|
||||
++k;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Should we regenerate the gradient index?
|
||||
*
|
||||
|
||||
@ -5,16 +5,18 @@
|
||||
|
||||
#include <rabit/rabit.h>
|
||||
|
||||
#include <algorithm> // std::copy
|
||||
|
||||
#include "../common/categorical.h" // common::IsCat
|
||||
#include "../common/column_matrix.h"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
|
||||
#include "../common/hist_util.h" // common::HistogramCuts
|
||||
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
|
||||
#include "gradient_index.h"
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "simple_batch_iterator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
|
||||
std::shared_ptr<DMatrix> ref, DataIterResetCallback* reset,
|
||||
XGDMatrixCallbackNext* next, float missing, int nthread,
|
||||
@ -144,7 +146,6 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
} else {
|
||||
CHECK_EQ(n_features, num_cols()) << "Inconsistent number of columns.";
|
||||
}
|
||||
|
||||
size_t batch_size = num_rows();
|
||||
batch_nnz.push_back(nnz_cnt());
|
||||
nnz += batch_nnz.back();
|
||||
@ -161,6 +162,8 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
return f > accumulated_rows;
|
||||
})) << "Something went wrong during iteration.";
|
||||
|
||||
CHECK_GE(n_features, 1) << "Data must has at least 1 column.";
|
||||
|
||||
/**
|
||||
* Generate quantiles
|
||||
*/
|
||||
@ -249,9 +252,47 @@ BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const&
|
||||
LOG(WARNING) << "`sparse_threshold` can not be changed when `QuantileDMatrix` is used instead "
|
||||
"of `DMatrix`.";
|
||||
}
|
||||
|
||||
auto begin_iter =
|
||||
BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_));
|
||||
return BatchSet<GHistIndexMatrix>(begin_iter);
|
||||
}
|
||||
|
||||
BatchSet<ExtSparsePage> IterativeDMatrix::GetExtBatches(BatchParam const& param) {
|
||||
for (auto const& page : this->GetGradientIndex(param)) {
|
||||
auto p_out = std::make_shared<SparsePage>();
|
||||
p_out->data.Resize(this->Info().num_nonzero_);
|
||||
p_out->offset.Resize(this->Info().num_row_ + 1);
|
||||
|
||||
auto& h_offset = p_out->offset.HostVector();
|
||||
CHECK_EQ(page.row_ptr.size(), h_offset.size());
|
||||
std::copy(page.row_ptr.cbegin(), page.row_ptr.cend(), h_offset.begin());
|
||||
|
||||
auto& h_data = p_out->data.HostVector();
|
||||
auto const& vals = page.cut.Values();
|
||||
auto const& mins = page.cut.MinValues();
|
||||
auto const& ptrs = page.cut.Ptrs();
|
||||
auto ft = Info().feature_types.ConstHostSpan();
|
||||
|
||||
AssignColumnBinIndex(page, [&](auto bin_idx, std::size_t idx, std::size_t, bst_feature_t fidx) {
|
||||
float v;
|
||||
if (common::IsCat(ft, fidx)) {
|
||||
v = vals[bin_idx];
|
||||
} else {
|
||||
v = common::HistogramCuts::NumericBinValue(ptrs, vals, mins, fidx, bin_idx);
|
||||
}
|
||||
h_data[idx] = Entry{fidx, v};
|
||||
});
|
||||
|
||||
auto p_ext_out = std::make_shared<ExtSparsePage>(p_out);
|
||||
auto begin_iter =
|
||||
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(p_ext_out));
|
||||
return BatchSet<ExtSparsePage>(begin_iter);
|
||||
}
|
||||
LOG(FATAL) << "Unreachable";
|
||||
auto begin_iter =
|
||||
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(nullptr));
|
||||
return BatchSet<ExtSparsePage>(begin_iter);
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@ -88,6 +88,9 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
||||
} while (iter.Next());
|
||||
iter.Reset();
|
||||
|
||||
auto n_features = cols;
|
||||
CHECK_GE(n_features, 1) << "Data must has at least 1 column.";
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(get_device()));
|
||||
if (!ref) {
|
||||
HostDeviceVector<FeatureType> ft;
|
||||
|
||||
@ -97,6 +97,7 @@ class IterativeDMatrix : public DMatrix {
|
||||
BatchSet<GHistIndexMatrix> GetGradientIndex(BatchParam const ¶m) override;
|
||||
|
||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam ¶m) override;
|
||||
BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) override;
|
||||
|
||||
bool SingleColBlock() const override { return true; }
|
||||
|
||||
@ -117,15 +118,14 @@ void GetCutsFromRef(std::shared_ptr<DMatrix> ref_, bst_feature_t n_features, Bat
|
||||
void GetCutsFromEllpack(EllpackPage const &page, common::HistogramCuts *cuts);
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
inline void IterativeDMatrix::InitFromCUDA(DataIterHandle iter, float missing,
|
||||
std::shared_ptr<DMatrix> ref) {
|
||||
inline void IterativeDMatrix::InitFromCUDA(DataIterHandle, float, std::shared_ptr<DMatrix>) {
|
||||
// silent the warning about unused variables.
|
||||
(void)(proxy_);
|
||||
(void)(reset_);
|
||||
(void)(next_);
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
inline BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(const BatchParam ¶m) {
|
||||
inline BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(const BatchParam &) {
|
||||
common::AssertGPUSupport();
|
||||
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
|
||||
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
|
||||
|
||||
@ -57,6 +57,7 @@ class DMatrixProxy : public DMatrix {
|
||||
|
||||
void SetCUDAArray(char const* c_interface) {
|
||||
common::AssertGPUSupport();
|
||||
CHECK(c_interface);
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
StringView interface_str{c_interface};
|
||||
Json json_array_interface = Json::Load(interface_str);
|
||||
@ -106,7 +107,10 @@ class DMatrixProxy : public DMatrix {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(nullptr));
|
||||
}
|
||||
|
||||
BatchSet<ExtSparsePage> GetExtBatches(BatchParam const&) override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<ExtSparsePage>(BatchIterator<ExtSparsePage>(nullptr));
|
||||
}
|
||||
dmlc::any Adapter() const {
|
||||
return batch_;
|
||||
}
|
||||
|
||||
@ -114,6 +114,14 @@ BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& par
|
||||
return BatchSet<GHistIndexMatrix>(begin_iter);
|
||||
}
|
||||
|
||||
BatchSet<ExtSparsePage> SimpleDMatrix::GetExtBatches(BatchParam const&) {
|
||||
auto casted = std::make_shared<ExtSparsePage>(sparse_page_);
|
||||
CHECK(casted);
|
||||
auto begin_iter =
|
||||
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(casted));
|
||||
return BatchSet<ExtSparsePage>(begin_iter);
|
||||
}
|
||||
|
||||
template <typename AdapterT>
|
||||
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
this->ctx_.nthread = nthread;
|
||||
|
||||
@ -45,6 +45,7 @@ class SimpleDMatrix : public DMatrix {
|
||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
|
||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
|
||||
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) override;
|
||||
BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) override;
|
||||
|
||||
MetaInfo info_;
|
||||
// Primary storage type
|
||||
|
||||
@ -114,6 +114,10 @@ class SparsePageDMatrix : public DMatrix {
|
||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
|
||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
|
||||
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam&) override;
|
||||
BatchSet<ExtSparsePage> GetExtBatches(BatchParam const &) override {
|
||||
LOG(FATAL) << "Can not obtain a single CSR page for external memory DMatrix";
|
||||
return BatchSet<ExtSparsePage>(BatchIterator<ExtSparsePage>(nullptr));
|
||||
}
|
||||
|
||||
// source data pointers.
|
||||
std::shared_ptr<SparsePageSource> sparse_page_source_;
|
||||
|
||||
@ -210,9 +210,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
void DevicePush(DMatrixProxy* proxy, float missing, SparsePage* page);
|
||||
#else
|
||||
inline void DevicePush(DMatrixProxy* proxy, float missing, SparsePage* page) {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
inline void DevicePush(DMatrixProxy*, float, SparsePage*) { common::AssertGPUSupport(); }
|
||||
#endif
|
||||
|
||||
class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
|
||||
|
||||
@ -125,6 +125,7 @@ if __name__ == "__main__":
|
||||
# tests
|
||||
"tests/python/test_config.py",
|
||||
"tests/python/test_spark/",
|
||||
"tests/python/test_quantile_dmatrix.py",
|
||||
"tests/python-gpu/test_gpu_spark/",
|
||||
"tests/ci_build/lint_python.py",
|
||||
# demo
|
||||
|
||||
@ -2,6 +2,8 @@ import numpy as np
|
||||
import xgboost as xgb
|
||||
import pytest
|
||||
import sys
|
||||
from hypothesis import given, strategies, settings
|
||||
from scipy import sparse
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
@ -96,3 +98,42 @@ class TestDeviceQuantileDMatrix:
|
||||
import cupy as cp
|
||||
rng = cp.random.RandomState(1994)
|
||||
self.cputest.run_ref_dmatrix(rng, "gpu_hist", False)
|
||||
|
||||
@given(
|
||||
strategies.integers(1, 1000),
|
||||
strategies.integers(1, 100),
|
||||
strategies.fractions(0, 0.99),
|
||||
)
|
||||
@settings(print_blob=True, deadline=None)
|
||||
def test_to_csr(self, n_samples, n_features, sparsity) -> None:
|
||||
import cupy as cp
|
||||
X, y = tm.make_sparse_regression(
|
||||
n_samples, n_features, sparsity, False
|
||||
)
|
||||
h_X = X.astype(np.float32)
|
||||
|
||||
csr = h_X
|
||||
h_X = X.toarray()
|
||||
h_X[h_X == 0] = np.nan
|
||||
|
||||
h_m = xgb.QuantileDMatrix(data=h_X)
|
||||
h_ret = h_m.get_data()
|
||||
|
||||
d_X = cp.array(h_X)
|
||||
|
||||
d_m = xgb.QuantileDMatrix(data=d_X, label=y)
|
||||
d_ret = d_m.get_data()
|
||||
|
||||
np.testing.assert_equal(csr.indptr, d_ret.indptr)
|
||||
np.testing.assert_equal(csr.indices, d_ret.indices)
|
||||
|
||||
np.testing.assert_equal(h_ret.indptr, d_ret.indptr)
|
||||
np.testing.assert_equal(h_ret.indices, d_ret.indices)
|
||||
|
||||
booster = xgb.train({"tree_method": "gpu_hist"}, dtrain=d_m)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
booster.predict(d_m),
|
||||
booster.predict(xgb.DMatrix(d_m.get_data())),
|
||||
atol=1e-6
|
||||
)
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import scipy.sparse
|
||||
import pytest
|
||||
from scipy.sparse import rand, csr_matrix
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import scipy.sparse
|
||||
import testing as tm
|
||||
from hypothesis import given, settings, strategies
|
||||
from scipy.sparse import csr_matrix, rand
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
rng = np.random.RandomState(1)
|
||||
|
||||
@ -433,3 +434,22 @@ class TestDMatrix:
|
||||
|
||||
def test_base_margin(self):
|
||||
set_base_margin_info(np.asarray, xgb.DMatrix, "hist")
|
||||
|
||||
@given(
|
||||
strategies.integers(0, 1000),
|
||||
strategies.integers(0, 100),
|
||||
strategies.fractions(0, 1),
|
||||
)
|
||||
@settings(deadline=None, print_blob=True)
|
||||
def test_to_csr(self, n_samples, n_features, sparsity) -> None:
|
||||
if n_samples == 0 or n_features == 0 or sparsity == 1.0:
|
||||
csr = scipy.sparse.csr_matrix(np.empty((0, 0)))
|
||||
else:
|
||||
csr = tm.make_sparse_regression(n_samples, n_features, sparsity, False)[
|
||||
0
|
||||
].astype(np.float32)
|
||||
m = xgb.DMatrix(data=csr)
|
||||
ret = m.get_data()
|
||||
np.testing.assert_equal(csr.indptr, ret.indptr)
|
||||
np.testing.assert_equal(csr.data, ret.data)
|
||||
np.testing.assert_equal(csr.indices, ret.indices)
|
||||
|
||||
@ -1,9 +1,16 @@
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from hypothesis import given, settings, strategies
|
||||
from scipy import sparse
|
||||
from testing import IteratorForTest, make_batches, make_batches_sparse, make_categorical
|
||||
from testing import (
|
||||
IteratorForTest,
|
||||
make_batches,
|
||||
make_batches_sparse,
|
||||
make_categorical,
|
||||
make_sparse_regression,
|
||||
)
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
@ -102,6 +109,7 @@ class TestQuantileDMatrix:
|
||||
)
|
||||
if tree_method == "gpu_hist":
|
||||
import cudf
|
||||
|
||||
X = cudf.from_pandas(X)
|
||||
y = cudf.from_pandas(y)
|
||||
else:
|
||||
@ -154,6 +162,7 @@ class TestQuantileDMatrix:
|
||||
X, y = make_categorical(n_samples, n_features, 13, onehot=False)
|
||||
if tree_method == "gpu_hist":
|
||||
import cudf
|
||||
|
||||
X = cudf.from_pandas(X)
|
||||
y = cudf.from_pandas(y)
|
||||
else:
|
||||
@ -198,9 +207,7 @@ class TestQuantileDMatrix:
|
||||
|
||||
def test_predict(self) -> None:
|
||||
n_samples, n_features = 16, 2
|
||||
X, y = make_categorical(
|
||||
n_samples, n_features, n_categories=13, onehot=False
|
||||
)
|
||||
X, y = make_categorical(n_samples, n_features, n_categories=13, onehot=False)
|
||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||
|
||||
booster = xgb.train({"tree_method": "hist"}, Xy)
|
||||
@ -210,3 +217,24 @@ class TestQuantileDMatrix:
|
||||
qXy = xgb.QuantileDMatrix(X, y, enable_categorical=True)
|
||||
b = booster.predict(qXy)
|
||||
np.testing.assert_allclose(a, b)
|
||||
|
||||
# we don't test empty Quantile DMatrix in single node construction.
|
||||
@given(
|
||||
strategies.integers(1, 1000),
|
||||
strategies.integers(1, 100),
|
||||
strategies.fractions(0, 0.99),
|
||||
)
|
||||
@settings(deadline=None, print_blob=True)
|
||||
def test_to_csr(self, n_samples: int, n_features: int, sparsity: float) -> None:
|
||||
csr, y = make_sparse_regression(n_samples, n_features, sparsity, False)
|
||||
csr = csr.astype(np.float32)
|
||||
qdm = xgb.QuantileDMatrix(data=csr, label=y)
|
||||
ret = qdm.get_data()
|
||||
np.testing.assert_equal(csr.indptr, ret.indptr)
|
||||
np.testing.assert_equal(csr.indices, ret.indices)
|
||||
|
||||
booster = xgb.train({"tree_method": "hist"}, dtrain=qdm)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
booster.predict(qdm), booster.predict(xgb.DMatrix(qdm.get_data()))
|
||||
)
|
||||
|
||||
@ -577,6 +577,8 @@ def make_sparse_regression(
|
||||
|
||||
if as_dense:
|
||||
arr = csr.toarray()
|
||||
assert arr.shape[0] == n_samples
|
||||
assert arr.shape[1] == n_features
|
||||
arr[arr == 0] = np.nan
|
||||
return arr, y
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user