Use Booster context in DMatrix. (#8896)
- Pass context from booster to DMatrix. - Use context instead of integer for `n_threads`. - Check the consistency configuration for `max_bin`. - Test for all combinations of initialization options.
This commit is contained in:
@@ -25,16 +25,11 @@ class DataIterProxy {
|
||||
NextFn* next_;
|
||||
|
||||
public:
|
||||
DataIterProxy(DataIterHandle iter, ResetFn* reset, NextFn* next) :
|
||||
iter_{iter},
|
||||
reset_{reset}, next_{next} {}
|
||||
DataIterProxy(DataIterHandle iter, ResetFn* reset, NextFn* next)
|
||||
: iter_{iter}, reset_{reset}, next_{next} {}
|
||||
|
||||
bool Next() {
|
||||
return next_(iter_);
|
||||
}
|
||||
void Reset() {
|
||||
reset_(iter_);
|
||||
}
|
||||
bool Next() { return next_(iter_); }
|
||||
void Reset() { reset_(iter_); }
|
||||
};
|
||||
|
||||
/*
|
||||
@@ -68,9 +63,8 @@ class DMatrixProxy : public DMatrix {
|
||||
}
|
||||
|
||||
void SetArrayData(char const* c_interface);
|
||||
void SetCSRData(char const *c_indptr, char const *c_indices,
|
||||
char const *c_values, bst_feature_t n_features,
|
||||
bool on_host);
|
||||
void SetCSRData(char const* c_indptr, char const* c_indices, char const* c_values,
|
||||
bst_feature_t n_features, bool on_host);
|
||||
|
||||
MetaInfo& Info() override { return info_; }
|
||||
MetaInfo const& Info() const override { return info_; }
|
||||
@@ -81,6 +75,12 @@ class DMatrixProxy : public DMatrix {
|
||||
bool GHistIndexExists() const override { return false; }
|
||||
bool SparsePageExists() const override { return false; }
|
||||
|
||||
template <typename Page>
|
||||
BatchSet<Page> NoBatch() {
|
||||
LOG(FATAL) << "Proxy DMatrix cannot return data batch.";
|
||||
return BatchSet<Page>(BatchIterator<Page>(nullptr));
|
||||
}
|
||||
|
||||
DMatrix* Slice(common::Span<int32_t const> /*ridxs*/) override {
|
||||
LOG(FATAL) << "Slicing DMatrix is not supported for Proxy DMatrix.";
|
||||
return nullptr;
|
||||
@@ -89,29 +89,19 @@ class DMatrixProxy : public DMatrix {
|
||||
LOG(FATAL) << "Slicing DMatrix columns is not supported for Proxy DMatrix.";
|
||||
return nullptr;
|
||||
}
|
||||
BatchSet<SparsePage> GetRowBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
|
||||
BatchSet<SparsePage> GetRowBatches() override { return NoBatch<SparsePage>(); }
|
||||
BatchSet<CSCPage> GetColumnBatches(Context const*) override { return NoBatch<CSCPage>(); }
|
||||
BatchSet<SortedCSCPage> GetSortedColumnBatches(Context const*) override {
|
||||
return NoBatch<SortedCSCPage>();
|
||||
}
|
||||
BatchSet<CSCPage> GetColumnBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<CSCPage>(BatchIterator<CSCPage>(nullptr));
|
||||
BatchSet<EllpackPage> GetEllpackBatches(Context const*, BatchParam const&) override {
|
||||
return NoBatch<EllpackPage>();
|
||||
}
|
||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(nullptr));
|
||||
BatchSet<GHistIndexMatrix> GetGradientIndex(Context const*, BatchParam const&) override {
|
||||
return NoBatch<GHistIndexMatrix>();
|
||||
}
|
||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam&) override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(nullptr));
|
||||
}
|
||||
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam&) override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(nullptr));
|
||||
}
|
||||
BatchSet<ExtSparsePage> GetExtBatches(BatchParam const&) override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<ExtSparsePage>(BatchIterator<ExtSparsePage>(nullptr));
|
||||
BatchSet<ExtSparsePage> GetExtBatches(Context const*, BatchParam const&) override {
|
||||
return NoBatch<ExtSparsePage>();
|
||||
}
|
||||
std::any Adapter() const { return batch_; }
|
||||
};
|
||||
@@ -144,8 +134,7 @@ decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
|
||||
}
|
||||
return std::result_of_t<Fn(
|
||||
decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->Value()))>();
|
||||
return std::result_of_t<Fn(decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->Value()))>();
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::data
|
||||
|
||||
Reference in New Issue
Block a user