Define the new device parameter. (#9362)
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include <limits> // for numeric_limits
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
|
||||
#include "../../../src/data/adapter.h" // for ArrayAdapter
|
||||
#include "../../../src/data/device_adapter.cuh" // for CupyAdapter
|
||||
@@ -41,7 +42,7 @@ void TestInplaceFallback(Context const* ctx) {
|
||||
|
||||
// learner is configured to the device specified by ctx
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
|
||||
ConfigLearnerByCtx(ctx, learner.get());
|
||||
learner->SetParam("device", ctx->DeviceName());
|
||||
for (std::int32_t i = 0; i < 3; ++i) {
|
||||
learner->UpdateOneIter(i, Xy);
|
||||
}
|
||||
@@ -56,18 +57,31 @@ void TestInplaceFallback(Context const* ctx) {
|
||||
|
||||
HostDeviceVector<float>* out_predt{nullptr};
|
||||
ConsoleLogger::Configure(Args{{"verbosity", "1"}});
|
||||
std::string output;
|
||||
// test whether the warning is raised
|
||||
#if !defined(_WIN32)
|
||||
// Windows has issue with CUDA and thread local storage. For some reason, on Windows a
|
||||
// cudaInitializationError is raised during destruction of `HostDeviceVector`. This
|
||||
// might be related to https://github.com/dmlc/xgboost/issues/5793
|
||||
::testing::internal::CaptureStderr();
|
||||
std::thread{[&] {
|
||||
// Launch a new thread to ensure a warning is raised as we prevent over-verbose
|
||||
// warning by using thread-local flags.
|
||||
learner->InplacePredict(p_m, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),
|
||||
&out_predt, 0, 0);
|
||||
}}.join();
|
||||
output = testing::internal::GetCapturedStderr();
|
||||
ASSERT_NE(output.find("Falling back"), std::string::npos);
|
||||
#endif
|
||||
|
||||
learner->InplacePredict(p_m, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),
|
||||
&out_predt, 0, 0);
|
||||
auto output = testing::internal::GetCapturedStderr();
|
||||
ASSERT_NE(output.find("Falling back"), std::string::npos);
|
||||
|
||||
// test when the contexts match
|
||||
Context new_ctx = *proxy->Ctx();
|
||||
ASSERT_NE(new_ctx.gpu_id, ctx->gpu_id);
|
||||
|
||||
ConfigLearnerByCtx(&new_ctx, learner.get());
|
||||
learner->SetParam("device", new_ctx.DeviceName());
|
||||
HostDeviceVector<float>* out_predt_1{nullptr};
|
||||
// no warning is raised
|
||||
::testing::internal::CaptureStderr();
|
||||
|
||||
Reference in New Issue
Block a user