Fix slice and get info. (#5552)

This commit is contained in:
Jiaming Yuan
2020-04-18 18:00:13 +08:00
committed by GitHub
parent c245eb8755
commit e1f22baf8c
14 changed files with 177 additions and 163 deletions

View File

@@ -67,31 +67,6 @@ TEST(Adapter, CSCAdapterColsMoreThanRows) {
EXPECT_EQ(inst[3].index, 3);
}
TEST(CAPI, DMatrixSliceAdapterFromSimpleDMatrix) {
auto p_dmat = RandomDataGenerator(6, 2, 1.0).GenerateDMatrix();
std::vector<int> ridx_set = {1, 3, 5};
data::DMatrixSliceAdapter adapter(p_dmat.get(),
{ridx_set.data(), ridx_set.size()});
EXPECT_EQ(adapter.NumRows(), ridx_set.size());
adapter.BeforeFirst();
for (auto &batch : p_dmat->GetBatches<SparsePage>()) {
adapter.Next();
auto &adapter_batch = adapter.Value();
for (auto i = 0ull; i < adapter_batch.Size(); i++) {
auto inst = batch[ridx_set[i]];
auto line = adapter_batch.GetLine(i);
ASSERT_EQ(inst.size(), line.Size());
for (auto j = 0ull; j < line.Size(); j++) {
EXPECT_EQ(inst[j].fvalue, line.GetElement(j).value);
EXPECT_EQ(inst[j].index, line.GetElement(j).column_idx);
EXPECT_EQ(i, line.GetElement(j).row_idx);
}
}
}
}
// A mock for JVM data iterator.
class DataIterForTest {
std::vector<float> data_ {1, 2, 3, 4, 5};

View File

@@ -125,5 +125,4 @@ TEST(DMatrix, Uri) {
ASSERT_EQ(dmat->Info().num_col_, kCols);
ASSERT_EQ(dmat->Info().num_row_, kRows);
}
} // namespace xgboost

View File

@@ -1,11 +1,12 @@
// Copyright by Contributors
#include <dmlc/filesystem.h>
#include <xgboost/data.h>
#include "../../../src/data/simple_dmatrix.h"
#include <array>
#include "xgboost/base.h"
#include "../../../src/data/simple_dmatrix.h"
#include "../../../src/data/adapter.h"
#include "../helpers.h"
#include "xgboost/base.h"
using namespace xgboost; // NOLINT
@@ -218,45 +219,64 @@ TEST(SimpleDMatrix, FromFile) {
}
TEST(SimpleDMatrix, Slice) {
const int kRows = 6;
const int kCols = 2;
auto p_dmat = RandomDataGenerator(kRows, kCols, 1.0).GenerateDMatrix();
auto &labels = p_dmat->Info().labels_.HostVector();
auto &weights = p_dmat->Info().weights_.HostVector();
auto &base_margin = p_dmat->Info().base_margin_.HostVector();
size_t constexpr kRows {16};
size_t constexpr kCols {8};
size_t constexpr kClasses {3};
auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
auto& weights = p_m->Info().weights_.HostVector();
weights.resize(kRows);
labels.resize(kRows);
base_margin.resize(kRows);
std::iota(labels.begin(), labels.end(), 0);
std::iota(weights.begin(), weights.end(), 0);
std::iota(base_margin.begin(), base_margin.end(), 0);
std::iota(weights.begin(), weights.end(), 0.0f);
std::vector<int> ridx_set = {1, 3, 5};
data::DMatrixSliceAdapter adapter(p_dmat.get(),
{ridx_set.data(), ridx_set.size()});
EXPECT_EQ(adapter.NumRows(), ridx_set.size());
data::SimpleDMatrix new_dmat(&adapter,
std::numeric_limits<float>::quiet_NaN(), 1);
auto& lower = p_m->Info().labels_lower_bound_.HostVector();
auto& upper = p_m->Info().labels_upper_bound_.HostVector();
lower.resize(kRows);
upper.resize(kRows);
EXPECT_EQ(new_dmat.Info().num_row_, ridx_set.size());
std::iota(lower.begin(), lower.end(), 0.0f);
std::iota(upper.begin(), upper.end(), 1.0f);
auto &old_batch = *p_dmat->GetBatches<SparsePage>().begin();
auto &new_batch = *new_dmat.GetBatches<SparsePage>().begin();
for (auto i = 0ull; i < ridx_set.size(); i++) {
EXPECT_EQ(new_dmat.Info().labels_.HostVector()[i],
p_dmat->Info().labels_.HostVector()[ridx_set[i]]);
EXPECT_EQ(new_dmat.Info().weights_.HostVector()[i],
p_dmat->Info().weights_.HostVector()[ridx_set[i]]);
EXPECT_EQ(new_dmat.Info().base_margin_.HostVector()[i],
p_dmat->Info().base_margin_.HostVector()[ridx_set[i]]);
auto old_inst = old_batch[ridx_set[i]];
auto new_inst = new_batch[i];
ASSERT_EQ(old_inst.size(), new_inst.size());
for (auto j = 0ull; j < old_inst.size(); j++) {
EXPECT_EQ(old_inst[j], new_inst[j]);
auto& margin = p_m->Info().base_margin_.HostVector();
margin.resize(kRows * kClasses);
std::array<int32_t, 3> ridxs {1, 3, 5};
std::unique_ptr<DMatrix> out { p_m->Slice(ridxs) };
ASSERT_EQ(out->Info().labels_.Size(), ridxs.size());
ASSERT_EQ(out->Info().labels_lower_bound_.Size(), ridxs.size());
ASSERT_EQ(out->Info().labels_upper_bound_.Size(), ridxs.size());
ASSERT_EQ(out->Info().base_margin_.Size(), ridxs.size() * kClasses);
for (auto const& in_page : p_m->GetBatches<SparsePage>()) {
for (auto const &out_page : out->GetBatches<SparsePage>()) {
for (size_t i = 0; i < ridxs.size(); ++i) {
auto ridx = ridxs[i];
auto out_inst = out_page[i];
auto in_inst = in_page[ridx];
ASSERT_EQ(out_inst.size(), in_inst.size()) << i;
for (size_t j = 0; j < in_inst.size(); ++j) {
ASSERT_EQ(in_inst[j].fvalue, out_inst[j].fvalue);
ASSERT_EQ(in_inst[j].index, out_inst[j].index);
}
ASSERT_EQ(p_m->Info().labels_lower_bound_.HostVector().at(ridx),
out->Info().labels_lower_bound_.HostVector().at(i));
ASSERT_EQ(p_m->Info().labels_upper_bound_.HostVector().at(ridx),
out->Info().labels_upper_bound_.HostVector().at(i));
ASSERT_EQ(p_m->Info().weights_.HostVector().at(ridx),
out->Info().weights_.HostVector().at(i));
auto& out_margin = out->Info().base_margin_.HostVector();
for (size_t j = 0; j < kClasses; ++j) {
auto in_beg = ridx * kClasses;
ASSERT_EQ(out_margin.at(i * kClasses + j), margin.at(in_beg + j));
}
}
}
}
};
ASSERT_EQ(out->Info().num_col_, out->Info().num_col_);
ASSERT_EQ(out->Info().num_row_, ridxs.size());
ASSERT_EQ(out->Info().num_nonzero_, ridxs.size() * kCols); // dense
}
TEST(SimpleDMatrix, SaveLoadBinary) {
dmlc::TemporaryDirectory tempdir;