Support column split in approx tree method (#8847)
This commit is contained in:
parent
6d8afb2218
commit
7cbaee9916
@ -529,6 +529,11 @@ class DMatrix {
|
|||||||
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
|
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*! \brief Whether the data is split row-wise. */
|
||||||
|
bool IsRowSplit() const {
|
||||||
|
return Info().data_split_mode == DataSplitMode::kRow;
|
||||||
|
}
|
||||||
|
|
||||||
/*! \brief Whether the data is split column-wise. */
|
/*! \brief Whether the data is split column-wise. */
|
||||||
bool IsColumnSplit() const {
|
bool IsColumnSplit() const {
|
||||||
return Info().data_split_mode == DataSplitMode::kCol;
|
return Info().data_split_mode == DataSplitMode::kCol;
|
||||||
|
|||||||
@ -912,6 +912,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
|||||||
if (!cache_file.empty()) {
|
if (!cache_file.empty()) {
|
||||||
LOG(FATAL) << "Column-wise data split is not support for external memory.";
|
LOG(FATAL) << "Column-wise data split is not support for external memory.";
|
||||||
}
|
}
|
||||||
|
LOG(CONSOLE) << "Splitting data by column";
|
||||||
auto* sliced = dmat->SliceCol(npart, partid);
|
auto* sliced = dmat->SliceCol(npart, partid);
|
||||||
delete dmat;
|
delete dmat;
|
||||||
return sliced;
|
return sliced;
|
||||||
|
|||||||
@ -38,6 +38,7 @@ class HistEvaluator {
|
|||||||
TrainParam param_;
|
TrainParam param_;
|
||||||
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
||||||
TreeEvaluator tree_evaluator_;
|
TreeEvaluator tree_evaluator_;
|
||||||
|
bool is_col_split_{false};
|
||||||
FeatureInteractionConstraintHost interaction_constraints_;
|
FeatureInteractionConstraintHost interaction_constraints_;
|
||||||
std::vector<NodeEntry> snode_;
|
std::vector<NodeEntry> snode_;
|
||||||
|
|
||||||
@ -355,7 +356,24 @@ class HistEvaluator {
|
|||||||
tloc_candidates[n_threads * nidx_in_set + tidx].split);
|
tloc_candidates[n_threads * nidx_in_set + tidx].split);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (is_col_split_) {
|
||||||
|
// With column-wise data split, we gather the best splits from all the workers and update the
|
||||||
|
// expand entries accordingly.
|
||||||
|
auto const world = collective::GetWorldSize();
|
||||||
|
auto const rank = collective::GetRank();
|
||||||
|
auto const num_entries = entries.size();
|
||||||
|
std::vector<ExpandEntry> buffer{num_entries * world};
|
||||||
|
std::copy_n(entries.cbegin(), num_entries, buffer.begin() + num_entries * rank);
|
||||||
|
collective::Allgather(buffer.data(), buffer.size() * sizeof(ExpandEntry));
|
||||||
|
for (auto worker = 0; worker < world; ++worker) {
|
||||||
|
for (auto nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
|
||||||
|
entries[nidx_in_set].split.Update(buffer[worker * num_entries + nidx_in_set].split);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Add splits to tree, handles all statistic
|
// Add splits to tree, handles all statistic
|
||||||
void ApplyTreeSplit(ExpandEntry const& candidate, RegTree *p_tree) {
|
void ApplyTreeSplit(ExpandEntry const& candidate, RegTree *p_tree) {
|
||||||
auto evaluator = tree_evaluator_.GetEvaluator();
|
auto evaluator = tree_evaluator_.GetEvaluator();
|
||||||
@ -429,7 +447,8 @@ class HistEvaluator {
|
|||||||
std::shared_ptr<common::ColumnSampler> sampler)
|
std::shared_ptr<common::ColumnSampler> sampler)
|
||||||
: ctx_{ctx}, param_{param},
|
: ctx_{ctx}, param_{param},
|
||||||
column_sampler_{std::move(sampler)},
|
column_sampler_{std::move(sampler)},
|
||||||
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId} {
|
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
|
||||||
|
is_col_split_{info.data_split_mode == DataSplitMode::kCol} {
|
||||||
interaction_constraints_.Configure(param, info.num_col_);
|
interaction_constraints_.Configure(param, info.num_col_);
|
||||||
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
|
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
|
||||||
param_.colsample_bynode, param_.colsample_bylevel,
|
param_.colsample_bynode, param_.colsample_bylevel,
|
||||||
|
|||||||
@ -98,7 +98,7 @@ class HistogramBuilder {
|
|||||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
||||||
RegTree const *p_tree) {
|
RegTree const *p_tree) {
|
||||||
if (is_distributed_) {
|
if (is_distributed_ && !is_col_split_) {
|
||||||
this->AddHistRowsDistributed(starting_index, sync_count, nodes_for_explicit_hist_build,
|
this->AddHistRowsDistributed(starting_index, sync_count, nodes_for_explicit_hist_build,
|
||||||
nodes_for_subtraction_trick, p_tree);
|
nodes_for_subtraction_trick, p_tree);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -90,7 +90,9 @@ class GloablApproxBuilder {
|
|||||||
for (auto const &g : gpair) {
|
for (auto const &g : gpair) {
|
||||||
root_sum.Add(g);
|
root_sum.Add(g);
|
||||||
}
|
}
|
||||||
|
if (p_fmat->IsRowSplit()) {
|
||||||
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
|
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
|
||||||
|
}
|
||||||
std::vector<CPUExpandEntry> nodes{best};
|
std::vector<CPUExpandEntry> nodes{best};
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
auto space = ConstructHistSpace(partitioner_, nodes);
|
auto space = ConstructHistSpace(partitioner_, nodes);
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <xgboost/tree_model.h>
|
#include <xgboost/tree_model.h>
|
||||||
#include <xgboost/tree_updater.h>
|
#include <xgboost/tree_updater.h>
|
||||||
|
|
||||||
@ -8,26 +7,34 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
TEST(GrowHistMaker, InteractionConstraint) {
|
std::shared_ptr<DMatrix> GenerateDMatrix(std::size_t rows, std::size_t cols){
|
||||||
size_t constexpr kRows = 32;
|
return RandomDataGenerator{rows, cols, 0.6f}.Seed(3).GenerateDMatrix();
|
||||||
size_t constexpr kCols = 16;
|
}
|
||||||
|
|
||||||
Context ctx;
|
std::unique_ptr<HostDeviceVector<GradientPair>> GenerateGradients(std::size_t rows) {
|
||||||
|
auto p_gradients = std::make_unique<HostDeviceVector<GradientPair>>(rows);
|
||||||
auto p_dmat = RandomDataGenerator{kRows, kCols, 0.6f}.Seed(3).GenerateDMatrix();
|
auto& h_gradients = p_gradients->HostVector();
|
||||||
|
|
||||||
HostDeviceVector<GradientPair> gradients (kRows);
|
|
||||||
std::vector<GradientPair>& h_gradients = gradients.HostVector();
|
|
||||||
|
|
||||||
xgboost::SimpleLCG gen;
|
xgboost::SimpleLCG gen;
|
||||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
||||||
|
|
||||||
for (size_t i = 0; i < kRows; ++i) {
|
for (auto i = 0; i < rows; ++i) {
|
||||||
bst_float grad = dist(&gen);
|
auto grad = dist(&gen);
|
||||||
bst_float hess = dist(&gen);
|
auto hess = dist(&gen);
|
||||||
h_gradients[i] = GradientPair(grad, hess);
|
h_gradients[i] = GradientPair{grad, hess};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return p_gradients;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GrowHistMaker, InteractionConstraint)
|
||||||
|
{
|
||||||
|
auto constexpr kRows = 32;
|
||||||
|
auto constexpr kCols = 16;
|
||||||
|
auto p_dmat = GenerateDMatrix(kRows, kCols);
|
||||||
|
auto p_gradients = GenerateGradients(kRows);
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
{
|
{
|
||||||
// With constraints
|
// With constraints
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
@ -39,7 +46,7 @@ TEST(GrowHistMaker, InteractionConstraint) {
|
|||||||
{"interaction_constraints", "[[0, 1]]"},
|
{"interaction_constraints", "[[0, 1]]"},
|
||||||
{"num_feature", std::to_string(kCols)}});
|
{"num_feature", std::to_string(kCols)}});
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
updater->Update(&gradients, p_dmat.get(), position, {&tree});
|
updater->Update(p_gradients.get(), p_dmat.get(), position, {&tree});
|
||||||
|
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 4);
|
ASSERT_EQ(tree.NumExtraNodes(), 4);
|
||||||
ASSERT_EQ(tree[0].SplitIndex(), 1);
|
ASSERT_EQ(tree[0].SplitIndex(), 1);
|
||||||
@ -56,7 +63,7 @@ TEST(GrowHistMaker, InteractionConstraint) {
|
|||||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||||
updater->Configure(Args{{"num_feature", std::to_string(kCols)}});
|
updater->Configure(Args{{"num_feature", std::to_string(kCols)}});
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
updater->Update(&gradients, p_dmat.get(), position, {&tree});
|
updater->Update(p_gradients.get(), p_dmat.get(), position, {&tree});
|
||||||
|
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 10);
|
ASSERT_EQ(tree.NumExtraNodes(), 10);
|
||||||
ASSERT_EQ(tree[0].SplitIndex(), 1);
|
ASSERT_EQ(tree[0].SplitIndex(), 1);
|
||||||
@ -66,5 +73,53 @@ TEST(GrowHistMaker, InteractionConstraint) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void TestColumnSplit(int32_t rows, int32_t cols, RegTree const& expected_tree) {
|
||||||
|
auto p_dmat = GenerateDMatrix(rows, cols);
|
||||||
|
auto p_gradients = GenerateGradients(rows);
|
||||||
|
Context ctx;
|
||||||
|
std::unique_ptr<TreeUpdater> updater{
|
||||||
|
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||||
|
updater->Configure(Args{{"num_feature", std::to_string(cols)}});
|
||||||
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
|
|
||||||
|
std::unique_ptr<DMatrix> sliced{
|
||||||
|
p_dmat->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||||
|
|
||||||
|
RegTree tree;
|
||||||
|
tree.param.num_feature = cols;
|
||||||
|
updater->Update(p_gradients.get(), sliced.get(), position, {&tree});
|
||||||
|
|
||||||
|
EXPECT_EQ(tree.NumExtraNodes(), 10);
|
||||||
|
EXPECT_EQ(tree[0].SplitIndex(), 1);
|
||||||
|
|
||||||
|
EXPECT_NE(tree[tree[0].LeftChild()].SplitIndex(), 0);
|
||||||
|
EXPECT_NE(tree[tree[0].RightChild()].SplitIndex(), 0);
|
||||||
|
|
||||||
|
EXPECT_EQ(tree, expected_tree);
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
TEST(GrowHistMaker, ColumnSplit) {
|
||||||
|
auto constexpr kRows = 32;
|
||||||
|
auto constexpr kCols = 16;
|
||||||
|
|
||||||
|
RegTree expected_tree;
|
||||||
|
expected_tree.param.num_feature = kCols;
|
||||||
|
{
|
||||||
|
auto p_dmat = GenerateDMatrix(kRows, kCols);
|
||||||
|
auto p_gradients = GenerateGradients(kRows);
|
||||||
|
Context ctx;
|
||||||
|
std::unique_ptr<TreeUpdater> updater{
|
||||||
|
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||||
|
updater->Configure(Args{{"num_feature", std::to_string(kCols)}});
|
||||||
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
|
updater->Update(p_gradients.get(), p_dmat.get(), position, {&expected_tree});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto constexpr kWorldSize = 2;
|
||||||
|
RunWithInMemoryCommunicator(kWorldSize, TestColumnSplit, kRows, kCols, std::cref(expected_tree));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user