Use context in SetInfo. (#7687)

* Use the name `Context`.
* Pass a context object into `SetInfo`.
* Add context to proxy matrix.
* Add context to iterative DMatrix.

This is to remove the use of the default number of threads during `SetInfo` as a follow-up on
removing the global omp variable while preparing for CUDA stream semantic.  Currently, XGBoost
uses the legacy CUDA stream, we will gradually remove them in the future in favor of non-blocking streams.
This commit is contained in:
Jiaming Yuan
2022-03-24 22:16:26 +08:00
committed by GitHub
parent f5b20286e2
commit 64575591d8
19 changed files with 142 additions and 142 deletions

View File

@@ -148,13 +148,13 @@ class MetaInfo {
* \param dtype The type of the source data.
* \param num Number of elements in the source array.
*/
void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num);
void SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, size_t num);
/*!
* \brief Set information in the meta info with array interface.
* \param key The key of the information.
* \param interface_str String representation of json format array interface.
*/
void SetInfo(StringView key, StringView interface_str);
void SetInfo(Context const& ctx, StringView key, StringView interface_str);
void GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
const void** out_dptr) const;
@@ -176,8 +176,8 @@ class MetaInfo {
void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);
private:
void SetInfoFromHost(StringView key, Json arr);
void SetInfoFromCUDA(StringView key, Json arr);
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
/*! \brief argsort of labels */
mutable std::vector<size_t> label_order_cache_;
@@ -478,12 +478,13 @@ class DMatrix {
DMatrix() = default;
/*! \brief meta information of the dataset */
virtual MetaInfo& Info() = 0;
virtual void SetInfo(const char *key, const void *dptr, DataType dtype,
size_t num) {
this->Info().SetInfo(key, dptr, dtype, num);
virtual void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
auto const& ctx = *this->Ctx();
this->Info().SetInfo(ctx, key, dptr, dtype, num);
}
virtual void SetInfo(const char* key, std::string const& interface_str) {
this->Info().SetInfo(key, StringView{interface_str});
auto const& ctx = *this->Ctx();
this->Info().SetInfo(ctx, key, StringView{interface_str});
}
/*! \brief meta information of the dataset */
virtual const MetaInfo& Info() const = 0;
@@ -494,7 +495,7 @@ class DMatrix {
* \brief Get the context object of this DMatrix. The context is created during construction of
* DMatrix with user specified `nthread` parameter.
*/
virtual GenericParameter const* Ctx() const = 0;
virtual Context const* Ctx() const = 0;
/**
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.

View File

@@ -75,6 +75,8 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
.describe("Enable checking whether parameters are used or not.");
}
};
using Context = GenericParameter;
} // namespace xgboost
#endif // XGBOOST_GENERIC_PARAMETERS_H_