Pass DMatrix into metric for caching. (#8790)
This commit is contained in:
parent
31d3ec07af
commit
81b2ee1153
@ -32,6 +32,8 @@ class DMatrixCache {
|
||||
CacheT& Value() { return *value; }
|
||||
};
|
||||
|
||||
static constexpr std::size_t DefaultSize() { return 32; }
|
||||
|
||||
protected:
|
||||
std::unordered_map<DMatrix const*, Item> container_;
|
||||
std::queue<DMatrix const*> queue_;
|
||||
|
||||
@ -54,12 +54,15 @@ class Metric : public Configurable {
|
||||
out["name"] = String(this->Name());
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief evaluate a specific metric
|
||||
* \param preds prediction
|
||||
* \param info information, including label etc.
|
||||
/**
|
||||
* \brief Evaluate a metric with DMatrix as input.
|
||||
*
|
||||
* \param preds Prediction
|
||||
* \param p_fmat DMatrix that contains related information like labels.
|
||||
*/
|
||||
virtual double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) = 0;
|
||||
virtual double Evaluate(HostDeviceVector<bst_float> const& preds,
|
||||
std::shared_ptr<DMatrix> p_fmat) = 0;
|
||||
|
||||
/*! \return name of metric */
|
||||
virtual const char* Name() const = 0;
|
||||
/*! \brief virtual destructor */
|
||||
|
||||
@ -1339,7 +1339,7 @@ class LearnerImpl : public LearnerIO {
|
||||
|
||||
obj_->EvalTransform(&out);
|
||||
for (auto& ev : metrics_) {
|
||||
os << '\t' << data_names[i] << '-' << ev->Name() << ':' << ev->Eval(out, m->Info());
|
||||
os << '\t' << data_names[i] << '-' << ev->Name() << ':' << ev->Evaluate(out, m);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
#include "../common/math.h"
|
||||
#include "../common/optional_weight.h" // OptionalWeights
|
||||
#include "metric_common.h" // MetricNoCache
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/metric.h"
|
||||
@ -253,7 +254,7 @@ std::pair<double, uint32_t> RankingAUC(std::vector<float> const &predts,
|
||||
}
|
||||
|
||||
template <typename Curve>
|
||||
class EvalAUC : public Metric {
|
||||
class EvalAUC : public MetricNoCache {
|
||||
double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info) override {
|
||||
double auc {0};
|
||||
if (ctx_->gpu_id != Context::kCpuId) {
|
||||
|
||||
@ -11,7 +11,7 @@
|
||||
#include <cmath>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/common.h" // MetricNoCache
|
||||
#include "../common/math.h"
|
||||
#include "../common/optional_weight.h" // OptionalWeights
|
||||
#include "../common/pseudo_huber.h"
|
||||
@ -23,8 +23,8 @@
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#include <thrust/execution_policy.h> // thrust::cuda::par
|
||||
#include <thrust/functional.h> // thrust::plus<>
|
||||
#include <thrust/transform_reduce.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
#include <thrust/transform_reduce.h>
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
@ -167,7 +167,7 @@ struct EvalRowLogLoss {
|
||||
}
|
||||
};
|
||||
|
||||
class PseudoErrorLoss : public Metric {
|
||||
class PseudoErrorLoss : public MetricNoCache {
|
||||
PesudoHuberParam param_;
|
||||
|
||||
public:
|
||||
@ -339,7 +339,7 @@ struct EvalTweedieNLogLik {
|
||||
* \tparam Derived the name of subclass
|
||||
*/
|
||||
template <typename Policy>
|
||||
struct EvalEWiseBase : public Metric {
|
||||
struct EvalEWiseBase : public MetricNoCache {
|
||||
EvalEWiseBase() = default;
|
||||
explicit EvalEWiseBase(char const* policy_param) : policy_{policy_param} {}
|
||||
|
||||
|
||||
@ -53,20 +53,21 @@ Metric::Create(const std::string& name, Context const* ctx) {
|
||||
return metric;
|
||||
}
|
||||
|
||||
Metric *
|
||||
GPUMetric::CreateGPUMetric(const std::string& name, Context const* ctx) {
|
||||
GPUMetric* GPUMetric::CreateGPUMetric(const std::string& name, Context const* ctx) {
|
||||
auto metric = CreateMetricImpl<MetricGPUReg>(name);
|
||||
if (metric == nullptr) {
|
||||
LOG(WARNING) << "Cannot find a GPU metric builder for metric " << name
|
||||
<< ". Resorting to the CPU builder";
|
||||
return metric;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Narrowing reference only for the compiler to allow assignment to a base class member.
|
||||
// As such, using this narrowed reference to refer to derived members will be an illegal op.
|
||||
// This is moot, as this type is stateless.
|
||||
static_cast<GPUMetric *>(metric)->ctx_ = ctx;
|
||||
return metric;
|
||||
auto casted = static_cast<GPUMetric*>(metric);
|
||||
CHECK(casted);
|
||||
casted->ctx_ = ctx;
|
||||
return casted;
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@ -13,12 +13,21 @@
|
||||
|
||||
namespace xgboost {
|
||||
struct Context;
|
||||
// Metric that doesn't need to cache anything based on input data.
|
||||
class MetricNoCache : public Metric {
|
||||
public:
|
||||
virtual double Eval(HostDeviceVector<float> const &predts, MetaInfo const &info) = 0;
|
||||
|
||||
double Evaluate(HostDeviceVector<float> const &predts, std::shared_ptr<DMatrix> p_fmat) final {
|
||||
return this->Eval(predts, p_fmat->Info());
|
||||
}
|
||||
};
|
||||
|
||||
// This creates a GPU metric instance dynamically and adds it to the GPU metric registry, if not
|
||||
// present already. This is created when there is a device ordinal present and if xgboost
|
||||
// is compiled with CUDA support
|
||||
struct GPUMetric : Metric {
|
||||
static Metric *CreateGPUMetric(const std::string &name, Context const *tparam);
|
||||
struct GPUMetric : public MetricNoCache {
|
||||
static GPUMetric *CreateGPUMetric(const std::string &name, Context const *tparam);
|
||||
};
|
||||
|
||||
/*!
|
||||
|
||||
@ -9,16 +9,16 @@
|
||||
#include <atomic>
|
||||
#include <cmath>
|
||||
|
||||
#include "metric_common.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "metric_common.h" // MetricNoCache
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#include <thrust/execution_policy.h> // thrust::cuda::par
|
||||
#include <thrust/functional.h> // thrust::plus<>
|
||||
#include <thrust/transform_reduce.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
#include <thrust/transform_reduce.h>
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
@ -162,7 +162,7 @@ class MultiClassMetricsReduction {
|
||||
* \tparam Derived the name of subclass
|
||||
*/
|
||||
template<typename Derived>
|
||||
struct EvalMClassBase : public Metric {
|
||||
struct EvalMClassBase : public MetricNoCache {
|
||||
double Eval(const HostDeviceVector<float> &preds, const MetaInfo &info) override {
|
||||
if (info.labels.Size() == 0) {
|
||||
CHECK_EQ(preds.Size(), 0);
|
||||
|
||||
@ -92,7 +92,7 @@ namespace metric {
|
||||
DMLC_REGISTRY_FILE_TAG(rank_metric);
|
||||
|
||||
/*! \brief AMS: also records best threshold */
|
||||
struct EvalAMS : public Metric {
|
||||
struct EvalAMS : public MetricNoCache {
|
||||
public:
|
||||
explicit EvalAMS(const char* param) {
|
||||
CHECK(param != nullptr) // NOLINT
|
||||
@ -155,10 +155,10 @@ struct EvalAMS : public Metric {
|
||||
};
|
||||
|
||||
/*! \brief Evaluate rank list */
|
||||
struct EvalRank : public Metric, public EvalRankConfig {
|
||||
struct EvalRank : public MetricNoCache, public EvalRankConfig {
|
||||
private:
|
||||
// This is used to compute the ranking metrics on the GPU - for training jobs that run on the GPU.
|
||||
std::unique_ptr<xgboost::Metric> rank_gpu_;
|
||||
std::unique_ptr<MetricNoCache> rank_gpu_;
|
||||
|
||||
public:
|
||||
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
|
||||
@ -322,7 +322,7 @@ struct EvalMAP : public EvalRank {
|
||||
};
|
||||
|
||||
/*! \brief Cox: Partial likelihood of the Cox proportional hazards model */
|
||||
struct EvalCox : public Metric {
|
||||
struct EvalCox : public MetricNoCache {
|
||||
public:
|
||||
EvalCox() = default;
|
||||
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
|
||||
|
||||
@ -1,21 +1,20 @@
|
||||
/**
|
||||
* Copyright 2020-2023 by XGBoost Contributors
|
||||
* \file rank_metric.cu
|
||||
* \brief prediction rank based metrics.
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
*/
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include <thrust/iterator/counting_iterator.h> // make_counting_iterator
|
||||
#include <thrust/reduce.h> // reduce
|
||||
#include <xgboost/metric.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
|
||||
#include <vector>
|
||||
#include <cstddef> // std::size_t
|
||||
#include <memory> // std::shared_ptr
|
||||
|
||||
#include "../common/cuda_context.cuh" // CUDAContext
|
||||
#include "metric_common.h"
|
||||
|
||||
#include "../common/math.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "xgboost/base.h" // XGBOOST_DEVICE
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/data.h" // MetaInfo
|
||||
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
|
||||
@ -10,15 +10,14 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/metric.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
#include "metric_common.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/survival_util.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "metric_common.h" // MetricNoCache
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/metric.h"
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#include <thrust/execution_policy.h> // thrust::cuda::par
|
||||
@ -194,10 +193,9 @@ struct EvalAFTNLogLik {
|
||||
AFTParam param_;
|
||||
};
|
||||
|
||||
template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
|
||||
explicit EvalEWiseSurvivalBase(Context const *ctx) {
|
||||
ctx_ = ctx;
|
||||
}
|
||||
template <typename Policy>
|
||||
struct EvalEWiseSurvivalBase : public MetricNoCache {
|
||||
explicit EvalEWiseSurvivalBase(Context const* ctx) { ctx_ = ctx; }
|
||||
EvalEWiseSurvivalBase() = default;
|
||||
|
||||
void Configure(const Args& args) override {
|
||||
@ -230,7 +228,7 @@ template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
|
||||
|
||||
// This class exists because we want to perform dispatch according to the distribution type at
|
||||
// configuration time, not at prediction time.
|
||||
struct AFTNLogLikDispatcher : public Metric {
|
||||
struct AFTNLogLikDispatcher : public MetricNoCache {
|
||||
const char* Name() const override {
|
||||
return "aft-nloglik";
|
||||
}
|
||||
@ -270,7 +268,7 @@ struct AFTNLogLikDispatcher : public Metric {
|
||||
|
||||
private:
|
||||
AFTParam param_;
|
||||
std::unique_ptr<Metric> metric_;
|
||||
std::unique_ptr<MetricNoCache> metric_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_METRIC(AFTNLogLik, "aft-nloglik")
|
||||
|
||||
@ -156,14 +156,15 @@ double GetMultiMetricEval(xgboost::Metric* metric,
|
||||
xgboost::linalg::Tensor<float, 2> const& labels,
|
||||
std::vector<xgboost::bst_float> weights,
|
||||
std::vector<xgboost::bst_uint> groups) {
|
||||
xgboost::MetaInfo info;
|
||||
std::shared_ptr<xgboost::DMatrix> p_fmat{xgboost::RandomDataGenerator{0, 0, 0}.GenerateDMatrix()};
|
||||
auto& info = p_fmat->Info();
|
||||
info.num_row_ = labels.Shape(0);
|
||||
info.labels.Reshape(labels.Shape()[0], labels.Shape()[1]);
|
||||
info.labels.Data()->Copy(*labels.Data());
|
||||
info.weights_.HostVector() = weights;
|
||||
info.group_ptr_ = groups;
|
||||
|
||||
return metric->Eval(preds, info);
|
||||
return metric->Evaluate(preds, p_fmat);
|
||||
}
|
||||
|
||||
namespace xgboost {
|
||||
@ -661,4 +662,4 @@ void DeleteRMMResource(RMMAllocator*) {}
|
||||
|
||||
RMMAllocatorPtr SetUpRMMResourceForCppTests(int, char**) { return {nullptr, DeleteRMMResource}; }
|
||||
#endif // !defined(XGBOOST_USE_RMM) || XGBOOST_USE_RMM != 1
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost
|
||||
|
||||
@ -301,6 +301,11 @@ class RandomDataGenerator {
|
||||
std::shared_ptr<DMatrix> GenerateQuantileDMatrix();
|
||||
};
|
||||
|
||||
// Generate an empty DMatrix, mostly for its meta info.
|
||||
inline std::shared_ptr<DMatrix> EmptyDMatrix() {
|
||||
return RandomDataGenerator{0, 0, 0.0}.GenerateDMatrix();
|
||||
}
|
||||
|
||||
inline std::vector<float>
|
||||
GenerateRandomCategoricalSingleColumn(int n, size_t num_categories) {
|
||||
std::vector<float> x(n);
|
||||
|
||||
@ -20,12 +20,13 @@ TEST(Metric, DeclareUnifiedTest(BinaryAUC)) {
|
||||
EXPECT_NEAR(GetMetricEval(metric, {1, 0, 0}, {0, 0, 1}), 0.25f, 1e-10);
|
||||
|
||||
// Invalid dataset
|
||||
MetaInfo info;
|
||||
auto p_fmat = EmptyDMatrix();
|
||||
MetaInfo& info = p_fmat->Info();
|
||||
info.labels = linalg::Tensor<float, 2>{{0.0f, 0.0f}, {2}, -1};
|
||||
float auc = metric->Eval({1, 1}, info);
|
||||
float auc = metric->Evaluate({1, 1}, p_fmat);
|
||||
ASSERT_TRUE(std::isnan(auc));
|
||||
*info.labels.Data() = HostDeviceVector<float>{};
|
||||
auc = metric->Eval(HostDeviceVector<float>{}, info);
|
||||
auc = metric->Evaluate(HostDeviceVector<float>{}, p_fmat);
|
||||
ASSERT_TRUE(std::isnan(auc));
|
||||
|
||||
EXPECT_NEAR(GetMetricEval(metric, {0, 1, 0, 1}, {0, 1, 0, 1}), 1.0f, 1e-10);
|
||||
|
||||
@ -19,7 +19,8 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device)
|
||||
HostDeviceVector<float> predts;
|
||||
size_t n_samples = 2048;
|
||||
|
||||
MetaInfo info;
|
||||
auto p_fmat = EmptyDMatrix();
|
||||
MetaInfo& info = p_fmat->Info();
|
||||
info.labels.Reshape(n_samples, 1);
|
||||
info.num_row_ = n_samples;
|
||||
auto &h_labels = info.labels.Data()->HostVector();
|
||||
@ -36,9 +37,9 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device)
|
||||
h_labels[i] = dist(&lcg);
|
||||
}
|
||||
|
||||
auto result = metric->Eval(predts, info);
|
||||
auto result = metric->Evaluate(predts, p_fmat);
|
||||
for (size_t i = 0; i < 8; ++i) {
|
||||
ASSERT_EQ(metric->Eval(predts, info), result);
|
||||
ASSERT_EQ(metric->Evaluate(predts, p_fmat), result);
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
@ -10,7 +10,8 @@ inline void CheckDeterministicMetricMultiClass(StringView name, int32_t device)
|
||||
std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &ctx)};
|
||||
|
||||
HostDeviceVector<float> predts;
|
||||
MetaInfo info;
|
||||
auto p_fmat = EmptyDMatrix();
|
||||
MetaInfo& info = p_fmat->Info();
|
||||
auto &h_predts = predts.HostVector();
|
||||
|
||||
SimpleLCG lcg;
|
||||
@ -35,9 +36,9 @@ inline void CheckDeterministicMetricMultiClass(StringView name, int32_t device)
|
||||
}
|
||||
}
|
||||
|
||||
auto result = metric->Eval(predts, info);
|
||||
auto result = metric->Evaluate(predts, p_fmat);
|
||||
for (size_t i = 0; i < 8; ++i) {
|
||||
ASSERT_EQ(metric->Eval(predts, info), result);
|
||||
ASSERT_EQ(metric->Evaluate(predts, p_fmat), result);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -18,7 +18,8 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device)
|
||||
metric->Configure(Args{});
|
||||
|
||||
HostDeviceVector<float> predts;
|
||||
MetaInfo info;
|
||||
auto p_fmat = EmptyDMatrix();
|
||||
MetaInfo& info = p_fmat->Info();
|
||||
auto &h_predts = predts.HostVector();
|
||||
|
||||
SimpleLCG lcg;
|
||||
@ -40,9 +41,9 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device)
|
||||
h_upper[i] = 10;
|
||||
}
|
||||
|
||||
auto result = metric->Eval(predts, info);
|
||||
auto result = metric->Evaluate(predts, p_fmat);
|
||||
for (size_t i = 0; i < 8; ++i) {
|
||||
ASSERT_EQ(metric->Eval(predts, info), result);
|
||||
ASSERT_EQ(metric->Evaluate(predts, p_fmat), result);
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
@ -54,7 +55,8 @@ TEST(Metric, DeclareUnifiedTest(AFTNegLogLik)) {
|
||||
* Test aggregate output from the AFT metric over a small test data set.
|
||||
* This is unlike AFTLoss.* tests, which verify metric values over individual data points.
|
||||
**/
|
||||
MetaInfo info;
|
||||
auto p_fmat = EmptyDMatrix();
|
||||
MetaInfo& info = p_fmat->Info();
|
||||
info.num_row_ = 4;
|
||||
info.labels_lower_bound_.HostVector()
|
||||
= { 100.0f, 0.0f, 60.0f, 16.0f };
|
||||
@ -72,14 +74,15 @@ TEST(Metric, DeclareUnifiedTest(AFTNegLogLik)) {
|
||||
std::unique_ptr<Metric> metric(Metric::Create("aft-nloglik", &ctx));
|
||||
metric->Configure({ {"aft_loss_distribution", test_case.dist_type},
|
||||
{"aft_loss_distribution_scale", "1.0"} });
|
||||
EXPECT_NEAR(metric->Eval(preds, info), test_case.reference_value, 1e-4);
|
||||
EXPECT_NEAR(metric->Evaluate(preds, p_fmat), test_case.reference_value, 1e-4);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Metric, DeclareUnifiedTest(IntervalRegressionAccuracy)) {
|
||||
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
|
||||
MetaInfo info;
|
||||
auto p_fmat = EmptyDMatrix();
|
||||
MetaInfo& info = p_fmat->Info();
|
||||
info.num_row_ = 4;
|
||||
info.labels_lower_bound_.HostVector() = { 20.0f, 0.0f, 60.0f, 16.0f };
|
||||
info.labels_upper_bound_.HostVector() = { 80.0f, 20.0f, 80.0f, 200.0f };
|
||||
@ -87,15 +90,15 @@ TEST(Metric, DeclareUnifiedTest(IntervalRegressionAccuracy)) {
|
||||
HostDeviceVector<bst_float> preds(4, std::log(60.0f));
|
||||
|
||||
std::unique_ptr<Metric> metric(Metric::Create("interval-regression-accuracy", &ctx));
|
||||
EXPECT_FLOAT_EQ(metric->Eval(preds, info), 0.75f);
|
||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.75f);
|
||||
info.labels_lower_bound_.HostVector()[2] = 70.0f;
|
||||
EXPECT_FLOAT_EQ(metric->Eval(preds, info), 0.50f);
|
||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
|
||||
info.labels_upper_bound_.HostVector()[2] = std::numeric_limits<bst_float>::infinity();
|
||||
EXPECT_FLOAT_EQ(metric->Eval(preds, info), 0.50f);
|
||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
|
||||
info.labels_upper_bound_.HostVector()[3] = std::numeric_limits<bst_float>::infinity();
|
||||
EXPECT_FLOAT_EQ(metric->Eval(preds, info), 0.50f);
|
||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.50f);
|
||||
info.labels_lower_bound_.HostVector()[0] = 70.0f;
|
||||
EXPECT_FLOAT_EQ(metric->Eval(preds, info), 0.25f);
|
||||
EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.25f);
|
||||
|
||||
CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, GPUIDX);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user