Simplify sparse and dense CPU hist kernels (#7029)
* Simplify sparse and dense kernels * Extract row partitioner. Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
@@ -23,19 +23,19 @@ TEST(DenseColumn, Test) {
|
||||
for (auto j = 0ull; j < dmat->Info().num_col_; j++) {
|
||||
switch (column_matrix.GetTypeSize()) {
|
||||
case kUint8BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint8_t>(j);
|
||||
auto col = column_matrix.GetColumn<uint8_t, false>(j);
|
||||
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j],
|
||||
(*col.get()).GetGlobalBinIdx(i));
|
||||
}
|
||||
break;
|
||||
case kUint16BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint16_t>(j);
|
||||
auto col = column_matrix.GetColumn<uint16_t, false>(j);
|
||||
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j],
|
||||
(*col.get()).GetGlobalBinIdx(i));
|
||||
}
|
||||
break;
|
||||
case kUint32BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint32_t>(j);
|
||||
auto col = column_matrix.GetColumn<uint32_t, false>(j);
|
||||
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j],
|
||||
(*col.get()).GetGlobalBinIdx(i));
|
||||
}
|
||||
@@ -68,17 +68,17 @@ TEST(SparseColumn, Test) {
|
||||
column_matrix.Init(gmat, 0.5);
|
||||
switch (column_matrix.GetTypeSize()) {
|
||||
case kUint8BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint8_t>(0);
|
||||
auto col = column_matrix.GetColumn<uint8_t, true>(0);
|
||||
CheckSparseColumn(*col.get(), gmat);
|
||||
}
|
||||
break;
|
||||
case kUint16BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint16_t>(0);
|
||||
auto col = column_matrix.GetColumn<uint16_t, true>(0);
|
||||
CheckSparseColumn(*col.get(), gmat);
|
||||
}
|
||||
break;
|
||||
case kUint32BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint32_t>(0);
|
||||
auto col = column_matrix.GetColumn<uint32_t, true>(0);
|
||||
CheckSparseColumn(*col.get(), gmat);
|
||||
}
|
||||
break;
|
||||
@@ -89,7 +89,7 @@ TEST(SparseColumn, Test) {
|
||||
template<typename BinIdxType>
|
||||
inline void CheckColumWithMissingValue(const Column<BinIdxType>& col_input,
|
||||
const GHistIndexMatrix& gmat) {
|
||||
const DenseColumn<BinIdxType>& col = static_cast<const DenseColumn<BinIdxType>& >(col_input);
|
||||
const DenseColumn<BinIdxType, true>& col = static_cast<const DenseColumn<BinIdxType, true>& >(col_input);
|
||||
for (auto i = 0ull; i < col.Size(); i++) {
|
||||
if (col.IsMissing(i)) continue;
|
||||
EXPECT_EQ(gmat.index[gmat.row_ptr[i]],
|
||||
@@ -109,17 +109,17 @@ TEST(DenseColumnWithMissing, Test) {
|
||||
column_matrix.Init(gmat, 0.2);
|
||||
switch (column_matrix.GetTypeSize()) {
|
||||
case kUint8BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint8_t>(0);
|
||||
auto col = column_matrix.GetColumn<uint8_t, true>(0);
|
||||
CheckColumWithMissingValue(*col.get(), gmat);
|
||||
}
|
||||
break;
|
||||
case kUint16BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint16_t>(0);
|
||||
auto col = column_matrix.GetColumn<uint16_t, true>(0);
|
||||
CheckColumWithMissingValue(*col.get(), gmat);
|
||||
}
|
||||
break;
|
||||
case kUint32BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint32_t>(0);
|
||||
auto col = column_matrix.GetColumn<uint32_t, true>(0);
|
||||
CheckColumWithMissingValue(*col.get(), gmat);
|
||||
}
|
||||
break;
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <utility>
|
||||
|
||||
#include "../../../src/common/row_set.h"
|
||||
#include "../../../src/common/partition_builder.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@@ -309,7 +309,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
RealImpl::InitData(gmat, fmat, tree, &gpair);
|
||||
this->hist_.AddHistRow(nid);
|
||||
this->hist_.AllocateAllData();
|
||||
this->BuildHist(gpair, this->row_set_collection_[nid],
|
||||
this->hist_builder_.template BuildHist<true>(gpair, this->row_set_collection_[nid],
|
||||
gmat, this->hist_[nid]);
|
||||
|
||||
// Check if number of histogram bins is correct
|
||||
@@ -350,7 +350,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
|
||||
this->hist_.AddHistRow(0);
|
||||
this->hist_.AllocateAllData();
|
||||
this->BuildHist(row_gpairs, this->row_set_collection_[0],
|
||||
this->hist_builder_.template BuildHist<false>(row_gpairs, this->row_set_collection_[0],
|
||||
gmat, this->hist_[0]);
|
||||
|
||||
RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree);
|
||||
@@ -482,8 +482,13 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
});
|
||||
const size_t task_id = RealImpl::partition_builder_.GetTaskIdx(0, 0);
|
||||
RealImpl::partition_builder_.AllocateForTask(task_id);
|
||||
this->template PartitionKernel<uint8_t>(0, 0, common::Range1d(0, kNRows),
|
||||
split, cm, tree);
|
||||
if (cm.AnyMissing()) {
|
||||
RealImpl::partition_builder_.template Partition<uint8_t, true>(0, 0, common::Range1d(0, kNRows),
|
||||
split, cm, tree, this->row_set_collection_[0].begin);
|
||||
} else {
|
||||
RealImpl::partition_builder_.template Partition<uint8_t, false>(0, 0, common::Range1d(0, kNRows),
|
||||
split, cm, tree, this->row_set_collection_[0].begin);
|
||||
}
|
||||
RealImpl::partition_builder_.CalculateRowOffsets();
|
||||
ASSERT_EQ(RealImpl::partition_builder_.GetNLeftElems(0), left_cnt);
|
||||
ASSERT_EQ(RealImpl::partition_builder_.GetNRightElems(0), right_cnt);
|
||||
|
||||
Reference in New Issue
Block a user