Support column split in gpu hist updater (#9384)
This commit is contained in:
parent
ccfc90e4c6
commit
9bab06cbca
40
src/collective/aggregator.cuh
Normal file
40
src/collective/aggregator.cuh
Normal file
@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*
|
||||
* Higher level functions built on top the Communicator API, taking care of behavioral differences
|
||||
* between row-split vs column-split distributed training, and horizontal vs vertical federated
|
||||
* learning.
|
||||
*/
|
||||
#pragma once
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator-inl.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* @brief Find the global sum of the given values across all workers.
|
||||
*
|
||||
* This only applies when the data is split row-wise (horizontally). When data is split
|
||||
* column-wise (vertically), the original values are returned.
|
||||
*
|
||||
* @tparam T The type of the values.
|
||||
* @param info MetaInfo about the DMatrix.
|
||||
* @param device The device id.
|
||||
* @param values Pointer to the inputs to sum.
|
||||
* @param size Number of values to sum.
|
||||
*/
|
||||
template <typename T>
|
||||
void GlobalSum(MetaInfo const& info, int device, T* values, size_t size) {
|
||||
if (info.IsRowSplit()) {
|
||||
collective::AllReduce<collective::Operation::kSum>(device, values, size);
|
||||
}
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@ -418,7 +418,8 @@ void GPUHistEvaluator::EvaluateSplits(
|
||||
|
||||
// Reduce to get the best candidate from all workers.
|
||||
dh::LaunchN(out_splits.size(), [world_size, all_candidates, out_splits] __device__(size_t i) {
|
||||
for (auto rank = 0; rank < world_size; rank++) {
|
||||
out_splits[i] = all_candidates[i];
|
||||
for (auto rank = 1; rank < world_size; rank++) {
|
||||
out_splits[i] = out_splits[i] + all_candidates[rank * out_splits.size() + i];
|
||||
}
|
||||
});
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <cstdint> // uint32_t
|
||||
#include <limits>
|
||||
|
||||
#include "../../collective/aggregator.h"
|
||||
#include "../../common/deterministic.cuh"
|
||||
#include "../../common/device_helpers.cuh"
|
||||
#include "../../data/ellpack_page.cuh"
|
||||
@ -52,7 +53,7 @@ struct Clip : public thrust::unary_function<GradientPair, Pair> {
|
||||
*
|
||||
* to avoid outliers, as the full reduction is reproducible on GPU with reduction tree.
|
||||
*/
|
||||
GradientQuantiser::GradientQuantiser(common::Span<GradientPair const> gpair) {
|
||||
GradientQuantiser::GradientQuantiser(common::Span<GradientPair const> gpair, MetaInfo const& info) {
|
||||
using GradientSumT = GradientPairPrecise;
|
||||
using T = typename GradientSumT::ValueT;
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
@ -64,11 +65,11 @@ GradientQuantiser::GradientQuantiser(common::Span<GradientPair const> gpair) {
|
||||
// Treat pair as array of 4 primitive types to allreduce
|
||||
using ReduceT = typename decltype(p.first)::ValueT;
|
||||
static_assert(sizeof(Pair) == sizeof(ReduceT) * 4, "Expected to reduce four elements.");
|
||||
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<ReduceT*>(&p), 4);
|
||||
collective::GlobalSum(info, reinterpret_cast<ReduceT*>(&p), 4);
|
||||
GradientPair positive_sum{p.first}, negative_sum{p.second};
|
||||
|
||||
std::size_t total_rows = gpair.size();
|
||||
collective::Allreduce<collective::Operation::kSum>(&total_rows, 1);
|
||||
collective::GlobalSum(info, &total_rows, 1);
|
||||
|
||||
auto histogram_rounding =
|
||||
GradientSumT{common::CreateRoundingFactor<T>(
|
||||
|
||||
@ -39,7 +39,7 @@ private:
|
||||
GradientPairPrecise to_floating_point_;
|
||||
|
||||
public:
|
||||
explicit GradientQuantiser(common::Span<GradientPair const> gpair);
|
||||
GradientQuantiser(common::Span<GradientPair const> gpair, MetaInfo const& info);
|
||||
XGBOOST_DEVICE GradientPairInt64 ToFixedPoint(GradientPair const& gpair) const {
|
||||
auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(),
|
||||
gpair.GetHess() * to_fixed_point_.GetHess());
|
||||
|
||||
@ -129,7 +129,7 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
|
||||
int batch_idx;
|
||||
std::size_t item_idx;
|
||||
AssignBatch(batch_info_itr, idx, &batch_idx, &item_idx);
|
||||
auto op_res = op(ridx[item_idx], batch_info_itr[batch_idx].data);
|
||||
auto op_res = op(ridx[item_idx], batch_idx, batch_info_itr[batch_idx].data);
|
||||
return IndexFlagTuple{static_cast<bst_uint>(item_idx), op_res, batch_idx, op_res};
|
||||
});
|
||||
size_t temp_bytes = 0;
|
||||
|
||||
@ -12,7 +12,8 @@
|
||||
#include <utility> // for move
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "../collective/aggregator.h"
|
||||
#include "../collective/aggregator.cuh"
|
||||
#include "../common/bitfield.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/cuda_context.cuh" // CUDAContext
|
||||
@ -161,6 +162,7 @@ struct GPUHistMakerDevice {
|
||||
GPUHistEvaluator evaluator_;
|
||||
Context const* ctx_;
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
||||
MetaInfo const& info_;
|
||||
|
||||
public:
|
||||
EllpackPageImpl const* page{nullptr};
|
||||
@ -193,13 +195,14 @@ struct GPUHistMakerDevice {
|
||||
GPUHistMakerDevice(Context const* ctx, bool is_external_memory,
|
||||
common::Span<FeatureType const> _feature_types, bst_row_t _n_rows,
|
||||
TrainParam _param, std::shared_ptr<common::ColumnSampler> column_sampler,
|
||||
uint32_t n_features, BatchParam batch_param)
|
||||
uint32_t n_features, BatchParam batch_param, MetaInfo const& info)
|
||||
: evaluator_{_param, n_features, ctx->gpu_id},
|
||||
ctx_(ctx),
|
||||
feature_types{_feature_types},
|
||||
param(std::move(_param)),
|
||||
column_sampler_(std::move(column_sampler)),
|
||||
interaction_constraints(param, n_features) {
|
||||
interaction_constraints(param, n_features),
|
||||
info_{info} {
|
||||
sampler = std::make_unique<GradientBasedSampler>(ctx, _n_rows, batch_param, param.subsample,
|
||||
param.sampling_method, is_external_memory);
|
||||
if (!param.monotone_constraints.empty()) {
|
||||
@ -245,7 +248,7 @@ struct GPUHistMakerDevice {
|
||||
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param,
|
||||
dmat->Info().IsColumnSplit(), ctx_->gpu_id);
|
||||
|
||||
quantiser = std::make_unique<GradientQuantiser>(this->gpair);
|
||||
quantiser = std::make_unique<GradientQuantiser>(this->gpair, dmat->Info());
|
||||
|
||||
row_partitioner.reset(); // Release the device memory first before reallocating
|
||||
row_partitioner = std::make_unique<RowPartitioner>(ctx_->gpu_id, sample.sample_rows);
|
||||
@ -369,6 +372,66 @@ struct GPUHistMakerDevice {
|
||||
common::KCatBitField node_cats;
|
||||
};
|
||||
|
||||
void UpdatePositionColumnSplit(EllpackDeviceAccessor d_matrix,
|
||||
std::vector<NodeSplitData> const& split_data,
|
||||
std::vector<bst_node_t> const& nidx,
|
||||
std::vector<bst_node_t> const& left_nidx,
|
||||
std::vector<bst_node_t> const& right_nidx) {
|
||||
auto const num_candidates = split_data.size();
|
||||
|
||||
using BitVector = LBitField64;
|
||||
using BitType = BitVector::value_type;
|
||||
auto const size = BitVector::ComputeStorageSize(d_matrix.n_rows * num_candidates);
|
||||
dh::TemporaryArray<BitType> decision_storage(size, 0);
|
||||
dh::TemporaryArray<BitType> missing_storage(size, 0);
|
||||
BitVector decision_bits{dh::ToSpan(decision_storage)};
|
||||
BitVector missing_bits{dh::ToSpan(missing_storage)};
|
||||
|
||||
dh::TemporaryArray<NodeSplitData> split_data_storage(num_candidates);
|
||||
dh::safe_cuda(cudaMemcpyAsync(split_data_storage.data().get(), split_data.data(),
|
||||
num_candidates * sizeof(NodeSplitData), cudaMemcpyDefault));
|
||||
auto d_split_data = dh::ToSpan(split_data_storage);
|
||||
|
||||
dh::LaunchN(d_matrix.n_rows, [=] __device__(std::size_t ridx) mutable {
|
||||
for (auto i = 0; i < num_candidates; i++) {
|
||||
auto const& data = d_split_data[i];
|
||||
auto const cut_value = d_matrix.GetFvalue(ridx, data.split_node.SplitIndex());
|
||||
if (isnan(cut_value)) {
|
||||
missing_bits.Set(ridx * num_candidates + i);
|
||||
} else {
|
||||
bool go_left;
|
||||
if (data.split_type == FeatureType::kCategorical) {
|
||||
go_left = common::Decision(data.node_cats.Bits(), cut_value);
|
||||
} else {
|
||||
go_left = cut_value <= data.split_node.SplitCond();
|
||||
}
|
||||
if (go_left) {
|
||||
decision_bits.Set(ridx * num_candidates + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
collective::AllReduce<collective::Operation::kBitwiseOR>(
|
||||
ctx_->gpu_id, decision_storage.data().get(), decision_storage.size());
|
||||
collective::AllReduce<collective::Operation::kBitwiseAND>(
|
||||
ctx_->gpu_id, missing_storage.data().get(), missing_storage.size());
|
||||
collective::Synchronize(ctx_->gpu_id);
|
||||
|
||||
row_partitioner->UpdatePositionBatch(
|
||||
nidx, left_nidx, right_nidx, split_data,
|
||||
[=] __device__(bst_uint ridx, int split_index, NodeSplitData const& data) {
|
||||
auto const index = ridx * num_candidates + split_index;
|
||||
bool go_left;
|
||||
if (missing_bits.Check(index)) {
|
||||
go_left = data.split_node.DefaultLeft();
|
||||
} else {
|
||||
go_left = decision_bits.Check(index);
|
||||
}
|
||||
return go_left;
|
||||
});
|
||||
}
|
||||
|
||||
void UpdatePosition(std::vector<GPUExpandEntry> const& candidates, RegTree* p_tree) {
|
||||
if (candidates.empty()) {
|
||||
return;
|
||||
@ -392,9 +455,15 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
|
||||
|
||||
if (info_.IsColumnSplit()) {
|
||||
UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx);
|
||||
return;
|
||||
}
|
||||
|
||||
row_partitioner->UpdatePositionBatch(
|
||||
nidx, left_nidx, right_nidx, split_data,
|
||||
[=] __device__(bst_uint ridx, const NodeSplitData& data) {
|
||||
[=] __device__(bst_uint ridx, int split_index, const NodeSplitData& data) {
|
||||
// given a row index, returns the node id it belongs to
|
||||
float cut_value = d_matrix.GetFvalue(ridx, data.split_node.SplitIndex());
|
||||
// Missing value
|
||||
@ -544,9 +613,8 @@ struct GPUHistMakerDevice {
|
||||
monitor.Start("AllReduce");
|
||||
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
|
||||
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
|
||||
collective::AllReduce<collective::Operation::kSum>(
|
||||
ctx_->gpu_id, reinterpret_cast<ReduceT*>(d_node_hist),
|
||||
page->Cuts().TotalBins() * 2 * num_histograms);
|
||||
collective::GlobalSum(info_, ctx_->gpu_id, reinterpret_cast<ReduceT*>(d_node_hist),
|
||||
page->Cuts().TotalBins() * 2 * num_histograms);
|
||||
|
||||
monitor.Stop("AllReduce");
|
||||
}
|
||||
@ -663,8 +731,7 @@ struct GPUHistMakerDevice {
|
||||
dh::Reduce(ctx_->CUDACtx()->CTP(), gpair_it, gpair_it + gpair.size(),
|
||||
GradientPairInt64{}, thrust::plus<GradientPairInt64>{});
|
||||
using ReduceT = typename decltype(root_sum_quantised)::ValueT;
|
||||
collective::Allreduce<collective::Operation::kSum>(
|
||||
reinterpret_cast<ReduceT *>(&root_sum_quantised), 2);
|
||||
collective::GlobalSum(info_, reinterpret_cast<ReduceT*>(&root_sum_quantised), 2);
|
||||
|
||||
hist.AllocateHistograms({kRootNIdx});
|
||||
this->BuildHist(kRootNIdx);
|
||||
@ -801,7 +868,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
info_->feature_types.SetDevice(ctx_->gpu_id);
|
||||
maker = std::make_unique<GPUHistMakerDevice>(
|
||||
ctx_, !dmat->SingleColBlock(), info_->feature_types.ConstDeviceSpan(), info_->num_row_,
|
||||
*param, column_sampler_, info_->num_col_, batch_param);
|
||||
*param, column_sampler_, info_->num_col_, batch_param, dmat->Info());
|
||||
|
||||
p_last_fmat_ = dmat;
|
||||
initialised_ = true;
|
||||
@ -915,7 +982,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
|
||||
auto batch = BatchParam{param->max_bin, hess, !task_->const_hess};
|
||||
maker_ = std::make_unique<GPUHistMakerDevice>(
|
||||
ctx_, !p_fmat->SingleColBlock(), info.feature_types.ConstDeviceSpan(), info.num_row_,
|
||||
*param, column_sampler_, info.num_col_, batch);
|
||||
*param, column_sampler_, info.num_col_, batch, p_fmat->Info());
|
||||
|
||||
std::size_t t_idx{0};
|
||||
for (xgboost::RegTree* tree : trees) {
|
||||
|
||||
@ -24,7 +24,7 @@ auto ZeroParam() {
|
||||
inline GradientQuantiser DummyRoundingFactor() {
|
||||
thrust::device_vector<GradientPair> gpair(1);
|
||||
gpair[0] = {1000.f, 1000.f}; // Tests should not exceed sum of 1000
|
||||
return GradientQuantiser(dh::ToSpan(gpair));
|
||||
return {dh::ToSpan(gpair), MetaInfo()};
|
||||
}
|
||||
|
||||
thrust::device_vector<GradientPairInt64> ConvertToInteger(std::vector<GradientPairPrecise> x) {
|
||||
|
||||
@ -39,7 +39,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
||||
FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size,
|
||||
sizeof(GradientPairInt64));
|
||||
|
||||
auto quantiser = GradientQuantiser(gpair.DeviceSpan());
|
||||
auto quantiser = GradientQuantiser(gpair.DeviceSpan(), MetaInfo());
|
||||
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
|
||||
feature_groups.DeviceAccessor(0), gpair.DeviceSpan(), ridx, d_histogram,
|
||||
quantiser);
|
||||
@ -53,7 +53,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) {
|
||||
dh::device_vector<GradientPairInt64> new_histogram(num_bins);
|
||||
auto d_new_histogram = dh::ToSpan(new_histogram);
|
||||
|
||||
auto quantiser = GradientQuantiser(gpair.DeviceSpan());
|
||||
auto quantiser = GradientQuantiser(gpair.DeviceSpan(), MetaInfo());
|
||||
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
|
||||
feature_groups.DeviceAccessor(0), gpair.DeviceSpan(), ridx,
|
||||
d_new_histogram, quantiser);
|
||||
@ -131,7 +131,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
||||
dh::device_vector<GradientPairInt64> cat_hist(num_categories);
|
||||
auto gpair = GenerateRandomGradients(kRows, 0, 2);
|
||||
gpair.SetDevice(0);
|
||||
auto quantiser = GradientQuantiser(gpair.DeviceSpan());
|
||||
auto quantiser = GradientQuantiser(gpair.DeviceSpan(), MetaInfo());
|
||||
/**
|
||||
* Generate hist with cat data.
|
||||
*/
|
||||
|
||||
@ -30,7 +30,7 @@ void TestUpdatePositionBatch() {
|
||||
std::vector<int> extra_data = {0};
|
||||
// Send the first five training instances to the right node
|
||||
// and the second 5 to the left node
|
||||
rp.UpdatePositionBatch({0}, {1}, {2}, extra_data, [=] __device__(RowPartitioner::RowIndexT ridx, int) {
|
||||
rp.UpdatePositionBatch({0}, {1}, {2}, extra_data, [=] __device__(RowPartitioner::RowIndexT ridx, int, int) {
|
||||
return ridx > 4;
|
||||
});
|
||||
rows = rp.GetRowsHost(1);
|
||||
@ -43,7 +43,7 @@ void TestUpdatePositionBatch() {
|
||||
}
|
||||
|
||||
// Split the left node again
|
||||
rp.UpdatePositionBatch({1}, {3}, {4}, extra_data,[=] __device__(RowPartitioner::RowIndexT ridx, int) {
|
||||
rp.UpdatePositionBatch({1}, {3}, {4}, extra_data,[=] __device__(RowPartitioner::RowIndexT ridx, int, int) {
|
||||
return ridx < 7;
|
||||
});
|
||||
EXPECT_EQ(rp.GetRows(3).size(), 2);
|
||||
@ -57,7 +57,7 @@ void TestSortPositionBatch(const std::vector<int>& ridx_in, const std::vector<Se
|
||||
thrust::device_vector<uint32_t> ridx_tmp(ridx_in.size());
|
||||
thrust::device_vector<bst_uint> counts(segments.size());
|
||||
|
||||
auto op = [=] __device__(auto ridx, int data) { return ridx % 2 == 0; };
|
||||
auto op = [=] __device__(auto ridx, int split_index, int data) { return ridx % 2 == 0; };
|
||||
std::vector<int> op_data(segments.size());
|
||||
std::vector<PerNodeData<int>> h_batch_info(segments.size());
|
||||
dh::TemporaryArray<PerNodeData<int>> d_batch_info(segments.size());
|
||||
|
||||
@ -93,7 +93,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
Context ctx{MakeCUDACtx(0)};
|
||||
auto cs = std::make_shared<common::ColumnSampler>(0);
|
||||
GPUHistMakerDevice maker(&ctx, /*is_external_memory=*/false, {}, kNRows, param, cs, kNCols,
|
||||
batch_param);
|
||||
batch_param, MetaInfo());
|
||||
xgboost::SimpleLCG gen;
|
||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
||||
HostDeviceVector<GradientPair> gpair(kNRows);
|
||||
@ -111,7 +111,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
maker.hist.AllocateHistograms({0});
|
||||
|
||||
maker.gpair = gpair.DeviceSpan();
|
||||
maker.quantiser = std::make_unique<GradientQuantiser>(maker.gpair);
|
||||
maker.quantiser = std::make_unique<GradientQuantiser>(maker.gpair, MetaInfo());
|
||||
maker.page = page.get();
|
||||
|
||||
maker.InitFeatureGroupsOnce();
|
||||
@ -165,7 +165,7 @@ HistogramCutsWrapper GetHostCutMatrix () {
|
||||
inline GradientQuantiser DummyRoundingFactor() {
|
||||
thrust::device_vector<GradientPair> gpair(1);
|
||||
gpair[0] = {1000.f, 1000.f}; // Tests should not exceed sum of 1000
|
||||
return GradientQuantiser(dh::ToSpan(gpair));
|
||||
return {dh::ToSpan(gpair), MetaInfo()};
|
||||
}
|
||||
|
||||
void TestHistogramIndexImpl() {
|
||||
@ -426,4 +426,54 @@ TEST(GpuHist, MaxDepth) {
|
||||
|
||||
ASSERT_THROW({learner->UpdateOneIter(0, p_mat);}, dmlc::Error);
|
||||
}
|
||||
|
||||
namespace {
|
||||
RegTree GetUpdatedTree(Context const* ctx, DMatrix* dmat) {
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
GPUHistMaker hist_maker{ctx, &task};
|
||||
hist_maker.Configure(Args{});
|
||||
|
||||
TrainParam param;
|
||||
param.UpdateAllowUnknown(Args{});
|
||||
|
||||
linalg::Matrix<GradientPair> gpair({dmat->Info().num_row_}, ctx->Ordinal());
|
||||
gpair.Data()->Copy(GenerateRandomGradients(dmat->Info().num_row_));
|
||||
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
RegTree tree;
|
||||
hist_maker.Update(¶m, &gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
|
||||
{&tree});
|
||||
return tree;
|
||||
}
|
||||
|
||||
void VerifyColumnSplit(bst_row_t rows, bst_feature_t cols, RegTree const& expected_tree) {
|
||||
Context ctx(MakeCUDACtx(GPUIDX));
|
||||
|
||||
auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true);
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
std::unique_ptr<DMatrix> sliced{Xy->SliceCol(world_size, rank)};
|
||||
|
||||
RegTree tree = GetUpdatedTree(&ctx, sliced.get());
|
||||
|
||||
Json json{Object{}};
|
||||
tree.SaveModel(&json);
|
||||
Json expected_json{Object{}};
|
||||
expected_tree.SaveModel(&expected_json);
|
||||
ASSERT_EQ(json, expected_json);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
class MGPUHistTest : public BaseMGPUTest {};
|
||||
|
||||
TEST_F(MGPUHistTest, GPUHistColumnSplit) {
|
||||
auto constexpr kRows = 32;
|
||||
auto constexpr kCols = 16;
|
||||
|
||||
Context ctx(MakeCUDACtx(0));
|
||||
auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
|
||||
RegTree expected_tree = GetUpdatedTree(&ctx, dmat.get());
|
||||
|
||||
DoTest(VerifyColumnSplit, kRows, kCols, expected_tree);
|
||||
}
|
||||
} // namespace xgboost::tree
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user