merge 23Mar01
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2017-2023, XGBoost Contributors
|
||||
* \brief Data type for fast histogram aggregation.
|
||||
*/
|
||||
#include "gradient_index.h"
|
||||
@@ -19,18 +19,18 @@ namespace xgboost {
|
||||
|
||||
GHistIndexMatrix::GHistIndexMatrix() : columns_{std::make_unique<common::ColumnMatrix>()} {}
|
||||
|
||||
GHistIndexMatrix::GHistIndexMatrix(DMatrix *p_fmat, bst_bin_t max_bins_per_feat,
|
||||
double sparse_thresh, bool sorted_sketch, int32_t n_threads,
|
||||
GHistIndexMatrix::GHistIndexMatrix(Context const *ctx, DMatrix *p_fmat, bst_bin_t max_bins_per_feat,
|
||||
double sparse_thresh, bool sorted_sketch,
|
||||
common::Span<float> hess)
|
||||
: max_numeric_bins_per_feat{max_bins_per_feat} {
|
||||
CHECK(p_fmat->SingleColBlock());
|
||||
// We use sorted sketching for approx tree method since it's more efficient in
|
||||
// computation time (but higher memory usage).
|
||||
cut = common::SketchOnDMatrix(p_fmat, max_bins_per_feat, n_threads, sorted_sketch, hess);
|
||||
cut = common::SketchOnDMatrix(ctx, p_fmat, max_bins_per_feat, sorted_sketch, hess);
|
||||
|
||||
const uint32_t nbins = cut.Ptrs().back();
|
||||
hit_count.resize(nbins, 0);
|
||||
hit_count_tloc_.resize(n_threads * nbins, 0);
|
||||
hit_count_tloc_.resize(ctx->Threads() * nbins, 0);
|
||||
|
||||
size_t new_size = 1;
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
@@ -45,7 +45,7 @@ GHistIndexMatrix::GHistIndexMatrix(DMatrix *p_fmat, bst_bin_t max_bins_per_feat,
|
||||
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
this->PushBatch(batch, ft, n_threads);
|
||||
this->PushBatch(batch, ft, ctx->Threads());
|
||||
}
|
||||
this->columns_ = std::make_unique<common::ColumnMatrix>();
|
||||
|
||||
@@ -54,7 +54,7 @@ GHistIndexMatrix::GHistIndexMatrix(DMatrix *p_fmat, bst_bin_t max_bins_per_feat,
|
||||
// hist
|
||||
CHECK(!sorted_sketch);
|
||||
for (auto const &page : p_fmat->GetBatches<SparsePage>()) {
|
||||
this->columns_->InitFromSparse(page, *this, sparse_thresh, n_threads);
|
||||
this->columns_->InitFromSparse(page, *this, sparse_thresh, ctx->Threads());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -166,6 +166,12 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
|
||||
auto const &values = cut.Values();
|
||||
auto const &mins = cut.MinValues();
|
||||
auto const &ptrs = cut.Ptrs();
|
||||
return this->GetFvalue(ptrs, values, mins, ridx, fidx, is_cat);
|
||||
}
|
||||
|
||||
float GHistIndexMatrix::GetFvalue(std::vector<std::uint32_t> const &ptrs,
|
||||
std::vector<float> const &values, std::vector<float> const &mins,
|
||||
bst_row_t ridx, bst_feature_t fidx, bool is_cat) const {
|
||||
if (is_cat) {
|
||||
auto gidx = GetGindex(ridx, fidx);
|
||||
if (gidx == -1) {
|
||||
@@ -181,24 +187,27 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
|
||||
}
|
||||
return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx);
|
||||
};
|
||||
|
||||
if (columns_->GetColumnType(fidx) == common::kDenseColumn) {
|
||||
if (columns_->AnyMissing()) {
|
||||
switch (columns_->GetColumnType(fidx)) {
|
||||
case common::kDenseColumn: {
|
||||
if (columns_->AnyMissing()) {
|
||||
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
||||
auto column = columns_->DenseColumn<decltype(dtype), true>(fidx);
|
||||
return get_bin_val(column);
|
||||
});
|
||||
} else {
|
||||
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
||||
auto column = columns_->DenseColumn<decltype(dtype), false>(fidx);
|
||||
auto bin_idx = column[ridx];
|
||||
return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
case common::kSparseColumn: {
|
||||
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
||||
auto column = columns_->DenseColumn<decltype(dtype), true>(fidx);
|
||||
return get_bin_val(column);
|
||||
});
|
||||
} else {
|
||||
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
||||
auto column = columns_->DenseColumn<decltype(dtype), false>(fidx);
|
||||
auto column = columns_->SparseColumn<decltype(dtype)>(fidx, 0);
|
||||
return get_bin_val(column);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
||||
auto column = columns_->SparseColumn<decltype(dtype)>(fidx, 0);
|
||||
return get_bin_val(column);
|
||||
});
|
||||
}
|
||||
|
||||
SPAN_CHECK(false);
|
||||
|
||||
Reference in New Issue
Block a user