使用无UDF的Spark数据集的加权平均值
问题描述:
虽然有人已经询问了有关计算Weighted Average in Spark的问题,但在此问题中,我询问的是使用数据集/数据框而不是RDD。使用无UDF的Spark数据集的加权平均值
如何计算Spark中的加权平均值?我有两列:计数和以前的平均值:
case class Stat(name:String, count: Int, average: Double)
val statset = spark.createDataset(Seq(Stat("NY", 1,5.0),
Stat("NY",2,1.5),
Stat("LA",12,1.0),
Stat("LA",15,3.0)))
我想能够计算的加权平均值是这样的:
display(statset.groupBy($"name").agg(sum($"count").as("count"),
weightedAverage($"count",$"average").as("average")))
可以使用一个UDF亲近:
val weightedAverage = udf(
(row:Row)=>{
val counts = row.getAs[WrappedArray[Int]](0)
val averages = row.getAs[WrappedArray[Double]](1)
val (count,total) = (counts zip averages).foldLeft((0,0.0)){
case((cumcount:Int,cumtotal:Double),(newcount:Int,newaverage:Double))=>(cumcount+newcount,cumtotal+newcount*newaverage)}
(total/count) // Tested by returning count here and then extracting. Got same result as sum.
}
)
display(statset.groupBy($"name").agg(sum($"count").as("count"),
weightedAverage(struct(collect_list($"count"),
collect_list($"average"))).as("average")))
(感谢回答Passing a list of tuples as a parameter to a spark udf in scala帮忙写这)
福利局ies:使用这些进口:
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import scala.collection.mutable.WrappedArray
是否有一种方法可以通过内置列函数而不是UDF来完成此操作? UDF感觉笨重,如果数字变大,你必须将Int's转换成Long's。
答
看起来你可以分两次做到这一点:
val totalCount = statset.select(sum($"count")).collect.head.getLong(0)
statset.select(lit(totalCount) as "count", sum($"average" * $"count"/lit(totalCount)) as "average").show
或者,包括您刚才添加的GROUPBY:
display(statset.groupBy($"name").agg(sum($"count").as("count"),
sum($"count"*$"average").as("total"))
.select($"name",$"count",($"total"/$"count")))
在我实际的代码我有一个GROUPBY ......不过,这可能会工作... –
我会在第二次聚合中添加总数作为另一列,然后在最后进行分割。第二遍需要通过少得多的数据。 –
@MichelLemay:谢谢!这正是我需要慢慢思考的地方。我建议您对答案进行编辑,这也适用于groupBy。 –