diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 861cfae8a..7cf9b9071 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -154,6 +154,7 @@ class ColumnMatrix { index_base_ = const_cast(gmat.cut.Ptrs().data()); const bool noMissingValues = NoMissingValues(gmat.row_ptr[nrow], nrow, nfeature); + any_missing_ = !noMissingValues; if (noMissingValues) { missing_flags_.resize(feature_offsets_[nfeature], false); @@ -311,11 +312,18 @@ class ColumnMatrix { const BinTypeSize GetTypeSize() const { return bins_type_size_; } + + // This is just an utility function const bool NoMissingValues(const size_t n_elements, const size_t n_row, const size_t n_features) { return n_elements == n_features * n_row; } + // And this returns part of state + const bool AnyMissing() const { + return any_missing_; + } + private: std::vector index_; @@ -329,6 +337,7 @@ class ColumnMatrix { uint32_t* index_base_; std::vector missing_flags_; BinTypeSize bins_type_size_; + bool any_missing_; }; } // namespace common diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 05dd9cc06..84dc625ae 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -826,7 +826,7 @@ void QuantileHistMaker::Builder::EvaluateSplits(const std::vector& // on comparison of indexes values (idx_span) and split point (split_cond) // Handle dense columns // Analog of std::stable_partition, but in no-inplace manner -template +template inline std::pair PartitionDenseKernel(const common::DenseColumn& column, common::Span rid_span, const int32_t split_cond, common::Span left_part, common::Span right_part) { @@ -837,14 +837,24 @@ inline std::pair PartitionDenseKernel(const common::DenseColumn< size_t nleft_elems = 0; size_t nright_elems = 0; - for (auto rid : rid_span) { - if (column.IsMissing(rid)) { - if (default_left) { - p_left_part[nleft_elems++] = rid; + if (any_missing) { + for (auto rid : rid_span) { + if (column.IsMissing(rid)) { + if (default_left) { + p_left_part[nleft_elems++] = rid; + } else { + p_right_part[nright_elems++] = rid; + } } else { - p_right_part[nright_elems++] = rid; + if ((static_cast(idx[rid]) + offset) <= split_cond) { + p_left_part[nleft_elems++] = rid; + } else { + p_right_part[nright_elems++] = rid; + } } - } else { + } + } else { + for (auto rid : rid_span) { if ((static_cast(idx[rid]) + offset) <= split_cond) { p_left_part[nleft_elems++] = rid; } else { @@ -919,6 +929,7 @@ void QuantileHistMaker::Builder::PartitionKernel( const size_t node_in_set, const size_t nid, common::Range1d range, const int32_t split_cond, const ColumnMatrix& column_matrix, const RegTree& tree) { const size_t* rid = row_set_collection_[nid].begin; + common::Span rid_span(rid + range.begin(), rid + range.end()); common::Span left = partition_builder_.GetLeftBuffer(node_in_set, range.begin(), range.end()); @@ -934,9 +945,21 @@ void QuantileHistMaker::Builder::PartitionKernel( const common::DenseColumn& column = static_cast& >(*(column_ptr.get())); if (default_left) { - child_nodes_sizes = PartitionDenseKernel(column, rid_span, split_cond, left, right); + if (column_matrix.AnyMissing()) { + child_nodes_sizes = PartitionDenseKernel(column, rid_span, split_cond, + left, right); + } else { + child_nodes_sizes = PartitionDenseKernel(column, rid_span, split_cond, + left, right); + } } else { - child_nodes_sizes = PartitionDenseKernel(column, rid_span, split_cond, left, right); + if (column_matrix.AnyMissing()) { + child_nodes_sizes = PartitionDenseKernel(column, rid_span, split_cond, + left, right); + } else { + child_nodes_sizes = PartitionDenseKernel(column, rid_span, split_cond, + left, right); + } } } else { const common::SparseColumn& column diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index ad07930b4..2c16f68fc 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -250,6 +250,71 @@ class QuantileHistMock : public QuantileHistMaker { omp_set_num_threads(1); } + void TestApplySplit(const GHistIndexBlockMatrix& quantile_index_block, + const RegTree& tree) { + std::vector row_gpairs = + { {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f}, + {0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f} }; + size_t constexpr kMaxBins = 4; + + // try out different sparsity to get different number of missing values + for (double sparsity : {0.0, 0.1, 0.2}) { + // kNRows samples with kNCols features + auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix(); + + common::GHistIndexMatrix gmat; + gmat.Init(dmat.get(), kMaxBins); + ColumnMatrix cm; + + // treat everything as dense, as this is what we intend to test here + cm.Init(gmat, 0.0); + RealImpl::InitData(gmat, row_gpairs, *dmat, tree); + hist_.AddHistRow(0); + + RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree); + + const size_t num_row = dmat->Info().num_row_; + // split by feature 0 + const size_t bin_id_min = gmat.cut.Ptrs()[0]; + const size_t bin_id_max = gmat.cut.Ptrs()[1]; + + // attempt to split at different bins + for (size_t split = 0; split < 4; split++) { + size_t left_cnt = 0, right_cnt = 0; + + // manually compute how many samples go left or right + for (size_t rid = 0; rid < num_row; ++rid) { + for (size_t offset = gmat.row_ptr[rid]; offset < gmat.row_ptr[rid + 1]; ++offset) { + const size_t bin_id = gmat.index[offset]; + if (bin_id >= bin_id_min && bin_id < bin_id_max) { + if (bin_id <= split) { + left_cnt ++; + } else { + right_cnt ++; + } + } + } + } + + // if any were missing due to sparsity, we add them to the left or to the right + size_t missing = kNRows - left_cnt - right_cnt; + if (tree[0].DefaultLeft()) { + left_cnt += missing; + } else { + right_cnt += missing; + } + + // have one node with kNRows (=8 at the moment) rows, just one task + RealImpl::partition_builder_.Init(1, 1, [&](size_t node_in_set) { + return 1; + }); + RealImpl::PartitionKernel(0, 0, common::Range1d(0, kNRows), split, cm, tree); + RealImpl::partition_builder_.CalculateRowOffsets(); + ASSERT_EQ(RealImpl::partition_builder_.GetNLeftElems(0), left_cnt); + ASSERT_EQ(RealImpl::partition_builder_.GetNRightElems(0), right_cnt); + } + } + } }; int static constexpr kNRows = 8, kNCols = 16; @@ -322,6 +387,13 @@ class QuantileHistMock : public QuantileHistMaker { builder_->TestEvaluateSplit(gmatb_, tree); } + + void TestApplySplit() { + RegTree tree = RegTree(); + tree.param.UpdateAllowUnknown(cfg_); + + builder_->TestApplySplit(gmatb_, tree); + } }; TEST(QuantileHist, InitData) { @@ -359,5 +431,15 @@ TEST(QuantileHist, EvalSplits) { maker.TestEvaluateSplit(); } +TEST(QuantileHist, ApplySplit) { + std::vector> cfg + {{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}, + {"split_evaluator", "elastic_net"}, + {"reg_lambda", "0"}, {"reg_alpha", "0"}, {"max_delta_step", "0"}, + {"min_child_weight", "0"}}; + QuantileHistMock maker(cfg); + maker.TestApplySplit(); +} + } // namespace tree } // namespace xgboost