Fix memory usage of device sketching (#5407)
This commit is contained in:
parent
bb8c8df39d
commit
b745b7acce
@ -168,24 +168,15 @@ struct BatchParam {
|
||||
/*! \brief The GPU device to use. */
|
||||
int gpu_id;
|
||||
/*! \brief Maximum number of bins per feature for histograms. */
|
||||
int max_bin { 0 };
|
||||
/*! \brief Number of rows in a GPU batch, used for finding quantiles on GPU. */
|
||||
int gpu_batch_nrows;
|
||||
int max_bin{0};
|
||||
/*! \brief Page size for external memory mode. */
|
||||
size_t gpu_page_size;
|
||||
BatchParam() = default;
|
||||
BatchParam(int32_t device, int32_t max_bin, int32_t gpu_batch_nrows,
|
||||
size_t gpu_page_size = 0) :
|
||||
gpu_id{device},
|
||||
max_bin{max_bin},
|
||||
gpu_batch_nrows{gpu_batch_nrows},
|
||||
gpu_page_size{gpu_page_size}
|
||||
{}
|
||||
BatchParam(int32_t device, int32_t max_bin, size_t gpu_page_size = 0)
|
||||
: gpu_id{device}, max_bin{max_bin}, gpu_page_size{gpu_page_size} {}
|
||||
inline bool operator!=(const BatchParam& other) const {
|
||||
return gpu_id != other.gpu_id ||
|
||||
max_bin != other.max_bin ||
|
||||
gpu_batch_nrows != other.gpu_batch_nrows ||
|
||||
gpu_page_size != other.gpu_page_size;
|
||||
return gpu_id != other.gpu_id || max_bin != other.max_bin ||
|
||||
gpu_page_size != other.gpu_page_size;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -378,6 +378,11 @@ public:
|
||||
{
|
||||
return stats_.peak_allocated_bytes;
|
||||
}
|
||||
void Clear()
|
||||
{
|
||||
stats_ = DeviceStats();
|
||||
}
|
||||
|
||||
void Log() {
|
||||
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug))
|
||||
return;
|
||||
@ -475,7 +480,8 @@ struct XGBCachingDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
|
||||
template <typename T>
|
||||
using XGBDeviceAllocator = detail::XGBDefaultDeviceAllocatorImpl<T>;
|
||||
/*! Be careful that the initialization constructor is a no-op, which means calling
|
||||
* `vec.resize(n, 1)` won't initialize the memory region to 1. */
|
||||
* `vec.resize(n)` won't initialize the memory region to 0. Instead use
|
||||
* `vec.resize(n, 0)`*/
|
||||
template <typename T>
|
||||
using XGBCachingDeviceAllocator = detail::XGBCachingDeviceAllocatorImpl<T>;
|
||||
/** \brief Specialisation of thrust device vector using custom allocator. */
|
||||
|
||||
@ -97,6 +97,19 @@ struct EntryCompareOp {
|
||||
}
|
||||
};
|
||||
|
||||
// Compute number of sample cuts needed on local node to maintain accuracy
|
||||
// We take more cuts than needed and then reduce them later
|
||||
size_t RequiredSampleCuts(int max_bins, size_t num_rows) {
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * max_bins);
|
||||
size_t dummy_nlevel;
|
||||
size_t num_cuts;
|
||||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
|
||||
num_rows, eps, &dummy_nlevel, &num_cuts);
|
||||
return std::min(num_cuts, num_rows);
|
||||
}
|
||||
|
||||
|
||||
// Count the entries in each column and exclusive scan
|
||||
void GetColumnSizesScan(int device,
|
||||
dh::caching_device_vector<size_t>* column_sizes_scan,
|
||||
@ -210,7 +223,7 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
|
||||
size_t num_columns) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), EntryCompareOp());
|
||||
@ -237,11 +250,11 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
size_t num_columns) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
|
||||
// Binary search to assign weights to each element
|
||||
dh::device_vector<float> temp_weights(sorted_entries.size());
|
||||
dh::caching_device_vector<float> temp_weights(sorted_entries.size());
|
||||
auto d_temp_weights = temp_weights.data().get();
|
||||
page.offset.SetDevice(device);
|
||||
auto row_ptrs = page.offset.ConstDeviceSpan();
|
||||
@ -288,28 +301,29 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
|
||||
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
size_t sketch_batch_num_elements) {
|
||||
// Configure batch size based on available memory
|
||||
bool has_weights = dmat->Info().weights_.Size() > 0;
|
||||
size_t num_cuts = RequiredSampleCuts(max_bins, dmat->Info().num_row_);
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
int bytes_per_element = has_weights ? 24 : 16;
|
||||
size_t bytes_cuts = num_cuts * dmat->Info().num_col_ * sizeof(SketchEntry);
|
||||
// use up to 80% of available space
|
||||
sketch_batch_num_elements =
|
||||
(dh::AvailableMemory(device) - bytes_cuts) * 0.8 / bytes_per_element;
|
||||
}
|
||||
|
||||
HistogramCuts cuts;
|
||||
DenseCuts dense_cuts(&cuts);
|
||||
SketchContainer sketch_container(max_bins, dmat->Info().num_col_,
|
||||
dmat->Info().num_row_);
|
||||
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * max_bins);
|
||||
size_t dummy_nlevel;
|
||||
size_t num_cuts;
|
||||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
|
||||
dmat->Info().num_row_, eps, &dummy_nlevel, &num_cuts);
|
||||
num_cuts = std::min(num_cuts, dmat->Info().num_row_);
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
sketch_batch_num_elements = dmat->Info().num_nonzero_;
|
||||
}
|
||||
dmat->Info().weights_.SetDevice(device);
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
size_t batch_nnz = batch.data.Size();
|
||||
for (auto begin = 0ull; begin < batch_nnz;
|
||||
begin += sketch_batch_num_elements) {
|
||||
size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements));
|
||||
if (dmat->Info().weights_.Size() > 0) {
|
||||
if (has_weights) {
|
||||
ProcessWeightedBatch(
|
||||
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
|
||||
&sketch_container, num_cuts, dmat->Info().num_col_);
|
||||
@ -369,6 +383,7 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
// Work out how many valid entries we have in each column
|
||||
dh::caching_device_vector<size_t> column_sizes_scan(adapter->NumColumns() + 1,
|
||||
0);
|
||||
|
||||
auto d_column_sizes_scan = column_sizes_scan.data().get();
|
||||
IsValidFunctor is_valid(missing);
|
||||
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
|
||||
@ -385,7 +400,7 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
size_t num_valid = host_column_sizes_scan.back();
|
||||
|
||||
// Copy current subset of valid elements into temporary storage and sort
|
||||
thrust::device_vector<Entry> sorted_entries(num_valid);
|
||||
dh::caching_device_vector<Entry> sorted_entries(num_valid);
|
||||
thrust::copy_if(thrust::cuda::par(alloc), entry_iter + begin,
|
||||
entry_iter + end, sorted_entries.begin(), is_valid);
|
||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
@ -406,6 +421,17 @@ template <typename AdapterT>
|
||||
HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
|
||||
float missing,
|
||||
size_t sketch_batch_num_elements) {
|
||||
size_t num_cuts = RequiredSampleCuts(num_bins, adapter->NumRows());
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
int bytes_per_element = 16;
|
||||
size_t bytes_cuts = num_cuts * adapter->NumColumns() * sizeof(SketchEntry);
|
||||
size_t bytes_num_columns = (adapter->NumColumns() + 1) * sizeof(size_t);
|
||||
// use up to 80% of available space
|
||||
sketch_batch_num_elements = (dh::AvailableMemory(adapter->DeviceIdx()) -
|
||||
bytes_cuts - bytes_num_columns) *
|
||||
0.8 / bytes_per_element;
|
||||
}
|
||||
|
||||
CHECK(adapter->NumRows() != data::kAdapterUnknownSize);
|
||||
CHECK(adapter->NumColumns() != data::kAdapterUnknownSize);
|
||||
|
||||
@ -421,16 +447,6 @@ HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
|
||||
SketchContainer sketch_container(num_bins, adapter->NumColumns(),
|
||||
adapter->NumRows());
|
||||
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * num_bins);
|
||||
size_t dummy_nlevel;
|
||||
size_t num_cuts;
|
||||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
|
||||
adapter->NumRows(), eps, &dummy_nlevel, &num_cuts);
|
||||
num_cuts = std::min(num_cuts, adapter->NumRows());
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
sketch_batch_num_elements = batch.Size();
|
||||
}
|
||||
for (auto begin = 0ull; begin < batch.Size();
|
||||
begin += sketch_batch_num_elements) {
|
||||
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
|
||||
|
||||
@ -199,14 +199,15 @@ class DenseCuts : public CutsBuilder {
|
||||
void Build(DMatrix* p_fmat, uint32_t max_num_bins) override;
|
||||
};
|
||||
|
||||
|
||||
// sketch_batch_num_elements 0 means autodetect. Only modify this for testing.
|
||||
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
size_t sketch_batch_num_elements = 10000000);
|
||||
size_t sketch_batch_num_elements = 0);
|
||||
|
||||
// sketch_batch_num_elements 0 means autodetect. Only modify this for testing.
|
||||
template <typename AdapterT>
|
||||
HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
|
||||
float missing,
|
||||
size_t sketch_batch_num_elements = 10000000);
|
||||
size_t sketch_batch_num_elements = 0);
|
||||
|
||||
/*!
|
||||
* \brief preprocessed global index matrix, in CSR format
|
||||
|
||||
@ -101,8 +101,7 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param)
|
||||
monitor_.StartCuda("Quantiles");
|
||||
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
|
||||
row_stride = GetRowStride(dmat);
|
||||
cuts_ = common::DeviceSketch(param.gpu_id, dmat, param.max_bin,
|
||||
param.gpu_batch_nrows);
|
||||
cuts_ = common::DeviceSketch(param.gpu_id, dmat, param.max_bin);
|
||||
monitor_.StopCuda("Quantiles");
|
||||
|
||||
monitor_.StartCuda("InitCompressedData");
|
||||
|
||||
@ -45,8 +45,7 @@ EllpackPageSource::EllpackPageSource(DMatrix* dmat,
|
||||
|
||||
monitor_.StartCuda("Quantiles");
|
||||
size_t row_stride = GetRowStride(dmat);
|
||||
auto cuts = common::DeviceSketch(param.gpu_id, dmat, param.max_bin,
|
||||
param.gpu_batch_nrows);
|
||||
auto cuts = common::DeviceSketch(param.gpu_id, dmat, param.max_bin);
|
||||
monitor_.StopCuda("Quantiles");
|
||||
|
||||
monitor_.StartCuda("WriteEllpackPages");
|
||||
|
||||
@ -44,8 +44,6 @@ struct GPUHistMakerTrainParam
|
||||
: public XGBoostParameter<GPUHistMakerTrainParam> {
|
||||
bool single_precision_histogram;
|
||||
bool deterministic_histogram;
|
||||
// number of rows in a single GPU batch
|
||||
int gpu_batch_nrows;
|
||||
bool debug_synchronize;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) {
|
||||
@ -53,11 +51,6 @@ struct GPUHistMakerTrainParam
|
||||
"Use single precision to build histograms.");
|
||||
DMLC_DECLARE_FIELD(deterministic_histogram).set_default(true).describe(
|
||||
"Pre-round the gradient for obtaining deterministic gradient histogram.");
|
||||
DMLC_DECLARE_FIELD(gpu_batch_nrows)
|
||||
.set_lower_bound(-1)
|
||||
.set_default(0)
|
||||
.describe("Number of rows in a GPU batch, used for finding quantiles on GPU; "
|
||||
"-1 to use all rows assignted to a GPU, and 0 to auto-deduce");
|
||||
DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe(
|
||||
"Check if all distributed tree are identical after tree construction.");
|
||||
}
|
||||
@ -1018,7 +1011,6 @@ class GPUHistMakerSpecialised {
|
||||
BatchParam batch_param{
|
||||
device_,
|
||||
param_.max_bin,
|
||||
hist_maker_param_.gpu_batch_nrows,
|
||||
generic_param_->gpu_page_size
|
||||
};
|
||||
auto page = (*dmat->GetBatches<EllpackPage>(batch_param).begin()).Impl();
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
|
||||
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
|
||||
#include "xgboost/c_api.h"
|
||||
|
||||
@ -20,6 +19,7 @@
|
||||
#include "../../../src/common/math.h"
|
||||
#include "../../../src/data/simple_dmatrix.h"
|
||||
#include "test_hist_util.h"
|
||||
#include "../../../include/xgboost/logging.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@ -39,7 +39,7 @@ TEST(hist_util, DeviceSketch) {
|
||||
std::vector<float> x = {1.0, 2.0, 3.0, 4.0, 5.0};
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
|
||||
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
|
||||
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
HistogramCuts host_cuts;
|
||||
DenseCuts builder(&host_cuts);
|
||||
builder.Build(dmat.get(), num_bins);
|
||||
@ -49,6 +49,59 @@ TEST(hist_util, DeviceSketch) {
|
||||
EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues());
|
||||
}
|
||||
|
||||
// Duplicate this function from hist_util.cu so we don't have to expose it in
|
||||
// header
|
||||
size_t RequiredSampleCutsTest(int max_bins, size_t num_rows) {
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * max_bins);
|
||||
size_t dummy_nlevel;
|
||||
size_t num_cuts;
|
||||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
|
||||
num_rows, eps, &dummy_nlevel, &num_cuts);
|
||||
return std::min(num_cuts, num_rows);
|
||||
}
|
||||
|
||||
TEST(hist_util, DeviceSketchMemory) {
|
||||
int num_columns = 100;
|
||||
int num_rows = 1000;
|
||||
int num_bins = 256;
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
|
||||
dh::GlobalMemoryLogger().Clear();
|
||||
ConsoleLogger::Configure({{"verbosity", "3"}});
|
||||
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
|
||||
size_t bytes_num_elements = num_rows * num_columns*sizeof(Entry);
|
||||
size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns *
|
||||
sizeof(DenseCuts::WQSketch::Entry);
|
||||
size_t bytes_constant = 1000;
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(),
|
||||
bytes_num_elements + bytes_cuts + bytes_constant);
|
||||
}
|
||||
|
||||
TEST(hist_util, DeviceSketchMemoryWeights) {
|
||||
int num_columns = 100;
|
||||
int num_rows = 1000;
|
||||
int num_bins = 256;
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
dmat->Info().weights_.HostVector() = GenerateRandomWeights(num_rows);
|
||||
|
||||
dh::GlobalMemoryLogger().Clear();
|
||||
ConsoleLogger::Configure({{"verbosity", "3"}});
|
||||
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
|
||||
size_t bytes_num_elements =
|
||||
num_rows * num_columns * (sizeof(Entry) + sizeof(float));
|
||||
size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns *
|
||||
sizeof(DenseCuts::WQSketch::Entry);
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(),
|
||||
size_t((bytes_num_elements + bytes_cuts) * 1.05));
|
||||
}
|
||||
|
||||
TEST(hist_util, DeviceSketchDeterminism) {
|
||||
int num_rows = 500;
|
||||
int num_columns = 5;
|
||||
@ -72,7 +125,7 @@ TEST(hist_util, DeviceSketchDeterminism) {
|
||||
for (auto num_categories : categorical_sizes) {
|
||||
auto x = GenerateRandomCategoricalSingleColumn(n, num_categories);
|
||||
auto dmat = GetDMatrixFromData(x, n, 1);
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
@ -86,7 +139,7 @@ TEST(hist_util, DeviceSketchMultipleColumns) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
@ -102,7 +155,7 @@ TEST(hist_util, DeviceSketchMultipleColumnsWeights) {
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
dmat->Info().weights_.HostVector() = GenerateRandomWeights(num_rows);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
@ -131,7 +184,7 @@ TEST(hist_util, DeviceSketchMultipleColumnsExternal) {
|
||||
auto dmat =
|
||||
GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 100, temp);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
@ -159,6 +212,29 @@ TEST(hist_util, AdapterDeviceSketch)
|
||||
EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues());
|
||||
}
|
||||
|
||||
TEST(hist_util, AdapterDeviceSketchMemory) {
|
||||
int num_columns = 100;
|
||||
int num_rows = 1000;
|
||||
int num_bins = 256;
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||
|
||||
dh::GlobalMemoryLogger().Clear();
|
||||
ConsoleLogger::Configure({{"verbosity", "3"}});
|
||||
auto cuts = AdapterDeviceSketch(&adapter, num_bins,
|
||||
std::numeric_limits<float>::quiet_NaN());
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
|
||||
size_t bytes_num_elements = num_rows * num_columns * sizeof(Entry);
|
||||
size_t bytes_num_columns = (num_columns + 1) * sizeof(size_t);
|
||||
size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns *
|
||||
sizeof(DenseCuts::WQSketch::Entry);
|
||||
size_t bytes_constant = 1000;
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(),
|
||||
bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant);
|
||||
}
|
||||
|
||||
TEST(hist_util, AdapterDeviceSketchCategorical) {
|
||||
int categorical_sizes[] = {2, 6, 8, 12};
|
||||
int num_bins = 256;
|
||||
|
||||
@ -14,10 +14,10 @@
|
||||
namespace xgboost {
|
||||
|
||||
TEST(EllpackPage, EmptyDMatrix) {
|
||||
constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256, kGpuBatchNRows = 64;
|
||||
constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256;
|
||||
constexpr float kSparsity = 0;
|
||||
auto dmat = *CreateDMatrix(kNRows, kNCols, kSparsity);
|
||||
auto& page = *dmat->GetBatches<EllpackPage>({0, kMaxBin, kGpuBatchNRows}).begin();
|
||||
auto& page = *dmat->GetBatches<EllpackPage>({0, kMaxBin}).begin();
|
||||
auto impl = page.Impl();
|
||||
ASSERT_EQ(impl->row_stride, 0);
|
||||
ASSERT_EQ(impl->cuts_.TotalBins(), 0);
|
||||
@ -101,7 +101,7 @@ TEST(EllpackPage, Copy) {
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::unique_ptr<DMatrix>
|
||||
dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
||||
BatchParam param{0, 256, 0, kPageSize};
|
||||
BatchParam param{0, 256, kPageSize};
|
||||
auto page = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
|
||||
|
||||
// Create an empty result page.
|
||||
@ -147,7 +147,7 @@ TEST(EllpackPage, Compact) {
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::unique_ptr<DMatrix>
|
||||
dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
||||
BatchParam param{0, 256, 0, kPageSize};
|
||||
BatchParam param{0, 256, kPageSize};
|
||||
auto page = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
|
||||
|
||||
// Create an empty result page.
|
||||
|
||||
@ -33,7 +33,7 @@ TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
||||
// Loop over the batches and count the records
|
||||
int64_t batch_count = 0;
|
||||
int64_t row_count = 0;
|
||||
for (const auto& batch : dmat->GetBatches<EllpackPage>({0, 256, 0, 7UL})) {
|
||||
for (const auto& batch : dmat->GetBatches<EllpackPage>({0, 256, 7UL})) {
|
||||
EXPECT_LT(batch.Size(), dmat->Info().num_row_);
|
||||
batch_count++;
|
||||
row_count += batch.Size();
|
||||
@ -57,7 +57,7 @@ TEST(SparsePageDMatrix, EllpackPageContent) {
|
||||
std::unique_ptr<DMatrix>
|
||||
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
||||
|
||||
BatchParam param{0, 2, 0, 0};
|
||||
BatchParam param{0, 2, 0};
|
||||
auto impl = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
|
||||
EXPECT_EQ(impl->base_rowid, 0);
|
||||
EXPECT_EQ(impl->n_rows, kRows);
|
||||
@ -107,7 +107,7 @@ TEST(SparsePageDMatrix, MultipleEllpackPageContent) {
|
||||
std::unique_ptr<DMatrix>
|
||||
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
||||
|
||||
BatchParam param{0, kMaxBins, 0, kPageSize};
|
||||
BatchParam param{0, kMaxBins, kPageSize};
|
||||
auto impl = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
|
||||
EXPECT_EQ(impl->base_rowid, 0);
|
||||
EXPECT_EQ(impl->n_rows, kRows);
|
||||
@ -148,7 +148,7 @@ TEST(SparsePageDMatrix, EllpackPageMultipleLoops) {
|
||||
std::unique_ptr<DMatrix>
|
||||
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
||||
|
||||
BatchParam param{0, kMaxBins, 0, kPageSize};
|
||||
BatchParam param{0, kMaxBins, kPageSize};
|
||||
auto impl = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
|
||||
|
||||
size_t current_row = 0;
|
||||
|
||||
@ -27,7 +27,7 @@ void VerifySampling(size_t page_size,
|
||||
}
|
||||
gpair.SetDevice(0);
|
||||
|
||||
BatchParam param{0, 256, 0, page_size};
|
||||
BatchParam param{0, 256, page_size};
|
||||
auto page = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
|
||||
if (page_size != 0) {
|
||||
EXPECT_NE(page->n_rows, kRows);
|
||||
@ -82,7 +82,7 @@ TEST(GradientBasedSampler, NoSampling_ExternalMemory) {
|
||||
auto gpair = GenerateRandomGradients(kRows);
|
||||
gpair.SetDevice(0);
|
||||
|
||||
BatchParam param{0, 256, 0, kPageSize};
|
||||
BatchParam param{0, 256, kPageSize};
|
||||
auto page = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
|
||||
EXPECT_NE(page->n_rows, kRows);
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ void TestDeterminsticHistogram() {
|
||||
|
||||
auto pp_m = CreateDMatrix(kRows, kCols, 0.5);
|
||||
auto& matrix = **pp_m;
|
||||
BatchParam batch_param{0, static_cast<int32_t>(kBins), 0, 0};
|
||||
BatchParam batch_param{0, static_cast<int32_t>(kBins), 0};
|
||||
|
||||
for (auto const& batch : matrix.GetBatches<EllpackPage>(batch_param)) {
|
||||
auto* page = batch.Impl();
|
||||
|
||||
@ -341,7 +341,7 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
// Loop over the batches and count the records
|
||||
int64_t batch_count = 0;
|
||||
int64_t row_count = 0;
|
||||
for (const auto& batch : dmat->GetBatches<EllpackPage>({0, max_bin, 0, gpu_page_size})) {
|
||||
for (const auto& batch : dmat->GetBatches<EllpackPage>({0, max_bin, gpu_page_size})) {
|
||||
EXPECT_LT(batch.Size(), dmat->Info().num_row_);
|
||||
batch_count++;
|
||||
row_count += batch.Size();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user