Make sure metrics work with column-wise distributed training (#9020)
This commit is contained in:
@@ -39,6 +39,18 @@
|
||||
#define GPUIDX -1
|
||||
#endif
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
#define DeclareUnifiedDistributedTest(name) MGPU ## name
|
||||
#else
|
||||
#define DeclareUnifiedDistributedTest(name) name
|
||||
#endif
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
#define WORLD_SIZE_FOR_TEST (xgboost::common::AllVisibleGPUs())
|
||||
#else
|
||||
#define WORLD_SIZE_FOR_TEST (3)
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
class ObjFunction;
|
||||
class Metric;
|
||||
@@ -92,13 +104,15 @@ xgboost::bst_float GetMetricEval(
|
||||
xgboost::HostDeviceVector<xgboost::bst_float> const& preds,
|
||||
std::vector<xgboost::bst_float> labels,
|
||||
std::vector<xgboost::bst_float> weights = std::vector<xgboost::bst_float>(),
|
||||
std::vector<xgboost::bst_uint> groups = std::vector<xgboost::bst_uint>());
|
||||
std::vector<xgboost::bst_uint> groups = std::vector<xgboost::bst_uint>(),
|
||||
xgboost::DataSplitMode data_split_Mode = xgboost::DataSplitMode::kRow);
|
||||
|
||||
double GetMultiMetricEval(xgboost::Metric* metric,
|
||||
xgboost::HostDeviceVector<xgboost::bst_float> const& preds,
|
||||
xgboost::linalg::Tensor<float, 2> const& labels,
|
||||
std::vector<xgboost::bst_float> weights = {},
|
||||
std::vector<xgboost::bst_uint> groups = {});
|
||||
std::vector<xgboost::bst_uint> groups = {},
|
||||
xgboost::DataSplitMode data_split_Mode = xgboost::DataSplitMode::kRow);
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@@ -496,4 +510,17 @@ void RunWithInMemoryCommunicator(int32_t world_size, Function&& function, Args&&
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
class DeclareUnifiedDistributedTest(MetricTest) : public ::testing::Test {
|
||||
protected:
|
||||
int world_size_;
|
||||
|
||||
void SetUp() override {
|
||||
world_size_ = WORLD_SIZE_FOR_TEST;
|
||||
if (world_size_ <= 1) {
|
||||
GTEST_SKIP() << "Skipping MGPU test with # GPUs = " << world_size_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user