Extract optional weight. (#8747)

- Extract optional weight from coommon.h to reduce dependency on this header.
- Add test.
This commit is contained in:
Jiaming Yuan 2023-02-07 03:11:53 +08:00 committed by GitHub
parent 0f37a01dd9
commit 28bb01aa22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 74 additions and 22 deletions

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2015-2022 by XGBoost Contributors
/**
* Copyright 2015-2023 by XGBoost Contributors
* \file common.h
* \brief Common utilities
*/
@ -199,17 +199,6 @@ std::vector<Idx> ArgSort(Container const &array, Comp comp = std::less<V>{}) {
return result;
}
struct OptionalWeights {
Span<float const> weights;
float dft{1.0f}; // fixme: make this compile time constant
explicit OptionalWeights(Span<float const> w) : weights{w} {}
explicit OptionalWeights(float w) : dft{w} {}
XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; }
auto Empty() const { return weights.empty(); }
};
/**
* Last index of a group in a CSR style of index pointer.
*/

View File

@ -0,0 +1,33 @@
/**
* Copyright 2022-2023 by XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_OPTIONAL_WEIGHT_H_
#define XGBOOST_COMMON_OPTIONAL_WEIGHT_H_
#include "xgboost/base.h" // XGBOOST_DEVICE
#include "xgboost/context.h" // Context
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/span.h" // Span
namespace xgboost {
namespace common {
struct OptionalWeights {
Span<float const> weights;
float dft{1.0f}; // fixme: make this compile time constant
explicit OptionalWeights(Span<float const> w) : weights{w} {}
explicit OptionalWeights(float w) : dft{w} {}
XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; }
auto Empty() const { return weights.empty(); }
};
inline OptionalWeights MakeOptionalWeights(Context const* ctx,
HostDeviceVector<float> const& weights) {
if (ctx->IsCUDA()) {
weights.SetDevice(ctx->gpu_id);
}
return OptionalWeights{ctx->IsCPU() ? weights.ConstHostSpan() : weights.ConstDeviceSpan()};
}
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_OPTIONAL_WEIGHT_H_

View File

@ -19,6 +19,7 @@
#include "categorical.h"
#include "common.h"
#include "optional_weight.h" // OptionalWeights
#include "threading_utils.h"
#include "timer.h"

View File

@ -6,9 +6,9 @@
#include <cstddef> // size_t
#include "common.h" // common::OptionalWeights
#include "cuda_context.cuh" // CUDAContext
#include "device_helpers.cuh" // dh::MakeTransformIterator, tcbegin, tcend
#include "optional_weight.h" // common::OptionalWeights
#include "stats.cuh" // common::SegmentedQuantile, common::SegmentedWeightedQuantile
#include "xgboost/base.h" // XGBOOST_DEVICE
#include "xgboost/context.h" // Context

View File

@ -9,6 +9,7 @@
#include <vector>
#include "common.h" // AssertGPUSupport, OptionalWeights
#include "optional_weight.h" // OptionalWeights
#include "transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/context.h" // Context
#include "xgboost/linalg.h"

View File

@ -15,6 +15,7 @@
#include <vector>
#include "../common/math.h"
#include "../common/optional_weight.h" // OptionalWeights
#include "xgboost/host_device_vector.h"
#include "xgboost/linalg.h"
#include "xgboost/metric.h"

View File

@ -1,21 +1,22 @@
/*!
* Copyright 2021-2022 by XGBoost Contributors
/**
* Copyright 2021-2023 by XGBoost Contributors
*/
#include <thrust/scan.h>
#include <cub/cub.cuh>
#include <algorithm>
#include <cassert>
#include <cub/cub.cuh>
#include <limits>
#include <memory>
#include <utility>
#include <tuple>
#include <utility>
#include "xgboost/span.h"
#include "xgboost/data.h"
#include "auc.h"
#include "../collective/device_communicator.cuh"
#include "../common/optional_weight.h" // OptionalWeights
#include "../common/ranking_utils.cuh"
#include "auc.h"
#include "xgboost/data.h"
#include "xgboost/span.h"
namespace xgboost {
namespace metric {

View File

@ -14,6 +14,7 @@
#include "../collective/communicator-inl.h"
#include "../common/common.h"
#include "../common/math.h"
#include "../common/optional_weight.h" // OptionalWeights
#include "../common/pseudo_huber.h"
#include "../common/threading_utils.h"
#include "metric_common.h"

View File

@ -14,7 +14,8 @@
#include "../common/common.h"
#include "../common/linalg_op.h"
#include "../common/numeric.h" // Reduce
#include "../common/numeric.h" // Reduce
#include "../common/optional_weight.h" // OptionalWeights
#include "../common/pseudo_huber.h"
#include "../common/stats.h"
#include "../common/threading_utils.h"

View File

@ -0,0 +1,24 @@
/**
* Copyright 2023 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/context.h> // Context
#include <xgboost/host_device_vector.h> // HostDeviceVector
#include "../../../src/common/optional_weight.h"
namespace xgboost {
namespace common {
TEST(OptionalWeight, Basic) {
HostDeviceVector<float> weight{{2.0f, 3.0f, 4.0f}};
Context ctx;
auto opt_w = MakeOptionalWeights(&ctx, weight);
ASSERT_EQ(opt_w[0], 2.0f);
ASSERT_FALSE(opt_w.Empty());
weight.HostVector().clear();
opt_w = MakeOptionalWeights(&ctx, weight);
ASSERT_EQ(opt_w[0], 1.0f);
ASSERT_TRUE(opt_w.Empty());
}
} // namespace common
} // namespace xgboost