sync upstream code

This commit is contained in:
Hui Liu
2024-03-20 16:14:38 -07:00
75 changed files with 754 additions and 312 deletions

View File

@@ -28,7 +28,7 @@ class ColumnSplitHelper {
public:
ColumnSplitHelper() = default;
ColumnSplitHelper(bst_row_t num_row,
ColumnSplitHelper(bst_idx_t num_row,
common::PartitionBuilder<kPartitionBlockSize>* partition_builder,
common::RowSetCollection* row_set_collection)
: partition_builder_{partition_builder}, row_set_collection_{row_set_collection} {
@@ -85,10 +85,10 @@ class ColumnSplitHelper {
class CommonRowPartitioner {
public:
bst_row_t base_rowid = 0;
bst_idx_t base_rowid = 0;
CommonRowPartitioner() = default;
CommonRowPartitioner(Context const* ctx, bst_row_t num_row, bst_row_t _base_rowid,
CommonRowPartitioner(Context const* ctx, bst_idx_t num_row, bst_idx_t _base_rowid,
bool is_col_split)
: base_rowid{_base_rowid}, is_col_split_{is_col_split} {
row_set_collection_.Clear();

View File

@@ -277,7 +277,7 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
common::Span<GradientPair> gpair,
DMatrix* dmat) {
auto cuctx = ctx->CUDACtx();
bst_row_t n_rows = dmat->Info().num_row_;
bst_idx_t n_rows = dmat->Info().num_row_;
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);

View File

@@ -54,7 +54,7 @@ inline void SampleGradient(Context const* ctx, TrainParam param,
if (param.subsample >= 1.0) {
return;
}
bst_row_t n_samples = out.Shape(0);
bst_idx_t n_samples = out.Shape(0);
auto& rnd = common::GlobalRandom();
#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG

View File

@@ -192,7 +192,7 @@ struct GPUHistMakerDevice {
std::unique_ptr<FeatureGroups> feature_groups;
GPUHistMakerDevice(Context const* ctx, bool is_external_memory,
common::Span<FeatureType const> _feature_types, bst_row_t _n_rows,
common::Span<FeatureType const> _feature_types, bst_idx_t _n_rows,
TrainParam _param, std::shared_ptr<common::ColumnSampler> column_sampler,
uint32_t n_features, BatchParam batch_param, MetaInfo const& info)
: evaluator_{_param, n_features, ctx->Device()},