Group aware GPU sketching. (#5551)

* Group aware GPU weighted sketching.

* Distribute group weights to each data point.
* Relax the test.
* Validate input meta info.
* Fix metainfo copy ctor.
This commit is contained in:
Jiaming Yuan 2020-04-20 17:18:52 +08:00 committed by GitHub
parent 397d8f0ee7
commit 29a4cfe400
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 296 additions and 124 deletions

View File

@ -87,11 +87,23 @@ class MetaInfo {
this->weights_.Resize(that.weights_.Size()); this->weights_.Resize(that.weights_.Size());
this->weights_.Copy(that.weights_); this->weights_.Copy(that.weights_);
this->base_margin_.Resize(that.base_margin_.Size()); this->base_margin_.Resize(that.base_margin_.Size());
this->base_margin_.Copy(that.base_margin_); this->base_margin_.Copy(that.base_margin_);
this->labels_lower_bound_.Resize(that.labels_lower_bound_.Size());
this->labels_lower_bound_.Copy(that.labels_lower_bound_);
this->labels_upper_bound_.Resize(that.labels_upper_bound_.Size());
this->labels_upper_bound_.Copy(that.labels_upper_bound_);
return *this; return *this;
} }
/*!
* \brief Validate all metainfo.
*/
void Validate() const;
MetaInfo Slice(common::Span<int32_t const> ridxs) const; MetaInfo Slice(common::Span<int32_t const> ridxs) const;
/*! /*!
* \brief Get weight of each instances. * \brief Get weight of each instances.

View File

@ -138,25 +138,26 @@ void GetColumnSizesScan(int device,
* \param column_sizes_scan Describes the boundaries of column segments in * \param column_sizes_scan Describes the boundaries of column segments in
* sorted data * sorted data
*/ */
void ExtractCuts(int device, Span<SketchEntry> cuts, void ExtractCuts(int device,
size_t num_cuts_per_feature, Span<Entry> sorted_data, size_t num_cuts_per_feature,
Span<size_t> column_sizes_scan) { Span<Entry const> sorted_data,
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) { Span<size_t const> column_sizes_scan,
Span<SketchEntry> out_cuts) {
dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) {
// Each thread is responsible for obtaining one cut from the sorted input // Each thread is responsible for obtaining one cut from the sorted input
size_t column_idx = idx / num_cuts_per_feature; size_t column_idx = idx / num_cuts_per_feature;
size_t column_size = size_t column_size =
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx]; column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
size_t num_available_cuts = size_t num_available_cuts =
min(size_t(num_cuts_per_feature), column_size); min(static_cast<size_t>(num_cuts_per_feature), column_size);
size_t cut_idx = idx % num_cuts_per_feature; size_t cut_idx = idx % num_cuts_per_feature;
if (cut_idx >= num_available_cuts) return; if (cut_idx >= num_available_cuts) return;
Span<Entry const> column_entries =
Span<Entry> column_entries =
sorted_data.subspan(column_sizes_scan[column_idx], column_size); sorted_data.subspan(column_sizes_scan[column_idx], column_size);
size_t rank = (column_entries.size() * cut_idx) /
size_t rank = (column_entries.size() * cut_idx) / num_available_cuts; static_cast<float>(num_available_cuts);
auto value = column_entries[rank].fvalue; out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1,
cuts[idx] = SketchEntry(rank, rank + 1, 1, value); column_entries[rank].fvalue);
}); });
} }
@ -170,31 +171,32 @@ void ExtractCuts(int device, Span<SketchEntry> cuts,
* \param weights_scan Inclusive scan of weights for each entry in sorted_data. * \param weights_scan Inclusive scan of weights for each entry in sorted_data.
* \param column_sizes_scan Describes the boundaries of column segments in sorted data. * \param column_sizes_scan Describes the boundaries of column segments in sorted data.
*/ */
void ExtractWeightedCuts(int device, Span<SketchEntry> cuts, void ExtractWeightedCuts(int device,
size_t num_cuts_per_feature, Span<Entry> sorted_data, size_t num_cuts_per_feature,
Span<Entry> sorted_data,
Span<float> weights_scan, Span<float> weights_scan,
Span<size_t> column_sizes_scan) { Span<size_t> column_sizes_scan,
Span<SketchEntry> cuts) {
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) { dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) {
// Each thread is responsible for obtaining one cut from the sorted input // Each thread is responsible for obtaining one cut from the sorted input
size_t column_idx = idx / num_cuts_per_feature; size_t column_idx = idx / num_cuts_per_feature;
size_t column_size = size_t column_size =
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx]; column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
size_t num_available_cuts = size_t num_available_cuts =
min(size_t(num_cuts_per_feature), column_size); min(static_cast<size_t>(num_cuts_per_feature), column_size);
size_t cut_idx = idx % num_cuts_per_feature; size_t cut_idx = idx % num_cuts_per_feature;
if (cut_idx >= num_available_cuts) return; if (cut_idx >= num_available_cuts) return;
Span<Entry> column_entries = Span<Entry> column_entries =
sorted_data.subspan(column_sizes_scan[column_idx], column_size); sorted_data.subspan(column_sizes_scan[column_idx], column_size);
Span<float> column_weights =
weights_scan.subspan(column_sizes_scan[column_idx], column_size);
float total_column_weight = column_weights.back(); Span<float> column_weights_scan =
weights_scan.subspan(column_sizes_scan[column_idx], column_size);
float total_column_weight = column_weights_scan.back();
size_t sample_idx = 0; size_t sample_idx = 0;
if (cut_idx == 0) { if (cut_idx == 0) {
// First cut // First cut
sample_idx = 0; sample_idx = 0;
} else if (cut_idx == num_available_cuts - 1) { } else if (cut_idx == num_available_cuts) {
// Last cut // Last cut
sample_idx = column_entries.size() - 1; sample_idx = column_entries.size() - 1;
} else if (num_available_cuts == column_size) { } else if (num_available_cuts == column_size) {
@ -204,15 +206,18 @@ void ExtractWeightedCuts(int device, Span<SketchEntry> cuts,
} else { } else {
bst_float rank = (total_column_weight * cut_idx) / bst_float rank = (total_column_weight * cut_idx) /
static_cast<float>(num_available_cuts); static_cast<float>(num_available_cuts);
sample_idx = thrust::upper_bound(thrust::seq, column_weights.begin(), sample_idx = thrust::upper_bound(thrust::seq,
column_weights.end(), rank) - column_weights_scan.begin(),
column_weights.begin() - 1; column_weights_scan.end(),
rank) -
column_weights_scan.begin();
sample_idx = sample_idx =
max(size_t(0), min(sample_idx, column_entries.size() - 1)); max(static_cast<size_t>(0),
min(sample_idx, column_entries.size() - 1));
} }
// repeated values will be filtered out on the CPU // repeated values will be filtered out on the CPU
bst_float rmin = sample_idx > 0 ? column_weights[sample_idx - 1] : 0; bst_float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0.0f;
bst_float rmax = column_weights[sample_idx]; bst_float rmax = column_weights_scan[sample_idx];
cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin, cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin,
column_entries[sample_idx].fvalue); column_entries[sample_idx].fvalue);
}); });
@ -235,9 +240,10 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan); thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts); dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts);
ExtractCuts(device, {cuts.data().get(), cuts.size()}, num_cuts, ExtractCuts(device, num_cuts,
{sorted_entries.data().get(), sorted_entries.size()}, dh::ToSpan(sorted_entries),
{column_sizes_scan.data().get(), column_sizes_scan.size()}); dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts));
// add cuts into sketches // add cuts into sketches
thrust::host_vector<SketchEntry> host_cuts(cuts); thrust::host_vector<SketchEntry> host_cuts(cuts);
@ -246,8 +252,9 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
void ProcessWeightedBatch(int device, const SparsePage& page, void ProcessWeightedBatch(int device, const SparsePage& page,
Span<const float> weights, size_t begin, size_t end, Span<const float> weights, size_t begin, size_t end,
SketchContainer* sketch_container, int num_cuts, SketchContainer* sketch_container, int num_cuts_per_feature,
size_t num_columns) { size_t num_columns,
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
dh::XGBCachingDeviceAllocator<char> alloc; dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector(); const auto& host_data = page.data.ConstHostVector();
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin, dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
@ -259,6 +266,25 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
page.offset.SetDevice(device); page.offset.SetDevice(device);
auto row_ptrs = page.offset.ConstDeviceSpan(); auto row_ptrs = page.offset.ConstDeviceSpan();
size_t base_rowid = page.base_rowid; size_t base_rowid = page.base_rowid;
if (is_ranking) {
CHECK_GE(d_group_ptr.size(), 2)
<< "Must have at least 1 group for ranking.";
CHECK_EQ(weights.size(), d_group_ptr.size() - 1)
<< "Weight size should equal to number of groups.";
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
size_t element_idx = idx + begin;
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
row_ptrs.end(), element_idx) -
row_ptrs.begin() - 1;
auto it =
thrust::upper_bound(thrust::seq,
d_group_ptr.cbegin(), d_group_ptr.cend(),
ridx + base_rowid) - 1;
bst_group_t group = thrust::distance(d_group_ptr.cbegin(), it);
d_temp_weights[idx] = weights[group];
});
} else {
CHECK_EQ(weights.size(), page.offset.Size() - 1);
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) { dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
size_t element_idx = idx + begin; size_t element_idx = idx + begin;
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(), size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
@ -266,8 +292,9 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
row_ptrs.begin() - 1; row_ptrs.begin() - 1;
d_temp_weights[idx] = weights[ridx + base_rowid]; d_temp_weights[idx] = weights[ridx + base_rowid];
}); });
}
// Sort // Sort both entries and wegihts.
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries.begin(), thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), temp_weights.begin(), sorted_entries.end(), temp_weights.begin(),
EntryCompareOp()); EntryCompareOp());
@ -287,26 +314,26 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan); thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);
// Extract cuts // Extract cuts
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts); dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts_per_feature);
ExtractWeightedCuts( ExtractWeightedCuts(device, num_cuts_per_feature,
device, {cuts.data().get(), cuts.size()}, num_cuts, dh::ToSpan(sorted_entries),
{sorted_entries.data().get(), sorted_entries.size()}, dh::ToSpan(temp_weights),
{temp_weights.data().get(), temp_weights.size()}, dh::ToSpan(column_sizes_scan),
{column_sizes_scan.data().get(), column_sizes_scan.size()}); dh::ToSpan(cuts));
// add cuts into sketches // add cuts into sketches
thrust::host_vector<SketchEntry> host_cuts(cuts); thrust::host_vector<SketchEntry> host_cuts(cuts);
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan); sketch_container->Push(num_cuts_per_feature, host_cuts, host_column_sizes_scan);
} }
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
size_t sketch_batch_num_elements) { size_t sketch_batch_num_elements) {
// Configure batch size based on available memory // Configure batch size based on available memory
bool has_weights = dmat->Info().weights_.Size() > 0; bool has_weights = dmat->Info().weights_.Size() > 0;
size_t num_cuts = RequiredSampleCuts(max_bins, dmat->Info().num_row_); size_t num_cuts_per_feature = RequiredSampleCuts(max_bins, dmat->Info().num_row_);
if (sketch_batch_num_elements == 0) { if (sketch_batch_num_elements == 0) {
int bytes_per_element = has_weights ? 24 : 16; int bytes_per_element = has_weights ? 24 : 16;
size_t bytes_cuts = num_cuts * dmat->Info().num_col_ * sizeof(SketchEntry); size_t bytes_cuts = num_cuts_per_feature * dmat->Info().num_col_ * sizeof(SketchEntry);
// use up to 80% of available space // use up to 80% of available space
sketch_batch_num_elements = sketch_batch_num_elements =
(dh::AvailableMemory(device) - bytes_cuts) * 0.8 / bytes_per_element; (dh::AvailableMemory(device) - bytes_cuts) * 0.8 / bytes_per_element;
@ -320,15 +347,21 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
dmat->Info().weights_.SetDevice(device); dmat->Info().weights_.SetDevice(device);
for (const auto& batch : dmat->GetBatches<SparsePage>()) { for (const auto& batch : dmat->GetBatches<SparsePage>()) {
size_t batch_nnz = batch.data.Size(); size_t batch_nnz = batch.data.Size();
for (auto begin = 0ull; begin < batch_nnz; auto const& info = dmat->Info();
begin += sketch_batch_num_elements) { dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
info.group_ptr_.cend());
for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) {
size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements)); size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements));
if (has_weights) { if (has_weights) {
bool is_ranking = CutsBuilder::UseGroup(dmat);
ProcessWeightedBatch( ProcessWeightedBatch(
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end, device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
&sketch_container, num_cuts, dmat->Info().num_col_); &sketch_container,
num_cuts_per_feature,
dmat->Info().num_col_,
is_ranking, dh::ToSpan(groups));
} else { } else {
ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts, ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts_per_feature,
dmat->Info().num_col_); dmat->Info().num_col_);
} }
} }
@ -383,9 +416,10 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
// Extract the cuts from all columns concurrently // Extract the cuts from all columns concurrently
dh::caching_device_vector<SketchEntry> cuts(adapter->NumColumns() * num_cuts); dh::caching_device_vector<SketchEntry> cuts(adapter->NumColumns() * num_cuts);
ExtractCuts(adapter->DeviceIdx(), {cuts.data().get(), cuts.size()}, num_cuts, ExtractCuts(adapter->DeviceIdx(), num_cuts,
{sorted_entries.data().get(), sorted_entries.size()}, dh::ToSpan(sorted_entries),
{column_sizes_scan.data().get(), column_sizes_scan.size()}); dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts));
// Push cuts into sketches stored in host memory // Push cuts into sketches stored in host memory
thrust::host_vector<SketchEntry> host_cuts(cuts); thrust::host_vector<SketchEntry> host_cuts(cuts);

View File

@ -127,11 +127,11 @@ class HistogramCuts {
class CutsBuilder { class CutsBuilder {
public: public:
using WQSketch = common::WQuantileSketch<bst_float, bst_float>; using WQSketch = common::WQuantileSketch<bst_float, bst_float>;
/* \brief return whether group for ranking is used. */
static bool UseGroup(DMatrix* dmat);
protected: protected:
HistogramCuts* p_cuts_; HistogramCuts* p_cuts_;
/* \brief return whether group for ranking is used. */
static bool UseGroup(DMatrix* dmat);
public: public:
explicit CutsBuilder(HistogramCuts* p_cuts) : p_cuts_{p_cuts} {} explicit CutsBuilder(HistogramCuts* p_cuts) : p_cuts_{p_cuts} {}

View File

@ -338,6 +338,45 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
} }
} }
void MetaInfo::Validate() const {
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
<< "Size of weights must equal to number of groups when ranking "
"group is used.";
return;
}
if (group_ptr_.size() != 0) {
CHECK_EQ(group_ptr_.back(), num_row_)
<< "Invalid group structure. Number of rows obtained from groups "
"doesn't equal to actual number of rows given by data.";
}
if (weights_.Size() != 0) {
CHECK_EQ(weights_.Size(), num_row_)
<< "Size of weights must equal to number of rows.";
return;
}
if (labels_.Size() != 0) {
CHECK_EQ(labels_.Size(), num_row_)
<< "Size of labels must equal to number of rows.";
return;
}
if (labels_lower_bound_.Size() != 0) {
CHECK_EQ(labels_lower_bound_.Size(), num_row_)
<< "Size of label_lower_bound must equal to number of rows.";
return;
}
if (labels_upper_bound_.Size() != 0) {
CHECK_EQ(labels_upper_bound_.Size(), num_row_)
<< "Size of label_upper_bound must equal to number of rows.";
return;
}
CHECK_LE(num_nonzero_, num_col_ * num_row_);
if (base_margin_.Size() != 0) {
CHECK_EQ(base_margin_.Size() % num_row_, 0)
<< "Size of base margin must be a multiple of number of rows.";
}
}
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
common::AssertGPUSupport(); common::AssertGPUSupport();

View File

@ -1048,15 +1048,7 @@ class LearnerImpl : public LearnerIO {
void ValidateDMatrix(DMatrix* p_fmat) const { void ValidateDMatrix(DMatrix* p_fmat) const {
MetaInfo const& info = p_fmat->Info(); MetaInfo const& info = p_fmat->Info();
auto const& weights = info.weights_; info.Validate();
if (info.group_ptr_.size() != 0 && weights.Size() != 0) {
CHECK(weights.Size() == info.group_ptr_.size() - 1)
<< "\n"
<< "weights size: " << weights.Size() << ", "
<< "groups size: " << info.group_ptr_.size() -1 << ", "
<< "num rows: " << p_fmat->Info().num_row_ << "\n"
<< "Number of weights should be equal to number of groups in ranking task.";
}
auto const row_based_split = [this]() { auto const row_based_split = [this]() {
return tparam_.dsplit == DataSplitMode::kRow || return tparam_.dsplit == DataSplitMode::kRow ||

View File

@ -3,22 +3,19 @@
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include "xgboost/c_api.h" #include <xgboost/data.h>
#include <xgboost/c_api.h>
#include "test_hist_util.h"
#include "../helpers.h"
#include "../data/test_array_interface.h"
#include "../../../src/common/device_helpers.cuh" #include "../../../src/common/device_helpers.cuh"
#include "../../../src/common/hist_util.h" #include "../../../src/common/hist_util.h"
#include "../helpers.h"
#include <xgboost/data.h>
#include "../../../src/data/device_adapter.cuh" #include "../../../src/data/device_adapter.cuh"
#include "../data/test_array_interface.h"
#include "../../../src/common/math.h" #include "../../../src/common/math.h"
#include "../../../src/data/simple_dmatrix.h" #include "../../../src/data/simple_dmatrix.h"
#include "test_hist_util.h"
#include "../../../include/xgboost/logging.h" #include "../../../include/xgboost/logging.h"
namespace xgboost { namespace xgboost {
@ -143,7 +140,6 @@ TEST(HistUtil, DeviceSketchMultipleColumns) {
ValidateCuts(cuts, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins);
} }
} }
} }
TEST(HistUtil, DeviceSketchMultipleColumnsWeights) { TEST(HistUtil, DeviceSketchMultipleColumnsWeights) {
@ -161,6 +157,29 @@ TEST(HistUtil, DeviceSketchMultipleColumnsWeights) {
} }
} }
TEST(HistUitl, DeviceSketchWeights) {
int bin_sizes[] = {2, 16, 256, 512};
int sizes[] = {100, 1000, 1500};
int num_columns = 5;
for (auto num_rows : sizes) {
auto x = GenerateRandom(num_rows, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
auto weighted_dmat = GetDMatrixFromData(x, num_rows, num_columns);
auto& h_weights = weighted_dmat->Info().weights_.HostVector();
h_weights.resize(num_rows);
std::fill(h_weights.begin(), h_weights.end(), 1.0f);
for (auto num_bins : bin_sizes) {
auto cuts = DeviceSketch(0, dmat.get(), num_bins);
auto wcuts = DeviceSketch(0, weighted_dmat.get(), num_bins);
ASSERT_EQ(cuts.MinValues(), wcuts.MinValues());
ASSERT_EQ(cuts.Ptrs(), wcuts.Ptrs());
ASSERT_EQ(cuts.Values(), wcuts.Values());
ValidateCuts(cuts, dmat.get(), num_bins);
ValidateCuts(wcuts, weighted_dmat.get(), num_bins);
}
}
}
TEST(HistUtil, DeviceSketchBatches) { TEST(HistUtil, DeviceSketchBatches) {
int num_bins = 256; int num_bins = 256;
int num_rows = 5000; int num_rows = 5000;
@ -190,8 +209,7 @@ TEST(HistUtil, DeviceSketchMultipleColumnsExternal) {
} }
} }
TEST(HistUtil, AdapterDeviceSketch) TEST(HistUtil, AdapterDeviceSketch) {
{
int rows = 5; int rows = 5;
int cols = 1; int cols = 1;
int num_bins = 4; int num_bins = 4;
@ -268,6 +286,7 @@ TEST(HistUtil, AdapterDeviceSketchMultipleColumns) {
} }
} }
} }
TEST(HistUtil, AdapterDeviceSketchBatches) { TEST(HistUtil, AdapterDeviceSketchBatches) {
int num_bins = 256; int num_bins = 256;
int num_rows = 5000; int num_rows = 5000;
@ -305,7 +324,38 @@ TEST(HistUtil, SketchingEquivalent) {
EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues()); EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues());
} }
} }
}
TEST(HistUtil, DeviceSketchFromGroupWeights) {
size_t constexpr kRows = 3000, kCols = 200, kBins = 256;
size_t constexpr kGroups = 10;
auto m = RandomDataGenerator {kRows, kCols, 0}.GenerateDMatrix();
auto& h_weights = m->Info().weights_.HostVector();
h_weights.resize(kRows);
std::fill(h_weights.begin(), h_weights.end(), 1.0f);
std::vector<bst_group_t> groups(kGroups);
for (size_t i = 0; i < kGroups; ++i) {
groups[i] = kRows / kGroups;
}
m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
HistogramCuts weighted_cuts = DeviceSketch(0, m.get(), kBins, 0);
h_weights.clear();
HistogramCuts cuts = DeviceSketch(0, m.get(), kBins, 0);
ASSERT_EQ(cuts.Values().size(), weighted_cuts.Values().size());
ASSERT_EQ(cuts.MinValues().size(), weighted_cuts.MinValues().size());
ASSERT_EQ(cuts.Ptrs().size(), weighted_cuts.Ptrs().size());
for (size_t i = 0; i < cuts.Values().size(); ++i) {
EXPECT_EQ(cuts.Values()[i], weighted_cuts.Values()[i]) << "i:"<< i;
}
for (size_t i = 0; i < cuts.MinValues().size(); ++i) {
ASSERT_EQ(cuts.MinValues()[i], weighted_cuts.MinValues()[i]);
}
for (size_t i = 0; i < cuts.Ptrs().size(); ++i) {
ASSERT_EQ(cuts.Ptrs().at(i), weighted_cuts.Ptrs().at(i));
}
} }
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -9,6 +9,11 @@
#include "../../../src/data/simple_dmatrix.h" #include "../../../src/data/simple_dmatrix.h"
#include "../../../src/data/adapter.h" #include "../../../src/data/adapter.h"
#ifdef __CUDACC__
#include <xgboost/json.h>
#include "../../../src/data/device_adapter.cuh"
#endif // __CUDACC__
// Some helper functions used to test both GPU and CPU algorithms // Some helper functions used to test both GPU and CPU algorithms
// //
namespace xgboost { namespace xgboost {
@ -69,11 +74,11 @@ inline std::vector<float> GenerateRandomCategoricalSingleColumn(int n,
return x; return x;
} }
inline std::shared_ptr<data::SimpleDMatrix> GetDMatrixFromData(const std::vector<float>& x, int num_rows, int num_columns) { inline std::shared_ptr<data::SimpleDMatrix>
GetDMatrixFromData(const std::vector<float> &x, int num_rows, int num_columns) {
data::DenseAdapter adapter(x.data(), num_rows, num_columns); data::DenseAdapter adapter(x.data(), num_rows, num_columns);
return std::shared_ptr<data::SimpleDMatrix>(new data::SimpleDMatrix( return std::shared_ptr<data::SimpleDMatrix>(new data::SimpleDMatrix(
&adapter, std::numeric_limits<float>::quiet_NaN(), &adapter, std::numeric_limits<float>::quiet_NaN(), 1));
1));
} }
inline std::shared_ptr<DMatrix> GetExternalMemoryDMatrixFromData( inline std::shared_ptr<DMatrix> GetExternalMemoryDMatrixFromData(
@ -97,7 +102,8 @@ inline std::shared_ptr<DMatrix> GetExternalMemoryDMatrixFromData(
// Test that elements are approximately equally distributed among bins // Test that elements are approximately equally distributed among bins
inline void TestBinDistribution(const HistogramCuts &cuts, int column_idx, inline void TestBinDistribution(const HistogramCuts &cuts, int column_idx,
const std::vector<float>& sorted_column,const std::vector<float >&sorted_weights, const std::vector<float> &sorted_column,
const std::vector<float> &sorted_weights,
int num_bins) { int num_bins) {
std::map<int, int> bin_weights; std::map<int, int> bin_weights;
for (auto i = 0ull; i < sorted_column.size(); i++) { for (auto i = 0ull; i < sorted_column.size(); i++) {
@ -117,9 +123,9 @@ inline void TestBinDistribution(const HistogramCuts& cuts, int column_idx,
} }
} }
// Test sketch quantiles against the real quantiles // Test sketch quantiles against the real quantiles Not a very strict
// Not a very strict test // test
inline void TestRank(const std::vector<float>& cuts, inline void TestRank(const std::vector<float> &column_cuts,
const std::vector<float> &sorted_x, const std::vector<float> &sorted_x,
const std::vector<float> &sorted_weights) { const std::vector<float> &sorted_weights) {
double eps = 0.05; double eps = 0.05;
@ -128,14 +134,14 @@ inline void TestRank(const std::vector<float>& cuts,
// Ignore the last cut, its special // Ignore the last cut, its special
double sum_weight = 0.0; double sum_weight = 0.0;
size_t j = 0; size_t j = 0;
for (size_t i = 0; i < cuts.size() - 1; i++) { for (size_t i = 0; i < column_cuts.size() - 1; i++) {
while (cuts[i] > sorted_x[j]) { while (column_cuts[i] > sorted_x[j]) {
sum_weight += sorted_weights[j]; sum_weight += sorted_weights[j];
j++; j++;
} }
double expected_rank = ((i + 1) * total_weight) / cuts.size(); double expected_rank = ((i + 1) * total_weight) / column_cuts.size();
double acceptable_error = std::max(2.0, total_weight * eps); double acceptable_error = std::max(2.9, total_weight * eps);
ASSERT_LE(std::abs(expected_rank - sum_weight), acceptable_error); EXPECT_LE(std::abs(expected_rank - sum_weight), acceptable_error);
} }
} }
@ -167,8 +173,7 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx,
ASSERT_EQ(cuts.SearchBin(v, column_idx), cuts.Ptrs()[column_idx] + i); ASSERT_EQ(cuts.SearchBin(v, column_idx), cuts.Ptrs()[column_idx] + i);
i++; i++;
} }
} } else {
else {
int num_cuts_column = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx]; int num_cuts_column = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx];
std::vector<float> column_cuts(num_cuts_column); std::vector<float> column_cuts(num_cuts_column);
std::copy(cuts.Values().begin() + cuts.Ptrs()[column_idx], std::copy(cuts.Values().begin() + cuts.Ptrs()[column_idx],
@ -196,10 +201,8 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat,
const auto& w = dmat->Info().weights_.HostVector(); const auto& w = dmat->Info().weights_.HostVector();
std::vector<size_t > index(col.size()); std::vector<size_t > index(col.size());
std::iota(index.begin(), index.end(), 0); std::iota(index.begin(), index.end(), 0);
std::sort(index.begin(), index.end(),[=](size_t a,size_t b) std::sort(index.begin(), index.end(),
{ [=](size_t a, size_t b) { return col[a] < col[b]; });
return col[a] < col[b];
});
std::vector<float> sorted_column(col.size()); std::vector<float> sorted_column(col.size());
std::vector<float> sorted_weights(col.size(), 1.0); std::vector<float> sorted_weights(col.size(), 1.0);

View File

@ -141,3 +141,17 @@ TEST(MetaInfo, LoadQid) {
CHECK(batch.data.HostVector() == expected_data); CHECK(batch.data.HostVector() == expected_data);
} }
} }
TEST(MetaInfo, Validate) {
xgboost::MetaInfo info;
info.num_row_ = 10;
info.num_nonzero_ = 12;
info.num_col_ = 3;
std::vector<xgboost::bst_group_t> groups (11);
info.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, 11);
EXPECT_THROW(info.Validate(), dmlc::Error);
std::vector<float> labels(info.num_row_ + 1);
info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1);
EXPECT_THROW(info.Validate(), dmlc::Error);
}

View File

@ -1,14 +1,13 @@
import numpy as np import numpy as np
from scipy.sparse import csr_matrix
import xgboost import xgboost
import os import os
import math
import unittest import unittest
import itertools import itertools
import shutil import shutil
import urllib.request import urllib.request
import zipfile import zipfile
class TestRanking(unittest.TestCase): class TestRanking(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
@ -50,17 +49,30 @@ class TestRanking(unittest.TestCase):
cls.qid_test = qid_test cls.qid_test = qid_test
cls.qid_valid = qid_valid cls.qid_valid = qid_valid
def setup_weighted(x, y, groups):
# Setup weighted data
data = xgboost.DMatrix(x, y)
groups_segment = [len(list(items))
for _key, items in itertools.groupby(groups)]
data.set_group(groups_segment)
n_groups = len(groups_segment)
weights = np.ones((n_groups,))
data.set_weight(weights)
return data
cls.dtrain_w = setup_weighted(x_train, y_train, qid_train)
cls.dtest_w = setup_weighted(x_test, y_test, qid_test)
cls.dvalid_w = setup_weighted(x_valid, y_valid, qid_valid)
# model training parameters # model training parameters
cls.params = {'booster': 'gbtree', cls.params = {'booster': 'gbtree',
'tree_method': 'gpu_hist', 'tree_method': 'gpu_hist',
'gpu_id': 0, 'gpu_id': 0,
'predictor': 'gpu_predictor' 'predictor': 'gpu_predictor'}
}
cls.cpu_params = {'booster': 'gbtree', cls.cpu_params = {'booster': 'gbtree',
'tree_method': 'hist', 'tree_method': 'hist',
'gpu_id': -1, 'gpu_id': -1,
'predictor': 'cpu_predictor' 'predictor': 'cpu_predictor'}
}
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
@ -87,7 +99,8 @@ class TestRanking(unittest.TestCase):
evals_result = {} evals_result = {}
cls.params['objective'] = rank_objective cls.params['objective'] = rank_objective
cls.params['eval_metric'] = metric_name cls.params['eval_metric'] = metric_name
bst = xgboost.train(cls.params, cls.dtrain, num_boost_round=num_trees, bst = xgboost.train(
cls.params, cls.dtrain, num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds, early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist, evals_result=evals_result) evals=watchlist, evals_result=evals_result)
gpu_map_metric = evals_result['train'][metric_name][-1] gpu_map_metric = evals_result['train'][metric_name][-1]
@ -95,16 +108,31 @@ class TestRanking(unittest.TestCase):
evals_result = {} evals_result = {}
cls.cpu_params['objective'] = rank_objective cls.cpu_params['objective'] = rank_objective
cls.cpu_params['eval_metric'] = metric_name cls.cpu_params['eval_metric'] = metric_name
bstc = xgboost.train(cls.cpu_params, cls.dtrain, num_boost_round=num_trees, bstc = xgboost.train(
cls.cpu_params, cls.dtrain, num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds, early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist, evals_result=evals_result) evals=watchlist, evals_result=evals_result)
cpu_map_metric = evals_result['train'][metric_name][-1] cpu_map_metric = evals_result['train'][metric_name][-1]
print("{0} gpu {1} metric {2}".format(rank_objective, metric_name, gpu_map_metric)) assert np.allclose(gpu_map_metric, cpu_map_metric, tolerance,
print("{0} cpu {1} metric {2}".format(rank_objective, metric_name, cpu_map_metric)) tolerance)
print("gpu best score {0} cpu best score {1}".format(bst.best_score, bstc.best_score)) assert np.allclose(bst.best_score, bstc.best_score, tolerance,
assert np.allclose(gpu_map_metric, cpu_map_metric, tolerance, tolerance) tolerance)
assert np.allclose(bst.best_score, bstc.best_score, tolerance, tolerance)
evals_result_weighted = {}
watchlist = [(cls.dtest_w, 'eval'), (cls.dtrain_w, 'train')]
bst_w = xgboost.train(
cls.params, cls.dtrain_w, num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist, evals_result=evals_result_weighted)
weighted_metric = evals_result_weighted['train'][metric_name][-1]
# GPU Ranking is not deterministic due to `AtomicAddGpair`,
# remove tolerance once the issue is resolved.
# https://github.com/dmlc/xgboost/issues/5561
assert np.allclose(bst_w.best_score, bst.best_score,
tolerance, tolerance)
assert np.allclose(weighted_metric, gpu_map_metric,
tolerance, tolerance)
def test_training_rank_pairwise_map_metric(self): def test_training_rank_pairwise_map_metric(self):
""" """