/** * Copyright 2020-2024, XGBoost contributors */ #include // for shared_ptr #include // for move #include "batch_utils.h" // for RegenGHist, CheckParam #include "device_adapter.cuh" #include "ellpack_page.cuh" #include "iterative_dmatrix.h" #include "proxy_dmatrix.cuh" #include "proxy_dmatrix.h" // for BatchSamples, BatchColumns #include "simple_batch_iterator.h" namespace xgboost::data { void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p, DataIterHandle iter_handle, float missing, std::shared_ptr ref) { // A handle passed to external iterator. DMatrixProxy* proxy = MakeProxy(proxy_); CHECK(proxy); // The external iterator auto iter = DataIterProxy{iter_handle, reset_, next_}; dh::XGBCachingDeviceAllocator alloc; // Sketch for all batches. std::int32_t current_device{dh::CurrentDevice()}; auto get_ctx = [&]() { Context d_ctx = (ctx->IsCUDA()) ? *ctx : Context{}.MakeCUDA(current_device); CHECK(!d_ctx.IsCPU()); return d_ctx; }; fmat_ctx_ = get_ctx(); /** * Generate quantiles */ auto cuts = std::make_shared(); ExternalDataInfo ext_info; cuda_impl::MakeSketches(ctx, &iter, proxy, ref, p, missing, cuts, this->Info(), &ext_info); ext_info.SetInfo(ctx, &this->info_); auto init_page = [this, &cuts, &ext_info]() { if (!ellpack_) { // Should be put inside the while loop to protect against empty batch. In // that case device id is invalid. ellpack_.reset(new EllpackPage); *(ellpack_->Impl()) = EllpackPageImpl(&fmat_ctx_, cuts, this->IsDense(), ext_info.row_stride, ext_info.accumulated_rows); } }; /** * Generate gradient index. */ size_t offset = 0; iter.Reset(); size_t n_batches_for_verification = 0; while (iter.Next()) { init_page(); dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal)); auto rows = BatchSamples(proxy); dh::device_vector row_counts(rows + 1, 0); common::Span row_counts_span(row_counts.data().get(), row_counts.size()); cuda_impl::Dispatch(proxy, [=](auto const& value) { return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing); }); auto is_dense = this->IsDense(); proxy->Info().feature_types.SetDevice(dh::GetDevice(ctx)); auto d_feature_types = proxy->Info().feature_types.ConstDeviceSpan(); auto new_impl = cuda_impl::Dispatch(proxy, [&](auto const& value) { return EllpackPageImpl(&fmat_ctx_, value, missing, is_dense, row_counts_span, d_feature_types, ext_info.row_stride, rows, cuts); }); std::size_t num_elements = ellpack_->Impl()->Copy(&fmat_ctx_, &new_impl, offset); offset += num_elements; proxy->Info().num_row_ = BatchSamples(proxy); proxy->Info().num_col_ = ext_info.n_features; if (ext_info.n_batches != 1) { this->info_.Extend(std::move(proxy->Info()), false, true); } n_batches_for_verification++; } CHECK_EQ(ext_info.n_batches, n_batches_for_verification) << "Different number of batches returned between 2 iterations"; if (ext_info.n_batches == 1) { this->info_ = std::move(proxy->Info()); this->info_.num_nonzero_ = ext_info.nnz; CHECK_EQ(proxy->Info().labels.Size(), 0); } iter.Reset(); // Synchronise worker columns } BatchSet IterativeDMatrix::GetEllpackBatches(Context const* ctx, BatchParam const& param) { if (param.Initialized()) { detail::CheckParam(this->batch_, param); CHECK(!detail::RegenGHist(param, batch_)) << error::InconsistentMaxBin(); } if (!ellpack_ && !ghist_) { LOG(FATAL) << "`QuantileDMatrix` not initialized."; } if (!ellpack_) { ellpack_.reset(new EllpackPage()); if (ctx->IsCUDA()) { this->Info().feature_types.SetDevice(ctx->Device()); *ellpack_->Impl() = EllpackPageImpl(ctx, *this->ghist_, this->Info().feature_types.ConstDeviceSpan()); } else if (fmat_ctx_.IsCUDA()) { this->Info().feature_types.SetDevice(fmat_ctx_.Device()); *ellpack_->Impl() = EllpackPageImpl(&fmat_ctx_, *this->ghist_, this->Info().feature_types.ConstDeviceSpan()); } else { // Can happen when QDM is initialized on CPU, but a GPU version is queried by a different QDM // for cut reference. auto cuda_ctx = ctx->MakeCUDA(); this->Info().feature_types.SetDevice(cuda_ctx.Device()); *ellpack_->Impl() = EllpackPageImpl(&cuda_ctx, *this->ghist_, this->Info().feature_types.ConstDeviceSpan()); } } CHECK(ellpack_); auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(ellpack_)); return BatchSet(begin_iter); } } // namespace xgboost::data