Initial support for column-wise data split (#8468)
This commit is contained in:
parent
c0609b98f1
commit
78d65a1928
@ -112,6 +112,9 @@ class MetaInfo {
|
|||||||
void Validate(int32_t device) const;
|
void Validate(int32_t device) const;
|
||||||
|
|
||||||
MetaInfo Slice(common::Span<int32_t const> ridxs) const;
|
MetaInfo Slice(common::Span<int32_t const> ridxs) const;
|
||||||
|
|
||||||
|
MetaInfo Copy() const;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Get weight of each instances.
|
* \brief Get weight of each instances.
|
||||||
* \param i Instance index.
|
* \param i Instance index.
|
||||||
@ -620,6 +623,15 @@ class DMatrix {
|
|||||||
|
|
||||||
virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;
|
virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Slice a DMatrix by columns.
|
||||||
|
*
|
||||||
|
* @param start The position of the first column
|
||||||
|
* @param size The number of columns in the slice
|
||||||
|
* @return DMatrix containing the slice of columns
|
||||||
|
*/
|
||||||
|
virtual DMatrix *SliceCol(std::size_t start, std::size_t size) = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual BatchSet<SparsePage> GetRowBatches() = 0;
|
virtual BatchSet<SparsePage> GetRowBatches() = 0;
|
||||||
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
|
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
|
||||||
|
|||||||
@ -381,6 +381,12 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MetaInfo MetaInfo::Copy() const {
|
||||||
|
MetaInfo out;
|
||||||
|
out.Extend(*this, /*accumulate_rows=*/true, /*check_column=*/false);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <int32_t D, typename T>
|
template <int32_t D, typename T>
|
||||||
void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
||||||
@ -777,8 +783,10 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
|
|||||||
|
|
||||||
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode,
|
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode,
|
||||||
const std::string& file_format) {
|
const std::string& file_format) {
|
||||||
CHECK(data_split_mode == DataSplitMode::kRow || data_split_mode == DataSplitMode::kNone)
|
CHECK(data_split_mode == DataSplitMode::kRow ||
|
||||||
<< "Precondition violated; data split mode can only be 'row' or 'none'";
|
data_split_mode == DataSplitMode::kCol ||
|
||||||
|
data_split_mode == DataSplitMode::kNone)
|
||||||
|
<< "Precondition violated; data split mode can only be 'row', 'col', or 'none'";
|
||||||
std::string fname, cache_file;
|
std::string fname, cache_file;
|
||||||
size_t dlm_pos = uri.find('#');
|
size_t dlm_pos = uri.find('#');
|
||||||
if (dlm_pos != std::string::npos) {
|
if (dlm_pos != std::string::npos) {
|
||||||
@ -878,7 +886,20 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
|||||||
* partitioned data will fail the train/val validation check
|
* partitioned data will fail the train/val validation check
|
||||||
* since partitioned data not knowing the real number of features. */
|
* since partitioned data not knowing the real number of features. */
|
||||||
collective::Allreduce<collective::Operation::kMax>(&dmat->Info().num_col_, 1);
|
collective::Allreduce<collective::Operation::kMax>(&dmat->Info().num_col_, 1);
|
||||||
return dmat;
|
|
||||||
|
if (data_split_mode == DataSplitMode::kCol) {
|
||||||
|
if (!cache_file.empty()) {
|
||||||
|
LOG(FATAL) << "Column-wise data split is not support for external memory.";
|
||||||
|
}
|
||||||
|
auto slice_cols = (dmat->Info().num_col_ + 1) / npart;
|
||||||
|
auto slice_start = slice_cols * partid;
|
||||||
|
auto size = std::min(slice_cols, dmat->Info().num_col_ - slice_start);
|
||||||
|
auto* sliced = dmat->SliceCol(slice_start, size);
|
||||||
|
delete dmat;
|
||||||
|
return sliced;
|
||||||
|
} else {
|
||||||
|
return dmat;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
|
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
|
||||||
|
|||||||
@ -86,6 +86,10 @@ class IterativeDMatrix : public DMatrix {
|
|||||||
LOG(FATAL) << "Slicing DMatrix is not supported for Quantile DMatrix.";
|
LOG(FATAL) << "Slicing DMatrix is not supported for Quantile DMatrix.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
DMatrix *SliceCol(std::size_t start, std::size_t size) override {
|
||||||
|
LOG(FATAL) << "Slicing DMatrix columns is not supported for Quantile DMatrix.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
BatchSet<SparsePage> GetRowBatches() override {
|
BatchSet<SparsePage> GetRowBatches() override {
|
||||||
LOG(FATAL) << "Not implemented.";
|
LOG(FATAL) << "Not implemented.";
|
||||||
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
|
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
|
||||||
|
|||||||
@ -87,6 +87,10 @@ class DMatrixProxy : public DMatrix {
|
|||||||
LOG(FATAL) << "Slicing DMatrix is not supported for Proxy DMatrix.";
|
LOG(FATAL) << "Slicing DMatrix is not supported for Proxy DMatrix.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
DMatrix* SliceCol(std::size_t start, std::size_t size) override {
|
||||||
|
LOG(FATAL) << "Slicing DMatrix columns is not supported for Proxy DMatrix.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
BatchSet<SparsePage> GetRowBatches() override {
|
BatchSet<SparsePage> GetRowBatches() override {
|
||||||
LOG(FATAL) << "Not implemented.";
|
LOG(FATAL) << "Not implemented.";
|
||||||
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
|
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
|
||||||
|
|||||||
@ -45,6 +45,29 @@ DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DMatrix* SimpleDMatrix::SliceCol(std::size_t start, std::size_t size) {
|
||||||
|
auto out = new SimpleDMatrix;
|
||||||
|
SparsePage& out_page = *out->sparse_page_;
|
||||||
|
for (auto const &page : this->GetBatches<SparsePage>()) {
|
||||||
|
auto batch = page.GetView();
|
||||||
|
auto& h_data = out_page.data.HostVector();
|
||||||
|
auto& h_offset = out_page.offset.HostVector();
|
||||||
|
size_t rptr{0};
|
||||||
|
for (auto i = 0; i < this->Info().num_row_; i++) {
|
||||||
|
auto inst = batch[i];
|
||||||
|
auto prev_size = h_data.size();
|
||||||
|
std::copy_if(inst.begin(), inst.end(), std::back_inserter(h_data), [&](Entry e) {
|
||||||
|
return e.index >= start && e.index < start + size;
|
||||||
|
});
|
||||||
|
rptr += h_data.size() - prev_size;
|
||||||
|
h_offset.emplace_back(rptr);
|
||||||
|
}
|
||||||
|
out->Info() = this->Info().Copy();
|
||||||
|
out->Info().num_nonzero_ = h_offset.back();
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
|
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
|
||||||
// since csr is the default data structure so `source_` is always available.
|
// since csr is the default data structure so `source_` is always available.
|
||||||
auto begin_iter = BatchIterator<SparsePage>(
|
auto begin_iter = BatchIterator<SparsePage>(
|
||||||
|
|||||||
@ -35,6 +35,7 @@ class SimpleDMatrix : public DMatrix {
|
|||||||
|
|
||||||
bool SingleColBlock() const override { return true; }
|
bool SingleColBlock() const override { return true; }
|
||||||
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
|
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
|
||||||
|
DMatrix* SliceCol(std::size_t start, std::size_t size) override;
|
||||||
|
|
||||||
/*! \brief magic number used to identify SimpleDMatrix binary files */
|
/*! \brief magic number used to identify SimpleDMatrix binary files */
|
||||||
static const int kMagic = 0xffffab01;
|
static const int kMagic = 0xffffab01;
|
||||||
|
|||||||
@ -107,6 +107,10 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
|
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
DMatrix *SliceCol(std::size_t start, std::size_t size) override {
|
||||||
|
LOG(FATAL) << "Slicing DMatrix columns is not supported for external memory.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
BatchSet<SparsePage> GetRowBatches() override;
|
BatchSet<SparsePage> GetRowBatches() override;
|
||||||
|
|||||||
@ -300,6 +300,69 @@ TEST(SimpleDMatrix, Slice) {
|
|||||||
ASSERT_EQ(out->Info().num_nonzero_, ridxs.size() * kCols); // dense
|
ASSERT_EQ(out->Info().num_nonzero_, ridxs.size() * kCols); // dense
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SimpleDMatrix, SliceCol) {
|
||||||
|
size_t constexpr kRows {16};
|
||||||
|
size_t constexpr kCols {8};
|
||||||
|
size_t constexpr kClasses {3};
|
||||||
|
auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
|
||||||
|
auto& weights = p_m->Info().weights_.HostVector();
|
||||||
|
weights.resize(kRows);
|
||||||
|
std::iota(weights.begin(), weights.end(), 0.0f);
|
||||||
|
|
||||||
|
auto& lower = p_m->Info().labels_lower_bound_.HostVector();
|
||||||
|
auto& upper = p_m->Info().labels_upper_bound_.HostVector();
|
||||||
|
lower.resize(kRows);
|
||||||
|
upper.resize(kRows);
|
||||||
|
|
||||||
|
std::iota(lower.begin(), lower.end(), 0.0f);
|
||||||
|
std::iota(upper.begin(), upper.end(), 1.0f);
|
||||||
|
|
||||||
|
auto& margin = p_m->Info().base_margin_;
|
||||||
|
margin = decltype(p_m->Info().base_margin_){{kRows, kClasses}, GenericParameter::kCpuId};
|
||||||
|
|
||||||
|
size_t constexpr kSlicCols {4};
|
||||||
|
for (auto slice = 0; slice < 2; slice++) {
|
||||||
|
auto const slice_start = slice * kSlicCols;
|
||||||
|
std::unique_ptr<DMatrix> out { p_m->SliceCol(slice_start, kSlicCols) };
|
||||||
|
ASSERT_EQ(out->Info().labels.Size(), kRows);
|
||||||
|
ASSERT_EQ(out->Info().labels_lower_bound_.Size(), kRows);
|
||||||
|
ASSERT_EQ(out->Info().labels_upper_bound_.Size(), kRows);
|
||||||
|
ASSERT_EQ(out->Info().base_margin_.Size(), kRows * kClasses);
|
||||||
|
|
||||||
|
for (auto const &in_batch : p_m->GetBatches<SparsePage>()) {
|
||||||
|
auto in_page = in_batch.GetView();
|
||||||
|
for (auto const &out_batch : out->GetBatches<SparsePage>()) {
|
||||||
|
auto out_page = out_batch.GetView();
|
||||||
|
for (size_t i = 0; i < kRows; ++i) {
|
||||||
|
auto out_inst = out_page[i];
|
||||||
|
auto in_inst = in_page[i];
|
||||||
|
ASSERT_EQ(out_inst.size() * 2, in_inst.size()) << i;
|
||||||
|
for (size_t j = 0; j < kSlicCols; ++j) {
|
||||||
|
ASSERT_EQ(in_inst[slice_start + j].fvalue, out_inst[j].fvalue);
|
||||||
|
ASSERT_EQ(in_inst[slice_start + j].index, out_inst[j].index);
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_EQ(p_m->Info().labels_lower_bound_.HostVector().at(i),
|
||||||
|
out->Info().labels_lower_bound_.HostVector().at(i));
|
||||||
|
ASSERT_EQ(p_m->Info().labels_upper_bound_.HostVector().at(i),
|
||||||
|
out->Info().labels_upper_bound_.HostVector().at(i));
|
||||||
|
ASSERT_EQ(p_m->Info().weights_.HostVector().at(i), out->Info().weights_.HostVector().at(i));
|
||||||
|
|
||||||
|
auto out_margin = out->Info().base_margin_.View(GenericParameter::kCpuId);
|
||||||
|
auto in_margin = margin.View(GenericParameter::kCpuId);
|
||||||
|
for (size_t j = 0; j < kClasses; ++j) {
|
||||||
|
ASSERT_EQ(out_margin(i, j), in_margin(i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_EQ(out->Info().num_col_, out->Info().num_col_);
|
||||||
|
ASSERT_EQ(out->Info().num_row_, kRows);
|
||||||
|
ASSERT_EQ(out->Info().num_nonzero_, kRows * kSlicCols); // dense
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(SimpleDMatrix, SaveLoadBinary) {
|
TEST(SimpleDMatrix, SaveLoadBinary) {
|
||||||
dmlc::TemporaryDirectory tempdir;
|
dmlc::TemporaryDirectory tempdir;
|
||||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user