Extract optional weight. (#8747)

- Extract optional weight from coommon.h to reduce dependency on this header.
- Add test.
This commit is contained in:
Jiaming Yuan
2023-02-07 03:11:53 +08:00
committed by GitHub
parent 0f37a01dd9
commit 28bb01aa22
10 changed files with 74 additions and 22 deletions

View File

@@ -0,0 +1,24 @@
/**
* Copyright 2023 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/context.h> // Context
#include <xgboost/host_device_vector.h> // HostDeviceVector
#include "../../../src/common/optional_weight.h"
namespace xgboost {
namespace common {
TEST(OptionalWeight, Basic) {
HostDeviceVector<float> weight{{2.0f, 3.0f, 4.0f}};
Context ctx;
auto opt_w = MakeOptionalWeights(&ctx, weight);
ASSERT_EQ(opt_w[0], 2.0f);
ASSERT_FALSE(opt_w.Empty());
weight.HostVector().clear();
opt_w = MakeOptionalWeights(&ctx, weight);
ASSERT_EQ(opt_w[0], 1.0f);
ASSERT_TRUE(opt_w.Empty());
}
} // namespace common
} // namespace xgboost