Use Booster context in DMatrix. (#8896)
- Pass context from booster to DMatrix. - Use context instead of integer for `n_threads`. - Check the consistency configuration for `max_bin`. - Test for all combinations of initialization options.
This commit is contained in:
@@ -3,30 +3,50 @@
|
||||
*/
|
||||
#include "xgboost/c_api.h"
|
||||
|
||||
#include <rabit/c_api.h>
|
||||
#include <algorithm> // for copy
|
||||
#include <cinttypes> // for strtoimax
|
||||
#include <cmath> // for nan
|
||||
#include <cstring> // for strcmp
|
||||
#include <fstream> // for operator<<, basic_ostream, ios, stringstream
|
||||
#include <functional> // for less
|
||||
#include <limits> // for numeric_limits
|
||||
#include <map> // for operator!=, _Rb_tree_const_iterator, _Rb_tre...
|
||||
#include <memory> // for shared_ptr, allocator, __shared_ptr_access
|
||||
#include <string> // for char_traits, basic_string, operator==, string
|
||||
#include <system_error> // for errc
|
||||
#include <utility> // for pair
|
||||
#include <vector> // for vector
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||
#include "../common/charconv.h"
|
||||
#include "../common/io.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/simple_dmatrix.h"
|
||||
#include "c_api_utils.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/global_config.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/string_view.h" // StringView
|
||||
#include "xgboost/version_config.h"
|
||||
#include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor...
|
||||
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
|
||||
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
|
||||
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
|
||||
#include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte...
|
||||
#include "../data/proxy_dmatrix.h" // for DMatrixProxy
|
||||
#include "../data/simple_dmatrix.h" // for SimpleDMatrix
|
||||
#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN
|
||||
#include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM...
|
||||
#include "dmlc/base.h" // for BeginPtr, DMLC_ATTRIBUTE_UNUSED
|
||||
#include "dmlc/io.h" // for Stream
|
||||
#include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager
|
||||
#include "dmlc/thread_local.h" // for ThreadLocalStore
|
||||
#include "rabit/c_api.h" // for RabitLinkTag
|
||||
#include "rabit/rabit.h" // for CheckPoint, LoadCheckPoint
|
||||
#include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat...
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage
|
||||
#include "xgboost/feature_map.h" // for FeatureMap
|
||||
#include "xgboost/global_config.h" // for GlobalConfiguration, GlobalConfigThreadLocal...
|
||||
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||
#include "xgboost/intrusive_ptr.h" // for xgboost
|
||||
#include "xgboost/json.h" // for Json, get, Integer, IsA, Boolean, String
|
||||
#include "xgboost/learner.h" // for Learner, PredictionType
|
||||
#include "xgboost/logging.h" // for LOG_FATAL, LogMessageFatal, CHECK, LogCheck_EQ
|
||||
#include "xgboost/predictor.h" // for PredictionCacheEntry
|
||||
#include "xgboost/span.h" // for Span
|
||||
#include "xgboost/string_view.h" // for StringView, operator<<
|
||||
#include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS...
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_server.h"
|
||||
@@ -341,10 +361,10 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle* out) {
|
||||
XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = new std::shared_ptr<xgboost::DMatrix>(new xgboost::data::DMatrixProxy);;
|
||||
*out = new std::shared_ptr<xgboost::DMatrix>(new xgboost::data::DMatrixProxy);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -746,7 +766,7 @@ XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config
|
||||
|
||||
CHECK_LE(p_m->Info().num_col_, std::numeric_limits<unsigned>::max());
|
||||
|
||||
for (auto const &page : p_m->GetBatches<ExtSparsePage>()) {
|
||||
for (auto const &page : p_m->GetBatches<ExtSparsePage>(p_m->Ctx(), BatchParam{})) {
|
||||
CHECK(page.page);
|
||||
auto const &h_offset = page.page->offset.ConstHostVector();
|
||||
std::copy(h_offset.cbegin(), h_offset.cend(), out_indptr);
|
||||
|
||||
Reference in New Issue
Block a user