Small cleanup for mock tests. (#10085)
This commit is contained in:
parent
7a61216690
commit
d07b7fe8c8
@ -29,14 +29,14 @@ if(PLUGIN_SYCL)
|
|||||||
${xgboost_SOURCE_DIR}/rabit/include)
|
${xgboost_SOURCE_DIR}/rabit/include)
|
||||||
|
|
||||||
target_compile_definitions(plugin_sycl_test PUBLIC -DXGBOOST_USE_SYCL=1)
|
target_compile_definitions(plugin_sycl_test PUBLIC -DXGBOOST_USE_SYCL=1)
|
||||||
|
|
||||||
target_link_libraries(plugin_sycl_test PUBLIC -fsycl)
|
target_link_libraries(plugin_sycl_test PUBLIC -fsycl)
|
||||||
|
target_link_libraries(plugin_sycl_test PRIVATE ${GTEST_LIBRARIES})
|
||||||
|
|
||||||
set_target_properties(plugin_sycl_test PROPERTIES
|
set_target_properties(plugin_sycl_test PROPERTIES
|
||||||
COMPILE_FLAGS -fsycl
|
COMPILE_FLAGS -fsycl
|
||||||
CXX_STANDARD 17
|
CXX_STANDARD 17
|
||||||
CXX_STANDARD_REQUIRED ON
|
CXX_STANDARD_REQUIRED ON
|
||||||
POSITION_INDEPENDENT_CODE ON)
|
POSITION_INDEPENDENT_CODE ON)
|
||||||
if(USE_OPENMP)
|
if(USE_OPENMP)
|
||||||
find_package(OpenMP REQUIRED)
|
find_package(OpenMP REQUIRED)
|
||||||
set_target_properties(plugin_sycl_test PROPERTIES
|
set_target_properties(plugin_sycl_test PROPERTIES
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
// Copyright 2016-2021 by Contributors
|
/**
|
||||||
|
* Copyright 2016-2024, XGBoost contributors
|
||||||
|
*/
|
||||||
#include "test_metainfo.h"
|
#include "test_metainfo.h"
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
#include <dmlc/io.h>
|
#include <dmlc/io.h>
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
|
|
||||||
@ -9,7 +12,7 @@
|
|||||||
|
|
||||||
#include "../../../src/common/version.h"
|
#include "../../../src/common/version.h"
|
||||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||||
#include "../helpers.h"
|
#include "../helpers.h" // for GMockTHrow
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -46,6 +49,8 @@ TEST(MetaInfo, GetSet) {
|
|||||||
|
|
||||||
TEST(MetaInfo, GetSetFeature) {
|
TEST(MetaInfo, GetSetFeature) {
|
||||||
xgboost::MetaInfo info;
|
xgboost::MetaInfo info;
|
||||||
|
ASSERT_THAT([&] { info.SetFeatureInfo("", nullptr, 0); },
|
||||||
|
GMockThrow("Unknown feature info name"));
|
||||||
EXPECT_THROW(info.SetFeatureInfo("", nullptr, 0), dmlc::Error);
|
EXPECT_THROW(info.SetFeatureInfo("", nullptr, 0), dmlc::Error);
|
||||||
EXPECT_THROW(info.SetFeatureInfo("foo", nullptr, 0), dmlc::Error);
|
EXPECT_THROW(info.SetFeatureInfo("foo", nullptr, 0), dmlc::Error);
|
||||||
EXPECT_NO_THROW(info.SetFeatureInfo("feature_name", nullptr, 0));
|
EXPECT_NO_THROW(info.SetFeatureInfo("feature_name", nullptr, 0));
|
||||||
@ -86,7 +91,8 @@ void VerifyGetSetFeatureColumnSplit() {
|
|||||||
std::transform(types.cbegin(), types.cend(), c_types.begin(),
|
std::transform(types.cbegin(), types.cend(), c_types.begin(),
|
||||||
[](auto const &str) { return str.c_str(); });
|
[](auto const &str) { return str.c_str(); });
|
||||||
info.num_col_ = kCols;
|
info.num_col_ = kCols;
|
||||||
EXPECT_THROW(info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()), dmlc::Error);
|
ASSERT_THAT([&] { info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()); },
|
||||||
|
GMockThrow("Length of feature_type must be equal to number of columns"));
|
||||||
info.num_col_ = kCols * world_size;
|
info.num_col_ = kCols * world_size;
|
||||||
EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()));
|
EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()));
|
||||||
std::vector<std::string> expected_type_names{u8"float", u8"c", u8"float",
|
std::vector<std::string> expected_type_names{u8"float", u8"c", u8"float",
|
||||||
@ -103,7 +109,8 @@ void VerifyGetSetFeatureColumnSplit() {
|
|||||||
std::transform(names.cbegin(), names.cend(), c_names.begin(),
|
std::transform(names.cbegin(), names.cend(), c_names.begin(),
|
||||||
[](auto const &str) { return str.c_str(); });
|
[](auto const &str) { return str.c_str(); });
|
||||||
info.num_col_ = kCols;
|
info.num_col_ = kCols;
|
||||||
EXPECT_THROW(info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size()), dmlc::Error);
|
ASSERT_THAT([&] { info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size()); },
|
||||||
|
GMockThrow("Length of feature_name must be equal to number of columns"));
|
||||||
info.num_col_ = kCols * world_size;
|
info.num_col_ = kCols * world_size;
|
||||||
EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size()));
|
EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size()));
|
||||||
std::vector<std::string> expected_names{u8"0.feature0", u8"0.feature1", u8"1.feature0",
|
std::vector<std::string> expected_names{u8"0.feature0", u8"0.feature1", u8"1.feature0",
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2016-2024 by XGBoost contributors
|
* Copyright 2016-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <sys/stat.h>
|
#include <sys/stat.h>
|
||||||
#include <sys/types.h>
|
#include <sys/types.h>
|
||||||
@ -12,7 +13,7 @@
|
|||||||
#include <xgboost/learner.h> // for LearnerModelParam
|
#include <xgboost/learner.h> // for LearnerModelParam
|
||||||
#include <xgboost/model.h> // for Configurable
|
#include <xgboost/model.h> // for Configurable
|
||||||
|
|
||||||
#include <cstdint> // std::int32_t
|
#include <cstdint> // std::int32_t
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
@ -573,30 +574,7 @@ class DeclareUnifiedDistributedTest(MetricTest) : public BaseMGPUTest{};
|
|||||||
|
|
||||||
inline DeviceOrd FstCU() { return DeviceOrd::CUDA(0); }
|
inline DeviceOrd FstCU() { return DeviceOrd::CUDA(0); }
|
||||||
|
|
||||||
/**
|
inline auto GMockThrow(StringView msg) {
|
||||||
* @brief poor man's gmock for message matching.
|
return ::testing::ThrowsMessage<dmlc::Error>(::testing::HasSubstr(msg));
|
||||||
*
|
|
||||||
* @tparam Error The type of expected execption.
|
|
||||||
*
|
|
||||||
* @param submsg A substring of the actual error message.
|
|
||||||
* @param fn The function that throws Error
|
|
||||||
*/
|
|
||||||
template <typename Error, typename Fn>
|
|
||||||
void ExpectThrow(std::string submsg, Fn&& fn) {
|
|
||||||
try {
|
|
||||||
fn();
|
|
||||||
} catch (Error const& exc) {
|
|
||||||
auto actual = std::string{exc.what()};
|
|
||||||
ASSERT_NE(actual.find(submsg), std::string::npos)
|
|
||||||
<< "Expecting substring `" << submsg << "` from the error message."
|
|
||||||
<< " Got:\n"
|
|
||||||
<< actual << "\n";
|
|
||||||
return;
|
|
||||||
} catch (std::exception const& exc) {
|
|
||||||
auto actual = exc.what();
|
|
||||||
ASSERT_TRUE(false) << "An unexpected type of exception is thrown. what:" << actual;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
ASSERT_TRUE(false) << "No exception is thrown";
|
|
||||||
}
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2022-2023, XGBoost contributors
|
* Copyright 2022-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gmock/gmock.h>
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
@ -9,7 +9,7 @@
|
|||||||
|
|
||||||
#include "../../../../plugin/federated/federated_comm.h"
|
#include "../../../../plugin/federated/federated_comm.h"
|
||||||
#include "../../collective/test_worker.h" // for SocketTest
|
#include "../../collective/test_worker.h" // for SocketTest
|
||||||
#include "../../helpers.h" // for ExpectThrow
|
#include "../../helpers.h" // for GMockThrow
|
||||||
#include "test_worker.h" // for TestFederated
|
#include "test_worker.h" // for TestFederated
|
||||||
#include "xgboost/json.h" // for Json
|
#include "xgboost/json.h" // for Json
|
||||||
|
|
||||||
@ -20,19 +20,19 @@ class FederatedCommTest : public SocketTest {};
|
|||||||
|
|
||||||
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
|
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
|
||||||
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
|
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
|
||||||
ASSERT_THAT(construct,
|
ASSERT_THAT(construct, GMockThrow("Invalid world size"));
|
||||||
::testing::ThrowsMessage<dmlc::Error>(::testing::HasSubstr("Invalid world size")));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
|
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
|
||||||
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
|
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
|
||||||
ASSERT_THAT(construct,
|
ASSERT_THAT(construct, GMockThrow("Invalid worker rank."));
|
||||||
::testing::ThrowsMessage<dmlc::Error>(::testing::HasSubstr("Invalid worker rank.")));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
|
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
|
||||||
auto construct = [] { FederatedComm comm{"localhost", 0, 1, 1}; };
|
auto construct = [] {
|
||||||
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
|
FederatedComm comm{"localhost", 0, 1, 1};
|
||||||
|
};
|
||||||
|
ASSERT_THAT(construct, GMockThrow("Invalid worker rank."));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) {
|
TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) {
|
||||||
@ -43,7 +43,7 @@ TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) {
|
|||||||
config["federated_rank"] = Integer(0);
|
config["federated_rank"] = Integer(0);
|
||||||
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
|
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
|
||||||
};
|
};
|
||||||
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
ASSERT_THAT(construct, GMockThrow("got: `String`"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FederatedCommTest, ThrowOnRankNotInteger) {
|
TEST_F(FederatedCommTest, ThrowOnRankNotInteger) {
|
||||||
@ -54,7 +54,7 @@ TEST_F(FederatedCommTest, ThrowOnRankNotInteger) {
|
|||||||
config["federated_rank"] = std::string("0");
|
config["federated_rank"] = std::string("0");
|
||||||
FederatedComm comm(DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config);
|
FederatedComm comm(DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config);
|
||||||
};
|
};
|
||||||
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
ASSERT_THAT(construct, GMockThrow("got: `String`"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
|
TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright (c) 2017-2023, XGBoost contributors
|
* Copyright 2017-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <gmock/gmock.h>
|
#include <gmock/gmock.h>
|
||||||
@ -82,9 +82,7 @@ TEST(Learner, ParameterValidation) {
|
|||||||
|
|
||||||
// whitespace
|
// whitespace
|
||||||
learner->SetParam("tree method", "exact");
|
learner->SetParam("tree method", "exact");
|
||||||
EXPECT_THAT([&] { learner->Configure(); },
|
ASSERT_THAT([&] { learner->Configure(); }, GMockThrow(R"("tree method" contains whitespace)"));
|
||||||
::testing::ThrowsMessage<dmlc::Error>(
|
|
||||||
::testing::HasSubstr(R"("tree method" contains whitespace)")));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Learner, CheckGroup) {
|
TEST(Learner, CheckGroup) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user