Define bin type. (#7850)
This commit is contained in:
parent
f7db16add1
commit
288c52596c
@ -121,6 +121,8 @@ using bst_float = float; // NOLINT
|
|||||||
using bst_cat_t = int32_t; // NOLINT
|
using bst_cat_t = int32_t; // NOLINT
|
||||||
/*! \brief Type for data column (feature) index. */
|
/*! \brief Type for data column (feature) index. */
|
||||||
using bst_feature_t = uint32_t; // NOLINT
|
using bst_feature_t = uint32_t; // NOLINT
|
||||||
|
/*! \brief Type for histogram bin index. */
|
||||||
|
using bst_bin_t = int32_t; // NOLINT
|
||||||
/*! \brief Type for data row index.
|
/*! \brief Type for data row index.
|
||||||
*
|
*
|
||||||
* Be careful `std::size_t' is implementation-defined. Meaning that the binary
|
* Be careful `std::size_t' is implementation-defined. Meaning that the binary
|
||||||
|
|||||||
@ -34,8 +34,8 @@ class Column {
|
|||||||
public:
|
public:
|
||||||
static constexpr int32_t kMissingId = -1;
|
static constexpr int32_t kMissingId = -1;
|
||||||
|
|
||||||
Column(ColumnType type, common::Span<const BinIdxType> index, const uint32_t index_base)
|
Column(ColumnType type, common::Span<const BinIdxType> index, const bst_bin_t index_base)
|
||||||
: type_(type), index_(index), index_base_(index_base) {}
|
: type_(type), index_(index), index_base_{index_base} {}
|
||||||
|
|
||||||
virtual ~Column() = default;
|
virtual ~Column() = default;
|
||||||
|
|
||||||
@ -60,19 +60,19 @@ class Column {
|
|||||||
/* bin indexes in range [0, max_bins - 1] */
|
/* bin indexes in range [0, max_bins - 1] */
|
||||||
common::Span<const BinIdxType> index_;
|
common::Span<const BinIdxType> index_;
|
||||||
/* bin index offset for specific feature */
|
/* bin index offset for specific feature */
|
||||||
const uint32_t index_base_;
|
bst_bin_t const index_base_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename BinIdxType>
|
template <typename BinIdxType>
|
||||||
class SparseColumn : public Column<BinIdxType> {
|
class SparseColumn : public Column<BinIdxType> {
|
||||||
public:
|
public:
|
||||||
SparseColumn(ColumnType type, common::Span<const BinIdxType> index, uint32_t index_base,
|
SparseColumn(ColumnType type, common::Span<const BinIdxType> index, bst_bin_t index_base,
|
||||||
common::Span<const size_t> row_ind)
|
common::Span<const size_t> row_ind)
|
||||||
: Column<BinIdxType>(type, index, index_base), row_ind_(row_ind) {}
|
: Column<BinIdxType>(type, index, index_base), row_ind_(row_ind) {}
|
||||||
|
|
||||||
const size_t* GetRowData() const { return row_ind_.data(); }
|
const size_t* GetRowData() const { return row_ind_.data(); }
|
||||||
|
|
||||||
int32_t GetBinIdx(size_t rid, size_t* state) const {
|
bst_bin_t GetBinIdx(size_t rid, size_t* state) const {
|
||||||
const size_t column_size = this->Size();
|
const size_t column_size = this->Size();
|
||||||
if (!((*state) < column_size)) {
|
if (!((*state) < column_size)) {
|
||||||
return this->kMissingId;
|
return this->kMissingId;
|
||||||
|
|||||||
@ -40,8 +40,6 @@ class HistogramCuts {
|
|||||||
float max_cat_{-1.0f};
|
float max_cat_{-1.0f};
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
using BinIdx = uint32_t;
|
|
||||||
|
|
||||||
void Swap(HistogramCuts&& that) noexcept(true) {
|
void Swap(HistogramCuts&& that) noexcept(true) {
|
||||||
std::swap(cut_values_, that.cut_values_);
|
std::swap(cut_values_, that.cut_values_);
|
||||||
std::swap(cut_ptrs_, that.cut_ptrs_);
|
std::swap(cut_ptrs_, that.cut_ptrs_);
|
||||||
@ -110,31 +108,31 @@ class HistogramCuts {
|
|||||||
|
|
||||||
// Return the index of a cut point that is strictly greater than the input
|
// Return the index of a cut point that is strictly greater than the input
|
||||||
// value, or the last available index if none exists
|
// value, or the last available index if none exists
|
||||||
BinIdx SearchBin(float value, bst_feature_t column_id, std::vector<uint32_t> const& ptrs,
|
bst_bin_t SearchBin(float value, bst_feature_t column_id, std::vector<uint32_t> const& ptrs,
|
||||||
std::vector<float> const& values) const {
|
std::vector<float> const& values) const {
|
||||||
auto end = ptrs[column_id + 1];
|
auto end = ptrs[column_id + 1];
|
||||||
auto beg = ptrs[column_id];
|
auto beg = ptrs[column_id];
|
||||||
auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value);
|
auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value);
|
||||||
BinIdx idx = it - values.cbegin();
|
bst_bin_t idx = it - values.cbegin();
|
||||||
idx -= !!(idx == end);
|
idx -= !!(idx == end);
|
||||||
return idx;
|
return idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
BinIdx SearchBin(float value, bst_feature_t column_id) const {
|
bst_bin_t SearchBin(float value, bst_feature_t column_id) const {
|
||||||
return this->SearchBin(value, column_id, Ptrs(), Values());
|
return this->SearchBin(value, column_id, Ptrs(), Values());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Search the bin index for numerical feature.
|
* \brief Search the bin index for numerical feature.
|
||||||
*/
|
*/
|
||||||
BinIdx SearchBin(Entry const& e) const {
|
bst_bin_t SearchBin(Entry const& e) const {
|
||||||
return SearchBin(e.fvalue, e.index);
|
return SearchBin(e.fvalue, e.index);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Search the bin index for categorical feature.
|
* \brief Search the bin index for categorical feature.
|
||||||
*/
|
*/
|
||||||
BinIdx SearchCatBin(Entry const &e) const {
|
bst_bin_t SearchCatBin(Entry const &e) const {
|
||||||
auto const &ptrs = this->Ptrs();
|
auto const &ptrs = this->Ptrs();
|
||||||
auto const &vals = this->Values();
|
auto const &vals = this->Values();
|
||||||
auto end = ptrs.at(e.index + 1) + vals.cbegin();
|
auto end = ptrs.at(e.index + 1) + vals.cbegin();
|
||||||
@ -296,10 +294,10 @@ struct Index {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename GradientIndex>
|
template <typename GradientIndex>
|
||||||
int32_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end,
|
bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end,
|
||||||
GradientIndex const &data,
|
GradientIndex const& data,
|
||||||
uint32_t const fidx_begin,
|
uint32_t const fidx_begin,
|
||||||
uint32_t const fidx_end) {
|
uint32_t const fidx_end) {
|
||||||
size_t previous_middle = std::numeric_limits<size_t>::max();
|
size_t previous_middle = std::numeric_limits<size_t>::max();
|
||||||
while (end != begin) {
|
while (end != begin) {
|
||||||
size_t middle = begin + (end - begin) / 2;
|
size_t middle = begin + (end - begin) / 2;
|
||||||
@ -324,8 +322,6 @@ int32_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end,
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
class ColumnMatrix;
|
|
||||||
|
|
||||||
template<typename GradientSumT>
|
template<typename GradientSumT>
|
||||||
using GHistRow = Span<xgboost::detail::GradientPairInternal<GradientSumT> >;
|
using GHistRow = Span<xgboost::detail::GradientPairInternal<GradientSumT> >;
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,10 @@
|
|||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
class ColumnMatrix;
|
||||||
|
} // namespace common
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief preprocessed global index matrix, in CSR format
|
* \brief preprocessed global index matrix, in CSR format
|
||||||
*
|
*
|
||||||
@ -80,13 +84,13 @@ class GHistIndexMatrix {
|
|||||||
for (bst_uint j = 0; j < inst.size(); ++j) {
|
for (bst_uint j = 0; j < inst.size(); ++j) {
|
||||||
auto e = inst[j];
|
auto e = inst[j];
|
||||||
if (common::IsCat(ft, e.index)) {
|
if (common::IsCat(ft, e.index)) {
|
||||||
auto bin_idx = cut.SearchCatBin(e);
|
bst_bin_t bin_idx = cut.SearchCatBin(e);
|
||||||
index_data[ibegin + j] = get_offset(bin_idx, j);
|
index_data[ibegin + j] = get_offset(bin_idx, j);
|
||||||
++hit_count_tloc_[tid * nbins + bin_idx];
|
++hit_count_tloc_[tid * nbins + bin_idx];
|
||||||
} else {
|
} else {
|
||||||
uint32_t idx = cut.SearchBin(e.fvalue, e.index, ptrs, values);
|
bst_bin_t bin_idx = cut.SearchBin(e.fvalue, e.index, ptrs, values);
|
||||||
index_data[ibegin + j] = get_offset(idx, j);
|
index_data[ibegin + j] = get_offset(bin_idx, j);
|
||||||
++hit_count_tloc_[tid * nbins + idx];
|
++hit_count_tloc_[tid * nbins + bin_idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user