Use generic dispatching routine for array interface. (#6672)

This commit is contained in:
Jiaming Yuan
2021-02-05 09:23:38 +08:00
committed by GitHub
parent a4101de678
commit 1e949110da
4 changed files with 65 additions and 31 deletions

View File

@@ -315,40 +315,50 @@ class ArrayInterface {
}
}
XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const {
void* p_values;
template <typename Fn>
XGBOOST_HOST_DEV_INLINE decltype(auto) DispatchCall(Fn func) const {
switch (type) {
case kF4:
p_values = reinterpret_cast<float *>(data) + offset;
return func(reinterpret_cast<float *>(data));
break;
case kF8:
p_values = reinterpret_cast<double *>(data) + offset;
return func(reinterpret_cast<double *>(data));
break;
case kI1:
p_values = reinterpret_cast<int8_t *>(data) + offset;
return func(reinterpret_cast<int8_t *>(data));
break;
case kI2:
p_values = reinterpret_cast<int16_t *>(data) + offset;
return func(reinterpret_cast<int16_t *>(data));
break;
case kI4:
p_values = reinterpret_cast<int32_t *>(data) + offset;
return func(reinterpret_cast<int32_t *>(data));
break;
case kI8:
p_values = reinterpret_cast<int64_t *>(data) + offset;
return func(reinterpret_cast<int64_t *>(data));
break;
case kU1:
p_values = reinterpret_cast<uint8_t *>(data) + offset;
return func(reinterpret_cast<uint8_t *>(data));
break;
case kU2:
p_values = reinterpret_cast<uint16_t *>(data) + offset;
return func(reinterpret_cast<uint16_t *>(data));
break;
case kU4:
p_values = reinterpret_cast<uint32_t *>(data) + offset;
return func(reinterpret_cast<uint32_t *>(data));
break;
case kU8:
p_values = reinterpret_cast<uint64_t *>(data) + offset;
return func(reinterpret_cast<uint64_t *>(data));
break;
}
SPAN_CHECK(false);
return func(reinterpret_cast<uint64_t *>(data));
}
XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const {
void* p_values{nullptr};
this->DispatchCall([&p_values, offset](auto *ptr) {
p_values = ptr + offset;
});
ArrayInterface ret = *this;
ret.data = p_values;
return ret;
@@ -390,6 +400,12 @@ class ArrayInterface {
return reinterpret_cast<float*>(data)[idx];
}
XGBOOST_DEVICE size_t ElementSize() {
return this->DispatchCall([](auto* p_values) {
return sizeof(std::remove_pointer_t<decltype(p_values)>);
});
}
RBitField8 valid;
bst_row_t num_rows;
bst_feature_t num_cols;