[Python] Require black and isort for new Python files. (#8096)
* [Python] Require black and isort for new Python files. - Require black and isort for spark and dask module. These files are relatively new and are more conform to the black formatter. We will convert the rest of the library as we move forward. Other libraries including dask/distributed and optuna use the same formatting style and have a more strict standard. The black formatter is indeed quite nice, automating it can help us unify the code style. - Gather Python checks into a single script.
This commit is contained in:
parent
f23cc92130
commit
8bdea72688
18
.github/workflows/main.yml
vendored
18
.github/workflows/main.yml
vendored
@ -146,24 +146,6 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
make lint
|
make lint
|
||||||
|
|
||||||
mypy:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
name: Type checking for Python
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
with:
|
|
||||||
submodules: 'true'
|
|
||||||
- uses: actions/setup-python@v2
|
|
||||||
with:
|
|
||||||
python-version: "3.8"
|
|
||||||
architecture: 'x64'
|
|
||||||
- name: Install Python packages
|
|
||||||
run: |
|
|
||||||
python -m pip install wheel setuptools mypy pandas dask[complete] distributed
|
|
||||||
- name: Run mypy
|
|
||||||
run: |
|
|
||||||
make mypy
|
|
||||||
|
|
||||||
doxygen:
|
doxygen:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
name: Generate C/C++ API doc using Doxygen
|
name: Generate C/C++ API doc using Doxygen
|
||||||
|
|||||||
31
.github/workflows/python_tests.yml
vendored
31
.github/workflows/python_tests.yml
vendored
@ -3,6 +3,37 @@ name: XGBoost-Python-Tests
|
|||||||
on: [push, pull_request]
|
on: [push, pull_request]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
python-mypy-lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
name: Type and format checks for the Python package
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest]
|
||||||
|
python-version: ["3.8"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
submodules: 'true'
|
||||||
|
- uses: conda-incubator/setup-miniconda@v2
|
||||||
|
with:
|
||||||
|
auto-update-conda: true
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
activate-environment: python_lint
|
||||||
|
environment-file: tests/ci_build/conda_env/python_lint.yml
|
||||||
|
- name: Display Conda env
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
conda info
|
||||||
|
conda list
|
||||||
|
- name: Run mypy
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
python tests/ci_build/lint_python.py --format=0 --type-check=1
|
||||||
|
- name: Run formatter
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
python tests/ci_build/lint_python.py --format=1 --type-check=0
|
||||||
|
|
||||||
python-sdist-test:
|
python-sdist-test:
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
name: Test installing XGBoost Python source package on ${{ matrix.os }}
|
name: Test installing XGBoost Python source package on ${{ matrix.os }}
|
||||||
|
|||||||
13
Makefile
13
Makefile
@ -87,19 +87,6 @@ cover: check
|
|||||||
endif
|
endif
|
||||||
|
|
||||||
|
|
||||||
# dask is required to pass, others are not
|
|
||||||
# If any of the dask tests failed, contributor won't see the other error.
|
|
||||||
mypy:
|
|
||||||
cd python-package; \
|
|
||||||
mypy . && \
|
|
||||||
mypy ../demo/guide-python/external_memory.py && \
|
|
||||||
mypy ../demo/guide-python/categorical.py && \
|
|
||||||
mypy ../demo/guide-python/cat_in_the_dat.py && \
|
|
||||||
mypy ../tests/python-gpu/test_gpu_with_dask.py && \
|
|
||||||
mypy ../tests/python/test_data_iterator.py && \
|
|
||||||
mypy ../tests/python-gpu/test_gpu_data_iterator.py || exit 1; \
|
|
||||||
mypy . || true ;
|
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o #xgboost
|
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o #xgboost
|
||||||
$(RM) -rf build_tests *.gcov tests/cpp/xgboost_test
|
$(RM) -rf build_tests *.gcov tests/cpp/xgboost_test
|
||||||
|
|||||||
@ -31,49 +31,76 @@ Optional dask configuration
|
|||||||
dask.config.set({"xgboost.scheduler_address": "192.0.0.100:12345"})
|
dask.config.set({"xgboost.scheduler_address": "192.0.0.100:12345"})
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import platform
|
|
||||||
import logging
|
|
||||||
import collections
|
import collections
|
||||||
|
import logging
|
||||||
|
import platform
|
||||||
import socket
|
import socket
|
||||||
from contextlib import contextmanager
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from threading import Thread
|
from contextlib import contextmanager
|
||||||
from functools import partial, update_wrapper
|
from functools import partial, update_wrapper
|
||||||
from typing import TYPE_CHECKING, List, Tuple, Callable, Optional, Any, Union, Dict, Set
|
from threading import Thread
|
||||||
from typing import Sequence
|
from typing import (
|
||||||
from typing import Awaitable, Generator, TypeVar
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from . import rabit, config
|
from . import config, rabit
|
||||||
|
|
||||||
from .callback import TrainingCallback
|
|
||||||
|
|
||||||
from .compat import LazyLoader
|
|
||||||
from .compat import scipy_sparse
|
|
||||||
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
|
|
||||||
from .compat import lazy_isinstance
|
|
||||||
|
|
||||||
from ._typing import FeatureNames, FeatureTypes
|
from ._typing import FeatureNames, FeatureTypes
|
||||||
|
from .callback import TrainingCallback
|
||||||
from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
|
from .compat import (
|
||||||
from .core import Objective, Metric
|
PANDAS_INSTALLED,
|
||||||
from .core import _deprecate_positional_args, _has_categorical
|
DataFrame,
|
||||||
from .training import train as worker_train
|
LazyLoader,
|
||||||
|
Series,
|
||||||
|
lazy_isinstance,
|
||||||
|
pandas_concat,
|
||||||
|
scipy_sparse,
|
||||||
|
)
|
||||||
|
from .core import (
|
||||||
|
Booster,
|
||||||
|
DataIter,
|
||||||
|
DeviceQuantileDMatrix,
|
||||||
|
DMatrix,
|
||||||
|
Metric,
|
||||||
|
Objective,
|
||||||
|
_deprecate_positional_args,
|
||||||
|
_expect,
|
||||||
|
_has_categorical,
|
||||||
|
)
|
||||||
|
from .sklearn import (
|
||||||
|
XGBClassifier,
|
||||||
|
XGBClassifierBase,
|
||||||
|
XGBModel,
|
||||||
|
XGBRanker,
|
||||||
|
XGBRankerMixIn,
|
||||||
|
XGBRegressorBase,
|
||||||
|
_check_rf_callback,
|
||||||
|
_cls_predict_proba,
|
||||||
|
_objective_decorator,
|
||||||
|
_wrap_evaluation_matrices,
|
||||||
|
xgboost_model_doc,
|
||||||
|
)
|
||||||
from .tracker import RabitTracker, get_host_ip
|
from .tracker import RabitTracker, get_host_ip
|
||||||
from .sklearn import XGBModel, XGBClassifier, XGBRegressorBase, XGBClassifierBase
|
from .training import train as worker_train
|
||||||
from .sklearn import _wrap_evaluation_matrices, _objective_decorator, _check_rf_callback
|
|
||||||
from .sklearn import XGBRankerMixIn
|
|
||||||
from .sklearn import xgboost_model_doc
|
|
||||||
from .sklearn import _cls_predict_proba
|
|
||||||
from .sklearn import XGBRanker
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from dask import dataframe as dd
|
|
||||||
from dask import array as da
|
|
||||||
import dask
|
import dask
|
||||||
import distributed
|
import distributed
|
||||||
|
from dask import array as da
|
||||||
|
from dask import dataframe as dd
|
||||||
else:
|
else:
|
||||||
dd = LazyLoader("dd", globals(), "dask.dataframe")
|
dd = LazyLoader("dd", globals(), "dask.dataframe")
|
||||||
da = LazyLoader("da", globals(), "dask.array")
|
da = LazyLoader("da", globals(), "dask.array")
|
||||||
|
|||||||
@ -4,23 +4,21 @@
|
|||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from scipy.special import expit, softmax # pylint: disable=no-name-in-module
|
|
||||||
|
|
||||||
from pyspark.ml.functions import array_to_vector, vector_to_array
|
|
||||||
from pyspark.ml import Estimator, Model
|
from pyspark.ml import Estimator, Model
|
||||||
|
from pyspark.ml.functions import array_to_vector, vector_to_array
|
||||||
from pyspark.ml.linalg import VectorUDT
|
from pyspark.ml.linalg import VectorUDT
|
||||||
|
from pyspark.ml.param import Param, Params, TypeConverters
|
||||||
from pyspark.ml.param.shared import (
|
from pyspark.ml.param.shared import (
|
||||||
HasFeaturesCol,
|
HasFeaturesCol,
|
||||||
HasLabelCol,
|
HasLabelCol,
|
||||||
HasWeightCol,
|
|
||||||
HasPredictionCol,
|
HasPredictionCol,
|
||||||
HasProbabilityCol,
|
HasProbabilityCol,
|
||||||
HasRawPredictionCol,
|
HasRawPredictionCol,
|
||||||
HasValidationIndicatorCol,
|
HasValidationIndicatorCol,
|
||||||
|
HasWeightCol,
|
||||||
)
|
)
|
||||||
from pyspark.ml.param import Param, Params, TypeConverters
|
|
||||||
from pyspark.ml.util import MLReadable, MLWritable
|
from pyspark.ml.util import MLReadable, MLWritable
|
||||||
from pyspark.sql.functions import col, pandas_udf, countDistinct, struct
|
from pyspark.sql.functions import col, countDistinct, pandas_udf, struct
|
||||||
from pyspark.sql.types import (
|
from pyspark.sql.types import (
|
||||||
ArrayType,
|
ArrayType,
|
||||||
DoubleType,
|
DoubleType,
|
||||||
@ -29,35 +27,32 @@ from pyspark.sql.types import (
|
|||||||
LongType,
|
LongType,
|
||||||
ShortType,
|
ShortType,
|
||||||
)
|
)
|
||||||
|
from scipy.special import expit, softmax # pylint: disable=no-name-in-module
|
||||||
import xgboost
|
|
||||||
from xgboost import XGBClassifier, XGBRegressor
|
|
||||||
from xgboost.core import Booster
|
from xgboost.core import Booster
|
||||||
from xgboost.training import train as worker_train
|
from xgboost.training import train as worker_train
|
||||||
|
|
||||||
from .data import (
|
import xgboost
|
||||||
_convert_partition_data_to_dmatrix,
|
from xgboost import XGBClassifier, XGBRegressor
|
||||||
)
|
|
||||||
|
from .data import _convert_partition_data_to_dmatrix
|
||||||
from .model import (
|
from .model import (
|
||||||
SparkXGBReader,
|
|
||||||
SparkXGBWriter,
|
|
||||||
SparkXGBModelReader,
|
SparkXGBModelReader,
|
||||||
SparkXGBModelWriter,
|
SparkXGBModelWriter,
|
||||||
|
SparkXGBReader,
|
||||||
|
SparkXGBWriter,
|
||||||
)
|
)
|
||||||
|
from .params import HasArbitraryParamsDict, HasBaseMarginCol
|
||||||
from .utils import (
|
from .utils import (
|
||||||
get_logger, _get_max_num_concurrent_tasks,
|
|
||||||
_get_default_params_from_func,
|
|
||||||
get_class_name,
|
|
||||||
RabitContext,
|
RabitContext,
|
||||||
_get_rabit_args,
|
|
||||||
_get_args_from_message_list,
|
_get_args_from_message_list,
|
||||||
|
_get_default_params_from_func,
|
||||||
|
_get_gpu_id,
|
||||||
|
_get_max_num_concurrent_tasks,
|
||||||
|
_get_rabit_args,
|
||||||
_get_spark_session,
|
_get_spark_session,
|
||||||
_is_local,
|
_is_local,
|
||||||
_get_gpu_id,
|
get_class_name,
|
||||||
)
|
get_logger,
|
||||||
from .params import (
|
|
||||||
HasArbitraryParamsDict,
|
|
||||||
HasBaseMarginCol,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Put pyspark specific params here, they won't be passed to XGBoost.
|
# Put pyspark specific params here, they won't be passed to XGBoost.
|
||||||
@ -282,10 +277,7 @@ class _SparkXGBParams(
|
|||||||
.get("spark.task.resource.gpu.amount")
|
.get("spark.task.resource.gpu.amount")
|
||||||
)
|
)
|
||||||
|
|
||||||
is_local = _is_local(
|
is_local = _is_local(_get_spark_session().sparkContext)
|
||||||
_get_spark_session()
|
|
||||||
.sparkContext
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_local:
|
if is_local:
|
||||||
# checking spark local mode.
|
# checking spark local mode.
|
||||||
@ -301,7 +293,8 @@ class _SparkXGBParams(
|
|||||||
# gpu numbers and raising the exception.
|
# gpu numbers and raising the exception.
|
||||||
get_logger(self.__class__.__name__).warning(
|
get_logger(self.__class__.__name__).warning(
|
||||||
"You enabled use_gpu in spark local mode. Please make sure your local node "
|
"You enabled use_gpu in spark local mode. Please make sure your local node "
|
||||||
"has at least %d GPUs", self.getOrDefault(self.num_workers)
|
"has at least %d GPUs",
|
||||||
|
self.getOrDefault(self.num_workers),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# checking spark non-local mode.
|
# checking spark non-local mode.
|
||||||
@ -316,7 +309,7 @@ class _SparkXGBParams(
|
|||||||
get_logger(self.__class__.__name__).warning(
|
get_logger(self.__class__.__name__).warning(
|
||||||
"You configured %s GPU cores for each spark task, but in "
|
"You configured %s GPU cores for each spark task, but in "
|
||||||
"XGBoost training, every Spark task will only use one GPU core.",
|
"XGBoost training, every Spark task will only use one GPU core.",
|
||||||
gpu_per_task
|
gpu_per_task,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -545,7 +538,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
"training is greater than current max number of concurrent "
|
"training is greater than current max number of concurrent "
|
||||||
"spark task slots, you need wait until more task slots available "
|
"spark task slots, you need wait until more task slots available "
|
||||||
"or you need increase spark cluster workers.",
|
"or you need increase spark cluster workers.",
|
||||||
num_workers
|
num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._repartition_needed(dataset):
|
if self._repartition_needed(dataset):
|
||||||
@ -568,10 +561,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
booster_params["nthread"] = cpu_per_task
|
booster_params["nthread"] = cpu_per_task
|
||||||
use_gpu = self.getOrDefault(self.use_gpu)
|
use_gpu = self.getOrDefault(self.use_gpu)
|
||||||
|
|
||||||
is_local = _is_local(
|
is_local = _is_local(_get_spark_session().sparkContext)
|
||||||
_get_spark_session()
|
|
||||||
.sparkContext
|
|
||||||
)
|
|
||||||
|
|
||||||
def _train_booster(pandas_df_iter):
|
def _train_booster(pandas_df_iter):
|
||||||
"""
|
"""
|
||||||
@ -584,8 +574,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
context.barrier()
|
context.barrier()
|
||||||
|
|
||||||
if use_gpu:
|
if use_gpu:
|
||||||
booster_params["gpu_id"] = context.partitionId() if is_local \
|
booster_params["gpu_id"] = (
|
||||||
else _get_gpu_id(context)
|
context.partitionId() if is_local else _get_gpu_id(context)
|
||||||
|
)
|
||||||
|
|
||||||
_rabit_args = ""
|
_rabit_args = ""
|
||||||
if context.partitionId() == 0:
|
if context.partitionId() == 0:
|
||||||
|
|||||||
@ -2,8 +2,10 @@
|
|||||||
"""Xgboost pyspark integration submodule for data related functions."""
|
"""Xgboost pyspark integration submodule for data related functions."""
|
||||||
# pylint: disable=too-many-arguments
|
# pylint: disable=too-many-arguments
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from xgboost import DMatrix
|
from xgboost import DMatrix
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -2,12 +2,14 @@
|
|||||||
"""Xgboost pyspark integration submodule for estimator API."""
|
"""Xgboost pyspark integration submodule for estimator API."""
|
||||||
# pylint: disable=too-many-ancestors
|
# pylint: disable=too-many-ancestors
|
||||||
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
|
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
|
||||||
|
|
||||||
from xgboost import XGBClassifier, XGBRegressor
|
from xgboost import XGBClassifier, XGBRegressor
|
||||||
|
|
||||||
from .core import (
|
from .core import (
|
||||||
_SparkXGBEstimator,
|
|
||||||
SparkXGBClassifierModel,
|
SparkXGBClassifierModel,
|
||||||
SparkXGBRegressorModel,
|
SparkXGBRegressorModel,
|
||||||
_set_pyspark_xgb_cls_param_attrs,
|
_set_pyspark_xgb_cls_param_attrs,
|
||||||
|
_SparkXGBEstimator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -5,13 +5,12 @@ import base64
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from pyspark import cloudpickle
|
from pyspark import SparkFiles, cloudpickle
|
||||||
from pyspark import SparkFiles
|
|
||||||
from pyspark.sql import SparkSession
|
|
||||||
from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, MLReader, MLWriter
|
from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, MLReader, MLWriter
|
||||||
|
from pyspark.sql import SparkSession
|
||||||
from xgboost.core import Booster
|
from xgboost.core import Booster
|
||||||
|
|
||||||
from .utils import get_logger, get_class_name
|
from .utils import get_class_name, get_logger
|
||||||
|
|
||||||
|
|
||||||
def _get_or_create_tmp_dir():
|
def _get_or_create_tmp_dir():
|
||||||
@ -250,7 +249,9 @@ class SparkXGBModelReader(MLReader):
|
|||||||
self.cls, path, self.sc, self.logger
|
self.cls, path, self.sc, self.logger
|
||||||
)
|
)
|
||||||
|
|
||||||
xgb_sklearn_params = py_model._gen_xgb_params_dict(gen_xgb_sklearn_estimator_param=True)
|
xgb_sklearn_params = py_model._gen_xgb_params_dict(
|
||||||
|
gen_xgb_sklearn_estimator_param=True
|
||||||
|
)
|
||||||
model_load_path = os.path.join(path, "model.json")
|
model_load_path = os.path.join(path, "model.json")
|
||||||
|
|
||||||
ser_xgb_model = (
|
ser_xgb_model = (
|
||||||
@ -263,8 +264,6 @@ class SparkXGBModelReader(MLReader):
|
|||||||
def create_xgb_model():
|
def create_xgb_model():
|
||||||
return self.cls._xgb_cls()(**xgb_sklearn_params)
|
return self.cls._xgb_cls()(**xgb_sklearn_params)
|
||||||
|
|
||||||
xgb_model = deserialize_xgb_model(
|
xgb_model = deserialize_xgb_model(ser_xgb_model, create_xgb_model)
|
||||||
ser_xgb_model, create_xgb_model
|
|
||||||
)
|
|
||||||
py_model._xgb_sklearn_model = xgb_model
|
py_model._xgb_sklearn_model = xgb_model
|
||||||
return py_model
|
return py_model
|
||||||
|
|||||||
@ -1,15 +1,15 @@
|
|||||||
# type: ignore
|
# type: ignore
|
||||||
"""Xgboost pyspark integration submodule for helper functions."""
|
"""Xgboost pyspark integration submodule for helper functions."""
|
||||||
import inspect
|
import inspect
|
||||||
from threading import Thread
|
|
||||||
import sys
|
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
import pyspark
|
import pyspark
|
||||||
from pyspark.sql.session import SparkSession
|
from pyspark.sql.session import SparkSession
|
||||||
|
from xgboost.tracker import RabitTracker
|
||||||
|
|
||||||
from xgboost import rabit
|
from xgboost import rabit
|
||||||
from xgboost.tracker import RabitTracker
|
|
||||||
|
|
||||||
|
|
||||||
def get_class_name(cls):
|
def get_class_name(cls):
|
||||||
@ -140,13 +140,11 @@ def _get_gpu_id(task_context) -> int:
|
|||||||
"""Get the gpu id from the task resources"""
|
"""Get the gpu id from the task resources"""
|
||||||
if task_context is None:
|
if task_context is None:
|
||||||
# This is a safety check.
|
# This is a safety check.
|
||||||
raise RuntimeError(
|
raise RuntimeError("_get_gpu_id should not be invoked from driver side.")
|
||||||
"_get_gpu_id should not be invoked from driver side."
|
|
||||||
)
|
|
||||||
resources = task_context.resources()
|
resources = task_context.resources()
|
||||||
if 'gpu' not in resources:
|
if "gpu" not in resources:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Couldn't get the gpu id, Please check the GPU resource configuration"
|
"Couldn't get the gpu id, Please check the GPU resource configuration"
|
||||||
)
|
)
|
||||||
# return the first gpu id.
|
# return the first gpu id.
|
||||||
return int(resources['gpu'].addresses[0].strip())
|
return int(resources["gpu"].addresses[0].strip())
|
||||||
|
|||||||
15
tests/ci_build/conda_env/python_lint.yml
Normal file
15
tests/ci_build/conda_env/python_lint.yml
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
name: python_lint
|
||||||
|
channels:
|
||||||
|
- conda-forge
|
||||||
|
dependencies:
|
||||||
|
- python=3.8
|
||||||
|
- wheel
|
||||||
|
- setuptools
|
||||||
|
- mypy
|
||||||
|
- numpy
|
||||||
|
- scipy
|
||||||
|
- pandas
|
||||||
|
- dask
|
||||||
|
- distributed
|
||||||
|
- black
|
||||||
|
- isort
|
||||||
63
tests/ci_build/lint_python.py
Normal file
63
tests/ci_build/lint_python.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from test_utils import DirectoryExcursion
|
||||||
|
|
||||||
|
CURDIR = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
|
||||||
|
PROJECT_ROOT = os.path.normpath(os.path.join(CURDIR, os.path.pardir, os.path.pardir))
|
||||||
|
|
||||||
|
|
||||||
|
def run_formatter(rel_path: str):
|
||||||
|
path = os.path.join(PROJECT_ROOT, rel_path)
|
||||||
|
isort_ret = subprocess.run(["isort", "--check", "--profile=black", path]).returncode
|
||||||
|
black_ret = subprocess.run(
|
||||||
|
["black", "--check", "./python-package/xgboost/dask.py"]
|
||||||
|
).returncode
|
||||||
|
if isort_ret != 0 or black_ret != 0:
|
||||||
|
msg = (
|
||||||
|
"Please run the following command on your machine to address the format"
|
||||||
|
f" errors:\n isort --check --profile=black {rel_path}\n black {rel_path}\n"
|
||||||
|
)
|
||||||
|
print(msg, file=sys.stdout)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def run_mypy(rel_path: str):
|
||||||
|
with DirectoryExcursion(os.path.join(PROJECT_ROOT, "python-package")):
|
||||||
|
path = os.path.join(PROJECT_ROOT, rel_path)
|
||||||
|
ret = subprocess.run(["mypy", path])
|
||||||
|
if ret.returncode != 0:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--format", type=int, choices=[0, 1], default=1)
|
||||||
|
parser.add_argument("--type-check", type=int, choices=[0, 1], default=1)
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.format == 1:
|
||||||
|
if not all(
|
||||||
|
[
|
||||||
|
run_formatter("python-package/xgboost/dask.py"),
|
||||||
|
run_formatter("python-package/xgboost/spark"),
|
||||||
|
]
|
||||||
|
):
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
if args.type_check == 1:
|
||||||
|
if not all(
|
||||||
|
run_mypy(path)
|
||||||
|
for path in [
|
||||||
|
"python-package/xgboost/",
|
||||||
|
"demo/guide-python/external_memory.py",
|
||||||
|
"demo/guide-python/cat_in_the_dat.py",
|
||||||
|
"tests/python/test_data_iterator.py",
|
||||||
|
"tests/python-gpu/test_gpu_with_dask.py",
|
||||||
|
"tests/python-gpu/test_gpu_data_iterator.py",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
sys.exit(-1)
|
||||||
@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from test_utils import DirectoryExcursion
|
||||||
|
|
||||||
ROOT = os.path.normpath(
|
ROOT = os.path.normpath(
|
||||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), os.path.pardir,
|
os.path.join(os.path.dirname(os.path.abspath(__file__)), os.path.pardir,
|
||||||
@ -8,18 +9,6 @@ ROOT = os.path.normpath(
|
|||||||
r_package = os.path.join(ROOT, 'R-package')
|
r_package = os.path.join(ROOT, 'R-package')
|
||||||
|
|
||||||
|
|
||||||
class DirectoryExcursion:
|
|
||||||
def __init__(self, path: os.PathLike):
|
|
||||||
self.path = path
|
|
||||||
self.curdir = os.path.normpath(os.path.abspath(os.path.curdir))
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
os.chdir(self.path)
|
|
||||||
|
|
||||||
def __exit__(self, *args):
|
|
||||||
os.chdir(self.curdir)
|
|
||||||
|
|
||||||
|
|
||||||
def get_mingw_bin():
|
def get_mingw_bin():
|
||||||
return os.path.join('c:/rtools40/mingw64/', 'bin')
|
return os.path.join('c:/rtools40/mingw64/', 'bin')
|
||||||
|
|
||||||
|
|||||||
14
tests/ci_build/test_utils.py
Normal file
14
tests/ci_build/test_utils.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import os
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
|
class DirectoryExcursion:
|
||||||
|
def __init__(self, path: Union[os.PathLike, str]):
|
||||||
|
self.path = path
|
||||||
|
self.curdir = os.path.normpath(os.path.abspath(os.path.curdir))
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
os.chdir(self.path)
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
os.chdir(self.curdir)
|
||||||
Loading…
x
Reference in New Issue
Block a user