[SYCL] Implement UpdatePredictionCache and connect updater with leraner. (#10701)
--------- Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
committed by
GitHub
parent
9b88495840
commit
24d225c1ab
@@ -307,6 +307,99 @@ void HistUpdater<GradientSumT>::ExpandWithLossGuide(
|
||||
builder_monitor_.Stop("ExpandWithLossGuide");
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void HistUpdater<GradientSumT>::Update(
|
||||
xgboost::tree::TrainParam const *param,
|
||||
const common::GHistIndexMatrix &gmat,
|
||||
const USMVector<GradientPair, MemoryType::on_device>& gpair,
|
||||
DMatrix *p_fmat,
|
||||
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
RegTree *p_tree) {
|
||||
builder_monitor_.Start("Update");
|
||||
|
||||
tree_evaluator_.Reset(qu_, param_, p_fmat->Info().num_col_);
|
||||
interaction_constraints_.Reset();
|
||||
|
||||
this->InitData(gmat, gpair, *p_fmat, *p_tree);
|
||||
if (param_.grow_policy == xgboost::tree::TrainParam::kLossGuide) {
|
||||
ExpandWithLossGuide(gmat, p_tree, gpair);
|
||||
} else {
|
||||
ExpandWithDepthWise(gmat, p_tree, gpair);
|
||||
}
|
||||
|
||||
for (int nid = 0; nid < p_tree->NumNodes(); ++nid) {
|
||||
p_tree->Stat(nid).loss_chg = snode_host_[nid].best.loss_chg;
|
||||
p_tree->Stat(nid).base_weight = snode_host_[nid].weight;
|
||||
p_tree->Stat(nid).sum_hess = static_cast<float>(snode_host_[nid].stats.GetHess());
|
||||
}
|
||||
|
||||
builder_monitor_.Stop("Update");
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
bool HistUpdater<GradientSumT>::UpdatePredictionCache(
|
||||
const DMatrix* data,
|
||||
linalg::MatrixView<float> out_preds) {
|
||||
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
|
||||
// conjunction with Update().
|
||||
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) {
|
||||
return false;
|
||||
}
|
||||
builder_monitor_.Start("UpdatePredictionCache");
|
||||
CHECK_GT(out_preds.Size(), 0U);
|
||||
|
||||
const size_t stride = out_preds.Stride(0);
|
||||
const bool is_first_group = (out_pred_ptr == nullptr);
|
||||
const size_t gid = out_pred_ptr == nullptr ? 0 : &out_preds(0) - out_pred_ptr;
|
||||
const bool is_last_group = (gid + 1 == stride);
|
||||
|
||||
const int buffer_size = out_preds.Size() *stride;
|
||||
if (buffer_size == 0) return true;
|
||||
|
||||
::sycl::event event;
|
||||
if (is_first_group) {
|
||||
out_preds_buf_.ResizeNoCopy(&qu_, buffer_size);
|
||||
out_pred_ptr = &out_preds(0);
|
||||
event = qu_.memcpy(out_preds_buf_.Data(), out_pred_ptr, buffer_size * sizeof(bst_float), event);
|
||||
}
|
||||
auto* out_preds_buf_ptr = out_preds_buf_.Data();
|
||||
|
||||
size_t n_nodes = row_set_collection_.Size();
|
||||
std::vector<::sycl::event> events(n_nodes);
|
||||
for (size_t node = 0; node < n_nodes; node++) {
|
||||
const common::RowSetCollection::Elem& rowset = row_set_collection_[node];
|
||||
if (rowset.begin != nullptr && rowset.end != nullptr && rowset.Size() != 0) {
|
||||
int nid = rowset.node_id;
|
||||
// if a node is marked as deleted by the pruner, traverse upward to locate
|
||||
// a non-deleted leaf.
|
||||
if ((*p_last_tree_)[nid].IsDeleted()) {
|
||||
while ((*p_last_tree_)[nid].IsDeleted()) {
|
||||
nid = (*p_last_tree_)[nid].Parent();
|
||||
}
|
||||
CHECK((*p_last_tree_)[nid].IsLeaf());
|
||||
}
|
||||
bst_float leaf_value = (*p_last_tree_)[nid].LeafValue();
|
||||
const size_t* rid = rowset.begin;
|
||||
const size_t num_rows = rowset.Size();
|
||||
|
||||
events[node] = qu_.submit([&](::sycl::handler& cgh) {
|
||||
cgh.depends_on(event);
|
||||
cgh.parallel_for<>(::sycl::range<1>(num_rows), [=](::sycl::item<1> pid) {
|
||||
out_preds_buf_ptr[rid[pid.get_id(0)]*stride + gid] += leaf_value;
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
if (is_last_group) {
|
||||
qu_.memcpy(out_pred_ptr, out_preds_buf_ptr, buffer_size * sizeof(bst_float), events);
|
||||
out_pred_ptr = nullptr;
|
||||
}
|
||||
qu_.wait();
|
||||
|
||||
builder_monitor_.Stop("UpdatePredictionCache");
|
||||
return true;
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void HistUpdater<GradientSumT>::InitSampling(
|
||||
const USMVector<GradientPair, MemoryType::on_device> &gpair,
|
||||
@@ -479,6 +572,8 @@ void HistUpdater<GradientSumT>::InitData(
|
||||
}
|
||||
}
|
||||
|
||||
// store a pointer to the tree
|
||||
p_last_tree_ = &tree;
|
||||
column_sampler_->Init(ctx_, info.num_col_, info.feature_weights.ConstHostVector(),
|
||||
param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree);
|
||||
|
||||
@@ -11,10 +11,10 @@
|
||||
#include <xgboost/tree_updater.h>
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
|
||||
#include "../common/partition_builder.h"
|
||||
#include "split_evaluator.h"
|
||||
@@ -54,12 +54,10 @@ class HistUpdater {
|
||||
explicit HistUpdater(const Context* ctx,
|
||||
::sycl::queue qu,
|
||||
const xgboost::tree::TrainParam& param,
|
||||
std::unique_ptr<TreeUpdater> pruner,
|
||||
FeatureInteractionConstraintHost int_constraints_,
|
||||
DMatrix const* fmat)
|
||||
: ctx_(ctx), qu_(qu), param_(param),
|
||||
tree_evaluator_(qu, param, fmat->Info().num_col_),
|
||||
pruner_(std::move(pruner)),
|
||||
interaction_constraints_{std::move(int_constraints_)},
|
||||
p_last_tree_(nullptr), p_last_fmat_(fmat) {
|
||||
builder_monitor_.Init("SYCL::Quantile::HistUpdater");
|
||||
@@ -73,6 +71,17 @@ class HistUpdater {
|
||||
sub_group_size_ = sub_group_sizes.back();
|
||||
}
|
||||
|
||||
// update one tree, growing
|
||||
void Update(xgboost::tree::TrainParam const *param,
|
||||
const common::GHistIndexMatrix &gmat,
|
||||
const USMVector<GradientPair, MemoryType::on_device>& gpair,
|
||||
DMatrix *p_fmat,
|
||||
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
RegTree *p_tree);
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
linalg::MatrixView<float> p_out_preds);
|
||||
|
||||
void SetHistSynchronizer(HistSynchronizer<GradientSumT>* sync);
|
||||
void SetHistRowsAdder(HistRowsAdder<GradientSumT>* adder);
|
||||
|
||||
@@ -200,7 +209,6 @@ class HistUpdater {
|
||||
std::vector<SplitEntry<GradientSumT>> best_splits_host_;
|
||||
|
||||
TreeEvaluator<GradientSumT> tree_evaluator_;
|
||||
std::unique_ptr<TreeUpdater> pruner_;
|
||||
FeatureInteractionConstraintHost interaction_constraints_;
|
||||
|
||||
// back pointers to tree and data matrix
|
||||
@@ -247,6 +255,9 @@ class HistUpdater {
|
||||
std::unique_ptr<HistSynchronizer<GradientSumT>> hist_synchronizer_;
|
||||
std::unique_ptr<HistRowsAdder<GradientSumT>> hist_rows_adder_;
|
||||
|
||||
USMVector<bst_float, MemoryType::on_device> out_preds_buf_;
|
||||
bst_float* out_pred_ptr = nullptr;
|
||||
|
||||
::sycl::queue qu_;
|
||||
};
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
* \file updater_quantile_hist.cc
|
||||
*/
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
@@ -29,6 +30,50 @@ void QuantileHistMaker::Configure(const Args& args) {
|
||||
|
||||
param_.UpdateAllowUnknown(args);
|
||||
hist_maker_param_.UpdateAllowUnknown(args);
|
||||
|
||||
bool has_fp64_support = qu_.get_device().has(::sycl::aspect::fp64);
|
||||
if (hist_maker_param_.single_precision_histogram || !has_fp64_support) {
|
||||
if (!hist_maker_param_.single_precision_histogram) {
|
||||
LOG(WARNING) << "Target device doesn't support fp64, using single_precision_histogram=True";
|
||||
}
|
||||
hist_precision_ = HistPrecision::fp32;
|
||||
} else {
|
||||
hist_precision_ = HistPrecision::fp64;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::SetPimpl(std::unique_ptr<HistUpdater<GradientSumT>>* pimpl,
|
||||
DMatrix *dmat) {
|
||||
pimpl->reset(new HistUpdater<GradientSumT>(
|
||||
ctx_,
|
||||
qu_,
|
||||
param_,
|
||||
int_constraint_, dmat));
|
||||
if (collective::IsDistributed()) {
|
||||
LOG(FATAL) << "Distributed mode is not yet upstreamed for sycl";
|
||||
} else {
|
||||
(*pimpl)->SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
|
||||
(*pimpl)->SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
|
||||
}
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::CallUpdate(
|
||||
const std::unique_ptr<HistUpdater<GradientSumT>>& pimpl,
|
||||
xgboost::tree::TrainParam const *param,
|
||||
linalg::Matrix<GradientPair> *gpair,
|
||||
DMatrix *dmat,
|
||||
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree *> &trees) {
|
||||
const auto* gpair_h = gpair->Data();
|
||||
gpair_device_.Resize(&qu_, gpair_h->Size());
|
||||
qu_.memcpy(gpair_device_.Data(), gpair_h->HostPointer(), gpair_h->Size() * sizeof(GradientPair));
|
||||
qu_.wait();
|
||||
|
||||
for (auto tree : trees) {
|
||||
pimpl->Update(param, gmat_, gpair_device_, dmat, out_position, tree);
|
||||
}
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,
|
||||
@@ -36,12 +81,55 @@ void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,
|
||||
DMatrix *dmat,
|
||||
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree *> &trees) {
|
||||
LOG(FATAL) << "Not Implemented yet";
|
||||
if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) {
|
||||
updater_monitor_.Start("DeviceMatrixInitialization");
|
||||
sycl::DeviceMatrix dmat_device;
|
||||
dmat_device.Init(qu_, dmat);
|
||||
updater_monitor_.Stop("DeviceMatrixInitialization");
|
||||
updater_monitor_.Start("GmatInitialization");
|
||||
gmat_.Init(qu_, ctx_, dmat_device, static_cast<uint32_t>(param_.max_bin));
|
||||
updater_monitor_.Stop("GmatInitialization");
|
||||
is_gmat_initialized_ = true;
|
||||
}
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param_.learning_rate;
|
||||
param_.learning_rate = lr / trees.size();
|
||||
int_constraint_.Configure(param_, dmat->Info().num_col_);
|
||||
// build tree
|
||||
if (hist_precision_ == HistPrecision::fp32) {
|
||||
if (!pimpl_fp32) {
|
||||
SetPimpl(&pimpl_fp32, dmat);
|
||||
}
|
||||
CallUpdate(pimpl_fp32, param, gpair, dmat, out_position, trees);
|
||||
} else {
|
||||
if (!pimpl_fp64) {
|
||||
SetPimpl(&pimpl_fp64, dmat);
|
||||
}
|
||||
CallUpdate(pimpl_fp64, param, gpair, dmat, out_position, trees);
|
||||
}
|
||||
|
||||
param_.learning_rate = lr;
|
||||
|
||||
p_last_dmat_ = dmat;
|
||||
}
|
||||
|
||||
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data,
|
||||
linalg::MatrixView<float> out_preds) {
|
||||
LOG(FATAL) << "Not Implemented yet";
|
||||
if (param_.subsample < 1.0f) return false;
|
||||
|
||||
if (hist_precision_ == HistPrecision::fp32) {
|
||||
if (pimpl_fp32) {
|
||||
return pimpl_fp32->UpdatePredictionCache(data, out_preds);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (pimpl_fp64) {
|
||||
return pimpl_fp64->UpdatePredictionCache(data, out_preds);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl")
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "../data/gradient_index.h"
|
||||
#include "../common/hist_util.h"
|
||||
@@ -16,8 +17,9 @@
|
||||
#include "../common/partition_builder.h"
|
||||
#include "split_evaluator.h"
|
||||
#include "../device_manager.h"
|
||||
|
||||
#include "hist_updater.h"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include "../../src/tree/constraints.h"
|
||||
#include "../../src/common/random.h"
|
||||
@@ -75,12 +77,39 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
HistMakerTrainParam hist_maker_param_;
|
||||
// training parameter
|
||||
xgboost::tree::TrainParam param_;
|
||||
// quantized data matrix
|
||||
common::GHistIndexMatrix gmat_;
|
||||
// (optional) data matrix with feature grouping
|
||||
// column accessor
|
||||
DMatrix const* p_last_dmat_ {nullptr};
|
||||
bool is_gmat_initialized_ {false};
|
||||
|
||||
xgboost::common::Monitor updater_monitor_;
|
||||
|
||||
template<typename GradientSumT>
|
||||
void SetPimpl(std::unique_ptr<HistUpdater<GradientSumT>>*, DMatrix *dmat);
|
||||
|
||||
template<typename GradientSumT>
|
||||
void CallUpdate(const std::unique_ptr<HistUpdater<GradientSumT>>& builder,
|
||||
xgboost::tree::TrainParam const *param,
|
||||
linalg::Matrix<GradientPair> *gpair,
|
||||
DMatrix *dmat,
|
||||
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree *> &trees);
|
||||
|
||||
enum class HistPrecision {fp32, fp64};
|
||||
HistPrecision hist_precision_;
|
||||
|
||||
std::unique_ptr<HistUpdater<float>> pimpl_fp32;
|
||||
std::unique_ptr<HistUpdater<double>> pimpl_fp64;
|
||||
|
||||
FeatureInteractionConstraintHost int_constraint_;
|
||||
|
||||
::sycl::queue qu_;
|
||||
DeviceManager device_manager;
|
||||
ObjInfo const *task_{nullptr};
|
||||
|
||||
USMVector<GradientPair, MemoryType::on_device> gpair_device_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user