[SYCL] Initial implementation of GHistIndexMatrix (#10045)
Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
parent
7cc256e246
commit
057f03cacc
@ -1,10 +1,7 @@
|
||||
if(PLUGIN_SYCL)
|
||||
set(CMAKE_CXX_COMPILER "icpx")
|
||||
add_library(plugin_sycl OBJECT
|
||||
${xgboost_SOURCE_DIR}/plugin/sycl/objective/regression_obj.cc
|
||||
${xgboost_SOURCE_DIR}/plugin/sycl/objective/multiclass_obj.cc
|
||||
${xgboost_SOURCE_DIR}/plugin/sycl/device_manager.cc
|
||||
${xgboost_SOURCE_DIR}/plugin/sycl/predictor/predictor.cc)
|
||||
file(GLOB_RECURSE SYCL_SOURCES "sycl/*.cc")
|
||||
add_library(plugin_sycl OBJECT ${SYCL_SOURCES})
|
||||
target_include_directories(plugin_sycl
|
||||
PRIVATE
|
||||
${xgboost_SOURCE_DIR}/include
|
||||
|
||||
@ -26,8 +26,13 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace sycl {
|
||||
enum class MemoryType { shared, on_device};
|
||||
template <typename T>
|
||||
using AtomicRef = ::sycl::atomic_ref<T,
|
||||
::sycl::memory_order::relaxed,
|
||||
::sycl::memory_scope::device,
|
||||
::sycl::access::address_space::ext_intel_global_device_space>;
|
||||
|
||||
enum class MemoryType { shared, on_device};
|
||||
|
||||
template <typename T>
|
||||
class USMDeleter {
|
||||
|
||||
177
plugin/sycl/data/gradient_index.cc
Normal file
177
plugin/sycl/data/gradient_index.cc
Normal file
@ -0,0 +1,177 @@
|
||||
/*!
|
||||
* Copyright 2017-2024 by Contributors
|
||||
* \file gradient_index.cc
|
||||
*/
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
|
||||
#include "gradient_index.h"
|
||||
|
||||
#include <CL/sycl.hpp>
|
||||
|
||||
namespace xgboost {
|
||||
namespace sycl {
|
||||
namespace common {
|
||||
|
||||
uint32_t SearchBin(const bst_float* cut_values, const uint32_t* cut_ptrs, Entry const& e) {
|
||||
auto beg = cut_ptrs[e.index];
|
||||
auto end = cut_ptrs[e.index + 1];
|
||||
auto it = std::upper_bound(cut_values + beg, cut_values + end, e.fvalue);
|
||||
uint32_t idx = it - cut_values;
|
||||
if (idx == end) {
|
||||
idx -= 1;
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <typename BinIdxType>
|
||||
void mergeSort(BinIdxType* begin, BinIdxType* end, BinIdxType* buf) {
|
||||
const size_t total_len = end - begin;
|
||||
for (size_t block_len = 1; block_len < total_len; block_len <<= 1) {
|
||||
for (size_t cur_block = 0; cur_block + block_len < total_len; cur_block += 2 * block_len) {
|
||||
size_t start = cur_block;
|
||||
size_t mid = start + block_len;
|
||||
size_t finish = mid + block_len < total_len ? mid + block_len : total_len;
|
||||
size_t left_pos = start;
|
||||
size_t right_pos = mid;
|
||||
size_t pos = start;
|
||||
while (left_pos < mid || right_pos < finish) {
|
||||
if (left_pos < mid && (right_pos == finish || begin[left_pos] < begin[right_pos])) {
|
||||
buf[pos++] = begin[left_pos++];
|
||||
} else {
|
||||
buf[pos++] = begin[right_pos++];
|
||||
}
|
||||
}
|
||||
for (size_t i = start; i < finish; i++) begin[i] = buf[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BinIdxType>
|
||||
void GHistIndexMatrix::SetIndexData(::sycl::queue qu,
|
||||
BinIdxType* index_data,
|
||||
const DeviceMatrix &dmat,
|
||||
size_t nbins,
|
||||
size_t row_stride,
|
||||
uint32_t* offsets) {
|
||||
if (nbins == 0) return;
|
||||
const xgboost::Entry *data_ptr = dmat.data.DataConst();
|
||||
const bst_row_t *offset_vec = dmat.row_ptr.DataConst();
|
||||
const size_t num_rows = dmat.row_ptr.Size() - 1;
|
||||
const bst_float* cut_values = cut_device.Values().DataConst();
|
||||
const uint32_t* cut_ptrs = cut_device.Ptrs().DataConst();
|
||||
size_t* hit_count_ptr = hit_count_buff.Data();
|
||||
|
||||
// Sparse case only
|
||||
if (!offsets) {
|
||||
// sort_buff has type uint8_t
|
||||
sort_buff.Resize(&qu, num_rows * row_stride * sizeof(BinIdxType));
|
||||
}
|
||||
BinIdxType* sort_data = reinterpret_cast<BinIdxType*>(sort_buff.Data());
|
||||
|
||||
auto event = qu.submit([&](::sycl::handler& cgh) {
|
||||
cgh.parallel_for<>(::sycl::range<1>(num_rows), [=](::sycl::item<1> pid) {
|
||||
const size_t i = pid.get_id(0);
|
||||
const size_t ibegin = offset_vec[i];
|
||||
const size_t iend = offset_vec[i + 1];
|
||||
const size_t size = iend - ibegin;
|
||||
const size_t start = i * row_stride;
|
||||
for (bst_uint j = 0; j < size; ++j) {
|
||||
uint32_t idx = SearchBin(cut_values, cut_ptrs, data_ptr[ibegin + j]);
|
||||
index_data[start + j] = offsets ? idx - offsets[j] : idx;
|
||||
AtomicRef<size_t> hit_count_ref(hit_count_ptr[idx]);
|
||||
hit_count_ref.fetch_add(1);
|
||||
}
|
||||
if (!offsets) {
|
||||
// Sparse case only
|
||||
mergeSort<BinIdxType>(index_data + start, index_data + start + size, sort_data + start);
|
||||
for (bst_uint j = size; j < row_stride; ++j) {
|
||||
index_data[start + j] = nbins;
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
qu.memcpy(hit_count.data(), hit_count_ptr, nbins * sizeof(size_t), event);
|
||||
qu.wait();
|
||||
}
|
||||
|
||||
void GHistIndexMatrix::ResizeIndex(size_t n_index, bool isDense) {
|
||||
if ((max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint8_t>::max())) && isDense) {
|
||||
index.SetBinTypeSize(BinTypeSize::kUint8BinsTypeSize);
|
||||
index.Resize((sizeof(uint8_t)) * n_index);
|
||||
} else if ((max_num_bins - 1 > static_cast<int>(std::numeric_limits<uint8_t>::max()) &&
|
||||
max_num_bins - 1 <= static_cast<int>(std::numeric_limits<uint16_t>::max())) && isDense) {
|
||||
index.SetBinTypeSize(BinTypeSize::kUint16BinsTypeSize);
|
||||
index.Resize((sizeof(uint16_t)) * n_index);
|
||||
} else {
|
||||
index.SetBinTypeSize(BinTypeSize::kUint32BinsTypeSize);
|
||||
index.Resize((sizeof(uint32_t)) * n_index);
|
||||
}
|
||||
}
|
||||
|
||||
void GHistIndexMatrix::Init(::sycl::queue qu,
|
||||
Context const * ctx,
|
||||
const DeviceMatrix& p_fmat_device,
|
||||
int max_bins) {
|
||||
nfeatures = p_fmat_device.p_mat->Info().num_col_;
|
||||
|
||||
cut = xgboost::common::SketchOnDMatrix(ctx, p_fmat_device.p_mat, max_bins);
|
||||
cut_device.Init(qu, cut);
|
||||
|
||||
max_num_bins = max_bins;
|
||||
const uint32_t nbins = cut.Ptrs().back();
|
||||
this->nbins = nbins;
|
||||
hit_count.resize(nbins, 0);
|
||||
hit_count_buff.Resize(&qu, nbins, 0);
|
||||
|
||||
this->p_fmat = p_fmat_device.p_mat;
|
||||
const bool isDense = p_fmat_device.p_mat->IsDense();
|
||||
this->isDense_ = isDense;
|
||||
|
||||
index.setQueue(qu);
|
||||
|
||||
row_stride = 0;
|
||||
for (const auto& batch : p_fmat_device.p_mat->GetBatches<SparsePage>()) {
|
||||
const auto& row_offset = batch.offset.ConstHostVector();
|
||||
for (auto i = 1ull; i < row_offset.size(); i++) {
|
||||
row_stride = std::max(row_stride, static_cast<size_t>(row_offset[i] - row_offset[i - 1]));
|
||||
}
|
||||
}
|
||||
|
||||
const size_t n_offsets = cut_device.Ptrs().Size() - 1;
|
||||
const size_t n_rows = p_fmat_device.row_ptr.Size() - 1;
|
||||
const size_t n_index = n_rows * row_stride;
|
||||
ResizeIndex(n_index, isDense);
|
||||
|
||||
CHECK_GT(cut_device.Values().Size(), 0U);
|
||||
|
||||
uint32_t* offsets = nullptr;
|
||||
if (isDense) {
|
||||
index.ResizeOffset(n_offsets);
|
||||
offsets = index.Offset();
|
||||
qu.memcpy(offsets, cut_device.Ptrs().DataConst(),
|
||||
sizeof(uint32_t) * n_offsets).wait_and_throw();
|
||||
}
|
||||
|
||||
if (isDense) {
|
||||
BinTypeSize curent_bin_size = index.GetBinTypeSize();
|
||||
if (curent_bin_size == BinTypeSize::kUint8BinsTypeSize) {
|
||||
SetIndexData(qu, index.data<uint8_t>(), p_fmat_device, nbins, row_stride, offsets);
|
||||
|
||||
} else if (curent_bin_size == BinTypeSize::kUint16BinsTypeSize) {
|
||||
SetIndexData(qu, index.data<uint16_t>(), p_fmat_device, nbins, row_stride, offsets);
|
||||
} else {
|
||||
CHECK_EQ(curent_bin_size, BinTypeSize::kUint32BinsTypeSize);
|
||||
SetIndexData(qu, index.data<uint32_t>(), p_fmat_device, nbins, row_stride, offsets);
|
||||
}
|
||||
/* For sparse DMatrix we have to store index of feature for each bin
|
||||
in index field to chose right offset. So offset is nullptr and index is not reduced */
|
||||
} else {
|
||||
SetIndexData(qu, index.data<uint32_t>(), p_fmat_device, nbins, row_stride, offsets);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace sycl
|
||||
} // namespace xgboost
|
||||
216
plugin/sycl/data/gradient_index.h
Normal file
216
plugin/sycl/data/gradient_index.h
Normal file
@ -0,0 +1,216 @@
|
||||
/*!
|
||||
* Copyright 2017-2024 by Contributors
|
||||
* \file gradient_index.h
|
||||
*/
|
||||
#ifndef PLUGIN_SYCL_DATA_GRADIENT_INDEX_H_
|
||||
#define PLUGIN_SYCL_DATA_GRADIENT_INDEX_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "../data.h"
|
||||
#include "../../src/common/hist_util.h"
|
||||
|
||||
#include <CL/sycl.hpp>
|
||||
|
||||
namespace xgboost {
|
||||
namespace sycl {
|
||||
namespace common {
|
||||
|
||||
/*!
|
||||
* \brief SYCL implementation of HistogramCuts stored in USM buffers to provide access from device kernels
|
||||
*/
|
||||
class HistogramCuts {
|
||||
protected:
|
||||
using BinIdx = uint32_t;
|
||||
|
||||
public:
|
||||
HistogramCuts() {}
|
||||
|
||||
explicit HistogramCuts(::sycl::queue qu) {}
|
||||
|
||||
~HistogramCuts() {
|
||||
}
|
||||
|
||||
void Init(::sycl::queue qu, xgboost::common::HistogramCuts const& cuts) {
|
||||
qu_ = qu;
|
||||
cut_values_.Init(&qu_, cuts.cut_values_.HostVector());
|
||||
cut_ptrs_.Init(&qu_, cuts.cut_ptrs_.HostVector());
|
||||
min_vals_.Init(&qu_, cuts.min_vals_.HostVector());
|
||||
}
|
||||
|
||||
// Getters for USM buffers to pass pointers into device kernels
|
||||
const USMVector<uint32_t>& Ptrs() const { return cut_ptrs_; }
|
||||
const USMVector<float>& Values() const { return cut_values_; }
|
||||
const USMVector<float>& MinValues() const { return min_vals_; }
|
||||
|
||||
private:
|
||||
USMVector<bst_float> cut_values_;
|
||||
USMVector<uint32_t> cut_ptrs_;
|
||||
USMVector<float> min_vals_;
|
||||
::sycl::queue qu_;
|
||||
};
|
||||
|
||||
using BinTypeSize = ::xgboost::common::BinTypeSize;
|
||||
|
||||
/*!
|
||||
* \brief Index data and offsets stored in USM buffers to provide access from device kernels
|
||||
*/
|
||||
struct Index {
|
||||
Index() {
|
||||
SetBinTypeSize(binTypeSize_);
|
||||
}
|
||||
Index(const Index& i) = delete;
|
||||
Index& operator=(Index i) = delete;
|
||||
Index(Index&& i) = delete;
|
||||
Index& operator=(Index&& i) = delete;
|
||||
uint32_t operator[](size_t i) const {
|
||||
if (!offset_.Empty()) {
|
||||
return func_(data_.DataConst(), i) + offset_[i%p_];
|
||||
} else {
|
||||
return func_(data_.DataConst(), i);
|
||||
}
|
||||
}
|
||||
void SetBinTypeSize(BinTypeSize binTypeSize) {
|
||||
binTypeSize_ = binTypeSize;
|
||||
switch (binTypeSize) {
|
||||
case BinTypeSize::kUint8BinsTypeSize:
|
||||
func_ = &GetValueFromUint8;
|
||||
break;
|
||||
case BinTypeSize::kUint16BinsTypeSize:
|
||||
func_ = &GetValueFromUint16;
|
||||
break;
|
||||
case BinTypeSize::kUint32BinsTypeSize:
|
||||
func_ = &GetValueFromUint32;
|
||||
break;
|
||||
default:
|
||||
CHECK(binTypeSize == BinTypeSize::kUint8BinsTypeSize ||
|
||||
binTypeSize == BinTypeSize::kUint16BinsTypeSize ||
|
||||
binTypeSize == BinTypeSize::kUint32BinsTypeSize);
|
||||
}
|
||||
}
|
||||
BinTypeSize GetBinTypeSize() const {
|
||||
return binTypeSize_;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T* data() {
|
||||
return reinterpret_cast<T*>(data_.Data());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
const T* data() const {
|
||||
return reinterpret_cast<const T*>(data_.DataConst());
|
||||
}
|
||||
|
||||
uint32_t* Offset() {
|
||||
return offset_.Data();
|
||||
}
|
||||
|
||||
const uint32_t* Offset() const {
|
||||
return offset_.DataConst();
|
||||
}
|
||||
|
||||
size_t Size() const {
|
||||
return data_.Size() / (binTypeSize_);
|
||||
}
|
||||
|
||||
void Resize(const size_t nBytesData) {
|
||||
data_.Resize(&qu_, nBytesData);
|
||||
}
|
||||
|
||||
void ResizeOffset(const size_t nDisps) {
|
||||
offset_.Resize(&qu_, nDisps);
|
||||
p_ = nDisps;
|
||||
}
|
||||
|
||||
uint8_t* begin() const {
|
||||
return data_.Begin();
|
||||
}
|
||||
|
||||
uint8_t* end() const {
|
||||
return data_.End();
|
||||
}
|
||||
|
||||
void setQueue(::sycl::queue qu) {
|
||||
qu_ = qu;
|
||||
}
|
||||
|
||||
private:
|
||||
static uint32_t GetValueFromUint8(const uint8_t* t, size_t i) {
|
||||
return reinterpret_cast<const uint8_t*>(t)[i];
|
||||
}
|
||||
static uint32_t GetValueFromUint16(const uint8_t* t, size_t i) {
|
||||
return reinterpret_cast<const uint16_t*>(t)[i];
|
||||
}
|
||||
static uint32_t GetValueFromUint32(const uint8_t* t, size_t i) {
|
||||
return reinterpret_cast<const uint32_t*>(t)[i];
|
||||
}
|
||||
|
||||
using Func = uint32_t (*)(const uint8_t*, size_t);
|
||||
|
||||
USMVector<uint8_t, MemoryType::on_device> data_;
|
||||
// size of this field is equal to number of features
|
||||
USMVector<uint32_t, MemoryType::on_device> offset_;
|
||||
BinTypeSize binTypeSize_ {BinTypeSize::kUint8BinsTypeSize};
|
||||
size_t p_ {1};
|
||||
Func func_;
|
||||
|
||||
::sycl::queue qu_;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Preprocessed global index matrix, in CSR format, stored in USM buffers
|
||||
*
|
||||
* Transform floating values to integer index in histogram
|
||||
*/
|
||||
struct GHistIndexMatrix {
|
||||
/*! \brief row pointer to rows by element position */
|
||||
/*! \brief The index data */
|
||||
Index index;
|
||||
/*! \brief hit count of each index */
|
||||
std::vector<size_t> hit_count;
|
||||
/*! \brief buffers for calculations */
|
||||
USMVector<size_t, MemoryType::on_device> hit_count_buff;
|
||||
USMVector<uint8_t, MemoryType::on_device> sort_buff;
|
||||
/*! \brief The corresponding cuts */
|
||||
xgboost::common::HistogramCuts cut;
|
||||
HistogramCuts cut_device;
|
||||
DMatrix* p_fmat;
|
||||
size_t max_num_bins;
|
||||
size_t nbins;
|
||||
size_t nfeatures;
|
||||
size_t row_stride;
|
||||
|
||||
// Create a global histogram matrix based on a given DMatrix device wrapper
|
||||
void Init(::sycl::queue qu, Context const * ctx,
|
||||
const sycl::DeviceMatrix& p_fmat_device, int max_num_bins);
|
||||
|
||||
template <typename BinIdxType>
|
||||
void SetIndexData(::sycl::queue qu, BinIdxType* index_data,
|
||||
const sycl::DeviceMatrix &dmat_device,
|
||||
size_t nbins, size_t row_stride, uint32_t* offsets);
|
||||
|
||||
void ResizeIndex(size_t n_index, bool isDense);
|
||||
|
||||
inline void GetFeatureCounts(size_t* counts) const {
|
||||
auto nfeature = cut_device.Ptrs().Size() - 1;
|
||||
for (unsigned fid = 0; fid < nfeature; ++fid) {
|
||||
auto ibegin = cut_device.Ptrs()[fid];
|
||||
auto iend = cut_device.Ptrs()[fid + 1];
|
||||
for (auto i = ibegin; i < iend; ++i) {
|
||||
*(counts + fid) += hit_count[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
inline bool IsDense() const {
|
||||
return isDense_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool isDense_;
|
||||
};
|
||||
|
||||
} // namespace common
|
||||
} // namespace sycl
|
||||
} // namespace xgboost
|
||||
#endif // PLUGIN_SYCL_DATA_GRADIENT_INDEX_H_
|
||||
30
tests/cpp/plugin/sycl_helpers.h
Normal file
30
tests/cpp/plugin/sycl_helpers.h
Normal file
@ -0,0 +1,30 @@
|
||||
/*!
|
||||
* Copyright 2022-2024 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost::sycl {
|
||||
template<typename T, typename Container>
|
||||
void VerifySyclVector(const USMVector<T, MemoryType::shared>& sycl_vector,
|
||||
const Container& host_vector) {
|
||||
ASSERT_EQ(sycl_vector.Size(), host_vector.size());
|
||||
|
||||
size_t size = sycl_vector.Size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
ASSERT_EQ(sycl_vector[i], host_vector[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename Container>
|
||||
void VerifySyclVector(const std::vector<T>& sycl_vector, const Container& host_vector) {
|
||||
ASSERT_EQ(sycl_vector.size(), host_vector.size());
|
||||
|
||||
size_t size = sycl_vector.size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
ASSERT_EQ(sycl_vector[i], host_vector[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xgboost::sycl
|
||||
79
tests/cpp/plugin/test_sycl_gradient_index.cc
Normal file
79
tests/cpp/plugin/test_sycl_gradient_index.cc
Normal file
@ -0,0 +1,79 @@
|
||||
/**
|
||||
* Copyright 2021-2024 by XGBoost contributors
|
||||
*/
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include "../../../plugin/sycl/data/gradient_index.h"
|
||||
#include "../../../plugin/sycl/device_manager.h"
|
||||
#include "sycl_helpers.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost::sycl::data {
|
||||
|
||||
TEST(SyclGradientIndex, HistogramCuts) {
|
||||
size_t max_bins = 8;
|
||||
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
|
||||
|
||||
DeviceManager device_manager;
|
||||
auto qu = device_manager.GetQueue(ctx.Device());
|
||||
|
||||
auto p_fmat = RandomDataGenerator{512, 16, 0.5}.GenerateDMatrix(true);
|
||||
|
||||
xgboost::common::HistogramCuts cut =
|
||||
xgboost::common::SketchOnDMatrix(&ctx, p_fmat.get(), max_bins);
|
||||
|
||||
common::HistogramCuts cut_sycl;
|
||||
cut_sycl.Init(qu, cut);
|
||||
|
||||
VerifySyclVector(cut_sycl.Ptrs(), cut.cut_ptrs_.HostVector());
|
||||
VerifySyclVector(cut_sycl.Values(), cut.cut_values_.HostVector());
|
||||
VerifySyclVector(cut_sycl.MinValues(), cut.min_vals_.HostVector());
|
||||
}
|
||||
|
||||
TEST(SyclGradientIndex, Init) {
|
||||
size_t n_rows = 128;
|
||||
size_t n_columns = 7;
|
||||
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
|
||||
|
||||
DeviceManager device_manager;
|
||||
auto qu = device_manager.GetQueue(ctx.Device());
|
||||
|
||||
auto p_fmat = RandomDataGenerator{n_rows, n_columns, 0.3}.GenerateDMatrix();
|
||||
|
||||
sycl::DeviceMatrix dmat(qu, p_fmat.get());
|
||||
|
||||
int max_bins = 256;
|
||||
common::GHistIndexMatrix gmat_sycl;
|
||||
gmat_sycl.Init(qu, &ctx, dmat, max_bins);
|
||||
|
||||
xgboost::GHistIndexMatrix gmat{&ctx, p_fmat.get(), max_bins, 0.3, false};
|
||||
|
||||
{
|
||||
ASSERT_EQ(gmat_sycl.max_num_bins, max_bins);
|
||||
ASSERT_EQ(gmat_sycl.nfeatures, n_columns);
|
||||
}
|
||||
|
||||
{
|
||||
VerifySyclVector(gmat_sycl.hit_count, gmat.hit_count);
|
||||
}
|
||||
|
||||
{
|
||||
std::vector<size_t> feature_count_sycl(n_columns, 0);
|
||||
gmat_sycl.GetFeatureCounts(feature_count_sycl.data());
|
||||
|
||||
std::vector<size_t> feature_count(n_columns, 0);
|
||||
gmat.GetFeatureCounts(feature_count.data());
|
||||
VerifySyclVector(feature_count_sycl, feature_count);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xgboost::sycl::data
|
||||
Loading…
x
Reference in New Issue
Block a user