Unify test helpers for creating ctx. (#9274)

This commit is contained in:
Jiaming Yuan
2023-06-10 03:35:22 +08:00
committed by GitHub
parent ea0deeca68
commit 152e2fb072
36 changed files with 161 additions and 169 deletions

View File

@@ -8,7 +8,7 @@ namespace xgboost {
namespace metric {
inline void CheckDeterministicMetricMultiClass(StringView name, int32_t device) {
auto ctx = CreateEmptyGenericParam(device);
auto ctx = MakeCUDACtx(device);
std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &ctx)};
HostDeviceVector<float> predts;
@@ -45,7 +45,7 @@ inline void CheckDeterministicMetricMultiClass(StringView name, int32_t device)
}
inline void TestMultiClassError(int device, DataSplitMode data_split_mode) {
auto ctx = xgboost::CreateEmptyGenericParam(device);
auto ctx = MakeCUDACtx(device);
ctx.gpu_id = device;
xgboost::Metric * metric = xgboost::Metric::Create("merror", &ctx);
metric->Configure({});
@@ -66,7 +66,7 @@ inline void VerifyMultiClassError(DataSplitMode data_split_mode = DataSplitMode:
}
inline void TestMultiClassLogLoss(int device, DataSplitMode data_split_mode) {
auto ctx = xgboost::CreateEmptyGenericParam(device);
auto ctx = MakeCUDACtx(device);
ctx.gpu_id = device;
xgboost::Metric * metric = xgboost::Metric::Create("mlogloss", &ctx);
metric->Configure({});