Support building gradient index with cat data. (#7371)

This commit is contained in:
Jiaming Yuan
2021-11-03 22:37:37 +08:00
committed by GitHub
parent 57a4b4ff64
commit ccdabe4512
10 changed files with 105 additions and 27 deletions

View File

@@ -16,13 +16,12 @@
#include <utility>
#include <map>
#include "row_set.h"
#include "categorical.h"
#include "common.h"
#include "quantile.h"
#include "row_set.h"
#include "threading_utils.h"
#include "../tree/param.h"
#include "./quantile.h"
#include "./timer.h"
#include "../include/rabit/rabit.h"
#include "timer.h"
namespace xgboost {
class GHistIndexMatrix;
@@ -105,9 +104,29 @@ class HistogramCuts {
return idx;
}
/**
* \brief Search the bin index for numerical feature.
*/
BinIdx SearchBin(Entry const& e) const {
return SearchBin(e.fvalue, e.index);
}
/**
* \brief Search the bin index for categorical feature.
*/
BinIdx SearchCatBin(Entry const &e) const {
auto const &ptrs = this->Ptrs();
auto const &vals = this->Values();
auto end = ptrs.at(e.index + 1) + vals.cbegin();
auto beg = ptrs[e.index] + vals.cbegin();
// Truncates the value in case it's not perfectly rounded.
auto v = static_cast<float>(common::AsCat(e.fvalue));
auto bin_idx = std::lower_bound(beg, end, v) - vals.cbegin();
if (bin_idx == ptrs.at(e.index + 1)) {
bin_idx -= 1;
}
return bin_idx;
}
};
inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins,