Define CUDA Context. (#8604)
We will transition to non-default and non-blocking CUDA stream.
This commit is contained in:
parent
e01639548a
commit
c6a8754c62
@ -8,15 +8,14 @@
|
|||||||
#include <xgboost/logging.h>
|
#include <xgboost/logging.h>
|
||||||
#include <xgboost/parameter.h>
|
#include <xgboost/parameter.h>
|
||||||
|
|
||||||
|
#include <memory> // std::shared_ptr
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
struct Context : public XGBoostParameter<Context> {
|
struct CUDAContext;
|
||||||
private:
|
|
||||||
// cached value for CFS CPU limit. (used in containerized env)
|
|
||||||
std::int32_t cfs_cpu_count_; // NOLINT
|
|
||||||
|
|
||||||
|
struct Context : public XGBoostParameter<Context> {
|
||||||
public:
|
public:
|
||||||
// Constant representing the device ID of CPU.
|
// Constant representing the device ID of CPU.
|
||||||
static std::int32_t constexpr kCpuId = -1;
|
static std::int32_t constexpr kCpuId = -1;
|
||||||
@ -51,6 +50,7 @@ struct Context : public XGBoostParameter<Context> {
|
|||||||
|
|
||||||
bool IsCPU() const { return gpu_id == kCpuId; }
|
bool IsCPU() const { return gpu_id == kCpuId; }
|
||||||
bool IsCUDA() const { return !IsCPU(); }
|
bool IsCUDA() const { return !IsCPU(); }
|
||||||
|
CUDAContext const* CUDACtx() const;
|
||||||
|
|
||||||
// declare parameters
|
// declare parameters
|
||||||
DMLC_DECLARE_PARAMETER(Context) {
|
DMLC_DECLARE_PARAMETER(Context) {
|
||||||
@ -73,6 +73,14 @@ struct Context : public XGBoostParameter<Context> {
|
|||||||
.set_default(false)
|
.set_default(false)
|
||||||
.describe("Enable checking whether parameters are used or not.");
|
.describe("Enable checking whether parameters are used or not.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// mutable for lazy initialization for cuda context to avoid initializing CUDA at load.
|
||||||
|
// shared_ptr is used instead of unique_ptr as with unique_ptr it's difficult to define p_impl
|
||||||
|
// while trying to hide CUDA code from host compiler.
|
||||||
|
mutable std::shared_ptr<CUDAContext> cuctx_;
|
||||||
|
// cached value for CFS CPU limit. (used in containerized env)
|
||||||
|
std::int32_t cfs_cpu_count_; // NOLINT
|
||||||
};
|
};
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
28
src/common/cuda_context.cuh
Normal file
28
src/common/cuda_context.cuh
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2022 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_COMMON_CUDA_CONTEXT_CUH_
|
||||||
|
#define XGBOOST_COMMON_CUDA_CONTEXT_CUH_
|
||||||
|
#include <thrust/execution_policy.h>
|
||||||
|
|
||||||
|
#include "device_helpers.cuh"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
struct CUDAContext {
|
||||||
|
private:
|
||||||
|
dh::XGBCachingDeviceAllocator<char> caching_alloc_;
|
||||||
|
dh::XGBDeviceAllocator<char> alloc_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* \brief Caching thrust policy.
|
||||||
|
*/
|
||||||
|
auto CTP() const { return thrust::cuda::par(caching_alloc_).on(dh::DefaultStream()); }
|
||||||
|
/**
|
||||||
|
* \brief Thrust policy without caching allocator.
|
||||||
|
*/
|
||||||
|
auto TP() const { return thrust::cuda::par(alloc_).on(dh::DefaultStream()); }
|
||||||
|
auto Stream() const { return dh::DefaultStream(); }
|
||||||
|
};
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_COMMON_CUDA_CONTEXT_CUH_
|
||||||
@ -5,7 +5,7 @@
|
|||||||
*/
|
*/
|
||||||
#include <xgboost/context.h>
|
#include <xgboost/context.h>
|
||||||
|
|
||||||
#include "common/common.h"
|
#include "common/common.h" // AssertGPUSupport
|
||||||
#include "common/threading_utils.h"
|
#include "common/threading_utils.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -59,4 +59,11 @@ std::int32_t Context::Threads() const {
|
|||||||
}
|
}
|
||||||
return n_threads;
|
return n_threads;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
|
CUDAContext const* Context::CUDACtx() const {
|
||||||
|
common::AssertGPUSupport();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
14
src/context.cu
Normal file
14
src/context.cu
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2022 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include "common/cuda_context.cuh" // CUDAContext
|
||||||
|
#include "xgboost/context.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
CUDAContext const* Context::CUDACtx() const {
|
||||||
|
if (!cuctx_) {
|
||||||
|
cuctx_.reset(new CUDAContext{});
|
||||||
|
}
|
||||||
|
return cuctx_.get();
|
||||||
|
}
|
||||||
|
} // namespace xgboost
|
||||||
@ -1,18 +1,19 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2019-2021 by XGBoost Contributors
|
* Copyright 2019-2022 by XGBoost Contributors
|
||||||
*
|
*
|
||||||
* \file data.cu
|
* \file data.cu
|
||||||
* \brief Handles setting metainfo from array interface.
|
* \brief Handles setting metainfo from array interface.
|
||||||
*/
|
*/
|
||||||
#include "xgboost/data.h"
|
#include "../common/cuda_context.cuh"
|
||||||
#include "xgboost/logging.h"
|
|
||||||
#include "xgboost/json.h"
|
|
||||||
#include "array_interface.h"
|
|
||||||
#include "../common/device_helpers.cuh"
|
#include "../common/device_helpers.cuh"
|
||||||
#include "../common/linalg_op.cuh"
|
#include "../common/linalg_op.cuh"
|
||||||
|
#include "array_interface.h"
|
||||||
#include "device_adapter.cuh"
|
#include "device_adapter.cuh"
|
||||||
#include "simple_dmatrix.h"
|
#include "simple_dmatrix.h"
|
||||||
#include "validation.h"
|
#include "validation.h"
|
||||||
|
#include "xgboost/data.h"
|
||||||
|
#include "xgboost/json.h"
|
||||||
|
#include "xgboost/logging.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace {
|
namespace {
|
||||||
@ -25,7 +26,7 @@ auto SetDeviceToPtr(void const* ptr) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int32_t D>
|
template <typename T, int32_t D>
|
||||||
void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
void CopyTensorInfoImpl(CUDAContext const* ctx, Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
||||||
ArrayInterface<D> array(arr_interface);
|
ArrayInterface<D> array(arr_interface);
|
||||||
if (array.n == 0) {
|
if (array.n == 0) {
|
||||||
p_out->SetDevice(0);
|
p_out->SetDevice(0);
|
||||||
@ -43,15 +44,19 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
|||||||
// set data
|
// set data
|
||||||
data->Resize(array.n);
|
data->Resize(array.n);
|
||||||
dh::safe_cuda(cudaMemcpyAsync(data->DevicePointer(), array.data, array.n * sizeof(T),
|
dh::safe_cuda(cudaMemcpyAsync(data->DevicePointer(), array.data, array.n * sizeof(T),
|
||||||
cudaMemcpyDefault));
|
cudaMemcpyDefault, ctx->Stream()));
|
||||||
});
|
});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
p_out->Reshape(array.shape);
|
p_out->Reshape(array.shape);
|
||||||
auto t = p_out->View(ptr_device);
|
auto t = p_out->View(ptr_device);
|
||||||
linalg::ElementWiseTransformDevice(t, [=] __device__(size_t i, T) {
|
linalg::ElementWiseTransformDevice(
|
||||||
return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, array.shape));
|
t,
|
||||||
});
|
[=] __device__(size_t i, T) {
|
||||||
|
return linalg::detail::Apply(TypedIndex<T, D>{array},
|
||||||
|
linalg::UnravelIndex<D>(i, array.shape));
|
||||||
|
},
|
||||||
|
ctx->Stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void CopyGroupInfoImpl(ArrayInterface<1> column, std::vector<bst_group_t>* out) {
|
void CopyGroupInfoImpl(ArrayInterface<1> column, std::vector<bst_group_t>* out) {
|
||||||
@ -115,14 +120,13 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector<bst_group_t>* p_
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Context is not used until we have CUDA stream.
|
void MetaInfo::SetInfoFromCUDA(Context const& ctx, StringView key, Json array) {
|
||||||
void MetaInfo::SetInfoFromCUDA(Context const&, StringView key, Json array) {
|
|
||||||
// multi-dim float info
|
// multi-dim float info
|
||||||
if (key == "base_margin") {
|
if (key == "base_margin") {
|
||||||
CopyTensorInfoImpl(array, &base_margin_);
|
CopyTensorInfoImpl(ctx.CUDACtx(), array, &base_margin_);
|
||||||
return;
|
return;
|
||||||
} else if (key == "label") {
|
} else if (key == "label") {
|
||||||
CopyTensorInfoImpl(array, &labels);
|
CopyTensorInfoImpl(ctx.CUDACtx(), array, &labels);
|
||||||
auto ptr = labels.Data()->ConstDevicePointer();
|
auto ptr = labels.Data()->ConstDevicePointer();
|
||||||
auto valid = thrust::none_of(thrust::device, ptr, ptr + labels.Size(), data::LabelsCheck{});
|
auto valid = thrust::none_of(thrust::device, ptr, ptr + labels.Size(), data::LabelsCheck{});
|
||||||
CHECK(valid) << "Label contains NaN, infinity or a value too large.";
|
CHECK(valid) << "Label contains NaN, infinity or a value too large.";
|
||||||
@ -142,7 +146,7 @@ void MetaInfo::SetInfoFromCUDA(Context const&, StringView key, Json array) {
|
|||||||
}
|
}
|
||||||
// float info
|
// float info
|
||||||
linalg::Tensor<float, 1> t;
|
linalg::Tensor<float, 1> t;
|
||||||
CopyTensorInfoImpl(array, &t);
|
CopyTensorInfoImpl(ctx.CUDACtx(), array, &t);
|
||||||
if (key == "weight") {
|
if (key == "weight") {
|
||||||
this->weights_ = std::move(*t.Data());
|
this->weights_ = std::move(*t.Data());
|
||||||
auto ptr = weights_.ConstDevicePointer();
|
auto ptr = weights_.ConstDevicePointer();
|
||||||
@ -156,7 +160,7 @@ void MetaInfo::SetInfoFromCUDA(Context const&, StringView key, Json array) {
|
|||||||
this->feature_weights = std::move(*t.Data());
|
this->feature_weights = std::move(*t.Data());
|
||||||
auto d_feature_weights = feature_weights.ConstDeviceSpan();
|
auto d_feature_weights = feature_weights.ConstDeviceSpan();
|
||||||
auto valid =
|
auto valid =
|
||||||
thrust::none_of(thrust::device, d_feature_weights.data(),
|
thrust::none_of(ctx.CUDACtx()->CTP(), d_feature_weights.data(),
|
||||||
d_feature_weights.data() + d_feature_weights.size(), data::WeightsCheck{});
|
d_feature_weights.data() + d_feature_weights.size(), data::WeightsCheck{});
|
||||||
CHECK(valid) << "Feature weight must be greater than 0.";
|
CHECK(valid) << "Feature weight must be greater than 0.";
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -35,7 +35,7 @@
|
|||||||
#include "common/version.h"
|
#include "common/version.h"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/c_api.h"
|
#include "xgboost/c_api.h"
|
||||||
#include "xgboost/context.h"
|
#include "xgboost/context.h" // Context
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
#include "xgboost/feature_map.h"
|
#include "xgboost/feature_map.h"
|
||||||
#include "xgboost/gbm.h"
|
#include "xgboost/gbm.h"
|
||||||
|
|||||||
@ -267,12 +267,12 @@ __global__ void __launch_bounds__(kBlockThreads)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
|
||||||
FeatureGroupsAccessor const& feature_groups,
|
FeatureGroupsAccessor const& feature_groups,
|
||||||
common::Span<GradientPair const> gpair,
|
common::Span<GradientPair const> gpair,
|
||||||
common::Span<const uint32_t> d_ridx,
|
common::Span<const uint32_t> d_ridx,
|
||||||
common::Span<GradientPairInt64> histogram,
|
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding,
|
||||||
GradientQuantiser rounding, bool force_global_memory) {
|
bool force_global_memory) {
|
||||||
// decide whether to use shared memory
|
// decide whether to use shared memory
|
||||||
int device = 0;
|
int device = 0;
|
||||||
dh::safe_cuda(cudaGetDevice(&device));
|
dh::safe_cuda(cudaGetDevice(&device));
|
||||||
@ -318,9 +318,9 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
|||||||
min(grid_size,
|
min(grid_size,
|
||||||
unsigned(common::DivRoundUp(items_per_group, kMinItemsPerBlock)));
|
unsigned(common::DivRoundUp(items_per_group, kMinItemsPerBlock)));
|
||||||
|
|
||||||
dh::LaunchKernel {dim3(grid_size, num_groups),
|
dh::LaunchKernel{dim3(grid_size, num_groups), static_cast<uint32_t>(kBlockThreads), smem_size,
|
||||||
static_cast<uint32_t>(kBlockThreads), smem_size}(
|
ctx->Stream()} (kernel, matrix, feature_groups, d_ridx, histogram.data(),
|
||||||
kernel, matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding);
|
gpair.data(), rounding);
|
||||||
};
|
};
|
||||||
|
|
||||||
if (shared) {
|
if (shared) {
|
||||||
|
|||||||
@ -5,9 +5,9 @@
|
|||||||
#define HISTOGRAM_CUH_
|
#define HISTOGRAM_CUH_
|
||||||
#include <thrust/transform.h>
|
#include <thrust/transform.h>
|
||||||
|
|
||||||
#include "feature_groups.cuh"
|
#include "../../common/cuda_context.cuh"
|
||||||
|
|
||||||
#include "../../data/ellpack_page.cuh"
|
#include "../../data/ellpack_page.cuh"
|
||||||
|
#include "feature_groups.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -56,12 +56,11 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
|
||||||
FeatureGroupsAccessor const& feature_groups,
|
FeatureGroupsAccessor const& feature_groups,
|
||||||
common::Span<GradientPair const> gpair,
|
common::Span<GradientPair const> gpair,
|
||||||
common::Span<const uint32_t> ridx,
|
common::Span<const uint32_t> ridx,
|
||||||
common::Span<GradientPairInt64> histogram,
|
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding,
|
||||||
GradientQuantiser rounding,
|
|
||||||
bool force_global_memory = false);
|
bool force_global_memory = false);
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -20,6 +20,7 @@
|
|||||||
#include "../common/io.h"
|
#include "../common/io.h"
|
||||||
#include "../common/timer.h"
|
#include "../common/timer.h"
|
||||||
#include "../data/ellpack_page.cuh"
|
#include "../data/ellpack_page.cuh"
|
||||||
|
#include "../common/cuda_context.cuh" // CUDAContext
|
||||||
#include "constraints.cuh"
|
#include "constraints.cuh"
|
||||||
#include "driver.h"
|
#include "driver.h"
|
||||||
#include "gpu_hist/evaluate_splits.cuh"
|
#include "gpu_hist/evaluate_splits.cuh"
|
||||||
@ -344,9 +345,9 @@ struct GPUHistMakerDevice {
|
|||||||
void BuildHist(int nidx) {
|
void BuildHist(int nidx) {
|
||||||
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
||||||
auto d_ridx = row_partitioner->GetRows(nidx);
|
auto d_ridx = row_partitioner->GetRows(nidx);
|
||||||
BuildGradientHistogram(page->GetDeviceAccessor(ctx_->gpu_id),
|
BuildGradientHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->gpu_id),
|
||||||
feature_groups->DeviceAccessor(ctx_->gpu_id), gpair,
|
feature_groups->DeviceAccessor(ctx_->gpu_id), gpair, d_ridx, d_node_hist,
|
||||||
d_ridx, d_node_hist, *quantiser);
|
*quantiser);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attempt to do subtraction trick
|
// Attempt to do subtraction trick
|
||||||
@ -646,7 +647,7 @@ struct GPUHistMakerDevice {
|
|||||||
return quantiser.ToFixedPoint(gpair);
|
return quantiser.ToFixedPoint(gpair);
|
||||||
});
|
});
|
||||||
GradientPairInt64 root_sum_quantised =
|
GradientPairInt64 root_sum_quantised =
|
||||||
dh::Reduce(thrust::cuda::par(alloc), gpair_it, gpair_it + gpair.size(),
|
dh::Reduce(ctx_->CUDACtx()->CTP(), gpair_it, gpair_it + gpair.size(),
|
||||||
GradientPairInt64{}, thrust::plus<GradientPairInt64>{});
|
GradientPairInt64{}, thrust::plus<GradientPairInt64>{});
|
||||||
using ReduceT = typename decltype(root_sum_quantised)::ValueT;
|
using ReduceT = typename decltype(root_sum_quantised)::ValueT;
|
||||||
collective::Allreduce<collective::Operation::kSum>(
|
collective::Allreduce<collective::Operation::kSum>(
|
||||||
|
|||||||
@ -11,6 +11,7 @@ namespace xgboost {
|
|||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
||||||
|
Context ctx = CreateEmptyGenericParam(0);
|
||||||
size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16;
|
size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16;
|
||||||
float constexpr kLower = -1e-2, kUpper = 1e2;
|
float constexpr kLower = -1e-2, kUpper = 1e2;
|
||||||
|
|
||||||
@ -34,9 +35,9 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
|||||||
sizeof(GradientPairInt64));
|
sizeof(GradientPairInt64));
|
||||||
|
|
||||||
auto quantiser = GradientQuantiser(gpair.DeviceSpan());
|
auto quantiser = GradientQuantiser(gpair.DeviceSpan());
|
||||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
|
||||||
feature_groups.DeviceAccessor(0), gpair.DeviceSpan(),
|
feature_groups.DeviceAccessor(0), gpair.DeviceSpan(), ridx, d_histogram,
|
||||||
ridx, d_histogram, quantiser);
|
quantiser);
|
||||||
|
|
||||||
std::vector<GradientPairInt64> histogram_h(num_bins);
|
std::vector<GradientPairInt64> histogram_h(num_bins);
|
||||||
dh::safe_cuda(cudaMemcpy(histogram_h.data(), d_histogram.data(),
|
dh::safe_cuda(cudaMemcpy(histogram_h.data(), d_histogram.data(),
|
||||||
@ -48,10 +49,9 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
|||||||
auto d_new_histogram = dh::ToSpan(new_histogram);
|
auto d_new_histogram = dh::ToSpan(new_histogram);
|
||||||
|
|
||||||
auto quantiser = GradientQuantiser(gpair.DeviceSpan());
|
auto quantiser = GradientQuantiser(gpair.DeviceSpan());
|
||||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
|
||||||
feature_groups.DeviceAccessor(0),
|
feature_groups.DeviceAccessor(0), gpair.DeviceSpan(), ridx,
|
||||||
gpair.DeviceSpan(), ridx, d_new_histogram,
|
d_new_histogram, quantiser);
|
||||||
quantiser);
|
|
||||||
|
|
||||||
std::vector<GradientPairInt64> new_histogram_h(num_bins);
|
std::vector<GradientPairInt64> new_histogram_h(num_bins);
|
||||||
dh::safe_cuda(cudaMemcpy(new_histogram_h.data(), d_new_histogram.data(),
|
dh::safe_cuda(cudaMemcpy(new_histogram_h.data(), d_new_histogram.data(),
|
||||||
@ -71,10 +71,9 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
|||||||
FeatureGroups single_group(page->Cuts());
|
FeatureGroups single_group(page->Cuts());
|
||||||
|
|
||||||
dh::device_vector<GradientPairInt64> baseline(num_bins);
|
dh::device_vector<GradientPairInt64> baseline(num_bins);
|
||||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
|
||||||
single_group.DeviceAccessor(0),
|
single_group.DeviceAccessor(0), gpair.DeviceSpan(), ridx,
|
||||||
gpair.DeviceSpan(), ridx, dh::ToSpan(baseline),
|
dh::ToSpan(baseline), quantiser);
|
||||||
quantiser);
|
|
||||||
|
|
||||||
std::vector<GradientPairInt64> baseline_h(num_bins);
|
std::vector<GradientPairInt64> baseline_h(num_bins);
|
||||||
dh::safe_cuda(cudaMemcpy(baseline_h.data(), baseline.data().get(),
|
dh::safe_cuda(cudaMemcpy(baseline_h.data(), baseline.data().get(),
|
||||||
@ -115,6 +114,7 @@ void ValidateCategoricalHistogram(size_t n_categories, common::Span<GradientPair
|
|||||||
|
|
||||||
// Test 1 vs rest categorical histogram is equivalent to one hot encoded data.
|
// Test 1 vs rest categorical histogram is equivalent to one hot encoded data.
|
||||||
void TestGPUHistogramCategorical(size_t num_categories) {
|
void TestGPUHistogramCategorical(size_t num_categories) {
|
||||||
|
auto ctx = CreateEmptyGenericParam(0);
|
||||||
size_t constexpr kRows = 340;
|
size_t constexpr kRows = 340;
|
||||||
size_t constexpr kBins = 256;
|
size_t constexpr kBins = 256;
|
||||||
auto x = GenerateRandomCategoricalSingleColumn(kRows, num_categories);
|
auto x = GenerateRandomCategoricalSingleColumn(kRows, num_categories);
|
||||||
@ -133,10 +133,9 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
|||||||
for (auto const &batch : cat_m->GetBatches<EllpackPage>(batch_param)) {
|
for (auto const &batch : cat_m->GetBatches<EllpackPage>(batch_param)) {
|
||||||
auto* page = batch.Impl();
|
auto* page = batch.Impl();
|
||||||
FeatureGroups single_group(page->Cuts());
|
FeatureGroups single_group(page->Cuts());
|
||||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
|
||||||
single_group.DeviceAccessor(0),
|
single_group.DeviceAccessor(0), gpair.DeviceSpan(), ridx,
|
||||||
gpair.DeviceSpan(), ridx, dh::ToSpan(cat_hist),
|
dh::ToSpan(cat_hist), quantiser);
|
||||||
quantiser);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -148,10 +147,9 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
|||||||
for (auto const &batch : encode_m->GetBatches<EllpackPage>(batch_param)) {
|
for (auto const &batch : encode_m->GetBatches<EllpackPage>(batch_param)) {
|
||||||
auto* page = batch.Impl();
|
auto* page = batch.Impl();
|
||||||
FeatureGroups single_group(page->Cuts());
|
FeatureGroups single_group(page->Cuts());
|
||||||
BuildGradientHistogram(page->GetDeviceAccessor(0),
|
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
|
||||||
single_group.DeviceAccessor(0),
|
single_group.DeviceAccessor(0), gpair.DeviceSpan(), ridx,
|
||||||
gpair.DeviceSpan(), ridx, dh::ToSpan(encode_hist),
|
dh::ToSpan(encode_hist), quantiser);
|
||||||
quantiser);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<GradientPairInt64> h_cat_hist(cat_hist.size());
|
std::vector<GradientPairInt64> h_cat_hist(cat_hist.size());
|
||||||
|
|||||||
@ -109,11 +109,10 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
|||||||
maker.gpair = gpair.DeviceSpan();
|
maker.gpair = gpair.DeviceSpan();
|
||||||
maker.quantiser.reset(new GradientQuantiser(maker.gpair));
|
maker.quantiser.reset(new GradientQuantiser(maker.gpair));
|
||||||
|
|
||||||
BuildGradientHistogram(
|
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
|
||||||
page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0),
|
maker.feature_groups->DeviceAccessor(0), gpair.DeviceSpan(),
|
||||||
gpair.DeviceSpan(), maker.row_partitioner->GetRows(0),
|
maker.row_partitioner->GetRows(0), maker.hist.GetNodeHistogram(0),
|
||||||
maker.hist.GetNodeHistogram(0), *maker.quantiser,
|
*maker.quantiser, !use_shared_memory_histograms);
|
||||||
!use_shared_memory_histograms);
|
|
||||||
|
|
||||||
DeviceHistogramStorage<>& d_hist = maker.hist;
|
DeviceHistogramStorage<>& d_hist = maker.hist;
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user