Use ellpack for prediction only when sparsepage doesn't exist. (#5504)
This commit is contained in:
parent
ad826e913f
commit
6671b42dd4
@ -181,12 +181,6 @@ class GradientBooster : public Model, public Configurable {
|
||||
const std::string& name,
|
||||
GenericParameter const* generic_param,
|
||||
LearnerModelParam const* learner_model_param);
|
||||
|
||||
static void AssertGPUSupport() {
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
LOG(FATAL) << "XGBoost version not compiled with GPU support.";
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
|
||||
@ -85,7 +85,7 @@ XGB_DLL int XGDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs,
|
||||
int nthread,
|
||||
DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
LOG(FATAL) << "XGBoost not compiled with CUDA";
|
||||
common::AssertGPUSupport();
|
||||
API_END();
|
||||
}
|
||||
|
||||
@ -94,7 +94,7 @@ XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs,
|
||||
int nthread,
|
||||
DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
LOG(FATAL) << "XGBoost not compiled with CUDA";
|
||||
common::AssertGPUSupport();
|
||||
API_END();
|
||||
}
|
||||
|
||||
@ -521,7 +521,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle,
|
||||
float const** out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
LOG(FATAL) << "XGBoost not compiled with CUDA.";
|
||||
common::AssertGPUSupport();
|
||||
API_END();
|
||||
}
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle,
|
||||
@ -535,7 +535,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle,
|
||||
const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
LOG(FATAL) << "XGBoost not compiled with CUDA.";
|
||||
common::AssertGPUSupport();
|
||||
API_END();
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
@ -147,6 +147,13 @@ class Range {
|
||||
};
|
||||
|
||||
int AllVisibleGPUs();
|
||||
|
||||
inline void AssertGPUSupport() {
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
LOG(FATAL) << "XGBoost version not compiled with GPU support.";
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_COMMON_H_
|
||||
|
||||
@ -293,7 +293,7 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
LOG(FATAL) << "XGBoost version is not compiled with GPU support";
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
|
||||
#include "gblinear_model.h"
|
||||
#include "../common/timer.h"
|
||||
#include "../common/common.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
@ -68,7 +69,7 @@ class GBLinear : public GradientBooster {
|
||||
updater_->Configure(cfg);
|
||||
monitor_.Init("GBLinear");
|
||||
if (param_.updater == "gpu_coord_descent") {
|
||||
this->AssertGPUSupport();
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -172,7 +172,7 @@ void GBTree::ConfigureUpdaters() {
|
||||
tparam_.updater_seq = "grow_quantile_histmaker";
|
||||
break;
|
||||
case TreeMethod::kGPUHist: {
|
||||
this->AssertGPUSupport();
|
||||
common::AssertGPUSupport();
|
||||
tparam_.updater_seq = "grow_gpu_hist";
|
||||
break;
|
||||
}
|
||||
@ -391,17 +391,21 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
|
||||
CHECK(gpu_predictor_);
|
||||
return gpu_predictor_;
|
||||
#else
|
||||
this->AssertGPUSupport();
|
||||
common::AssertGPUSupport();
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
CHECK(cpu_predictor_);
|
||||
return cpu_predictor_;
|
||||
}
|
||||
|
||||
auto on_device =
|
||||
f_dmat &&
|
||||
(f_dmat->PageExists<EllpackPage>() ||
|
||||
(*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead());
|
||||
// Data comes from Device DMatrix.
|
||||
auto is_ellpack = f_dmat && f_dmat->PageExists<EllpackPage>() &&
|
||||
!f_dmat->PageExists<SparsePage>();
|
||||
// Data comes from device memory, like CuDF or CuPy.
|
||||
auto is_from_device =
|
||||
f_dmat && f_dmat->PageExists<SparsePage>() &&
|
||||
(*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead();
|
||||
auto on_device = is_ellpack || is_from_device;
|
||||
|
||||
// Use GPU Predictor if data is already on device and gpu_id is set.
|
||||
if (on_device && generic_param_->gpu_id >= 0) {
|
||||
@ -434,7 +438,7 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
|
||||
CHECK(gpu_predictor_);
|
||||
return gpu_predictor_;
|
||||
#else
|
||||
this->AssertGPUSupport();
|
||||
common::AssertGPUSupport();
|
||||
return cpu_predictor_;
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
@ -348,7 +348,14 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
model_.Init(model, tree_begin, tree_end, generic_param_->gpu_id);
|
||||
out_preds->SetDevice(generic_param_->gpu_id);
|
||||
|
||||
if (dmat->PageExists<EllpackPage>()) {
|
||||
if (dmat->PageExists<SparsePage>()) {
|
||||
size_t batch_offset = 0;
|
||||
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||
this->PredictInternal(batch, model.learner_model_param->num_feature,
|
||||
out_preds, batch_offset);
|
||||
batch_offset += batch.Size() * model.learner_model_param->num_output_group;
|
||||
}
|
||||
} else {
|
||||
size_t batch_offset = 0;
|
||||
for (auto const& page : dmat->GetBatches<EllpackPage>()) {
|
||||
this->PredictInternal(
|
||||
@ -356,13 +363,6 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
batch_offset);
|
||||
batch_offset += page.Impl()->n_rows;
|
||||
}
|
||||
} else {
|
||||
size_t batch_offset = 0;
|
||||
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||
this->PredictInternal(batch, model.learner_model_param->num_feature,
|
||||
out_preds, batch_offset);
|
||||
batch_offset += batch.Size() * model.learner_model_param->num_output_group;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -82,7 +82,7 @@ TEST(CAPI, Version) {
|
||||
|
||||
TEST(CAPI, ConfigIO) {
|
||||
size_t constexpr kRows = 10;
|
||||
auto p_dmat = RandomDataGenerator(kRows, 10, 0).GenerateDMatix();
|
||||
auto p_dmat = RandomDataGenerator(kRows, 10, 0).GenerateDMatrix();
|
||||
std::vector<std::shared_ptr<DMatrix>> mat {p_dmat};
|
||||
std::vector<bst_float> labels(kRows);
|
||||
for (size_t i = 0; i < labels.size(); ++i) {
|
||||
@ -115,7 +115,7 @@ TEST(CAPI, JsonModelIO) {
|
||||
size_t constexpr kRows = 10;
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
|
||||
auto p_dmat = RandomDataGenerator(kRows, 10, 0).GenerateDMatix();
|
||||
auto p_dmat = RandomDataGenerator(kRows, 10, 0).GenerateDMatrix();
|
||||
std::vector<std::shared_ptr<DMatrix>> mat {p_dmat};
|
||||
std::vector<bst_float> labels(kRows);
|
||||
for (size_t i = 0; i < labels.size(); ++i) {
|
||||
|
||||
@ -13,7 +13,7 @@ TEST(DenseColumn, Test) {
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||
for (size_t max_num_bin : max_num_bins) {
|
||||
auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatix();
|
||||
auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix();
|
||||
GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat.get(), max_num_bin);
|
||||
ColumnMatrix column_matrix;
|
||||
@ -61,7 +61,7 @@ TEST(SparseColumn, Test) {
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||
for (size_t max_num_bin : max_num_bins) {
|
||||
auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatix();
|
||||
auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix();
|
||||
GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat.get(), max_num_bin);
|
||||
ColumnMatrix column_matrix;
|
||||
@ -102,7 +102,7 @@ TEST(DenseColumnWithMissing, Test) {
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 };
|
||||
for (size_t max_num_bin : max_num_bins) {
|
||||
auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatix();
|
||||
auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix();
|
||||
GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat.get(), max_num_bin);
|
||||
ColumnMatrix column_matrix;
|
||||
|
||||
@ -128,7 +128,7 @@ TEST(CutsBuilder, SearchGroupInd) {
|
||||
size_t constexpr kRows = 17;
|
||||
size_t constexpr kCols = 15;
|
||||
|
||||
auto p_mat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
auto p_mat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
std::vector<bst_int> group(kNumGroups);
|
||||
group[0] = 2;
|
||||
@ -155,7 +155,7 @@ TEST(SparseCuts, SingleThreadedBuild) {
|
||||
size_t constexpr kCols = 31;
|
||||
size_t constexpr kBins = 256;
|
||||
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
common::GHistIndexMatrix hmat;
|
||||
hmat.Init(p_fmat.get(), kBins);
|
||||
@ -206,12 +206,12 @@ TEST(SparseCuts, MultiThreadedBuild) {
|
||||
};
|
||||
|
||||
{
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
Compare(p_fmat.get());
|
||||
}
|
||||
|
||||
{
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0.0001).GenerateDMatix();
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0.0001).GenerateDMatrix();
|
||||
Compare(p_fmat.get());
|
||||
}
|
||||
|
||||
@ -360,7 +360,7 @@ TEST(HistUtil, IndexBinBound) {
|
||||
|
||||
size_t bin_id = 0;
|
||||
for (auto max_bin : bin_sizes) {
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
common::GHistIndexMatrix hmat;
|
||||
hmat.Init(p_fmat.get(), max_bin);
|
||||
@ -381,7 +381,7 @@ TEST(HistUtil, SparseIndexBinBound) {
|
||||
|
||||
size_t bin_id = 0;
|
||||
for (auto max_bin : bin_sizes) {
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0.2).GenerateDMatix();
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0.2).GenerateDMatrix();
|
||||
common::GHistIndexMatrix hmat;
|
||||
hmat.Init(p_fmat.get(), max_bin);
|
||||
EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize());
|
||||
@ -404,7 +404,7 @@ TEST(HistUtil, IndexBinData) {
|
||||
size_t constexpr kCols = 10;
|
||||
|
||||
for (auto max_bin : kBinSizes) {
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
common::GHistIndexMatrix hmat;
|
||||
hmat.Init(p_fmat.get(), max_bin);
|
||||
uint32_t* offsets = hmat.index.Offset();
|
||||
@ -434,7 +434,7 @@ TEST(HistUtil, SparseIndexBinData) {
|
||||
size_t constexpr kCols = 10;
|
||||
|
||||
for (auto max_bin : bin_sizes) {
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0.2).GenerateDMatix();
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0.2).GenerateDMatrix();
|
||||
common::GHistIndexMatrix hmat;
|
||||
hmat.Init(p_fmat.get(), max_bin);
|
||||
EXPECT_EQ(hmat.index.Offset(), nullptr);
|
||||
|
||||
@ -68,7 +68,7 @@ TEST(Adapter, CSCAdapterColsMoreThanRows) {
|
||||
}
|
||||
|
||||
TEST(CAPI, DMatrixSliceAdapterFromSimpleDMatrix) {
|
||||
auto p_dmat = RandomDataGenerator(6, 2, 1.0).GenerateDMatix();
|
||||
auto p_dmat = RandomDataGenerator(6, 2, 1.0).GenerateDMatrix();
|
||||
|
||||
std::vector<int> ridx_set = {1, 3, 5};
|
||||
data::DMatrixSliceAdapter adapter(p_dmat.get(),
|
||||
|
||||
@ -23,7 +23,7 @@ TEST(DeviceDMatrix, RowMajor) {
|
||||
auto adapter = common::AdapterFromData(x_device, num_rows, num_columns);
|
||||
|
||||
data::DeviceDMatrix dmat(&adapter,
|
||||
std::numeric_limits<float>::quiet_NaN(), 1, 256);
|
||||
std::numeric_limits<float>::quiet_NaN(), 1, 256);
|
||||
|
||||
auto &batch = *dmat.GetBatches<EllpackPage>({0, 256, 0}).begin();
|
||||
auto impl = batch.Impl();
|
||||
@ -60,7 +60,7 @@ TEST(DeviceDMatrix, RowMajorMissing) {
|
||||
EXPECT_EQ(iterator[1], impl->GetDeviceAccessor(0).NullValue());
|
||||
EXPECT_EQ(iterator[5], impl->GetDeviceAccessor(0).NullValue());
|
||||
// null values get placed after valid values in a row
|
||||
EXPECT_EQ(iterator[7], impl->GetDeviceAccessor(0).NullValue());
|
||||
EXPECT_EQ(iterator[7], impl->GetDeviceAccessor(0).NullValue());
|
||||
EXPECT_EQ(dmat.Info().num_col_, num_columns);
|
||||
EXPECT_EQ(dmat.Info().num_row_, num_rows);
|
||||
EXPECT_EQ(dmat.Info().num_nonzero_, num_rows*num_columns-3);
|
||||
|
||||
@ -17,7 +17,7 @@ namespace xgboost {
|
||||
TEST(EllpackPage, EmptyDMatrix) {
|
||||
constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256;
|
||||
constexpr float kSparsity = 0;
|
||||
auto dmat = RandomDataGenerator(kNRows, kNCols, kSparsity).GenerateDMatix();
|
||||
auto dmat = RandomDataGenerator(kNRows, kNCols, kSparsity).GenerateDMatrix();
|
||||
auto& page = *dmat->GetBatches<EllpackPage>({0, kMaxBin}).begin();
|
||||
auto impl = page.Impl();
|
||||
ASSERT_EQ(impl->row_stride, 0);
|
||||
|
||||
@ -220,7 +220,7 @@ TEST(SimpleDMatrix, FromFile) {
|
||||
TEST(SimpleDMatrix, Slice) {
|
||||
const int kRows = 6;
|
||||
const int kCols = 2;
|
||||
auto p_dmat = RandomDataGenerator(kRows, kCols, 1.0).GenerateDMatix();
|
||||
auto p_dmat = RandomDataGenerator(kRows, kCols, 1.0).GenerateDMatrix();
|
||||
auto &labels = p_dmat->Info().labels_.HostVector();
|
||||
auto &weights = p_dmat->Info().weights_.HostVector();
|
||||
auto &base_margin = p_dmat->Info().base_margin_.HostVector();
|
||||
|
||||
@ -55,7 +55,7 @@ TEST(GBTree, WrongUpdater) {
|
||||
size_t constexpr kRows = 17;
|
||||
size_t constexpr kCols = 15;
|
||||
|
||||
auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
p_dmat->Info().labels_.Resize(kRows);
|
||||
|
||||
@ -67,10 +67,11 @@ TEST(GBTree, WrongUpdater) {
|
||||
|
||||
#ifdef XGBOOST_USE_CUDA
|
||||
TEST(GBTree, ChoosePredictor) {
|
||||
// The test ensures data don't get pulled into device.
|
||||
size_t constexpr kRows = 17;
|
||||
size_t constexpr kCols = 15;
|
||||
|
||||
auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
auto& data = (*(p_dmat->GetBatches<SparsePage>().begin())).data;
|
||||
p_dmat->Info().labels_.Resize(kRows);
|
||||
@ -195,7 +196,7 @@ TEST(Dart, JsonIO) {
|
||||
TEST(Dart, Prediction) {
|
||||
size_t constexpr kRows = 16, kCols = 10;
|
||||
|
||||
auto p_mat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
auto p_mat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
std::vector<bst_float> labels (kRows);
|
||||
for (size_t i = 0; i < kRows; ++i) {
|
||||
|
||||
@ -260,8 +260,8 @@ void RandomDataGenerator::GenerateCSR(
|
||||
}
|
||||
|
||||
std::shared_ptr<DMatrix>
|
||||
RandomDataGenerator::GenerateDMatix(bool with_label, bool float_label,
|
||||
size_t classes) const {
|
||||
RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label,
|
||||
size_t classes) const {
|
||||
HostDeviceVector<float> data;
|
||||
HostDeviceVector<bst_row_t> rptrs;
|
||||
HostDeviceVector<bst_feature_t> columns;
|
||||
@ -399,7 +399,7 @@ std::unique_ptr<GradientBooster> CreateTrainedGBM(
|
||||
std::unique_ptr<GradientBooster> gbm {
|
||||
GradientBooster::Create(name, generic_param, learner_model_param)};
|
||||
gbm->Configure(kwargs);
|
||||
auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
std::vector<float> labels(kRows);
|
||||
for (size_t i = 0; i < kRows; ++i) {
|
||||
|
||||
17
tests/cpp/helpers.cu
Normal file
17
tests/cpp/helpers.cu
Normal file
@ -0,0 +1,17 @@
|
||||
#include "helpers.h"
|
||||
#include "../../src/data/device_adapter.cuh"
|
||||
#include "../../src/data/device_dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDeviceDMatrix(bool with_label,
|
||||
bool float_label,
|
||||
size_t classes) {
|
||||
std::vector<HostDeviceVector<float>> storage(cols_);
|
||||
std::string arr = this->GenerateColumnarArrayInterface(&storage);
|
||||
auto adapter = data::CudfAdapter(arr);
|
||||
std::shared_ptr<DMatrix> m {
|
||||
new data::DeviceDMatrix{&adapter,
|
||||
std::numeric_limits<float>::quiet_NaN(), 1, 256}};
|
||||
return m;
|
||||
}
|
||||
} // namespace xgboost
|
||||
@ -178,13 +178,15 @@ class RandomDataGenerator {
|
||||
int32_t device_;
|
||||
int32_t seed_;
|
||||
|
||||
size_t bins_;
|
||||
|
||||
Json ArrayInterfaceImpl(HostDeviceVector<float> *storage, size_t rows,
|
||||
size_t cols) const;
|
||||
|
||||
public:
|
||||
RandomDataGenerator(bst_row_t rows, size_t cols, float sparsity)
|
||||
: rows_{rows}, cols_{cols}, sparsity_{sparsity}, lower_{0.0f}, upper_{1.0f},
|
||||
device_{-1}, seed_{0} {}
|
||||
device_{-1}, seed_{0}, bins_{0} {}
|
||||
|
||||
RandomDataGenerator &Lower(float v) {
|
||||
lower_ = v;
|
||||
@ -202,6 +204,10 @@ class RandomDataGenerator {
|
||||
seed_ = s;
|
||||
return *this;
|
||||
}
|
||||
RandomDataGenerator& Bins(size_t b) {
|
||||
bins_ = b;
|
||||
return *this;
|
||||
}
|
||||
|
||||
void GenerateDense(HostDeviceVector<float>* out) const;
|
||||
std::string GenerateArrayInterface(HostDeviceVector<float>* storage) const;
|
||||
@ -210,9 +216,14 @@ class RandomDataGenerator {
|
||||
void GenerateCSR(HostDeviceVector<float>* value, HostDeviceVector<bst_row_t>* row_ptr,
|
||||
HostDeviceVector<bst_feature_t>* columns) const;
|
||||
|
||||
std::shared_ptr<DMatrix> GenerateDMatix(bool with_label = false,
|
||||
bool float_label = true,
|
||||
size_t classes = 1) const;
|
||||
std::shared_ptr<DMatrix> GenerateDMatrix(bool with_label = false,
|
||||
bool float_label = true,
|
||||
size_t classes = 1) const;
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
std::shared_ptr<DMatrix> GenerateDeviceDMatrix(bool with_label = false,
|
||||
bool float_label = true,
|
||||
size_t classes = 1);
|
||||
#endif
|
||||
};
|
||||
|
||||
std::unique_ptr<DMatrix> CreateSparsePageDMatrix(
|
||||
|
||||
@ -22,7 +22,7 @@ class HistogramCutsWrapper : public common::HistogramCuts {
|
||||
|
||||
inline std::unique_ptr<EllpackPageImpl> BuildEllpackPage(
|
||||
int n_rows, int n_cols, bst_float sparsity= 0) {
|
||||
auto dmat = RandomDataGenerator(n_rows, n_cols, sparsity).Seed(3).GenerateDMatix();
|
||||
auto dmat = RandomDataGenerator(n_rows, n_cols, sparsity).Seed(3).GenerateDMatrix();
|
||||
const SparsePage& batch = *dmat->GetBatches<xgboost::SparsePage>().begin();
|
||||
|
||||
HistogramCutsWrapper cmat;
|
||||
|
||||
@ -15,7 +15,7 @@ TEST(Linear, Shotgun) {
|
||||
size_t constexpr kRows = 10;
|
||||
size_t constexpr kCols = 10;
|
||||
|
||||
auto p_fmat = xgboost::RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
auto p_fmat = xgboost::RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
LearnerModelParam mparam;
|
||||
@ -51,7 +51,7 @@ TEST(Linear, coordinate) {
|
||||
size_t constexpr kRows = 10;
|
||||
size_t constexpr kCols = 10;
|
||||
|
||||
auto p_fmat = xgboost::RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
auto p_fmat = xgboost::RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
LearnerModelParam mparam;
|
||||
|
||||
@ -12,7 +12,7 @@ TEST(Linear, GPUCoordinate) {
|
||||
size_t constexpr kRows = 10;
|
||||
size_t constexpr kCols = 10;
|
||||
|
||||
auto mat = xgboost::RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
auto mat = xgboost::RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
||||
|
||||
LearnerModelParam mparam;
|
||||
|
||||
@ -259,7 +259,7 @@ TEST(Objective, CPU_vs_CUDA) {
|
||||
|
||||
constexpr size_t kRows = 400;
|
||||
constexpr size_t kCols = 100;
|
||||
auto pdmat = RandomDataGenerator(kRows, kCols, 0).Seed(0).GenerateDMatix();
|
||||
auto pdmat = RandomDataGenerator(kRows, kCols, 0).Seed(0).GenerateDMatrix();
|
||||
HostDeviceVector<float> preds;
|
||||
preds.Resize(kRows);
|
||||
auto& h_preds = preds.HostVector();
|
||||
|
||||
@ -26,7 +26,7 @@ TEST(CpuPredictor, Basic) {
|
||||
|
||||
gbm::GBTreeModel model = CreateTestModel(¶m);
|
||||
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
// Test predict batch
|
||||
PredictionCacheEntry out_predictions;
|
||||
|
||||
@ -31,7 +31,7 @@ TEST(GPUPredictor, Basic) {
|
||||
|
||||
for (size_t i = 1; i < 33; i *= 2) {
|
||||
int n_row = i, n_col = i;
|
||||
auto dmat = RandomDataGenerator(n_row, n_col, 0).GenerateDMatix();
|
||||
auto dmat = RandomDataGenerator(n_row, n_col, 0).GenerateDMatrix();
|
||||
|
||||
LearnerModelParam param;
|
||||
param.num_feature = n_col;
|
||||
@ -58,16 +58,33 @@ TEST(GPUPredictor, Basic) {
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, EllpackBasic) {
|
||||
size_t constexpr kCols {8};
|
||||
for (size_t bins = 2; bins < 258; bins += 16) {
|
||||
size_t rows = bins * 16;
|
||||
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", rows, bins);
|
||||
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", bins, bins);
|
||||
auto p_m = RandomDataGenerator{rows, kCols, 0.0}
|
||||
.Bins(bins)
|
||||
.Device(0)
|
||||
.GenerateDeviceDMatrix(true);
|
||||
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", rows, kCols, p_m);
|
||||
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", bins, kCols, p_m);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, EllpackTraining) {
|
||||
size_t constexpr kRows { 128 };
|
||||
TestTrainingPrediction(kRows, "gpu_hist");
|
||||
size_t constexpr kRows { 128 }, kCols { 16 }, kBins { 64 };
|
||||
auto p_ellpack = RandomDataGenerator{kRows, kCols, 0.0}
|
||||
.Bins(kBins)
|
||||
.Device(0)
|
||||
.GenerateDeviceDMatrix(true);
|
||||
std::vector<HostDeviceVector<float>> storage(kCols);
|
||||
auto columnar = RandomDataGenerator{kRows, kCols, 0.0}
|
||||
.Device(0)
|
||||
.GenerateColumnarArrayInterface(&storage);
|
||||
auto adapter = data::CudfAdapter(columnar);
|
||||
std::shared_ptr<DMatrix> p_full {
|
||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)
|
||||
};
|
||||
TestTrainingPrediction(kRows, "gpu_hist", p_full, p_ellpack);
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, ExternalMemoryTest) {
|
||||
|
||||
@ -21,7 +21,7 @@ TEST(Predictor, PredictionCache) {
|
||||
DMatrix* m;
|
||||
// Add a cache that is immediately expired.
|
||||
auto add_cache = [&]() {
|
||||
auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
container.Cache(p_dmat, GenericParameter::kCpuId);
|
||||
m = p_dmat.get();
|
||||
};
|
||||
@ -32,17 +32,16 @@ TEST(Predictor, PredictionCache) {
|
||||
EXPECT_ANY_THROW(container.Entry(m));
|
||||
}
|
||||
|
||||
// Only run this test when CUDA is enabled.
|
||||
void TestTrainingPrediction(size_t rows, std::string tree_method) {
|
||||
void TestTrainingPrediction(size_t rows, std::string tree_method,
|
||||
std::shared_ptr<DMatrix> p_full,
|
||||
std::shared_ptr<DMatrix> p_hist) {
|
||||
size_t constexpr kCols = 16;
|
||||
size_t constexpr kClasses = 3;
|
||||
size_t constexpr kIters = 3;
|
||||
|
||||
std::unique_ptr<Learner> learner;
|
||||
auto train = [&](std::string predictor, HostDeviceVector<float>* out) {
|
||||
auto p_m = RandomDataGenerator(rows, kCols, 0).GenerateDMatix();
|
||||
|
||||
auto &h_label = p_m->Info().labels_.HostVector();
|
||||
auto train = [&](std::string predictor, HostDeviceVector<float> *out) {
|
||||
auto &h_label = p_hist->Info().labels_.HostVector();
|
||||
h_label.resize(rows);
|
||||
|
||||
for (size_t i = 0; i < rows; ++i) {
|
||||
@ -52,30 +51,31 @@ void TestTrainingPrediction(size_t rows, std::string tree_method) {
|
||||
learner.reset(Learner::Create({}));
|
||||
learner->SetParam("tree_method", tree_method);
|
||||
learner->SetParam("objective", "multi:softprob");
|
||||
learner->SetParam("predictor", predictor);
|
||||
learner->SetParam("num_feature", std::to_string(kCols));
|
||||
learner->SetParam("num_class", std::to_string(kClasses));
|
||||
learner->Configure();
|
||||
|
||||
for (size_t i = 0; i < kIters; ++i) {
|
||||
learner->UpdateOneIter(i, p_m);
|
||||
learner->UpdateOneIter(i, p_hist);
|
||||
}
|
||||
|
||||
HostDeviceVector<float> from_full;
|
||||
learner->Predict(p_full, false, &from_full);
|
||||
|
||||
HostDeviceVector<float> from_hist;
|
||||
learner->Predict(p_hist, false, &from_hist);
|
||||
|
||||
for (size_t i = 0; i < rows; ++i) {
|
||||
EXPECT_NEAR(from_hist.ConstHostVector()[i],
|
||||
from_full.ConstHostVector()[i], kRtEps);
|
||||
}
|
||||
learner->Predict(p_m, false, out);
|
||||
};
|
||||
// Alternate the predictor, CPU predictor can not use ellpack while GPU predictor can
|
||||
// not use CPU histogram index. So it's guaranteed one of the following is not
|
||||
// predicting from histogram index. Note: As of writing only GPU supports predicting
|
||||
// from gradient index, the test is written for future portability.
|
||||
|
||||
HostDeviceVector<float> predictions_0;
|
||||
train("cpu_predictor", &predictions_0);
|
||||
|
||||
HostDeviceVector<float> predictions_1;
|
||||
train("gpu_predictor", &predictions_1);
|
||||
|
||||
for (size_t i = 0; i < rows; ++i) {
|
||||
EXPECT_NEAR(predictions_1.ConstHostVector()[i],
|
||||
predictions_0.ConstHostVector()[i], kRtEps);
|
||||
}
|
||||
}
|
||||
|
||||
void TestInplacePrediction(dmlc::any x, std::string predictor,
|
||||
@ -83,7 +83,7 @@ void TestInplacePrediction(dmlc::any x, std::string predictor,
|
||||
int32_t device) {
|
||||
size_t constexpr kClasses { 4 };
|
||||
auto gen = RandomDataGenerator{rows, cols, 0.5}.Device(device);
|
||||
std::shared_ptr<DMatrix> m = gen.GenerateDMatix(true, false, kClasses);
|
||||
std::shared_ptr<DMatrix> m = gen.GenerateDMatrix(true, false, kClasses);
|
||||
|
||||
std::unique_ptr<Learner> learner {
|
||||
Learner::Create({m})
|
||||
|
||||
@ -8,11 +8,12 @@
|
||||
|
||||
namespace xgboost {
|
||||
template <typename Page>
|
||||
void TestPredictionFromGradientIndex(std::string name, size_t rows, int32_t bins) {
|
||||
constexpr size_t kCols { 8 }, kClasses { 3 };
|
||||
void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols,
|
||||
std::shared_ptr<DMatrix> p_hist) {
|
||||
constexpr size_t kClasses { 3 };
|
||||
|
||||
LearnerModelParam param;
|
||||
param.num_feature = kCols;
|
||||
param.num_feature = cols;
|
||||
param.num_output_group = kClasses;
|
||||
param.base_score = 0.5;
|
||||
|
||||
@ -25,16 +26,10 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, int32_t bins
|
||||
gbm::GBTreeModel model = CreateTestModel(¶m, kClasses);
|
||||
|
||||
{
|
||||
auto p_ellpack = RandomDataGenerator(rows, kCols, 0).GenerateDMatix();
|
||||
// Use same number of bins as rows.
|
||||
for (auto const &page DMLC_ATTRIBUTE_UNUSED :
|
||||
p_ellpack->GetBatches<Page>({0, static_cast<int32_t>(bins), 0})) {
|
||||
}
|
||||
|
||||
auto p_precise = RandomDataGenerator(rows, kCols, 0).GenerateDMatix();
|
||||
auto p_precise = RandomDataGenerator(rows, cols, 0).GenerateDMatrix();
|
||||
|
||||
PredictionCacheEntry approx_out_predictions;
|
||||
predictor->PredictBatch(p_ellpack.get(), &approx_out_predictions, model, 0);
|
||||
predictor->PredictBatch(p_hist.get(), &approx_out_predictions, model, 0);
|
||||
|
||||
PredictionCacheEntry precise_out_predictions;
|
||||
predictor->PredictBatch(p_precise.get(), &precise_out_predictions, model, 0);
|
||||
@ -49,14 +44,17 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, int32_t bins
|
||||
// Predictor should never try to create the histogram index by itself. As only
|
||||
// histogram index from training data is valid and predictor doesn't known which
|
||||
// matrix is used for training.
|
||||
auto p_dmat = RandomDataGenerator(rows, kCols, 0).GenerateDMatix();
|
||||
auto p_dmat = RandomDataGenerator(rows, cols, 0).GenerateDMatrix();
|
||||
PredictionCacheEntry precise_out_predictions;
|
||||
predictor->PredictBatch(p_dmat.get(), &precise_out_predictions, model, 0);
|
||||
ASSERT_FALSE(p_dmat->PageExists<Page>());
|
||||
}
|
||||
}
|
||||
|
||||
void TestTrainingPrediction(size_t rows, std::string tree_method);
|
||||
// p_full and p_hist should come from the same data set.
|
||||
void TestTrainingPrediction(size_t rows, std::string tree_method,
|
||||
std::shared_ptr<DMatrix> p_full,
|
||||
std::shared_ptr<DMatrix> p_hist);
|
||||
|
||||
void TestInplacePrediction(dmlc::any x, std::string predictor,
|
||||
bst_row_t rows, bst_feature_t cols,
|
||||
|
||||
@ -7,7 +7,7 @@ namespace xgboost {
|
||||
TEST(RandomDataGenerator, DMatrix) {
|
||||
size_t constexpr kRows { 16 }, kCols { 32 };
|
||||
float constexpr kSparsity { 0.4f };
|
||||
auto p_dmatrix = RandomDataGenerator{kRows, kCols, kSparsity}.GenerateDMatix();
|
||||
auto p_dmatrix = RandomDataGenerator{kRows, kCols, kSparsity}.GenerateDMatrix();
|
||||
|
||||
HostDeviceVector<float> csr_value;
|
||||
HostDeviceVector<bst_row_t> csr_rptr;
|
||||
|
||||
@ -16,7 +16,7 @@ namespace xgboost {
|
||||
TEST(Learner, Basic) {
|
||||
using Arg = std::pair<std::string, std::string>;
|
||||
auto args = {Arg("tree_method", "exact")};
|
||||
auto mat_ptr = RandomDataGenerator{10, 10, 0.0f}.GenerateDMatix();
|
||||
auto mat_ptr = RandomDataGenerator{10, 10, 0.0f}.GenerateDMatrix();
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create({mat_ptr}));
|
||||
learner->SetParams(args);
|
||||
|
||||
@ -34,7 +34,7 @@ TEST(Learner, ParameterValidation) {
|
||||
ConsoleLogger::Configure({{"verbosity", "2"}});
|
||||
size_t constexpr kRows = 1;
|
||||
size_t constexpr kCols = 1;
|
||||
auto p_mat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatix();
|
||||
auto p_mat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
|
||||
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create({p_mat}));
|
||||
learner->SetParam("validate_parameters", "1");
|
||||
@ -56,7 +56,7 @@ TEST(Learner, CheckGroup) {
|
||||
bst_feature_t constexpr kNumCols = 15;
|
||||
|
||||
std::shared_ptr<DMatrix> p_mat{
|
||||
RandomDataGenerator{kNumRows, kNumCols, 0.0f}.GenerateDMatix()};
|
||||
RandomDataGenerator{kNumRows, kNumCols, 0.0f}.GenerateDMatrix()};
|
||||
std::vector<bst_float> weight(kNumGroups);
|
||||
std::vector<bst_int> group(kNumGroups);
|
||||
group[0] = 2;
|
||||
@ -137,7 +137,7 @@ TEST(Learner, JsonModelIO) {
|
||||
int32_t constexpr kIters = 4;
|
||||
|
||||
std::shared_ptr<DMatrix> p_dmat{
|
||||
RandomDataGenerator{kRows, 10, 0}.GenerateDMatix()};
|
||||
RandomDataGenerator{kRows, 10, 0}.GenerateDMatrix()};
|
||||
p_dmat->Info().labels_.Resize(kRows);
|
||||
CHECK_NE(p_dmat->Info().num_col_, 0);
|
||||
|
||||
@ -179,7 +179,7 @@ TEST(Learner, JsonModelIO) {
|
||||
TEST(Learner, BinaryModelIO) {
|
||||
size_t constexpr kRows = 8;
|
||||
int32_t constexpr kIters = 4;
|
||||
auto p_dmat = RandomDataGenerator{kRows, 10, 0}.GenerateDMatix();
|
||||
auto p_dmat = RandomDataGenerator{kRows, 10, 0}.GenerateDMatrix();
|
||||
p_dmat->Info().labels_.Resize(kRows);
|
||||
|
||||
std::unique_ptr<Learner> learner{Learner::Create({p_dmat})};
|
||||
@ -213,7 +213,7 @@ TEST(Learner, BinaryModelIO) {
|
||||
TEST(Learner, GPUConfiguration) {
|
||||
using Arg = std::pair<std::string, std::string>;
|
||||
size_t constexpr kRows = 10;
|
||||
auto p_dmat = RandomDataGenerator(kRows, 10, 0).GenerateDMatix();
|
||||
auto p_dmat = RandomDataGenerator(kRows, 10, 0).GenerateDMatrix();
|
||||
std::vector<std::shared_ptr<DMatrix>> mat {p_dmat};
|
||||
std::vector<bst_float> labels(kRows);
|
||||
for (size_t i = 0; i < labels.size(); ++i) {
|
||||
|
||||
@ -156,7 +156,7 @@ class SerializationTest : public ::testing::Test {
|
||||
protected:
|
||||
~SerializationTest() override = default;
|
||||
void SetUp() override {
|
||||
p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatix();
|
||||
p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatrix();
|
||||
|
||||
p_dmat_->Info().labels_.Resize(kRows);
|
||||
auto &h_labels = p_dmat_->Info().labels_.HostVector();
|
||||
@ -352,7 +352,7 @@ TEST_F(SerializationTest, GPUCoordDescent) {
|
||||
class LogitSerializationTest : public SerializationTest {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatix();
|
||||
p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatrix();
|
||||
|
||||
std::shared_ptr<DMatrix> p_dmat{p_dmat_};
|
||||
p_dmat->Info().labels_.Resize(kRows);
|
||||
@ -487,7 +487,7 @@ class MultiClassesSerializationTest : public SerializationTest {
|
||||
size_t constexpr static kClasses = 4;
|
||||
|
||||
void SetUp() override {
|
||||
p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatix();
|
||||
p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatrix();
|
||||
|
||||
std::shared_ptr<DMatrix> p_dmat{p_dmat_};
|
||||
p_dmat->Info().labels_.Resize(kRows);
|
||||
|
||||
@ -11,7 +11,7 @@ void TestDeterminsticHistogram() {
|
||||
size_t constexpr kBins = 24, kCols = 8, kRows = 32768, kRounds = 16;
|
||||
float constexpr kLower = -1e-2, kUpper = 1e2;
|
||||
|
||||
auto matrix = RandomDataGenerator(kRows, kCols, 0.5).GenerateDMatix();
|
||||
auto matrix = RandomDataGenerator(kRows, kCols, 0.5).GenerateDMatrix();
|
||||
BatchParam batch_param{0, static_cast<int32_t>(kBins), 0};
|
||||
|
||||
for (auto const& batch : matrix->GetBatches<EllpackPage>(batch_param)) {
|
||||
|
||||
@ -314,7 +314,7 @@ TEST(GpuHist, MinSplitLoss) {
|
||||
constexpr size_t kRows = 32;
|
||||
constexpr size_t kCols = 16;
|
||||
constexpr float kSparsity = 0.6;
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, kSparsity).Seed(3).GenerateDMatix();
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, kSparsity).Seed(3).GenerateDMatrix();
|
||||
auto gpair = GenerateRandomGradients(kRows);
|
||||
|
||||
{
|
||||
|
||||
@ -15,7 +15,7 @@ TEST(GrowHistMaker, InteractionConstraint) {
|
||||
GenericParameter param;
|
||||
param.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
|
||||
|
||||
auto p_dmat = RandomDataGenerator{kRows, kCols, 0.6f}.Seed(3).GenerateDMatix();
|
||||
auto p_dmat = RandomDataGenerator{kRows, kCols, 0.6f}.Seed(3).GenerateDMatrix();
|
||||
|
||||
HostDeviceVector<GradientPair> gradients (kRows);
|
||||
std::vector<GradientPair>& h_gradients = gradients.HostVector();
|
||||
|
||||
@ -29,7 +29,7 @@ TEST(Updater, Prune) {
|
||||
{ {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f},
|
||||
{0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f} };
|
||||
std::shared_ptr<DMatrix> p_dmat {
|
||||
RandomDataGenerator{32, 10, 0}.GenerateDMatix() };
|
||||
RandomDataGenerator{32, 10, 0}.GenerateDMatrix() };
|
||||
|
||||
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
||||
|
||||
|
||||
@ -139,7 +139,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
{ {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
||||
{0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f} };
|
||||
size_t constexpr kMaxBins = 4;
|
||||
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatix();
|
||||
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
|
||||
// dense, no missing values
|
||||
|
||||
common::GHistIndexMatrix gmat;
|
||||
@ -238,7 +238,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
cfg_{args} {
|
||||
QuantileHistMaker::Configure(args);
|
||||
spliteval_->Init(¶m_);
|
||||
dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatix();
|
||||
dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
||||
builder_.reset(
|
||||
new BuilderMock(
|
||||
param_,
|
||||
|
||||
@ -22,7 +22,7 @@ TEST(Updater, Refresh) {
|
||||
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
||||
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
||||
std::shared_ptr<DMatrix> p_dmat{
|
||||
RandomDataGenerator{kRows, kCols, 0.4f}.Seed(3).GenerateDMatix()};
|
||||
RandomDataGenerator{kRows, kCols, 0.4f}.Seed(3).GenerateDMatrix()};
|
||||
std::vector<std::pair<std::string, std::string>> cfg{
|
||||
{"reg_alpha", "0.0"},
|
||||
{"num_feature", std::to_string(kCols)},
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user