xgboost/tests/cpp/plugin/test_sycl_gradient_index.cc
Dmitry Razdoburdin 057f03cacc
[SYCL] Initial implementation of GHistIndexMatrix (#10045)
Co-authored-by: Dmitry Razdoburdin <>
2024-02-19 04:27:15 +08:00

80 lines
2.1 KiB
C++

/**
* Copyright 2021-2024 by XGBoost contributors
*/
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
#pragma GCC diagnostic pop
#include "../../../plugin/sycl/data/gradient_index.h"
#include "../../../plugin/sycl/device_manager.h"
#include "sycl_helpers.h"
#include "../helpers.h"
namespace xgboost::sycl::data {
TEST(SyclGradientIndex, HistogramCuts) {
size_t max_bins = 8;
Context ctx;
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
DeviceManager device_manager;
auto qu = device_manager.GetQueue(ctx.Device());
auto p_fmat = RandomDataGenerator{512, 16, 0.5}.GenerateDMatrix(true);
xgboost::common::HistogramCuts cut =
xgboost::common::SketchOnDMatrix(&ctx, p_fmat.get(), max_bins);
common::HistogramCuts cut_sycl;
cut_sycl.Init(qu, cut);
VerifySyclVector(cut_sycl.Ptrs(), cut.cut_ptrs_.HostVector());
VerifySyclVector(cut_sycl.Values(), cut.cut_values_.HostVector());
VerifySyclVector(cut_sycl.MinValues(), cut.min_vals_.HostVector());
}
TEST(SyclGradientIndex, Init) {
size_t n_rows = 128;
size_t n_columns = 7;
Context ctx;
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
DeviceManager device_manager;
auto qu = device_manager.GetQueue(ctx.Device());
auto p_fmat = RandomDataGenerator{n_rows, n_columns, 0.3}.GenerateDMatrix();
sycl::DeviceMatrix dmat(qu, p_fmat.get());
int max_bins = 256;
common::GHistIndexMatrix gmat_sycl;
gmat_sycl.Init(qu, &ctx, dmat, max_bins);
xgboost::GHistIndexMatrix gmat{&ctx, p_fmat.get(), max_bins, 0.3, false};
{
ASSERT_EQ(gmat_sycl.max_num_bins, max_bins);
ASSERT_EQ(gmat_sycl.nfeatures, n_columns);
}
{
VerifySyclVector(gmat_sycl.hit_count, gmat.hit_count);
}
{
std::vector<size_t> feature_count_sycl(n_columns, 0);
gmat_sycl.GetFeatureCounts(feature_count_sycl.data());
std::vector<size_t> feature_count(n_columns, 0);
gmat.GetFeatureCounts(feature_count.data());
VerifySyclVector(feature_count_sycl, feature_count);
}
}
} // namespace xgboost::sycl::data