Fix loading DMatrix binary in distributed env. (#8149)
- Try to load DMatrix binary before trying to parse text input. - Remove some unmaintained code.
This commit is contained in:
parent
8fc60b31bc
commit
446d536c23
@ -203,9 +203,7 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
|
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) {
|
||||||
int silent,
|
|
||||||
DMatrixHandle *out) {
|
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
bool load_row_split = false;
|
bool load_row_split = false;
|
||||||
#if defined(XGBOOST_USE_FEDERATED)
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
|
|||||||
@ -381,35 +381,6 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
// try to load group information from file, if exists
|
|
||||||
inline bool MetaTryLoadGroup(const std::string& fname,
|
|
||||||
std::vector<unsigned>* group) {
|
|
||||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
|
|
||||||
if (fi == nullptr) return false;
|
|
||||||
dmlc::istream is(fi.get());
|
|
||||||
group->clear();
|
|
||||||
group->push_back(0);
|
|
||||||
unsigned nline = 0;
|
|
||||||
while (is >> nline) {
|
|
||||||
group->push_back(group->back() + nline);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// try to load weight information from file, if exists
|
|
||||||
inline bool MetaTryLoadFloatInfo(const std::string& fname,
|
|
||||||
std::vector<bst_float>* data) {
|
|
||||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
|
|
||||||
if (fi == nullptr) return false;
|
|
||||||
dmlc::istream is(fi.get());
|
|
||||||
data->clear();
|
|
||||||
bst_float value;
|
|
||||||
while (is >> value) {
|
|
||||||
data->push_back(value);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <int32_t D, typename T>
|
template <int32_t D, typename T>
|
||||||
void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
||||||
@ -805,9 +776,7 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
DMatrix* DMatrix::Load(const std::string& uri,
|
DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
|
||||||
bool silent,
|
|
||||||
bool load_row_split,
|
|
||||||
const std::string& file_format) {
|
const std::string& file_format) {
|
||||||
std::string fname, cache_file;
|
std::string fname, cache_file;
|
||||||
size_t dlm_pos = uri.find('#');
|
size_t dlm_pos = uri.find('#');
|
||||||
@ -840,42 +809,39 @@ DMatrix* DMatrix::Load(const std::string& uri,
|
|||||||
} else {
|
} else {
|
||||||
fname = uri;
|
fname = uri;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// legacy handling of binary data loading
|
||||||
|
if (file_format == "auto") {
|
||||||
|
DMatrix* loaded = TryLoadBinary(fname, silent);
|
||||||
|
if (loaded) {
|
||||||
|
return loaded;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int partid = 0, npart = 1;
|
int partid = 0, npart = 1;
|
||||||
if (load_row_split) {
|
if (load_row_split) {
|
||||||
partid = rabit::GetRank();
|
partid = rabit::GetRank();
|
||||||
npart = rabit::GetWorldSize();
|
npart = rabit::GetWorldSize();
|
||||||
} else {
|
} else {
|
||||||
// test option to load in part
|
// test option to load in part
|
||||||
npart = dmlc::GetEnv("XGBOOST_TEST_NPART", 1);
|
npart = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (npart != 1) {
|
if (npart != 1) {
|
||||||
LOG(CONSOLE) << "Load part of data " << partid
|
LOG(CONSOLE) << "Load part of data " << partid << " of " << npart << " parts";
|
||||||
<< " of " << npart << " parts";
|
|
||||||
}
|
|
||||||
|
|
||||||
// legacy handling of binary data loading
|
|
||||||
if (file_format == "auto" && npart == 1) {
|
|
||||||
DMatrix *loaded = TryLoadBinary(fname, silent);
|
|
||||||
if (loaded) {
|
|
||||||
return loaded;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
DMatrix* dmat {nullptr};
|
DMatrix* dmat {nullptr};
|
||||||
try {
|
try {
|
||||||
if (cache_file.empty()) {
|
if (cache_file.empty()) {
|
||||||
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
|
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
|
||||||
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart,
|
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
|
||||||
file_format.c_str()));
|
|
||||||
data::FileAdapter adapter(parser.get());
|
data::FileAdapter adapter(parser.get());
|
||||||
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(),
|
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1, cache_file);
|
||||||
1, cache_file);
|
|
||||||
} else {
|
} else {
|
||||||
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart),
|
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart),
|
||||||
file_format};
|
file_format};
|
||||||
dmat = new data::SparsePageDMatrix{
|
dmat = new data::SparsePageDMatrix{&iter,
|
||||||
&iter,
|
|
||||||
iter.Proxy(),
|
iter.Proxy(),
|
||||||
data::fileiter::Reset,
|
data::fileiter::Reset,
|
||||||
data::fileiter::Next,
|
data::fileiter::Next,
|
||||||
@ -883,7 +849,7 @@ DMatrix* DMatrix::Load(const std::string& uri,
|
|||||||
1,
|
1,
|
||||||
cache_file};
|
cache_file};
|
||||||
}
|
}
|
||||||
} catch (dmlc::Error &e) {
|
} catch (dmlc::Error& e) {
|
||||||
std::vector<std::string> splited = common::Split(fname, '#');
|
std::vector<std::string> splited = common::Split(fname, '#');
|
||||||
std::vector<std::string> args = common::Split(splited.front(), '?');
|
std::vector<std::string> args = common::Split(splited.front(), '?');
|
||||||
std::string format {file_format};
|
std::string format {file_format};
|
||||||
@ -911,24 +877,6 @@ DMatrix* DMatrix::Load(const std::string& uri,
|
|||||||
* partitioned data will fail the train/val validation check
|
* partitioned data will fail the train/val validation check
|
||||||
* since partitioned data not knowing the real number of features. */
|
* since partitioned data not knowing the real number of features. */
|
||||||
rabit::Allreduce<rabit::op::Max>(&dmat->Info().num_col_, 1);
|
rabit::Allreduce<rabit::op::Max>(&dmat->Info().num_col_, 1);
|
||||||
// backward compatiblity code.
|
|
||||||
if (!load_row_split) {
|
|
||||||
MetaInfo& info = dmat->Info();
|
|
||||||
if (MetaTryLoadGroup(fname + ".group", &info.group_ptr_) && !silent) {
|
|
||||||
LOG(CONSOLE) << info.group_ptr_.size() - 1
|
|
||||||
<< " groups are loaded from " << fname << ".group";
|
|
||||||
}
|
|
||||||
if (MetaTryLoadFloatInfo(fname + ".base_margin", &info.base_margin_.Data()->HostVector()) &&
|
|
||||||
!silent) {
|
|
||||||
LOG(CONSOLE) << info.base_margin_.Size() << " base_margin are loaded from " << fname
|
|
||||||
<< ".base_margin";
|
|
||||||
}
|
|
||||||
if (MetaTryLoadFloatInfo
|
|
||||||
(fname + ".weight", &info.weights_.HostVector()) && !silent) {
|
|
||||||
LOG(CONSOLE) << info.weights_.Size()
|
|
||||||
<< " weights are loaded from " << fname << ".weight";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dmat;
|
return dmat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -118,9 +118,10 @@ def make_categorical(
|
|||||||
|
|
||||||
|
|
||||||
def generate_array(
|
def generate_array(
|
||||||
with_weights: bool = False
|
with_weights: bool = False,
|
||||||
) -> Tuple[xgb.dask._DaskCollection, xgb.dask._DaskCollection,
|
) -> Tuple[
|
||||||
Optional[xgb.dask._DaskCollection]]:
|
xgb.dask._DataT, xgb.dask._DaskCollection, Optional[xgb.dask._DaskCollection]
|
||||||
|
]:
|
||||||
chunk_size = 20
|
chunk_size = 20
|
||||||
rng = da.random.RandomState(1994)
|
rng = da.random.RandomState(1994)
|
||||||
X = rng.random_sample((kRows, kCols), chunks=(chunk_size, -1))
|
X = rng.random_sample((kRows, kCols), chunks=(chunk_size, -1))
|
||||||
@ -1265,6 +1266,50 @@ def test_dask_iteration_range(client: "Client"):
|
|||||||
|
|
||||||
|
|
||||||
class TestWithDask:
|
class TestWithDask:
|
||||||
|
def test_dmatrix_binary(self, client: "Client") -> None:
|
||||||
|
def save_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
|
||||||
|
with xgb.dask.RabitContext(rabit_args):
|
||||||
|
rank = xgb.rabit.get_rank()
|
||||||
|
X, y = tm.make_categorical(100, 4, 4, False)
|
||||||
|
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||||
|
path = os.path.join(tmpdir, f"{rank}.bin")
|
||||||
|
Xy.save_binary(path)
|
||||||
|
|
||||||
|
def load_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
|
||||||
|
with xgb.dask.RabitContext(rabit_args):
|
||||||
|
rank = xgb.rabit.get_rank()
|
||||||
|
path = os.path.join(tmpdir, f"{rank}.bin")
|
||||||
|
Xy = xgb.DMatrix(path)
|
||||||
|
assert Xy.num_row() == 100
|
||||||
|
assert Xy.num_col() == 4
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
workers = _get_client_workers(client)
|
||||||
|
rabit_args = client.sync(
|
||||||
|
xgb.dask._get_rabit_args, len(workers), None, client
|
||||||
|
)
|
||||||
|
futures = []
|
||||||
|
for w in workers:
|
||||||
|
# same argument for each worker, must set pure to False otherwise dask
|
||||||
|
# will try to reuse the result from the first worker and hang waiting
|
||||||
|
# for it.
|
||||||
|
f = client.submit(
|
||||||
|
save_dmatrix, rabit_args, tmpdir, workers=[w], pure=False
|
||||||
|
)
|
||||||
|
futures.append(f)
|
||||||
|
client.gather(futures)
|
||||||
|
|
||||||
|
rabit_args = client.sync(
|
||||||
|
xgb.dask._get_rabit_args, len(workers), None, client
|
||||||
|
)
|
||||||
|
futures = []
|
||||||
|
for w in workers:
|
||||||
|
f = client.submit(
|
||||||
|
load_dmatrix, rabit_args, tmpdir, workers=[w], pure=False
|
||||||
|
)
|
||||||
|
futures.append(f)
|
||||||
|
client.gather(futures)
|
||||||
|
|
||||||
@pytest.mark.parametrize('config_key,config_value', [('verbosity', 0), ('use_rmm', True)])
|
@pytest.mark.parametrize('config_key,config_value', [('verbosity', 0), ('use_rmm', True)])
|
||||||
def test_global_config(
|
def test_global_config(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user