Fix external memory with gpu_hist and subsampling combination bug. (#7481)
Instead of accessing data from the `original_page_`, access the data from the first page of the available batch. fix #7476 Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
parent
7f399eac8b
commit
29bfa94bb6
@ -185,12 +185,10 @@ GradientBasedSample UniformSampling::Sample(common::Span<GradientPair> gpair, DM
|
|||||||
return {dmat->Info().num_row_, page_, gpair};
|
return {dmat->Info().num_row_, page_, gpair};
|
||||||
}
|
}
|
||||||
|
|
||||||
ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(EllpackPageImpl const* page,
|
ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows,
|
||||||
size_t n_rows,
|
|
||||||
BatchParam batch_param,
|
BatchParam batch_param,
|
||||||
float subsample)
|
float subsample)
|
||||||
: original_page_(page),
|
: batch_param_(std::move(batch_param)),
|
||||||
batch_param_(std::move(batch_param)),
|
|
||||||
subsample_(subsample),
|
subsample_(subsample),
|
||||||
sample_row_index_(n_rows) {}
|
sample_row_index_(n_rows) {}
|
||||||
|
|
||||||
@ -218,15 +216,17 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientP
|
|||||||
sample_row_index_.begin(),
|
sample_row_index_.begin(),
|
||||||
ClearEmptyRows());
|
ClearEmptyRows());
|
||||||
|
|
||||||
|
auto batch_iterator = dmat->GetBatches<EllpackPage>(batch_param_);
|
||||||
|
auto first_page = (*batch_iterator.begin()).Impl();
|
||||||
// Create a new ELLPACK page with empty rows.
|
// Create a new ELLPACK page with empty rows.
|
||||||
page_.reset(); // Release the device memory first before reallocating
|
page_.reset(); // Release the device memory first before reallocating
|
||||||
page_.reset(new EllpackPageImpl(
|
page_.reset(new EllpackPageImpl(
|
||||||
batch_param_.gpu_id, original_page_->Cuts(), original_page_->is_dense,
|
batch_param_.gpu_id, first_page->Cuts(), first_page->is_dense,
|
||||||
original_page_->row_stride, sample_rows));
|
first_page->row_stride, sample_rows));
|
||||||
|
|
||||||
// Compact the ELLPACK pages into the single sample page.
|
// Compact the ELLPACK pages into the single sample page.
|
||||||
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
|
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
|
||||||
for (auto& batch : dmat->GetBatches<EllpackPage>(batch_param_)) {
|
for (auto& batch : batch_iterator) {
|
||||||
page_->Compact(batch_param_.gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
|
page_->Compact(batch_param_.gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -259,12 +259,10 @@ GradientBasedSample GradientBasedSampling::Sample(common::Span<GradientPair> gpa
|
|||||||
}
|
}
|
||||||
|
|
||||||
ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(
|
ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(
|
||||||
EllpackPageImpl const* page,
|
|
||||||
size_t n_rows,
|
size_t n_rows,
|
||||||
BatchParam batch_param,
|
BatchParam batch_param,
|
||||||
float subsample)
|
float subsample)
|
||||||
: original_page_(page),
|
: batch_param_(std::move(batch_param)),
|
||||||
batch_param_(std::move(batch_param)),
|
|
||||||
subsample_(subsample),
|
subsample_(subsample),
|
||||||
threshold_(n_rows + 1, 0.0f),
|
threshold_(n_rows + 1, 0.0f),
|
||||||
grad_sum_(n_rows, 0.0f),
|
grad_sum_(n_rows, 0.0f),
|
||||||
@ -300,15 +298,17 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span<Gra
|
|||||||
sample_row_index_.begin(),
|
sample_row_index_.begin(),
|
||||||
ClearEmptyRows());
|
ClearEmptyRows());
|
||||||
|
|
||||||
|
auto batch_iterator = dmat->GetBatches<EllpackPage>(batch_param_);
|
||||||
|
auto first_page = (*batch_iterator.begin()).Impl();
|
||||||
// Create a new ELLPACK page with empty rows.
|
// Create a new ELLPACK page with empty rows.
|
||||||
page_.reset(); // Release the device memory first before reallocating
|
page_.reset(); // Release the device memory first before reallocating
|
||||||
page_.reset(new EllpackPageImpl(batch_param_.gpu_id, original_page_->Cuts(),
|
page_.reset(new EllpackPageImpl(batch_param_.gpu_id, first_page->Cuts(),
|
||||||
original_page_->is_dense,
|
first_page->is_dense,
|
||||||
original_page_->row_stride, sample_rows));
|
first_page->row_stride, sample_rows));
|
||||||
|
|
||||||
// Compact the ELLPACK pages into the single sample page.
|
// Compact the ELLPACK pages into the single sample page.
|
||||||
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
|
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
|
||||||
for (auto& batch : dmat->GetBatches<EllpackPage>(batch_param_)) {
|
for (auto& batch : batch_iterator) {
|
||||||
page_->Compact(batch_param_.gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
|
page_->Compact(batch_param_.gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -329,7 +329,7 @@ GradientBasedSampler::GradientBasedSampler(EllpackPageImpl const* page,
|
|||||||
switch (sampling_method) {
|
switch (sampling_method) {
|
||||||
case TrainParam::kUniform:
|
case TrainParam::kUniform:
|
||||||
if (is_external_memory) {
|
if (is_external_memory) {
|
||||||
strategy_.reset(new ExternalMemoryUniformSampling(page, n_rows, batch_param, subsample));
|
strategy_.reset(new ExternalMemoryUniformSampling(n_rows, batch_param, subsample));
|
||||||
} else {
|
} else {
|
||||||
strategy_.reset(new UniformSampling(page, subsample));
|
strategy_.reset(new UniformSampling(page, subsample));
|
||||||
}
|
}
|
||||||
@ -337,7 +337,7 @@ GradientBasedSampler::GradientBasedSampler(EllpackPageImpl const* page,
|
|||||||
case TrainParam::kGradientBased:
|
case TrainParam::kGradientBased:
|
||||||
if (is_external_memory) {
|
if (is_external_memory) {
|
||||||
strategy_.reset(
|
strategy_.reset(
|
||||||
new ExternalMemoryGradientBasedSampling(page, n_rows, batch_param, subsample));
|
new ExternalMemoryGradientBasedSampling(n_rows, batch_param, subsample));
|
||||||
} else {
|
} else {
|
||||||
strategy_.reset(new GradientBasedSampling(page, n_rows, batch_param, subsample));
|
strategy_.reset(new GradientBasedSampling(page, n_rows, batch_param, subsample));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -66,14 +66,12 @@ class UniformSampling : public SamplingStrategy {
|
|||||||
/*! \brief No sampling in external memory mode. */
|
/*! \brief No sampling in external memory mode. */
|
||||||
class ExternalMemoryUniformSampling : public SamplingStrategy {
|
class ExternalMemoryUniformSampling : public SamplingStrategy {
|
||||||
public:
|
public:
|
||||||
ExternalMemoryUniformSampling(EllpackPageImpl const* page,
|
ExternalMemoryUniformSampling(size_t n_rows,
|
||||||
size_t n_rows,
|
|
||||||
BatchParam batch_param,
|
BatchParam batch_param,
|
||||||
float subsample);
|
float subsample);
|
||||||
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
|
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
EllpackPageImpl const* original_page_;
|
|
||||||
BatchParam batch_param_;
|
BatchParam batch_param_;
|
||||||
float subsample_;
|
float subsample_;
|
||||||
std::unique_ptr<EllpackPageImpl> page_;
|
std::unique_ptr<EllpackPageImpl> page_;
|
||||||
@ -100,14 +98,12 @@ class GradientBasedSampling : public SamplingStrategy {
|
|||||||
/*! \brief Gradient-based sampling in external memory mode.. */
|
/*! \brief Gradient-based sampling in external memory mode.. */
|
||||||
class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
|
class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
|
||||||
public:
|
public:
|
||||||
ExternalMemoryGradientBasedSampling(EllpackPageImpl const* page,
|
ExternalMemoryGradientBasedSampling(size_t n_rows,
|
||||||
size_t n_rows,
|
|
||||||
BatchParam batch_param,
|
BatchParam batch_param,
|
||||||
float subsample);
|
float subsample);
|
||||||
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
|
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
EllpackPageImpl const* original_page_;
|
|
||||||
BatchParam batch_param_;
|
BatchParam batch_param_;
|
||||||
float subsample_;
|
float subsample_;
|
||||||
dh::caching_device_vector<float> threshold_;
|
dh::caching_device_vector<float> threshold_;
|
||||||
|
|||||||
@ -17,16 +17,23 @@ def test_gpu_single_batch() -> None:
|
|||||||
|
|
||||||
@pytest.mark.skipif(**no_cupy())
|
@pytest.mark.skipif(**no_cupy())
|
||||||
@given(
|
@given(
|
||||||
strategies.integers(0, 1024), strategies.integers(1, 7), strategies.integers(0, 13)
|
strategies.integers(0, 1024),
|
||||||
|
strategies.integers(1, 7),
|
||||||
|
strategies.integers(0, 13),
|
||||||
|
strategies.booleans(),
|
||||||
)
|
)
|
||||||
@settings(deadline=None)
|
@settings(deadline=None)
|
||||||
def test_gpu_data_iterator(
|
def test_gpu_data_iterator(
|
||||||
n_samples_per_batch: int, n_features: int, n_batches: int
|
n_samples_per_batch: int, n_features: int, n_batches: int, subsample: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
run_data_iterator(n_samples_per_batch, n_features, n_batches, "gpu_hist", True)
|
run_data_iterator(
|
||||||
run_data_iterator(n_samples_per_batch, n_features, n_batches, "gpu_hist", False)
|
n_samples_per_batch, n_features, n_batches, "gpu_hist", subsample, True
|
||||||
|
)
|
||||||
|
run_data_iterator(
|
||||||
|
n_samples_per_batch, n_features, n_batches, "gpu_hist", subsample, False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_cpu_data_iterator() -> None:
|
def test_cpu_data_iterator() -> None:
|
||||||
"""Make sure CPU algorithm can handle GPU inputs"""
|
"""Make sure CPU algorithm can handle GPU inputs"""
|
||||||
run_data_iterator(1024, 2, 3, "approx", True)
|
run_data_iterator(1024, 2, 3, "approx", False, True)
|
||||||
|
|||||||
@ -68,9 +68,14 @@ def run_data_iterator(
|
|||||||
n_features: int,
|
n_features: int,
|
||||||
n_batches: int,
|
n_batches: int,
|
||||||
tree_method: str,
|
tree_method: str,
|
||||||
|
subsample: bool,
|
||||||
use_cupy: bool,
|
use_cupy: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
n_rounds = 2
|
n_rounds = 2
|
||||||
|
# The test is more difficult to pass if the subsample rate is smaller as the root_sum
|
||||||
|
# is accumulated in parallel. Reductions with different number of entries lead to
|
||||||
|
# different floating point errors.
|
||||||
|
subsample_rate = 0.8 if subsample else 1.0
|
||||||
|
|
||||||
it = IteratorForTest(
|
it = IteratorForTest(
|
||||||
*make_batches(n_samples_per_batch, n_features, n_batches, use_cupy)
|
*make_batches(n_samples_per_batch, n_features, n_batches, use_cupy)
|
||||||
@ -84,9 +89,19 @@ def run_data_iterator(
|
|||||||
assert Xy.num_row() == n_samples_per_batch * n_batches
|
assert Xy.num_row() == n_samples_per_batch * n_batches
|
||||||
assert Xy.num_col() == n_features
|
assert Xy.num_col() == n_features
|
||||||
|
|
||||||
|
parameters = {
|
||||||
|
"tree_method": tree_method,
|
||||||
|
"max_depth": 2,
|
||||||
|
"subsample": subsample_rate,
|
||||||
|
"seed": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tree_method == "gpu_hist":
|
||||||
|
parameters["sampling_method"] = "gradient_based"
|
||||||
|
|
||||||
results_from_it: xgb.callback.EvaluationMonitor.EvalsLog = {}
|
results_from_it: xgb.callback.EvaluationMonitor.EvalsLog = {}
|
||||||
from_it = xgb.train(
|
from_it = xgb.train(
|
||||||
{"tree_method": tree_method, "max_depth": 2},
|
parameters,
|
||||||
Xy,
|
Xy,
|
||||||
num_boost_round=n_rounds,
|
num_boost_round=n_rounds,
|
||||||
evals=[(Xy, "Train")],
|
evals=[(Xy, "Train")],
|
||||||
@ -102,7 +117,7 @@ def run_data_iterator(
|
|||||||
|
|
||||||
results_from_arrays: xgb.callback.EvaluationMonitor.EvalsLog = {}
|
results_from_arrays: xgb.callback.EvaluationMonitor.EvalsLog = {}
|
||||||
from_arrays = xgb.train(
|
from_arrays = xgb.train(
|
||||||
{"tree_method": tree_method, "max_depth": 2},
|
parameters,
|
||||||
Xy,
|
Xy,
|
||||||
num_boost_round=n_rounds,
|
num_boost_round=n_rounds,
|
||||||
evals=[(Xy, "Train")],
|
evals=[(Xy, "Train")],
|
||||||
@ -126,11 +141,21 @@ def run_data_iterator(
|
|||||||
|
|
||||||
|
|
||||||
@given(
|
@given(
|
||||||
strategies.integers(0, 1024), strategies.integers(1, 7), strategies.integers(0, 13)
|
strategies.integers(0, 1024),
|
||||||
|
strategies.integers(1, 7),
|
||||||
|
strategies.integers(0, 13),
|
||||||
|
strategies.booleans(),
|
||||||
)
|
)
|
||||||
@settings(deadline=None)
|
@settings(deadline=None)
|
||||||
def test_data_iterator(
|
def test_data_iterator(
|
||||||
n_samples_per_batch: int, n_features: int, n_batches: int
|
n_samples_per_batch: int,
|
||||||
|
n_features: int,
|
||||||
|
n_batches: int,
|
||||||
|
subsample: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
run_data_iterator(n_samples_per_batch, n_features, n_batches, "approx", False)
|
run_data_iterator(
|
||||||
run_data_iterator(n_samples_per_batch, n_features, n_batches, "hist", False)
|
n_samples_per_batch, n_features, n_batches, "approx", subsample, False
|
||||||
|
)
|
||||||
|
run_data_iterator(
|
||||||
|
n_samples_per_batch, n_features, n_batches, "hist", subsample, False
|
||||||
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user