[jvm-packages] refine tracker (#10313)

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Bobby Wang
2024-05-23 12:46:21 +08:00
committed by GitHub
parent 966dc81788
commit 932d7201f9
8 changed files with 71 additions and 92 deletions

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2021-2022 by Contributors
Copyright (c) 2021-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -29,7 +29,7 @@ import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.functions.{col, collect_list, struct}
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
@@ -444,7 +444,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
.groupBy(groupName)
.agg(collect_list(struct(schema.fieldNames.map(col): _*)) as "list")
implicit val encoder = RowEncoder(schema)
implicit val encoder = ExpressionEncoder(RowEncoder.encoderFor(schema, false))
// Expand the grouped rows after repartition
repartitionInputData(groupedDF, nWorkers).mapPartitions(iter => {
new Iterator[Row] {