Support cpu quantile sketch with column-wise data split (#8742)
This commit is contained in:
@@ -897,10 +897,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
if (!cache_file.empty()) {
|
||||
LOG(FATAL) << "Column-wise data split is not support for external memory.";
|
||||
}
|
||||
auto slice_cols = (dmat->Info().num_col_ + 1) / npart;
|
||||
auto slice_start = slice_cols * partid;
|
||||
auto size = std::min(slice_cols, dmat->Info().num_col_ - slice_start);
|
||||
auto* sliced = dmat->SliceCol(slice_start, size);
|
||||
auto* sliced = dmat->SliceCol(npart, partid);
|
||||
delete dmat;
|
||||
return sliced;
|
||||
} else {
|
||||
|
||||
@@ -172,9 +172,9 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
size_t i = 0;
|
||||
while (iter.Next()) {
|
||||
if (!p_sketch) {
|
||||
p_sketch.reset(new common::HostSketchContainer{batch_param_.max_bin,
|
||||
proxy->Info().feature_types.ConstHostSpan(),
|
||||
column_sizes, false, ctx_.Threads()});
|
||||
p_sketch.reset(new common::HostSketchContainer{
|
||||
batch_param_.max_bin, proxy->Info().feature_types.ConstHostSpan(), column_sizes, false,
|
||||
proxy->Info().data_split_mode == DataSplitMode::kCol, ctx_.Threads()});
|
||||
}
|
||||
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
||||
proxy->Info().num_nonzero_ = batch_nnz[i];
|
||||
|
||||
@@ -86,7 +86,7 @@ class IterativeDMatrix : public DMatrix {
|
||||
LOG(FATAL) << "Slicing DMatrix is not supported for Quantile DMatrix.";
|
||||
return nullptr;
|
||||
}
|
||||
DMatrix *SliceCol(std::size_t, std::size_t) override {
|
||||
DMatrix *SliceCol(int num_slices, int slice_id) override {
|
||||
LOG(FATAL) << "Slicing DMatrix columns is not supported for Quantile DMatrix.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ class DMatrixProxy : public DMatrix {
|
||||
LOG(FATAL) << "Slicing DMatrix is not supported for Proxy DMatrix.";
|
||||
return nullptr;
|
||||
}
|
||||
DMatrix* SliceCol(std::size_t, std::size_t) override {
|
||||
DMatrix* SliceCol(int num_slices, int slice_id) override {
|
||||
LOG(FATAL) << "Slicing DMatrix columns is not supported for Proxy DMatrix.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -46,9 +46,12 @@ DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
|
||||
return out;
|
||||
}
|
||||
|
||||
DMatrix* SimpleDMatrix::SliceCol(std::size_t start, std::size_t size) {
|
||||
DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
||||
auto out = new SimpleDMatrix;
|
||||
SparsePage& out_page = *out->sparse_page_;
|
||||
auto const slice_size = info_.num_col_ / num_slices;
|
||||
auto const slice_start = slice_size * slice_id;
|
||||
auto const slice_end = (slice_id == num_slices - 1) ? info_.num_col_ : slice_start + slice_size;
|
||||
for (auto const &page : this->GetBatches<SparsePage>()) {
|
||||
auto batch = page.GetView();
|
||||
auto& h_data = out_page.data.HostVector();
|
||||
@@ -58,7 +61,7 @@ DMatrix* SimpleDMatrix::SliceCol(std::size_t start, std::size_t size) {
|
||||
auto inst = batch[i];
|
||||
auto prev_size = h_data.size();
|
||||
std::copy_if(inst.begin(), inst.end(), std::back_inserter(h_data), [&](Entry e) {
|
||||
return e.index >= start && e.index < start + size;
|
||||
return e.index >= slice_start && e.index < slice_end;
|
||||
});
|
||||
rptr += h_data.size() - prev_size;
|
||||
h_offset.emplace_back(rptr);
|
||||
|
||||
@@ -35,7 +35,7 @@ class SimpleDMatrix : public DMatrix {
|
||||
|
||||
bool SingleColBlock() const override { return true; }
|
||||
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
|
||||
DMatrix* SliceCol(std::size_t start, std::size_t size) override;
|
||||
DMatrix* SliceCol(int num_slices, int slice_id) override;
|
||||
|
||||
/*! \brief magic number used to identify SimpleDMatrix binary files */
|
||||
static const int kMagic = 0xffffab01;
|
||||
|
||||
@@ -107,7 +107,7 @@ class SparsePageDMatrix : public DMatrix {
|
||||
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
|
||||
return nullptr;
|
||||
}
|
||||
DMatrix *SliceCol(std::size_t, std::size_t) override {
|
||||
DMatrix *SliceCol(int num_slices, int slice_id) override {
|
||||
LOG(FATAL) << "Slicing DMatrix columns is not supported for external memory.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user