Use matrix for gradient. (#9508)

- Use the `linalg::Matrix` for storing gradients.
- New API for the custom objective.
- Custom objective for multi-class/multi-target is now required to return the correct shape.
- Custom objective for Python can accept arrays with any strides. (row-major, column-major)
This commit is contained in:
Jiaming Yuan
2023-08-24 05:29:52 +08:00
committed by GitHub
parent 6103dca0bb
commit 972730cde0
77 changed files with 1052 additions and 651 deletions

View File

@@ -66,14 +66,13 @@ inline void FitStump(Context const*, linalg::TensorView<GradientPair const, 2>,
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace cuda_impl
void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector<GradientPair> const& gpair,
void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix<GradientPair> const& gpair,
bst_target_t n_targets, linalg::Vector<float>* out) {
out->SetDevice(ctx->gpu_id);
out->Reshape(n_targets);
auto n_samples = gpair.Size() / n_targets;
gpair.SetDevice(ctx->gpu_id);
auto gpair_t = linalg::MakeTensorView(ctx, &gpair, n_samples, n_targets);
gpair.SetDevice(ctx->Device());
auto gpair_t = gpair.View(ctx->Device());
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id));
}

View File

@@ -31,7 +31,7 @@ XGBOOST_DEVICE inline double CalcUnregularizedWeight(T sum_grad, T sum_hess) {
/**
* @brief Fit a tree stump as an estimation of base_score.
*/
void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector<GradientPair> const& gpair,
void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix<GradientPair> const& gpair,
bst_target_t n_targets, linalg::Vector<float>* out);
} // namespace tree
} // namespace xgboost

View File

@@ -269,17 +269,18 @@ class GlobalApproxUpdater : public TreeUpdater {
out["hist_train_param"] = ToJson(hist_param_);
}
void InitData(TrainParam const &param, HostDeviceVector<GradientPair> const *gpair,
void InitData(TrainParam const &param, linalg::Matrix<GradientPair> const *gpair,
linalg::Matrix<GradientPair> *sampled) {
*sampled = linalg::Empty<GradientPair>(ctx_, gpair->Size(), 1);
sampled->Data()->Copy(*gpair);
auto in = gpair->HostView().Values();
std::copy(in.data(), in.data() + in.size(), sampled->HostView().Values().data());
SampleGradient(ctx_, param, sampled->HostView());
}
[[nodiscard]] char const *Name() const override { return "grow_histmaker"; }
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *m,
void Update(TrainParam const *param, linalg::Matrix<GradientPair> *gpair, DMatrix *m,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) override {
CHECK(hist_param_.GetInitialised());

View File

@@ -91,7 +91,7 @@ class ColMaker: public TreeUpdater {
}
}
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *dmat,
void Update(TrainParam const *param, linalg::Matrix<GradientPair> *gpair, DMatrix *dmat,
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
const std::vector<RegTree *> &trees) override {
if (collective::IsDistributed()) {
@@ -106,10 +106,11 @@ class ColMaker: public TreeUpdater {
// rescale learning rate according to size of trees
interaction_constraints_.Configure(*param, dmat->Info().num_row_);
// build tree
CHECK_EQ(gpair->Shape(1), 1) << MTNotImplemented();
for (auto tree : trees) {
CHECK(ctx_);
Builder builder(*param, colmaker_param_, interaction_constraints_, ctx_, column_densities_);
builder.Update(gpair->ConstHostVector(), dmat, tree);
builder.Update(gpair->Data()->ConstHostVector(), dmat, tree);
}
}

View File

@@ -760,16 +760,18 @@ class GPUHistMaker : public TreeUpdater {
dh::GlobalMemoryLogger().Log();
}
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
void Update(TrainParam const* param, linalg::Matrix<GradientPair>* gpair, DMatrix* dmat,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) override {
monitor_.Start("Update");
CHECK_EQ(gpair->Shape(1), 1) << MTNotImplemented();
auto gpair_hdv = gpair->Data();
// build tree
try {
std::size_t t_idx{0};
for (xgboost::RegTree* tree : trees) {
this->UpdateTree(param, gpair, dmat, tree, &out_position[t_idx]);
this->UpdateTree(param, gpair_hdv, dmat, tree, &out_position[t_idx]);
this->hist_maker_param_.CheckTreesSynchronized(tree);
++t_idx;
}
@@ -887,7 +889,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
}
~GPUGlobalApproxMaker() override { dh::GlobalMemoryLogger().Log(); }
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
void Update(TrainParam const* param, linalg::Matrix<GradientPair>* gpair, DMatrix* p_fmat,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) override {
monitor_.Start("Update");
@@ -898,7 +900,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
auto hess = dh::ToSpan(hess_);
gpair->SetDevice(ctx_->Device());
auto d_gpair = gpair->ConstDeviceSpan();
auto d_gpair = gpair->Data()->ConstDeviceSpan();
auto cuctx = ctx_->CUDACtx();
thrust::transform(cuctx->CTP(), dh::tcbegin(d_gpair), dh::tcend(d_gpair), dh::tbegin(hess),
[=] XGBOOST_DEVICE(GradientPair const& g) { return g.GetHess(); });
@@ -912,7 +914,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
std::size_t t_idx{0};
for (xgboost::RegTree* tree : trees) {
this->UpdateTree(gpair, p_fmat, tree, &out_position[t_idx]);
this->UpdateTree(gpair->Data(), p_fmat, tree, &out_position[t_idx]);
this->hist_maker_param_.CheckTreesSynchronized(tree);
++t_idx;
}

View File

@@ -31,7 +31,7 @@ class TreePruner : public TreeUpdater {
[[nodiscard]] bool CanModifyTree() const override { return true; }
// update the tree, do pruning
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
void Update(TrainParam const* param, linalg::Matrix<GradientPair>* gpair, DMatrix* p_fmat,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) override {
pruner_monitor_.Start("PrunerUpdate");

View File

@@ -492,7 +492,7 @@ class QuantileHistMaker : public TreeUpdater {
[[nodiscard]] char const *Name() const override { return "grow_quantile_histmaker"; }
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat,
void Update(TrainParam const *param, linalg::Matrix<GradientPair> *gpair, DMatrix *p_fmat,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) override {
if (trees.front()->IsMultiTarget()) {
@@ -511,8 +511,7 @@ class QuantileHistMaker : public TreeUpdater {
}
bst_target_t n_targets = trees.front()->NumTargets();
auto h_gpair =
linalg::MakeTensorView(ctx_, gpair->HostSpan(), p_fmat->Info().num_row_, n_targets);
auto h_gpair = gpair->HostView();
linalg::Matrix<GradientPair> sample_out;
auto h_sample_out = h_gpair;

View File

@@ -31,11 +31,14 @@ class TreeRefresher : public TreeUpdater {
[[nodiscard]] char const *Name() const override { return "refresh"; }
[[nodiscard]] bool CanModifyTree() const override { return true; }
// update the tree, do pruning
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat,
void Update(TrainParam const *param, linalg::Matrix<GradientPair> *gpair, DMatrix *p_fmat,
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
const std::vector<RegTree *> &trees) override {
if (trees.size() == 0) return;
const std::vector<GradientPair> &gpair_h = gpair->ConstHostVector();
if (trees.size() == 0) {
return;
}
CHECK_EQ(gpair->Shape(1), 1) << MTNotImplemented();
const std::vector<GradientPair> &gpair_h = gpair->Data()->ConstHostVector();
// thread temporal space
std::vector<std::vector<GradStats> > stemp;
std::vector<RegTree::FVec> fvec_temp;

View File

@@ -31,7 +31,7 @@ class TreeSyncher : public TreeUpdater {
[[nodiscard]] char const* Name() const override { return "prune"; }
void Update(TrainParam const*, HostDeviceVector<GradientPair>*, DMatrix*,
void Update(TrainParam const*, linalg::Matrix<GradientPair>*, DMatrix*,
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
const std::vector<RegTree*>& trees) override {
if (collective::GetWorldSize() == 1) return;