Support CPU input for device QuantileDMatrix. (#8136)
- Copy `GHistIndexMatrix` to `Ellpack` when needed.
This commit is contained in:
parent
36e7c5364d
commit
16bca5d4a1
@ -17,6 +17,28 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
|
namespace cuda {
|
||||||
|
/**
|
||||||
|
* copy and paste of the host version, we can't make it a __host__ __device__ function as
|
||||||
|
* the fn might be a host only or device only callable object, which is not allowed by nvcc.
|
||||||
|
*/
|
||||||
|
template <typename Fn>
|
||||||
|
auto __device__ DispatchBinType(BinTypeSize type, Fn&& fn) {
|
||||||
|
switch (type) {
|
||||||
|
case kUint8BinsTypeSize: {
|
||||||
|
return fn(uint8_t{});
|
||||||
|
}
|
||||||
|
case kUint16BinsTypeSize: {
|
||||||
|
return fn(uint16_t{});
|
||||||
|
}
|
||||||
|
case kUint32BinsTypeSize: {
|
||||||
|
return fn(uint32_t{});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SPAN_CHECK(false);
|
||||||
|
return fn(uint32_t{});
|
||||||
|
}
|
||||||
|
} // namespace cuda
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
struct EntryCompareOp {
|
struct EntryCompareOp {
|
||||||
|
|||||||
@ -108,12 +108,12 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
device_idx_ = dh::CudaGetPointerDevice(first_column.data);
|
device_idx_ = dh::CudaGetPointerDevice(first_column.data);
|
||||||
CHECK_NE(device_idx_, -1);
|
CHECK_NE(device_idx_, Context::kCpuId);
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx_));
|
dh::safe_cuda(cudaSetDevice(device_idx_));
|
||||||
for (auto& json_col : json_columns) {
|
for (auto& json_col : json_columns) {
|
||||||
auto column = ArrayInterface<1>(get<Object const>(json_col));
|
auto column = ArrayInterface<1>(get<Object const>(json_col));
|
||||||
columns.push_back(column);
|
columns.push_back(column);
|
||||||
num_rows_ = std::max(num_rows_, size_t(column.Shape(0)));
|
num_rows_ = std::max(num_rows_, column.Shape(0));
|
||||||
CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data))
|
CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data))
|
||||||
<< "All columns should use the same device.";
|
<< "All columns should use the same device.";
|
||||||
CHECK_EQ(num_rows_, column.Shape(0))
|
CHECK_EQ(num_rows_, column.Shape(0))
|
||||||
@ -138,7 +138,7 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
|
|||||||
CudfAdapterBatch batch_;
|
CudfAdapterBatch batch_;
|
||||||
dh::device_vector<ArrayInterface<1>> columns_;
|
dh::device_vector<ArrayInterface<1>> columns_;
|
||||||
size_t num_rows_{0};
|
size_t num_rows_{0};
|
||||||
int device_idx_;
|
int32_t device_idx_{Context::kCpuId};
|
||||||
};
|
};
|
||||||
|
|
||||||
class CupyAdapterBatch : public detail::NoMetaInfo {
|
class CupyAdapterBatch : public detail::NoMetaInfo {
|
||||||
@ -173,7 +173,7 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
device_idx_ = dh::CudaGetPointerDevice(array_interface_.data);
|
device_idx_ = dh::CudaGetPointerDevice(array_interface_.data);
|
||||||
CHECK_NE(device_idx_, -1);
|
CHECK_NE(device_idx_, Context::kCpuId);
|
||||||
}
|
}
|
||||||
explicit CupyAdapter(std::string cuda_interface_str)
|
explicit CupyAdapter(std::string cuda_interface_str)
|
||||||
: CupyAdapter{StringView{cuda_interface_str}} {}
|
: CupyAdapter{StringView{cuda_interface_str}} {}
|
||||||
@ -186,7 +186,7 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
|
|||||||
private:
|
private:
|
||||||
ArrayInterface<2> array_interface_;
|
ArrayInterface<2> array_interface_;
|
||||||
CupyAdapterBatch batch_;
|
CupyAdapterBatch batch_;
|
||||||
int32_t device_idx_ {-1};
|
int32_t device_idx_ {Context::kCpuId};
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns maximum row length
|
// Returns maximum row length
|
||||||
|
|||||||
@ -1,14 +1,16 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2019-2020 XGBoost contributors
|
* Copyright 2019-2022 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <xgboost/data.h>
|
|
||||||
#include <thrust/iterator/discard_iterator.h>
|
#include <thrust/iterator/discard_iterator.h>
|
||||||
#include <thrust/iterator/transform_output_iterator.h>
|
#include <thrust/iterator/transform_output_iterator.h>
|
||||||
|
|
||||||
#include "../common/categorical.h"
|
#include "../common/categorical.h"
|
||||||
#include "../common/hist_util.cuh"
|
#include "../common/hist_util.cuh"
|
||||||
#include "../common/random.h"
|
#include "../common/random.h"
|
||||||
#include "./ellpack_page.cuh"
|
#include "./ellpack_page.cuh"
|
||||||
#include "device_adapter.cuh"
|
#include "device_adapter.cuh"
|
||||||
|
#include "gradient_index.h"
|
||||||
|
#include "xgboost/data.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
@ -32,7 +34,7 @@ __global__ void CompressBinEllpackKernel(
|
|||||||
const size_t* __restrict__ row_ptrs, // row offset of input data
|
const size_t* __restrict__ row_ptrs, // row offset of input data
|
||||||
const Entry* __restrict__ entries, // One batch of input data
|
const Entry* __restrict__ entries, // One batch of input data
|
||||||
const float* __restrict__ cuts, // HistogramCuts::cut_values_
|
const float* __restrict__ cuts, // HistogramCuts::cut_values_
|
||||||
const uint32_t* __restrict__ cut_rows, // HistogramCuts::cut_ptrs_
|
const uint32_t* __restrict__ cut_ptrs, // HistogramCuts::cut_ptrs_
|
||||||
common::Span<FeatureType const> feature_types,
|
common::Span<FeatureType const> feature_types,
|
||||||
size_t base_row, // batch_row_begin
|
size_t base_row, // batch_row_begin
|
||||||
size_t n_rows,
|
size_t n_rows,
|
||||||
@ -50,8 +52,8 @@ __global__ void CompressBinEllpackKernel(
|
|||||||
int feature = entry.index;
|
int feature = entry.index;
|
||||||
float fvalue = entry.fvalue;
|
float fvalue = entry.fvalue;
|
||||||
// {feature_cuts, ncuts} forms the array of cuts of `feature'.
|
// {feature_cuts, ncuts} forms the array of cuts of `feature'.
|
||||||
const float* feature_cuts = &cuts[cut_rows[feature]];
|
const float* feature_cuts = &cuts[cut_ptrs[feature]];
|
||||||
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
|
int ncuts = cut_ptrs[feature + 1] - cut_ptrs[feature];
|
||||||
bool is_cat = common::IsCat(feature_types, ifeature);
|
bool is_cat = common::IsCat(feature_types, ifeature);
|
||||||
// Assigning the bin in current entry.
|
// Assigning the bin in current entry.
|
||||||
// S.t.: fvalue < feature_cuts[bin]
|
// S.t.: fvalue < feature_cuts[bin]
|
||||||
@ -69,7 +71,7 @@ __global__ void CompressBinEllpackKernel(
|
|||||||
bin = ncuts - 1;
|
bin = ncuts - 1;
|
||||||
}
|
}
|
||||||
// Add the number of bins in previous features.
|
// Add the number of bins in previous features.
|
||||||
bin += cut_rows[feature];
|
bin += cut_ptrs[feature];
|
||||||
}
|
}
|
||||||
// Write to gidx buffer.
|
// Write to gidx buffer.
|
||||||
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
|
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
|
||||||
@ -284,6 +286,70 @@ EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device,
|
|||||||
ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch)
|
ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch)
|
||||||
ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch)
|
ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch)
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void CopyGHistToEllpack(GHistIndexMatrix const& page, common::Span<size_t const> d_row_ptr,
|
||||||
|
size_t row_stride, common::CompressedByteT* d_compressed_buffer,
|
||||||
|
size_t null) {
|
||||||
|
dh::device_vector<uint8_t> data(page.index.begin(), page.index.end());
|
||||||
|
auto d_data = dh::ToSpan(data);
|
||||||
|
|
||||||
|
dh::device_vector<size_t> csc_indptr(page.index.Offset(),
|
||||||
|
page.index.Offset() + page.index.OffsetSize());
|
||||||
|
auto d_csc_indptr = dh::ToSpan(csc_indptr);
|
||||||
|
|
||||||
|
auto bin_type = page.index.GetBinTypeSize();
|
||||||
|
common::CompressedBufferWriter writer{page.cut.TotalBins() + 1}; // +1 for null value
|
||||||
|
|
||||||
|
dh::LaunchN(row_stride * page.Size(), [=] __device__(size_t idx) mutable {
|
||||||
|
auto ridx = idx / row_stride;
|
||||||
|
auto ifeature = idx % row_stride;
|
||||||
|
|
||||||
|
auto r_begin = d_row_ptr[ridx];
|
||||||
|
auto r_end = d_row_ptr[ridx + 1];
|
||||||
|
size_t r_size = r_end - r_begin;
|
||||||
|
|
||||||
|
if (ifeature >= r_size) {
|
||||||
|
writer.AtomicWriteSymbol(d_compressed_buffer, null, idx);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
if (!d_csc_indptr.empty()) {
|
||||||
|
// is dense, ifeature is the actual feature index.
|
||||||
|
offset = d_csc_indptr[ifeature];
|
||||||
|
}
|
||||||
|
common::cuda::DispatchBinType(bin_type, [&](auto t) {
|
||||||
|
using T = decltype(t);
|
||||||
|
auto ptr = reinterpret_cast<T const*>(d_data.data());
|
||||||
|
auto bin_idx = ptr[r_begin + ifeature] + offset;
|
||||||
|
writer.AtomicWriteSymbol(d_compressed_buffer, bin_idx, idx);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page,
|
||||||
|
common::Span<FeatureType const> ft)
|
||||||
|
: is_dense{page.IsDense()}, base_rowid{page.base_rowid}, n_rows{page.Size()}, cuts_{page.cut} {
|
||||||
|
auto it = common::MakeIndexTransformIter(
|
||||||
|
[&](size_t i) { return page.row_ptr[i + 1] - page.row_ptr[i]; });
|
||||||
|
row_stride = *std::max_element(it, it + page.Size());
|
||||||
|
|
||||||
|
CHECK_GE(ctx->gpu_id, 0);
|
||||||
|
monitor_.Start("InitCompressedData");
|
||||||
|
InitCompressedData(ctx->gpu_id);
|
||||||
|
monitor_.Stop("InitCompressedData");
|
||||||
|
|
||||||
|
// copy gidx
|
||||||
|
common::CompressedByteT* d_compressed_buffer = gidx_buffer.DevicePointer();
|
||||||
|
dh::device_vector<size_t> row_ptr(page.row_ptr);
|
||||||
|
auto d_row_ptr = dh::ToSpan(row_ptr);
|
||||||
|
|
||||||
|
auto accessor = this->GetDeviceAccessor(ctx->gpu_id, ft);
|
||||||
|
auto null = accessor.NullValue();
|
||||||
|
CopyGHistToEllpack(page, d_row_ptr, row_stride, d_compressed_buffer, null);
|
||||||
|
}
|
||||||
|
|
||||||
// A functor that copies the data from one EllpackPage to another.
|
// A functor that copies the data from one EllpackPage to another.
|
||||||
struct CopyPage {
|
struct CopyPage {
|
||||||
common::CompressedBufferWriter cbw;
|
common::CompressedBufferWriter cbw;
|
||||||
|
|||||||
@ -116,6 +116,8 @@ struct EllpackDeviceAccessor {
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class GHistIndexMatrix;
|
||||||
|
|
||||||
class EllpackPageImpl {
|
class EllpackPageImpl {
|
||||||
public:
|
public:
|
||||||
/*!
|
/*!
|
||||||
@ -154,6 +156,11 @@ class EllpackPageImpl {
|
|||||||
common::Span<size_t> row_counts_span,
|
common::Span<size_t> row_counts_span,
|
||||||
common::Span<FeatureType const> feature_types, size_t row_stride,
|
common::Span<FeatureType const> feature_types, size_t row_stride,
|
||||||
size_t n_rows, common::HistogramCuts const& cuts);
|
size_t n_rows, common::HistogramCuts const& cuts);
|
||||||
|
/**
|
||||||
|
* \brief Constructor from an existing CPU gradient index.
|
||||||
|
*/
|
||||||
|
explicit EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page,
|
||||||
|
common::Span<FeatureType const> ft);
|
||||||
|
|
||||||
/*! \brief Copy the elements of the given ELLPACK page into this page.
|
/*! \brief Copy the elements of the given ELLPACK page into this page.
|
||||||
*
|
*
|
||||||
|
|||||||
@ -66,6 +66,7 @@ GHistIndexMatrix::GHistIndexMatrix(MetaInfo const &info, common::HistogramCuts &
|
|||||||
max_num_bins(max_bin_per_feat),
|
max_num_bins(max_bin_per_feat),
|
||||||
isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {}
|
isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {}
|
||||||
|
|
||||||
|
|
||||||
GHistIndexMatrix::~GHistIndexMatrix() = default;
|
GHistIndexMatrix::~GHistIndexMatrix() = default;
|
||||||
|
|
||||||
void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span<FeatureType const> ft,
|
void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span<FeatureType const> ft,
|
||||||
|
|||||||
@ -205,7 +205,12 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
|||||||
|
|
||||||
BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const& param) {
|
BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const& param) {
|
||||||
CheckParam(param);
|
CheckParam(param);
|
||||||
CHECK(ghist_) << "Not initialized with CPU data";
|
CHECK(ghist_) << R"(`QuantileDMatrix` is not initialized with CPU data but used for CPU training.
|
||||||
|
Possible solutions:
|
||||||
|
- Use `DMatrix` instead.
|
||||||
|
- Use CPU input for `QuantileDMatrix`.
|
||||||
|
- Run training on GPU.
|
||||||
|
)";
|
||||||
auto begin_iter =
|
auto begin_iter =
|
||||||
BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_));
|
BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_));
|
||||||
return BatchSet<GHistIndexMatrix>(begin_iter);
|
return BatchSet<GHistIndexMatrix>(begin_iter);
|
||||||
|
|||||||
@ -168,7 +168,17 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
|||||||
|
|
||||||
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& param) {
|
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& param) {
|
||||||
CheckParam(param);
|
CheckParam(param);
|
||||||
CHECK(ellpack_) << "Not initialized with GPU data";
|
if (!ellpack_ && !ghist_) {
|
||||||
|
LOG(FATAL) << "`QuantileDMatrix` not initialized.";
|
||||||
|
}
|
||||||
|
if (!ellpack_ && ghist_) {
|
||||||
|
ellpack_.reset(new EllpackPage());
|
||||||
|
this->ctx_.gpu_id = param.gpu_id;
|
||||||
|
this->Info().feature_types.SetDevice(param.gpu_id);
|
||||||
|
*ellpack_->Impl() =
|
||||||
|
EllpackPageImpl(&ctx_, *this->ghist_, this->Info().feature_types.ConstDeviceSpan());
|
||||||
|
}
|
||||||
|
CHECK(ellpack_);
|
||||||
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
|
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
|
||||||
return BatchSet<EllpackPage>(begin_iter);
|
return BatchSet<EllpackPage>(begin_iter);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -22,7 +22,28 @@ class HistogramCuts;
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace data {
|
namespace data {
|
||||||
|
/**
|
||||||
|
* \brief DMatrix type for `QuantileDMatrix`, the naming `IterativeDMatix` is due to its
|
||||||
|
* construction process.
|
||||||
|
*
|
||||||
|
* `QuantileDMatrix` is an intermediate storage for quantilization results including
|
||||||
|
* quantile cuts and histogram index. Quantilization is designed to be performed on stream
|
||||||
|
* of data (or batches of it). As a result, the `QuantileDMatrix` is also designed to work
|
||||||
|
* with batches of data. During initializaion, it will walk through the data multiple
|
||||||
|
* times iteratively in order to perform quantilization. This design can help us reduce
|
||||||
|
* memory usage significantly by avoiding data concatenation along with removing the CSR
|
||||||
|
* matrix `SparsePage`. However, it has its limitation (can be fixed if needed):
|
||||||
|
*
|
||||||
|
* - It's only supported by hist tree method (both CPU and GPU) since approx requires a
|
||||||
|
* re-calculation of quantiles for each iteration. We can fix this by retaining a
|
||||||
|
* reference to the callback if there are feature requests.
|
||||||
|
*
|
||||||
|
* - The CPU format and the GPU format are different, the former uses a CSR + CSC for
|
||||||
|
* histogram index while the latter uses only Ellpack. This results into a design that
|
||||||
|
* we can obtain the GPU format from CPU but the other way around is not yet
|
||||||
|
* supported. We can search the bin value from ellpack to recover the feature index when
|
||||||
|
* we support copying data from GPU to CPU.
|
||||||
|
*/
|
||||||
class IterativeDMatrix : public DMatrix {
|
class IterativeDMatrix : public DMatrix {
|
||||||
MetaInfo info_;
|
MetaInfo info_;
|
||||||
Context ctx_;
|
Context ctx_;
|
||||||
@ -40,7 +61,8 @@ class IterativeDMatrix : public DMatrix {
|
|||||||
LOG(WARNING) << "Inconsistent max_bin between Quantile DMatrix and Booster:" << param.max_bin
|
LOG(WARNING) << "Inconsistent max_bin between Quantile DMatrix and Booster:" << param.max_bin
|
||||||
<< " vs. " << batch_param_.max_bin;
|
<< " vs. " << batch_param_.max_bin;
|
||||||
}
|
}
|
||||||
CHECK(!param.regen) << "Only `hist` and `gpu_hist` tree method can use `QuantileDMatrix`.";
|
CHECK(!param.regen && param.hess.empty())
|
||||||
|
<< "Only `hist` and `gpu_hist` tree method can use `QuantileDMatrix`.";
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Page>
|
template <typename Page>
|
||||||
@ -49,7 +71,6 @@ class IterativeDMatrix : public DMatrix {
|
|||||||
return BatchSet<Page>(BatchIterator<Page>(nullptr));
|
return BatchSet<Page>(BatchIterator<Page>(nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
|
||||||
void InitFromCUDA(DataIterHandle iter, float missing, std::shared_ptr<DMatrix> ref);
|
void InitFromCUDA(DataIterHandle iter, float missing, std::shared_ptr<DMatrix> ref);
|
||||||
void InitFromCPU(DataIterHandle iter_handle, float missing, std::shared_ptr<DMatrix> ref);
|
void InitFromCPU(DataIterHandle iter_handle, float missing, std::shared_ptr<DMatrix> ref);
|
||||||
|
|
||||||
@ -73,8 +94,9 @@ class IterativeDMatrix : public DMatrix {
|
|||||||
batch_param_ = BatchParam{d, max_bin};
|
batch_param_ = BatchParam{d, max_bin};
|
||||||
batch_param_.sparse_thresh = 0.2; // default from TrainParam
|
batch_param_.sparse_thresh = 0.2; // default from TrainParam
|
||||||
|
|
||||||
ctx_.UpdateAllowUnknown(Args{{"nthread", std::to_string(nthread)}});
|
ctx_.UpdateAllowUnknown(
|
||||||
if (d == Context::kCpuId) {
|
Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}});
|
||||||
|
if (ctx_.IsCPU()) {
|
||||||
this->InitFromCPU(iter_handle, missing, ref);
|
this->InitFromCPU(iter_handle, missing, ref);
|
||||||
} else {
|
} else {
|
||||||
this->InitFromCUDA(iter_handle, missing, ref);
|
this->InitFromCUDA(iter_handle, missing, ref);
|
||||||
|
|||||||
@ -121,7 +121,6 @@ if __name__ == "__main__":
|
|||||||
"python-package/xgboost/sklearn.py",
|
"python-package/xgboost/sklearn.py",
|
||||||
"python-package/xgboost/spark",
|
"python-package/xgboost/spark",
|
||||||
"python-package/xgboost/federated.py",
|
"python-package/xgboost/federated.py",
|
||||||
"python-package/xgboost/spark",
|
|
||||||
# tests
|
# tests
|
||||||
"tests/python/test_config.py",
|
"tests/python/test_config.py",
|
||||||
"tests/python/test_spark/",
|
"tests/python/test_spark/",
|
||||||
|
|||||||
@ -236,4 +236,45 @@ TEST(EllpackPage, Compact) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class EllpackPageTest : public testing::TestWithParam<float> {
|
||||||
|
protected:
|
||||||
|
void Run(float sparsity) {
|
||||||
|
// Only testing with small sample size as the cuts might be different between host and
|
||||||
|
// device.
|
||||||
|
size_t n_samples{128}, n_features{13};
|
||||||
|
Context ctx;
|
||||||
|
ctx.gpu_id = 0;
|
||||||
|
auto Xy = RandomDataGenerator{n_samples, n_features, sparsity}.GenerateDMatrix(true);
|
||||||
|
std::unique_ptr<EllpackPageImpl> from_ghist;
|
||||||
|
ASSERT_TRUE(Xy->SingleColBlock());
|
||||||
|
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>(BatchParam{17, 0.6})) {
|
||||||
|
from_ghist.reset(new EllpackPageImpl{&ctx, page, {}});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto const& page : Xy->GetBatches<EllpackPage>(BatchParam{0, 17})) {
|
||||||
|
auto from_sparse_page = page.Impl();
|
||||||
|
ASSERT_EQ(from_sparse_page->is_dense, from_ghist->is_dense);
|
||||||
|
ASSERT_EQ(from_sparse_page->base_rowid, 0);
|
||||||
|
ASSERT_EQ(from_sparse_page->base_rowid, from_ghist->base_rowid);
|
||||||
|
ASSERT_EQ(from_sparse_page->n_rows, from_ghist->n_rows);
|
||||||
|
ASSERT_EQ(from_sparse_page->gidx_buffer.Size(), from_ghist->gidx_buffer.Size());
|
||||||
|
auto const& h_gidx_from_sparse = from_sparse_page->gidx_buffer.HostVector();
|
||||||
|
auto const& h_gidx_from_ghist = from_ghist->gidx_buffer.HostVector();
|
||||||
|
ASSERT_EQ(from_sparse_page->NumSymbols(), from_ghist->NumSymbols());
|
||||||
|
common::CompressedIterator<uint32_t> from_ghist_it(h_gidx_from_ghist.data(),
|
||||||
|
from_ghist->NumSymbols());
|
||||||
|
common::CompressedIterator<uint32_t> from_sparse_it(h_gidx_from_sparse.data(),
|
||||||
|
from_sparse_page->NumSymbols());
|
||||||
|
for (size_t i = 0; i < from_ghist->n_rows * from_ghist->row_stride; ++i) {
|
||||||
|
EXPECT_EQ(from_ghist_it[i], from_sparse_it[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST_P(EllpackPageTest, FromGHistIndex) { this->Run(GetParam()); }
|
||||||
|
INSTANTIATE_TEST_SUITE_P(EllpackPage, EllpackPageTest, testing::Values(.0f, .2f, .4f, .8f));
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -31,6 +31,34 @@ class TestDeviceQuantileDMatrix:
|
|||||||
data = cp.random.randn(5, 5)
|
data = cp.random.randn(5, 5)
|
||||||
xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))
|
xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
|
def test_from_host(self) -> None:
|
||||||
|
import cupy as cp
|
||||||
|
n_samples = 64
|
||||||
|
n_features = 3
|
||||||
|
X, y, w = tm.make_batches(
|
||||||
|
n_samples, n_features=n_features, n_batches=1, use_cupy=False
|
||||||
|
)
|
||||||
|
Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0])
|
||||||
|
booster_0 = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=4)
|
||||||
|
|
||||||
|
X[0] = cp.array(X[0])
|
||||||
|
y[0] = cp.array(y[0])
|
||||||
|
w[0] = cp.array(w[0])
|
||||||
|
|
||||||
|
Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0])
|
||||||
|
booster_1 = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=4)
|
||||||
|
cp.testing.assert_allclose(
|
||||||
|
booster_0.inplace_predict(X[0]), booster_1.inplace_predict(X[0])
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="not initialized with CPU"):
|
||||||
|
# Training on CPU with GPU data is not supported.
|
||||||
|
xgb.train({"tree_method": "hist"}, Xy, num_boost_round=4)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r"Only.*hist.*"):
|
||||||
|
xgb.train({"tree_method": "approx"}, Xy, num_boost_round=4)
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_metainfo(self) -> None:
|
def test_metainfo(self) -> None:
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user