Pass DMatrix into metric for caching. (#8790)
This commit is contained in:
@@ -10,15 +10,14 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/metric.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
#include "metric_common.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/survival_util.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "metric_common.h" // MetricNoCache
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/metric.h"
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#include <thrust/execution_policy.h> // thrust::cuda::par
|
||||
@@ -194,10 +193,9 @@ struct EvalAFTNLogLik {
|
||||
AFTParam param_;
|
||||
};
|
||||
|
||||
template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
|
||||
explicit EvalEWiseSurvivalBase(Context const *ctx) {
|
||||
ctx_ = ctx;
|
||||
}
|
||||
template <typename Policy>
|
||||
struct EvalEWiseSurvivalBase : public MetricNoCache {
|
||||
explicit EvalEWiseSurvivalBase(Context const* ctx) { ctx_ = ctx; }
|
||||
EvalEWiseSurvivalBase() = default;
|
||||
|
||||
void Configure(const Args& args) override {
|
||||
@@ -230,7 +228,7 @@ template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
|
||||
|
||||
// This class exists because we want to perform dispatch according to the distribution type at
|
||||
// configuration time, not at prediction time.
|
||||
struct AFTNLogLikDispatcher : public Metric {
|
||||
struct AFTNLogLikDispatcher : public MetricNoCache {
|
||||
const char* Name() const override {
|
||||
return "aft-nloglik";
|
||||
}
|
||||
@@ -270,7 +268,7 @@ struct AFTNLogLikDispatcher : public Metric {
|
||||
|
||||
private:
|
||||
AFTParam param_;
|
||||
std::unique_ptr<Metric> metric_;
|
||||
std::unique_ptr<MetricNoCache> metric_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_METRIC(AFTNLogLik, "aft-nloglik")
|
||||
|
||||
Reference in New Issue
Block a user