Extract optional weight. (#8747)
- Extract optional weight from coommon.h to reduce dependency on this header. - Add test.
This commit is contained in:
parent
0f37a01dd9
commit
28bb01aa22
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2015-2022 by XGBoost Contributors
|
* Copyright 2015-2023 by XGBoost Contributors
|
||||||
* \file common.h
|
* \file common.h
|
||||||
* \brief Common utilities
|
* \brief Common utilities
|
||||||
*/
|
*/
|
||||||
@ -199,17 +199,6 @@ std::vector<Idx> ArgSort(Container const &array, Comp comp = std::less<V>{}) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct OptionalWeights {
|
|
||||||
Span<float const> weights;
|
|
||||||
float dft{1.0f}; // fixme: make this compile time constant
|
|
||||||
|
|
||||||
explicit OptionalWeights(Span<float const> w) : weights{w} {}
|
|
||||||
explicit OptionalWeights(float w) : dft{w} {}
|
|
||||||
|
|
||||||
XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; }
|
|
||||||
auto Empty() const { return weights.empty(); }
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Last index of a group in a CSR style of index pointer.
|
* Last index of a group in a CSR style of index pointer.
|
||||||
*/
|
*/
|
||||||
|
|||||||
33
src/common/optional_weight.h
Normal file
33
src/common/optional_weight.h
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2022-2023 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_COMMON_OPTIONAL_WEIGHT_H_
|
||||||
|
#define XGBOOST_COMMON_OPTIONAL_WEIGHT_H_
|
||||||
|
#include "xgboost/base.h" // XGBOOST_DEVICE
|
||||||
|
#include "xgboost/context.h" // Context
|
||||||
|
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||||
|
#include "xgboost/span.h" // Span
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
struct OptionalWeights {
|
||||||
|
Span<float const> weights;
|
||||||
|
float dft{1.0f}; // fixme: make this compile time constant
|
||||||
|
|
||||||
|
explicit OptionalWeights(Span<float const> w) : weights{w} {}
|
||||||
|
explicit OptionalWeights(float w) : dft{w} {}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; }
|
||||||
|
auto Empty() const { return weights.empty(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
inline OptionalWeights MakeOptionalWeights(Context const* ctx,
|
||||||
|
HostDeviceVector<float> const& weights) {
|
||||||
|
if (ctx->IsCUDA()) {
|
||||||
|
weights.SetDevice(ctx->gpu_id);
|
||||||
|
}
|
||||||
|
return OptionalWeights{ctx->IsCPU() ? weights.ConstHostSpan() : weights.ConstDeviceSpan()};
|
||||||
|
}
|
||||||
|
} // namespace common
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_COMMON_OPTIONAL_WEIGHT_H_
|
||||||
@ -19,6 +19,7 @@
|
|||||||
|
|
||||||
#include "categorical.h"
|
#include "categorical.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
#include "optional_weight.h" // OptionalWeights
|
||||||
#include "threading_utils.h"
|
#include "threading_utils.h"
|
||||||
#include "timer.h"
|
#include "timer.h"
|
||||||
|
|
||||||
|
|||||||
@ -6,9 +6,9 @@
|
|||||||
|
|
||||||
#include <cstddef> // size_t
|
#include <cstddef> // size_t
|
||||||
|
|
||||||
#include "common.h" // common::OptionalWeights
|
|
||||||
#include "cuda_context.cuh" // CUDAContext
|
#include "cuda_context.cuh" // CUDAContext
|
||||||
#include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend
|
#include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend
|
||||||
|
#include "optional_weight.h" // common::OptionalWeights
|
||||||
#include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile
|
#include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile
|
||||||
#include "xgboost/base.h" // XGBOOST_DEVICE
|
#include "xgboost/base.h" // XGBOOST_DEVICE
|
||||||
#include "xgboost/context.h" // Context
|
#include "xgboost/context.h" // Context
|
||||||
|
|||||||
@ -9,6 +9,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "common.h" // AssertGPUSupport, OptionalWeights
|
#include "common.h" // AssertGPUSupport, OptionalWeights
|
||||||
|
#include "optional_weight.h" // OptionalWeights
|
||||||
#include "transform_iterator.h" // MakeIndexTransformIter
|
#include "transform_iterator.h" // MakeIndexTransformIter
|
||||||
#include "xgboost/context.h" // Context
|
#include "xgboost/context.h" // Context
|
||||||
#include "xgboost/linalg.h"
|
#include "xgboost/linalg.h"
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
|
#include "../common/optional_weight.h" // OptionalWeights
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
#include "xgboost/linalg.h"
|
#include "xgboost/linalg.h"
|
||||||
#include "xgboost/metric.h"
|
#include "xgboost/metric.h"
|
||||||
|
|||||||
@ -1,21 +1,22 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021-2022 by XGBoost Contributors
|
* Copyright 2021-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <thrust/scan.h>
|
#include <thrust/scan.h>
|
||||||
#include <cub/cub.cuh>
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
#include <cub/cub.cuh>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "xgboost/span.h"
|
|
||||||
#include "xgboost/data.h"
|
|
||||||
#include "auc.h"
|
|
||||||
#include "../collective/device_communicator.cuh"
|
#include "../collective/device_communicator.cuh"
|
||||||
|
#include "../common/optional_weight.h" // OptionalWeights
|
||||||
#include "../common/ranking_utils.cuh"
|
#include "../common/ranking_utils.cuh"
|
||||||
|
#include "auc.h"
|
||||||
|
#include "xgboost/data.h"
|
||||||
|
#include "xgboost/span.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace metric {
|
namespace metric {
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
#include "../collective/communicator-inl.h"
|
#include "../collective/communicator-inl.h"
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
|
#include "../common/optional_weight.h" // OptionalWeights
|
||||||
#include "../common/pseudo_huber.h"
|
#include "../common/pseudo_huber.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
#include "metric_common.h"
|
#include "metric_common.h"
|
||||||
|
|||||||
@ -14,7 +14,8 @@
|
|||||||
|
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
#include "../common/linalg_op.h"
|
#include "../common/linalg_op.h"
|
||||||
#include "../common/numeric.h" // Reduce
|
#include "../common/numeric.h" // Reduce
|
||||||
|
#include "../common/optional_weight.h" // OptionalWeights
|
||||||
#include "../common/pseudo_huber.h"
|
#include "../common/pseudo_huber.h"
|
||||||
#include "../common/stats.h"
|
#include "../common/stats.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
|
|||||||
24
tests/cpp/common/test_optional_weight.cc
Normal file
24
tests/cpp/common/test_optional_weight.cc
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/context.h> // Context
|
||||||
|
#include <xgboost/host_device_vector.h> // HostDeviceVector
|
||||||
|
|
||||||
|
#include "../../../src/common/optional_weight.h"
|
||||||
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
TEST(OptionalWeight, Basic) {
|
||||||
|
HostDeviceVector<float> weight{{2.0f, 3.0f, 4.0f}};
|
||||||
|
Context ctx;
|
||||||
|
auto opt_w = MakeOptionalWeights(&ctx, weight);
|
||||||
|
ASSERT_EQ(opt_w[0], 2.0f);
|
||||||
|
ASSERT_FALSE(opt_w.Empty());
|
||||||
|
|
||||||
|
weight.HostVector().clear();
|
||||||
|
opt_w = MakeOptionalWeights(&ctx, weight);
|
||||||
|
ASSERT_EQ(opt_w[0], 1.0f);
|
||||||
|
ASSERT_TRUE(opt_w.Empty());
|
||||||
|
}
|
||||||
|
} // namespace common
|
||||||
|
} // namespace xgboost
|
||||||
Loading…
x
Reference in New Issue
Block a user