Switch back to the GPUIDX macro (#9438)

This commit is contained in:
Rong Ou
2023-08-04 00:14:31 -07:00
committed by GitHub
parent 1aabc690ec
commit bde1ebc209
21 changed files with 85 additions and 88 deletions

View File

@@ -16,7 +16,7 @@ namespace xgboost {
namespace common {
TEST(Objective, DeclareUnifiedTest(AFTObjConfiguration)) {
auto ctx = MakeCUDACtx(GetGPUId());
auto ctx = MakeCUDACtx(GPUIDX);
std::unique_ptr<ObjFunction> objective(ObjFunction::Create("survival:aft", &ctx));
objective->Configure({ {"aft_loss_distribution", "logistic"},
{"aft_loss_distribution_scale", "5"} });
@@ -77,7 +77,7 @@ static inline void CheckGPairOverGridPoints(
}
TEST(Objective, DeclareUnifiedTest(AFTObjGPairUncensoredLabels)) {
auto ctx = MakeCUDACtx(GetGPUId());
auto ctx = MakeCUDACtx(GPUIDX);
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", &ctx));
CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "normal",
@@ -101,7 +101,7 @@ TEST(Objective, DeclareUnifiedTest(AFTObjGPairUncensoredLabels)) {
}
TEST(Objective, DeclareUnifiedTest(AFTObjGPairLeftCensoredLabels)) {
auto ctx = MakeCUDACtx(GetGPUId());
auto ctx = MakeCUDACtx(GPUIDX);
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", &ctx));
CheckGPairOverGridPoints(obj.get(), 0.0f, 20.0f, "normal",
@@ -122,7 +122,7 @@ TEST(Objective, DeclareUnifiedTest(AFTObjGPairLeftCensoredLabels)) {
}
TEST(Objective, DeclareUnifiedTest(AFTObjGPairRightCensoredLabels)) {
auto ctx = MakeCUDACtx(GetGPUId());
auto ctx = MakeCUDACtx(GPUIDX);
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", &ctx));
CheckGPairOverGridPoints(obj.get(), 60.0f, std::numeric_limits<float>::infinity(), "normal",
@@ -146,7 +146,7 @@ TEST(Objective, DeclareUnifiedTest(AFTObjGPairRightCensoredLabels)) {
}
TEST(Objective, DeclareUnifiedTest(AFTObjGPairIntervalCensoredLabels)) {
auto ctx = MakeCUDACtx(GetGPUId());
auto ctx = MakeCUDACtx(GPUIDX);
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", &ctx));
CheckGPairOverGridPoints(obj.get(), 16.0f, 200.0f, "normal",