Unify CPU hist sketching (#5880)
This commit is contained in:
@@ -9,12 +9,15 @@
|
||||
|
||||
#include <dmlc/base.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
|
||||
#include "timer.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
/*!
|
||||
@@ -682,6 +685,57 @@ template<typename DType, typename RType = unsigned>
|
||||
class WXQuantileSketch :
|
||||
public QuantileSketchTemplate<DType, RType, WXQSummary<DType, RType> > {
|
||||
};
|
||||
|
||||
class HistogramCuts;
|
||||
|
||||
/*!
|
||||
* A sketch matrix storing sketches for each feature.
|
||||
*/
|
||||
class HostSketchContainer {
|
||||
public:
|
||||
using WQSketch = WQuantileSketch<float, float>;
|
||||
|
||||
private:
|
||||
std::vector<WQSketch> sketches_;
|
||||
std::vector<bst_row_t> columns_size_;
|
||||
int32_t max_bins_;
|
||||
bool use_group_ind_{false};
|
||||
Monitor monitor_;
|
||||
|
||||
public:
|
||||
/* \brief Initialize necessary info.
|
||||
*
|
||||
* \param columns_size Size of each column.
|
||||
* \param max_bins maximum number of bins for each feature.
|
||||
* \param use_group whether is assigned to group to data instance.
|
||||
*/
|
||||
HostSketchContainer(std::vector<bst_row_t> columns_size, int32_t max_bins,
|
||||
bool use_group);
|
||||
|
||||
static bool UseGroup(MetaInfo const &info) {
|
||||
size_t const num_groups =
|
||||
info.group_ptr_.size() == 0 ? 0 : info.group_ptr_.size() - 1;
|
||||
// Use group index for weights?
|
||||
bool const use_group_ind =
|
||||
num_groups != 0 && (info.weights_.Size() != info.num_row_);
|
||||
return use_group_ind;
|
||||
}
|
||||
|
||||
static uint32_t SearchGroupIndFromRow(std::vector<bst_uint> const &group_ptr,
|
||||
size_t const base_rowid) {
|
||||
CHECK_LT(base_rowid, group_ptr.back())
|
||||
<< "Row: " << base_rowid << " is not found in any group.";
|
||||
bst_group_t group_ind =
|
||||
std::upper_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid) -
|
||||
group_ptr.cbegin() - 1;
|
||||
return group_ind;
|
||||
}
|
||||
|
||||
/* \brief Push a CSR matrix. */
|
||||
void PushRowPage(SparsePage const& page, MetaInfo const& info);
|
||||
|
||||
void MakeCuts(HistogramCuts* cuts);
|
||||
};
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_QUANTILE_H_
|
||||
|
||||
Reference in New Issue
Block a user