Use context in SetInfo. (#7687)
* Use the name `Context`. * Pass a context object into `SetInfo`. * Add context to proxy matrix. * Add context to iterative DMatrix. This is to remove the use of the default number of threads during `SetInfo` as a follow-up on removing the global omp variable while preparing for CUDA stream semantic. Currently, XGBoost uses the legacy CUDA stream, we will gradually remove them in the future in favor of non-blocking streams.
This commit is contained in:
@@ -148,13 +148,13 @@ class MetaInfo {
|
||||
* \param dtype The type of the source data.
|
||||
* \param num Number of elements in the source array.
|
||||
*/
|
||||
void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num);
|
||||
void SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, size_t num);
|
||||
/*!
|
||||
* \brief Set information in the meta info with array interface.
|
||||
* \param key The key of the information.
|
||||
* \param interface_str String representation of json format array interface.
|
||||
*/
|
||||
void SetInfo(StringView key, StringView interface_str);
|
||||
void SetInfo(Context const& ctx, StringView key, StringView interface_str);
|
||||
|
||||
void GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
|
||||
const void** out_dptr) const;
|
||||
@@ -176,8 +176,8 @@ class MetaInfo {
|
||||
void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);
|
||||
|
||||
private:
|
||||
void SetInfoFromHost(StringView key, Json arr);
|
||||
void SetInfoFromCUDA(StringView key, Json arr);
|
||||
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
|
||||
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
|
||||
|
||||
/*! \brief argsort of labels */
|
||||
mutable std::vector<size_t> label_order_cache_;
|
||||
@@ -478,12 +478,13 @@ class DMatrix {
|
||||
DMatrix() = default;
|
||||
/*! \brief meta information of the dataset */
|
||||
virtual MetaInfo& Info() = 0;
|
||||
virtual void SetInfo(const char *key, const void *dptr, DataType dtype,
|
||||
size_t num) {
|
||||
this->Info().SetInfo(key, dptr, dtype, num);
|
||||
virtual void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
|
||||
auto const& ctx = *this->Ctx();
|
||||
this->Info().SetInfo(ctx, key, dptr, dtype, num);
|
||||
}
|
||||
virtual void SetInfo(const char* key, std::string const& interface_str) {
|
||||
this->Info().SetInfo(key, StringView{interface_str});
|
||||
auto const& ctx = *this->Ctx();
|
||||
this->Info().SetInfo(ctx, key, StringView{interface_str});
|
||||
}
|
||||
/*! \brief meta information of the dataset */
|
||||
virtual const MetaInfo& Info() const = 0;
|
||||
@@ -494,7 +495,7 @@ class DMatrix {
|
||||
* \brief Get the context object of this DMatrix. The context is created during construction of
|
||||
* DMatrix with user specified `nthread` parameter.
|
||||
*/
|
||||
virtual GenericParameter const* Ctx() const = 0;
|
||||
virtual Context const* Ctx() const = 0;
|
||||
|
||||
/**
|
||||
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.
|
||||
|
||||
@@ -75,6 +75,8 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
|
||||
.describe("Enable checking whether parameters are used or not.");
|
||||
}
|
||||
};
|
||||
|
||||
using Context = GenericParameter;
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_GENERIC_PARAMETERS_H_
|
||||
|
||||
Reference in New Issue
Block a user