Make objectives work with vertical distributed and federated learning (#9002)
This commit is contained in:
parent
720a8c3273
commit
15e073ca9d
@ -85,7 +85,7 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
|
|||||||
size_t n_leaf = nidx.size();
|
size_t n_leaf = nidx.size();
|
||||||
if (nptr.empty()) {
|
if (nptr.empty()) {
|
||||||
std::vector<float> quantiles;
|
std::vector<float> quantiles;
|
||||||
UpdateLeafValues(&quantiles, nidx, learning_rate, p_tree);
|
UpdateLeafValues(&quantiles, nidx, info, learning_rate, p_tree);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,39 +99,46 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
|
|||||||
auto h_predt = linalg::MakeTensorView(ctx, predt.ConstHostSpan(), info.num_row_,
|
auto h_predt = linalg::MakeTensorView(ctx, predt.ConstHostSpan(), info.num_row_,
|
||||||
predt.Size() / info.num_row_);
|
predt.Size() / info.num_row_);
|
||||||
|
|
||||||
// loop over each leaf
|
if (!info.IsVerticalFederated() || collective::GetRank() == 0) {
|
||||||
common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) {
|
// loop over each leaf
|
||||||
auto nidx = h_node_idx[k];
|
common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) {
|
||||||
CHECK(tree[nidx].IsLeaf());
|
auto nidx = h_node_idx[k];
|
||||||
CHECK_LT(k + 1, h_node_ptr.size());
|
CHECK(tree[nidx].IsLeaf());
|
||||||
size_t n = h_node_ptr[k + 1] - h_node_ptr[k];
|
CHECK_LT(k + 1, h_node_ptr.size());
|
||||||
auto h_row_set = common::Span<size_t const>{ridx}.subspan(h_node_ptr[k], n);
|
size_t n = h_node_ptr[k + 1] - h_node_ptr[k];
|
||||||
|
auto h_row_set = common::Span<size_t const>{ridx}.subspan(h_node_ptr[k], n);
|
||||||
|
|
||||||
auto h_labels = info.labels.HostView().Slice(linalg::All(), IdxY(info, group_idx));
|
auto h_labels = info.labels.HostView().Slice(linalg::All(), IdxY(info, group_idx));
|
||||||
auto h_weights = linalg::MakeVec(&info.weights_);
|
auto h_weights = linalg::MakeVec(&info.weights_);
|
||||||
|
|
||||||
auto iter = common::MakeIndexTransformIter([&](size_t i) -> float {
|
auto iter = common::MakeIndexTransformIter([&](size_t i) -> float {
|
||||||
auto row_idx = h_row_set[i];
|
auto row_idx = h_row_set[i];
|
||||||
return h_labels(row_idx) - h_predt(row_idx, group_idx);
|
return h_labels(row_idx) - h_predt(row_idx, group_idx);
|
||||||
});
|
});
|
||||||
auto w_it = common::MakeIndexTransformIter([&](size_t i) -> float {
|
auto w_it = common::MakeIndexTransformIter([&](size_t i) -> float {
|
||||||
auto row_idx = h_row_set[i];
|
auto row_idx = h_row_set[i];
|
||||||
return h_weights(row_idx);
|
return h_weights(row_idx);
|
||||||
|
});
|
||||||
|
|
||||||
|
float q{0};
|
||||||
|
if (info.weights_.Empty()) {
|
||||||
|
q = common::Quantile(ctx, alpha, iter, iter + h_row_set.size());
|
||||||
|
} else {
|
||||||
|
q = common::WeightedQuantile(ctx, alpha, iter, iter + h_row_set.size(), w_it);
|
||||||
|
}
|
||||||
|
if (std::isnan(q)) {
|
||||||
|
CHECK(h_row_set.empty());
|
||||||
|
}
|
||||||
|
quantiles.at(k) = q;
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
float q{0};
|
if (info.IsVerticalFederated()) {
|
||||||
if (info.weights_.Empty()) {
|
collective::Broadcast(static_cast<void*>(quantiles.data()), quantiles.size() * sizeof(float),
|
||||||
q = common::Quantile(ctx, alpha, iter, iter + h_row_set.size());
|
0);
|
||||||
} else {
|
}
|
||||||
q = common::WeightedQuantile(ctx, alpha, iter, iter + h_row_set.size(), w_it);
|
|
||||||
}
|
|
||||||
if (std::isnan(q)) {
|
|
||||||
CHECK(h_row_set.empty());
|
|
||||||
}
|
|
||||||
quantiles.at(k) = q;
|
|
||||||
});
|
|
||||||
|
|
||||||
UpdateLeafValues(&quantiles, nidx, learning_rate, p_tree);
|
UpdateLeafValues(&quantiles, nidx, info, learning_rate, p_tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(XGBOOST_USE_CUDA)
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
|
|||||||
@ -151,7 +151,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
|||||||
|
|
||||||
if (nptr.Empty()) {
|
if (nptr.Empty()) {
|
||||||
std::vector<float> quantiles;
|
std::vector<float> quantiles;
|
||||||
UpdateLeafValues(&quantiles, nidx.ConstHostVector(), learning_rate, p_tree);
|
UpdateLeafValues(&quantiles, nidx.ConstHostVector(), info, learning_rate, p_tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
HostDeviceVector<float> quantiles;
|
HostDeviceVector<float> quantiles;
|
||||||
@ -186,7 +186,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
|||||||
w_it + d_weights.size(), &quantiles);
|
w_it + d_weights.size(), &quantiles);
|
||||||
}
|
}
|
||||||
|
|
||||||
UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), learning_rate, p_tree);
|
UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), info, learning_rate, p_tree);
|
||||||
}
|
}
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
} // namespace obj
|
} // namespace obj
|
||||||
|
|||||||
@ -36,13 +36,15 @@ inline void FillMissingLeaf(std::vector<bst_node_t> const& maybe_missing,
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_node_t> const& nidx,
|
inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_node_t> const& nidx,
|
||||||
float learning_rate, RegTree* p_tree) {
|
MetaInfo const& info, float learning_rate, RegTree* p_tree) {
|
||||||
auto& tree = *p_tree;
|
auto& tree = *p_tree;
|
||||||
auto& quantiles = *p_quantiles;
|
auto& quantiles = *p_quantiles;
|
||||||
auto const& h_node_idx = nidx;
|
auto const& h_node_idx = nidx;
|
||||||
|
|
||||||
size_t n_leaf{h_node_idx.size()};
|
size_t n_leaf{h_node_idx.size()};
|
||||||
collective::Allreduce<collective::Operation::kMax>(&n_leaf, 1);
|
if (info.IsRowSplit()) {
|
||||||
|
collective::Allreduce<collective::Operation::kMax>(&n_leaf, 1);
|
||||||
|
}
|
||||||
CHECK(quantiles.empty() || quantiles.size() == n_leaf);
|
CHECK(quantiles.empty() || quantiles.size() == n_leaf);
|
||||||
if (quantiles.empty()) {
|
if (quantiles.empty()) {
|
||||||
quantiles.resize(n_leaf, std::numeric_limits<float>::quiet_NaN());
|
quantiles.resize(n_leaf, std::numeric_limits<float>::quiet_NaN());
|
||||||
@ -52,12 +54,16 @@ inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_no
|
|||||||
std::vector<int32_t> n_valids(quantiles.size());
|
std::vector<int32_t> n_valids(quantiles.size());
|
||||||
std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(),
|
std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(),
|
||||||
[](float q) { return static_cast<int32_t>(!std::isnan(q)); });
|
[](float q) { return static_cast<int32_t>(!std::isnan(q)); });
|
||||||
collective::Allreduce<collective::Operation::kSum>(n_valids.data(), n_valids.size());
|
if (info.IsRowSplit()) {
|
||||||
|
collective::Allreduce<collective::Operation::kSum>(n_valids.data(), n_valids.size());
|
||||||
|
}
|
||||||
// convert to 0 for all reduce
|
// convert to 0 for all reduce
|
||||||
std::replace_if(
|
std::replace_if(
|
||||||
quantiles.begin(), quantiles.end(), [](float q) { return std::isnan(q); }, 0.f);
|
quantiles.begin(), quantiles.end(), [](float q) { return std::isnan(q); }, 0.f);
|
||||||
// use the mean value
|
// use the mean value
|
||||||
collective::Allreduce<collective::Operation::kSum>(quantiles.data(), quantiles.size());
|
if (info.IsRowSplit()) {
|
||||||
|
collective::Allreduce<collective::Operation::kSum>(quantiles.data(), quantiles.size());
|
||||||
|
}
|
||||||
for (size_t i = 0; i < n_leaf; ++i) {
|
for (size_t i = 0; i < n_leaf; ++i) {
|
||||||
if (n_valids[i] > 0) {
|
if (n_valids[i] > 0) {
|
||||||
quantiles[i] /= static_cast<float>(n_valids[i]);
|
quantiles[i] /= static_cast<float>(n_valids[i]);
|
||||||
|
|||||||
@ -35,7 +35,10 @@ class QuantileRegression : public ObjFunction {
|
|||||||
bst_target_t Targets(MetaInfo const& info) const override {
|
bst_target_t Targets(MetaInfo const& info) const override {
|
||||||
auto const& alpha = param_.quantile_alpha.Get();
|
auto const& alpha = param_.quantile_alpha.Get();
|
||||||
CHECK_EQ(alpha.size(), alpha_.Size()) << "The objective is not yet configured.";
|
CHECK_EQ(alpha.size(), alpha_.Size()) << "The objective is not yet configured.";
|
||||||
CHECK_EQ(info.labels.Shape(1), 1) << "Multi-target is not yet supported by the quantile loss.";
|
if (!info.IsVerticalFederated() || collective::GetRank() == 0) {
|
||||||
|
CHECK_EQ(info.labels.Shape(1), 1)
|
||||||
|
<< "Multi-target is not yet supported by the quantile loss.";
|
||||||
|
}
|
||||||
CHECK(!alpha.empty());
|
CHECK(!alpha.empty());
|
||||||
// We have some placeholders for multi-target in the quantile loss. But it's not
|
// We have some placeholders for multi-target in the quantile loss. But it's not
|
||||||
// supported as the gbtree doesn't know how to slice the gradient and there's no 3-dim
|
// supported as the gbtree doesn't know how to slice the gradient and there's no 3-dim
|
||||||
@ -167,8 +170,10 @@ class QuantileRegression : public ObjFunction {
|
|||||||
common::Mean(ctx_, *base_score, &temp);
|
common::Mean(ctx_, *base_score, &temp);
|
||||||
double meanq = temp(0) * sw;
|
double meanq = temp(0) * sw;
|
||||||
|
|
||||||
collective::Allreduce<collective::Operation::kSum>(&meanq, 1);
|
if (info.IsRowSplit()) {
|
||||||
collective::Allreduce<collective::Operation::kSum>(&sw, 1);
|
collective::Allreduce<collective::Operation::kSum>(&meanq, 1);
|
||||||
|
collective::Allreduce<collective::Operation::kSum>(&sw, 1);
|
||||||
|
}
|
||||||
meanq /= (sw + kRtEps);
|
meanq /= (sw + kRtEps);
|
||||||
base_score->Reshape(1);
|
base_score->Reshape(1);
|
||||||
base_score->Data()->Fill(meanq);
|
base_score->Data()->Fill(meanq);
|
||||||
|
|||||||
@ -728,8 +728,10 @@ class MeanAbsoluteError : public ObjFunction {
|
|||||||
std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out),
|
std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out),
|
||||||
[w](float v) { return v * w; });
|
[w](float v) { return v * w; });
|
||||||
|
|
||||||
collective::Allreduce<collective::Operation::kSum>(out.Values().data(), out.Values().size());
|
if (info.IsRowSplit()) {
|
||||||
collective::Allreduce<collective::Operation::kSum>(&w, 1);
|
collective::Allreduce<collective::Operation::kSum>(out.Values().data(), out.Values().size());
|
||||||
|
collective::Allreduce<collective::Operation::kSum>(&w, 1);
|
||||||
|
}
|
||||||
|
|
||||||
if (common::CloseTo(w, 0.0)) {
|
if (common::CloseTo(w, 0.0)) {
|
||||||
// Mostly for handling empty dataset test.
|
// Mostly for handling empty dataset test.
|
||||||
|
|||||||
@ -13,66 +13,91 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
|
void VerifyObjectives(size_t rows, size_t cols, std::vector<float> const &expected_base_scores,
|
||||||
|
std::vector<Json> const &expected_models) {
|
||||||
|
auto const world_size = collective::GetWorldSize();
|
||||||
|
auto const rank = collective::GetRank();
|
||||||
|
std::shared_ptr<DMatrix> dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
|
||||||
|
|
||||||
|
if (rank == 0) {
|
||||||
|
auto &h_upper = dmat->Info().labels_upper_bound_.HostVector();
|
||||||
|
auto &h_lower = dmat->Info().labels_lower_bound_.HostVector();
|
||||||
|
h_lower.resize(rows);
|
||||||
|
h_upper.resize(rows);
|
||||||
|
for (size_t i = 0; i < rows; ++i) {
|
||||||
|
h_lower[i] = 1;
|
||||||
|
h_upper[i] = 10;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::shared_ptr<DMatrix> sliced{dmat->SliceCol(world_size, rank)};
|
||||||
|
|
||||||
|
auto i = 0;
|
||||||
|
for (auto const *entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) {
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
|
||||||
|
learner->SetParam("tree_method", "approx");
|
||||||
|
learner->SetParam("objective", entry->name);
|
||||||
|
if (entry->name.find("quantile") != std::string::npos) {
|
||||||
|
learner->SetParam("quantile_alpha", "0.5");
|
||||||
|
}
|
||||||
|
if (entry->name.find("multi") != std::string::npos) {
|
||||||
|
learner->SetParam("num_class", "3");
|
||||||
|
}
|
||||||
|
learner->UpdateOneIter(0, sliced);
|
||||||
|
|
||||||
|
Json config{Object{}};
|
||||||
|
learner->SaveConfig(&config);
|
||||||
|
auto base_score = GetBaseScore(config);
|
||||||
|
ASSERT_EQ(base_score, expected_base_scores[i]);
|
||||||
|
|
||||||
|
Json model{Object{}};
|
||||||
|
learner->SaveModel(&model);
|
||||||
|
ASSERT_EQ(model, expected_models[i]);
|
||||||
|
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class FederatedLearnerTest : public BaseFederatedTest {
|
class FederatedLearnerTest : public BaseFederatedTest {
|
||||||
protected:
|
protected:
|
||||||
static auto constexpr kRows{16};
|
static auto constexpr kRows{16};
|
||||||
static auto constexpr kCols{16};
|
static auto constexpr kCols{16};
|
||||||
};
|
};
|
||||||
|
|
||||||
void VerifyBaseScore(size_t rows, size_t cols, float expected_base_score) {
|
TEST_F(FederatedLearnerTest, Objectives) {
|
||||||
auto const world_size = collective::GetWorldSize();
|
std::shared_ptr<DMatrix> dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)};
|
||||||
auto const rank = collective::GetRank();
|
|
||||||
std::shared_ptr<DMatrix> Xy_{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
|
|
||||||
std::shared_ptr<DMatrix> sliced{Xy_->SliceCol(world_size, rank)};
|
|
||||||
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
|
|
||||||
learner->SetParam("tree_method", "approx");
|
|
||||||
learner->SetParam("objective", "binary:logistic");
|
|
||||||
learner->UpdateOneIter(0, sliced);
|
|
||||||
Json config{Object{}};
|
|
||||||
learner->SaveConfig(&config);
|
|
||||||
auto base_score = GetBaseScore(config);
|
|
||||||
ASSERT_EQ(base_score, expected_base_score);
|
|
||||||
}
|
|
||||||
|
|
||||||
void VerifyModel(size_t rows, size_t cols, Json const& expected_model) {
|
auto &h_upper = dmat->Info().labels_upper_bound_.HostVector();
|
||||||
auto const world_size = collective::GetWorldSize();
|
auto &h_lower = dmat->Info().labels_lower_bound_.HostVector();
|
||||||
auto const rank = collective::GetRank();
|
h_lower.resize(kRows);
|
||||||
std::shared_ptr<DMatrix> Xy_{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
|
h_upper.resize(kRows);
|
||||||
std::shared_ptr<DMatrix> sliced{Xy_->SliceCol(world_size, rank)};
|
for (size_t i = 0; i < kRows; ++i) {
|
||||||
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
|
h_lower[i] = 1;
|
||||||
learner->SetParam("tree_method", "approx");
|
h_upper[i] = 10;
|
||||||
learner->SetParam("objective", "binary:logistic");
|
}
|
||||||
learner->UpdateOneIter(0, sliced);
|
|
||||||
Json model{Object{}};
|
|
||||||
learner->SaveModel(&model);
|
|
||||||
ASSERT_EQ(model, expected_model);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(FederatedLearnerTest, BaseScore) {
|
std::vector<float> base_scores;
|
||||||
std::shared_ptr<DMatrix> Xy_{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)};
|
std::vector<Json> models;
|
||||||
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
for (auto const *entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) {
|
||||||
learner->SetParam("tree_method", "approx");
|
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
|
||||||
learner->SetParam("objective", "binary:logistic");
|
learner->SetParam("tree_method", "approx");
|
||||||
learner->UpdateOneIter(0, Xy_);
|
learner->SetParam("objective", entry->name);
|
||||||
Json config{Object{}};
|
if (entry->name.find("quantile") != std::string::npos) {
|
||||||
learner->SaveConfig(&config);
|
learner->SetParam("quantile_alpha", "0.5");
|
||||||
auto base_score = GetBaseScore(config);
|
}
|
||||||
ASSERT_NE(base_score, ObjFunction::DefaultBaseScore());
|
if (entry->name.find("multi") != std::string::npos) {
|
||||||
|
learner->SetParam("num_class", "3");
|
||||||
|
}
|
||||||
|
learner->UpdateOneIter(0, dmat);
|
||||||
|
Json config{Object{}};
|
||||||
|
learner->SaveConfig(&config);
|
||||||
|
base_scores.emplace_back(GetBaseScore(config));
|
||||||
|
|
||||||
RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyBaseScore, kRows, kCols,
|
Json model{Object{}};
|
||||||
base_score);
|
learner->SaveModel(&model);
|
||||||
}
|
models.emplace_back(model);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(FederatedLearnerTest, Model) {
|
RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyObjectives, kRows, kCols,
|
||||||
std::shared_ptr<DMatrix> Xy_{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)};
|
base_scores, models);
|
||||||
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
|
||||||
learner->SetParam("tree_method", "approx");
|
|
||||||
learner->SetParam("objective", "binary:logistic");
|
|
||||||
learner->UpdateOneIter(0, Xy_);
|
|
||||||
Json model{Object{}};
|
|
||||||
learner->SaveModel(&model);
|
|
||||||
|
|
||||||
RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyModel, kRows, kCols,
|
|
||||||
std::cref(model));
|
|
||||||
}
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -608,31 +608,74 @@ TEST_F(InitBaseScore, InitWithPredict) { this->TestInitWithPredt(); }
|
|||||||
|
|
||||||
TEST_F(InitBaseScore, UpdateProcess) { this->TestUpdateProcess(); }
|
TEST_F(InitBaseScore, UpdateProcess) { this->TestUpdateProcess(); }
|
||||||
|
|
||||||
void TestColumnSplitBaseScore(std::shared_ptr<DMatrix> Xy_, float expected_base_score) {
|
void TestColumnSplit(std::shared_ptr<DMatrix> dmat, std::vector<float> const& expected_base_scores,
|
||||||
|
std::vector<Json> const& expected_models) {
|
||||||
auto const world_size = collective::GetWorldSize();
|
auto const world_size = collective::GetWorldSize();
|
||||||
auto const rank = collective::GetRank();
|
auto const rank = collective::GetRank();
|
||||||
std::shared_ptr<DMatrix> sliced{Xy_->SliceCol(world_size, rank)};
|
std::shared_ptr<DMatrix> sliced{dmat->SliceCol(world_size, rank)};
|
||||||
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
|
|
||||||
learner->SetParam("tree_method", "approx");
|
auto i = 0;
|
||||||
learner->SetParam("objective", "binary:logistic");
|
for (auto const* entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) {
|
||||||
learner->UpdateOneIter(0, sliced);
|
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
|
||||||
Json config{Object{}};
|
learner->SetParam("tree_method", "approx");
|
||||||
learner->SaveConfig(&config);
|
learner->SetParam("objective", entry->name);
|
||||||
auto base_score = GetBaseScore(config);
|
if (entry->name.find("quantile") != std::string::npos) {
|
||||||
ASSERT_EQ(base_score, expected_base_score);
|
learner->SetParam("quantile_alpha", "0.5");
|
||||||
|
}
|
||||||
|
if (entry->name.find("multi") != std::string::npos) {
|
||||||
|
learner->SetParam("num_class", "3");
|
||||||
|
}
|
||||||
|
learner->UpdateOneIter(0, sliced);
|
||||||
|
Json config{Object{}};
|
||||||
|
learner->SaveConfig(&config);
|
||||||
|
auto base_score = GetBaseScore(config);
|
||||||
|
ASSERT_EQ(base_score, expected_base_scores[i]);
|
||||||
|
|
||||||
|
Json model{Object{}};
|
||||||
|
learner->SaveModel(&model);
|
||||||
|
ASSERT_EQ(model, expected_models[i]);
|
||||||
|
|
||||||
|
i++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(InitBaseScore, ColumnSplit) {
|
TEST(ColumnSplit, Objectives) {
|
||||||
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
auto constexpr kRows = 10, kCols = 10;
|
||||||
learner->SetParam("tree_method", "approx");
|
std::shared_ptr<DMatrix> dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)};
|
||||||
learner->SetParam("objective", "binary:logistic");
|
|
||||||
learner->UpdateOneIter(0, Xy_);
|
auto& h_upper = dmat->Info().labels_upper_bound_.HostVector();
|
||||||
Json config{Object{}};
|
auto& h_lower = dmat->Info().labels_lower_bound_.HostVector();
|
||||||
learner->SaveConfig(&config);
|
h_lower.resize(kRows);
|
||||||
auto base_score = GetBaseScore(config);
|
h_upper.resize(kRows);
|
||||||
ASSERT_NE(base_score, ObjFunction::DefaultBaseScore());
|
for (size_t i = 0; i < kRows; ++i) {
|
||||||
|
h_lower[i] = 1;
|
||||||
|
h_upper[i] = 10;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> base_scores;
|
||||||
|
std::vector<Json> models;
|
||||||
|
for (auto const* entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) {
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
|
||||||
|
learner->SetParam("tree_method", "approx");
|
||||||
|
learner->SetParam("objective", entry->name);
|
||||||
|
if (entry->name.find("quantile") != std::string::npos) {
|
||||||
|
learner->SetParam("quantile_alpha", "0.5");
|
||||||
|
}
|
||||||
|
if (entry->name.find("multi") != std::string::npos) {
|
||||||
|
learner->SetParam("num_class", "3");
|
||||||
|
}
|
||||||
|
learner->UpdateOneIter(0, dmat);
|
||||||
|
|
||||||
|
Json config{Object{}};
|
||||||
|
learner->SaveConfig(&config);
|
||||||
|
base_scores.emplace_back(GetBaseScore(config));
|
||||||
|
|
||||||
|
Json model{Object{}};
|
||||||
|
learner->SaveModel(&model);
|
||||||
|
models.emplace_back(model);
|
||||||
|
}
|
||||||
|
|
||||||
auto constexpr kWorldSize{3};
|
auto constexpr kWorldSize{3};
|
||||||
RunWithInMemoryCommunicator(kWorldSize, &TestColumnSplitBaseScore, Xy_, base_score);
|
RunWithInMemoryCommunicator(kWorldSize, &TestColumnSplit, dmat, base_scores, models);
|
||||||
}
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user