Initial GPU support for the approx tree method. (#9414)

This commit is contained in:
Jiaming Yuan 2023-07-31 15:50:28 +08:00 committed by GitHub
parent 8f0efb4ab3
commit 912e341d57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 639 additions and 360 deletions

View File

@ -162,7 +162,8 @@ Parameters for Tree Booster
- ``grow_colmaker``: non-distributed column-based construction of trees.
- ``grow_histmaker``: distributed tree construction with row-based data splitting based on global proposal of histogram counting.
- ``grow_quantile_histmaker``: Grow tree using quantized histogram.
- ``grow_gpu_hist``: Grow tree with GPU. Enabled when ``tree_method`` is set to ``hist`` along with ``device=cuda``.
- ``grow_gpu_hist``: Enabled when ``tree_method`` is set to ``hist`` along with ``device=cuda``.
- ``grow_gpu_approx``: Enabled when ``tree_method`` is set to ``approx`` along with ``device=cuda``.
- ``sync``: synchronizes trees in all distributed nodes.
- ``refresh``: refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed.
- ``prune``: prunes the splits where loss < min_split_loss (or gamma) and nodes that have depth greater than ``max_depth``.

View File

@ -123,23 +123,23 @@ Feature Matrix
Following table summarizes some differences in supported features between 4 tree methods,
`T` means supported while `F` means unsupported.
+------------------+-----------+---------------------+---------------------+------------------------+
| | Exact | Approx | Hist | Hist (GPU) |
+==================+===========+=====================+=====================+========================+
| grow_policy | Depthwise | depthwise/lossguide | depthwise/lossguide | depthwise/lossguide |
+------------------+-----------+---------------------+---------------------+------------------------+
| max_leaves | F | T | T | T |
+------------------+-----------+---------------------+---------------------+------------------------+
| sampling method | uniform | uniform | uniform | gradient_based/uniform |
+------------------+-----------+---------------------+---------------------+------------------------+
| categorical data | F | T | T | T |
+------------------+-----------+---------------------+---------------------+------------------------+
| External memory | F | T | T | P |
+------------------+-----------+---------------------+---------------------+------------------------+
| Distributed | F | T | T | T |
+------------------+-----------+---------------------+---------------------+------------------------+
+------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| | Exact | Approx | Approx (GPU) | Hist | Hist (GPU) |
+==================+===========+=====================+========================+=====================+========================+
| grow_policy | Depthwise | depthwise/lossguide | depthwise/lossguide | depthwise/lossguide | depthwise/lossguide |
+------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| max_leaves | F | T | T | T | T |
+------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| sampling method | uniform | uniform | gradient_based/uniform | uniform | gradient_based/uniform |
+------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| categorical data | F | T | T | T | T |
+------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| External memory | F | T | P | T | P |
+------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| Distributed | F | T | T | T | T |
+------------------+-----------+---------------------+------------------------+---------------------+------------------------+
Features/parameters that are not mentioned here are universally supported for all 4 tree
Features/parameters that are not mentioned here are universally supported for all 3 tree
methods (for instance, column sampling and constraints). The `P` in external memory means
special handling. Please note that both categorical data and external memory are
experimental.

View File

@ -1,7 +1,7 @@
"""Tests for updaters."""
import json
from functools import partial, update_wrapper
from typing import Any, Dict
from typing import Any, Dict, List
import numpy as np
@ -256,3 +256,141 @@ def check_get_quantile_cut(tree_method: str) -> None:
check_get_quantile_cut_device(tree_method, False)
if use_cupy:
check_get_quantile_cut_device(tree_method, True)
USE_ONEHOT = np.iinfo(np.int32).max
USE_PART = 1
def check_categorical_ohe( # pylint: disable=too-many-arguments
rows: int, cols: int, rounds: int, cats: int, device: str, tree_method: str
) -> None:
"Test for one-hot encoding with categorical data."
onehot, label = tm.make_categorical(rows, cols, cats, True)
cat, _ = tm.make_categorical(rows, cols, cats, False)
by_etl_results: Dict[str, Dict[str, List[float]]] = {}
by_builtin_results: Dict[str, Dict[str, List[float]]] = {}
parameters: Dict[str, Any] = {
"tree_method": tree_method,
# Use one-hot exclusively
"max_cat_to_onehot": USE_ONEHOT,
"device": device,
}
m = xgb.DMatrix(onehot, label, enable_categorical=False)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_etl_results,
)
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_builtin_results,
)
# There are guidelines on how to specify tolerance based on considering output
# as random variables. But in here the tree construction is extremely sensitive
# to floating point errors. An 1e-5 error in a histogram bin can lead to an
# entirely different tree. So even though the test is quite lenient, hypothesis
# can still pick up falsifying examples from time to time.
np.testing.assert_allclose(
np.array(by_etl_results["Train"]["rmse"]),
np.array(by_builtin_results["Train"]["rmse"]),
rtol=1e-3,
)
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
by_grouping: Dict[str, Dict[str, List[float]]] = {}
# switch to partition-based splits
parameters["max_cat_to_onehot"] = USE_PART
parameters["reg_lambda"] = 0
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_grouping,
)
rmse_oh = by_builtin_results["Train"]["rmse"]
rmse_group = by_grouping["Train"]["rmse"]
# always better or equal to onehot when there's no regularization.
for a, b in zip(rmse_oh, rmse_group):
assert a >= b
parameters["reg_lambda"] = 1.0
by_grouping = {}
xgb.train(
parameters,
m,
num_boost_round=32,
evals=[(m, "Train")],
evals_result=by_grouping,
)
assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping
def check_categorical_missing(
rows: int, cols: int, cats: int, device: str, tree_method: str
) -> None:
"""Check categorical data with missing values."""
parameters: Dict[str, Any] = {"tree_method": tree_method, "device": device}
cat, label = tm.make_categorical(
rows, n_features=cols, n_categories=cats, onehot=False, sparsity=0.5
)
Xy = xgb.DMatrix(cat, label, enable_categorical=True)
def run(max_cat_to_onehot: int) -> None:
# Test with onehot splits
parameters["max_cat_to_onehot"] = max_cat_to_onehot
evals_result: Dict[str, Dict] = {}
booster = xgb.train(
parameters,
Xy,
num_boost_round=16,
evals=[(Xy, "Train")],
evals_result=evals_result,
)
assert tm.non_increasing(evals_result["Train"]["rmse"])
y_predt = booster.predict(Xy)
rmse = tm.root_mean_square(label, y_predt)
np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1], rtol=2e-5)
# Test with OHE split
run(USE_ONEHOT)
# Test with partition-based split
run(USE_PART)
def train_result(
param: Dict[str, Any], dmat: xgb.DMatrix, num_rounds: int
) -> Dict[str, Any]:
"""Get training result from parameters and data."""
result: Dict[str, Any] = {}
booster = xgb.train(
param,
dmat,
num_rounds,
evals=[(dmat, "train")],
verbose_eval=False,
evals_result=result,
)
assert booster.num_features() == dmat.num_col()
assert booster.num_boosted_rounds() == num_rounds
assert booster.feature_names == dmat.feature_names
assert booster.feature_types == dmat.feature_types
return result

View File

@ -89,5 +89,10 @@ void WarnDeprecatedGPUId();
void WarnEmptyDataset();
std::string DeprecatedFunc(StringView old, StringView since, StringView replacement);
constexpr StringView InvalidCUDAOrdinal() {
return "Invalid device. `device` is required to be CUDA and there must be at least one GPU "
"available for using GPU.";
}
} // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_

View File

@ -12,7 +12,7 @@
#include <vector> // for vector
#include "dmlc/parameter.h" // for FieldEntry, DMLC_DECLARE_FIELD
#include "error_msg.h" // for GroupWeight, GroupSize
#include "error_msg.h" // for GroupWeight, GroupSize, InvalidCUDAOrdinal
#include "xgboost/base.h" // for XGBOOST_DEVICE, bst_group_t
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for MetaInfo
@ -240,7 +240,7 @@ class RankingCache {
// The function simply returns a uninitialized buffer as this is only used by the
// objective for creating pairs.
common::Span<std::size_t> SortedIdxY(Context const* ctx, std::size_t n_samples) {
CHECK(ctx->IsCUDA());
CHECK(ctx->IsCUDA()) << error::InvalidCUDAOrdinal();
if (y_sorted_idx_cache_.Empty()) {
y_sorted_idx_cache_.SetDevice(ctx->gpu_id);
y_sorted_idx_cache_.Resize(n_samples);
@ -248,7 +248,7 @@ class RankingCache {
return y_sorted_idx_cache_.DeviceSpan();
}
common::Span<float> RankedY(Context const* ctx, std::size_t n_samples) {
CHECK(ctx->IsCUDA());
CHECK(ctx->IsCUDA()) << error::InvalidCUDAOrdinal();
if (y_ranked_by_model_.Empty()) {
y_ranked_by_model_.SetDevice(ctx->gpu_id);
y_ranked_by_model_.Resize(n_samples);

View File

@ -11,7 +11,6 @@
#include "../common/categorical.h"
#include "../common/cuda_context.cuh"
#include "../common/hist_util.cuh"
#include "../common/random.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "./ellpack_page.cuh"
#include "device_adapter.cuh" // for HasInfInData
@ -131,7 +130,11 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchP
monitor_.Start("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
row_stride = GetRowStride(dmat);
cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin);
if (!param.hess.empty()) {
cuts_ = common::DeviceSketchWithHessian(ctx, dmat, param.max_bin, param.hess);
} else {
cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin);
}
monitor_.Stop("Quantiles");
monitor_.Start("InitCompressedData");

View File

@ -7,13 +7,12 @@
#include <algorithm>
#include <limits>
#include <memory>
#include <utility> // std::forward
#include <utility> // for forward
#include "../common/column_matrix.h"
#include "../common/hist_util.h"
#include "../common/numeric.h"
#include "../common/threading_utils.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
namespace xgboost {

View File

@ -8,12 +8,12 @@
#include <algorithm>
#include <limits>
#include <numeric> // for accumulate
#include <type_traits>
#include <vector>
#include "../common/error_msg.h" // for InconsistentMaxBin
#include "../common/random.h"
#include "../common/threading_utils.h"
#include "../collective/communicator-inl.h" // for GetWorldSize, GetRank, Allgather
#include "../common/error_msg.h" // for InconsistentMaxBin
#include "./simple_batch_iterator.h"
#include "adapter.h"
#include "batch_utils.h" // for CheckEmpty, RegenGHist

View File

@ -8,7 +8,6 @@
#include "./sparse_page_dmatrix.h"
#include "../collective/communicator-inl.h"
#include "./simple_batch_iterator.h"
#include "batch_utils.h" // for RegenGHist
#include "gradient_index.h"

View File

@ -1,13 +1,15 @@
/**
* Copyright 2021-2023 by XGBoost contributors
*/
#include <memory>
#include <memory> // for unique_ptr
#include "../common/hist_util.cuh"
#include "batch_utils.h" // for CheckEmpty, RegenGHist
#include "../common/hist_util.h" // for HistogramCuts
#include "batch_utils.h" // for CheckEmpty, RegenGHist
#include "ellpack_page.cuh"
#include "sparse_page_dmatrix.h"
#include "sparse_page_source.h"
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for BatchParam
namespace xgboost::data {
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
@ -25,8 +27,13 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
cache_info_.erase(id);
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
std::unique_ptr<common::HistogramCuts> cuts;
cuts =
std::make_unique<common::HistogramCuts>(common::DeviceSketch(ctx, this, param.max_bin, 0));
if (!param.hess.empty()) {
cuts = std::make_unique<common::HistogramCuts>(
common::DeviceSketchWithHessian(ctx, this, param.max_bin, param.hess));
} else {
cuts =
std::make_unique<common::HistogramCuts>(common::DeviceSketch(ctx, this, param.max_bin));
}
this->InitializeSparsePage(ctx); // reset after use.
row_stride = GetRowStride(this);
@ -35,10 +42,10 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
batch_param_ = param;
auto ft = this->info_.feature_types.ConstDeviceSpan();
ellpack_page_source_.reset(); // release resources.
ellpack_page_source_.reset(new EllpackPageSource(
ellpack_page_source_.reset(); // make sure resource is released before making new ones.
ellpack_page_source_ = std::make_shared<EllpackPageSource>(
this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id),
param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_, ctx->gpu_id));
param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_, ctx->gpu_id);
} else {
CHECK(sparse_page_source_);
ellpack_page_source_->Reset();

View File

@ -47,15 +47,16 @@ std::string MapTreeMethodToUpdaters(Context const* ctx, TreeMethod tree_method)
if (ctx->IsCUDA()) {
common::AssertGPUSupport();
}
switch (tree_method) {
case TreeMethod::kAuto: // Use hist as default in 2.0
case TreeMethod::kHist: {
return ctx->DispatchDevice([] { return "grow_quantile_histmaker"; },
[] { return "grow_gpu_hist"; });
}
case TreeMethod::kApprox:
CHECK(ctx->IsCPU()) << "The `approx` tree method is not supported on GPU.";
return "grow_histmaker";
case TreeMethod::kApprox: {
return ctx->DispatchDevice([] { return "grow_histmaker"; }, [] { return "grow_gpu_approx"; });
}
case TreeMethod::kExact:
CHECK(ctx->IsCPU()) << "The `exact` tree method is not supported on GPU.";
return "grow_colmaker,prune";

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2018-2019 by Contributors
/**
* Copyright 2018-2023 by Contributors
*/
#ifndef XGBOOST_TREE_CONSTRAINTS_H_
#define XGBOOST_TREE_CONSTRAINTS_H_
@ -8,10 +8,8 @@
#include <unordered_set>
#include <vector>
#include "xgboost/span.h"
#include "xgboost/base.h"
#include "param.h"
#include "xgboost/base.h"
namespace xgboost {
/*!

View File

@ -8,10 +8,10 @@
#include <xgboost/logging.h>
#include <algorithm>
#include <cstddef> // for size_t
#include <limits>
#include <utility>
#include "../../common/compressed_iterator.h"
#include "../../common/cuda_context.cuh" // for CUDAContext
#include "../../common/random.h"
#include "../param.h"
@ -202,27 +202,27 @@ ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows,
GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx,
common::Span<GradientPair> gpair,
DMatrix* dmat) {
auto cuctx = ctx->CUDACtx();
// Set gradient pair to 0 with p = 1 - subsample
thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0),
BernoulliTrial(common::GlobalRandom()(), subsample_),
GradientPair());
thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<std::size_t>(0),
BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair{});
// Count the sampled rows.
size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero());
size_t sample_rows =
thrust::count_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), IsNonZero{});
// Compact gradient pairs.
gpair_.resize(sample_rows);
thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
thrust::copy_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero{});
// Index the sample rows.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), IsNonZero());
thrust::exclusive_scan(sample_row_index_.begin(), sample_row_index_.end(),
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
IsNonZero());
thrust::exclusive_scan(cuctx->CTP(), sample_row_index_.begin(), sample_row_index_.end(),
sample_row_index_.begin());
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
sample_row_index_.begin(),
sample_row_index_.begin(),
ClearEmptyRows());
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
sample_row_index_.begin(), ClearEmptyRows());
auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_);
auto first_page = (*batch_iterator.begin()).Impl();
@ -232,7 +232,7 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx,
first_page->row_stride, sample_rows));
// Compact the ELLPACK pages into the single sample page.
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
thrust::fill(cuctx->CTP(), dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
for (auto& batch : batch_iterator) {
page_->Compact(ctx->gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
}

View File

@ -11,7 +11,6 @@
#include "../common/random.h"
#include "../data/gradient_index.h"
#include "common_row_partitioner.h"
#include "constraints.h"
#include "driver.h"
#include "hist/evaluate_splits.h"
#include "hist/histogram.h"

View File

@ -31,7 +31,6 @@
#include "gpu_hist/histogram.cuh"
#include "gpu_hist/row_partitioner.cuh"
#include "param.h"
#include "split_evaluator.h"
#include "updater_gpu_common.cuh"
#include "xgboost/base.h"
#include "xgboost/context.h"
@ -49,13 +48,30 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
#endif // !defined(GTEST_TEST)
// training parameters specific to this algorithm
struct GPUHistMakerTrainParam
: public XGBoostParameter<GPUHistMakerTrainParam> {
struct GPUHistMakerTrainParam : public XGBoostParameter<GPUHistMakerTrainParam> {
bool debug_synchronize;
// declare parameters
DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) {
DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe(
"Check if all distributed tree are identical after tree construction.");
DMLC_DECLARE_FIELD(debug_synchronize)
.set_default(false)
.describe("Check if all distributed tree are identical after tree construction.");
}
// Only call this method for testing
void CheckTreesSynchronized(RegTree const* local_tree) const {
if (this->debug_synchronize) {
std::string s_model;
common::MemoryBufferStream fs(&s_model);
int rank = collective::GetRank();
if (rank == 0) {
local_tree->Save(&fs);
}
fs.Seek(0);
collective::Broadcast(&s_model, 0);
RegTree reference_tree{}; // rank 0 tree
reference_tree.Load(&fs);
CHECK(*local_tree == reference_tree);
}
}
};
#if !defined(GTEST_TEST)
@ -170,16 +186,15 @@ class DeviceHistogramStorage {
};
// Manage memory for a single GPU
template <typename GradientSumT>
struct GPUHistMakerDevice {
private:
GPUHistEvaluator evaluator_;
Context const* ctx_;
std::shared_ptr<common::ColumnSampler> column_sampler_;
public:
EllpackPageImpl const* page{nullptr};
common::Span<FeatureType const> feature_types;
BatchParam batch_param;
std::unique_ptr<RowPartitioner> row_partitioner;
DeviceHistogramStorage<> hist{};
@ -199,7 +214,6 @@ struct GPUHistMakerDevice {
dh::PinnedMemory pinned2;
common::Monitor monitor;
common::ColumnSampler column_sampler;
FeatureInteractionConstraintDevice interaction_constraints;
std::unique_ptr<GradientBasedSampler> sampler;
@ -208,22 +222,22 @@ struct GPUHistMakerDevice {
GPUHistMakerDevice(Context const* ctx, bool is_external_memory,
common::Span<FeatureType const> _feature_types, bst_row_t _n_rows,
TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features,
BatchParam _batch_param)
TrainParam _param, std::shared_ptr<common::ColumnSampler> column_sampler,
uint32_t n_features, BatchParam batch_param)
: evaluator_{_param, n_features, ctx->gpu_id},
ctx_(ctx),
feature_types{_feature_types},
param(std::move(_param)),
column_sampler(column_sampler_seed),
interaction_constraints(param, n_features),
batch_param(std::move(_batch_param)) {
sampler.reset(new GradientBasedSampler(ctx, _n_rows, batch_param, param.subsample,
param.sampling_method, is_external_memory));
column_sampler_(std::move(column_sampler)),
interaction_constraints(param, n_features) {
sampler = std::make_unique<GradientBasedSampler>(ctx, _n_rows, batch_param, param.subsample,
param.sampling_method, is_external_memory);
if (!param.monotone_constraints.empty()) {
// Copy assigning an empty vector causes an exception in MSVC debug builds
monotone_constraints = param.monotone_constraints;
}
CHECK(column_sampler_);
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(ctx_->gpu_id));
}
@ -234,16 +248,16 @@ struct GPUHistMakerDevice {
CHECK(page);
feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense,
dh::MaxSharedMemoryOptin(ctx_->gpu_id),
sizeof(GradientSumT)));
sizeof(GradientPairPrecise)));
}
}
// Reset values for each update iteration
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
auto const& info = dmat->Info();
this->column_sampler.Init(ctx_, num_columns, info.feature_weights.HostVector(),
param.colsample_bynode, param.colsample_bylevel,
param.colsample_bytree);
this->column_sampler_->Init(ctx_, num_columns, info.feature_weights.HostVector(),
param.colsample_bynode, param.colsample_bylevel,
param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
this->interaction_constraints.Reset();
@ -275,8 +289,8 @@ struct GPUHistMakerDevice {
GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) {
int nidx = RegTree::kRoot;
GPUTrainingParam gpu_param(param);
auto sampled_features = column_sampler.GetFeatureSet(0);
sampled_features->SetDevice(ctx_->gpu_id);
auto sampled_features = column_sampler_->GetFeatureSet(0);
sampled_features->SetDevice(ctx_->Device());
common::Span<bst_feature_t> feature_set =
interaction_constraints.Query(sampled_features->DeviceSpan(), nidx);
auto matrix = page->GetDeviceAccessor(ctx_->gpu_id);
@ -316,13 +330,13 @@ struct GPUHistMakerDevice {
int right_nidx = tree[candidate.nid].RightChild();
nidx[i * 2] = left_nidx;
nidx[i * 2 + 1] = right_nidx;
auto left_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(left_nidx));
left_sampled_features->SetDevice(ctx_->gpu_id);
auto left_sampled_features = column_sampler_->GetFeatureSet(tree.GetDepth(left_nidx));
left_sampled_features->SetDevice(ctx_->Device());
feature_sets.emplace_back(left_sampled_features);
common::Span<bst_feature_t> left_feature_set =
interaction_constraints.Query(left_sampled_features->DeviceSpan(), left_nidx);
auto right_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(right_nidx));
right_sampled_features->SetDevice(ctx_->gpu_id);
auto right_sampled_features = column_sampler_->GetFeatureSet(tree.GetDepth(right_nidx));
right_sampled_features->SetDevice(ctx_->Device());
feature_sets.emplace_back(right_sampled_features);
common::Span<bst_feature_t> right_feature_set =
interaction_constraints.Query(right_sampled_features->DeviceSpan(),
@ -657,7 +671,6 @@ struct GPUHistMakerDevice {
evaluator_.ApplyTreeSplit(candidate, p_tree);
const auto& parent = tree[candidate.nid];
std::size_t max_nidx = std::max(parent.LeftChild(), parent.RightChild());
interaction_constraints.Split(candidate.nid, parent.SplitIndex(), parent.LeftChild(),
parent.RightChild());
}
@ -693,9 +706,8 @@ struct GPUHistMakerDevice {
return root_entry;
}
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
ObjInfo const* task, RegTree* p_tree,
HostDeviceVector<bst_node_t>* p_out_position) {
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo const* task,
RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) {
auto& tree = *p_tree;
// Process maximum 32 nodes at a time
Driver<GPUExpandEntry> driver(param, 32);
@ -720,7 +732,6 @@ struct GPUHistMakerDevice {
std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(filtered_expand_set),
[&](const auto& e) { return driver.IsChildValid(e); });
auto new_candidates =
pinned.GetSpan<GPUExpandEntry>(filtered_expand_set.size() * 2, GPUExpandEntry());
@ -753,8 +764,7 @@ class GPUHistMaker : public TreeUpdater {
using GradientSumT = GradientPairPrecise;
public:
explicit GPUHistMaker(Context const* ctx, ObjInfo const* task)
: TreeUpdater(ctx), task_{task} {};
explicit GPUHistMaker(Context const* ctx, ObjInfo const* task) : TreeUpdater(ctx), task_{task} {};
void Configure(const Args& args) override {
// Used in test to count how many configurations are performed
LOG(DEBUG) << "[GPU Hist]: Configure";
@ -786,13 +796,10 @@ class GPUHistMaker : public TreeUpdater {
// build tree
try {
size_t t_idx{0};
std::size_t t_idx{0};
for (xgboost::RegTree* tree : trees) {
this->UpdateTree(param, gpair, dmat, tree, &out_position[t_idx]);
if (hist_maker_param_.debug_synchronize) {
this->CheckTreesSynchronized(tree);
}
this->hist_maker_param_.CheckTreesSynchronized(tree);
++t_idx;
}
dh::safe_cuda(cudaGetLastError());
@ -809,13 +816,14 @@ class GPUHistMaker : public TreeUpdater {
// Synchronise the column sampling seed
uint32_t column_sampling_seed = common::GlobalRandom()();
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()};
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
info_->feature_types.SetDevice(ctx_->gpu_id);
maker.reset(new GPUHistMakerDevice<GradientSumT>(
maker = std::make_unique<GPUHistMakerDevice>(
ctx_, !dmat->SingleColBlock(), info_->feature_types.ConstDeviceSpan(), info_->num_row_,
*param, column_sampling_seed, info_->num_col_, batch_param));
*param, column_sampler_, info_->num_col_, batch_param);
p_last_fmat_ = dmat;
initialised_ = true;
@ -830,21 +838,6 @@ class GPUHistMaker : public TreeUpdater {
p_last_tree_ = p_tree;
}
// Only call this method for testing
void CheckTreesSynchronized(RegTree* local_tree) const {
std::string s_model;
common::MemoryBufferStream fs(&s_model);
int rank = collective::GetRank();
if (rank == 0) {
local_tree->Save(&fs);
}
fs.Seek(0);
collective::Broadcast(&s_model, 0);
RegTree reference_tree{}; // rank 0 tree
reference_tree.Load(&fs);
CHECK(*local_tree == reference_tree);
}
void UpdateTree(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) {
monitor_.Start("InitData");
@ -868,7 +861,7 @@ class GPUHistMaker : public TreeUpdater {
MetaInfo* info_{}; // NOLINT
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
std::unique_ptr<GPUHistMakerDevice> maker; // NOLINT
[[nodiscard]] char const* Name() const override { return "grow_gpu_hist"; }
[[nodiscard]] bool HasNodePosition() const override { return true; }
@ -883,6 +876,7 @@ class GPUHistMaker : public TreeUpdater {
ObjInfo const* task_{nullptr};
common::Monitor monitor_;
std::shared_ptr<common::ColumnSampler> column_sampler_;
};
#if !defined(GTEST_TEST)
@ -892,4 +886,131 @@ XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
return new GPUHistMaker(ctx, task);
});
#endif // !defined(GTEST_TEST)
class GPUGlobalApproxMaker : public TreeUpdater {
public:
explicit GPUGlobalApproxMaker(Context const* ctx, ObjInfo const* task)
: TreeUpdater(ctx), task_{task} {};
void Configure(Args const& args) override {
// Used in test to count how many configurations are performed
LOG(DEBUG) << "[GPU Approx]: Configure";
hist_maker_param_.UpdateAllowUnknown(args);
dh::CheckComputeCapability();
initialised_ = false;
monitor_.Init(this->Name());
}
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
FromJson(config.at("approx_train_param"), &this->hist_maker_param_);
initialised_ = false;
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["approx_train_param"] = ToJson(hist_maker_param_);
}
~GPUGlobalApproxMaker() override { dh::GlobalMemoryLogger().Log(); }
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) override {
monitor_.Start("Update");
this->InitDataOnce(p_fmat);
// build tree
hess_.resize(gpair->Size());
auto hess = dh::ToSpan(hess_);
gpair->SetDevice(ctx_->Device());
auto d_gpair = gpair->ConstDeviceSpan();
auto cuctx = ctx_->CUDACtx();
thrust::transform(cuctx->CTP(), dh::tcbegin(d_gpair), dh::tcend(d_gpair), dh::tbegin(hess),
[=] XGBOOST_DEVICE(GradientPair const& g) { return g.GetHess(); });
auto const& info = p_fmat->Info();
info.feature_types.SetDevice(ctx_->Device());
auto batch = BatchParam{param->max_bin, hess, !task_->const_hess};
maker_ = std::make_unique<GPUHistMakerDevice>(
ctx_, !p_fmat->SingleColBlock(), info.feature_types.ConstDeviceSpan(), info.num_row_,
*param, column_sampler_, info.num_col_, batch);
std::size_t t_idx{0};
for (xgboost::RegTree* tree : trees) {
this->UpdateTree(gpair, p_fmat, tree, &out_position[t_idx]);
this->hist_maker_param_.CheckTreesSynchronized(tree);
++t_idx;
}
monitor_.Stop("Update");
}
void InitDataOnce(DMatrix* p_fmat) {
if (this->initialised_) {
return;
}
monitor_.Start(__func__);
CHECK(ctx_->IsCUDA()) << error::InvalidCUDAOrdinal();
// Synchronise the column sampling seed
uint32_t column_sampling_seed = common::GlobalRandom()();
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
p_last_fmat_ = p_fmat;
initialised_ = true;
monitor_.Stop(__func__);
}
void InitData(DMatrix* p_fmat, RegTree const* p_tree) {
this->InitDataOnce(p_fmat);
p_last_tree_ = p_tree;
}
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree,
HostDeviceVector<bst_node_t>* p_out_position) {
monitor_.Start("InitData");
this->InitData(p_fmat, p_tree);
monitor_.Stop("InitData");
gpair->SetDevice(ctx_->gpu_id);
maker_->UpdateTree(gpair, p_fmat, task_, p_tree, p_out_position);
}
bool UpdatePredictionCache(const DMatrix* data,
linalg::MatrixView<bst_float> p_out_preds) override {
if (maker_ == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
return false;
}
monitor_.Start("UpdatePredictionCache");
bool result = maker_->UpdatePredictionCache(p_out_preds, p_last_tree_);
monitor_.Stop("UpdatePredictionCache");
return result;
}
[[nodiscard]] char const* Name() const override { return "grow_gpu_approx"; }
[[nodiscard]] bool HasNodePosition() const override { return true; }
private:
bool initialised_{false};
GPUHistMakerTrainParam hist_maker_param_;
dh::device_vector<float> hess_;
std::shared_ptr<common::ColumnSampler> column_sampler_;
std::unique_ptr<GPUHistMakerDevice> maker_;
DMatrix* p_last_fmat_{nullptr};
RegTree const* p_last_tree_{nullptr};
ObjInfo const* task_{nullptr};
common::Monitor monitor_;
};
#if !defined(GTEST_TEST)
XGBOOST_REGISTER_TREE_UPDATER(GPUApproxMaker, "grow_gpu_approx")
.describe("Grow tree with GPU.")
.set_body([](Context const* ctx, ObjInfo const* task) {
return new GPUGlobalApproxMaker(ctx, task);
});
#endif // !defined(GTEST_TEST)
} // namespace xgboost::tree

View File

@ -13,10 +13,7 @@
#include "../../../src/common/common.h"
#include "../../../src/data/ellpack_page.cuh" // for EllpackPageImpl
#include "../../../src/data/ellpack_page.h" // for EllpackPage
#include "../../../src/data/sparse_page_source.h"
#include "../../../src/tree/constraints.cuh"
#include "../../../src/tree/param.h" // for TrainParam
#include "../../../src/tree/updater_gpu_common.cuh"
#include "../../../src/tree/updater_gpu_hist.cu"
#include "../filesystem.h" // dmlc::TemporaryDirectory
#include "../helpers.h"
@ -94,8 +91,9 @@ void TestBuildHist(bool use_shared_memory_histograms) {
auto page = BuildEllpackPage(kNRows, kNCols);
BatchParam batch_param{};
Context ctx{MakeCUDACtx(0)};
GPUHistMakerDevice<GradientSumT> maker(&ctx, /*is_external_memory=*/false, {}, kNRows, param,
kNCols, kNCols, batch_param);
auto cs = std::make_shared<common::ColumnSampler>(0);
GPUHistMakerDevice maker(&ctx, /*is_external_memory=*/false, {}, kNRows, param, cs, kNCols,
batch_param);
xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
HostDeviceVector<GradientPair> gpair(kNRows);

View File

@ -24,15 +24,11 @@ class TestPredictionCache : public ::testing::Test {
Xy_ = RandomDataGenerator{n_samples_, n_features, 0}.Targets(n_targets).GenerateDMatrix(true);
}
void RunLearnerTest(std::string updater_name, float subsample, std::string const& grow_policy,
std::string const& strategy) {
void RunLearnerTest(Context const* ctx, std::string updater_name, float subsample,
std::string const& grow_policy, std::string const& strategy) {
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
if (updater_name == "grow_gpu_hist") {
// gpu_id setup
learner->SetParam("tree_method", "gpu_hist");
} else {
learner->SetParam("updater", updater_name);
}
learner->SetParam("device", ctx->DeviceName());
learner->SetParam("updater", updater_name);
learner->SetParam("multi_strategy", strategy);
learner->SetParam("grow_policy", grow_policy);
learner->SetParam("subsample", std::to_string(subsample));
@ -65,20 +61,14 @@ class TestPredictionCache : public ::testing::Test {
}
}
void RunTest(std::string const& updater_name, std::string const& strategy) {
void RunTest(Context* ctx, std::string const& updater_name, std::string const& strategy) {
{
Context ctx;
ctx.InitAllowUnknown(Args{{"nthread", "8"}});
if (updater_name == "grow_gpu_hist") {
ctx = ctx.MakeCUDA(0);
} else {
ctx = ctx.MakeCPU();
}
ctx->InitAllowUnknown(Args{{"nthread", "8"}});
ObjInfo task{ObjInfo::kRegression};
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create(updater_name, &ctx, &task)};
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create(updater_name, ctx, &task)};
RegTree tree;
std::vector<RegTree *> trees{&tree};
std::vector<RegTree*> trees{&tree};
auto gpair = GenerateRandomGradients(n_samples_);
tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"max_bin", "64"}});
@ -86,33 +76,46 @@ class TestPredictionCache : public ::testing::Test {
std::vector<HostDeviceVector<bst_node_t>> position(1);
updater->Update(&param, &gpair, Xy_.get(), position, trees);
HostDeviceVector<float> out_prediction_cached;
out_prediction_cached.SetDevice(ctx.gpu_id);
out_prediction_cached.SetDevice(ctx->Device());
out_prediction_cached.Resize(n_samples_);
auto cache =
linalg::MakeTensorView(&ctx, &out_prediction_cached, out_prediction_cached.Size(), 1);
linalg::MakeTensorView(ctx, &out_prediction_cached, out_prediction_cached.Size(), 1);
ASSERT_TRUE(updater->UpdatePredictionCache(Xy_.get(), cache));
}
for (auto policy : {"depthwise", "lossguide"}) {
for (auto subsample : {1.0f, 0.4f}) {
this->RunLearnerTest(updater_name, subsample, policy, strategy);
this->RunLearnerTest(updater_name, subsample, policy, strategy);
this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy);
this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy);
}
}
}
};
TEST_F(TestPredictionCache, Approx) { this->RunTest("grow_histmaker", "one_output_per_tree"); }
TEST_F(TestPredictionCache, Approx) {
Context ctx;
this->RunTest(&ctx, "grow_histmaker", "one_output_per_tree");
}
TEST_F(TestPredictionCache, Hist) {
this->RunTest("grow_quantile_histmaker", "one_output_per_tree");
Context ctx;
this->RunTest(&ctx, "grow_quantile_histmaker", "one_output_per_tree");
}
TEST_F(TestPredictionCache, HistMulti) {
this->RunTest("grow_quantile_histmaker", "multi_output_tree");
Context ctx;
this->RunTest(&ctx, "grow_quantile_histmaker", "multi_output_tree");
}
#if defined(XGBOOST_USE_CUDA)
TEST_F(TestPredictionCache, GpuHist) { this->RunTest("grow_gpu_hist", "one_output_per_tree"); }
TEST_F(TestPredictionCache, GpuHist) {
auto ctx = MakeCUDACtx(0);
this->RunTest(&ctx, "grow_gpu_hist", "one_output_per_tree");
}
TEST_F(TestPredictionCache, GpuApprox) {
auto ctx = MakeCUDACtx(0);
this->RunTest(&ctx, "grow_gpu_approx", "one_output_per_tree");
}
#endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost

View File

@ -62,8 +62,10 @@ class RegenTest : public ::testing::Test {
auto constexpr Iter() const { return 4; }
template <typename Page>
size_t TestTreeMethod(std::string tree_method, std::string obj, bool reset = true) const {
size_t TestTreeMethod(Context const* ctx, std::string tree_method, std::string obj,
bool reset = true) const {
auto learner = std::unique_ptr<Learner>{Learner::Create({p_fmat_})};
learner->SetParam("device", ctx->DeviceName());
learner->SetParam("tree_method", tree_method);
learner->SetParam("objective", obj);
learner->Configure();
@ -87,40 +89,71 @@ class RegenTest : public ::testing::Test {
} // anonymous namespace
TEST_F(RegenTest, Approx) {
auto n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:squarederror");
Context ctx;
auto n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:squarederror");
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic");
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:logistic");
ASSERT_EQ(n, this->Iter());
}
TEST_F(RegenTest, Hist) {
auto n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror");
Context ctx;
auto n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:squarederror");
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:logistic");
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:logistic");
ASSERT_EQ(n, 1);
}
TEST_F(RegenTest, Mixed) {
auto n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror", false);
Context ctx;
auto n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:squarederror", false);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic", true);
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:logistic", true);
ASSERT_EQ(n, this->Iter() + 1);
n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic", false);
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:logistic", false);
ASSERT_EQ(n, this->Iter());
n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror", true);
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:squarederror", true);
ASSERT_EQ(n, this->Iter() + 1);
}
#if defined(XGBOOST_USE_CUDA)
TEST_F(RegenTest, GpuHist) {
auto n = this->TestTreeMethod<EllpackPage>("gpu_hist", "reg:squarederror");
TEST_F(RegenTest, GpuApprox) {
auto ctx = MakeCUDACtx(0);
auto n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:squarederror", true);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<EllpackPage>("gpu_hist", "reg:logistic", false);
n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", false);
ASSERT_EQ(n, this->Iter());
n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", true);
ASSERT_EQ(n, this->Iter() * 2);
}
TEST_F(RegenTest, GpuHist) {
auto ctx = MakeCUDACtx(0);
auto n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:squarederror", true);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:logistic", false);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<EllpackPage>("hist", "reg:logistic");
ASSERT_EQ(n, 2);
{
Context ctx;
n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:logistic");
ASSERT_EQ(n, 2);
}
}
TEST_F(RegenTest, GpuMixed) {
auto ctx = MakeCUDACtx(0);
auto n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:squarederror", false);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", true);
ASSERT_EQ(n, this->Iter() + 1);
n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", false);
ASSERT_EQ(n, this->Iter());
n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:squarederror", true);
ASSERT_EQ(n, this->Iter() + 1);
}
#endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost

View File

@ -20,10 +20,11 @@ class TestGrowPolicy : public ::testing::Test {
true);
}
std::unique_ptr<Learner> TrainOneIter(std::string tree_method, std::string policy,
int32_t max_leaves, int32_t max_depth) {
std::unique_ptr<Learner> TrainOneIter(Context const* ctx, std::string tree_method,
std::string policy, int32_t max_leaves, int32_t max_depth) {
std::unique_ptr<Learner> learner{Learner::Create({this->Xy_})};
learner->SetParam("tree_method", tree_method);
learner->SetParam("device", ctx->DeviceName());
if (max_leaves >= 0) {
learner->SetParam("max_leaves", std::to_string(max_leaves));
}
@ -63,7 +64,7 @@ class TestGrowPolicy : public ::testing::Test {
if (max_leaves == 0 && max_depth == 0) {
// unconstrainted
if (tree_method != "gpu_hist") {
if (ctx->IsCPU()) {
// GPU pre-allocates for all nodes.
learner->UpdateOneIter(0, Xy_);
}
@ -86,23 +87,23 @@ class TestGrowPolicy : public ::testing::Test {
return learner;
}
void TestCombination(std::string tree_method) {
void TestCombination(Context const* ctx, std::string tree_method) {
for (auto policy : {"depthwise", "lossguide"}) {
// -1 means default
for (auto leaves : {-1, 0, 3}) {
for (auto depth : {-1, 0, 3}) {
this->TrainOneIter(tree_method, policy, leaves, depth);
this->TrainOneIter(ctx, tree_method, policy, leaves, depth);
}
}
}
}
void TestTreeGrowPolicy(std::string tree_method, std::string policy) {
void TestTreeGrowPolicy(Context const* ctx, std::string tree_method, std::string policy) {
{
/**
* max_leaves
*/
auto learner = this->TrainOneIter(tree_method, policy, 16, -1);
auto learner = this->TrainOneIter(ctx, tree_method, policy, 16, -1);
Json model{Object{}};
learner->SaveModel(&model);
@ -115,7 +116,7 @@ class TestGrowPolicy : public ::testing::Test {
/**
* max_depth
*/
auto learner = this->TrainOneIter(tree_method, policy, -1, 3);
auto learner = this->TrainOneIter(ctx, tree_method, policy, -1, 3);
Json model{Object{}};
learner->SaveModel(&model);
@ -133,25 +134,36 @@ class TestGrowPolicy : public ::testing::Test {
};
TEST_F(TestGrowPolicy, Approx) {
this->TestTreeGrowPolicy("approx", "depthwise");
this->TestTreeGrowPolicy("approx", "lossguide");
Context ctx;
this->TestTreeGrowPolicy(&ctx, "approx", "depthwise");
this->TestTreeGrowPolicy(&ctx, "approx", "lossguide");
this->TestCombination("approx");
this->TestCombination(&ctx, "approx");
}
TEST_F(TestGrowPolicy, Hist) {
this->TestTreeGrowPolicy("hist", "depthwise");
this->TestTreeGrowPolicy("hist", "lossguide");
Context ctx;
this->TestTreeGrowPolicy(&ctx, "hist", "depthwise");
this->TestTreeGrowPolicy(&ctx, "hist", "lossguide");
this->TestCombination("hist");
this->TestCombination(&ctx, "hist");
}
#if defined(XGBOOST_USE_CUDA)
TEST_F(TestGrowPolicy, GpuHist) {
this->TestTreeGrowPolicy("gpu_hist", "depthwise");
this->TestTreeGrowPolicy("gpu_hist", "lossguide");
auto ctx = MakeCUDACtx(0);
this->TestTreeGrowPolicy(&ctx, "hist", "depthwise");
this->TestTreeGrowPolicy(&ctx, "hist", "lossguide");
this->TestCombination("gpu_hist");
this->TestCombination(&ctx, "hist");
}
TEST_F(TestGrowPolicy, GpuApprox) {
auto ctx = MakeCUDACtx(0);
this->TestTreeGrowPolicy(&ctx, "approx", "depthwise");
this->TestTreeGrowPolicy(&ctx, "approx", "lossguide");
this->TestCombination(&ctx, "approx");
}
#endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost

View File

@ -135,7 +135,7 @@ class TestMinSplitLoss : public ::testing::Test {
gpair_ = GenerateRandomGradients(kRows);
}
std::int32_t Update(std::string updater, float gamma) {
std::int32_t Update(Context const* ctx, std::string updater, float gamma) {
Args args{{"max_depth", "1"},
{"max_leaves", "0"},
@ -154,8 +154,7 @@ class TestMinSplitLoss : public ::testing::Test {
param.UpdateAllowUnknown(args);
ObjInfo task{ObjInfo::kRegression};
Context ctx{MakeCUDACtx(updater == "grow_gpu_hist" ? 0 : Context::kCpuId)};
auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, ctx, &task)};
up->Configure({});
RegTree tree;
@ -167,16 +166,16 @@ class TestMinSplitLoss : public ::testing::Test {
}
public:
void RunTest(std::string updater) {
void RunTest(Context const* ctx, std::string updater) {
{
int32_t n_nodes = Update(updater, 0.01);
int32_t n_nodes = Update(ctx, updater, 0.01);
// This is not strictly verified, meaning the numeber `2` is whatever GPU_Hist retured
// when writing this test, and only used for testing larger gamma (below) does prevent
// building tree.
ASSERT_EQ(n_nodes, 2);
}
{
int32_t n_nodes = Update(updater, 100.0);
int32_t n_nodes = Update(ctx, updater, 100.0);
// No new nodes with gamma == 100.
ASSERT_EQ(n_nodes, static_cast<decltype(n_nodes)>(0));
}
@ -185,10 +184,25 @@ class TestMinSplitLoss : public ::testing::Test {
/* Exact tree method requires a pruner as an additional updater, so not tested here. */
TEST_F(TestMinSplitLoss, Approx) { this->RunTest("grow_histmaker"); }
TEST_F(TestMinSplitLoss, Approx) {
Context ctx;
this->RunTest(&ctx, "grow_histmaker");
}
TEST_F(TestMinSplitLoss, Hist) {
Context ctx;
this->RunTest(&ctx, "grow_quantile_histmaker");
}
TEST_F(TestMinSplitLoss, Hist) { this->RunTest("grow_quantile_histmaker"); }
#if defined(XGBOOST_USE_CUDA)
TEST_F(TestMinSplitLoss, GpuHist) { this->RunTest("grow_gpu_hist"); }
TEST_F(TestMinSplitLoss, GpuHist) {
auto ctx = MakeCUDACtx(0);
this->RunTest(&ctx, "grow_gpu_hist");
}
TEST_F(TestMinSplitLoss, GpuApprox) {
auto ctx = MakeCUDACtx(0);
this->RunTest(&ctx, "grow_gpu_approx");
}
#endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost

View File

@ -7,11 +7,18 @@ from hypothesis import assume, given, note, settings, strategies
import xgboost as xgb
from xgboost import testing as tm
from xgboost.testing.params import cat_parameter_strategy, hist_parameter_strategy
from xgboost.testing.params import (
cat_parameter_strategy,
exact_parameter_strategy,
hist_parameter_strategy,
)
from xgboost.testing.updater import (
check_categorical_missing,
check_categorical_ohe,
check_get_quantile_cut,
check_init_estimation,
check_quantile_loss,
train_result,
)
sys.path.append("tests/python")
@ -20,22 +27,6 @@ import test_updaters as test_up
pytestmark = tm.timeout(30)
def train_result(param, dmat: xgb.DMatrix, num_rounds: int) -> dict:
result: xgb.callback.TrainingCallback.EvalsLog = {}
booster = xgb.train(
param,
dmat,
num_rounds,
[(dmat, "train")],
verbose_eval=False,
evals_result=result,
)
assert booster.num_features() == dmat.num_col()
assert booster.num_boosted_rounds() == num_rounds
return result
class TestGPUUpdatersMulti:
@given(
hist_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
@ -53,14 +44,45 @@ class TestGPUUpdaters:
cputest = test_up.TestTreeMethod()
@given(
hist_parameter_strategy, strategies.integers(1, 20), tm.make_dataset_strategy()
exact_parameter_strategy,
hist_parameter_strategy,
strategies.integers(1, 20),
tm.make_dataset_strategy(),
)
@settings(deadline=None, max_examples=50, print_blob=True)
def test_gpu_hist(self, param, num_rounds, dataset):
param["tree_method"] = "gpu_hist"
def test_gpu_hist(
self,
param: Dict[str, Any],
hist_param: Dict[str, Any],
num_rounds: int,
dataset: tm.TestDataset,
) -> None:
param.update({"tree_method": "hist", "device": "cuda"})
param.update(hist_param)
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(result)
note(str(result))
assert tm.non_increasing(result["train"][dataset.metric])
@given(
exact_parameter_strategy,
hist_parameter_strategy,
strategies.integers(1, 20),
tm.make_dataset_strategy(),
)
@settings(deadline=None, print_blob=True)
def test_gpu_approx(
self,
param: Dict[str, Any],
hist_param: Dict[str, Any],
num_rounds: int,
dataset: tm.TestDataset,
) -> None:
param.update({"tree_method": "approx", "device": "cuda"})
param.update(hist_param)
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(str(result))
assert tm.non_increasing(result["train"][dataset.metric])
@given(tm.sparse_datasets_strategy)
@ -69,23 +91,27 @@ class TestGPUUpdaters:
param = {"tree_method": "hist", "max_bin": 64}
hist_result = train_result(param, dataset.get_dmat(), 16)
note(hist_result)
assert tm.non_increasing(hist_result['train'][dataset.metric])
assert tm.non_increasing(hist_result["train"][dataset.metric])
param = {"tree_method": "gpu_hist", "max_bin": 64}
gpu_hist_result = train_result(param, dataset.get_dmat(), 16)
note(gpu_hist_result)
assert tm.non_increasing(gpu_hist_result['train'][dataset.metric])
assert tm.non_increasing(gpu_hist_result["train"][dataset.metric])
np.testing.assert_allclose(
hist_result["train"]["rmse"], gpu_hist_result["train"]["rmse"], rtol=1e-2
)
@given(strategies.integers(10, 400), strategies.integers(3, 8),
strategies.integers(1, 2), strategies.integers(4, 7))
@given(
strategies.integers(10, 400),
strategies.integers(3, 8),
strategies.integers(1, 2),
strategies.integers(4, 7),
)
@settings(deadline=None, max_examples=20, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical_ohe(self, rows, cols, rounds, cats):
self.cputest.run_categorical_ohe(rows, cols, rounds, cats, "gpu_hist")
check_categorical_ohe(rows, cols, rounds, cats, "cuda", "hist")
@given(
tm.categorical_dataset_strategy,
@ -95,7 +121,7 @@ class TestGPUUpdaters:
)
@settings(deadline=None, max_examples=20, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical(
def test_categorical_hist(
self,
dataset: tm.TestDataset,
hist_parameters: Dict[str, Any],
@ -103,7 +129,30 @@ class TestGPUUpdaters:
n_rounds: int,
) -> None:
cat_parameters.update(hist_parameters)
cat_parameters["tree_method"] = "gpu_hist"
cat_parameters["tree_method"] = "hist"
cat_parameters["device"] = "cuda"
results = train_result(cat_parameters, dataset.get_dmat(), n_rounds)
tm.non_increasing(results["train"]["rmse"])
@given(
tm.categorical_dataset_strategy,
hist_parameter_strategy,
cat_parameter_strategy,
strategies.integers(4, 32),
)
@settings(deadline=None, max_examples=20, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical_approx(
self,
dataset: tm.TestDataset,
hist_parameters: Dict[str, Any],
cat_parameters: Dict[str, Any],
n_rounds: int,
) -> None:
cat_parameters.update(hist_parameters)
cat_parameters["tree_method"] = "approx"
cat_parameters["device"] = "cuda"
results = train_result(cat_parameters, dataset.get_dmat(), n_rounds)
tm.non_increasing(results["train"]["rmse"])
@ -129,24 +178,25 @@ class TestGPUUpdaters:
@given(
strategies.integers(10, 400),
strategies.integers(3, 8),
strategies.integers(4, 7)
strategies.integers(4, 7),
)
@settings(deadline=None, max_examples=20, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical_missing(self, rows, cols, cats):
self.cputest.run_categorical_missing(rows, cols, cats, "gpu_hist")
check_categorical_missing(rows, cols, cats, "cuda", "approx")
check_categorical_missing(rows, cols, cats, "cuda", "hist")
@pytest.mark.skipif(**tm.no_pandas())
def test_max_cat(self) -> None:
self.cputest.run_max_cat("gpu_hist")
def test_categorical_32_cat(self):
'''32 hits the bound of integer bitset, so special test'''
"""32 hits the bound of integer bitset, so special test"""
rows = 1000
cols = 10
cats = 32
rounds = 4
self.cputest.run_categorical_ohe(rows, cols, rounds, cats, "gpu_hist")
check_categorical_ohe(rows, cols, rounds, cats, "cuda", "hist")
@pytest.mark.skipif(**tm.no_cupy())
def test_invalid_category(self):
@ -164,15 +214,15 @@ class TestGPUUpdaters:
) -> None:
# We cannot handle empty dataset yet
assume(len(dataset.y) > 0)
param['tree_method'] = 'gpu_hist'
param["tree_method"] = "gpu_hist"
param = dataset.set_params(param)
result = train_result(
param,
dataset.get_device_dmat(max_bin=param.get("max_bin", None)),
num_rounds
num_rounds,
)
note(result)
assert tm.non_increasing(result['train'][dataset.metric], tolerance=1e-3)
assert tm.non_increasing(result["train"][dataset.metric], tolerance=1e-3)
@given(
hist_parameter_strategy,
@ -185,12 +235,12 @@ class TestGPUUpdaters:
return
# We cannot handle empty dataset yet
assume(len(dataset.y) > 0)
param['tree_method'] = 'gpu_hist'
param["tree_method"] = "gpu_hist"
param = dataset.set_params(param)
m = dataset.get_external_dmat()
external_result = train_result(param, m, num_rounds)
del m
assert tm.non_increasing(external_result['train'][dataset.metric])
assert tm.non_increasing(external_result["train"][dataset.metric])
def test_empty_dmatrix_prediction(self):
# FIXME(trivialfis): This should be done with all updaters
@ -207,7 +257,7 @@ class TestGPUUpdaters:
dtrain,
verbose_eval=True,
num_boost_round=6,
evals=[(dtrain, 'Train')]
evals=[(dtrain, "Train")],
)
kRows = 100
@ -222,10 +272,10 @@ class TestGPUUpdaters:
@given(tm.make_dataset_strategy(), strategies.integers(0, 10))
@settings(deadline=None, max_examples=10, print_blob=True)
def test_specified_gpu_id_gpu_update(self, dataset, gpu_id):
param = {'tree_method': 'gpu_hist', 'gpu_id': gpu_id}
param = {"tree_method": "gpu_hist", "gpu_id": gpu_id}
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), 10)
assert tm.non_increasing(result['train'][dataset.metric])
assert tm.non_increasing(result["train"][dataset.metric])
@pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.parametrize("weighted", [True, False])

View File

@ -1,6 +1,6 @@
import json
from string import ascii_lowercase
from typing import Any, Dict, List
from typing import Any, Dict
import numpy as np
import pytest
@ -15,30 +15,15 @@ from xgboost.testing.params import (
hist_parameter_strategy,
)
from xgboost.testing.updater import (
check_categorical_missing,
check_categorical_ohe,
check_get_quantile_cut,
check_init_estimation,
check_quantile_loss,
train_result,
)
def train_result(param, dmat, num_rounds):
result = {}
booster = xgb.train(
param,
dmat,
num_rounds,
evals=[(dmat, "train")],
verbose_eval=False,
evals_result=result,
)
assert booster.num_features() == dmat.num_col()
assert booster.num_boosted_rounds() == num_rounds
assert booster.feature_names == dmat.feature_names
assert booster.feature_types == dmat.feature_types
return result
class TestTreeMethodMulti:
@given(
exact_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
@ -281,115 +266,6 @@ class TestTreeMethod:
def test_max_cat(self, tree_method) -> None:
self.run_max_cat(tree_method)
def run_categorical_missing(
self, rows: int, cols: int, cats: int, tree_method: str
) -> None:
parameters: Dict[str, Any] = {"tree_method": tree_method}
cat, label = tm.make_categorical(
rows, n_features=cols, n_categories=cats, onehot=False, sparsity=0.5
)
Xy = xgb.DMatrix(cat, label, enable_categorical=True)
def run(max_cat_to_onehot: int):
# Test with onehot splits
parameters["max_cat_to_onehot"] = max_cat_to_onehot
evals_result: Dict[str, Dict] = {}
booster = xgb.train(
parameters,
Xy,
num_boost_round=16,
evals=[(Xy, "Train")],
evals_result=evals_result
)
assert tm.non_increasing(evals_result["Train"]["rmse"])
y_predt = booster.predict(Xy)
rmse = tm.root_mean_square(label, y_predt)
np.testing.assert_allclose(
rmse, evals_result["Train"]["rmse"][-1], rtol=2e-5
)
# Test with OHE split
run(self.USE_ONEHOT)
# Test with partition-based split
run(self.USE_PART)
def run_categorical_ohe(
self, rows: int, cols: int, rounds: int, cats: int, tree_method: str
) -> None:
onehot, label = tm.make_categorical(rows, cols, cats, True)
cat, _ = tm.make_categorical(rows, cols, cats, False)
by_etl_results: Dict[str, Dict[str, List[float]]] = {}
by_builtin_results: Dict[str, Dict[str, List[float]]] = {}
parameters: Dict[str, Any] = {
"tree_method": tree_method,
# Use one-hot exclusively
"max_cat_to_onehot": self.USE_ONEHOT
}
m = xgb.DMatrix(onehot, label, enable_categorical=False)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_etl_results,
)
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_builtin_results,
)
# There are guidelines on how to specify tolerance based on considering output
# as random variables. But in here the tree construction is extremely sensitive
# to floating point errors. An 1e-5 error in a histogram bin can lead to an
# entirely different tree. So even though the test is quite lenient, hypothesis
# can still pick up falsifying examples from time to time.
np.testing.assert_allclose(
np.array(by_etl_results["Train"]["rmse"]),
np.array(by_builtin_results["Train"]["rmse"]),
rtol=1e-3,
)
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
by_grouping: Dict[str, Dict[str, List[float]]] = {}
# switch to partition-based splits
parameters["max_cat_to_onehot"] = self.USE_PART
parameters["reg_lambda"] = 0
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_grouping,
)
rmse_oh = by_builtin_results["Train"]["rmse"]
rmse_group = by_grouping["Train"]["rmse"]
# always better or equal to onehot when there's no regularization.
for a, b in zip(rmse_oh, rmse_group):
assert a >= b
parameters["reg_lambda"] = 1.0
by_grouping = {}
xgb.train(
parameters,
m,
num_boost_round=32,
evals=[(m, "Train")],
evals_result=by_grouping,
)
assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping
@given(strategies.integers(10, 400), strategies.integers(3, 8),
strategies.integers(1, 2), strategies.integers(4, 7))
@settings(deadline=None, print_blob=True)
@ -397,8 +273,8 @@ class TestTreeMethod:
def test_categorical_ohe(
self, rows: int, cols: int, rounds: int, cats: int
) -> None:
self.run_categorical_ohe(rows, cols, rounds, cats, "approx")
self.run_categorical_ohe(rows, cols, rounds, cats, "hist")
check_categorical_ohe(rows, cols, rounds, cats, "cpu", "approx")
check_categorical_ohe(rows, cols, rounds, cats, "cpu", "hist")
@given(
tm.categorical_dataset_strategy,
@ -454,8 +330,8 @@ class TestTreeMethod:
@settings(deadline=None, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical_missing(self, rows, cols, cats):
self.run_categorical_missing(rows, cols, cats, "approx")
self.run_categorical_missing(rows, cols, cats, "hist")
check_categorical_missing(rows, cols, cats, "cpu", "approx")
check_categorical_missing(rows, cols, cats, "cpu", "hist")
def run_adaptive(self, tree_method, weighted) -> None:
rng = np.random.RandomState(1994)

View File

@ -154,7 +154,6 @@ def run_gpu_hist(
DMatrixT: Type,
client: Client,
) -> None:
params["tree_method"] = "hist"
params["device"] = "cuda"
params = dataset.set_params(params)
# It doesn't make sense to distribute a completely
@ -275,8 +274,31 @@ class TestDistributedGPU:
dmatrix_type: type,
local_cuda_client: Client,
) -> None:
params["tree_method"] = "hist"
run_gpu_hist(params, num_rounds, dataset, dmatrix_type, local_cuda_client)
@given(
params=hist_parameter_strategy,
num_rounds=strategies.integers(1, 20),
dataset=tm.make_dataset_strategy(),
)
@settings(
deadline=duration(seconds=120),
max_examples=20,
suppress_health_check=suppress,
print_blob=True,
)
@pytest.mark.skipif(**tm.no_cupy())
def test_gpu_approx(
self,
params: Dict,
num_rounds: int,
dataset: tm.TestDataset,
local_cuda_client: Client,
) -> None:
params["tree_method"] = "approx"
run_gpu_hist(params, num_rounds, dataset, dxgb.DaskDMatrix, local_cuda_client)
def test_empty_quantile_dmatrix(self, local_cuda_client: Client) -> None:
client = local_cuda_client
X, y = make_categorical(client, 1, 30, 13)