* Fix loading DMatrix binary in distributed env. (#8149) - Try to load DMatrix binary before trying to parse text input. - Remove some unmaintained code. * Fix.
This commit is contained in:
parent
922d2137dd
commit
0fd6391a77
@ -193,9 +193,7 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
|
||||
int silent,
|
||||
DMatrixHandle *out) {
|
||||
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
bool load_row_split = false;
|
||||
if (rabit::IsDistributed()) {
|
||||
|
||||
@ -378,35 +378,6 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
||||
return out;
|
||||
}
|
||||
|
||||
// try to load group information from file, if exists
|
||||
inline bool MetaTryLoadGroup(const std::string& fname,
|
||||
std::vector<unsigned>* group) {
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
|
||||
if (fi == nullptr) return false;
|
||||
dmlc::istream is(fi.get());
|
||||
group->clear();
|
||||
group->push_back(0);
|
||||
unsigned nline = 0;
|
||||
while (is >> nline) {
|
||||
group->push_back(group->back() + nline);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// try to load weight information from file, if exists
|
||||
inline bool MetaTryLoadFloatInfo(const std::string& fname,
|
||||
std::vector<bst_float>* data) {
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
|
||||
if (fi == nullptr) return false;
|
||||
dmlc::istream is(fi.get());
|
||||
data->clear();
|
||||
bst_float value;
|
||||
while (is >> value) {
|
||||
data->push_back(value);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <int32_t D, typename T>
|
||||
void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
||||
@ -811,9 +782,7 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
DMatrix* DMatrix::Load(const std::string& uri,
|
||||
bool silent,
|
||||
bool load_row_split,
|
||||
DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
|
||||
const std::string& file_format) {
|
||||
std::string fname, cache_file;
|
||||
size_t dlm_pos = uri.find('#');
|
||||
@ -846,42 +815,39 @@ DMatrix* DMatrix::Load(const std::string& uri,
|
||||
} else {
|
||||
fname = uri;
|
||||
}
|
||||
int partid = 0, npart = 1;
|
||||
if (load_row_split) {
|
||||
partid = rabit::GetRank();
|
||||
npart = rabit::GetWorldSize();
|
||||
} else {
|
||||
// test option to load in part
|
||||
npart = dmlc::GetEnv("XGBOOST_TEST_NPART", 1);
|
||||
}
|
||||
|
||||
if (npart != 1) {
|
||||
LOG(CONSOLE) << "Load part of data " << partid
|
||||
<< " of " << npart << " parts";
|
||||
}
|
||||
|
||||
// legacy handling of binary data loading
|
||||
if (file_format == "auto" && npart == 1) {
|
||||
if (file_format == "auto") {
|
||||
DMatrix* loaded = TryLoadBinary(fname, silent);
|
||||
if (loaded) {
|
||||
return loaded;
|
||||
}
|
||||
}
|
||||
|
||||
int partid = 0, npart = 1;
|
||||
if (load_row_split) {
|
||||
partid = rabit::GetRank();
|
||||
npart = rabit::GetWorldSize();
|
||||
} else {
|
||||
// test option to load in part
|
||||
npart = 1;
|
||||
}
|
||||
|
||||
if (npart != 1) {
|
||||
LOG(CONSOLE) << "Load part of data " << partid << " of " << npart << " parts";
|
||||
}
|
||||
|
||||
DMatrix* dmat {nullptr};
|
||||
try {
|
||||
if (cache_file.empty()) {
|
||||
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
|
||||
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart,
|
||||
file_format.c_str()));
|
||||
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
|
||||
data::FileAdapter adapter(parser.get());
|
||||
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(),
|
||||
1, cache_file);
|
||||
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1, cache_file);
|
||||
} else {
|
||||
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart),
|
||||
file_format};
|
||||
dmat = new data::SparsePageDMatrix{
|
||||
&iter,
|
||||
dmat = new data::SparsePageDMatrix{&iter,
|
||||
iter.Proxy(),
|
||||
data::fileiter::Reset,
|
||||
data::fileiter::Next,
|
||||
@ -917,24 +883,6 @@ DMatrix* DMatrix::Load(const std::string& uri,
|
||||
* partitioned data will fail the train/val validation check
|
||||
* since partitioned data not knowing the real number of features. */
|
||||
rabit::Allreduce<rabit::op::Max>(&dmat->Info().num_col_, 1);
|
||||
// backward compatiblity code.
|
||||
if (!load_row_split) {
|
||||
MetaInfo& info = dmat->Info();
|
||||
if (MetaTryLoadGroup(fname + ".group", &info.group_ptr_) && !silent) {
|
||||
LOG(CONSOLE) << info.group_ptr_.size() - 1
|
||||
<< " groups are loaded from " << fname << ".group";
|
||||
}
|
||||
if (MetaTryLoadFloatInfo(fname + ".base_margin", &info.base_margin_.Data()->HostVector()) &&
|
||||
!silent) {
|
||||
LOG(CONSOLE) << info.base_margin_.Size() << " base_margin are loaded from " << fname
|
||||
<< ".base_margin";
|
||||
}
|
||||
if (MetaTryLoadFloatInfo
|
||||
(fname + ".weight", &info.weights_.HostVector()) && !silent) {
|
||||
LOG(CONSOLE) << info.weights_.Size()
|
||||
<< " weights are loaded from " << fname << ".weight";
|
||||
}
|
||||
}
|
||||
return dmat;
|
||||
}
|
||||
template <typename DataIterHandle, typename DMatrixHandle,
|
||||
|
||||
@ -111,9 +111,10 @@ def make_categorical(
|
||||
|
||||
|
||||
def generate_array(
|
||||
with_weights: bool = False
|
||||
) -> Tuple[xgb.dask._DaskCollection, xgb.dask._DaskCollection,
|
||||
Optional[xgb.dask._DaskCollection]]:
|
||||
with_weights: bool = False,
|
||||
) -> Tuple[
|
||||
xgb.dask._DaskCollection, xgb.dask._DaskCollection, Optional[xgb.dask._DaskCollection]
|
||||
]:
|
||||
chunk_size = 20
|
||||
rng = da.random.RandomState(1994)
|
||||
X = rng.random_sample((kRows, kCols), chunks=(chunk_size, -1))
|
||||
@ -1190,6 +1191,50 @@ def test_dask_iteration_range(client: "Client"):
|
||||
|
||||
|
||||
class TestWithDask:
|
||||
def test_dmatrix_binary(self, client: "Client") -> None:
|
||||
def save_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
rank = xgb.rabit.get_rank()
|
||||
X, y = tm.make_categorical(100, 4, 4, False)
|
||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||
path = os.path.join(tmpdir, f"{rank}.bin")
|
||||
Xy.save_binary(path)
|
||||
|
||||
def load_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
rank = xgb.rabit.get_rank()
|
||||
path = os.path.join(tmpdir, f"{rank}.bin")
|
||||
Xy = xgb.DMatrix(path)
|
||||
assert Xy.num_row() == 100
|
||||
assert Xy.num_col() == 4
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
workers = _get_client_workers(client)
|
||||
rabit_args = client.sync(
|
||||
xgb.dask._get_rabit_args, len(workers), None, client
|
||||
)
|
||||
futures = []
|
||||
for w in workers:
|
||||
# same argument for each worker, must set pure to False otherwise dask
|
||||
# will try to reuse the result from the first worker and hang waiting
|
||||
# for it.
|
||||
f = client.submit(
|
||||
save_dmatrix, rabit_args, tmpdir, workers=[w], pure=False
|
||||
)
|
||||
futures.append(f)
|
||||
client.gather(futures)
|
||||
|
||||
rabit_args = client.sync(
|
||||
xgb.dask._get_rabit_args, len(workers), None, client
|
||||
)
|
||||
futures = []
|
||||
for w in workers:
|
||||
f = client.submit(
|
||||
load_dmatrix, rabit_args, tmpdir, workers=[w], pure=False
|
||||
)
|
||||
futures.append(f)
|
||||
client.gather(futures)
|
||||
|
||||
@pytest.mark.parametrize('config_key,config_value', [('verbosity', 0), ('use_rmm', True)])
|
||||
def test_global_config(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user