Fix slice and get info. (#5552)

This commit is contained in:
Jiaming Yuan 2020-04-18 18:00:13 +08:00 committed by GitHub
parent c245eb8755
commit e1f22baf8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 177 additions and 163 deletions

View File

@ -188,9 +188,10 @@ getinfo <- function(object, ...) UseMethod("getinfo")
getinfo.xgb.DMatrix <- function(object, name, ...) { getinfo.xgb.DMatrix <- function(object, name, ...) {
if (typeof(name) != "character" || if (typeof(name) != "character" ||
length(name) != 1 || length(name) != 1 ||
!name %in% c('label', 'weight', 'base_margin', 'nrow')) { !name %in% c('label', 'weight', 'base_margin', 'nrow',
'label_lower_bound', 'label_upper_bound')) {
stop("getinfo: name must be one of the following\n", stop("getinfo: name must be one of the following\n",
" 'label', 'weight', 'base_margin', 'nrow'") " 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound'")
} }
if (name != "nrow"){ if (name != "nrow"){
ret <- .Call(XGDMatrixGetInfo_R, object, name) ret <- .Call(XGDMatrixGetInfo_R, object, name)

View File

@ -50,6 +50,12 @@ test_that("xgb.DMatrix: getinfo & setinfo", {
labels <- getinfo(dtest, 'label') labels <- getinfo(dtest, 'label')
expect_equal(test_label, getinfo(dtest, 'label')) expect_equal(test_label, getinfo(dtest, 'label'))
expect_true(setinfo(dtest, 'label_lower_bound', test_label))
expect_equal(test_label, getinfo(dtest, 'label_lower_bound'))
expect_true(setinfo(dtest, 'label_upper_bound', test_label))
expect_equal(test_label, getinfo(dtest, 'label_upper_bound'))
expect_true(length(getinfo(dtest, 'weight')) == 0) expect_true(length(getinfo(dtest, 'weight')) == 0)
expect_true(length(getinfo(dtest, 'base_margin')) == 0) expect_true(length(getinfo(dtest, 'base_margin')) == 0)
@ -59,7 +65,7 @@ test_that("xgb.DMatrix: getinfo & setinfo", {
expect_error(setinfo(dtest, 'group', test_label)) expect_error(setinfo(dtest, 'group', test_label))
# providing character values will give a warning # providing character values will give a warning
expect_warning( setinfo(dtest, 'weight', rep('a', nrow(test_data))) ) expect_warning(setinfo(dtest, 'weight', rep('a', nrow(test_data))))
# any other label should error # any other label should error
expect_error(setinfo(dtest, 'asdf', test_label)) expect_error(setinfo(dtest, 'asdf', test_label))

View File

@ -73,6 +73,8 @@ class MetaInfo {
/*! \brief default constructor */ /*! \brief default constructor */
MetaInfo() = default; MetaInfo() = default;
MetaInfo(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo const& that) { MetaInfo& operator=(MetaInfo const& that) {
this->num_row_ = that.num_row_; this->num_row_ = that.num_row_;
this->num_col_ = that.num_col_; this->num_col_ = that.num_col_;
@ -89,6 +91,8 @@ class MetaInfo {
this->base_margin_.Copy(that.base_margin_); this->base_margin_.Copy(that.base_margin_);
return *this; return *this;
} }
MetaInfo Slice(common::Span<int32_t const> ridxs) const;
/*! /*!
* \brief Get weight of each instances. * \brief Get weight of each instances.
* \param i Instance index. * \param i Instance index.
@ -491,7 +495,7 @@ class DMatrix {
const std::string& cache_prefix = "", const std::string& cache_prefix = "",
size_t page_size = kPageSize); size_t page_size = kPageSize);
virtual DMatrix* Slice(common::Span<int32_t const> ridxs) = 0;
/*! \brief page size 32 MB */ /*! \brief page size 32 MB */
static const size_t kPageSize = 32UL << 20UL; static const size_t kPageSize = 32UL << 20UL;

View File

@ -181,11 +181,7 @@ XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle,
<< "slice does not support group structure"; << "slice does not support group structure";
} }
DMatrix* dmat = static_cast<std::shared_ptr<DMatrix>*>(handle)->get(); DMatrix* dmat = static_cast<std::shared_ptr<DMatrix>*>(handle)->get();
CHECK(dynamic_cast<data::SimpleDMatrix*>(dmat)) *out = new std::shared_ptr<DMatrix>(dmat->Slice({idxset, len}));
<< "Slice only supported for SimpleDMatrix currently.";
data::DMatrixSliceAdapter adapter(dmat, {idxset, static_cast<size_t>(len)});
*out = new std::shared_ptr<DMatrix>(
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1));
API_END(); API_END();
} }

View File

@ -599,93 +599,6 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
dmlc::RowBlock<uint32_t> block_; dmlc::RowBlock<uint32_t> block_;
std::unique_ptr<FileAdapterBatch> batch_; std::unique_ptr<FileAdapterBatch> batch_;
}; };
class DMatrixSliceAdapterBatch {
public:
// Fetch metainfo values according to sliced rows
template <typename T>
std::vector<T> Gather(const std::vector<T>& in) {
if (in.empty()) return {};
std::vector<T> out(this->Size());
for (auto i = 0ull; i < this->Size(); i++) {
out[i] = in[ridx_set[i]];
}
return out;
}
DMatrixSliceAdapterBatch(const SparsePage& batch, DMatrix* dmat,
common::Span<const int> ridx_set)
: batch(batch), ridx_set(ridx_set) {
batch_labels = this->Gather(dmat->Info().labels_.HostVector());
batch_weights = this->Gather(dmat->Info().weights_.HostVector());
batch_base_margin = this->Gather(dmat->Info().base_margin_.HostVector());
}
class Line {
public:
Line(const SparsePage::Inst& inst, size_t row_idx)
: inst_(inst), row_idx_(row_idx) {}
size_t Size() { return inst_.size(); }
COOTuple GetElement(size_t idx) {
return COOTuple{row_idx_, inst_[idx].index, inst_[idx].fvalue};
}
private:
SparsePage::Inst inst_;
size_t row_idx_;
};
Line GetLine(size_t idx) const { return Line(batch[ridx_set[idx]], idx); }
const float* Labels() const {
if (batch_labels.empty()) {
return nullptr;
}
return batch_labels.data();
}
const float* Weights() const {
if (batch_weights.empty()) {
return nullptr;
}
return batch_weights.data();
}
const uint64_t* Qid() const { return nullptr; }
const float* BaseMargin() const {
if (batch_base_margin.empty()) {
return nullptr;
}
return batch_base_margin.data();
}
size_t Size() const { return ridx_set.size(); }
const SparsePage& batch;
common::Span<const int> ridx_set;
std::vector<float> batch_labels;
std::vector<float> batch_weights;
std::vector<float> batch_base_margin;
};
// Group pointer is not exposed
// This is because external bindings currently manipulate the group values
// manually when slicing This could potentially be moved to internal C++ code if
// needed
class DMatrixSliceAdapter
: public detail::SingleBatchDataIter<DMatrixSliceAdapterBatch> {
public:
DMatrixSliceAdapter(DMatrix* dmat, common::Span<const int> ridx_set)
: dmat_(dmat),
ridx_set_(ridx_set),
batch_(*dmat_->GetBatches<SparsePage>().begin(), dmat_, ridx_set) {}
const DMatrixSliceAdapterBatch& Value() const override { return batch_; }
// Indicates a number of rows/columns must be inferred
size_t NumRows() const { return ridx_set_.size(); }
size_t NumColumns() const { return dmat_->Info().num_col_; }
private:
DMatrix* dmat_;
common::Span<const int> ridx_set_;
DMatrixSliceAdapterBatch batch_;
};
}; // namespace data }; // namespace data
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_DATA_ADAPTER_H_ #endif // XGBOOST_DATA_ADAPTER_H_

View File

@ -205,6 +205,53 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
LoadVectorField(fi, u8"labels_upper_bound", DataType::kFloat32, &labels_upper_bound_); LoadVectorField(fi, u8"labels_upper_bound", DataType::kFloat32, &labels_upper_bound_);
} }
template <typename T>
std::vector<T> Gather(const std::vector<T> &in, common::Span<int const> ridxs, size_t stride = 1) {
if (in.empty()) {
return {};
}
auto size = ridxs.size();
std::vector<T> out(size * stride);
for (auto i = 0ull; i < size; i++) {
auto ridx = ridxs[i];
for (size_t j = 0; j < stride; ++j) {
out[i * stride +j] = in[ridx * stride + j];
}
}
return out;
}
MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
MetaInfo out;
out.num_row_ = ridxs.size();
out.num_col_ = this->num_col_;
// Groups is maintained by a higher level Python function. We should aim at deprecating
// the slice function.
out.labels_.HostVector() = Gather(this->labels_.HostVector(), ridxs);
out.labels_upper_bound_.HostVector() =
Gather(this->labels_upper_bound_.HostVector(), ridxs);
out.labels_lower_bound_.HostVector() =
Gather(this->labels_lower_bound_.HostVector(), ridxs);
// weights
if (this->weights_.Size() + 1 == this->group_ptr_.size()) {
auto& h_weights = out.weights_.HostVector();
// Assuming all groups are available.
out.weights_.HostVector() = h_weights;
} else {
out.weights_.HostVector() = Gather(this->weights_.HostVector(), ridxs);
}
if (this->base_margin_.Size() != this->num_row_) {
CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0)
<< "Incorrect size of base margin vector.";
size_t stride = this->base_margin_.Size() / this->num_row_;
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs, stride);
} else {
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs);
}
return out;
}
// try to load group information from file, if exists // try to load group information from file, if exists
inline bool MetaTryLoadGroup(const std::string& fname, inline bool MetaTryLoadGroup(const std::string& fname,
std::vector<unsigned>* group) { std::vector<unsigned>* group) {
@ -459,9 +506,6 @@ template DMatrix* DMatrix::Create<data::DataTableAdapter>(
template DMatrix* DMatrix::Create<data::FileAdapter>( template DMatrix* DMatrix::Create<data::FileAdapter>(
data::FileAdapter* adapter, float missing, int nthread, data::FileAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size); const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::DMatrixSliceAdapter>(
data::DMatrixSliceAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::IteratorAdapter>( template DMatrix* DMatrix::Create<data::IteratorAdapter>(
data::IteratorAdapter* adapter, float missing, int nthread, data::IteratorAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size); const std::string& cache_prefix, size_t page_size);

View File

@ -31,6 +31,10 @@ class DeviceDMatrix : public DMatrix {
bool EllpackExists() const override { return true; } bool EllpackExists() const override { return true; }
bool SparsePageExists() const override { return false; } bool SparsePageExists() const override { return false; }
DMatrix *Slice(common::Span<int32_t const> ridxs) override {
LOG(FATAL) << "Slicing DMatrix is not supported for Device DMatrix.";
return nullptr;
}
private: private:
BatchSet<SparsePage> GetRowBatches() override { BatchSet<SparsePage> GetRowBatches() override {

View File

@ -16,6 +16,27 @@ MetaInfo& SimpleDMatrix::Info() { return info_; }
const MetaInfo& SimpleDMatrix::Info() const { return info_; } const MetaInfo& SimpleDMatrix::Info() const { return info_; }
DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
auto out = new SimpleDMatrix;
SparsePage& out_page = out->sparse_page_;
for (auto const &page : this->GetBatches<SparsePage>()) {
page.data.HostVector();
page.offset.HostVector();
auto& h_data = out_page.data.HostVector();
auto& h_offset = out_page.offset.HostVector();
size_t rptr{0};
for (auto ridx : ridxs) {
auto inst = page[ridx];
rptr += inst.size();
std::copy(inst.begin(), inst.end(), std::back_inserter(h_data));
h_offset.emplace_back(rptr);
}
out->Info() = this->Info().Slice(ridxs);
out->Info().num_nonzero_ = h_offset.back();
}
return out;
}
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() { BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
// since csr is the default data structure so `source_` is always available. // since csr is the default data structure so `source_` is always available.
auto begin_iter = BatchIterator<SparsePage>( auto begin_iter = BatchIterator<SparsePage>(
@ -174,8 +195,6 @@ template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,
int nthread); int nthread);
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing,
int nthread); int nthread);
template SimpleDMatrix::SimpleDMatrix(DMatrixSliceAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(IteratorAdapter* adapter, float missing, template SimpleDMatrix::SimpleDMatrix(IteratorAdapter* adapter, float missing,
int nthread); int nthread);
} // namespace data } // namespace data

View File

@ -19,6 +19,7 @@ namespace data {
// Used for single batch data. // Used for single batch data.
class SimpleDMatrix : public DMatrix { class SimpleDMatrix : public DMatrix {
public: public:
SimpleDMatrix() = default;
template <typename AdapterT> template <typename AdapterT>
explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread); explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread);
@ -32,6 +33,7 @@ class SimpleDMatrix : public DMatrix {
const MetaInfo& Info() const override; const MetaInfo& Info() const override;
bool SingleColBlock() const override { return true; } bool SingleColBlock() const override { return true; }
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
/*! \brief magic number used to identify SimpleDMatrix binary files */ /*! \brief magic number used to identify SimpleDMatrix binary files */
static const int kMagic = 0xffffab01; static const int kMagic = 0xffffab01;

View File

@ -37,6 +37,10 @@ class SparsePageDMatrix : public DMatrix {
const MetaInfo& Info() const override; const MetaInfo& Info() const override;
bool SingleColBlock() const override { return false; } bool SingleColBlock() const override { return false; }
DMatrix *Slice(common::Span<int32_t const> ridxs) override {
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
return nullptr;
}
private: private:
BatchSet<SparsePage> GetRowBatches() override; BatchSet<SparsePage> GetRowBatches() override;

View File

@ -67,31 +67,6 @@ TEST(Adapter, CSCAdapterColsMoreThanRows) {
EXPECT_EQ(inst[3].index, 3); EXPECT_EQ(inst[3].index, 3);
} }
TEST(CAPI, DMatrixSliceAdapterFromSimpleDMatrix) {
auto p_dmat = RandomDataGenerator(6, 2, 1.0).GenerateDMatrix();
std::vector<int> ridx_set = {1, 3, 5};
data::DMatrixSliceAdapter adapter(p_dmat.get(),
{ridx_set.data(), ridx_set.size()});
EXPECT_EQ(adapter.NumRows(), ridx_set.size());
adapter.BeforeFirst();
for (auto &batch : p_dmat->GetBatches<SparsePage>()) {
adapter.Next();
auto &adapter_batch = adapter.Value();
for (auto i = 0ull; i < adapter_batch.Size(); i++) {
auto inst = batch[ridx_set[i]];
auto line = adapter_batch.GetLine(i);
ASSERT_EQ(inst.size(), line.Size());
for (auto j = 0ull; j < line.Size(); j++) {
EXPECT_EQ(inst[j].fvalue, line.GetElement(j).value);
EXPECT_EQ(inst[j].index, line.GetElement(j).column_idx);
EXPECT_EQ(i, line.GetElement(j).row_idx);
}
}
}
}
// A mock for JVM data iterator. // A mock for JVM data iterator.
class DataIterForTest { class DataIterForTest {
std::vector<float> data_ {1, 2, 3, 4, 5}; std::vector<float> data_ {1, 2, 3, 4, 5};

View File

@ -125,5 +125,4 @@ TEST(DMatrix, Uri) {
ASSERT_EQ(dmat->Info().num_col_, kCols); ASSERT_EQ(dmat->Info().num_col_, kCols);
ASSERT_EQ(dmat->Info().num_row_, kRows); ASSERT_EQ(dmat->Info().num_row_, kRows);
} }
} // namespace xgboost } // namespace xgboost

View File

@ -1,11 +1,12 @@
// Copyright by Contributors // Copyright by Contributors
#include <dmlc/filesystem.h> #include <dmlc/filesystem.h>
#include <xgboost/data.h> #include <xgboost/data.h>
#include "../../../src/data/simple_dmatrix.h"
#include <array>
#include "xgboost/base.h"
#include "../../../src/data/simple_dmatrix.h"
#include "../../../src/data/adapter.h" #include "../../../src/data/adapter.h"
#include "../helpers.h" #include "../helpers.h"
#include "xgboost/base.h"
using namespace xgboost; // NOLINT using namespace xgboost; // NOLINT
@ -218,45 +219,64 @@ TEST(SimpleDMatrix, FromFile) {
} }
TEST(SimpleDMatrix, Slice) { TEST(SimpleDMatrix, Slice) {
const int kRows = 6; size_t constexpr kRows {16};
const int kCols = 2; size_t constexpr kCols {8};
auto p_dmat = RandomDataGenerator(kRows, kCols, 1.0).GenerateDMatrix(); size_t constexpr kClasses {3};
auto &labels = p_dmat->Info().labels_.HostVector(); auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
auto &weights = p_dmat->Info().weights_.HostVector(); auto& weights = p_m->Info().weights_.HostVector();
auto &base_margin = p_dmat->Info().base_margin_.HostVector();
weights.resize(kRows); weights.resize(kRows);
labels.resize(kRows); std::iota(weights.begin(), weights.end(), 0.0f);
base_margin.resize(kRows);
std::iota(labels.begin(), labels.end(), 0);
std::iota(weights.begin(), weights.end(), 0);
std::iota(base_margin.begin(), base_margin.end(), 0);
std::vector<int> ridx_set = {1, 3, 5}; auto& lower = p_m->Info().labels_lower_bound_.HostVector();
data::DMatrixSliceAdapter adapter(p_dmat.get(), auto& upper = p_m->Info().labels_upper_bound_.HostVector();
{ridx_set.data(), ridx_set.size()}); lower.resize(kRows);
EXPECT_EQ(adapter.NumRows(), ridx_set.size()); upper.resize(kRows);
data::SimpleDMatrix new_dmat(&adapter,
std::numeric_limits<float>::quiet_NaN(), 1);
EXPECT_EQ(new_dmat.Info().num_row_, ridx_set.size()); std::iota(lower.begin(), lower.end(), 0.0f);
std::iota(upper.begin(), upper.end(), 1.0f);
auto &old_batch = *p_dmat->GetBatches<SparsePage>().begin(); auto& margin = p_m->Info().base_margin_.HostVector();
auto &new_batch = *new_dmat.GetBatches<SparsePage>().begin(); margin.resize(kRows * kClasses);
for (auto i = 0ull; i < ridx_set.size(); i++) {
EXPECT_EQ(new_dmat.Info().labels_.HostVector()[i], std::array<int32_t, 3> ridxs {1, 3, 5};
p_dmat->Info().labels_.HostVector()[ridx_set[i]]); std::unique_ptr<DMatrix> out { p_m->Slice(ridxs) };
EXPECT_EQ(new_dmat.Info().weights_.HostVector()[i], ASSERT_EQ(out->Info().labels_.Size(), ridxs.size());
p_dmat->Info().weights_.HostVector()[ridx_set[i]]); ASSERT_EQ(out->Info().labels_lower_bound_.Size(), ridxs.size());
EXPECT_EQ(new_dmat.Info().base_margin_.HostVector()[i], ASSERT_EQ(out->Info().labels_upper_bound_.Size(), ridxs.size());
p_dmat->Info().base_margin_.HostVector()[ridx_set[i]]); ASSERT_EQ(out->Info().base_margin_.Size(), ridxs.size() * kClasses);
auto old_inst = old_batch[ridx_set[i]];
auto new_inst = new_batch[i]; for (auto const& in_page : p_m->GetBatches<SparsePage>()) {
ASSERT_EQ(old_inst.size(), new_inst.size()); for (auto const &out_page : out->GetBatches<SparsePage>()) {
for (auto j = 0ull; j < old_inst.size(); j++) { for (size_t i = 0; i < ridxs.size(); ++i) {
EXPECT_EQ(old_inst[j], new_inst[j]); auto ridx = ridxs[i];
auto out_inst = out_page[i];
auto in_inst = in_page[ridx];
ASSERT_EQ(out_inst.size(), in_inst.size()) << i;
for (size_t j = 0; j < in_inst.size(); ++j) {
ASSERT_EQ(in_inst[j].fvalue, out_inst[j].fvalue);
ASSERT_EQ(in_inst[j].index, out_inst[j].index);
}
ASSERT_EQ(p_m->Info().labels_lower_bound_.HostVector().at(ridx),
out->Info().labels_lower_bound_.HostVector().at(i));
ASSERT_EQ(p_m->Info().labels_upper_bound_.HostVector().at(ridx),
out->Info().labels_upper_bound_.HostVector().at(i));
ASSERT_EQ(p_m->Info().weights_.HostVector().at(ridx),
out->Info().weights_.HostVector().at(i));
auto& out_margin = out->Info().base_margin_.HostVector();
for (size_t j = 0; j < kClasses; ++j) {
auto in_beg = ridx * kClasses;
ASSERT_EQ(out_margin.at(i * kClasses + j), margin.at(in_beg + j));
}
}
} }
} }
};
ASSERT_EQ(out->Info().num_col_, out->Info().num_col_);
ASSERT_EQ(out->Info().num_row_, ridxs.size());
ASSERT_EQ(out->Info().num_nonzero_, ridxs.size() * kCols); // dense
}
TEST(SimpleDMatrix, SaveLoadBinary) { TEST(SimpleDMatrix, SaveLoadBinary) {
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;

View File

@ -71,7 +71,34 @@ class TestDMatrix(unittest.TestCase):
assert (from_view.shape == from_array.shape) assert (from_view.shape == from_array.shape)
assert (from_view == from_array).all() assert (from_view == from_array).all()
def test_feature_names(self): def test_slice(self):
X = rng.randn(100, 100)
y = rng.randint(low=0, high=3, size=100)
d = xgb.DMatrix(X, y)
eval_res_0 = {}
booster = xgb.train(
{'num_class': 3, 'objective': 'multi:softprob'}, d,
num_boost_round=2, evals=[(d, 'd')], evals_result=eval_res_0)
predt = booster.predict(d)
predt = predt.reshape(100 * 3, 1)
d.set_base_margin(predt)
ridxs = [1, 2, 3, 4, 5, 6]
d = d.slice(ridxs)
sliced_margin = d.get_float_info('base_margin')
assert sliced_margin.shape[0] == len(ridxs) * 3
eval_res_1 = {}
xgb.train({'num_class': 3, 'objective': 'multi:softprob'}, d,
num_boost_round=2, evals=[(d, 'd')], evals_result=eval_res_1)
eval_res_0 = eval_res_0['d']['merror']
eval_res_1 = eval_res_1['d']['merror']
for i in range(len(eval_res_0)):
assert abs(eval_res_0[i] - eval_res_1[i]) < 0.02
def test_feature_names_slice(self):
data = np.random.randn(5, 5) data = np.random.randn(5, 5)
# different length # different length