Make SimpleDMatrix ctor reusable. (#7075)
This commit is contained in:
parent
d7e1fa7664
commit
116d711815
@ -1,89 +1,34 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2019 by Contributors
|
* Copyright 2019-2021 by XGBoost Contributors
|
||||||
* \file simple_dmatrix.cu
|
* \file simple_dmatrix.cu
|
||||||
*/
|
*/
|
||||||
#include <thrust/copy.h>
|
#include <thrust/copy.h>
|
||||||
#include <thrust/execution_policy.h>
|
|
||||||
#include <thrust/sort.h>
|
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
#include "../common/random.h"
|
#include "simple_dmatrix.cuh"
|
||||||
#include "./simple_dmatrix.h"
|
#include "simple_dmatrix.h"
|
||||||
#include "device_adapter.cuh"
|
#include "device_adapter.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
|
|
||||||
template <typename AdapterBatchT>
|
|
||||||
void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
|
||||||
int device_idx, float missing) {
|
|
||||||
IsValidFunctor is_valid(missing);
|
|
||||||
// Count elements per row
|
|
||||||
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
|
|
||||||
auto element = batch.GetElement(idx);
|
|
||||||
if (is_valid(element)) {
|
|
||||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
|
||||||
&offset[element.row_idx]),
|
|
||||||
static_cast<unsigned long long>(1)); // NOLINT
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
|
||||||
thrust::exclusive_scan(thrust::cuda::par(alloc),
|
|
||||||
thrust::device_pointer_cast(offset.data()),
|
|
||||||
thrust::device_pointer_cast(offset.data() + offset.size()),
|
|
||||||
thrust::device_pointer_cast(offset.data()));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename AdapterBatchT>
|
|
||||||
struct COOToEntryOp {
|
|
||||||
AdapterBatchT batch;
|
|
||||||
__device__ Entry operator()(size_t idx) {
|
|
||||||
const auto& e = batch.GetElement(idx);
|
|
||||||
return Entry(e.column_idx, e.value);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Here the data is already correctly ordered and simply needs to be compacted
|
|
||||||
// to remove missing data
|
|
||||||
template <typename AdapterT>
|
|
||||||
void CopyDataToDMatrix(AdapterT* adapter, common::Span<Entry> data,
|
|
||||||
float missing) {
|
|
||||||
auto batch = adapter->Value();
|
|
||||||
auto counting = thrust::make_counting_iterator(0llu);
|
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
|
||||||
COOToEntryOp<decltype(batch)> transform_op{batch};
|
|
||||||
thrust::transform_iterator<decltype(transform_op), decltype(counting)>
|
|
||||||
transform_iter(counting, transform_op);
|
|
||||||
auto begin_output = thrust::device_pointer_cast(data.data());
|
|
||||||
dh::CopyIf(transform_iter, transform_iter + batch.Size(), begin_output,
|
|
||||||
IsValidFunctor(missing));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Does not currently support metainfo as no on-device data source contains this
|
// Does not currently support metainfo as no on-device data source contains this
|
||||||
// Current implementation assumes a single batch. More batches can
|
// Current implementation assumes a single batch. More batches can
|
||||||
// be supported in future. Does not currently support inferring row/column size
|
// be supported in future. Does not currently support inferring row/column size
|
||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||||
dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx()));
|
dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx()));
|
||||||
|
|
||||||
CHECK(adapter->NumRows() != kAdapterUnknownSize);
|
CHECK(adapter->NumRows() != kAdapterUnknownSize);
|
||||||
CHECK(adapter->NumColumns() != kAdapterUnknownSize);
|
CHECK(adapter->NumColumns() != kAdapterUnknownSize);
|
||||||
|
|
||||||
adapter->BeforeFirst();
|
adapter->BeforeFirst();
|
||||||
adapter->Next();
|
adapter->Next();
|
||||||
auto& batch = adapter->Value();
|
|
||||||
sparse_page_.offset.SetDevice(adapter->DeviceIdx());
|
|
||||||
sparse_page_.data.SetDevice(adapter->DeviceIdx());
|
|
||||||
|
|
||||||
// Enforce single batch
|
// Enforce single batch
|
||||||
CHECK(!adapter->Next());
|
CHECK(!adapter->Next());
|
||||||
sparse_page_.offset.Resize(adapter->NumRows() + 1);
|
|
||||||
auto s_offset = sparse_page_.offset.DeviceSpan();
|
|
||||||
CountRowOffsets(batch, s_offset, adapter->DeviceIdx(), missing);
|
|
||||||
info_.num_nonzero_ = sparse_page_.offset.HostVector().back();
|
|
||||||
sparse_page_.data.Resize(info_.num_nonzero_);
|
|
||||||
CopyDataToDMatrix(adapter, sparse_page_.data.DeviceSpan(), missing);
|
|
||||||
|
|
||||||
|
info_.num_nonzero_ = CopyToSparsePage(adapter->Value(), adapter->DeviceIdx(),
|
||||||
|
missing, &sparse_page_);
|
||||||
info_.num_col_ = adapter->NumColumns();
|
info_.num_col_ = adapter->NumColumns();
|
||||||
info_.num_row_ = adapter->NumRows();
|
info_.num_row_ = adapter->NumRows();
|
||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
|
|||||||
78
src/data/simple_dmatrix.cuh
Normal file
78
src/data/simple_dmatrix.cuh
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2019-2021 by XGBoost Contributors
|
||||||
|
* \file simple_dmatrix.cuh
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_DATA_SIMPLE_DMATRIX_CUH_
|
||||||
|
#define XGBOOST_DATA_SIMPLE_DMATRIX_CUH_
|
||||||
|
|
||||||
|
#include <thrust/copy.h>
|
||||||
|
#include <thrust/scan.h>
|
||||||
|
#include <thrust/execution_policy.h>
|
||||||
|
#include "device_adapter.cuh"
|
||||||
|
#include "../common/device_helpers.cuh"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace data {
|
||||||
|
|
||||||
|
template <typename AdapterBatchT>
|
||||||
|
struct COOToEntryOp {
|
||||||
|
AdapterBatchT batch;
|
||||||
|
__device__ Entry operator()(size_t idx) {
|
||||||
|
const auto& e = batch.GetElement(idx);
|
||||||
|
return Entry(e.column_idx, e.value);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Here the data is already correctly ordered and simply needs to be compacted
|
||||||
|
// to remove missing data
|
||||||
|
template <typename AdapterBatchT>
|
||||||
|
void CopyDataToDMatrix(AdapterBatchT batch, common::Span<Entry> data,
|
||||||
|
float missing) {
|
||||||
|
auto counting = thrust::make_counting_iterator(0llu);
|
||||||
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
|
COOToEntryOp<decltype(batch)> transform_op{batch};
|
||||||
|
thrust::transform_iterator<decltype(transform_op), decltype(counting)>
|
||||||
|
transform_iter(counting, transform_op);
|
||||||
|
auto begin_output = thrust::device_pointer_cast(data.data());
|
||||||
|
dh::CopyIf(transform_iter, transform_iter + batch.Size(), begin_output,
|
||||||
|
IsValidFunctor(missing));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename AdapterBatchT>
|
||||||
|
void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
||||||
|
int device_idx, float missing) {
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||||
|
IsValidFunctor is_valid(missing);
|
||||||
|
// Count elements per row
|
||||||
|
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
|
||||||
|
auto element = batch.GetElement(idx);
|
||||||
|
if (is_valid(element)) {
|
||||||
|
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||||
|
&offset[element.row_idx]),
|
||||||
|
static_cast<unsigned long long>(1)); // NOLINT
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
|
thrust::exclusive_scan(thrust::cuda::par(alloc),
|
||||||
|
thrust::device_pointer_cast(offset.data()),
|
||||||
|
thrust::device_pointer_cast(offset.data() + offset.size()),
|
||||||
|
thrust::device_pointer_cast(offset.data()));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename AdapterBatchT>
|
||||||
|
size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing, SparsePage* page) {
|
||||||
|
page->offset.SetDevice(device);
|
||||||
|
page->data.SetDevice(device);
|
||||||
|
page->offset.Resize(batch.NumRows() + 1);
|
||||||
|
auto s_offset = page->offset.DeviceSpan();
|
||||||
|
CountRowOffsets(batch, s_offset, device, missing);
|
||||||
|
auto num_nonzero_ = page->offset.HostVector().back();
|
||||||
|
page->data.Resize(num_nonzero_);
|
||||||
|
CopyDataToDMatrix(batch, page->data.DeviceSpan(), missing);
|
||||||
|
|
||||||
|
return num_nonzero_;
|
||||||
|
}
|
||||||
|
} // namespace data
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_DATA_SIMPLE_DMATRIX_CUH_
|
||||||
Loading…
x
Reference in New Issue
Block a user