[sycl] add partitioning and related tests (#10080)
Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
parent
2c12b956da
commit
7a61216690
@ -21,6 +21,9 @@
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include "../data.h"
|
||||
#include "row_set.h"
|
||||
#include "../data/gradient_index.h"
|
||||
#include "../tree/expand_entry.h"
|
||||
|
||||
#include <CL/sycl.hpp>
|
||||
|
||||
@ -28,6 +31,87 @@ namespace xgboost {
|
||||
namespace sycl {
|
||||
namespace common {
|
||||
|
||||
// split row indexes (rid_span) to 2 parts (both stored in rid_buf) depending
|
||||
// on comparison of indexes values (idx_span) and split point (split_cond)
|
||||
// Handle dense columns
|
||||
template <bool default_left, typename BinIdxType>
|
||||
inline ::sycl::event PartitionDenseKernel(
|
||||
::sycl::queue* qu,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const RowSetCollection::Elem& rid_span,
|
||||
const size_t fid,
|
||||
const int32_t split_cond,
|
||||
xgboost::common::Span<size_t>* rid_buf,
|
||||
size_t* parts_size,
|
||||
::sycl::event event) {
|
||||
const size_t row_stride = gmat.row_stride;
|
||||
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
|
||||
const size_t* rid = rid_span.begin;
|
||||
const size_t range_size = rid_span.Size();
|
||||
const size_t offset = gmat.cut.Ptrs()[fid];
|
||||
|
||||
size_t* p_rid_buf = rid_buf->data();
|
||||
|
||||
return qu->submit([&](::sycl::handler& cgh) {
|
||||
cgh.depends_on(event);
|
||||
cgh.parallel_for<>(::sycl::range<1>(range_size), [=](::sycl::item<1> nid) {
|
||||
const size_t id = rid[nid.get_id(0)];
|
||||
const int32_t value = static_cast<int32_t>(gradient_index[id * row_stride + fid] + offset);
|
||||
const bool is_left = value <= split_cond;
|
||||
if (is_left) {
|
||||
AtomicRef<size_t> n_left(parts_size[0]);
|
||||
p_rid_buf[n_left.fetch_add(1)] = id;
|
||||
} else {
|
||||
AtomicRef<size_t> n_right(parts_size[1]);
|
||||
p_rid_buf[range_size - n_right.fetch_add(1) - 1] = id;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// split row indexes (rid_span) to 2 parts (both stored in rid_buf) depending
|
||||
// on comparison of indexes values (idx_span) and split point (split_cond)
|
||||
// Handle sparce columns
|
||||
template <bool default_left, typename BinIdxType>
|
||||
inline ::sycl::event PartitionSparseKernel(::sycl::queue* qu,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const RowSetCollection::Elem& rid_span,
|
||||
const size_t fid,
|
||||
const int32_t split_cond,
|
||||
xgboost::common::Span<size_t>* rid_buf,
|
||||
size_t* parts_size,
|
||||
::sycl::event event) {
|
||||
const size_t row_stride = gmat.row_stride;
|
||||
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
|
||||
const size_t* rid = rid_span.begin;
|
||||
const size_t range_size = rid_span.Size();
|
||||
const uint32_t* cut_ptrs = gmat.cut_device.Ptrs().DataConst();
|
||||
|
||||
size_t* p_rid_buf = rid_buf->data();
|
||||
return qu->submit([&](::sycl::handler& cgh) {
|
||||
cgh.depends_on(event);
|
||||
cgh.parallel_for<>(::sycl::range<1>(range_size), [=](::sycl::item<1> nid) {
|
||||
const size_t id = rid[nid.get_id(0)];
|
||||
|
||||
const BinIdxType* gr_index_local = gradient_index + row_stride * id;
|
||||
const int32_t fid_local = std::lower_bound(gr_index_local,
|
||||
gr_index_local + row_stride,
|
||||
cut_ptrs[fid]) - gr_index_local;
|
||||
const bool is_left = (fid_local >= row_stride ||
|
||||
gr_index_local[fid_local] >= cut_ptrs[fid + 1]) ?
|
||||
default_left :
|
||||
gr_index_local[fid_local] <= split_cond;
|
||||
if (is_left) {
|
||||
AtomicRef<size_t> n_left(parts_size[0]);
|
||||
p_rid_buf[n_left.fetch_add(1)] = id;
|
||||
} else {
|
||||
AtomicRef<size_t> n_right(parts_size[1]);
|
||||
p_rid_buf[range_size - n_right.fetch_add(1) - 1] = id;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// The builder is required for samples partition to left and rights children for set of nodes
|
||||
class PartitionBuilder {
|
||||
public:
|
||||
@ -53,7 +137,6 @@ class PartitionBuilder {
|
||||
return result_rows_[2 * nid];
|
||||
}
|
||||
|
||||
|
||||
size_t GetNRightElems(int nid) const {
|
||||
return result_rows_[2 * nid + 1];
|
||||
}
|
||||
@ -72,19 +155,97 @@ class PartitionBuilder {
|
||||
return { data_.Data() + nodes_offsets_[nid], nodes_offsets_[nid + 1] - nodes_offsets_[nid] };
|
||||
}
|
||||
|
||||
template <typename BinIdxType>
|
||||
::sycl::event Partition(const int32_t split_cond,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const RowSetCollection::Elem& rid_span,
|
||||
const xgboost::RegTree::Node& node,
|
||||
xgboost::common::Span<size_t>* rid_buf,
|
||||
size_t* parts_size,
|
||||
::sycl::event event) {
|
||||
const bst_uint fid = node.SplitIndex();
|
||||
const bool default_left = node.DefaultLeft();
|
||||
|
||||
if (gmat.IsDense()) {
|
||||
if (default_left) {
|
||||
return PartitionDenseKernel<true, BinIdxType>(qu_, gmat, rid_span, fid,
|
||||
split_cond, rid_buf, parts_size, event);
|
||||
} else {
|
||||
return PartitionDenseKernel<false, BinIdxType>(qu_, gmat, rid_span, fid,
|
||||
split_cond, rid_buf, parts_size, event);
|
||||
}
|
||||
} else {
|
||||
if (default_left) {
|
||||
return PartitionSparseKernel<true, BinIdxType>(qu_, gmat, rid_span, fid,
|
||||
split_cond, rid_buf, parts_size, event);
|
||||
} else {
|
||||
return PartitionSparseKernel<false, BinIdxType>(qu_, gmat, rid_span, fid,
|
||||
split_cond, rid_buf, parts_size, event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Entry point for Partition
|
||||
void Partition(const GHistIndexMatrix& gmat,
|
||||
const std::vector<tree::ExpandEntry> nodes,
|
||||
const RowSetCollection& row_set_collection,
|
||||
const std::vector<int32_t>& split_conditions,
|
||||
RegTree* p_tree,
|
||||
::sycl::event* general_event) {
|
||||
nodes_events_.resize(n_nodes_);
|
||||
|
||||
parts_size_.ResizeAndFill(qu_, 2 * n_nodes_, 0, general_event);
|
||||
|
||||
for (size_t node_in_set = 0; node_in_set < n_nodes_; node_in_set++) {
|
||||
const int32_t nid = nodes[node_in_set].nid;
|
||||
::sycl::event& node_event = nodes_events_[node_in_set];
|
||||
const auto& rid_span = row_set_collection[nid];
|
||||
if (rid_span.Size() > 0) {
|
||||
const RegTree::Node& node = (*p_tree)[nid];
|
||||
xgboost::common::Span<size_t> rid_buf = GetData(node_in_set);
|
||||
size_t* part_size = parts_size_.Data() + 2 * node_in_set;
|
||||
int32_t split_condition = split_conditions[node_in_set];
|
||||
switch (gmat.index.GetBinTypeSize()) {
|
||||
case common::BinTypeSize::kUint8BinsTypeSize:
|
||||
node_event = Partition<uint8_t>(split_condition, gmat, rid_span, node,
|
||||
&rid_buf, part_size, *general_event);
|
||||
break;
|
||||
case common::BinTypeSize::kUint16BinsTypeSize:
|
||||
node_event = Partition<uint16_t>(split_condition, gmat, rid_span, node,
|
||||
&rid_buf, part_size, *general_event);
|
||||
break;
|
||||
case common::BinTypeSize::kUint32BinsTypeSize:
|
||||
node_event = Partition<uint32_t>(split_condition, gmat, rid_span, node,
|
||||
&rid_buf, part_size, *general_event);
|
||||
break;
|
||||
default:
|
||||
CHECK(false); // no default behavior
|
||||
}
|
||||
} else {
|
||||
node_event = ::sycl::event();
|
||||
}
|
||||
}
|
||||
|
||||
*general_event = qu_->memcpy(result_rows_.data(),
|
||||
parts_size_.DataConst(),
|
||||
sizeof(size_t) * 2 * n_nodes_,
|
||||
nodes_events_);
|
||||
}
|
||||
|
||||
void MergeToArray(size_t nid,
|
||||
size_t* data_result,
|
||||
::sycl::event event) {
|
||||
::sycl::event* event) {
|
||||
size_t n_nodes_total = GetNLeftElems(nid) + GetNRightElems(nid);
|
||||
if (n_nodes_total > 0) {
|
||||
const size_t* data = data_.Data() + nodes_offsets_[nid];
|
||||
qu_->memcpy(data_result, data, sizeof(size_t) * n_nodes_total, event);
|
||||
qu_->memcpy(data_result, data, sizeof(size_t) * n_nodes_total, *event);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<size_t> nodes_offsets_;
|
||||
std::vector<size_t> result_rows_;
|
||||
std::vector<::sycl::event> nodes_events_;
|
||||
size_t n_nodes_;
|
||||
|
||||
USMVector<size_t, MemoryType::on_device> parts_size_;
|
||||
|
||||
@ -171,20 +171,20 @@ class USMVector {
|
||||
}
|
||||
}
|
||||
|
||||
::sycl::event ResizeAndFill(::sycl::queue* qu, size_t size_new, int v) {
|
||||
void ResizeAndFill(::sycl::queue* qu, size_t size_new, int v, ::sycl::event* event) {
|
||||
if (size_new <= size_) {
|
||||
size_ = size_new;
|
||||
return qu->memset(data_.get(), v, size_new * sizeof(T));
|
||||
*event = qu->memset(data_.get(), v, size_new * sizeof(T), *event);
|
||||
} else if (size_new <= capacity_) {
|
||||
size_ = size_new;
|
||||
return qu->memset(data_.get(), v, size_new * sizeof(T));
|
||||
*event = qu->memset(data_.get(), v, size_new * sizeof(T), *event);
|
||||
} else {
|
||||
size_t size_old = size_;
|
||||
auto data_old = data_;
|
||||
size_ = size_new;
|
||||
capacity_ = size_new;
|
||||
data_ = allocate_memory_(qu, size_);
|
||||
return qu->memset(data_.get(), v, size_new * sizeof(T));
|
||||
*event = qu->memset(data_.get(), v, size_new * sizeof(T), *event);
|
||||
}
|
||||
}
|
||||
|
||||
@ -211,11 +211,16 @@ class USMVector {
|
||||
struct DeviceMatrix {
|
||||
DMatrix* p_mat; // Pointer to the original matrix on the host
|
||||
::sycl::queue qu_;
|
||||
USMVector<size_t> row_ptr;
|
||||
USMVector<size_t, MemoryType::on_device> row_ptr;
|
||||
USMVector<Entry, MemoryType::on_device> data;
|
||||
size_t total_offset;
|
||||
|
||||
DeviceMatrix(::sycl::queue qu, DMatrix* dmat) : p_mat(dmat), qu_(qu) {
|
||||
DeviceMatrix() = default;
|
||||
|
||||
void Init(::sycl::queue qu, DMatrix* dmat) {
|
||||
qu_ = qu;
|
||||
p_mat = dmat;
|
||||
|
||||
size_t num_row = 0;
|
||||
size_t num_nonzero = 0;
|
||||
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||
@ -226,27 +231,41 @@ struct DeviceMatrix {
|
||||
}
|
||||
|
||||
row_ptr.Resize(&qu_, num_row + 1);
|
||||
size_t* rows = row_ptr.Data();
|
||||
data.Resize(&qu_, num_nonzero);
|
||||
|
||||
size_t data_offset = 0;
|
||||
::sycl::event event;
|
||||
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||
const auto& data_vec = batch.data.HostVector();
|
||||
const auto& offset_vec = batch.offset.HostVector();
|
||||
size_t batch_size = batch.Size();
|
||||
if (batch_size > 0) {
|
||||
std::copy(offset_vec.data(), offset_vec.data() + batch_size,
|
||||
row_ptr.Data() + batch.base_rowid);
|
||||
if (batch.base_rowid > 0) {
|
||||
for (size_t i = 0; i < batch_size; i++)
|
||||
row_ptr[i + batch.base_rowid] += batch.base_rowid;
|
||||
const auto base_rowid = batch.base_rowid;
|
||||
event = qu.memcpy(row_ptr.Data() + base_rowid, offset_vec.data(),
|
||||
sizeof(size_t) * batch_size, event);
|
||||
if (base_rowid > 0) {
|
||||
qu.submit([&](::sycl::handler& cgh) {
|
||||
cgh.depends_on(event);
|
||||
cgh.parallel_for<>(::sycl::range<1>(batch_size), [=](::sycl::id<1> pid) {
|
||||
int row_id = pid[0];
|
||||
rows[row_id] += base_rowid;
|
||||
});
|
||||
});
|
||||
}
|
||||
qu.memcpy(data.Data() + data_offset,
|
||||
data_vec.data(),
|
||||
offset_vec[batch_size] * sizeof(Entry)).wait();
|
||||
event = qu.memcpy(data.Data() + data_offset, data_vec.data(),
|
||||
sizeof(Entry) * offset_vec[batch_size], event);
|
||||
data_offset += offset_vec[batch_size];
|
||||
qu.wait();
|
||||
}
|
||||
}
|
||||
row_ptr[num_row] = data_offset;
|
||||
qu.submit([&](::sycl::handler& cgh) {
|
||||
cgh.depends_on(event);
|
||||
cgh.single_task<>([=] {
|
||||
rows[num_row] = data_offset;
|
||||
});
|
||||
});
|
||||
qu.wait();
|
||||
total_offset = data_offset;
|
||||
}
|
||||
|
||||
|
||||
@ -280,7 +280,8 @@ class Predictor : public xgboost::Predictor {
|
||||
uint32_t tree_end = 0) const override {
|
||||
::sycl::queue qu = device_manager.GetQueue(ctx_->Device());
|
||||
// TODO(razdoburdin): remove temporary workaround after cache fix
|
||||
sycl::DeviceMatrix device_matrix(qu, dmat);
|
||||
sycl::DeviceMatrix device_matrix;
|
||||
device_matrix.Init(qu, dmat);
|
||||
|
||||
auto* out_preds = &predts->predictions;
|
||||
if (tree_end == 0) {
|
||||
|
||||
50
plugin/sycl/tree/expand_entry.h
Normal file
50
plugin/sycl/tree/expand_entry.h
Normal file
@ -0,0 +1,50 @@
|
||||
/*!
|
||||
* Copyright 2017-2024 by Contributors
|
||||
* \file expand_entry.h
|
||||
*/
|
||||
#ifndef PLUGIN_SYCL_TREE_EXPAND_ENTRY_H_
|
||||
#define PLUGIN_SYCL_TREE_EXPAND_ENTRY_H_
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#include "../../src/tree/constraints.h"
|
||||
#pragma GCC diagnostic pop
|
||||
#include "../../src/tree/hist/expand_entry.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace sycl {
|
||||
namespace tree {
|
||||
/* tree growing policies */
|
||||
struct ExpandEntry : public xgboost::tree::ExpandEntryImpl<ExpandEntry> {
|
||||
static constexpr bst_node_t kRootNid = 0;
|
||||
|
||||
xgboost::tree::SplitEntry split;
|
||||
|
||||
ExpandEntry(int nid, int depth) : ExpandEntryImpl{nid, depth} {}
|
||||
|
||||
inline bst_node_t GetSiblingId(const xgboost::RegTree* p_tree) const {
|
||||
CHECK_EQ((*p_tree)[nid].IsRoot(), false);
|
||||
const size_t parent_id = (*p_tree)[nid].Parent();
|
||||
return GetSiblingId(p_tree, parent_id);
|
||||
}
|
||||
|
||||
inline bst_node_t GetSiblingId(const xgboost::RegTree* p_tree, size_t parent_id) const {
|
||||
return p_tree->IsLeftChild(nid) ? p_tree->RightChild(parent_id)
|
||||
: p_tree->LeftChild(parent_id);
|
||||
}
|
||||
|
||||
bool IsValidImpl(xgboost::tree::TrainParam const ¶m, int32_t num_leaves) const {
|
||||
if (split.loss_chg <= kRtEps) return false;
|
||||
if (split.loss_chg < param.min_split_loss) return false;
|
||||
if (param.max_depth > 0 && depth == param.max_depth) return false;
|
||||
if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tree
|
||||
} // namespace sycl
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // PLUGIN_SYCL_TREE_EXPAND_ENTRY_H_
|
||||
@ -49,7 +49,8 @@ TEST(SyclGradientIndex, Init) {
|
||||
|
||||
auto p_fmat = RandomDataGenerator{n_rows, n_columns, 0.3}.GenerateDMatrix();
|
||||
|
||||
sycl::DeviceMatrix dmat(qu, p_fmat.get());
|
||||
sycl::DeviceMatrix dmat;
|
||||
dmat.Init(qu, p_fmat.get());
|
||||
|
||||
int max_bins = 256;
|
||||
common::GHistIndexMatrix gmat_sycl;
|
||||
|
||||
@ -13,6 +13,108 @@
|
||||
|
||||
namespace xgboost::sycl::common {
|
||||
|
||||
void TestPartitioning(float sparsity, int max_bins) {
|
||||
const size_t num_rows = 16;
|
||||
const size_t num_columns = 1;
|
||||
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
|
||||
|
||||
DeviceManager device_manager;
|
||||
auto qu = device_manager.GetQueue(ctx.Device());
|
||||
|
||||
auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix();
|
||||
sycl::DeviceMatrix dmat;
|
||||
dmat.Init(qu, p_fmat.get());
|
||||
|
||||
common::GHistIndexMatrix gmat;
|
||||
gmat.Init(qu, &ctx, dmat, max_bins);
|
||||
|
||||
RowSetCollection row_set_collection;
|
||||
auto& row_indices = row_set_collection.Data();
|
||||
row_indices.Resize(&qu, num_rows);
|
||||
size_t* p_row_indices = row_indices.Data();
|
||||
|
||||
qu.submit([&](::sycl::handler& cgh) {
|
||||
cgh.parallel_for<>(::sycl::range<1>(num_rows),
|
||||
[p_row_indices](::sycl::item<1> pid) {
|
||||
const size_t idx = pid.get_id(0);
|
||||
p_row_indices[idx] = idx;
|
||||
});
|
||||
}).wait_and_throw();
|
||||
row_set_collection.Init();
|
||||
|
||||
RegTree tree;
|
||||
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
|
||||
const size_t n_nodes = row_set_collection.Size();
|
||||
PartitionBuilder partition_builder;
|
||||
partition_builder.Init(&qu, n_nodes, [&](size_t nid) {
|
||||
return row_set_collection[nid].Size();
|
||||
});
|
||||
|
||||
std::vector<tree::ExpandEntry> nodes;
|
||||
nodes.emplace_back(tree::ExpandEntry(0, tree.GetDepth(0)));
|
||||
|
||||
::sycl::event event;
|
||||
std::vector<int32_t> split_conditions = {2};
|
||||
partition_builder.Partition(gmat, nodes, row_set_collection,
|
||||
split_conditions, &tree, &event);
|
||||
qu.wait_and_throw();
|
||||
|
||||
size_t* data_result = const_cast<size_t*>(row_set_collection[0].begin);
|
||||
partition_builder.MergeToArray(0, data_result, &event);
|
||||
qu.wait_and_throw();
|
||||
|
||||
bst_float split_pt = gmat.cut.Values()[split_conditions[0]];
|
||||
|
||||
std::vector<uint8_t> ridx_left(num_rows, 0);
|
||||
std::vector<uint8_t> ridx_right(num_rows, 0);
|
||||
for (auto &batch : gmat.p_fmat->GetBatches<SparsePage>()) {
|
||||
const auto& data_vec = batch.data.HostVector();
|
||||
const auto& offset_vec = batch.offset.HostVector();
|
||||
|
||||
size_t begin = offset_vec[0];
|
||||
for (size_t idx = 0; idx < offset_vec.size() - 1; ++idx) {
|
||||
size_t end = offset_vec[idx + 1];
|
||||
if (begin < end) {
|
||||
const auto& entry = data_vec[begin];
|
||||
if (entry.fvalue < split_pt) {
|
||||
ridx_left[idx] = 1;
|
||||
} else {
|
||||
ridx_right[idx] = 1;
|
||||
}
|
||||
} else {
|
||||
// missing value
|
||||
if (tree[0].DefaultLeft()) {
|
||||
ridx_left[idx] = 1;
|
||||
} else {
|
||||
ridx_right[idx] = 1;
|
||||
}
|
||||
}
|
||||
begin = end;
|
||||
}
|
||||
}
|
||||
auto n_left = std::accumulate(ridx_left.begin(), ridx_left.end(), 0);
|
||||
auto n_right = std::accumulate(ridx_right.begin(), ridx_right.end(), 0);
|
||||
|
||||
std::vector<size_t> row_indices_host(num_rows);
|
||||
qu.memcpy(row_indices_host.data(), row_indices.Data(), num_rows * sizeof(size_t));
|
||||
qu.wait_and_throw();
|
||||
|
||||
ASSERT_EQ(n_left, partition_builder.GetNLeftElems(0));
|
||||
for (size_t i = 0; i < n_left; ++i) {
|
||||
auto idx = row_indices_host[i];
|
||||
ASSERT_EQ(ridx_left[idx], 1);
|
||||
}
|
||||
|
||||
ASSERT_EQ(n_right, partition_builder.GetNRightElems(0));
|
||||
for (size_t i = 0; i < n_right; ++i) {
|
||||
auto idx = row_indices_host[num_rows - 1 - i];
|
||||
ASSERT_EQ(ridx_right[idx], 1);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SyclPartitionBuilder, BasicTest) {
|
||||
constexpr size_t kNodes = 5;
|
||||
// Number of rows for each node
|
||||
@ -67,7 +169,7 @@ TEST(SyclPartitionBuilder, BasicTest) {
|
||||
std::vector<size_t> v(*std::max_element(rows.begin(), rows.end()));
|
||||
size_t row_id = 0;
|
||||
for(size_t nid = 0; nid < kNodes; ++nid) {
|
||||
builder.MergeToArray(nid, v.data(), event);
|
||||
builder.MergeToArray(nid, v.data(), &event);
|
||||
qu.wait();
|
||||
|
||||
// Check that row_id for left side are correct
|
||||
@ -88,4 +190,20 @@ TEST(SyclPartitionBuilder, BasicTest) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SyclPartitionBuilder, PartitioningSparce) {
|
||||
TestPartitioning(0.3, 256);
|
||||
}
|
||||
|
||||
TEST(SyclPartitionBuilder, PartitioningDence8Bits) {
|
||||
TestPartitioning(0.0, 256);
|
||||
}
|
||||
|
||||
TEST(SyclPartitionBuilder, PartitioningDence16Bits) {
|
||||
TestPartitioning(0.0, 256 + 1);
|
||||
}
|
||||
|
||||
TEST(SyclPartitionBuilder, PartitioningDence32Bits) {
|
||||
TestPartitioning(0.0, (1u << 16) + 1);
|
||||
}
|
||||
|
||||
} // namespace xgboost::common
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user