Implement slice via adapters (#5198)
This commit is contained in:
@@ -23,6 +23,7 @@
|
||||
#include "../data/simple_csr_source.h"
|
||||
#include "../common/io.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/simple_dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
// declare the data callback.
|
||||
@@ -287,53 +288,20 @@ XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle,
|
||||
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
data::SimpleCSRSource src;
|
||||
src.CopyFrom(static_cast<std::shared_ptr<DMatrix>*>(handle)->get());
|
||||
data::SimpleCSRSource& ret = *source;
|
||||
|
||||
if (!allow_groups) {
|
||||
CHECK_EQ(src.info.group_ptr_.size(), 0U)
|
||||
<< "slice does not support group structure";
|
||||
CHECK_EQ(static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()
|
||||
->Info()
|
||||
.group_ptr_.size(),
|
||||
0U)
|
||||
<< "slice does not support group structure";
|
||||
}
|
||||
|
||||
ret.Clear();
|
||||
ret.info.num_row_ = len;
|
||||
ret.info.num_col_ = src.info.num_col_;
|
||||
|
||||
auto iter = &src;
|
||||
iter->BeforeFirst();
|
||||
CHECK(iter->Next());
|
||||
|
||||
const auto& batch = iter->Value();
|
||||
const auto& src_labels = src.info.labels_.ConstHostVector();
|
||||
const auto& src_weights = src.info.weights_.ConstHostVector();
|
||||
const auto& src_base_margin = src.info.base_margin_.ConstHostVector();
|
||||
auto& ret_labels = ret.info.labels_.HostVector();
|
||||
auto& ret_weights = ret.info.weights_.HostVector();
|
||||
auto& ret_base_margin = ret.info.base_margin_.HostVector();
|
||||
auto& offset_vec = ret.page_.offset.HostVector();
|
||||
auto& data_vec = ret.page_.data.HostVector();
|
||||
|
||||
for (xgboost::bst_ulong i = 0; i < len; ++i) {
|
||||
const int ridx = idxset[i];
|
||||
auto inst = batch[ridx];
|
||||
CHECK_LT(static_cast<xgboost::bst_ulong>(ridx), batch.Size());
|
||||
data_vec.insert(data_vec.end(), inst.data(),
|
||||
inst.data() + inst.size());
|
||||
offset_vec.push_back(offset_vec.back() + inst.size());
|
||||
ret.info.num_nonzero_ += inst.size();
|
||||
|
||||
if (src_labels.size() != 0) {
|
||||
ret_labels.push_back(src_labels[ridx]);
|
||||
}
|
||||
if (src_weights.size() != 0) {
|
||||
ret_weights.push_back(src_weights[ridx]);
|
||||
}
|
||||
if (src_base_margin.size() != 0) {
|
||||
ret_base_margin.push_back(src_base_margin[ridx]);
|
||||
}
|
||||
}
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
|
||||
DMatrix* dmat = static_cast<std::shared_ptr<DMatrix>*>(handle)->get();
|
||||
CHECK(dynamic_cast<data::SimpleDMatrix*>(dmat))
|
||||
<< "Slice only supported for SimpleDMatrix currently.";
|
||||
data::DMatrixSliceAdapter adapter(dmat, {idxset, len});
|
||||
*out = new std::shared_ptr<DMatrix>(
|
||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1));
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user