Support building gradient index with cat data. (#7371)
This commit is contained in:
@@ -16,13 +16,12 @@
|
||||
#include <utility>
|
||||
#include <map>
|
||||
|
||||
#include "row_set.h"
|
||||
#include "categorical.h"
|
||||
#include "common.h"
|
||||
#include "quantile.h"
|
||||
#include "row_set.h"
|
||||
#include "threading_utils.h"
|
||||
#include "../tree/param.h"
|
||||
#include "./quantile.h"
|
||||
#include "./timer.h"
|
||||
#include "../include/rabit/rabit.h"
|
||||
#include "timer.h"
|
||||
|
||||
namespace xgboost {
|
||||
class GHistIndexMatrix;
|
||||
@@ -105,9 +104,29 @@ class HistogramCuts {
|
||||
return idx;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Search the bin index for numerical feature.
|
||||
*/
|
||||
BinIdx SearchBin(Entry const& e) const {
|
||||
return SearchBin(e.fvalue, e.index);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Search the bin index for categorical feature.
|
||||
*/
|
||||
BinIdx SearchCatBin(Entry const &e) const {
|
||||
auto const &ptrs = this->Ptrs();
|
||||
auto const &vals = this->Values();
|
||||
auto end = ptrs.at(e.index + 1) + vals.cbegin();
|
||||
auto beg = ptrs[e.index] + vals.cbegin();
|
||||
// Truncates the value in case it's not perfectly rounded.
|
||||
auto v = static_cast<float>(common::AsCat(e.fvalue));
|
||||
auto bin_idx = std::lower_bound(beg, end, v) - vals.cbegin();
|
||||
if (bin_idx == ptrs.at(e.index + 1)) {
|
||||
bin_idx -= 1;
|
||||
}
|
||||
return bin_idx;
|
||||
}
|
||||
};
|
||||
|
||||
inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins,
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
*/
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
#include "rabit/rabit.h"
|
||||
#include "quantile.h"
|
||||
#include "hist_util.h"
|
||||
#include "categorical.h"
|
||||
@@ -189,7 +191,7 @@ void HostSketchContainer::PushRowPage(
|
||||
if (is_dense) {
|
||||
for (size_t ii = begin; ii < end; ii++) {
|
||||
if (IsCat(feature_types_, ii)) {
|
||||
categories_[ii].emplace(p_inst[ii].fvalue);
|
||||
categories_[ii].emplace(AsCat(p_inst[ii].fvalue));
|
||||
} else {
|
||||
sketches_[ii].Push(p_inst[ii].fvalue, w);
|
||||
}
|
||||
@@ -199,7 +201,7 @@ void HostSketchContainer::PushRowPage(
|
||||
auto const& entry = p_inst[i];
|
||||
if (entry.index >= begin && entry.index < end) {
|
||||
if (IsCat(feature_types_, entry.index)) {
|
||||
categories_[entry.index].emplace(entry.fvalue);
|
||||
categories_[entry.index].emplace(AsCat(entry.fvalue));
|
||||
} else {
|
||||
sketches_[entry.index].Push(entry.fvalue, w);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user