Remove omp_get_max_threads in objective. (#7589)

This commit is contained in:
Jiaming Yuan 2022-01-24 04:35:49 +08:00 committed by GitHub
parent 5817840858
commit 6967ef7267
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 76 additions and 74 deletions

View File

@ -25,7 +25,7 @@ namespace xgboost {
/*! \brief interface of objective function */ /*! \brief interface of objective function */
class ObjFunction : public Configurable { class ObjFunction : public Configurable {
protected: protected:
GenericParameter const* tparam_; GenericParameter const* ctx_;
public: public:
/*! \brief virtual destructor */ /*! \brief virtual destructor */

View File

@ -1,22 +1,21 @@
/*! /*!
* Copyright 2018 XGBoost contributors * Copyright 2018-2022 XGBoost contributors
*/ */
#ifndef XGBOOST_COMMON_TRANSFORM_H_ #ifndef XGBOOST_COMMON_TRANSFORM_H_
#define XGBOOST_COMMON_TRANSFORM_H_ #define XGBOOST_COMMON_TRANSFORM_H_
#include <dmlc/omp.h>
#include <dmlc/common.h> #include <dmlc/common.h>
#include <dmlc/omp.h>
#include <xgboost/data.h> #include <xgboost/data.h>
#include <type_traits> // enable_if
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <type_traits> // enable_if
#include "xgboost/host_device_vector.h"
#include "xgboost/span.h"
#include "common.h" #include "common.h"
#include "threading_utils.h" #include "threading_utils.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/span.h"
#if defined (__CUDACC__) #if defined (__CUDACC__)
#include "device_helpers.cuh" #include "device_helpers.cuh"
@ -61,10 +60,8 @@ class Transform {
template <typename Functor> template <typename Functor>
struct Evaluator { struct Evaluator {
public: public:
Evaluator(Functor func, Range range, int device, bool shard) : Evaluator(Functor func, Range range, int32_t n_threads, int32_t device_idx)
func_(func), range_{std::move(range)}, : func_(func), range_{std::move(range)}, n_threads_{n_threads}, device_{device_idx} {}
shard_{shard},
device_{device} {}
/*! /*!
* \brief Evaluate the functor with input pointers to HostDeviceVector. * \brief Evaluate the functor with input pointers to HostDeviceVector.
@ -134,9 +131,7 @@ class Transform {
template <typename std::enable_if<CompiledWithCuda>::type* = nullptr, template <typename std::enable_if<CompiledWithCuda>::type* = nullptr,
typename... HDV> typename... HDV>
void LaunchCUDA(Functor _func, HDV*... _vectors) const { void LaunchCUDA(Functor _func, HDV*... _vectors) const {
if (shard_) {
UnpackShard(device_, _vectors...); UnpackShard(device_, _vectors...);
}
size_t range_size = *range_.end() - *range_.begin(); size_t range_size = *range_.end() - *range_.begin();
@ -167,12 +162,10 @@ class Transform {
#endif // defined(__CUDACC__) #endif // defined(__CUDACC__)
template <typename... HDV> template <typename... HDV>
void LaunchCPU(Functor func, HDV*... vectors) const { void LaunchCPU(Functor func, HDV *...vectors) const {
omp_ulong end = static_cast<omp_ulong>(*(range_.end())); omp_ulong end = static_cast<omp_ulong>(*(range_.end()));
SyncHost(vectors...); SyncHost(vectors...);
ParallelFor(end, [&](omp_ulong idx) { ParallelFor(end, n_threads_, [&](omp_ulong idx) { func(idx, UnpackHDV(vectors)...); });
func(idx, UnpackHDV(vectors)...);
});
} }
private: private:
@ -180,9 +173,8 @@ class Transform {
Functor func_; Functor func_;
/*! \brief Range object specifying parallel threads index range. */ /*! \brief Range object specifying parallel threads index range. */
Range range_; Range range_;
/*! \brief Whether sharding for vectors is required. */ int32_t n_threads_;
bool shard_; int32_t device_;
int device_;
}; };
public: public:
@ -195,14 +187,13 @@ class Transform {
* \param func A callable object, accepting a size_t thread index, * \param func A callable object, accepting a size_t thread index,
* followed by a set of Span classes. * followed by a set of Span classes.
* \param range Range object specifying parallel threads index range. * \param range Range object specifying parallel threads index range.
* \param device Specify GPU to use. * \param n_threads Number of CPU threads
* \param shard Whether Shard for HostDeviceVector is needed. * \param device_idx GPU device ordinal
*/ */
template <typename Functor> template <typename Functor>
static Evaluator<Functor> Init(Functor func, Range const range, static Evaluator<Functor> Init(Functor func, Range const range, int32_t n_threads,
int device, int32_t device_idx) {
bool const shard = true) { return Evaluator<Functor>{func, std::move(range), n_threads, device_idx};
return Evaluator<Functor> {func, std::move(range), device, shard};
} }
}; };

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2019-2020 by Contributors * Copyright 2019-2022 by Contributors
* \file aft_obj.cu * \file aft_obj.cu
* \brief Definition of AFT loss for survival analysis. * \brief Definition of AFT loss for survival analysis.
* \author Avinash Barnwal, Hyunsu Cho and Toby Hocking * \author Avinash Barnwal, Hyunsu Cho and Toby Hocking
@ -65,7 +65,7 @@ class AFTObj : public ObjFunction {
const bst_float w = is_null_weight ? 1.0f : _weights[_idx]; const bst_float w = is_null_weight ? 1.0f : _weights[_idx];
_out_gpair[_idx] = GradientPair(grad * w, hess * w); _out_gpair[_idx] = GradientPair(grad * w, hess * w);
}, },
common::Range{0, static_cast<int64_t>(ndata)}, device).Eval( common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(), device).Eval(
out_gpair, &preds, &info.labels_lower_bound_, &info.labels_upper_bound_, out_gpair, &preds, &info.labels_lower_bound_, &info.labels_upper_bound_,
&info.weights_); &info.weights_);
} }
@ -78,7 +78,7 @@ class AFTObj : public ObjFunction {
CHECK_EQ(info.labels_lower_bound_.Size(), ndata); CHECK_EQ(info.labels_lower_bound_.Size(), ndata);
CHECK_EQ(info.labels_upper_bound_.Size(), ndata); CHECK_EQ(info.labels_upper_bound_.Size(), ndata);
out_gpair->Resize(ndata); out_gpair->Resize(ndata);
const int device = tparam_->gpu_id; const int device = ctx_->gpu_id;
const float aft_loss_distribution_scale = param_.aft_loss_distribution_scale; const float aft_loss_distribution_scale = param_.aft_loss_distribution_scale;
const bool is_null_weight = info.weights_.Size() == 0; const bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) { if (!is_null_weight) {
@ -109,7 +109,8 @@ class AFTObj : public ObjFunction {
common::Transform<>::Init( common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) { [] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = exp(_preds[_idx]); _preds[_idx] = exp(_preds[_idx]);
}, common::Range{0, static_cast<int64_t>(io_preds->Size())}, },
common::Range{0, static_cast<int64_t>(io_preds->Size())}, this->ctx_->Threads(),
io_preds->DeviceIdx()) io_preds->DeviceIdx())
.Eval(io_preds); .Eval(io_preds);
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2018-2019 by Contributors * Copyright 2018-2022 by XGBoost Contributors
* \file hinge.cc * \file hinge.cc
* \brief Provides an implementation of the hinge loss function * \brief Provides an implementation of the hinge loss function
* \author Henry Gouk * \author Henry Gouk
@ -65,8 +65,8 @@ class HingeObj : public ObjFunction {
} }
_out_gpair[_idx] = GradientPair(g, h); _out_gpair[_idx] = GradientPair(g, h);
}, },
common::Range{0, static_cast<int64_t>(ndata)}, common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(),
tparam_->gpu_id).Eval( ctx_->gpu_id).Eval(
out_gpair, &preds, info.labels.Data(), &info.weights_); out_gpair, &preds, info.labels.Data(), &info.weights_);
} }
@ -75,7 +75,7 @@ class HingeObj : public ObjFunction {
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) { [] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0; _preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0;
}, },
common::Range{0, static_cast<int64_t>(io_preds->Size()), 1}, common::Range{0, static_cast<int64_t>(io_preds->Size()), 1}, this->ctx_->Threads(),
io_preds->DeviceIdx()) io_preds->DeviceIdx())
.Eval(io_preds); .Eval(io_preds);
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2015-2018 by Contributors * Copyright 2015-2022 by XGBoost Contributors
* \file multi_class.cc * \file multi_class.cc
* \brief Definition of multi-class classification objectives. * \brief Definition of multi-class classification objectives.
* \author Tianqi Chen * \author Tianqi Chen
@ -68,7 +68,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
const int nclass = param_.num_class; const int nclass = param_.num_class;
const auto ndata = static_cast<int64_t>(preds.Size() / nclass); const auto ndata = static_cast<int64_t>(preds.Size() / nclass);
auto device = tparam_->gpu_id; auto device = ctx_->gpu_id;
out_gpair->SetDevice(device); out_gpair->SetDevice(device);
info.labels.SetDevice(device); info.labels.SetDevice(device);
info.weights_.SetDevice(device); info.weights_.SetDevice(device);
@ -114,7 +114,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
p = label == k ? p - 1.0f : p; p = label == k ? p - 1.0f : p;
gpair[idx * nclass + k] = GradientPair(p * wt, h); gpair[idx * nclass + k] = GradientPair(p * wt, h);
} }
}, common::Range{0, ndata}, device, false) }, common::Range{0, ndata}, ctx_->Threads(), device)
.Eval(out_gpair, info.labels.Data(), &preds, &info.weights_, &label_correct_); .Eval(out_gpair, info.labels.Data(), &preds, &info.weights_, &label_correct_);
std::vector<int>& label_correct_h = label_correct_.HostVector(); std::vector<int>& label_correct_h = label_correct_.HostVector();
@ -146,7 +146,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
_preds.subspan(_idx * nclass, nclass); _preds.subspan(_idx * nclass, nclass);
common::Softmax(point.begin(), point.end()); common::Softmax(point.begin(), point.end());
}, },
common::Range{0, ndata}, device) common::Range{0, ndata}, this->ctx_->Threads(), device)
.Eval(io_preds); .Eval(io_preds);
} else { } else {
io_preds->SetDevice(device); io_preds->SetDevice(device);
@ -162,7 +162,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
common::FindMaxIndex(point.cbegin(), point.cend()) - common::FindMaxIndex(point.cbegin(), point.cend()) -
point.cbegin(); point.cbegin();
}, },
common::Range{0, ndata}, device, false) common::Range{0, ndata}, this->ctx_->Threads(), device)
.Eval(io_preds, &max_preds); .Eval(io_preds, &max_preds);
io_preds->Resize(max_preds.Size()); io_preds->Resize(max_preds.Size());
io_preds->Copy(max_preds); io_preds->Copy(max_preds);

View File

@ -27,7 +27,7 @@ ObjFunction* ObjFunction::Create(const std::string& name, GenericParameter const
<< ss.str(); << ss.str();
} }
auto pobj = (e->body)(); auto pobj = (e->body)();
pobj->tparam_ = tparam; pobj->ctx_ = tparam;
return pobj; return pobj;
} }

View File

@ -773,7 +773,7 @@ class LambdaRankObj : public ObjFunction {
#if defined(__CUDACC__) #if defined(__CUDACC__)
// Check if we have a GPU assignment; else, revert back to CPU // Check if we have a GPU assignment; else, revert back to CPU
auto device = tparam_->gpu_id; auto device = ctx_->gpu_id;
if (device >= 0) { if (device >= 0) {
ComputeGradientsOnGPU(preds, info, iter, out_gpair, gptr); ComputeGradientsOnGPU(preds, info, iter, out_gpair, gptr);
} else { } else {
@ -909,7 +909,7 @@ class LambdaRankObj : public ObjFunction {
const std::vector<unsigned> &gptr) { const std::vector<unsigned> &gptr) {
LOG(DEBUG) << "Computing " << LambdaWeightComputerT::Name() << " gradients on GPU."; LOG(DEBUG) << "Computing " << LambdaWeightComputerT::Name() << " gradients on GPU.";
auto device = tparam_->gpu_id; auto device = ctx_->gpu_id;
dh::safe_cuda(cudaSetDevice(device)); dh::safe_cuda(cudaSetDevice(device));
bst_float weight_normalization_factor = ComputeWeightNormalizationFactor(info, gptr); bst_float weight_normalization_factor = ComputeWeightNormalizationFactor(info, gptr);

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2015-2019 by Contributors * Copyright 2015-2022 by XGBoost Contributors
* \file regression_obj.cu * \file regression_obj.cu
* \brief Definition of single-value regression and classification objectives. * \brief Definition of single-value regression and classification objectives.
* \author Tianqi Chen, Kailong Chen * \author Tianqi Chen, Kailong Chen
@ -70,7 +70,7 @@ class RegLossObj : public ObjFunction {
<< "Loss: " << Loss::Name(); << "Loss: " << Loss::Name();
size_t const ndata = preds.Size(); size_t const ndata = preds.Size();
out_gpair->Resize(ndata); out_gpair->Resize(ndata);
auto device = tparam_->gpu_id; auto device = ctx_->gpu_id;
additional_input_.HostVector().begin()[0] = 1; // Fill the label_correct flag additional_input_.HostVector().begin()[0] = 1; // Fill the label_correct flag
bool is_null_weight = info.weights_.Size() == 0; bool is_null_weight = info.weights_.Size() == 0;
@ -82,7 +82,7 @@ class RegLossObj : public ObjFunction {
additional_input_.HostVector().begin()[1] = scale_pos_weight; additional_input_.HostVector().begin()[1] = scale_pos_weight;
additional_input_.HostVector().begin()[2] = is_null_weight; additional_input_.HostVector().begin()[2] = is_null_weight;
const size_t nthreads = tparam_->Threads(); const size_t nthreads = ctx_->Threads();
bool on_device = device >= 0; bool on_device = device >= 0;
// On CPU we run the transformation each thread processing a contigious block of data // On CPU we run the transformation each thread processing a contigious block of data
// for better performance. // for better performance.
@ -121,7 +121,7 @@ class RegLossObj : public ObjFunction {
Loss::SecondOrderGradient(p, label) * w); Loss::SecondOrderGradient(p, label) * w);
} }
}, },
common::Range{0, static_cast<int64_t>(n_data_blocks)}, device) common::Range{0, static_cast<int64_t>(n_data_blocks)}, nthreads, device)
.Eval(&additional_input_, out_gpair, &preds, info.labels.Data(), .Eval(&additional_input_, out_gpair, &preds, info.labels.Data(),
&info.weights_); &info.weights_);
@ -140,7 +140,8 @@ class RegLossObj : public ObjFunction {
common::Transform<>::Init( common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<float> _preds) { [] XGBOOST_DEVICE(size_t _idx, common::Span<float> _preds) {
_preds[_idx] = Loss::PredTransform(_preds[_idx]); _preds[_idx] = Loss::PredTransform(_preds[_idx]);
}, common::Range{0, static_cast<int64_t>(io_preds->Size())}, },
common::Range{0, static_cast<int64_t>(io_preds->Size())}, this->ctx_->Threads(),
io_preds->DeviceIdx()) io_preds->DeviceIdx())
.Eval(io_preds); .Eval(io_preds);
} }
@ -228,7 +229,7 @@ class PoissonRegression : public ObjFunction {
CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided"; CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided";
size_t const ndata = preds.Size(); size_t const ndata = preds.Size();
out_gpair->Resize(ndata); out_gpair->Resize(ndata);
auto device = tparam_->gpu_id; auto device = ctx_->gpu_id;
label_correct_.Resize(1); label_correct_.Resize(1);
label_correct_.Fill(1); label_correct_.Fill(1);
@ -254,7 +255,7 @@ class PoissonRegression : public ObjFunction {
_out_gpair[_idx] = GradientPair{(expf(p) - y) * w, _out_gpair[_idx] = GradientPair{(expf(p) - y) * w,
expf(p + max_delta_step) * w}; expf(p + max_delta_step) * w};
}, },
common::Range{0, static_cast<int64_t>(ndata)}, device).Eval( common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(), device).Eval(
&label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_); &label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_);
// copy "label correct" flags back to host // copy "label correct" flags back to host
std::vector<int>& label_correct_h = label_correct_.HostVector(); std::vector<int>& label_correct_h = label_correct_.HostVector();
@ -269,7 +270,7 @@ class PoissonRegression : public ObjFunction {
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) { [] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = expf(_preds[_idx]); _preds[_idx] = expf(_preds[_idx]);
}, },
common::Range{0, static_cast<int64_t>(io_preds->Size())}, common::Range{0, static_cast<int64_t>(io_preds->Size())}, this->ctx_->Threads(),
io_preds->DeviceIdx()) io_preds->DeviceIdx())
.Eval(io_preds); .Eval(io_preds);
} }
@ -381,7 +382,7 @@ class CoxRegression : public ObjFunction {
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override { void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
std::vector<bst_float> &preds = io_preds->HostVector(); std::vector<bst_float> &preds = io_preds->HostVector();
const long ndata = static_cast<long>(preds.size()); // NOLINT(*) const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
common::ParallelFor(ndata, [&](long j) { // NOLINT(*) common::ParallelFor(ndata, ctx_->Threads(), [&](long j) { // NOLINT(*)
preds[j] = std::exp(preds[j]); preds[j] = std::exp(preds[j]);
}); });
} }
@ -423,7 +424,7 @@ class GammaRegression : public ObjFunction {
CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided"; CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided";
const size_t ndata = preds.Size(); const size_t ndata = preds.Size();
auto device = tparam_->gpu_id; auto device = ctx_->gpu_id;
out_gpair->Resize(ndata); out_gpair->Resize(ndata);
label_correct_.Resize(1); label_correct_.Resize(1);
label_correct_.Fill(1); label_correct_.Fill(1);
@ -448,7 +449,7 @@ class GammaRegression : public ObjFunction {
} }
_out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w); _out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w);
}, },
common::Range{0, static_cast<int64_t>(ndata)}, device).Eval( common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(), device).Eval(
&label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_); &label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_);
// copy "label correct" flags back to host // copy "label correct" flags back to host
@ -464,7 +465,7 @@ class GammaRegression : public ObjFunction {
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) { [] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = expf(_preds[_idx]); _preds[_idx] = expf(_preds[_idx]);
}, },
common::Range{0, static_cast<int64_t>(io_preds->Size())}, common::Range{0, static_cast<int64_t>(io_preds->Size())}, this->ctx_->Threads(),
io_preds->DeviceIdx()) io_preds->DeviceIdx())
.Eval(io_preds); .Eval(io_preds);
} }
@ -525,7 +526,7 @@ class TweedieRegression : public ObjFunction {
const size_t ndata = preds.Size(); const size_t ndata = preds.Size();
out_gpair->Resize(ndata); out_gpair->Resize(ndata);
auto device = tparam_->gpu_id; auto device = ctx_->gpu_id;
label_correct_.Resize(1); label_correct_.Resize(1);
label_correct_.Fill(1); label_correct_.Fill(1);
@ -555,7 +556,7 @@ class TweedieRegression : public ObjFunction {
std::exp((1 - rho) * p) + (2 - rho) * expf((2 - rho) * p); std::exp((1 - rho) * p) + (2 - rho) * expf((2 - rho) * p);
_out_gpair[_idx] = GradientPair(grad * w, hess * w); _out_gpair[_idx] = GradientPair(grad * w, hess * w);
}, },
common::Range{0, static_cast<int64_t>(ndata), 1}, device) common::Range{0, static_cast<int64_t>(ndata), 1}, this->ctx_->Threads(), device)
.Eval(&label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_); .Eval(&label_correct_, out_gpair, &preds, info.labels.Data(), &info.weights_);
// copy "label correct" flags back to host // copy "label correct" flags back to host
@ -571,7 +572,7 @@ class TweedieRegression : public ObjFunction {
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) { [] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = expf(_preds[_idx]); _preds[_idx] = expf(_preds[_idx]);
}, },
common::Range{0, static_cast<int64_t>(io_preds->Size())}, common::Range{0, static_cast<int64_t>(io_preds->Size())}, this->ctx_->Threads(),
io_preds->DeviceIdx()) io_preds->DeviceIdx())
.Eval(io_preds); .Eval(io_preds);
} }

View File

@ -176,7 +176,7 @@ class TreeEvaluator {
lower[rightid] = mid; lower[rightid] = mid;
} }
}, },
common::Range(0, 1), device_, false) common::Range(0, 1), 1, device_)
.Eval(&lower_bounds_, &upper_bounds_, &monotone_); .Eval(&lower_bounds_, &upper_bounds_, &monotone_);
} }
}; };

View File

@ -1,3 +1,6 @@
/*!
* Copyright 2018-2022 by XGBoost Contributors
*/
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/base.h> #include <xgboost/base.h>
#include <xgboost/span.h> #include <xgboost/span.h>
@ -42,7 +45,7 @@ TEST(Transform, DeclareUnifiedTest(Basic)) {
out_vec.Fill(0); out_vec.Fill(0);
Transform<>::Init(TestTransformRange<bst_float>{}, Transform<>::Init(TestTransformRange<bst_float>{},
Range{0, static_cast<Range::DifferenceType>(size)}, Range{0, static_cast<Range::DifferenceType>(size)}, common::OmpGetNumThreads(0),
TRANSFORM_GPU) TRANSFORM_GPU)
.Eval(&out_vec, &in_vec); .Eval(&out_vec, &in_vec);
std::vector<bst_float> res = out_vec.HostVector(); std::vector<bst_float> res = out_vec.HostVector();
@ -55,11 +58,14 @@ TEST(TransformDeathTest, Exception) {
size_t const kSize {16}; size_t const kSize {16};
std::vector<bst_float> h_in(kSize); std::vector<bst_float> h_in(kSize);
const HostDeviceVector<bst_float> in_vec{h_in, -1}; const HostDeviceVector<bst_float> in_vec{h_in, -1};
EXPECT_DEATH({ EXPECT_DEATH(
{
Transform<>::Init([](size_t idx, common::Span<float const> _in) { _in[idx + 1]; }, Transform<>::Init([](size_t idx, common::Span<float const> _in) { _in[idx + 1]; },
Range(0, static_cast<Range::DifferenceType>(kSize)), -1) Range(0, static_cast<Range::DifferenceType>(kSize)),
common::OmpGetNumThreads(0), -1)
.Eval(&in_vec); .Eval(&in_vec);
}, ""); },
"");
} }
#endif #endif

View File

@ -1,4 +1,7 @@
// This converts all tests from CPU to GPU. /*!
* Copyright 2018-2022 by XGBoost Contributors
* \brief This converts all tests from CPU to GPU.
*/
#include "test_transform_range.cc" #include "test_transform_range.cc"
#if defined(XGBOOST_USE_NCCL) #if defined(XGBOOST_USE_NCCL)
@ -22,8 +25,8 @@ TEST(Transform, MGPU_SpecifiedGpuId) { // NOLINT
const HostDeviceVector<bst_float> in_vec {h_in, device}; const HostDeviceVector<bst_float> in_vec {h_in, device};
HostDeviceVector<bst_float> out_vec {h_out, device}; HostDeviceVector<bst_float> out_vec {h_out, device};
ASSERT_NO_THROW( ASSERT_NO_THROW(Transform<>::Init(TestTransformRange<bst_float>{}, Range{0, size},
Transform<>::Init(TestTransformRange<bst_float>{}, Range{0, size}, device) common::OmpGetNumThreads(0), device)
.Eval(&out_vec, &in_vec)); .Eval(&out_vec, &in_vec));
std::vector<bst_float> res = out_vec.HostVector(); std::vector<bst_float> res = out_vec.HostVector();
ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin())); ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin()));