Group aware GPU sketching. (#5551)
* Group aware GPU weighted sketching. * Distribute group weights to each data point. * Relax the test. * Validate input meta info. * Fix metainfo copy ctor.
This commit is contained in:
@@ -3,22 +3,19 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
#include "xgboost/c_api.h"
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/c_api.h>
|
||||
|
||||
#include "test_hist_util.h"
|
||||
#include "../helpers.h"
|
||||
#include "../data/test_array_interface.h"
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
#include "../../../src/common/hist_util.h"
|
||||
|
||||
#include "../helpers.h"
|
||||
#include <xgboost/data.h>
|
||||
#include "../../../src/data/device_adapter.cuh"
|
||||
#include "../data/test_array_interface.h"
|
||||
#include "../../../src/common/math.h"
|
||||
#include "../../../src/data/simple_dmatrix.h"
|
||||
#include "test_hist_util.h"
|
||||
#include "../../../include/xgboost/logging.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -143,7 +140,6 @@ TEST(HistUtil, DeviceSketchMultipleColumns) {
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchMultipleColumnsWeights) {
|
||||
@@ -161,6 +157,29 @@ TEST(HistUtil, DeviceSketchMultipleColumnsWeights) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUitl, DeviceSketchWeights) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
auto weighted_dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
auto& h_weights = weighted_dmat->Info().weights_.HostVector();
|
||||
h_weights.resize(num_rows);
|
||||
std::fill(h_weights.begin(), h_weights.end(), 1.0f);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
auto wcuts = DeviceSketch(0, weighted_dmat.get(), num_bins);
|
||||
ASSERT_EQ(cuts.MinValues(), wcuts.MinValues());
|
||||
ASSERT_EQ(cuts.Ptrs(), wcuts.Ptrs());
|
||||
ASSERT_EQ(cuts.Values(), wcuts.Values());
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
ValidateCuts(wcuts, weighted_dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchBatches) {
|
||||
int num_bins = 256;
|
||||
int num_rows = 5000;
|
||||
@@ -190,8 +209,7 @@ TEST(HistUtil, DeviceSketchMultipleColumnsExternal) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterDeviceSketch)
|
||||
{
|
||||
TEST(HistUtil, AdapterDeviceSketch) {
|
||||
int rows = 5;
|
||||
int cols = 1;
|
||||
int num_bins = 4;
|
||||
@@ -235,7 +253,7 @@ TEST(HistUtil, AdapterDeviceSketchMemory) {
|
||||
bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant);
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterDeviceSketchCategorical) {
|
||||
TEST(HistUtil, AdapterDeviceSketchCategorical) {
|
||||
int categorical_sizes[] = {2, 6, 8, 12};
|
||||
int num_bins = 256;
|
||||
int sizes[] = {25, 100, 1000};
|
||||
@@ -268,6 +286,7 @@ TEST(HistUtil, AdapterDeviceSketchMultipleColumns) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterDeviceSketchBatches) {
|
||||
int num_bins = 256;
|
||||
int num_rows = 5000;
|
||||
@@ -305,7 +324,38 @@ TEST(HistUtil, SketchingEquivalent) {
|
||||
EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchFromGroupWeights) {
|
||||
size_t constexpr kRows = 3000, kCols = 200, kBins = 256;
|
||||
size_t constexpr kGroups = 10;
|
||||
auto m = RandomDataGenerator {kRows, kCols, 0}.GenerateDMatrix();
|
||||
auto& h_weights = m->Info().weights_.HostVector();
|
||||
h_weights.resize(kRows);
|
||||
std::fill(h_weights.begin(), h_weights.end(), 1.0f);
|
||||
std::vector<bst_group_t> groups(kGroups);
|
||||
for (size_t i = 0; i < kGroups; ++i) {
|
||||
groups[i] = kRows / kGroups;
|
||||
}
|
||||
m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
HistogramCuts weighted_cuts = DeviceSketch(0, m.get(), kBins, 0);
|
||||
|
||||
h_weights.clear();
|
||||
HistogramCuts cuts = DeviceSketch(0, m.get(), kBins, 0);
|
||||
|
||||
ASSERT_EQ(cuts.Values().size(), weighted_cuts.Values().size());
|
||||
ASSERT_EQ(cuts.MinValues().size(), weighted_cuts.MinValues().size());
|
||||
ASSERT_EQ(cuts.Ptrs().size(), weighted_cuts.Ptrs().size());
|
||||
|
||||
for (size_t i = 0; i < cuts.Values().size(); ++i) {
|
||||
EXPECT_EQ(cuts.Values()[i], weighted_cuts.Values()[i]) << "i:"<< i;
|
||||
}
|
||||
for (size_t i = 0; i < cuts.MinValues().size(); ++i) {
|
||||
ASSERT_EQ(cuts.MinValues()[i], weighted_cuts.MinValues()[i]);
|
||||
}
|
||||
for (size_t i = 0; i < cuts.Ptrs().size(); ++i) {
|
||||
ASSERT_EQ(cuts.Ptrs().at(i), weighted_cuts.Ptrs().at(i));
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user