[EM] Log the page size of ellpack. (#10713)

This commit is contained in:
Jiaming Yuan 2024-08-17 01:35:47 +08:00 committed by GitHub
parent abe65e3769
commit 033a666900
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 9 additions and 21 deletions

View File

@ -1,18 +1,15 @@
/*! /**
* Copyright by Contributors 2017-2019 * Copyright 2017-2024, XGBoost Contributors
*/ */
#pragma once #pragma once
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <chrono> #include <chrono>
#include <iostream>
#include <map> #include <map>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector>
namespace xgboost {
namespace common {
namespace xgboost::common {
struct Timer { struct Timer {
using ClockT = std::chrono::high_resolution_clock; using ClockT = std::chrono::high_resolution_clock;
using TimePointT = std::chrono::high_resolution_clock::time_point; using TimePointT = std::chrono::high_resolution_clock::time_point;
@ -82,5 +79,4 @@ struct Monitor {
void Start(const std::string &name); void Start(const std::string &name);
void Stop(const std::string &name); void Stop(const std::string &name);
}; };
} // namespace common } // namespace xgboost::common
} // namespace xgboost

View File

@ -356,7 +356,6 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& pag
: is_dense{page.IsDense()}, : is_dense{page.IsDense()},
base_rowid{page.base_rowid}, base_rowid{page.base_rowid},
n_rows{page.Size()}, n_rows{page.Size()},
// This makes a copy of the cut values.
cuts_{std::make_shared<common::HistogramCuts>(page.cut)} { cuts_{std::make_shared<common::HistogramCuts>(page.cut)} {
auto it = common::MakeIndexTransformIter( auto it = common::MakeIndexTransformIter(
[&](size_t i) { return page.row_ptr[i + 1] - page.row_ptr[i]; }); [&](size_t i) { return page.row_ptr[i + 1] - page.row_ptr[i]; });
@ -540,15 +539,7 @@ void EllpackPageImpl::CreateHistIndices(DeviceOrd device,
// Return the number of rows contained in this page. // Return the number of rows contained in this page.
[[nodiscard]] bst_idx_t EllpackPageImpl::Size() const { return n_rows; } [[nodiscard]] bst_idx_t EllpackPageImpl::Size() const { return n_rows; }
// Return the memory cost for storing the compressed features. std::size_t EllpackPageImpl::MemCostBytes() const { return this->gidx_buffer.size_bytes(); }
size_t EllpackPageImpl::MemCostBytes(size_t num_rows, size_t row_stride,
const common::HistogramCuts& cuts) {
// Required buffer size for storing data matrix in EtoLLPack format.
size_t compressed_size_bytes =
common::CompressedBufferWriter::CalculateBufferSize(row_stride * num_rows,
cuts.TotalBins() + 1);
return compressed_size_bytes;
}
EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor( EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor(
DeviceOrd device, common::Span<FeatureType const> feature_types) const { DeviceOrd device, common::Span<FeatureType const> feature_types) const {

View File

@ -217,8 +217,7 @@ class EllpackPageImpl {
[[nodiscard]] bool IsDense() const { return is_dense; } [[nodiscard]] bool IsDense() const { return is_dense; }
/** @return Estimation of memory cost of this page. */ /** @return Estimation of memory cost of this page. */
static size_t MemCostBytes(size_t num_rows, size_t row_stride, const common::HistogramCuts&cuts) ; std::size_t MemCostBytes() const;
/** /**
* @brief Return the total number of symbols (total number of bins plus 1 for not * @brief Return the total number of symbols (total number of bins plus 1 for not

View File

@ -172,6 +172,8 @@ void EllpackPageSourceImpl<F>::Fetch() {
Context ctx = Context{}.MakeCUDA(this->Device().ordinal); Context ctx = Context{}.MakeCUDA(this->Device().ordinal);
*impl = EllpackPageImpl{&ctx, this->GetCuts(), *csr, is_dense_, row_stride_, feature_types_}; *impl = EllpackPageImpl{&ctx, this->GetCuts(), *csr, is_dense_, row_stride_, feature_types_};
this->page_->SetBaseRowId(csr->base_rowid); this->page_->SetBaseRowId(csr->base_rowid);
LOG(INFO) << "Generated an Ellpack page with size: " << impl->MemCostBytes()
<< " from a SparsePage with size:" << csr->MemCostBytes();
this->WriteCache(); this->WriteCache();
} }
} }