merge latest changes
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
* Copyright 2022-2024, XGBoost contributors
|
||||
*/
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
#include "../../../../plugin/federated/federated_comm.h"
|
||||
#include "../../collective/test_worker.h" // for SocketTest
|
||||
#include "../../helpers.h" // for ExpectThrow
|
||||
#include "../../helpers.h" // for GMockThrow
|
||||
#include "test_worker.h" // for TestFederated
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
@@ -20,19 +20,19 @@ class FederatedCommTest : public SocketTest {};
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
|
||||
ASSERT_THAT(construct,
|
||||
::testing::ThrowsMessage<dmlc::Error>(::testing::HasSubstr("Invalid world size")));
|
||||
ASSERT_THAT(construct, GMockThrow("Invalid world size"));
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
|
||||
ASSERT_THAT(construct,
|
||||
::testing::ThrowsMessage<dmlc::Error>(::testing::HasSubstr("Invalid worker rank.")));
|
||||
ASSERT_THAT(construct, GMockThrow("Invalid worker rank."));
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 1, 1}; };
|
||||
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
|
||||
auto construct = [] {
|
||||
FederatedComm comm{"localhost", 0, 1, 1};
|
||||
};
|
||||
ASSERT_THAT(construct, GMockThrow("Invalid worker rank."));
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) {
|
||||
@@ -43,7 +43,7 @@ TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) {
|
||||
config["federated_rank"] = Integer(0);
|
||||
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
|
||||
};
|
||||
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
||||
ASSERT_THAT(construct, GMockThrow("got: `String`"));
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankNotInteger) {
|
||||
@@ -54,7 +54,7 @@ TEST_F(FederatedCommTest, ThrowOnRankNotInteger) {
|
||||
config["federated_rank"] = std::string("0");
|
||||
FederatedComm comm(DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config);
|
||||
};
|
||||
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
||||
ASSERT_THAT(construct, GMockThrow("got: `String`"));
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
|
||||
|
||||
30
tests/cpp/plugin/sycl_helpers.h
Normal file
30
tests/cpp/plugin/sycl_helpers.h
Normal file
@@ -0,0 +1,30 @@
|
||||
/*!
|
||||
* Copyright 2022-2024 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost::sycl {
|
||||
template<typename T, typename Container>
|
||||
void VerifySyclVector(const USMVector<T, MemoryType::shared>& sycl_vector,
|
||||
const Container& host_vector) {
|
||||
ASSERT_EQ(sycl_vector.Size(), host_vector.size());
|
||||
|
||||
size_t size = sycl_vector.Size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
ASSERT_EQ(sycl_vector[i], host_vector[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename Container>
|
||||
void VerifySyclVector(const std::vector<T>& sycl_vector, const Container& host_vector) {
|
||||
ASSERT_EQ(sycl_vector.size(), host_vector.size());
|
||||
|
||||
size_t size = sycl_vector.size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
ASSERT_EQ(sycl_vector[i], host_vector[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xgboost::sycl
|
||||
80
tests/cpp/plugin/test_sycl_gradient_index.cc
Normal file
80
tests/cpp/plugin/test_sycl_gradient_index.cc
Normal file
@@ -0,0 +1,80 @@
|
||||
/**
|
||||
* Copyright 2021-2024 by XGBoost contributors
|
||||
*/
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include "../../../plugin/sycl/data/gradient_index.h"
|
||||
#include "../../../plugin/sycl/device_manager.h"
|
||||
#include "sycl_helpers.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost::sycl::data {
|
||||
|
||||
TEST(SyclGradientIndex, HistogramCuts) {
|
||||
size_t max_bins = 8;
|
||||
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
|
||||
|
||||
DeviceManager device_manager;
|
||||
auto qu = device_manager.GetQueue(ctx.Device());
|
||||
|
||||
auto p_fmat = RandomDataGenerator{512, 16, 0.5}.GenerateDMatrix(true);
|
||||
|
||||
xgboost::common::HistogramCuts cut =
|
||||
xgboost::common::SketchOnDMatrix(&ctx, p_fmat.get(), max_bins);
|
||||
|
||||
common::HistogramCuts cut_sycl;
|
||||
cut_sycl.Init(qu, cut);
|
||||
|
||||
VerifySyclVector(cut_sycl.Ptrs(), cut.cut_ptrs_.HostVector());
|
||||
VerifySyclVector(cut_sycl.Values(), cut.cut_values_.HostVector());
|
||||
VerifySyclVector(cut_sycl.MinValues(), cut.min_vals_.HostVector());
|
||||
}
|
||||
|
||||
TEST(SyclGradientIndex, Init) {
|
||||
size_t n_rows = 128;
|
||||
size_t n_columns = 7;
|
||||
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
|
||||
|
||||
DeviceManager device_manager;
|
||||
auto qu = device_manager.GetQueue(ctx.Device());
|
||||
|
||||
auto p_fmat = RandomDataGenerator{n_rows, n_columns, 0.3}.GenerateDMatrix();
|
||||
|
||||
sycl::DeviceMatrix dmat;
|
||||
dmat.Init(qu, p_fmat.get());
|
||||
|
||||
int max_bins = 256;
|
||||
common::GHistIndexMatrix gmat_sycl;
|
||||
gmat_sycl.Init(qu, &ctx, dmat, max_bins);
|
||||
|
||||
xgboost::GHistIndexMatrix gmat{&ctx, p_fmat.get(), max_bins, 0.3, false};
|
||||
|
||||
{
|
||||
ASSERT_EQ(gmat_sycl.max_num_bins, max_bins);
|
||||
ASSERT_EQ(gmat_sycl.nfeatures, n_columns);
|
||||
}
|
||||
|
||||
{
|
||||
VerifySyclVector(gmat_sycl.hit_count, gmat.hit_count);
|
||||
}
|
||||
|
||||
{
|
||||
std::vector<size_t> feature_count_sycl(n_columns, 0);
|
||||
gmat_sycl.GetFeatureCounts(feature_count_sycl.data());
|
||||
|
||||
std::vector<size_t> feature_count(n_columns, 0);
|
||||
gmat.GetFeatureCounts(feature_count.data());
|
||||
VerifySyclVector(feature_count_sycl, feature_count);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xgboost::sycl::data
|
||||
@@ -13,6 +13,108 @@
|
||||
|
||||
namespace xgboost::sycl::common {
|
||||
|
||||
void TestPartitioning(float sparsity, int max_bins) {
|
||||
const size_t num_rows = 16;
|
||||
const size_t num_columns = 1;
|
||||
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
|
||||
|
||||
DeviceManager device_manager;
|
||||
auto qu = device_manager.GetQueue(ctx.Device());
|
||||
|
||||
auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix();
|
||||
sycl::DeviceMatrix dmat;
|
||||
dmat.Init(qu, p_fmat.get());
|
||||
|
||||
common::GHistIndexMatrix gmat;
|
||||
gmat.Init(qu, &ctx, dmat, max_bins);
|
||||
|
||||
RowSetCollection row_set_collection;
|
||||
auto& row_indices = row_set_collection.Data();
|
||||
row_indices.Resize(&qu, num_rows);
|
||||
size_t* p_row_indices = row_indices.Data();
|
||||
|
||||
qu.submit([&](::sycl::handler& cgh) {
|
||||
cgh.parallel_for<>(::sycl::range<1>(num_rows),
|
||||
[p_row_indices](::sycl::item<1> pid) {
|
||||
const size_t idx = pid.get_id(0);
|
||||
p_row_indices[idx] = idx;
|
||||
});
|
||||
}).wait_and_throw();
|
||||
row_set_collection.Init();
|
||||
|
||||
RegTree tree;
|
||||
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
|
||||
const size_t n_nodes = row_set_collection.Size();
|
||||
PartitionBuilder partition_builder;
|
||||
partition_builder.Init(&qu, n_nodes, [&](size_t nid) {
|
||||
return row_set_collection[nid].Size();
|
||||
});
|
||||
|
||||
std::vector<tree::ExpandEntry> nodes;
|
||||
nodes.emplace_back(tree::ExpandEntry(0, tree.GetDepth(0)));
|
||||
|
||||
::sycl::event event;
|
||||
std::vector<int32_t> split_conditions = {2};
|
||||
partition_builder.Partition(gmat, nodes, row_set_collection,
|
||||
split_conditions, &tree, &event);
|
||||
qu.wait_and_throw();
|
||||
|
||||
size_t* data_result = const_cast<size_t*>(row_set_collection[0].begin);
|
||||
partition_builder.MergeToArray(0, data_result, &event);
|
||||
qu.wait_and_throw();
|
||||
|
||||
bst_float split_pt = gmat.cut.Values()[split_conditions[0]];
|
||||
|
||||
std::vector<uint8_t> ridx_left(num_rows, 0);
|
||||
std::vector<uint8_t> ridx_right(num_rows, 0);
|
||||
for (auto &batch : gmat.p_fmat->GetBatches<SparsePage>()) {
|
||||
const auto& data_vec = batch.data.HostVector();
|
||||
const auto& offset_vec = batch.offset.HostVector();
|
||||
|
||||
size_t begin = offset_vec[0];
|
||||
for (size_t idx = 0; idx < offset_vec.size() - 1; ++idx) {
|
||||
size_t end = offset_vec[idx + 1];
|
||||
if (begin < end) {
|
||||
const auto& entry = data_vec[begin];
|
||||
if (entry.fvalue < split_pt) {
|
||||
ridx_left[idx] = 1;
|
||||
} else {
|
||||
ridx_right[idx] = 1;
|
||||
}
|
||||
} else {
|
||||
// missing value
|
||||
if (tree[0].DefaultLeft()) {
|
||||
ridx_left[idx] = 1;
|
||||
} else {
|
||||
ridx_right[idx] = 1;
|
||||
}
|
||||
}
|
||||
begin = end;
|
||||
}
|
||||
}
|
||||
auto n_left = std::accumulate(ridx_left.begin(), ridx_left.end(), 0);
|
||||
auto n_right = std::accumulate(ridx_right.begin(), ridx_right.end(), 0);
|
||||
|
||||
std::vector<size_t> row_indices_host(num_rows);
|
||||
qu.memcpy(row_indices_host.data(), row_indices.Data(), num_rows * sizeof(size_t));
|
||||
qu.wait_and_throw();
|
||||
|
||||
ASSERT_EQ(n_left, partition_builder.GetNLeftElems(0));
|
||||
for (size_t i = 0; i < n_left; ++i) {
|
||||
auto idx = row_indices_host[i];
|
||||
ASSERT_EQ(ridx_left[idx], 1);
|
||||
}
|
||||
|
||||
ASSERT_EQ(n_right, partition_builder.GetNRightElems(0));
|
||||
for (size_t i = 0; i < n_right; ++i) {
|
||||
auto idx = row_indices_host[num_rows - 1 - i];
|
||||
ASSERT_EQ(ridx_right[idx], 1);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SyclPartitionBuilder, BasicTest) {
|
||||
constexpr size_t kNodes = 5;
|
||||
// Number of rows for each node
|
||||
@@ -67,7 +169,7 @@ TEST(SyclPartitionBuilder, BasicTest) {
|
||||
std::vector<size_t> v(*std::max_element(rows.begin(), rows.end()));
|
||||
size_t row_id = 0;
|
||||
for(size_t nid = 0; nid < kNodes; ++nid) {
|
||||
builder.MergeToArray(nid, v.data(), event);
|
||||
builder.MergeToArray(nid, v.data(), &event);
|
||||
qu.wait();
|
||||
|
||||
// Check that row_id for left side are correct
|
||||
@@ -88,4 +190,20 @@ TEST(SyclPartitionBuilder, BasicTest) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SyclPartitionBuilder, PartitioningSparce) {
|
||||
TestPartitioning(0.3, 256);
|
||||
}
|
||||
|
||||
TEST(SyclPartitionBuilder, PartitioningDence8Bits) {
|
||||
TestPartitioning(0.0, 256);
|
||||
}
|
||||
|
||||
TEST(SyclPartitionBuilder, PartitioningDence16Bits) {
|
||||
TestPartitioning(0.0, 256 + 1);
|
||||
}
|
||||
|
||||
TEST(SyclPartitionBuilder, PartitioningDence32Bits) {
|
||||
TestPartitioning(0.0, (1u << 16) + 1);
|
||||
}
|
||||
|
||||
} // namespace xgboost::common
|
||||
|
||||
78
tests/cpp/plugin/test_sycl_row_set_collection.cc
Normal file
78
tests/cpp/plugin/test_sycl_row_set_collection.cc
Normal file
@@ -0,0 +1,78 @@
|
||||
/**
|
||||
* Copyright 2020-2023 by XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../../../plugin/sycl/common/row_set.h"
|
||||
#include "../../../plugin/sycl/device_manager.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost::sycl::common {
|
||||
TEST(SyclRowSetCollection, AddSplits) {
|
||||
const size_t num_rows = 16;
|
||||
|
||||
DeviceManager device_manager;
|
||||
auto qu = device_manager.GetQueue(DeviceOrd::SyclDefault());
|
||||
|
||||
RowSetCollection row_set_collection;
|
||||
|
||||
auto& row_indices = row_set_collection.Data();
|
||||
row_indices.Resize(&qu, num_rows);
|
||||
size_t* p_row_indices = row_indices.Data();
|
||||
|
||||
qu.submit([&](::sycl::handler& cgh) {
|
||||
cgh.parallel_for<>(::sycl::range<1>(num_rows),
|
||||
[p_row_indices](::sycl::item<1> pid) {
|
||||
const size_t idx = pid.get_id(0);
|
||||
p_row_indices[idx] = idx;
|
||||
});
|
||||
}).wait_and_throw();
|
||||
row_set_collection.Init();
|
||||
|
||||
CHECK_EQ(row_set_collection.Size(), 1);
|
||||
{
|
||||
size_t nid_test = 0;
|
||||
auto& elem = row_set_collection[nid_test];
|
||||
CHECK_EQ(elem.begin, row_indices.Begin());
|
||||
CHECK_EQ(elem.end, row_indices.End());
|
||||
CHECK_EQ(elem.node_id , 0);
|
||||
}
|
||||
|
||||
size_t nid = 0;
|
||||
size_t nid_left = 1;
|
||||
size_t nid_right = 2;
|
||||
size_t n_left = 4;
|
||||
size_t n_right = num_rows - n_left;
|
||||
row_set_collection.AddSplit(nid, nid_left, nid_right, n_left, n_right);
|
||||
CHECK_EQ(row_set_collection.Size(), 3);
|
||||
|
||||
{
|
||||
size_t nid_test = 0;
|
||||
auto& elem = row_set_collection[nid_test];
|
||||
CHECK_EQ(elem.begin, nullptr);
|
||||
CHECK_EQ(elem.end, nullptr);
|
||||
CHECK_EQ(elem.node_id , -1);
|
||||
}
|
||||
|
||||
{
|
||||
size_t nid_test = 1;
|
||||
auto& elem = row_set_collection[nid_test];
|
||||
CHECK_EQ(elem.begin, row_indices.Begin());
|
||||
CHECK_EQ(elem.end, row_indices.Begin() + n_left);
|
||||
CHECK_EQ(elem.node_id , nid_test);
|
||||
}
|
||||
|
||||
{
|
||||
size_t nid_test = 2;
|
||||
auto& elem = row_set_collection[nid_test];
|
||||
CHECK_EQ(elem.begin, row_indices.Begin() + n_left);
|
||||
CHECK_EQ(elem.end, row_indices.End());
|
||||
CHECK_EQ(elem.node_id , nid_test);
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace xgboost::sycl::common
|
||||
Reference in New Issue
Block a user