Pass DMatrix into metric for caching. (#8790)

This commit is contained in:
Jiaming Yuan 2023-02-13 22:15:05 +08:00 committed by GitHub
parent 31d3ec07af
commit 81b2ee1153
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 95 additions and 70 deletions

View File

@ -32,6 +32,8 @@ class DMatrixCache {
CacheT& Value() { return *value; }
};
static constexpr std::size_t DefaultSize() { return 32; }
protected:
std::unordered_map<DMatrix const*, Item> container_;
std::queue<DMatrix const*> queue_;

View File

@ -54,12 +54,15 @@ class Metric : public Configurable {
out["name"] = String(this->Name());
}
/*!
* \brief evaluate a specific metric
* \param preds prediction
* \param info information, including label etc.
/**
* \brief Evaluate a metric with DMatrix as input.
*
* \param preds Prediction
* \param p_fmat DMatrix that contains related information like labels.
*/
virtual double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) = 0;
virtual double Evaluate(HostDeviceVector<bst_float> const& preds,
std::shared_ptr<DMatrix> p_fmat) = 0;
/*! \return name of metric */
virtual const char* Name() const = 0;
/*! \brief virtual destructor */

View File

@ -1339,7 +1339,7 @@ class LearnerImpl : public LearnerIO {
obj_->EvalTransform(&out);
for (auto& ev : metrics_) {
os << '\t' << data_names[i] << '-' << ev->Name() << ':' << ev->Eval(out, m->Info());
os << '\t' << data_names[i] << '-' << ev->Name() << ':' << ev->Evaluate(out, m);
}
}

View File

@ -16,6 +16,7 @@
#include "../common/math.h"
#include "../common/optional_weight.h" // OptionalWeights
#include "metric_common.h" // MetricNoCache
#include "xgboost/host_device_vector.h"
#include "xgboost/linalg.h"
#include "xgboost/metric.h"
@ -253,7 +254,7 @@ std::pair<double, uint32_t> RankingAUC(std::vector<float> const &predts,
}
template <typename Curve>
class EvalAUC : public Metric {
class EvalAUC : public MetricNoCache {
double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info) override {
double auc {0};
if (ctx_->gpu_id != Context::kCpuId) {

View File

@ -11,7 +11,7 @@
#include <cmath>
#include "../collective/communicator-inl.h"
#include "../common/common.h"
#include "../common/common.h" // MetricNoCache
#include "../common/math.h"
#include "../common/optional_weight.h" // OptionalWeights
#include "../common/pseudo_huber.h"
@ -23,8 +23,8 @@
#if defined(XGBOOST_USE_CUDA)
#include <thrust/execution_policy.h> // thrust::cuda::par
#include <thrust/functional.h> // thrust::plus<>
#include <thrust/transform_reduce.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/transform_reduce.h>
#include "../common/device_helpers.cuh"
#endif // XGBOOST_USE_CUDA
@ -167,7 +167,7 @@ struct EvalRowLogLoss {
}
};
class PseudoErrorLoss : public Metric {
class PseudoErrorLoss : public MetricNoCache {
PesudoHuberParam param_;
public:
@ -339,7 +339,7 @@ struct EvalTweedieNLogLik {
* \tparam Derived the name of subclass
*/
template <typename Policy>
struct EvalEWiseBase : public Metric {
struct EvalEWiseBase : public MetricNoCache {
EvalEWiseBase() = default;
explicit EvalEWiseBase(char const* policy_param) : policy_{policy_param} {}

View File

@ -53,20 +53,21 @@ Metric::Create(const std::string& name, Context const* ctx) {
return metric;
}
Metric *
GPUMetric::CreateGPUMetric(const std::string& name, Context const* ctx) {
GPUMetric* GPUMetric::CreateGPUMetric(const std::string& name, Context const* ctx) {
auto metric = CreateMetricImpl<MetricGPUReg>(name);
if (metric == nullptr) {
LOG(WARNING) << "Cannot find a GPU metric builder for metric " << name
<< ". Resorting to the CPU builder";
return metric;
return nullptr;
}
// Narrowing reference only for the compiler to allow assignment to a base class member.
// As such, using this narrowed reference to refer to derived members will be an illegal op.
// This is moot, as this type is stateless.
static_cast<GPUMetric *>(metric)->ctx_ = ctx;
return metric;
auto casted = static_cast<GPUMetric*>(metric);
CHECK(casted);
casted->ctx_ = ctx;
return casted;
}
} // namespace xgboost

View File

@ -13,12 +13,21 @@
namespace xgboost {
struct Context;
// Metric that doesn't need to cache anything based on input data.
class MetricNoCache : public Metric {
public:
virtual double Eval(HostDeviceVector<float> const &predts, MetaInfo const &info) = 0;
double Evaluate(HostDeviceVector<float> const &predts, std::shared_ptr<DMatrix> p_fmat) final {
return this->Eval(predts, p_fmat->Info());
}
};
// This creates a GPU metric instance dynamically and adds it to the GPU metric registry, if not
// present already. This is created when there is a device ordinal present and if xgboost
// is compiled with CUDA support
struct GPUMetric : Metric {
static Metric *CreateGPUMetric(const std::string &name, Context const *tparam);
struct GPUMetric : public MetricNoCache {
static GPUMetric *CreateGPUMetric(const std::string &name, Context const *tparam);
};
/*!

View File

@ -9,16 +9,16 @@
#include <atomic>
#include <cmath>
#include "metric_common.h"
#include "../collective/communicator-inl.h"
#include "../common/math.h"
#include "../common/threading_utils.h"
#include "metric_common.h" // MetricNoCache
#if defined(XGBOOST_USE_CUDA)
#include <thrust/execution_policy.h> // thrust::cuda::par
#include <thrust/functional.h> // thrust::plus<>
#include <thrust/transform_reduce.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/transform_reduce.h>
#include "../common/device_helpers.cuh"
#endif // XGBOOST_USE_CUDA
@ -162,7 +162,7 @@ class MultiClassMetricsReduction {
* \tparam Derived the name of subclass
*/
template<typename Derived>
struct EvalMClassBase : public Metric {
struct EvalMClassBase : public MetricNoCache {
double Eval(const HostDeviceVector<float> &preds, const MetaInfo &info) override {
if (info.labels.Size() == 0) {
CHECK_EQ(preds.Size(), 0);

View File

@ -92,7 +92,7 @@ namespace metric {
DMLC_REGISTRY_FILE_TAG(rank_metric);
/*! \brief AMS: also records best threshold */
struct EvalAMS : public Metric {
struct EvalAMS : public MetricNoCache {
public:
explicit EvalAMS(const char* param) {
CHECK(param != nullptr) // NOLINT
@ -155,10 +155,10 @@ struct EvalAMS : public Metric {
};
/*! \brief Evaluate rank list */
struct EvalRank : public Metric, public EvalRankConfig {
struct EvalRank : public MetricNoCache, public EvalRankConfig {
private:
// This is used to compute the ranking metrics on the GPU - for training jobs that run on the GPU.
std::unique_ptr<xgboost::Metric> rank_gpu_;
std::unique_ptr<MetricNoCache> rank_gpu_;
public:
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
@ -322,7 +322,7 @@ struct EvalMAP : public EvalRank {
};
/*! \brief Cox: Partial likelihood of the Cox proportional hazards model */
struct EvalCox : public Metric {
struct EvalCox : public MetricNoCache {
public:
EvalCox() = default;
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {

View File

@ -1,21 +1,20 @@
/**
* Copyright 2020-2023 by XGBoost Contributors
* \file rank_metric.cu
* \brief prediction rank based metrics.
* \author Kailong Chen, Tianqi Chen
*/
#include <dmlc/registry.h>
#include <thrust/iterator/counting_iterator.h> // make_counting_iterator
#include <thrust/reduce.h> // reduce
#include <xgboost/metric.h>
#include <xgboost/host_device_vector.h>
#include <thrust/iterator/discard_iterator.h>
#include <vector>
#include <cstddef> // std::size_t
#include <memory> // std::shared_ptr
#include "../common/cuda_context.cuh" // CUDAContext
#include "metric_common.h"
#include "../common/math.h"
#include "../common/device_helpers.cuh"
#include "xgboost/base.h" // XGBOOST_DEVICE
#include "xgboost/context.h" // Context
#include "xgboost/data.h" // MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
namespace xgboost {
namespace metric {

View File

@ -10,15 +10,14 @@
#include <memory>
#include <vector>
#include "xgboost/json.h"
#include "xgboost/metric.h"
#include "xgboost/host_device_vector.h"
#include "metric_common.h"
#include "../collective/communicator-inl.h"
#include "../common/math.h"
#include "../common/survival_util.h"
#include "../common/threading_utils.h"
#include "../common/threading_utils.h"
#include "metric_common.h" // MetricNoCache
#include "xgboost/host_device_vector.h"
#include "xgboost/json.h"
#include "xgboost/metric.h"
#if defined(XGBOOST_USE_CUDA)
#include <thrust/execution_policy.h> // thrust::cuda::par
@ -194,10 +193,9 @@ struct EvalAFTNLogLik {
AFTParam param_;
};
template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
explicit EvalEWiseSurvivalBase(Context const *ctx) {
ctx_ = ctx;
}
template <typename Policy>
struct EvalEWiseSurvivalBase : public MetricNoCache {
explicit EvalEWiseSurvivalBase(Context const* ctx) { ctx_ = ctx; }
EvalEWiseSurvivalBase() = default;
void Configure(const Args& args) override {
@ -230,7 +228,7 @@ template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
// This class exists because we want to perform dispatch according to the distribution type at
// configuration time, not at prediction time.
struct AFTNLogLikDispatcher : public Metric {
struct AFTNLogLikDispatcher : public MetricNoCache {
const char* Name() const override {
return "aft-nloglik";
}
@ -270,7 +268,7 @@ struct AFTNLogLikDispatcher : public Metric {
private:
AFTParam param_;
std::unique_ptr<Metric> metric_;
std::unique_ptr<MetricNoCache> metric_;
};
XGBOOST_REGISTER_METRIC(AFTNLogLik, "aft-nloglik")

View File

@ -156,14 +156,15 @@ double GetMultiMetricEval(xgboost::Metric* metric,
xgboost::linalg::Tensor<float, 2> const& labels,
std::vector<xgboost::bst_float> weights,
std::vector<xgboost::bst_uint> groups) {
xgboost::MetaInfo info;
std::shared_ptr<xgboost::DMatrix> p_fmat{xgboost::RandomDataGenerator{0, 0, 0}.GenerateDMatrix()};
auto& info = p_fmat->Info();
info.num_row_ = labels.Shape(0);
info.labels.Reshape(labels.Shape()[0], labels.Shape()[1]);
info.labels.Data()->Copy(*labels.Data());
info.weights_.HostVector() = weights;
info.group_ptr_ = groups;
return metric->Eval(preds, info);
return metric->Evaluate(preds, p_fmat);
}
namespace xgboost {
@ -661,4 +662,4 @@ void DeleteRMMResource(RMMAllocator*) {}
RMMAllocatorPtr SetUpRMMResourceForCppTests(int, char**) { return {nullptr, DeleteRMMResource}; }
#endif // !defined(XGBOOST_USE_RMM) || XGBOOST_USE_RMM != 1
} // namespace xgboost
} // namespace xgboost

View File

@ -301,6 +301,11 @@ class RandomDataGenerator {
std::shared_ptr<DMatrix> GenerateQuantileDMatrix();
};
// Generate an empty DMatrix, mostly for its meta info.
inline std::shared_ptr<DMatrix> EmptyDMatrix() {
return RandomDataGenerator{0, 0, 0.0}.GenerateDMatrix();
}
inline std::vector<float>
GenerateRandomCategoricalSingleColumn(int n, size_t num_categories) {
std::vector<float> x(n);

View File

@ -20,12 +20,13 @@ TEST(Metric, DeclareUnifiedTest(BinaryAUC)) {
EXPECT_NEAR(GetMetricEval(metric, {1, 0, 0}, {0, 0, 1}), 0.25f, 1e-10);
// Invalid dataset
MetaInfo info;
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
info.labels = linalg::Tensor<float, 2>{{0.0f, 0.0f}, {2}, -1};
float auc = metric->Eval({1, 1}, info);
float auc = metric->Evaluate({1, 1}, p_fmat);
ASSERT_TRUE(std::isnan(auc));
*info.labels.Data() = HostDeviceVector<float>{};
auc = metric->Eval(HostDeviceVector<float>{}, info);
auc = metric->Evaluate(HostDeviceVector<float>{}, p_fmat);
ASSERT_TRUE(std::isnan(auc));
EXPECT_NEAR(GetMetricEval(metric, {0, 1, 0, 1}, {0, 1, 0, 1}), 1.0f, 1e-10);

View File

@ -19,7 +19,8 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device)
HostDeviceVector<float> predts;
size_t n_samples = 2048;
MetaInfo info;
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
info.labels.Reshape(n_samples, 1);
info.num_row_ = n_samples;
auto &h_labels = info.labels.Data()->HostVector();
@ -36,9 +37,9 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device)
h_labels[i] = dist(&lcg);
}
auto result = metric->Eval(predts, info);
auto result = metric->Evaluate(predts, p_fmat);
for (size_t i = 0; i < 8; ++i) {
ASSERT_EQ(metric->Eval(predts, info), result);
ASSERT_EQ(metric->Evaluate(predts, p_fmat), result);
}
}
} // anonymous namespace

View File

@ -10,7 +10,8 @@ inline void CheckDeterministicMetricMultiClass(StringView name, int32_t device)
std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &ctx)};
HostDeviceVector<float> predts;
MetaInfo info;
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
auto &h_predts = predts.HostVector();
SimpleLCG lcg;
@ -35,9 +36,9 @@ inline void CheckDeterministicMetricMultiClass(StringView name, int32_t device)
}
}
auto result = metric->Eval(predts, info);
auto result = metric->Evaluate(predts, p_fmat);
for (size_t i = 0; i < 8; ++i) {
ASSERT_EQ(metric->Eval(predts, info), result);
ASSERT_EQ(metric->Evaluate(predts, p_fmat), result);
}
}
} // namespace xgboost

View File

@ -18,7 +18,8 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device)
metric->Configure(Args{});
HostDeviceVector<float> predts;
MetaInfo info;
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
auto &h_predts = predts.HostVector();
SimpleLCG lcg;
@ -40,9 +41,9 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device)
h_upper[i] = 10;
}
auto result = metric->Eval(predts, info);
auto result = metric->Evaluate(predts, p_fmat);
for (size_t i = 0; i < 8; ++i) {
ASSERT_EQ(metric->Eval(predts, info), result);
ASSERT_EQ(metric->Evaluate(predts, p_fmat), result);
}
}
} // anonymous namespace
@ -54,7 +55,8 @@ TEST(Metric, DeclareUnifiedTest(AFTNegLogLik)) {
* Test aggregate output from the AFT metric over a small test data set.
* This is unlike AFTLoss.* tests, which verify metric values over individual data points.
**/
MetaInfo info;
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
info.num_row_ = 4;
info.labels_lower_bound_.HostVector()
= { 100.0f, 0.0f, 60.0f, 16.0f };
@ -72,14 +74,15 @@ TEST(Metric, DeclareUnifiedTest(AFTNegLogLik)) {
std::unique_ptr<Metric> metric(Metric::Create("aft-nloglik", &ctx));
metric->Configure({ {"aft_loss_distribution", test_case.dist_type},
{"aft_loss_distribution_scale", "1.0"} });
EXPECT_NEAR(metric->Eval(preds, info), test_case.reference_value, 1e-4);
EXPECT_NEAR(metric->Evaluate(preds, p_fmat), test_case.reference_value, 1e-4);
}
}
TEST(Metric, DeclareUnifiedTest(IntervalRegressionAccuracy)) {
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
MetaInfo info;
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
info.num_row_ = 4;
info.labels_lower_bound_.HostVector() = { 20.0f, 0.0f, 60.0f, 16.0f };
info.labels_upper_bound_.HostVector() = { 80.0f, 20.0f, 80.0f, 200.0f };
@ -87,15 +90,15 @@ TEST(Metric, DeclareUnifiedTest(IntervalRegressionAccuracy)) {
HostDeviceVector<bst_float> preds(4, std::log(60.0f));
std::unique_ptr<Metric> metric(Metric::Create("interval-regression-accuracy", &ctx));
EXPECT_FLOAT_EQ(metric->Eval(preds, info), 0.75f);
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.75f);
info.labels_lower_bound_.HostVector()[2] = 70.0f;
EXPECT_FLOAT_EQ(metric->Eval(preds, info), 0.50f);
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
info.labels_upper_bound_.HostVector()[2] = std::numeric_limits<bst_float>::infinity();
EXPECT_FLOAT_EQ(metric->Eval(preds, info), 0.50f);
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
info.labels_upper_bound_.HostVector()[3] = std::numeric_limits<bst_float>::infinity();
EXPECT_FLOAT_EQ(metric->Eval(preds, info), 0.50f);
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
info.labels_lower_bound_.HostVector()[0] = 70.0f;
EXPECT_FLOAT_EQ(metric->Eval(preds, info), 0.25f);
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.25f);
CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, GPUIDX);
}