Extract optional weight. (#8747)
- Extract optional weight from coommon.h to reduce dependency on this header. - Add test.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2015-2023 by XGBoost Contributors
|
||||
* \file common.h
|
||||
* \brief Common utilities
|
||||
*/
|
||||
@@ -199,17 +199,6 @@ std::vector<Idx> ArgSort(Container const &array, Comp comp = std::less<V>{}) {
|
||||
return result;
|
||||
}
|
||||
|
||||
struct OptionalWeights {
|
||||
Span<float const> weights;
|
||||
float dft{1.0f}; // fixme: make this compile time constant
|
||||
|
||||
explicit OptionalWeights(Span<float const> w) : weights{w} {}
|
||||
explicit OptionalWeights(float w) : dft{w} {}
|
||||
|
||||
XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; }
|
||||
auto Empty() const { return weights.empty(); }
|
||||
};
|
||||
|
||||
/**
|
||||
* Last index of a group in a CSR style of index pointer.
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user