GPU implementation of AFT survival objective and metric (#5714)
* Add interval accuracy * De-virtualize AFT functions * Lint * Refactor AFT metric using GPU-CPU reducer * Fix R build * Fix build on Windows * Fix copyright header * Clang-tidy * Fix crashing demo * Fix typos in comment; explain GPU ID * Remove unnecessary #include * Add C++ test for interval accuracy * Fix a bug in accuracy metric: use log pred * Refactor AFT objective using GPU-CPU Transform * Lint * Fix lint * Use Ninja to speed up build * Use time, not /usr/bin/time * Add cpu_build worker class, with concurrency = 1 * Use concurrency = 1 only for CUDA build * concurrency = 1 for clang-tidy * Address reviewer's feedback * Update link to AFT paper
This commit is contained in:
committed by
GitHub
parent
7c2686146e
commit
71b0528a2f
@@ -15,8 +15,8 @@
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
TEST(Objective, AFTObjConfiguration) {
|
||||
auto lparam = CreateEmptyGenericParam(-1); // currently AFT objective is CPU only
|
||||
TEST(Objective, DeclareUnifiedTest(AFTObjConfiguration)) {
|
||||
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
||||
std::unique_ptr<ObjFunction> objective(ObjFunction::Create("survival:aft", &lparam));
|
||||
objective->Configure({ {"aft_loss_distribution", "logistic"},
|
||||
{"aft_loss_distribution_scale", "5"} });
|
||||
@@ -76,8 +76,8 @@ static inline void CheckGPairOverGridPoints(
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Objective, AFTObjGPairUncensoredLabels) {
|
||||
auto lparam = CreateEmptyGenericParam(-1); // currently AFT objective is CPU only
|
||||
TEST(Objective, DeclareUnifiedTest(AFTObjGPairUncensoredLabels)) {
|
||||
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
||||
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", &lparam));
|
||||
|
||||
CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "normal",
|
||||
@@ -100,29 +100,29 @@ TEST(Objective, AFTObjGPairUncensoredLabels) {
|
||||
0.3026f, 0.1816f, 0.1090f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f });
|
||||
}
|
||||
|
||||
TEST(Objective, AFTObjGPairLeftCensoredLabels) {
|
||||
auto lparam = CreateEmptyGenericParam(-1); // currently AFT objective is CPU only
|
||||
TEST(Objective, DeclareUnifiedTest(AFTObjGPairLeftCensoredLabels)) {
|
||||
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
||||
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", &lparam));
|
||||
|
||||
CheckGPairOverGridPoints(obj.get(), -std::numeric_limits<float>::infinity(), 20.0f, "normal",
|
||||
CheckGPairOverGridPoints(obj.get(), 0.0f, 20.0f, "normal",
|
||||
{ 0.0285f, 0.0832f, 0.1951f, 0.3804f, 0.6403f, 0.9643f, 1.3379f, 1.7475f, 2.1828f, 2.6361f,
|
||||
3.1023f, 3.5779f, 4.0603f, 4.5479f, 5.0394f, 5.5340f, 6.0309f, 6.5298f, 7.0303f, 7.5326f },
|
||||
{ 0.0663f, 0.1559f, 0.2881f, 0.4378f, 0.5762f, 0.6878f, 0.7707f, 0.8300f, 0.8719f, 0.9016f,
|
||||
0.9229f, 0.9385f, 0.9501f, 0.9588f, 0.9656f, 0.9709f, 0.9751f, 0.9785f, 0.9813f, 0.9877f });
|
||||
CheckGPairOverGridPoints(obj.get(), -std::numeric_limits<float>::infinity(), 20.0f, "logistic",
|
||||
CheckGPairOverGridPoints(obj.get(), 0.0f, 20.0f, "logistic",
|
||||
{ 0.0909f, 0.1428f, 0.2174f, 0.3164f, 0.4355f, 0.5625f, 0.6818f, 0.7812f, 0.8561f, 0.9084f,
|
||||
0.9429f, 0.9650f, 0.9787f, 0.9871f, 0.9922f, 0.9953f, 0.9972f, 0.9983f, 0.9990f, 0.9994f },
|
||||
{ 0.0826f, 0.1224f, 0.1701f, 0.2163f, 0.2458f, 0.2461f, 0.2170f, 0.1709f, 0.1232f, 0.0832f,
|
||||
0.0538f, 0.0338f, 0.0209f, 0.0127f, 0.0077f, 0.0047f, 0.0028f, 0.0017f, 0.0010f, 0.0006f });
|
||||
CheckGPairOverGridPoints(obj.get(), -std::numeric_limits<float>::infinity(), 20.0f, "extreme",
|
||||
CheckGPairOverGridPoints(obj.get(), 0.0f, 20.0f, "extreme",
|
||||
{ 0.0005f, 0.0149f, 0.1011f, 0.2815f, 0.4881f, 0.6610f, 0.7847f, 0.8665f, 0.9183f, 0.9504f,
|
||||
0.9700f, 0.9820f, 0.9891f, 0.9935f, 0.9961f, 0.9976f, 0.9986f, 0.9992f, 0.9995f, 0.9997f },
|
||||
{ 0.0041f, 0.0747f, 0.2731f, 0.4059f, 0.3829f, 0.2901f, 0.1973f, 0.1270f, 0.0793f, 0.0487f,
|
||||
0.0296f, 0.0179f, 0.0108f, 0.0065f, 0.0039f, 0.0024f, 0.0014f, 0.0008f, 0.0005f, 0.0003f });
|
||||
}
|
||||
|
||||
TEST(Objective, AFTObjGPairRightCensoredLabels) {
|
||||
auto lparam = CreateEmptyGenericParam(-1); // currently AFT objective is CPU only
|
||||
TEST(Objective, DeclareUnifiedTest(AFTObjGPairRightCensoredLabels)) {
|
||||
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
||||
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", &lparam));
|
||||
|
||||
CheckGPairOverGridPoints(obj.get(), 60.0f, std::numeric_limits<float>::infinity(), "normal",
|
||||
@@ -145,8 +145,8 @@ TEST(Objective, AFTObjGPairRightCensoredLabels) {
|
||||
0.1816f, 0.1089f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f, 0.0018f });
|
||||
}
|
||||
|
||||
TEST(Objective, AFTObjGPairIntervalCensoredLabels) {
|
||||
auto lparam = CreateEmptyGenericParam(-1); // currently AFT objective is CPU only
|
||||
TEST(Objective, DeclareUnifiedTest(AFTObjGPairIntervalCensoredLabels)) {
|
||||
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
||||
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", &lparam));
|
||||
|
||||
CheckGPairOverGridPoints(obj.get(), 16.0f, 200.0f, "normal",
|
||||
|
||||
6
tests/cpp/objective/test_aft_obj.cu
Normal file
6
tests/cpp/objective/test_aft_obj.cu
Normal file
@@ -0,0 +1,6 @@
|
||||
/*!
|
||||
* Copyright 2020 XGBoost contributors
|
||||
*/
|
||||
// Dummy file to keep the CUDA tests.
|
||||
|
||||
#include "test_aft_obj.cc"
|
||||
Reference in New Issue
Block a user