Extract CPU sampling routines. (#8697)

This commit is contained in:
Jiaming Yuan
2023-01-19 23:28:18 +08:00
committed by GitHub
parent 7a068af1a3
commit e49e0998c0
6 changed files with 211 additions and 119 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2021-2022 by XGBoost Contributors
/**
* Copyright 2021-2023 by XGBoost Contributors
* \file linalg.h
* \brief Linear algebra related utilities.
*/
@@ -8,7 +8,7 @@
#include <dmlc/endian.h>
#include <xgboost/base.h>
#include <xgboost/context.h> // fixme(jiamingy): Remove the dependency on this header.
#include <xgboost/context.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/json.h>
#include <xgboost/span.h>
@@ -834,9 +834,26 @@ class Tensor {
int32_t DeviceIdx() const { return data_.DeviceIdx(); }
};
template <typename T>
using Matrix = Tensor<T, 2>;
template <typename T>
using Vector = Tensor<T, 1>;
/**
* \brief Create an array without initialization.
*/
template <typename T, typename... Index>
auto Empty(Context const *ctx, Index &&...index) {
Tensor<T, sizeof...(Index)> t;
t.SetDevice(ctx->gpu_id);
t.Reshape(index...);
return t;
}
/**
* \brief Create an array with value v.
*/
template <typename T, typename... Index>
auto Constant(Context const *ctx, T v, Index &&...index) {
Tensor<T, sizeof...(Index)> t;
@@ -846,7 +863,6 @@ auto Constant(Context const *ctx, T v, Index &&...index) {
return t;
}
/**
* \brief Like `np.zeros`, return a new array of given shape and type, filled with zeros.
*/