Pass DMatrix into metric for caching. (#8790)

This commit is contained in:
Jiaming Yuan
2023-02-13 22:15:05 +08:00
committed by GitHub
parent 31d3ec07af
commit 81b2ee1153
17 changed files with 95 additions and 70 deletions

View File

@@ -10,15 +10,14 @@
#include <memory>
#include <vector>
#include "xgboost/json.h"
#include "xgboost/metric.h"
#include "xgboost/host_device_vector.h"
#include "metric_common.h"
#include "../collective/communicator-inl.h"
#include "../common/math.h"
#include "../common/survival_util.h"
#include "../common/threading_utils.h"
#include "../common/threading_utils.h"
#include "metric_common.h" // MetricNoCache
#include "xgboost/host_device_vector.h"
#include "xgboost/json.h"
#include "xgboost/metric.h"
#if defined(XGBOOST_USE_CUDA)
#include <thrust/execution_policy.h> // thrust::cuda::par
@@ -194,10 +193,9 @@ struct EvalAFTNLogLik {
AFTParam param_;
};
template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
explicit EvalEWiseSurvivalBase(Context const *ctx) {
ctx_ = ctx;
}
template <typename Policy>
struct EvalEWiseSurvivalBase : public MetricNoCache {
explicit EvalEWiseSurvivalBase(Context const* ctx) { ctx_ = ctx; }
EvalEWiseSurvivalBase() = default;
void Configure(const Args& args) override {
@@ -230,7 +228,7 @@ template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
// This class exists because we want to perform dispatch according to the distribution type at
// configuration time, not at prediction time.
struct AFTNLogLikDispatcher : public Metric {
struct AFTNLogLikDispatcher : public MetricNoCache {
const char* Name() const override {
return "aft-nloglik";
}
@@ -270,7 +268,7 @@ struct AFTNLogLikDispatcher : public Metric {
private:
AFTParam param_;
std::unique_ptr<Metric> metric_;
std::unique_ptr<MetricNoCache> metric_;
};
XGBOOST_REGISTER_METRIC(AFTNLogLik, "aft-nloglik")