[sycl] add partitioning and related tests (#10080)

Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
Dmitry Razdoburdin 2024-03-01 18:49:27 +01:00 committed by GitHub
parent 2c12b956da
commit 7a61216690
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 371 additions and 21 deletions

View File

@ -21,6 +21,9 @@
#pragma GCC diagnostic pop
#include "../data.h"
#include "row_set.h"
#include "../data/gradient_index.h"
#include "../tree/expand_entry.h"
#include <CL/sycl.hpp>
@ -28,6 +31,87 @@ namespace xgboost {
namespace sycl {
namespace common {
// split row indexes (rid_span) to 2 parts (both stored in rid_buf) depending
// on comparison of indexes values (idx_span) and split point (split_cond)
// Handle dense columns
template <bool default_left, typename BinIdxType>
inline ::sycl::event PartitionDenseKernel(
::sycl::queue* qu,
const GHistIndexMatrix& gmat,
const RowSetCollection::Elem& rid_span,
const size_t fid,
const int32_t split_cond,
xgboost::common::Span<size_t>* rid_buf,
size_t* parts_size,
::sycl::event event) {
const size_t row_stride = gmat.row_stride;
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
const size_t* rid = rid_span.begin;
const size_t range_size = rid_span.Size();
const size_t offset = gmat.cut.Ptrs()[fid];
size_t* p_rid_buf = rid_buf->data();
return qu->submit([&](::sycl::handler& cgh) {
cgh.depends_on(event);
cgh.parallel_for<>(::sycl::range<1>(range_size), [=](::sycl::item<1> nid) {
const size_t id = rid[nid.get_id(0)];
const int32_t value = static_cast<int32_t>(gradient_index[id * row_stride + fid] + offset);
const bool is_left = value <= split_cond;
if (is_left) {
AtomicRef<size_t> n_left(parts_size[0]);
p_rid_buf[n_left.fetch_add(1)] = id;
} else {
AtomicRef<size_t> n_right(parts_size[1]);
p_rid_buf[range_size - n_right.fetch_add(1) - 1] = id;
}
});
});
}
// split row indexes (rid_span) to 2 parts (both stored in rid_buf) depending
// on comparison of indexes values (idx_span) and split point (split_cond)
// Handle sparce columns
template <bool default_left, typename BinIdxType>
inline ::sycl::event PartitionSparseKernel(::sycl::queue* qu,
const GHistIndexMatrix& gmat,
const RowSetCollection::Elem& rid_span,
const size_t fid,
const int32_t split_cond,
xgboost::common::Span<size_t>* rid_buf,
size_t* parts_size,
::sycl::event event) {
const size_t row_stride = gmat.row_stride;
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
const size_t* rid = rid_span.begin;
const size_t range_size = rid_span.Size();
const uint32_t* cut_ptrs = gmat.cut_device.Ptrs().DataConst();
size_t* p_rid_buf = rid_buf->data();
return qu->submit([&](::sycl::handler& cgh) {
cgh.depends_on(event);
cgh.parallel_for<>(::sycl::range<1>(range_size), [=](::sycl::item<1> nid) {
const size_t id = rid[nid.get_id(0)];
const BinIdxType* gr_index_local = gradient_index + row_stride * id;
const int32_t fid_local = std::lower_bound(gr_index_local,
gr_index_local + row_stride,
cut_ptrs[fid]) - gr_index_local;
const bool is_left = (fid_local >= row_stride ||
gr_index_local[fid_local] >= cut_ptrs[fid + 1]) ?
default_left :
gr_index_local[fid_local] <= split_cond;
if (is_left) {
AtomicRef<size_t> n_left(parts_size[0]);
p_rid_buf[n_left.fetch_add(1)] = id;
} else {
AtomicRef<size_t> n_right(parts_size[1]);
p_rid_buf[range_size - n_right.fetch_add(1) - 1] = id;
}
});
});
}
// The builder is required for samples partition to left and rights children for set of nodes
class PartitionBuilder {
public:
@ -53,7 +137,6 @@ class PartitionBuilder {
return result_rows_[2 * nid];
}
size_t GetNRightElems(int nid) const {
return result_rows_[2 * nid + 1];
}
@ -72,19 +155,97 @@ class PartitionBuilder {
return { data_.Data() + nodes_offsets_[nid], nodes_offsets_[nid + 1] - nodes_offsets_[nid] };
}
template <typename BinIdxType>
::sycl::event Partition(const int32_t split_cond,
const GHistIndexMatrix& gmat,
const RowSetCollection::Elem& rid_span,
const xgboost::RegTree::Node& node,
xgboost::common::Span<size_t>* rid_buf,
size_t* parts_size,
::sycl::event event) {
const bst_uint fid = node.SplitIndex();
const bool default_left = node.DefaultLeft();
if (gmat.IsDense()) {
if (default_left) {
return PartitionDenseKernel<true, BinIdxType>(qu_, gmat, rid_span, fid,
split_cond, rid_buf, parts_size, event);
} else {
return PartitionDenseKernel<false, BinIdxType>(qu_, gmat, rid_span, fid,
split_cond, rid_buf, parts_size, event);
}
} else {
if (default_left) {
return PartitionSparseKernel<true, BinIdxType>(qu_, gmat, rid_span, fid,
split_cond, rid_buf, parts_size, event);
} else {
return PartitionSparseKernel<false, BinIdxType>(qu_, gmat, rid_span, fid,
split_cond, rid_buf, parts_size, event);
}
}
}
// Entry point for Partition
void Partition(const GHistIndexMatrix& gmat,
const std::vector<tree::ExpandEntry> nodes,
const RowSetCollection& row_set_collection,
const std::vector<int32_t>& split_conditions,
RegTree* p_tree,
::sycl::event* general_event) {
nodes_events_.resize(n_nodes_);
parts_size_.ResizeAndFill(qu_, 2 * n_nodes_, 0, general_event);
for (size_t node_in_set = 0; node_in_set < n_nodes_; node_in_set++) {
const int32_t nid = nodes[node_in_set].nid;
::sycl::event& node_event = nodes_events_[node_in_set];
const auto& rid_span = row_set_collection[nid];
if (rid_span.Size() > 0) {
const RegTree::Node& node = (*p_tree)[nid];
xgboost::common::Span<size_t> rid_buf = GetData(node_in_set);
size_t* part_size = parts_size_.Data() + 2 * node_in_set;
int32_t split_condition = split_conditions[node_in_set];
switch (gmat.index.GetBinTypeSize()) {
case common::BinTypeSize::kUint8BinsTypeSize:
node_event = Partition<uint8_t>(split_condition, gmat, rid_span, node,
&rid_buf, part_size, *general_event);
break;
case common::BinTypeSize::kUint16BinsTypeSize:
node_event = Partition<uint16_t>(split_condition, gmat, rid_span, node,
&rid_buf, part_size, *general_event);
break;
case common::BinTypeSize::kUint32BinsTypeSize:
node_event = Partition<uint32_t>(split_condition, gmat, rid_span, node,
&rid_buf, part_size, *general_event);
break;
default:
CHECK(false); // no default behavior
}
} else {
node_event = ::sycl::event();
}
}
*general_event = qu_->memcpy(result_rows_.data(),
parts_size_.DataConst(),
sizeof(size_t) * 2 * n_nodes_,
nodes_events_);
}
void MergeToArray(size_t nid,
size_t* data_result,
::sycl::event event) {
::sycl::event* event) {
size_t n_nodes_total = GetNLeftElems(nid) + GetNRightElems(nid);
if (n_nodes_total > 0) {
const size_t* data = data_.Data() + nodes_offsets_[nid];
qu_->memcpy(data_result, data, sizeof(size_t) * n_nodes_total, event);
qu_->memcpy(data_result, data, sizeof(size_t) * n_nodes_total, *event);
}
}
protected:
std::vector<size_t> nodes_offsets_;
std::vector<size_t> result_rows_;
std::vector<::sycl::event> nodes_events_;
size_t n_nodes_;
USMVector<size_t, MemoryType::on_device> parts_size_;

View File

@ -171,20 +171,20 @@ class USMVector {
}
}
::sycl::event ResizeAndFill(::sycl::queue* qu, size_t size_new, int v) {
void ResizeAndFill(::sycl::queue* qu, size_t size_new, int v, ::sycl::event* event) {
if (size_new <= size_) {
size_ = size_new;
return qu->memset(data_.get(), v, size_new * sizeof(T));
*event = qu->memset(data_.get(), v, size_new * sizeof(T), *event);
} else if (size_new <= capacity_) {
size_ = size_new;
return qu->memset(data_.get(), v, size_new * sizeof(T));
*event = qu->memset(data_.get(), v, size_new * sizeof(T), *event);
} else {
size_t size_old = size_;
auto data_old = data_;
size_ = size_new;
capacity_ = size_new;
data_ = allocate_memory_(qu, size_);
return qu->memset(data_.get(), v, size_new * sizeof(T));
*event = qu->memset(data_.get(), v, size_new * sizeof(T), *event);
}
}
@ -211,11 +211,16 @@ class USMVector {
struct DeviceMatrix {
DMatrix* p_mat; // Pointer to the original matrix on the host
::sycl::queue qu_;
USMVector<size_t> row_ptr;
USMVector<size_t, MemoryType::on_device> row_ptr;
USMVector<Entry, MemoryType::on_device> data;
size_t total_offset;
DeviceMatrix(::sycl::queue qu, DMatrix* dmat) : p_mat(dmat), qu_(qu) {
DeviceMatrix() = default;
void Init(::sycl::queue qu, DMatrix* dmat) {
qu_ = qu;
p_mat = dmat;
size_t num_row = 0;
size_t num_nonzero = 0;
for (auto &batch : dmat->GetBatches<SparsePage>()) {
@ -226,27 +231,41 @@ struct DeviceMatrix {
}
row_ptr.Resize(&qu_, num_row + 1);
size_t* rows = row_ptr.Data();
data.Resize(&qu_, num_nonzero);
size_t data_offset = 0;
::sycl::event event;
for (auto &batch : dmat->GetBatches<SparsePage>()) {
const auto& data_vec = batch.data.HostVector();
const auto& offset_vec = batch.offset.HostVector();
size_t batch_size = batch.Size();
if (batch_size > 0) {
std::copy(offset_vec.data(), offset_vec.data() + batch_size,
row_ptr.Data() + batch.base_rowid);
if (batch.base_rowid > 0) {
for (size_t i = 0; i < batch_size; i++)
row_ptr[i + batch.base_rowid] += batch.base_rowid;
const auto base_rowid = batch.base_rowid;
event = qu.memcpy(row_ptr.Data() + base_rowid, offset_vec.data(),
sizeof(size_t) * batch_size, event);
if (base_rowid > 0) {
qu.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event);
cgh.parallel_for<>(::sycl::range<1>(batch_size), [=](::sycl::id<1> pid) {
int row_id = pid[0];
rows[row_id] += base_rowid;
});
});
}
qu.memcpy(data.Data() + data_offset,
data_vec.data(),
offset_vec[batch_size] * sizeof(Entry)).wait();
event = qu.memcpy(data.Data() + data_offset, data_vec.data(),
sizeof(Entry) * offset_vec[batch_size], event);
data_offset += offset_vec[batch_size];
qu.wait();
}
}
row_ptr[num_row] = data_offset;
qu.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event);
cgh.single_task<>([=] {
rows[num_row] = data_offset;
});
});
qu.wait();
total_offset = data_offset;
}

View File

@ -280,7 +280,8 @@ class Predictor : public xgboost::Predictor {
uint32_t tree_end = 0) const override {
::sycl::queue qu = device_manager.GetQueue(ctx_->Device());
// TODO(razdoburdin): remove temporary workaround after cache fix
sycl::DeviceMatrix device_matrix(qu, dmat);
sycl::DeviceMatrix device_matrix;
device_matrix.Init(qu, dmat);
auto* out_preds = &predts->predictions;
if (tree_end == 0) {

View File

@ -0,0 +1,50 @@
/*!
* Copyright 2017-2024 by Contributors
* \file expand_entry.h
*/
#ifndef PLUGIN_SYCL_TREE_EXPAND_ENTRY_H_
#define PLUGIN_SYCL_TREE_EXPAND_ENTRY_H_
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#include "../../src/tree/constraints.h"
#pragma GCC diagnostic pop
#include "../../src/tree/hist/expand_entry.h"
namespace xgboost {
namespace sycl {
namespace tree {
/* tree growing policies */
struct ExpandEntry : public xgboost::tree::ExpandEntryImpl<ExpandEntry> {
static constexpr bst_node_t kRootNid = 0;
xgboost::tree::SplitEntry split;
ExpandEntry(int nid, int depth) : ExpandEntryImpl{nid, depth} {}
inline bst_node_t GetSiblingId(const xgboost::RegTree* p_tree) const {
CHECK_EQ((*p_tree)[nid].IsRoot(), false);
const size_t parent_id = (*p_tree)[nid].Parent();
return GetSiblingId(p_tree, parent_id);
}
inline bst_node_t GetSiblingId(const xgboost::RegTree* p_tree, size_t parent_id) const {
return p_tree->IsLeftChild(nid) ? p_tree->RightChild(parent_id)
: p_tree->LeftChild(parent_id);
}
bool IsValidImpl(xgboost::tree::TrainParam const &param, int32_t num_leaves) const {
if (split.loss_chg <= kRtEps) return false;
if (split.loss_chg < param.min_split_loss) return false;
if (param.max_depth > 0 && depth == param.max_depth) return false;
if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false;
return true;
}
};
} // namespace tree
} // namespace sycl
} // namespace xgboost
#endif // PLUGIN_SYCL_TREE_EXPAND_ENTRY_H_

View File

@ -49,7 +49,8 @@ TEST(SyclGradientIndex, Init) {
auto p_fmat = RandomDataGenerator{n_rows, n_columns, 0.3}.GenerateDMatrix();
sycl::DeviceMatrix dmat(qu, p_fmat.get());
sycl::DeviceMatrix dmat;
dmat.Init(qu, p_fmat.get());
int max_bins = 256;
common::GHistIndexMatrix gmat_sycl;

View File

@ -13,6 +13,108 @@
namespace xgboost::sycl::common {
void TestPartitioning(float sparsity, int max_bins) {
const size_t num_rows = 16;
const size_t num_columns = 1;
Context ctx;
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
DeviceManager device_manager;
auto qu = device_manager.GetQueue(ctx.Device());
auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix();
sycl::DeviceMatrix dmat;
dmat.Init(qu, p_fmat.get());
common::GHistIndexMatrix gmat;
gmat.Init(qu, &ctx, dmat, max_bins);
RowSetCollection row_set_collection;
auto& row_indices = row_set_collection.Data();
row_indices.Resize(&qu, num_rows);
size_t* p_row_indices = row_indices.Data();
qu.submit([&](::sycl::handler& cgh) {
cgh.parallel_for<>(::sycl::range<1>(num_rows),
[p_row_indices](::sycl::item<1> pid) {
const size_t idx = pid.get_id(0);
p_row_indices[idx] = idx;
});
}).wait_and_throw();
row_set_collection.Init();
RegTree tree;
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
const size_t n_nodes = row_set_collection.Size();
PartitionBuilder partition_builder;
partition_builder.Init(&qu, n_nodes, [&](size_t nid) {
return row_set_collection[nid].Size();
});
std::vector<tree::ExpandEntry> nodes;
nodes.emplace_back(tree::ExpandEntry(0, tree.GetDepth(0)));
::sycl::event event;
std::vector<int32_t> split_conditions = {2};
partition_builder.Partition(gmat, nodes, row_set_collection,
split_conditions, &tree, &event);
qu.wait_and_throw();
size_t* data_result = const_cast<size_t*>(row_set_collection[0].begin);
partition_builder.MergeToArray(0, data_result, &event);
qu.wait_and_throw();
bst_float split_pt = gmat.cut.Values()[split_conditions[0]];
std::vector<uint8_t> ridx_left(num_rows, 0);
std::vector<uint8_t> ridx_right(num_rows, 0);
for (auto &batch : gmat.p_fmat->GetBatches<SparsePage>()) {
const auto& data_vec = batch.data.HostVector();
const auto& offset_vec = batch.offset.HostVector();
size_t begin = offset_vec[0];
for (size_t idx = 0; idx < offset_vec.size() - 1; ++idx) {
size_t end = offset_vec[idx + 1];
if (begin < end) {
const auto& entry = data_vec[begin];
if (entry.fvalue < split_pt) {
ridx_left[idx] = 1;
} else {
ridx_right[idx] = 1;
}
} else {
// missing value
if (tree[0].DefaultLeft()) {
ridx_left[idx] = 1;
} else {
ridx_right[idx] = 1;
}
}
begin = end;
}
}
auto n_left = std::accumulate(ridx_left.begin(), ridx_left.end(), 0);
auto n_right = std::accumulate(ridx_right.begin(), ridx_right.end(), 0);
std::vector<size_t> row_indices_host(num_rows);
qu.memcpy(row_indices_host.data(), row_indices.Data(), num_rows * sizeof(size_t));
qu.wait_and_throw();
ASSERT_EQ(n_left, partition_builder.GetNLeftElems(0));
for (size_t i = 0; i < n_left; ++i) {
auto idx = row_indices_host[i];
ASSERT_EQ(ridx_left[idx], 1);
}
ASSERT_EQ(n_right, partition_builder.GetNRightElems(0));
for (size_t i = 0; i < n_right; ++i) {
auto idx = row_indices_host[num_rows - 1 - i];
ASSERT_EQ(ridx_right[idx], 1);
}
}
TEST(SyclPartitionBuilder, BasicTest) {
constexpr size_t kNodes = 5;
// Number of rows for each node
@ -67,7 +169,7 @@ TEST(SyclPartitionBuilder, BasicTest) {
std::vector<size_t> v(*std::max_element(rows.begin(), rows.end()));
size_t row_id = 0;
for(size_t nid = 0; nid < kNodes; ++nid) {
builder.MergeToArray(nid, v.data(), event);
builder.MergeToArray(nid, v.data(), &event);
qu.wait();
// Check that row_id for left side are correct
@ -88,4 +190,20 @@ TEST(SyclPartitionBuilder, BasicTest) {
}
}
TEST(SyclPartitionBuilder, PartitioningSparce) {
TestPartitioning(0.3, 256);
}
TEST(SyclPartitionBuilder, PartitioningDence8Bits) {
TestPartitioning(0.0, 256);
}
TEST(SyclPartitionBuilder, PartitioningDence16Bits) {
TestPartitioning(0.0, 256 + 1);
}
TEST(SyclPartitionBuilder, PartitioningDence32Bits) {
TestPartitioning(0.0, (1u << 16) + 1);
}
} // namespace xgboost::common