Implement slice via adapters (#5198)
This commit is contained in:
@@ -61,3 +61,31 @@ TEST(adapter, CSCAdapterColsMoreThanRows) {
|
||||
EXPECT_EQ(inst[3].fvalue, 8);
|
||||
EXPECT_EQ(inst[3].index, 3);
|
||||
}
|
||||
|
||||
TEST(c_api, DMatrixSliceAdapterFromSimpleDMatrix) {
|
||||
auto pp_dmat = CreateDMatrix(6, 2, 1.0);
|
||||
auto p_dmat = *pp_dmat;
|
||||
|
||||
std::vector<int> ridx_set = {1, 3, 5};
|
||||
data::DMatrixSliceAdapter adapter(p_dmat.get(),
|
||||
{ridx_set.data(), ridx_set.size()});
|
||||
EXPECT_EQ(adapter.NumRows(), ridx_set.size());
|
||||
|
||||
adapter.BeforeFirst();
|
||||
for (auto &batch : p_dmat->GetBatches<SparsePage>()) {
|
||||
adapter.Next();
|
||||
auto &adapter_batch = adapter.Value();
|
||||
for (auto i = 0ull; i < adapter_batch.Size(); i++) {
|
||||
auto inst = batch[ridx_set[i]];
|
||||
auto line = adapter_batch.GetLine(i);
|
||||
ASSERT_EQ(inst.size(), line.Size());
|
||||
for (auto j = 0ull; j < line.Size(); j++) {
|
||||
EXPECT_EQ(inst[j].fvalue, line.GetElement(j).value);
|
||||
EXPECT_EQ(inst[j].index, line.GetElement(j).column_idx);
|
||||
EXPECT_EQ(i, line.GetElement(j).row_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
delete pp_dmat;
|
||||
}
|
||||
|
||||
@@ -210,3 +210,47 @@ TEST(SimpleDMatrix, FromFile) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SimpleDMatrix, Slice) {
|
||||
const int kRows = 6;
|
||||
const int kCols = 2;
|
||||
auto pp_dmat = CreateDMatrix(kRows, kCols, 1.0);
|
||||
auto p_dmat = *pp_dmat;
|
||||
auto &labels = p_dmat->Info().labels_.HostVector();
|
||||
auto &weights = p_dmat->Info().weights_.HostVector();
|
||||
auto &base_margin = p_dmat->Info().base_margin_.HostVector();
|
||||
weights.resize(kRows);
|
||||
labels.resize(kRows);
|
||||
base_margin.resize(kRows);
|
||||
std::iota(labels.begin(), labels.end(), 0);
|
||||
std::iota(weights.begin(), weights.end(), 0);
|
||||
std::iota(base_margin.begin(), base_margin.end(), 0);
|
||||
|
||||
std::vector<int> ridx_set = {1, 3, 5};
|
||||
data::DMatrixSliceAdapter adapter(p_dmat.get(),
|
||||
{ridx_set.data(), ridx_set.size()});
|
||||
EXPECT_EQ(adapter.NumRows(), ridx_set.size());
|
||||
data::SimpleDMatrix new_dmat(&adapter,
|
||||
std::numeric_limits<float>::quiet_NaN(), 1);
|
||||
|
||||
EXPECT_EQ(new_dmat.Info().num_row_, ridx_set.size());
|
||||
|
||||
auto &old_batch = *p_dmat->GetBatches<SparsePage>().begin();
|
||||
auto &new_batch = *new_dmat.GetBatches<SparsePage>().begin();
|
||||
for (auto i = 0ull; i < ridx_set.size(); i++) {
|
||||
EXPECT_EQ(new_dmat.Info().labels_.HostVector()[i],
|
||||
p_dmat->Info().labels_.HostVector()[ridx_set[i]]);
|
||||
EXPECT_EQ(new_dmat.Info().weights_.HostVector()[i],
|
||||
p_dmat->Info().weights_.HostVector()[ridx_set[i]]);
|
||||
EXPECT_EQ(new_dmat.Info().base_margin_.HostVector()[i],
|
||||
p_dmat->Info().base_margin_.HostVector()[ridx_set[i]]);
|
||||
auto old_inst = old_batch[ridx_set[i]];
|
||||
auto new_inst = new_batch[i];
|
||||
ASSERT_EQ(old_inst.size(), new_inst.size());
|
||||
for (auto j = 0ull; j < old_inst.size(); j++) {
|
||||
EXPECT_EQ(old_inst[j], new_inst[j]);
|
||||
}
|
||||
}
|
||||
|
||||
delete pp_dmat;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user