Device dmatrix (#5420)
This commit is contained in:
@@ -8,26 +8,20 @@
|
||||
#include <xgboost/data.h>
|
||||
#include "../common/random.h"
|
||||
#include "./simple_dmatrix.h"
|
||||
#include "../common/math.h"
|
||||
#include "device_adapter.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
XGBOOST_DEVICE bool IsValid(float value, float missing) {
|
||||
if (common::CheckNAN(value) || value == missing) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename AdapterBatchT>
|
||||
void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
||||
int device_idx, float missing) {
|
||||
IsValidFunctor is_valid(missing);
|
||||
// Count elements per row
|
||||
dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) {
|
||||
auto element = batch.GetElement(idx);
|
||||
if (IsValid(element.value, missing)) {
|
||||
if (is_valid(element)) {
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&offset[element.row_idx]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
@@ -66,11 +60,12 @@ void CopyDataColumnMajor(AdapterT* adapter, common::Span<Entry> data,
|
||||
thrust::device_pointer_cast(row_ptr.data() + row_ptr.size()));
|
||||
auto d_temp_row_ptr = temp_row_ptr.data().get();
|
||||
size_t begin = 0;
|
||||
IsValidFunctor is_valid(missing);
|
||||
for (auto size : host_column_sizes) {
|
||||
size_t end = begin + size;
|
||||
dh::LaunchN(device_idx, end - begin, [=] __device__(size_t idx) {
|
||||
const auto& e = batch.GetElement(idx + begin);
|
||||
if (!IsValid(e.value, missing)) return;
|
||||
if (!is_valid(e)) return;
|
||||
data[d_temp_row_ptr[e.row_idx]] = Entry(e.column_idx, e.value);
|
||||
d_temp_row_ptr[e.row_idx] += 1;
|
||||
});
|
||||
@@ -79,15 +74,6 @@ void CopyDataColumnMajor(AdapterT* adapter, common::Span<Entry> data,
|
||||
}
|
||||
}
|
||||
|
||||
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
|
||||
explicit IsValidFunctor(float missing) : missing(missing) {}
|
||||
|
||||
float missing;
|
||||
__device__ bool operator()(const Entry& x) const {
|
||||
return IsValid(x.fvalue, missing);
|
||||
}
|
||||
};
|
||||
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
template <typename AdapterT>
|
||||
|
||||
Reference in New Issue
Block a user