Use matrix for gradient. (#9508)
- Use the `linalg::Matrix` for storing gradients. - New API for the custom objective. - Custom objective for multi-class/multi-target is now required to return the correct shape. - Custom objective for Python can accept arrays with any strides. (row-major, column-major)
This commit is contained in:
@@ -66,14 +66,13 @@ inline void FitStump(Context const*, linalg::TensorView<GradientPair const, 2>,
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
} // namespace cuda_impl
|
||||
|
||||
void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector<GradientPair> const& gpair,
|
||||
void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix<GradientPair> const& gpair,
|
||||
bst_target_t n_targets, linalg::Vector<float>* out) {
|
||||
out->SetDevice(ctx->gpu_id);
|
||||
out->Reshape(n_targets);
|
||||
auto n_samples = gpair.Size() / n_targets;
|
||||
|
||||
gpair.SetDevice(ctx->gpu_id);
|
||||
auto gpair_t = linalg::MakeTensorView(ctx, &gpair, n_samples, n_targets);
|
||||
gpair.SetDevice(ctx->Device());
|
||||
auto gpair_t = gpair.View(ctx->Device());
|
||||
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
|
||||
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id));
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ XGBOOST_DEVICE inline double CalcUnregularizedWeight(T sum_grad, T sum_hess) {
|
||||
/**
|
||||
* @brief Fit a tree stump as an estimation of base_score.
|
||||
*/
|
||||
void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector<GradientPair> const& gpair,
|
||||
void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix<GradientPair> const& gpair,
|
||||
bst_target_t n_targets, linalg::Vector<float>* out);
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -269,17 +269,18 @@ class GlobalApproxUpdater : public TreeUpdater {
|
||||
out["hist_train_param"] = ToJson(hist_param_);
|
||||
}
|
||||
|
||||
void InitData(TrainParam const ¶m, HostDeviceVector<GradientPair> const *gpair,
|
||||
void InitData(TrainParam const ¶m, linalg::Matrix<GradientPair> const *gpair,
|
||||
linalg::Matrix<GradientPair> *sampled) {
|
||||
*sampled = linalg::Empty<GradientPair>(ctx_, gpair->Size(), 1);
|
||||
sampled->Data()->Copy(*gpair);
|
||||
auto in = gpair->HostView().Values();
|
||||
std::copy(in.data(), in.data() + in.size(), sampled->HostView().Values().data());
|
||||
|
||||
SampleGradient(ctx_, param, sampled->HostView());
|
||||
}
|
||||
|
||||
[[nodiscard]] char const *Name() const override { return "grow_histmaker"; }
|
||||
|
||||
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *m,
|
||||
void Update(TrainParam const *param, linalg::Matrix<GradientPair> *gpair, DMatrix *m,
|
||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree *> &trees) override {
|
||||
CHECK(hist_param_.GetInitialised());
|
||||
|
||||
@@ -91,7 +91,7 @@ class ColMaker: public TreeUpdater {
|
||||
}
|
||||
}
|
||||
|
||||
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *dmat,
|
||||
void Update(TrainParam const *param, linalg::Matrix<GradientPair> *gpair, DMatrix *dmat,
|
||||
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
|
||||
const std::vector<RegTree *> &trees) override {
|
||||
if (collective::IsDistributed()) {
|
||||
@@ -106,10 +106,11 @@ class ColMaker: public TreeUpdater {
|
||||
// rescale learning rate according to size of trees
|
||||
interaction_constraints_.Configure(*param, dmat->Info().num_row_);
|
||||
// build tree
|
||||
CHECK_EQ(gpair->Shape(1), 1) << MTNotImplemented();
|
||||
for (auto tree : trees) {
|
||||
CHECK(ctx_);
|
||||
Builder builder(*param, colmaker_param_, interaction_constraints_, ctx_, column_densities_);
|
||||
builder.Update(gpair->ConstHostVector(), dmat, tree);
|
||||
builder.Update(gpair->Data()->ConstHostVector(), dmat, tree);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -760,16 +760,18 @@ class GPUHistMaker : public TreeUpdater {
|
||||
dh::GlobalMemoryLogger().Log();
|
||||
}
|
||||
|
||||
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
void Update(TrainParam const* param, linalg::Matrix<GradientPair>* gpair, DMatrix* dmat,
|
||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
monitor_.Start("Update");
|
||||
|
||||
CHECK_EQ(gpair->Shape(1), 1) << MTNotImplemented();
|
||||
auto gpair_hdv = gpair->Data();
|
||||
// build tree
|
||||
try {
|
||||
std::size_t t_idx{0};
|
||||
for (xgboost::RegTree* tree : trees) {
|
||||
this->UpdateTree(param, gpair, dmat, tree, &out_position[t_idx]);
|
||||
this->UpdateTree(param, gpair_hdv, dmat, tree, &out_position[t_idx]);
|
||||
this->hist_maker_param_.CheckTreesSynchronized(tree);
|
||||
++t_idx;
|
||||
}
|
||||
@@ -887,7 +889,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
|
||||
}
|
||||
~GPUGlobalApproxMaker() override { dh::GlobalMemoryLogger().Log(); }
|
||||
|
||||
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
|
||||
void Update(TrainParam const* param, linalg::Matrix<GradientPair>* gpair, DMatrix* p_fmat,
|
||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
monitor_.Start("Update");
|
||||
@@ -898,7 +900,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
|
||||
auto hess = dh::ToSpan(hess_);
|
||||
|
||||
gpair->SetDevice(ctx_->Device());
|
||||
auto d_gpair = gpair->ConstDeviceSpan();
|
||||
auto d_gpair = gpair->Data()->ConstDeviceSpan();
|
||||
auto cuctx = ctx_->CUDACtx();
|
||||
thrust::transform(cuctx->CTP(), dh::tcbegin(d_gpair), dh::tcend(d_gpair), dh::tbegin(hess),
|
||||
[=] XGBOOST_DEVICE(GradientPair const& g) { return g.GetHess(); });
|
||||
@@ -912,7 +914,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
|
||||
|
||||
std::size_t t_idx{0};
|
||||
for (xgboost::RegTree* tree : trees) {
|
||||
this->UpdateTree(gpair, p_fmat, tree, &out_position[t_idx]);
|
||||
this->UpdateTree(gpair->Data(), p_fmat, tree, &out_position[t_idx]);
|
||||
this->hist_maker_param_.CheckTreesSynchronized(tree);
|
||||
++t_idx;
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ class TreePruner : public TreeUpdater {
|
||||
[[nodiscard]] bool CanModifyTree() const override { return true; }
|
||||
|
||||
// update the tree, do pruning
|
||||
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
|
||||
void Update(TrainParam const* param, linalg::Matrix<GradientPair>* gpair, DMatrix* p_fmat,
|
||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
pruner_monitor_.Start("PrunerUpdate");
|
||||
|
||||
@@ -492,7 +492,7 @@ class QuantileHistMaker : public TreeUpdater {
|
||||
|
||||
[[nodiscard]] char const *Name() const override { return "grow_quantile_histmaker"; }
|
||||
|
||||
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat,
|
||||
void Update(TrainParam const *param, linalg::Matrix<GradientPair> *gpair, DMatrix *p_fmat,
|
||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree *> &trees) override {
|
||||
if (trees.front()->IsMultiTarget()) {
|
||||
@@ -511,8 +511,7 @@ class QuantileHistMaker : public TreeUpdater {
|
||||
}
|
||||
|
||||
bst_target_t n_targets = trees.front()->NumTargets();
|
||||
auto h_gpair =
|
||||
linalg::MakeTensorView(ctx_, gpair->HostSpan(), p_fmat->Info().num_row_, n_targets);
|
||||
auto h_gpair = gpair->HostView();
|
||||
|
||||
linalg::Matrix<GradientPair> sample_out;
|
||||
auto h_sample_out = h_gpair;
|
||||
|
||||
@@ -31,11 +31,14 @@ class TreeRefresher : public TreeUpdater {
|
||||
[[nodiscard]] char const *Name() const override { return "refresh"; }
|
||||
[[nodiscard]] bool CanModifyTree() const override { return true; }
|
||||
// update the tree, do pruning
|
||||
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat,
|
||||
void Update(TrainParam const *param, linalg::Matrix<GradientPair> *gpair, DMatrix *p_fmat,
|
||||
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
|
||||
const std::vector<RegTree *> &trees) override {
|
||||
if (trees.size() == 0) return;
|
||||
const std::vector<GradientPair> &gpair_h = gpair->ConstHostVector();
|
||||
if (trees.size() == 0) {
|
||||
return;
|
||||
}
|
||||
CHECK_EQ(gpair->Shape(1), 1) << MTNotImplemented();
|
||||
const std::vector<GradientPair> &gpair_h = gpair->Data()->ConstHostVector();
|
||||
// thread temporal space
|
||||
std::vector<std::vector<GradStats> > stemp;
|
||||
std::vector<RegTree::FVec> fvec_temp;
|
||||
|
||||
@@ -31,7 +31,7 @@ class TreeSyncher : public TreeUpdater {
|
||||
|
||||
[[nodiscard]] char const* Name() const override { return "prune"; }
|
||||
|
||||
void Update(TrainParam const*, HostDeviceVector<GradientPair>*, DMatrix*,
|
||||
void Update(TrainParam const*, linalg::Matrix<GradientPair>*, DMatrix*,
|
||||
common::Span<HostDeviceVector<bst_node_t>> /*out_position*/,
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
if (collective::GetWorldSize() == 1) return;
|
||||
|
||||
Reference in New Issue
Block a user