Fix slice and get info. (#5552)
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
// Copyright by Contributors
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <xgboost/data.h>
|
||||
#include "../../../src/data/simple_dmatrix.h"
|
||||
|
||||
#include <array>
|
||||
#include "xgboost/base.h"
|
||||
#include "../../../src/data/simple_dmatrix.h"
|
||||
#include "../../../src/data/adapter.h"
|
||||
#include "../helpers.h"
|
||||
#include "xgboost/base.h"
|
||||
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
@@ -218,45 +219,64 @@ TEST(SimpleDMatrix, FromFile) {
|
||||
}
|
||||
|
||||
TEST(SimpleDMatrix, Slice) {
|
||||
const int kRows = 6;
|
||||
const int kCols = 2;
|
||||
auto p_dmat = RandomDataGenerator(kRows, kCols, 1.0).GenerateDMatrix();
|
||||
auto &labels = p_dmat->Info().labels_.HostVector();
|
||||
auto &weights = p_dmat->Info().weights_.HostVector();
|
||||
auto &base_margin = p_dmat->Info().base_margin_.HostVector();
|
||||
size_t constexpr kRows {16};
|
||||
size_t constexpr kCols {8};
|
||||
size_t constexpr kClasses {3};
|
||||
auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
|
||||
auto& weights = p_m->Info().weights_.HostVector();
|
||||
weights.resize(kRows);
|
||||
labels.resize(kRows);
|
||||
base_margin.resize(kRows);
|
||||
std::iota(labels.begin(), labels.end(), 0);
|
||||
std::iota(weights.begin(), weights.end(), 0);
|
||||
std::iota(base_margin.begin(), base_margin.end(), 0);
|
||||
std::iota(weights.begin(), weights.end(), 0.0f);
|
||||
|
||||
std::vector<int> ridx_set = {1, 3, 5};
|
||||
data::DMatrixSliceAdapter adapter(p_dmat.get(),
|
||||
{ridx_set.data(), ridx_set.size()});
|
||||
EXPECT_EQ(adapter.NumRows(), ridx_set.size());
|
||||
data::SimpleDMatrix new_dmat(&adapter,
|
||||
std::numeric_limits<float>::quiet_NaN(), 1);
|
||||
auto& lower = p_m->Info().labels_lower_bound_.HostVector();
|
||||
auto& upper = p_m->Info().labels_upper_bound_.HostVector();
|
||||
lower.resize(kRows);
|
||||
upper.resize(kRows);
|
||||
|
||||
EXPECT_EQ(new_dmat.Info().num_row_, ridx_set.size());
|
||||
std::iota(lower.begin(), lower.end(), 0.0f);
|
||||
std::iota(upper.begin(), upper.end(), 1.0f);
|
||||
|
||||
auto &old_batch = *p_dmat->GetBatches<SparsePage>().begin();
|
||||
auto &new_batch = *new_dmat.GetBatches<SparsePage>().begin();
|
||||
for (auto i = 0ull; i < ridx_set.size(); i++) {
|
||||
EXPECT_EQ(new_dmat.Info().labels_.HostVector()[i],
|
||||
p_dmat->Info().labels_.HostVector()[ridx_set[i]]);
|
||||
EXPECT_EQ(new_dmat.Info().weights_.HostVector()[i],
|
||||
p_dmat->Info().weights_.HostVector()[ridx_set[i]]);
|
||||
EXPECT_EQ(new_dmat.Info().base_margin_.HostVector()[i],
|
||||
p_dmat->Info().base_margin_.HostVector()[ridx_set[i]]);
|
||||
auto old_inst = old_batch[ridx_set[i]];
|
||||
auto new_inst = new_batch[i];
|
||||
ASSERT_EQ(old_inst.size(), new_inst.size());
|
||||
for (auto j = 0ull; j < old_inst.size(); j++) {
|
||||
EXPECT_EQ(old_inst[j], new_inst[j]);
|
||||
auto& margin = p_m->Info().base_margin_.HostVector();
|
||||
margin.resize(kRows * kClasses);
|
||||
|
||||
std::array<int32_t, 3> ridxs {1, 3, 5};
|
||||
std::unique_ptr<DMatrix> out { p_m->Slice(ridxs) };
|
||||
ASSERT_EQ(out->Info().labels_.Size(), ridxs.size());
|
||||
ASSERT_EQ(out->Info().labels_lower_bound_.Size(), ridxs.size());
|
||||
ASSERT_EQ(out->Info().labels_upper_bound_.Size(), ridxs.size());
|
||||
ASSERT_EQ(out->Info().base_margin_.Size(), ridxs.size() * kClasses);
|
||||
|
||||
for (auto const& in_page : p_m->GetBatches<SparsePage>()) {
|
||||
for (auto const &out_page : out->GetBatches<SparsePage>()) {
|
||||
for (size_t i = 0; i < ridxs.size(); ++i) {
|
||||
auto ridx = ridxs[i];
|
||||
auto out_inst = out_page[i];
|
||||
auto in_inst = in_page[ridx];
|
||||
ASSERT_EQ(out_inst.size(), in_inst.size()) << i;
|
||||
for (size_t j = 0; j < in_inst.size(); ++j) {
|
||||
ASSERT_EQ(in_inst[j].fvalue, out_inst[j].fvalue);
|
||||
ASSERT_EQ(in_inst[j].index, out_inst[j].index);
|
||||
}
|
||||
|
||||
ASSERT_EQ(p_m->Info().labels_lower_bound_.HostVector().at(ridx),
|
||||
out->Info().labels_lower_bound_.HostVector().at(i));
|
||||
ASSERT_EQ(p_m->Info().labels_upper_bound_.HostVector().at(ridx),
|
||||
out->Info().labels_upper_bound_.HostVector().at(i));
|
||||
ASSERT_EQ(p_m->Info().weights_.HostVector().at(ridx),
|
||||
out->Info().weights_.HostVector().at(i));
|
||||
|
||||
auto& out_margin = out->Info().base_margin_.HostVector();
|
||||
for (size_t j = 0; j < kClasses; ++j) {
|
||||
auto in_beg = ridx * kClasses;
|
||||
ASSERT_EQ(out_margin.at(i * kClasses + j), margin.at(in_beg + j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
ASSERT_EQ(out->Info().num_col_, out->Info().num_col_);
|
||||
ASSERT_EQ(out->Info().num_row_, ridxs.size());
|
||||
ASSERT_EQ(out->Info().num_nonzero_, ridxs.size() * kCols); // dense
|
||||
}
|
||||
|
||||
TEST(SimpleDMatrix, SaveLoadBinary) {
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
|
||||
Reference in New Issue
Block a user