Categorical data support in CPU sketching. (#7221)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2020 by Contributors
|
||||
* Copyright 2017-2021 by Contributors
|
||||
* \file hist_util.h
|
||||
* \brief Utility for fast histogram aggregation
|
||||
* \author Philip Cho, Tianqi Chen
|
||||
@@ -128,6 +128,7 @@ inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins,
|
||||
}
|
||||
}
|
||||
HostSketchContainer container(reduced, max_bins,
|
||||
m->Info().feature_types.ConstHostSpan(),
|
||||
HostSketchContainer::UseGroup(info), threads);
|
||||
for (auto const &page : m->GetBatches<SparsePage>()) {
|
||||
container.PushRowPage(page, info, hessian);
|
||||
|
||||
@@ -1,29 +1,35 @@
|
||||
/*!
|
||||
* Copyright 2020 by XGBoost Contributors
|
||||
* Copyright 2020-2021 by XGBoost Contributors
|
||||
*/
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
#include "quantile.h"
|
||||
#include "hist_util.h"
|
||||
#include "categorical.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
HostSketchContainer::HostSketchContainer(std::vector<bst_row_t> columns_size,
|
||||
int32_t max_bins, bool use_group,
|
||||
int32_t n_threads)
|
||||
: columns_size_{std::move(columns_size)}, max_bins_{max_bins},
|
||||
HostSketchContainer::HostSketchContainer(
|
||||
std::vector<bst_row_t> columns_size, int32_t max_bins,
|
||||
common::Span<FeatureType const> feature_types, bool use_group,
|
||||
int32_t n_threads)
|
||||
: feature_types_(feature_types.cbegin(), feature_types.cend()),
|
||||
columns_size_{std::move(columns_size)}, max_bins_{max_bins},
|
||||
use_group_ind_{use_group}, n_threads_{n_threads} {
|
||||
monitor_.Init(__func__);
|
||||
CHECK_NE(columns_size_.size(), 0);
|
||||
sketches_.resize(columns_size_.size());
|
||||
CHECK_GE(n_threads_, 1);
|
||||
categories_.resize(columns_size_.size());
|
||||
ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) {
|
||||
auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]);
|
||||
n_bins = std::max(n_bins, static_cast<decltype(n_bins)>(1));
|
||||
auto eps = 1.0 / (static_cast<float>(n_bins) * WQSketch::kFactor);
|
||||
sketches_[i].Init(columns_size_[i], eps);
|
||||
sketches_[i].inqueue.queue.resize(sketches_[i].limit_size * 2);
|
||||
if (!IsCat(this->feature_types_, i)) {
|
||||
sketches_[i].Init(columns_size_[i], eps);
|
||||
sketches_[i].inqueue.queue.resize(sketches_[i].limit_size * 2);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -182,13 +188,21 @@ void HostSketchContainer::PushRowPage(
|
||||
auto p_inst = inst.data();
|
||||
if (is_dense) {
|
||||
for (size_t ii = begin; ii < end; ii++) {
|
||||
sketches_[ii].Push(p_inst[ii].fvalue, w);
|
||||
if (IsCat(feature_types_, ii)) {
|
||||
categories_[ii].emplace(p_inst[ii].fvalue);
|
||||
} else {
|
||||
sketches_[ii].Push(p_inst[ii].fvalue, w);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < inst.size(); ++i) {
|
||||
auto const& entry = p_inst[i];
|
||||
if (entry.index >= begin && entry.index < end) {
|
||||
sketches_[entry.index].Push(entry.fvalue, w);
|
||||
if (IsCat(feature_types_, entry.index)) {
|
||||
categories_[entry.index].emplace(entry.fvalue);
|
||||
} else {
|
||||
sketches_[entry.index].Push(entry.fvalue, w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -338,6 +352,13 @@ void AddCutPoint(WQuantileSketch<float, float>::SummaryContainer const &summary,
|
||||
}
|
||||
}
|
||||
|
||||
void AddCategories(std::set<bst_cat_t> const &categories, HistogramCuts *cuts) {
|
||||
auto &cut_values = cuts->cut_values_.HostVector();
|
||||
for (auto const &v : categories) {
|
||||
cut_values.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
void HostSketchContainer::MakeCuts(HistogramCuts* cuts) {
|
||||
monitor_.Start(__func__);
|
||||
std::vector<WQSketch::SummaryContainer> reduced;
|
||||
@@ -348,6 +369,9 @@ void HostSketchContainer::MakeCuts(HistogramCuts* cuts) {
|
||||
std::vector<WQSketch::SummaryContainer> final_summaries(reduced.size());
|
||||
|
||||
ParallelFor(reduced.size(), n_threads_, Sched::Guided(), [&](size_t fidx) {
|
||||
if (IsCat(feature_types_, fidx)) {
|
||||
return;
|
||||
}
|
||||
WQSketch::SummaryContainer &a = final_summaries[fidx];
|
||||
size_t max_num_bins = std::min(num_cuts[fidx], max_bins_);
|
||||
a.Reserve(max_num_bins + 1);
|
||||
@@ -367,13 +391,17 @@ void HostSketchContainer::MakeCuts(HistogramCuts* cuts) {
|
||||
for (size_t fid = 0; fid < reduced.size(); ++fid) {
|
||||
size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
|
||||
WQSketch::SummaryContainer const& a = final_summaries[fid];
|
||||
AddCutPoint(a, max_num_bins, cuts);
|
||||
// push a value that is greater than anything
|
||||
const bst_float cpt
|
||||
= (a.size > 0) ? a.data[a.size - 1].value : cuts->min_vals_.HostVector()[fid];
|
||||
// this must be bigger than last value in a scale
|
||||
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
|
||||
cuts->cut_values_.HostVector().push_back(last);
|
||||
if (IsCat(feature_types_, fid)) {
|
||||
AddCategories(categories_.at(fid), cuts);
|
||||
} else {
|
||||
AddCutPoint(a, max_num_bins, cuts);
|
||||
// push a value that is greater than anything
|
||||
const bst_float cpt = (a.size > 0) ? a.data[a.size - 1].value
|
||||
: cuts->min_vals_.HostVector()[fid];
|
||||
// this must be bigger than last value in a scale
|
||||
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
|
||||
cuts->cut_values_.HostVector().push_back(last);
|
||||
}
|
||||
|
||||
// Ensure that every feature gets at least one quantile point
|
||||
CHECK_LE(cuts->cut_values_.HostVector().size(), std::numeric_limits<uint32_t>::max());
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* Copyright 2014-2021 by Contributors
|
||||
* \file quantile.h
|
||||
* \brief util to compute quantiles
|
||||
* \author Tianqi Chen
|
||||
@@ -15,6 +15,7 @@
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <set>
|
||||
|
||||
#include "timer.h"
|
||||
|
||||
@@ -707,6 +708,9 @@ class HostSketchContainer {
|
||||
|
||||
private:
|
||||
std::vector<WQSketch> sketches_;
|
||||
std::vector<std::set<bst_cat_t>> categories_;
|
||||
std::vector<FeatureType> const feature_types_;
|
||||
|
||||
std::vector<bst_row_t> columns_size_;
|
||||
int32_t max_bins_;
|
||||
bool use_group_ind_{false};
|
||||
@@ -721,7 +725,8 @@ class HostSketchContainer {
|
||||
* \param use_group whether is assigned to group to data instance.
|
||||
*/
|
||||
HostSketchContainer(std::vector<bst_row_t> columns_size, int32_t max_bins,
|
||||
bool use_group, int32_t n_threads);
|
||||
common::Span<FeatureType const> feature_types, bool use_group,
|
||||
int32_t n_threads);
|
||||
|
||||
static bool UseGroup(MetaInfo const &info) {
|
||||
size_t const num_groups =
|
||||
|
||||
Reference in New Issue
Block a user