sync upstream code
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include "coll.h"
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t, int64_t
|
||||
#include <functional> // for bit_and, bit_or, bit_xor, plus
|
||||
#include <string> // for string
|
||||
#include <type_traits> // for is_floating_point_v, is_same_v
|
||||
#include <utility> // for move
|
||||
|
||||
@@ -60,6 +61,8 @@ bool constexpr IsFloatingPointV() {
|
||||
return cpu_impl::RingAllreduce(comm, data, erased_fn, type);
|
||||
};
|
||||
|
||||
std::string msg{"Floating point is not supported for bit wise collective operations."};
|
||||
|
||||
auto rc = DispatchDType(type, [&](auto t) {
|
||||
using T = decltype(t);
|
||||
switch (op) {
|
||||
@@ -74,21 +77,21 @@ bool constexpr IsFloatingPointV() {
|
||||
}
|
||||
case Op::kBitwiseAND: {
|
||||
if constexpr (IsFloatingPointV<T>()) {
|
||||
return Fail("Invalid type.");
|
||||
return Fail(msg);
|
||||
} else {
|
||||
return fn(std::bit_and<>{}, t);
|
||||
}
|
||||
}
|
||||
case Op::kBitwiseOR: {
|
||||
if constexpr (IsFloatingPointV<T>()) {
|
||||
return Fail("Invalid type.");
|
||||
return Fail(msg);
|
||||
} else {
|
||||
return fn(std::bit_or<>{}, t);
|
||||
}
|
||||
}
|
||||
case Op::kBitwiseXOR: {
|
||||
if constexpr (IsFloatingPointV<T>()) {
|
||||
return Fail("Invalid type.");
|
||||
return Fail(msg);
|
||||
} else {
|
||||
return fn(std::bit_xor<>{}, t);
|
||||
}
|
||||
|
||||
@@ -75,9 +75,11 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
||||
} << [&] {
|
||||
return next->NonBlocking(true);
|
||||
} << [&] {
|
||||
SockAddrV4 addr;
|
||||
SockAddress addr;
|
||||
return listener->Accept(prev.get(), &addr);
|
||||
} << [&] { return prev->NonBlocking(true); };
|
||||
} << [&] {
|
||||
return prev->NonBlocking(true);
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@@ -157,10 +159,13 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
||||
}
|
||||
|
||||
for (std::int32_t r = 0; r < comm.Rank(); ++r) {
|
||||
SockAddrV4 addr;
|
||||
auto peer = std::shared_ptr<TCPSocket>(TCPSocket::CreatePtr(comm.Domain()));
|
||||
rc = std::move(rc) << [&] { return listener->Accept(peer.get(), &addr); }
|
||||
<< [&] { return peer->RecvTimeout(timeout); };
|
||||
rc = std::move(rc) << [&] {
|
||||
SockAddress addr;
|
||||
return listener->Accept(peer.get(), &addr);
|
||||
} << [&] {
|
||||
return peer->RecvTimeout(timeout);
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@@ -187,7 +192,9 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
|
||||
: HostComm{std::move(host), port, timeout, retry, std::move(task_id)},
|
||||
nccl_path_{std::move(nccl_path)} {
|
||||
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
if (!rc.OK()) {
|
||||
SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc)));
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_NCCL) && !defined(XGBOOST_USE_RCCL)
|
||||
@@ -247,10 +254,12 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
||||
// get ring neighbors
|
||||
std::string snext;
|
||||
tracker.Recv(&snext);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to receive the rank for the next worker.", std::move(rc));
|
||||
}
|
||||
auto jnext = Json::Load(StringView{snext});
|
||||
|
||||
proto::PeerInfo ninfo{jnext};
|
||||
|
||||
// get the rank of this worker
|
||||
this->rank_ = BootstrapPrev(ninfo.rank, world);
|
||||
this->tracker_.rank = rank_;
|
||||
@@ -258,7 +267,7 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
||||
std::vector<std::shared_ptr<TCPSocket>> workers;
|
||||
rc = ConnectWorkers(*this, &listener, lport, ninfo, timeout, retry, &workers);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
return Fail("Failed to connect to other workers.", std::move(rc));
|
||||
}
|
||||
|
||||
CHECK(this->channels_.empty());
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <numeric> // for accumulate
|
||||
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#if defined(__unix__) || defined(__APPLE__)
|
||||
#include <netdb.h> // gethostbyname
|
||||
@@ -27,12 +27,14 @@
|
||||
#include "tracker.h"
|
||||
#include "xgboost/collective/result.h" // for Result, Fail, Success
|
||||
#include "xgboost/collective/socket.h" // for GetHostName, FailWithCode, MakeSockAddress, ...
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
Tracker::Tracker(Json const& config)
|
||||
: n_workers_{static_cast<std::int32_t>(
|
||||
RequiredArg<Integer const>(config, "n_workers", __func__))},
|
||||
: sortby_{static_cast<SortBy>(
|
||||
OptionalArg<Integer const>(config, "sortby", static_cast<Integer::Int>(SortBy::kHost)))},
|
||||
n_workers_{
|
||||
static_cast<std::int32_t>(RequiredArg<Integer const>(config, "n_workers", __func__))},
|
||||
port_{static_cast<std::int32_t>(OptionalArg<Integer const>(config, "port", Integer::Int{0}))},
|
||||
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
|
||||
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
|
||||
@@ -56,13 +58,15 @@ Result Tracker::WaitUntilReady() const {
|
||||
return Success();
|
||||
}
|
||||
|
||||
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr)
|
||||
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddress addr)
|
||||
: sock_{std::move(sock)} {
|
||||
std::int32_t rank{0};
|
||||
Json jcmd;
|
||||
std::int32_t port{0};
|
||||
|
||||
rc_ = Success() << [&] { return proto::Magic{}.Verify(&sock_); } << [&] {
|
||||
rc_ = Success() << [&] {
|
||||
return proto::Magic{}.Verify(&sock_);
|
||||
} << [&] {
|
||||
return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_);
|
||||
} << [&] {
|
||||
std::string cmd;
|
||||
@@ -83,8 +87,13 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
|
||||
}
|
||||
return Success();
|
||||
} << [&] {
|
||||
auto host = addr.Addr();
|
||||
info_ = proto::PeerInfo{host, port, rank};
|
||||
if (addr.IsV4()) {
|
||||
auto host = addr.V4().Addr();
|
||||
info_ = proto::PeerInfo{host, port, rank};
|
||||
} else {
|
||||
auto host = addr.V6().Addr();
|
||||
info_ = proto::PeerInfo{host, port, rank};
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
}
|
||||
@@ -92,19 +101,19 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
|
||||
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
|
||||
std::string self;
|
||||
auto rc = collective::GetHostAddress(&self);
|
||||
auto host = OptionalArg<String>(config, "host", self);
|
||||
host_ = OptionalArg<String>(config, "host", self);
|
||||
|
||||
host_ = host;
|
||||
listener_ = TCPSocket::Create(SockDomain::kV4);
|
||||
rc = listener_.Bind(host, &this->port_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
auto addr = MakeSockAddress(xgboost::StringView{host_}, 0);
|
||||
listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6);
|
||||
rc = listener_.Bind(host_, &this->port_);
|
||||
SafeColl(rc);
|
||||
listener_.Listen();
|
||||
}
|
||||
|
||||
Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
auto& workers = *p_workers;
|
||||
|
||||
std::sort(workers.begin(), workers.end(), WorkerCmp{});
|
||||
std::sort(workers.begin(), workers.end(), WorkerCmp{this->sortby_});
|
||||
|
||||
std::vector<std::thread> bootstrap_threads;
|
||||
for (std::int32_t r = 0; r < n_workers_; ++r) {
|
||||
@@ -224,7 +233,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
|
||||
while (state.ShouldContinue()) {
|
||||
TCPSocket sock;
|
||||
SockAddrV4 addr;
|
||||
SockAddress addr;
|
||||
this->ready_ = true;
|
||||
auto rc = listener_.Accept(&sock, &addr);
|
||||
if (!rc.OK()) {
|
||||
@@ -291,7 +300,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
|
||||
[[nodiscard]] Json RabitTracker::WorkerArgs() const {
|
||||
auto rc = this->WaitUntilReady();
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
|
||||
Json args{Object{}};
|
||||
args["DMLC_TRACKER_URI"] = String{host_};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <chrono> // for seconds
|
||||
@@ -36,6 +36,16 @@ namespace xgboost::collective {
|
||||
* signal an error to the tracker and the tracker will notify other workers.
|
||||
*/
|
||||
class Tracker {
|
||||
protected:
|
||||
// How to sort the workers, either by host name or by task ID. When using a multi-GPU
|
||||
// setting, multiple workers can occupy the same host, in which case one should sort
|
||||
// workers by task. Due to compatibility reason, the task ID is not always available, so
|
||||
// we use host as the default.
|
||||
enum class SortBy : std::int8_t {
|
||||
kHost = 0,
|
||||
kTask = 1,
|
||||
} sortby_;
|
||||
|
||||
protected:
|
||||
std::int32_t n_workers_{0};
|
||||
std::int32_t port_{-1};
|
||||
@@ -76,7 +86,7 @@ class RabitTracker : public Tracker {
|
||||
Result rc_;
|
||||
|
||||
public:
|
||||
explicit WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr);
|
||||
explicit WorkerProxy(std::int32_t world, TCPSocket sock, SockAddress addr);
|
||||
WorkerProxy(WorkerProxy const& that) = delete;
|
||||
WorkerProxy(WorkerProxy&& that) = default;
|
||||
WorkerProxy& operator=(WorkerProxy const&) = delete;
|
||||
@@ -96,11 +106,14 @@ class RabitTracker : public Tracker {
|
||||
|
||||
void Send(StringView value) { this->sock_.Send(value); }
|
||||
};
|
||||
// provide an ordering for workers, this helps us get deterministic topology.
|
||||
// Provide an ordering for workers, this helps us get deterministic topology.
|
||||
struct WorkerCmp {
|
||||
SortBy sortby;
|
||||
explicit WorkerCmp(SortBy sortby) : sortby{sortby} {}
|
||||
|
||||
[[nodiscard]] bool operator()(WorkerProxy const& lhs, WorkerProxy const& rhs) {
|
||||
auto const& lh = lhs.Host();
|
||||
auto const& rh = rhs.Host();
|
||||
auto const& lh = sortby == Tracker::SortBy::kHost ? lhs.Host() : lhs.TaskID();
|
||||
auto const& rh = sortby == Tracker::SortBy::kHost ? rhs.Host() : rhs.TaskID();
|
||||
|
||||
if (lh != rh) {
|
||||
return lh < rh;
|
||||
|
||||
@@ -72,7 +72,7 @@ class SparseColumnIter : public Column<BinIdxT> {
|
||||
|
||||
public:
|
||||
SparseColumnIter(common::Span<const BinIdxT> index, bst_bin_t least_bin_idx,
|
||||
common::Span<const size_t> row_ind, bst_row_t first_row_idx)
|
||||
common::Span<const size_t> row_ind, bst_idx_t first_row_idx)
|
||||
: Base{index, least_bin_idx}, row_ind_(row_ind) {
|
||||
// first_row_id is the first row in the leaf partition
|
||||
const size_t* row_data = RowIndices();
|
||||
@@ -301,7 +301,7 @@ class ColumnMatrix {
|
||||
}
|
||||
|
||||
template <typename BinIdxType>
|
||||
auto SparseColumn(bst_feature_t fidx, bst_row_t first_row_idx) const {
|
||||
auto SparseColumn(bst_feature_t fidx, bst_idx_t first_row_idx) const {
|
||||
const size_t feature_offset = feature_offsets_[fidx]; // to get right place for certain feature
|
||||
const size_t column_size = feature_offsets_[fidx + 1] - feature_offset;
|
||||
common::Span<const BinIdxType> bin_index = {
|
||||
@@ -325,7 +325,7 @@ class ColumnMatrix {
|
||||
// all columns are dense column and has no missing value
|
||||
// FIXME(jiamingy): We don't need a column matrix if there's no missing value.
|
||||
template <typename RowBinIdxT>
|
||||
void SetIndexNoMissing(bst_row_t base_rowid, RowBinIdxT const* row_index, const size_t n_samples,
|
||||
void SetIndexNoMissing(bst_idx_t base_rowid, RowBinIdxT const* row_index, const size_t n_samples,
|
||||
const size_t n_features, int32_t n_threads) {
|
||||
missing_.GrowTo(feature_offsets_[n_features], false);
|
||||
|
||||
|
||||
@@ -21,11 +21,9 @@
|
||||
#include <thrust/unique.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cstddef> // for size_t
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_allocator.cuh>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
@@ -33,7 +31,6 @@
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "common.h"
|
||||
#include "xgboost/global_config.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/span.h"
|
||||
|
||||
@@ -34,7 +34,7 @@ HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins
|
||||
HistogramCuts out;
|
||||
auto const &info = m->Info();
|
||||
auto n_threads = ctx->Threads();
|
||||
std::vector<bst_row_t> reduced(info.num_col_, 0);
|
||||
std::vector<bst_idx_t> reduced(info.num_col_, 0);
|
||||
for (auto const &page : m->GetBatches<SparsePage>()) {
|
||||
auto const &entries_per_column =
|
||||
CalcColumnSize(data::SparsePageAdapterBatch{page.GetView()}, info.num_col_, n_threads,
|
||||
@@ -209,10 +209,10 @@ void RowsWiseBuildHistKernel(Span<GradientPair const> gpair,
|
||||
CHECK(offsets);
|
||||
}
|
||||
|
||||
auto get_row_ptr = [&](bst_row_t ridx) {
|
||||
auto get_row_ptr = [&](bst_idx_t ridx) {
|
||||
return kFirstPage ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
|
||||
};
|
||||
auto get_rid = [&](bst_row_t ridx) { return kFirstPage ? ridx : (ridx - base_rowid); };
|
||||
auto get_rid = [&](bst_idx_t ridx) { return kFirstPage ? ridx : (ridx - base_rowid); };
|
||||
|
||||
const size_t n_features =
|
||||
get_row_ptr(row_indices.begin[0] + 1) - get_row_ptr(row_indices.begin[0]);
|
||||
@@ -275,10 +275,10 @@ void ColsWiseBuildHistKernel(Span<GradientPair const> gpair,
|
||||
auto const &row_ptr = gmat.row_ptr.data();
|
||||
auto base_rowid = gmat.base_rowid;
|
||||
const uint32_t *offsets = gmat.index.Offset();
|
||||
auto get_row_ptr = [&](bst_row_t ridx) {
|
||||
auto get_row_ptr = [&](bst_idx_t ridx) {
|
||||
return kFirstPage ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
|
||||
};
|
||||
auto get_rid = [&](bst_row_t ridx) { return kFirstPage ? ridx : (ridx - base_rowid); };
|
||||
auto get_rid = [&](bst_idx_t ridx) { return kFirstPage ? ridx : (ridx - base_rowid); };
|
||||
|
||||
const size_t n_features = gmat.cut.Ptrs().size() - 1;
|
||||
const size_t n_columns = n_features;
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <cstddef> // for size_t
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@@ -39,7 +37,7 @@ size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows) {
|
||||
return std::min(num_cuts, num_rows);
|
||||
}
|
||||
|
||||
size_t RequiredSampleCuts(bst_row_t num_rows, bst_feature_t num_columns,
|
||||
size_t RequiredSampleCuts(bst_idx_t num_rows, bst_feature_t num_columns,
|
||||
size_t max_bins, size_t nnz) {
|
||||
auto per_column = RequiredSampleCutsPerColumn(max_bins, num_rows);
|
||||
auto if_dense = num_columns * per_column;
|
||||
@@ -47,7 +45,7 @@ size_t RequiredSampleCuts(bst_row_t num_rows, bst_feature_t num_columns,
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz,
|
||||
size_t RequiredMemory(bst_idx_t num_rows, bst_feature_t num_columns, size_t nnz,
|
||||
size_t num_bins, bool with_weights) {
|
||||
size_t peak = 0;
|
||||
// 0. Allocate cut pointer in quantile container by increasing: n_columns + 1
|
||||
@@ -85,7 +83,7 @@ size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz,
|
||||
return peak;
|
||||
}
|
||||
|
||||
size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_row_t num_rows,
|
||||
size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_idx_t num_rows,
|
||||
bst_feature_t columns, size_t nnz, int device, size_t num_cuts,
|
||||
bool has_weight) {
|
||||
auto constexpr kIntMax = static_cast<std::size_t>(std::numeric_limits<std::int32_t>::max());
|
||||
@@ -123,7 +121,7 @@ void SortByWeight(dh::device_vector<float>* weights, dh::device_vector<Entry>* s
|
||||
[=] __device__(const Entry& a, const Entry& b) { return a.index == b.index; });
|
||||
}
|
||||
|
||||
void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst_row_t> d_cuts_ptr,
|
||||
void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst_idx_t> d_cuts_ptr,
|
||||
dh::device_vector<Entry>* p_sorted_entries,
|
||||
dh::device_vector<float>* p_sorted_weights,
|
||||
dh::caching_device_vector<size_t>* p_column_sizes_scan) {
|
||||
@@ -210,7 +208,7 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c
|
||||
sorted_entries = dh::device_vector<Entry>(h_data.begin() + begin, h_data.begin() + end);
|
||||
}
|
||||
|
||||
bst_row_t base_rowid = page.base_rowid;
|
||||
bst_idx_t base_rowid = page.base_rowid;
|
||||
|
||||
dh::device_vector<float> entry_weight;
|
||||
auto cuctx = ctx->CUDACtx();
|
||||
|
||||
@@ -187,7 +187,7 @@ inline size_t constexpr BytesPerElement(bool has_weight) {
|
||||
* directly if it's not 0.
|
||||
*/
|
||||
size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
|
||||
bst_row_t num_rows, bst_feature_t columns,
|
||||
bst_idx_t num_rows, bst_feature_t columns,
|
||||
size_t nnz, int device,
|
||||
size_t num_cuts, bool has_weight);
|
||||
|
||||
@@ -210,7 +210,7 @@ size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows);
|
||||
*
|
||||
* \return The estimated bytes
|
||||
*/
|
||||
size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz,
|
||||
size_t RequiredMemory(bst_idx_t num_rows, bst_feature_t num_columns, size_t nnz,
|
||||
size_t num_bins, bool with_weights);
|
||||
|
||||
// Count the valid entries in each column and copy them out.
|
||||
@@ -241,7 +241,7 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, Ran
|
||||
void SortByWeight(dh::device_vector<float>* weights,
|
||||
dh::device_vector<Entry>* sorted_entries);
|
||||
|
||||
void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst_row_t> d_cuts_ptr,
|
||||
void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst_idx_t> d_cuts_ptr,
|
||||
dh::device_vector<Entry>* p_sorted_entries,
|
||||
dh::device_vector<float>* p_sorted_weights,
|
||||
dh::caching_device_vector<size_t>* p_column_sizes_scan);
|
||||
|
||||
@@ -178,7 +178,7 @@ template class HostDeviceVector<uint8_t>;
|
||||
template class HostDeviceVector<int8_t>;
|
||||
template class HostDeviceVector<FeatureType>;
|
||||
template class HostDeviceVector<Entry>;
|
||||
template class HostDeviceVector<uint64_t>; // bst_row_t
|
||||
template class HostDeviceVector<bst_idx_t>;
|
||||
template class HostDeviceVector<uint32_t>; // bst_feature_t
|
||||
|
||||
#if defined(__APPLE__) || defined(__EMSCRIPTEN__)
|
||||
|
||||
@@ -416,7 +416,7 @@ template class HostDeviceVector<uint8_t>;
|
||||
template class HostDeviceVector<int8_t>;
|
||||
template class HostDeviceVector<FeatureType>;
|
||||
template class HostDeviceVector<Entry>;
|
||||
template class HostDeviceVector<uint64_t>; // bst_row_t
|
||||
template class HostDeviceVector<bst_idx_t>;
|
||||
template class HostDeviceVector<uint32_t>; // bst_feature_t
|
||||
template class HostDeviceVector<RegTree::Node>;
|
||||
template class HostDeviceVector<RegTree::CategoricalSplitMatrix::Segment>;
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
namespace xgboost::common {
|
||||
template <typename WQSketch>
|
||||
SketchContainerImpl<WQSketch>::SketchContainerImpl(Context const *ctx,
|
||||
std::vector<bst_row_t> columns_size,
|
||||
std::vector<bst_idx_t> columns_size,
|
||||
int32_t max_bins,
|
||||
Span<FeatureType const> feature_types,
|
||||
bool use_group)
|
||||
@@ -120,8 +120,8 @@ namespace {
|
||||
template <typename T>
|
||||
struct QuantileAllreduce {
|
||||
common::Span<T> global_values;
|
||||
common::Span<size_t> worker_indptr;
|
||||
common::Span<size_t> feature_indptr;
|
||||
common::Span<bst_idx_t> worker_indptr;
|
||||
common::Span<bst_idx_t> feature_indptr;
|
||||
size_t n_features{0};
|
||||
/**
|
||||
* \brief Get sketch values of the a feature from a worker.
|
||||
@@ -147,7 +147,7 @@ template <typename WQSketch>
|
||||
void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
Context const *ctx, MetaInfo const &info,
|
||||
std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
||||
std::vector<size_t> *p_worker_segments, std::vector<bst_row_t> *p_sketches_scan,
|
||||
std::vector<bst_idx_t> *p_worker_segments, std::vector<bst_idx_t> *p_sketches_scan,
|
||||
std::vector<typename WQSketch::Entry> *p_global_sketches) {
|
||||
auto &worker_segments = *p_worker_segments;
|
||||
worker_segments.resize(1, 0);
|
||||
@@ -156,7 +156,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
auto n_columns = sketches_.size();
|
||||
|
||||
// get the size of each feature.
|
||||
std::vector<bst_row_t> sketch_size;
|
||||
std::vector<bst_idx_t> sketch_size;
|
||||
for (size_t i = 0; i < reduced.size(); ++i) {
|
||||
if (IsCat(feature_types_, i)) {
|
||||
sketch_size.push_back(0);
|
||||
@@ -165,7 +165,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
}
|
||||
}
|
||||
// turn the size into CSC indptr
|
||||
std::vector<bst_row_t> &sketches_scan = *p_sketches_scan;
|
||||
std::vector<bst_idx_t> &sketches_scan = *p_sketches_scan;
|
||||
sketches_scan.resize((n_columns + 1) * world, 0);
|
||||
size_t beg_scan = rank * (n_columns + 1); // starting storage for current worker.
|
||||
std::partial_sum(sketch_size.cbegin(), sketch_size.cend(), sketches_scan.begin() + beg_scan + 1);
|
||||
@@ -226,7 +226,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const* ctx, Meta
|
||||
CHECK_EQ(feature_ptr.front(), 0);
|
||||
|
||||
// gather all feature ptrs from workers
|
||||
std::vector<size_t> global_feat_ptrs(feature_ptr.size() * world_size, 0);
|
||||
std::vector<bst_idx_t> global_feat_ptrs(feature_ptr.size() * world_size, 0);
|
||||
size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker
|
||||
std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin);
|
||||
auto rc = collective::GlobalSum(
|
||||
@@ -241,7 +241,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const* ctx, Meta
|
||||
}
|
||||
|
||||
// indptr for indexing workers
|
||||
std::vector<size_t> global_worker_ptr(world_size + 1, 0);
|
||||
std::vector<bst_idx_t> global_worker_ptr(world_size + 1, 0);
|
||||
global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr
|
||||
rc = collective::GlobalSum(ctx, info,
|
||||
linalg::MakeVec(global_worker_ptr.data(), global_worker_ptr.size()));
|
||||
@@ -298,14 +298,14 @@ void SketchContainerImpl<WQSketch>::AllReduce(
|
||||
reduced.resize(sketches_.size());
|
||||
|
||||
// Prune the intermediate num cuts for synchronization.
|
||||
std::vector<bst_row_t> global_column_size(columns_size_);
|
||||
std::vector<bst_idx_t> global_column_size(columns_size_);
|
||||
auto rc = collective::GlobalSum(
|
||||
ctx, info, linalg::MakeVec(global_column_size.data(), global_column_size.size()));
|
||||
collective::SafeColl(rc);
|
||||
|
||||
ParallelFor(sketches_.size(), n_threads_, [&](size_t i) {
|
||||
int32_t intermediate_num_cuts = static_cast<int32_t>(
|
||||
std::min(global_column_size[i], static_cast<size_t>(max_bins_ * WQSketch::kFactor)));
|
||||
std::min(global_column_size[i], static_cast<bst_idx_t>(max_bins_ * WQSketch::kFactor)));
|
||||
if (global_column_size[i] == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -327,8 +327,8 @@ void SketchContainerImpl<WQSketch>::AllReduce(
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<size_t> worker_segments(1, 0); // CSC pointer to sketches.
|
||||
std::vector<bst_row_t> sketches_scan((n_columns + 1) * world, 0);
|
||||
std::vector<bst_idx_t> worker_segments(1, 0); // CSC pointer to sketches.
|
||||
std::vector<bst_idx_t> sketches_scan((n_columns + 1) * world, 0);
|
||||
|
||||
std::vector<typename WQSketch::Entry> global_sketches;
|
||||
this->GatherSketchInfo(ctx, info, reduced, &worker_segments, &sketches_scan, &global_sketches);
|
||||
@@ -452,11 +452,11 @@ template class SketchContainerImpl<WXQuantileSketch<float, float>>;
|
||||
|
||||
HostSketchContainer::HostSketchContainer(Context const *ctx, bst_bin_t max_bins,
|
||||
common::Span<FeatureType const> ft,
|
||||
std::vector<size_t> columns_size, bool use_group)
|
||||
std::vector<bst_idx_t> columns_size, bool use_group)
|
||||
: SketchContainerImpl{ctx, columns_size, max_bins, ft, use_group} {
|
||||
monitor_.Init(__func__);
|
||||
ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) {
|
||||
auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]);
|
||||
auto n_bins = std::min(static_cast<bst_idx_t>(max_bins_), columns_size_[i]);
|
||||
n_bins = std::max(n_bins, static_cast<decltype(n_bins)>(1));
|
||||
auto eps = 1.0 / (static_cast<float>(n_bins) * WQSketch::kFactor);
|
||||
if (!IsCat(this->feature_types_, i)) {
|
||||
|
||||
@@ -115,16 +115,16 @@ void CopyTo(Span<T> out, Span<U> src) {
|
||||
|
||||
// Compute the merge path.
|
||||
common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
|
||||
Span<SketchEntry const> const &d_x, Span<bst_row_t const> const &x_ptr,
|
||||
Span<SketchEntry const> const &d_y, Span<bst_row_t const> const &y_ptr,
|
||||
Span<SketchEntry> out, Span<bst_row_t> out_ptr) {
|
||||
Span<SketchEntry const> const &d_x, Span<bst_idx_t const> const &x_ptr,
|
||||
Span<SketchEntry const> const &d_y, Span<bst_idx_t const> const &y_ptr,
|
||||
Span<SketchEntry> out, Span<bst_idx_t> out_ptr) {
|
||||
auto x_merge_key_it = thrust::make_zip_iterator(thrust::make_tuple(
|
||||
dh::MakeTransformIterator<bst_row_t>(
|
||||
dh::MakeTransformIterator<bst_idx_t>(
|
||||
thrust::make_counting_iterator(0ul),
|
||||
[=] __device__(size_t idx) { return dh::SegmentId(x_ptr, idx); }),
|
||||
d_x.data()));
|
||||
auto y_merge_key_it = thrust::make_zip_iterator(thrust::make_tuple(
|
||||
dh::MakeTransformIterator<bst_row_t>(
|
||||
dh::MakeTransformIterator<bst_idx_t>(
|
||||
thrust::make_counting_iterator(0ul),
|
||||
[=] __device__(size_t idx) { return dh::SegmentId(y_ptr, idx); }),
|
||||
d_y.data()));
|
||||
@@ -175,13 +175,13 @@ common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
|
||||
|
||||
auto scan_key_it = dh::MakeTransformIterator<size_t>(
|
||||
thrust::make_counting_iterator(0ul),
|
||||
[=] __device__(size_t idx) { return dh::SegmentId(out_ptr, idx); });
|
||||
[=] XGBOOST_DEVICE(size_t idx) { return dh::SegmentId(out_ptr, idx); });
|
||||
|
||||
auto scan_val_it = dh::MakeTransformIterator<Tuple>(
|
||||
merge_path.data(), [=] __device__(Tuple const &t) -> Tuple {
|
||||
merge_path.data(), [=] XGBOOST_DEVICE(Tuple const &t) -> Tuple {
|
||||
auto ind = get_ind(t); // == 0 if element is from x
|
||||
// x_counter, y_counter
|
||||
return thrust::make_tuple<uint64_t, uint64_t>(!ind, ind);
|
||||
return thrust::tuple<std::uint64_t, std::uint64_t>{!ind, ind};
|
||||
});
|
||||
|
||||
// Compute the index for both x and y (which of the element in a and b are used in each
|
||||
@@ -208,8 +208,8 @@ common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
|
||||
// run it in 2 passes to obtain the merge path and then customize the standard merge
|
||||
// algorithm.
|
||||
void MergeImpl(DeviceOrd device, Span<SketchEntry const> const &d_x,
|
||||
Span<bst_row_t const> const &x_ptr, Span<SketchEntry const> const &d_y,
|
||||
Span<bst_row_t const> const &y_ptr, Span<SketchEntry> out, Span<bst_row_t> out_ptr) {
|
||||
Span<bst_idx_t const> const &x_ptr, Span<SketchEntry const> const &d_y,
|
||||
Span<bst_idx_t const> const &y_ptr, Span<SketchEntry> out, Span<bst_idx_t> out_ptr) {
|
||||
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||
CHECK_EQ(d_x.size() + d_y.size(), out.size());
|
||||
CHECK_EQ(x_ptr.size(), out_ptr.size());
|
||||
|
||||
@@ -32,13 +32,13 @@ struct SketchUnique {
|
||||
class SketchContainer {
|
||||
public:
|
||||
static constexpr float kFactor = WQSketch::kFactor;
|
||||
using OffsetT = bst_row_t;
|
||||
using OffsetT = bst_idx_t;
|
||||
static_assert(sizeof(OffsetT) == sizeof(size_t), "Wrong type for sketch element offset.");
|
||||
|
||||
private:
|
||||
Monitor timer_;
|
||||
HostDeviceVector<FeatureType> feature_types_;
|
||||
bst_row_t num_rows_;
|
||||
bst_idx_t num_rows_;
|
||||
bst_feature_t num_columns_;
|
||||
int32_t num_bins_;
|
||||
DeviceOrd device_;
|
||||
@@ -94,7 +94,7 @@ class SketchContainer {
|
||||
* \param device GPU ID.
|
||||
*/
|
||||
SketchContainer(HostDeviceVector<FeatureType> const& feature_types, int32_t max_bin,
|
||||
bst_feature_t num_columns, bst_row_t num_rows, DeviceOrd device)
|
||||
bst_feature_t num_columns, bst_idx_t num_rows, DeviceOrd device)
|
||||
: num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
|
||||
CHECK(device.IsCUDA());
|
||||
// Initialize Sketches for this dmatrix
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2014-2023 by XGBoost Contributors
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
* \file quantile.h
|
||||
* \brief util to compute quantiles
|
||||
* \author Tianqi Chen
|
||||
@@ -701,12 +701,12 @@ inline std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
|
||||
auto n_groups = group_ptr.size() - 1;
|
||||
CHECK_EQ(info.weights_.Size(), n_groups) << error::GroupWeight();
|
||||
|
||||
bst_row_t n_samples = info.num_row_;
|
||||
bst_idx_t n_samples = info.num_row_;
|
||||
std::vector<float> results(n_samples);
|
||||
CHECK_EQ(group_ptr.back(), n_samples)
|
||||
<< error::GroupSize() << " the number of rows from the data.";
|
||||
size_t cur_group = 0;
|
||||
for (bst_row_t i = 0; i < n_samples; ++i) {
|
||||
for (bst_idx_t i = 0; i < n_samples; ++i) {
|
||||
results[i] = group_weights[cur_group];
|
||||
if (i == group_ptr[cur_group + 1]) {
|
||||
cur_group++;
|
||||
@@ -719,9 +719,9 @@ inline std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
|
||||
class HistogramCuts;
|
||||
|
||||
template <typename Batch, typename IsValid>
|
||||
std::vector<bst_row_t> CalcColumnSize(Batch const &batch, bst_feature_t const n_columns,
|
||||
std::vector<bst_idx_t> CalcColumnSize(Batch const &batch, bst_feature_t const n_columns,
|
||||
size_t const n_threads, IsValid &&is_valid) {
|
||||
std::vector<std::vector<bst_row_t>> column_sizes_tloc(n_threads);
|
||||
std::vector<std::vector<bst_idx_t>> column_sizes_tloc(n_threads);
|
||||
for (auto &column : column_sizes_tloc) {
|
||||
column.resize(n_columns, 0);
|
||||
}
|
||||
@@ -759,7 +759,7 @@ std::vector<bst_feature_t> LoadBalance(Batch const &batch, size_t nnz, bst_featu
|
||||
size_t const entries_per_thread = DivRoundUp(total_entries, nthreads);
|
||||
|
||||
// Need to calculate the size for each batch.
|
||||
std::vector<bst_row_t> entries_per_columns = CalcColumnSize(batch, n_columns, nthreads, is_valid);
|
||||
std::vector<bst_idx_t> entries_per_columns = CalcColumnSize(batch, n_columns, nthreads, is_valid);
|
||||
std::vector<bst_feature_t> cols_ptr(nthreads + 1, 0);
|
||||
size_t count{0};
|
||||
size_t current_thread{1};
|
||||
@@ -791,8 +791,8 @@ class SketchContainerImpl {
|
||||
std::vector<std::set<float>> categories_;
|
||||
std::vector<FeatureType> const feature_types_;
|
||||
|
||||
std::vector<bst_row_t> columns_size_;
|
||||
int32_t max_bins_;
|
||||
std::vector<bst_idx_t> columns_size_;
|
||||
bst_bin_t max_bins_;
|
||||
bool use_group_ind_{false};
|
||||
int32_t n_threads_;
|
||||
bool has_categorical_{false};
|
||||
@@ -805,7 +805,7 @@ class SketchContainerImpl {
|
||||
* \param max_bins maximum number of bins for each feature.
|
||||
* \param use_group whether is assigned to group to data instance.
|
||||
*/
|
||||
SketchContainerImpl(Context const *ctx, std::vector<bst_row_t> columns_size, int32_t max_bins,
|
||||
SketchContainerImpl(Context const *ctx, std::vector<bst_idx_t> columns_size, bst_bin_t max_bins,
|
||||
common::Span<FeatureType const> feature_types, bool use_group);
|
||||
|
||||
static bool UseGroup(MetaInfo const &info) {
|
||||
@@ -829,8 +829,8 @@ class SketchContainerImpl {
|
||||
// Gather sketches from all workers.
|
||||
void GatherSketchInfo(Context const *ctx, MetaInfo const &info,
|
||||
std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
||||
std::vector<bst_row_t> *p_worker_segments,
|
||||
std::vector<bst_row_t> *p_sketches_scan,
|
||||
std::vector<bst_idx_t> *p_worker_segments,
|
||||
std::vector<bst_idx_t> *p_sketches_scan,
|
||||
std::vector<typename WQSketch::Entry> *p_global_sketches);
|
||||
// Merge sketches from all workers.
|
||||
void AllReduce(Context const *ctx, MetaInfo const &info,
|
||||
@@ -901,7 +901,7 @@ class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, fl
|
||||
|
||||
public:
|
||||
HostSketchContainer(Context const *ctx, bst_bin_t max_bins, common::Span<FeatureType const> ft,
|
||||
std::vector<size_t> columns_size, bool use_group);
|
||||
std::vector<bst_idx_t> columns_size, bool use_group);
|
||||
|
||||
template <typename Batch>
|
||||
void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, float missing);
|
||||
@@ -998,7 +998,7 @@ class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float,
|
||||
public:
|
||||
explicit SortedSketchContainer(Context const *ctx, int32_t max_bins,
|
||||
common::Span<FeatureType const> ft,
|
||||
std::vector<size_t> columns_size, bool use_group)
|
||||
std::vector<bst_idx_t> columns_size, bool use_group)
|
||||
: SketchContainerImpl{ctx, columns_size, max_bins, ft, use_group} {
|
||||
monitor_.Init(__func__);
|
||||
sketches_.resize(columns_size.size());
|
||||
|
||||
@@ -73,11 +73,11 @@ constexpr size_t kAdapterUnknownSize = std::numeric_limits<size_t >::max();
|
||||
|
||||
struct COOTuple {
|
||||
COOTuple() = default;
|
||||
XGBOOST_DEVICE COOTuple(size_t row_idx, size_t column_idx, float value)
|
||||
XGBOOST_DEVICE COOTuple(bst_idx_t row_idx, bst_idx_t column_idx, float value)
|
||||
: row_idx(row_idx), column_idx(column_idx), value(value) {}
|
||||
|
||||
size_t row_idx{0};
|
||||
size_t column_idx{0};
|
||||
bst_idx_t row_idx{0};
|
||||
bst_idx_t column_idx{0};
|
||||
float value{0};
|
||||
};
|
||||
|
||||
@@ -136,12 +136,8 @@ class CSRAdapterBatch : public detail::NoMetaInfo {
|
||||
public:
|
||||
class Line {
|
||||
public:
|
||||
Line(size_t row_idx, size_t size, const unsigned* feature_idx,
|
||||
const float* values)
|
||||
: row_idx_(row_idx),
|
||||
size_(size),
|
||||
feature_idx_(feature_idx),
|
||||
values_(values) {}
|
||||
Line(bst_idx_t row_idx, bst_idx_t size, const unsigned* feature_idx, const float* values)
|
||||
: row_idx_(row_idx), size_(size), feature_idx_(feature_idx), values_(values) {}
|
||||
|
||||
size_t Size() const { return size_; }
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
@@ -149,8 +145,8 @@ class CSRAdapterBatch : public detail::NoMetaInfo {
|
||||
}
|
||||
|
||||
private:
|
||||
size_t row_idx_;
|
||||
size_t size_;
|
||||
bst_idx_t row_idx_;
|
||||
bst_idx_t size_;
|
||||
const unsigned* feature_idx_;
|
||||
const float* values_;
|
||||
};
|
||||
@@ -178,29 +174,25 @@ class CSRAdapterBatch : public detail::NoMetaInfo {
|
||||
|
||||
class CSRAdapter : public detail::SingleBatchDataIter<CSRAdapterBatch> {
|
||||
public:
|
||||
CSRAdapter(const size_t* row_ptr, const unsigned* feature_idx,
|
||||
const float* values, size_t num_rows, size_t num_elements,
|
||||
size_t num_features)
|
||||
: batch_(row_ptr, feature_idx, values, num_rows, num_elements,
|
||||
num_features),
|
||||
CSRAdapter(const size_t* row_ptr, const unsigned* feature_idx, const float* values,
|
||||
bst_idx_t num_rows, bst_idx_t num_elements, size_t num_features)
|
||||
: batch_(row_ptr, feature_idx, values, num_rows, num_elements, num_features),
|
||||
num_rows_(num_rows),
|
||||
num_columns_(num_features) {}
|
||||
const CSRAdapterBatch& Value() const override { return batch_; }
|
||||
size_t NumRows() const { return num_rows_; }
|
||||
size_t NumColumns() const { return num_columns_; }
|
||||
bst_idx_t NumRows() const { return num_rows_; }
|
||||
bst_idx_t NumColumns() const { return num_columns_; }
|
||||
|
||||
private:
|
||||
CSRAdapterBatch batch_;
|
||||
size_t num_rows_;
|
||||
size_t num_columns_;
|
||||
bst_idx_t num_rows_;
|
||||
bst_idx_t num_columns_;
|
||||
};
|
||||
|
||||
class DenseAdapterBatch : public detail::NoMetaInfo {
|
||||
public:
|
||||
DenseAdapterBatch(const float* values, size_t num_rows, size_t num_features)
|
||||
: values_(values),
|
||||
num_rows_(num_rows),
|
||||
num_features_(num_features) {}
|
||||
DenseAdapterBatch(const float* values, bst_idx_t num_rows, bst_idx_t num_features)
|
||||
: values_(values), num_rows_(num_rows), num_features_(num_features) {}
|
||||
|
||||
private:
|
||||
class Line {
|
||||
@@ -910,7 +902,7 @@ class SparsePageAdapterBatch {
|
||||
struct Line {
|
||||
Entry const* inst;
|
||||
size_t n;
|
||||
bst_row_t ridx;
|
||||
bst_idx_t ridx;
|
||||
COOTuple GetElement(size_t idx) const { return {ridx, inst[idx].index, inst[idx].fvalue}; }
|
||||
size_t Size() const { return n; }
|
||||
};
|
||||
|
||||
@@ -47,7 +47,7 @@
|
||||
#include "simple_dmatrix.h" // for SimpleDMatrix
|
||||
#include "sparse_page_writer.h" // for SparsePageFormatReg
|
||||
#include "validation.h" // for LabelsCheck, WeightsCheck, ValidateQueryGroup
|
||||
#include "xgboost/base.h" // for bst_group_t, bst_row_t, bst_float, bst_ulong
|
||||
#include "xgboost/base.h" // for bst_group_t, bst_idx_t, bst_float, bst_ulong
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||
#include "xgboost/learner.h" // for HostDeviceVector
|
||||
@@ -996,7 +996,7 @@ template DMatrix* DMatrix::Create(
|
||||
|
||||
SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
|
||||
SparsePage transpose;
|
||||
common::ParallelGroupBuilder<Entry, bst_row_t> builder(&transpose.offset.HostVector(),
|
||||
common::ParallelGroupBuilder<Entry, bst_idx_t> builder(&transpose.offset.HostVector(),
|
||||
&transpose.data.HostVector());
|
||||
builder.InitBudget(num_columns, n_threads);
|
||||
long batch_size = static_cast<long>(this->Size()); // NOLINT(*)
|
||||
@@ -1192,7 +1192,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
||||
|
||||
void SparsePage::PushCSC(const SparsePage &batch) {
|
||||
std::vector<xgboost::Entry>& self_data = data.HostVector();
|
||||
std::vector<bst_row_t>& self_offset = offset.HostVector();
|
||||
std::vector<bst_idx_t>& self_offset = offset.HostVector();
|
||||
|
||||
auto const& other_data = batch.data.ConstHostVector();
|
||||
auto const& other_offset = batch.offset.ConstHostVector();
|
||||
@@ -1211,7 +1211,7 @@ void SparsePage::PushCSC(const SparsePage &batch) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<bst_row_t> offset(other_offset.size());
|
||||
std::vector<bst_idx_t> offset(other_offset.size());
|
||||
offset[0] = 0;
|
||||
|
||||
std::vector<xgboost::Entry> data(self_data.size() + other_data.size());
|
||||
|
||||
@@ -39,7 +39,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
|
||||
return {row_idx, column_idx, value};
|
||||
}
|
||||
|
||||
[[nodiscard]] __device__ float GetElement(bst_row_t ridx, bst_feature_t fidx) const {
|
||||
[[nodiscard]] __device__ float GetElement(bst_idx_t ridx, bst_feature_t fidx) const {
|
||||
auto const& column = columns_[fidx];
|
||||
float value = column.valid.Data() == nullptr || column.valid.Check(ridx)
|
||||
? column(ridx)
|
||||
@@ -47,8 +47,8 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
|
||||
return value;
|
||||
}
|
||||
|
||||
[[nodiscard]] XGBOOST_DEVICE bst_row_t NumRows() const { return num_rows_; }
|
||||
[[nodiscard]] XGBOOST_DEVICE bst_row_t NumCols() const { return columns_.size(); }
|
||||
[[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return num_rows_; }
|
||||
[[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return columns_.size(); }
|
||||
|
||||
private:
|
||||
common::Span<ArrayInterface<1>> columns_;
|
||||
@@ -168,13 +168,13 @@ class CupyAdapterBatch : public detail::NoMetaInfo {
|
||||
float value = array_interface_(row_idx, column_idx);
|
||||
return {row_idx, column_idx, value};
|
||||
}
|
||||
[[nodiscard]] __device__ float GetElement(bst_row_t ridx, bst_feature_t fidx) const {
|
||||
[[nodiscard]] __device__ float GetElement(bst_idx_t ridx, bst_feature_t fidx) const {
|
||||
float value = array_interface_(ridx, fidx);
|
||||
return value;
|
||||
}
|
||||
|
||||
[[nodiscard]] XGBOOST_DEVICE bst_row_t NumRows() const { return array_interface_.Shape(0); }
|
||||
[[nodiscard]] XGBOOST_DEVICE bst_row_t NumCols() const { return array_interface_.Shape(1); }
|
||||
[[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return array_interface_.Shape(0); }
|
||||
[[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return array_interface_.Shape(1); }
|
||||
|
||||
private:
|
||||
ArrayInterface<2> array_interface_;
|
||||
@@ -208,8 +208,8 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
|
||||
|
||||
// Returns maximum row length
|
||||
template <typename AdapterBatchT>
|
||||
std::size_t GetRowCounts(const AdapterBatchT batch, common::Span<bst_row_t> offset, DeviceOrd device,
|
||||
float missing) {
|
||||
bst_idx_t GetRowCounts(const AdapterBatchT batch, common::Span<bst_idx_t> offset, DeviceOrd device,
|
||||
float missing) {
|
||||
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||
IsValidFunctor is_valid(missing);
|
||||
dh::safe_cuda(cudaMemsetAsync(offset.data(), '\0', offset.size_bytes()));
|
||||
@@ -231,7 +231,7 @@ std::size_t GetRowCounts(const AdapterBatchT batch, common::Span<bst_row_t> offs
|
||||
|
||||
// Count elements per row
|
||||
dh::LaunchN(n_samples * stride, [=] __device__(std::size_t idx) {
|
||||
bst_row_t cnt{0};
|
||||
bst_idx_t cnt{0};
|
||||
auto [ridx, fbeg] = linalg::UnravelIndex(idx, n_samples, stride);
|
||||
SPAN_CHECK(ridx < n_samples);
|
||||
for (bst_feature_t fidx = fbeg; fidx < n_features; fidx += stride) {
|
||||
@@ -246,10 +246,10 @@ std::size_t GetRowCounts(const AdapterBatchT batch, common::Span<bst_row_t> offs
|
||||
});
|
||||
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
bst_row_t row_stride =
|
||||
bst_idx_t row_stride =
|
||||
dh::Reduce(thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
|
||||
thrust::device_pointer_cast(offset.data()) + offset.size(),
|
||||
static_cast<bst_row_t>(0), thrust::maximum<bst_row_t>());
|
||||
static_cast<bst_idx_t>(0), thrust::maximum<bst_idx_t>());
|
||||
return row_stride;
|
||||
}
|
||||
|
||||
|
||||
@@ -175,11 +175,10 @@ struct WriteCompressedEllpackFunctor {
|
||||
|
||||
using Tuple = thrust::tuple<size_t, size_t, size_t>;
|
||||
__device__ size_t operator()(Tuple out) {
|
||||
auto e = batch.GetElement(out.get<2>());
|
||||
auto e = batch.GetElement(thrust::get<2>(out));
|
||||
if (is_valid(e)) {
|
||||
// -1 because the scan is inclusive
|
||||
size_t output_position =
|
||||
accessor.row_stride * e.row_idx + out.get<1>() - 1;
|
||||
size_t output_position = accessor.row_stride * e.row_idx + thrust::get<1>(out) - 1;
|
||||
uint32_t bin_idx = 0;
|
||||
if (common::IsCat(feature_types, e.column_idx)) {
|
||||
bin_idx = accessor.SearchBin<true>(e.value, e.column_idx);
|
||||
@@ -196,8 +195,8 @@ template <typename Tuple>
|
||||
struct TupleScanOp {
|
||||
__device__ Tuple operator()(Tuple a, Tuple b) {
|
||||
// Key equal
|
||||
if (a.template get<0>() == b.template get<0>()) {
|
||||
b.template get<1>() += a.template get<1>();
|
||||
if (thrust::get<0>(a) == thrust::get<0>(b)) {
|
||||
thrust::get<1>(b) += thrust::get<1>(a);
|
||||
return b;
|
||||
}
|
||||
// Not equal
|
||||
|
||||
@@ -193,7 +193,7 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
|
||||
|
||||
float GHistIndexMatrix::GetFvalue(std::vector<std::uint32_t> const &ptrs,
|
||||
std::vector<float> const &values, std::vector<float> const &mins,
|
||||
bst_row_t ridx, bst_feature_t fidx, bool is_cat) const {
|
||||
bst_idx_t ridx, bst_feature_t fidx, bool is_cat) const {
|
||||
if (is_cat) {
|
||||
auto gidx = GetGindex(ridx, fidx);
|
||||
if (gidx == -1) {
|
||||
|
||||
@@ -149,7 +149,7 @@ class GHistIndexMatrix {
|
||||
/** @brief max_bin for each feature. */
|
||||
bst_bin_t max_numeric_bins_per_feat;
|
||||
/** @brief base row index for current page (used by external memory) */
|
||||
bst_row_t base_rowid{0};
|
||||
bst_idx_t base_rowid{0};
|
||||
|
||||
[[nodiscard]] bst_bin_t MaxNumBinPerFeat() const {
|
||||
return std::max(static_cast<bst_bin_t>(cut.MaxCategory() + 1), max_numeric_bins_per_feat);
|
||||
@@ -230,7 +230,7 @@ class GHistIndexMatrix {
|
||||
*/
|
||||
[[nodiscard]] std::size_t RowIdx(size_t ridx) const { return row_ptr[ridx - base_rowid]; }
|
||||
|
||||
[[nodiscard]] bst_row_t Size() const { return row_ptr.empty() ? 0 : row_ptr.size() - 1; }
|
||||
[[nodiscard]] bst_idx_t Size() const { return row_ptr.empty() ? 0 : row_ptr.size() - 1; }
|
||||
[[nodiscard]] bst_feature_t Features() const { return cut.Ptrs().size() - 1; }
|
||||
|
||||
[[nodiscard]] bool ReadColumnPage(common::AlignedResourceReadStream* fi);
|
||||
@@ -243,7 +243,7 @@ class GHistIndexMatrix {
|
||||
[[nodiscard]] float GetFvalue(size_t ridx, size_t fidx, bool is_cat) const;
|
||||
[[nodiscard]] float GetFvalue(std::vector<std::uint32_t> const& ptrs,
|
||||
std::vector<float> const& values, std::vector<float> const& mins,
|
||||
bst_row_t ridx, bst_feature_t fidx, bool is_cat) const;
|
||||
bst_idx_t ridx, bst_feature_t fidx, bool is_cat) const;
|
||||
|
||||
[[nodiscard]] common::HistogramCuts& Cuts() { return cut; }
|
||||
[[nodiscard]] common::HistogramCuts const& Cuts() const { return cut; }
|
||||
|
||||
@@ -132,7 +132,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
|
||||
return HostAdapterDispatch(proxy, [](auto const& value) { return value.NumCols(); });
|
||||
};
|
||||
|
||||
std::vector<std::size_t> column_sizes;
|
||||
std::vector<bst_idx_t> column_sizes;
|
||||
auto const is_valid = data::IsValidFunctor{missing};
|
||||
auto nnz_cnt = [&]() {
|
||||
return HostAdapterDispatch(proxy, [&](auto const& value) {
|
||||
|
||||
@@ -59,7 +59,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
||||
auto& h_data = out_page.data.HostVector();
|
||||
auto& h_offset = out_page.offset.HostVector();
|
||||
size_t rptr{0};
|
||||
for (bst_row_t i = 0; i < this->Info().num_row_; i++) {
|
||||
for (bst_idx_t i = 0; i < this->Info().num_row_; i++) {
|
||||
auto inst = batch[i];
|
||||
auto prev_size = h_data.size();
|
||||
std::copy_if(inst.begin(), inst.end(), std::back_inserter(h_data),
|
||||
|
||||
@@ -54,7 +54,7 @@ void CopyDataToDMatrix(AdapterBatchT batch, common::Span<Entry> data,
|
||||
}
|
||||
|
||||
template <typename AdapterBatchT>
|
||||
void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset, DeviceOrd device,
|
||||
void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_idx_t> offset, DeviceOrd device,
|
||||
float missing) {
|
||||
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||
IsValidFunctor is_valid(missing);
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
#include <cstdint> // for int32_t, uint32_t, int64_t, uint64_t
|
||||
#include <cstdlib> // for atoi
|
||||
#include <cstring> // for memcpy, size_t, memset
|
||||
#include <functional> // for less
|
||||
#include <iomanip> // for operator<<, setiosflags
|
||||
#include <iterator> // for back_insert_iterator, distance, back_inserter
|
||||
#include <limits> // for numeric_limits
|
||||
|
||||
@@ -184,7 +184,7 @@ void FVecDrop(std::size_t const block_size, std::size_t const fvec_offset,
|
||||
static std::size_t constexpr kUnroll = 8;
|
||||
|
||||
struct SparsePageView {
|
||||
bst_row_t base_rowid;
|
||||
bst_idx_t base_rowid;
|
||||
HostSparsePageView view;
|
||||
|
||||
explicit SparsePageView(SparsePage const *p) : base_rowid{p->base_rowid} { view = p->GetView(); }
|
||||
@@ -193,7 +193,7 @@ struct SparsePageView {
|
||||
};
|
||||
|
||||
struct SingleInstanceView {
|
||||
bst_row_t base_rowid{};
|
||||
bst_idx_t base_rowid{};
|
||||
SparsePage::Inst const &inst;
|
||||
|
||||
explicit SingleInstanceView(SparsePage::Inst const &instance) : inst{instance} {}
|
||||
@@ -214,7 +214,7 @@ struct GHistIndexMatrixView {
|
||||
std::vector<float> const& values_;
|
||||
|
||||
public:
|
||||
size_t base_rowid;
|
||||
bst_idx_t base_rowid;
|
||||
|
||||
public:
|
||||
GHistIndexMatrixView(GHistIndexMatrix const &_page, uint64_t n_feat,
|
||||
@@ -292,7 +292,7 @@ class AdapterView {
|
||||
|
||||
[[nodiscard]] size_t Size() const { return adapter_->NumRows(); }
|
||||
|
||||
bst_row_t const static base_rowid = 0; // NOLINT
|
||||
bst_idx_t const static base_rowid = 0; // NOLINT
|
||||
};
|
||||
|
||||
template <typename DataView, size_t block_of_rows_size>
|
||||
|
||||
@@ -67,12 +67,12 @@ struct TreeView {
|
||||
|
||||
struct SparsePageView {
|
||||
common::Span<const Entry> d_data;
|
||||
common::Span<const bst_row_t> d_row_ptr;
|
||||
common::Span<const bst_idx_t> d_row_ptr;
|
||||
bst_feature_t num_features;
|
||||
|
||||
SparsePageView() = default;
|
||||
XGBOOST_DEVICE SparsePageView(common::Span<const Entry> data,
|
||||
common::Span<const bst_row_t> row_ptr,
|
||||
common::Span<const bst_idx_t> row_ptr,
|
||||
bst_feature_t num_features)
|
||||
: d_data{data}, d_row_ptr{row_ptr}, num_features(num_features) {}
|
||||
[[nodiscard]] __device__ float GetElement(size_t ridx, size_t fidx) const {
|
||||
@@ -113,7 +113,7 @@ struct SparsePageLoader {
|
||||
float* smem;
|
||||
|
||||
__device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features,
|
||||
bst_row_t num_rows, size_t entry_start, float)
|
||||
bst_idx_t num_rows, size_t entry_start, float)
|
||||
: use_shared(use_shared),
|
||||
data(data) {
|
||||
extern __shared__ float _smem[];
|
||||
@@ -146,7 +146,7 @@ struct SparsePageLoader {
|
||||
|
||||
struct EllpackLoader {
|
||||
EllpackDeviceAccessor const& matrix;
|
||||
XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor const& m, bool, bst_feature_t, bst_row_t,
|
||||
XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor const& m, bool, bst_feature_t, bst_idx_t,
|
||||
size_t, float)
|
||||
: matrix{m} {}
|
||||
[[nodiscard]] __device__ __forceinline__ float GetElement(size_t ridx, size_t fidx) const {
|
||||
@@ -177,7 +177,7 @@ struct DeviceAdapterLoader {
|
||||
using BatchT = Batch;
|
||||
|
||||
XGBOOST_DEV_INLINE DeviceAdapterLoader(Batch const batch, bool use_shared,
|
||||
bst_feature_t num_features, bst_row_t num_rows,
|
||||
bst_feature_t num_features, bst_idx_t num_rows,
|
||||
size_t entry_start, float missing)
|
||||
: batch{batch}, columns{num_features}, use_shared{use_shared}, is_valid{missing} {
|
||||
extern __shared__ float _smem[];
|
||||
@@ -215,7 +215,7 @@ struct DeviceAdapterLoader {
|
||||
};
|
||||
|
||||
template <bool has_missing, bool has_categorical, typename Loader>
|
||||
__device__ bst_node_t GetLeafIndex(bst_row_t ridx, TreeView const &tree,
|
||||
__device__ bst_node_t GetLeafIndex(bst_idx_t ridx, TreeView const &tree,
|
||||
Loader *loader) {
|
||||
bst_node_t nidx = 0;
|
||||
RegTree::Node n = tree.d_tree[nidx];
|
||||
@@ -230,7 +230,7 @@ __device__ bst_node_t GetLeafIndex(bst_row_t ridx, TreeView const &tree,
|
||||
}
|
||||
|
||||
template <bool has_missing, typename Loader>
|
||||
__device__ float GetLeafWeight(bst_row_t ridx, TreeView const &tree,
|
||||
__device__ float GetLeafWeight(bst_idx_t ridx, TreeView const &tree,
|
||||
Loader *loader) {
|
||||
bst_node_t nidx = -1;
|
||||
if (tree.HasCategoricalSplit()) {
|
||||
@@ -255,7 +255,7 @@ PredictLeafKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||
size_t tree_begin, size_t tree_end, size_t num_features,
|
||||
size_t num_rows, size_t entry_start, bool use_shared,
|
||||
float missing) {
|
||||
bst_row_t ridx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
bst_idx_t ridx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (ridx >= num_rows) {
|
||||
return;
|
||||
}
|
||||
@@ -670,7 +670,7 @@ __global__ void MaskBitVectorKernel(
|
||||
}
|
||||
}
|
||||
|
||||
__device__ bst_node_t GetLeafIndexByBitVector(bst_row_t ridx, TreeView const& tree,
|
||||
__device__ bst_node_t GetLeafIndexByBitVector(bst_idx_t ridx, TreeView const& tree,
|
||||
BitVector const& decision_bits,
|
||||
BitVector const& missing_bits, std::size_t num_nodes,
|
||||
std::size_t tree_offset) {
|
||||
@@ -688,7 +688,7 @@ __device__ bst_node_t GetLeafIndexByBitVector(bst_row_t ridx, TreeView const& tr
|
||||
return nidx;
|
||||
}
|
||||
|
||||
__device__ float GetLeafWeightByBitVector(bst_row_t ridx, TreeView const& tree,
|
||||
__device__ float GetLeafWeightByBitVector(bst_idx_t ridx, TreeView const& tree,
|
||||
BitVector const& decision_bits,
|
||||
BitVector const& missing_bits, std::size_t num_nodes,
|
||||
std::size_t tree_offset) {
|
||||
@@ -1177,7 +1177,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
auto max_shared_memory_bytes = ConfigureDevice(ctx_->Device());
|
||||
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
bst_row_t num_rows = info.num_row_;
|
||||
bst_idx_t num_rows = info.num_row_;
|
||||
if (tree_end == 0 || tree_end > model.trees.size()) {
|
||||
tree_end = static_cast<uint32_t>(model.trees.size());
|
||||
}
|
||||
@@ -1202,7 +1202,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
batch.data.SetDevice(ctx_->Device());
|
||||
batch.offset.SetDevice(ctx_->Device());
|
||||
bst_row_t batch_offset = 0;
|
||||
bst_idx_t batch_offset = 0;
|
||||
SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||
model.learner_model_param->num_feature};
|
||||
size_t num_rows = batch.Size();
|
||||
@@ -1225,7 +1225,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
} else {
|
||||
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) {
|
||||
bst_row_t batch_offset = 0;
|
||||
bst_idx_t batch_offset = 0;
|
||||
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_->Device())};
|
||||
size_t num_rows = batch.Size();
|
||||
auto grid =
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include <string> // for string, to_string
|
||||
|
||||
#include "../gbm/gbtree_model.h" // for GBTreeModel
|
||||
#include "xgboost/base.h" // for bst_float, Args, bst_group_t, bst_row_t
|
||||
#include "xgboost/base.h" // for bst_float, Args, bst_group_t, bst_idx_t
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for MetaInfo
|
||||
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||
@@ -34,7 +34,7 @@ Predictor* Predictor::Create(std::string const& name, Context const* ctx) {
|
||||
}
|
||||
|
||||
template <int32_t D>
|
||||
void ValidateBaseMarginShape(linalg::Tensor<float, D> const& margin, bst_row_t n_samples,
|
||||
void ValidateBaseMarginShape(linalg::Tensor<float, D> const& margin, bst_idx_t n_samples,
|
||||
bst_group_t n_groups) {
|
||||
// FIXME: Bindings other than Python doesn't have shape.
|
||||
std::string expected{"Invalid shape of base_margin. Expected: (" + std::to_string(n_samples) +
|
||||
|
||||
@@ -28,7 +28,7 @@ class ColumnSplitHelper {
|
||||
public:
|
||||
ColumnSplitHelper() = default;
|
||||
|
||||
ColumnSplitHelper(bst_row_t num_row,
|
||||
ColumnSplitHelper(bst_idx_t num_row,
|
||||
common::PartitionBuilder<kPartitionBlockSize>* partition_builder,
|
||||
common::RowSetCollection* row_set_collection)
|
||||
: partition_builder_{partition_builder}, row_set_collection_{row_set_collection} {
|
||||
@@ -85,10 +85,10 @@ class ColumnSplitHelper {
|
||||
|
||||
class CommonRowPartitioner {
|
||||
public:
|
||||
bst_row_t base_rowid = 0;
|
||||
bst_idx_t base_rowid = 0;
|
||||
|
||||
CommonRowPartitioner() = default;
|
||||
CommonRowPartitioner(Context const* ctx, bst_row_t num_row, bst_row_t _base_rowid,
|
||||
CommonRowPartitioner(Context const* ctx, bst_idx_t num_row, bst_idx_t _base_rowid,
|
||||
bool is_col_split)
|
||||
: base_rowid{_base_rowid}, is_col_split_{is_col_split} {
|
||||
row_set_collection_.Clear();
|
||||
|
||||
@@ -277,7 +277,7 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
|
||||
common::Span<GradientPair> gpair,
|
||||
DMatrix* dmat) {
|
||||
auto cuctx = ctx->CUDACtx();
|
||||
bst_row_t n_rows = dmat->Info().num_row_;
|
||||
bst_idx_t n_rows = dmat->Info().num_row_;
|
||||
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
||||
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ inline void SampleGradient(Context const* ctx, TrainParam param,
|
||||
if (param.subsample >= 1.0) {
|
||||
return;
|
||||
}
|
||||
bst_row_t n_samples = out.Shape(0);
|
||||
bst_idx_t n_samples = out.Shape(0);
|
||||
auto& rnd = common::GlobalRandom();
|
||||
|
||||
#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
||||
|
||||
@@ -192,7 +192,7 @@ struct GPUHistMakerDevice {
|
||||
std::unique_ptr<FeatureGroups> feature_groups;
|
||||
|
||||
GPUHistMakerDevice(Context const* ctx, bool is_external_memory,
|
||||
common::Span<FeatureType const> _feature_types, bst_row_t _n_rows,
|
||||
common::Span<FeatureType const> _feature_types, bst_idx_t _n_rows,
|
||||
TrainParam _param, std::shared_ptr<common::ColumnSampler> column_sampler,
|
||||
uint32_t n_features, BatchParam batch_param, MetaInfo const& info)
|
||||
: evaluator_{_param, n_features, ctx->Device()},
|
||||
|
||||
Reference in New Issue
Block a user