Extract optional weight. (#8747)

- Extract optional weight from coommon.h to reduce dependency on this header.
- Add test.
This commit is contained in:
Jiaming Yuan
2023-02-07 03:11:53 +08:00
committed by GitHub
parent 0f37a01dd9
commit 28bb01aa22
10 changed files with 74 additions and 22 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2015-2022 by XGBoost Contributors
/**
* Copyright 2015-2023 by XGBoost Contributors
* \file common.h
* \brief Common utilities
*/
@@ -199,17 +199,6 @@ std::vector<Idx> ArgSort(Container const &array, Comp comp = std::less<V>{}) {
return result;
}
struct OptionalWeights {
Span<float const> weights;
float dft{1.0f}; // fixme: make this compile time constant
explicit OptionalWeights(Span<float const> w) : weights{w} {}
explicit OptionalWeights(float w) : dft{w} {}
XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; }
auto Empty() const { return weights.empty(); }
};
/**
* Last index of a group in a CSR style of index pointer.
*/