Fix IsDense. (#5702)
This commit is contained in:
parent
e35ad8a074
commit
8438c7d0e4
@ -33,16 +33,17 @@ DeviceDMatrix::DeviceDMatrix(AdapterT* adapter, float missing, int nthread, int
|
|||||||
size_t row_stride =
|
size_t row_stride =
|
||||||
GetRowCounts(batch, row_counts_span, adapter->DeviceIdx(), missing);
|
GetRowCounts(batch, row_counts_span, adapter->DeviceIdx(), missing);
|
||||||
|
|
||||||
ellpack_page_.reset(new EllpackPage());
|
|
||||||
*ellpack_page_->Impl() =
|
|
||||||
EllpackPageImpl(adapter, missing, this->IsDense(), nthread, max_bin,
|
|
||||||
row_counts_span, row_stride);
|
|
||||||
|
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
info_.num_nonzero_ = thrust::reduce(thrust::cuda::par(alloc),
|
info_.num_nonzero_ = thrust::reduce(thrust::cuda::par(alloc),
|
||||||
row_counts.begin(), row_counts.end());
|
row_counts.begin(), row_counts.end());
|
||||||
info_.num_col_ = adapter->NumColumns();
|
info_.num_col_ = adapter->NumColumns();
|
||||||
info_.num_row_ = adapter->NumRows();
|
info_.num_row_ = adapter->NumRows();
|
||||||
|
|
||||||
|
ellpack_page_.reset(new EllpackPage());
|
||||||
|
*ellpack_page_->Impl() =
|
||||||
|
EllpackPageImpl(adapter, missing, this->IsDense(), nthread, max_bin,
|
||||||
|
row_counts_span, row_stride);
|
||||||
|
|
||||||
// Synchronise worker columns
|
// Synchronise worker columns
|
||||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -129,3 +129,22 @@ TEST(DeviceDMatrix, Equivalent) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(DeviceDMatrix, IsDense) {
|
||||||
|
int num_bins = 16;
|
||||||
|
auto test = [num_bins] (float sparsity) {
|
||||||
|
HostDeviceVector<float> data;
|
||||||
|
std::string interface_str = RandomDataGenerator{10, 10, sparsity}
|
||||||
|
.Device(0).GenerateArrayInterface(&data);
|
||||||
|
data::CupyAdapter x{interface_str};
|
||||||
|
std::unique_ptr<data::DeviceDMatrix> device_dmat{ new data::DeviceDMatrix(
|
||||||
|
&x, std::numeric_limits<float>::quiet_NaN(), 1, num_bins) };
|
||||||
|
if (sparsity == 0.0) {
|
||||||
|
ASSERT_TRUE(device_dmat->IsDense()) << sparsity;
|
||||||
|
} else {
|
||||||
|
ASSERT_FALSE(device_dmat->IsDense());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
test(0.0);
|
||||||
|
test(0.1);
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user