[coll] Allreduce. (#9679)
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/bitfield.h"
|
||||
#include "../common/bitfield.h" // for RBitField8
|
||||
#include "../common/common.h"
|
||||
#include "../common/error_msg.h" // for NoF128
|
||||
#include "xgboost/base.h"
|
||||
@@ -104,7 +104,20 @@ struct ArrayInterfaceErrors {
|
||||
*/
|
||||
class ArrayInterfaceHandler {
|
||||
public:
|
||||
enum Type : std::int8_t { kF2, kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
|
||||
enum Type : std::int8_t {
|
||||
kF2 = 0,
|
||||
kF4 = 1,
|
||||
kF8 = 2,
|
||||
kF16 = 3,
|
||||
kI1 = 4,
|
||||
kI2 = 5,
|
||||
kI4 = 6,
|
||||
kI8 = 7,
|
||||
kU1 = 8,
|
||||
kU2 = 9,
|
||||
kU4 = 10,
|
||||
kU8 = 11,
|
||||
};
|
||||
|
||||
template <typename PtrType>
|
||||
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
|
||||
@@ -587,6 +600,57 @@ class ArrayInterface {
|
||||
ArrayInterfaceHandler::Type type{ArrayInterfaceHandler::kF16};
|
||||
};
|
||||
|
||||
template <typename Fn>
|
||||
auto DispatchDType(ArrayInterfaceHandler::Type dtype, Fn dispatch) {
|
||||
switch (dtype) {
|
||||
case ArrayInterfaceHandler::kF2: {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
return dispatch(__half{});
|
||||
#else
|
||||
LOG(FATAL) << "half type is only supported for CUDA input.";
|
||||
break;
|
||||
#endif
|
||||
}
|
||||
case ArrayInterfaceHandler::kF4: {
|
||||
return dispatch(float{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kF8: {
|
||||
return dispatch(double{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kF16: {
|
||||
using T = long double;
|
||||
CHECK(sizeof(T) == 16) << error::NoF128();
|
||||
return dispatch(T{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kI1: {
|
||||
return dispatch(std::int8_t{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kI2: {
|
||||
return dispatch(std::int16_t{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kI4: {
|
||||
return dispatch(std::int32_t{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kI8: {
|
||||
return dispatch(std::int64_t{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kU1: {
|
||||
return dispatch(std::uint8_t{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kU2: {
|
||||
return dispatch(std::uint16_t{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kU4: {
|
||||
return dispatch(std::uint32_t{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kU8: {
|
||||
return dispatch(std::uint64_t{});
|
||||
}
|
||||
}
|
||||
|
||||
return std::result_of_t<Fn(std::int8_t)>();
|
||||
}
|
||||
|
||||
template <std::int32_t D, typename Fn>
|
||||
void DispatchDType(ArrayInterface<D> const array, DeviceOrd device, Fn fn) {
|
||||
// Only used for cuDF at the moment.
|
||||
@@ -602,60 +666,7 @@ void DispatchDType(ArrayInterface<D> const array, DeviceOrd device, Fn fn) {
|
||||
std::numeric_limits<std::size_t>::max()},
|
||||
array.shape, array.strides, device});
|
||||
};
|
||||
switch (array.type) {
|
||||
case ArrayInterfaceHandler::kF2: {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dispatch(__half{});
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kF4: {
|
||||
dispatch(float{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kF8: {
|
||||
dispatch(double{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kF16: {
|
||||
using T = long double;
|
||||
CHECK(sizeof(long double) == 16) << error::NoF128();
|
||||
dispatch(T{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kI1: {
|
||||
dispatch(std::int8_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kI2: {
|
||||
dispatch(std::int16_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kI4: {
|
||||
dispatch(std::int32_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kI8: {
|
||||
dispatch(std::int64_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kU1: {
|
||||
dispatch(std::uint8_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kU2: {
|
||||
dispatch(std::uint16_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kU4: {
|
||||
dispatch(std::uint32_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kU8: {
|
||||
dispatch(std::uint64_t{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
DispatchDType(array.type, dispatch);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user