[pyspark] Re-work _fit function (#8630)
This commit is contained in:
parent
beefd28471
commit
d3ad0524e7
@ -3,7 +3,8 @@
|
|||||||
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
||||||
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
|
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
|
||||||
import json
|
import json
|
||||||
from typing import Iterator, Optional, Tuple
|
from collections import namedtuple
|
||||||
|
from typing import Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -21,6 +22,7 @@ from pyspark.ml.param.shared import (
|
|||||||
HasWeightCol,
|
HasWeightCol,
|
||||||
)
|
)
|
||||||
from pyspark.ml.util import MLReadable, MLWritable
|
from pyspark.ml.util import MLReadable, MLWritable
|
||||||
|
from pyspark.sql import DataFrame
|
||||||
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
|
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
|
||||||
from pyspark.sql.types import (
|
from pyspark.sql.types import (
|
||||||
ArrayType,
|
ArrayType,
|
||||||
@ -471,6 +473,12 @@ def _get_unwrapped_vec_cols(feature_col):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
FeatureProp = namedtuple(
|
||||||
|
"FeatureProp",
|
||||||
|
("enable_sparse_data_optim", "has_validation_col", "features_cols_names"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -641,9 +649,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
}
|
}
|
||||||
return booster_params, kwargs_params
|
return booster_params, kwargs_params
|
||||||
|
|
||||||
def _fit(self, dataset):
|
def _prepare_input_columns_and_feature_prop(
|
||||||
# pylint: disable=too-many-statements, too-many-locals
|
self, dataset: DataFrame
|
||||||
self._validate_params()
|
) -> Tuple[List[str], FeatureProp]:
|
||||||
label_col = col(self.getOrDefault(self.labelCol)).alias(alias.label)
|
label_col = col(self.getOrDefault(self.labelCol)).alias(alias.label)
|
||||||
|
|
||||||
select_cols = [label_col]
|
select_cols = [label_col]
|
||||||
@ -698,6 +706,18 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col):
|
if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col):
|
||||||
select_cols.append(col(self.getOrDefault(self.qid_col)).alias(alias.qid))
|
select_cols.append(col(self.getOrDefault(self.qid_col)).alias(alias.qid))
|
||||||
|
|
||||||
|
feature_prop = FeatureProp(
|
||||||
|
enable_sparse_data_optim, has_validation_col, features_cols_names
|
||||||
|
)
|
||||||
|
return select_cols, feature_prop
|
||||||
|
|
||||||
|
def _prepare_input(self, dataset: DataFrame) -> Tuple[DataFrame, FeatureProp]:
|
||||||
|
"""Prepare the input including column pruning, repartition and so on"""
|
||||||
|
|
||||||
|
select_cols, feature_prop = self._prepare_input_columns_and_feature_prop(
|
||||||
|
dataset
|
||||||
|
)
|
||||||
|
|
||||||
dataset = dataset.select(*select_cols)
|
dataset = dataset.select(*select_cols)
|
||||||
|
|
||||||
num_workers = self.getOrDefault(self.num_workers)
|
num_workers = self.getOrDefault(self.num_workers)
|
||||||
@ -732,11 +752,13 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
# XGBoost requires qid to be sorted for each partition
|
# XGBoost requires qid to be sorted for each partition
|
||||||
dataset = dataset.sortWithinPartitions(alias.qid, ascending=True)
|
dataset = dataset.sortWithinPartitions(alias.qid, ascending=True)
|
||||||
|
|
||||||
|
return dataset, feature_prop
|
||||||
|
|
||||||
|
def _get_xgb_parameters(self, dataset: DataFrame):
|
||||||
train_params = self._get_distributed_train_params(dataset)
|
train_params = self._get_distributed_train_params(dataset)
|
||||||
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
|
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
|
||||||
train_params
|
train_params
|
||||||
)
|
)
|
||||||
|
|
||||||
cpu_per_task = int(
|
cpu_per_task = int(
|
||||||
_get_spark_session().sparkContext.getConf().get("spark.task.cpus", "1")
|
_get_spark_session().sparkContext.getConf().get("spark.task.cpus", "1")
|
||||||
)
|
)
|
||||||
@ -749,9 +771,6 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
"missing": float(self.getOrDefault(self.missing)),
|
"missing": float(self.getOrDefault(self.missing)),
|
||||||
}
|
}
|
||||||
booster_params["nthread"] = cpu_per_task
|
booster_params["nthread"] = cpu_per_task
|
||||||
use_gpu = self.getOrDefault(self.use_gpu)
|
|
||||||
|
|
||||||
is_local = _is_local(_get_spark_session().sparkContext)
|
|
||||||
|
|
||||||
# Remove the parameters whose value is None
|
# Remove the parameters whose value is None
|
||||||
booster_params = {k: v for k, v in booster_params.items() if v is not None}
|
booster_params = {k: v for k, v in booster_params.items() if v is not None}
|
||||||
@ -760,7 +779,25 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
}
|
}
|
||||||
dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None}
|
dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None}
|
||||||
|
|
||||||
use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
|
return booster_params, train_call_kwargs_params, dmatrix_kwargs
|
||||||
|
|
||||||
|
def _fit(self, dataset):
|
||||||
|
# pylint: disable=too-many-statements, too-many-locals
|
||||||
|
self._validate_params()
|
||||||
|
|
||||||
|
dataset, feature_prop = self._prepare_input(dataset)
|
||||||
|
|
||||||
|
(
|
||||||
|
booster_params,
|
||||||
|
train_call_kwargs_params,
|
||||||
|
dmatrix_kwargs,
|
||||||
|
) = self._get_xgb_parameters(dataset)
|
||||||
|
|
||||||
|
use_gpu = self.getOrDefault(self.use_gpu)
|
||||||
|
|
||||||
|
is_local = _is_local(_get_spark_session().sparkContext)
|
||||||
|
|
||||||
|
num_workers = self.getOrDefault(self.num_workers)
|
||||||
|
|
||||||
def _train_booster(pandas_df_iter):
|
def _train_booster(pandas_df_iter):
|
||||||
"""Takes in an RDD partition and outputs a booster for that partition after
|
"""Takes in an RDD partition and outputs a booster for that partition after
|
||||||
@ -772,6 +809,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
context = BarrierTaskContext.get()
|
context = BarrierTaskContext.get()
|
||||||
|
|
||||||
gpu_id = None
|
gpu_id = None
|
||||||
|
use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
|
||||||
|
|
||||||
if use_gpu:
|
if use_gpu:
|
||||||
gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
|
gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
|
||||||
booster_params["gpu_id"] = gpu_id
|
booster_params["gpu_id"] = gpu_id
|
||||||
@ -814,12 +853,12 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
with CommunicatorContext(context, **_rabit_args):
|
with CommunicatorContext(context, **_rabit_args):
|
||||||
dtrain, dvalid = create_dmatrix_from_partitions(
|
dtrain, dvalid = create_dmatrix_from_partitions(
|
||||||
pandas_df_iter,
|
pandas_df_iter,
|
||||||
features_cols_names,
|
feature_prop.features_cols_names,
|
||||||
gpu_id,
|
gpu_id,
|
||||||
use_qdm,
|
use_qdm,
|
||||||
dmatrix_kwargs,
|
dmatrix_kwargs,
|
||||||
enable_sparse_data_optim=enable_sparse_data_optim,
|
enable_sparse_data_optim=feature_prop.enable_sparse_data_optim,
|
||||||
has_validation_col=has_validation_col,
|
has_validation_col=feature_prop.has_validation_col,
|
||||||
)
|
)
|
||||||
if dvalid is not None:
|
if dvalid is not None:
|
||||||
dval = [(dtrain, "training"), (dvalid, "validation")]
|
dval = [(dtrain, "training"), (dvalid, "validation")]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user