diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 20923519a..b1a2e0ded 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -29,14 +29,14 @@ if(PLUGIN_SYCL) ${xgboost_SOURCE_DIR}/rabit/include) target_compile_definitions(plugin_sycl_test PUBLIC -DXGBOOST_USE_SYCL=1) - target_link_libraries(plugin_sycl_test PUBLIC -fsycl) + target_link_libraries(plugin_sycl_test PRIVATE ${GTEST_LIBRARIES}) set_target_properties(plugin_sycl_test PROPERTIES - COMPILE_FLAGS -fsycl - CXX_STANDARD 17 - CXX_STANDARD_REQUIRED ON - POSITION_INDEPENDENT_CODE ON) + COMPILE_FLAGS -fsycl + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + POSITION_INDEPENDENT_CODE ON) if(USE_OPENMP) find_package(OpenMP REQUIRED) set_target_properties(plugin_sycl_test PROPERTIES diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 67c5b39a4..9229832c0 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -1,6 +1,9 @@ -// Copyright 2016-2021 by Contributors +/** + * Copyright 2016-2024, XGBoost contributors + */ #include "test_metainfo.h" +#include #include #include @@ -9,7 +12,7 @@ #include "../../../src/common/version.h" #include "../filesystem.h" // dmlc::TemporaryDirectory -#include "../helpers.h" +#include "../helpers.h" // for GMockTHrow #include "xgboost/base.h" namespace xgboost { @@ -46,6 +49,8 @@ TEST(MetaInfo, GetSet) { TEST(MetaInfo, GetSetFeature) { xgboost::MetaInfo info; + ASSERT_THAT([&] { info.SetFeatureInfo("", nullptr, 0); }, + GMockThrow("Unknown feature info name")); EXPECT_THROW(info.SetFeatureInfo("", nullptr, 0), dmlc::Error); EXPECT_THROW(info.SetFeatureInfo("foo", nullptr, 0), dmlc::Error); EXPECT_NO_THROW(info.SetFeatureInfo("feature_name", nullptr, 0)); @@ -86,7 +91,8 @@ void VerifyGetSetFeatureColumnSplit() { std::transform(types.cbegin(), types.cend(), c_types.begin(), [](auto const &str) { return str.c_str(); }); info.num_col_ = kCols; - EXPECT_THROW(info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()), dmlc::Error); + ASSERT_THAT([&] { info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()); }, + GMockThrow("Length of feature_type must be equal to number of columns")); info.num_col_ = kCols * world_size; EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size())); std::vector expected_type_names{u8"float", u8"c", u8"float", @@ -103,7 +109,8 @@ void VerifyGetSetFeatureColumnSplit() { std::transform(names.cbegin(), names.cend(), c_names.begin(), [](auto const &str) { return str.c_str(); }); info.num_col_ = kCols; - EXPECT_THROW(info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size()), dmlc::Error); + ASSERT_THAT([&] { info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size()); }, + GMockThrow("Length of feature_name must be equal to number of columns")); info.num_col_ = kCols * world_size; EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size())); std::vector expected_names{u8"0.feature0", u8"0.feature1", u8"1.feature0", diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index d603685eb..56b9d7739 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -1,8 +1,9 @@ /** - * Copyright 2016-2024 by XGBoost contributors + * Copyright 2016-2024, XGBoost contributors */ #pragma once +#include #include #include #include @@ -12,7 +13,7 @@ #include // for LearnerModelParam #include // for Configurable -#include // std::int32_t +#include // std::int32_t #include #include #include @@ -573,30 +574,7 @@ class DeclareUnifiedDistributedTest(MetricTest) : public BaseMGPUTest{}; inline DeviceOrd FstCU() { return DeviceOrd::CUDA(0); } -/** - * @brief poor man's gmock for message matching. - * - * @tparam Error The type of expected execption. - * - * @param submsg A substring of the actual error message. - * @param fn The function that throws Error - */ -template -void ExpectThrow(std::string submsg, Fn&& fn) { - try { - fn(); - } catch (Error const& exc) { - auto actual = std::string{exc.what()}; - ASSERT_NE(actual.find(submsg), std::string::npos) - << "Expecting substring `" << submsg << "` from the error message." - << " Got:\n" - << actual << "\n"; - return; - } catch (std::exception const& exc) { - auto actual = exc.what(); - ASSERT_TRUE(false) << "An unexpected type of exception is thrown. what:" << actual; - return; - } - ASSERT_TRUE(false) << "No exception is thrown"; +inline auto GMockThrow(StringView msg) { + return ::testing::ThrowsMessage(::testing::HasSubstr(msg)); } } // namespace xgboost diff --git a/tests/cpp/plugin/federated/test_federated_comm.cc b/tests/cpp/plugin/federated/test_federated_comm.cc index 0d0692b5f..16edc685f 100644 --- a/tests/cpp/plugin/federated/test_federated_comm.cc +++ b/tests/cpp/plugin/federated/test_federated_comm.cc @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023, XGBoost contributors + * Copyright 2022-2024, XGBoost contributors */ #include #include @@ -9,7 +9,7 @@ #include "../../../../plugin/federated/federated_comm.h" #include "../../collective/test_worker.h" // for SocketTest -#include "../../helpers.h" // for ExpectThrow +#include "../../helpers.h" // for GMockThrow #include "test_worker.h" // for TestFederated #include "xgboost/json.h" // for Json @@ -20,19 +20,19 @@ class FederatedCommTest : public SocketTest {}; TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) { auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; }; - ASSERT_THAT(construct, - ::testing::ThrowsMessage(::testing::HasSubstr("Invalid world size"))); + ASSERT_THAT(construct, GMockThrow("Invalid world size")); } TEST_F(FederatedCommTest, ThrowOnRankTooSmall) { auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; }; - ASSERT_THAT(construct, - ::testing::ThrowsMessage(::testing::HasSubstr("Invalid worker rank."))); + ASSERT_THAT(construct, GMockThrow("Invalid worker rank.")); } TEST_F(FederatedCommTest, ThrowOnRankTooBig) { - auto construct = [] { FederatedComm comm{"localhost", 0, 1, 1}; }; - ExpectThrow("Invalid worker rank.", construct); + auto construct = [] { + FederatedComm comm{"localhost", 0, 1, 1}; + }; + ASSERT_THAT(construct, GMockThrow("Invalid worker rank.")); } TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) { @@ -43,7 +43,7 @@ TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) { config["federated_rank"] = Integer(0); FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config}; }; - ExpectThrow("got: `String`", construct); + ASSERT_THAT(construct, GMockThrow("got: `String`")); } TEST_F(FederatedCommTest, ThrowOnRankNotInteger) { @@ -54,7 +54,7 @@ TEST_F(FederatedCommTest, ThrowOnRankNotInteger) { config["federated_rank"] = std::string("0"); FederatedComm comm(DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config); }; - ExpectThrow("got: `String`", construct); + ASSERT_THAT(construct, GMockThrow("got: `String`")); } TEST_F(FederatedCommTest, GetWorldSizeAndRank) { diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 04f1d35b4..2429e09eb 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -1,5 +1,5 @@ /** - * Copyright (c) 2017-2023, XGBoost contributors + * Copyright 2017-2024, XGBoost contributors */ #include #include @@ -82,9 +82,7 @@ TEST(Learner, ParameterValidation) { // whitespace learner->SetParam("tree method", "exact"); - EXPECT_THAT([&] { learner->Configure(); }, - ::testing::ThrowsMessage( - ::testing::HasSubstr(R"("tree method" contains whitespace)"))); + ASSERT_THAT([&] { learner->Configure(); }, GMockThrow(R"("tree method" contains whitespace)")); } TEST(Learner, CheckGroup) {