Support dmatrix construction from cupy array (#5206)
This commit is contained in:
@@ -78,6 +78,35 @@ void CopyDataColumnMajor(AdapterT* adapter, common::Span<Entry> data,
|
||||
}
|
||||
}
|
||||
|
||||
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
|
||||
explicit IsValidFunctor(float missing) : missing(missing) {}
|
||||
|
||||
float missing;
|
||||
__device__ bool operator()(const Entry& x) const {
|
||||
return IsValid(x.fvalue, missing);
|
||||
}
|
||||
};
|
||||
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
template <typename AdapterT>
|
||||
void CopyDataRowMajor(AdapterT* adapter, common::Span<Entry> data,
|
||||
int device_idx, float missing,
|
||||
common::Span<size_t> row_ptr) {
|
||||
auto& batch = adapter->Value();
|
||||
auto transform_f = [=] __device__(size_t idx) {
|
||||
const auto& e = batch.GetElement(idx);
|
||||
return Entry(e.column_idx, e.value);
|
||||
}; // NOLINT
|
||||
auto counting = thrust::make_counting_iterator(0llu);
|
||||
thrust::transform_iterator<decltype(transform_f), decltype(counting), Entry>
|
||||
transform_iter(counting, transform_f);
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::copy_if(
|
||||
thrust::cuda::par(alloc), transform_iter, transform_iter + batch.Size(),
|
||||
thrust::device_pointer_cast(data.data()), IsValidFunctor(missing));
|
||||
}
|
||||
|
||||
// Does not currently support metainfo as no on-device data source contains this
|
||||
// Current implementation assumes a single batch. More batches can
|
||||
// be supported in future. Does not currently support inferring row/column size
|
||||
@@ -102,11 +131,14 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
mat.info.num_nonzero_ = mat.page_.offset.HostVector().back();
|
||||
mat.page_.data.Resize(mat.info.num_nonzero_);
|
||||
if (adapter->IsRowMajor()) {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
CopyDataRowMajor(adapter, mat.page_.data.DeviceSpan(),
|
||||
adapter->DeviceIdx(), missing, s_offset);
|
||||
} else {
|
||||
CopyDataColumnMajor(adapter, mat.page_.data.DeviceSpan(),
|
||||
adapter->DeviceIdx(), missing, s_offset);
|
||||
}
|
||||
// Sync
|
||||
mat.page_.data.HostVector();
|
||||
|
||||
mat.info.num_col_ = adapter->NumColumns();
|
||||
mat.info.num_row_ = adapter->NumRows();
|
||||
@@ -116,5 +148,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
|
||||
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
template SimpleDMatrix::SimpleDMatrix(CupyAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user