Extract optional weight. (#8747)
- Extract optional weight from coommon.h to reduce dependency on this header. - Add test.
This commit is contained in:
24
tests/cpp/common/test_optional_weight.cc
Normal file
24
tests/cpp/common/test_optional_weight.cc
Normal file
@@ -0,0 +1,24 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/context.h> // Context
|
||||
#include <xgboost/host_device_vector.h> // HostDeviceVector
|
||||
|
||||
#include "../../../src/common/optional_weight.h"
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
TEST(OptionalWeight, Basic) {
|
||||
HostDeviceVector<float> weight{{2.0f, 3.0f, 4.0f}};
|
||||
Context ctx;
|
||||
auto opt_w = MakeOptionalWeights(&ctx, weight);
|
||||
ASSERT_EQ(opt_w[0], 2.0f);
|
||||
ASSERT_FALSE(opt_w.Empty());
|
||||
|
||||
weight.HostVector().clear();
|
||||
opt_w = MakeOptionalWeights(&ctx, weight);
|
||||
ASSERT_EQ(opt_w[0], 1.0f);
|
||||
ASSERT_TRUE(opt_w.Empty());
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
Reference in New Issue
Block a user