/*! * Copyright 2022 XGBoost contributors */ #include "iterative_dmatrix.h" #include // std::copy #include "../collective/communicator-inl.h" #include "../common/categorical.h" // common::IsCat #include "../common/column_matrix.h" #include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter. #include "gradient_index.h" #include "proxy_dmatrix.h" #include "simple_batch_iterator.h" namespace xgboost { namespace data { IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy, std::shared_ptr ref, DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing, int nthread, bst_bin_t max_bin) : proxy_{proxy}, reset_{reset}, next_{next} { // fetch the first batch auto iter = DataIterProxy{iter_handle, reset_, next_}; iter.Reset(); bool valid = iter.Next(); CHECK(valid) << "Iterative DMatrix must have at least 1 batch."; auto d = MakeProxy(proxy_)->DeviceIdx(); StringView msg{"All batch should be on the same device."}; if (batch_param_.gpu_id != Context::kCpuId) { CHECK_EQ(d, batch_param_.gpu_id) << msg; } batch_param_ = BatchParam{d, max_bin}; // hardcoded parameter. batch_param_.sparse_thresh = tree::TrainParam::DftSparseThreshold(); ctx_.UpdateAllowUnknown( Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}}); if (ctx_.IsCPU()) { this->InitFromCPU(iter_handle, missing, ref); } else { this->InitFromCUDA(iter_handle, missing, ref); } } void GetCutsFromRef(std::shared_ptr ref_, bst_feature_t n_features, BatchParam p, common::HistogramCuts* p_cuts) { CHECK(ref_); CHECK(p_cuts); auto csr = [&]() { for (auto const& page : ref_->GetBatches(p)) { *p_cuts = page.cut; break; } }; auto ellpack = [&]() { // workaround ellpack being initialized from CPU. if (p.gpu_id == Context::kCpuId) { p.gpu_id = ref_->Ctx()->gpu_id; } if (p.gpu_id == Context::kCpuId) { p.gpu_id = 0; } for (auto const& page : ref_->GetBatches(p)) { GetCutsFromEllpack(page, p_cuts); break; } }; if (ref_->PageExists()) { csr(); } else if (ref_->PageExists()) { ellpack(); } else { if (p.gpu_id == Context::kCpuId) { csr(); } else { ellpack(); } } CHECK_EQ(ref_->Info().num_col_, n_features) << "Invalid ref DMatrix, different number of features."; } void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing, std::shared_ptr ref) { DMatrixProxy* proxy = MakeProxy(proxy_); CHECK(proxy); // The external iterator auto iter = DataIterProxy{iter_handle, reset_, next_}; common::HistogramCuts cuts; auto num_rows = [&]() { return HostAdapterDispatch(proxy, [](auto const& value) { return value.Size(); }); }; auto num_cols = [&]() { return HostAdapterDispatch(proxy, [](auto const& value) { return value.NumCols(); }); }; std::vector column_sizes; auto const is_valid = data::IsValidFunctor{missing}; auto nnz_cnt = [&]() { return HostAdapterDispatch(proxy, [&](auto const& value) { size_t n_threads = ctx_.Threads(); size_t n_features = column_sizes.size(); linalg::Tensor column_sizes_tloc({n_threads, n_features}, Context::kCpuId); column_sizes_tloc.Data()->Fill(0); auto view = column_sizes_tloc.HostView(); common::ParallelFor(value.Size(), n_threads, common::Sched::Static(256), [&](auto i) { auto const& line = value.GetLine(i); for (size_t j = 0; j < line.Size(); ++j) { data::COOTuple const& elem = line.GetElement(j); if (is_valid(elem)) { view(omp_get_thread_num(), elem.column_idx)++; } } }); auto ptr = column_sizes_tloc.Data()->HostPointer(); auto result = std::accumulate(ptr, ptr + column_sizes_tloc.Size(), static_cast(0)); for (size_t tidx = 0; tidx < n_threads; ++tidx) { for (size_t fidx = 0; fidx < n_features; ++fidx) { column_sizes[fidx] += view(tidx, fidx); } } return result; }); }; size_t n_features = 0; size_t n_batches = 0; size_t accumulated_rows{0}; size_t nnz{0}; /** * CPU impl needs an additional loop for accumulating the column size. */ std::unique_ptr p_sketch; std::vector batch_nnz; do { // We use do while here as the first batch is fetched in ctor if (n_features == 0) { n_features = num_cols(); collective::Allreduce(&n_features, 1); column_sizes.resize(n_features); info_.num_col_ = n_features; } else { CHECK_EQ(n_features, num_cols()) << "Inconsistent number of columns."; } size_t batch_size = num_rows(); batch_nnz.push_back(nnz_cnt()); nnz += batch_nnz.back(); accumulated_rows += batch_size; n_batches++; } while (iter.Next()); iter.Reset(); // From here on Info() has the correct data shape Info().num_row_ = accumulated_rows; Info().num_nonzero_ = nnz; collective::Allreduce(&info_.num_col_, 1); CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) { return f > accumulated_rows; })) << "Something went wrong during iteration."; CHECK_GE(n_features, 1) << "Data must has at least 1 column."; /** * Generate quantiles */ accumulated_rows = 0; if (ref) { GetCutsFromRef(ref, Info().num_col_, batch_param_, &cuts); } else { size_t i = 0; while (iter.Next()) { if (!p_sketch) { p_sketch.reset(new common::HostSketchContainer{ batch_param_.max_bin, proxy->Info().feature_types.ConstHostSpan(), column_sizes, !proxy->Info().group_ptr_.empty(), ctx_.Threads()}); } HostAdapterDispatch(proxy, [&](auto const& batch) { proxy->Info().num_nonzero_ = batch_nnz[i]; // We don't need base row idx here as Info is from proxy and the number of rows in // it is consistent with data batch. p_sketch->PushAdapterBatch(batch, 0, proxy->Info(), missing); }); accumulated_rows += num_rows(); ++i; } iter.Reset(); CHECK_EQ(accumulated_rows, Info().num_row_); CHECK(p_sketch); p_sketch->MakeCuts(&cuts); } /** * Generate gradient index. */ this->ghist_ = std::make_unique(Info(), std::move(cuts), batch_param_.max_bin); size_t rbegin = 0; size_t prev_sum = 0; size_t i = 0; while (iter.Next()) { HostAdapterDispatch(proxy, [&](auto const& batch) { proxy->Info().num_nonzero_ = batch_nnz[i]; this->ghist_->PushAdapterBatch(&ctx_, rbegin, prev_sum, batch, missing, proxy->Info().feature_types.ConstHostSpan(), batch_param_.sparse_thresh, Info().num_row_); }); if (n_batches != 1) { this->info_.Extend(std::move(proxy->Info()), false, true); } size_t batch_size = num_rows(); prev_sum = this->ghist_->row_ptr[rbegin + batch_size]; rbegin += batch_size; ++i; } iter.Reset(); CHECK_EQ(rbegin, Info().num_row_); /** * Generate column matrix */ accumulated_rows = 0; while (iter.Next()) { HostAdapterDispatch(proxy, [&](auto const& batch) { this->ghist_->PushAdapterBatchColumns(&ctx_, batch, missing, accumulated_rows); }); accumulated_rows += num_rows(); } iter.Reset(); CHECK_EQ(accumulated_rows, Info().num_row_); if (n_batches == 1) { this->info_ = std::move(proxy->Info()); this->info_.num_nonzero_ = nnz; this->info_.num_col_ = n_features; // proxy might be empty. CHECK_EQ(proxy->Info().labels.Size(), 0); } } BatchSet IterativeDMatrix::GetGradientIndex(BatchParam const& param) { CheckParam(param); if (!ghist_) { CHECK(ellpack_); ghist_ = std::make_shared(&ctx_, Info(), *ellpack_, param); } if (!std::isnan(param.sparse_thresh) && param.sparse_thresh != tree::TrainParam::DftSparseThreshold()) { LOG(WARNING) << "`sparse_threshold` can not be changed when `QuantileDMatrix` is used instead " "of `DMatrix`."; } auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(ghist_)); return BatchSet(begin_iter); } BatchSet IterativeDMatrix::GetExtBatches(BatchParam const& param) { for (auto const& page : this->GetGradientIndex(param)) { auto p_out = std::make_shared(); p_out->data.Resize(this->Info().num_nonzero_); p_out->offset.Resize(this->Info().num_row_ + 1); auto& h_offset = p_out->offset.HostVector(); CHECK_EQ(page.row_ptr.size(), h_offset.size()); std::copy(page.row_ptr.cbegin(), page.row_ptr.cend(), h_offset.begin()); auto& h_data = p_out->data.HostVector(); auto const& vals = page.cut.Values(); auto const& mins = page.cut.MinValues(); auto const& ptrs = page.cut.Ptrs(); auto ft = Info().feature_types.ConstHostSpan(); AssignColumnBinIndex(page, [&](auto bin_idx, std::size_t idx, std::size_t, bst_feature_t fidx) { float v; if (common::IsCat(ft, fidx)) { v = vals[bin_idx]; } else { v = common::HistogramCuts::NumericBinValue(ptrs, vals, mins, fidx, bin_idx); } h_data[idx] = Entry{fidx, v}; }); auto p_ext_out = std::make_shared(p_out); auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(p_ext_out)); return BatchSet(begin_iter); } LOG(FATAL) << "Unreachable"; auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(nullptr)); return BatchSet(begin_iter); } } // namespace data } // namespace xgboost