Demo of federated learning using NVFlare (#7879)

Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
Rong Ou
2022-05-14 07:45:41 -07:00
committed by GitHub
parent 11e46e4bc0
commit af907e2d0d
9 changed files with 298 additions and 14 deletions

View File

@@ -1,6 +1,7 @@
"""Distributed XGBoost Rabit related API."""
import ctypes
from enum import IntEnum, unique
import logging
import pickle
from typing import Any, TypeVar, Callable, Optional, cast, List, Union
@@ -8,6 +9,8 @@ import numpy as np
from .core import _LIB, c_str, _check_call
LOGGER = logging.getLogger("[xgboost.rabit]")
def _init_rabit() -> None:
"""internal library initializer."""
@@ -224,5 +227,21 @@ def version_number() -> int:
return ret
class RabitContext:
"""A context controlling rabit initialization and finalization."""
def __init__(self, args: List[bytes]) -> None:
self.args = args
def __enter__(self) -> None:
init(self.args)
assert is_distributed()
LOGGER.debug("-------------- rabit say hello ------------------")
def __exit__(self, *args: List) -> None:
finalize()
LOGGER.debug("--------------- rabit say bye ------------------")
# initialization script
_init_rabit()