[EM] Merge GPU partitioning with histogram building. (#10766)
- Stop concatenating pages if there's no subsampling. - Use a single iteration for histogram build and partitioning.
This commit is contained in:
parent
98ac153265
commit
e1a2c1bbb3
@ -222,10 +222,12 @@ def check_extmem_qdm(
|
||||
Xy = xgb.QuantileDMatrix(X, y, weight=w)
|
||||
booster = xgb.train({"device": device}, Xy, num_boost_round=8)
|
||||
|
||||
cut_it = Xy_it.get_quantile_cut()
|
||||
cut = Xy.get_quantile_cut()
|
||||
np.testing.assert_allclose(cut_it[0], cut[0])
|
||||
np.testing.assert_allclose(cut_it[1], cut[1])
|
||||
if device == "cpu":
|
||||
# Get cuts from ellpack without CPU-GPU interpolation is not yet supported.
|
||||
cut_it = Xy_it.get_quantile_cut()
|
||||
cut = Xy.get_quantile_cut()
|
||||
np.testing.assert_allclose(cut_it[0], cut[0])
|
||||
np.testing.assert_allclose(cut_it[1], cut[1])
|
||||
|
||||
predt_it = booster_it.predict(Xy_it)
|
||||
predt = booster.predict(Xy)
|
||||
|
||||
@ -158,28 +158,10 @@ GradientBasedSample NoSampling::Sample(Context const*, common::Span<GradientPair
|
||||
ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param)
|
||||
: batch_param_{std::move(batch_param)} {}
|
||||
|
||||
GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx,
|
||||
GradientBasedSample ExternalMemoryNoSampling::Sample(Context const*,
|
||||
common::Span<GradientPair> gpair,
|
||||
DMatrix* p_fmat) {
|
||||
std::shared_ptr<EllpackPage> new_page;
|
||||
if (!page_concatenated_) {
|
||||
// Concatenate all the external memory ELLPACK pages into a single in-memory page.
|
||||
bst_idx_t offset = 0;
|
||||
for (auto& batch : p_fmat->GetBatches<EllpackPage>(ctx, batch_param_)) {
|
||||
auto page = batch.Impl();
|
||||
if (!new_page) {
|
||||
new_page = std::make_shared<EllpackPage>();
|
||||
*new_page->Impl() = EllpackPageImpl(ctx, page->CutsShared(), page->is_dense,
|
||||
page->row_stride, p_fmat->Info().num_row_);
|
||||
}
|
||||
bst_idx_t num_elements = new_page->Impl()->Copy(ctx, page, offset);
|
||||
offset += num_elements;
|
||||
}
|
||||
page_concatenated_ = true;
|
||||
this->p_fmat_new_ =
|
||||
std::make_unique<data::IterativeDMatrix>(new_page, p_fmat->Info(), batch_param_);
|
||||
}
|
||||
return {this->p_fmat_new_.get(), gpair};
|
||||
return {p_fmat, gpair};
|
||||
}
|
||||
|
||||
UniformSampling::UniformSampling(BatchParam batch_param, float subsample)
|
||||
|
||||
@ -46,8 +46,6 @@ class ExternalMemoryNoSampling : public SamplingStrategy {
|
||||
|
||||
private:
|
||||
BatchParam batch_param_;
|
||||
std::unique_ptr<DMatrix> p_fmat_new_{nullptr};
|
||||
bool page_concatenated_{false};
|
||||
};
|
||||
|
||||
/*! \brief Uniform sampling in in-memory mode. */
|
||||
|
||||
@ -22,6 +22,10 @@ void RowPartitioner::Reset(Context const* ctx, bst_idx_t n_samples, bst_idx_t ba
|
||||
NodePositionInfo{Segment{0, static_cast<cuda_impl::RowIndexT>(n_samples)}});
|
||||
|
||||
thrust::sequence(ctx->CUDACtx()->CTP(), ridx_.data(), ridx_.data() + ridx_.size(), base_rowid);
|
||||
|
||||
// Pre-allocate some host memory
|
||||
this->pinned_.GetSpan<std::int32_t>(1 << 11);
|
||||
this->pinned2_.GetSpan<std::int32_t>(1 << 13);
|
||||
}
|
||||
|
||||
RowPartitioner::~RowPartitioner() = default;
|
||||
|
||||
@ -200,6 +200,7 @@ struct GPUHistMakerDevice {
|
||||
|
||||
// Reset values for each update iteration
|
||||
[[nodiscard]] DMatrix* Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* p_fmat) {
|
||||
this->monitor.Start(__func__);
|
||||
auto const& info = p_fmat->Info();
|
||||
this->column_sampler_->Init(ctx_, p_fmat->Info().num_col_, info.feature_weights.HostVector(),
|
||||
param.colsample_bynode, param.colsample_bylevel,
|
||||
@ -252,7 +253,7 @@ struct GPUHistMakerDevice {
|
||||
this->histogram_.Reset(ctx_, this->hist_param_->MaxCachedHistNodes(ctx_->Device()),
|
||||
feature_groups->DeviceAccessor(ctx_->Device()), cuts_->TotalBins(),
|
||||
false);
|
||||
|
||||
this->monitor.Stop(__func__);
|
||||
return p_fmat;
|
||||
}
|
||||
|
||||
@ -346,6 +347,38 @@ struct GPUHistMakerDevice {
|
||||
monitor.Stop(__func__);
|
||||
}
|
||||
|
||||
void ReduceHist(DMatrix* p_fmat, std::vector<GPUExpandEntry> const& candidates,
|
||||
std::vector<bst_node_t> const& build_nidx,
|
||||
std::vector<bst_node_t> const& subtraction_nidx) {
|
||||
if (candidates.empty()) {
|
||||
return;
|
||||
}
|
||||
this->monitor.Start(__func__);
|
||||
|
||||
// Reduce all in one go
|
||||
// This gives much better latency in a distributed setting when processing a large batch
|
||||
this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), build_nidx.at(0), build_nidx.size());
|
||||
// Perform subtraction for sibiling nodes
|
||||
auto need_build = this->histogram_.SubtractHist(candidates, build_nidx, subtraction_nidx);
|
||||
if (need_build.empty()) {
|
||||
this->monitor.Stop(__func__);
|
||||
return;
|
||||
}
|
||||
|
||||
// Build the nodes that can not obtain the histogram using subtraction. This is the slow path.
|
||||
std::int32_t k = 0;
|
||||
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
||||
for (auto nidx : need_build) {
|
||||
this->BuildHist(page, k, nidx);
|
||||
}
|
||||
++k;
|
||||
}
|
||||
for (auto nidx : need_build) {
|
||||
this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), nidx, 1);
|
||||
}
|
||||
this->monitor.Stop(__func__);
|
||||
}
|
||||
|
||||
void UpdatePositionColumnSplit(EllpackDeviceAccessor d_matrix,
|
||||
std::vector<NodeSplitData> const& split_data,
|
||||
std::vector<bst_node_t> const& nidx,
|
||||
@ -434,56 +467,74 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
};
|
||||
|
||||
void UpdatePosition(DMatrix* p_fmat, std::vector<GPUExpandEntry> const& candidates,
|
||||
RegTree* p_tree) {
|
||||
if (candidates.empty()) {
|
||||
// Update position and build histogram.
|
||||
void PartitionAndBuildHist(DMatrix* p_fmat, std::vector<GPUExpandEntry> const& expand_set,
|
||||
std::vector<GPUExpandEntry> const& candidates, RegTree const* p_tree) {
|
||||
if (expand_set.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
monitor.Start(__func__);
|
||||
CHECK_LE(candidates.size(), expand_set.size());
|
||||
|
||||
auto [nidx, left_nidx, right_nidx, split_data] = this->CreatePartitionNodes(p_tree, candidates);
|
||||
// Update all the nodes if working with external memory, this saves us from working
|
||||
// with the finalize position call, which adds an additional iteration and requires
|
||||
// special handling for row index.
|
||||
bool const is_single_block = p_fmat->SingleColBlock();
|
||||
|
||||
for (size_t i = 0; i < candidates.size(); i++) {
|
||||
auto const& e = candidates[i];
|
||||
RegTree::Node const& split_node = (*p_tree)[e.nid];
|
||||
auto split_type = p_tree->NodeSplitType(e.nid);
|
||||
nidx[i] = e.nid;
|
||||
left_nidx[i] = split_node.LeftChild();
|
||||
right_nidx[i] = split_node.RightChild();
|
||||
split_data[i] = NodeSplitData{split_node, split_type, evaluator_.GetDeviceNodeCats(e.nid)};
|
||||
// Prepare for update partition
|
||||
auto [nidx, left_nidx, right_nidx, split_data] =
|
||||
this->CreatePartitionNodes(p_tree, is_single_block ? candidates : expand_set);
|
||||
|
||||
CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat);
|
||||
}
|
||||
// Prepare for build hist
|
||||
std::vector<bst_node_t> build_nidx(candidates.size());
|
||||
std::vector<bst_node_t> subtraction_nidx(candidates.size());
|
||||
auto prefetch_copy =
|
||||
AssignNodes(p_tree, this->quantiser.get(), candidates, build_nidx, subtraction_nidx);
|
||||
|
||||
CHECK_EQ(p_fmat->NumBatches(), 1);
|
||||
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
||||
this->histogram_.AllocateHistograms(ctx_, build_nidx, subtraction_nidx);
|
||||
|
||||
monitor.Start("Partition-BuildHist");
|
||||
|
||||
std::int32_t k{0};
|
||||
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(prefetch_copy))) {
|
||||
auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_->Device());
|
||||
auto go_left = GoLeftOp{d_matrix};
|
||||
|
||||
// Partition histogram.
|
||||
monitor.Start("UpdatePositionBatch");
|
||||
if (p_fmat->Info().IsColumnSplit()) {
|
||||
UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx);
|
||||
monitor.Stop(__func__);
|
||||
return;
|
||||
} else {
|
||||
partitioners_.at(k)->UpdatePositionBatch(
|
||||
nidx, left_nidx, right_nidx, split_data,
|
||||
[=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/,
|
||||
const NodeSplitData& data) { return go_left(ridx, data); });
|
||||
}
|
||||
auto go_left = GoLeftOp{d_matrix};
|
||||
partitioners_.front()->UpdatePositionBatch(
|
||||
nidx, left_nidx, right_nidx, split_data,
|
||||
[=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/,
|
||||
const NodeSplitData& data) { return go_left(ridx, data); });
|
||||
monitor.Stop("UpdatePositionBatch");
|
||||
|
||||
for (auto nidx : build_nidx) {
|
||||
this->BuildHist(page, k, nidx);
|
||||
}
|
||||
|
||||
++k;
|
||||
}
|
||||
|
||||
monitor.Stop("Partition-BuildHist");
|
||||
|
||||
this->ReduceHist(p_fmat, candidates, build_nidx, subtraction_nidx);
|
||||
|
||||
monitor.Stop(__func__);
|
||||
}
|
||||
|
||||
// After tree update is finished, update the position of all training
|
||||
// instances to their final leaf. This information is used later to update the
|
||||
// prediction cache
|
||||
void FinalisePosition(DMatrix* p_fmat, RegTree const* p_tree, ObjInfo task, bst_idx_t n_samples,
|
||||
void FinalisePosition(DMatrix* p_fmat, RegTree const* p_tree, ObjInfo task,
|
||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
if (!p_fmat->SingleColBlock() && task.UpdateTreeLeaf()) {
|
||||
LOG(FATAL) << "Current objective function can not be used with external memory.";
|
||||
}
|
||||
if (p_fmat->Info().num_row_ != n_samples) {
|
||||
if (static_cast<std::size_t>(p_fmat->NumBatches() + 1) != this->batch_ptr_.size()) {
|
||||
// External memory with concatenation. Not supported.
|
||||
p_out_position->Resize(0);
|
||||
positions_.clear();
|
||||
@ -577,60 +628,6 @@ struct GPUHistMakerDevice {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Build GPU local histograms for the left and right child of some parent node
|
||||
*/
|
||||
void BuildHistLeftRight(DMatrix* p_fmat, std::vector<GPUExpandEntry> const& candidates,
|
||||
const RegTree& tree) {
|
||||
if (candidates.empty()) {
|
||||
return;
|
||||
}
|
||||
this->monitor.Start(__func__);
|
||||
// Some nodes we will manually compute histograms
|
||||
// others we will do by subtraction
|
||||
std::vector<bst_node_t> hist_nidx(candidates.size());
|
||||
std::vector<bst_node_t> subtraction_nidx(candidates.size());
|
||||
auto prefetch_copy =
|
||||
AssignNodes(&tree, this->quantiser.get(), candidates, hist_nidx, subtraction_nidx);
|
||||
|
||||
std::vector<int> all_new = hist_nidx;
|
||||
all_new.insert(all_new.end(), subtraction_nidx.begin(), subtraction_nidx.end());
|
||||
// Allocate the histograms
|
||||
// Guaranteed contiguous memory
|
||||
histogram_.AllocateHistograms(ctx_, all_new);
|
||||
|
||||
std::int32_t k = 0;
|
||||
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(prefetch_copy))) {
|
||||
for (auto nidx : hist_nidx) {
|
||||
this->BuildHist(page, k, nidx);
|
||||
}
|
||||
++k;
|
||||
}
|
||||
|
||||
// Reduce all in one go
|
||||
// This gives much better latency in a distributed setting
|
||||
// when processing a large batch
|
||||
this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), hist_nidx.at(0), hist_nidx.size());
|
||||
|
||||
for (size_t i = 0; i < subtraction_nidx.size(); i++) {
|
||||
auto build_hist_nidx = hist_nidx.at(i);
|
||||
auto subtraction_trick_nidx = subtraction_nidx.at(i);
|
||||
auto parent_nidx = candidates.at(i).nid;
|
||||
|
||||
if (!this->histogram_.SubtractionTrick(parent_nidx, build_hist_nidx,
|
||||
subtraction_trick_nidx)) {
|
||||
// Calculate other histogram manually
|
||||
std::int32_t k = 0;
|
||||
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
||||
this->BuildHist(page, k, subtraction_trick_nidx);
|
||||
++k;
|
||||
}
|
||||
this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), subtraction_trick_nidx, 1);
|
||||
}
|
||||
}
|
||||
this->monitor.Stop(__func__);
|
||||
}
|
||||
|
||||
void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) {
|
||||
RegTree& tree = *p_tree;
|
||||
|
||||
@ -681,8 +678,9 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
GPUExpandEntry InitRoot(DMatrix* p_fmat, RegTree* p_tree) {
|
||||
constexpr bst_node_t kRootNIdx = 0;
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
this->monitor.Start(__func__);
|
||||
|
||||
constexpr bst_node_t kRootNIdx = RegTree::kRoot;
|
||||
auto quantiser = *this->quantiser;
|
||||
auto gpair_it = dh::MakeTransformIterator<GradientPairInt64>(
|
||||
dh::tbegin(gpair),
|
||||
@ -697,6 +695,7 @@ struct GPUHistMakerDevice {
|
||||
|
||||
histogram_.AllocateHistograms(ctx_, {kRootNIdx});
|
||||
std::int32_t k = 0;
|
||||
CHECK_EQ(p_fmat->NumBatches(), this->partitioners_.size());
|
||||
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
||||
this->BuildHist(page, k, kRootNIdx);
|
||||
++k;
|
||||
@ -712,25 +711,18 @@ struct GPUHistMakerDevice {
|
||||
|
||||
// Generate first split
|
||||
auto root_entry = this->EvaluateRootSplit(p_fmat, root_sum_quantised);
|
||||
|
||||
this->monitor.Stop(__func__);
|
||||
return root_entry;
|
||||
}
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo const* task,
|
||||
RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
bool const is_single_block = p_fmat->SingleColBlock();
|
||||
bst_idx_t const n_samples = p_fmat->Info().num_row_;
|
||||
|
||||
auto& tree = *p_tree;
|
||||
// Process maximum 32 nodes at a time
|
||||
Driver<GPUExpandEntry> driver(param, 32);
|
||||
|
||||
monitor.Start("Reset");
|
||||
p_fmat = this->Reset(gpair_all, p_fmat);
|
||||
monitor.Stop("Reset");
|
||||
|
||||
monitor.Start("InitRoot");
|
||||
driver.Push({this->InitRoot(p_fmat, p_tree)});
|
||||
monitor.Stop("InitRoot");
|
||||
|
||||
// The set of leaves that can be expanded asynchronously
|
||||
auto expand_set = driver.Pop();
|
||||
@ -740,20 +732,17 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
// Get the candidates we are allowed to expand further
|
||||
// e.g. We do not bother further processing nodes whose children are beyond max depth
|
||||
std::vector<GPUExpandEntry> filtered_expand_set;
|
||||
std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(filtered_expand_set),
|
||||
[&](const auto& e) { return driver.IsChildValid(e); });
|
||||
std::vector<GPUExpandEntry> valid_candidates;
|
||||
std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(valid_candidates),
|
||||
[&](auto const& e) { return driver.IsChildValid(e); });
|
||||
|
||||
// Allocaate children nodes.
|
||||
auto new_candidates =
|
||||
pinned.GetSpan<GPUExpandEntry>(filtered_expand_set.size() * 2, GPUExpandEntry{});
|
||||
// Update all the nodes if working with external memory, this saves us from working
|
||||
// with the finalize position call, which adds an additional iteration and requires
|
||||
// special handling for row index.
|
||||
this->UpdatePosition(p_fmat, is_single_block ? filtered_expand_set : expand_set, p_tree);
|
||||
pinned.GetSpan<GPUExpandEntry>(valid_candidates.size() * 2, GPUExpandEntry());
|
||||
|
||||
this->BuildHistLeftRight(p_fmat, filtered_expand_set, tree);
|
||||
this->PartitionAndBuildHist(p_fmat, expand_set, valid_candidates, p_tree);
|
||||
|
||||
this->EvaluateSplits(p_fmat, filtered_expand_set, *p_tree, new_candidates);
|
||||
this->EvaluateSplits(p_fmat, valid_candidates, *p_tree, new_candidates);
|
||||
dh::DefaultStream().Sync();
|
||||
|
||||
driver.Push(new_candidates.begin(), new_candidates.end());
|
||||
@ -764,10 +753,10 @@ struct GPUHistMakerDevice {
|
||||
// be spliable before evaluation but invalid after evaluation as we have more
|
||||
// restrictions like min loss change after evalaution. Therefore, the check condition
|
||||
// is greater than or equal to.
|
||||
if (is_single_block) {
|
||||
if (p_fmat->SingleColBlock()) {
|
||||
CHECK_GE(p_tree->NumNodes(), this->partitioners_.front()->GetNumNodes());
|
||||
}
|
||||
this->FinalisePosition(p_fmat, p_tree, *task, n_samples, p_out_position);
|
||||
this->FinalisePosition(p_fmat, p_tree, *task, p_out_position);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -67,7 +67,6 @@ TEST(GradientBasedSampler, NoSampling) {
|
||||
VerifySampling(kPageSize, kSubsample, kSamplingMethod);
|
||||
}
|
||||
|
||||
// In external mode, when not sampling, we concatenate the pages together.
|
||||
TEST(GradientBasedSampler, NoSamplingExternalMemory) {
|
||||
constexpr size_t kRows = 2048;
|
||||
constexpr size_t kCols = 1;
|
||||
@ -81,34 +80,11 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
|
||||
gpair.SetDevice(ctx.Device());
|
||||
|
||||
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
||||
EXPECT_NE(page->n_rows, kRows);
|
||||
|
||||
GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true);
|
||||
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
|
||||
auto p_fmat = sample.p_fmat;
|
||||
EXPECT_EQ(sample.p_fmat->Info().num_row_, kRows);
|
||||
EXPECT_EQ(sample.gpair.size(), gpair.Size());
|
||||
EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer());
|
||||
EXPECT_EQ(p_fmat->Info().num_row_, kRows);
|
||||
|
||||
ASSERT_EQ(p_fmat->NumBatches(), 1);
|
||||
for (auto const& sampled_page : p_fmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer;
|
||||
auto h_accessor = sampled_page.Impl()->GetHostAccessor(&ctx, &h_gidx_buffer);
|
||||
|
||||
std::size_t offset = 0;
|
||||
for (auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||
auto page = batch.Impl();
|
||||
std::vector<common::CompressedByteT> h_page_gidx_buffer;
|
||||
auto page_accessor = page->GetHostAccessor(&ctx, &h_page_gidx_buffer);
|
||||
size_t num_elements = page->n_rows * page->row_stride;
|
||||
for (size_t i = 0; i < num_elements; i++) {
|
||||
EXPECT_EQ(h_accessor.gidx_iter[i + offset], page_accessor.gidx_iter[i]);
|
||||
}
|
||||
offset += num_elements;
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(p_fmat, dmat.get());
|
||||
}
|
||||
|
||||
TEST(GradientBasedSampler, UniformSampling) {
|
||||
|
||||
@ -4,7 +4,7 @@ import pytest
|
||||
from hypothesis import given, settings, strategies
|
||||
|
||||
from xgboost.testing import no_cupy
|
||||
from xgboost.testing.updater import check_quantile_loss_extmem
|
||||
from xgboost.testing.updater import check_extmem_qdm, check_quantile_loss_extmem
|
||||
|
||||
sys.path.append("tests/python")
|
||||
from test_data_iterator import run_data_iterator
|
||||
@ -59,6 +59,14 @@ def test_cpu_data_iterator() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_quantile_objective() -> None:
|
||||
with pytest.raises(ValueError, match="external memory"):
|
||||
check_quantile_loss_extmem(2, 2, 2, "hist", "cuda")
|
||||
@given(
|
||||
strategies.integers(1, 2048),
|
||||
strategies.integers(1, 8),
|
||||
strategies.integers(1, 4),
|
||||
strategies.booleans(),
|
||||
)
|
||||
@settings(deadline=None, max_examples=10, print_blob=True)
|
||||
def test_extmem_qdm(
|
||||
n_samples_per_batch: int, n_features: int, n_batches: int, on_host: bool
|
||||
) -> None:
|
||||
check_extmem_qdm(n_samples_per_batch, n_features, n_batches, "cuda", on_host)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user