Fix loading DMatrix binary in distributed env. (#8149)

- Try to load DMatrix binary before trying to parse text input.
- Remove some unmaintained code.
This commit is contained in:
Jiaming Yuan 2022-08-10 22:53:16 +08:00 committed by GitHub
parent 8fc60b31bc
commit 446d536c23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 71 additions and 80 deletions

View File

@ -203,9 +203,7 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
API_END();
}
XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
int silent,
DMatrixHandle *out) {
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) {
API_BEGIN();
bool load_row_split = false;
#if defined(XGBOOST_USE_FEDERATED)

View File

@ -381,35 +381,6 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
return out;
}
// try to load group information from file, if exists
inline bool MetaTryLoadGroup(const std::string& fname,
std::vector<unsigned>* group) {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
if (fi == nullptr) return false;
dmlc::istream is(fi.get());
group->clear();
group->push_back(0);
unsigned nline = 0;
while (is >> nline) {
group->push_back(group->back() + nline);
}
return true;
}
// try to load weight information from file, if exists
inline bool MetaTryLoadFloatInfo(const std::string& fname,
std::vector<bst_float>* data) {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
if (fi == nullptr) return false;
dmlc::istream is(fi.get());
data->clear();
bst_float value;
while (is >> value) {
data->push_back(value);
}
return true;
}
namespace {
template <int32_t D, typename T>
void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T, D>* p_out) {
@ -805,9 +776,7 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
return nullptr;
}
DMatrix* DMatrix::Load(const std::string& uri,
bool silent,
bool load_row_split,
DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
const std::string& file_format) {
std::string fname, cache_file;
size_t dlm_pos = uri.find('#');
@ -840,42 +809,39 @@ DMatrix* DMatrix::Load(const std::string& uri,
} else {
fname = uri;
}
int partid = 0, npart = 1;
if (load_row_split) {
partid = rabit::GetRank();
npart = rabit::GetWorldSize();
} else {
// test option to load in part
npart = dmlc::GetEnv("XGBOOST_TEST_NPART", 1);
}
if (npart != 1) {
LOG(CONSOLE) << "Load part of data " << partid
<< " of " << npart << " parts";
}
// legacy handling of binary data loading
if (file_format == "auto" && npart == 1) {
if (file_format == "auto") {
DMatrix* loaded = TryLoadBinary(fname, silent);
if (loaded) {
return loaded;
}
}
int partid = 0, npart = 1;
if (load_row_split) {
partid = rabit::GetRank();
npart = rabit::GetWorldSize();
} else {
// test option to load in part
npart = 1;
}
if (npart != 1) {
LOG(CONSOLE) << "Load part of data " << partid << " of " << npart << " parts";
}
DMatrix* dmat {nullptr};
try {
if (cache_file.empty()) {
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart,
file_format.c_str()));
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
data::FileAdapter adapter(parser.get());
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(),
1, cache_file);
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1, cache_file);
} else {
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart),
file_format};
dmat = new data::SparsePageDMatrix{
&iter,
dmat = new data::SparsePageDMatrix{&iter,
iter.Proxy(),
data::fileiter::Reset,
data::fileiter::Next,
@ -911,24 +877,6 @@ DMatrix* DMatrix::Load(const std::string& uri,
* partitioned data will fail the train/val validation check
* since partitioned data not knowing the real number of features. */
rabit::Allreduce<rabit::op::Max>(&dmat->Info().num_col_, 1);
// backward compatiblity code.
if (!load_row_split) {
MetaInfo& info = dmat->Info();
if (MetaTryLoadGroup(fname + ".group", &info.group_ptr_) && !silent) {
LOG(CONSOLE) << info.group_ptr_.size() - 1
<< " groups are loaded from " << fname << ".group";
}
if (MetaTryLoadFloatInfo(fname + ".base_margin", &info.base_margin_.Data()->HostVector()) &&
!silent) {
LOG(CONSOLE) << info.base_margin_.Size() << " base_margin are loaded from " << fname
<< ".base_margin";
}
if (MetaTryLoadFloatInfo
(fname + ".weight", &info.weights_.HostVector()) && !silent) {
LOG(CONSOLE) << info.weights_.Size()
<< " weights are loaded from " << fname << ".weight";
}
}
return dmat;
}

View File

@ -118,9 +118,10 @@ def make_categorical(
def generate_array(
with_weights: bool = False
) -> Tuple[xgb.dask._DaskCollection, xgb.dask._DaskCollection,
Optional[xgb.dask._DaskCollection]]:
with_weights: bool = False,
) -> Tuple[
xgb.dask._DataT, xgb.dask._DaskCollection, Optional[xgb.dask._DaskCollection]
]:
chunk_size = 20
rng = da.random.RandomState(1994)
X = rng.random_sample((kRows, kCols), chunks=(chunk_size, -1))
@ -1265,6 +1266,50 @@ def test_dask_iteration_range(client: "Client"):
class TestWithDask:
def test_dmatrix_binary(self, client: "Client") -> None:
def save_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
with xgb.dask.RabitContext(rabit_args):
rank = xgb.rabit.get_rank()
X, y = tm.make_categorical(100, 4, 4, False)
Xy = xgb.DMatrix(X, y, enable_categorical=True)
path = os.path.join(tmpdir, f"{rank}.bin")
Xy.save_binary(path)
def load_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
with xgb.dask.RabitContext(rabit_args):
rank = xgb.rabit.get_rank()
path = os.path.join(tmpdir, f"{rank}.bin")
Xy = xgb.DMatrix(path)
assert Xy.num_row() == 100
assert Xy.num_col() == 4
with tempfile.TemporaryDirectory() as tmpdir:
workers = _get_client_workers(client)
rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), None, client
)
futures = []
for w in workers:
# same argument for each worker, must set pure to False otherwise dask
# will try to reuse the result from the first worker and hang waiting
# for it.
f = client.submit(
save_dmatrix, rabit_args, tmpdir, workers=[w], pure=False
)
futures.append(f)
client.gather(futures)
rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), None, client
)
futures = []
for w in workers:
f = client.submit(
load_dmatrix, rabit_args, tmpdir, workers=[w], pure=False
)
futures.append(f)
client.gather(futures)
@pytest.mark.parametrize('config_key,config_value', [('verbosity', 0), ('use_rmm', True)])
def test_global_config(
self,