[jvm-packages] refine tracker (#10313)
Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021-2022 by Contributors
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -29,7 +29,7 @@ import org.apache.spark.{SparkContext, TaskContext}
|
||||
import org.apache.spark.ml.{Estimator, Model}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
||||
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
|
||||
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
|
||||
import org.apache.spark.sql.functions.{col, collect_list, struct}
|
||||
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
||||
@@ -444,7 +444,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
.groupBy(groupName)
|
||||
.agg(collect_list(struct(schema.fieldNames.map(col): _*)) as "list")
|
||||
|
||||
implicit val encoder = RowEncoder(schema)
|
||||
implicit val encoder = ExpressionEncoder(RowEncoder.encoderFor(schema, false))
|
||||
// Expand the grouped rows after repartition
|
||||
repartitionInputData(groupedDF, nWorkers).mapPartitions(iter => {
|
||||
new Iterator[Row] {
|
||||
|
||||
Reference in New Issue
Block a user