Initial GPU support for the approx tree method. (#9414)

This commit is contained in:
Jiaming Yuan 2023-07-31 15:50:28 +08:00 committed by GitHub
parent 8f0efb4ab3
commit 912e341d57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 639 additions and 360 deletions

View File

@ -162,7 +162,8 @@ Parameters for Tree Booster
- ``grow_colmaker``: non-distributed column-based construction of trees. - ``grow_colmaker``: non-distributed column-based construction of trees.
- ``grow_histmaker``: distributed tree construction with row-based data splitting based on global proposal of histogram counting. - ``grow_histmaker``: distributed tree construction with row-based data splitting based on global proposal of histogram counting.
- ``grow_quantile_histmaker``: Grow tree using quantized histogram. - ``grow_quantile_histmaker``: Grow tree using quantized histogram.
- ``grow_gpu_hist``: Grow tree with GPU. Enabled when ``tree_method`` is set to ``hist`` along with ``device=cuda``. - ``grow_gpu_hist``: Enabled when ``tree_method`` is set to ``hist`` along with ``device=cuda``.
- ``grow_gpu_approx``: Enabled when ``tree_method`` is set to ``approx`` along with ``device=cuda``.
- ``sync``: synchronizes trees in all distributed nodes. - ``sync``: synchronizes trees in all distributed nodes.
- ``refresh``: refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed. - ``refresh``: refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed.
- ``prune``: prunes the splits where loss < min_split_loss (or gamma) and nodes that have depth greater than ``max_depth``. - ``prune``: prunes the splits where loss < min_split_loss (or gamma) and nodes that have depth greater than ``max_depth``.

View File

@ -123,23 +123,23 @@ Feature Matrix
Following table summarizes some differences in supported features between 4 tree methods, Following table summarizes some differences in supported features between 4 tree methods,
`T` means supported while `F` means unsupported. `T` means supported while `F` means unsupported.
+------------------+-----------+---------------------+---------------------+------------------------+ +------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| | Exact | Approx | Hist | Hist (GPU) | | | Exact | Approx | Approx (GPU) | Hist | Hist (GPU) |
+==================+===========+=====================+=====================+========================+ +==================+===========+=====================+========================+=====================+========================+
| grow_policy | Depthwise | depthwise/lossguide | depthwise/lossguide | depthwise/lossguide | | grow_policy | Depthwise | depthwise/lossguide | depthwise/lossguide | depthwise/lossguide | depthwise/lossguide |
+------------------+-----------+---------------------+---------------------+------------------------+ +------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| max_leaves | F | T | T | T | | max_leaves | F | T | T | T | T |
+------------------+-----------+---------------------+---------------------+------------------------+ +------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| sampling method | uniform | uniform | uniform | gradient_based/uniform | | sampling method | uniform | uniform | gradient_based/uniform | uniform | gradient_based/uniform |
+------------------+-----------+---------------------+---------------------+------------------------+ +------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| categorical data | F | T | T | T | | categorical data | F | T | T | T | T |
+------------------+-----------+---------------------+---------------------+------------------------+ +------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| External memory | F | T | T | P | | External memory | F | T | P | T | P |
+------------------+-----------+---------------------+---------------------+------------------------+ +------------------+-----------+---------------------+------------------------+---------------------+------------------------+
| Distributed | F | T | T | T | | Distributed | F | T | T | T | T |
+------------------+-----------+---------------------+---------------------+------------------------+ +------------------+-----------+---------------------+------------------------+---------------------+------------------------+
Features/parameters that are not mentioned here are universally supported for all 4 tree Features/parameters that are not mentioned here are universally supported for all 3 tree
methods (for instance, column sampling and constraints). The `P` in external memory means methods (for instance, column sampling and constraints). The `P` in external memory means
special handling. Please note that both categorical data and external memory are special handling. Please note that both categorical data and external memory are
experimental. experimental.

View File

@ -1,7 +1,7 @@
"""Tests for updaters.""" """Tests for updaters."""
import json import json
from functools import partial, update_wrapper from functools import partial, update_wrapper
from typing import Any, Dict from typing import Any, Dict, List
import numpy as np import numpy as np
@ -256,3 +256,141 @@ def check_get_quantile_cut(tree_method: str) -> None:
check_get_quantile_cut_device(tree_method, False) check_get_quantile_cut_device(tree_method, False)
if use_cupy: if use_cupy:
check_get_quantile_cut_device(tree_method, True) check_get_quantile_cut_device(tree_method, True)
USE_ONEHOT = np.iinfo(np.int32).max
USE_PART = 1
def check_categorical_ohe( # pylint: disable=too-many-arguments
rows: int, cols: int, rounds: int, cats: int, device: str, tree_method: str
) -> None:
"Test for one-hot encoding with categorical data."
onehot, label = tm.make_categorical(rows, cols, cats, True)
cat, _ = tm.make_categorical(rows, cols, cats, False)
by_etl_results: Dict[str, Dict[str, List[float]]] = {}
by_builtin_results: Dict[str, Dict[str, List[float]]] = {}
parameters: Dict[str, Any] = {
"tree_method": tree_method,
# Use one-hot exclusively
"max_cat_to_onehot": USE_ONEHOT,
"device": device,
}
m = xgb.DMatrix(onehot, label, enable_categorical=False)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_etl_results,
)
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_builtin_results,
)
# There are guidelines on how to specify tolerance based on considering output
# as random variables. But in here the tree construction is extremely sensitive
# to floating point errors. An 1e-5 error in a histogram bin can lead to an
# entirely different tree. So even though the test is quite lenient, hypothesis
# can still pick up falsifying examples from time to time.
np.testing.assert_allclose(
np.array(by_etl_results["Train"]["rmse"]),
np.array(by_builtin_results["Train"]["rmse"]),
rtol=1e-3,
)
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
by_grouping: Dict[str, Dict[str, List[float]]] = {}
# switch to partition-based splits
parameters["max_cat_to_onehot"] = USE_PART
parameters["reg_lambda"] = 0
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_grouping,
)
rmse_oh = by_builtin_results["Train"]["rmse"]
rmse_group = by_grouping["Train"]["rmse"]
# always better or equal to onehot when there's no regularization.
for a, b in zip(rmse_oh, rmse_group):
assert a >= b
parameters["reg_lambda"] = 1.0
by_grouping = {}
xgb.train(
parameters,
m,
num_boost_round=32,
evals=[(m, "Train")],
evals_result=by_grouping,
)
assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping
def check_categorical_missing(
rows: int, cols: int, cats: int, device: str, tree_method: str
) -> None:
"""Check categorical data with missing values."""
parameters: Dict[str, Any] = {"tree_method": tree_method, "device": device}
cat, label = tm.make_categorical(
rows, n_features=cols, n_categories=cats, onehot=False, sparsity=0.5
)
Xy = xgb.DMatrix(cat, label, enable_categorical=True)
def run(max_cat_to_onehot: int) -> None:
# Test with onehot splits
parameters["max_cat_to_onehot"] = max_cat_to_onehot
evals_result: Dict[str, Dict] = {}
booster = xgb.train(
parameters,
Xy,
num_boost_round=16,
evals=[(Xy, "Train")],
evals_result=evals_result,
)
assert tm.non_increasing(evals_result["Train"]["rmse"])
y_predt = booster.predict(Xy)
rmse = tm.root_mean_square(label, y_predt)
np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1], rtol=2e-5)
# Test with OHE split
run(USE_ONEHOT)
# Test with partition-based split
run(USE_PART)
def train_result(
param: Dict[str, Any], dmat: xgb.DMatrix, num_rounds: int
) -> Dict[str, Any]:
"""Get training result from parameters and data."""
result: Dict[str, Any] = {}
booster = xgb.train(
param,
dmat,
num_rounds,
evals=[(dmat, "train")],
verbose_eval=False,
evals_result=result,
)
assert booster.num_features() == dmat.num_col()
assert booster.num_boosted_rounds() == num_rounds
assert booster.feature_names == dmat.feature_names
assert booster.feature_types == dmat.feature_types
return result

View File

@ -89,5 +89,10 @@ void WarnDeprecatedGPUId();
void WarnEmptyDataset(); void WarnEmptyDataset();
std::string DeprecatedFunc(StringView old, StringView since, StringView replacement); std::string DeprecatedFunc(StringView old, StringView since, StringView replacement);
constexpr StringView InvalidCUDAOrdinal() {
return "Invalid device. `device` is required to be CUDA and there must be at least one GPU "
"available for using GPU.";
}
} // namespace xgboost::error } // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_ #endif // XGBOOST_COMMON_ERROR_MSG_H_

View File

@ -12,7 +12,7 @@
#include <vector> // for vector #include <vector> // for vector
#include "dmlc/parameter.h" // for FieldEntry, DMLC_DECLARE_FIELD #include "dmlc/parameter.h" // for FieldEntry, DMLC_DECLARE_FIELD
#include "error_msg.h" // for GroupWeight, GroupSize #include "error_msg.h" // for GroupWeight, GroupSize, InvalidCUDAOrdinal
#include "xgboost/base.h" // for XGBOOST_DEVICE, bst_group_t #include "xgboost/base.h" // for XGBOOST_DEVICE, bst_group_t
#include "xgboost/context.h" // for Context #include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for MetaInfo #include "xgboost/data.h" // for MetaInfo
@ -240,7 +240,7 @@ class RankingCache {
// The function simply returns a uninitialized buffer as this is only used by the // The function simply returns a uninitialized buffer as this is only used by the
// objective for creating pairs. // objective for creating pairs.
common::Span<std::size_t> SortedIdxY(Context const* ctx, std::size_t n_samples) { common::Span<std::size_t> SortedIdxY(Context const* ctx, std::size_t n_samples) {
CHECK(ctx->IsCUDA()); CHECK(ctx->IsCUDA()) << error::InvalidCUDAOrdinal();
if (y_sorted_idx_cache_.Empty()) { if (y_sorted_idx_cache_.Empty()) {
y_sorted_idx_cache_.SetDevice(ctx->gpu_id); y_sorted_idx_cache_.SetDevice(ctx->gpu_id);
y_sorted_idx_cache_.Resize(n_samples); y_sorted_idx_cache_.Resize(n_samples);
@ -248,7 +248,7 @@ class RankingCache {
return y_sorted_idx_cache_.DeviceSpan(); return y_sorted_idx_cache_.DeviceSpan();
} }
common::Span<float> RankedY(Context const* ctx, std::size_t n_samples) { common::Span<float> RankedY(Context const* ctx, std::size_t n_samples) {
CHECK(ctx->IsCUDA()); CHECK(ctx->IsCUDA()) << error::InvalidCUDAOrdinal();
if (y_ranked_by_model_.Empty()) { if (y_ranked_by_model_.Empty()) {
y_ranked_by_model_.SetDevice(ctx->gpu_id); y_ranked_by_model_.SetDevice(ctx->gpu_id);
y_ranked_by_model_.Resize(n_samples); y_ranked_by_model_.Resize(n_samples);

View File

@ -11,7 +11,6 @@
#include "../common/categorical.h" #include "../common/categorical.h"
#include "../common/cuda_context.cuh" #include "../common/cuda_context.cuh"
#include "../common/hist_util.cuh" #include "../common/hist_util.cuh"
#include "../common/random.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter #include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "./ellpack_page.cuh" #include "./ellpack_page.cuh"
#include "device_adapter.cuh" // for HasInfInData #include "device_adapter.cuh" // for HasInfInData
@ -131,7 +130,11 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchP
monitor_.Start("Quantiles"); monitor_.Start("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts. // Create the quantile sketches for the dmatrix and initialize HistogramCuts.
row_stride = GetRowStride(dmat); row_stride = GetRowStride(dmat);
cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin); if (!param.hess.empty()) {
cuts_ = common::DeviceSketchWithHessian(ctx, dmat, param.max_bin, param.hess);
} else {
cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin);
}
monitor_.Stop("Quantiles"); monitor_.Stop("Quantiles");
monitor_.Start("InitCompressedData"); monitor_.Start("InitCompressedData");

View File

@ -7,13 +7,12 @@
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <utility> // std::forward #include <utility> // for forward
#include "../common/column_matrix.h" #include "../common/column_matrix.h"
#include "../common/hist_util.h" #include "../common/hist_util.h"
#include "../common/numeric.h" #include "../common/numeric.h"
#include "../common/threading_utils.h" #include "../common/transform_iterator.h" // for MakeIndexTransformIter
#include "../common/transform_iterator.h" // MakeIndexTransformIter
namespace xgboost { namespace xgboost {

View File

@ -8,12 +8,12 @@
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <numeric> // for accumulate
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include "../common/error_msg.h" // for InconsistentMaxBin #include "../collective/communicator-inl.h" // for GetWorldSize, GetRank, Allgather
#include "../common/random.h" #include "../common/error_msg.h" // for InconsistentMaxBin
#include "../common/threading_utils.h"
#include "./simple_batch_iterator.h" #include "./simple_batch_iterator.h"
#include "adapter.h" #include "adapter.h"
#include "batch_utils.h" // for CheckEmpty, RegenGHist #include "batch_utils.h" // for CheckEmpty, RegenGHist

View File

@ -8,7 +8,6 @@
#include "./sparse_page_dmatrix.h" #include "./sparse_page_dmatrix.h"
#include "../collective/communicator-inl.h" #include "../collective/communicator-inl.h"
#include "./simple_batch_iterator.h"
#include "batch_utils.h" // for RegenGHist #include "batch_utils.h" // for RegenGHist
#include "gradient_index.h" #include "gradient_index.h"

View File

@ -1,13 +1,15 @@
/** /**
* Copyright 2021-2023 by XGBoost contributors * Copyright 2021-2023 by XGBoost contributors
*/ */
#include <memory> #include <memory> // for unique_ptr
#include "../common/hist_util.cuh" #include "../common/hist_util.cuh"
#include "batch_utils.h" // for CheckEmpty, RegenGHist #include "../common/hist_util.h" // for HistogramCuts
#include "batch_utils.h" // for CheckEmpty, RegenGHist
#include "ellpack_page.cuh" #include "ellpack_page.cuh"
#include "sparse_page_dmatrix.h" #include "sparse_page_dmatrix.h"
#include "sparse_page_source.h" #include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for BatchParam
namespace xgboost::data { namespace xgboost::data {
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx, BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
@ -25,8 +27,13 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
cache_info_.erase(id); cache_info_.erase(id);
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_); MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
std::unique_ptr<common::HistogramCuts> cuts; std::unique_ptr<common::HistogramCuts> cuts;
cuts = if (!param.hess.empty()) {
std::make_unique<common::HistogramCuts>(common::DeviceSketch(ctx, this, param.max_bin, 0)); cuts = std::make_unique<common::HistogramCuts>(
common::DeviceSketchWithHessian(ctx, this, param.max_bin, param.hess));
} else {
cuts =
std::make_unique<common::HistogramCuts>(common::DeviceSketch(ctx, this, param.max_bin));
}
this->InitializeSparsePage(ctx); // reset after use. this->InitializeSparsePage(ctx); // reset after use.
row_stride = GetRowStride(this); row_stride = GetRowStride(this);
@ -35,10 +42,10 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
batch_param_ = param; batch_param_ = param;
auto ft = this->info_.feature_types.ConstDeviceSpan(); auto ft = this->info_.feature_types.ConstDeviceSpan();
ellpack_page_source_.reset(); // release resources. ellpack_page_source_.reset(); // make sure resource is released before making new ones.
ellpack_page_source_.reset(new EllpackPageSource( ellpack_page_source_ = std::make_shared<EllpackPageSource>(
this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id), this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id),
param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_, ctx->gpu_id)); param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_, ctx->gpu_id);
} else { } else {
CHECK(sparse_page_source_); CHECK(sparse_page_source_);
ellpack_page_source_->Reset(); ellpack_page_source_->Reset();

View File

@ -47,15 +47,16 @@ std::string MapTreeMethodToUpdaters(Context const* ctx, TreeMethod tree_method)
if (ctx->IsCUDA()) { if (ctx->IsCUDA()) {
common::AssertGPUSupport(); common::AssertGPUSupport();
} }
switch (tree_method) { switch (tree_method) {
case TreeMethod::kAuto: // Use hist as default in 2.0 case TreeMethod::kAuto: // Use hist as default in 2.0
case TreeMethod::kHist: { case TreeMethod::kHist: {
return ctx->DispatchDevice([] { return "grow_quantile_histmaker"; }, return ctx->DispatchDevice([] { return "grow_quantile_histmaker"; },
[] { return "grow_gpu_hist"; }); [] { return "grow_gpu_hist"; });
} }
case TreeMethod::kApprox: case TreeMethod::kApprox: {
CHECK(ctx->IsCPU()) << "The `approx` tree method is not supported on GPU."; return ctx->DispatchDevice([] { return "grow_histmaker"; }, [] { return "grow_gpu_approx"; });
return "grow_histmaker"; }
case TreeMethod::kExact: case TreeMethod::kExact:
CHECK(ctx->IsCPU()) << "The `exact` tree method is not supported on GPU."; CHECK(ctx->IsCPU()) << "The `exact` tree method is not supported on GPU.";
return "grow_colmaker,prune"; return "grow_colmaker,prune";

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2018-2019 by Contributors * Copyright 2018-2023 by Contributors
*/ */
#ifndef XGBOOST_TREE_CONSTRAINTS_H_ #ifndef XGBOOST_TREE_CONSTRAINTS_H_
#define XGBOOST_TREE_CONSTRAINTS_H_ #define XGBOOST_TREE_CONSTRAINTS_H_
@ -8,10 +8,8 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "xgboost/span.h"
#include "xgboost/base.h"
#include "param.h" #include "param.h"
#include "xgboost/base.h"
namespace xgboost { namespace xgboost {
/*! /*!

View File

@ -8,10 +8,10 @@
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <algorithm> #include <algorithm>
#include <cstddef> // for size_t
#include <limits> #include <limits>
#include <utility> #include <utility>
#include "../../common/compressed_iterator.h"
#include "../../common/cuda_context.cuh" // for CUDAContext #include "../../common/cuda_context.cuh" // for CUDAContext
#include "../../common/random.h" #include "../../common/random.h"
#include "../param.h" #include "../param.h"
@ -202,27 +202,27 @@ ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows,
GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx, GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx,
common::Span<GradientPair> gpair, common::Span<GradientPair> gpair,
DMatrix* dmat) { DMatrix* dmat) {
auto cuctx = ctx->CUDACtx();
// Set gradient pair to 0 with p = 1 - subsample // Set gradient pair to 0 with p = 1 - subsample
thrust::replace_if(dh::tbegin(gpair), dh::tend(gpair), thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0), thrust::counting_iterator<std::size_t>(0),
BernoulliTrial(common::GlobalRandom()(), subsample_), BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair{});
GradientPair());
// Count the sampled rows. // Count the sampled rows.
size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero()); size_t sample_rows =
thrust::count_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), IsNonZero{});
// Compact gradient pairs. // Compact gradient pairs.
gpair_.resize(sample_rows); gpair_.resize(sample_rows);
thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero()); thrust::copy_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero{});
// Index the sample rows. // Index the sample rows.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), IsNonZero()); thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
thrust::exclusive_scan(sample_row_index_.begin(), sample_row_index_.end(), IsNonZero());
thrust::exclusive_scan(cuctx->CTP(), sample_row_index_.begin(), sample_row_index_.end(),
sample_row_index_.begin()); sample_row_index_.begin());
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
sample_row_index_.begin(), sample_row_index_.begin(), ClearEmptyRows());
sample_row_index_.begin(),
ClearEmptyRows());
auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_); auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_);
auto first_page = (*batch_iterator.begin()).Impl(); auto first_page = (*batch_iterator.begin()).Impl();
@ -232,7 +232,7 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx,
first_page->row_stride, sample_rows)); first_page->row_stride, sample_rows));
// Compact the ELLPACK pages into the single sample page. // Compact the ELLPACK pages into the single sample page.
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0); thrust::fill(cuctx->CTP(), dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
for (auto& batch : batch_iterator) { for (auto& batch : batch_iterator) {
page_->Compact(ctx->gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_)); page_->Compact(ctx->gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
} }

View File

@ -11,7 +11,6 @@
#include "../common/random.h" #include "../common/random.h"
#include "../data/gradient_index.h" #include "../data/gradient_index.h"
#include "common_row_partitioner.h" #include "common_row_partitioner.h"
#include "constraints.h"
#include "driver.h" #include "driver.h"
#include "hist/evaluate_splits.h" #include "hist/evaluate_splits.h"
#include "hist/histogram.h" #include "hist/histogram.h"

View File

@ -31,7 +31,6 @@
#include "gpu_hist/histogram.cuh" #include "gpu_hist/histogram.cuh"
#include "gpu_hist/row_partitioner.cuh" #include "gpu_hist/row_partitioner.cuh"
#include "param.h" #include "param.h"
#include "split_evaluator.h"
#include "updater_gpu_common.cuh" #include "updater_gpu_common.cuh"
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/context.h" #include "xgboost/context.h"
@ -49,13 +48,30 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
#endif // !defined(GTEST_TEST) #endif // !defined(GTEST_TEST)
// training parameters specific to this algorithm // training parameters specific to this algorithm
struct GPUHistMakerTrainParam struct GPUHistMakerTrainParam : public XGBoostParameter<GPUHistMakerTrainParam> {
: public XGBoostParameter<GPUHistMakerTrainParam> {
bool debug_synchronize; bool debug_synchronize;
// declare parameters // declare parameters
DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) { DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) {
DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe( DMLC_DECLARE_FIELD(debug_synchronize)
"Check if all distributed tree are identical after tree construction."); .set_default(false)
.describe("Check if all distributed tree are identical after tree construction.");
}
// Only call this method for testing
void CheckTreesSynchronized(RegTree const* local_tree) const {
if (this->debug_synchronize) {
std::string s_model;
common::MemoryBufferStream fs(&s_model);
int rank = collective::GetRank();
if (rank == 0) {
local_tree->Save(&fs);
}
fs.Seek(0);
collective::Broadcast(&s_model, 0);
RegTree reference_tree{}; // rank 0 tree
reference_tree.Load(&fs);
CHECK(*local_tree == reference_tree);
}
} }
}; };
#if !defined(GTEST_TEST) #if !defined(GTEST_TEST)
@ -170,16 +186,15 @@ class DeviceHistogramStorage {
}; };
// Manage memory for a single GPU // Manage memory for a single GPU
template <typename GradientSumT>
struct GPUHistMakerDevice { struct GPUHistMakerDevice {
private: private:
GPUHistEvaluator evaluator_; GPUHistEvaluator evaluator_;
Context const* ctx_; Context const* ctx_;
std::shared_ptr<common::ColumnSampler> column_sampler_;
public: public:
EllpackPageImpl const* page{nullptr}; EllpackPageImpl const* page{nullptr};
common::Span<FeatureType const> feature_types; common::Span<FeatureType const> feature_types;
BatchParam batch_param;
std::unique_ptr<RowPartitioner> row_partitioner; std::unique_ptr<RowPartitioner> row_partitioner;
DeviceHistogramStorage<> hist{}; DeviceHistogramStorage<> hist{};
@ -199,7 +214,6 @@ struct GPUHistMakerDevice {
dh::PinnedMemory pinned2; dh::PinnedMemory pinned2;
common::Monitor monitor; common::Monitor monitor;
common::ColumnSampler column_sampler;
FeatureInteractionConstraintDevice interaction_constraints; FeatureInteractionConstraintDevice interaction_constraints;
std::unique_ptr<GradientBasedSampler> sampler; std::unique_ptr<GradientBasedSampler> sampler;
@ -208,22 +222,22 @@ struct GPUHistMakerDevice {
GPUHistMakerDevice(Context const* ctx, bool is_external_memory, GPUHistMakerDevice(Context const* ctx, bool is_external_memory,
common::Span<FeatureType const> _feature_types, bst_row_t _n_rows, common::Span<FeatureType const> _feature_types, bst_row_t _n_rows,
TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features, TrainParam _param, std::shared_ptr<common::ColumnSampler> column_sampler,
BatchParam _batch_param) uint32_t n_features, BatchParam batch_param)
: evaluator_{_param, n_features, ctx->gpu_id}, : evaluator_{_param, n_features, ctx->gpu_id},
ctx_(ctx), ctx_(ctx),
feature_types{_feature_types}, feature_types{_feature_types},
param(std::move(_param)), param(std::move(_param)),
column_sampler(column_sampler_seed), column_sampler_(std::move(column_sampler)),
interaction_constraints(param, n_features), interaction_constraints(param, n_features) {
batch_param(std::move(_batch_param)) { sampler = std::make_unique<GradientBasedSampler>(ctx, _n_rows, batch_param, param.subsample,
sampler.reset(new GradientBasedSampler(ctx, _n_rows, batch_param, param.subsample, param.sampling_method, is_external_memory);
param.sampling_method, is_external_memory));
if (!param.monotone_constraints.empty()) { if (!param.monotone_constraints.empty()) {
// Copy assigning an empty vector causes an exception in MSVC debug builds // Copy assigning an empty vector causes an exception in MSVC debug builds
monotone_constraints = param.monotone_constraints; monotone_constraints = param.monotone_constraints;
} }
CHECK(column_sampler_);
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(ctx_->gpu_id)); monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(ctx_->gpu_id));
} }
@ -234,16 +248,16 @@ struct GPUHistMakerDevice {
CHECK(page); CHECK(page);
feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense, feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense,
dh::MaxSharedMemoryOptin(ctx_->gpu_id), dh::MaxSharedMemoryOptin(ctx_->gpu_id),
sizeof(GradientSumT))); sizeof(GradientPairPrecise)));
} }
} }
// Reset values for each update iteration // Reset values for each update iteration
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) { void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
auto const& info = dmat->Info(); auto const& info = dmat->Info();
this->column_sampler.Init(ctx_, num_columns, info.feature_weights.HostVector(), this->column_sampler_->Init(ctx_, num_columns, info.feature_weights.HostVector(),
param.colsample_bynode, param.colsample_bylevel, param.colsample_bynode, param.colsample_bylevel,
param.colsample_bytree); param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
this->interaction_constraints.Reset(); this->interaction_constraints.Reset();
@ -275,8 +289,8 @@ struct GPUHistMakerDevice {
GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) { GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) {
int nidx = RegTree::kRoot; int nidx = RegTree::kRoot;
GPUTrainingParam gpu_param(param); GPUTrainingParam gpu_param(param);
auto sampled_features = column_sampler.GetFeatureSet(0); auto sampled_features = column_sampler_->GetFeatureSet(0);
sampled_features->SetDevice(ctx_->gpu_id); sampled_features->SetDevice(ctx_->Device());
common::Span<bst_feature_t> feature_set = common::Span<bst_feature_t> feature_set =
interaction_constraints.Query(sampled_features->DeviceSpan(), nidx); interaction_constraints.Query(sampled_features->DeviceSpan(), nidx);
auto matrix = page->GetDeviceAccessor(ctx_->gpu_id); auto matrix = page->GetDeviceAccessor(ctx_->gpu_id);
@ -316,13 +330,13 @@ struct GPUHistMakerDevice {
int right_nidx = tree[candidate.nid].RightChild(); int right_nidx = tree[candidate.nid].RightChild();
nidx[i * 2] = left_nidx; nidx[i * 2] = left_nidx;
nidx[i * 2 + 1] = right_nidx; nidx[i * 2 + 1] = right_nidx;
auto left_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(left_nidx)); auto left_sampled_features = column_sampler_->GetFeatureSet(tree.GetDepth(left_nidx));
left_sampled_features->SetDevice(ctx_->gpu_id); left_sampled_features->SetDevice(ctx_->Device());
feature_sets.emplace_back(left_sampled_features); feature_sets.emplace_back(left_sampled_features);
common::Span<bst_feature_t> left_feature_set = common::Span<bst_feature_t> left_feature_set =
interaction_constraints.Query(left_sampled_features->DeviceSpan(), left_nidx); interaction_constraints.Query(left_sampled_features->DeviceSpan(), left_nidx);
auto right_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(right_nidx)); auto right_sampled_features = column_sampler_->GetFeatureSet(tree.GetDepth(right_nidx));
right_sampled_features->SetDevice(ctx_->gpu_id); right_sampled_features->SetDevice(ctx_->Device());
feature_sets.emplace_back(right_sampled_features); feature_sets.emplace_back(right_sampled_features);
common::Span<bst_feature_t> right_feature_set = common::Span<bst_feature_t> right_feature_set =
interaction_constraints.Query(right_sampled_features->DeviceSpan(), interaction_constraints.Query(right_sampled_features->DeviceSpan(),
@ -657,7 +671,6 @@ struct GPUHistMakerDevice {
evaluator_.ApplyTreeSplit(candidate, p_tree); evaluator_.ApplyTreeSplit(candidate, p_tree);
const auto& parent = tree[candidate.nid]; const auto& parent = tree[candidate.nid];
std::size_t max_nidx = std::max(parent.LeftChild(), parent.RightChild());
interaction_constraints.Split(candidate.nid, parent.SplitIndex(), parent.LeftChild(), interaction_constraints.Split(candidate.nid, parent.SplitIndex(), parent.LeftChild(),
parent.RightChild()); parent.RightChild());
} }
@ -693,9 +706,8 @@ struct GPUHistMakerDevice {
return root_entry; return root_entry;
} }
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo const* task,
ObjInfo const* task, RegTree* p_tree, RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) {
HostDeviceVector<bst_node_t>* p_out_position) {
auto& tree = *p_tree; auto& tree = *p_tree;
// Process maximum 32 nodes at a time // Process maximum 32 nodes at a time
Driver<GPUExpandEntry> driver(param, 32); Driver<GPUExpandEntry> driver(param, 32);
@ -720,7 +732,6 @@ struct GPUHistMakerDevice {
std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(filtered_expand_set), std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(filtered_expand_set),
[&](const auto& e) { return driver.IsChildValid(e); }); [&](const auto& e) { return driver.IsChildValid(e); });
auto new_candidates = auto new_candidates =
pinned.GetSpan<GPUExpandEntry>(filtered_expand_set.size() * 2, GPUExpandEntry()); pinned.GetSpan<GPUExpandEntry>(filtered_expand_set.size() * 2, GPUExpandEntry());
@ -753,8 +764,7 @@ class GPUHistMaker : public TreeUpdater {
using GradientSumT = GradientPairPrecise; using GradientSumT = GradientPairPrecise;
public: public:
explicit GPUHistMaker(Context const* ctx, ObjInfo const* task) explicit GPUHistMaker(Context const* ctx, ObjInfo const* task) : TreeUpdater(ctx), task_{task} {};
: TreeUpdater(ctx), task_{task} {};
void Configure(const Args& args) override { void Configure(const Args& args) override {
// Used in test to count how many configurations are performed // Used in test to count how many configurations are performed
LOG(DEBUG) << "[GPU Hist]: Configure"; LOG(DEBUG) << "[GPU Hist]: Configure";
@ -786,13 +796,10 @@ class GPUHistMaker : public TreeUpdater {
// build tree // build tree
try { try {
size_t t_idx{0}; std::size_t t_idx{0};
for (xgboost::RegTree* tree : trees) { for (xgboost::RegTree* tree : trees) {
this->UpdateTree(param, gpair, dmat, tree, &out_position[t_idx]); this->UpdateTree(param, gpair, dmat, tree, &out_position[t_idx]);
this->hist_maker_param_.CheckTreesSynchronized(tree);
if (hist_maker_param_.debug_synchronize) {
this->CheckTreesSynchronized(tree);
}
++t_idx; ++t_idx;
} }
dh::safe_cuda(cudaGetLastError()); dh::safe_cuda(cudaGetLastError());
@ -809,13 +816,14 @@ class GPUHistMaker : public TreeUpdater {
// Synchronise the column sampling seed // Synchronise the column sampling seed
uint32_t column_sampling_seed = common::GlobalRandom()(); uint32_t column_sampling_seed = common::GlobalRandom()();
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()}; auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()};
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
info_->feature_types.SetDevice(ctx_->gpu_id); info_->feature_types.SetDevice(ctx_->gpu_id);
maker.reset(new GPUHistMakerDevice<GradientSumT>( maker = std::make_unique<GPUHistMakerDevice>(
ctx_, !dmat->SingleColBlock(), info_->feature_types.ConstDeviceSpan(), info_->num_row_, ctx_, !dmat->SingleColBlock(), info_->feature_types.ConstDeviceSpan(), info_->num_row_,
*param, column_sampling_seed, info_->num_col_, batch_param)); *param, column_sampler_, info_->num_col_, batch_param);
p_last_fmat_ = dmat; p_last_fmat_ = dmat;
initialised_ = true; initialised_ = true;
@ -830,21 +838,6 @@ class GPUHistMaker : public TreeUpdater {
p_last_tree_ = p_tree; p_last_tree_ = p_tree;
} }
// Only call this method for testing
void CheckTreesSynchronized(RegTree* local_tree) const {
std::string s_model;
common::MemoryBufferStream fs(&s_model);
int rank = collective::GetRank();
if (rank == 0) {
local_tree->Save(&fs);
}
fs.Seek(0);
collective::Broadcast(&s_model, 0);
RegTree reference_tree{}; // rank 0 tree
reference_tree.Load(&fs);
CHECK(*local_tree == reference_tree);
}
void UpdateTree(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, void UpdateTree(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) { RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) {
monitor_.Start("InitData"); monitor_.Start("InitData");
@ -868,7 +861,7 @@ class GPUHistMaker : public TreeUpdater {
MetaInfo* info_{}; // NOLINT MetaInfo* info_{}; // NOLINT
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT std::unique_ptr<GPUHistMakerDevice> maker; // NOLINT
[[nodiscard]] char const* Name() const override { return "grow_gpu_hist"; } [[nodiscard]] char const* Name() const override { return "grow_gpu_hist"; }
[[nodiscard]] bool HasNodePosition() const override { return true; } [[nodiscard]] bool HasNodePosition() const override { return true; }
@ -883,6 +876,7 @@ class GPUHistMaker : public TreeUpdater {
ObjInfo const* task_{nullptr}; ObjInfo const* task_{nullptr};
common::Monitor monitor_; common::Monitor monitor_;
std::shared_ptr<common::ColumnSampler> column_sampler_;
}; };
#if !defined(GTEST_TEST) #if !defined(GTEST_TEST)
@ -892,4 +886,131 @@ XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
return new GPUHistMaker(ctx, task); return new GPUHistMaker(ctx, task);
}); });
#endif // !defined(GTEST_TEST) #endif // !defined(GTEST_TEST)
class GPUGlobalApproxMaker : public TreeUpdater {
public:
explicit GPUGlobalApproxMaker(Context const* ctx, ObjInfo const* task)
: TreeUpdater(ctx), task_{task} {};
void Configure(Args const& args) override {
// Used in test to count how many configurations are performed
LOG(DEBUG) << "[GPU Approx]: Configure";
hist_maker_param_.UpdateAllowUnknown(args);
dh::CheckComputeCapability();
initialised_ = false;
monitor_.Init(this->Name());
}
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
FromJson(config.at("approx_train_param"), &this->hist_maker_param_);
initialised_ = false;
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["approx_train_param"] = ToJson(hist_maker_param_);
}
~GPUGlobalApproxMaker() override { dh::GlobalMemoryLogger().Log(); }
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) override {
monitor_.Start("Update");
this->InitDataOnce(p_fmat);
// build tree
hess_.resize(gpair->Size());
auto hess = dh::ToSpan(hess_);
gpair->SetDevice(ctx_->Device());
auto d_gpair = gpair->ConstDeviceSpan();
auto cuctx = ctx_->CUDACtx();
thrust::transform(cuctx->CTP(), dh::tcbegin(d_gpair), dh::tcend(d_gpair), dh::tbegin(hess),
[=] XGBOOST_DEVICE(GradientPair const& g) { return g.GetHess(); });
auto const& info = p_fmat->Info();
info.feature_types.SetDevice(ctx_->Device());
auto batch = BatchParam{param->max_bin, hess, !task_->const_hess};
maker_ = std::make_unique<GPUHistMakerDevice>(
ctx_, !p_fmat->SingleColBlock(), info.feature_types.ConstDeviceSpan(), info.num_row_,
*param, column_sampler_, info.num_col_, batch);
std::size_t t_idx{0};
for (xgboost::RegTree* tree : trees) {
this->UpdateTree(gpair, p_fmat, tree, &out_position[t_idx]);
this->hist_maker_param_.CheckTreesSynchronized(tree);
++t_idx;
}
monitor_.Stop("Update");
}
void InitDataOnce(DMatrix* p_fmat) {
if (this->initialised_) {
return;
}
monitor_.Start(__func__);
CHECK(ctx_->IsCUDA()) << error::InvalidCUDAOrdinal();
// Synchronise the column sampling seed
uint32_t column_sampling_seed = common::GlobalRandom()();
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
p_last_fmat_ = p_fmat;
initialised_ = true;
monitor_.Stop(__func__);
}
void InitData(DMatrix* p_fmat, RegTree const* p_tree) {
this->InitDataOnce(p_fmat);
p_last_tree_ = p_tree;
}
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree,
HostDeviceVector<bst_node_t>* p_out_position) {
monitor_.Start("InitData");
this->InitData(p_fmat, p_tree);
monitor_.Stop("InitData");
gpair->SetDevice(ctx_->gpu_id);
maker_->UpdateTree(gpair, p_fmat, task_, p_tree, p_out_position);
}
bool UpdatePredictionCache(const DMatrix* data,
linalg::MatrixView<bst_float> p_out_preds) override {
if (maker_ == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
return false;
}
monitor_.Start("UpdatePredictionCache");
bool result = maker_->UpdatePredictionCache(p_out_preds, p_last_tree_);
monitor_.Stop("UpdatePredictionCache");
return result;
}
[[nodiscard]] char const* Name() const override { return "grow_gpu_approx"; }
[[nodiscard]] bool HasNodePosition() const override { return true; }
private:
bool initialised_{false};
GPUHistMakerTrainParam hist_maker_param_;
dh::device_vector<float> hess_;
std::shared_ptr<common::ColumnSampler> column_sampler_;
std::unique_ptr<GPUHistMakerDevice> maker_;
DMatrix* p_last_fmat_{nullptr};
RegTree const* p_last_tree_{nullptr};
ObjInfo const* task_{nullptr};
common::Monitor monitor_;
};
#if !defined(GTEST_TEST)
XGBOOST_REGISTER_TREE_UPDATER(GPUApproxMaker, "grow_gpu_approx")
.describe("Grow tree with GPU.")
.set_body([](Context const* ctx, ObjInfo const* task) {
return new GPUGlobalApproxMaker(ctx, task);
});
#endif // !defined(GTEST_TEST)
} // namespace xgboost::tree } // namespace xgboost::tree

View File

@ -13,10 +13,7 @@
#include "../../../src/common/common.h" #include "../../../src/common/common.h"
#include "../../../src/data/ellpack_page.cuh" // for EllpackPageImpl #include "../../../src/data/ellpack_page.cuh" // for EllpackPageImpl
#include "../../../src/data/ellpack_page.h" // for EllpackPage #include "../../../src/data/ellpack_page.h" // for EllpackPage
#include "../../../src/data/sparse_page_source.h"
#include "../../../src/tree/constraints.cuh"
#include "../../../src/tree/param.h" // for TrainParam #include "../../../src/tree/param.h" // for TrainParam
#include "../../../src/tree/updater_gpu_common.cuh"
#include "../../../src/tree/updater_gpu_hist.cu" #include "../../../src/tree/updater_gpu_hist.cu"
#include "../filesystem.h" // dmlc::TemporaryDirectory #include "../filesystem.h" // dmlc::TemporaryDirectory
#include "../helpers.h" #include "../helpers.h"
@ -94,8 +91,9 @@ void TestBuildHist(bool use_shared_memory_histograms) {
auto page = BuildEllpackPage(kNRows, kNCols); auto page = BuildEllpackPage(kNRows, kNCols);
BatchParam batch_param{}; BatchParam batch_param{};
Context ctx{MakeCUDACtx(0)}; Context ctx{MakeCUDACtx(0)};
GPUHistMakerDevice<GradientSumT> maker(&ctx, /*is_external_memory=*/false, {}, kNRows, param, auto cs = std::make_shared<common::ColumnSampler>(0);
kNCols, kNCols, batch_param); GPUHistMakerDevice maker(&ctx, /*is_external_memory=*/false, {}, kNRows, param, cs, kNCols,
batch_param);
xgboost::SimpleLCG gen; xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f); xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
HostDeviceVector<GradientPair> gpair(kNRows); HostDeviceVector<GradientPair> gpair(kNRows);

View File

@ -24,15 +24,11 @@ class TestPredictionCache : public ::testing::Test {
Xy_ = RandomDataGenerator{n_samples_, n_features, 0}.Targets(n_targets).GenerateDMatrix(true); Xy_ = RandomDataGenerator{n_samples_, n_features, 0}.Targets(n_targets).GenerateDMatrix(true);
} }
void RunLearnerTest(std::string updater_name, float subsample, std::string const& grow_policy, void RunLearnerTest(Context const* ctx, std::string updater_name, float subsample,
std::string const& strategy) { std::string const& grow_policy, std::string const& strategy) {
std::unique_ptr<Learner> learner{Learner::Create({Xy_})}; std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
if (updater_name == "grow_gpu_hist") { learner->SetParam("device", ctx->DeviceName());
// gpu_id setup learner->SetParam("updater", updater_name);
learner->SetParam("tree_method", "gpu_hist");
} else {
learner->SetParam("updater", updater_name);
}
learner->SetParam("multi_strategy", strategy); learner->SetParam("multi_strategy", strategy);
learner->SetParam("grow_policy", grow_policy); learner->SetParam("grow_policy", grow_policy);
learner->SetParam("subsample", std::to_string(subsample)); learner->SetParam("subsample", std::to_string(subsample));
@ -65,20 +61,14 @@ class TestPredictionCache : public ::testing::Test {
} }
} }
void RunTest(std::string const& updater_name, std::string const& strategy) { void RunTest(Context* ctx, std::string const& updater_name, std::string const& strategy) {
{ {
Context ctx; ctx->InitAllowUnknown(Args{{"nthread", "8"}});
ctx.InitAllowUnknown(Args{{"nthread", "8"}});
if (updater_name == "grow_gpu_hist") {
ctx = ctx.MakeCUDA(0);
} else {
ctx = ctx.MakeCPU();
}
ObjInfo task{ObjInfo::kRegression}; ObjInfo task{ObjInfo::kRegression};
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create(updater_name, &ctx, &task)}; std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create(updater_name, ctx, &task)};
RegTree tree; RegTree tree;
std::vector<RegTree *> trees{&tree}; std::vector<RegTree*> trees{&tree};
auto gpair = GenerateRandomGradients(n_samples_); auto gpair = GenerateRandomGradients(n_samples_);
tree::TrainParam param; tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"max_bin", "64"}}); param.UpdateAllowUnknown(Args{{"max_bin", "64"}});
@ -86,33 +76,46 @@ class TestPredictionCache : public ::testing::Test {
std::vector<HostDeviceVector<bst_node_t>> position(1); std::vector<HostDeviceVector<bst_node_t>> position(1);
updater->Update(&param, &gpair, Xy_.get(), position, trees); updater->Update(&param, &gpair, Xy_.get(), position, trees);
HostDeviceVector<float> out_prediction_cached; HostDeviceVector<float> out_prediction_cached;
out_prediction_cached.SetDevice(ctx.gpu_id); out_prediction_cached.SetDevice(ctx->Device());
out_prediction_cached.Resize(n_samples_); out_prediction_cached.Resize(n_samples_);
auto cache = auto cache =
linalg::MakeTensorView(&ctx, &out_prediction_cached, out_prediction_cached.Size(), 1); linalg::MakeTensorView(ctx, &out_prediction_cached, out_prediction_cached.Size(), 1);
ASSERT_TRUE(updater->UpdatePredictionCache(Xy_.get(), cache)); ASSERT_TRUE(updater->UpdatePredictionCache(Xy_.get(), cache));
} }
for (auto policy : {"depthwise", "lossguide"}) { for (auto policy : {"depthwise", "lossguide"}) {
for (auto subsample : {1.0f, 0.4f}) { for (auto subsample : {1.0f, 0.4f}) {
this->RunLearnerTest(updater_name, subsample, policy, strategy); this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy);
this->RunLearnerTest(updater_name, subsample, policy, strategy); this->RunLearnerTest(ctx, updater_name, subsample, policy, strategy);
} }
} }
} }
}; };
TEST_F(TestPredictionCache, Approx) { this->RunTest("grow_histmaker", "one_output_per_tree"); } TEST_F(TestPredictionCache, Approx) {
Context ctx;
this->RunTest(&ctx, "grow_histmaker", "one_output_per_tree");
}
TEST_F(TestPredictionCache, Hist) { TEST_F(TestPredictionCache, Hist) {
this->RunTest("grow_quantile_histmaker", "one_output_per_tree"); Context ctx;
this->RunTest(&ctx, "grow_quantile_histmaker", "one_output_per_tree");
} }
TEST_F(TestPredictionCache, HistMulti) { TEST_F(TestPredictionCache, HistMulti) {
this->RunTest("grow_quantile_histmaker", "multi_output_tree"); Context ctx;
this->RunTest(&ctx, "grow_quantile_histmaker", "multi_output_tree");
} }
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
TEST_F(TestPredictionCache, GpuHist) { this->RunTest("grow_gpu_hist", "one_output_per_tree"); } TEST_F(TestPredictionCache, GpuHist) {
auto ctx = MakeCUDACtx(0);
this->RunTest(&ctx, "grow_gpu_hist", "one_output_per_tree");
}
TEST_F(TestPredictionCache, GpuApprox) {
auto ctx = MakeCUDACtx(0);
this->RunTest(&ctx, "grow_gpu_approx", "one_output_per_tree");
}
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost } // namespace xgboost

View File

@ -62,8 +62,10 @@ class RegenTest : public ::testing::Test {
auto constexpr Iter() const { return 4; } auto constexpr Iter() const { return 4; }
template <typename Page> template <typename Page>
size_t TestTreeMethod(std::string tree_method, std::string obj, bool reset = true) const { size_t TestTreeMethod(Context const* ctx, std::string tree_method, std::string obj,
bool reset = true) const {
auto learner = std::unique_ptr<Learner>{Learner::Create({p_fmat_})}; auto learner = std::unique_ptr<Learner>{Learner::Create({p_fmat_})};
learner->SetParam("device", ctx->DeviceName());
learner->SetParam("tree_method", tree_method); learner->SetParam("tree_method", tree_method);
learner->SetParam("objective", obj); learner->SetParam("objective", obj);
learner->Configure(); learner->Configure();
@ -87,40 +89,71 @@ class RegenTest : public ::testing::Test {
} // anonymous namespace } // anonymous namespace
TEST_F(RegenTest, Approx) { TEST_F(RegenTest, Approx) {
auto n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:squarederror"); Context ctx;
auto n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:squarederror");
ASSERT_EQ(n, 1); ASSERT_EQ(n, 1);
n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic"); n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:logistic");
ASSERT_EQ(n, this->Iter()); ASSERT_EQ(n, this->Iter());
} }
TEST_F(RegenTest, Hist) { TEST_F(RegenTest, Hist) {
auto n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror"); Context ctx;
auto n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:squarederror");
ASSERT_EQ(n, 1); ASSERT_EQ(n, 1);
n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:logistic"); n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:logistic");
ASSERT_EQ(n, 1); ASSERT_EQ(n, 1);
} }
TEST_F(RegenTest, Mixed) { TEST_F(RegenTest, Mixed) {
auto n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror", false); Context ctx;
auto n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:squarederror", false);
ASSERT_EQ(n, 1); ASSERT_EQ(n, 1);
n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic", true); n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:logistic", true);
ASSERT_EQ(n, this->Iter() + 1); ASSERT_EQ(n, this->Iter() + 1);
n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic", false); n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:logistic", false);
ASSERT_EQ(n, this->Iter()); ASSERT_EQ(n, this->Iter());
n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror", true); n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:squarederror", true);
ASSERT_EQ(n, this->Iter() + 1); ASSERT_EQ(n, this->Iter() + 1);
} }
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
TEST_F(RegenTest, GpuHist) { TEST_F(RegenTest, GpuApprox) {
auto n = this->TestTreeMethod<EllpackPage>("gpu_hist", "reg:squarederror"); auto ctx = MakeCUDACtx(0);
auto n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:squarederror", true);
ASSERT_EQ(n, 1); ASSERT_EQ(n, 1);
n = this->TestTreeMethod<EllpackPage>("gpu_hist", "reg:logistic", false); n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", false);
ASSERT_EQ(n, this->Iter());
n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", true);
ASSERT_EQ(n, this->Iter() * 2);
}
TEST_F(RegenTest, GpuHist) {
auto ctx = MakeCUDACtx(0);
auto n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:squarederror", true);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:logistic", false);
ASSERT_EQ(n, 1); ASSERT_EQ(n, 1);
n = this->TestTreeMethod<EllpackPage>("hist", "reg:logistic"); {
ASSERT_EQ(n, 2); Context ctx;
n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:logistic");
ASSERT_EQ(n, 2);
}
}
TEST_F(RegenTest, GpuMixed) {
auto ctx = MakeCUDACtx(0);
auto n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:squarederror", false);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", true);
ASSERT_EQ(n, this->Iter() + 1);
n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", false);
ASSERT_EQ(n, this->Iter());
n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:squarederror", true);
ASSERT_EQ(n, this->Iter() + 1);
} }
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost } // namespace xgboost

View File

@ -20,10 +20,11 @@ class TestGrowPolicy : public ::testing::Test {
true); true);
} }
std::unique_ptr<Learner> TrainOneIter(std::string tree_method, std::string policy, std::unique_ptr<Learner> TrainOneIter(Context const* ctx, std::string tree_method,
int32_t max_leaves, int32_t max_depth) { std::string policy, int32_t max_leaves, int32_t max_depth) {
std::unique_ptr<Learner> learner{Learner::Create({this->Xy_})}; std::unique_ptr<Learner> learner{Learner::Create({this->Xy_})};
learner->SetParam("tree_method", tree_method); learner->SetParam("tree_method", tree_method);
learner->SetParam("device", ctx->DeviceName());
if (max_leaves >= 0) { if (max_leaves >= 0) {
learner->SetParam("max_leaves", std::to_string(max_leaves)); learner->SetParam("max_leaves", std::to_string(max_leaves));
} }
@ -63,7 +64,7 @@ class TestGrowPolicy : public ::testing::Test {
if (max_leaves == 0 && max_depth == 0) { if (max_leaves == 0 && max_depth == 0) {
// unconstrainted // unconstrainted
if (tree_method != "gpu_hist") { if (ctx->IsCPU()) {
// GPU pre-allocates for all nodes. // GPU pre-allocates for all nodes.
learner->UpdateOneIter(0, Xy_); learner->UpdateOneIter(0, Xy_);
} }
@ -86,23 +87,23 @@ class TestGrowPolicy : public ::testing::Test {
return learner; return learner;
} }
void TestCombination(std::string tree_method) { void TestCombination(Context const* ctx, std::string tree_method) {
for (auto policy : {"depthwise", "lossguide"}) { for (auto policy : {"depthwise", "lossguide"}) {
// -1 means default // -1 means default
for (auto leaves : {-1, 0, 3}) { for (auto leaves : {-1, 0, 3}) {
for (auto depth : {-1, 0, 3}) { for (auto depth : {-1, 0, 3}) {
this->TrainOneIter(tree_method, policy, leaves, depth); this->TrainOneIter(ctx, tree_method, policy, leaves, depth);
} }
} }
} }
} }
void TestTreeGrowPolicy(std::string tree_method, std::string policy) { void TestTreeGrowPolicy(Context const* ctx, std::string tree_method, std::string policy) {
{ {
/** /**
* max_leaves * max_leaves
*/ */
auto learner = this->TrainOneIter(tree_method, policy, 16, -1); auto learner = this->TrainOneIter(ctx, tree_method, policy, 16, -1);
Json model{Object{}}; Json model{Object{}};
learner->SaveModel(&model); learner->SaveModel(&model);
@ -115,7 +116,7 @@ class TestGrowPolicy : public ::testing::Test {
/** /**
* max_depth * max_depth
*/ */
auto learner = this->TrainOneIter(tree_method, policy, -1, 3); auto learner = this->TrainOneIter(ctx, tree_method, policy, -1, 3);
Json model{Object{}}; Json model{Object{}};
learner->SaveModel(&model); learner->SaveModel(&model);
@ -133,25 +134,36 @@ class TestGrowPolicy : public ::testing::Test {
}; };
TEST_F(TestGrowPolicy, Approx) { TEST_F(TestGrowPolicy, Approx) {
this->TestTreeGrowPolicy("approx", "depthwise"); Context ctx;
this->TestTreeGrowPolicy("approx", "lossguide"); this->TestTreeGrowPolicy(&ctx, "approx", "depthwise");
this->TestTreeGrowPolicy(&ctx, "approx", "lossguide");
this->TestCombination("approx"); this->TestCombination(&ctx, "approx");
} }
TEST_F(TestGrowPolicy, Hist) { TEST_F(TestGrowPolicy, Hist) {
this->TestTreeGrowPolicy("hist", "depthwise"); Context ctx;
this->TestTreeGrowPolicy("hist", "lossguide"); this->TestTreeGrowPolicy(&ctx, "hist", "depthwise");
this->TestTreeGrowPolicy(&ctx, "hist", "lossguide");
this->TestCombination("hist"); this->TestCombination(&ctx, "hist");
} }
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
TEST_F(TestGrowPolicy, GpuHist) { TEST_F(TestGrowPolicy, GpuHist) {
this->TestTreeGrowPolicy("gpu_hist", "depthwise"); auto ctx = MakeCUDACtx(0);
this->TestTreeGrowPolicy("gpu_hist", "lossguide"); this->TestTreeGrowPolicy(&ctx, "hist", "depthwise");
this->TestTreeGrowPolicy(&ctx, "hist", "lossguide");
this->TestCombination("gpu_hist"); this->TestCombination(&ctx, "hist");
}
TEST_F(TestGrowPolicy, GpuApprox) {
auto ctx = MakeCUDACtx(0);
this->TestTreeGrowPolicy(&ctx, "approx", "depthwise");
this->TestTreeGrowPolicy(&ctx, "approx", "lossguide");
this->TestCombination(&ctx, "approx");
} }
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost } // namespace xgboost

View File

@ -135,7 +135,7 @@ class TestMinSplitLoss : public ::testing::Test {
gpair_ = GenerateRandomGradients(kRows); gpair_ = GenerateRandomGradients(kRows);
} }
std::int32_t Update(std::string updater, float gamma) { std::int32_t Update(Context const* ctx, std::string updater, float gamma) {
Args args{{"max_depth", "1"}, Args args{{"max_depth", "1"},
{"max_leaves", "0"}, {"max_leaves", "0"},
@ -154,8 +154,7 @@ class TestMinSplitLoss : public ::testing::Test {
param.UpdateAllowUnknown(args); param.UpdateAllowUnknown(args);
ObjInfo task{ObjInfo::kRegression}; ObjInfo task{ObjInfo::kRegression};
Context ctx{MakeCUDACtx(updater == "grow_gpu_hist" ? 0 : Context::kCpuId)}; auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, ctx, &task)};
auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
up->Configure({}); up->Configure({});
RegTree tree; RegTree tree;
@ -167,16 +166,16 @@ class TestMinSplitLoss : public ::testing::Test {
} }
public: public:
void RunTest(std::string updater) { void RunTest(Context const* ctx, std::string updater) {
{ {
int32_t n_nodes = Update(updater, 0.01); int32_t n_nodes = Update(ctx, updater, 0.01);
// This is not strictly verified, meaning the numeber `2` is whatever GPU_Hist retured // This is not strictly verified, meaning the numeber `2` is whatever GPU_Hist retured
// when writing this test, and only used for testing larger gamma (below) does prevent // when writing this test, and only used for testing larger gamma (below) does prevent
// building tree. // building tree.
ASSERT_EQ(n_nodes, 2); ASSERT_EQ(n_nodes, 2);
} }
{ {
int32_t n_nodes = Update(updater, 100.0); int32_t n_nodes = Update(ctx, updater, 100.0);
// No new nodes with gamma == 100. // No new nodes with gamma == 100.
ASSERT_EQ(n_nodes, static_cast<decltype(n_nodes)>(0)); ASSERT_EQ(n_nodes, static_cast<decltype(n_nodes)>(0));
} }
@ -185,10 +184,25 @@ class TestMinSplitLoss : public ::testing::Test {
/* Exact tree method requires a pruner as an additional updater, so not tested here. */ /* Exact tree method requires a pruner as an additional updater, so not tested here. */
TEST_F(TestMinSplitLoss, Approx) { this->RunTest("grow_histmaker"); } TEST_F(TestMinSplitLoss, Approx) {
Context ctx;
this->RunTest(&ctx, "grow_histmaker");
}
TEST_F(TestMinSplitLoss, Hist) {
Context ctx;
this->RunTest(&ctx, "grow_quantile_histmaker");
}
TEST_F(TestMinSplitLoss, Hist) { this->RunTest("grow_quantile_histmaker"); }
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
TEST_F(TestMinSplitLoss, GpuHist) { this->RunTest("grow_gpu_hist"); } TEST_F(TestMinSplitLoss, GpuHist) {
auto ctx = MakeCUDACtx(0);
this->RunTest(&ctx, "grow_gpu_hist");
}
TEST_F(TestMinSplitLoss, GpuApprox) {
auto ctx = MakeCUDACtx(0);
this->RunTest(&ctx, "grow_gpu_approx");
}
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost } // namespace xgboost

View File

@ -7,11 +7,18 @@ from hypothesis import assume, given, note, settings, strategies
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
from xgboost.testing.params import cat_parameter_strategy, hist_parameter_strategy from xgboost.testing.params import (
cat_parameter_strategy,
exact_parameter_strategy,
hist_parameter_strategy,
)
from xgboost.testing.updater import ( from xgboost.testing.updater import (
check_categorical_missing,
check_categorical_ohe,
check_get_quantile_cut, check_get_quantile_cut,
check_init_estimation, check_init_estimation,
check_quantile_loss, check_quantile_loss,
train_result,
) )
sys.path.append("tests/python") sys.path.append("tests/python")
@ -20,22 +27,6 @@ import test_updaters as test_up
pytestmark = tm.timeout(30) pytestmark = tm.timeout(30)
def train_result(param, dmat: xgb.DMatrix, num_rounds: int) -> dict:
result: xgb.callback.TrainingCallback.EvalsLog = {}
booster = xgb.train(
param,
dmat,
num_rounds,
[(dmat, "train")],
verbose_eval=False,
evals_result=result,
)
assert booster.num_features() == dmat.num_col()
assert booster.num_boosted_rounds() == num_rounds
return result
class TestGPUUpdatersMulti: class TestGPUUpdatersMulti:
@given( @given(
hist_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy hist_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
@ -53,14 +44,45 @@ class TestGPUUpdaters:
cputest = test_up.TestTreeMethod() cputest = test_up.TestTreeMethod()
@given( @given(
hist_parameter_strategy, strategies.integers(1, 20), tm.make_dataset_strategy() exact_parameter_strategy,
hist_parameter_strategy,
strategies.integers(1, 20),
tm.make_dataset_strategy(),
) )
@settings(deadline=None, max_examples=50, print_blob=True) @settings(deadline=None, max_examples=50, print_blob=True)
def test_gpu_hist(self, param, num_rounds, dataset): def test_gpu_hist(
param["tree_method"] = "gpu_hist" self,
param: Dict[str, Any],
hist_param: Dict[str, Any],
num_rounds: int,
dataset: tm.TestDataset,
) -> None:
param.update({"tree_method": "hist", "device": "cuda"})
param.update(hist_param)
param = dataset.set_params(param) param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds) result = train_result(param, dataset.get_dmat(), num_rounds)
note(result) note(str(result))
assert tm.non_increasing(result["train"][dataset.metric])
@given(
exact_parameter_strategy,
hist_parameter_strategy,
strategies.integers(1, 20),
tm.make_dataset_strategy(),
)
@settings(deadline=None, print_blob=True)
def test_gpu_approx(
self,
param: Dict[str, Any],
hist_param: Dict[str, Any],
num_rounds: int,
dataset: tm.TestDataset,
) -> None:
param.update({"tree_method": "approx", "device": "cuda"})
param.update(hist_param)
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(str(result))
assert tm.non_increasing(result["train"][dataset.metric]) assert tm.non_increasing(result["train"][dataset.metric])
@given(tm.sparse_datasets_strategy) @given(tm.sparse_datasets_strategy)
@ -69,23 +91,27 @@ class TestGPUUpdaters:
param = {"tree_method": "hist", "max_bin": 64} param = {"tree_method": "hist", "max_bin": 64}
hist_result = train_result(param, dataset.get_dmat(), 16) hist_result = train_result(param, dataset.get_dmat(), 16)
note(hist_result) note(hist_result)
assert tm.non_increasing(hist_result['train'][dataset.metric]) assert tm.non_increasing(hist_result["train"][dataset.metric])
param = {"tree_method": "gpu_hist", "max_bin": 64} param = {"tree_method": "gpu_hist", "max_bin": 64}
gpu_hist_result = train_result(param, dataset.get_dmat(), 16) gpu_hist_result = train_result(param, dataset.get_dmat(), 16)
note(gpu_hist_result) note(gpu_hist_result)
assert tm.non_increasing(gpu_hist_result['train'][dataset.metric]) assert tm.non_increasing(gpu_hist_result["train"][dataset.metric])
np.testing.assert_allclose( np.testing.assert_allclose(
hist_result["train"]["rmse"], gpu_hist_result["train"]["rmse"], rtol=1e-2 hist_result["train"]["rmse"], gpu_hist_result["train"]["rmse"], rtol=1e-2
) )
@given(strategies.integers(10, 400), strategies.integers(3, 8), @given(
strategies.integers(1, 2), strategies.integers(4, 7)) strategies.integers(10, 400),
strategies.integers(3, 8),
strategies.integers(1, 2),
strategies.integers(4, 7),
)
@settings(deadline=None, max_examples=20, print_blob=True) @settings(deadline=None, max_examples=20, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas()) @pytest.mark.skipif(**tm.no_pandas())
def test_categorical_ohe(self, rows, cols, rounds, cats): def test_categorical_ohe(self, rows, cols, rounds, cats):
self.cputest.run_categorical_ohe(rows, cols, rounds, cats, "gpu_hist") check_categorical_ohe(rows, cols, rounds, cats, "cuda", "hist")
@given( @given(
tm.categorical_dataset_strategy, tm.categorical_dataset_strategy,
@ -95,7 +121,7 @@ class TestGPUUpdaters:
) )
@settings(deadline=None, max_examples=20, print_blob=True) @settings(deadline=None, max_examples=20, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas()) @pytest.mark.skipif(**tm.no_pandas())
def test_categorical( def test_categorical_hist(
self, self,
dataset: tm.TestDataset, dataset: tm.TestDataset,
hist_parameters: Dict[str, Any], hist_parameters: Dict[str, Any],
@ -103,7 +129,30 @@ class TestGPUUpdaters:
n_rounds: int, n_rounds: int,
) -> None: ) -> None:
cat_parameters.update(hist_parameters) cat_parameters.update(hist_parameters)
cat_parameters["tree_method"] = "gpu_hist" cat_parameters["tree_method"] = "hist"
cat_parameters["device"] = "cuda"
results = train_result(cat_parameters, dataset.get_dmat(), n_rounds)
tm.non_increasing(results["train"]["rmse"])
@given(
tm.categorical_dataset_strategy,
hist_parameter_strategy,
cat_parameter_strategy,
strategies.integers(4, 32),
)
@settings(deadline=None, max_examples=20, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical_approx(
self,
dataset: tm.TestDataset,
hist_parameters: Dict[str, Any],
cat_parameters: Dict[str, Any],
n_rounds: int,
) -> None:
cat_parameters.update(hist_parameters)
cat_parameters["tree_method"] = "approx"
cat_parameters["device"] = "cuda"
results = train_result(cat_parameters, dataset.get_dmat(), n_rounds) results = train_result(cat_parameters, dataset.get_dmat(), n_rounds)
tm.non_increasing(results["train"]["rmse"]) tm.non_increasing(results["train"]["rmse"])
@ -129,24 +178,25 @@ class TestGPUUpdaters:
@given( @given(
strategies.integers(10, 400), strategies.integers(10, 400),
strategies.integers(3, 8), strategies.integers(3, 8),
strategies.integers(4, 7) strategies.integers(4, 7),
) )
@settings(deadline=None, max_examples=20, print_blob=True) @settings(deadline=None, max_examples=20, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas()) @pytest.mark.skipif(**tm.no_pandas())
def test_categorical_missing(self, rows, cols, cats): def test_categorical_missing(self, rows, cols, cats):
self.cputest.run_categorical_missing(rows, cols, cats, "gpu_hist") check_categorical_missing(rows, cols, cats, "cuda", "approx")
check_categorical_missing(rows, cols, cats, "cuda", "hist")
@pytest.mark.skipif(**tm.no_pandas()) @pytest.mark.skipif(**tm.no_pandas())
def test_max_cat(self) -> None: def test_max_cat(self) -> None:
self.cputest.run_max_cat("gpu_hist") self.cputest.run_max_cat("gpu_hist")
def test_categorical_32_cat(self): def test_categorical_32_cat(self):
'''32 hits the bound of integer bitset, so special test''' """32 hits the bound of integer bitset, so special test"""
rows = 1000 rows = 1000
cols = 10 cols = 10
cats = 32 cats = 32
rounds = 4 rounds = 4
self.cputest.run_categorical_ohe(rows, cols, rounds, cats, "gpu_hist") check_categorical_ohe(rows, cols, rounds, cats, "cuda", "hist")
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_invalid_category(self): def test_invalid_category(self):
@ -164,15 +214,15 @@ class TestGPUUpdaters:
) -> None: ) -> None:
# We cannot handle empty dataset yet # We cannot handle empty dataset yet
assume(len(dataset.y) > 0) assume(len(dataset.y) > 0)
param['tree_method'] = 'gpu_hist' param["tree_method"] = "gpu_hist"
param = dataset.set_params(param) param = dataset.set_params(param)
result = train_result( result = train_result(
param, param,
dataset.get_device_dmat(max_bin=param.get("max_bin", None)), dataset.get_device_dmat(max_bin=param.get("max_bin", None)),
num_rounds num_rounds,
) )
note(result) note(result)
assert tm.non_increasing(result['train'][dataset.metric], tolerance=1e-3) assert tm.non_increasing(result["train"][dataset.metric], tolerance=1e-3)
@given( @given(
hist_parameter_strategy, hist_parameter_strategy,
@ -185,12 +235,12 @@ class TestGPUUpdaters:
return return
# We cannot handle empty dataset yet # We cannot handle empty dataset yet
assume(len(dataset.y) > 0) assume(len(dataset.y) > 0)
param['tree_method'] = 'gpu_hist' param["tree_method"] = "gpu_hist"
param = dataset.set_params(param) param = dataset.set_params(param)
m = dataset.get_external_dmat() m = dataset.get_external_dmat()
external_result = train_result(param, m, num_rounds) external_result = train_result(param, m, num_rounds)
del m del m
assert tm.non_increasing(external_result['train'][dataset.metric]) assert tm.non_increasing(external_result["train"][dataset.metric])
def test_empty_dmatrix_prediction(self): def test_empty_dmatrix_prediction(self):
# FIXME(trivialfis): This should be done with all updaters # FIXME(trivialfis): This should be done with all updaters
@ -207,7 +257,7 @@ class TestGPUUpdaters:
dtrain, dtrain,
verbose_eval=True, verbose_eval=True,
num_boost_round=6, num_boost_round=6,
evals=[(dtrain, 'Train')] evals=[(dtrain, "Train")],
) )
kRows = 100 kRows = 100
@ -222,10 +272,10 @@ class TestGPUUpdaters:
@given(tm.make_dataset_strategy(), strategies.integers(0, 10)) @given(tm.make_dataset_strategy(), strategies.integers(0, 10))
@settings(deadline=None, max_examples=10, print_blob=True) @settings(deadline=None, max_examples=10, print_blob=True)
def test_specified_gpu_id_gpu_update(self, dataset, gpu_id): def test_specified_gpu_id_gpu_update(self, dataset, gpu_id):
param = {'tree_method': 'gpu_hist', 'gpu_id': gpu_id} param = {"tree_method": "gpu_hist", "gpu_id": gpu_id}
param = dataset.set_params(param) param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), 10) result = train_result(param, dataset.get_dmat(), 10)
assert tm.non_increasing(result['train'][dataset.metric]) assert tm.non_increasing(result["train"][dataset.metric])
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.parametrize("weighted", [True, False]) @pytest.mark.parametrize("weighted", [True, False])

View File

@ -1,6 +1,6 @@
import json import json
from string import ascii_lowercase from string import ascii_lowercase
from typing import Any, Dict, List from typing import Any, Dict
import numpy as np import numpy as np
import pytest import pytest
@ -15,30 +15,15 @@ from xgboost.testing.params import (
hist_parameter_strategy, hist_parameter_strategy,
) )
from xgboost.testing.updater import ( from xgboost.testing.updater import (
check_categorical_missing,
check_categorical_ohe,
check_get_quantile_cut, check_get_quantile_cut,
check_init_estimation, check_init_estimation,
check_quantile_loss, check_quantile_loss,
train_result,
) )
def train_result(param, dmat, num_rounds):
result = {}
booster = xgb.train(
param,
dmat,
num_rounds,
evals=[(dmat, "train")],
verbose_eval=False,
evals_result=result,
)
assert booster.num_features() == dmat.num_col()
assert booster.num_boosted_rounds() == num_rounds
assert booster.feature_names == dmat.feature_names
assert booster.feature_types == dmat.feature_types
return result
class TestTreeMethodMulti: class TestTreeMethodMulti:
@given( @given(
exact_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy exact_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
@ -281,115 +266,6 @@ class TestTreeMethod:
def test_max_cat(self, tree_method) -> None: def test_max_cat(self, tree_method) -> None:
self.run_max_cat(tree_method) self.run_max_cat(tree_method)
def run_categorical_missing(
self, rows: int, cols: int, cats: int, tree_method: str
) -> None:
parameters: Dict[str, Any] = {"tree_method": tree_method}
cat, label = tm.make_categorical(
rows, n_features=cols, n_categories=cats, onehot=False, sparsity=0.5
)
Xy = xgb.DMatrix(cat, label, enable_categorical=True)
def run(max_cat_to_onehot: int):
# Test with onehot splits
parameters["max_cat_to_onehot"] = max_cat_to_onehot
evals_result: Dict[str, Dict] = {}
booster = xgb.train(
parameters,
Xy,
num_boost_round=16,
evals=[(Xy, "Train")],
evals_result=evals_result
)
assert tm.non_increasing(evals_result["Train"]["rmse"])
y_predt = booster.predict(Xy)
rmse = tm.root_mean_square(label, y_predt)
np.testing.assert_allclose(
rmse, evals_result["Train"]["rmse"][-1], rtol=2e-5
)
# Test with OHE split
run(self.USE_ONEHOT)
# Test with partition-based split
run(self.USE_PART)
def run_categorical_ohe(
self, rows: int, cols: int, rounds: int, cats: int, tree_method: str
) -> None:
onehot, label = tm.make_categorical(rows, cols, cats, True)
cat, _ = tm.make_categorical(rows, cols, cats, False)
by_etl_results: Dict[str, Dict[str, List[float]]] = {}
by_builtin_results: Dict[str, Dict[str, List[float]]] = {}
parameters: Dict[str, Any] = {
"tree_method": tree_method,
# Use one-hot exclusively
"max_cat_to_onehot": self.USE_ONEHOT
}
m = xgb.DMatrix(onehot, label, enable_categorical=False)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_etl_results,
)
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_builtin_results,
)
# There are guidelines on how to specify tolerance based on considering output
# as random variables. But in here the tree construction is extremely sensitive
# to floating point errors. An 1e-5 error in a histogram bin can lead to an
# entirely different tree. So even though the test is quite lenient, hypothesis
# can still pick up falsifying examples from time to time.
np.testing.assert_allclose(
np.array(by_etl_results["Train"]["rmse"]),
np.array(by_builtin_results["Train"]["rmse"]),
rtol=1e-3,
)
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
by_grouping: Dict[str, Dict[str, List[float]]] = {}
# switch to partition-based splits
parameters["max_cat_to_onehot"] = self.USE_PART
parameters["reg_lambda"] = 0
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_grouping,
)
rmse_oh = by_builtin_results["Train"]["rmse"]
rmse_group = by_grouping["Train"]["rmse"]
# always better or equal to onehot when there's no regularization.
for a, b in zip(rmse_oh, rmse_group):
assert a >= b
parameters["reg_lambda"] = 1.0
by_grouping = {}
xgb.train(
parameters,
m,
num_boost_round=32,
evals=[(m, "Train")],
evals_result=by_grouping,
)
assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping
@given(strategies.integers(10, 400), strategies.integers(3, 8), @given(strategies.integers(10, 400), strategies.integers(3, 8),
strategies.integers(1, 2), strategies.integers(4, 7)) strategies.integers(1, 2), strategies.integers(4, 7))
@settings(deadline=None, print_blob=True) @settings(deadline=None, print_blob=True)
@ -397,8 +273,8 @@ class TestTreeMethod:
def test_categorical_ohe( def test_categorical_ohe(
self, rows: int, cols: int, rounds: int, cats: int self, rows: int, cols: int, rounds: int, cats: int
) -> None: ) -> None:
self.run_categorical_ohe(rows, cols, rounds, cats, "approx") check_categorical_ohe(rows, cols, rounds, cats, "cpu", "approx")
self.run_categorical_ohe(rows, cols, rounds, cats, "hist") check_categorical_ohe(rows, cols, rounds, cats, "cpu", "hist")
@given( @given(
tm.categorical_dataset_strategy, tm.categorical_dataset_strategy,
@ -454,8 +330,8 @@ class TestTreeMethod:
@settings(deadline=None, print_blob=True) @settings(deadline=None, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas()) @pytest.mark.skipif(**tm.no_pandas())
def test_categorical_missing(self, rows, cols, cats): def test_categorical_missing(self, rows, cols, cats):
self.run_categorical_missing(rows, cols, cats, "approx") check_categorical_missing(rows, cols, cats, "cpu", "approx")
self.run_categorical_missing(rows, cols, cats, "hist") check_categorical_missing(rows, cols, cats, "cpu", "hist")
def run_adaptive(self, tree_method, weighted) -> None: def run_adaptive(self, tree_method, weighted) -> None:
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)

View File

@ -154,7 +154,6 @@ def run_gpu_hist(
DMatrixT: Type, DMatrixT: Type,
client: Client, client: Client,
) -> None: ) -> None:
params["tree_method"] = "hist"
params["device"] = "cuda" params["device"] = "cuda"
params = dataset.set_params(params) params = dataset.set_params(params)
# It doesn't make sense to distribute a completely # It doesn't make sense to distribute a completely
@ -275,8 +274,31 @@ class TestDistributedGPU:
dmatrix_type: type, dmatrix_type: type,
local_cuda_client: Client, local_cuda_client: Client,
) -> None: ) -> None:
params["tree_method"] = "hist"
run_gpu_hist(params, num_rounds, dataset, dmatrix_type, local_cuda_client) run_gpu_hist(params, num_rounds, dataset, dmatrix_type, local_cuda_client)
@given(
params=hist_parameter_strategy,
num_rounds=strategies.integers(1, 20),
dataset=tm.make_dataset_strategy(),
)
@settings(
deadline=duration(seconds=120),
max_examples=20,
suppress_health_check=suppress,
print_blob=True,
)
@pytest.mark.skipif(**tm.no_cupy())
def test_gpu_approx(
self,
params: Dict,
num_rounds: int,
dataset: tm.TestDataset,
local_cuda_client: Client,
) -> None:
params["tree_method"] = "approx"
run_gpu_hist(params, num_rounds, dataset, dxgb.DaskDMatrix, local_cuda_client)
def test_empty_quantile_dmatrix(self, local_cuda_client: Client) -> None: def test_empty_quantile_dmatrix(self, local_cuda_client: Client) -> None:
client = local_cuda_client client = local_cuda_client
X, y = make_categorical(client, 1, 30, 13) X, y = make_categorical(client, 1, 30, 13)