Use hypothesis (#5759)

* Use hypothesis

* Allow int64 array interface for groups

* Add packages to Windows CI

* Add to travis

* Make sure device index is set correctly

* Fix dask-cudf test

* appveyor
This commit is contained in:
Rory Mitchell
2020-06-16 12:45:59 +12:00
committed by GitHub
parent 02884b08aa
commit b47b5ac771
17 changed files with 411 additions and 439 deletions

View File

@@ -34,6 +34,30 @@ void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {
});
}
void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
CHECK(column.type[1] == 'i' || column.type[1] == 'u')
<< "Expected integer metainfo";
auto SetDeviceToPtr = [](void* ptr) {
cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
int32_t ptr_device = attr.device;
dh::safe_cuda(cudaSetDevice(ptr_device));
return ptr_device;
};
auto ptr_device = SetDeviceToPtr(column.data);
dh::TemporaryArray<bst_group_t> temp(column.num_rows);
auto d_tmp = temp.data();
dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) {
d_tmp[idx] = column.GetElement(idx);
});
auto length = column.num_rows;
out->resize(length + 1);
out->at(0) = 0;
thrust::copy(temp.data(), temp.data() + length, out->begin() + 1);
std::partial_sum(out->begin(), out->end(), out->begin());
}
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()});
auto const& j_arr = get<Array>(j_interface);
@@ -53,16 +77,7 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
} else if (key == "base_margin") {
CopyInfoImpl(array_interface, &base_margin_);
} else if (key == "group") {
// Ranking is not performed on device.
thrust::device_ptr<uint32_t> p_src{
reinterpret_cast<uint32_t*>(array_interface.data)};
auto length = array_interface.num_rows;
group_ptr_.resize(length + 1);
group_ptr_[0] = 0;
thrust::copy(p_src, p_src + length, group_ptr_.begin() + 1);
std::partial_sum(group_ptr_.begin(), group_ptr_.end(), group_ptr_.begin());
CopyGroupInfoImpl(array_interface, &group_ptr_);
return;
} else {
LOG(FATAL) << "Unknown metainfo: " << key;