[SYCL] Implement row set collection. (#10057)
Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
parent
0ce4372bd4
commit
761845f594
123
plugin/sycl/common/row_set.h
Normal file
123
plugin/sycl/common/row_set.h
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2017-2023 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#ifndef PLUGIN_SYCL_COMMON_ROW_SET_H_
|
||||||
|
#define PLUGIN_SYCL_COMMON_ROW_SET_H_
|
||||||
|
|
||||||
|
#pragma GCC diagnostic push
|
||||||
|
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||||
|
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||||
|
#include <xgboost/data.h>
|
||||||
|
#pragma GCC diagnostic pop
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "../data.h"
|
||||||
|
|
||||||
|
#include <CL/sycl.hpp>
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace sycl {
|
||||||
|
namespace common {
|
||||||
|
|
||||||
|
|
||||||
|
/*! \brief Collection of rowsets stored on device in USM memory */
|
||||||
|
class RowSetCollection {
|
||||||
|
public:
|
||||||
|
/*! \brief data structure to store an instance set, a subset of
|
||||||
|
* rows (instances) associated with a particular node in a decision
|
||||||
|
* tree. */
|
||||||
|
struct Elem {
|
||||||
|
const size_t* begin{nullptr};
|
||||||
|
const size_t* end{nullptr};
|
||||||
|
bst_node_t node_id{-1}; // id of node associated with this instance set; -1 means uninitialized
|
||||||
|
Elem()
|
||||||
|
= default;
|
||||||
|
Elem(const size_t* begin,
|
||||||
|
const size_t* end,
|
||||||
|
bst_node_t node_id = -1)
|
||||||
|
: begin(begin), end(end), node_id(node_id) {}
|
||||||
|
|
||||||
|
|
||||||
|
inline size_t Size() const {
|
||||||
|
return end - begin;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline size_t Size() const {
|
||||||
|
return elem_of_each_node_.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
/*! \brief return corresponding element set given the node_id */
|
||||||
|
inline const Elem& operator[](unsigned node_id) const {
|
||||||
|
const Elem& e = elem_of_each_node_[node_id];
|
||||||
|
CHECK(e.begin != nullptr)
|
||||||
|
<< "access element that is not in the set";
|
||||||
|
return e;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*! \brief return corresponding element set given the node_id */
|
||||||
|
inline Elem& operator[](unsigned node_id) {
|
||||||
|
Elem& e = elem_of_each_node_[node_id];
|
||||||
|
return e;
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear up things
|
||||||
|
inline void Clear() {
|
||||||
|
elem_of_each_node_.clear();
|
||||||
|
}
|
||||||
|
// initialize node id 0->everything
|
||||||
|
inline void Init() {
|
||||||
|
CHECK_EQ(elem_of_each_node_.size(), 0U);
|
||||||
|
|
||||||
|
const size_t* begin = row_indices_.Begin();
|
||||||
|
const size_t* end = row_indices_.End();
|
||||||
|
elem_of_each_node_.emplace_back(Elem(begin, end, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& Data() { return row_indices_; }
|
||||||
|
|
||||||
|
// split rowset into two
|
||||||
|
inline void AddSplit(unsigned node_id,
|
||||||
|
unsigned left_node_id,
|
||||||
|
unsigned right_node_id,
|
||||||
|
size_t n_left,
|
||||||
|
size_t n_right) {
|
||||||
|
const Elem e = elem_of_each_node_[node_id];
|
||||||
|
CHECK(e.begin != nullptr);
|
||||||
|
size_t* all_begin = row_indices_.Begin();
|
||||||
|
size_t* begin = all_begin + (e.begin - all_begin);
|
||||||
|
|
||||||
|
|
||||||
|
CHECK_EQ(n_left + n_right, e.Size());
|
||||||
|
CHECK_LE(begin + n_left, e.end);
|
||||||
|
CHECK_EQ(begin + n_left + n_right, e.end);
|
||||||
|
|
||||||
|
|
||||||
|
if (left_node_id >= elem_of_each_node_.size()) {
|
||||||
|
elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1));
|
||||||
|
}
|
||||||
|
if (right_node_id >= elem_of_each_node_.size()) {
|
||||||
|
elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
elem_of_each_node_[left_node_id] = Elem(begin, begin + n_left, left_node_id);
|
||||||
|
elem_of_each_node_[right_node_id] = Elem(begin + n_left, e.end, right_node_id);
|
||||||
|
elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// stores the row indexes in the set
|
||||||
|
USMVector<size_t, MemoryType::on_device> row_indices_;
|
||||||
|
// vector: node_id -> elements
|
||||||
|
std::vector<Elem> elem_of_each_node_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace common
|
||||||
|
} // namespace sycl
|
||||||
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|
||||||
|
#endif // PLUGIN_SYCL_COMMON_ROW_SET_H_
|
||||||
78
tests/cpp/plugin/test_sycl_row_set_collection.cc
Normal file
78
tests/cpp/plugin/test_sycl_row_set_collection.cc
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020-2023 by XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../../../plugin/sycl/common/row_set.h"
|
||||||
|
#include "../../../plugin/sycl/device_manager.h"
|
||||||
|
#include "../helpers.h"
|
||||||
|
|
||||||
|
namespace xgboost::sycl::common {
|
||||||
|
TEST(SyclRowSetCollection, AddSplits) {
|
||||||
|
const size_t num_rows = 16;
|
||||||
|
|
||||||
|
DeviceManager device_manager;
|
||||||
|
auto qu = device_manager.GetQueue(DeviceOrd::SyclDefault());
|
||||||
|
|
||||||
|
RowSetCollection row_set_collection;
|
||||||
|
|
||||||
|
auto& row_indices = row_set_collection.Data();
|
||||||
|
row_indices.Resize(&qu, num_rows);
|
||||||
|
size_t* p_row_indices = row_indices.Data();
|
||||||
|
|
||||||
|
qu.submit([&](::sycl::handler& cgh) {
|
||||||
|
cgh.parallel_for<>(::sycl::range<1>(num_rows),
|
||||||
|
[p_row_indices](::sycl::item<1> pid) {
|
||||||
|
const size_t idx = pid.get_id(0);
|
||||||
|
p_row_indices[idx] = idx;
|
||||||
|
});
|
||||||
|
}).wait_and_throw();
|
||||||
|
row_set_collection.Init();
|
||||||
|
|
||||||
|
CHECK_EQ(row_set_collection.Size(), 1);
|
||||||
|
{
|
||||||
|
size_t nid_test = 0;
|
||||||
|
auto& elem = row_set_collection[nid_test];
|
||||||
|
CHECK_EQ(elem.begin, row_indices.Begin());
|
||||||
|
CHECK_EQ(elem.end, row_indices.End());
|
||||||
|
CHECK_EQ(elem.node_id , 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t nid = 0;
|
||||||
|
size_t nid_left = 1;
|
||||||
|
size_t nid_right = 2;
|
||||||
|
size_t n_left = 4;
|
||||||
|
size_t n_right = num_rows - n_left;
|
||||||
|
row_set_collection.AddSplit(nid, nid_left, nid_right, n_left, n_right);
|
||||||
|
CHECK_EQ(row_set_collection.Size(), 3);
|
||||||
|
|
||||||
|
{
|
||||||
|
size_t nid_test = 0;
|
||||||
|
auto& elem = row_set_collection[nid_test];
|
||||||
|
CHECK_EQ(elem.begin, nullptr);
|
||||||
|
CHECK_EQ(elem.end, nullptr);
|
||||||
|
CHECK_EQ(elem.node_id , -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
size_t nid_test = 1;
|
||||||
|
auto& elem = row_set_collection[nid_test];
|
||||||
|
CHECK_EQ(elem.begin, row_indices.Begin());
|
||||||
|
CHECK_EQ(elem.end, row_indices.Begin() + n_left);
|
||||||
|
CHECK_EQ(elem.node_id , nid_test);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
size_t nid_test = 2;
|
||||||
|
auto& elem = row_set_collection[nid_test];
|
||||||
|
CHECK_EQ(elem.begin, row_indices.Begin() + n_left);
|
||||||
|
CHECK_EQ(elem.end, row_indices.End());
|
||||||
|
CHECK_EQ(elem.node_id , nid_test);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
} // namespace xgboost::sycl::common
|
||||||
Loading…
x
Reference in New Issue
Block a user