Support CPU input for device QuantileDMatrix. (#8136)

- Copy `GHistIndexMatrix` to `Ellpack` when needed.
This commit is contained in:
Jiaming Yuan 2022-08-11 21:21:26 +08:00 committed by GitHub
parent 36e7c5364d
commit 16bca5d4a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 220 additions and 19 deletions

View File

@ -17,6 +17,28 @@
namespace xgboost { namespace xgboost {
namespace common { namespace common {
namespace cuda {
/**
* copy and paste of the host version, we can't make it a __host__ __device__ function as
* the fn might be a host only or device only callable object, which is not allowed by nvcc.
*/
template <typename Fn>
auto __device__ DispatchBinType(BinTypeSize type, Fn&& fn) {
switch (type) {
case kUint8BinsTypeSize: {
return fn(uint8_t{});
}
case kUint16BinsTypeSize: {
return fn(uint16_t{});
}
case kUint32BinsTypeSize: {
return fn(uint32_t{});
}
}
SPAN_CHECK(false);
return fn(uint32_t{});
}
} // namespace cuda
namespace detail { namespace detail {
struct EntryCompareOp { struct EntryCompareOp {

View File

@ -108,12 +108,12 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
} }
device_idx_ = dh::CudaGetPointerDevice(first_column.data); device_idx_ = dh::CudaGetPointerDevice(first_column.data);
CHECK_NE(device_idx_, -1); CHECK_NE(device_idx_, Context::kCpuId);
dh::safe_cuda(cudaSetDevice(device_idx_)); dh::safe_cuda(cudaSetDevice(device_idx_));
for (auto& json_col : json_columns) { for (auto& json_col : json_columns) {
auto column = ArrayInterface<1>(get<Object const>(json_col)); auto column = ArrayInterface<1>(get<Object const>(json_col));
columns.push_back(column); columns.push_back(column);
num_rows_ = std::max(num_rows_, size_t(column.Shape(0))); num_rows_ = std::max(num_rows_, column.Shape(0));
CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data)) CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data))
<< "All columns should use the same device."; << "All columns should use the same device.";
CHECK_EQ(num_rows_, column.Shape(0)) CHECK_EQ(num_rows_, column.Shape(0))
@ -138,7 +138,7 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
CudfAdapterBatch batch_; CudfAdapterBatch batch_;
dh::device_vector<ArrayInterface<1>> columns_; dh::device_vector<ArrayInterface<1>> columns_;
size_t num_rows_{0}; size_t num_rows_{0};
int device_idx_; int32_t device_idx_{Context::kCpuId};
}; };
class CupyAdapterBatch : public detail::NoMetaInfo { class CupyAdapterBatch : public detail::NoMetaInfo {
@ -173,7 +173,7 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
return; return;
} }
device_idx_ = dh::CudaGetPointerDevice(array_interface_.data); device_idx_ = dh::CudaGetPointerDevice(array_interface_.data);
CHECK_NE(device_idx_, -1); CHECK_NE(device_idx_, Context::kCpuId);
} }
explicit CupyAdapter(std::string cuda_interface_str) explicit CupyAdapter(std::string cuda_interface_str)
: CupyAdapter{StringView{cuda_interface_str}} {} : CupyAdapter{StringView{cuda_interface_str}} {}
@ -186,7 +186,7 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
private: private:
ArrayInterface<2> array_interface_; ArrayInterface<2> array_interface_;
CupyAdapterBatch batch_; CupyAdapterBatch batch_;
int32_t device_idx_ {-1}; int32_t device_idx_ {Context::kCpuId};
}; };
// Returns maximum row length // Returns maximum row length

View File

@ -1,14 +1,16 @@
/*! /*!
* Copyright 2019-2020 XGBoost contributors * Copyright 2019-2022 XGBoost contributors
*/ */
#include <xgboost/data.h>
#include <thrust/iterator/discard_iterator.h> #include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h> #include <thrust/iterator/transform_output_iterator.h>
#include "../common/categorical.h" #include "../common/categorical.h"
#include "../common/hist_util.cuh" #include "../common/hist_util.cuh"
#include "../common/random.h" #include "../common/random.h"
#include "./ellpack_page.cuh" #include "./ellpack_page.cuh"
#include "device_adapter.cuh" #include "device_adapter.cuh"
#include "gradient_index.h"
#include "xgboost/data.h"
namespace xgboost { namespace xgboost {
@ -32,7 +34,7 @@ __global__ void CompressBinEllpackKernel(
const size_t* __restrict__ row_ptrs, // row offset of input data const size_t* __restrict__ row_ptrs, // row offset of input data
const Entry* __restrict__ entries, // One batch of input data const Entry* __restrict__ entries, // One batch of input data
const float* __restrict__ cuts, // HistogramCuts::cut_values_ const float* __restrict__ cuts, // HistogramCuts::cut_values_
const uint32_t* __restrict__ cut_rows, // HistogramCuts::cut_ptrs_ const uint32_t* __restrict__ cut_ptrs, // HistogramCuts::cut_ptrs_
common::Span<FeatureType const> feature_types, common::Span<FeatureType const> feature_types,
size_t base_row, // batch_row_begin size_t base_row, // batch_row_begin
size_t n_rows, size_t n_rows,
@ -50,8 +52,8 @@ __global__ void CompressBinEllpackKernel(
int feature = entry.index; int feature = entry.index;
float fvalue = entry.fvalue; float fvalue = entry.fvalue;
// {feature_cuts, ncuts} forms the array of cuts of `feature'. // {feature_cuts, ncuts} forms the array of cuts of `feature'.
const float* feature_cuts = &cuts[cut_rows[feature]]; const float* feature_cuts = &cuts[cut_ptrs[feature]];
int ncuts = cut_rows[feature + 1] - cut_rows[feature]; int ncuts = cut_ptrs[feature + 1] - cut_ptrs[feature];
bool is_cat = common::IsCat(feature_types, ifeature); bool is_cat = common::IsCat(feature_types, ifeature);
// Assigning the bin in current entry. // Assigning the bin in current entry.
// S.t.: fvalue < feature_cuts[bin] // S.t.: fvalue < feature_cuts[bin]
@ -69,7 +71,7 @@ __global__ void CompressBinEllpackKernel(
bin = ncuts - 1; bin = ncuts - 1;
} }
// Add the number of bins in previous features. // Add the number of bins in previous features.
bin += cut_rows[feature]; bin += cut_ptrs[feature];
} }
// Write to gidx buffer. // Write to gidx buffer.
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature); wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
@ -284,6 +286,70 @@ EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device,
ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch) ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch)
ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch) ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch)
namespace {
void CopyGHistToEllpack(GHistIndexMatrix const& page, common::Span<size_t const> d_row_ptr,
size_t row_stride, common::CompressedByteT* d_compressed_buffer,
size_t null) {
dh::device_vector<uint8_t> data(page.index.begin(), page.index.end());
auto d_data = dh::ToSpan(data);
dh::device_vector<size_t> csc_indptr(page.index.Offset(),
page.index.Offset() + page.index.OffsetSize());
auto d_csc_indptr = dh::ToSpan(csc_indptr);
auto bin_type = page.index.GetBinTypeSize();
common::CompressedBufferWriter writer{page.cut.TotalBins() + 1}; // +1 for null value
dh::LaunchN(row_stride * page.Size(), [=] __device__(size_t idx) mutable {
auto ridx = idx / row_stride;
auto ifeature = idx % row_stride;
auto r_begin = d_row_ptr[ridx];
auto r_end = d_row_ptr[ridx + 1];
size_t r_size = r_end - r_begin;
if (ifeature >= r_size) {
writer.AtomicWriteSymbol(d_compressed_buffer, null, idx);
return;
}
size_t offset = 0;
if (!d_csc_indptr.empty()) {
// is dense, ifeature is the actual feature index.
offset = d_csc_indptr[ifeature];
}
common::cuda::DispatchBinType(bin_type, [&](auto t) {
using T = decltype(t);
auto ptr = reinterpret_cast<T const*>(d_data.data());
auto bin_idx = ptr[r_begin + ifeature] + offset;
writer.AtomicWriteSymbol(d_compressed_buffer, bin_idx, idx);
});
});
}
} // anonymous namespace
EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page,
common::Span<FeatureType const> ft)
: is_dense{page.IsDense()}, base_rowid{page.base_rowid}, n_rows{page.Size()}, cuts_{page.cut} {
auto it = common::MakeIndexTransformIter(
[&](size_t i) { return page.row_ptr[i + 1] - page.row_ptr[i]; });
row_stride = *std::max_element(it, it + page.Size());
CHECK_GE(ctx->gpu_id, 0);
monitor_.Start("InitCompressedData");
InitCompressedData(ctx->gpu_id);
monitor_.Stop("InitCompressedData");
// copy gidx
common::CompressedByteT* d_compressed_buffer = gidx_buffer.DevicePointer();
dh::device_vector<size_t> row_ptr(page.row_ptr);
auto d_row_ptr = dh::ToSpan(row_ptr);
auto accessor = this->GetDeviceAccessor(ctx->gpu_id, ft);
auto null = accessor.NullValue();
CopyGHistToEllpack(page, d_row_ptr, row_stride, d_compressed_buffer, null);
}
// A functor that copies the data from one EllpackPage to another. // A functor that copies the data from one EllpackPage to another.
struct CopyPage { struct CopyPage {
common::CompressedBufferWriter cbw; common::CompressedBufferWriter cbw;

View File

@ -116,6 +116,8 @@ struct EllpackDeviceAccessor {
}; };
class GHistIndexMatrix;
class EllpackPageImpl { class EllpackPageImpl {
public: public:
/*! /*!
@ -154,6 +156,11 @@ class EllpackPageImpl {
common::Span<size_t> row_counts_span, common::Span<size_t> row_counts_span,
common::Span<FeatureType const> feature_types, size_t row_stride, common::Span<FeatureType const> feature_types, size_t row_stride,
size_t n_rows, common::HistogramCuts const& cuts); size_t n_rows, common::HistogramCuts const& cuts);
/**
* \brief Constructor from an existing CPU gradient index.
*/
explicit EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page,
common::Span<FeatureType const> ft);
/*! \brief Copy the elements of the given ELLPACK page into this page. /*! \brief Copy the elements of the given ELLPACK page into this page.
* *

View File

@ -66,6 +66,7 @@ GHistIndexMatrix::GHistIndexMatrix(MetaInfo const &info, common::HistogramCuts &
max_num_bins(max_bin_per_feat), max_num_bins(max_bin_per_feat),
isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {} isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {}
GHistIndexMatrix::~GHistIndexMatrix() = default; GHistIndexMatrix::~GHistIndexMatrix() = default;
void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span<FeatureType const> ft, void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span<FeatureType const> ft,

View File

@ -205,7 +205,12 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const& param) { BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const& param) {
CheckParam(param); CheckParam(param);
CHECK(ghist_) << "Not initialized with CPU data"; CHECK(ghist_) << R"(`QuantileDMatrix` is not initialized with CPU data but used for CPU training.
Possible solutions:
- Use `DMatrix` instead.
- Use CPU input for `QuantileDMatrix`.
- Run training on GPU.
)";
auto begin_iter = auto begin_iter =
BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_)); BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_));
return BatchSet<GHistIndexMatrix>(begin_iter); return BatchSet<GHistIndexMatrix>(begin_iter);

View File

@ -168,7 +168,17 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& param) { BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& param) {
CheckParam(param); CheckParam(param);
CHECK(ellpack_) << "Not initialized with GPU data"; if (!ellpack_ && !ghist_) {
LOG(FATAL) << "`QuantileDMatrix` not initialized.";
}
if (!ellpack_ && ghist_) {
ellpack_.reset(new EllpackPage());
this->ctx_.gpu_id = param.gpu_id;
this->Info().feature_types.SetDevice(param.gpu_id);
*ellpack_->Impl() =
EllpackPageImpl(&ctx_, *this->ghist_, this->Info().feature_types.ConstDeviceSpan());
}
CHECK(ellpack_);
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_)); auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
return BatchSet<EllpackPage>(begin_iter); return BatchSet<EllpackPage>(begin_iter);
} }

View File

@ -22,7 +22,28 @@ class HistogramCuts;
} }
namespace data { namespace data {
/**
* \brief DMatrix type for `QuantileDMatrix`, the naming `IterativeDMatix` is due to its
* construction process.
*
* `QuantileDMatrix` is an intermediate storage for quantilization results including
* quantile cuts and histogram index. Quantilization is designed to be performed on stream
* of data (or batches of it). As a result, the `QuantileDMatrix` is also designed to work
* with batches of data. During initializaion, it will walk through the data multiple
* times iteratively in order to perform quantilization. This design can help us reduce
* memory usage significantly by avoiding data concatenation along with removing the CSR
* matrix `SparsePage`. However, it has its limitation (can be fixed if needed):
*
* - It's only supported by hist tree method (both CPU and GPU) since approx requires a
* re-calculation of quantiles for each iteration. We can fix this by retaining a
* reference to the callback if there are feature requests.
*
* - The CPU format and the GPU format are different, the former uses a CSR + CSC for
* histogram index while the latter uses only Ellpack. This results into a design that
* we can obtain the GPU format from CPU but the other way around is not yet
* supported. We can search the bin value from ellpack to recover the feature index when
* we support copying data from GPU to CPU.
*/
class IterativeDMatrix : public DMatrix { class IterativeDMatrix : public DMatrix {
MetaInfo info_; MetaInfo info_;
Context ctx_; Context ctx_;
@ -40,7 +61,8 @@ class IterativeDMatrix : public DMatrix {
LOG(WARNING) << "Inconsistent max_bin between Quantile DMatrix and Booster:" << param.max_bin LOG(WARNING) << "Inconsistent max_bin between Quantile DMatrix and Booster:" << param.max_bin
<< " vs. " << batch_param_.max_bin; << " vs. " << batch_param_.max_bin;
} }
CHECK(!param.regen) << "Only `hist` and `gpu_hist` tree method can use `QuantileDMatrix`."; CHECK(!param.regen && param.hess.empty())
<< "Only `hist` and `gpu_hist` tree method can use `QuantileDMatrix`.";
} }
template <typename Page> template <typename Page>
@ -49,7 +71,6 @@ class IterativeDMatrix : public DMatrix {
return BatchSet<Page>(BatchIterator<Page>(nullptr)); return BatchSet<Page>(BatchIterator<Page>(nullptr));
} }
public:
void InitFromCUDA(DataIterHandle iter, float missing, std::shared_ptr<DMatrix> ref); void InitFromCUDA(DataIterHandle iter, float missing, std::shared_ptr<DMatrix> ref);
void InitFromCPU(DataIterHandle iter_handle, float missing, std::shared_ptr<DMatrix> ref); void InitFromCPU(DataIterHandle iter_handle, float missing, std::shared_ptr<DMatrix> ref);
@ -73,8 +94,9 @@ class IterativeDMatrix : public DMatrix {
batch_param_ = BatchParam{d, max_bin}; batch_param_ = BatchParam{d, max_bin};
batch_param_.sparse_thresh = 0.2; // default from TrainParam batch_param_.sparse_thresh = 0.2; // default from TrainParam
ctx_.UpdateAllowUnknown(Args{{"nthread", std::to_string(nthread)}}); ctx_.UpdateAllowUnknown(
if (d == Context::kCpuId) { Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}});
if (ctx_.IsCPU()) {
this->InitFromCPU(iter_handle, missing, ref); this->InitFromCPU(iter_handle, missing, ref);
} else { } else {
this->InitFromCUDA(iter_handle, missing, ref); this->InitFromCUDA(iter_handle, missing, ref);

View File

@ -121,7 +121,6 @@ if __name__ == "__main__":
"python-package/xgboost/sklearn.py", "python-package/xgboost/sklearn.py",
"python-package/xgboost/spark", "python-package/xgboost/spark",
"python-package/xgboost/federated.py", "python-package/xgboost/federated.py",
"python-package/xgboost/spark",
# tests # tests
"tests/python/test_config.py", "tests/python/test_config.py",
"tests/python/test_spark/", "tests/python/test_spark/",

View File

@ -236,4 +236,45 @@ TEST(EllpackPage, Compact) {
} }
} }
} }
namespace {
class EllpackPageTest : public testing::TestWithParam<float> {
protected:
void Run(float sparsity) {
// Only testing with small sample size as the cuts might be different between host and
// device.
size_t n_samples{128}, n_features{13};
Context ctx;
ctx.gpu_id = 0;
auto Xy = RandomDataGenerator{n_samples, n_features, sparsity}.GenerateDMatrix(true);
std::unique_ptr<EllpackPageImpl> from_ghist;
ASSERT_TRUE(Xy->SingleColBlock());
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>(BatchParam{17, 0.6})) {
from_ghist.reset(new EllpackPageImpl{&ctx, page, {}});
}
for (auto const& page : Xy->GetBatches<EllpackPage>(BatchParam{0, 17})) {
auto from_sparse_page = page.Impl();
ASSERT_EQ(from_sparse_page->is_dense, from_ghist->is_dense);
ASSERT_EQ(from_sparse_page->base_rowid, 0);
ASSERT_EQ(from_sparse_page->base_rowid, from_ghist->base_rowid);
ASSERT_EQ(from_sparse_page->n_rows, from_ghist->n_rows);
ASSERT_EQ(from_sparse_page->gidx_buffer.Size(), from_ghist->gidx_buffer.Size());
auto const& h_gidx_from_sparse = from_sparse_page->gidx_buffer.HostVector();
auto const& h_gidx_from_ghist = from_ghist->gidx_buffer.HostVector();
ASSERT_EQ(from_sparse_page->NumSymbols(), from_ghist->NumSymbols());
common::CompressedIterator<uint32_t> from_ghist_it(h_gidx_from_ghist.data(),
from_ghist->NumSymbols());
common::CompressedIterator<uint32_t> from_sparse_it(h_gidx_from_sparse.data(),
from_sparse_page->NumSymbols());
for (size_t i = 0; i < from_ghist->n_rows * from_ghist->row_stride; ++i) {
EXPECT_EQ(from_ghist_it[i], from_sparse_it[i]);
}
}
}
};
} // namespace
TEST_P(EllpackPageTest, FromGHistIndex) { this->Run(GetParam()); }
INSTANTIATE_TEST_SUITE_P(EllpackPage, EllpackPageTest, testing::Values(.0f, .2f, .4f, .8f));
} // namespace xgboost } // namespace xgboost

View File

@ -31,6 +31,34 @@ class TestDeviceQuantileDMatrix:
data = cp.random.randn(5, 5) data = cp.random.randn(5, 5)
xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64)) xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))
@pytest.mark.skipif(**tm.no_cupy())
def test_from_host(self) -> None:
import cupy as cp
n_samples = 64
n_features = 3
X, y, w = tm.make_batches(
n_samples, n_features=n_features, n_batches=1, use_cupy=False
)
Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0])
booster_0 = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=4)
X[0] = cp.array(X[0])
y[0] = cp.array(y[0])
w[0] = cp.array(w[0])
Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0])
booster_1 = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=4)
cp.testing.assert_allclose(
booster_0.inplace_predict(X[0]), booster_1.inplace_predict(X[0])
)
with pytest.raises(ValueError, match="not initialized with CPU"):
# Training on CPU with GPU data is not supported.
xgb.train({"tree_method": "hist"}, Xy, num_boost_round=4)
with pytest.raises(ValueError, match=r"Only.*hist.*"):
xgb.train({"tree_method": "approx"}, Xy, num_boost_round=4)
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_metainfo(self) -> None: def test_metainfo(self) -> None:
import cupy as cp import cupy as cp