Make `HistCutMatrix::Init' be aware of groups. (#4115)
* Add checks for group size. * Simple docs. * Search group index during hist cut matrix initialization. Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com> Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
37ddfd7d6e
commit
754fe8142b
@ -101,6 +101,10 @@ The data is stored in a :py:class:`DMatrix <xgboost.DMatrix>` object.
|
|||||||
w = np.random.rand(5, 1)
|
w = np.random.rand(5, 1)
|
||||||
dtrain = xgb.DMatrix(data, label=label, missing=-999.0, weight=w)
|
dtrain = xgb.DMatrix(data, label=label, missing=-999.0, weight=w)
|
||||||
|
|
||||||
|
When performing ranking tasks, the number of weights should be equal
|
||||||
|
to number of groups.
|
||||||
|
|
||||||
|
|
||||||
Setting Parameters
|
Setting Parameters
|
||||||
------------------
|
------------------
|
||||||
XGBoost can use either a list of pairs or a dictionary to set :doc:`parameters </parameter>`. For instance:
|
XGBoost can use either a list of pairs or a dictionary to set :doc:`parameters </parameter>`. For instance:
|
||||||
|
|||||||
@ -24,7 +24,25 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
|
|
||||||
|
HistCutMatrix::HistCutMatrix() {
|
||||||
|
monitor_.Init("HistCutMatrix");
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t HistCutMatrix::SearchGroupIndFromBaseRow(
|
||||||
|
std::vector<bst_uint> const& group_ptr, size_t const base_rowid) const {
|
||||||
|
using KIt = std::vector<bst_uint>::const_iterator;
|
||||||
|
KIt res = std::lower_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid);
|
||||||
|
// Cannot use CHECK_NE because it will try to print the iterator.
|
||||||
|
bool const found = res != group_ptr.cend() - 1;
|
||||||
|
if (!found) {
|
||||||
|
LOG(FATAL) << "Row " << base_rowid << " does not lie in any group!\n";
|
||||||
|
}
|
||||||
|
size_t group_ind = std::distance(group_ptr.cbegin(), res);
|
||||||
|
return group_ind;
|
||||||
|
}
|
||||||
|
|
||||||
void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||||
|
monitor_.Start("Init");
|
||||||
const MetaInfo& info = p_fmat->Info();
|
const MetaInfo& info = p_fmat->Info();
|
||||||
|
|
||||||
// safe factor for better accuracy
|
// safe factor for better accuracy
|
||||||
@ -33,30 +51,50 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
|||||||
|
|
||||||
const int nthread = omp_get_max_threads();
|
const int nthread = omp_get_max_threads();
|
||||||
|
|
||||||
auto nstep = static_cast<unsigned>((info.num_col_ + nthread - 1) / nthread);
|
unsigned const nstep =
|
||||||
auto ncol = static_cast<unsigned>(info.num_col_);
|
static_cast<unsigned>((info.num_col_ + nthread - 1) / nthread);
|
||||||
|
unsigned const ncol = static_cast<unsigned>(info.num_col_);
|
||||||
sketchs.resize(info.num_col_);
|
sketchs.resize(info.num_col_);
|
||||||
for (auto& s : sketchs) {
|
for (auto& s : sketchs) {
|
||||||
s.Init(info.num_row_, 1.0 / (max_num_bins * kFactor));
|
s.Init(info.num_row_, 1.0 / (max_num_bins * kFactor));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& weights = info.weights_.HostVector();
|
const auto& weights = info.weights_.HostVector();
|
||||||
|
|
||||||
|
// Data groups, used in ranking.
|
||||||
|
std::vector<bst_uint> const& group_ptr = info.group_ptr_;
|
||||||
|
size_t const num_groups = group_ptr.size() == 0 ? 0 : group_ptr.size() - 1;
|
||||||
|
// Use group index for weights?
|
||||||
|
bool const use_group_ind = num_groups != 0 && weights.size() != info.num_row_;
|
||||||
|
|
||||||
for (const auto &batch : p_fmat->GetRowBatches()) {
|
for (const auto &batch : p_fmat->GetRowBatches()) {
|
||||||
#pragma omp parallel num_threads(nthread)
|
size_t group_ind = 0;
|
||||||
|
if (use_group_ind) {
|
||||||
|
group_ind = this->SearchGroupIndFromBaseRow(group_ptr, batch.base_rowid);
|
||||||
|
}
|
||||||
|
#pragma omp parallel num_threads(nthread) firstprivate(group_ind, use_group_ind)
|
||||||
{
|
{
|
||||||
CHECK_EQ(nthread, omp_get_num_threads());
|
CHECK_EQ(nthread, omp_get_num_threads());
|
||||||
auto tid = static_cast<unsigned>(omp_get_thread_num());
|
auto tid = static_cast<unsigned>(omp_get_thread_num());
|
||||||
unsigned begin = std::min(nstep * tid, ncol);
|
unsigned begin = std::min(nstep * tid, ncol);
|
||||||
unsigned end = std::min(nstep * (tid + 1), ncol);
|
unsigned end = std::min(nstep * (tid + 1), ncol);
|
||||||
|
|
||||||
// do not iterate if no columns are assigned to the thread
|
// do not iterate if no columns are assigned to the thread
|
||||||
if (begin < end && end <= ncol) {
|
if (begin < end && end <= ncol) {
|
||||||
for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*)
|
for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*)
|
||||||
size_t ridx = batch.base_rowid + i;
|
size_t const ridx = batch.base_rowid + i;
|
||||||
SparsePage::Inst inst = batch[i];
|
SparsePage::Inst const inst = batch[i];
|
||||||
for (auto& ins : inst) {
|
if (use_group_ind &&
|
||||||
if (ins.index >= begin && ins.index < end) {
|
group_ptr[group_ind] == ridx &&
|
||||||
sketchs[ins.index].Push(ins.fvalue,
|
// maximum equals to weights.size() - 1
|
||||||
weights.size() > 0 ? weights[ridx] : 1.0f);
|
group_ind < num_groups - 1) {
|
||||||
|
// move to next group
|
||||||
|
group_ind++;
|
||||||
|
}
|
||||||
|
for (auto const& entry : inst) {
|
||||||
|
if (entry.index >= begin && entry.index < end) {
|
||||||
|
size_t w_idx = use_group_ind ? group_ind : ridx;
|
||||||
|
sketchs[entry.index].Push(entry.fvalue, info.GetWeight(w_idx));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -65,6 +103,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Init(&sketchs, max_num_bins);
|
Init(&sketchs, max_num_bins);
|
||||||
|
monitor_.Stop("Init");
|
||||||
}
|
}
|
||||||
|
|
||||||
void HistCutMatrix::Init
|
void HistCutMatrix::Init
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
#include "row_set.h"
|
#include "row_set.h"
|
||||||
#include "../tree/param.h"
|
#include "../tree/param.h"
|
||||||
#include "./quantile.h"
|
#include "./quantile.h"
|
||||||
|
#include "./timer.h"
|
||||||
#include "../include/rabit/rabit.h"
|
#include "../include/rabit/rabit.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -35,6 +36,14 @@ struct HistCutMatrix {
|
|||||||
void Init(DMatrix* p_fmat, uint32_t max_num_bins);
|
void Init(DMatrix* p_fmat, uint32_t max_num_bins);
|
||||||
|
|
||||||
void Init(std::vector<WXQSketch>* sketchs, uint32_t max_num_bins);
|
void Init(std::vector<WXQSketch>* sketchs, uint32_t max_num_bins);
|
||||||
|
|
||||||
|
HistCutMatrix();
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual size_t SearchGroupIndFromBaseRow(
|
||||||
|
std::vector<bst_uint> const& group_ptr, size_t const base_rowid) const;
|
||||||
|
|
||||||
|
Monitor monitor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief Builds the cut matrix on the GPU */
|
/*! \brief Builds the cut matrix on the GPU */
|
||||||
|
|||||||
@ -474,12 +474,16 @@ class LearnerImpl : public Learner {
|
|||||||
|
|
||||||
void UpdateOneIter(int iter, DMatrix* train) override {
|
void UpdateOneIter(int iter, DMatrix* train) override {
|
||||||
monitor_.Start("UpdateOneIter");
|
monitor_.Start("UpdateOneIter");
|
||||||
|
|
||||||
|
// TODO(trivialfis): Merge the duplicated code with BoostOneIter
|
||||||
CHECK(ModelInitialized())
|
CHECK(ModelInitialized())
|
||||||
<< "Always call InitModel or LoadModel before update";
|
<< "Always call InitModel or LoadModel before update";
|
||||||
if (tparam_.seed_per_iteration || rabit::IsDistributed()) {
|
if (tparam_.seed_per_iteration || rabit::IsDistributed()) {
|
||||||
common::GlobalRandom().seed(tparam_.seed * kRandSeedMagic + iter);
|
common::GlobalRandom().seed(tparam_.seed * kRandSeedMagic + iter);
|
||||||
}
|
}
|
||||||
|
this->ValidateDMatrix(train);
|
||||||
this->PerformTreeMethodHeuristic(train);
|
this->PerformTreeMethodHeuristic(train);
|
||||||
|
|
||||||
monitor_.Start("PredictRaw");
|
monitor_.Start("PredictRaw");
|
||||||
this->PredictRaw(train, &preds_);
|
this->PredictRaw(train, &preds_);
|
||||||
monitor_.Stop("PredictRaw");
|
monitor_.Stop("PredictRaw");
|
||||||
@ -493,10 +497,15 @@ class LearnerImpl : public Learner {
|
|||||||
void BoostOneIter(int iter, DMatrix* train,
|
void BoostOneIter(int iter, DMatrix* train,
|
||||||
HostDeviceVector<GradientPair>* in_gpair) override {
|
HostDeviceVector<GradientPair>* in_gpair) override {
|
||||||
monitor_.Start("BoostOneIter");
|
monitor_.Start("BoostOneIter");
|
||||||
|
|
||||||
|
CHECK(ModelInitialized())
|
||||||
|
<< "Always call InitModel or LoadModel before boost.";
|
||||||
if (tparam_.seed_per_iteration || rabit::IsDistributed()) {
|
if (tparam_.seed_per_iteration || rabit::IsDistributed()) {
|
||||||
common::GlobalRandom().seed(tparam_.seed * kRandSeedMagic + iter);
|
common::GlobalRandom().seed(tparam_.seed * kRandSeedMagic + iter);
|
||||||
}
|
}
|
||||||
|
this->ValidateDMatrix(train);
|
||||||
this->PerformTreeMethodHeuristic(train);
|
this->PerformTreeMethodHeuristic(train);
|
||||||
|
|
||||||
gbm_->DoBoost(train, in_gpair);
|
gbm_->DoBoost(train, in_gpair);
|
||||||
monitor_.Stop("BoostOneIter");
|
monitor_.Stop("BoostOneIter");
|
||||||
}
|
}
|
||||||
@ -711,7 +720,7 @@ class LearnerImpl : public Learner {
|
|||||||
mparam_.num_feature = num_feature;
|
mparam_.num_feature = num_feature;
|
||||||
}
|
}
|
||||||
CHECK_NE(mparam_.num_feature, 0)
|
CHECK_NE(mparam_.num_feature, 0)
|
||||||
<< "0 feature is supplied. Are you using raw Booster?";
|
<< "0 feature is supplied. Are you using raw Booster interface?";
|
||||||
// setup
|
// setup
|
||||||
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
||||||
CHECK(obj_ == nullptr && gbm_ == nullptr);
|
CHECK(obj_ == nullptr && gbm_ == nullptr);
|
||||||
@ -736,6 +745,19 @@ class LearnerImpl : public Learner {
|
|||||||
gbm_->PredictBatch(data, out_preds, ntree_limit);
|
gbm_->PredictBatch(data, out_preds, ntree_limit);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ValidateDMatrix(DMatrix* p_fmat) {
|
||||||
|
MetaInfo const& info = p_fmat->Info();
|
||||||
|
auto const& weights = info.weights_.HostVector();
|
||||||
|
if (info.group_ptr_.size() != 0 && weights.size() != 0) {
|
||||||
|
CHECK(weights.size() == info.group_ptr_.size() - 1)
|
||||||
|
<< "\n"
|
||||||
|
<< "weights size: " << weights.size() << ", "
|
||||||
|
<< "groups size: " << info.group_ptr_.size() -1 << ", "
|
||||||
|
<< "num rows: " << p_fmat->Info().num_row_ << "\n"
|
||||||
|
<< "Number of weights should be equal to number of groups in ranking task.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// model parameter
|
// model parameter
|
||||||
LearnerModelParam mparam_;
|
LearnerModelParam mparam_;
|
||||||
// training parameter
|
// training parameter
|
||||||
|
|||||||
51
tests/cpp/common/test_hist_util.cc
Normal file
51
tests/cpp/common/test_hist_util.cc
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "../../../src/common/hist_util.h"
|
||||||
|
#include "../helpers.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
|
|
||||||
|
class HistCutMatrixMock : public HistCutMatrix {
|
||||||
|
public:
|
||||||
|
size_t SearchGroupIndFromBaseRow(
|
||||||
|
std::vector<bst_uint> const& group_ptr, size_t const base_rowid) {
|
||||||
|
return HistCutMatrix::SearchGroupIndFromBaseRow(group_ptr, base_rowid);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(HistCutMatrix, SearchGroupInd) {
|
||||||
|
size_t constexpr kNumGroups = 4;
|
||||||
|
size_t constexpr kNumRows = 17;
|
||||||
|
size_t constexpr kNumCols = 15;
|
||||||
|
|
||||||
|
auto pp_mat = CreateDMatrix(kNumRows, kNumCols, 0);
|
||||||
|
|
||||||
|
auto& p_mat = *pp_mat;
|
||||||
|
std::vector<bst_int> group(kNumGroups);
|
||||||
|
group[0] = 2;
|
||||||
|
group[1] = 3;
|
||||||
|
group[2] = 7;
|
||||||
|
group[3] = 5;
|
||||||
|
|
||||||
|
p_mat->Info().SetInfo(
|
||||||
|
"group", group.data(), DataType::kUInt32, kNumGroups);
|
||||||
|
|
||||||
|
HistCutMatrixMock hmat;
|
||||||
|
|
||||||
|
size_t group_ind = hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 0);
|
||||||
|
ASSERT_EQ(group_ind, 0);
|
||||||
|
|
||||||
|
group_ind = hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 5);
|
||||||
|
ASSERT_EQ(group_ind, 2);
|
||||||
|
|
||||||
|
EXPECT_ANY_THROW(hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 17));
|
||||||
|
|
||||||
|
delete pp_mat;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace common
|
||||||
|
} // namespace xgboost
|
||||||
@ -6,9 +6,9 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
TEST(learner, Test) {
|
TEST(Learner, Basic) {
|
||||||
typedef std::pair<std::string, std::string> arg;
|
typedef std::pair<std::string, std::string> Arg;
|
||||||
auto args = {arg("tree_method", "exact")};
|
auto args = {Arg("tree_method", "exact")};
|
||||||
auto mat_ptr = CreateDMatrix(10, 10, 0);
|
auto mat_ptr = CreateDMatrix(10, 10, 0);
|
||||||
std::vector<std::shared_ptr<xgboost::DMatrix>> mat = {*mat_ptr};
|
std::vector<std::shared_ptr<xgboost::DMatrix>> mat = {*mat_ptr};
|
||||||
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
|
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
|
||||||
@ -17,33 +17,33 @@ TEST(learner, Test) {
|
|||||||
delete mat_ptr;
|
delete mat_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(learner, SelectTreeMethod) {
|
TEST(Learner, SelectTreeMethod) {
|
||||||
using arg = std::pair<std::string, std::string>;
|
using Arg = std::pair<std::string, std::string>;
|
||||||
auto mat_ptr = CreateDMatrix(10, 10, 0);
|
auto mat_ptr = CreateDMatrix(10, 10, 0);
|
||||||
std::vector<std::shared_ptr<xgboost::DMatrix>> mat = {*mat_ptr};
|
std::vector<std::shared_ptr<xgboost::DMatrix>> mat = {*mat_ptr};
|
||||||
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
|
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
|
||||||
|
|
||||||
// Test if `tree_method` can be set
|
// Test if `tree_method` can be set
|
||||||
learner->Configure({arg("tree_method", "approx")});
|
learner->Configure({Arg("tree_method", "approx")});
|
||||||
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
||||||
"grow_histmaker,prune");
|
"grow_histmaker,prune");
|
||||||
learner->Configure({arg("tree_method", "exact")});
|
learner->Configure({Arg("tree_method", "exact")});
|
||||||
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
||||||
"grow_colmaker,prune");
|
"grow_colmaker,prune");
|
||||||
learner->Configure({arg("tree_method", "hist")});
|
learner->Configure({Arg("tree_method", "hist")});
|
||||||
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
||||||
"grow_quantile_histmaker");
|
"grow_quantile_histmaker");
|
||||||
learner->Configure({arg{"booster", "dart"}, arg{"tree_method", "hist"}});
|
learner->Configure({Arg{"booster", "dart"}, Arg{"tree_method", "hist"}});
|
||||||
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
||||||
"grow_quantile_histmaker");
|
"grow_quantile_histmaker");
|
||||||
#ifdef XGBOOST_USE_CUDA
|
#ifdef XGBOOST_USE_CUDA
|
||||||
learner->Configure({arg("tree_method", "gpu_exact")});
|
learner->Configure({Arg("tree_method", "gpu_exact")});
|
||||||
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
||||||
"grow_gpu,prune");
|
"grow_gpu,prune");
|
||||||
learner->Configure({arg("tree_method", "gpu_hist")});
|
learner->Configure({Arg("tree_method", "gpu_hist")});
|
||||||
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
||||||
"grow_gpu_hist");
|
"grow_gpu_hist");
|
||||||
learner->Configure({arg{"booster", "dart"}, arg{"tree_method", "gpu_hist"}});
|
learner->Configure({Arg{"booster", "dart"}, Arg{"tree_method", "gpu_hist"}});
|
||||||
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
||||||
"grow_gpu_hist");
|
"grow_gpu_hist");
|
||||||
#endif
|
#endif
|
||||||
@ -51,4 +51,45 @@ TEST(learner, SelectTreeMethod) {
|
|||||||
delete mat_ptr;
|
delete mat_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Learner, CheckGroup) {
|
||||||
|
using Arg = std::pair<std::string, std::string>;
|
||||||
|
size_t constexpr kNumGroups = 4;
|
||||||
|
size_t constexpr kNumRows = 17;
|
||||||
|
size_t constexpr kNumCols = 15;
|
||||||
|
|
||||||
|
auto pp_mat = CreateDMatrix(kNumRows, kNumCols, 0);
|
||||||
|
auto& p_mat = *pp_mat;
|
||||||
|
std::vector<bst_float> weight(kNumGroups);
|
||||||
|
std::vector<bst_int> group(kNumGroups);
|
||||||
|
group[0] = 2;
|
||||||
|
group[1] = 3;
|
||||||
|
group[2] = 7;
|
||||||
|
group[3] = 5;
|
||||||
|
std::vector<bst_float> labels (kNumRows);
|
||||||
|
for (size_t i = 0; i < kNumRows; ++i) {
|
||||||
|
labels[i] = i % 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
p_mat->Info().SetInfo(
|
||||||
|
"weight", static_cast<void*>(weight.data()), DataType::kFloat32, kNumGroups);
|
||||||
|
p_mat->Info().SetInfo(
|
||||||
|
"group", group.data(), DataType::kUInt32, kNumGroups);
|
||||||
|
p_mat->Info().SetInfo("label", labels.data(), DataType::kFloat32, kNumRows);
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<xgboost::DMatrix>> mat = {p_mat};
|
||||||
|
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
|
||||||
|
learner->Configure({Arg{"objective", "rank:pairwise"}});
|
||||||
|
learner->InitModel();
|
||||||
|
|
||||||
|
EXPECT_NO_THROW(learner->UpdateOneIter(0, p_mat.get()));
|
||||||
|
|
||||||
|
group.resize(kNumGroups+1);
|
||||||
|
group[3] = 4;
|
||||||
|
group[4] = 1;
|
||||||
|
p_mat->Info().SetInfo("group", group.data(), DataType::kUInt32, kNumGroups+1);
|
||||||
|
EXPECT_ANY_THROW(learner->UpdateOneIter(0, p_mat.get()));
|
||||||
|
|
||||||
|
delete pp_mat;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user