Obtain CSR matrix from DMatrix. (#8269)
This commit is contained in:
@@ -5,16 +5,18 @@
|
||||
|
||||
#include <rabit/rabit.h>
|
||||
|
||||
#include <algorithm> // std::copy
|
||||
|
||||
#include "../common/categorical.h" // common::IsCat
|
||||
#include "../common/column_matrix.h"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
|
||||
#include "../common/hist_util.h" // common::HistogramCuts
|
||||
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
|
||||
#include "gradient_index.h"
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "simple_batch_iterator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
|
||||
std::shared_ptr<DMatrix> ref, DataIterResetCallback* reset,
|
||||
XGDMatrixCallbackNext* next, float missing, int nthread,
|
||||
@@ -144,7 +146,6 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
} else {
|
||||
CHECK_EQ(n_features, num_cols()) << "Inconsistent number of columns.";
|
||||
}
|
||||
|
||||
size_t batch_size = num_rows();
|
||||
batch_nnz.push_back(nnz_cnt());
|
||||
nnz += batch_nnz.back();
|
||||
@@ -161,6 +162,8 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
return f > accumulated_rows;
|
||||
})) << "Something went wrong during iteration.";
|
||||
|
||||
CHECK_GE(n_features, 1) << "Data must has at least 1 column.";
|
||||
|
||||
/**
|
||||
* Generate quantiles
|
||||
*/
|
||||
@@ -249,9 +252,47 @@ BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const&
|
||||
LOG(WARNING) << "`sparse_threshold` can not be changed when `QuantileDMatrix` is used instead "
|
||||
"of `DMatrix`.";
|
||||
}
|
||||
|
||||
auto begin_iter =
|
||||
BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_));
|
||||
return BatchSet<GHistIndexMatrix>(begin_iter);
|
||||
}
|
||||
|
||||
BatchSet<ExtSparsePage> IterativeDMatrix::GetExtBatches(BatchParam const& param) {
|
||||
for (auto const& page : this->GetGradientIndex(param)) {
|
||||
auto p_out = std::make_shared<SparsePage>();
|
||||
p_out->data.Resize(this->Info().num_nonzero_);
|
||||
p_out->offset.Resize(this->Info().num_row_ + 1);
|
||||
|
||||
auto& h_offset = p_out->offset.HostVector();
|
||||
CHECK_EQ(page.row_ptr.size(), h_offset.size());
|
||||
std::copy(page.row_ptr.cbegin(), page.row_ptr.cend(), h_offset.begin());
|
||||
|
||||
auto& h_data = p_out->data.HostVector();
|
||||
auto const& vals = page.cut.Values();
|
||||
auto const& mins = page.cut.MinValues();
|
||||
auto const& ptrs = page.cut.Ptrs();
|
||||
auto ft = Info().feature_types.ConstHostSpan();
|
||||
|
||||
AssignColumnBinIndex(page, [&](auto bin_idx, std::size_t idx, std::size_t, bst_feature_t fidx) {
|
||||
float v;
|
||||
if (common::IsCat(ft, fidx)) {
|
||||
v = vals[bin_idx];
|
||||
} else {
|
||||
v = common::HistogramCuts::NumericBinValue(ptrs, vals, mins, fidx, bin_idx);
|
||||
}
|
||||
h_data[idx] = Entry{fidx, v};
|
||||
});
|
||||
|
||||
auto p_ext_out = std::make_shared<ExtSparsePage>(p_out);
|
||||
auto begin_iter =
|
||||
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(p_ext_out));
|
||||
return BatchSet<ExtSparsePage>(begin_iter);
|
||||
}
|
||||
LOG(FATAL) << "Unreachable";
|
||||
auto begin_iter =
|
||||
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(nullptr));
|
||||
return BatchSet<ExtSparsePage>(begin_iter);
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user