Optimized EvaluateSplut function (#5138)
* Add block based threading utilities.
This commit is contained in:
parent
04db125699
commit
7b17e76c5b
123
src/common/threading_utils.h
Executable file
123
src/common/threading_utils.h
Executable file
@ -0,0 +1,123 @@
|
||||
/*!
|
||||
* Copyright 2015-2019 by Contributors
|
||||
* \file common.h
|
||||
* \brief Threading utilities
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_THREADING_UTILS_H_
|
||||
#define XGBOOST_COMMON_THREADING_UTILS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
// Represent simple range of indexes [begin, end)
|
||||
// Inspired by tbb::blocked_range
|
||||
class Range1d {
|
||||
public:
|
||||
Range1d(size_t begin, size_t end): begin_(begin), end_(end) {
|
||||
CHECK_LT(begin, end);
|
||||
}
|
||||
|
||||
size_t begin() {
|
||||
return begin_;
|
||||
}
|
||||
|
||||
size_t end() {
|
||||
return end_;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t begin_;
|
||||
size_t end_;
|
||||
};
|
||||
|
||||
|
||||
// Split 2d space to balanced blocks
|
||||
// Implementation of the class is inspired by tbb::blocked_range2d
|
||||
// However, TBB provides only (n x m) 2d range (matrix) separated by blocks. Example:
|
||||
// [ 1,2,3 ]
|
||||
// [ 4,5,6 ]
|
||||
// [ 7,8,9 ]
|
||||
// But the class is able to work with different sizes in each 'row'. Example:
|
||||
// [ 1,2 ]
|
||||
// [ 3,4,5,6 ]
|
||||
// [ 7,8,9]
|
||||
// If grain_size is 2: It produces following blocks:
|
||||
// [1,2], [3,4], [5,6], [7,8], [9]
|
||||
// The class helps to process data in several tree nodes (non-balanced usually) in parallel
|
||||
// Using nested parallelism (by nodes and by data in each node)
|
||||
// it helps to improve CPU resources utilization
|
||||
class BlockedSpace2d {
|
||||
public:
|
||||
// Example of space:
|
||||
// [ 1,2 ]
|
||||
// [ 3,4,5,6 ]
|
||||
// [ 7,8,9]
|
||||
// BlockedSpace2d will create following blocks (tasks) if grain_size=2:
|
||||
// 1-block: first_dimension = 0, range of indexes in a 'row' = [0,2) (includes [1,2] values)
|
||||
// 2-block: first_dimension = 1, range of indexes in a 'row' = [0,2) (includes [3,4] values)
|
||||
// 3-block: first_dimension = 1, range of indexes in a 'row' = [2,4) (includes [5,6] values)
|
||||
// 4-block: first_dimension = 2, range of indexes in a 'row' = [0,2) (includes [7,8] values)
|
||||
// 5-block: first_dimension = 2, range of indexes in a 'row' = [2,3) (includes [9] values)
|
||||
// Arguments:
|
||||
// dim1 - size of the first dimension in the space
|
||||
// getter_size_dim2 - functor to get the second dimensions for each 'row' by row-index
|
||||
// grain_size - max size of produced blocks
|
||||
template<typename Func>
|
||||
BlockedSpace2d(size_t dim1, Func getter_size_dim2, size_t grain_size) {
|
||||
for (size_t i = 0; i < dim1; ++i) {
|
||||
const size_t size = getter_size_dim2(i);
|
||||
const size_t n_blocks = size/grain_size + !!(size % grain_size);
|
||||
for (size_t iblock = 0; iblock < n_blocks; ++iblock) {
|
||||
const size_t begin = iblock * grain_size;
|
||||
const size_t end = std::min(begin + grain_size, size);
|
||||
AddBlock(i, begin, end);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Amount of blocks(tasks) in a space
|
||||
size_t Size() const {
|
||||
return ranges_.size();
|
||||
}
|
||||
|
||||
// get index of the first dimension of i-th block(task)
|
||||
size_t GetFirstDimension(size_t i) const {
|
||||
CHECK_LT(i, first_dimension_.size());
|
||||
return first_dimension_[i];
|
||||
}
|
||||
|
||||
// get a range of indexes for the second dimension of i-th block(task)
|
||||
Range1d GetRange(size_t i) const {
|
||||
CHECK_LT(i, ranges_.size());
|
||||
return ranges_[i];
|
||||
}
|
||||
|
||||
private:
|
||||
void AddBlock(size_t first_dimension, size_t begin, size_t end) {
|
||||
first_dimension_.push_back(first_dimension);
|
||||
ranges_.emplace_back(begin, end);
|
||||
}
|
||||
|
||||
std::vector<Range1d> ranges_;
|
||||
std::vector<size_t> first_dimension_;
|
||||
};
|
||||
|
||||
|
||||
// Wrapper to implement nested parallelism with simple omp parallel for
|
||||
template<typename Func>
|
||||
void ParallelFor2d(const BlockedSpace2d& space, Func func) {
|
||||
const int num_blocks_in_space = static_cast<int>(space.Size());
|
||||
|
||||
#pragma omp parallel for
|
||||
for (auto i = 0; i < num_blocks_in_space; i++) {
|
||||
func(space.GetFirstDimension(i), space.GetRange(i));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_COMMON_THREADING_UTILS_H_
|
||||
@ -28,6 +28,8 @@
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/row_set.h"
|
||||
#include "../common/column_matrix.h"
|
||||
#include "../common/threading_utils.h"
|
||||
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@ -191,9 +193,11 @@ void QuantileHistMaker::Builder::EvaluateSplits(
|
||||
int depth,
|
||||
unsigned *timestamp,
|
||||
std::vector<ExpandEntry> *temp_qexpand_depth) {
|
||||
this->EvaluateSplit(qexpand_depth_wise_, gmat, hist_, *p_fmat, *p_tree);
|
||||
|
||||
for (auto const& entry : qexpand_depth_wise_) {
|
||||
int nid = entry.nid;
|
||||
this->EvaluateSplit(nid, gmat, hist_, *p_fmat, *p_tree);
|
||||
|
||||
if (snode_[nid].best.loss_chg < kRtEps ||
|
||||
(param_.max_depth > 0 && depth == param_.max_depth) ||
|
||||
(param_.max_leaves > 0 && (*num_leaves) == param_.max_leaves)) {
|
||||
@ -223,7 +227,8 @@ void QuantileHistMaker::Builder::ExpandWithDepthWise(
|
||||
int num_leaves = 0;
|
||||
|
||||
// in depth_wise growing, we feed loss_chg with 0.0 since it is not used anyway
|
||||
qexpand_depth_wise_.emplace_back(ExpandEntry(0, p_tree->GetDepth(0), 0.0, timestamp++));
|
||||
qexpand_depth_wise_.emplace_back(ExpandEntry(ExpandEntry::kRootNid,
|
||||
p_tree->GetDepth(ExpandEntry::kRootNid), 0.0, timestamp++));
|
||||
++num_leaves;
|
||||
for (int depth = 0; depth < param_.max_depth + 1; depth++) {
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
@ -257,14 +262,18 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
|
||||
unsigned timestamp = 0;
|
||||
int num_leaves = 0;
|
||||
|
||||
hist_.AddHistRow(0);
|
||||
BuildHist(gpair_h, row_set_collection_[0], gmat, gmatb, hist_[0], true);
|
||||
hist_.AddHistRow(ExpandEntry::kRootNid);
|
||||
BuildHist(gpair_h, row_set_collection_[ExpandEntry::kRootNid], gmat, gmatb,
|
||||
hist_[ExpandEntry::kRootNid], true);
|
||||
|
||||
this->InitNewNode(0, gmat, gpair_h, *p_fmat, *p_tree);
|
||||
this->InitNewNode(ExpandEntry::kRootNid, gmat, gpair_h, *p_fmat, *p_tree);
|
||||
|
||||
this->EvaluateSplit(0, gmat, hist_, *p_fmat, *p_tree);
|
||||
qexpand_loss_guided_->push(ExpandEntry(0, p_tree->GetDepth(0),
|
||||
snode_[0].best.loss_chg, timestamp++));
|
||||
ExpandEntry node(ExpandEntry::kRootNid, p_tree->GetDepth(ExpandEntry::kRootNid),
|
||||
snode_[ExpandEntry::kRootNid].best.loss_chg, timestamp++);
|
||||
this->EvaluateSplit({node}, gmat, hist_, *p_fmat, *p_tree);
|
||||
node.loss_chg = snode_[ExpandEntry::kRootNid].best.loss_chg;
|
||||
|
||||
qexpand_loss_guided_->push(node);
|
||||
++num_leaves;
|
||||
|
||||
while (!qexpand_loss_guided_->empty()) {
|
||||
@ -304,15 +313,17 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
|
||||
snode_[cleft].weight, snode_[cright].weight);
|
||||
interaction_constraints_.Split(nid, featureid, cleft, cright);
|
||||
|
||||
this->EvaluateSplit(cleft, gmat, hist_, *p_fmat, *p_tree);
|
||||
this->EvaluateSplit(cright, gmat, hist_, *p_fmat, *p_tree);
|
||||
ExpandEntry left_node(cleft, p_tree->GetDepth(cleft),
|
||||
snode_[cleft].best.loss_chg, timestamp++);
|
||||
ExpandEntry right_node(cright, p_tree->GetDepth(cright),
|
||||
snode_[cright].best.loss_chg, timestamp++);
|
||||
|
||||
qexpand_loss_guided_->push(ExpandEntry(cleft, p_tree->GetDepth(cleft),
|
||||
snode_[cleft].best.loss_chg,
|
||||
timestamp++));
|
||||
qexpand_loss_guided_->push(ExpandEntry(cright, p_tree->GetDepth(cright),
|
||||
snode_[cright].best.loss_chg,
|
||||
timestamp++));
|
||||
this->EvaluateSplit({left_node, right_node}, gmat, hist_, *p_fmat, *p_tree);
|
||||
left_node.loss_chg = snode_[cleft].best.loss_chg;
|
||||
right_node.loss_chg = snode_[cright].best.loss_chg;
|
||||
|
||||
qexpand_loss_guided_->push(left_node);
|
||||
qexpand_loss_guided_->push(right_node);
|
||||
|
||||
++num_leaves; // give two and take one, as parent is no longer a leaf
|
||||
}
|
||||
@ -553,40 +564,77 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
|
||||
builder_monitor_.Stop("InitData");
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::EvaluateSplit(const int nid,
|
||||
// if sum of statistics for non-missing values in the node
|
||||
// is equal to sum of statistics for all values:
|
||||
// then - there are no missing values
|
||||
// else - there are missing values
|
||||
bool QuantileHistMaker::Builder::SplitContainsMissingValues(const GradStats e,
|
||||
const NodeEntry& snode) {
|
||||
if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) {
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// nodes_set - set of nodes to be processed in parallel
|
||||
void QuantileHistMaker::Builder::EvaluateSplit(const std::vector<ExpandEntry>& nodes_set,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const HistCollection& hist,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree) {
|
||||
builder_monitor_.Start("EvaluateSplit");
|
||||
// start enumeration
|
||||
const MetaInfo& info = fmat.Info();
|
||||
auto p_feature_set = column_sampler_.GetFeatureSet(tree.GetDepth(nid));
|
||||
auto const& feature_set = p_feature_set->HostVector();
|
||||
const auto nfeature = static_cast<bst_feature_t>(feature_set.size());
|
||||
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
|
||||
best_split_tloc_.resize(nthread);
|
||||
#pragma omp parallel for schedule(static) num_threads(nthread)
|
||||
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
|
||||
best_split_tloc_[tid] = snode_[nid].best;
|
||||
}
|
||||
GHistRow node_hist = hist[nid];
|
||||
|
||||
#pragma omp parallel for schedule(dynamic) num_threads(nthread)
|
||||
for (bst_omp_uint i = 0; i < nfeature; ++i) { // NOLINT(*)
|
||||
const auto feature_id = static_cast<bst_uint>(feature_set[i]);
|
||||
const auto tid = static_cast<unsigned>(omp_get_thread_num());
|
||||
const auto node_id = static_cast<bst_uint>(nid);
|
||||
if (interaction_constraints_.Query(node_id, feature_id)) {
|
||||
this->EnumerateSplit(-1, gmat, node_hist, snode_[nid], info,
|
||||
&best_split_tloc_[tid], feature_id, node_id);
|
||||
this->EnumerateSplit(+1, gmat, node_hist, snode_[nid], info,
|
||||
&best_split_tloc_[tid], feature_id, node_id);
|
||||
const size_t n_nodes_in_set = nodes_set.size();
|
||||
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
|
||||
|
||||
using FeatureSetType = std::shared_ptr<HostDeviceVector<bst_feature_t>>;
|
||||
std::vector<FeatureSetType> features_sets(n_nodes_in_set);
|
||||
best_split_tloc_.resize(nthread * n_nodes_in_set);
|
||||
|
||||
// Generate feature set for each tree node
|
||||
for (size_t nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) {
|
||||
const int32_t nid = nodes_set[nid_in_set].nid;
|
||||
features_sets[nid_in_set] = column_sampler_.GetFeatureSet(tree.GetDepth(nid));
|
||||
|
||||
for (unsigned tid = 0; tid < nthread; ++tid) {
|
||||
best_split_tloc_[nthread*nid_in_set + tid] = snode_[nid].best;
|
||||
}
|
||||
}
|
||||
for (unsigned tid = 0; tid < nthread; ++tid) {
|
||||
snode_[nid].best.Update(best_split_tloc_[tid]);
|
||||
|
||||
// Create 2D space (# of nodes to process x # of features to process)
|
||||
// to process them in parallel
|
||||
common::BlockedSpace2d space(n_nodes_in_set, [&](size_t nid_in_set) {
|
||||
return features_sets[nid_in_set]->Size();
|
||||
}, 1);
|
||||
|
||||
// Start parallel enumeration for all tree nodes in the set and all features
|
||||
common::ParallelFor2d(space, [&](size_t nid_in_set, common::Range1d r) {
|
||||
const int32_t nid = nodes_set[nid_in_set].nid;
|
||||
const auto tid = static_cast<unsigned>(omp_get_thread_num());
|
||||
GHistRow node_hist = hist[nid];
|
||||
|
||||
for (auto idx_in_feature_set = r.begin(); idx_in_feature_set < r.end(); ++idx_in_feature_set) {
|
||||
const auto fid = features_sets[nid_in_set]->ConstHostVector()[idx_in_feature_set];
|
||||
if (interaction_constraints_.Query(nid, fid)) {
|
||||
auto grad_stats = this->EnumerateSplit<+1>(gmat, node_hist, snode_[nid], fmat.Info(),
|
||||
&best_split_tloc_[nthread*nid_in_set + tid], fid, nid);
|
||||
if (SplitContainsMissingValues(grad_stats, snode_[nid])) {
|
||||
this->EnumerateSplit<-1>(gmat, node_hist, snode_[nid], fmat.Info(),
|
||||
&best_split_tloc_[nthread*nid_in_set + tid], fid, nid);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Find Best Split across threads for each node in nodes set
|
||||
for (unsigned nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) {
|
||||
const int32_t nid = nodes_set[nid_in_set].nid;
|
||||
for (unsigned tid = 0; tid < nthread; ++tid) {
|
||||
snode_[nid].best.Update(best_split_tloc_[nthread*nid_in_set + tid]);
|
||||
}
|
||||
}
|
||||
|
||||
builder_monitor_.Stop("EvaluateSplit");
|
||||
}
|
||||
|
||||
@ -830,8 +878,12 @@ void QuantileHistMaker::Builder::InitNewNode(int nid,
|
||||
builder_monitor_.Stop("InitNewNode");
|
||||
}
|
||||
|
||||
// enumerate the split values of specific feature
|
||||
void QuantileHistMaker::Builder::EnumerateSplit(int d_step,
|
||||
|
||||
// Enumerate the split values of specific feature.
|
||||
// Returns the sum of gradients corresponding to the data points that contains a non-missing value
|
||||
// for the particular feature fid.
|
||||
template<int d_step>
|
||||
GradStats QuantileHistMaker::Builder::EnumerateSplit(
|
||||
const GHistIndexMatrix& gmat,
|
||||
const GHistRow& hist,
|
||||
const NodeEntry& snode,
|
||||
@ -903,6 +955,8 @@ void QuantileHistMaker::Builder::EnumerateSplit(int d_step,
|
||||
}
|
||||
}
|
||||
p_best->Update(best);
|
||||
|
||||
return e;
|
||||
}
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker")
|
||||
|
||||
@ -183,6 +183,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
protected:
|
||||
/* tree growing policies */
|
||||
struct ExpandEntry {
|
||||
static const int kRootNid = 0;
|
||||
int nid;
|
||||
int depth;
|
||||
bst_float loss_chg;
|
||||
@ -197,7 +198,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree);
|
||||
|
||||
void EvaluateSplit(const int nid,
|
||||
void EvaluateSplit(const std::vector<ExpandEntry>& nodes_set,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const HistCollection& hist,
|
||||
const DMatrix& fmat,
|
||||
@ -232,8 +233,11 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree);
|
||||
|
||||
// enumerate the split values of specific feature
|
||||
void EnumerateSplit(int d_step,
|
||||
// Enumerate the split values of specific feature
|
||||
// Returns the sum of gradients corresponding to the data points that contains a non-missing
|
||||
// value for the particular feature fid.
|
||||
template<int d_step>
|
||||
GradStats EnumerateSplit(
|
||||
const GHistIndexMatrix& gmat,
|
||||
const GHistRow& hist,
|
||||
const NodeEntry& snode,
|
||||
@ -242,6 +246,12 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
bst_uint fid,
|
||||
bst_uint nodeID);
|
||||
|
||||
// if sum of statistics for non-missing values in the node
|
||||
// is equal to sum of statistics for all values:
|
||||
// then - there are no missing values
|
||||
// else - there are missing values
|
||||
bool SplitContainsMissingValues(const GradStats e, const NodeEntry& snode);
|
||||
|
||||
void ExpandWithDepthWise(const GHistIndexMatrix &gmat,
|
||||
const GHistIndexBlockMatrix &gmatb,
|
||||
const ColumnMatrix &column_matrix,
|
||||
|
||||
82
tests/cpp/common/test_threading_utils.cc
Executable file
82
tests/cpp/common/test_threading_utils.cc
Executable file
@ -0,0 +1,82 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../../../src/common/column_matrix.h"
|
||||
#include "../../../src/common/threading_utils.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
TEST(CreateBlockedSpace2d, Test) {
|
||||
constexpr size_t kDim1 = 5;
|
||||
constexpr size_t kDim2 = 3;
|
||||
constexpr size_t kGrainSize = 1;
|
||||
|
||||
BlockedSpace2d space(kDim1, [&](size_t i) {
|
||||
return kDim2;
|
||||
}, kGrainSize);
|
||||
|
||||
ASSERT_EQ(kDim1 * kDim2, space.Size());
|
||||
|
||||
for (auto i = 0; i < kDim1; i++) {
|
||||
for (auto j = 0; j < kDim2; j++) {
|
||||
ASSERT_EQ(space.GetFirstDimension(i*kDim2 + j), i);
|
||||
ASSERT_EQ(j, space.GetRange(i*kDim2 + j).begin());
|
||||
ASSERT_EQ(j + kGrainSize, space.GetRange(i*kDim2 + j).end());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ParallelFor2d, Test) {
|
||||
constexpr size_t kDim1 = 100;
|
||||
constexpr size_t kDim2 = 15;
|
||||
constexpr size_t kGrainSize = 2;
|
||||
|
||||
// working space is matrix of size (kDim1 x kDim2)
|
||||
std::vector<int> matrix(kDim1 * kDim2, 0);
|
||||
BlockedSpace2d space(kDim1, [&](size_t i) {
|
||||
return kDim2;
|
||||
}, kGrainSize);
|
||||
|
||||
ParallelFor2d(space, [&](size_t i, Range1d r) {
|
||||
for (auto j = r.begin(); j < r.end(); ++j) {
|
||||
matrix[i*kDim2 + j] += 1;
|
||||
}
|
||||
});
|
||||
|
||||
for (auto i = 0; i < kDim1 * kDim2; i++) {
|
||||
ASSERT_EQ(matrix[i], 1);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ParallelFor2dNonUniform, Test) {
|
||||
constexpr size_t kDim1 = 5;
|
||||
constexpr size_t kGrainSize = 256;
|
||||
|
||||
// here are quite non-uniform distribution in space
|
||||
// but ParallelFor2d should split them by blocks with max size = kGrainSize
|
||||
// and process in balanced manner (optimal performance)
|
||||
std::vector<size_t> dim2 { 1024, 500, 255, 5, 10000 };
|
||||
BlockedSpace2d space(kDim1, [&](size_t i) {
|
||||
return dim2[i];
|
||||
}, kGrainSize);
|
||||
|
||||
std::vector<std::vector<int>> working_space(kDim1);
|
||||
for (auto i = 0; i < kDim1; i++) {
|
||||
working_space[i].resize(dim2[i], 0);
|
||||
}
|
||||
|
||||
ParallelFor2d(space, [&](size_t i, Range1d r) {
|
||||
for (auto j = r.begin(); j < r.end(); ++j) {
|
||||
working_space[i][j] += 1;
|
||||
}
|
||||
});
|
||||
|
||||
for (auto i = 0; i < kDim1; i++) {
|
||||
for (auto j = 0; j < dim2[i]; j++) {
|
||||
ASSERT_EQ(working_space[i][j], 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@ -211,7 +211,8 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
|
||||
/* Now compare against result given by EvaluateSplit() */
|
||||
RealImpl::EvaluateSplit(0, gmat, hist_, *(*dmat), tree);
|
||||
ExpandEntry node(0, tree.GetDepth(0), snode_[0].best.loss_chg, 0);
|
||||
RealImpl::EvaluateSplit({node}, gmat, hist_, *(*dmat), tree);
|
||||
ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature);
|
||||
ASSERT_EQ(snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user