Fix slice and get info. (#5552)
This commit is contained in:
parent
c245eb8755
commit
e1f22baf8c
@ -188,9 +188,10 @@ getinfo <- function(object, ...) UseMethod("getinfo")
|
||||
getinfo.xgb.DMatrix <- function(object, name, ...) {
|
||||
if (typeof(name) != "character" ||
|
||||
length(name) != 1 ||
|
||||
!name %in% c('label', 'weight', 'base_margin', 'nrow')) {
|
||||
!name %in% c('label', 'weight', 'base_margin', 'nrow',
|
||||
'label_lower_bound', 'label_upper_bound')) {
|
||||
stop("getinfo: name must be one of the following\n",
|
||||
" 'label', 'weight', 'base_margin', 'nrow'")
|
||||
" 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound'")
|
||||
}
|
||||
if (name != "nrow"){
|
||||
ret <- .Call(XGDMatrixGetInfo_R, object, name)
|
||||
|
||||
@ -50,6 +50,12 @@ test_that("xgb.DMatrix: getinfo & setinfo", {
|
||||
labels <- getinfo(dtest, 'label')
|
||||
expect_equal(test_label, getinfo(dtest, 'label'))
|
||||
|
||||
expect_true(setinfo(dtest, 'label_lower_bound', test_label))
|
||||
expect_equal(test_label, getinfo(dtest, 'label_lower_bound'))
|
||||
|
||||
expect_true(setinfo(dtest, 'label_upper_bound', test_label))
|
||||
expect_equal(test_label, getinfo(dtest, 'label_upper_bound'))
|
||||
|
||||
expect_true(length(getinfo(dtest, 'weight')) == 0)
|
||||
expect_true(length(getinfo(dtest, 'base_margin')) == 0)
|
||||
|
||||
|
||||
@ -73,6 +73,8 @@ class MetaInfo {
|
||||
|
||||
/*! \brief default constructor */
|
||||
MetaInfo() = default;
|
||||
MetaInfo(MetaInfo&& that) = default;
|
||||
MetaInfo& operator=(MetaInfo&& that) = default;
|
||||
MetaInfo& operator=(MetaInfo const& that) {
|
||||
this->num_row_ = that.num_row_;
|
||||
this->num_col_ = that.num_col_;
|
||||
@ -89,6 +91,8 @@ class MetaInfo {
|
||||
this->base_margin_.Copy(that.base_margin_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
MetaInfo Slice(common::Span<int32_t const> ridxs) const;
|
||||
/*!
|
||||
* \brief Get weight of each instances.
|
||||
* \param i Instance index.
|
||||
@ -491,7 +495,7 @@ class DMatrix {
|
||||
const std::string& cache_prefix = "",
|
||||
size_t page_size = kPageSize);
|
||||
|
||||
|
||||
virtual DMatrix* Slice(common::Span<int32_t const> ridxs) = 0;
|
||||
/*! \brief page size 32 MB */
|
||||
static const size_t kPageSize = 32UL << 20UL;
|
||||
|
||||
|
||||
@ -181,11 +181,7 @@ XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle,
|
||||
<< "slice does not support group structure";
|
||||
}
|
||||
DMatrix* dmat = static_cast<std::shared_ptr<DMatrix>*>(handle)->get();
|
||||
CHECK(dynamic_cast<data::SimpleDMatrix*>(dmat))
|
||||
<< "Slice only supported for SimpleDMatrix currently.";
|
||||
data::DMatrixSliceAdapter adapter(dmat, {idxset, static_cast<size_t>(len)});
|
||||
*out = new std::shared_ptr<DMatrix>(
|
||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1));
|
||||
*out = new std::shared_ptr<DMatrix>(dmat->Slice({idxset, len}));
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
@ -599,93 +599,6 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
|
||||
dmlc::RowBlock<uint32_t> block_;
|
||||
std::unique_ptr<FileAdapterBatch> batch_;
|
||||
};
|
||||
|
||||
class DMatrixSliceAdapterBatch {
|
||||
public:
|
||||
// Fetch metainfo values according to sliced rows
|
||||
template <typename T>
|
||||
std::vector<T> Gather(const std::vector<T>& in) {
|
||||
if (in.empty()) return {};
|
||||
|
||||
std::vector<T> out(this->Size());
|
||||
for (auto i = 0ull; i < this->Size(); i++) {
|
||||
out[i] = in[ridx_set[i]];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
DMatrixSliceAdapterBatch(const SparsePage& batch, DMatrix* dmat,
|
||||
common::Span<const int> ridx_set)
|
||||
: batch(batch), ridx_set(ridx_set) {
|
||||
batch_labels = this->Gather(dmat->Info().labels_.HostVector());
|
||||
batch_weights = this->Gather(dmat->Info().weights_.HostVector());
|
||||
batch_base_margin = this->Gather(dmat->Info().base_margin_.HostVector());
|
||||
}
|
||||
|
||||
class Line {
|
||||
public:
|
||||
Line(const SparsePage::Inst& inst, size_t row_idx)
|
||||
: inst_(inst), row_idx_(row_idx) {}
|
||||
|
||||
size_t Size() { return inst_.size(); }
|
||||
COOTuple GetElement(size_t idx) {
|
||||
return COOTuple{row_idx_, inst_[idx].index, inst_[idx].fvalue};
|
||||
}
|
||||
|
||||
private:
|
||||
SparsePage::Inst inst_;
|
||||
size_t row_idx_;
|
||||
};
|
||||
Line GetLine(size_t idx) const { return Line(batch[ridx_set[idx]], idx); }
|
||||
const float* Labels() const {
|
||||
if (batch_labels.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
return batch_labels.data();
|
||||
}
|
||||
const float* Weights() const {
|
||||
if (batch_weights.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
return batch_weights.data();
|
||||
}
|
||||
const uint64_t* Qid() const { return nullptr; }
|
||||
const float* BaseMargin() const {
|
||||
if (batch_base_margin.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
return batch_base_margin.data();
|
||||
}
|
||||
|
||||
size_t Size() const { return ridx_set.size(); }
|
||||
const SparsePage& batch;
|
||||
common::Span<const int> ridx_set;
|
||||
std::vector<float> batch_labels;
|
||||
std::vector<float> batch_weights;
|
||||
std::vector<float> batch_base_margin;
|
||||
};
|
||||
|
||||
// Group pointer is not exposed
|
||||
// This is because external bindings currently manipulate the group values
|
||||
// manually when slicing This could potentially be moved to internal C++ code if
|
||||
// needed
|
||||
|
||||
class DMatrixSliceAdapter
|
||||
: public detail::SingleBatchDataIter<DMatrixSliceAdapterBatch> {
|
||||
public:
|
||||
DMatrixSliceAdapter(DMatrix* dmat, common::Span<const int> ridx_set)
|
||||
: dmat_(dmat),
|
||||
ridx_set_(ridx_set),
|
||||
batch_(*dmat_->GetBatches<SparsePage>().begin(), dmat_, ridx_set) {}
|
||||
const DMatrixSliceAdapterBatch& Value() const override { return batch_; }
|
||||
// Indicates a number of rows/columns must be inferred
|
||||
size_t NumRows() const { return ridx_set_.size(); }
|
||||
size_t NumColumns() const { return dmat_->Info().num_col_; }
|
||||
|
||||
private:
|
||||
DMatrix* dmat_;
|
||||
common::Span<const int> ridx_set_;
|
||||
DMatrixSliceAdapterBatch batch_;
|
||||
};
|
||||
}; // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_ADAPTER_H_
|
||||
|
||||
@ -205,6 +205,53 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
||||
LoadVectorField(fi, u8"labels_upper_bound", DataType::kFloat32, &labels_upper_bound_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> Gather(const std::vector<T> &in, common::Span<int const> ridxs, size_t stride = 1) {
|
||||
if (in.empty()) {
|
||||
return {};
|
||||
}
|
||||
auto size = ridxs.size();
|
||||
std::vector<T> out(size * stride);
|
||||
for (auto i = 0ull; i < size; i++) {
|
||||
auto ridx = ridxs[i];
|
||||
for (size_t j = 0; j < stride; ++j) {
|
||||
out[i * stride +j] = in[ridx * stride + j];
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
||||
MetaInfo out;
|
||||
out.num_row_ = ridxs.size();
|
||||
out.num_col_ = this->num_col_;
|
||||
// Groups is maintained by a higher level Python function. We should aim at deprecating
|
||||
// the slice function.
|
||||
out.labels_.HostVector() = Gather(this->labels_.HostVector(), ridxs);
|
||||
out.labels_upper_bound_.HostVector() =
|
||||
Gather(this->labels_upper_bound_.HostVector(), ridxs);
|
||||
out.labels_lower_bound_.HostVector() =
|
||||
Gather(this->labels_lower_bound_.HostVector(), ridxs);
|
||||
// weights
|
||||
if (this->weights_.Size() + 1 == this->group_ptr_.size()) {
|
||||
auto& h_weights = out.weights_.HostVector();
|
||||
// Assuming all groups are available.
|
||||
out.weights_.HostVector() = h_weights;
|
||||
} else {
|
||||
out.weights_.HostVector() = Gather(this->weights_.HostVector(), ridxs);
|
||||
}
|
||||
|
||||
if (this->base_margin_.Size() != this->num_row_) {
|
||||
CHECK_EQ(this->base_margin_.Size() % this->num_row_, 0)
|
||||
<< "Incorrect size of base margin vector.";
|
||||
size_t stride = this->base_margin_.Size() / this->num_row_;
|
||||
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs, stride);
|
||||
} else {
|
||||
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// try to load group information from file, if exists
|
||||
inline bool MetaTryLoadGroup(const std::string& fname,
|
||||
std::vector<unsigned>* group) {
|
||||
@ -459,9 +506,6 @@ template DMatrix* DMatrix::Create<data::DataTableAdapter>(
|
||||
template DMatrix* DMatrix::Create<data::FileAdapter>(
|
||||
data::FileAdapter* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix, size_t page_size);
|
||||
template DMatrix* DMatrix::Create<data::DMatrixSliceAdapter>(
|
||||
data::DMatrixSliceAdapter* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix, size_t page_size);
|
||||
template DMatrix* DMatrix::Create<data::IteratorAdapter>(
|
||||
data::IteratorAdapter* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix, size_t page_size);
|
||||
|
||||
@ -31,6 +31,10 @@ class DeviceDMatrix : public DMatrix {
|
||||
|
||||
bool EllpackExists() const override { return true; }
|
||||
bool SparsePageExists() const override { return false; }
|
||||
DMatrix *Slice(common::Span<int32_t const> ridxs) override {
|
||||
LOG(FATAL) << "Slicing DMatrix is not supported for Device DMatrix.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
BatchSet<SparsePage> GetRowBatches() override {
|
||||
|
||||
@ -16,6 +16,27 @@ MetaInfo& SimpleDMatrix::Info() { return info_; }
|
||||
|
||||
const MetaInfo& SimpleDMatrix::Info() const { return info_; }
|
||||
|
||||
DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
|
||||
auto out = new SimpleDMatrix;
|
||||
SparsePage& out_page = out->sparse_page_;
|
||||
for (auto const &page : this->GetBatches<SparsePage>()) {
|
||||
page.data.HostVector();
|
||||
page.offset.HostVector();
|
||||
auto& h_data = out_page.data.HostVector();
|
||||
auto& h_offset = out_page.offset.HostVector();
|
||||
size_t rptr{0};
|
||||
for (auto ridx : ridxs) {
|
||||
auto inst = page[ridx];
|
||||
rptr += inst.size();
|
||||
std::copy(inst.begin(), inst.end(), std::back_inserter(h_data));
|
||||
h_offset.emplace_back(rptr);
|
||||
}
|
||||
out->Info() = this->Info().Slice(ridxs);
|
||||
out->Info().num_nonzero_ = h_offset.back();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
|
||||
// since csr is the default data structure so `source_` is always available.
|
||||
auto begin_iter = BatchIterator<SparsePage>(
|
||||
@ -174,8 +195,6 @@ template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
template SimpleDMatrix::SimpleDMatrix(DMatrixSliceAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
template SimpleDMatrix::SimpleDMatrix(IteratorAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
} // namespace data
|
||||
|
||||
@ -19,6 +19,7 @@ namespace data {
|
||||
// Used for single batch data.
|
||||
class SimpleDMatrix : public DMatrix {
|
||||
public:
|
||||
SimpleDMatrix() = default;
|
||||
template <typename AdapterT>
|
||||
explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread);
|
||||
|
||||
@ -32,6 +33,7 @@ class SimpleDMatrix : public DMatrix {
|
||||
const MetaInfo& Info() const override;
|
||||
|
||||
bool SingleColBlock() const override { return true; }
|
||||
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
|
||||
|
||||
/*! \brief magic number used to identify SimpleDMatrix binary files */
|
||||
static const int kMagic = 0xffffab01;
|
||||
|
||||
@ -37,6 +37,10 @@ class SparsePageDMatrix : public DMatrix {
|
||||
const MetaInfo& Info() const override;
|
||||
|
||||
bool SingleColBlock() const override { return false; }
|
||||
DMatrix *Slice(common::Span<int32_t const> ridxs) override {
|
||||
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
BatchSet<SparsePage> GetRowBatches() override;
|
||||
|
||||
@ -67,31 +67,6 @@ TEST(Adapter, CSCAdapterColsMoreThanRows) {
|
||||
EXPECT_EQ(inst[3].index, 3);
|
||||
}
|
||||
|
||||
TEST(CAPI, DMatrixSliceAdapterFromSimpleDMatrix) {
|
||||
auto p_dmat = RandomDataGenerator(6, 2, 1.0).GenerateDMatrix();
|
||||
|
||||
std::vector<int> ridx_set = {1, 3, 5};
|
||||
data::DMatrixSliceAdapter adapter(p_dmat.get(),
|
||||
{ridx_set.data(), ridx_set.size()});
|
||||
EXPECT_EQ(adapter.NumRows(), ridx_set.size());
|
||||
|
||||
adapter.BeforeFirst();
|
||||
for (auto &batch : p_dmat->GetBatches<SparsePage>()) {
|
||||
adapter.Next();
|
||||
auto &adapter_batch = adapter.Value();
|
||||
for (auto i = 0ull; i < adapter_batch.Size(); i++) {
|
||||
auto inst = batch[ridx_set[i]];
|
||||
auto line = adapter_batch.GetLine(i);
|
||||
ASSERT_EQ(inst.size(), line.Size());
|
||||
for (auto j = 0ull; j < line.Size(); j++) {
|
||||
EXPECT_EQ(inst[j].fvalue, line.GetElement(j).value);
|
||||
EXPECT_EQ(inst[j].index, line.GetElement(j).column_idx);
|
||||
EXPECT_EQ(i, line.GetElement(j).row_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A mock for JVM data iterator.
|
||||
class DataIterForTest {
|
||||
std::vector<float> data_ {1, 2, 3, 4, 5};
|
||||
|
||||
@ -125,5 +125,4 @@ TEST(DMatrix, Uri) {
|
||||
ASSERT_EQ(dmat->Info().num_col_, kCols);
|
||||
ASSERT_EQ(dmat->Info().num_row_, kRows);
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
// Copyright by Contributors
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <xgboost/data.h>
|
||||
#include "../../../src/data/simple_dmatrix.h"
|
||||
|
||||
#include <array>
|
||||
#include "xgboost/base.h"
|
||||
#include "../../../src/data/simple_dmatrix.h"
|
||||
#include "../../../src/data/adapter.h"
|
||||
#include "../helpers.h"
|
||||
#include "xgboost/base.h"
|
||||
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
@ -218,45 +219,64 @@ TEST(SimpleDMatrix, FromFile) {
|
||||
}
|
||||
|
||||
TEST(SimpleDMatrix, Slice) {
|
||||
const int kRows = 6;
|
||||
const int kCols = 2;
|
||||
auto p_dmat = RandomDataGenerator(kRows, kCols, 1.0).GenerateDMatrix();
|
||||
auto &labels = p_dmat->Info().labels_.HostVector();
|
||||
auto &weights = p_dmat->Info().weights_.HostVector();
|
||||
auto &base_margin = p_dmat->Info().base_margin_.HostVector();
|
||||
size_t constexpr kRows {16};
|
||||
size_t constexpr kCols {8};
|
||||
size_t constexpr kClasses {3};
|
||||
auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
|
||||
auto& weights = p_m->Info().weights_.HostVector();
|
||||
weights.resize(kRows);
|
||||
labels.resize(kRows);
|
||||
base_margin.resize(kRows);
|
||||
std::iota(labels.begin(), labels.end(), 0);
|
||||
std::iota(weights.begin(), weights.end(), 0);
|
||||
std::iota(base_margin.begin(), base_margin.end(), 0);
|
||||
std::iota(weights.begin(), weights.end(), 0.0f);
|
||||
|
||||
std::vector<int> ridx_set = {1, 3, 5};
|
||||
data::DMatrixSliceAdapter adapter(p_dmat.get(),
|
||||
{ridx_set.data(), ridx_set.size()});
|
||||
EXPECT_EQ(adapter.NumRows(), ridx_set.size());
|
||||
data::SimpleDMatrix new_dmat(&adapter,
|
||||
std::numeric_limits<float>::quiet_NaN(), 1);
|
||||
auto& lower = p_m->Info().labels_lower_bound_.HostVector();
|
||||
auto& upper = p_m->Info().labels_upper_bound_.HostVector();
|
||||
lower.resize(kRows);
|
||||
upper.resize(kRows);
|
||||
|
||||
EXPECT_EQ(new_dmat.Info().num_row_, ridx_set.size());
|
||||
std::iota(lower.begin(), lower.end(), 0.0f);
|
||||
std::iota(upper.begin(), upper.end(), 1.0f);
|
||||
|
||||
auto &old_batch = *p_dmat->GetBatches<SparsePage>().begin();
|
||||
auto &new_batch = *new_dmat.GetBatches<SparsePage>().begin();
|
||||
for (auto i = 0ull; i < ridx_set.size(); i++) {
|
||||
EXPECT_EQ(new_dmat.Info().labels_.HostVector()[i],
|
||||
p_dmat->Info().labels_.HostVector()[ridx_set[i]]);
|
||||
EXPECT_EQ(new_dmat.Info().weights_.HostVector()[i],
|
||||
p_dmat->Info().weights_.HostVector()[ridx_set[i]]);
|
||||
EXPECT_EQ(new_dmat.Info().base_margin_.HostVector()[i],
|
||||
p_dmat->Info().base_margin_.HostVector()[ridx_set[i]]);
|
||||
auto old_inst = old_batch[ridx_set[i]];
|
||||
auto new_inst = new_batch[i];
|
||||
ASSERT_EQ(old_inst.size(), new_inst.size());
|
||||
for (auto j = 0ull; j < old_inst.size(); j++) {
|
||||
EXPECT_EQ(old_inst[j], new_inst[j]);
|
||||
auto& margin = p_m->Info().base_margin_.HostVector();
|
||||
margin.resize(kRows * kClasses);
|
||||
|
||||
std::array<int32_t, 3> ridxs {1, 3, 5};
|
||||
std::unique_ptr<DMatrix> out { p_m->Slice(ridxs) };
|
||||
ASSERT_EQ(out->Info().labels_.Size(), ridxs.size());
|
||||
ASSERT_EQ(out->Info().labels_lower_bound_.Size(), ridxs.size());
|
||||
ASSERT_EQ(out->Info().labels_upper_bound_.Size(), ridxs.size());
|
||||
ASSERT_EQ(out->Info().base_margin_.Size(), ridxs.size() * kClasses);
|
||||
|
||||
for (auto const& in_page : p_m->GetBatches<SparsePage>()) {
|
||||
for (auto const &out_page : out->GetBatches<SparsePage>()) {
|
||||
for (size_t i = 0; i < ridxs.size(); ++i) {
|
||||
auto ridx = ridxs[i];
|
||||
auto out_inst = out_page[i];
|
||||
auto in_inst = in_page[ridx];
|
||||
ASSERT_EQ(out_inst.size(), in_inst.size()) << i;
|
||||
for (size_t j = 0; j < in_inst.size(); ++j) {
|
||||
ASSERT_EQ(in_inst[j].fvalue, out_inst[j].fvalue);
|
||||
ASSERT_EQ(in_inst[j].index, out_inst[j].index);
|
||||
}
|
||||
|
||||
ASSERT_EQ(p_m->Info().labels_lower_bound_.HostVector().at(ridx),
|
||||
out->Info().labels_lower_bound_.HostVector().at(i));
|
||||
ASSERT_EQ(p_m->Info().labels_upper_bound_.HostVector().at(ridx),
|
||||
out->Info().labels_upper_bound_.HostVector().at(i));
|
||||
ASSERT_EQ(p_m->Info().weights_.HostVector().at(ridx),
|
||||
out->Info().weights_.HostVector().at(i));
|
||||
|
||||
auto& out_margin = out->Info().base_margin_.HostVector();
|
||||
for (size_t j = 0; j < kClasses; ++j) {
|
||||
auto in_beg = ridx * kClasses;
|
||||
ASSERT_EQ(out_margin.at(i * kClasses + j), margin.at(in_beg + j));
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(out->Info().num_col_, out->Info().num_col_);
|
||||
ASSERT_EQ(out->Info().num_row_, ridxs.size());
|
||||
ASSERT_EQ(out->Info().num_nonzero_, ridxs.size() * kCols); // dense
|
||||
}
|
||||
|
||||
TEST(SimpleDMatrix, SaveLoadBinary) {
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
|
||||
@ -71,7 +71,34 @@ class TestDMatrix(unittest.TestCase):
|
||||
assert (from_view.shape == from_array.shape)
|
||||
assert (from_view == from_array).all()
|
||||
|
||||
def test_feature_names(self):
|
||||
def test_slice(self):
|
||||
X = rng.randn(100, 100)
|
||||
y = rng.randint(low=0, high=3, size=100)
|
||||
d = xgb.DMatrix(X, y)
|
||||
eval_res_0 = {}
|
||||
booster = xgb.train(
|
||||
{'num_class': 3, 'objective': 'multi:softprob'}, d,
|
||||
num_boost_round=2, evals=[(d, 'd')], evals_result=eval_res_0)
|
||||
|
||||
predt = booster.predict(d)
|
||||
predt = predt.reshape(100 * 3, 1)
|
||||
d.set_base_margin(predt)
|
||||
|
||||
ridxs = [1, 2, 3, 4, 5, 6]
|
||||
d = d.slice(ridxs)
|
||||
sliced_margin = d.get_float_info('base_margin')
|
||||
assert sliced_margin.shape[0] == len(ridxs) * 3
|
||||
|
||||
eval_res_1 = {}
|
||||
xgb.train({'num_class': 3, 'objective': 'multi:softprob'}, d,
|
||||
num_boost_round=2, evals=[(d, 'd')], evals_result=eval_res_1)
|
||||
|
||||
eval_res_0 = eval_res_0['d']['merror']
|
||||
eval_res_1 = eval_res_1['d']['merror']
|
||||
for i in range(len(eval_res_0)):
|
||||
assert abs(eval_res_0[i] - eval_res_1[i]) < 0.02
|
||||
|
||||
def test_feature_names_slice(self):
|
||||
data = np.random.randn(5, 5)
|
||||
|
||||
# different length
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user