Use generic dispatching routine for array interface. (#6672)
This commit is contained in:
@@ -315,40 +315,50 @@ class ArrayInterface {
|
||||
}
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const {
|
||||
void* p_values;
|
||||
template <typename Fn>
|
||||
XGBOOST_HOST_DEV_INLINE decltype(auto) DispatchCall(Fn func) const {
|
||||
switch (type) {
|
||||
case kF4:
|
||||
p_values = reinterpret_cast<float *>(data) + offset;
|
||||
return func(reinterpret_cast<float *>(data));
|
||||
break;
|
||||
case kF8:
|
||||
p_values = reinterpret_cast<double *>(data) + offset;
|
||||
return func(reinterpret_cast<double *>(data));
|
||||
break;
|
||||
case kI1:
|
||||
p_values = reinterpret_cast<int8_t *>(data) + offset;
|
||||
return func(reinterpret_cast<int8_t *>(data));
|
||||
break;
|
||||
case kI2:
|
||||
p_values = reinterpret_cast<int16_t *>(data) + offset;
|
||||
return func(reinterpret_cast<int16_t *>(data));
|
||||
break;
|
||||
case kI4:
|
||||
p_values = reinterpret_cast<int32_t *>(data) + offset;
|
||||
return func(reinterpret_cast<int32_t *>(data));
|
||||
break;
|
||||
case kI8:
|
||||
p_values = reinterpret_cast<int64_t *>(data) + offset;
|
||||
return func(reinterpret_cast<int64_t *>(data));
|
||||
break;
|
||||
case kU1:
|
||||
p_values = reinterpret_cast<uint8_t *>(data) + offset;
|
||||
return func(reinterpret_cast<uint8_t *>(data));
|
||||
break;
|
||||
case kU2:
|
||||
p_values = reinterpret_cast<uint16_t *>(data) + offset;
|
||||
return func(reinterpret_cast<uint16_t *>(data));
|
||||
break;
|
||||
case kU4:
|
||||
p_values = reinterpret_cast<uint32_t *>(data) + offset;
|
||||
return func(reinterpret_cast<uint32_t *>(data));
|
||||
break;
|
||||
case kU8:
|
||||
p_values = reinterpret_cast<uint64_t *>(data) + offset;
|
||||
return func(reinterpret_cast<uint64_t *>(data));
|
||||
break;
|
||||
}
|
||||
SPAN_CHECK(false);
|
||||
return func(reinterpret_cast<uint64_t *>(data));
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const {
|
||||
void* p_values{nullptr};
|
||||
this->DispatchCall([&p_values, offset](auto *ptr) {
|
||||
p_values = ptr + offset;
|
||||
});
|
||||
|
||||
ArrayInterface ret = *this;
|
||||
ret.data = p_values;
|
||||
return ret;
|
||||
@@ -390,6 +400,12 @@ class ArrayInterface {
|
||||
return reinterpret_cast<float*>(data)[idx];
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE size_t ElementSize() {
|
||||
return this->DispatchCall([](auto* p_values) {
|
||||
return sizeof(std::remove_pointer_t<decltype(p_values)>);
|
||||
});
|
||||
}
|
||||
|
||||
RBitField8 valid;
|
||||
bst_row_t num_rows;
|
||||
bst_feature_t num_cols;
|
||||
|
||||
Reference in New Issue
Block a user