Support categorical data in ellpack. (#6140)

This commit is contained in:
Jiaming Yuan 2020-09-24 19:28:57 +08:00 committed by GitHub
parent 78d72ef936
commit 14afdb4d92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 189 additions and 43 deletions

View File

@ -1,10 +1,10 @@
/*!
* Copyright 2019 XGBoost contributors
* Copyright 2019-2020 XGBoost contributors
*/
#include <xgboost/data.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include "../common/categorical.h"
#include "../common/hist_util.cuh"
#include "../common/random.h"
#include "./ellpack_page.cuh"
@ -33,6 +33,7 @@ __global__ void CompressBinEllpackKernel(
const Entry* __restrict__ entries, // One batch of input data
const float* __restrict__ cuts, // HistogramCuts::cut_values_
const uint32_t* __restrict__ cut_rows, // HistogramCuts::cut_ptrs_
common::Span<FeatureType const> feature_types,
size_t base_row, // batch_row_begin
size_t n_rows,
size_t row_stride,
@ -51,11 +52,19 @@ __global__ void CompressBinEllpackKernel(
// {feature_cuts, ncuts} forms the array of cuts of `feature'.
const float* feature_cuts = &cuts[cut_rows[feature]];
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
bool is_cat = common::IsCat(feature_types, ifeature);
// Assigning the bin in current entry.
// S.t.: fvalue < feature_cuts[bin]
if (is_cat) {
auto it = dh::MakeTransformIterator<int>(
feature_cuts, [](float v) { return common::AsCat(v); });
bin = thrust::lower_bound(thrust::seq, it, it + ncuts, common::AsCat(fvalue)) - it;
} else {
bin = thrust::upper_bound(thrust::seq, feature_cuts, feature_cuts + ncuts,
fvalue) -
feature_cuts;
}
if (bin >= ncuts) {
bin = ncuts - 1;
}
@ -84,13 +93,12 @@ EllpackPageImpl::EllpackPageImpl(int device, common::HistogramCuts cuts,
EllpackPageImpl::EllpackPageImpl(int device, common::HistogramCuts cuts,
const SparsePage &page, bool is_dense,
size_t row_stride)
: cuts_(std::move(cuts)),
is_dense(is_dense),
n_rows(page.Size()),
size_t row_stride,
common::Span<FeatureType const> feature_types)
: cuts_(std::move(cuts)), is_dense(is_dense), n_rows(page.Size()),
row_stride(row_stride) {
this->InitCompressedData(device);
this->CreateHistIndices(device, page);
this->CreateHistIndices(device, page, feature_types);
}
// Construct an ELLPACK matrix in memory.
@ -108,12 +116,14 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param)
monitor_.Stop("Quantiles");
monitor_.Start("InitCompressedData");
InitCompressedData(param.gpu_id);
this->InitCompressedData(param.gpu_id);
monitor_.Stop("InitCompressedData");
dmat->Info().feature_types.SetDevice(param.gpu_id);
auto ft = dmat->Info().feature_types.ConstDeviceSpan();
monitor_.Start("BinningCompression");
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
CreateHistIndices(param.gpu_id, batch);
CreateHistIndices(param.gpu_id, batch, ft);
}
monitor_.Stop("BinningCompression");
}
@ -365,7 +375,8 @@ void EllpackPageImpl::InitCompressedData(int device) {
// Compress a CSR page into ELLPACK.
void EllpackPageImpl::CreateHistIndices(int device,
const SparsePage& row_batch) {
const SparsePage& row_batch,
common::Span<FeatureType const> feature_types) {
if (row_batch.Size() == 0) return;
unsigned int null_gidx_value = NumSymbols() - 1;
@ -397,7 +408,7 @@ void EllpackPageImpl::CreateHistIndices(int device,
size_t n_entries = ent_cnt_end - ent_cnt_begin;
dh::device_vector<Entry> entries_d(n_entries);
// copy data entries to device.
dh::safe_cuda(cudaMemcpy(entries_d.data().get(),
dh::safe_cuda(cudaMemcpyAsync(entries_d.data().get(),
data_vec.data() + ent_cnt_begin,
n_entries * sizeof(Entry), cudaMemcpyDefault));
const dim3 block3(32, 8, 1); // 256 threads
@ -408,7 +419,7 @@ void EllpackPageImpl::CreateHistIndices(int device,
CompressBinEllpackKernel, common::CompressedBufferWriter(NumSymbols()),
gidx_buffer.DevicePointer(), row_ptrs.data().get(),
entries_d.data().get(), device_accessor.gidx_fvalue_map.data(),
device_accessor.feature_segments.data(),
device_accessor.feature_segments.data(), feature_types,
row_batch.base_rowid + batch_row_begin, batch_nrows, row_stride,
null_gidx_value);
}

View File

@ -118,10 +118,12 @@ class EllpackPageImpl {
*/
EllpackPageImpl(int device, common::HistogramCuts cuts, bool is_dense,
size_t row_stride, size_t n_rows);
/*!
* \brief Constructor used for external memory.
*/
EllpackPageImpl(int device, common::HistogramCuts cuts,
const SparsePage& page,
bool is_dense, size_t row_stride);
const SparsePage &page, bool is_dense, size_t row_stride,
common::Span<FeatureType const> feature_types);
/*!
* \brief Constructor from an existing DMatrix.
@ -184,8 +186,8 @@ class EllpackPageImpl {
* @param row_batch The CSR page.
*/
void CreateHistIndices(int device,
const SparsePage& row_batch
);
const SparsePage& row_batch,
common::Span<FeatureType const> feature_types);
/*!
* \brief Initialize the buffer to store compressed features.
*/

View File

@ -55,6 +55,7 @@ void EllpackPageSource::WriteEllpackPages(int device, DMatrix* dmat,
SparsePage temp_host_page;
writer.Alloc(&page);
auto* impl = page->Impl();
auto ft = dmat->Info().feature_types.ConstDeviceSpan();
size_t bytes_write = 0;
double tstart = dmlc::GetTime();
@ -66,7 +67,7 @@ void EllpackPageSource::WriteEllpackPages(int device, DMatrix* dmat,
if (mem_cost_bytes >= page_size_) {
bytes_write += mem_cost_bytes;
*impl = EllpackPageImpl(device, cuts, temp_host_page, dmat->IsDense(),
row_stride);
row_stride, ft);
writer.PushWrite(std::move(page));
writer.Alloc(&page);
impl = page->Impl();
@ -79,7 +80,7 @@ void EllpackPageSource::WriteEllpackPages(int device, DMatrix* dmat,
}
if (temp_host_page.Size() != 0) {
*impl = EllpackPageImpl(device, cuts, temp_host_page, dmat->IsDense(),
row_stride);
row_stride, ft);
writer.PushWrite(std::move(page));
}
}

View File

@ -60,19 +60,6 @@ inline data::CupyAdapter AdapterFromData(const thrust::device_vector<float> &x,
}
#endif
inline std::vector<float> GenerateRandomCategoricalSingleColumn(int n,
int num_categories) {
std::vector<float> x(n);
std::mt19937 rng(0);
std::uniform_int_distribution<int> dist(0, num_categories - 1);
std::generate(x.begin(), x.end(), [&]() { return dist(rng); });
// Make sure each category is present
for(auto i = 0; i < num_categories; i++) {
x[i] = i;
}
return x;
}
inline std::shared_ptr<data::SimpleDMatrix>
GetDMatrixFromData(const std::vector<float> &x, int num_rows, int num_columns) {
data::DenseAdapter adapter(x.data(), num_rows, num_columns);

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2019 XGBoost contributors
* Copyright 2019-2020 XGBoost contributors
*/
#include <xgboost/base.h>
@ -9,6 +9,7 @@
#include "../histogram_helpers.h"
#include "gtest/gtest.h"
#include "../../../src/common/categorical.h"
#include "../../../src/common/hist_util.h"
#include "../../../src/data/ellpack_page.cuh"
@ -77,6 +78,45 @@ TEST(EllpackPage, BuildGidxSparse) {
}
}
TEST(EllpackPage, FromCategoricalBasic) {
using common::AsCat;
size_t constexpr kRows = 1000, kCats = 13, kCols = 1;
size_t max_bins = 8;
auto x = GenerateRandomCategoricalSingleColumn(kRows, kCats);
auto m = GetDMatrixFromData(x, kRows, 1);
auto& h_ft = m->Info().feature_types.HostVector();
h_ft.resize(kCols, FeatureType::kCategorical);
BatchParam p(0, max_bins);
auto ellpack = EllpackPage(m.get(), p);
auto accessor = ellpack.Impl()->GetDeviceAccessor(0);
ASSERT_EQ(kCats, accessor.NumBins());
auto x_copy = x;
std::sort(x_copy.begin(), x_copy.end());
auto n_uniques = std::unique(x_copy.begin(), x_copy.end()) - x_copy.begin();
ASSERT_EQ(n_uniques, kCats);
std::vector<uint32_t> h_cuts_ptr(accessor.feature_segments.size());
dh::CopyDeviceSpanToVector(&h_cuts_ptr, accessor.feature_segments);
std::vector<float> h_cuts_values(accessor.gidx_fvalue_map.size());
dh::CopyDeviceSpanToVector(&h_cuts_values, accessor.gidx_fvalue_map);
ASSERT_EQ(h_cuts_ptr.size(), 2);
ASSERT_EQ(h_cuts_values.size(), kCats);
std::vector<common::CompressedByteT> const &h_gidx_buffer =
ellpack.Impl()->gidx_buffer.HostVector();
auto h_gidx_iter = common::CompressedIterator<uint32_t>(
h_gidx_buffer.data(), accessor.NumSymbols());
for (size_t i = 0; i < x.size(); ++i) {
auto bin = h_gidx_iter[i];
auto bin_value = h_cuts_values.at(bin);
ASSERT_EQ(AsCat(x[i]), AsCat(bin_value));
}
}
struct ReadRowFunction {
EllpackDeviceAccessor matrix;
int row;

View File

@ -17,6 +17,7 @@
#include "helpers.h"
#include "xgboost/c_api.h"
#include "../../src/data/adapter.h"
#include "../../src/data/simple_dmatrix.h"
#include "../../src/gbm/gbtree_model.h"
#include "xgboost/predictor.h"
@ -350,6 +351,13 @@ RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label,
return out;
}
std::shared_ptr<DMatrix>
GetDMatrixFromData(const std::vector<float> &x, int num_rows, int num_columns){
data::DenseAdapter adapter(x.data(), num_rows, num_columns);
return std::shared_ptr<DMatrix>(new data::SimpleDMatrix(
&adapter, std::numeric_limits<float>::quiet_NaN(), 1));
}
std::unique_ptr<DMatrix> CreateSparsePageDMatrix(
size_t n_entries, size_t page_size, std::string tmp_file) {
// Create sufficiently large data to make two row pages
@ -539,5 +547,4 @@ RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv) {
return RMMAllocatorPtr(nullptr, DeleteRMMResource);
}
#endif // !defined(XGBOOST_USE_RMM) || XGBOOST_USE_RMM != 1
} // namespace xgboost

View File

@ -42,6 +42,12 @@ struct LearnerModelParam;
class GradientBooster;
}
template <typename Float>
Float RelError(Float l, Float r) {
static_assert(std::is_floating_point<Float>::value, "");
return std::abs(1.0f - l / r);
}
bool FileExists(const std::string& filename);
int64_t GetFileSize(const std::string& filename);
@ -254,6 +260,22 @@ class RandomDataGenerator {
#endif
};
inline std::vector<float>
GenerateRandomCategoricalSingleColumn(int n, size_t num_categories) {
std::vector<float> x(n);
std::mt19937 rng(0);
std::uniform_int_distribution<size_t> dist(0, num_categories - 1);
std::generate(x.begin(), x.end(), [&]() { return dist(rng); });
// Make sure each category is present
for(size_t i = 0; i < num_categories; i++) {
x[i] = i;
}
return x;
}
std::shared_ptr<DMatrix> GetDMatrixFromData(const std::vector<float> &x,
int num_rows, int num_columns);
std::unique_ptr<DMatrix> CreateSparsePageDMatrix(
size_t n_entries, size_t page_size, std::string tmp_file);

View File

@ -45,7 +45,7 @@ inline std::unique_ptr<EllpackPageImpl> BuildEllpackPage(
}
auto page = std::unique_ptr<EllpackPageImpl>(
new EllpackPageImpl(0, cmat, batch, dmat->IsDense(), row_stride));
new EllpackPageImpl(0, cmat, batch, dmat->IsDense(), row_stride, {}));
return page;
}

View File

@ -1,6 +1,7 @@
#include <gtest/gtest.h>
#include <vector>
#include "../../helpers.h"
#include "../../../../src/common/categorical.h"
#include "../../../../src/tree/gpu_hist/row_partitioner.cuh"
#include "../../../../src/tree/gpu_hist/histogram.cuh"
@ -97,5 +98,80 @@ TEST(Histogram, GPUDeterministic) {
}
}
}
std::vector<float> OneHotEncodeFeature(std::vector<float> x, size_t num_cat) {
std::vector<float> ret(x.size() * num_cat, 0);
size_t n_rows = x.size();
for (size_t r = 0; r < n_rows; ++r) {
bst_cat_t cat = common::AsCat(x[r]);
ret.at(num_cat * r + cat) = 1;
}
return ret;
}
// Test 1 vs rest categorical histogram is equivalent to one hot encoded data.
void TestGPUHistogramCategorical(size_t num_categories) {
size_t constexpr kRows = 340;
size_t constexpr kBins = 256;
auto x = GenerateRandomCategoricalSingleColumn(kRows, num_categories);
auto cat_m = GetDMatrixFromData(x, kRows, 1);
cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical);
BatchParam batch_param{0, static_cast<int32_t>(kBins), 0};
tree::RowPartitioner row_partitioner(0, kRows);
auto ridx = row_partitioner.GetRows(0);
dh::device_vector<GradientPairPrecise> cat_hist(num_categories);
auto gpair = GenerateRandomGradients(kRows, 0, 2);
gpair.SetDevice(0);
auto rounding = CreateRoundingFactor<GradientPairPrecise>(gpair.DeviceSpan());
// Generate hist with cat data.
for (auto const &batch : cat_m->GetBatches<EllpackPage>(batch_param)) {
auto* page = batch.Impl();
FeatureGroups single_group(page->Cuts());
BuildGradientHistogram(page->GetDeviceAccessor(0),
single_group.DeviceAccessor(0),
gpair.DeviceSpan(), ridx, dh::ToSpan(cat_hist),
rounding);
}
// Generate hist with one hot encoded data.
auto x_encoded = OneHotEncodeFeature(x, num_categories);
auto encode_m = GetDMatrixFromData(x_encoded, kRows, num_categories);
dh::device_vector<GradientPairPrecise> encode_hist(2 * num_categories);
for (auto const &batch : encode_m->GetBatches<EllpackPage>(batch_param)) {
auto* page = batch.Impl();
FeatureGroups single_group(page->Cuts());
BuildGradientHistogram(page->GetDeviceAccessor(0),
single_group.DeviceAccessor(0),
gpair.DeviceSpan(), ridx, dh::ToSpan(encode_hist),
rounding);
}
std::vector<GradientPairPrecise> h_cat_hist(cat_hist.size());
thrust::copy(cat_hist.begin(), cat_hist.end(), h_cat_hist.begin());
auto cat_sum = std::accumulate(h_cat_hist.begin(), h_cat_hist.end(), GradientPairPrecise{});
std::vector<GradientPairPrecise> h_encode_hist(encode_hist.size());
thrust::copy(encode_hist.begin(), encode_hist.end(), h_encode_hist.begin());
for (size_t c = 0; c < num_categories; ++c) {
auto zero = h_encode_hist[c * 2];
auto one = h_encode_hist[c * 2 + 1];
auto chosen = h_cat_hist[c];
auto not_chosen = cat_sum - chosen;
ASSERT_LE(RelError(zero.GetGrad(), not_chosen.GetGrad()), kRtEps);
ASSERT_LE(RelError(zero.GetHess(), not_chosen.GetHess()), kRtEps);
ASSERT_LE(RelError(one.GetGrad(), chosen.GetGrad()), kRtEps);
ASSERT_LE(RelError(one.GetHess(), chosen.GetHess()), kRtEps);
}
}
TEST(Histogram, GPUHistCategorical) {
for (size_t num_categories = 2; num_categories < 8; ++num_categories) {
TestGPUHistogramCategorical(num_categories);
}
}
} // namespace tree
} // namespace xgboost