Improve test coverage with predictor configuration. (#9354)

* Improve test coverage with predictor configuration.

- Test with ext memory.
- Test with QDM.
- Test with dart.
This commit is contained in:
Jiaming Yuan
2023-07-05 15:17:22 +08:00
committed by GitHub
parent 6c9c8a9001
commit 645037e376
17 changed files with 280 additions and 79 deletions

View File

@@ -9,9 +9,10 @@
#include <xgboost/logging.h> // for CHECK_GE
#include <xgboost/parameter.h> // for XGBoostParameter
#include <cstdint> // for int16_t, int32_t, int64_t
#include <memory> // for shared_ptr
#include <string> // for string, to_string
#include <cstdint> // for int16_t, int32_t, int64_t
#include <memory> // for shared_ptr
#include <string> // for string, to_string
#include <type_traits> // for invoke_result_t, is_same_v
namespace xgboost {
@@ -152,6 +153,25 @@ struct Context : public XGBoostParameter<Context> {
ctx.gpu_id = kCpuId;
return ctx;
}
/**
* @brief Call function based on the current device.
*/
template <typename CPUFn, typename CUDAFn>
decltype(auto) DispatchDevice(CPUFn&& cpu_fn, CUDAFn&& cuda_fn) const {
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<CUDAFn>>);
switch (this->Device().device) {
case DeviceOrd::kCPU:
return cpu_fn();
case DeviceOrd::kCUDA:
return cuda_fn();
default:
// Do not use the device name as this is likely an internal error, the name
// wouldn't be valid.
LOG(FATAL) << "Unknown device type:" << static_cast<std::int16_t>(this->Device().device);
break;
}
return std::invoke_result_t<CPUFn>();
}
// declare parameters
DMLC_DECLARE_PARAMETER(Context) {

View File

@@ -6,24 +6,22 @@
*/
#pragma once
#include <xgboost/base.h>
#include <xgboost/cache.h> // DMatrixCache
#include <xgboost/cache.h> // for DMatrixCache
#include <xgboost/context.h> // for Context
#include <xgboost/context.h>
#include <xgboost/data.h>
#include <xgboost/host_device_vector.h>
#include <functional> // std::function
#include <memory>
#include <functional> // for function
#include <memory> // for shared_ptr
#include <string>
#include <thread> // for get_id
#include <utility> // for make_pair
#include <vector>
// Forward declarations
namespace xgboost {
namespace gbm {
namespace xgboost::gbm {
struct GBTreeModel;
} // namespace gbm
} // namespace xgboost
} // namespace xgboost::gbm
namespace xgboost {
/**