[sycl] add loss guided hist building (#10251)
Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
parent
9b465052ce
commit
f588252481
@ -46,6 +46,93 @@ template<typename GradientSumT>
|
||||
const GHistRow<GradientSumT, MemoryType::on_device>& src2,
|
||||
size_t size, ::sycl::event event_priv);
|
||||
|
||||
/*!
|
||||
* \brief Histograms of gradient statistics for multiple nodes
|
||||
*/
|
||||
template<typename GradientSumT, MemoryType memory_type = MemoryType::shared>
|
||||
class HistCollection {
|
||||
public:
|
||||
using GHistRowT = GHistRow<GradientSumT, memory_type>;
|
||||
|
||||
// Access histogram for i-th node
|
||||
GHistRowT& operator[](bst_uint nid) {
|
||||
return *(data_.at(nid));
|
||||
}
|
||||
|
||||
const GHistRowT& operator[](bst_uint nid) const {
|
||||
return *(data_.at(nid));
|
||||
}
|
||||
|
||||
// Initialize histogram collection
|
||||
void Init(::sycl::queue qu, uint32_t nbins) {
|
||||
qu_ = qu;
|
||||
if (nbins_ != nbins) {
|
||||
nbins_ = nbins;
|
||||
data_.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Create an empty histogram for i-th node
|
||||
::sycl::event AddHistRow(bst_uint nid) {
|
||||
::sycl::event event;
|
||||
if (data_.count(nid) == 0) {
|
||||
data_[nid] =
|
||||
std::make_shared<GHistRowT>(&qu_, nbins_,
|
||||
xgboost::detail::GradientPairInternal<GradientSumT>(0, 0),
|
||||
&event);
|
||||
} else {
|
||||
data_[nid]->Resize(&qu_, nbins_,
|
||||
xgboost::detail::GradientPairInternal<GradientSumT>(0, 0),
|
||||
&event);
|
||||
}
|
||||
return event;
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief Number of all bins over all features */
|
||||
uint32_t nbins_ = 0;
|
||||
|
||||
std::unordered_map<uint32_t, std::shared_ptr<GHistRowT>> data_;
|
||||
|
||||
::sycl::queue qu_;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Stores temporary histograms to compute them in parallel
|
||||
*/
|
||||
template<typename GradientSumT>
|
||||
class ParallelGHistBuilder {
|
||||
public:
|
||||
using GHistRowT = GHistRow<GradientSumT, MemoryType::on_device>;
|
||||
|
||||
void Init(::sycl::queue qu, size_t nbins) {
|
||||
qu_ = qu;
|
||||
if (nbins != nbins_) {
|
||||
hist_buffer_.Init(qu_, nbins);
|
||||
nbins_ = nbins;
|
||||
}
|
||||
}
|
||||
|
||||
void Reset(size_t nblocks) {
|
||||
hist_device_buffer_.Resize(&qu_, nblocks * nbins_ * 2);
|
||||
}
|
||||
|
||||
GHistRowT& GetDeviceBuffer() {
|
||||
return hist_device_buffer_;
|
||||
}
|
||||
|
||||
protected:
|
||||
/*! \brief Number of bins in each histogram */
|
||||
size_t nbins_ = 0;
|
||||
/*! \brief Buffers for histograms for all nodes processed */
|
||||
HistCollection<GradientSumT> hist_buffer_;
|
||||
|
||||
/*! \brief Buffer for additional histograms for Parallel processing */
|
||||
GHistRowT hist_device_buffer_;
|
||||
|
||||
::sycl::queue qu_;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Builder for histograms of gradient statistics
|
||||
*/
|
||||
|
||||
@ -80,6 +80,12 @@ class USMVector {
|
||||
qu->fill(data_.get(), v, size_).wait();
|
||||
}
|
||||
|
||||
USMVector(::sycl::queue* qu, size_t size, T v,
|
||||
::sycl::event* event) : size_(size), capacity_(size) {
|
||||
data_ = allocate_memory_(qu, size_);
|
||||
*event = qu->fill(data_.get(), v, size_, *event);
|
||||
}
|
||||
|
||||
USMVector(::sycl::queue* qu, const std::vector<T> &vec) {
|
||||
size_ = vec.size();
|
||||
capacity_ = size_;
|
||||
|
||||
46
plugin/sycl/tree/hist_row_adder.h
Normal file
46
plugin/sycl/tree/hist_row_adder.h
Normal file
@ -0,0 +1,46 @@
|
||||
/*!
|
||||
* Copyright 2017-2024 by Contributors
|
||||
* \file hist_row_adder.h
|
||||
*/
|
||||
#ifndef PLUGIN_SYCL_TREE_HIST_ROW_ADDER_H_
|
||||
#define PLUGIN_SYCL_TREE_HIST_ROW_ADDER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
namespace xgboost {
|
||||
namespace sycl {
|
||||
namespace tree {
|
||||
|
||||
template <typename GradientSumT>
|
||||
class HistRowsAdder {
|
||||
public:
|
||||
virtual void AddHistRows(HistUpdater<GradientSumT>* builder,
|
||||
std::vector<int>* sync_ids, RegTree *p_tree) = 0;
|
||||
virtual ~HistRowsAdder() = default;
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
class BatchHistRowsAdder: public HistRowsAdder<GradientSumT> {
|
||||
public:
|
||||
void AddHistRows(HistUpdater<GradientSumT>* builder,
|
||||
std::vector<int>* sync_ids, RegTree *p_tree) override {
|
||||
builder->builder_monitor_.Start("AddHistRows");
|
||||
|
||||
for (auto const& entry : builder->nodes_for_explicit_hist_build_) {
|
||||
int nid = entry.nid;
|
||||
auto event = builder->hist_.AddHistRow(nid);
|
||||
}
|
||||
for (auto const& node : builder->nodes_for_subtraction_trick_) {
|
||||
auto event = builder->hist_.AddHistRow(node.nid);
|
||||
}
|
||||
|
||||
builder->builder_monitor_.Stop("AddHistRows");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tree
|
||||
} // namespace sycl
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // PLUGIN_SYCL_TREE_HIST_ROW_ADDER_H_
|
||||
68
plugin/sycl/tree/hist_synchronizer.h
Normal file
68
plugin/sycl/tree/hist_synchronizer.h
Normal file
@ -0,0 +1,68 @@
|
||||
/*!
|
||||
* Copyright 2017-2024 by Contributors
|
||||
* \file hist_synchronizer.h
|
||||
*/
|
||||
#ifndef PLUGIN_SYCL_TREE_HIST_SYNCHRONIZER_H_
|
||||
#define PLUGIN_SYCL_TREE_HIST_SYNCHRONIZER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "../common/hist_util.h"
|
||||
#include "expand_entry.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace sycl {
|
||||
namespace tree {
|
||||
|
||||
template <typename GradientSumT>
|
||||
class HistUpdater;
|
||||
|
||||
template <typename GradientSumT>
|
||||
class HistSynchronizer {
|
||||
public:
|
||||
virtual void SyncHistograms(HistUpdater<GradientSumT>* builder,
|
||||
const std::vector<int>& sync_ids,
|
||||
RegTree *p_tree) = 0;
|
||||
virtual ~HistSynchronizer() = default;
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
class BatchHistSynchronizer: public HistSynchronizer<GradientSumT> {
|
||||
public:
|
||||
void SyncHistograms(HistUpdater<GradientSumT>* builder,
|
||||
const std::vector<int>& sync_ids,
|
||||
RegTree *p_tree) override {
|
||||
builder->builder_monitor_.Start("SyncHistograms");
|
||||
const size_t nbins = builder->hist_builder_.GetNumBins();
|
||||
|
||||
hist_sync_events_.resize(builder->nodes_for_explicit_hist_build_.size());
|
||||
for (int i = 0; i < builder->nodes_for_explicit_hist_build_.size(); i++) {
|
||||
const auto entry = builder->nodes_for_explicit_hist_build_[i];
|
||||
auto& this_hist = builder->hist_[entry.nid];
|
||||
|
||||
if (!(*p_tree)[entry.nid].IsRoot()) {
|
||||
const size_t parent_id = (*p_tree)[entry.nid].Parent();
|
||||
auto& parent_hist = builder->hist_[parent_id];
|
||||
auto& sibling_hist = builder->hist_[entry.GetSiblingId(p_tree, parent_id)];
|
||||
hist_sync_events_[i] = common::SubtractionHist(builder->qu_, &sibling_hist, parent_hist,
|
||||
this_hist, nbins, ::sycl::event());
|
||||
}
|
||||
}
|
||||
builder->qu_.wait_and_throw();
|
||||
|
||||
builder->builder_monitor_.Stop("SyncHistograms");
|
||||
}
|
||||
|
||||
std::vector<::sycl::event> GetEvents() const {
|
||||
return hist_sync_events_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<::sycl::event> hist_sync_events_;
|
||||
};
|
||||
|
||||
} // namespace tree
|
||||
} // namespace sycl
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // PLUGIN_SYCL_TREE_HIST_SYNCHRONIZER_H_
|
||||
@ -7,10 +7,69 @@
|
||||
|
||||
#include <oneapi/dpl/random>
|
||||
|
||||
#include "../common/hist_util.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace sycl {
|
||||
namespace tree {
|
||||
|
||||
template <typename GradientSumT>
|
||||
void HistUpdater<GradientSumT>::SetHistSynchronizer(
|
||||
HistSynchronizer<GradientSumT> *sync) {
|
||||
hist_synchronizer_.reset(sync);
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void HistUpdater<GradientSumT>::SetHistRowsAdder(
|
||||
HistRowsAdder<GradientSumT> *adder) {
|
||||
hist_rows_adder_.reset(adder);
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void HistUpdater<GradientSumT>::BuildHistogramsLossGuide(
|
||||
ExpandEntry entry,
|
||||
const common::GHistIndexMatrix &gmat,
|
||||
RegTree *p_tree,
|
||||
const USMVector<GradientPair, MemoryType::on_device> &gpair_device) {
|
||||
nodes_for_explicit_hist_build_.clear();
|
||||
nodes_for_subtraction_trick_.clear();
|
||||
nodes_for_explicit_hist_build_.push_back(entry);
|
||||
|
||||
if (!(*p_tree)[entry.nid].IsRoot()) {
|
||||
auto sibling_id = entry.GetSiblingId(p_tree);
|
||||
nodes_for_subtraction_trick_.emplace_back(sibling_id, p_tree->GetDepth(sibling_id));
|
||||
}
|
||||
|
||||
std::vector<int> sync_ids;
|
||||
hist_rows_adder_->AddHistRows(this, &sync_ids, p_tree);
|
||||
qu_.wait_and_throw();
|
||||
BuildLocalHistograms(gmat, p_tree, gpair_device);
|
||||
hist_synchronizer_->SyncHistograms(this, sync_ids, p_tree);
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void HistUpdater<GradientSumT>::BuildLocalHistograms(
|
||||
const common::GHistIndexMatrix &gmat,
|
||||
RegTree *p_tree,
|
||||
const USMVector<GradientPair, MemoryType::on_device> &gpair_device) {
|
||||
builder_monitor_.Start("BuildLocalHistograms");
|
||||
const size_t n_nodes = nodes_for_explicit_hist_build_.size();
|
||||
::sycl::event event;
|
||||
|
||||
for (size_t i = 0; i < n_nodes; i++) {
|
||||
const int32_t nid = nodes_for_explicit_hist_build_[i].nid;
|
||||
|
||||
if (row_set_collection_[nid].Size() > 0) {
|
||||
event = BuildHist(gpair_device, row_set_collection_[nid], gmat, &(hist_[nid]),
|
||||
&(hist_buffer_.GetDeviceBuffer()), event);
|
||||
} else {
|
||||
common::InitHist(qu_, &(hist_[nid]), hist_[nid].Size(), &event);
|
||||
}
|
||||
}
|
||||
qu_.wait_and_throw();
|
||||
builder_monitor_.Stop("BuildLocalHistograms");
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void HistUpdater<GradientSumT>::InitSampling(
|
||||
const USMVector<GradientPair, MemoryType::on_device> &gpair,
|
||||
@ -70,6 +129,21 @@ void HistUpdater<GradientSumT>::InitData(
|
||||
// initialize the row set
|
||||
{
|
||||
row_set_collection_.Clear();
|
||||
|
||||
// initialize histogram collection
|
||||
uint32_t nbins = gmat.cut.Ptrs().back();
|
||||
hist_.Init(qu_, nbins);
|
||||
|
||||
hist_buffer_.Init(qu_, nbins);
|
||||
size_t buffer_size = kBufferSize;
|
||||
if (buffer_size > info.num_row_ / kMinBlockSize + 1) {
|
||||
buffer_size = info.num_row_ / kMinBlockSize + 1;
|
||||
}
|
||||
hist_buffer_.Reset(buffer_size);
|
||||
|
||||
// initialize histogram builder
|
||||
hist_builder_ = common::GHistBuilder<GradientSumT>(qu_, nbins);
|
||||
|
||||
USMVector<size_t, MemoryType::on_device>* row_indices = &(row_set_collection_.Data());
|
||||
row_indices->Resize(&qu_, info.num_row_);
|
||||
size_t* p_row_indices = row_indices->Data();
|
||||
@ -122,6 +196,25 @@ void HistUpdater<GradientSumT>::InitData(
|
||||
}
|
||||
}
|
||||
row_set_collection_.Init();
|
||||
|
||||
{
|
||||
/* determine layout of data */
|
||||
const size_t nrow = info.num_row_;
|
||||
const size_t ncol = info.num_col_;
|
||||
const size_t nnz = info.num_nonzero_;
|
||||
// number of discrete bins for feature 0
|
||||
const uint32_t nbins_f0 = gmat.cut.Ptrs()[1] - gmat.cut.Ptrs()[0];
|
||||
if (nrow * ncol == nnz) {
|
||||
// dense data with zero-based indexing
|
||||
data_layout_ = kDenseDataZeroBased;
|
||||
} else if (nbins_f0 == 0 && nrow * (ncol - 1) == nnz) {
|
||||
// dense data with one-based indexing
|
||||
data_layout_ = kDenseDataOneBased;
|
||||
} else {
|
||||
// sparse data
|
||||
data_layout_ = kSparseData;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template class HistUpdater<float>;
|
||||
|
||||
@ -12,10 +12,13 @@
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "../common/partition_builder.h"
|
||||
#include "split_evaluator.h"
|
||||
#include "hist_synchronizer.h"
|
||||
#include "hist_row_adder.h"
|
||||
|
||||
#include "../data.h"
|
||||
|
||||
@ -26,6 +29,10 @@ namespace tree {
|
||||
template<typename GradientSumT>
|
||||
class HistUpdater {
|
||||
public:
|
||||
template <MemoryType memory_type = MemoryType::shared>
|
||||
using GHistRowT = common::GHistRow<GradientSumT, memory_type>;
|
||||
using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>;
|
||||
|
||||
explicit HistUpdater(::sycl::queue qu,
|
||||
const xgboost::tree::TrainParam& param,
|
||||
std::unique_ptr<TreeUpdater> pruner,
|
||||
@ -43,7 +50,13 @@ class HistUpdater {
|
||||
sub_group_size_ = sub_group_sizes.back();
|
||||
}
|
||||
|
||||
void SetHistSynchronizer(HistSynchronizer<GradientSumT>* sync);
|
||||
void SetHistRowsAdder(HistRowsAdder<GradientSumT>* adder);
|
||||
|
||||
protected:
|
||||
friend class BatchHistSynchronizer<GradientSumT>;
|
||||
friend class BatchHistRowsAdder<GradientSumT>;
|
||||
|
||||
void InitSampling(const USMVector<GradientPair, MemoryType::on_device> &gpair,
|
||||
USMVector<size_t, MemoryType::on_device>* row_indices);
|
||||
|
||||
@ -54,6 +67,27 @@ class HistUpdater {
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree);
|
||||
|
||||
inline ::sycl::event BuildHist(
|
||||
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
|
||||
const common::RowSetCollection::Elem row_indices,
|
||||
const common::GHistIndexMatrix& gmat,
|
||||
GHistRowT<MemoryType::on_device>* hist,
|
||||
GHistRowT<MemoryType::on_device>* hist_buffer,
|
||||
::sycl::event event_priv) {
|
||||
return hist_builder_.BuildHist(gpair_device, row_indices, gmat, hist,
|
||||
data_layout_ != kSparseData, hist_buffer, event_priv);
|
||||
}
|
||||
|
||||
void BuildLocalHistograms(const common::GHistIndexMatrix &gmat,
|
||||
RegTree *p_tree,
|
||||
const USMVector<GradientPair, MemoryType::on_device> &gpair);
|
||||
|
||||
void BuildHistogramsLossGuide(
|
||||
ExpandEntry entry,
|
||||
const common::GHistIndexMatrix &gmat,
|
||||
RegTree *p_tree,
|
||||
const USMVector<GradientPair, MemoryType::on_device> &gpair);
|
||||
|
||||
// --data fields--
|
||||
size_t sub_group_size_;
|
||||
|
||||
@ -69,11 +103,30 @@ class HistUpdater {
|
||||
const RegTree* p_last_tree_;
|
||||
DMatrix const* const p_last_fmat_;
|
||||
|
||||
enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData };
|
||||
DataLayout data_layout_;
|
||||
|
||||
constexpr static size_t kBufferSize = 2048;
|
||||
constexpr static size_t kMinBlockSize = 128;
|
||||
common::GHistBuilder<GradientSumT> hist_builder_;
|
||||
common::ParallelGHistBuilder<GradientSumT> hist_buffer_;
|
||||
/*! \brief culmulative histogram of gradients. */
|
||||
common::HistCollection<GradientSumT, MemoryType::on_device> hist_;
|
||||
|
||||
xgboost::common::Monitor builder_monitor_;
|
||||
xgboost::common::Monitor kernel_monitor_;
|
||||
|
||||
uint64_t seed_ = 0;
|
||||
|
||||
// key is the node id which should be calculated by Subtraction Trick, value is the node which
|
||||
// provides the evidence for substracts
|
||||
std::vector<ExpandEntry> nodes_for_subtraction_trick_;
|
||||
// list of nodes whose histograms would be built explicitly.
|
||||
std::vector<ExpandEntry> nodes_for_explicit_hist_build_;
|
||||
|
||||
std::unique_ptr<HistSynchronizer<GradientSumT>> hist_synchronizer_;
|
||||
std::unique_ptr<HistRowsAdder<GradientSumT>> hist_rows_adder_;
|
||||
|
||||
::sycl::queue qu_;
|
||||
};
|
||||
|
||||
|
||||
@ -28,16 +28,40 @@ class TestHistUpdater : public HistUpdater<GradientSumT> {
|
||||
HistUpdater<GradientSumT>::InitSampling(gpair, row_indices);
|
||||
}
|
||||
|
||||
const auto* TestInitData(Context const * ctx,
|
||||
auto* TestInitData(Context const * ctx,
|
||||
const common::GHistIndexMatrix& gmat,
|
||||
const USMVector<GradientPair, MemoryType::on_device> &gpair,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree) {
|
||||
HistUpdater<GradientSumT>::InitData(ctx, gmat, gpair, fmat, tree);
|
||||
return &(HistUpdater<GradientSumT>::row_set_collection_.Data());
|
||||
return &(HistUpdater<GradientSumT>::row_set_collection_);
|
||||
}
|
||||
|
||||
const auto* TestBuildHistogramsLossGuide(ExpandEntry entry,
|
||||
const common::GHistIndexMatrix &gmat,
|
||||
RegTree *p_tree,
|
||||
const USMVector<GradientPair, MemoryType::on_device> &gpair) {
|
||||
HistUpdater<GradientSumT>::BuildHistogramsLossGuide(entry, gmat, p_tree, gpair);
|
||||
return &(HistUpdater<GradientSumT>::hist_);
|
||||
}
|
||||
};
|
||||
|
||||
void GenerateRandomGPairs(::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) {
|
||||
qu->submit([&](::sycl::handler& cgh) {
|
||||
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
|
||||
[=](::sycl::item<1> pid) {
|
||||
uint64_t i = pid.get_linear_id();
|
||||
|
||||
constexpr uint32_t seed = 777;
|
||||
oneapi::dpl::minstd_rand engine(seed, i);
|
||||
GradientPair::ValueT smallest_hess_val = has_neg_hess ? -1. : 0.;
|
||||
oneapi::dpl::uniform_real_distribution<GradientPair::ValueT> distr(smallest_hess_val, 1.);
|
||||
gpair_ptr[i] = {distr(engine), distr(engine)};
|
||||
});
|
||||
});
|
||||
qu->wait();
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void TestHistUpdaterSampling(const xgboost::tree::TrainParam& param) {
|
||||
const size_t num_rows = 1u << 12;
|
||||
@ -60,18 +84,7 @@ void TestHistUpdaterSampling(const xgboost::tree::TrainParam& param) {
|
||||
USMVector<size_t, MemoryType::on_device> row_indices_0(&qu, num_rows);
|
||||
USMVector<size_t, MemoryType::on_device> row_indices_1(&qu, num_rows);
|
||||
USMVector<GradientPair, MemoryType::on_device> gpair(&qu, num_rows);
|
||||
auto* gpair_ptr = gpair.Data();
|
||||
qu.submit([&](::sycl::handler& cgh) {
|
||||
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
|
||||
[=](::sycl::item<1> pid) {
|
||||
uint64_t i = pid.get_linear_id();
|
||||
|
||||
constexpr uint32_t seed = 777;
|
||||
oneapi::dpl::minstd_rand engine(seed, i);
|
||||
oneapi::dpl::uniform_real_distribution<GradientPair::ValueT> distr(-1., 1.);
|
||||
gpair_ptr[i] = {distr(engine), distr(engine)};
|
||||
});
|
||||
}).wait();
|
||||
GenerateRandomGPairs(&qu, gpair.Data(), num_rows, true);
|
||||
|
||||
updater.TestInitSampling(gpair, &row_indices_0);
|
||||
|
||||
@ -125,19 +138,7 @@ void TestHistUpdaterInitData(const xgboost::tree::TrainParam& param, bool has_ne
|
||||
TestHistUpdater<GradientSumT> updater(qu, param, std::move(pruner), int_constraints, p_fmat.get());
|
||||
|
||||
USMVector<GradientPair, MemoryType::on_device> gpair(&qu, num_rows);
|
||||
auto* gpair_ptr = gpair.Data();
|
||||
qu.submit([&](::sycl::handler& cgh) {
|
||||
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
|
||||
[=](::sycl::item<1> pid) {
|
||||
uint64_t i = pid.get_linear_id();
|
||||
|
||||
constexpr uint32_t seed = 777;
|
||||
oneapi::dpl::minstd_rand engine(seed, i);
|
||||
GradientPair::ValueT smallest_hess_val = has_neg_hess ? -1. : 0.;
|
||||
oneapi::dpl::uniform_real_distribution<GradientPair::ValueT> distr(smallest_hess_val, 1.);
|
||||
gpair_ptr[i] = {distr(engine), distr(engine)};
|
||||
});
|
||||
}).wait();
|
||||
GenerateRandomGPairs(&qu, gpair.Data(), num_rows, has_neg_hess);
|
||||
|
||||
DeviceMatrix dmat;
|
||||
dmat.Init(qu, p_fmat.get());
|
||||
@ -145,10 +146,11 @@ void TestHistUpdaterInitData(const xgboost::tree::TrainParam& param, bool has_ne
|
||||
gmat.Init(qu, &ctx, dmat, n_bins);
|
||||
RegTree tree;
|
||||
|
||||
const auto* row_indices = updater.TestInitData(&ctx, gmat, gpair, *p_fmat, tree);
|
||||
auto* row_set_collection = updater.TestInitData(&ctx, gmat, gpair, *p_fmat, tree);
|
||||
auto& row_indices = row_set_collection->Data();
|
||||
|
||||
std::vector<size_t> row_indices_host(row_indices->Size());
|
||||
qu.memcpy(row_indices_host.data(), row_indices->DataConst(), row_indices->Size()*sizeof(size_t)).wait();
|
||||
std::vector<size_t> row_indices_host(row_indices.Size());
|
||||
qu.memcpy(row_indices_host.data(), row_indices.DataConst(), row_indices.Size()*sizeof(size_t)).wait();
|
||||
|
||||
if (!has_neg_hess) {
|
||||
for (size_t i = 0; i < num_rows; ++i) {
|
||||
@ -171,6 +173,70 @@ void TestHistUpdaterInitData(const xgboost::tree::TrainParam& param, bool has_ne
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void TestHistUpdaterBuildHistogramsLossGuide(const xgboost::tree::TrainParam& param, float sparsity) {
|
||||
const size_t num_rows = 1u << 8;
|
||||
const size_t num_columns = 1;
|
||||
const size_t n_bins = 32;
|
||||
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
|
||||
|
||||
DeviceManager device_manager;
|
||||
auto qu = device_manager.GetQueue(ctx.Device());
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
|
||||
auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix();
|
||||
|
||||
FeatureInteractionConstraintHost int_constraints;
|
||||
std::unique_ptr<TreeUpdater> pruner{TreeUpdater::Create("prune", &ctx, &task)};
|
||||
|
||||
TestHistUpdater<GradientSumT> updater(qu, param, std::move(pruner), int_constraints, p_fmat.get());
|
||||
updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
|
||||
updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
|
||||
|
||||
USMVector<GradientPair, MemoryType::on_device> gpair(&qu, num_rows);
|
||||
auto* gpair_ptr = gpair.Data();
|
||||
GenerateRandomGPairs(&qu, gpair_ptr, num_rows, false);
|
||||
|
||||
DeviceMatrix dmat;
|
||||
dmat.Init(qu, p_fmat.get());
|
||||
common::GHistIndexMatrix gmat;
|
||||
gmat.Init(qu, &ctx, dmat, n_bins);
|
||||
|
||||
RegTree tree;
|
||||
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
|
||||
ExpandEntry node0(0, tree.GetDepth(0));
|
||||
ExpandEntry node1(1, tree.GetDepth(1));
|
||||
ExpandEntry node2(2, tree.GetDepth(2));
|
||||
|
||||
auto* row_set_collection = updater.TestInitData(&ctx, gmat, gpair, *p_fmat, tree);
|
||||
row_set_collection->AddSplit(0, 1, 2, 42, num_rows - 42);
|
||||
|
||||
updater.TestBuildHistogramsLossGuide(node0, gmat, &tree, gpair);
|
||||
const auto* hist = updater.TestBuildHistogramsLossGuide(node1, gmat, &tree, gpair);
|
||||
|
||||
ASSERT_EQ((*hist)[0].Size(), n_bins);
|
||||
ASSERT_EQ((*hist)[1].Size(), n_bins);
|
||||
ASSERT_EQ((*hist)[2].Size(), n_bins);
|
||||
|
||||
std::vector<xgboost::detail::GradientPairInternal<GradientSumT>> hist0_host(n_bins);
|
||||
std::vector<xgboost::detail::GradientPairInternal<GradientSumT>> hist1_host(n_bins);
|
||||
std::vector<xgboost::detail::GradientPairInternal<GradientSumT>> hist2_host(n_bins);
|
||||
qu.memcpy(hist0_host.data(), (*hist)[0].DataConst(), sizeof(xgboost::detail::GradientPairInternal<GradientSumT>) * n_bins);
|
||||
qu.memcpy(hist1_host.data(), (*hist)[1].DataConst(), sizeof(xgboost::detail::GradientPairInternal<GradientSumT>) * n_bins);
|
||||
qu.memcpy(hist2_host.data(), (*hist)[2].DataConst(), sizeof(xgboost::detail::GradientPairInternal<GradientSumT>) * n_bins);
|
||||
qu.wait();
|
||||
|
||||
for (size_t idx_bin = 0; idx_bin < n_bins; ++idx_bin) {
|
||||
EXPECT_NEAR(hist0_host[idx_bin].GetGrad(), hist1_host[idx_bin].GetGrad() + hist2_host[idx_bin].GetGrad(), 1e-6);
|
||||
EXPECT_NEAR(hist0_host[idx_bin].GetHess(), hist1_host[idx_bin].GetHess() + hist2_host[idx_bin].GetHess(), 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SyclHistUpdater, Sampling) {
|
||||
xgboost::tree::TrainParam param;
|
||||
param.UpdateAllowUnknown(Args{{"subsample", "0.7"}});
|
||||
@ -190,4 +256,14 @@ TEST(SyclHistUpdater, InitData) {
|
||||
TestHistUpdaterInitData<double>(param, false);
|
||||
}
|
||||
|
||||
TEST(SyclHistUpdater, BuildHistogramsLossGuide) {
|
||||
xgboost::tree::TrainParam param;
|
||||
param.UpdateAllowUnknown(Args{{"max_depth", "3"}});
|
||||
|
||||
TestHistUpdaterBuildHistogramsLossGuide<float>(param, 0.0);
|
||||
TestHistUpdaterBuildHistogramsLossGuide<float>(param, 0.5);
|
||||
TestHistUpdaterBuildHistogramsLossGuide<double>(param, 0.0);
|
||||
TestHistUpdaterBuildHistogramsLossGuide<double>(param, 0.5);
|
||||
}
|
||||
|
||||
} // namespace xgboost::sycl::tree
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user