Implement slice via adapters (#5198)

This commit is contained in:
Rory Mitchell
2020-01-14 12:55:41 +13:00
committed by GitHub
parent f100b8d878
commit a73e25e15f
7 changed files with 190 additions and 45 deletions

View File

@@ -23,6 +23,7 @@
#include "../data/simple_csr_source.h"
#include "../common/io.h"
#include "../data/adapter.h"
#include "../data/simple_dmatrix.h"
namespace xgboost {
// declare the data callback.
@@ -287,53 +288,20 @@ XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle,
API_BEGIN();
CHECK_HANDLE();
data::SimpleCSRSource src;
src.CopyFrom(static_cast<std::shared_ptr<DMatrix>*>(handle)->get());
data::SimpleCSRSource& ret = *source;
if (!allow_groups) {
CHECK_EQ(src.info.group_ptr_.size(), 0U)
<< "slice does not support group structure";
CHECK_EQ(static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()
->Info()
.group_ptr_.size(),
0U)
<< "slice does not support group structure";
}
ret.Clear();
ret.info.num_row_ = len;
ret.info.num_col_ = src.info.num_col_;
auto iter = &src;
iter->BeforeFirst();
CHECK(iter->Next());
const auto& batch = iter->Value();
const auto& src_labels = src.info.labels_.ConstHostVector();
const auto& src_weights = src.info.weights_.ConstHostVector();
const auto& src_base_margin = src.info.base_margin_.ConstHostVector();
auto& ret_labels = ret.info.labels_.HostVector();
auto& ret_weights = ret.info.weights_.HostVector();
auto& ret_base_margin = ret.info.base_margin_.HostVector();
auto& offset_vec = ret.page_.offset.HostVector();
auto& data_vec = ret.page_.data.HostVector();
for (xgboost::bst_ulong i = 0; i < len; ++i) {
const int ridx = idxset[i];
auto inst = batch[ridx];
CHECK_LT(static_cast<xgboost::bst_ulong>(ridx), batch.Size());
data_vec.insert(data_vec.end(), inst.data(),
inst.data() + inst.size());
offset_vec.push_back(offset_vec.back() + inst.size());
ret.info.num_nonzero_ += inst.size();
if (src_labels.size() != 0) {
ret_labels.push_back(src_labels[ridx]);
}
if (src_weights.size() != 0) {
ret_weights.push_back(src_weights[ridx]);
}
if (src_base_margin.size() != 0) {
ret_base_margin.push_back(src_base_margin[ridx]);
}
}
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
DMatrix* dmat = static_cast<std::shared_ptr<DMatrix>*>(handle)->get();
CHECK(dynamic_cast<data::SimpleDMatrix*>(dmat))
<< "Slice only supported for SimpleDMatrix currently.";
data::DMatrixSliceAdapter adapter(dmat, {idxset, len});
*out = new std::shared_ptr<DMatrix>(
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1));
API_END();
}