Support dmatrix construction from cupy array (#5206)

This commit is contained in:
Rory Mitchell
2020-01-22 13:15:27 +13:00
committed by GitHub
parent 2a071cebc5
commit 9c56480c61
19 changed files with 522 additions and 158 deletions

View File

@@ -78,6 +78,35 @@ void CopyDataColumnMajor(AdapterT* adapter, common::Span<Entry> data,
}
}
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
explicit IsValidFunctor(float missing) : missing(missing) {}
float missing;
__device__ bool operator()(const Entry& x) const {
return IsValid(x.fvalue, missing);
}
};
// Here the data is already correctly ordered and simply needs to be compacted
// to remove missing data
template <typename AdapterT>
void CopyDataRowMajor(AdapterT* adapter, common::Span<Entry> data,
int device_idx, float missing,
common::Span<size_t> row_ptr) {
auto& batch = adapter->Value();
auto transform_f = [=] __device__(size_t idx) {
const auto& e = batch.GetElement(idx);
return Entry(e.column_idx, e.value);
}; // NOLINT
auto counting = thrust::make_counting_iterator(0llu);
thrust::transform_iterator<decltype(transform_f), decltype(counting), Entry>
transform_iter(counting, transform_f);
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::copy_if(
thrust::cuda::par(alloc), transform_iter, transform_iter + batch.Size(),
thrust::device_pointer_cast(data.data()), IsValidFunctor(missing));
}
// Does not currently support metainfo as no on-device data source contains this
// Current implementation assumes a single batch. More batches can
// be supported in future. Does not currently support inferring row/column size
@@ -102,11 +131,14 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
mat.info.num_nonzero_ = mat.page_.offset.HostVector().back();
mat.page_.data.Resize(mat.info.num_nonzero_);
if (adapter->IsRowMajor()) {
LOG(FATAL) << "Not implemented.";
CopyDataRowMajor(adapter, mat.page_.data.DeviceSpan(),
adapter->DeviceIdx(), missing, s_offset);
} else {
CopyDataColumnMajor(adapter, mat.page_.data.DeviceSpan(),
adapter->DeviceIdx(), missing, s_offset);
}
// Sync
mat.page_.data.HostVector();
mat.info.num_col_ = adapter->NumColumns();
mat.info.num_row_ = adapter->NumRows();
@@ -116,5 +148,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(CupyAdapter* adapter, float missing,
int nthread);
} // namespace data
} // namespace xgboost