Support building gradient index with cat data. (#7371)

This commit is contained in:
Jiaming Yuan
2021-11-03 22:37:37 +08:00
committed by GitHub
parent 57a4b4ff64
commit ccdabe4512
10 changed files with 105 additions and 27 deletions

View File

@@ -16,13 +16,12 @@
#include <utility>
#include <map>
#include "row_set.h"
#include "categorical.h"
#include "common.h"
#include "quantile.h"
#include "row_set.h"
#include "threading_utils.h"
#include "../tree/param.h"
#include "./quantile.h"
#include "./timer.h"
#include "../include/rabit/rabit.h"
#include "timer.h"
namespace xgboost {
class GHistIndexMatrix;
@@ -105,9 +104,29 @@ class HistogramCuts {
return idx;
}
/**
* \brief Search the bin index for numerical feature.
*/
BinIdx SearchBin(Entry const& e) const {
return SearchBin(e.fvalue, e.index);
}
/**
* \brief Search the bin index for categorical feature.
*/
BinIdx SearchCatBin(Entry const &e) const {
auto const &ptrs = this->Ptrs();
auto const &vals = this->Values();
auto end = ptrs.at(e.index + 1) + vals.cbegin();
auto beg = ptrs[e.index] + vals.cbegin();
// Truncates the value in case it's not perfectly rounded.
auto v = static_cast<float>(common::AsCat(e.fvalue));
auto bin_idx = std::lower_bound(beg, end, v) - vals.cbegin();
if (bin_idx == ptrs.at(e.index + 1)) {
bin_idx -= 1;
}
return bin_idx;
}
};
inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins,

View File

@@ -3,6 +3,8 @@
*/
#include <limits>
#include <utility>
#include "rabit/rabit.h"
#include "quantile.h"
#include "hist_util.h"
#include "categorical.h"
@@ -189,7 +191,7 @@ void HostSketchContainer::PushRowPage(
if (is_dense) {
for (size_t ii = begin; ii < end; ii++) {
if (IsCat(feature_types_, ii)) {
categories_[ii].emplace(p_inst[ii].fvalue);
categories_[ii].emplace(AsCat(p_inst[ii].fvalue));
} else {
sketches_[ii].Push(p_inst[ii].fvalue, w);
}
@@ -199,7 +201,7 @@ void HostSketchContainer::PushRowPage(
auto const& entry = p_inst[i];
if (entry.index >= begin && entry.index < end) {
if (IsCat(feature_types_, entry.index)) {
categories_[entry.index].emplace(entry.fvalue);
categories_[entry.index].emplace(AsCat(entry.fvalue));
} else {
sketches_[entry.index].Push(entry.fvalue, w);
}