首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何将数组传递给星火中的用户定义的聚合函数(UDAF)

如何将数组传递给星火中的用户定义的聚合函数(UDAF)
EN

Stack Overflow用户
提问于 2019-05-31 09:25:42
回答 1查看 377关注 0票数 0

我想将Array作为输入模式传递到一个UDAF中。

我给出的例子非常简单,它只是两个向量之和。实际上,我的用例更复杂,我需要使用一个联新议程。

代码语言:javascript
复制
import sc.implicits._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions._

val df = Seq(
  (1, Array(10.2, 12.3, 11.2)),
  (1, Array(11.2, 12.6, 10.8)),
  (2, Array(12.1, 11.2, 10.1)),
  (2, Array(10.1, 16.0, 9.3)) 
  ).toDF("siteId", "bidRevenue")


class BidAggregatorBySiteId() extends UserDefinedAggregateFunction {

  def inputSchema: StructType = StructType(Array(StructField("bidRevenue", ArrayType(DoubleType))))

  def bufferSchema = StructType(Array(StructField("sumArray", ArrayType(DoubleType))))

  def dataType: DataType = ArrayType(DoubleType)

  def deterministic = true

  def initialize(buffer: MutableAggregationBuffer) = {
      buffer.update(0, Array(0.0, 0.0, 0.0))
      }

  def update(buffer: MutableAggregationBuffer, input: Row) = {
      val seqBuffer = buffer(0).asInstanceOf[IndexedSeq[Double]]
      val seqInput = input(0).asInstanceOf[IndexedSeq[Double]]
      buffer(0) = seqBuffer.zip(seqInput).map{ case (x, y) => x + y }
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
     val seqBuffer1 = buffer1(0).asInstanceOf[IndexedSeq[Double]]
     val seqBuffer2 = buffer2(0).asInstanceOf[IndexedSeq[Double]]
     buffer1(0) = seqBuffer1.zip(seqBuffer2).map{ case (x, y) => x + y }
  }

  def evaluate(buffer: Row) = { 
    buffer
  }
}
val fun = new BidAggregatorBySiteId()

df.select($"siteId", $"bidRevenue" cast(ArrayType(DoubleType)))
.groupBy("siteId").agg(fun($"bidRevenue"))
.show

在“显示”操作之前,所有这些都可以用于转换。但是这个节目引发了一个错误:

org.apache.spark.sql.catalyst.CatalystTypeConverters$ArrayConverter.toCatalystImpl(CatalystTypeConverters.scala:160)的scala.MatchError:WrappedArray(21.4, 24.9, 22.0)

我的数据结构是:

代码语言:javascript
复制
root
 |-- siteId: integer (nullable = false)
 |-- bidRevenue: array (nullable = true)
 |    |-- element: double (containsNull = true)

df.dtypes = Array(String,String) = Array(("siteId","IntegerType"),("bidRevenue","ArrayType(DoubleType,true)“)

坦克给你很有价值的帮助。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-06-02 06:40:15

代码语言:javascript
复制
def evaluate(buffer: Row): Any

一旦一个组被完全处理,就会调用上面的方法来获得最终的结果。在初始化和更新缓冲区的第0次索引时

代码语言:javascript
复制
i.e. buffer(0)  

因此,您需要在末尾返回第0索引值,因为聚合的结果存储在0索引处。

代码语言:javascript
复制
  def evaluate(buffer: Row) = {
    buffer.get(0)
  }

上述对评估()方法的修改将导致:

代码语言:javascript
复制
// +------+---------------------------------+
// |siteId|bidaggregatorbysiteid(bidRevenue)|
// +------+---------------------------------+
// |     1|               [21.4, 24.9, 22.0]|
// |     2|               [22.2, 27.2, 19.4]|
// +------+---------------------------------+
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56392286

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档