initial merge
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019-2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2019-2023 by XGBoost contributors
|
||||
*/
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
#include <thrust/iterator/transform_output_iterator.h>
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "../common/random.h"
|
||||
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "./ellpack_page.cuh"
|
||||
#include "device_adapter.cuh"
|
||||
#include "device_adapter.cuh" // for HasInfInData
|
||||
#include "gradient_index.h"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
@@ -203,9 +203,8 @@ struct TupleScanOp {
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
template <typename AdapterBatchT>
|
||||
void CopyDataToEllpack(const AdapterBatchT &batch,
|
||||
common::Span<FeatureType const> feature_types,
|
||||
EllpackPageImpl *dst, int device_idx, float missing) {
|
||||
void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType const> feature_types,
|
||||
EllpackPageImpl* dst, int device_idx, float missing) {
|
||||
// Some witchcraft happens here
|
||||
// The goal is to copy valid elements out of the input to an ELLPACK matrix
|
||||
// with a given row stride, using no extra working memory Standard stream
|
||||
@@ -215,6 +214,9 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
|
||||
// correct output position
|
||||
auto counting = thrust::make_counting_iterator(0llu);
|
||||
data::IsValidFunctor is_valid(missing);
|
||||
bool valid = data::HasInfInData(batch, is_valid);
|
||||
CHECK(valid) << error::InfInData();
|
||||
|
||||
auto key_iter = dh::MakeTransformIterator<size_t>(
|
||||
counting,
|
||||
[=] __device__(size_t idx) {
|
||||
@@ -255,9 +257,9 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
|
||||
cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
|
||||
TupleScanOp<Tuple>, cub::NullType, int64_t>;
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
|
||||
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
|
||||
nullptr);
|
||||
dh::safe_cuda(DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
|
||||
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
|
||||
nullptr));
|
||||
#else
|
||||
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
|
||||
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
|
||||
@@ -265,9 +267,9 @@ void CopyDataToEllpack(const AdapterBatchT &batch,
|
||||
#endif
|
||||
dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
|
||||
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
||||
cub::NullType(), batch.Size(), nullptr);
|
||||
dh::safe_cuda(DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
|
||||
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
||||
cub::NullType(), batch.Size(), nullptr));
|
||||
#else
|
||||
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
|
||||
key_value_index_iter, out, TupleScanOp<Tuple>(),
|
||||
|
||||
Reference in New Issue
Block a user